1. 程式人生 > >最小堆解決 Top K 問題

最小堆解決 Top K 問題

Top K 問題指從一組資料中選出最大的K個數。常見的例子有:熱門搜尋前10,最常聽的20首歌等。

對於這類問題,可能我們會首先想到先對這組資料進行排序,再選取前K個數。雖然這能解決問題,但效率不高,因為我們只需要部分有序,它卻對整體進行了排序。最小堆是解決Top K 問題的一個好的方法(如果我們需要選出K個最小的數,用的是最大堆)。

Top K 實現步驟

最小堆也叫小根堆,實際上是一個完全二叉樹,它的子結點的值總是大於等於它的父節點。關於最小堆的構造與調整可以參考這篇文章:Java優先順序佇列

對於 Top K 問題,我們只需要維持一個大小為K的最小堆。

比如我們現在要選取陣列A中最大的10個數,過程如下:

  • 用A[0]-A[9]建立一個最小堆
  • 對於 A[i](i > 9),如果 A[i] 大於最小堆的堆頂,將堆頂替換為 A[i],替換後調整堆使得符合最小堆的特徵
  • 迴圈進行第二個步驟直至遍歷完陣列

Java 示例

最小堆的構建與調整

public class MinHeap<T> {
    private Object[] queue;
    private int size;

    public MinHeap() {
        queue = new Object[11];
    }

    public MinHeap(int capacity)
{ queue = new Object[capacity]; } public boolean offer(T t) { int k = size; if (size == 0) queue[0] = t; size++; moveUp(k, t); return true; } public void moveUp(int k, T t) { Comparable<? super T> key = (Comparable<
? super T>) t; while (k > 0) { int parent = (k - 1) >>> 1; Object e = queue[parent]; if (key.compareTo((T) e) > 0) break; queue[k] = e; k = parent; } queue[k] = key; } public T poll() { if (size == 0) return null; int s = --size; T result = (T) queue[0]; T end = (T) queue[s]; queue[s] = null; if (s != 0) moveDown(0, end); return result; } public void moveDown(int k, T end) { Comparable<? super T> key = (Comparable<? super T>) end; int half = size >>> 1; while (k < half) { int left = (k << 1) + 1; int right = left + 1; Object c = queue[left]; if (right < size && ((Comparable<? super T>) c).compareTo((T) queue[right]) > 0) { c = queue[left = right]; } if (key.compareTo((T) c) <= 0) break; queue[k] = c; k = left; } queue[k] = key; } boolean setHead(T t) { queue[0] = t; return true; } public T peek() { return size == 0 ? null : (T) queue[0]; } }

從陣列中選取前K個最大的數

public class TopK {
    private static Random random = new Random();

    public static int[] factory(int n) {
        int[] data = new int[n];
        for (int i = 0; i < n; i++)
            data[i] = random.nextInt(100);
        return data;
    }

    public void find(int[] array, int n) {
        MinHeap minHeap = new MinHeap(n);
        for (int i = 0; i < n; i++) {
            minHeap.offer(array[i]);
        }
        for (int j = n; j < array.length; j++) {
            if (array[j] > (int) minHeap.peek()) {
                minHeap.setHead(array[j]);
                minHeap.moveDown(0, array[j]);
            }
        }
        System.out.print("[");
        for (int t = 0; t < n - 1; t++)
            System.out.print(minHeap.poll() + ", ");
        System.out.println(minHeap.poll() + "]");
    }

    public static void main(String[] args) {
        int[] data = factory(11);
        System.out.println(Arrays.toString(data));
        TopK topK = new TopK();
        topK.find(data,10);
    }
}

輸出結果:

41, 34, 39, 58, 37, 9, 70, 18, 97, 75, 92]
[18, 34, 37, 39, 41, 58, 70, 75, 92, 97]

堆排序

堆排序藉助了最小堆或最大堆的特性,它的時間複雜度為 O(nlogn),空間複雜度為 O(1)。堆排序是一種原地排序,一般比快速排序慢,但它佔用的空間少,因此在對佔用空間有要求或求解類似 Top K 問題時,可以考慮採用。

注意,堆排序與快速排序都是不穩定的演算法。

Java 堆排序示例:

public class HeapSort {

    /**
     *從下往上建堆
     * @param array
     */
    public static void buildHeap(int[] array) {
        for (int t = array.length / 2; t >= 0; t--) {
            heapify(array, array.length, t);
        }
    }

    /**
     * @param array
     * @param size  待排序陣列長度
     * @param t     從t位置向下堆化
     */
    public static void heapify(int[] array, int size, int t) {
        int half = size >>> 1;
        int temp = array[t];
        while (t < half) {
            int left = (t << 1) + 1;
            int right = left + 1;
            int min = array[left];
            if (right < size && min > array[right])
                min = array[left = right];
            if (temp < min)
                break;
            array[t] = min;
            t = left;
        }
        array[t] = temp;
    }

    public static void sort(int[] array) {
        for (int i = array.length - 1; i > 0; i--) {
            int temp = array[0];
            array[0] = array[i];
            array[i] = temp;
            heapify(array, i, 0);
        }
    }

    public static Random random = new Random();

    public static int[] factory(int i) {
        int[] array = new int[i];
        for (int t = 0; t < i; t++) {
            array[t] = random.nextInt(100);
        }
        return array;
    }

    public static void main(String[] args) {
        int[] array = factory(20);
        System.out.println("初始陣列:" + Arrays.toString(array));
        buildHeap(array);
        System.out.println("堆化陣列:" + Arrays.toString(array));
        sort(array);
        System.out.println("堆排序後陣列:" + Arrays.toString(array));
    }
}

輸出結果:

初始陣列:[96, 77, 19, 14, 12, 91, 43, 36, 56, 21, 91, 37, 21, 48, 16, 14, 4, 37, 83, 39]
堆化陣列:[4, 12, 16, 14, 19, 37, 21, 14, 37, 21, 91, 91, 43, 77, 48, 96, 36, 56, 83, 39]
堆排序後陣列:[96, 91, 91, 83, 77, 56, 48, 43, 39, 37, 37, 36, 21, 21, 19, 16, 14, 14, 12, 4]

參考連結: