DispNet中Caffe自定義層解讀(一)—— CustomData
阿新 • • 發佈:2018-11-06
DispNet中Caffe自定義層解讀(一)—— CustomData
這一系列博文記錄了博主在學習DispNet過程中遇到的自定義Caffe層的筆記。這一部分是CustomData層,其主要功能是:讀取資料庫中的LMDB型別資料,並將其隨機排布後存入top。更新於2018.10.25。
文章目錄
呼叫方式
layer { name: "CustomData1" type: "CustomData" top: "blob0" top: "blob1" top: "blob2" include { phase: TRAIN } data_param { source: "path/to/your/data/lmdb" batch_size: 32 backend: LMDB rand_permute: true rand_permute_seed: 77 slice_point: 3 slice_point: 6 encoding: UINT8 encoding: UINT8 encoding: UINT16FLOW verbose: true } }
custom_data_layer.hpp
定義了LMDB資料型別的幾個變數:
// LMDB
MDB_env* mdb_env_;
MDB_dbi mdb_dbi_;
MDB_txn* mdb_txn_;
MDB_cursor* mdb_cursor_;
MDB_val mdb_key_, mdb_value_;
小知識:
- 什麼是控制代碼? 簡單來說,如果從一個數可以“拎出”很多東西,那麼這個數就是控制代碼。
- MDB_env*中的*號代表什麼: 表示指向前面那個型別的指標,可以理解為mdb_env這個變數裡放的地址,誰賦值給這個變數,就放的誰的地址。(注:解釋來自C++大牛@
- MDB_env: 為資料庫環境(database environment)定義的一個不透明結構體。官網解釋:Opaque structure for a database environmen。更多官網解釋看這裡。
- MDB_dbi: DB環境下的個人資料庫(individual database)的控制代碼。官網解釋:A handle for an individual database in the DB environment。
- MDB_txn: 為一個事務控制代碼(transaction handle)定義一個不透明結構體(Opaque structure)。官網解釋:Opaque structure for a transaction handle。更多官網解釋看這裡。
- MDB_cursor: 為巡航一個數據庫定義的不透明結構體。官網解釋:Opaque structure for navigating through a database。更多官網解釋看這裡。
- MDB_val: 用於將keys和資料傳入、傳出資料庫的語類結構(generic structure)。
定義變數用於宣告執行緒ID:
pthread_t thread_;
定義智慧指標:
shared_ptr<Blob<Dtype> > prefetch_label_;
vector<shared_ptr<Blob<Dtype> > > prefetch_data_blobs_;
還有其他Caffe中通用的函式及變數定義,此處不作贅述。
custom_data_layer.cpp
Forward_cpu
用函式JoinPrefetchThread();
將執行緒joint在一起,並檢查是否成功。如果失敗,返回"Pthread joining failed."
。
注:整合用到函式pthread_join
,用來等待一個執行緒的結束,執行緒間同步的操作。標頭檔案 : #include <pthread.h>。具體描述看這裡。
將cpu中的資料複製到top中:
for (int i = 0; i <= slice_point_.size(); ++i) {
// Copy the data
caffe_copy(prefetch_data_blobs_[i]->count(), prefetch_data_blobs_[i]->cpu_data(), top[i]->mutable_cpu_data());
}
其中,prefetch_data_blobs_[i]->cpu_data()
為被複制的源資料,top[i]->mutable_cpu_data()
為複製到的目標資料。
如果output_labels(標頭檔案中定義的bool型別變數)為真,地址移動到整個slice的下一個地址(label_topblob = slice_point_.size() + 1;
),將prefetch_label_
中的內容複製給top[label_topblob]
。
隨機排布輸入的影象。
iter_++;
if (this->layer_param_.data_param().rand_permute() && this->layer_param_.data_param().permute_every_iter()) { //如果層引數中設定了rand_permute(true)和permute_every_iter
if (iter_ % this->layer_param_.data_param().permute_every_iter() == 0) { //如果permute_every_iter為0
generateRandomPermutation(-1, this->layer_param_.data_param().block_size()); //呼叫generateRandomPermutation函式,引數為-1和層引數中的block_size
if (this->layer_param_.data_param().verbose()) { //如果需要,將新的排布順序顯示出來。
printf("Re-permuting at iteration %d. Permutation:\n", iter_);
for(int j = 0; j < permutation_vector_.size(); j++) {
printf("%d ",permutation_vector_.at(j));
}
printf("\n");
}
}
}
生成一個新的執行緒。
// Start a new prefetch thread
CreatePrefetchThread();
generateRandomPermutation
在Forward_cpu中,這個函式的輸入為-1和層引數block_size的值。此時函式的功能是:將所有輸入的影象重新隨機排布。
template <typename Dtype>
void CustomDataLayer<Dtype>::generateRandomPermutation(int seed, int block_size) {
if (seed > 0) //如果seed大於0,根據seed初始化隨機數發生器。(srand函式是隨機數發生器的初始化函式)
std::srand (unsigned(seed));
if (block_size > 0) { //如果block_size大於0,
int num_blocks = (permutation_vector_.size() + block_size - 1) / block_size; // equal to ceil(size / block_size)
for (int b=0; b < num_blocks; ++b) {
int n1 = b * block_size;
int n2 = std::min((b+1)*block_size, static_cast<int>(permutation_vector_.size()));
std::random_shuffle(permutation_vector_.begin() + n1, permutation_vector_.begin() + n2 -1);
}
} else { //否則,將permutation_vector_中的數隨機排列。變數定義在hpp中:std::vector<int> permutation_vector_;
std::random_shuffle(permutation_vector_.begin(), permutation_vector_.end());
}
}
CreatePrefetchThread
template <typename Dtype>
void CustomDataLayer<Dtype>::CreatePrefetchThread() {
const bool prefetch_needs_rand = (this->phase_ == TRAIN) &&
(this->layer_param_.data_param().mirror() ||
this->layer_param_.data_param().crop_size());
if (prefetch_needs_rand) {
const unsigned int prefetch_rng_seed = caffe_rng_rand();
prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed));
} else {
prefetch_rng_.reset();
}
// Create the thread.
CHECK(!pthread_create(&thread_, NULL, CustomDataLayerPrefetch<Dtype>,
static_cast<void*>(this))) << "Pthread execution failed.";
}
DecodeData
template <typename Dtype>
void DecodeData(Dtype*& ptr,Datum& datum,const vector<int>& slice_points,const vector<int>& encoding)
{
int width=datum.width();
int height=datum.height();
int channels=datum.channels();
int count=width*height*channels;
ptr=new Dtype[count];
if(datum.float_data_size())
{
CHECK_EQ(encoding.size(),0) << "Encoded layers must be stored as uint8 in LMDB.";
for(int i=0; i<count; i++)
ptr[i]=datum.float_data(i);
return;
}
const unsigned char* srcptr=(const unsigned char*)datum.data().c_str();
Dtype* destptr=ptr;
int channel_start = -1; //inclusive
int channel_end = 0; //non-inclusive (end will become start in next slice)
for(int slice = 0; slice <= slice_points.size(); slice++)
{
channel_start = channel_end;
if(slice == slice_points.size())
channel_end = channels;
else
channel_end = slice_points[slice];
int channel_count=channel_end-channel_start;
int format;
if(encoding.size()<=slice)
format=DataParameter_CHANNELENCODING_UINT8;
else
format=encoding[slice];
// LOG(INFO) << "Slice " << slice << "(" << channel_start << "," << channel_end << ") has format " << ((int)format);
switch(format)
{
case DataParameter_CHANNELENCODING_UINT8:
for(int c=0; c<channel_count; c++)
for(int y=0; y<height; y++)
for(int x=0; x<width; x++)
*(destptr++)=static_cast<Dtype>(*(srcptr++));
break;
case DataParameter_CHANNELENCODING_UINT16FLOW:
for(int c=0; c<channel_count; c++)
for(int y=0; y<height; y++)
for(int x=0; x<width; x++)
{
short v;
*((unsigned char*)&v)=*(srcptr++);
*((unsigned char*)&v+1)=*(srcptr++);
Dtype value;
if(v==std::numeric_limits<short>::max()) {
value = std::numeric_limits<Dtype>::signaling_NaN();
} else {
value = ((Dtype)v)/32.0;
}
*(destptr++)=value;
}
break;
case DataParameter_CHANNELENCODING_BOOL1:
{
int j=0;
for(int i=0; i<(width*height-1)/8+1; i++)
{
unsigned char data=*(srcptr++);
for(int k=0; k<8; k++)
{
float value=(data&(1<<k))==(1<<k);
if(j<width*height)
*(destptr++)=value?1.0:0;
j++;
}
}
}
break;
default:
LOG(FATAL) << "Invalid format for slice " << slice;
break;
}
}
// LOG(INFO) << destptr << " " << ptr;
assert(destptr==ptr+count);
}