1. 程式人生 > >Tensorflow:Android呼叫Tensorflow Mobile版本API(1)-訓練一個網路

Tensorflow:Android呼叫Tensorflow Mobile版本API(1)-訓練一個網路

在這裡,我們訓練一個網路來擬合 y=x^2+1
程式碼如下:

# y = x^2 + 1

import tensorflow as tf
import numpy as np
import random

def get_batch(size=128):
    xs = []
    ys = []
    for i in range(size):
        x = random.random() * 2
        y = x * x + 1
        xs.append(x)
        ys.append(y)
    return np.array(xs), np.array(ys)




X = tf.placeholder(tf.float32, [None
,1], name='input') Y = tf.placeholder(tf.float32, [None,1]) def my_dnn(): x = tf.reshape(X, shape=[-1, 1]) w1 = tf.Variable(tf.random_normal(shape=[1,256], mean=0.0, stddev=1)) b1 = tf.Variable(tf.random_normal([256])) out1 = tf.nn.bias_add(tf.matmul(x,w1),b1) out1 = tf.nn.relu(out1) w2= tf.Variable(tf.random_normal(shape=[256
,256])) b2 = tf.Variable(tf.random_normal([256])) out2= tf.nn.bias_add(tf.matmul(out1, w2),b2) out2 = tf.nn.relu(out2) w3 = tf.Variable(tf.random_normal(shape=[256, 256])) b3 = tf.Variable(tf.random_normal([256])) out3 = tf.nn.bias_add(tf.matmul(out2, w3),b3) out3 = tf.nn.relu(out3) w4 = tf.Variable(tf.random_normal(shape=[256
, 1])) b4 = tf.Variable(tf.random_normal([1])) out4 = tf.nn.bias_add(tf.matmul(out3, w4), b4, name='output') return out4 def train(): out = my_dnn() loss = tf.reduce_mean(tf.square(Y - out)) optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.initialize_all_variables()) step = 0 while True: batch_x, batch_y = get_batch(64) batch_x = batch_x.reshape([-1, 1]) batch_y = batch_y.reshape([-1, 1]) _, loss_ = sess.run([optimizer, loss], feed_dict={X:batch_x, Y:batch_y}) print(loss_) if loss_ < 0.0001: saver.save(sess, "./1.model", global_step=step) break step += 1 # train() def eval(): out = my_dnn() saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, tf.train.latest_checkpoint('.')) for i in range(100): x = random.random() * 2 x = np.array([x]).reshape([-1,1]) y = sess.run(out, feed_dict={X:x}) print("x=%.5f 正確的y=%.5f 預測的 y=%.5f" % (x, x*x + 1, y)) if __name__ == '__main__': # 訓練 # train() # 評估 # eval()