[AI教程]TensorFlow入門:Simple Linear Model
介紹
本文演示了使用簡單線性模型瞭解TensorFlow的基本工作流程。 資料集:MNIST資料集 工具:TensorFlow 1.9.0 + Python 3.6.3 方法:簡單線性模型
1、import
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from sklearn.metrics import confusion_matrix
# Tensorflow's version
tf.__version__
2、Load Data
# The MNIST data-set is about 12 MB
# and will be downloaded automatically if it is not located in the given path.
from tensorflow.examples.tutorial
s.mnist import input_data
data = input_data.read_data_sets("data/MNIST/", one_hot=True)
MNIST資料集現在已經載入並由70個影象和相關聯的標籤(即影象的分類)組成。資料集被分成3個互斥子集。在本文中,我們只使用訓練和測試集。
print("Size of:")
print("- Training-set:\t\t{}" .format(len(data.train.labels)))
print("- Test-set:\t\t{}".format(len(data.test.labels)))
print("- Validation-set:\t{}".format(len(data.validation.labels)))
輸出結果如下: Size of: – Training-set: 55000 – Test-set: 10000 – Validation-set: 5000
3、One-Hot Encoding
資料集已被載入為一個熱編碼。這意味著標籤已經從一個單一的數字轉換為一個向量,其長度等於類的數量。向量的所有元素都是零,除了第 "i"個元素是 "1"並且意味著類是 “i”。例如,測試集中前5個影象的一個熱編碼標籤是:
data.test.labels[0:5, :]
輸出結果如下:
array([[ 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.], [ 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.], [ 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.], [ 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])
為了進行各種比較和效能度量,我們還需要將類作為單個數字,因此我們採用最高元素的索引將One-Hot編碼向量轉換為單個數字。注意,“類”一詞是Python中使用的關鍵字,因此我們需要使用“CLS”的名稱。
data.test.cls = np.array([label.argmax() for label in data.test.labels])
現在我們可以看到測試集前五個影象的類。將它們與上面的一個熱編碼向量進行比較。例如,第一影象的類是7,它對應於一個熱編碼向量,其中所有元素都是零,除了具有索引7的元素。
data.test.cls[0:5]
輸出結果如下:
array([7, 2, 1, 0, 4], dtype=int64)
4、Data dimensions
資料維度在下面的程式碼中的多個地方使用。在計算機程式設計中,通常最好使用變數和常量,而不是每次使用該數字時都要編碼特定的數字。這意味著數字只需要在一個地方改變。
# We know that MNIST images are 28 pixels in each dimension.
img_size = 28
# Images are stored in one-dimensional arrays of this length.
img_size_flat = img_size * img_size
# Tuple with height and width of images used to reshape arrays.
img_shape = (img_size, img_size)
# Number of classes, one class for each of 10 digits.
num_classes = 10
5、Helper-function for plotting images
函式用於在3x3網格中繪製9幅影象,並在每個影象下寫入真實和預測的類。
def plot_images(images, cls_true, cls_pred=None):
assert len(images) == len(cls_true) == 9
# Create figure with 3x3 sub-plots.
fig, axes = plt.subplots(3, 3)
fig.subplots_adjust(hspace=0.3, wspace=0.3)
for i, ax in enumerate(axes.flat):
# Plot image.
ax.imshow(images[i].reshape(img_shape), cmap='binary')
# Show true and predicted classes.
if cls_pred is None:
xlabel = "True: {0}".format(cls_true[i])
else:
xlabel = "True: {0}, Pred: {1}".format(cls_true[i], cls_pred[i])
ax.set_xlabel(xlabel)
# Remove ticks from the plot.
ax.set_xticks([])
ax.set_yticks([])
# Ensure the plot is shown correctly with multiple plots
# in a single Notebook cell.
plt.show()
載入少量資料檢視上述函式是否正確
# Get the first images from the test-set.
images = data.test.images[0:9]
# Get the true classes for those images.
cls_true = data.test.cls[0:9]
# Plot the images and labels using our helper-function above.
plot_images(images=images, cls_true=cls_true)
輸出結果如下:
6、Placeholder variables
佔位符變數作為圖表的輸入,我們可以在每次執行圖表時改變。
# Define the placeholder variable for the input images.
# None means that the tensor may hold an arbitrary number of images
# with each image being a vector of length img_size_flat.
x = tf.placeholder(tf.float32, [None, img_size_flat])
# Define the placeholder variable for the true labels.
y_true = tf.placeholder(tf.float32, [None, num_classes])
# Define the placeholder variable for the true class.
y_true_cls = tf.placeholder(tf.int64, [None])
7、Variables to be optimized
# The first variable that must be optimized is called weights
# and is defined here as a TensorFlow variable that must be initialized with zeros
# and whose shape is [img_size_flat, num_classes],
# so it is a 2-dimensional tensor (or matrix)
# with img_size_flat rows and num_classes columns.
weights = tf.Variable(tf.zeros([img_size_flat, num_classes]))
# The second variable that must be optimized is called biases
# and is defined as a 1-dimensional tensor (or vector) of length num_classes.
biases = tf.Variable(tf.zeros([num_classes]))
8、Model
# This simple mathematical model multiplies the images
# in the placeholder variable x with the weights and then adds the biases.
logits = tf.matmul(x, weights) + biases
y_pred = tf.nn.softmax(logits)
y_pred_cls = tf.argmax(y_pred, axis=1)
9、Cost-function to be optimized
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits,
labels=y_true)
cost = tf.reduce_mean(cross_entropy)
10、Optimization method
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(cost)
11、Performance measures
correct_prediction = tf.equal(y_pred_cls, y_true_cls)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
12、TensorFlow Run
session = tf.Session()
session.run(tf.global_variables_initializer())
batch_size = 100
def optimize(num_iterations):
for i in range(num_iterations):
# Get a batch of training examples.
# x_batch now holds a batch of images and
# y_true_batch are the true labels for those images.
x_batch, y_true_batch = data.train.next_batch(batch_size)
# Put the batch into a dict with the proper names
# for placeholder variables in the TensorFlow graph.
# Note that the placeholder for y_true_cls is not set
# because it is not used during training.
feed_dict_train = {x: x_batch,
y_true: y_true_batch}
# Run the optimizer using this batch of training data.
# TensorFlow assigns the variables in feed_dict_train
# to the placeholder variables and then runs the optimizer.
session.run(optimizer, feed_dict=feed_dict_train)
13、Helper-functions to show performance
feed_dict_test = {x: data.test.images,
y_true: data.test.labels,
y_true_cls: data.test.cls}
def print_accuracy():
# Use TensorFlow to compute the accuracy.
acc = session.run(accuracy, feed_dict=feed_dict_test)
# Print the accuracy.
print("Accuracy on test-set: {0:.1%}".format(acc))
def print_confusion_matrix():
# Get the true classifications for the test-set.
cls_true = data.test.cls
# Get the predicted classifications for the test-set.
cls_pred = session.run(y_pred_cls, feed_dict=feed_dict_test)
# Get the confusion matrix using sklearn.
cm = confusion_matrix(y_true=cls_true,
y_pred=cls_pred)
# Print the confusion matrix as text.
print(cm)
# Plot the confusion matrix as an image.
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
# Make various adjustments to the plot.
plt.tight_layout()
plt.colorbar()
tick_marks = np.arange(num_classes)
plt.xticks(tick_marks, range(num_classes))
plt.yticks(tick_marks, range(num_classes))
plt.xlabel('Predicted')
plt.ylabel('True')
# Ensure the plot is shown correctly with multiple plots
# in a single Notebook cell.
plt.show()
def plot_example_errors():
# Use TensorFlow to get a list of boolean values
# whether each test-image has been correctly classified,
# and a list for the predicted class of each image.
correct, cls_pred = session.run([correct_prediction, y_pred_cls],
feed_dict=feed_dict_test)
# Negate the boolean array.
incorrect = (correct == False)
# Get the images from the test-set that have been
# incorrectly classified.
images = data.test.images[incorrect]
# Get the predicted classes for those images.
cls_pred = cls_pred[incorrect]
# Get the true classes for those images.
cls_true = data.test.cls[incorrect]
# Plot the first 9 images.
plot_images(images=images[0:9],
cls_true=cls_true[0:9],
cls_pred=cls_pred[0:9])
def plot_weights():
# Get the values for the weights from the TensorFlow variable.
w = session.run(weights)
# Get the lowest and highest values for the weights.
# This is used to correct the colour intensity across
# the images so they can be compared with each other.
w_min = np.min(w)
w_max = np.max(w)
# Create figure with 3x4 sub-plots,
# where the last 2 sub-plots are unused.
fig, axes = plt.subplots(3, 4)
fig.subplots_adjust(hspace=0.3, wspace=0.3)
for i, ax in enumerate(axes.flat):
# Only use the weights for the first 10 sub-plots.
if i<10:
# Get the weights for the i'th digit and reshape it.
# Note that w.shape == (img_size_flat, 10)
image = w[:, i].reshape(img_shape)
# Set the label for the sub-plot.
ax.set_xlabel("Weights: {0}".format(i))
# Plot the image.
ax.imshow(image, vmin=w_min, vmax=w_max, cmap='seismic')
# Remove ticks from each sub-plot.
ax.set_xticks([])
ax.set_yticks([])
# Ensure the plot is shown correctly with multiple plots
# in a single Notebook cell.
plt.show()
14、Performance before any optimization
print_accuracy()
plot_example_errors()
15、Performance after 1 optimization iteration
optimize(num_iterations=1)
print_accuracy()
plot_example_errors()
plot_weights()
16、Performance after 10 optimization iterations
# We have already performed 1 iteration.
optimize(num_iterations=9)
print_accuracy()
plot_example_errors()
plot_weights()
17、Performance after 1000 optimization iterations
# We have already performed 10 iterations.
optimize(num_iterations=990)
print_accuracy()
plot_example_errors()
plot_weights()
print_confusion_matrix()
本文內容編輯:張永輝