1. 程式人生 > >tensorflow 神經網路解決二分類問題

tensorflow 神經網路解決二分類問題

import tensorflow as tf from numpy.random import RandomState # 定義訓練資料batch大小 batch_size = 8 # 定義神經網路的引數 w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1)) w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1)) # 在shape的一個維度上使用None可以方便使用不大的batch大小,在訓練時需要把資料分 # 成比較小的batch,但是在測試時,可以一次性使用全部的資料。當資料集比較小時這樣比較 # 方便測試,但資料集比價大時,將大量資料放入一個batch可能會導致記憶體溢位。 x = tf.placeholder(tf.float32, shape=(None, 2), name='x-input') y_ = tf.placeholder(tf.float32, shape=(None, 1), name='y-input') # 定義神經網路前向傳播的過程 a = tf.matmul(x, w1) y = tf.matmul(a, w2) # 定義損失函式和反向傳播演算法 cross_entropy = -tf.reduce_mean(     y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy) # 通過隨機數生成一個模型資料集 rdm = RandomState(1) dataset_size = 128 X = rdm.rand(dataset_size, 2) # 定義規則來給出樣本的標籤。在這裡所有x1+x2<1的樣例都被認為是正樣本(比如零件合格), # 而其他為負樣本(比如零件不合格)。和Tensorflow遊樂場中的表示法不大一樣的地方是, # 這裡使用0來表示負樣本,1來表示正樣本。大部分解決分類問題的神經網路都會採用0和1的表示方法。 Y = [[int(x1+x2 < 1)] for (x1, x2) in X ] # 建立一個會話來執行Tensorflow程式 with tf.Session() as sess:     init_op = tf.initialize_all_variables() # 初始化變數     sess.run(init_op)     print sess.run(w1)     print sess.run(w2)     '''在訓練之前神經網路引數的值為:     ''' # 設定訓練的輪數 STEPS = 5000 for i in range(STEPS):     # 每次選取batch_size個樣本進行訓練。     start = (i * batch_size) % dataset_size     end = min(start+batch_size, dataset_size)     # 通過選取的樣本訓練神經網路並更新引數。     sess.run(train_step,              feed_dict={x: X[start:end], y_: Y[start:end]})     if i % 1000 == 0:         #每隔一段時間計算在所有資料上的交叉熵並輸出。         total_cross_entropy = sess.run(             cross_entropy, feed_dict={x: X, y_: Y})         print ("After %d training step(s), cross entropy on all data is % g" % (i, total_cross_entropy))         '''         輸出結果:         '''     print sess.run(w1)     print sess.run(w2)     '''     在訓練之後神經網路引數的值:     ''' 上面的程式實現了訓練神經網路的全部過程。從中可以總結出訓練神經網路的過程可以分為以下3步: 1、定義神經網路的結構和前向傳播的過程。 2、定義損失函式以及選擇反向傳播優化的演算法。 3、生成會話(tf.Session)並且在訓練資料上反覆執行反向傳播優化演算法。 無論神經網路的結構如何變化,這3個步驟是不變的