1. 程式人生 > >影象上取樣

影象上取樣

記錄常使用的函式避免遺忘

def upsample(x,scale=2,features=64,activation=tf.nn.relu):
	assert scale in [2,3,4]
	x = slim.conv2d(x,features,[3,3],activation_fn=activation)
	if scale == 2:
		ps_features = 3*(scale**2) #filter個數,[3,3]卷積核維度
		x = slim.conv2d(x,ps_features,[3,3],activation_fn=activation)
		#x = slim.conv2d_transpose(x,ps_features,6,stride=1,activation_fn=activation)
		x = PS(x,2,color=True)
	elif scale == 3:
		ps_features =3*(scale**2) #特徵圖個數發生改變 64變成12
		x = slim.conv2d(x,ps_features,[3,3],activation_fn=activation)
		#x = slim.conv2d_transpose(x,ps_features,9,stride=1,activation_fn=activation)
		x = PS(x,3,color=True)
	elif scale == 4:
		ps_features = 3*(2**2)
		for i in range(2):
			x = slim.conv2d(x,ps_features,[3,3],activation_fn=activation)
			#x = slim.conv2d_transpose(x,ps_features,6,stride=1,activation_fn=activation)
			x = PS(x,2,color=True)
	return x

def PS(X, r, color=False):
	if color:
		Xc = tf.split(X, 3, 3) #將x在第3個維度切成3份 10*50*50*12切割成 10*50*50*4
        #value:準備切分的張量; num_or_size_splits:準備切成幾份; axis : 準備在第幾個維度上進行切割
		X = tf.concat([_phase_shift(x, r) for x in Xc],3) #對每一個通道填充畫素
	else:
		X = _phase_shift(X, r)
	return X

def _phase_shift(I, r):
	bsize, a, b, c = I.get_shape().as_list()# bsize = 10, a=50, b=50, c=4
	bsize = tf.shape(I)[0] # Handling Dimension(None) type for undefined batch dim
	X = tf.reshape(I, (bsize, a, b, r, r))
	X = tf.transpose(X, (0, 1, 2, 4, 3))  # bsize, a, b, 1, 1
	X = tf.split(X, a, 1)  # a * [bsize, b, r, r] 
    #tf.squeeze函式
    #從tensor中刪除所有大小是1的維度,axis可以用來指定要刪掉的為1的維度,但指定的維度必須確保其是1,否則會報錯
	X = tf.concat([tf.squeeze(x, axis=1) for x in X],2)  # bsize, b, a*r, r
	X = tf.split(X, b, 1)  # b * [bsize, a*r, r]
	X = tf.concat([tf.squeeze(x, axis=1) for x in X],2)  # bsize, a*r, b*r
	return tf.reshape(X, (bsize, a*r, b*r, 1))

def my_anti_shuffle(input_image, ratio):
    shape = input_image.shape
    ori_height = int(shape[0])
    ori_width = int(shape[1])
    ori_channels = int(shape[2])
    if ori_height % ratio != 0 or ori_width % ratio != 0:
        print("Error! Height and width must be divided by ratio!")
        return
    height = ori_height // ratio
    width = ori_width // ratio
    channels = ori_channels * ratio * ratio
    anti_shuffle = np.zeros((height, width, channels), dtype=np.uint8)
    for c in range(0, ori_channels):
        for x in range(0, ratio):
            for y in range(0, ratio):
                anti_shuffle[:,:,c * ratio * ratio + x * ratio + y] = input_image[x::ratio, y::ratio, c]#每ratio取樣一次
    return anti_shuffle

def shuffle(input_image, ratio):
    shape = input_image.shape
    height = int(shape[0]) * ratio
    width = int(shape[1]) * ratio
    channels = int(shape[2]) / ratio / ratio
    shuffled = np.zeros((height, width, channels), dtype=np.uint8)
    for i in range(0, height):
        for j in range(0, width):
            for k in range(0, channels):
                shuffled[i,j,k] = input_image[i / ratio, j / ratio, k * ratio * ratio + (i % ratio) * ratio + (j % ratio)]
    return shuffled