CountDownLatch 和 CyclicBarrier 的运用及实现原理

Posted on 2013-08-03 21:35  冰天雪域  阅读(263)  评论(0编辑  收藏  举报

I.CountDownLatch 和 CyclicBarrier 的运用

CountDownlatch:

定义: 其是一个线程同步的辅助工具,通过它可以做到使一条线程一直阻塞等待,直到其他线程完成其所处理的任务。一个特性就是它不要求调用countDown方法的线程等到计数到达0时才继续,而在所有线程都能通过之前,它只是阻止任何线程继续通过一个await

用法:用给定的计数初始化CountDownLath。调用countDown()方法计数减 1,在计数被减到 0之前,调用await方法会一直阻塞。减为 0之后,则会迅速释放所有阻塞等待的线程,并且调用await操作会立即返回。

场景:(1)将CountDownLatch 的计数置为 1,此时CountDownLath 可以用作一个肩带的开/关锁存器或入口,在通过调用countDown()的线程打开入口前,所有调用await的线程会一直在入口处等待。(2)用 N (N >= 1) 初始化的CountDownLatch 可以是一条线程在N个线程完成某项操作之前一直等待,或者使其在某项操作完成 N 次之前一直等待。

ps:CountDownLath计数无法被重置,如果需要重置计数,请考虑使用CyclicBarrier.

实践: 下面用代码实现10条线程分别计算一组数字,要求者10条线程逻辑上同时开始计算(其实并不能做到同时,CPU核不够,不能达到并行计算),并且10条线程中如果有任何一条线程没有计算完成之前,谁都不允许提前返回。

MyCalculator.java:

package simple.demo;

import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
/**
 * @author jianying.wcj
 * @date 2013-8-2
 */
public class MyCalculator implements Callable<Integer> {
/**
 * 开始开关
 */
private CountDownLatch startSwitch;
/**
 * 结束开关
 */
private CountDownLatch stopSwitch;
/**
 * 要计算的分组数
 */
private int groupNum; 
/**
 * 构造函数
 */
public MyCalculator(CountDownLatch startSwitch,CountDownLatch stopSwitch,Integer groupNum) {
    this.startSwitch = startSwitch;
    this.stopSwitch = stopSwitch;
    this.groupNum = groupNum;
}

@Override
public Integer call() throws Exception {

    startSwitch.await();
    int res = compute();
    System.out.println(Thread.currentThread().getName()+" is ok wait other thread...");
    stopSwitch.countDown();
    stopSwitch.await();
    System.out.println(Thread.currentThread().getName()+" is stop! the group"+groupNum+" temp result is sum="+res);
    return res;
}
/**
 * 累计求和
 * @return
 * @throws InterruptedException 
 */
public int compute() throws InterruptedException {
    int sum = 0;
    for(int i = (groupNum - 1)*10+1; i <= groupNum * 10; i++) {
        sum += i;
    }
    return sum;
}    }    

MyTest.java:

package simple.demo;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

public class MyTest {

private int groupNum = 10;
/**
 * 开始和结束开关
 */
private CountDownLatch startSwitch = new CountDownLatch(1);

private CountDownLatch stopSwitch = new CountDownLatch(groupNum);
/**
 * 线程池
 */
private ExecutorService service = Executors.newFixedThreadPool(groupNum);
/**
 * 保存计算结果
 */
private List<Future<Integer>> result = new ArrayList<Future<Integer>>();
/**
 * 启动groupNum条线程计算数值
 */
public void init() {

    for(int i = 1; i <= groupNum; i++) {
        result.add(service.submit(new MyCalculator(startSwitch,stopSwitch,i)));
    }
    System.out.println("init is ok!");
}

public void printRes() throws InterruptedException, ExecutionException {

    int sum = 0;

    for(Future<Integer> f : result) {
        sum += f.get();
    }
    System.out.println("the result is "+sum);
}

public void start() {
    this.startSwitch.countDown();
}

public void stop() throws InterruptedException {
    this.stopSwitch.await();
    this.service.shutdown();
}

public static void main(String[] args) throws IOException, InterruptedException, ExecutionException {

    BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));

    MyTest myTest = new MyTest();
    myTest.init();
    System.out.println("please enter start command....");

    reader.readLine();
    myTest.start();
    myTest.stop();

    myTest.printRes();
} }

运行结果:

init is ok!
please enter start command....

