Java并发之Semaphore源码解析(一)
Semaphore
前情提要:在学习本章前,需要先了解笔者先前讲解过的ReentrantLock源码解析,ReentrantLock源码解析里介绍的方法有很多是本章的铺垫。下面,我们进入本章正题Semaphore。
从概念上来讲,信号量(Semaphore)会维护一组许可证用于限制线程对资源的访问,当我们有一资源允许线程并发访问,但我们希望能限制访问量,就可以用信号量对访问线程进行限制。当线程要访问资源时,要先调用信号量的acquire方法获取访问许可证,当线程访问完毕后,调用信号量的release归还许可证。使用信号量我们可以服务做限流,尤其像淘宝天猫这样平时访问量就很大的电商大户,在双十一的时候更要评估其服务能承受的访问量并对其做限流,避免因为访问量过大导致服务宕机。然而,Semaphore内部实际上并没有维护一组许可证对象,而是维护一个数字作为许可证数量,如果线程要获取许可证,则会根据线程请求的许可证数量扣减内部的维护的数量,如果足够扣除则线程获取许可证成功,否则线程必须陷入阻塞,直到信号量内部的许可证数量足够。
我们来看下面的代码,假设OrderService是一个远程服务,我们预估这个服务能承受的并发量是5000,访问一次远程服务需要获取一个许可证,执行methodA()的业务只需要请求一次远程服务,所以调用semaphore.acquire()默认获取一个许可证。执行methodB()的业务需要向远程服务并发发送两次请求,所以这里acquire(int permits)的参数我们传2,以保证不管是执行methodA()还是methodB(),远程服务的并发量不超过5000。
当我们的业务不再对远程服务的访问,需要归还许可证,methodA()原先只请求一个许可证,这里调用release()对信号量内部的许可证数量+1即可。methodB()请求两个许可证,所以这里要调用release(int permits)归还两个。假设我们的服务里同时有4999个线程已经在执行methodA()方法,有一个线程要执行methodB()方法,可以知道许可证数量是不够的,信号量维护的许可证数量为5000,但线程如果要同时执行需要5001个许可证,所以要执行methodB()的线程会陷入阻塞,直到信号量内部的许可证数量足够扣除,才会获取需要的许可证数量,然后访问远程服务。
public class OrderService { private Semaphore semaphore = new Semaphore(5000); public void methodA() { try { semaphore.acquire(); //methodA body } catch (InterruptedException e) { e.printStackTrace(); } finally { semaphore.release(); } } public void methodB() { try { semaphore.acquire(2); //methodB body } catch (InterruptedException e) { e.printStackTrace(); } finally { semaphore.release(2); } } }
如果是许可证为1的信号量可以把它当做互斥锁,这时信号量只有两个状态:0或者1,我们把1代表锁未被占用,0代表锁被占用。如果是用这种方式将信号量当做互斥锁我们可以用一个线程来获取锁,而另一个线程来释放锁,比如下面的<1>处和<2>处分别在不同的线程加锁和释放锁。某种程度上来说这一做法可以避免死锁,与传统java.util.concurrent.locks.Lock的实现会有很大的不同,传统的Lock实现,比如:ReentrantLock会要求解锁的线程必须要是原先加锁的线程,否则会抛出异常。
public static void main(String[] args) { Semaphore semaphore = new Semaphore(1); new Thread(() -> { try { semaphore.acquire();//<1> System.out.println(Thread.currentThread().getName() + "获取独占锁"); } catch (InterruptedException e) { e.printStackTrace(); } }, "线程1").start(); try { Thread.sleep(100); } catch (InterruptedException e) { e.printStackTrace(); } new Thread(() -> { semaphore.release();//<2> System.out.println(Thread.currentThread().getName() + "释放独占锁"); }, "线程2").start(); try { Thread.sleep(100); } catch (InterruptedException e) { e.printStackTrace(); } }
当信号量的许可证数量为0时,如果还有线程请求获取许可证,信号量会将线程放入一个队列,然后挂起线程,直到有许可证被归还,信号量会尝试唤醒队列中等待许可证最长时间的线程。所以信号量就分为公平(FairSync)和非公平(NonfairSync)两种模式。在公平模式下,如果有线程要获取信号量的许可证时,会先判断信号量维护的等待队列中是否已经有线程,如果有的话则乖乖入队,没有才尝试请求许可证;而非公平模式则是直接请求许可证,不管队列中是否已有线程在等待信号量的许可证。
而下面的代码也印证了笔者之前所说的,信号量本身并不会去维护一个许可证对象的集合,当我们把许可证数量传给信号量的构造函数时,最终会由静态内部类Sync调用其父类AQS的setState(permits)方法将许可证赋值给AQS内部的字段state,由这个字段决定信号量有多少个许可证,请求许可证的线程能否成功。
public class Semaphore implements java.io.Serializable { private final Sync sync; abstract static class Sync extends AbstractQueuedSynchronizer { //... Sync(int permits) { setState(permits); } //... } static final class NonfairSync extends Sync {//非公平 NonfairSync(int permits) { super(permits); } //... } static final class FairSync extends Sync {//公平 //... FairSync(int permits) { super(permits); } //... } public Semaphore(int permits) { sync = new NonfairSync(permits); } public Semaphore(int permits, boolean fair) { sync = fair ? new FairSync(permits) : new NonfairSync(permits); } //... }
从上面节选的代码来看,官方更推荐使用非公平的信号量,因为根据许可证数量创建信号量默认使用的非公平信号量,而相比于公平信号量,非公平信号量有更高的吞吐量。因此笔者先介绍非公平信号量,再介绍公平信号量。
我们先来看看acquire()和acquire(int permits) 这两个方法,可以看到不管我们是请求一个许可证,还是请求多个许可证,本质上都是调用Sync.
acquireSharedInterruptibly(int arg)方法。如果大家观察静态内部类Sync的代码可以发现:Sync并没有实现acquireSharedInterruptibly(int arg)方法,而是其父类AQS实现了此方法。
public class Semaphore implements java.io.Serializable { //... private final Sync sync; abstract static class Sync extends AbstractQueuedSynchronizer { //... } public void acquire() throws InterruptedException { sync.acquireSharedInterruptibly(1); } //... public void acquire(int permits) throws InterruptedException { if (permits < 0) throw new IllegalArgumentException(); sync.acquireSharedInterruptibly(permits); } //... }
于是我们追溯到AQS实现的acquireSharedInterruptibly(int arg)方法,这个方法的实现其实并不难,先判断当前线程是否有中断标记,有的话则直接抛出中断异常InterruptedException,之后调用tryAcquireShared(int arg)尝试获取许可证,AQS本身并没有实现tryAcquireShared(int arg)方法,而是交由子类去实现的,才有了子类来决定是直接尝试获取许可证,还是先判断信号量的等待队列是否有线程正在等待许可证,有的话则排队,没有则尝试请求。
public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable { //... public final void acquireSharedInterruptibly(int arg) throws InterruptedException { if (Thread.interrupted()) throw new InterruptedException(); if (tryAcquireShared(arg) < 0) doAcquireSharedInterruptibly(arg); } //... protected int tryAcquireShared(int arg) { throw new UnsupportedOperationException(); } //... }
所以我们来看看非公平锁实现的tryAcquireShared(int arg)方法,在非公平锁的ryAcquireShared(int arg)方法中会调用到Sync类实现的nonfairTryAcquireShared(int acquires)方法,这个方法会先获取当前信号量剩余的许可证数量available,然后减去请求的数量(available - acquires)得到剩余许可证数量remaining,如果remaining大于0代表信号量现有的许可证数量是允许分配调用线程请求的许可证数量,是允许分配的,所以<1>处的条件为false,会进行<2>处的CAS扣减,如果能扣减成功,则返回剩余许可证数量,返回的remaining如果大于等于0,则代表扣减成功,如果小于0代表请求失败,表示信号量现有的许可证数量不足调用线程所需。
public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable { //... abstract static class Sync extends AbstractQueuedSynchronizer { final int nonfairTryAcquireShared(int acquires) { for (;;) { int available = getState(); int remaining = available - acquires; if (remaining < 0 ||//<1> compareAndSetState(available, remaining))//<2> return remaining; } } } //... static final class NonfairSync extends Sync { private static final long serialVersionUID = -2694183684443567898L; NonfairSync(int permits) { super(permits); } protected int tryAcquireShared(int acquires) { return nonfairTryAcquireShared(acquires); } } //... }
如果在<1>处执行tryAcquireShared(arg)尝试获取许可证失败,则会调用<2>处的方法将当前线程挂起。
public final void acquireSharedInterruptibly(int arg) throws InterruptedException { if (Thread.interrupted()) throw new InterruptedException(); if (tryAcquireShared(arg) < 0)//<1> doAcquireSharedInterruptibly(arg);//<2> }
那么我们来看看如果调用tryAcquireShared(arg)请求许可证失败后,doAcquireSharedInterruptibly(int arg)里面完成的逻辑。如果有看过笔者前一章ReentrantLock源码解析的朋友在看到这个方法应该会觉非常熟悉,这里会先调用<1>处的addWaiter(Node mode)方法将当前请求许可证的线程封装成一个Node节点并入队,这里我们也首次看到使用Node.SHARED的地方,如果一个节点Node的nextWaiter指向的是静态常量Node.SHARED,则代表这个节点是一个共享节点,换句话说这个节点的线程可以和其他同为共享节点的线程共享资源。
当线程作为节点入队后,判断节点的前驱节点是否是头节点,如果是头节点则话则进入<2>处的分支,这里会再次调用tryAcquireShared(arg)请求许可证,之前说过如果tryAcquireShared(arg)返回的结果大于等于0代表请求许可证成功,否则请求失败。如果请求失败的话,之后的流程大家想必都清楚了,会先执行shouldParkAfterFailedAcquire(p, node)判断前驱节点p的等待状态是否为SIGNAL(-1),如果为SIGNAL则直接返回true,调用parkAndCheckInterrupt()阻塞当前线程,如果前驱节点p的等待状态为0,会先用CAS的方式修改为SIGNAL,然后再下一次循环中阻塞当前线程。
public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable { //... private void doAcquireSharedInterruptibly(int arg) throws InterruptedException { final Node node = addWaiter(Node.SHARED);//<1> try { for (;;) { final Node p = node.predecessor(); if (p == head) {//<2> int r = tryAcquireShared(arg); if (r >= 0) { setHeadAndPropagate(node, r); p.next = null; // help GC return; } } if (shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt()) throw new InterruptedException(); } } catch (Throwable t) { cancelAcquire(node); throw t; } } //... private Node addWaiter(Node mode) { Node node = new Node(mode); for (;;) { Node oldTail = tail; if (oldTail != null) { node.setPrevRelaxed(oldTail); if (compareAndSetTail(oldTail, node)) { oldTail.next = node; return node; } } else { initializeSyncQueue(); } } } //... static final class Node { static final Node SHARED = new Node(); static final Node EXCLUSIVE = null; static final int CANCELLED = 1; //... static final int PROPAGATE = -3; volatile int waitStatus; volatile Node prev; volatile Node next; volatile Thread thread; Node nextWaiter; final boolean isShared() { return nextWaiter == SHARED; } //... Node(Node nextWaiter) { this.nextWaiter = nextWaiter; THREAD.set(this, Thread.currentThread()); } } //... }
上面的流程是当前线程没有请求许可证成功而陷入阻塞的情况,那么如果是线程进入等待队列后又获取到许可证呢?即:执行完下面<1>处的代码确定线程对应的节点入队,在<2>处判断节点的前驱节点是头节点,进入<2>处的分支后执行<3>处的tryAcquireShared(arg)方法成功获取到许可证此时返回的r>=0,进入<4>处的分支,那么在setHeadAndPropagate(Node node, int propagate)方法中又会做什么呢?
首先会保留一个原始头节点head的引用,其次替换头节点为当前节点。如果原先返回r(propagate)大于0,代表当前线程在请求完许可证后,信号量还有剩余许可证,于是<5>处的分值一定成立,因为propagate大于0,这里会判断当前节点的下一个节点next是否是共享模式,是的话则调用doReleaseShared()方法唤醒当前节点的下一个节点。但如果传入的propagate等于0,还有另外几个条件可以尝试通知当前节点的后继节点,只要条件(h == null || h.waitStatus < 0 || (h = head) == null || h.waitStatus < 0)成立,且当前节点的下个节点仍为共享节点,则可以唤醒后继节点申请许可证。那么怎么来理解条件(h == null || h.waitStatus < 0 || (h = head) == null || h.waitStatus < 0)呢?
首先我们可以先忽略这4个条件里面其中的两个条件,h == null和(h = head) == null都不可能成立,h是原始头节点,只要有节点入队,头节点不可能为null,其次判断head也不可能为null,因为头节点已经是当前节点,就笔者看来这两个判断是防止空指针异常的标准写法,只是预防空指针不代表会发生空指针异常。所以我们只要关注h.waitStatus < 0或者head.waitStatus < 0两个条件,其中一个成立,就可以进入<5>处的分支。那么,又如何来理解h.waitStatus < 0和head.waitStatus < 0两个条件呢?
我们先来回忆下shouldParkAfterFailedAcquire(Node pred, Node node)这个方法,这个方法接收一个前驱节点和当前节点,把前驱节点的等待状态改为SIGNAL(-1),代表前驱节点pred的下一个节点node等待唤醒。于是我们能够明白如果head.waitStatus < 0代表当前节点的下一个节点等待唤醒,如果下一个节点的模式是共享节点,就会尝试调用doReleaseShared()方法唤醒下一个节点尝试申请许可证,即便目前传入的信号量剩余许可证数量propagate为0,因为可能存从唤醒到申请许可证的期间,已经有别的线程归还了许可证,这样做可以提高整体的吞吐量,即便下一个线程被唤醒后没有可申请的许可证数量,要做的也无非是重新阻塞线程。需要注意的是:如果队列中有n个节点,唤醒后继节点这个操作不一定会从头节点一直传播到尾节点,即便前n-1个节点的等待状态(waitStatus)都为SIGNAL,最后一个因为是尾节点,它没有下一个等待唤醒的节点,所以等待状态为0。要知道当前节点能唤醒下一个节点的前提条件,首先是前驱节点为头节点,其次当前节点的线程申请到许可证,才有资格尝试唤醒下一个节点,如果节点被唤醒后,虽然前驱节点是头节点,却没有多余的许可证可以申请,无法将头节点替换成当前节点,就会重新陷入阻塞,也就不会尝试唤醒下一个节点。
public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable { //... private void doAcquireSharedInterruptibly(int arg) throws InterruptedException { final Node node = addWaiter(Node.SHARED);//<1> try { for (;;) { final Node p = node.predecessor(); if (p == head) {//<2> int r = tryAcquireShared(arg);//<3> if (r >= 0) {//<4> setHeadAndPropagate(node, r); p.next = null; // help GC return; } } if (shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt()) throw new InterruptedException(); } } catch (Throwable t) { cancelAcquire(node); throw t; } } //... private void setHeadAndPropagate(Node node, int propagate) { Node h = head; // Record old head for check below setHead(node); if (propagate > 0 || h == null || h.waitStatus < 0 || (h = head) == null || h.waitStatus < 0) {//<5> Node s = node.next; if (s == null || s.isShared()) doReleaseShared(); } } //... private void setHead(Node node) { head = node; node.thread = null; node.prev = null; } //... }
下面该来讲解当h.waitStatus < 0,事实上h.waitStatus < 0这一判断并非必要,但可以提高吞吐量。
这里笔者预先介绍一个知识点,当线程调用release()或release(int permits)方法向信号量(Semaphore)归还许可证时后,会再调用doReleaseShared()方法唤醒信号量等待队列中被阻塞的线程起来申请许可证,这里如果判断头节点的等待状态为SIGNAL,则表明头节点的后继节点陷入阻塞,如果能用CAS的方式修改头节点的等待状态成功,则调用unparkSuccessor(h)唤醒被阻塞的后继节点起来申请许可证。被阻塞的线程唤醒后如果能申请到许可证,会先把头节点替换成当前节点,并根据条件判断是否要调用doReleaseShared()唤醒下一个后继节点,如果申请许可证失败则执行两次shouldParkAfterFailedAcquire(Node pred, Node node)后重新挂起当前线程。
那么h.waitStatus < 0这一判断是如何来提高吞吐量呢?举个例子:有一信号量许可证为2,并已经分配给线程1和线程2,此时信号量的许可证数量为0。线程3和线程4想要请求许可证只能先入队等待,线程3和线程4的对应节点是N3和N4,队列中的节点排序为:head->N3->N4。假设N3在入队之后,线程1就归还了许可证,此时N3判断它的前驱节点是头节点,继而申请到许可证,因此N3不会调用shouldParkAfterFailedAcquire(Node pred, Node node)改变原先头节点的等待状态。线程1在归还许可证后,调用doReleaseShared(),假定N3入队的时候队列为空,head是调用initializeSyncQueue()方法初始化完成的。所以head的等待状态为0,在<3>处会用CAS的方式修改head的等待状态为PROPAGATE(-3)。于是线程3在执行setHeadAndPropagate(Node node, int propagate)的时候,将头节点指向N3,假定此时线程4虽然入队,但尚未修改前驱节点N3的等待状态为SIGNAL,所以((h = head) == null || h.waitStatus < 0)为false,但原先头节点的等待状态小于0,这里还是会进入<1>处的分支,判断N4是共享节点,调用doReleaseShared()唤醒线程4。
如果线程4正在执行,且有别的线程调用LockSupport.unpark(Thread thread)唤醒线程4,线程4在第一次执行LockSupport.park(Object blocker)并不会陷入阻塞,会退出parkAndCheckInterrupt()方法后又重新申请许可证,如果申请失败,再次调用parkAndCheckInterrupt()执行LockSupport.park(Object blocker)才会被阻塞,相当于线程4多了一次申请许可证的机会。也许在线程4第一次执行LockSupport.park(Object blocker)却没陷入阻塞的时候,线程2就归还了许可证,在新一轮的循环时线程4就直接申请到许可证。
如果线程4被阻塞,此时线程2归还了许可证却还没来得及调用doReleaseShared(),线程3先进入<1>处的分支调用了doReleaseShared(),线程4会被唤醒起来申请许可证,相当于有两个线程争相唤醒线程4。由此可见,如果头节点的等待状态为0,修改其等待状态为PROPAGATE,并在<1>处加上判断原先头节点的等待状态,可以提高吞吐量。
当然,执行h.compareAndSetWaitStatus(0, Node.PROPAGATE)存在失败的情况,比如原先判定头节点的等待状态为0,在执行<3>处代码之前,头节点的后继节点修改前驱节点的等待状态为SIGNAL,此时CAS修改头节点的等待状态为PROPAGATE失败,会重新执行一次循环,此时会进入<2>处唤醒后继节点,于是后继节点就又多了一次申请许可证的机会。
public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable { //... static final class Node { //... static final int SIGNAL = -1; static final int PROPAGATE = -3; //... } //... private void setHeadAndPropagate(Node node, int propagate) { Node h = head; // Record old head for check below setHead(node); if (propagate > 0 || h == null || h.waitStatus < 0 ||//<1> (h = head) == null || h.waitStatus < 0) { Node s = node.next; if (s == null || s.isShared()) doReleaseShared(); } } //... private void doReleaseShared() { for (;;) { Node h = head; if (h != null && h != tail) { int ws = h.waitStatus; if (ws == Node.SIGNAL) {//<2> if (!h.compareAndSetWaitStatus(Node.SIGNAL, 0)) continue; // loop to recheck cases unparkSuccessor(h); } else if (ws == 0 && !h.compareAndSetWaitStatus(0, Node.PROPAGATE))//<3> continue; // loop on failed CAS } if (h == head) // loop if head changed break; } } //... }