007-尋找第k小元素-分治法-《演算法設計技巧與分析》M.H.A學習筆記
在n個元素的陣列中查詢第k小的元素。Θ(n)
顯然先排序的話,複雜度為O(nlogn)。
但我們還有一個很漂亮的Θ(n)的演算法。
先說一下分治法的閾值:我們有一種吊炸天的分治演算法,可以用很好的效率求解出某個問題,分治演算法當然在達到一個非常小的規模時,會能直接或用很簡單的方法得出結論,但是,其實,問題規模在達到某個閾值的時候,用直接樸素的方法解決這個規模的問題的效率,已經比繼續分治的演算法高了。這個時候,我們在這個閾值就開始選擇樸素的方法才是最明智的選擇。
在尋找第k小元素的分治演算法中,這個閾值是44。為什麼是44看下面分析。
基本思路:
(1) 當規模小於閾值時,直接用排序演算法返回結果。
(2) 當n大於閾值時,把n個元素劃分為5個元素一組的n/5組,排除剩餘元素(不會有影響,這裡只是為了求中項mm),分別排序,然後挑出每一組元素的中間值,再在所有的中間值中,遞迴呼叫本演算法,挑出中間值mm。
(3) 把元素劃分為A1、A2、A3三組,分別包含小於、等於、大於mm的元素。
(4)分三種情況:
若A1的元素數量大於等於K,即第K個元素在第一組內:在A1中遞迴查詢第k小元素。
若A1、A2元素個數之和大於等於K,即中項mm為第K個元素:返回mm
否則,第K個元素在第三組:在A3中遞迴尋找第(k-|A1、A2元素數量之和|)小元素。
虛擬碼:
-
輸入 n 個元素的陣列 A[1...n] 和整數 k,1 ≤ k ≤ n
- 輸出 A 中的第 k 小元素
- 演算法描述 select(A, low, high, k)
- 1. n ← high - low + 1----(Θ(1))
- 2. if n < 44 then 將 A 排序 return (A[k])----(Θ(1))
- 3. 令 q = ⌊n/5⌋。將 A 分成 q 組,每組5個元素。如果5不整除 n ,則排除剩餘的元素。----(Θ(n))
- 4. 將 q 組中的每一組單獨排序,找出中項。所有中項的集合為 M。----(Θ(n))
-
5. mm ← select(M, 1, q, ⌈q/2⌉) { mm 為中項集合的中項 } ----T(n/5)
- 6. 將 A[low...high] 分成三組----(Θ(n))
- A1 = { a | a < mm }
- A2 = { a | a = mm }
- A3 = { a | a > mm }
- 7. case
- |A1| ≥ k : return select(A1, 1, |A1|, k)
- |A1| + |A2| ≥ k : return mm
- |A1| + |A2| < k : return select(A3, 1, |A3|, k - |A1| - |A2|)
- 8. end case
演算法分析:
第1-6步的複雜度都很容易理解,我們著重討論第7步的演算法複雜度。
上圖是處理到第5步後的元素,從左到右按各組中項升序排列,每組5個元素從下到上按升序排列。
我們需要知道的是第7步時候問題的規模,即A1、A3這兩個陣列的規模。
上圖中我們可以看到W區的元素都是小於或等於mm的,令A1’表示小於或等於mm的元素的集合,顯然W會是A1’的子集,即A1’的元素數量大於等於W的元素數量。
於是我們有下面這個式子:
A3的數量=n-A1’的數量,於是我們可以等到下面的式子:
由對稱性,可得:
至此,我們知道A1、A3的上界是0.7n+1.2,步驟7耗費的時間是T(0.7n+1.2)。
到這裡還沒說到44閾值的由來,好,要開始說了。
我們希望去掉1.2這個常數,於是引入底函式幫忙:
即
這條式子什麼時候成立呢?解不等式可得n>=44。
閾值44誕生了!!!
現在我們還有了演算法執行時間的遞推式:
可以算出來T(n)=Θ(n)。
對於求中項的題目也是同樣的解法,就是找第(n+1)/2個元素(奇數)和第n/2、n/2+1個元素(偶數)。
需要注意,這個演算法的常數倍數(比如c)都是很大的。
對於這個問題,還存在一個具有Θ(n)期望執行時間和較小常數倍數的隨機選擇演算法,請多關注這個專欄,有機會再介紹(挖坑)。
Java程式碼:
貼一個找到的Java程式碼,C++程式碼以後再寫一個補上(再挖坑):
- publicstaticint select(int[] A, int k){
- return selectDo(A, 0, A.length-1, k);
- }
- privatestaticint selectDo(int[] A, int low, int high, int k){
- //select k min number
- int p = high - low + 1;
- if(p < 44){
- Arrays.sort(A, low, high+1);
- return A[low+k];
- }
- //A divided into q groups, each group 5 elements, and sort them
- int q = p/5;
- int[] M = newint[q];
- for(int i = 0; i < q; i ++){
- Arrays.sort(A, low + 5*i, low + 5*i + 5);
- M[i] = A[low+5*i+2];
- }
- //select mid in M
- int mid = selectDo(A, 0, q-1, (q-1)/2);
- //A divided into 3 groups
- int[] A1 = newint[p];
- int[] A2 = newint[p];
- int[] A3 = newint[p];
- int count1, count2, count3;
- count1 = count2 = count3 = 0;
- for(int i = low; i <= high; i ++){
- if(A[i] < mid)
- A1[count1++] = A[i];
- elseif(A[i] == mid)
- A2[count2++] = A[i];
- else
- A3[count3++] = A[i];
- }
- if(count1 >= k)
- return selectDo(A1, 0, count1-1, k);
- if(count1 + count2 >= k)
- return mid;
- return selectDo(A3, 0, count3-1, k-count1-count2);
- }