多进程组件之CountDownLatch

前言

java中的多线程组件比较多,前些天看了一下CountDownLatch的源码,今天拿 CountDownLatch来详细分析一下。

样例、代码、说明

分析之前,我们先看一个比较简单的用法。

样例

我们假设一项工作:

  1. 需要先由3个人做准备;
  2. 准备好后再由2个人进行计算;
  3. 计算完后,由上级来进行输出。
  4. 假设刚开始,全部人员就位。

代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
package thread.demo.codox.net;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;

public class CountDown {

// ID生成器
private static AtomicInteger idGenerator = new AtomicInteger(1);

// 主线程,用来进行主要的准备工作和进行调用。
public static void main(String[] args) throws InterruptedException {
// 假设工作流程要先做准备,由3个线程来做准备工作。
// 做完准备后,由多个Worker线程进行处理,处理完毕后,由主线程进行打印。

// prepLatch 为准备好了的latch, workLatch为工作线程完成了的latch。
CountDownLatch prepLatch = new CountDownLatch(3);
CountDownLatch workLatch = new CountDownLatch(2);

// 生成线程池,并发准备纯种、工作线程全放到线程池。
ExecutorService es = Executors.newFixedThreadPool(10);
es.execute(new PrepareWorker(CountDown.idGenerator.incrementAndGet(), prepLatch));
es.execute(new PrepareWorker(CountDown.idGenerator.incrementAndGet(), prepLatch));
es.execute(new PrepareWorker(CountDown.idGenerator.incrementAndGet(), prepLatch));

es.execute(new Worker(CountDown.idGenerator.incrementAndGet(), prepLatch,workLatch));
es.execute(new Worker(CountDown.idGenerator.incrementAndGet(),prepLatch, workLatch));

workLatch.await();
System.out.println("Main thread end!");
es.shutdown();
}

static class PrepareWorker extends Thread {
private CountDownLatch startLatch;
private int id;

public PrepareWorker(int id, CountDownLatch latch) {
this.id = id;
this.startLatch = latch;
}

@Override
public void run() {
System.out.println("Thread " + id + " prepare started!");

try {
Thread.sleep(2000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println("Thread " + id + " prepare end!");
startLatch.countDown();
}
}

static class Worker extends Thread {
private int id;
private CountDownLatch latch;
private CountDownLatch preLatch;

public Worker(int id, CountDownLatch preLatch, CountDownLatch latch) {
this.id = id;
this.preLatch = preLatch;
this.latch = latch;
}

@Override
public void run() {
try {
preLatch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println("Woker " + id + " workout!");
latch.countDown();
}
}
}
/* OUTPUT:
Thread 3 prepare started!
Thread 4 prepare started!
Thread 2 prepare started!
Thread 4 prepare end!
Thread 2 prepare end!
Thread 3 prepare end!
Woker 5 workout!
Woker 6 workout!
Main thread end!
*/

说明

1
2
3
4
5
6
7
8
9
10
11
12
13
/**
* 唯一的构造函数,由一个正数。当countDown(),并到0时,在await()中阻塞的调用会返回。如果用负数初始化,会报
* IllegalArgumentException异常。
*/
CountDownLatch(int count)
// 一直阻塞,直到countDown()到0
void await()
// 阻塞,直到countDown()到0,或超时
boolean await(long timeout, TimeUnit unit)
// 每调用一次,由构造函数初始化的值count减1,直到0,这时在await()中阻塞的函数返回。
void countDown()
// 返回当前的count值。(实际底层使用的是int值。构造函数初始化时也是用int来初始化的。)
long getCount()

实现分析

构成

我们来看LatchDownLatch的构造函数。在内部,使用一个继承了AQS(后续简称AQS)的同步器来实现,并重新实现了AQS 的 int tryAcquireShared(int acquires) 和 boolean tryAcquireShared(int acquires)。

其中,tryReleaseShared(int)用在CountDownLatch的countDown();tryAcquireShared(int) 用在CountDownLatch的await()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
   public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;

Sync(int count) {
setState(count);
}

int getCount() {
return getState();
}

protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}

protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}

private final Sync sync;

函数分析

countDown

CountDownLatch的countDown()实现:

1
2
3
4
   // ID: snippet-1-1
public void countDown() {
sync.releaseShared(1); // -> snippet-1-2
}

sync.releaseShared(int)是来自Sync类继承的抽象类AQS:

1
2
3
4
5
6
7
8
9
10
11
12
13
// ID: snippet-1-2
// With in AbstractQueuedSynchronizer
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) { // <- AbstractQueuedSynchronizer中这个调用直接抛异常。实际用的Sync中重载版本
// -> snippet-1-3
doReleaseShared();
return true;
}
return false;
}
protected boolean tryReleaseShared(int arg) {
throw new UnsupportedOperationException();
}

