1. 程式人生 > >CountDownLatch 和 CyclicBarrier 的運用及實現原理

CountDownLatch 和 CyclicBarrier 的運用及實現原理

I.CountDownLatch 和 CyclicBarrier 的運用

CountDownlatch:

定義: 其是一個執行緒同步的輔助工具,通過它可以做到使一條執行緒一直阻塞等待,直到其他執行緒完成其所處理的任務。一個特性就是它不要求呼叫countDown方法的執行緒等到計數到達0時才繼續,而在所有執行緒都能通過之前,它只是阻止任何執行緒繼續通過一個await

用法:用給定的計數初始化CountDownLath。呼叫countDown()方法計數減 1,在計數被減到 0之前,呼叫await方法會一直阻塞。減為 0之後,則會迅速釋放所有阻塞等待的執行緒,並且呼叫await操作會立即返回。

場景:(1)將CountDownLatch 的計數置為 1,此時CountDownLath 可以用作一個肩帶的開/關鎖存器或入口,在通過呼叫countDown()的執行緒開啟入口前,所有呼叫await的執行緒會一直在入口處等待。(2)用 N (N >= 1) 初始化的CountDownLatch 可以是一條執行緒在N個執行緒完成某項操作之前一直等待,或者使其在某項操作完成 N 次之前一直等待。

ps:CountDownLath計數無法被重置,如果需要重置計數,請考慮使用CyclicBarrier.

實踐: 下面用程式碼實現10條執行緒分別計算一組數字,要求者10條執行緒邏輯上同時開始計算(其實並不能做到同時,CPU核不夠,不能達到平行計算),並且10條執行緒中如果有任何一條執行緒沒有計算完成之前,誰都不允許提前返回。

MyCalculator.java:

package simple.demo;

import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
/**
 * @author jianying.wcj
 * @date 2013-8-2
 */
public class MyCalculator implements Callable<Integer> {
/**
 * 開始開關
 */
private CountDownLatch startSwitch;
/**
 * 結束開關
 */
private CountDownLatch stopSwitch;
/**
 * 要計算的分組數
 */
private int groupNum; 
/**
 * 建構函式
 */
public MyCalculator(CountDownLatch startSwitch,CountDownLatch stopSwitch,Integer groupNum) {
    this.startSwitch = startSwitch;
    this.stopSwitch = stopSwitch;
    this.groupNum = groupNum;
}

@Override
public Integer call() throws Exception {

    startSwitch.await();
    int res = compute();
    System.out.println(Thread.currentThread().getName()+" is ok wait other thread...");
    stopSwitch.countDown();
    stopSwitch.await();
    System.out.println(Thread.currentThread().getName()+" is stop! the group"+groupNum+" temp result is sum="+res);
    return res;
}
/**
 * 累計求和
 * @return
 * @throws InterruptedException 
 */
public int compute() throws InterruptedException {
    int sum = 0;
    for(int i = (groupNum - 1)*10+1; i <= groupNum * 10; i++) {
        sum += i;
    }
    return sum;
}    }    

MyTest.java:

package simple.demo;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

public class MyTest {

private int groupNum = 10;
/**
 * 開始和結束開關
 */
private CountDownLatch startSwitch = new CountDownLatch(1);

private CountDownLatch stopSwitch = new CountDownLatch(groupNum);
/**
 * 執行緒池
 */
private ExecutorService service = Executors.newFixedThreadPool(groupNum);
/**
 * 儲存計算結果
 */
private List<Future<Integer>> result = new ArrayList<Future<Integer>>();
/**
 * 啟動groupNum條執行緒計算數值
 */
public void init() {

    for(int i = 1; i <= groupNum; i++) {
        result.add(service.submit(new MyCalculator(startSwitch,stopSwitch,i)));
    }
    System.out.println("init is ok!");
}

public void printRes() throws InterruptedException, ExecutionException {

    int sum = 0;

    for(Future<Integer> f : result) {
        sum += f.get();
    }
    System.out.println("the result is "+sum);
}

public void start() {
    this.startSwitch.countDown();
}

public void stop() throws InterruptedException {
    this.stopSwitch.await();
    this.service.shutdown();
}

public static void main(String[] args) throws IOException, InterruptedException, ExecutionException {

    BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));

    MyTest myTest = new MyTest();
    myTest.init();
    System.out.println("please enter start command....");

    reader.readLine();
    myTest.start();
    myTest.stop();

    myTest.printRes();
} }

執行結果:

init is ok!
please enter start command....

