如何用Python實現神奇切圖演算法seam carving?

我們把時鐘撥到 11 年前,2007 年,在第 34 屆 SIGGRAPH 2007 數字圖形學年會上,以色列的兩位教授 Shai Avidan 和 Ariel Shamir 展示了一種新的縮放裁剪影象方法,他們稱之為 Seam Carving for Content-Aware Image Resizing,也就是我們後來所說的“接縫剪裁”(Seam Carving)演算法。
這個演算法能實現什麼效果呢?
這項技術能計算出影象上的“關鍵部分”和“不重要區域”,從而使得隨意改變一個影象的高寬比但“不會讓影象內容變得扭曲變形”。

簡單的說,利用這個技術我們可以在縮放時固定圖片中特定區域的大小,或者可以在縮小時讓特定的區塊被周圍影象縫合消除,並且因為“seam carving”的縫補演算法,你可以 讓圖片縮放後仍然維持整體的完整性 。
舉實際應用的例子來說,利用 Seam Carving 演算法我們可以將原本窄鏡頭的夕陽照片,修改成廣角鏡頭的夕陽照片, 且照片中心的太陽不會因為圖片拉寬而變形 ;或者我們可以將原本中間隔著距離的兩人合照, 修改成靠在一起的合照,且圖片也不會因為修改變形 。這是一個很有趣也讓人覺得很厲害的技術,是你從沒有玩過的船新版本切圖工具。
接縫剪裁演算法這種很新穎的技術,能讓我們在沒有損失影象中重要內容的情況下裁切影象。因此它又常被稱為“內容感知”裁剪或影象重定向。
到底這種演算法有多奇妙?我們看下面這個圖:

使用接縫剪裁演算法,我們可以把它變成這樣:

可以看到,圖片中的大部分重要內容比如小船都完整的儲存了下來。演算法移除了一些岩石以及湖水(所以我們看到圖中的小船離得更近了)。這就是接縫剪裁演算法的神奇之處, 它能在調整影象大小本身的同時,也能保留影象中最重要最突出的內容 。如果我們在切圖時,既想獲得合適的影象大小,也想保留影象的完整內容,使用傳統的切圖方法幾乎無法做到。而使用接縫剪裁演算法就能實現二者兼得。
關於演算法的核心原理,在原論文中解釋的非常清楚了,網上也有很多解析文章,這裡不再贅述。在本文我(作者Karthik Karanth——譯者注)就以上面所舉的例子為素材,重點講講 如何用Python基本實現接縫剪裁演算法 。
演算法論文地址:
ofollow,noindex">http://graphics.cs.cmu.edu/courses/15-463/2007_fall/hw/proj2/imret.pdf
工作過程概覽
在接縫裁剪(seam carving)演算法中,縫隙(seam)就是指從左到右或從上到下的連續畫素,它們橫向或縱向穿過整個影象。
因此,為了執行縫隙拼接,我們需要兩個重要的輸入:
- 1.原始圖片:我們想要調整大小的圖片。
- 2.能量圖(energy map): 我們從原始影象匯出的能量圖。
能量圖應該代表影象的最顯著的區域。通常,我們使用梯度幅度,熵圖或顯著圖表示。
演算法工作過程如下所示:
- 為每個畫素分配一個能量值
- 找到能量值最小的畫素的八連通路徑
- 刪除路徑中的所有畫素
- 重複前面1-3步,直到刪除的行/列數量達到理想狀態
在本文,我們會假設只想裁切影象的寬度,也就是隻刪除列。但是同樣的方法也能用於刪除行。
下面是我們需要匯入的環境依賴:
import sys import numpy as np from imageio import imread, imwrite from scipy.ndimage.filters import convolve # tqdm並非必需,但能為提供很美觀的進度條,方便我們檢視進度 from tqdm import trange
能量圖
第一步是為每個畫素計算出一個能量值。原作者在論文中定義了很多不同的能量函式供我們使用,我們使用最基本的那個:

那麼這到底是啥意思呢?I 指影象,那麼這個方程告訴我們的是,對於影象中的每個畫素,每個通道,我們執行如下操作:
- 找到X軸中的偏導數
- 找到Y軸中的偏導數
- 將它們的絕對值相加。
這就會成為該畫素的能量值。那麼問題來了“怎麼計算影象中的導數?”計算影象的導數有很多種方法,我們這裡使用 sobel 濾波器。這是一種卷積核心,可以在影象的每個通道上執行。這裡是影象兩個不同方向上的濾波器:

我們從直覺上可以認為,第一個濾波器會用其頂部值的差將每個畫素替換為其在底部的值。第二個濾波器會用其左邊值和右邊值的差替換每個畫素。這樣就能捕捉 3X3 區域畫素的整體趨勢。實際上,這種方法和邊緣檢測演算法高度相關。
計算能量圖就比較簡單了:
def calc_energy(img): filter_du = np.array([ [1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0], ]) # 這會將它從2D濾波轉換為3D濾波器 # 為每個通道:R,G,B複製相同的濾波器 filter_du = np.stack([filter_du] * 3, axis=2) filter_dv = np.array([ [1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0], ]) # 這會將它從2D濾波轉換為3D濾波器 # 為每個通道:R,G,B複製相同的濾波器 filter_dv = np.stack([filter_dv] * 3, axis=2) img = img.astype('float32') convolved = np.absolute(convolve(img, filter_du)) + np.absolute(convolve(img, filter_dv)) # 我們將紅,綠,藍通道中的能量相加 energy_map = convolved.sum(axis=2) return energy_map
我們將能量圖視覺化:

很明顯,具有最小變分的區域,比如天空、靜止的水域,都有非常低的能量(較暗的區域)。在我們執行接縫裁剪演算法時,被移除的線條會在緊密關聯影象中這些區域的同時,試圖儲存影象中具有高能量的部分(較亮的區域)。
找到能量值最小的縫隙
我們的下一個目標是找到從影象頂部到底部之間具有最小能量值的路徑。這條線必須是八連通的線:意味著線條上的每個畫素必須在邊緣或拐角處彼此相連。比如,下圖的紅線就是我們要找的縫隙:

那麼我們是怎麼發現這條線的?很明顯(明顯??),這個問題可以很好的轉化為動態規劃概念!

我們建立一個稱為 M 的 2D 陣列,儲存該畫素上可見的最小能量值。如果你不熟悉動態規劃,這裡大概就是說 M[i,j] 會在影象中這個點包含最小能量,同時考慮影象頂部到底部之間所有可能經過這個點的縫隙。所以,需要從影象頂部遍歷至影象底部的最小能量值會出現在 M 的最後一行。我們需要從這裡回溯,找到在該縫隙中出現的畫素列,因此我們會使用這些值和 2D 陣列,呼叫 backtrack。
def minimum_seam(img): r, c, _ = img.shape energy_map = calc_energy(img) M = energy_map.copy() backtrack = np.zeros_like(M, dtype=np.int) for i in range(1, r): for j in range(0, c): # 處理影象的左側邊緣,確保我們不會索引-1 if j == 0: idx = np.argmin(M[i - 1, j:j + 2]) backtrack[i, j] = idx + j min_energy = M[i - 1, idx + j] else: idx = np.argmin(M[i - 1, j - 1:j + 2]) backtrack[i, j] = idx + j - 1 min_energy = M[i - 1, idx + j - 1] M[i, j] += min_energy return M, backtrack
從具有最小能量值的縫隙中刪除畫素
然後我們移除具有最小能量值的縫隙,返回一個新影象:
def carve_column(img): r, c, _ = img.shape M, backtrack = minimum_seam(img) # 建立一個(r,c)矩陣,填充值為True # 後面會從值為False的影象中移除所有畫素 mask = np.ones((r, c), dtype=np.bool) # 找到M的最後一行中的最小元素的位置 j = np.argmin(M[-1]) for i in reversed(range(r)): # 標記出需要刪除的畫素 mask[i, j] = False j = backtrack[i, j] # 因為影象有3個通道,我們將蒙版轉換為3D mask = np.stack([mask] * 3, axis=2) # 刪除蒙版中所有標記為False的畫素, # 將其大小重新調整為新影象的維度 img = img[mask].reshape((r, c - 1, 3)) return img
在每一列重複此項操作
到了這裡我們已經打好了所有的地基!現在,我們反覆執行 carve_column 函式,直到刪除了理想數量的列。我們建立一個 crop_c 函式,它會將影象和一個比例因子作為輸入。如果影象維度為(300,600),我們想把它縮減為(150,600),我們需要輸入 0.5 作為引數 scale_c 的值。
def crop_c(img, scale_c): r, c, _ = img.shape new_c = int(scale_c * c) for i in trange(c - new_c): # use range if you don't want to use tqdm img = carve_column(img) return img
彙總資訊
我們可以新增一個主函式,從如下命令列呼叫該函式:
def main(): scale = float(sys.argv[1]) in_filename = sys.argv[2] out_filename = sys.argv[3] img = imread(in_filename) out = crop_c(img, scale) imwrite(out_filename, out) if __name__ == '__main__': main()
然後用如下程式碼執行:
python carver.py 0.5 image.jpg cropped.jpg
現在,cropped.jpg 應該包含如下一張圖:
<figure style="margin: 1em 0px; color: rgb(26, 26, 26); font-family: -apple-system, BlinkMacSystemFont, "Helvetica Neue", "PingFang SC", "Microsoft YaHei", "Source Han Sans SC", "Noto Sans CJK SC", "WenQuanYi Micro Hei", sans-serif; font-size: medium; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: 2; text-align: start; text-indent: 0px; text-transform: none; white-space: normal; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; background-color: rgb(255, 255, 255); text-decoration-style: initial; text-decoration-color: initial;">

