1. 程式人生 > >caffe Python API 之Inference

caffe Python API 之Inference

#以SSD的檢測測試為例
def detetion(image_dir,weight,deploy,resolution=300):
    caffe.set_mode_gpu()
    net = caffe.Net(weight,deploy,caffe.TEST)
    transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
    transformer.set_transpose('data',(2,0,1))
    transformer.set_mean('data', np.array([104, 117, 123]))  #
mean pixel images = os.listdir(image_dir) target_dir = "det_results" if not os.path.exists(target_dir): os.mkdir(target_dir) for image in images: image_path = os.path.join(image_dir,image) target_path = os.path.join(target_dir,image) croped = cut(image_path,resolution) net.blobs[
'data'].reshape(1, 3, resolution, resolution) transformed_image = transformer.preprocess('data',croped) net.blobs['data'].data[...]=transformed_image start = time.time() net.forward() end = time.time() print "Forward time is {} s.".format(int(end-start)) out_put
= net.blobs["detection_out"].data out_put = np.squeeze(out_put) # label,conf,xmin,ymin,xmax,ymax for box in out_put: conf = box[2] # if conf < 0.1: # continue xmin = int(box[3]*resolution) if box[3] > 0 else 0 ymin = int(box[4]*resolution) if box[4] > 0 else 0 xmax = int(box[5]*resolution) if box[5] > 0 else 0 ymax = int(box[6]*resolution) if box[6] > 0 else 0 cv2.rectangle(croped,(xmin,ymin),(xmax,ymax),(0,255,0),1) cv2.imwrite(target_path,croped) print target_path