1. 程式人生 > >tensorflow 中,修改張量tensor特定元素的值

tensorflow 中,修改張量tensor特定元素的值

tensorflow中:

                   constant tensor不能直接賦值,否則會報錯:

               TypeError: 'Tensor' object does not support item assignment

       Variable

tensor不能為某個特定元素賦值,只能為整個變數tensor全部賦值。

                                        賦值命令為:new_state = tf.assign(state,new_tensor)

 為了解決張量元素賦值問題,上網檢視解決方法,主要是針對一維張量的,主要是用one_hot來實現,如文章:張量元素值修改,

本文參照其思想,實現2D張量元素替換,如果想實現更高維度的張量的元素賦值,可以在此基礎上再進行修改。

指令碼如下,親測可行,如有紕漏,歡迎大家拍磚指教。

#-*-coding:utf-8-*-
import tensorflow as tf
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

def tensor_expand(tensor_Input,Num):
    '''
    張量自我複製擴充套件,將Num個tensor_Input串聯起來,生成新的張量,
    新的張量的shape=[tensor_Input.shape,Num]
    :param tensor_Input:
    :param Num:
    :return:
    '''
    tensor_Input = tf.expand_dims(tensor_Input,axis=0)
    tensor_Output = tensor_Input
    for i in range(Num-1):
        tensor_Output= tf.concat([tensor_Output,tensor_Input],axis=0)
    return tensor_Output

def get_one_hot_matrix(height,width,position):
    '''
    生成一個 one_hot矩陣,shape=【height*width】,在position處的元素為1,其餘元素為0
    :param height:
    :param width:
    :param position: 格式為【h_Index,w_Index】,h_Index,w_Index為int格式
    :return:
    '''
    col_length = height
    row_length = width
    col_one_position = position[0]
    row_one_position = position[1]
    rows_num = height
    cols_num = width

    single_row_one_hot = tf.one_hot(row_one_position, row_length, dtype=tf.float32)
    single_col_one_hot = tf.one_hot(col_one_position, col_length, dtype=tf.float32)

    one_hot_rows = tensor_expand(single_row_one_hot, rows_num)
    one_hot_cols = tensor_expand(single_col_one_hot, cols_num)
    one_hot_cols = tf.transpose(one_hot_cols)

    one_hot_matrx = one_hot_rows * one_hot_cols
    return one_hot_matrx

def tensor_assign_2D(tensor_input,position,value):
    '''
    給 2D tensor的特定位置元素賦值
    :param tensor_input: 輸入的2D tensor,目前只支援2D
    :param position: 被賦值的張量元素的座標位置,=【h_index,w_index】
    :param value:
    :return:
    '''
    shape = tensor_input.get_shape().as_list()
    height = shape[0]
    width = shape[1]
    h_index = position[0]
    w_index = position[1]
    one_hot_matrix = get_one_hot_matrix(height, width, position)

    new_tensor = tensor_input - tensor_input[h_index,w_index]*one_hot_matrix +one_hot_matrix*value

    return new_tensor

if __name__=="__main__":
    ##test
    tensor_input = tf.constant([i for i in range(20)],tf.float32)
    tensor_input = tf.reshape(tensor_input,[4,5])
    new_tensor = tensor_assign_2D(tensor_input,[2,3],100)

    print(new_tensor.eval())

本文原創,如需轉載,請註明出處:https://blog.csdn.net/Strive_For_Future/article/details/82426015