1. 程式人生 > >Tensorflow學習筆記:多輸入線性迴歸神經網路

Tensorflow學習筆記:多輸入線性迴歸神經網路

def myregression():
	"""
	實現多維輸入的線性迴歸神經網路
	假設輸入為x = [a,b,c,d,e,f,g],正確答案為:y_true
	則y_true = x1 * a + x2 * b + x3 * c + x4 * d + x5 * e + x6 * f + x7 * g + 1
		其中x取值為[0.1,0.2,0.3,0.4,0.5,0.6,0.7]
	"""
	#1、生成輸入與標準答案資料集,x是一個[100,7]矩陣,應該乘以[7,1]的矩陣得到[100,1]的答案矩陣
	x = tf.random_normal([100,7],mean = 1.75,stddev = 0.5,name = 'input_data_x')
	y_true = tf.matmul(x,[[0.1],[0.2],[0.3],[0.4],[0.5],[0.6],[0.7]]) + 1

	#2、生成線性迴歸引數,計算預測結果
	w = tf.Variable(tf.random_normal([7,1],mean = 2,stddev=0.2,))
	b = tf.Variable(0.0)

	y_predict = tf.matmul(x,w) + b

	#3、計算loss
	loss = tf.reduce_mean(tf.square(y_predict - y_true))

	#4、反向傳播優化引數
	train_op = tf.train.GradientDescentOptimizer(0.04).minimize(loss)

	init_op = tf.global_variables_initializer()

	with tf.Session() as sess:
		sess.run(init_op)
		for i in range(100000):
			sess.run(train_op)
			print(i,end=' ')
			tmp = w.eval()
			for j in tmp:
				print(j,end=' ')
			print(b.eval())

if __name__ == "__main__":
	import tensorflow as tf
	myregression()

新手程式碼,如有不當,請多指正

運算結果:在10000次訓練之後,引數很好的向正確值收斂