1. 程式人生 > >基於深度學習的CT影象肺結節自動檢測技術六—模型預測

基於深度學習的CT影象肺結節自動檢測技術六—模型預測

#模型預測的相關功能

from chapter4 import get_unet
from chapter5 import get_3dnnnet, stack_2dcube_to_3darray, prepare_image_for_net3D, MEAN_PIXEL_VALUE
import glob
import cv2
import numpy as np
from scipy import ndimage
from scipy.ndimage.measurements import center_of_mass
from skimage import morphology

CHANNEL_COUNT =
1 _3DCNN_WEIGHTS = './model/3dcnn.hd5' UNET_WEIGHTS = './model/unet.hd5' THRESHOLD = 2 BATCH_SIZE = 1 # 獲取unet預測結果的中心點座標(x,y) def unet_candidate_dicom(unet_result_path): centers = [] image_t = cv2.imread(unet_result_path, cv2.IMREAD_GRAYSCALE) # Thresholding(閾值化) image_t[image_t < THRESHOLD]
= 0 image_t[image_t > 0] = 1 # dilation(擴張) selem = morphology.disk(1) image_eroded = morphology.binary_dilation(image_t, selem=selem) label_im, nb_labels = ndimage.label(image_eroded) for i in range(1, nb_labels + 1): blob_i = np.where(label_im == i, 1, 0) mass =
center_of_mass(blob_i) y_px = int(round(mass[0])) x_px = int(round(mass[1])) centers.append([y_px, x_px]) return centers # 資料輸入網路之前先進行預處理 def prepare_image_for_net(img): img = img.astype(np.float) img /= 255. if len(img.shape) == 3: img = img.reshape(img.shape[-3], img.shape[-2], img.shape[-1]) else: img = img.reshape(1, img.shape[-2], img.shape[-1], 1) return img # unet模型的預測程式碼 def unet_predict(imagepath): model = get_unet() model.load_weights(UNET_WEIGHTS) # read png and ready for predict images = [] img = cv2.imread(imagepath, cv2.IMREAD_GRAYSCALE) images.append(img) for index, img in enumerate(images): img = prepare_image_for_net(img) images[index] = img images3d = np.vstack(images) y_pred = model.predict(images3d, batch_size=BATCH_SIZE) print(len(y_pred)) count = 0 for y in y_pred: y *= 255. y = y.reshape((y.shape[0], y.shape[1])).astype(np.uint8) cv2.imwrite('./temp_dir/chapter6/unet_result.png', y) count += 1 # 3dcnn模型的預測程式碼 def _3dcnn_predict(imagepath): cube_image = stack_2dcube_to_3darray(imagepath, 4, 8, 32) img3d = prepare_image_for_net3D(cube_image, MEAN_PIXEL_VALUE) model = get_3dnnnet(load_weight_path='./model/3dcnn.hd5') result = model.predict(img3d, batch_size=BATCH_SIZE, verbose=1) print('3dcnn result: ', result) if __name__ == "__main__": unet_predict('./data/chapter6/unet_input_img.png') centers = unet_candidate_dicom('./temp_dir/chapter6/unet_result.png') print('y, x', centers) _3dcnn_predict('./data/chapter6/true_positive_nodules.png') _3dcnn_predict('./data/chapter6/false_positive_nodules.png')