1. 程式人生 > >如何用多執行緒實現歸併排序

如何用多執行緒實現歸併排序

等我有時間了,一定要把《演算法導論》啃完,這本書的印刷質量實在太好了,滑稽。

之前聽吳恩達老大說過Python裡面的Numpy包的矩陣運算就是多執行緒的,所以能做到的情況下儘量用矩陣運算代替迴圈,這樣能大大加快運算的速度。

為了提高速度,如果不涉及外部資源讀取的話,要提高執行速度就要做到平行計算,依賴於處理器的數量;如果需要等待耗時的外部資源讀取,就可以通過併發邊讀邊運算。

演算法導論有一章節提到了並行迴圈,多執行緒矩陣乘法和多執行緒歸併排序,方法都是講一個大的計算過程分成幾個獨立的小部分,各個部分讓單獨的執行緒去計算。

排序裡面講問題分解的典型的就有快排和歸併,接下來看一下怎麼寫多執行緒的。

多執行緒歸併排序

直接點的思考方式,歸併排序先要把一個數據分成兩個,然後這兩個分別歸併排序,拍完了把兩個歸併到一起,典型的遞迴。

那麼我們直接點,先把陣列分割好,然後開兩個執行緒,一個執行緒給一個,等著兩個執行緒都搞定了,在把兩個結果合併起來。或者你覺得兩個執行緒每個要處理的還是太長了,那就在這兩個執行緒裡面再把拿到的陣列分割了,各自再開兩個。嘗試一下

先看下單執行緒的版本,做下測試

import java.util.Random;

public class Main {
    public static void main(String[] args) {
        int length = 1000;
        int[] data = (new Data(length)).getData();
        printArr(data);
        System.out.println();
        mergeSort(data);
        printArr(data);
    }

    //遞迴
    private static void mergeSort(int[] nums,int[] tmp,int left,int right){
        if(left<right){
            int center = (left+right)/2;
            mergeSort(nums,tmp,left,center);
            mergeSort(nums,tmp,center+1,right);
            merge(nums,tmp,left,center+1,right);
        }
    }

    //合併
    private static void merge(int[] nums,int[] tmp,int leftPos, int rightPos, int rightEnd){
        int leftEnd = rightPos-1;
        int tmpPos = leftPos;
        int numElements = rightEnd - leftPos + 1;
    
        while(leftPos<=leftEnd&&rightPos<=rightEnd){
            if(nums[leftPos]<nums[rightPos])
                tmp[tmpPos++]=nums[leftPos++];
            else 
                tmp[tmpPos++]=nums[rightPos++];
        }
        while(leftPos<=leftEnd)
            tmp[tmpPos++]=nums[leftPos++];
        
        while(rightPos<=rightEnd)
            tmp[tmpPos++]=nums[rightPos++];
    
        for(int i = 0;i<numElements;i++,rightEnd--)
            nums[rightEnd]=tmp[rightEnd];
    }
    public static void mergeSort(int[] nums){
        int[] tmp = new int[nums.length];
        mergeSort(nums,tmp,0,nums.length-1);
    }
    
    //列印
    public static void printArr(int[] arr) {
        for(int i : arr){
            System.out.print(i+" ");
        }
    }


}


/**
 * 產生隨機資料
 */
class Data{
    int length;
    int[] data;

    public Data(int length){
        this.length = length;
        data = new int[length];
    }

    public int[] getData(){

        Random random = new Random(System.currentTimeMillis());
        for(int i=0;i<length;i++){
            data[i]=random.nextInt(2*length);
        }
        return data;
    }


}

可以看到演算法是能正常執行的

按上面思路的多執行緒版本呢?用 兩個執行緒試驗了下

只修改了main函式,加入了一個verify用作驗證排序是不是OK的,不能人眼看吧

import java.util.Random;
import java.util.concurrent.CountDownLatch;

public class Main {
    public static void main (String[] args) throws InterruptedException {
        int length = 1000;
        int[] data = (new Data(length)).getData();
        printArr(data);
        System.out.println();
        // mergeSort(data);
        //在這裡修改
        int center = data.length/2;

        int[] tmp = new int[data.length];
        CountDownLatch latch = new CountDownLatch(2);//CountDownLatch能夠使一個執行緒在等待另            
                                                    //外一些執行緒完成各自工作之後,再繼續執行
        new Thread(new Runnable(){
        
            @Override
            public void run() {
                mergeSort(data,tmp,0,center);
                latch.countDown();
            }
        }).start();

        new Thread(new Runnable(){
        
            @Override
            public void run() {
                mergeSort(data,tmp,center+1,data.length-1);
                latch.countDown();
            }
        }).start();

        latch.await();

        merge(data, tmp, 0, center+1, data.length-1);

        printArr(data);
        System.out.println();
        verify(data);
    }