pool-1-thread-1 is ok wait other thread...
pool-1-thread-2 is ok wait other thread...
pool-1-thread-3 is ok wait other thread...
pool-1-thread-4 is ok wait other thread...
pool-1-thread-6 is ok wait other thread...
pool-1-thread-5 is ok wait other thread...
pool-1-thread-8 is ok wait other thread...
pool-1-thread-7 is ok wait other thread...
pool-1-thread-9 is ok wait other thread...
pool-1-thread-10 is ok wait other thread...
pool-1-thread-10 is stop! the group10 temp result is sum=955
pool-1-thread-1 is stop! the group1 temp result is sum=55
pool-1-thread-2 is stop! the group2 temp result is sum=155
pool-1-thread-3 is stop! the group3 temp result is sum=255
pool-1-thread-4 is stop! the group4 temp result is sum=355
pool-1-thread-6 is stop! the group6 temp result is sum=555
pool-1-thread-5 is stop! the group5 temp result is sum=455
pool-1-thread-8 is stop! the group8 temp result is sum=755
pool-1-thread-7 is stop! the group7 temp result is sum=655
pool-1-thread-9 is stop! the group9 temp result is sum=855
the result is 5050

CyclicBarrier.java:

定义:其是一个同步辅助类,它允许一组线程互相等待,直到到达某个公共的屏障点,所有线程一起继续执行或者返回。一个特性就是CyclicBarrier支持一个可选的Runnable命令,在一组线程中的最后一个线程到达之后,该命令只在每个屏障点运行一次。若在继续所有参与线程之前更新此共享状态,此屏障操作很有用。

用法:用计数 N 初始化CyclicBarrier, 每调用一次await,线程阻塞,并且计数+1(计数起始是0),当计数增长到指定计数N时,所有阻塞线程会被唤醒。继续调用await也将迅速返回。

场景:用N初始化CyclicBarrier,可以在N线程中分布调用await方法,可以控制N调线程都执行到await方法后,一起继续执行。

实践:和CountDownLatch实践相同,见上文:

MyCalculator.java:

package simple.demo;

import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;

public class MyCalculator implements Callable<Integer> {
/**
 * 开始开关
 */
private CyclicBarrier startSwitch;
/**
 * 结束开关
 */
private CyclicBarrier stopSwitch;
/**
 * 要计算的分组数
 */
private int groupNum; 
/**
 * 构造函数
 */
public MyCalculator(CyclicBarrier startSwitch,CyclicBarrier stopSwitch,Integer groupNum) {
    this.startSwitch = startSwitch;
    this.stopSwitch = stopSwitch;
    this.groupNum = groupNum;
}

@Override
public Integer call() throws Exception {

    startSwitch.await();
    int res = compute();
    System.out.println(Thread.currentThread().getName()+" is ok wait other thread...");
    stopSwitch.await();
    System.out.println(Thread.currentThread().getName()+" is stop! the group"+groupNum+" temp result is sum="+res);
    return res;
}
/**
 * 累计求和
 * @return
 * @throws InterruptedException 
 */
public int compute() throws InterruptedException {
    int sum = 0;
    for(int i = (groupNum - 1)*10+1; i <= groupNum * 10; i++) {
        sum += i;
    }
    return sum;
}}

MyTest.java:

package simple.demo;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

public class MyTest {

private int groupNum = 10;
/**
 * 开始和结束开关
 */
private CyclicBarrier startSwitch = new CyclicBarrier(groupNum+1);

private CyclicBarrier stopSwitch = new CyclicBarrier(groupNum);
/**
 * 线程池
 */
private ExecutorService service = Executors.newFixedThreadPool(groupNum);
/**
 * 保存计算结果
 */
private List<Future<Integer>> result = new ArrayList<Future<Integer>>();
/**
 * 启动groupNum条线程计算数值
 */
public void init() {

    for(int i = 1; i <= groupNum; i++) {
        result.add(service.submit(new MyCalculator(startSwitch,stopSwitch,i)));
    }
    System.out.println("init is ok!");
}

public void printRes() throws InterruptedException, ExecutionException {

    int sum = 0;

    for(Future<Integer> f : result) {
        sum += f.get();
    }
    System.out.println("the result is "+sum);
}

public void start() throws InterruptedException, BrokenBarrierException {
    this.startSwitch.await();
}

public void stop() throws InterruptedException {

    this.service.shutdown();
}

public static void main(String[] args) throws IOException, InterruptedException, ExecutionException, BrokenBarrierException {

    BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));

    MyTest myTest = new MyTest();
    myTest.init();
    System.out.println("please enter start command....");

    reader.readLine();

    myTest.start();
    myTest.stop();

        myTest.printRes();
    }
}

运行结果:

init is ok!
please enter start command....

