1. 程式人生 > >tensorflow 自己手動實現的線性迴歸

tensorflow 自己手動實現的線性迴歸

tensorflow 自己手動實現的線性迴歸

#!/usr/bin/python
# -*- coding:utf-8 -*-

import tensorflow as tf
import os

# 第二個引數是預設值
tf.app.flags.DEFINE_integer("max_iter", 100, "迭代次數")
tf.app.flags.DEFINE_string("model_dir", "./tmp/ckpt/model", "模型路徑")
tf.app.flags.DEFINE_string("summary_dir", "./tmp/test/", "graph路徑")
tf.app.flags.DEFINE_string("checkpoint_dir", "./tmp/ckpt/checkpoint", "模型路徑")

FLAGS=tf.app.flags.FLAGS
def mylineregression():
    with tf.variable_scope("data"):
        x=tf.random_normal([100,1],0.0,1.0)
        y=tf.multiply(x,[[0.7]])+0.8
    with tf.variable_scope("model"):
        weight=tf.Variable(tf.random_normal([1,1],0.0,1.0))
        bias=tf.Variable(0.0)
        y_predict=tf.multiply(x,weight)+bias
    with tf.variable_scope("loss"):
        loss=tf.reduce_mean(tf.square(y-y_predict))
    with tf.variable_scope("optimizer"):
        train_op=tf.train.GradientDescentOptimizer(0.1).minimize(loss)


    init_value=tf.global_variables_initializer()

    saver=tf.train.Saver()

    tf.summary.scalar("losses",loss)
    tf.summary.histogram("weight",weight)
    tf.summary.histogram("bias",bias)

    merged=tf.summary.merge_all()

    with tf.Session() as sess:
        sess.run(init_value)
        filwriter=tf.summary.FileWriter(FLAGS.summary_dir, graph=sess.graph)
        # print(sess.run([weight,bias]))
        # 載入模型,覆蓋變數的值
        if os.path.exists(FLAGS.checkpoint_dir):
            saver.restore(sess,FLAGS.model_dir)
        for i in range(FLAGS.max_iter):
            print("第%d次訓練引數weight:%f,bias:%f"%(i,weight.eval(),bias.eval()))
            # print(y_predict.eval())
            summary=sess.run(merged)
            filwriter.add_summary(summary,i)
            sess.run(train_op)

        tf.summary.FileWriter(FLAGS.summary_dir,graph=sess.graph)
        # 儲存模型
        # saver.save(sess,"./tmp/ckpt/model")


    return None
if __name__ == '__main__':
    print("hello")
    mylineregression()