1. 程式人生 > >Unet專案解析(6): 影象分塊、整合 / 資料對齊、網路輸出轉成影象

Unet專案解析(6): 影象分塊、整合 / 資料對齊、網路輸出轉成影象


1. 訓練資料

1.1 訓練影象、訓練金標準隨機分塊

主程式碼:

# 訓練集太少,採用分塊的方法進行訓練
def get_data_training(DRIVE_train_imgs_original,  #訓練影象路徑
                      DRIVE_train_groudTruth,     #金標準影象路徑
                      patch_height,
                      patch_width,
                      N_subimgs,
                      inside_FOV):
    train_imgs_original = load_hdf5(DRIVE_train_imgs_original)
    train_masks = load_hdf5(DRIVE_train_groudTruth) 
    #visualize(group_images(train_imgs_original[0:20,:,:,:],5),'imgs_train').show() 

    train_imgs = my_PreProc(train_imgs_original) # 影象預處理 歸一化等
    train_masks = train_masks/255.

    train_imgs = train_imgs[:,:,9:574,:]   # 影象裁剪 size=565*565
    train_masks = train_masks[:,:,9:574,:] # 影象裁剪 size=565*565
    data_consistency_check(train_imgs,train_masks) # 訓練影象和金標準影象一致性檢查
    assert(np.min(train_masks)==0 and np.max(train_masks)==1) #金標準影象 2類 0-1

    print ("\n train images/masks shape:")
    print (train_imgs.shape)
    print ("train images range (min-max): " +str(np.min(train_imgs)) +' - '+str(np.max(train_imgs)))
    print ("train masks are within 0-1\n")

    # 從整張影象中-隨機提取-訓練子塊
    patches_imgs_train, patches_masks_train =
			extract_random(train_imgs,train_masks,patch_height,patch_width,N_subimgs,inside_FOV)
    data_consistency_check(patches_imgs_train, patches_masks_train) # 訓練影象子塊和金標準影象子塊一致性檢查

    print ("\n train PATCHES images/masks shape:")
    print (patches_imgs_train.shape)
    print ("train PATCHES images range (min-max): " +
			str(np.min(patches_imgs_train)) +' - '+str(np.max(patches_imgs_train)))

    return patches_imgs_train, patches_masks_train

隨機提取子塊:

# 訓練集影象 隨機 提取子塊
def extract_random(full_imgs,full_masks, patch_h,patch_w, N_patches, inside=True):
    if (N_patches%full_imgs.shape[0] != 0): # 檢驗每張影象應該提取多少塊
        print "N_patches: plase enter a multiple of 20"
        exit()
    assert (len(full_imgs.shape)==4 and len(full_masks.shape)==4)  # 張量尺寸檢驗
    assert (full_imgs.shape[1]==1 or full_imgs.shape[1]==3)  # 通道檢驗
    assert (full_masks.shape[1]==1)   # 通道檢驗
    assert (full_imgs.shape[2] == full_masks.shape[2] and full_imgs.shape[3] == full_masks.shape[3]) # 尺寸檢驗
    patches = np.empty((N_patches,full_imgs.shape[1],patch_h,patch_w)) # 訓練影象總子塊
    patches_masks = np.empty((N_patches,full_masks.shape[1],patch_h,patch_w)) # 訓練金標準總子塊
    img_h = full_imgs.shape[2]  
    img_w = full_imgs.shape[3] 
    
    patch_per_img = int(N_patches/full_imgs.shape[0])  # 每張影象中提取的子塊數量
    print ("patches per full image: " +str(patch_per_img))
    iter_tot = 0   # 影象子塊總量計數器
    for i in range(full_imgs.shape[0]):  # 遍歷每一張影象
        k=0 # 每張影象子塊計數器
        while k <patch_per_img:
            x_center = random.randint(0+int(patch_w/2),img_w-int(patch_w/2)) # 塊中心的範圍
            y_center = random.randint(0+int(patch_h/2),img_h-int(patch_h/2))
            
            if inside==True:
                if is_patch_inside_FOV(x_center,y_center,img_w,img_h,patch_h)==False:
                    continue
					
            patch = full_imgs[i,:,y_center-int(patch_h/2):y_center+int(patch_h/2),
								  x_center-int(patch_w/2):x_center+int(patch_w/2)]
            patch_mask = full_masks[i,:,y_center-int(patch_h/2):y_center+int(patch_h/2),
										x_center-int(patch_w/2):x_center+int(patch_w/2)]
            patches[iter_tot]=patch # size=[Npatches, 3, patch_h, patch_w]
            patches_masks[iter_tot]=patch_mask # size=[Npatches, 1, patch_h, patch_w]
            iter_tot +=1   # 子塊總量計數器
            k+=1  # 每張影象子塊總量計數器
    return patches, patches_masks

資料一致性檢查函式:

# 訓練集影象 和 金標準影象一致性檢驗
def data_consistency_check(imgs,masks):
    assert(len(imgs.shape)==len(masks.shape))
    assert(imgs.shape[0]==masks.shape[0])
    assert(imgs.shape[2]==masks.shape[2])
    assert(imgs.shape[3]==masks.shape[3])
    assert(masks.shape[1]==1)
    assert(imgs.shape[1]==1 or imgs.shape[1]==3)

1.2 訓練金標準改寫成Une輸出形式

# 將金標準影象改寫成模型輸出形式
def masks_Unet(masks): # size=[Npatches, 1, patch_height, patch_width]
    assert (len(masks.shape)==4)
    assert (masks.shape[1]==1 )
    im_h = masks.shape[2]
    im_w = masks.shape[3]
    masks = np.reshape(masks,(masks.shape[0],im_h*im_w)) # 單畫素建模
    new_masks = np.empty((masks.shape[0],im_h*im_w,2)) # 二分類輸出
    for i in range(masks.shape[0]):
        for j in range(im_h*im_w):
            if  masks[i,j] == 0:
                new_masks[i,j,0]=1 # 金標準影象的反轉
                new_masks[i,j,1]=0 # 金標準影象
            else:
                new_masks[i,j,0]=0
                new_masks[i,j,1]=1
    return new_masks

2. 網路輸出轉換成影象子塊

# 網路輸出 size=[Npatches, patch_height*patch_width, 2]
def pred_to_imgs(pred, patch_height, patch_width, mode="original"):
    assert (len(pred.shape)==3)  
    assert (pred.shape[2]==2 )  # 確認是否為二分類
    pred_images = np.empty((pred.shape[0],pred.shape[1]))  #(Npatches,height*width)
    if mode=="original": # 網路概率輸出
        for i in range(pred.shape[0]):
            for pix in range(pred.shape[1]):
                pred_images[i,pix]=pred[i,pix,1] #pred[:, :, 0] 是反分割影象輸出 pred[:, :, 1]是分割輸出
    elif mode=="threshold": # 網路概率-閾值輸出
        for i in range(pred.shape[0]):
            for pix in range(pred.shape[1]):
                if pred[i,pix,1]>=0.5:
                    pred_images[i,pix]=1
                else:
                    pred_images[i,pix]=0
    else:
        print ("mode " +str(mode) +" not recognized, it can be 'original' or 'threshold'")
        exit()
	# 改寫成(Npatches,1, height, width)
    pred_images = np.reshape(pred_images,(pred_images.shape[0],1, patch_height, patch_width)) 
    return pred_images

3. 測試影象按順序分塊、預測子塊重新整合成影象

3.1 測試影象分塊

def get_data_testing_overlap(DRIVE_test_imgs_original, 
							 DRIVE_test_groudTruth, 
							 Imgs_to_test, # 20
							 patch_height, 
							 patch_width, 
							 stride_height, 
							 stride_width):
    test_imgs_original = load_hdf5(DRIVE_test_imgs_original)
    test_masks = load_hdf5(DRIVE_test_groudTruth)

    test_imgs = my_PreProc(test_imgs_original)
    test_masks = test_masks/255.
    
    test_imgs = test_imgs[0:Imgs_to_test,:,:,:]
    test_masks = test_masks[0:Imgs_to_test,:,:,:]
	
    test_imgs = paint_border_overlap(test_imgs, patch_height, # 拓展影象 可以準確劃分
									 patch_width, stride_height, stride_width)
    assert(np.max(test_masks)==1  and np.min(test_masks)==0)

    print ("\n test images shape:")
    print (test_imgs.shape)
    print ("\n test mask shape:")
    print (test_masks.shape)
    print ("test images range (min-max): " +str(np.min(test_imgs)) +' - '+str(np.max(test_imgs)))

    # 按照順序提取影象快 方便後續進行影象恢復(作者採用了overlap策略)
    patches_imgs_test = extract_ordered_overlap(test_imgs,patch_height,patch_width,stride_height,stride_width)
    print ("\n test PATCHES images shape:")
    print (patches_imgs_test.shape)
    print ("test PATCHES images range (min-max): " +
		   str(np.min(patches_imgs_test)) +' - '+str(np.max(patches_imgs_test)))

    return patches_imgs_test, test_imgs.shape[2], test_imgs.shape[3], test_masks #原始大小

原始影象進行拓展填充:

def paint_border_overlap(full_imgs, patch_h, patch_w, stride_h, stride_w):
    assert (len(full_imgs.shape)==4)  #4D arrays
    assert (full_imgs.shape[1]==1 or full_imgs.shape[1]==3)  #check the channel is 1 or 3
    img_h = full_imgs.shape[2]  #height of the full image
    img_w = full_imgs.shape[3] #width of the full image
    leftover_h = (img_h-patch_h)%stride_h  #leftover on the h dim
    leftover_w = (img_w-patch_w)%stride_w  #leftover on the w dim
    if (leftover_h != 0):  #change dimension of img_h
        tmp_full_imgs = np.zeros((full_imgs.shape[0],full_imgs.shape[1],img_h+(stride_h-leftover_h),img_w))
        tmp_full_imgs[0:full_imgs.shape[0],0:full_imgs.shape[1],0:img_h,0:img_w] = full_imgs
        full_imgs = tmp_full_imgs
    if (leftover_w != 0):   #change dimension of img_w
        tmp_full_imgs = np.zeros((full_imgs.shape[0],full_imgs.shape[1],full_imgs.shape[2],img_w+(stride_w - leftover_w)))
        tmp_full_imgs[0:full_imgs.shape[0],0:full_imgs.shape[1],0:full_imgs.shape[2],0:img_w] = full_imgs
        full_imgs = tmp_full_imgs
    return full_imgs

按順序提取影象子塊:

# 按照順序對拓展後的影象進行子塊取樣
def extract_ordered_overlap(full_imgs, patch_h, patch_w,stride_h,stride_w):
    assert (len(full_imgs.shape)==4)  
    assert (full_imgs.shape[1]==1 or full_imgs.shape[1]==3)  
    img_h = full_imgs.shape[2]  
    img_w = full_imgs.shape[3] 
    assert ((img_h-patch_h)%stride_h==0 and (img_w-patch_w)%stride_w==0)
    N_patches_img = ((img_h-patch_h)//stride_h+1)*((img_w-patch_w)//stride_w+1)  # 每張影象採集到的子影象
    N_patches_tot = N_patches_img*full_imgs.shape[0] # 測試集總共的子影象數量
    patches = np.empty((N_patches_tot,full_imgs.shape[1],patch_h,patch_w))
    iter_tot = 0   
    for i in range(full_imgs.shape[0]):  
        for h in range((img_h-patch_h)//stride_h+1):
            for w in range((img_w-patch_w)//stride_w+1):
                patch = full_imgs[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w]
                patches[iter_tot]=patch
                iter_tot +=1   #total
    assert (iter_tot==N_patches_tot)
    return patches 

3.2 對於影象子塊進行復原

# [Npatches, 1, patch_h, patch_w]  img_h=new_height[588] img_w=new_width[568] stride-[10,10]
def recompone_overlap(preds, img_h, img_w, stride_h, stride_w):
    assert (len(preds.shape)==4)  # 檢查張量尺寸
    assert (preds.shape[1]==1 or preds.shape[1]==3)
    patch_h = preds.shape[2]
    patch_w = preds.shape[3]
    N_patches_h = (img_h-patch_h)//stride_h+1 # img_h方向包括的patch_h數量
    N_patches_w = (img_w-patch_w)//stride_w+1 # img_w方向包括的patch_w數量
    N_patches_img = N_patches_h * N_patches_w # 每張影象包含的patch的數目
    assert (preds.shape[0]%N_patches_img==0   
    N_full_imgs = preds.shape[0]//N_patches_img # 全幅影象的數目
    full_prob = np.zeros((N_full_imgs,preds.shape[1],img_h,img_w))
    full_sum = np.zeros((N_full_imgs,preds.shape[1],img_h,img_w))

    k = 0 #迭代所有的子塊
    for i in range(N_full_imgs):
        for h in range((img_h-patch_h)//stride_h+1):
            for w in range((img_w-patch_w)//stride_w+1):
                full_prob[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w]+=preds[k]
                full_sum[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w]+=1
                k+=1
    assert(k==preds.shape[0])
    assert(np.min(full_sum)>=1.0) 
    final_avg = full_prob/full_sum # 疊加概率 / 疊加權重 : 採用了均值的方法
    print final_avg.shape
    assert(np.max(final_avg)<=1.0)
    assert(np.min(final_avg)>=0.0)
    return final_avg