# encoding: utf-8
import sys
reload(sys)
sys.setdefaultencoding('utf-8')

import numpy as np
from keras.datasets import mnist
import gc

from keras.models import Sequential, Model
from keras.layers import Input, Dense, Dropout, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.applications.vgg16 import VGG16
from keras.optimizers import SGD


import cv2
import h5py as h5py 
import numpy as np


##由於輸入層需要10個節點,所以最好把目標數字0-9做成one Hot編碼的形式。
def tran_y(y): 
    y_ohe = np.zeros(10) 
    y_ohe[y] = 1 
    return y_ohe


# 如果硬體配置較高,比如主機具備32GB以上記憶體,GPU具備8GB以上視訊記憶體,可以適當增大這個值。VGG要求至少48畫素
ishape=48
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = [cv2.cvtColor(cv2.resize(i, (ishape, ishape)), cv2.COLOR_GRAY2BGR) for i in X_train] 
X_train = np.concatenate([arr[np.newaxis] for arr in X_train]).astype('float32') 
X_train /= 255.0

X_test = [cv2.cvtColor(cv2.resize(i, (ishape, ishape)), cv2.COLOR_GRAY2BGR) for i in X_test] 
X_test = np.concatenate([arr[np.newaxis] for arr in X_test]).astype('float32')
X_test /= 255.0

y_train_ohe = np.array([tran_y(y_train[i]) for i in range(len(y_train))]) 
y_test_ohe = np.array([tran_y(y_test[i]) for i in range(len(y_test))])
y_train_ohe = y_train_ohe.astype('float32')
y_test_ohe = y_test_ohe.astype('float32')


print X_train.shape

# VGG16 全參重訓遷移學習

# 很多時候需要多次回收垃圾才能徹底收回記憶體。如果不行,重新啟動單獨執行下面的模型
for i in range(10):
    gc.collect()



ishape=224
model_vgg = VGG16(include_top = False, weights = 'imagenet', input_shape = (ishape, ishape, 3)) 

for layer in model_vgg.layers:
        layer.trainable = False
model = Flatten()(model_vgg.output) 
model = Dense(4096, activation='relu', name='fc1')(model)
model = Dense(4096, activation='relu', name='fc2')(model)
model = Dropout(0.5)(model)
model = Dense(10, activation = 'softmax', name='prediction')(model) 
model_vgg_mnist_pretrain = Model(model_vgg.input, model, name = 'vgg16_pretrain')
print model_vgg_mnist_pretrain.summary()


##我們只需要訓練25萬個引數,比之前整數少了60倍。
sgd = SGD(lr = 0.05, decay = 1e-5) 
model_vgg_mnist_pretrain.compile(loss = 'categorical_crossentropy', optimizer = sgd, metrics = ['accuracy'])

model_vgg_mnist_pretrain.fit(X_train, y_train_ohe, validation_data = (X_test, y_test_ohe), epochs = 10, batch_size = 64)


#######在測試集上評價模型精確度
scores=model_vgg_mnist_pretrain.evaluate(X_test,y_test_ohe,verbose=0)

#####列印精確度
print scores


time2 = time.time()
print u'ok,結束!'
print u'總共耗時:' + str(time2 - time1) + 's'