1. 程式人生 > >Caffemodel資料結構解析與Protocol Buffer技術詳解(C++例項)

Caffemodel資料結構解析與Protocol Buffer技術詳解(C++例項)

Caffe中,資料的讀取、運算、儲存都是採用Google Protocol Buffer來進行的,所以首先來較為詳細的介紹下Protocol Buffer(PB)。

PB是一種輕便、高效的結構化資料儲存格式,可以用於結構化資料序列化,很適合做資料儲存或 RPC 資料交換格式。它可用於通訊協議、資料儲存等領域的語言無關、平臺無關、可擴充套件的序列化結構資料格式。是一種效率和相容性都很優秀的二進位制資料傳輸格式,目前提供了 C++、Java、Python 三種語言的 API。Caffe採用的是C++和Python的API。

接下來,我用一個簡單的例子來說明一下。

使用PB和 C++ 編寫一個十分簡單的例子程式。該程式由兩部分組成。第一部分被稱為Writer,第二部分叫做Reader。Writer 負責將一些結構化的資料寫入一個磁碟檔案,Reader則負責從該磁碟檔案中讀取結構化資料並列印到螢幕上。準備用於演示的結構化資料是HelloWorld,它包含兩個基本資料:

ID,為一個整數型別的資料;

Str,這是一個字串。

首先我們需要編寫一個proto檔案,定義我們程式中需要處理的結構化資料,Caffe是定義在caffe.proto檔案中。在PB的術語中,結構化資料被稱為 Message。proto檔案非常類似java或C語言的資料定義。程式碼清單 1 顯示了例子應用中的proto檔案內容。

