1. 程式人生 > >快速入門深度學習(2)遷移學習

快速入門深度學習(2)遷移學習

咱們繼續入門課程系列,這次是關於遷移學習(Transfer Learning)的故事。

    這次咱們要“學習”一把了,針對特定的任務構造自己的分類器了。這次咱們仍然使用AlexNet的網路結構(誰讓它經典呢),訓練這個網路讓它為咱們服務。

    在正式Coding之前,首先了解下什麼是遷移學習。所謂的遷移學習就是指在深度學習中,把一個學習好的深度網路,稍加改造變成自己特有網路的意思,至於這樣做的道理,咱們這裡先不深入探討,只要先記住遷移學習有個很大的好處,就是網路收斂速度快。

實驗準備

Matlab2017b或者更新的版本,AlexNet。

資料準備:為了實驗的一致性,使用Matlab計算機視覺工具箱自帶的資料。

開始程式設計

載入資料

unzip('MerchData.zip');

imds = imageDatastore('MerchData','IncludeSubfolders',true,'LabelSource','foldernames');

[imdsTrain,idmsValidation] =splitEachLabel(imds,0.7,'randomized');

unzip函式的意思是解壓壓縮檔案。執行這一句之後可以看到在當前目錄下多了一個資料夾:


這個資料夾裡面就是本次實驗所使用的資料。為了更方便地組織該資料,我們使用imageDatastore函式來構造一個數據結構,用以管理資料。執行上面一句之後得到來一個imageDatastore資料結構,我們進入當前的工作空間對其進行觀察。

可以看到待使用的資料,被一個數據結構進行了組織,並且使用資料夾的名稱作為了類標籤。我們隨機選擇16個影象用 的方式進行顯示。

numImages= numel(imds.Labels);%統計總數

idx =randperm(numImages,16);  %隨機選擇

figure

for i = 1:16

    subplot(4,4,i)

    I = readimage(imds,idx(i));

    imshow(I)

end

可以看到

我們接下來把影象分為測試集(30%)和訓練集(70%):

[imdsTrain,idmsValidation]= splitEachLabel(imds,0.7,'randomized'

);

資料準備完畢了。

載入AlexNet網路

由於我們這一章講的是遷移學習,所以接下來需要載入已經訓練好的alexnet網路。關於如何載入請參看前一章。

net =alexnet;

修改網路

由於咱們這次只需要識別5個類,所以需要對AlexNet網路進行修改以適應當前的問題。我們這次主要對其進行如下修改:修改全連線層的輸出數量,從原來的1000變為5,其餘保持不變。首先提取出前面的層數,然後使用fullyConnectedLayer構造全連線層,最後完成整個網路的構建。

layersTransfer= net.Layers(1:end-3);

layers =[

    layersTransfer

    fullyConnectedLayer(5,'WeightLearnRateFactor',20,'BiasLearnRateFactor',20)

    softmaxLayer

    classificationLayer];

最後的結果layer就是我們需要的網路結構,此時網路還未經訓練。

訓練網路

訓練網路在Matlab中是一件非常簡單的事情,我們只需要配置好訓練引數就好了:

options = trainingOptions('sgdm',...
'MiniBatchSize',10, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'ValidationData',idmsValidation, ...
'ValidationFrequency',3, ...
'ValidationPatience',Inf, ...
'Verbose',false, ...
'Plots','training-progress');

關於訓練的引數,咱們以後再詳細介紹,這裡需要了解的一點就是,由於神經網路引數眾多,而且是一個典型的非凸優化問題,所以,訓練的引數選擇相當重要。

netTransfer = trainNetwork(imdsTrain,layers,options);

執行完上面一句就可以得到netTransfer作為遷移網路。

驗證網路

我們使用驗證集去測試神經網路的有效性:

YPred = classify(netTransfer,idmsValidation);
accuracy = mean(YPred == idmsValidation.Labels)

結果表明我們的訓練出來的神經網路具有良好的泛化性。

總結

從上面的程式設計過程中,可以發現Matlab神經網路工具箱已經幫助我們做好了很多工作,我們只需要去設計網路即可,然後訓練即可,把廣大程式設計師從無邊無際的codeing中解放出來。