【JUC源码解析】CyclicBarrier
简介
CyclicBarrier,一个同步器,允许多个线程相互等待,直到达到一个公共屏障点。
概述
CyclicBarrier支持一个可选的 Runnable 命令,在一组线程中的最后一个线程到达之后,释放所有线程之前,该命令只在屏障点运行一次。
应用
描述
有一个矩阵,每一行数据交给一个线程去处理,处理内容是,将这行数据的每一个值相加,结果存入第一个元素中,每个线程处理完成后,会在屏障点相互等待,直到最后一个线程也到达屏障点,最后将各个线程处理的数据汇总,具体是将每一行汇总的数据(存在每行的第一个元素中的数据)再相加,结果存在第一行中的第一个元素。
具体如下代码所示。
代码
1 public class Solver { 2 final int N; 3 final float[][] data; // 待处理的数据 4 final CyclicBarrier barrier; // 屏障 5 6 class Worker implements Runnable { // 工作者 7 int myRow; 8 9 Worker(int row) { 10 myRow = row; 11 } 12 13 public void run() { 14 System.out.println("Matrix[" + myRow + "]数据准备处理"); 15 processRow(myRow); // 处理每一行数据,把各个数据相加,并且存入第一个元素 16 System.out.println("Matrix[" + myRow + "]数据处理完成"); 17 try { 18 barrier.await(); // 处理完成后,在此等待,直到最后一个线程也完成任务 19 } catch (InterruptedException ex) { 20 return; 21 } catch (BrokenBarrierException ex) { 22 return; 23 } 24 } 25 } 26 27 public Solver(float[][] matrix) throws InterruptedException { 28 data = matrix; 29 N = matrix.length; 30 31 System.out.println("开始处理数据,最初矩阵如下所示"); 32 displayData(); // 数据展示 33 34 Runnable barrierAction = new Runnable() { // 操作完成后,由最后一个到达屏障点的线程执行此操作,整体只执行一次 35 public void run() { // 数据汇总,把第一列每行的和再相加,结果存入第一行第一个元素 36 float tmp = 0.0f; 37 for (int i = 0; i < N; i++) { 38 tmp += data[i][0]; 39 } 40 41 data[0][0] = tmp; 42 System.out.println("Matrix[0]数据汇总完成,结果是" + tmp + ",存在了Matrix[0][0]"); 43 } 44 }; 45 barrier = new CyclicBarrier(N, barrierAction); 46 47 List<Thread> threads = new ArrayList<Thread>(N); 48 for (int i = 0; i < N; i++) { 49 Thread thread = new Thread(new Worker(i)); 50 threads.add(thread); 51 thread.start(); // 启动各个工作线程 52 } 53 54 for (Thread thread : threads) { // 等待所有线程执行完成 55 thread.join(); 56 } 57 System.out.println("所以数据都已经处理完成,最终矩阵如下所示"); 58 displayData(); 59 } 60 61 private void processRow(int myRow) { // 处理每一行数据,把各个数据相加,并且存入第一个元素 62 float[] row = data[myRow]; 63 float tmp = 0.0f; 64 int length = row.length; 65 for (int i = 0; i < length; i++) { 66 tmp += row[i]; 67 } 68 69 String msg = Arrays.toString(row); 70 row[0] = tmp; 71 System.out.println("Matrix[" + myRow + "]数据" + msg + "的和是:" + tmp + ", 且存入了Matrix[" + myRow + "][0]"); 72 } 73 74 private void displayData() { // 数据展示 75 TableInfo info = new TableInfo(N + 1); // TableInfo见下面代码,仅仅为了展示数据,可忽略,也可以用Arrays.toString(array); 76 77 String[] header = new String[N + 1]; 78 header[0] = String.valueOf(" "); 79 for (int i = 0; i < N; i++) { 80 header[i + 1] = String.valueOf(i); 81 } 82 info.addHeader(header); 83 for (int i = 0; i < N; i++) { 84 String[] tmp = new String[N + 1]; 85 tmp[0] = String.valueOf(i); 86 for (int j = 0; j < data[i].length; j++) { 87 tmp[j + 1] = String.valueOf(data[i][j]); 88 } 89 info.addRecode(tmp); 90 } 91 System.out.println(info.getInfo()); 92 } 93 94 public static void main(String[] args) throws InterruptedException { 95 Random r = new Random(47); 96 float[][] matrix = new float[5][5]; 97 for (int i = 0; i < 5; i++) { 98 for (int j = 0; j < 5; j++) { 99 matrix[i][j] = r.nextFloat(); 100 } 101 } 102 new Solver(matrix); 103 } 104 } 105 106 class TableInfo { 107 private int count; 108 private int[] maxLens; 109 private String[] columns; 110 private List<String[]> records; 111 private StringBuilder builder; 112 113 public TableInfo(int count) { 114 this.count = count; 115 maxLens = new int[count]; 116 columns = new String[count]; 117 builder = new StringBuilder(); 118 records = new ArrayList<String[]>(); 119 } 120 121 private boolean isValid(String... args) { 122 return null == args || args.length == 0; 123 } 124 125 private boolean isOutOfIndex(String... args) { 126 return args.length > count; 127 } 128 129 public boolean addHeader(String... record) { 130 if (isValid(record) || isOutOfIndex(record)) { 131 return false; 132 } 133 134 copy(columns, record); 135 copy(maxLens, record); 136 137 return false; 138 } 139 140 public boolean addRecode(String... record) { 141 if (isValid(record) || isOutOfIndex(record)) { 142 return false; 143 } 144 145 copy(maxLens, record); 146 147 records.add(record); 148 149 return false; 150 } 151 152 public String getInfo() { 153 buildLine(); 154 buildHeader(); 155 buildLine(); 156 buildBody(); 157 buildLine(); 158 159 return builder.toString(); 160 } 161 162 private void buildHeader() { 163 builder.append("|"); 164 for (int i = 0; i < columns.length; i++) { 165 builder.append(columns[i]); 166 builder.append(getEmpty(maxLens[i] - columns[i].length())); 167 builder.append("|"); 168 } 169 builder.append("\r\n"); 170 } 171 172 private void buildBody() { 173 for (String[] recode : records) { 174 builder.append("|"); 175 for (int i = 0; i < recode.length; i++) { 176 builder.append(recode[i]); 177 builder.append(getEmpty(maxLens[i] - recode[i].length())); 178 builder.append("|"); 179 } 180 builder.append("\r\n"); 181 } 182 } 183 184 private void buildLine() { 185 builder.append("+"); 186 for (int i = 0; i < maxLens.length; i++) { 187 builder.append(getLine(maxLens[i])); 188 builder.append("+"); 189 } 190 builder.append("\r\n"); 191 } 192 193 private String getLine(int count) { 194 StringBuilder builder = new StringBuilder(); 195 for (int i = 0; i < count; i++) { 196 builder.append("-"); 197 } 198 199 return builder.toString(); 200 201 } 202 203 private String getEmpty(int count) { 204 StringBuilder builder = new StringBuilder(); 205 for (int i = 0; i < count; i++) { 206 builder.append(" "); 207 } 208 209 return builder.toString(); 210 } 211 212 private void copy(String[] args1, String... args2) { 213 for (int i = 0; i < args2.length; i++) { 214 args1[i] = args2[i]; 215 } 216 } 217 218 private void copy(int[] args1, String... args2) { 219 for (int i = 0; i < args2.length; i++) { 220 args1[i] = swap(args1[i], args2[i].length()); 221 } 222 } 223 224 private int swap(int max, int length) { 225 if (max > length) { 226 return max; 227 } 228 229 return length; 230 } 231
输出
开始处理数据,最初矩阵如下所示 +----+----------+----------+----------+----------+----------+ | |0 |1 |2 |3 |4 | +----+----------+----------+----------+----------+----------+ |0 |0.72711575|0.39982635|0.5309454 |0.0534122 |0.16020656| |1 |0.57799757|0.18847865|0.4170137 |0.51660204|0.73734957| |2 |0.2678662 |0.9510573 |0.261361 |0.11435455|0.05086732| |3 |0.5466897 |0.8037155 |0.20143336|0.76206654|0.55373144| |4 |0.5304296 |0.15709275|0.5295954 |0.39661872|0.48718303| +----+----------+----------+----------+----------+----------+ Matrix[0]数据准备处理 Matrix[1]数据准备处理 Matrix[0]数据[0.72711575, 0.39982635, 0.5309454, 0.0534122, 0.16020656]的和是:1.8715063, 且存入了Matrix[0][0] Matrix[0]数据处理完成 Matrix[2]数据准备处理 Matrix[2]数据[0.2678662, 0.9510573, 0.261361, 0.11435455, 0.05086732]的和是:1.6455064, 且存入了Matrix[2][0] Matrix[2]数据处理完成 Matrix[1]数据[0.57799757, 0.18847865, 0.4170137, 0.51660204, 0.73734957]的和是:2.4374416, 且存入了Matrix[1][0] Matrix[1]数据处理完成 Matrix[4]数据准备处理 Matrix[3]数据准备处理 Matrix[3]数据[0.5466897, 0.8037155, 0.20143336, 0.76206654, 0.55373144]的和是:2.8676367, 且存入了Matrix[3][0] Matrix[3]数据处理完成 Matrix[4]数据[0.5304296, 0.15709275, 0.5295954, 0.39661872, 0.48718303]的和是:2.1009195, 且存入了Matrix[4][0] Matrix[4]数据处理完成 Matrix[0]数据汇总完成,结果是10.923011,存在了Matrix[0][0] 所以数据都已经处理完成,最终矩阵如下所示 +----+---------+----------+----------+----------+----------+ | |0 |1 |2 |3 |4 | +----+---------+----------+----------+----------+----------+ |0 |10.923011|0.39982635|0.5309454 |0.0534122 |0.16020656| |1 |2.4374416|0.18847865|0.4170137 |0.51660204|0.73734957| |2 |1.6455064|0.9510573 |0.261361 |0.11435455|0.05086732| |3 |2.8676367|0.8037155 |0.20143336|0.76206654|0.55373144| |4 |2.1009195|0.15709275|0.5295954 |0.39661872|0.48718303| +----+---------+----------+----------+----------+----------+
源码解析
关于静态内部类Generation,其属性broken描述CyclicBarrier是否被打破。每次到达屏障点,或者重置时,都会新生下一代(创建Generation实例,记录下一批线程的屏障状态)。
一组线程,执行一批任务,调用CyclicBarrier的await方法,其内部调用的是Condition的await方法,实质上是调用LockSupport.park方法,入队等待。也就是说,这组线程一调用await方法,就会再此阻塞,而最后一个线程调用await方法时,不在走Condition的await的代码分支,而是先执行barrierCommand任务,然后调用nextGeneration方法,在该方法内部,会调用Condition的signalAll方法,即是唤醒阻塞在Condition上的线程,重置count,并重新创建Generation实例,一遍下次使用。
reset方法,会先打破原有的屏障,即是Generation的broken设置为true,提前调用Condition的signalAll方法,释放阻塞着的线程。并且接着调用nextGeneration,开启一个新的Generation实例。
问题:为什么要多次创建Generation实例呢?重用一个实例不是更好吗?毕竟只有一个broken属性,加一个set方法就好了?
回答:之所以创建新的Generation,是因为,每个Generation实例对应一批冲向屏障的线程。假如,只有一个Generation实例,想像这样一个场景,某个线程调用了reset方法,本意是使另外一组线程能重新使用CyclicBarrier,并唤醒与它同一批阻塞在Generation上的线程,被唤醒的老一批线程需要记录屏障打破记录(reset方法会打破屏障状态,即是broken为true),如果是同一个Generation对象,且broken为true,而新线程又要求broken为false,因为是全新的,对它们来说,屏障没有被打破。因此,无法满足。
以下是源码:
属性
1 private static class Generation { // 分代,对应每一批冲向屏障的线程 2 boolean broken = false; // 记录屏障是否被打破 3 } 4 5 private final ReentrantLock lock = new ReentrantLock(); // 可重入锁 6 private final Condition trip = lock.newCondition(); // 条件 7 private final int parties; // 记录线程数量 8 private final Runnable barrierCommand; // 最后一个到达屏障的线程需要执行的任务 9 private Generation generation = new Generation(); // 初代 10 11 private int count; // 还未到达屏障的线程个数,会再次重置为parties
创建下一代
1 private void nextGeneration() { // 创建下一代 2 trip.signalAll(); // 唤醒阻塞在该代屏障上的线程 3 count = parties; // 重置count 4 generation = new Generation(); // 创建新的Generation实例 5 }
打破屏障
1 private void breakBarrier() { // 打破屏障 2 generation.broken = true; // 设置broken 3 count = parties; // 重置count 4 trip.signalAll(); // 唤醒阻塞的线程 5 }
关键的dowait
1 private int dowait(boolean timed, long nanos) 2 throws InterruptedException, BrokenBarrierException, TimeoutException { 3 final ReentrantLock lock = this.lock; // 可重入锁 4 lock.lock(); // 加锁 5 try { 6 final Generation g = generation; // 对应该代线程的分代器,记录屏障的打破状态 7 8 if (g.broken) // 如果被打破,抛出异常 9 throw new BrokenBarrierException(); 10 11 if (Thread.interrupted()) { // 如果线程中断了 12 breakBarrier(); // 打破屏障,并抛出异常 13 throw new InterruptedException(); 14 } 15 16 int index = --count; // 来一个线程,count减1 17 if (index == 0) { // 即将通过屏障 18 boolean ranAction = false; // 记录是否正常完成 19 try { 20 final Runnable command = barrierCommand; // 屏障点任务,由最后一个到达的线程负责执行 21 if (command != null) 22 command.run(); // 执行任务 23 ranAction = true; // 设置为正常完成 24 nextGeneration(); // 创建新的Generation对象 25 return 0; // 返回 26 } finally { 27 if (!ranAction) // 如果没有正常完成,打破屏障 28 breakBarrier(); 29 } 30 } 31 32 for (;;) { 33 try { 34 if (!timed) // 如果没有设置超时 35 trip.await(); // 调用await方法 36 else if (nanos > 0L) // 否则,调用awaitNanos方法 37 nanos = trip.awaitNanos(nanos); 38 } catch (InterruptedException ie) { // 如果是中断唤醒的,则看是否换代 39 if (g == generation && !g.broken) { // 如果还是同一代,并且屏障没有被打破,那么打破屏障,并抛出异常 40 breakBarrier(); 41 throw ie; 42 } else { // 如果换代了,或者屏障已经打破了,什么都不作,仅仅重新设置中断标记 43 Thread.currentThread().interrupt(); 44 } 45 } 46 47 if (g.broken) // 如果是正常唤醒的,并且屏障已经打破,抛出异常 48 throw new BrokenBarrierException(); 49 50 if (g != generation) // 超时,换代了,返回未到达屏障的线程数目 51 return index; 52 53 if (timed && nanos <= 0L) { // 超时, 没换代 54 breakBarrier(); // 打破屏障 55 throw new TimeoutException(); // 抛出超时异常 56 } 57 } 58 } finally { 59 lock.unlock(); // 解锁 60 } 61 }
行文至此结束。
尊重他人的劳动,转载请注明出处:http://www.cnblogs.com/aniao/p/aniao_cyclicbarrier.html