1. 程式人生 > >MatConvNet卷積神經網路(四)——用自己的資料訓練

MatConvNet卷積神經網路(四)——用自己的資料訓練

嘗試過從Matconvnet官網上下載的已經訓練好的神經網路之後,最近自己訓練了能夠識別果樹上紅蘋果的神經網路。先上圖。原始碼放在https://github.com/YunpengZhai/MATCONVNET

10/21/2016 更新:把滑動窗的程式碼放到了github上(結尾為**slide)


下面分享一下經驗。

以下內容看之前,希望已經閱讀過Matconvnet的官方文件matconvnet-manual,或者對機器學習的一些概念、卷積神經網路的原理具備基本的瞭解。

現在進入正題。

構建自己的神經網路,需要完成以下三個部分:

1.準備資料。

2.設計神經網路的結構。

3.設定引數,用資料訓練網路。

一、準備資料。

資料在磁碟中的存放如下圖:


之後,將檔案中的圖片匯入、格式化、劃分訓練集測試集交叉驗證集、求均值,然後以.mat格式儲存在磁碟上。

%cnn_setup_data.m

<span style="font-size:14px;">function imdb =cnn_setup_data(datadir)

inputSize =[64,64];
subdir=dir(datadir);
imdb.images.data=[];
imdb.images.labels=[];
imdb.images.set = [] ;
imdb.meta.sets = {'train', 'val', 'test'} ;
image_counter=0;
trainratio=0.8;
for i=3:length(subdir)
    imdb.meta.classes(i-2) = {subdir(i).name};
    imgfiles=dir(fullfile(datadir,subdir(i).name));
    imgpercategory_count=length(imgfiles)-2;
    disp([i-2 imgpercategory_count]);
    image_counter=image_counter+imgpercategory_count;
    for j=3:length(imgfiles)
        img=imread(fullfile(datadir,subdir(i).name,imgfiles(j).name));
        img=imresize(img, inputSize(1:2));
        img=single(img);
        imdb.images.data(:,:,:,end+1)=single(img);
        imdb.images.labels(end+1)= i-2;
        if j-2<imgpercategory_count*trainratio
            imdb.images.set(end+1)=1;
        else
            imdb.images.set(end+1)=3;
        end
    end
end

dataMean=mean(imdb.images.data,4);
imdb.images.data = single(bsxfun(@minus,imdb.images.data, dataMean)) ;
imdb.images.data_mean = single(dataMean);%!!!!!!!!!!!
end</span>
二、初始化神經網路

這一部分包括了對神經網路各個層的設計(比如每一層的種類、維度、正則化,以及在訓練中的一些引數等)。

%cnn_mnist_init.m

<span style="font-size:14px;">function net = cnn_mnist_init(varargin)
% CNN_MNIST_LENET Initialize a CNN similar for MNIST
opts.batchNormalization = true ;
opts.networkType = 'simplenn' ;
opts = vl_argparse(opts, varargin) ;

rng('default');
rng(0) ;

f=1/100 ;
net.layers = {} ;
net.layers{end+1} = struct('type', 'conv', ...
                           'weights', {{f*randn(5,5,3,20, 'single'), zeros(1, 20, 'single')}}, ...
                           'stride', 1, ...
                           'pad', 0) ;
net.layers{end+1} = struct('type', 'pool', ...
                           'method', 'max', ...
                           'pool', [2 2], ...
                           'stride', 2, ...
                           'pad', 0) ;
net.layers{end+1} = struct('type', 'conv', ...
                           'weights', {{f*randn(10,10,20,50, 'single'),zeros(1,50,'single')}}, ...
                           'stride', 1, ...
                           'pad', 0) ;
net.layers{end+1} = struct('type', 'pool', ...
                           'method', 'max', ...
                           'pool', [2 2], ...
                           'stride', 2, ...
                           'pad', 0) ;
net.layers{end+1} = struct('type', 'conv', ...
                           'weights', {{f*randn(10,10,50,500, 'single'),  zeros(1,500,'single')}}, ...
                           'stride', 1, ...
                           'pad', 0) ;
net.layers{end+1} = struct('type', 'relu') ;
net.layers{end+1} = struct('type', 'conv', ...
                           'weights', {{f*randn(1,1,500,4, 'single'), zeros(1,4,'single')}}, ...
                           'stride', 1, ...
                           'pad', 0) ;
net.layers{end+1} = struct('type', 'softmaxloss') ;

% optionally switch to batch normalization
if opts.batchNormalization
  net = insertBnorm(net, 1) ;
  net = insertBnorm(net, 4) ;
  net = insertBnorm(net, 7) ;
end

% Meta parameters
net.meta.inputSize = [64 64] ;
net.meta.trainOpts.learningRate = 0.0005 ;
net.meta.trainOpts.numEpochs = 30 ;
net.meta.trainOpts.batchSize = 200 ;

% Fill in defaul values
net = vl_simplenn_tidy(net) ;

% Switch to DagNN if requested
switch lower(opts.networkType)
  case 'simplenn'
    % done
  case 'dagnn'
    net = dagnn.DagNN.fromSimpleNN(net, 'canonicalNames', true) ;
    net.addLayer('top1err', dagnn.Loss('loss', 'classerror'), ...
      {'prediction', 'label'}, 'error') ;
    net.addLayer('top5err', dagnn.Loss('loss', 'topkerror', ...
      'opts', {'topk', 5}), {'prediction', 'label'}, 'top5err') ;
  otherwise
    assert(false) ;