    //遞迴
    private static void mergeSort(int[] nums,int[] tmp,int left,int right){
        if(left<right){
            int center = (left+right)/2;
            mergeSort(nums,tmp,left,center);
            mergeSort(nums,tmp,center+1,right);
            merge(nums,tmp,left,center+1,right);
        }
    }

    //合併
    private static void merge(int[] nums,int[] tmp,int leftPos, int rightPos, int rightEnd){
        int leftEnd = rightPos-1;
        int tmpPos = leftPos;
        int numElements = rightEnd - leftPos + 1;
    
        while(leftPos<=leftEnd&&rightPos<=rightEnd){
            if(nums[leftPos]<nums[rightPos])
                tmp[tmpPos++]=nums[leftPos++];
            else 
                tmp[tmpPos++]=nums[rightPos++];
        }
        while(leftPos<=leftEnd)
            tmp[tmpPos++]=nums[leftPos++];
        
        while(rightPos<=rightEnd)
            tmp[tmpPos++]=nums[rightPos++];
    
        for(int i = 0;i<numElements;i++,rightEnd--)
            nums[rightEnd]=tmp[rightEnd];
    }
    public static void mergeSort(int[] nums){
        int[] tmp = new int[nums.length];
        mergeSort(nums,tmp,0,nums.length-1);
    }
    
    //列印
    public static void printArr(int[] arr) {
        for(int i : arr){
            System.out.print(i+" ");
        }
    }
    
    public static void verify(int[] nums) {
        for(int i=0;i<nums.length-1;i++){
            if(nums[i]>nums[i+1]){
                System.out.println("排序失敗");
                return;
            } 

        }
        System.out.println("排序成功");
    }
    


}


/**
 * 產生隨機資料
 */
class Data{
    int length;
    int[] data;

    public Data(int length){
        this.length = length;
        data = new int[length];
    }

    public int[] getData(){

        Random random = new Random(System.currentTimeMillis());
        for(int i=0;i<length;i++){
            data[i]=random.nextInt(2*length);
        }
        return data;
    }

}

結果是OK的

上面是按自己的構思開啟的執行緒。

其實Java本身提供了更好的解決方案,就是Fork/Join框架, 貼一下這個框架的介紹

使用Fork/Join 我們需要知道兩個類:

  • ForkJoinTask:我們要使用ForkJoin框架,必須首先建立一個ForkJoin任務。它提供在任務中執行fork()和join()操作的機制,通常情況下我們不需要直接繼承ForkJoinTask類,而只需要繼承它的子類,Fork/Join框架提供了以下兩個子類:
    • RecursiveAction:用於沒有返回結果的任務。
    • RecursiveTask :用於有返回結果的任務。
  • ForkJoinPool :ForkJoinTask需要通過ForkJoinPool來執行,任務分割出的子任務會新增到當前工作執行緒所維護的雙端佇列中,進入佇列的頭部。當一個工作執行緒的佇列裡暫時沒有任務時,它會隨機從其他工作執行緒的佇列的尾部獲取一個任務。

下面看下如何用這個框架實現多執行緒歸併排序

import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;

public class Main {
    public static void main (String[] args) throws InterruptedException {
        int length = 1000;
        int[] data = (new Data(length)).getData();
        printArr(data);
        System.out.println();
        // mergeSort(data);
        //在這裡修改
        // int center = data.length/2;

        int[] tmp = new int[data.length];
        // CountDownLatch latch = new CountDownLatch(2);//CountDownLatch能夠使一個執行緒在等待另外一些執行緒完成各自工作之後,再繼續執行
        // new Thread(new Runnable(){
        
        //     @Override
        //     public void run() {
        //         mergeSort(data,tmp,0,center);
        //         latch.countDown();
        //     }
        // }).start();

        // new Thread(new Runnable(){
        
        //     @Override
        //     public void run() {
        //         mergeSort(data,tmp,center+1,data.length-1);
        //         latch.countDown();
        //     }
        // }).start();

        // latch.await();

        // merge(data, tmp, 0, center+1, data.length-1);

        //Fork/Join 從這裡開始
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        Main.mergeTask task = new Main.mergeTask(data, tmp, 0, data.length-1);//建立任務
        forkJoinPool.execute(task);//執行任務
        forkJoinPool.awaitTermination(2, TimeUnit.SECONDS);//阻塞當前執行緒直到pool中的任務都完成了

        printArr(data);
        System.out.println();
        verify(data);

    }

