1. 程式人生 > >007-尋找第k小元素-分治法-《演算法設計技巧與分析》M.H.A學習筆記

007-尋找第k小元素-分治法-《演算法設計技巧與分析》M.H.A學習筆記

n個元素的陣列中查詢第k小的元素。Θ(n

顯然先排序的話,複雜度為Onlogn)。

但我們還有一個很漂亮的Θ(n)的演算法。

先說一下分治法的閾值:我們有一種吊炸天的分治演算法,可以用很好的效率求解出某個問題,分治演算法當然在達到一個非常小的規模時,會能直接或用很簡單的方法得出結論,但是,其實,問題規模在達到某個閾值的時候,用直接樸素的方法解決這個規模的問題的效率,已經比繼續分治的演算法高了。這個時候,我們在這個閾值就開始選擇樸素的方法才是最明智的選擇。

在尋找第k小元素的分治演算法中,這個閾值是44。為什麼是44看下面分析。

基本思路:

(1) 當規模小於閾值時,直接用排序演算法返回結果。

(2) n大於閾值時,把n元素劃分為5個元素一組的n/5組,排除剩餘元素(不會有影響,這裡只是為了求中項mm,分別排序,然後挑出每一組元素的中間值,再在所有的中間值中,遞迴呼叫本演算法,挑出中間值mm

(3) 把元素劃分為A1A2A3三組,分別包含小於、等於、大於mm的元素。

(4)分三種情況:

若A1的元素數量大於等於K,即第K個元素在第一組內:在A1中遞迴查詢第k小元素。

若A1A2元素個數之和大於等於K,即中項mm為第K個元素:返回mm

否則,第K個元素在第三組:在A3中遞迴尋找第(k-|A1A2元素數量之和|)小元素。

虛擬碼:

  1. 輸入  n 個元素的陣列 A[1...n] 和整數 k,1 ≤ k ≤ n  
  2. 輸出  A 中的第 k 小元素  
  3. 演算法描述 select(A, low, high, k)  
  4. 1. n ← high - low + 1----(Θ(1))  
  5. 2. if  n < 44 then 將 A 排序 return (A[k])----(Θ(1))  
  6. 3. 令 q =  ⌊n/5⌋。將 A 分成 q 組,每組5個元素。如果5不整除 n ,則排除剩餘的元素。----(Θ(n))  
  7. 4. 將 q 組中的每一組單獨排序,找出中項。所有中項的集合為 M。----(Θ(n))  
  8. 5. mm ← select(M, 1, q,  ⌈q/2⌉)   { mm 為中項集合的中項 } ----T(n/5)  
  9. 6. 將 A[low...high] 分成三組----(Θ(n))  
  10.     A1 = { a | a < mm }  
  11.     A2 = { a | a = mm }  
  12.     A3 = { a | a > mm }  
  13. 7. case  
  14.     |A1| ≥ k : return select(A1, 1, |A1|, k)  
  15.     |A1| + |A2| ≥ k : return mm  
  16.     |A1| + |A2| < k : return select(A3, 1, |A3|, k - |A1| - |A2|)  
  17. 8. end case  


演算法分析:

1-6步的複雜度都很容易理解,我們著重討論第7步的演算法複雜度。

 

上圖是處理到第5步後的元素,從左到右按各組中項升序排列,每組5個元素從下到上按升序排列。

我們需要知道的是第7步時候問題的規模,即A1A3這兩個陣列的規模。

上圖中我們可以看到W區的元素都是小於或等於mm的,令A1’表示小於或等於mm的元素的集合,顯然W會是A1’的子集,即A1’的元素數量大於等於W的元素數量。

於是我們有下面這個式子:

 

A3的數量=n-A1’的數量,於是我們可以等到下面的式子:

 

由對稱性,可得:

 

至此,我們知道A1A3的上界是0.7n+1.2,步驟7耗費的時間是T0.7n+1.2)。

到這裡還沒說到44閾值的由來,好,要開始說了。

我們希望去掉1.2這個常數,於是引入底函式幫忙:

 

 

這條式子什麼時候成立呢?解不等式可得n>=44

閾值44誕生了!!!

現在我們還有了演算法執行時間的遞推式:

 

可以算出來T(n)=Θ(n)

對於求中項的題目也是同樣的解法,就是找第(n+1/2個元素(奇數)和第n/2n/2+1個元素(偶數)。

需要注意,這個演算法的常數倍數(比如c)都是很大的。

對於這個問題,還存在一個具有Θ(n)期望執行時間和較小常數倍數的隨機選擇演算法,請多關注這個專欄,有機會再介紹(挖坑)。

Java程式碼:

貼一個找到的Java程式碼,C++程式碼以後再寫一個補上(再挖坑):

  1. publicstaticint select(int[] A, int k){  
  2.         return selectDo(A, 0, A.length-1, k);  
  3.     }  
  4.     privatestaticint selectDo(int[] A, int low, int high, int k){  
  5.         //select k min number
  6.         int p = high - low + 1;  
  7.         if(p < 44){  
  8.             Arrays.sort(A, low, high+1);  
  9.             return A[low+k];  
  10.         }  
  11.         //A divided into q groups, each group 5 elements, and sort them
  12.         int q = p/5;  
  13.         int[] M = newint[q];  
  14.         for(int i = 0; i < q; i ++){  
  15.             Arrays.sort(A, low + 5*i, low + 5*i + 5);  
  16.             M[i] = A[low+5*i+2];  
  17.         }  
  18.         //select mid in M
  19.         int mid = selectDo(A, 0, q-1, (q-1)/2);  
  20.         //A divided into 3 groups
  21.         int[] A1 = newint[p];  
  22.         int[] A2 = newint[p];  
  23.         int[] A3 = newint[p];  
  24.         int count1, count2, count3;  
  25.         count1 = count2 = count3 = 0;  
  26.         for(int i = low; i <= high; i ++){  
  27.             if(A[i] < mid)  
  28.                 A1[count1++] = A[i];  
  29.             elseif(A[i] == mid)  
  30.                 A2[count2++] = A[i];  
  31.             else
  32.                 A3[count3++] = A[i];  
  33.         }  
  34.         if(count1 >= k)  
  35.             return selectDo(A1, 0, count1-1, k);  
  36.         if(count1 + count2 >= k)  
  37.             return mid;  
  38.         return selectDo(A3, 0, count3-1, k-count1-count2);  
  39.     }