pool-1-thread-1 is ok wait other thread...
pool-1-thread-2 is ok wait other thread...
pool-1-thread-3 is ok wait other thread...
pool-1-thread-4 is ok wait other thread...
pool-1-thread-6 is ok wait other thread...
pool-1-thread-5 is ok wait other thread...
pool-1-thread-8 is ok wait other thread...
pool-1-thread-7 is ok wait other thread...
pool-1-thread-9 is ok wait other thread...
pool-1-thread-10 is ok wait other thread...
pool-1-thread-10 is stop! the group10 temp result is sum=955
pool-1-thread-1 is stop! the group1 temp result is sum=55
pool-1-thread-2 is stop! the group2 temp result is sum=155
pool-1-thread-3 is stop! the group3 temp result is sum=255
pool-1-thread-4 is stop! the group4 temp result is sum=355
pool-1-thread-6 is stop! the group6 temp result is sum=555
pool-1-thread-5 is stop! the group5 temp result is sum=455
pool-1-thread-8 is stop! the group8 temp result is sum=755
pool-1-thread-7 is stop! the group7 temp result is sum=655
pool-1-thread-9 is stop! the group9 temp result is sum=855
the result is 5050

CyclicBarrier.java:

定義:其是一個同步輔助類,它允許一組執行緒互相等待,直到到達某個公共的屏障點,所有執行緒一起繼續執行或者返回。一個特性就是CyclicBarrier支援一個可選的Runnable命令,在一組執行緒中的最後一個執行緒到達之後,該命令只在每個屏障點執行一次。若在繼續所有參與執行緒之前更新此共享狀態,此屏障操作很有用。

用法:用計數 N 初始化CyclicBarrier, 每呼叫一次await,執行緒阻塞,並且計數+1(計數起始是0),當計數增長到指定計數N時,所有阻塞執行緒會被喚醒。繼續呼叫await也將迅速返回。

場景:用N初始化CyclicBarrier,可以在N執行緒中分佈呼叫await方法,可以控制N調執行緒都執行到await方法後,一起繼續執行。

實踐:和CountDownLatch實踐相同,見上文:

MyCalculator.java:

package simple.demo;

import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;

public class MyCalculator implements Callable<Integer> {
/**
 * 開始開關
 */
private CyclicBarrier startSwitch;
/**
 * 結束開關
 */
private CyclicBarrier stopSwitch;
/**
 * 要計算的分組數
 */
private int groupNum; 
/**
 * 建構函式
 */
public MyCalculator(CyclicBarrier startSwitch,CyclicBarrier stopSwitch,Integer groupNum) {
    this.startSwitch = startSwitch;
    this.stopSwitch = stopSwitch;
    this.groupNum = groupNum;
}

@Override
public Integer call() throws Exception {

    startSwitch.await();
    int res = compute();
    System.out.println(Thread.currentThread().getName()+" is ok wait other thread...");
    stopSwitch.await();
    System.out.println(Thread.currentThread().getName()+" is stop! the group"+groupNum+" temp result is sum="+res);
    return res;
}
/**
 * 累計求和
 * @return
 * @throws InterruptedException 
 */
public int compute() throws InterruptedException {
    int sum = 0;
    for(int i = (groupNum - 1)*10+1; i <= groupNum * 10; i++) {
        sum += i;
    }
    return sum;
}}

MyTest.java:

package simple.demo;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

public class MyTest {

private int groupNum = 10;
/**
 * 開始和結束開關
 */
private CyclicBarrier startSwitch = new CyclicBarrier(groupNum+1);

private CyclicBarrier stopSwitch = new CyclicBarrier(groupNum);
/**
 * 執行緒池
 */
private ExecutorService service = Executors.newFixedThreadPool(groupNum);
/**
 * 儲存計算結果
 */
private List<Future<Integer>> result = new ArrayList<Future<Integer>>();
/**
 * 啟動groupNum條執行緒計算數值
 */
public void init() {

    for(int i = 1; i <= groupNum; i++) {
        result.add(service.submit(new MyCalculator(startSwitch,stopSwitch,i)));
    }
    System.out.println("init is ok!");
}

public void printRes() throws InterruptedException, ExecutionException {

    int sum = 0;

    for(Future<Integer> f : result) {
        sum += f.get();
    }
    System.out.println("the result is "+sum);
}

public void start() throws InterruptedException, BrokenBarrierException {
    this.startSwitch.await();
}

public void stop() throws InterruptedException {

    this.service.shutdown();
}

public static void main(String[] args) throws IOException, InterruptedException, ExecutionException, BrokenBarrierException {

    BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));

    MyTest myTest = new MyTest();
    myTest.init();
    System.out.println("please enter start command....");

    reader.readLine();

    myTest.start();
    myTest.stop();

        myTest.printRes();
    }
}

執行結果:

init is ok!
please enter start command....