    //遞迴
    private static void mergeSort(int[] nums,int[] tmp,int left,int right){
        if(left<right){
            int center = (left+right)/2;
            mergeSort(nums,tmp,left,center);
            mergeSort(nums,tmp,center+1,right);
            merge(nums,tmp,left,center+1,right);
        }
    }

    //合併
    private static void merge(int[] nums,int[] tmp,int leftPos, int rightPos, int rightEnd){
        int leftEnd = rightPos-1;
        int tmpPos = leftPos;
        int numElements = rightEnd - leftPos + 1;
    
        while(leftPos<=leftEnd&&rightPos<=rightEnd){
            if(nums[leftPos]<nums[rightPos])
                tmp[tmpPos++]=nums[leftPos++];
            else 
                tmp[tmpPos++]=nums[rightPos++];
        }
        while(leftPos<=leftEnd)
            tmp[tmpPos++]=nums[leftPos++];
        
        while(rightPos<=rightEnd)
            tmp[tmpPos++]=nums[rightPos++];
    
        for(int i = 0;i<numElements;i++,rightEnd--)
            nums[rightEnd]=tmp[rightEnd];
    }
    public static void mergeSort(int[] nums){
        int[] tmp = new int[nums.length];
        mergeSort(nums,tmp,0,nums.length-1);
    }
    
    //列印
    public static void printArr(int[] arr) {
        for(int i : arr){
            System.out.print(i+" ");
        }
    }
    
    public static void verify(int[] nums) {
        for(int i=0;i<nums.length-1;i++){
            if(nums[i]>nums[i+1]){
                System.out.println("排序失敗");
                return;
            } 

        }
        System.out.println("排序成功");
    }
    

    static class mergeTask extends RecursiveAction {
        private static final int THRESHOLD = 2;//設定任務大小閾值
        private int start;
        private int end;
        private int[] data;
        private int[] tmp;
    
        public mergeTask(int[] data, int[] tmp, int start, int end){
            this.data = data;
            this.tmp = tmp;
            this.start = start;
            this.end = end;
        }
    
        @Override
        protected void compute(){
            if((end - start)<=THRESHOLD){
                mergeSort(data,tmp,start,end);
            }else{
                int center = (start + end)/2;
                Main.mergeTask leftTask = new Main.mergeTask(data, tmp, start, center);
                Main.mergeTask rightTask = new Main.mergeTask(data, tmp, center+1, end);

                leftTask.fork();
                rightTask.fork();

                leftTask.join();
                rightTask.join();

                merge(data, tmp, start, center+1, end);

            }
        }
    }

}


/**
 * 產生隨機資料
 */
class Data{
    int length;
    int[] data;

    public Data(int length){
        this.length = length;
        data = new int[length];
    }

    public int[] getData(){

        Random random = new Random(System.currentTimeMillis());
        for(int i=0;i<length;i++){
            data[i]=random.nextInt(2*length);
        }
        return data;
    }

}

結果也是OK的

以上都沒有涉及到鎖,雖然操作的是共享的陣列,但是被讀寫的區域是被隔離開的。

也是在演算法導論上瞟到多執行緒演算法這麼一章,順藤摸瓜才知道有Fork/Join 這個東西,要學的東西真的多。

搞完這個我又聯想到之前看過的一道演算法題:

在大量的資料中,尋找最大的k個數,或者是出現次數最多的k個數據,比如說這個資料有10個G,放在一個大檔案中,電腦記憶體4G。

解題思路就是先把這個檔案分塊,為了確保相同的資料在一個塊中,通過計算Hash值來分塊,相同Hash 放到一個塊中。比如每分100個塊,這樣平均一個塊就在100M左右,對每個塊分別載入記憶體找最大的前K個數或者出現最多的前K個數據,最後比較這100*K個數據來得到結果。

怎麼用多執行緒求解?