1. 程式人生 > >Deeplearning4j 實戰 (10):遷移學習--ImageNet比賽預訓練網路VGG16分類花卉圖片

Deeplearning4j 實戰 (10):遷移學習--ImageNet比賽預訓練網路VGG16分類花卉圖片

在最新發布的Deeplearning4j 0.8.0的版本中,開始支援深度神經網路的遷移學習模型。嚴格來說,這種遷移的方式是一種模型遷移。在具體操作的時候,將一個預先訓練好的神經網路,用新資料集重新訓練網路中的一小部分,從而完成在新資料集上建立的演算法任務,即完成了神經網路的遷移學習。在給出具體的案例之前,先簡單討論下遷移學習的相關內容。

遷移學習是被認為可以解決標註資料不足的情況下訓練模型的問題。舉監督學習的例子,大量標註資料的收集是訓練模型的必要條件。如果標註資料不足或質量不高,那模型的泛化能力會大大下降,原因就在於標註資料的缺失將無法刻畫資料在特徵空間的分佈情況,不準確的分佈自然難以擬合測試資料,預測就不準確了。但很多時候,標註資料本身確實很難大量獲得,那麼是否有其他辦法來解決這種情況下模型泛化能力的問題呢?其中一個可行的方法就是利用遷移學習。試想一下,如果兩個任務的情形比較類似,比如,一個是做各種貓的圖片分類,另一個是做各種獵豹的圖片分類,由於這兩種動物都是貓科動物,很多特徵比如眼睛、鬍鬚、牙齒等都比較相似,所以可以考慮用已經訓練好的分類各種貓的模型,在用少量獵豹的資料重新訓練網路的部分引數後,來分類獵豹。如果用的是卷積神經網路的話,比如,AlexNet,VGG等,那麼前幾層的卷積+池化層就可以認為是提取貓科動物共同特徵的,後幾層全連階層用於提取各自任務中的不同特徵來分類。當然這樣的描述未必準確,但遷移學習的思想可以從這個例子中做些類比。

那麼遷移學習有什麼缺點呢?其實任務遷移的假設、遷移的效果都很難保證。比如,很多情況下,兩個任務之間是否相似,是否可以遷移,這個在理論上比較難界定。還有就是,負遷移的情況時有發生。也就是說,遷移以後演算法效果反而變差了。這些問題有一些研究成果,但在實際生產環境中,還是比較難解決。

更多的遷移學習的資料,可以參考楊強教授的個人主頁:http://www.cse.ust.hk/~qyang/

tutorial/survey性質的文章就可以參考楊強教授的文章:《A Survey on Transfer Learning》

下面就在最新版本Deeplearning4j的基礎上,給出一個遷移學習的例子作為入門之用。例子的主要內容是將ImageNet資料集訓練的分類模型VGG16,遷移到幾種花卉圖片的分類問題中。ImageNet資料集中共有~1000類的圖片集,涵蓋了動物、植物、物品等圖片。這裡,VGG16的模型是事先用Keras訓練好的。我們要做的事情,就是在該模型的基礎上,用新任務中的花卉圖片重新訓練網路的一小部分,從而遷移到新的任務上。這裡,對重新訓練網路的一部分做些解釋:

1. 可以將網路中的若干層神經元連線權重重新訓練

2. 可以將網路中的若干層直接移除,新增新的網路層,從而重新訓練這個新新增的網路層

這就是基於神經網路的遷移學習的一些可行的做法。下面就結合之前說的這兩種策略,給出基於Deeplearning4j的實現程式碼:

首先給出第一種策略的程式碼實現邏輯:

        FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
								.learningRate(3e-5)
							        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
							        .updater(Updater.ADAM)
							        .seed(seed)
							        .build();

        //Construct a new model with the intended architecture and print summary
        ComputationGraph vgg16Transfer = new TransferLearning.GraphBuilder(vgg16)
                                                .fineTuneConfiguration(fineTuneConf)
						.setFeatureExtractor(featureExtractionLayer) //the specified layer and below are "frozen"
						.removeVertexKeepConnections("predictions") //replace the functionality of the final vertex
						.addLayer("predictions",new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
			                        .nIn(4096).nOut(numClasses)
					        .weightInit(WeightInit.DISTRIBUTION)
			                       .dist(new NormalDistribution(0,0.2*(2.0/(4096+numClasses)))) //This weight init dist gave better results than Xavier
					       .activation(Activation.SOFTMAX).build(), "fc2")
					       .build();
        System.out.println(vgg16Transfer.summary());
這裡簡單解釋下程式碼。FineTuneConfiguration是定義重新訓練的一些引數,和訓練整個網路的引數類似。TransferLearning是遷移學習的主要類。removeVertexKeepConnections的作用是保證網路結構,但是那一層的網路權重要重新訓練。setFeatureExtractor的作用是做遷移學習的時候,凍結部分網路引數。之後打印出網路的結構就一目瞭然了:


