tensorflow 中,修改張量tensor特定元素的值
阿新 • • 發佈:2018-11-10
tensorflow中:
constant tensor不能直接賦值,否則會報錯:
TypeError:
'Tensor'
object
does
not
support item assignment
Variable
賦值命令為:new_state = tf.assign(state,new_tensor)
為了解決張量元素賦值問題,上網檢視解決方法,主要是針對一維張量的,主要是用one_hot來實現,如文章:張量元素值修改,
本文參照其思想,實現2D張量元素替換,如果想實現更高維度的張量的元素賦值,可以在此基礎上再進行修改。
指令碼如下,親測可行,如有紕漏,歡迎大家拍磚指教。
#-*-coding:utf-8-*- import tensorflow as tf sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) def tensor_expand(tensor_Input,Num): ''' 張量自我複製擴充套件,將Num個tensor_Input串聯起來,生成新的張量, 新的張量的shape=[tensor_Input.shape,Num] :param tensor_Input: :param Num: :return: ''' tensor_Input = tf.expand_dims(tensor_Input,axis=0) tensor_Output = tensor_Input for i in range(Num-1): tensor_Output= tf.concat([tensor_Output,tensor_Input],axis=0) return tensor_Output def get_one_hot_matrix(height,width,position): ''' 生成一個 one_hot矩陣,shape=【height*width】,在position處的元素為1,其餘元素為0 :param height: :param width: :param position: 格式為【h_Index,w_Index】,h_Index,w_Index為int格式 :return: ''' col_length = height row_length = width col_one_position = position[0] row_one_position = position[1] rows_num = height cols_num = width single_row_one_hot = tf.one_hot(row_one_position, row_length, dtype=tf.float32) single_col_one_hot = tf.one_hot(col_one_position, col_length, dtype=tf.float32) one_hot_rows = tensor_expand(single_row_one_hot, rows_num) one_hot_cols = tensor_expand(single_col_one_hot, cols_num) one_hot_cols = tf.transpose(one_hot_cols) one_hot_matrx = one_hot_rows * one_hot_cols return one_hot_matrx def tensor_assign_2D(tensor_input,position,value): ''' 給 2D tensor的特定位置元素賦值 :param tensor_input: 輸入的2D tensor,目前只支援2D :param position: 被賦值的張量元素的座標位置,=【h_index,w_index】 :param value: :return: ''' shape = tensor_input.get_shape().as_list() height = shape[0] width = shape[1] h_index = position[0] w_index = position[1] one_hot_matrix = get_one_hot_matrix(height, width, position) new_tensor = tensor_input - tensor_input[h_index,w_index]*one_hot_matrix +one_hot_matrix*value return new_tensor if __name__=="__main__": ##test tensor_input = tf.constant([i for i in range(20)],tf.float32) tensor_input = tf.reshape(tensor_input,[4,5]) new_tensor = tensor_assign_2D(tensor_input,[2,3],100) print(new_tensor.eval())
本文原創,如需轉載,請註明出處:https://blog.csdn.net/Strive_For_Future/article/details/82426015