pool-1-thread-1 is ok wait other thread...
pool-1-thread-2 is ok wait other thread...
pool-1-thread-3 is ok wait other thread...
pool-1-thread-4 is ok wait other thread...
pool-1-thread-5 is ok wait other thread...
pool-1-thread-6 is ok wait other thread...
pool-1-thread-7 is ok wait other thread...
pool-1-thread-8 is ok wait other thread...
pool-1-thread-9 is ok wait other thread...
pool-1-thread-10 is ok wait other thread...
pool-1-thread-10 is stop! the group10 temp result is sum=955
pool-1-thread-1 is stop! the group1 temp result is sum=55
pool-1-thread-2 is stop! the group2 temp result is sum=155
pool-1-thread-3 is stop! the group3 temp result is sum=255
pool-1-thread-5 is stop! the group5 temp result is sum=455
pool-1-thread-6 is stop! the group6 temp result is sum=555
pool-1-thread-4 is stop! the group4 temp result is sum=355
pool-1-thread-8 is stop! the group8 temp result is sum=755
pool-1-thread-7 is stop! the group7 temp result is sum=655
pool-1-thread-9 is stop! the group9 temp result is sum=855
the result is 5050

II.CountDownLatch 和 CyclicBarrier的實現原理

CountDownLatch的類圖如下:



CountDownLatch的實現是基於AQS的,其實現了一個sync的內部類,而sync繼承了AQS。關鍵的原始碼如下:
await方法

 /**
 * Causes the current thread to wait until the latch has counted down to
 * zero, unless the thread is {@linkplain Thread#interrupt interrupted}.
 *
 * <p>If the current count is zero then this method returns immediately.
 *
 * <p>If the current count is greater than zero then the current
 * thread becomes disabled for thread scheduling purposes and lies
 * dormant until one of two things happen:
 * <ul>
 * <li>The count reaches zero due to invocations of the
 * {@link #countDown} method; or
 * <li>Some other thread {@linkplain Thread#interrupt interrupts}
 * the current thread.
 * </ul>
 *
 * <p>If the current thread:
 * <ul>
 * <li>has its interrupted status set on entry to this method; or
 * <li>is {@linkplain Thread#interrupt interrupted} while waiting,
 * </ul>
 * then {@link InterruptedException} is thrown and the current thread's
 * interrupted status is cleared.
 *
 * @throws InterruptedException if the current thread is interrupted
 *         while waiting
 */
public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

CyclicBarrier的類圖如下:



/**
 * Decrements the count of the latch, releasing all waiting threads if
 * the count reaches zero.
 *
 * <p>If the current count is greater than zero then it is decremented.
 * If the new count is zero then all waiting threads are re-enabled for
 * thread scheduling purposes.
 *
 * <p>If the current count equals zero then nothing happens.
 */
public void countDown() {
     sync.releaseShared(1);
}

以上是CountDownLatch的兩個關鍵方法 await 和 countDown 的定義。具體的方法通過註釋能夠理解,其實CountDownLatch只是簡單的利用了 AQS 的 state 屬性(表示鎖可重入的次數),CountDownLatch 的內部類 sync 重寫了 AQS 的 tryAcquireShared,CountDownLatch 的 tryAcquireShared 方法的定義是:

public int tryAcquireShared(int acquires) {
    return getState() == 0? 1 : -1;
}

state的初始值就是初始化 CountDownLatch 時的計數器,在 sync 呼叫 AQS 的 acquireSharedInterruptibly的時候會判斷 tryAcquireShared(int acquires) 是否大於 0,如果小於 0,會將執行緒掛起。具體的AQS當中掛起執行緒的方法是:

 /**
 * Acquires in shared interruptible mode.
 * @param arg the acquire argument
 */
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
 final Node node = addWaiter(Node.SHARED);
try {
    for (;;) {
        final Node p = node.predecessor();
        if (p == head) {
            int r = tryAcquireShared(arg);
            if (r >= 0) {
            setHeadAndPropagate(node, r);
            p.next = null; // help GC
            return;
        }
    }
if (shouldParkAfterFailedAcquire(p, node) &&
    parkAndCheckInterrupt())
    break;
}
} catch (RuntimeException ex) {
    cancelAcquire(node);
    throw ex;
}
// Arrive here only if interrupted
    cancelAcquire(node);
    throw new InterruptedException();
}

在CountDownLatch呼叫countDown方法時,會呼叫CountDownLatch中內部類sync重寫AQS的方法tryReleaseShared,方法的定義如下:

public boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    for (;;) {
         int c = getState();
     if (c == 0)
        return false;
        int nextc = c-1;
    if (compareAndSetState(c, nextc))
         return nextc == 0;
    }
}