在Sync中的版本:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// ID: snippet-1-3
// With in CountDownLatch.Sync
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0; // return snippet-1-2:6
}
// return snippet-1-2:9
}

这里先获取到状态值,并计算减1(即countDown)之后的值,然后把检查和设置原值做为一个原子操作来执行。如果原值变化了,没执行成功,则重新尝试;否则,返回countDown之后是否为0。使用(失败-重试方式的)乐观锁而不是synchronized方式的悲观锁,这在多线程中有较好的性能。

在java中,有很多的多线程类中,使用了compareAndSet(后续简称CAS)操作。这个compareAndSetState(int,int),调用了Unsafe.compareAndSwapInt(Object, long, int, int)。这个方法的实现,需要CPU的支持,是基于CPU的 CMPXCHG 指令。

在snippet-1-3中,是尝试取出state、设置state的过程。其中的state是在Sync构造时设置的值,这个值是放在AQS类中的一个由volatile标识的field。volatile可以避免 脏读。volatile和CAS一样,也需要CPU的支持。

在tryReleaseShared执行完成以后,如果count减少到0,则返回snippet-1-2:6,运行doReleaseShared函数,来唤醒其它在CountDownLatch上等待的线程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// ID: snippet-1-4
private void doReleaseShared() {
for (;;) {
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
break;
}
}

在AQS中,其核心数据是一个改造过的CLH自旋锁队列。是一个双端队列,Node节点的pre和next分别标示前前、后节点。节点数据包含当前线程和状态信息。

wait()

当需要阻塞的地方,就调用CountDownLatch的wait(),这时候阻塞等待countDown。阻塞时,调用的是AQS的acquireSharedInterruptibly(1):

1
2
3
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}

在AQS的这个方法中,

1
2
3
4
5
6
7
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
if (tryAcquireShared(arg) < 0) // 先直接获取锁,如果失败则进入自旋
doAcquireSharedInterruptibly(arg); // 如果当前status小于0,则在这个函数中自旋
}

而这儿的tryAcquireShared(int),在AQS是protected方法,即要求子类实现的。这儿是在CountDownLatch中实现的:查询其state,如果是0则返回1否则返回-1。

我们再来看AQS的doAcquireSharedInterruptibly(int):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED); // 以安全的方式添加一个共享节点到队列尾
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) { //如果前一个节点是head,则当前节点就是等待获取锁的节点
int r = tryAcquireShared(arg); // 尝试获取锁。
if (r >= 0) {
setHeadAndPropagate(node, r); //获取锁成功,重置head节点
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}

队列中的head和tail是lazy initialize。来看看addWaiter(Node):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
private Node addWaiter(Node mode) {
Node node = new Node(Thread.currentThread(), mode);
// Try the fast path of enq; backup to full enq on failure
Node pred = tail;
if (pred != null) { //当前的tail不是空,则将新建的节点加到队尾。
node.prev = pred;
if (compareAndSetTail(pred, node)) {
pred.next = node;
return node;
}
}
enq(node); // tail是空,则初始化一下head和tail
return node;
}

进阶

  • CLH锁,MCS锁,自旋锁,排队自旋锁
  • CAS实现原理和CPU指令
  • AbstractQueuedSynchronizer的扩展