caffe 原始碼分析【一】: Blob類
阿新 • • 發佈:2018-11-10
-
Blob類的:
//標頭檔案: include\caffe\blob.hpp //cpp檔案: src\caffe\blob.cpp //cu檔案: src/caffe/blob.cu //定義某layer的輸入blobs const vector<Blob<Dtype> *> bottom; //定義某layer輸出blobs const vector<Blob<Dtype> *> top; //獲取blob中不可變的資料指標 const Dtype* bottom_data = bottom[0]->cpu_data(); //獲取blob中可變資料指標 Dtype* top_data = top[0]->mutable_cpu_data(); //獲取blob中不可變的梯度指標 const Dtype * top_diff = top[0]->cpu_diff(); //獲取blob中可變的梯度指標 Dtype * bottom_diff = bottom[0]->bottom_diff(); //獲取blob中資料單元的數量, 等於 Batch * C * H * W const int count = bottom[0]->count(); //獲取BatchSize大小 const int num = bottom[0]->num(); //獲取通道數 const int channels = bottom[0]->channels(); //獲取圖片高度 const int height = bottom[0]->height(); //獲取圖片寬度 const int width = bottom[0]->width(); //獲取指定維度的大小的通用方法,可以使用賦值index num = bottom[0]->shape(0) #第一個維度大小,通常為bath size width = bottom[0]->shape(-1) #最後一個維度大小,通常為width const vector<int> &bottom_shape = bottom[0]->shape(); //得到有多少個維度(axes num) const axes_num = bottom[0]->num_axes(); //讀取具體的資料 const datum = bottom[0]->data_at(0,0,0,0); //獲取batch 0 channel 0 height 0 width 0的資料值 const diff_datum = bottom[0]->diff_at(0,0,0,0) //獲取batch 0 channel 0 height 0 width 0的梯度 vector<int> vector_index; for(int i=0; i<bottom[0]->num_axes(); i++) vector_index.push_back(0); datum = bottom[0]->data_at(vector_index); diff_datum = bottom[0]->diff_at(vector_index); /* * 修改blob的尺寸 * 記憶體不夠時會重新分配記憶體,存在多餘的記憶體則不會釋放 * layer::reshape()後需要執行Net:Farward()或者Net:Reshape()調整整個網路結構,之後才可以呼叫 * Net:Backward(), * */ bottom[0]->Reshape(1,2,3,4); /* *Update()是更新網路中引數設計的blob函式, 引數= 引數+ alpha*引數梯度 *計算公式為: Y = alpha * X + Y * Y= bottom[0]->mutable_cpu_data(); * X= bottom[0]->cpu_diff(); * alpha為梯度下降演算法的超引數 */
BLOB操作例項:
#include<vector> #include<iostream> #include<caffe/blob.hpp> #include<caffe/util/io.hpp> using namespace caffe; using namespace std; void print_blob(Blob<float> *a) { for(int u = 0;u<a->num();u++) for(int v = 0;v<a->channels();v++) for(int w=0;w<a->height();w++) for(int x = 0;x<a->width();x++) //輸出blob的值 cout<<"a["<<u<<"]["<<v<<"]["<<w<<"]["<<x<<"]="<<a->data_at(u,v,w,x) <<endl; } int main(void) { Blob<float> a; BlobProto bp; cout<<"size:"<<a.shape_string()<<endl; a.Reshape(1,2,3,4); cout<<"after:"<<a.shape_string()<<endl; float *p=a.mutable_cpu_data(); float *q=a.mutable_cpu_diff(); for(int i = 0;i<a.count();i++){ p[i]=i; q[i]=a.count() - 1 - i; } //更新blob的資料,主要用於更新網路層引數 a.Update();//diff data combine print_blob(&a); //計算blob data的L1範數 cout<<"ASUM="<<a.asum_data()<<endl; //計算blob data的L2範數 cout<<"SUMSQ="<<a.sumsq_data()<<endl; //儲存網路引數資料 a.ToProto(&bp,true); //生成BlobProto物件 WriteProtoToBinaryFile(bp,"a.blob");//寫檔案 //從檔案中讀取網路引數 BlobProto bp2; ReadProtoFromBinaryFileOrDie("a.blob",&bp2); Blob<float> b; b.FromProto(bp2,true); print_blob(&b); return 0; }
其中blobProto的定義如下:
//src/caffe/proto/caffe.proto檔案重定義 message BlobProto { optional BlobShape shape = 7; repeated float data = 5 [packed = true]; repeated float diff = 6 [packed = true]; repeated double double_data = 8 [packed = true]; repeated double double_diff = 9 [packed = true]; // 4D dimensions -- deprecated. Use "shape" instead. optional int32 num = 1 [default = 0]; optional int32 channels = 2 [default = 0]; optional int32 height = 3 [default = 0]; optional int32 width = 4 [default = 0]; }