Tensorflow學習筆記:多輸入線性迴歸神經網路
阿新 • • 發佈:2019-01-02
def myregression(): """ 實現多維輸入的線性迴歸神經網路 假設輸入為x = [a,b,c,d,e,f,g],正確答案為:y_true 則y_true = x1 * a + x2 * b + x3 * c + x4 * d + x5 * e + x6 * f + x7 * g + 1 其中x取值為[0.1,0.2,0.3,0.4,0.5,0.6,0.7] """ #1、生成輸入與標準答案資料集,x是一個[100,7]矩陣,應該乘以[7,1]的矩陣得到[100,1]的答案矩陣 x = tf.random_normal([100,7],mean = 1.75,stddev = 0.5,name = 'input_data_x') y_true = tf.matmul(x,[[0.1],[0.2],[0.3],[0.4],[0.5],[0.6],[0.7]]) + 1 #2、生成線性迴歸引數,計算預測結果 w = tf.Variable(tf.random_normal([7,1],mean = 2,stddev=0.2,)) b = tf.Variable(0.0) y_predict = tf.matmul(x,w) + b #3、計算loss loss = tf.reduce_mean(tf.square(y_predict - y_true)) #4、反向傳播優化引數 train_op = tf.train.GradientDescentOptimizer(0.04).minimize(loss) init_op = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init_op) for i in range(100000): sess.run(train_op) print(i,end=' ') tmp = w.eval() for j in tmp: print(j,end=' ') print(b.eval()) if __name__ == "__main__": import tensorflow as tf myregression()
新手程式碼,如有不當,請多指正
運算結果:在10000次訓練之後,引數很好的向正確值收斂