1. 程式人生 > >openpose訓練程式碼(二)

openpose訓練程式碼(二)

在上一篇openpose訓練程式碼(一) 中講到cpm_data_transformer,其實這個檔案才是包含資料處理核心程式碼的檔案,在上一篇部落格提高Transform_nv函式,我們先來看看Transform_nv函式:

template<typename Dtype> void CPMDataTransformer<Dtype>::Transform_nv(const Datum& datum, Blob<Dtype>* transformed_data, Blob<Dtype>* transformed_label, int cnt) {
  //std::cout << "Function 2 is used"; std::cout.flush();
const int datum_channels = datum.channels(); //const int datum_height = datum.height(); //const int datum_width = datum.width(); const int im_channels = transformed_data->channels(); //const int im_height = transformed_data->height(); //const int im_width = transformed_data->width(); const
int im_num = transformed_data->num(); //const int lb_channels = transformed_label->channels(); //const int lb_height = transformed_label->height(); //const int lb_width = transformed_label->width(); const int lb_num = transformed_label->num(); //LOG(INFO) << "image shape: " << transformed_data->num() << " " << transformed_data->channels() << " "
// << transformed_data->height() << " " << transformed_data->width(); //LOG(INFO) << "label shape: " << transformed_label->num() << " " << transformed_label->channels() << " " // << transformed_label->height() << " " << transformed_label->width(); CHECK_EQ(datum_channels, 6); CHECK_EQ(im_channels, 6); ///CHECK_EQ(im_channels, 4); //CHECK_EQ(datum_channels, 4); CHECK_EQ(im_num, lb_num); //CHECK_LE(im_height, datum_height); //CHECK_LE(im_width, datum_width); CHECK_GE(im_num, 1); //const int crop_size = param_.crop_size(); // if (crop_size) { // CHECK_EQ(crop_size, im_height); // CHECK_EQ(crop_size, im_width); // } else { // CHECK_EQ(datum_height, im_height); // CHECK_EQ(datum_width, im_width); // } Dtype* transformed_data_pointer = transformed_data->mutable_cpu_data(); Dtype* transformed_label_pointer = transformed_label->mutable_cpu_data(); CPUTimer timer; timer.Start(); Transform_nv(datum, transformed_data_pointer, transformed_label_pointer, cnt); //call function 1 VLOG(2) << "Transform_nv: " << timer.MicroSeconds() / 1000.0 << " ms"; }

這個函式主要就是得到lmdb的一些引數,比如datum_channels,im_channels 等,轉而呼叫Transform_nv函式

template<typename Dtype> void CPMDataTransformer<Dtype>::Transform_nv(const Datum& datum, Dtype* transformed_data, Dtype* transformed_label, int cnt) {
  ...
}

data是lmdb的首地址,datum_channels,datum_height ,datum_width 分別是之前python程式碼確定的每頁的尺寸,mask_miss 和mask_all全1的矩陣,為後續所用做準備。

  const string& data = datum.data();
  const int datum_channels = datum.channels();
  const int datum_height = datum.height();
  const int datum_width = datum.width();
  // To do: make this a parameter in caffe.proto
  //const int mode = 5; //related to datum.channels();
  const int mode = 5;

  //const int crop_size = param_.crop_size();
  //const Dtype scale = param_.scale();
  //const bool do_mirror = param_.mirror() && Rand(2);
  //const bool has_mean_file = param_.has_mean_file();
  const bool has_uint8 = data.size() > 0;
  //const bool has_mean_values = mean_values_.size() > 0;
  int crop_x = param_.crop_size_x();
  int crop_y = param_.crop_size_y();

  CHECK_GT(datum_channels, 0);
  //CHECK_GE(datum_height, crop_size);
  //CHECK_GE(datum_width, crop_size);
  CPUTimer timer1;
  timer1.Start();
  //before any transformation, get the image from datum
  Mat img = Mat::zeros(datum_height, datum_width, CV_8UC3);
  Mat mask_all, mask_miss;
  if(mode >= 5){
    mask_miss = Mat::ones(datum_height, datum_width, CV_8UC1);
  }
  if(mode == 6){
    mask_all = Mat::zeros(datum_height, datum_width, CV_8UC1);
  }

讀取原始圖片資料儲存在rbg中,以及讀取mask_miss 和 mask_all,如下:
offset = img.rows * img.cols,為指標偏移量,和python檔案一一對應。

  int offset = img.rows * img.cols;
  int dindex;
  Dtype d_element;
  for (int i = 0; i < img.rows; ++i) {
    for (int j = 0; j < img.cols; ++j) {
      Vec3b& rgb = img.at<Vec3b>(i, j);
      for(int c = 0; c < 3; c++){
        dindex = c*offset + i*img.cols + j;
        if (has_uint8)
          d_element = static_cast<Dtype>(static_cast<uint8_t>(data[dindex]));
        else
          d_element = datum.float_data(dindex);
        rgb[c] = d_element;
      }

      if(mode >= 5){
        dindex = 4*offset + i*img.cols + j;
        if (has_uint8)
          d_element = static_cast<Dtype>(static_cast<uint8_t>(data[dindex]));
        else
          d_element = datum.float_data(dindex);
        if (round(d_element/255)!=1 && round(d_element/255)!=0){
          cout << d_element << " " << round(d_element/255) << endl;
        }
        mask_miss.at<uchar>(i, j) = d_element; //round(d_element/255);
      }

      if(mode == 6){
        dindex = 5*offset + i*img.cols + j;
        if (has_uint8)
          d_element = static_cast<Dtype>(static_cast<uint8_t>(data[dindex]));
        else
          d_element = datum.float_data(dindex);
        mask_all.at<uchar>(i, j) = d_element;
      }
    }
  }
  VLOG(2) << "  rgb[:] = datum: " << timer1.MicroSeconds()/1000.0 << " ms";
  timer1.Start();

接下來開始讀meta檔案,就是儲存的關鍵點和尺寸資訊,其中關鍵的是ReadMetaData函式,這個函式就是完全按照python寫入格式來讀的,所以,一定要理清楚python程式碼的邏輯,不然,這裡很容易混亂,同時,這裡有一個小的技巧就是轉換了關鍵點的順序,TransformMetaJoints函式實現這一功能,其實就是為了和MPII資料集對應,我的理解是方便transfer 權重,程式碼如下:

  //color, contract
  if(param_.do_clahe())
    clahe(img, clahe_tileSize, clahe_clipLimit);
  if(param_.gray() == 1){
    cv::cvtColor(img, img, CV_BGR2GRAY);
    cv::cvtColor(img, img, CV_GRAY2BGR);
  }
  VLOG(2) << "  color: " << timer1.MicroSeconds()/1000.0 << " ms";
  timer1.Start();

  int offset3 = 3 * offset;
  int offset1 = datum_width;
  int stride = param_.stride();
  ReadMetaData(meta, data, offset3, offset1);
  if(param_.transform_body_joint()) // we expect to transform body joints, and not to transform hand joints
    TransformMetaJoints(meta);

  VLOG(2) << "  ReadMeta+MetaJoints: " << timer1.MicroSeconds()/1000.0 << " ms";

讀取到原始資料後,接下來做的就是資料增廣,原始程式碼主要做了如下幾種資料增廣:scale、rotate、crop、flip;具體實現如下,沒做一個都是疊加在原來的基礎上,這裡在做資料增廣的時候,用到了原圖scale的資訊:

  //Start transforming
  Mat img_aug = Mat::zeros(crop_y, crop_x, CV_8UC3);
  Mat mask_miss_aug, mask_all_aug ;
  //Mat mask_miss_aug = Mat::zeros(crop_y, crop_x, CV_8UC1);
  //Mat mask_all_aug = Mat::zeros(crop_y, crop_x, CV_8UC1);
  Mat img_temp, img_temp2, img_temp3; //size determined by scale
  VLOG(2) << "   input size (" << img.cols << ", " << img.rows << ")"; 
  // We only do random transform as augmentation when training.
  if (phase_ == TRAIN) {
    as.scale = augmentation_scale(img, img_temp, mask_miss, mask_all, meta, mode);
    //LOG(INFO) << meta.joint_self.joints.size();
    //LOG(INFO) << meta.joint_self.joints[0];
    as.degree = augmentation_rotate(img_temp, img_temp2, mask_miss, mask_all, meta, mode);
    //LOG(INFO) << meta.joint_self.joints.size();
    //LOG(INFO) << meta.joint_self.joints[0];
    if(0 && param_.visualize()) 
      visualize(img_temp2, meta, as);
    as.crop = augmentation_croppad(img_temp2, img_temp3, mask_miss, mask_miss_aug, mask_all, mask_all_aug, meta, mode);
    //LOG(INFO) << meta.joint_self.joints.size();
    //LOG(INFO) << meta.joint_self.joints[0];
    if(0 && param_.visualize()) 
      visualize(img_temp3, meta, as);
    as.flip = augmentation_flip(img_temp3, img_aug, mask_miss_aug, mask_all_aug, meta, mode);
    //LOG(INFO) << meta.joint_self.joints.size();
    //LOG(INFO) << meta.joint_self.joints[0];
    if(param_.visualize()) 
      visualize(img_aug, meta, as);

    // imshow("img_aug", img_aug);
    // Mat label_map = mask_miss_aug;
    // applyColorMap(label_map, label_map, COLORMAP_JET);
    // addWeighted(label_map, 0.5, img_aug, 0.5, 0.0, label_map);
    // imshow("mask_miss_aug", label_map);

    if (mode > 4){
      resize(mask_miss_aug, mask_miss_aug, Size(), 1.0/stride, 1.0/stride, INTER_CUBIC);
    }
    if (mode > 5){
      resize(mask_all_aug, mask_all_aug, Size(), 1.0/stride, 1.0/stride, INTER_CUBIC);
    }
  }
  else {
    img_aug = img.clone();
    as.scale = 1;
    as.crop = Size();
    as.flip = 0;
    as.degree = 0;
  }
  VLOG(2) << "  Aug: " << timer1.MicroSeconds()/1000.0 << " ms";
  timer1.Start();

資料增廣過後就是歸一化,和準備label檔案,有一點不同的地方就是負責背景關鍵點的那一個label使用的是mask_miss資訊,同時,把輸入歸一化到 [-0.5, 0.5] 具體如下:

  for (int i = 0; i < img_aug.rows; ++i) {
    for (int j = 0; j < img_aug.cols; ++j) {
      Vec3b& rgb = img_aug.at<Vec3b>(i, j);
      transformed_data[0*offset + i*img_aug.cols + j] = (rgb[0] - 128)/256.0;
      transformed_data[1*offset + i*img_aug.cols + j] = (rgb[1] - 128)/256.0;
      transformed_data[2*offset + i*img_aug.cols + j] = (rgb[2] - 128)/256.0;
    }
  }

  // label size is image size/ stride
  if (mode > 4){
    for (int g_y = 0; g_y < grid_y; g_y++){
      for (int g_x = 0; g_x < grid_x; g_x++){
        for (int i = 0; i < np; i++){
          float weight = float(mask_miss_aug.at<uchar>(g_y, g_x)) /255; //mask_miss_aug.at<uchar>(i, j); 
          if (meta.joint_self.isVisible[i] != 3){
            transformed_label[i*channelOffset + g_y*grid_x + g_x] = weight;
          }
        }  
        // background channel
        if(mode == 5){
          transformed_label[np*channelOffset + g_y*grid_x + g_x] = float(mask_miss_aug.at<uchar>(g_y, g_x)) /255;
        }
        if(mode > 5){
          transformed_label[np*channelOffset + g_y*grid_x + g_x] = 1;
          transformed_label[(2*np+1)*channelOffset + g_y*grid_x + g_x] = float(mask_all_aug.at<uchar>(g_y, g_x)) /255;
        }
      }
    }
  }  

做完上面的工作,把圖片資料準備好,背景關鍵點準備好,就剩下其它關鍵點和PAF的label了,主要是在generateLabelMap函式中完成。

  //putGaussianMaps(transformed_data + 3*offset, meta.objpos, 1, img_aug.cols, img_aug.rows, param_.sigma_center());
  //LOG(INFO) << "image transformation done!";
  generateLabelMap(transformed_label, img_aug, meta);

  VLOG(2) << "  putGauss+genLabel: " << timer1.MicroSeconds()/1000.0 << " ms";
  //starts to visualize everything (transformed_data in 4 ch, label) fed into conv1
  //if(param_.visualize()){
    //dumpEverything(transformed_data, transformed_label, meta);
  //}

具體的,我們來看一下generateLabelMap函式,大概的說來,主要就是做兩件事,其一是在每個關鍵點部位放置高斯響應,其二就是在有連線的關鍵點之間放vector,更具體的細節,可以去查閱原始碼,這裡不再做更為詳細的說明:

template<typename Dtype>
void CPMDataTransformer<Dtype>::generateLabelMap(Dtype* transformed_label, Mat& img_aug, MetaData meta) {
  int rezX = img_aug.cols;
  int rezY = img_aug.rows;
  int stride = param_.stride();
  int grid_x = rezX / stride;
  int grid_y = rezY / stride;
  int channelOffset = grid_y * grid_x;
  int mode = 5; // TO DO: make this as a parameter

  for (int g_y = 0; g_y < grid_y; g_y++){
    for (int g_x = 0; g_x < grid_x; g_x++){
      for (int i = np+1; i < 2*(np+1); i++){
        if (mode == 6 && i == (2*np + 1))
          continue;
        transformed_label[i*channelOffset + g_y*grid_x + g_x] = 0;
      }
    }
  }

  if (np == 56){
    for (int i = 0; i < 18; i++){
      Point2f center = meta.joint_self.joints[i];
      if(meta.joint_self.isVisible[i] <= 1){
        putGaussianMaps(transformed_label + (i+np+39)*channelOffset, center, param_.stride(), 
                        grid_x, grid_y, param_.sigma()); //self
      }
      for(int j = 0; j < meta.numOtherPeople; j++){ //for every other person
        Point2f center = meta.joint_others[j].joints[i];
        if(meta.joint_others[j].isVisible[i] <= 1){
          putGaussianMaps(transformed_label + (i+np+39)*channelOffset, center, param_.stride(), 
                          grid_x, grid_y, param_.sigma());
        }
      }
    }

    int mid_1[19] = {2, 9,  10, 2,  12, 13, 2, 3, 4, 3,  2, 6, 7, 6,  2, 1,  1,  15, 16};
    int mid_2[19] = {9, 10, 11, 12, 13, 14, 3, 4, 5, 17, 6, 7, 8, 18, 1, 15, 16, 17, 18};
    int thre = 1;

    for(int i=0;i<19;i++){
      Mat count = Mat::zeros(grid_y, grid_x, CV_8UC1);
      Joints jo = meta.joint_self;
      if(jo.isVisible[mid_1[i]-1]<=1 && jo.isVisible[mid_2[i]-1]<=1){
        //putVecPeaks
        putVecMaps(transformed_label + (np+ 1+ 2*i)*channelOffset, transformed_label + (np+ 2+ 2*i)*channelOffset, 
                  count, jo.joints[mid_1[i]-1], jo.joints[mid_2[i]-1], param_.stride(), grid_x, grid_y, param_.sigma(), thre); //self
      }

      for(int j = 0; j < meta.numOtherPeople; j++){ //for every other person
        Joints jo2 = meta.joint_others[j];
        if(jo2.isVisible[mid_1[i]-1]<=1 && jo2.isVisible[mid_2[i]-1]<=1){
          //putVecPeaks
          putVecMaps(transformed_label + (np+ 1+ 2*i)*channelOffset, transformed_label + (np+ 2+ 2*i)*channelOffset, 
                  count, jo2.joints[mid_1[i]-1], jo2.joints[mid_2[i]-1], param_.stride(), grid_x, grid_y, param_.sigma(), thre); //self
        }
      }
    }

    //put background channel
    for (int g_y = 0; g_y < grid_y; g_y++){
      for (int g_x = 0; g_x < grid_x; g_x++){
        float maximum = 0;
        //second background channel
        for (int i = np+39; i < np+57; i++){
          maximum = (maximum > transformed_label[i*channelOffset + g_y*grid_x + g_x]) ? maximum : transformed_label[i*channelOffset + g_y*grid_x + g_x];
        }
        transformed_label[(2*np+1)*channelOffset + g_y*grid_x + g_x] = max(1.0-maximum, 0.0);
      }
    }
    //LOG(INFO) << "background put";
  }

  else if (np == 43){
    for (int i = 0; i < 15; i++){
      Point2f center = meta.joint_self.joints[i];
      if(meta.joint_self.isVisible[i] <= 1){
        putGaussianMaps(transformed_label + (i+np+29)*channelOffset, center, param_.stride(), 
                        grid_x, grid_y, param_.sigma()); //self
      }
      for(int j = 0; j < meta.numOtherPeople; j++){ //for every other person
        Point2f center = meta.joint_others[j].joints[i];
        if(meta.joint_others[j].isVisible[i] <= 1){
          putGaussianMaps(transformed_label + (i+np+29)*channelOffset, center, param_.stride(), 
                          grid_x, grid_y, param_.sigma());
        }
      }
    }

    int mid_1[14] = {0, 1, 2, 3, 1, 5, 6, 1, 14, 8, 9,  14, 11, 12};
    int mid_2[14] = {1, 2, 3, 4, 5, 6, 7, 14, 8, 9, 10, 11, 12, 13};
    int thre = 1;

    for(int i=0;i<14;i++){
      Mat count = Mat::zeros(grid_y, grid_x, CV_8UC1);
      Joints jo = meta.joint_self;
      if(jo.isVisible[mid_1[i]]<=1 && jo.isVisible[mid_2[i]]<=1){
        //putVecPeaks
        putVecMaps(transformed_label + (np+ 1+ 2*i)*channelOffset, transformed_label + (np+ 2+ 2*i)*channelOffset, 
                  count, jo.joints[mid_1[i]], jo.joints[mid_2[i]], param_.stride(), grid_x, grid_y, param_.sigma(), thre); //self
      }

      for(int j = 0; j < meta.numOtherPeople; j++){ //for every other person
        Joints jo2 = meta.joint_others[j];
        if(jo2.isVisible[mid_1[i]]<=1 && jo2.isVisible[mid_2[i]]<=1){
          //putVecPeaks
          putVecMaps(transformed_label + (np+ 1+ 2*i)*channelOffset, transformed_label + (np+ 2+ 2*i)*channelOffset, 
                  count, jo2.joints[mid_1[i]], jo2.joints[mid_2[i]], param_.stride(), grid_x, grid_y, param_.sigma(), thre); //self
        }
      }
    }

    //put background channel
    for (int g_y = 0; g_y < grid_y; g_y++){
      for (int g_x = 0; g_x < grid_x; g_x++){
        float maximum = 0;
        //second background channel
        for (int i = np+29; i < np+44; i++){
          maximum = (maximum > transformed_label[i*channelOffset + g_y*grid_x + g_x]) ? maximum : transformed_label[i*channelOffset + g_y*grid_x + g_x];
        }
        transformed_label[(2*np+1)*channelOffset + g_y*grid_x + g_x] = max(1.0-maximum, 0.0);
      }
    }
    //LOG(INFO) << "background put";
  }

  //visualize
  if(1 && param_.visualize()){
    Mat label_map;
    for(int i = 0; i < 2*(np+1); i++){      
      label_map = Mat::zeros(grid_y, grid_x, CV_8UC1);
      //int MPI_index = MPI_to_ours[i];
      //Point2f center = meta.joint_self.joints[MPI_index];
      for (int g_y = 0; g_y < grid_y; g_y++){
        //printf("\n");
        for (int g_x = 0; g_x < grid_x; g_x++){
          label_map.at<uchar>(g_y,g_x) = (int)(transformed_label[i*channelOffset + g_y*grid_x + g_x]*255);
          //printf("%f ", transformed_label_entry[g_y*grid_x + g_x]*255);
        }
      }
      resize(label_map, label_map, Size(), stride, stride, INTER_LINEAR);
      applyColorMap(label_map, label_map, COLORMAP_JET);
      addWeighted(label_map, 0.5, img_aug, 0.5, 0.0, label_map);

      //center = center * (1.0/(float)param_.stride());
      //circle(label_map, center, 3, CV_RGB(255,0,255), -1);
      char imagename [100];
      sprintf(imagename, "augment_%04d_label_part_%02d.jpg", meta.write_number, i);
      //LOG(INFO) << "filename is " << imagename;
      imwrite(imagename, label_map);
    }

  }
}