1. 程式人生 > >BFPRT 演算法 (TOP-K 問題)——本質就是在利用分組中位數的中位數來找到較快排更合適的pivot元素

BFPRT 演算法 (TOP-K 問題)——本質就是在利用分組中位數的中位數來找到較快排更合適的pivot元素

先說快排最壞情況下的時間複雜度為n^2。

正常情況:

 

 

最壞的情況下,待排序的記錄序列正序或逆序,每次劃分只能得到一個比上一次劃分少一個記錄的子序列,(另一個子序列為空)。此時,必須經過n-1次遞迴呼叫才能把所有記錄定位,而且第i趟劃分需要經過n-i次比較才能找個才能找到第i個記錄的位置,因此時間複雜度為

  所以BFPRT本質上是在尋找正確的pivot元素!!!避免這種最壞情況出現。    

在BFPTR演算法中,僅僅是改變了快速排序Partion中的pivot值的選取,在快速排序中,我們始終選擇第一個元素或者最後一個元素作為pivot,而在BFPTR演算法中,每次選擇五分中位數的中位數作為pivot,這樣做的目的就是使得劃分比較合理,從而避免了最壞情況的發生。演算法步驟如下:

1. 將 n 個元素劃為 \lfloor n/5\rfloor 組,每組5個,至多隻有一組由 n\bmod5 個元素組成。 
2. 尋找這 \lceil n/5\rceil 個組中每一個組的中位數,這個過程可以用插入排序。 
3. 對步驟2中的 \lceil n/5\rceil 箇中位數,重複步驟1和步驟2,遞迴下去,直到剩下一個數字。
4. 最終剩下的數字即為pivot,把大於它的數全放左邊,小於等於它的數全放右邊。 
5. 判斷pivot的位置與k的大小,有選擇的對左邊或右邊遞迴。

 

下面為程式碼實現,其所求為前 k 小的數

#include <iostream>
#include <algorithm>

using namespace std;

int InsertSort(int array[], int left, int right);
int GetPivotIndex(int array[], int left, int right);
int Partition(int array[], int left, int right, int pivot_index);
int BFPRT(int array[], int left, int right, int k);

int main()
{
    int k = 8; // 1 <= k <= array.size
    int array[20] = { 11,9,10,1,13,8,15,0,16,2,17,5,14,3,6,18,12,7,19,4 };

    cout << "原陣列:";
    for (int i = 0; i < 20; i++)
        cout << array[i] << " ";
    cout << endl;

    // 因為是以 k 為劃分,所以還可以求出第 k 小值
    cout << "第 " << k << " 小值為:" << array[BFPRT(array, 0, 19, k)] << endl;

    cout << "變換後的陣列:";
    for (int i = 0; i < 20; i++)
        cout << array[i] << " ";
    cout << endl;

    return 0;
}

/**
 * 對陣列 array[left, right] 進行插入排序,並返回 [left, right]
 * 的中位數。
 */
int InsertSort(int array[], int left, int right)
{
    int temp;
    int j;

    for (int i = left + 1; i <= right; i++)
    {
        temp = array[i];
        j = i - 1;
        while (j >= left && array[j] > temp)
            array[j + 1] = array[j--];
        array[j + 1] = temp;
    }

    return ((right - left) >> 1) + left;
}

/**
 * 陣列 array[left, right] 每五個元素作為一組,並計算每組的中位數,
 * 最後返回這些中位數的中位數下標(即主元下標)。
 *
 * @attention 末尾返回語句最後一個引數多加一個 1 的作用其實就是向上取整的意思,
 * 這樣可以始終保持 k 大於 0。
 */
int GetPivotIndex(int array[], int left, int right)
{
    if (right - left < 5)
        return InsertSort(array, left, right);

    int sub_right = left - 1;

    // 每五個作為一組,求出中位數,並把這些中位數全部依次移動到陣列左邊
    for (int i = left; i + 4 <= right; i += 5)
    {
        int index = InsertSort(array, i, i + 4);
        swap(array[++sub_right], array[index]);
    }

    // 利用 BFPRT 得到這些中位數的中位數下標(即主元下標)
    return BFPRT(array, left, sub_right, ((sub_right - left + 1) >> 1) + 1);
}

/**
 * 利用主元下標 pivot_index 進行對陣列 array[left, right] 劃分,並返回
 * 劃分後的分界線下標。
 */
int Partition(int array[], int left, int right, int pivot_index)
{
    swap(array[pivot_index], array[right]); // 把主元放置於末尾

    int partition_index = left; // 跟蹤劃分的分界線
    for (int i = left; i < right; i++)
    {
        if (array[i] < array[right])
        {
            swap(array[partition_index++], array[i]); // 比主元小的都放在左側
        }
    }

    swap(array[partition_index], array[right]); // 最後把主元換回來

    return partition_index;
}

/**
 * 返回陣列 array[left, right] 的第 k 小數的下標
 */
int BFPRT(int array[], int left, int right, int k)
{
    int pivot_index = GetPivotIndex(array, left, right); // 得到中位數的中位數下標(即主元下標)
    int partition_index = Partition(array, left, right, pivot_index); // 進行劃分,返回劃分邊界
    int num = partition_index - left + 1;

    if (num == k)
        return partition_index;
    else if (num > k)
        return BFPRT(array, left, partition_index - 1, k);
    else
        return BFPRT(array, partition_index + 1, right, k - num);
}

  

 

執行如下:

原陣列:11 9 10 1 13 8 15 0 16 2 17 5 14 3 6 18 12 7 19 4
第 8 小值為:7
變換後的陣列:4 0 1 3 2 5 6 7 8 9 10 12 13 14 17 15 16 11 18 19

 

 

效能分析:

劃分時以5個元素為一組求取中位數,共得到n/5箇中位數,再遞迴求取中位數,複雜度為T(n/5)。

得到的中位數x作為主元進行劃分,在n/5箇中位數中,主元x大於其中1/2*n/5=n/10的中位數,而每個中位數在其本來的5個數的小組中又大於或等於其中的3個數,所以主元x至少大於所有數中的n/10*3=3/10*n個。同理,主元x至少小於所有數中的3/10*n個

。即劃分之後,任意一邊的長度至少為3/10,在最壞情況下,每次選擇都選到了7/10的那一部分,則遞迴的複雜度為T(7/10*n)。

在每5個數求中位數和劃分的函式中,進行若干個次線性的掃描,其時間複雜度為c*n,其中c為常數。其總的時間複雜度滿足 T(n) <= T(n/5) + T(7/10*n) + c * n。

我們假設T(n)=x*n,其中x不一定是常數(比如x可以為n的倍數,則對應的T(n)=O(n^2))。則有 x*n <= x*n/5 + x*7/10*n + c*n,得到 x<=10*c。於是可以知道x與n無關,T(n)<=10*c*n,為線性時間複雜度演算法。而這又是最壞情況下的分析,故BFPRT可以在最壞情況下以線性時間求得n個數中的第k個數。