end

% --------------------------------------------------------------------
function net = insertBnorm(net, l)
% --------------------------------------------------------------------
assert(isfield(net.layers{l}, 'weights'));
ndim = size(net.layers{l}.weights{1}, 4);
layer = struct('type', 'bnorm', ...
               'weights', {{ones(ndim, 1, 'single'), zeros(ndim, 1, 'single')}}, ...
               'learningRate', [1 1 0.05], ...
               'weightDecay', [0 0]) ;
net.layers{l}.biases = [] ;
net.layers = horzcat(net.layers(1:l), layer, net.layers(l+1:end)) ;</span><span style="font-size:18px;">
</span>
該網路結構:


三、訓練網路

%cnn_mnist.m

<span style="font-size:14px;">function [net, info] = cnn_mnist(varargin)
%CNN_MNIST  Demonstrates MatConvNet on MNIST

run(fullfile(fileparts(mfilename('fullpath')),...
  '..', '..', 'matlab', 'vl_setupnn.m')) ;

opts.batchNormalization = false ;
opts.networkType = 'simplenn' ;
[opts, varargin] = vl_argparse(opts, varargin) ;

sfx = opts.networkType ;
if opts.batchNormalization, sfx = [sfx '-bnorm'] ; end
datadir='E:\學習\機器學習\matconvnet-1.0-beta20\photos\multi-label';
opts.expDir = fullfile(vl_rootnn, 'data', ['mnist-zyp-' sfx]) ;
[opts, varargin] = vl_argparse(opts, varargin) ;

opts.dataDir = fullfile(vl_rootnn, 'data', 'mnist') ;
opts.imdbPath = fullfile(opts.expDir, 'imdb.mat');
opts.train = struct() ;
opts = vl_argparse(opts, varargin) ;
if ~isfield(opts.train, 'gpus'), opts.train.gpus = []; end;

% --------------------------------------------------------------------
%                                                         Prepare data
% --------------------------------------------------------------------

net = cnn_mnist_init('batchNormalization', opts.batchNormalization, ...
                     'networkType', opts.networkType) ;

if exist(opts.imdbPath, 'file')
  imdb = load(opts.imdbPath) ;
else
  imdb=cnn_setup_data(datadir);
  mkdir(opts.expDir) ;
  save(opts.imdbPath, '-struct', 'imdb') ;
end

net.meta.classes.name = arrayfun(@(x)sprintf('%d',x),1:2,'UniformOutput',false) ;

% --------------------------------------------------------------------
%                                                                Train
% --------------------------------------------------------------------

switch opts.networkType
  case 'simplenn', trainfn = @cnn_train ;
  case 'dagnn', trainfn = @cnn_train_dag ;
end

[net, info] = trainfn(net, imdb, getBatch(opts), ...
  'expDir', opts.expDir, ...
  net.meta.trainOpts, ...
  opts.train, ...
  'val', find(imdb.images.set == 3)) ;
net.meta.data_mean = imdb.images.data_mean;
net.layers{end}.class = [1] ;

% --------------------------------------------------------------------
function fn = getBatch(opts)
% --------------------------------------------------------------------
switch lower(opts.networkType)
  case 'simplenn'
    fn = @(x,y) getSimpleNNBatch(x,y) ;
  case 'dagnn'
    bopts = struct('numGpus', numel(opts.train.gpus)) ;
    fn = @(x,y) getDagNNBatch(bopts,x,y) ;
end

% --------------------------------------------------------------------
function [images, labels] = getSimpleNNBatch(imdb, batch)
% --------------------------------------------------------------------
images = imdb.images.data(:,:,:,batch) ;
labels = imdb.images.labels(1,batch) ;

% --------------------------------------------------------------------
function inputs = getDagNNBatch(opts, imdb, batch)
% --------------------------------------------------------------------
images = imdb.images.data(:,:,:,batch) ;
labels = imdb.images.labels(1,batch) ;
if opts.numGpus > 0
  images = gpuArray(images) ;
end
inputs = {'input', images, 'label', labels} ;</span><span style="font-size:18px;">
</span>

四、應用——測試程式
<span style="font-size:14px;">%初次執行一次,之後不再執行
%[net_bn, info_bn] = cnn_mnist('batchNormalization', true);
load('E:\學習\機器學習\matconvnet-1.0-beta20\data\mnist-zyp-simplenn-bnorm\imdb.mat');
im=imread('E:\學習\機器學習\matconvnet-1.0-beta20\photos\QQ截圖20160922172145.png');
im=imresize(im,[64 64 ]);
imshow(im);
im = single(im);
im = im - images.data_mean;
res = vl_simplenn(net_bn, im,[],[],...
                      'accumulate', 0, ...
                      'mode', 'test', ...
                      'backPropDepth', inf, ...
                      'sync', 0, ...
                      'cudnn', 1) ;
scores = res(11).x(1,1,:);
[bestScore, best] = max(scores);
switch best
    case 1
        title('判斷結果:不是蘋果');
    case 2
        title('判斷結果:1個蘋果');
    case 3
        title('判斷結果:2個蘋果');
    case 4 
        title('判斷結果:3個蘋果');
end</span><span style="font-size:18px;">
</span>
測試一下:



PS:寫著寫著就懶得寫註釋了。

配合滑動窗的話,結果如下: