1. 程式人生 > >caffe 原始碼分析【一】: Blob類

caffe 原始碼分析【一】: Blob類

  • Blob類的:    

//標頭檔案: include\caffe\blob.hpp
//cpp檔案: src\caffe\blob.cpp
//cu檔案: src/caffe/blob.cu
//定義某layer的輸入blobs
const vector<Blob<Dtype> *> bottom;

//定義某layer輸出blobs
const vector<Blob<Dtype> *> top;

//獲取blob中不可變的資料指標
const Dtype* bottom_data = bottom[0]->cpu_data();

//獲取blob中可變資料指標
Dtype* top_data = top[0]->mutable_cpu_data();

//獲取blob中不可變的梯度指標
const Dtype * top_diff = top[0]->cpu_diff();

//獲取blob中可變的梯度指標
Dtype * bottom_diff = bottom[0]->bottom_diff();

//獲取blob中資料單元的數量, 等於 Batch * C * H * W
const int count = bottom[0]->count();

//獲取BatchSize大小
const int num = bottom[0]->num();

//獲取通道數
const int channels = bottom[0]->channels();

//獲取圖片高度
const int height = bottom[0]->height();

//獲取圖片寬度
const int width = bottom[0]->width();

//獲取指定維度的大小的通用方法,可以使用賦值index
num = bottom[0]->shape(0)       #第一個維度大小,通常為bath size
width = bottom[0]->shape(-1)    #最後一個維度大小,通常為width
const vector<int> &bottom_shape = bottom[0]->shape();

//得到有多少個維度(axes num)
const axes_num = bottom[0]->num_axes();


//讀取具體的資料
const datum = bottom[0]->data_at(0,0,0,0); //獲取batch 0 channel 0 height 0 width 0的資料值
const diff_datum = bottom[0]->diff_at(0,0,0,0)  //獲取batch 0 channel 0 height 0 width 0的梯度    
vector<int> vector_index;
for(int i=0; i<bottom[0]->num_axes(); i++)
    vector_index.push_back(0);
datum = bottom[0]->data_at(vector_index);
diff_datum = bottom[0]->diff_at(vector_index);

/*
 * 修改blob的尺寸
 * 記憶體不夠時會重新分配記憶體,存在多餘的記憶體則不會釋放
 * layer::reshape()後需要執行Net:Farward()或者Net:Reshape()調整整個網路結構,之後才可以呼叫 
 * Net:Backward(),
 *
*/
bottom[0]->Reshape(1,2,3,4);

/*
 *Update()是更新網路中引數設計的blob函式, 引數= 引數+ alpha*引數梯度
 *計算公式為: Y = alpha * X + Y
 *    Y= bottom[0]->mutable_cpu_data();
 *    X= bottom[0]->cpu_diff();
 *    alpha為梯度下降演算法的超引數
*/

BLOB操作例項:

#include<vector>
#include<iostream>
#include<caffe/blob.hpp>
#include<caffe/util/io.hpp>

using namespace caffe;
using namespace std;
void print_blob(Blob<float> *a)
{
    for(int u = 0;u<a->num();u++)
        for(int v = 0;v<a->channels();v++)
            for(int w=0;w<a->height();w++)
                for(int x = 0;x<a->width();x++)
                    //輸出blob的值
                    cout<<"a["<<u<<"]["<<v<<"]["<<w<<"]["<<x<<"]="<<a->data_at(u,v,w,x)                
                      <<endl;
}

  int main(void)
  {
          Blob<float> a;
          BlobProto bp;
          cout<<"size:"<<a.shape_string()<<endl;
           
          a.Reshape(1,2,3,4);
          cout<<"after:"<<a.shape_string()<<endl;

          float *p=a.mutable_cpu_data();
          float *q=a.mutable_cpu_diff();
          for(int i = 0;i<a.count();i++){
                  p[i]=i;
                  q[i]=a.count() - 1 - i;
          }
          //更新blob的資料,主要用於更新網路層引數
          a.Update();//diff data combine
          print_blob(&a);
          //計算blob data的L1範數
          cout<<"ASUM="<<a.asum_data()<<endl;
          //計算blob data的L2範數
          cout<<"SUMSQ="<<a.sumsq_data()<<endl;
          
          //儲存網路引數資料
          a.ToProto(&bp,true);    //生成BlobProto物件
          WriteProtoToBinaryFile(bp,"a.blob");//寫檔案
          
          //從檔案中讀取網路引數
          BlobProto bp2;
          ReadProtoFromBinaryFileOrDie("a.blob",&bp2);
          Blob<float> b;
          b.FromProto(bp2,true);
          print_blob(&b);

          return 0;
  }

其中blobProto的定義如下:

//src/caffe/proto/caffe.proto檔案重定義
message BlobProto {
  optional BlobShape shape = 7;
  repeated float data = 5 [packed = true];
  repeated float diff = 6 [packed = true];
  repeated double double_data = 8 [packed = true];
  repeated double double_diff = 9 [packed = true];

  // 4D dimensions -- deprecated.  Use "shape" instead.
  optional int32 num = 1 [default = 0];
  optional int32 channels = 2 [default = 0];
  optional int32 height = 3 [default = 0];
  optional int32 width = 4 [default = 0];
}