從圖片中我們看出,setFeatureExtractor設定的引數凍結了包括fc2往下的所有網路層,即這部分網路連線的權重引數在遷移學習訓練的過程中保持不變,即所謂的frozen(圖片中綠框圈出的部分)。而removeVertexKeepConnections的設定保證了fc2這一層和predictions的連線保持不變,即keep connection。因此,可訓練的引數的數量,就是4096*5+5=20485個。其中,5是最後分類的花卉的品種數量,也是bias的數量。以上就是Deeplearning4j提供的第一種遷移學習的策略,保持新新增層和之前網路層的連線的前提下,根據任務要求重新訓練部分網路引數。下面介紹第二種遷移學習的策略:整個移除網路層,不保留該層和之前的連線。程式碼如下:

	        FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
	            .activation(Activation.LEAKYRELU)
	            .weightInit(WeightInit.RELU)
	            .learningRate(5e-5)
	            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
	            .updater(Updater.NESTEROVS)
	            .dropOut(0.5)
	            .seed(seed)
	            .build();

	        //Construct a new model with the intended architecture and print summary
	        //  Note: This architecture is constructed with the primary intent of demonstrating use of the transfer learning API,
	        //        secondary to what might give better results
	        ComputationGraph vgg16Transfer = new TransferLearning.GraphBuilder(vgg16)
	            .fineTuneConfiguration(fineTuneConf)
	            .setFeatureExtractor(featureExtractionLayer) //"block5_pool" and below are frozen
	            .nOutReplace("fc2",1024, WeightInit.XAVIER) //modify nOut of the "fc2" vertex
	            .removeVertexAndConnections("predictions") //remove the final vertex and it's connections
	            .addLayer("fc3",new DenseLayer.Builder().activation(Activation.TANH).nIn(1024).nOut(256).build(),"fc2") //add in a new dense layer
	            .addLayer("newpredictions",new OutputLayer
	                                        .Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
	                                        .activation(Activation.SOFTMAX)
	                                        .nIn(256)
	                                        .nOut(numClasses)
	                                        .build(),"fc3") //add in a final output dense layer,
	                                                        // note that learning related configurations applied on a new layer here will be honored
	                                                        // In other words - these will override the finetune confs.
	                                                        // For eg. activation function will be softmax not RELU
	            .setOutputs("newpredictions") //since we removed the output vertex and it's connections we need to specify outputs for the graph
	            .build();
	        log.info(vgg16Transfer.summary());

從圖中可以看出,可第一種策略不同的是,這一種策略添加了fc3這一層,並且與fc2連線的權重重新訓練,相當於完全移除了fc2這一層。因此可以重新訓練的引數的數量也就變成了最後三層全連階層+輸出層。

以上即為Deeplearning4j現在已經實現了的在深度學習基礎上對遷移學習的支援。

在這裡,我們的任務是將ImageNet比賽資料訓練好的VGG16模型遷移到5中花卉的訓練問題上。這五種花卉的訓練資料集的地址: http://download.tensorflow.org/example_images/flower_photos.tgz

模型的匯入,直接呼叫Deeplearning4j中Keras模型的匯入介面就行了:

        ComputationGraph vgg16 = KerasModelImport.importKerasModelAndWeights("/home/wangongxi/transferlearning/VGG16.json", "/home/wangongxi/transferlearning/vgg16_weights_th_dim_ordering_th_kernels.h5", false);

VGG16模型的下載地址為:


https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_th_dim_ordering_th_kernels.h5

https://raw.githubusercontent.com/deeplearning4j/dl4j-examples/f9da30063c1636e1de515f2ac514e9a45c1b32cd/dl4j-examples/src/main/resources/trainedModels/VGG16.json

最後,再給出圖片讀取和模型訓練以及評估的程式碼邏輯。這一部分和之前部落格的內容非常相似,所以就不再做過多的解釋了。

        FlowerDataSetIterator.setup(batchSize,trainPerc);
        DataSetIterator trainIter = FlowerDataSetIterator.trainIterator();
        DataSetIterator testIter = FlowerDataSetIterator.testIterator();

        Evaluation eval;
        eval = vgg16Transfer.evaluate(testIter);
        System.out.println("Eval stats BEFORE fit.....");
        System.out.println(eval.stats() + "\n");
        testIter.reset();
        System.out.println("Start Training");
        //
        final int numEpoch = Integer.parseInt(args[0]);
        for( int i = 0; i < numEpoch; ++i ){
        	vgg16Transfer.fit(trainIter);
        	System.out.println("Evaluate model at epoch "+ i + " ....");
            eval = vgg16Transfer.evaluate(testIter);
            System.out.println(eval.stats());
            testIter.reset();
        }
        System.out.println("Model build complete");

FlowerDataSetIterator是可以讀取這些花卉圖片的的迭代器包裝類。trainPerc是訓練資料和測試資料的比例。後面就是模型訓練和評估的程式碼了。我們直接來看下50輪左右訓練的結果:

策略一:


策略二:


可以看到,兩種遷移的策略最後可以達到的準確率也就在85%左右,不算非常高。不過這裡面原因也是多方面的,比如這些花本身就比較容易混淆(至少我本人不太善於辨別這些花卉),還有就是我們這些遷移學習的工作都放在的全連線層上面,如果適當重新訓練下卷積層+池化層,也許效果還會更好,這些個工作都留待後續去完成。最後,還要說明的一點是,本次遷移學習的模型訓練全部在GPU上完成,使用的公司的單機4卡的K80機器完成的。我只用了K80單核心來完成,並沒有配置並行訓練。訓練的時長大概在1天左右。GPU的程式碼這裡就不給出了,和我之前專門寫的一篇關於Deeplearning4j+GPU的部落格裡的內容是類似的。

最後做下簡單點的總結。這裡主要講了遷移學習在Deeplearning4j中的應用。從根本上來將,神經網路的遷移學習主要在於固定某些層已經事先訓練好的引數,然後,利用新的資料重新訓練部分新的網路連線權重。由於不是訓練整個網路,因此訓練的引數數量大大減少了。當然遷移學習不一定是基於神經網路的,其他傳統模型經過適當改造也可以適應遷移學習的要求。最後需要再次指出的是,遷移訓練本身的效果一般不會比用訓練資料重新訓練整個網路來得好,尤其是在兩個演算法任務不相似的情況下,這個情況更容易出現。但在沒有條件訓練大網路的情況下,用遷移學習的思想調優部分引數還是非常有價值的!