1. 程式人生 > >找出無序陣列中最小的k個數(top k問題)

找出無序陣列中最小的k個數(top k問題)

給定一個無序的整型陣列arr,找到其中最小的k個數

該題是網際網路面試中十分高頻的一道題,如果用普通的排序演算法,排序之後自然可以得到最小的k個數,但時間複雜度高達O(NlogN),且普通的排序演算法均屬於內部排序,需要一次性將全部資料裝入記憶體,對於求解海量資料的top k問題是無能為力的。

針對海量資料的top k問題,這裡實現了一種時間複雜度為O(Nlogk)的有效演算法:初始時一次性從檔案中讀取k個數據,並建立一個有k個數的最大堆,代表目前選出的最小的k個數。然後從檔案中一個一個的讀取剩餘資料,如果讀取的資料比堆頂元素小,則把堆頂元素替換成當前的數,然後從堆頂向下重新進行堆調整;否則不進行任何操作,繼續讀取下一個資料。直到檔案中的所有資料讀取完畢,堆中的k個數就是海量資料中最小的k個數(如果是找最大的k個數,則使用最小堆)。具體過程請參看如下程式碼:

public class FindKMinNums {

    /**
     * 維護一個有k個數的最大堆,代表目前選出的最小的k個數
     *
     * @param read 實際場景中,read提供的資料需要從檔案中讀取,這裡為了方便用陣列表示
     * @param k
     * @return
     */
    public static int[] getKMinsByHeap(int[] read, int k) {
        if (k < 1 || k > read.length) {
            return read;
        }
        int[] kHeap = new int[k];
        for (int i = 0; i < k; i++) {   // 初始時一次性從檔案中讀取k個數據
            kHeap[i] = read[i];
        }
        buildHeap(kHeap, k);            // 建堆,時間複雜度O(k)
        for (int i = k; i < read.length; i++) { // 從檔案中一個一個的讀取剩餘資料
            if (read[i] < kHeap[0]) {
                kHeap[0] = read[i];
                heapify(kHeap, 0, k);   // 從堆頂開始向下進行調整,時間複雜度O(logk)
            }
        }
        return kHeap;
    }

    /**
     * 建堆函式
     *
     * @param arr
     * @param n
     */
    public static void buildHeap(int[] arr, int n) {
        for (int i = n / 2 - 1; i >= 0; i--) {
            heapify(arr, i, n);
        }
    }

    /**
     * 從arr[i]向下進行堆調整
     *
     * @param arr
     * @param i
     * @param heapSize
     */
    public static void heapify(int[] arr, int i, int heapSize) {
        int leftChild = 2 * i + 1;
        int rightChild = 2 * i + 2;
        int max = i;
        if (leftChild < heapSize && arr[leftChild] > arr[max]) {
            max = leftChild;
        }
        if (rightChild < heapSize && arr[rightChild] > arr[max]) {
            max = rightChild;
        }
        if (max != i) {
            swap(arr, i, max);
            heapify(arr, max, heapSize);  // 堆結構發生了變化,繼續向下進行堆調整
        }
    }

    public static void swap(int[] arr, int i, int j) {
        int tmp = arr[i];
        arr[i] = arr[j];
        arr[j] = tmp;
    }

    public static void printArray(int[] arr) {
        for (int i = 0; i <= arr.length; i++) {
            System.out.print(arr[i] + " ");
        }
        System.out.println();
    }

    public static void main(String[] args) {
        int[] arr = {6, 9, 1, 3, 1, 2, 2, 5, 6, 1, 3, 5, 9, 7, 2, 5, 6, 1, 9};
        // sorted : { 1, 1, 1, 1, 2, 2, 2, 3, 3, 5, 5, 5, 6, 6, 6, 7, 9, 9, 9 }
        printArray(getKMinsByHeap(arr, 10));
    }
}

對於從海量資料(N)中找出TOP K,這種演算法僅需一次性將k個數裝入記憶體,其餘資料從檔案一個一個讀即可,所以它是針對海量資料TOP K問題最為有效的演算法


對於非海量資料的情況,還有一種時間複雜度僅為O(N)的經典演算法 —— BFPRT演算法,該演算法於1973年由Blum、Floyd、Pratt、Rivest和Tarjan聯合發明,其中蘊含的深刻思想改變了世界。

