基於畫素清晰度的影象融合演算法(Python實現)
阿新 • • 發佈:2019-01-02
# -*- coding: utf-8 -*- import numpy as np import matplotlib.pyplot as plt import cv2 from math import log from PIL import Image import datetime import pywt # 以下強行用Python巨集定義變數 halfWindowSize=9 src1_path = 'F:\\Python\\try\\BasicImageOperation\\disk1.jpg' src2_path = 'F:\\Python\\try\\BasicImageOperation\\disk2.jpg' ''' 來自敬忠良,肖剛,李振華《影象融合——理論與分析》P85:基於畫素清晰度的融合規則 1,用Laplace金字塔或者是小波變換,將影象分解成高頻部分和低頻部分兩個影象矩陣 2,以某個畫素點為中心開窗,該畫素點的清晰度定義為視窗所有點((高頻/低頻)**2).sum() 3,目前感覺主要的問題在於低頻 4,高頻取清晰度影象中較大的那個圖的高頻影象畫素點 5,演算法優化後速度由原來的2min.44s.變成9s.305ms. 補充:書上建議開窗大小10*10,DWT取3層,Laplace金字塔取2層 ''' def imgOpen(img_src1,img_src2): apple=Image.open(img_src1).convert('L') orange=Image.open(img_src2).convert('L') appleArray=np.array(apple) orangeArray=np.array(orange) return appleArray,orangeArray # 嚴格的變換尺寸 def _sameSize(img_std,img_cvt): x,y=img_std.shape pic_cvt=Image.fromarray(img_cvt) pic_cvt.resize((x,y)) return np.array(pic_cvt) # 小波變換的層數不能太高,Image模組的resize不能變換太小的矩陣,不相同大小的矩陣在計算對比度時會陣列越界 def getWaveImg(apple,orange): appleWave=pywt.wavedec2(apple,'haar',level=4) orangeWave=pywt.wavedec2(orange,'haar',level=4) lowApple=appleWave[0];lowOrange=orangeWave[0] # 以下處理低頻 lowAppleWeight,lowOrangeWeight = getVarianceWeight(lowApple,lowOrange) lowFusion = lowAppleWeight*lowApple + lowOrangeWeight*lowOrange # 以下處理高頻 for hi in range(1,5): waveRec=[] for highApple,highOrange in zip(appleWave[hi],orangeWave[hi]): highFusion = np.zeros(highApple.shape) contrastApple = getContrastImg(lowApple,highApple) contrastOrange = getContrastImg(lowOrange,highOrange) row,col = highApple.shape for i in xrange(row): for j in xrange(col): if contrastApple[i,j] > contrastOrange[i,j]: highFusion[i,j] = highApple[i,j] else: highFusion[i,j] = highOrange[i,j] waveRec.append(highFusion) recwave=(lowFusion,tuple(waveRec)) lowFusion=pywt.idwt2(recwave,'haar') lowApple=lowFusion;lowOrange=lowFusion return lowFusion # 求Laplace金字塔 def getLaplacePyr(img): firstLevel=img.copy() secondLevel=cv2.pyrDown(firstLevel) lowFreq=cv2.pyrUp(secondLevel) highFreq=cv2.subtract(firstLevel,_sameSize(firstLevel,lowFreq)) return lowFreq,highFreq # 計算對比度,優化後不需要這個函數了,扔在這裡看看公式就行 def _getContrastValue(highWin,lowWin): row,col = highWin.shape contrastValue = 0.00 for i in xrange(row): for j in xrange(col): contrastValue += (float(highWin[i,j])/lowWin[i,j])**2 return contrastValue # 先求出每個點的(hi/lo)**2,再用numpy的sum(C語言庫)求和 def getContrastImg(low,high): row,col=low.shape if low.shape!=high.shape: low=_sameSize(high,low) contrastImg=np.zeros((row,col)) contrastVal=(high/low)**2 for i in xrange(row): for j in xrange(col): up=i-halfWindowSize if i-halfWindowSize>0 else 0 down=i+halfWindowSize if i+halfWindowSize<row else row left=j-halfWindowSize if j-halfWindowSize>0 else 0 right=j+halfWindowSize if j+halfWindowSize<col else col contrastWindow=contrastVal[up:down,left:right] contrastImg[i,j]=contrastWindow.sum() return contrastImg # 計算方差權重比 def getVarianceWeight(apple,orange): appleMean,appleVar=cv2.meanStdDev(apple) orangeMean,orangeVar=cv2.meanStdDev(orange) appleWeight=float(appleVar)/(appleVar+orangeVar) orangeWeight=float(orangeVar)/(appleVar+orangeVar) return appleWeight,orangeWeight # 函式返回融合後的影象矩陣 def getPyrFusion(apple,orange): lowApple,highApple = getLaplacePyr(apple) lowOrange,highOrange = getLaplacePyr(orange) contrastApple = getContrastImg(lowApple,highApple) contrastOrange = getContrastImg(lowOrange,highOrange) row,col = lowApple.shape highFusion = np.zeros((row,col)) lowFusion = np.zeros((row,col)) # 開始處理低頻 # appleWeight,orangeWeight=getVarianceWeight(lowApple,lowOrange) for i in xrange(row): for j in xrange(col): # lowFusion[i,j]=lowApple[i,j]*appleWeight+lowOrange[i,j]*orangeWeight lowFusion[i,j] = lowApple[i,j] if lowApple[i,j]<lowOrange[i,j] else lowOrange[i,j] # 開始處理高頻 for i in xrange(row): for j in xrange(col): highFusion[i,j] = highApple[i,j] if contrastApple[i,j] > contrastOrange[i,j] else highOrange[i,j] # 開始重建 fusionResult = cv2.add(highFusion,lowFusion) return fusionResult # 繪圖函式 def getPlot(apple,orange,result): plt.subplot(131) plt.imshow(apple,cmap='gray') plt.title('src1') plt.axis('off') plt.subplot(132) plt.imshow(orange,cmap='gray') plt.title('src2') plt.axis('off') plt.subplot(133) plt.imshow(result,cmap='gray') plt.title('result') plt.axis('off') plt.show() # 畫四張圖的函式,為了方便同時比較 def cmpPlot(apple,orange,wave,pyr): plt.subplot(221) plt.imshow(apple,cmap='gray') plt.title('SRC1') plt.axis('off') plt.subplot(222) plt.imshow(orange,cmap='gray') plt.title('SRC2') plt.axis('off') plt.subplot(223) plt.imshow(wave,cmap='gray') plt.title('WAVELET') plt.axis('off') plt.subplot(224) plt.imshow(pyr,cmap='gray') plt.title('LAPLACE PYR') plt.axis('off') plt.show() def runTest(src1=src1_path,src2=src2_path,isplot=True): apple,orange=imgOpen(src1,src2) beginTime=datetime.datetime.now() print(beginTime) waveResult=getWaveImg(apple,orange) pyrResult=getPyrFusion(apple,orange) endTime=datetime.datetime.now() print(endTime) print('Runtime: '+str(endTime-beginTime)) if isplot: cmpPlot(apple,orange,waveResult,pyrResult) return waveResult,pyrResult if __name__=='__main__': runTest()
該寫的都寫在註釋裡了