1. 程式人生 > >【caffe學習筆記之7】caffe-matlab/python訓練LeNet模型並應用於mnist資料集(2)

【caffe學習筆記之7】caffe-matlab/python訓練LeNet模型並應用於mnist資料集(2)

【案例介紹】

LeNet網路模型是一個用來識別手寫數字的最經典的卷積神經網路,是Yann LeCun在1998年設計並提出的,是早期卷積神經網路中最有代表性的實驗系統之一,其論文是CNN領域第一篇經典之作。本篇部落格詳細介紹基於Matlab、Python訓練lenet手寫模型的案例,作為前幾次caffe深度學習框架的階段性總結。

【生成均值檔案】

接上回,mnist資料集生成leveldb資料庫之後,需要計算圖片均值

在train_leveldb資料夾同級建立mean資料夾,然後在當前目錄下開啟doc介面,輸入以下命令:

compute_image_mean train_leveldb mean/mean.binaryproto --backend leveldb
然後,在mean資料夾下生成mean.binaryproto檔案

【訓練LeNet網路】
訓練網路,有3種方法:

(1)使用可執行程式caffe.exe訓練,在命令提示符下執行caffe.exe train命令,參考之前的帖子:

(2)利用Matlab介面訓練網路,使用solver.solve()  命令,參考之前的帖子

(3)利用Python介面訓練網路,本節進行相關內容的介紹

首先,修改lenet_solver.prototxt與lenet_train_test.prototxt兩個檔案的內容,主要是資料庫的路徑、型別以及求解模式,CPU/GPU

然後開啟python,執行以下指令碼:

import caffe
caffe_root = 'D:/caffe-master/caffe-master/'

import os
os.chdir(caffe_root)
print os.getcwd()

solver = caffe.SGDSolver('./examples/mnist/lenet_solver.prototxt')
solver.solve()

在D:\caffe-master\caffe-master\examples\mnist路徑下生成以下4個檔案:


【均值檔案格式轉換】

使用Caffe的C++介面進行操作時,需要的影象均值檔案是pb格式,例如常見的均值檔名為mean.binaryproto;但在使用Python介面進行操作時,需要的影象均值檔案是numpy格式,例如mean.npy。所以在跨語言進行操作時,需要將mean.binaryproto轉換成mean.npy,轉換程式碼如下:

import caffe
caffe_root = 'D:/caffe-master/caffe-master/'

import os
os.chdir(caffe_root)
print os.getcwd()

import numpy as np  

#%%

MEAN_PROTO_PATH = "./examples/mnist/data/mean/mean.binaryproto"               
MEAN_NPY_PATH = "./examples/mnist/data/mean/mean.npy"                         

blob = caffe.proto.caffe_pb2.BlobProto()           # 建立protobuf blob
data = open(MEAN_PROTO_PATH, 'rb' ).read()         # 讀入mean.binaryproto檔案內容
blob.ParseFromString(data)                         # 解析檔案內容到blob

array = np.array(caffe.io.blobproto_to_array(blob))# 將blob中的均值轉換成numpy格式,array的shape (mean_number,channel, hight, width)
mean_npy = array[0]                                # 一個array中可以有多組均值存在,故需要通過下標選擇其中一組均值
np.save(MEAN_NPY_PATH ,mean_npy)

執行後,在mean資料夾下生成mean.npy檔案:


【CPU實現圖片分類】

執行以下python命令:

# -*- coding: utf-8 -*-

import caffe
caffe_root = 'D:/caffe-master/caffe-master/'

import os
os.chdir(caffe_root)
print os.getcwd()

import numpy as np  
import matplotlib.pyplot as plt
# %%
# Set Caffe to CPU mode and load the net from disk.
caffe.set_mode_cpu()

model_def = caffe_root + 'examples/mnist/lenet.prototxt'  #注意!
model_weights = caffe_root + 'examples/mnist/lenet_iter_5000.caffemodel'

net = caffe.Net(model_def,      # defines the structure of the model
                model_weights,  # contains the trained weights
                caffe.TEST)     # use test mode (e.g., don't perform dropout)

# load the mean ImageNet image (as distributed with Caffe) for subtraction
mu = np.load(caffe_root + 'examples/mnist/data/mean/mean.npy')
mu = mu.mean(1).mean(1)  # average over pixels to obtain the mean (BGR) pixel values 

# %%
# Load an image (that comes with Caffe) and perform the preprocessing we've set up.
# create transformer for the input called 'data'
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})

transformer.set_transpose('data', (2,0,1))  # move image channels to outermost dimension
transformer.set_mean('data', mu)            # subtract the dataset-mean value in each channel
transformer.set_raw_scale('data', 255)      # rescale from [0, 1] to [0, 255]
transformer.set_channel_swap('data', (2,1,0))  # swap channels from RGB to BGR

image = caffe.io.load_image(caffe_root + 'examples/mnist/data/test/TestImage_17.bmp')
transformed_image = transformer.preprocess('data', image)
plt.imshow(image)


# %%
# copy the image data into the memory allocated for the net
net.blobs['data'].data[...] = transformed_image

### perform classification
output = net.forward()
output_prob = output['prob'][0]  # the output probability vector for the first image in the batch
print 'predicted class is:', output_prob.argmax()

結果如下圖所示:


程式中有個地方需要注意:

model_def = caffe_root + 'examples/mnist/lenet.prototxt'  #注意!

lenet.prototxt檔案需要修改一個地方:

input_param { shape: { dim: 64 dim: 1 dim: 28 dim: 28 } }

需要改成

input_param { shape: { dim: 64 dim: 3 dim: 28 dim: 28 } }

這是因為手寫圖片雖然是黑白圖片,但是上篇帖子在資料轉換時,圖片已轉換為RGB3通道格式