1. 程式人生 > >TensorFlow(四)——MNIST分類之CNN

TensorFlow(四)——MNIST分類之CNN

import input_data
import tensorflow as tf
import numpy as np

mnist = input_data.read_data_sets('data/', one_hot=True)

trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

#-1 不考慮數量,28×28畫素,1為通道
trX = trX.reshape(-1, 28, 28, 1)
teX = teX.reshape(-1, 28, 28, 1)

X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])

#3卷基層,3池化層,1全連線,1輸出層
def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev=0.01))
w = init_weights([3, 3, 1, 32])
w2 = init_weights([3, 3, 32, 64])
w3 = init_weights([3, 3, 64, 128])
w4 = init_weights([128 * 4 *4, 625])
w_o = init_weights([625, 10])

#定義模型函式
#X:輸入資料,w:權重,p_keep_conv,p_keep_hidden:dropout保留的神經元比例
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
    #第一層卷集和池化,然後dropout
    l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
    l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    l1 = tf.nn.dropout(l1, p_keep_conv)
    
    l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
    l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    l2 = tf.nn.dropout(l2, p_keep_conv)
    
    l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
    l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]])
    l3 = tf.nn.dropout(l3, p_keep_conv)
    
    #全連線層
    l4 = tf.nn.relu(tf.matmul(l3, w4))
    l4 = tf.nn.dropout(l4, p_keep_hidden)
    
    #輸出層
    pyx = tf.matmul(l4, w_o)
    return pyx

p_keep_conv = tf.placeholder('float')
p_keep_hidden = tf.placeholder('float')
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden)

#定義損失函式
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)

batch_size = 128
test_size = 256

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    
    for i in range(100):
        training_batch = zip(range(0, len(trX), batch_size),
                            range(batch_size, len(trX)+1, batch_size))
        for start, end in training_batch:
            sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
                                         p_keep_conv: 0.8, p_keep_hidden: 0.5})
        test_indices = np.arange(len(teX))
        np.random.shuffle(test_indices)
        test_indices = test_indices[0:test_size]
        
        print (i, np.mean(np.argmax(teY[test_indices], axis=1) == 
                         sess.run(predict_op, feed_dict={X: teX[test_indices],
                                                        p_keep_conv: 1.0,
                                                        p_keep_hidden: 1.0})))

結果:

0 0.93359375
1 0.9765625
2 0.9765625
3 0.9921875
4 0.984375
5 0.9921875
6 1.0
7 0.98828125
8 0.984375
9 0.98046875
10 0.99609375
11 0.98828125
12 1.0
13 0.98828125
14 0.9921875
15 0.99609375
16 0.9921875
17 0.99609375
18 0.99609375
19 0.98828125
20 0.9921875
21 1.0
22 0.9921875
23 0.9921875
24 0.98828125
25 0.99609375
26 0.99609375
27 0.98046875
28 0.98828125
29 0.9921875
30 1.0
31 1.0
32 0.99609375
33 0.98828125
34 0.984375
35 1.0
36 0.984375
37 0.99609375
38 1.0
39 0.99609375
40 0.9921875
41 0.97265625
42 1.0
43 0.9921875
44 0.99609375
45 0.984375
46 1.0
47 1.0
48 0.98828125
49 0.9765625
50 0.9921875
51 1.0
52 0.98828125
53 0.98828125
54 0.9921875
55 0.99609375
56 1.0
57 0.99609375
58 1.0
59 0.9921875
60 0.99609375
61 0.98828125
62 0.9921875
63 0.9921875
64 0.9921875
65 0.98828125
66 0.99609375
67 0.99609375
68 0.984375
69 1.0
70 0.98828125
71 0.98828125
72 0.99609375
73 1.0
74 1.0
75 0.99609375
76 0.98828125
77 0.9921875
78 0.98828125
79 0.9921875
80 1.0
81 0.99609375
82 0.99609375
83 0.98828125
84 0.984375
85 0.98828125
86 0.99609375
87 0.99609375
88 0.9921875
89 0.99609375
90 0.98828125
91 0.99609375
92 1.0
93 0.9921875
94 0.98046875
95 1.0
96 0.99609375
97 0.984375
98 0.9921875
99 0.99609375