可見沒呼叫一次都會將state減1,直到等於 0。CountDownLatch就先說這麼多。

CyclicBarrier的類圖如下:



CyclicBarrier的實現是基於ReentrantLock的,而ReentrantLock是基於AQS的,說白了CyclicBarrier最終還是基於AQS的。CyclicBarrier內部使用ReentrantLock的Condition來喚醒柵欄前的執行緒,關鍵原始碼如下:
await方法:

/**
 * Waits until all {@linkplain #getParties parties} have invoked
 * <tt>await</tt> on this barrier.
 *
 * <p>If the current thread is not the last to arrive then it is
 * disabled for thread scheduling purposes and lies dormant until
 * one of the following things happens:
 * <ul>
 * <li>The last thread arrives; or
 * <li>Some other thread {@linkplain Thread#interrupt interrupts}
 * the current thread; or
 * <li>Some other thread {@linkplain Thread#interrupt interrupts}
 * one of the other waiting threads; or
 * <li>Some other thread times out while waiting for barrier; or
 * <li>Some other thread invokes {@link #reset} on this barrier.
 * </ul>
 *
 * <p>If the current thread:
 * <ul>
 * <li>has its interrupted status set on entry to this method; or
 * <li>is {@linkplain Thread#interrupt interrupted} while waiting
 * </ul>
 * then {@link InterruptedException} is thrown and the current thread's
 * interrupted status is cleared.
 *
 * <p>If the barrier is {@link #reset} while any thread is waiting,
 * or if the barrier {@linkplain #isBroken is broken} when
 * <tt>await</tt> is invoked, or while any thread is waiting, then
 * {@link BrokenBarrierException} is thrown.
 *
 * <p>If any thread is {@linkplain Thread#interrupt interrupted} while waiting,
 * then all other waiting threads will throw
 * {@link BrokenBarrierException} and the barrier is placed in the broken
 * state.
 *
 * <p>If the current thread is the last thread to arrive, and a
 * non-null barrier action was supplied in the constructor, then the
 * current thread runs the action before allowing the other threads to
 * continue.
 * If an exception occurs during the barrier action then that exception
 * will be propagated in the current thread and the barrier is placed in
 * the broken state.
 *
 * @return the arrival index of the current thread, where index
 * <tt>{@link #getParties()} - 1</tt> indicates the first
 * to arrive and zero indicates the last to arrive
 * @throws InterruptedException if the current thread was interrupted
 * while waiting
 * @throws BrokenBarrierException if <em>another</em> thread was
 * interrupted or timed out while the current thread was
 * waiting, or the barrier was reset, or the barrier was
 * broken when {@code await} was called, or the barrier
 * action (if present) failed due an exception.
 */
public int await() throws InterruptedException, BrokenBarrierException {
    try {
      return dowait(false, 0L);
    } catch (TimeoutException toe) {
      throw new Error(toe); // cannot happen;
    }
}

私有的 dowait 方法:

 /**
 * Main barrier code, covering the various policies.
 */
private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
         TimeoutException {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        final Generation g = generation;

        if (g.broken)
            throw new BrokenBarrierException();

        if (Thread.interrupted()) {
            breakBarrier();
            throw new InterruptedException();
        }

       int index = --count;
       if (index == 0) {  // tripped
           boolean ranAction = false;
           try {
       final Runnable command = barrierCommand;
               if (command != null)
                   command.run();
               ranAction = true;
               nextGeneration();
               return 0;
           } finally {
               if (!ranAction)
                   breakBarrier();
           }
       }

        // loop until tripped, broken, interrupted, or timed out
        for (;;) {
            try {
                if (!timed)
                    trip.await();
                else if (nanos > 0L)
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
                if (g == generation && ! g.broken) {
                    breakBarrier();
        throw ie;
        } else {
        // We're about to finish waiting even if we had not
        // been interrupted, so this interrupt is deemed to
        // "belong" to subsequent execution.
        Thread.currentThread().interrupt();
        }
            }

            if (g.broken)
                throw new BrokenBarrierException();

            if (g != generation)
                return index;

            if (timed && nanos <= 0L) {
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        lock.unlock();
    }
}   

從doAwait方法中可以看到,沒呼叫一次index 會減1,當減為 0時,會呼叫 breakBarrier()方法。 breakBarrier方法的實現是:

 /**
 * Sets current barrier generation as broken and wakes up everyone.
 * Called only while holding lock.
 */
private void breakBarrier() {
   generation.broken = true;
   count = parties;
   trip.signalAll();
}

會呼叫 trip.signalAll()喚醒所有的執行緒(trip的定義 Condition trip = lock.newCondition())。可見 CyclicBarrier 是對獨佔鎖 ReentrantLock 的簡單利用。