pool-1-thread-1 is ok wait other thread...
pool-1-thread-2 is ok wait other thread...
pool-1-thread-3 is ok wait other thread...
pool-1-thread-4 is ok wait other thread...
pool-1-thread-5 is ok wait other thread...
pool-1-thread-6 is ok wait other thread...
pool-1-thread-7 is ok wait other thread...
pool-1-thread-8 is ok wait other thread...
pool-1-thread-9 is ok wait other thread...
pool-1-thread-10 is ok wait other thread...
pool-1-thread-10 is stop! the group10 temp result is sum=955
pool-1-thread-1 is stop! the group1 temp result is sum=55
pool-1-thread-2 is stop! the group2 temp result is sum=155
pool-1-thread-3 is stop! the group3 temp result is sum=255
pool-1-thread-5 is stop! the group5 temp result is sum=455
pool-1-thread-6 is stop! the group6 temp result is sum=555
pool-1-thread-4 is stop! the group4 temp result is sum=355
pool-1-thread-8 is stop! the group8 temp result is sum=755
pool-1-thread-7 is stop! the group7 temp result is sum=655
pool-1-thread-9 is stop! the group9 temp result is sum=855
the result is 5050

II.CountDownLatch 和 CyclicBarrier的实现原理

CountDownLatch的类图如下:



CountDownLatch的实现是基于AQS的,其实现了一个sync的内部类,而sync继承了AQS。关键的源代码如下:
await方法

 /**
 * Causes the current thread to wait until the latch has counted down to
 * zero, unless the thread is {@linkplain Thread#interrupt interrupted}.
 *
 * <p>If the current count is zero then this method returns immediately.
 *
 * <p>If the current count is greater than zero then the current
 * thread becomes disabled for thread scheduling purposes and lies
 * dormant until one of two things happen:
 * <ul>
 * <li>The count reaches zero due to invocations of the
 * {@link #countDown} method; or
 * <li>Some other thread {@linkplain Thread#interrupt interrupts}
 * the current thread.
 * </ul>
 *
 * <p>If the current thread:
 * <ul>
 * <li>has its interrupted status set on entry to this method; or
 * <li>is {@linkplain Thread#interrupt interrupted} while waiting,
 * </ul>
 * then {@link InterruptedException} is thrown and the current thread's
 * interrupted status is cleared.
 *
 * @throws InterruptedException if the current thread is interrupted
 *         while waiting
 */
public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

CyclicBarrier的类图如下:



/**
 * Decrements the count of the latch, releasing all waiting threads if
 * the count reaches zero.
 *
 * <p>If the current count is greater than zero then it is decremented.
 * If the new count is zero then all waiting threads are re-enabled for
 * thread scheduling purposes.
 *
 * <p>If the current count equals zero then nothing happens.
 */
public void countDown() {
     sync.releaseShared(1);
}

以上是CountDownLatch的两个关键方法 await 和 countDown 的定义。具体的方法通过注释能够理解,其实CountDownLatch只是简单的利用了 AQS 的 state 属性(表示锁可重入的次数),CountDownLatch 的内部类 sync 重写了 AQS 的 tryAcquireShared,CountDownLatch 的 tryAcquireShared 方法的定义是:

public int tryAcquireShared(int acquires) {
    return getState() == 0? 1 : -1;
}

state的初始值就是初始化 CountDownLatch 时的计数器,在 sync 调用 AQS 的 acquireSharedInterruptibly的时候会判断 tryAcquireShared(int acquires) 是否大于 0,如果小于 0,会将线程挂起。具体的AQS当中挂起线程的方法是:

 /**
 * Acquires in shared interruptible mode.
 * @param arg the acquire argument
 */
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
 final Node node = addWaiter(Node.SHARED);
try {
    for (;;) {
        final Node p = node.predecessor();
        if (p == head) {
            int r = tryAcquireShared(arg);
            if (r >= 0) {
            setHeadAndPropagate(node, r);
            p.next = null; // help GC
            return;
        }
    }
if (shouldParkAfterFailedAcquire(p, node) &&
    parkAndCheckInterrupt())
    break;
}
} catch (RuntimeException ex) {
    cancelAcquire(node);
    throw ex;
}
// Arrive here only if interrupted
    cancelAcquire(node);
    throw new InterruptedException();
}

在CountDownLatch调用countDown方法时,会调用CountDownLatch中内部类sync重写AQS的方法tryReleaseShared,方法的定义如下:

public boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    for (;;) {
         int c = getState();
     if (c == 0)
        return false;
        int nextc = c-1;
    if (compareAndSetState(c, nextc))
         return nextc == 0;
    }
}

可见没调用一次都会将state减1,直到等于 0。CountDownLatch就先说这么多。

CyclicBarrier的类图如下:



