1. 程式人生 > >基於tensorflow的簡單線性回歸模型

基於tensorflow的簡單線性回歸模型

AC turn png cti ret type predict supports on()

#!/usr/local/bin/python3

##ljj [1]

##linear regression model

import tensorflow as tf

import matplotlib.pyplot as plt

x_ = [11,14,22,29,32,40,44,55,59,60,69,77]

y_res = [123,135,155,167,177,189,200,240,250,255,277,298]

w = tf.Variable(tf.ones([1]),dtype="float32")

b = tf.Variable(tf.ones([1]),dtype="float32")

y_predict = tf.placeholder(tf.float32)

x = tf.placeholder(tf.float32)

with tf.Session() as sess:

y_predict = w*x+b

loss = tf.reduce_mean(tf.square(y_res-y_predict))

train = tf.train.AdamOptimizer(0.7).minimize(loss)

sess.run(tf.global_variables_initializer())

for i in range(len(x_)):

# train.run(feed_dict={x:x_[i], y_predict:y_res[i]})

w_,b_,_= sess.run([w,b,train],feed_dict={x:x_[i], y_predict:y_res[i]})

print(w_,b_)

plt.plot(x_,y_res,‘.‘)

plt.plot(x_,x_*w_+b_,‘-‘)

plt.show()

主機環境:MacbookPro,tensoflow版本1.4

輸出結果:

ljjdeMBP:linear_regression lingjiajun$ ./linear_regression.py

/usr/local/Cellar/python3/3.6.2/Frameworks/Python.framework/Versions/3.6/lib/python3.6/importlib/_bootstrap.py:205: RuntimeWarning: compiletime version 3.5 of module ‘tensorflow.python.framework.fast_tensor_util‘ does not match runtime version 3.6

return f(*args, **kwds)

2018-04-27 22:20:05.963003: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA

[ 4.78998518] [ 5.67698431]

-------以上輸出分別是擬合出的Weight,Bias值。

技術分享圖片

基於tensorflow的簡單線性回歸模型