1. 程式人生 > >caffe下用AlexNet模型提取影象特徵並從指定層輸出特徵向量

caffe下用AlexNet模型提取影象特徵並從指定層輸出特徵向量

  1. 選擇需要提取特徵的影象,並將其路徑匯入txt

    ./example/_temp

    
    # 建立臨時目錄
    
    mkdir examples/_temp
    
    # 生成影象路徑列表檔案
    
    find `pwd`/examples/images -type f -exec echo {} \; > examples/_temp/temp.txt
    
    # 每個影象路徑最後都有一個分類標籤,因此在每條路徑最後加上0代表結束
    
    sed "s/$/ 0/" examples/_temp/temp.txt > examples/_temp/file_list.txt
  2. 下載bvlc_reference_caffenet.caffemodel,並製作網路結構檔案

    ./models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel
    ./examples/_temp/imagenet_val.prototxt

    ./data/ilsvrc12/get_ilsvrc_aux.sh
    
    # 匯入網路結構檔案
    
    cp examples/feature_extraction/imagenet_val.prototxt examples/_temp
  3. 使用extract_features.bin提取特徵,並以lmdb格式儲存。執行引數為extract_features.bin $MODEL $PROTOTXT $LAYER $LMDB_OUTPUT_PATH $BATCHSIZE

    ./models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel
    ./examples/_temp/imagenet_val.prototxt
    ./examples/_temp/features

    ./build/tools/extract_features.bin models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel examples/_temp/imagenet_val.prototxt fc7 examples/_temp/features 10 lmdb GPU
  4. 將特徵轉化為.mat檔案
    安裝CAFFE的python依賴庫,並使用以下兩個輔助檔案把lmdb轉換為mat。

    1. ./feat_helper_pb2.py

      
      # Generated by the protocol buffer compiler.  DO NOT EDIT!
      
      
      from google.protobuf import descriptor
      from google.protobuf import message
      from google.protobuf import reflection
      from google.protobuf import descriptor_pb2
      
      
      # @@protoc_insertion_point(imports)
      
      
      DESCRIPTOR = descriptor.FileDescriptor(
        name='datum.proto',
        package='feat_extract',
        serialized_pb='\n\x0b\x64\x61tum.proto\x12\x0c\x66\x65\x61t_extract\"i\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02')
      
      
      _DATUM = descriptor.Descriptor(
        name='Datum',
        full_name='feat_extract.Datum',
        filename=None,
        file=DESCRIPTOR,
        containing_type=None,
        fields=[
          descriptor.FieldDescriptor(
            name='channels', full_name='feat_extract.Datum.channels', index=0,
            number=1, type=5, cpp_type=1, label=1,
            has_default_value=False, default_value=0,
            message_type=None, enum_type=None, containing_type=None,
            is_extension=False, extension_scope=None,
            options=None),
          descriptor.FieldDescriptor(
            name='height', full_name='feat_extract.Datum.height', index=1,
            number=2, type=5, cpp_type=1, label=1,
            has_default_value=False, default_value=0,
            message_type=None, enum_type=None, containing_type=None,
            is_extension=False, extension_scope=None,
            options=None),
          descriptor.FieldDescriptor(
            name='width', full_name='feat_extract.Datum.width', index=2,
            number=3, type=5, cpp_type=1, label=1,
            has_default_value=False, default_value=0,
            message_type=None, enum_type=None, containing_type=None,
            is_extension=False, extension_scope=None,
            options=None),
          descriptor.FieldDescriptor(
            name='data', full_name='feat_extract.Datum.data', index=3,
            number=4, type=12, cpp_type=9, label=1,
            has_default_value=False, default_value="",
            message_type=None, enum_type=None, containing_type=None,
            is_extension=False, extension_scope=None,
            options=None),
          descriptor.FieldDescriptor(
            name='label', full_name='feat_extract.Datum.label', index=4,
            number=5, type=5, cpp_type=1, label=1,
            has_default_value=False, default_value=0,
            message_type=None, enum_type=None, containing_type=None,
            is_extension=False, extension_scope=None,
            options=None),
          descriptor.FieldDescriptor(
            name='float_data', full_name='feat_extract.Datum.float_data', index=5,
            number=6, type=2, cpp_type=6, label=3,
            has_default_value=False, default_value=[],
            message_type=None, enum_type=None, containing_type=None,
            is_extension=False, extension_scope=None,
            options=None),
        ],
        extensions=[
        ],
        nested_types=[],
        enum_types=[
        ],
        options=None,
        is_extendable=False,
        extension_ranges=[],
        serialized_start=29,
        serialized_end=134,
      )
      
      DESCRIPTOR.message_types_by_name['Datum'] = _DATUM
      
      class Datum(message.Message):
        __metaclass__ = reflection.GeneratedProtocolMessageType
        DESCRIPTOR = _DATUM
      
        # @@protoc_insertion_point(class_scope:feat_extract.Datum)
      
      # @@protoc_insertion_point(module_scope)
      
    2. ./lmdb2mat.py

      import lmdb
      import feat_helper_pb2
      import numpy as np
      import scipy.io as sio
      import time
      
      def main(argv):
          lmdb_name = sys.argv[1]
          print "%s" % sys.argv[1]
          batch_num = int(sys.argv[2]);
          batch_size = int(sys.argv[3]);
          window_num = batch_num*batch_size;
      
          start = time.time()
          if 'db' not in locals().keys():
              db = lmdb.open(lmdb_name)
              txn= db.begin()
              cursor = txn.cursor()
              cursor.iternext()
              datum = feat_helper_pb2.Datum()
      
              keys = []
              values = []
              for key, value in enumerate( cursor.iternext_nodup()):
                  keys.append(key)
                  values.append(cursor.value())
      
          ft = np.zeros((window_num, int(sys.argv[4])))
          for im_idx in range(window_num):
              datum.ParseFromString(values[im_idx])
              ft[im_idx, :] = datum.float_data
      
          print 'time 1: %f' %(time.time() - start)
          sio.savemat(sys.argv[5], {'feats':ft})
          print 'time 2: %f' %(time.time() - start)
          print 'done!'
      
      if __name__ == '__main__':
          import sys
          main(sys.argv)
    3. 執行bash輸出.mat檔案

      
      #!/usr/bin/env sh
      
      LMDB=./examples/_temp/features_fc7 # lmdb檔案路徑
      BATCHNUM=1
      BATCHSIZE=10
      
      # DIM=290400 # feature長度,conv1
      
      
      # DIM=43264 # conv5
      
      
      DIM=4096
      OUT=./examples/_temp/features_fc7.mat #mat檔案儲存路徑
      python ./lmdb2mat.py $LMDB $BATCHNUM $BATCHSIZE $DIM $OUT