1. 程式人生 > >Python 應用caffe模型進行分類(caffe介面)

Python 應用caffe模型進行分類(caffe介面)

遍歷一個檔案下的所有圖片,進行單張預測,並複製到相應的資料夾


import caffe
#import lmdb
import numpy as np
import cv2
from caffe.proto import caffe_pb2
import os
import sys

caffe.set_mode_gpu()


def dirlist(path, allfile):
    filelist = os.listdir(path)

    for filename in filelist:
        filepath = os.path.join(path, filename)
        if os.path.isdir(filepath):
            dirlist(filepath, allfile)
        else:
            allfile.append(filepath)
    return allfile

# sys.setrecursionlimit(1000000)


def is_bgr_img(img):
    bools = True
    try:
        a, b, c = img.shape
    except AttributeError:
        bools = False
    return bools


# load caffe
root = 'D:/stomach_raw_data/deepid/'  # 根目錄
deploy = root + 'deploy_all.prototxt'  # deploy檔案
caffe_model = root + 'id_128_net_iter_1695000.caffemodel'  # 訓練好的 caffemodel
labels_filename = root + 'labels.txt'  # 類別名稱檔案,將數字標籤轉換回類別名稱
  # 載入model和network
net = caffe.Net(deploy, caffe_model, caffe.TEST)
 # 設定圖片的shape格式(1,3,28,28)依次為數量,通道,高,寬
transformer = caffe.io.Transformer({'data': net.blobs['data_1'].data.shape}) 
 # 改變顏色通道,由RGB轉成BGR
transformer.set_transpose('data', (2, 0, 1)) 
 #減去均值,前面訓練模型時沒有減均值,這兒就不用
# transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))   
# transformer.set_raw_scale('data', 255)    # 縮放到【0,255】之間
# transformer.set_channel_swap('data', (2,1,0))   #交換通道,將圖片由RGB變為BGR
labels = np.loadtxt(labels_filename, str, delimiter='\t')
dirs = ['0_CA', '1_FV', '2_GB', '3_GA', '4_SV', '5_PY', '6_OT','7_IV']


imgnames = dirlist('D:\\2D', [])
path ='D:/sto_img_1695000/'

temp = imgnames[0]
print(temp.split('\\')[-2].split('_')[0])
print(temp)
t = 0
all = 0
acc = 0
a_pro = 1
for imgname in imgnames:
    image = cv2.imread(imgname)
    temp = imgname     
    try:
            image.shape
    except AttributeError:
            print(imgname)
            os.remove(imgname)
            continue

    # imgx = image/255
    net.blobs['data_1'].data[...] = transformer.preprocess('data', image)
    t1 = cv2.getTickCount()
    for i in range(1):
        out = net.forward()
    t += (cv2.getTickCount() - t1) * 1000 / cv2.getTickFrequency()

    prob = net.blobs['softmax'].data[0].flatten()
    #print(prob)
    order = prob.argsort()[-1]
    prob_max = prob[order]
    print('max = %f,class = %d,all = %d\n'%(prob_max,order,all))
    
    if prob_max > 0.70:
        imgname = temp.split('\\')[-1]        
        imgpath = path + dirs[order]
        if not os.path.exists(imgpath):
            os.mkdir(imgpath)
        cv2.imwrite(imgpath+'/'+imgname, image)
    else:
        imgname = temp.split('\\')[-1]
        imgpath = path + 'unkown'
        if not os.path.exists(imgpath):
            os.mkdir(imgpath)
        cv2.imwrite(imgpath+'/'+imgname, image)

    cv2.imshow('cv2', image)
    k = cv2.waitKey(1)
    if k == 27:
        break
    if k == 32:
        cv2.waitKey()

cv2.destroyAllWindows()