【caffe學習筆記之4】利用MATLAB介面執行cifar資料集
【前期準備工作】
1. 確保模型訓練成功,生成模型檔案:cifar10_quick_iter_4000.caffemodel及均值檔案:mean.binaryproto。注意,此處一定是生成caffemodel格式的模型檔案,而非.h5模型檔案,否則會導致Matlab執行崩潰。如何生成caffemodel檔案請參考上篇帖子。
也可以利用Matlab生成cifar10_quick_iter_4000.caffemodel,方法是進入caffe根目錄,例如我的電腦上為D:\caffe-master\caffe-master,然後在matlab中執行以下命令,即可對模型進行訓練:
solver = caffe.Solver('./examples/cifar10/cifar10_quick_solver.prototxt'); solver.solve()
2. 在caffe-master\matlab路徑下新建cifar資料夾用於案例除錯
3. 拷貝classification_demo.m檔案到cifar資料夾下,並更名為classification_cifar.m
【基於mean.binaryproto檔案生成.mat 檔案】
在matlab command line中輸入以下命令,對mean.binaryproto檔案進行轉換:
於是在matlab/cifar資料夾下生成了image_mean.mat檔案mean_file = 'D:\caffe-master\caffe-master\examples\cifar10\test\mean.binaryproto'; image_mean = caffe_('read_mean', mean_file); save 'D:\caffe-master\caffe-master\matlab\cifar\image_mean.mat' image_mean
【對classification_cifar.m檔案進行修改】
1. 修改dir路徑、model路徑和weight路徑:
2. 修改prepare.image()函式
修改後的classification_cifar.m檔案程式碼:
function [scores, maxlabel] = classification_cifar(im, use_gpu) % Add caffe/matlab to you Matlab search PATH to use matcaffe if exist('../+caffe', 'dir') addpath('..'); else error('Please run this demo from caffe/matlab/demo'); end % Set caffe mode if exist('use_gpu', 'var') && use_gpu caffe.set_mode_gpu(); gpu_id = 0; % we will use the first gpu in this demo caffe.set_device(gpu_id); else caffe.set_mode_cpu(); end % Initialize the network using BVLC CaffeNet for image classification % Weights (parameter) file needs to be downloaded from Model Zoo. model_dir = '../../examples/cifar10/'; net_model = [model_dir 'cifar10_quick.prototxt']; net_weights = [model_dir 'cifar10_quick_iter_4000.caffemodel']; phase = 'test'; % run with phase test (so that dropout isn't applied) if ~exist(net_weights, 'file') error('Please download CaffeNet from Model Zoo before you run this demo'); end % Initialize a network net = caffe.Net(net_model, net_weights, phase); if nargin < 1 % For demo purposes we will use the cat image fprintf('using caffe/examples/images/cat.jpg as input image\n'); im = imread('../../examples/images/cat.jpg'); end % prepare oversampled input % input_data is Height x Width x Channel x Num tic; input_data = {prepare_image(im)}; toc; % do forward pass to get scores % scores are now Channels x Num, where Channels == 1000 tic; % The net forward function. It takes in a cell array of N-D arrays % (where N == 4 here) containing data of input blob(s) and outputs a cell % array containing data from output blob(s) scores = net.forward(input_data); toc; scores = scores{1}; scores = mean(scores, 2); % take average scores over 10 crops [~, maxlabel] = max(scores); % call caffe.reset_all() to reset caffe caffe.reset_all(); % ------------------------------------------------------------------------ function im_data = prepare_image(im) % ------------------------------------------------------------------------ % caffe/matlab/+caffe/imagenet/ilsvrc_2012_mean.mat contains mean_data that % is already in W x H x C with BGR channels d = load('D:\caffe-master\caffe-master\matlab\cifar\image_mean.mat'); mean_data = d.mean_data; IMAGE_DIM = 32; % Convert an image returned by Matlab's imread to im_data in caffe's data % format: W x H x C with BGR channels im_data = im(:, :, [3, 2, 1]); % permute channels from RGB to BGR im_data = permute(im_data, [2, 1, 3]); % flip width and height im_data = single(im_data); % convert from uint8 to single im_data = imresize(im_data, [IMAGE_DIM IMAGE_DIM], 'bilinear'); % resize im_data im_data = im_data - mean_data; % subtract mean_data (already in W x H x C, BGR)
【模型測試】
編寫test.m檔案,用於模型測試,test.m檔案程式碼:
clear;clc
im = imread('D:\caffe-master\caffe-master\examples\images\cat.jpg');
[scores, maxlabel] = classification_cifar(im,0)
index = importdata('synset_words.txt');
name = index(maxlabel);
figure;imshow(im);
str=strcat('分類結果:',name,' 得分:',num2str(max(scores)));
title(str);
使用上述命令完成模型測試,並對貓做出了正確分類:
【檔案下載】
上述資料夾中的4個檔案:classification.m、test.m、image_mean.mat、synset_words.txt打包下載地址:
訓練的cifar10_quick_iter_4000.caffemodel檔案下載地址: