1. 程式人生 > >基於畫素清晰度的影象融合演算法(Python實現)

基於畫素清晰度的影象融合演算法(Python實現)

# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import cv2
from math import log
from PIL import Image
import datetime
import pywt

# 以下強行用Python巨集定義變數
halfWindowSize=9
src1_path = 'F:\\Python\\try\\BasicImageOperation\\disk1.jpg'
src2_path = 'F:\\Python\\try\\BasicImageOperation\\disk2.jpg'

'''
來自敬忠良,肖剛,李振華《影象融合——理論與分析》P85:基於畫素清晰度的融合規則
1,用Laplace金字塔或者是小波變換,將影象分解成高頻部分和低頻部分兩個影象矩陣
2,以某個畫素點為中心開窗,該畫素點的清晰度定義為視窗所有點((高頻/低頻)**2).sum()
3,目前感覺主要的問題在於低頻
4,高頻取清晰度影象中較大的那個圖的高頻影象畫素點
5,演算法優化後速度由原來的2min.44s.變成9s.305ms.
補充:書上建議開窗大小10*10,DWT取3層,Laplace金字塔取2層
'''

def imgOpen(img_src1,img_src2):
    apple=Image.open(img_src1).convert('L')
    orange=Image.open(img_src2).convert('L')
    appleArray=np.array(apple)
    orangeArray=np.array(orange)
    return appleArray,orangeArray

# 嚴格的變換尺寸
def _sameSize(img_std,img_cvt):
    x,y=img_std.shape
    pic_cvt=Image.fromarray(img_cvt)
    pic_cvt.resize((x,y))
    return np.array(pic_cvt)

# 小波變換的層數不能太高,Image模組的resize不能變換太小的矩陣,不相同大小的矩陣在計算對比度時會陣列越界
def getWaveImg(apple,orange):
    appleWave=pywt.wavedec2(apple,'haar',level=4)
    orangeWave=pywt.wavedec2(orange,'haar',level=4)
    lowApple=appleWave[0];lowOrange=orangeWave[0]
    # 以下處理低頻
    lowAppleWeight,lowOrangeWeight = getVarianceWeight(lowApple,lowOrange)
    lowFusion = lowAppleWeight*lowApple + lowOrangeWeight*lowOrange
    # 以下處理高頻
    for hi in range(1,5):
        waveRec=[]
        for highApple,highOrange in zip(appleWave[hi],orangeWave[hi]):
            highFusion = np.zeros(highApple.shape)
            contrastApple = getContrastImg(lowApple,highApple)
            contrastOrange = getContrastImg(lowOrange,highOrange)
            row,col = highApple.shape
            for i in xrange(row):
                for j in xrange(col):
                    if contrastApple[i,j] > contrastOrange[i,j]:
                        highFusion[i,j] = highApple[i,j]
                    else:
                        highFusion[i,j] = highOrange[i,j]
            waveRec.append(highFusion)
        recwave=(lowFusion,tuple(waveRec))
        lowFusion=pywt.idwt2(recwave,'haar')
        lowApple=lowFusion;lowOrange=lowFusion
    return lowFusion

# 求Laplace金字塔
def getLaplacePyr(img):
    firstLevel=img.copy()
    secondLevel=cv2.pyrDown(firstLevel)
    lowFreq=cv2.pyrUp(secondLevel)
    highFreq=cv2.subtract(firstLevel,_sameSize(firstLevel,lowFreq))
    return lowFreq,highFreq

# 計算對比度,優化後不需要這個函數了,扔在這裡看看公式就行
def _getContrastValue(highWin,lowWin):
    row,col = highWin.shape
    contrastValue = 0.00
    for i in xrange(row):
        for j in xrange(col):
            contrastValue += (float(highWin[i,j])/lowWin[i,j])**2
    return contrastValue

# 先求出每個點的(hi/lo)**2,再用numpy的sum(C語言庫)求和
def getContrastImg(low,high):
    row,col=low.shape
    if low.shape!=high.shape:
        low=_sameSize(high,low)
    contrastImg=np.zeros((row,col))
    contrastVal=(high/low)**2
    for i in xrange(row):
        for j in xrange(col):
            up=i-halfWindowSize if i-halfWindowSize>0 else 0
            down=i+halfWindowSize if i+halfWindowSize<row else row
            left=j-halfWindowSize if j-halfWindowSize>0 else 0
            right=j+halfWindowSize if j+halfWindowSize<col else col
            contrastWindow=contrastVal[up:down,left:right]
            contrastImg[i,j]=contrastWindow.sum()
    return contrastImg

# 計算方差權重比
def getVarianceWeight(apple,orange):
    appleMean,appleVar=cv2.meanStdDev(apple)
    orangeMean,orangeVar=cv2.meanStdDev(orange)
    appleWeight=float(appleVar)/(appleVar+orangeVar)
    orangeWeight=float(orangeVar)/(appleVar+orangeVar)
    return appleWeight,orangeWeight

# 函式返回融合後的影象矩陣
def getPyrFusion(apple,orange):
    lowApple,highApple = getLaplacePyr(apple)
    lowOrange,highOrange = getLaplacePyr(orange)
    contrastApple = getContrastImg(lowApple,highApple)
    contrastOrange = getContrastImg(lowOrange,highOrange)
    row,col = lowApple.shape
    highFusion = np.zeros((row,col))
    lowFusion = np.zeros((row,col))
    # 開始處理低頻
    # appleWeight,orangeWeight=getVarianceWeight(lowApple,lowOrange)
    for i in xrange(row):
        for j in xrange(col):
            # lowFusion[i,j]=lowApple[i,j]*appleWeight+lowOrange[i,j]*orangeWeight
            lowFusion[i,j] = lowApple[i,j] if lowApple[i,j]<lowOrange[i,j] else lowOrange[i,j]
    # 開始處理高頻
    for i in xrange(row):
        for j in xrange(col):
            highFusion[i,j] = highApple[i,j] if contrastApple[i,j] > contrastOrange[i,j] else highOrange[i,j]
    # 開始重建
    fusionResult = cv2.add(highFusion,lowFusion)
    return fusionResult

# 繪圖函式
def getPlot(apple,orange,result):
    plt.subplot(131)
    plt.imshow(apple,cmap='gray')
    plt.title('src1')
    plt.axis('off')
    plt.subplot(132)
    plt.imshow(orange,cmap='gray')
    plt.title('src2')
    plt.axis('off')
    plt.subplot(133)
    plt.imshow(result,cmap='gray')
    plt.title('result')
    plt.axis('off')
    plt.show()

# 畫四張圖的函式,為了方便同時比較
def cmpPlot(apple,orange,wave,pyr):
    plt.subplot(221)
    plt.imshow(apple,cmap='gray')
    plt.title('SRC1')
    plt.axis('off')
    plt.subplot(222)
    plt.imshow(orange,cmap='gray')
    plt.title('SRC2')
    plt.axis('off')
    plt.subplot(223)
    plt.imshow(wave,cmap='gray')
    plt.title('WAVELET')
    plt.axis('off')
    plt.subplot(224)
    plt.imshow(pyr,cmap='gray')
    plt.title('LAPLACE PYR')
    plt.axis('off')
    plt.show()

def runTest(src1=src1_path,src2=src2_path,isplot=True):
    apple,orange=imgOpen(src1,src2)
    beginTime=datetime.datetime.now()
    print(beginTime)
    waveResult=getWaveImg(apple,orange)
    pyrResult=getPyrFusion(apple,orange)
    endTime=datetime.datetime.now()
    print(endTime)
    print('Runtime: '+str(endTime-beginTime))
    if isplot:
        cmpPlot(apple,orange,waveResult,pyrResult)
    return waveResult,pyrResult

if __name__=='__main__':
    runTest()

該寫的都寫在註釋裡了