CyclicBarrier的实现是基于ReentrantLock的,而ReentrantLock是基于AQS的,说白了CyclicBarrier最终还是基于AQS的。CyclicBarrier内部使用ReentrantLock的Condition来唤醒栅栏前的线程,关键源代码如下:
await方法:

/**
 * Waits until all {@linkplain #getParties parties} have invoked
 * <tt>await</tt> on this barrier.
 *
 * <p>If the current thread is not the last to arrive then it is
 * disabled for thread scheduling purposes and lies dormant until
 * one of the following things happens:
 * <ul>
 * <li>The last thread arrives; or
 * <li>Some other thread {@linkplain Thread#interrupt interrupts}
 * the current thread; or
 * <li>Some other thread {@linkplain Thread#interrupt interrupts}
 * one of the other waiting threads; or
 * <li>Some other thread times out while waiting for barrier; or
 * <li>Some other thread invokes {@link #reset} on this barrier.
 * </ul>
 *
 * <p>If the current thread:
 * <ul>
 * <li>has its interrupted status set on entry to this method; or
 * <li>is {@linkplain Thread#interrupt interrupted} while waiting
 * </ul>
 * then {@link InterruptedException} is thrown and the current thread's
 * interrupted status is cleared.
 *
 * <p>If the barrier is {@link #reset} while any thread is waiting,
 * or if the barrier {@linkplain #isBroken is broken} when
 * <tt>await</tt> is invoked, or while any thread is waiting, then
 * {@link BrokenBarrierException} is thrown.
 *
 * <p>If any thread is {@linkplain Thread#interrupt interrupted} while waiting,
 * then all other waiting threads will throw
 * {@link BrokenBarrierException} and the barrier is placed in the broken
 * state.
 *
 * <p>If the current thread is the last thread to arrive, and a
 * non-null barrier action was supplied in the constructor, then the
 * current thread runs the action before allowing the other threads to
 * continue.
 * If an exception occurs during the barrier action then that exception
 * will be propagated in the current thread and the barrier is placed in
 * the broken state.
 *
 * @return the arrival index of the current thread, where index
 * <tt>{@link #getParties()} - 1</tt> indicates the first
 * to arrive and zero indicates the last to arrive
 * @throws InterruptedException if the current thread was interrupted
 * while waiting
 * @throws BrokenBarrierException if <em>another</em> thread was
 * interrupted or timed out while the current thread was
 * waiting, or the barrier was reset, or the barrier was
 * broken when {@code await} was called, or the barrier
 * action (if present) failed due an exception.
 */
public int await() throws InterruptedException, BrokenBarrierException {
    try {
      return dowait(false, 0L);
    } catch (TimeoutException toe) {
      throw new Error(toe); // cannot happen;
    }
}

私有的 dowait 方法:

 /**
 * Main barrier code, covering the various policies.
 */
private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
         TimeoutException {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        final Generation g = generation;

        if (g.broken)
            throw new BrokenBarrierException();

        if (Thread.interrupted()) {
            breakBarrier();
            throw new InterruptedException();
        }

       int index = --count;
       if (index == 0) {  // tripped
           boolean ranAction = false;
           try {
       final Runnable command = barrierCommand;
               if (command != null)
                   command.run();
               ranAction = true;
               nextGeneration();
               return 0;
           } finally {
               if (!ranAction)
                   breakBarrier();
           }
       }

        // loop until tripped, broken, interrupted, or timed out
        for (;;) {
            try {
                if (!timed)
                    trip.await();
                else if (nanos > 0L)
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
                if (g == generation && ! g.broken) {
                    breakBarrier();
        throw ie;
        } else {
        // We're about to finish waiting even if we had not
        // been interrupted, so this interrupt is deemed to
        // "belong" to subsequent execution.
        Thread.currentThread().interrupt();
        }
            }

            if (g.broken)
                throw new BrokenBarrierException();

            if (g != generation)
                return index;

            if (timed && nanos <= 0L) {
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        lock.unlock();
    }
}   

从doAwait方法中可以看到,没调用一次index 会减1,当减为 0时,会调用 breakBarrier()方法。 breakBarrier方法的实现是:

 /**
 * Sets current barrier generation as broken and wakes up everyone.
 * Called only while holding lock.
 */
private void breakBarrier() {
   generation.broken = true;
   count = parties;
   trip.signalAll();
}

会调用 trip.signalAll()唤醒所有的线程(trip的定义 Condition trip = lock.newCondition())。可见 CyclicBarrier 是对独占锁 ReentrantLock 的简单利用。

Copyright © 2024 冰天雪域
Powered by .NET 9.0 on Kubernetes