BFPRT演算法解決了這樣一個問題:在時間複雜度O(N)內,從無序的陣列中找到第k小的數。顯而易見的是,如果我們找到了第k小的數,那麼想求arr中最小的k個數,只需再遍歷一遍陣列,把小於第k小的數都蒐集起來,再把不足部分用第k小的數補全即可。

BFPRT演算法是如何找到第k小的數?以下是BFPRT演算法的過程,假設BFPRT演算法的函式是int select(int[] arr, int k),該函式的功能為在arr中找到第k小的數,然後返回該數。select(arr, k)的過程如下:

  1. 將arr中的n個元素劃分成 n/5 組,每組5個元素,如果最後的組不夠5個元素,那麼最後剩下的元素為一組(n%5 個元素)。時間複雜度O(1)

  2. 對每個組進行排序,比如選擇簡單的插入排序,只針對每個組最多5個元素之間的組內排序,組與組之間不排序。時間複雜度 N/5O(1)

  3. 找到每個組的中位數,如果元素個數為偶數可以找下中位數,讓這些中位陣列成一個新的陣列,記為mArr。時間複雜度O(N/5)

  4. 遞迴呼叫select(mArr, mArr.length / 2),意義是找到mArr這個陣列的中位數x,即中位數的中位數。時間複雜度T(N/2)

  5. 根據上面得到的x劃分整個arr陣列(partition過程),劃分的過程為:在arr中,比x小的都在x左邊,比x大的都在x右邊,x在中間。時間複雜度O(N)

  6. 假設劃分完成後,x在arr中的位置記為i,關於i與k的相對大小,有如下三種情況:

    1. 如果 i = k,說明x為整個陣列中第k小的數,直接返回。時間複雜度O(1)
    2. 如果 i < k,說明x處在第k小的數左邊,應該在x的右邊繼續尋找,所以遞迴呼叫select函式,在右半區尋找第k-i小的數。時間複雜度不超過T(7/10N + 6)
    3. 如果 i > k,說明x處在第k小的數右邊,應該在x的左邊繼續尋找,所以遞迴呼叫select函式,在左半區尋找第k小的數。時間複雜度同樣不超過T(7/10N + 6)

上述過程的程式碼實現如下:

public class FindKMinNums {

    /**
     * 先用BFPRT演算法求出第k小的數,再遍歷一遍陣列才能求出最小的k個數,時間複雜度O(N)
     * 需要將所有資料一次性裝入記憶體,適用於非海量資料的情況
     *
     * @param arr
     * @param k
     * @return
     */
    public static int[] getKMins(int[] arr, int k) {
        if (k < 1 || k > arr.length) {
            return arr;
        }
        int kthMin = getKthMinByBFPRT(arr, k);  // 使用BFPRT演算法求得第k小的數,O(N)
        int[] kMins = new int[k];               // 下面遍歷一遍陣列,利用第k小的數找到最小的k個數,O(N)
        int index = 0;
        for (int i = 0; i < arr.length; i++) {
            if (arr[i] < kthMin) {              // 小於第k小的數,必然屬於最小的k個數
                kMins[index++] = arr[i];
            }
        }
        while (index < k) {
            kMins[index++] = kthMin;            // 不足部分用第k小的數補全
        }
        return kMins;
    }

    /**
     * 使用BFPRT演算法求第k小的數
     *
     * @param arr
     * @param k
     * @return
     */
    public static int getKthMinByBFPRT(int[] arr, int k) {
        int[] arrCopy = copyArray(arr); // 在得到第k小的數之後還要遍歷一遍原陣列,所以並不直接操作原陣列
        return select(arrCopy, 0, arrCopy.length - 1, k - 1);   // 第k小的數,即排好序後下標為k-1的數
    }

    /**
     * 拷貝陣列
     *
     * @param arr
     * @return
     */
    public static int[] copyArray(int[] arr) {
        int[] arrCopy = new int[arr.length];
        for (int i = 0; i < arrCopy.length; i++) {
            arrCopy[i] = arr[i];
        }
        return arrCopy;
    }

