1. 程式人生 > >深度學習之Python 指令碼訓練keras mnist 數字識別模型

深度學習之Python 指令碼訓練keras mnist 數字識別模型

本指令碼是訓練keras 的mnist 數字識別程式 ,以前發過了 ,今天把 預測實現了,

# Larger CNN for the MNIST Dataset
# 2.Negative dimension size caused by subtracting 5 from 1 for 'conv2d_4/convolution' (op: 'Conv2D') with input shapes
# 3.UserWarning: Update your `Conv2D` call to the Keras 2 API: http://blog.csdn.net/johinieli/article/details/69222956
# 4.Error when checking input: expected conv2d_1_input to have shape (None, 28, 28, 1) but got array with shape (60000, 1, 28, 28)

# talk to wumi,you good .

# python 3.5.4
# keras.__version__  : 2.0.6
# thensorflow 1.2.1
# theano 0.10.0beta1

# good blog
# http://blog.csdn.net/shizhengxin123/article/details/72383728
# http://www.360doc.com/content/17/0415/12/1489589_645772879.shtml

# recommand another framework  http://tflearn.org/examples/

import numpy
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import Flatten
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.utils import np_utils
import matplotlib.pyplot as plt
from keras.constraints import maxnorm
from keras.optimizers import SGD
from keras.preprocessing import image
import skimage.io



# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()

plt.subplot(221)

plt.imshow(X_train[0], cmap=plt.get_cmap('gray'))

plt.show()
# reshape to be [samples][pixels][width][height]
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1).astype('float32')
# X_train = X_train.reshape(1, 28, 28, 1).astype('float32') ValueError: cannot reshape array of size 47040000 into shape (1,28,28,1)
#X_test = X_test.reshape(1, 28, 28, 1).astype('float32')  ValueError: cannot reshape array of size 47040000 into shape (1,28,28,1)
# X_train = X_train.reshape(X_train.shape[0], 1, 28, 28).astype('float32')
# X_test = X_test.reshape(X_test.shape[0], 1, 28, 28).astype('float32')    <---4
# normalize inputs from 0-255 to 0-1
X_train = X_train / 255
X_test = X_test / 255
# one hot encode outputs
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
num_classes = y_test.shape[1]


###raw
# define the larger model
def larger_model():
    # create model
    model = Sequential()
    model.add(Conv2D(30, (5, 5), padding='valid', input_shape=(28, 28, 1), activation='relu'))
    # model.add(Conv2D(30, (5, 5), padding='valid', input_shape=(28, 28,1), activation='relu'))   <----3,2
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.4))
    model.add(Conv2D(15, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.4))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.4))
    model.add(Dense(50, activation='relu'))
    model.add(Dropout(0.4))
    model.add(Dense(num_classes, activation='softmax'))
    # Compile model
    # optimizer  優化器
    # loss 損失函式
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model


# build the model
model = larger_model()
# Fit the model
# fit函式返回一個History的物件,其History.history屬性記錄了損失函式和其他指標的數值隨epoch變化的情況,如果有驗證集的話,也包含了驗證集的這些指標變化情況
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=10, batch_size=200,
          verbose=2)  # epochs 200 too bigger
# model.fit(X_train, y_train, validation_data=(X_test, y_test), nb_epoch=200, batch_size=200, verbose=2)
# Final evaluation of the model
scores = model.evaluate(X_test, y_test, verbose=0)
print("Large CNN Error: %.2f%%" % (100 - scores[1] * 100))

# save the model
model.save('D:\\works\\jetBrians\\PycharmProjects\\tryPicture\\my_model.h5')  # creates a HDF5 file 'my_model.h5'
del model

# reload the modle
# returns a compiled model
# identical to the previous one
# modelTrained = Sequential()
# model = modelTrained.load_model('D:\\works\\jetBrians\\PycharmProjects\\tryPicture\\my_model.h5')

# https://gist.github.com/ageitgey/a40dded08e82e59724c70da23786bbf0

# write a number in a picture
# predict numbers

#image_path = './lena.jpg'
# method 1
# load pic
#img = image.load_img(image_path, target_size=(28, 28))
# handle pic
#x = image.img_to_array(img)
#x = numpy.expand_dims(x, axis=0)
#x = preprocess_input(x)

# method2
#img2 = skimage.io.imread(image_path, as_grey=True)
#skimage.io.imshow(img2)
#plt.show()
#img2 = numpy.reshape(img2, (1, 28, 28, 1)).astype('float32')
# 對數字進行預測
#https://baijiahao.baidu.com/s?id=1574962680356106&wfr=spider&for=pc
#predict = model.predict(img2, verbose=0)
#result = model.prediect_classes(img2, verbose=0)
#print(predict[0])
#print(result[0])

#some warning tips  The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
#have no idea what's the meaning


原來資料樣式 




=================訓練log

D:\applications\Anaconda3\python.exe D:/works/jetBrians/PycharmProjects/tryPicture/trainModel/TrainModel.py
Using TensorFlow backend.
Train on 60000 samples, validate on 10000 samples
Epoch 1/10
62s - loss: 0.8830 - acc: 0.7027 - val_loss: 0.1566 - val_acc: 0.9545
Epoch 2/10
56s - loss: 0.3130 - acc: 0.9078 - val_loss: 0.0955 - val_acc: 0.9712
Epoch 3/10
61s - loss: 0.2342 - acc: 0.9340 - val_loss: 0.0737 - val_acc: 0.9763
Epoch 4/10
58s - loss: 0.1924 - acc: 0.9458 - val_loss: 0.0643 - val_acc: 0.9802
Epoch 5/10
60s - loss: 0.1678 - acc: 0.9534 - val_loss: 0.0541 - val_acc: 0.9848
Epoch 6/10
53s - loss: 0.1541 - acc: 0.9578 - val_loss: 0.0468 - val_acc: 0.9849
Epoch 7/10
53s - loss: 0.1396 - acc: 0.9617 - val_loss: 0.0464 - val_acc: 0.9852
Epoch 8/10
55s - loss: 0.1303 - acc: 0.9647 - val_loss: 0.0422 - val_acc: 0.9871
Epoch 9/10
52s - loss: 0.1276 - acc: 0.9656 - val_loss: 0.0398 - val_acc: 0.9871
Epoch 10/10
53s - loss: 0.1156 - acc: 0.9680 - val_loss: 0.0370 - val_acc: 0.9876
Large CNN Error: 1.24%

Process finished with exit code 0