1. 程式人生 > >Tensorflow初探之MNIST資料集學習

Tensorflow初探之MNIST資料集學習

MNIST資料集是手寫數字0~9的資料集,一般被用作機器學習領域的測試,相當於HelloWorld級別。

本程式先從網上匯入資料,再利用最小梯度法進行訓練使得樣本交叉熵最小,最後給出訓練之後程式的準確率。

交叉熵的定義:


y 是我們預測的概率分佈, y' 是實際的分佈。

該指標用來衡量學習結果與實際情況的差距。

import tensorflow.examples.tutorials.mnist.input_data as input_data
import tensorflow as tf
#initialize
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
#train
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
#predict
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))


先用placeholder和variable初始化。

placeholder(佔位符)一般用於匯入資料,而Variable一般是與學習過程相關的變數。

之後,啟動Session進行訓練,本程式當中隨機選取了100張圖片進行訓練。

最後執行的準確率為91.49%.

/home/jcole/PycharmProjects/MNIST-HelloWorld/venv/bin/python /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/MNIST.py
WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/MNIST.py:4: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:219: retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please use urllib or similar directly.
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/lib/python3.5/site-packages/tensorflow/python/util/tf_should_use.py:118: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
2018-04-23 16:27:25.917321: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
0.9149

Process finished with exit code 0