多线程限流工具类-Semaphore
Semaphore介绍
Semaphore(信号量)是JAVA多线程中的一个工具类,它可以通过指定参数来控制执行线程数量,一般用于限流访问某个资源时使用。
Semaphore使用示例
需求场景:用一个核心线程数为6,最大线程数为20的线程池执行任务,但是要求最多只能同时运行3个线程
代码:
public class demo { //创建线程池,核心线程数:6;最大线程数:20;时间:5;时间单位:秒;阻塞队列:ArrayBlockingQueue,最大容量为10;线程工厂:默认;拒绝策略:默认 static ThreadPoolExecutor poolExecutor = new ThreadPoolExecutor(6, 20, 5, TimeUnit.SECONDS, new ArrayBlockingQueue<>(10)); public static void main(String[] args) throws InterruptedException { Semaphore semaphore = new Semaphore(3);//指定线程数量 for (int i = 0; i < 10; i++) { poolExecutor.execute(new Runnable() { @Override public void run() { try { semaphore.acquire(); System.out.println(Thread.currentThread().getName() + " start..."); Thread.sleep(2000); //用来表明当前线程结束 System.out.println(Thread.currentThread().getName() + " end..."); } catch (InterruptedException e) { e.printStackTrace(); } finally { semaphore.release(); } } }); } poolExecutor.shutdown(); } }
(结果分析)从输出结果可以看出:最多只能同时开启3个线程(创建Semaphore时指定线程数量),只有等最先开启的三个线程中的某个结束了才会开启新的线程,但同时运行的总量始终保持在3个以内!
pool-1-thread-1 start... pool-1-thread-2 start... pool-1-thread-3 start... pool-1-thread-2 end... pool-1-thread-3 end... pool-1-thread-4 start... pool-1-thread-1 end... pool-1-thread-5 start... pool-1-thread-6 start... pool-1-thread-5 end... pool-1-thread-6 end... pool-1-thread-4 end... pool-1-thread-6 start... pool-1-thread-3 start... pool-1-thread-2 start... pool-1-thread-6 end... pool-1-thread-2 end... pool-1-thread-3 end... pool-1-thread-1 start... pool-1-thread-1 end... Process finished with exit code 0
Semaphore实现原理
源码:
public class Semaphore implements java.io.Serializable { private static final long serialVersionUID = -3222578661600680210L; //继承AQS的内部类 private final Sync sync; abstract static class Sync extends AbstractQueuedSynchronizer { private static final long serialVersionUID = 1192457210091910933L; //构造函数,传入的参数为信号量permits Sync(int permits) { setState(permits); } //获取信号量 final int getPermits() { return getState(); } //以非公平的方式尝试获取信号量 final int nonfairTryAcquireShared(int acquires) { //自旋 for (; ; ) { //当前信号量 int available = getState(); //获取acquires个信号量后的剩余信号量 int remaining = available - acquires; //如果剩余信号量小于0(获取失败),或者成功把剩余信号量更新为当前信号量(获取成功)都会退出自旋并返回剩余信号量 if (remaining < 0 || compareAndSetState(available, remaining)) return remaining; } } //尝试释放信号量 protected final boolean tryReleaseShared(int releases) { for (; ; ) { //当前信号量 int current = getState(); //下个信号量,即当前信号量+释放的信号量(线程运行结束将信号量还给Semaphore,所以相加) int next = current + releases; //如果下个信号量小于当前信号量则有越界的情况,报错! if (next < current) // overflow throw new Error("Maximum permit count exceeded"); //如果没问题就用CAS更新当前信号量,并结束自旋 if (compareAndSetState(current, next)) return true; } } //减少信号量 final void reducePermits(int reductions) { for (; ; ) { int current = getState(); //下个信号量为当前信号量-减少信号量 int next = current - reductions; //如果没有那么多可减少的信号则抛出异常 if (next > current) // underflow throw new Error("Permit count underflow"); //如果没问题就更新信号量并结束自旋 if (compareAndSetState(current, next)) return; } } //清空信号量 final int drainPermits() { for (; ; ) { int current = getState(); if (current == 0 || compareAndSetState(current, 0)) return current; } } } //非公平 static final class NonfairSync extends Sync { private static final long serialVersionUID = -2694183684443567898L; NonfairSync(int permits) { super(permits); } protected int tryAcquireShared(int acquires) { return nonfairTryAcquireShared(acquires); } } //公平 static final class FairSync extends Sync { private static final long serialVersionUID = 2014338818796000944L; FairSync(int permits) { super(permits); } //公平的方式尝试获取信号量 protected int tryAcquireShared(int acquires) { for (; ; ) { //如果队列当前节点已经有任务则结束自旋 if (hasQueuedPredecessors()) return -1; //当前信号量 int available = getState(); //剩余信号量=当先信号量-获取信号量 int remaining = available - acquires; //如果剩余信号量小于0(获取失败),或者成功把剩余信号量更新为当前信号量(获取成功)都会退出自旋并返回剩余信号量 if (remaining < 0 || compareAndSetState(available, remaining)) return remaining; } } } //构造方法,默认以非公平的方式设置信号量 public Semaphore(int permits) { sync = new NonfairSync(permits); } //构造方法,自定义以公平还是非公平的方式设置信号量 public Semaphore(int permits, boolean fair) { sync = fair ? new FairSync(permits) : new NonfairSync(permits); } //获取信号量,如果当前线程已中止(interrupted)就抛出异常 public void acquire() throws InterruptedException { sync.acquireSharedInterruptibly(1); } //获取信号量,无论当前线程是否中止都尝试获取 public void acquireUninterruptibly() { sync.acquireShared(1); } //尝试获取信号量 public boolean tryAcquire() { return sync.nonfairTryAcquireShared(1) >= 0; } //尝试获取信号量并设置超时时间 public boolean tryAcquire(long timeout, TimeUnit unit) throws InterruptedException { return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout)); } //释放信号量 public void release() { sync.releaseShared(1); } //获取指定数量的信号量 public void acquire(int permits) throws InterruptedException { if (permits < 0) throw new IllegalArgumentException(); sync.acquireSharedInterruptibly(permits); } //获取指定数量的信号量 public void acquireUninterruptibly(int permits) { if (permits < 0) throw new IllegalArgumentException(); sync.acquireShared(permits); } //尝试获取指定数量的信号量 public boolean tryAcquire(int permits) { if (permits < 0) throw new IllegalArgumentException(); return sync.nonfairTryAcquireShared(permits) >= 0; } //尝试获取指定数量的信号量,并设置超时时间 public boolean tryAcquire(int permits, long timeout, TimeUnit unit) throws InterruptedException { if (permits < 0) throw new IllegalArgumentException(); return sync.tryAcquireSharedNanos(permits, unit.toNanos(timeout)); } //释放指定数量的信号量 public void release(int permits) { if (permits < 0) throw new IllegalArgumentException(); sync.releaseShared(permits); } //获取当前信号数量 public int availablePermits() { return sync.getPermits(); } //清空信号量 public int drainPermits() { return sync.drainPermits(); } //减少指定数量的信号量 protected void reducePermits(int reduction) { if (reduction < 0) throw new IllegalArgumentException(); sync.reducePermits(reduction); } //判断是否为公平 public boolean isFair() { return sync instanceof FairSync; } //判断是否有队列(有阻塞线程时才会产生队列,即判断是否有阻塞线程) public final boolean hasQueuedThreads() { return sync.hasQueuedThreads(); } //获取阻塞线程数量 public final int getQueueLength() { return sync.getQueueLength(); } //获取阻塞线程并封装成集合返回 protected Collection<Thread> getQueuedThreads() { return sync.getQueuedThreads(); } public String toString() { return super.toString() + "[Permits = " + sync.getPermits() + "]"; } }
核心方法:
1、acquire()
//获取信号量,如果当前线程已中止(interrupted)就抛出异常 public void acquire() throws InterruptedException { sync.acquireSharedInterruptibly(1); } public final void acquireSharedInterruptibly(int arg) throws InterruptedException { //如果当前线程已经终止则抛出异常 if (Thread.interrupted()) throw new InterruptedException(); //尝试获取信号量(调用内部类Sync的方法) if (tryAcquireShared(arg) < 0) //获取信号量失败时会将当前线程封装成node加入到阻塞队列中 doAcquireSharedInterruptibly(arg); }
2、release()
//释放信号量 public void release() { sync.releaseShared(1); } public final boolean releaseShared(int arg) { //调用内部类Sync尝试释放信号量 if (tryReleaseShared(arg)) { //释放成功后唤醒阻塞队列的next节点 doReleaseShared(); return true; } return false; }