深度學習框架tensorflow學習與應用8(tensorboard網路結構)
阿新 • • 發佈:2018-12-29
#載入資料集 mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #如果沒有就下載,然後以獨熱碼的形式載入,有的話就不下載 #每個批次的大小 batch_size =100 #計算一共有多少個批次 n_batch = mnist.train.num_examples// batch_size with tf.name_scope('input'): #定義兩個placeholder x = tf.placeholder(tf.float32,[None,784],name='x-input') y = tf.placeholder(tf.float32,[None,10],name='y-input') #建立一個簡單的神經網路 with tf.name_scope('layer'): with tf.name_scope('weights'): W = tf.Variable(tf.zeros([784,10]),name='W') with tf.name_scope('biases'): b = tf.Variable(tf.zeros([10]),name="b") with tf.name_scope('wx_plus_b'): wx_plus_b = tf.matmul(x,W)+b with tf.name_scope('softmax'): prediction = tf.nn.softmax(wx_plus_b) with tf.name_scope('lose'): #二次代價函式 loss = tf.reduce_mean(tf.square(y - prediction)) with tf.name_scope('train'): #使用梯度下降法 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化變數 init = tf.global_variables_initializer() with tf.name_scope('accuracy'): with tf.name_scope('correct_prefdiction'): #結果存放在一個布林型列表中 correct_prediction =tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一維張量中最大的的值所在的位置 with tf.name_scope('accuracy'): #求準確率 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) with tf.Session() as sess: sess.run(init) wirter = tf.summary.FileWriter('logs/',sess.graph) for epoch in range(1): for batch in range(n_batch): batch_xs,batch_ys = mnist.train.next_batch(batch_size) sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}) #acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}) print("Tter"+str(epoch)+",Testing Accuracy"+str(acc))
實線:資料傳輸
粗細:表示的是兩個節點之間傳輸的標量維度。
使用以下程式碼可以看多更多的點:
for epoch in range(51): for batch in range(n_batch): batch_xs, batch_ys = mnist.train.next_batch(batch_size) # sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) summary, _ = sess.run([merged, train_step], feed_dict={x: batch_xs, y: batch_ys}) wirter.add_summary(summary, epoch) acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}) print("Tter" + str(epoch) + ",Testing Accuracy" + str(acc))