image
</figure>
這樣我們就用 Python 實現了接縫剪裁演算法!
那麼行呢?
很簡單,只需旋轉一下影象,執行 crop_c 就 ok 了!
def crop_r(img, scale_r): img = np.rot90(img, 1, (0, 1)) img = crop_c(img, scale_r) img = np.rot90(img, 3, (0, 1)) return img
將如下內容新增至主函式,現在我們也能剪裁行了!
def main(): if len(sys.argv) != 5: print('usage: carver.py <r/c> <scale> <image_in> <image_out>', file=sys.stderr) sys.exit(1) which_axis = sys.argv[1] scale = float(sys.argv[2]) in_filename = sys.argv[3] out_filename = sys.argv[4] img = imread(in_filename) if which_axis == 'r': out = crop_r(img, scale) elif which_axis == 'c': out = crop_c(img, scale) else: print('usage: carver.py <r/c> <scale> <image_in> <image_out>', file=sys.stderr) sys.exit(1) imwrite(out_filename, out)
以如下程式碼執行:
python carver.py r 0.5 image2.jpg cropped.jpg
這時我們就能將下面這張正方形照片:

轉換為廣角鏡頭的矩形照片,而且完整的保留了原圖的重要內容:

結語
希望本文能幫助你更好的理解接縫裁剪演算法,以及用 Python 實現它。我現在正在研究怎麼改進這種演算法,讓它執行的更快一些。一個比較簡單的改動會是利用計算出的影象中的同一縫隙,去除影象的多個縫隙。我自己試驗了幾次,發現這樣能使演算法執行的更快,每次迭代時去除的縫隙數量越多,演算法就越快,不過影象質量會有明顯的損失。另一個優化方式是在 GPU 上計算能量圖。
小編整理了一些有深度的Python教程和參考資料,加入Python學習交流群【 784758214 】群內有安裝包和學習視訊資料,零基礎,進階,實戰免費的線上直播免費課程,希望可以幫助你快速瞭解Python。“程式設計是門手藝活”。什麼意思?得練啊。
點選: 加入
以下是完整的程式:
#!/usr/bin/env python """ Usage: python carver.py <r/c> <scale> <image_in> <image_out> Copyright 2018 Karthik Karanth, MIT License """ import sys from tqdm import trange import numpy as np from imageio import imread, imwrite from scipy.ndimage.filters import convolve def calc_energy(img): filter_du = np.array([ [1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0], ]) # 這會將它從2D濾波轉換為3D濾波器 # 為每個通道:R,G,B複製相同的濾波器 filter_du = np.stack([filter_du] * 3, axis=2) filter_dv = np.array([ [1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0], ]) # 這會將它從2D濾波轉換為3D濾波器 # 為每個通道:R,G,B複製相同的濾波器 filter_dv = np.stack([filter_dv] * 3, axis=2) img = img.astype('float32') convolved = np.absolute(convolve(img, filter_du)) + np.absolute(convolve(img, filter_dv)) # 我們計算紅,綠,藍通道中的能量值之和 energy_map = convolved.sum(axis=2) return energy_map def crop_c(img, scale_c): r, c, _ = img.shape new_c = int(scale_c * c) for i in trange(c - new_c): img = carve_column(img) return img def crop_r(img, scale_r): img = np.rot90(img, 1, (0, 1)) img = crop_c(img, scale_r) img = np.rot90(img, 3, (0, 1)) return img def carve_column(img): r, c, _ = img.shape M, backtrack = minimum_seam(img) mask = np.ones((r, c), dtype=np.bool) j = np.argmin(M[-1]) for i in reversed(range(r)): mask[i, j] = False j = backtrack[i, j] mask = np.stack([mask] * 3, axis=2) img = img[mask].reshape((r, c - 1, 3)) return img def minimum_seam(img): r, c, _ = img.shape energy_map = calc_energy(img) M = energy_map.copy() backtrack = np.zeros_like(M, dtype=np.int) for i in range(1, r): for j in range(0, c): # 處理影象的左側邊緣,確保我們不會索引-1 if j == 0: idx = np.argmin(M[i-1, j:j + 2]) backtrack[i, j] = idx + j min_energy = M[i-1, idx + j] else: idx = np.argmin(M[i - 1, j - 1:j + 2]) backtrack[i, j] = idx + j - 1 min_energy = M[i - 1, idx + j - 1] M[i, j] += min_energy return M, backtrack def main(): if len(sys.argv) != 5: print('usage: carver.py <r/c> <scale> <image_in> <image_out>', file=sys.stderr) sys.exit(1) which_axis = sys.argv[1] scale = float(sys.argv[2]) in_filename = sys.argv[3] out_filename = sys.argv[4] img = imread(in_filename) if which_axis == 'r': out = crop_r(img, scale) elif which_axis == 'c': out = crop_c(img, scale) else: print('usage: carver.py <r/c> <scale> <image_in> <image_out>', file=sys.stderr) sys.exit(1) imwrite(out_filename, out) if __name__ == '__main__': main()
小編整理了一些有深度的Python教程和參考資料,加入Python學習交流群【 784758214 】群內有安裝包和學習視訊資料,零基礎,進階,實戰免費的線上直播免費課程,希望可以幫助你快速瞭解Python。“程式設計是門手藝活”。什麼意思?得練啊。
點選: 加入