1. 程式人生 > >java 併發之CountDownLatch

java 併發之CountDownLatch

概述

CountDownLatch是一個用來控制併發的很常見的工具,它允許一個或者多個執行緒等待其他的執行緒執行到某一操作,比如說需要去解析一個excel的資料,為了更快的解析則每個sheet都使用一個執行緒去進行解析,但是最後的彙總資料的工作則需要等待每個sheet的解析工作完成之後才能進行,這就可以使用CountDownLatch。

使用

例子:
這裡有三個執行緒(main,thread1,thread2),其中main執行緒將呼叫countDownLatch的await方法去等待另外兩個執行緒的某個操作的結束(呼叫countDownLatch的countDown方法)。

public class CountDownLatchDemo {

    public static void main(String[] args) throws InterruptedException{

        CountDownLatch countDownLatch = new CountDownLatch(2){
            @Override
            public void await() throws InterruptedException {
                super.await();
                System.out
.println(Thread.currentThread().getName() + " count down is ok"); } }; Thread thread1 = new Thread(new Runnable() { @Override public void run() { //do something try { Thread.sleep(1000); } catch
(InterruptedException e) { e.printStackTrace(); } System.out.println(Thread.currentThread().getName() + " is done"); countDownLatch.countDown(); } }, "thread1"); Thread thread2 = new Thread(new Runnable() { @Override public void run() { //do something try { Thread.sleep(2000); } catch (InterruptedException e) { e.printStackTrace(); } System.out.println(Thread.currentThread().getName() + " is done"); countDownLatch.countDown(); } }, "thread2"); thread1.start(); thread2.start(); countDownLatch.await(); } }

這裡的CountDownLatch的建構函式中使用的int型變數的意思是需要等待多少個操作 的完成。這裡是2所以需要等到呼叫了兩次countDown()方法之後主執行緒的await()方法才會返回。這意味著如果我們錯誤的估計了需要等待的操作的個數或者在某個應該呼叫countDown()方法的地方忘記了呼叫那麼將意味著await()方法將永遠的阻塞下去。

實現原理

CountDownLatch類實際上是使用計數器的方式去控制的,不難想象當我們初始化CountDownLatch的時候傳入了一個int變數這個時候在類的內部初始化一個int的變數,每當我們呼叫countDownt()方法的時候就使得這個變數的值減1,而對於await()方法則去判斷這個int的變數的值是否為0,是則表示所有的操作都已經完成,否則繼續等待。
實際上如果瞭解AQS的話應該很容易想到可以使用AQS的共享式獲取同步狀態的方式來完成這個功能。而CountDownLatch實際上也就是這麼做的。
這裡寫圖片描述
從結構上來看CountDownLatch的實現還是很簡單的,通過很常見的繼承AQS的方式來完成自己的同步器。

CountDownLatch的同步器實現:

private static final class Sync extends AbstractQueuedSynchronizer {
    private static final long serialVersionUID = 4982264981922014374L;
    //初始化state
    Sync(int count) {
        setState(count);
    }

    int getCount() {
        return getState();
    }
    //嘗試獲取同步狀態
    //只有當同步狀態為0的時候返回大於0的數1
    //同步狀態不為0則返回-1
    protected int tryAcquireShared(int acquires) {
        return (getState() == 0) ? 1 : -1;
    }
    //自旋+CAS的方式釋放同步狀態
    protected boolean tryReleaseShared(int releases) {
        // Decrement count; signal when transition to zero
        for (;;) {
            int c = getState();
            if (c == 0)
                return false;
            int nextc = c-1;
            if (compareAndSetState(c, nextc))
                return nextc == 0;
        }
    }
}

比較關鍵的地方是tryAquireShared()方法的實現,因為在父類的AQS中aquireShared()方法在呼叫tryAquireShared()方法的時候的判斷依據是返回值是否大於零。


public final void acquireShared(int arg) {
    if (tryAcquireShared(arg) < 0)
        //失敗則進入等待佇列
        doAcquireShared(arg);
}

同步器的實現相對都比較簡單,主要思路和上面基本一致。

CountDownLatch的主要方法(本身程式碼量就很少就直接貼了)


public class CountDownLatch {
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
            setState(count);
        }

        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

    private final Sync sync;

    //初始化一個同步器
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }
    //呼叫同步器的acquireSharedInterruptibly方法
    //並且是響應中斷的
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    //呼叫同步器的releaseShared方法去讓state減1
    public void countDown() {
        sync.releaseShared(1);
    }
    //獲取剩餘的count
    public long getCount() {
        return sync.getCount();
    }

    public String toString() {
        return super.toString() + "[Count = " + sync.getCount() + "]";
    }
}