    /**
     * 在陣列arr的下標範圍[begin, end]內,找到排序後位於整個arr陣列下標為index的數
     *
     * @param arr
     * @param begin
     * @param end
     * @param index
     * @return
     */
    public static int select(int[] arr, int begin, int end, int index) {
        if (begin == end) {
            return arr[begin];
        }
        int pivot = medianOfMedians(arr, begin, end);   // 核心操作:中位數的中位數作為基準
        int[] pivotRange = partition(arr, begin, end, pivot);   // 拿到分割槽後中區的範圍
        if (index >= pivotRange[0] && index <= pivotRange[1]) { // 命中
            return arr[index];
        } else if (index < pivotRange[0]) {
            return select(arr, begin, pivotRange[0] - 1, index);
        } else {
            return select(arr, pivotRange[1] + 1, end, index);
        }
    }

    /**
     * 選基準
     *
     * @param arr
     * @param begin
     * @param end
     * @return
     */
    public static int medianOfMedians(int[] arr, int begin, int end) {
        int num = end - begin + 1;
        int offset = num % 5 == 0 ? 0 : 1;      // 5個成一組,不滿5個的自己成一組
        int[] mArr = new int[num / 5 + offset]; // 每組的中位數取出構成中位數陣列mArr
        for (int i = 0; i < mArr.length; i++) {
            int beginI = begin + i * 5;
            int endI = beginI + 4;
            mArr[i] = getMedian(arr, beginI, Math.min(endI, end));
        }
        // 求中位數陣列mArr的中位數,作為基準返回
        return select(mArr, 0, mArr.length - 1, mArr.length / 2);
    }

    /**
     * 在陣列arr的下標範圍[begin, end]內,找中位數,如果元素個數為偶數則找下中位數
     *
     * @param arr
     * @param begin
     * @param end
     * @return
     */
    public static int getMedian(int[] arr, int begin, int end) {
        insertionSort(arr, begin, end);
        int sum = begin + end;
        int mid = (sum / 2) + (sum % 2);
        return arr[mid];
    }

    /**
     * 這裡僅用於對一組5個數進行插入排序,時間複雜度O(1)
     *
     * @param arr
     * @param begin
     * @param end
     */
    public static void insertionSort(int[] arr, int begin, int end) {
        for (int i = begin + 1; i <= end; i++) {
            int get = arr[i];
            int j = i - 1;
            while (j >= begin && arr[j] > get) {
                arr[j + 1] = arr[j];
                j--;
            }
            arr[j + 1] = get;
        }
    }

    /**
     * 優化後的快排partition操作
     *
     * @param arr
     * @param begin
     * @param end
     * @param pivot
     * @return 返回劃分後等於基準的元素下標範圍
     */
    public static int[] partition(int[] arr, int begin, int end, int pivot) {
        int small = begin - 1;     // 小區最後一個元素下標
        int big = end + 1;         // 大區第一個元素下標
        int cur = begin;
        while (cur < big) {
            if (arr[cur] < pivot) {
                swap(arr, ++small, cur++);
            } else if (arr[cur] > pivot) {
                swap(arr, --big, cur);
            } else {
                cur++;
            }
        }
        int[] range = new int[2];
        range[0] = small + 1;      // 中區第一個元素下標
        range[1] = big - 1;        // 中區最後一個元素下標
        return range;
    }

    public static void swap(int[] arr, int i, int j) {
        int tmp = arr[i];
        arr[i] = arr[j];
        arr[j] = tmp;
    }

    public static void printArray(int[] arr) {
        for (int i = 0; i < arr.length; i++) {
            System.out.print(arr[i] + " ");
        }
        System.out.println();
    }

    public static void main(String[] args) {
        int[] arr = {6, 9, 1, 3, 1, 2, 2, 5, 6, 1, 3, 5, 9, 7, 2, 5, 6, 1, 9};
        // sorted : { 1, 1, 1, 1, 2, 2, 2, 3, 3, 5, 5, 5, 6, 6, 6, 7, 9, 9, 9 }
        printArray(getKMins(arr, 10));
    }
}

關於BFPRT演算法為什麼在時間複雜度上可以做到穩定的O(N),可以參考《程式設計師程式碼面試指南》P339或《演算法導論》9.3節內容,這裡