清單 1. proto 檔案
 1 package lm; 
 2 
 3 message helloworld 
 4 
 5  { 
 6 
 7     required int32     id = 1;   // ID    
 8 
 9     required string
str = 2; // str 10 11 optional int32 opt = 3; // optional field 12 13 }
View Code

一個比較好的習慣是認真對待proto檔案的檔名。比如將命名規則定於如下: packageName.MessageName.proto

在上例中,package名字叫做 lm,定義了一個訊息helloworld,該訊息有三個成員,型別為int32的id,另一個為型別為string的成員str。optional是一個可選的成員,即訊息中可以不包含該成員,required表明是必須包含該成員。一般在定義中會出現如下三個欄位屬性:

對於required的欄位而言,初值是必須要提供的,否則欄位的便是未初始化的。 在Debug模式的buffer庫下編譯的話,序列化話的時候可能會失敗,而且在反序列化的時候對於該欄位的解析會總是失敗的。所以,對於修飾符為required的欄位,請在序列化的時候務必給予初始化。

對於optional的欄位而言,如果未進行初始化,那麼一個預設值將賦予該欄位,當然也可以指定預設值。

對於repeated的欄位而言,該欄位可以重複多個,谷歌提供的這個 addressbook例子便有個很好的該修飾符的應用場景,即每個人可能有多個電話號碼。在高階語言裡面,我們可以通過陣列來實現,而在proto定義檔案中可以使用repeated來修飾,從而達到相同目的。當然,出現0次也是包含在內的。

寫好proto檔案之後就可以用PB編譯器(protoc)將該檔案編譯成目標語言了。本例中我們將使用C++。假設proto檔案存放在 $SRC_DIR 下面,您也想把生成的檔案放在同一個目錄下,則可以使用如下命令:

1  protoc -I=$SRC_DIR --cpp_out=$DST_DIR $SRC_DIR/addressbook.proto
View Code

命令將生成兩個檔案:

lm.helloworld.pb.h, 定義了C++ 類的標頭檔案;

lm.helloworld.pb.cc,C++類的實現檔案。

在生成的標頭檔案中,定義了一個 C++ 類 helloworld,後面的 Writer 和 Reader 將使用這個類來對訊息進行操作。諸如對訊息的成員進行賦值,將訊息序列化等等都有相應的方法。

如前所述,Writer將把一個結構化資料寫入磁碟,以便其他人來讀取。假如我們不使用 PB,其實也有許多的選擇。一個可能的方法是將資料轉換為字串,然後將字串寫入磁碟。轉換為字串的方法可以使用 sprintf(),這非常簡單。數字 123 可以變成字串”123”。這樣做似乎沒有什麼不妥,但是仔細考慮一下就會發現,這樣的做法對寫Reader的那個人的要求比較高,Reader 的作者必須瞭解Writer 的細節。比如”123”可以是單個數字 123,但也可以是三個數字 1、2 和 3等等。這麼說來,我們還必須讓Writer定義一種分隔符一樣的字元,以便Reader可以正確讀取。但分隔符也許還會引起其他的什麼問題。最後我們發現一個簡單的Helloworld 也需要寫許多處理訊息格式的程式碼。

如果使用 PB,那麼這些細節就可以不需要應用程式來考慮了。使用PB,Writer 的工作很簡單,需要處理的結構化資料由 .proto 檔案描述,經過上一節中的編譯過程後,該資料化結構對應了一個 C++ 的類,並定義在 lm.helloworld.pb.h 中。對於本例,類名為lm::helloworld。

Writer 需要include該標頭檔案,然後便可以使用這個類了。現在,在Writer程式碼中,將要存入磁碟的結構化資料由一個lm::helloworld類的物件表示,它提供了一系列的 get/set 函式用來修改和讀取結構化資料中的資料成員,或者叫field。

當我們需要將該結構化資料儲存到磁碟上時,類 lm::helloworld 已經提供相應的方法來把一個複雜的資料變成一個位元組序列,我們可以將這個位元組序列寫入磁碟。

對於想要讀取這個資料的程式來說,也只需要使用類 lm::helloworld 的相應反序列化方法來將這個位元組序列重新轉換會結構化資料。這同我們開始時那個“123”的想法類似,不過PB想的遠遠比我們那個粗糙的字串轉換要全面,因此,我們可以放心將這類事情交給PB吧。程式清單 2 演示了 Writer 的主要程式碼。

清單 2. Writer 的主要程式碼
 1  #include "lm.helloworld.pb.h"
 2 
 3  4 
 5  int main(void) 
 6 
 7  { 
 8 
 9   lm::helloworld msg1; 
10 
11   msg1.set_id(101);          //設定id
12 
13   msg1.set_str(“hello”);   //設定str
14 
15   // 向磁碟中寫入資料流fstream 
16 
17   fstream output("./log", ios::out | ios::trunc | ios::binary);  
18 
19   if (!msg1.SerializeToOstream(&output)) { 
20 
21       cerr << "Failed to write msg." << endl; 
22 
23       return -1; 
24 
25   }         
26 
27   return 0; 
28 
29  }
View Code

Msg1 是一個helloworld類的物件,set_id()用來設定id的值。SerializeToOstream將物件序列化後寫入一個fstream流。我們可以寫出Reader程式碼,程式清單3列出了 reader 的主要程式碼。

清單 3. Reader的主要程式碼
 1 #include "lm.helloworld.pb.h" 
 2 
 3  4 
 5  void ListMsg(const lm::helloworld & msg) { 
 6 
 7   cout << msg.id() << endl; 
 8 
 9   cout << msg.str() << endl; 
10 
11  } 
12 
13  int main(int argc, char* argv[]) { 
14 
15   lm::helloworld msg1; 
16 
17   { 
18 
19     fstream input("./log", ios::in | ios::binary); 
20 
21     if (!msg1.ParseFromIstream(&input)) { 
22 
23       cerr << "Failed to parse address book." << endl; 
24 
25       return -1; 
26 
27     } 
28 
29   } 
30 
31   ListMsg(msg1); 
32 
33 34 
35  }
View Code

同樣,Reader 宣告類helloworld的物件msg1,然後利用ParseFromIstream從一個fstream流中讀取資訊並反序列化。此後,ListMsg中採用get方法讀取訊息的內部資訊,並進行列印輸出操作。

執行Writer和Reader的結果如下:

 >writer 
 >reader 
 101 
 Hello

Reader 讀取檔案 log 中的序列化資訊並列印到螢幕上。這個例子本身並無意義,但只要稍加修改就可以將它變成更加有用的程式。比如將磁碟替換為網路 socket,那麼就可以實現基於網路的資料交換任務。而儲存和交換正是PB最有效的應用領域。

到這裡為止,我們只給出了一個簡單的沒有任何用處的例子。在實際應用中,人們往往需要定義更加複雜的 Message。我們用“複雜”這個詞,不僅僅是指從個數上說有更多的 fields 或者更多型別的 fields,而是指更加複雜的資料結構:巢狀 Message,Caffe.proto檔案中定義了大量的巢狀Message。使得Message的表達能力增強很多。程式碼清單 4 給出一個巢狀 Message 的例子。

清單 4. 巢狀 Message 的例子
 1  message Person {
 2   required string name = 1;
 3   required int32 id = 2;        // Unique ID number for this person.
 4   optional string email = 3;
 5   enum PhoneType {
 6     MOBILE = 0;
 7     HOME = 1;
 8     WORK = 2;
 9   }
10  
11   message PhoneNumber {
12     required string number = 1;
13     optional PhoneType type = 2 [default = HOME];
14   }
15   repeated PhoneNumber phone = 4;
16  }
View Code

在 Message Person 中,定義了巢狀訊息 PhoneNumber,並用來定義 Person 訊息中的 phone 域。這使得人們可以定義更加複雜的資料結構。

在Caffe中也是類似於上例中的Writer和Reader去讀寫PB資料的。接下來,具體說明下Caffe中是如何儲存Caffemodel的。在Caffe主目錄下的solver.cpp檔案中的一段程式碼展示了Caffe是如何儲存Caffemodel的,程式碼清單5如下:

清單 5. Caffemodel儲存程式碼
 1 template <typename Dtype>
 2 
 3 void Solver<Dtype>::Snapshot() {
 4 
 5   NetParameter net_param;    // NetParameter為網路引數類
 6 
 7   // 為了中間結果,也會寫入梯度值
 8 
 9   net_->ToProto(&net_param, param_.snapshot_diff());
10 
11   string filename(param_.snapshot_prefix());
12 
13   string model_filename, snapshot_filename;
14 
15   const int kBufferSize = 20;
16 
17   char iter_str_buffer[kBufferSize];
18 
19   // 每訓練完1次,iter_就加1 
20 
21 snprintf(iter_str_buffer, kBufferSize, "_iter_%d", iter_ + 1);
22 
23   filename += iter_str_buffer;
24 
25   model_filename = filename + ".caffemodel"; //XX_iter_YY.caffemodel
26 
27   LOG(INFO) << "Snapshotting to " << model_filename;
28 
29   // 向磁碟寫入網路引數
30 
31   WriteProtoToBinaryFile(net_param, model_filename.c_str());
32 
33   SolverState state;
34 
35   SnapshotSolverState(&state);
36 
37   state.set_iter(iter_ + 1);    //set
38 
39   state.set_learned_net(model_filename);
40 
41   state.set_current_step(current_step_);
42 
43   snapshot_filename = filename + ".solverstate";
44 
45   LOG(INFO) << "Snapshotting solver state to " << snapshot_filename;
46 
47   // 向磁碟寫入網路state
48 
49   WriteProtoToBinaryFile(state, snapshot_filename.c_str());
50 
51 }
View Code

在清單5程式碼中,我們可以看到,其實Caffemodel儲存的資料也就是網路引數net_param的PB,Caffe可以儲存每一次訓練完成後的網路引數,我們可以通過XX.prototxt檔案來進行引數設定。在這裡的 WriteProtoToBinaryFile函式與之前HelloWorld例子中的Writer函式類似,在這就不在貼出。那麼我們只要弄清楚NetParameter類的組成,也就明白了Caffemodel的具體資料構成。在caffe.proto這個檔案中定義了NetParameter類,如程式碼清單6所示。

清單6. Caffemodel儲存程式碼
  1 message NetParameter {
  2 
  3   optional string name = 1;   // 網路名稱
  4 
  5   repeated string input = 3;  // 網路輸入input blobs
  6 
  7   repeated BlobShape input_shape = 8; // The shape of the input blobs
  8 
  9   // 輸入維度blobs,4維(num, channels, height and width)
 10 
 11   repeated int32 input_dim = 4;
 12 
 13   // 網路是否強制每層進行反饋操作開關
 14 
 15 // 如果設定為False,則會根據網路結構和學習率自動確定是否進行反饋操作
 16 
 17   optional bool force_backward = 5 [default = false];
 18 
 19 // 網路的state,部分網路層依賴,部分不依賴,需要看具體網路
 20 
 21   optional NetState state = 6;
 22 
 23   // 是否列印debug log
 24 
 25   optional bool debug_info = 7 [default = false];
 26 
 27   // 網路層引數,Field Number 為100,所以網路層引數在最後
 28 
 29   repeated LayerParameter layer = 100; 
 30 
 31   // 棄用: 用 'layer' 代替
 32 
 33   repeated V1LayerParameter layers = 2;
 34 
 35 }
 36 
 37 // Specifies the shape (dimensions) of a Blob.
 38 
 39 message BlobShape {
 40 
 41   repeated int64 dim = 1 [packed = true];
 42 
 43 }
 44 
 45 message BlobProto {
 46 
 47   optional BlobShape shape = 7;
 48 
 49   repeated float data = 5 [packed = true];
 50 
 51   repeated float diff = 6 [packed = true];
 52 
 53   optional int32 num = 1 [default = 0];
 54 
 55   optional int32 channels = 2 [default = 0];
 56 
 57   optional int32 height = 3 [default = 0];
 58 
 59   optional int32 width = 4 [default = 0];
 60 
 61 }
 62 
 63  
 64 
 65 // The BlobProtoVector is simply a way to pass multiple blobproto instances
 66 
 67 around.
 68 
 69 message BlobProtoVector {
 70 
 71   repeated BlobProto blobs = 1;
 72 
 73 }
 74 
 75 message NetState {
 76 
 77   optional Phase phase = 1 [default = TEST];
 78 
 79   optional int32 level = 2 [default = 0];
 80 
 81   repeated string stage = 3;
 82 
 83 }
 84 
 85 message LayerParameter {
 86 
 87   optional string name = 1;   // the layer name
 88 
 89   optional string type = 2;   // the layer type
 90 
 91   repeated string bottom = 3; // the name of each bottom blob
 92 
 93   repeated string top = 4;    // the name of each top blob
 94 
 95   // The train/test phase for computation.
 96 
 97   optional Phase phase = 10;
 98 
 99   // Loss weight值:float
100 
101   // 每一層為每一個top blob都分配了一個預設值,通常是0或1
102 
103   repeated float loss_weight = 5;
104 
105   // 指定的學習引數
106 
107   repeated ParamSpec param = 6;
108 
109   // The blobs containing the numeric parameters of the layer.
110 
111   repeated BlobProto blobs = 7;
112 
113   // included/excluded.
114 
115   repeated NetStateRule include = 8;
116 
117   repeated NetStateRule exclude = 9;
118 
119   // Parameters for data pre-processing.
120 
121   optional TransformationParameter transform_param = 100;
122 
123   // Parameters shared by loss layers.
124 
125   optional LossParameter loss_param = 101;
126 
127   // 各種型別層引數
128 
129   optional AccuracyParameter accuracy_param = 102;
130 
131   optional ArgMaxParameter argmax_param = 103;
132 
133   optional ConcatParameter concat_param = 104;
134 
135   optional ContrastiveLossParameter contrastive_loss_param = 105;
136 
137   optional ConvolutionParameter convolution_param = 106;
138 
139   optional DataParameter data_param = 107;
140 
141   optional DropoutParameter dropout_param = 108;
142 
143   optional DummyDataParameter dummy_data_param = 109;
144 
145   optional EltwiseParameter eltwise_param = 110;
146 
147   optional ExpParameter exp_param = 111;
148 
149   optional HDF5DataParameter hdf5_data_param = 112;
150 
151   optional HDF5OutputParameter hdf5_output_param = 113;
152 
153   optional HingeLossParameter hinge_loss_param = 114;
154 
155   optional ImageDataParameter image_data_param = 115;
156 
157   optional InfogainLossParameter infogain_loss_param = 116;
158 
159   optional InnerProductParameter inner_product_param = 117;
160 
161   optional LRNParameter lrn_param = 118;
162 
163   optional MemoryDataParameter memory_data_param = 119;
164 
165   optional MVNParameter mvn_param = 120;
166 
167   optional PoolingParameter pooling_param = 121;
168 
169   optional PowerParameter power_param = 122;
170 
171   optional PythonParameter python_param = 130;
172 
173   optional ReLUParameter relu_param = 123;
174 
175   optional SigmoidParameter sigmoid_param = 124;
176 
177   optional SoftmaxParameter softmax_param = 125;
178 
179   optional SliceParameter slice_param = 126;
180 
181   optional TanHParameter tanh_param = 127;
182 
183   optional ThresholdParameter threshold_param = 128;
184 
185   optional WindowDataParameter window_data_param = 129;
186 
187 }
View Code

那麼接下來的一段程式碼來演示如何解析Caffemodel,我解析用的model為MNIST手寫庫訓練後的model,Lenet_iter_10000.caffemodel。

清單7. Caffemodel解析程式碼
 1  #include <stdio.h>
 2  #include <string.h>
 3  #include <fstream>
 4  #include <iostream>
 5  #include "proto/caffe.pb.h"
 6 
 7  using namespace std;
 8  using namespace caffe;
 9 
10  int main(int argc, char* argv[]) 
11  { 
12 
13   caffe::NetParameter msg; 
14 
15   fstream input("lenet_iter_10000.caffemodel", ios::in | ios::binary); 
16   if (!msg.ParseFromIstream(&input)) 
17   { 
18     cerr << "Failed to parse address book." << endl; 
19     return -1; 
20   } 
21   printf("length = %d\n", length);
22   printf("Repeated Size = %d\n", msg.layer_size());
23 
24   ::google::protobuf::RepeatedPtrField< LayerParameter >* layer = msg.mutable_layer();
25   ::google::protobuf::RepeatedPtrField< LayerParameter >::iterator it = layer->begin();
26   for (; it != layer->end(); ++it)
27   {
28     cout << it->name() << endl;
29     cout << it->type() << endl;
30     cout << it->convolution_param().weight_filler().max() << endl;
31   } 
32 
33   return 0;
34  }

View Code

這篇Blog仍然是以Google的官方文件為主線,程式碼例項則完全取自於我們正在開發的一個Demo專案,通過前一段時間的嘗試,感覺這種結合的方式比較有利於培訓和內部的技術交流。還是那句話,沒有最好的,只有