多线程限流工具类-Semaphore

Semaphore介绍

Semaphore(信号量)是JAVA多线程中的一个工具类,它可以通过指定参数来控制执行线程数量,一般用于限流访问某个资源时使用。

Semaphore使用示例

需求场景:用一个核心线程数为6最大线程数为20的线程池执行任务,但是要求最多只能同时运行3个线程

代码:

public class demo {

    //创建线程池,核心线程数:6;最大线程数:20;时间:5;时间单位:秒;阻塞队列:ArrayBlockingQueue,最大容量为10;线程工厂:默认;拒绝策略:默认
    static ThreadPoolExecutor poolExecutor = new ThreadPoolExecutor(6, 20, 5, TimeUnit.SECONDS, new ArrayBlockingQueue<>(10));

    public static void main(String[] args) throws InterruptedException {
        Semaphore semaphore = new Semaphore(3);//指定线程数量
        for (int i = 0; i < 10; i++) {
            poolExecutor.execute(new Runnable() {
                @Override
                public void run() {
                    try {
                        semaphore.acquire();
                        System.out.println(Thread.currentThread().getName() + " start...");
                        Thread.sleep(2000);
                        //用来表明当前线程结束
                        System.out.println(Thread.currentThread().getName() + " end...");
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    } finally {
                        semaphore.release();
                    }
                }
            });
        }
        poolExecutor.shutdown();
    }
}

(结果分析)从输出结果可以看出:最多只能同时开启3个线程(创建Semaphore时指定线程数量),只有等最先开启的三个线程中的某个结束了才会开启新的线程,但同时运行的总量始终保持在3个以内!

pool-1-thread-1 start...
pool-1-thread-2 start...
pool-1-thread-3 start...
pool-1-thread-2 end...
pool-1-thread-3 end...
pool-1-thread-4 start...
pool-1-thread-1 end...
pool-1-thread-5 start...
pool-1-thread-6 start...
pool-1-thread-5 end...
pool-1-thread-6 end...
pool-1-thread-4 end...
pool-1-thread-6 start...
pool-1-thread-3 start...
pool-1-thread-2 start...
pool-1-thread-6 end...
pool-1-thread-2 end...
pool-1-thread-3 end...
pool-1-thread-1 start...
pool-1-thread-1 end...

Process finished with exit code 0

Semaphore实现原理

源码:

public class Semaphore implements java.io.Serializable {
    private static final long serialVersionUID = -3222578661600680210L;

    //继承AQS的内部类
    private final Sync sync;

    abstract static class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 1192457210091910933L;

        //构造函数,传入的参数为信号量permits
        Sync(int permits) {
            setState(permits);
        }

        //获取信号量
        final int getPermits() {
            return getState();
        }

        //以非公平的方式尝试获取信号量
        final int nonfairTryAcquireShared(int acquires) {
            //自旋
            for (; ; ) {
                //当前信号量
                int available = getState();
                //获取acquires个信号量后的剩余信号量
                int remaining = available - acquires;
                //如果剩余信号量小于0(获取失败),或者成功把剩余信号量更新为当前信号量(获取成功)都会退出自旋并返回剩余信号量
                if (remaining < 0 || compareAndSetState(available, remaining))
                    return remaining;
            }
        }

        //尝试释放信号量
        protected final boolean tryReleaseShared(int releases) {
            for (; ; ) {
                //当前信号量
                int current = getState();
                //下个信号量,即当前信号量+释放的信号量(线程运行结束将信号量还给Semaphore,所以相加)
                int next = current + releases;
                //如果下个信号量小于当前信号量则有越界的情况,报错!
                if (next < current) // overflow
                    throw new Error("Maximum permit count exceeded");
                //如果没问题就用CAS更新当前信号量,并结束自旋
                if (compareAndSetState(current, next))
                    return true;
            }
        }

        //减少信号量
        final void reducePermits(int reductions) {
            for (; ; ) {
                int current = getState();
                //下个信号量为当前信号量-减少信号量
                int next = current - reductions;
                //如果没有那么多可减少的信号则抛出异常
                if (next > current) // underflow
                    throw new Error("Permit count underflow");
                //如果没问题就更新信号量并结束自旋
                if (compareAndSetState(current, next))
                    return;
            }
        }

        //清空信号量
        final int drainPermits() {
            for (; ; ) {
                int current = getState();
                if (current == 0 || compareAndSetState(current, 0))
                    return current;
            }
        }
    }

    //非公平
    static final class NonfairSync extends Sync {
        private static final long serialVersionUID = -2694183684443567898L;

        NonfairSync(int permits) {
            super(permits);
        }

        protected int tryAcquireShared(int acquires) {
            return nonfairTryAcquireShared(acquires);
        }
    }

    //公平
    static final class FairSync extends Sync {
        private static final long serialVersionUID = 2014338818796000944L;

        FairSync(int permits) {
            super(permits);
        }

        //公平的方式尝试获取信号量
        protected int tryAcquireShared(int acquires) {
            for (; ; ) {
                //如果队列当前节点已经有任务则结束自旋
                if (hasQueuedPredecessors())
                    return -1;
                //当前信号量
                int available = getState();
                //剩余信号量=当先信号量-获取信号量
                int remaining = available - acquires;
                //如果剩余信号量小于0(获取失败),或者成功把剩余信号量更新为当前信号量(获取成功)都会退出自旋并返回剩余信号量
                if (remaining < 0 || compareAndSetState(available, remaining))
                    return remaining;
            }
        }
    }

    //构造方法,默认以非公平的方式设置信号量
    public Semaphore(int permits) {
        sync = new NonfairSync(permits);
    }

    //构造方法,自定义以公平还是非公平的方式设置信号量
    public Semaphore(int permits, boolean fair) {
        sync = fair ? new FairSync(permits) : new NonfairSync(permits);
    }

    //获取信号量,如果当前线程已中止(interrupted)就抛出异常
    public void acquire() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    //获取信号量,无论当前线程是否中止都尝试获取
    public void acquireUninterruptibly() {
        sync.acquireShared(1);
    }

    //尝试获取信号量
    public boolean tryAcquire() {
        return sync.nonfairTryAcquireShared(1) >= 0;
    }

    //尝试获取信号量并设置超时时间
    public boolean tryAcquire(long timeout, TimeUnit unit) throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    //释放信号量
    public void release() {
        sync.releaseShared(1);
    }

    //获取指定数量的信号量
    public void acquire(int permits) throws InterruptedException {
        if (permits < 0) throw new IllegalArgumentException();
        sync.acquireSharedInterruptibly(permits);
    }

    //获取指定数量的信号量
    public void acquireUninterruptibly(int permits) {
        if (permits < 0) throw new IllegalArgumentException();
        sync.acquireShared(permits);
    }

    //尝试获取指定数量的信号量
    public boolean tryAcquire(int permits) {
        if (permits < 0) throw new IllegalArgumentException();
        return sync.nonfairTryAcquireShared(permits) >= 0;
    }

    //尝试获取指定数量的信号量,并设置超时时间
    public boolean tryAcquire(int permits, long timeout, TimeUnit unit)
            throws InterruptedException {
        if (permits < 0) throw new IllegalArgumentException();
        return sync.tryAcquireSharedNanos(permits, unit.toNanos(timeout));
    }

    //释放指定数量的信号量
    public void release(int permits) {
        if (permits < 0) throw new IllegalArgumentException();
        sync.releaseShared(permits);
    }

    //获取当前信号数量
    public int availablePermits() {
        return sync.getPermits();
    }

    //清空信号量
    public int drainPermits() {
        return sync.drainPermits();
    }

    //减少指定数量的信号量
    protected void reducePermits(int reduction) {
        if (reduction < 0) throw new IllegalArgumentException();
        sync.reducePermits(reduction);
    }

    //判断是否为公平
    public boolean isFair() {
        return sync instanceof FairSync;
    }

    //判断是否有队列(有阻塞线程时才会产生队列,即判断是否有阻塞线程)
    public final boolean hasQueuedThreads() {
        return sync.hasQueuedThreads();
    }

    //获取阻塞线程数量
    public final int getQueueLength() {
        return sync.getQueueLength();
    }

    //获取阻塞线程并封装成集合返回
    protected Collection<Thread> getQueuedThreads() {
        return sync.getQueuedThreads();
    }


    public String toString() {
        return super.toString() + "[Permits = " + sync.getPermits() + "]";
    }
}

核心方法:

1、acquire()

    //获取信号量,如果当前线程已中止(interrupted)就抛出异常
    public void acquire() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        //如果当前线程已经终止则抛出异常
        if (Thread.interrupted())
            throw new InterruptedException();
        //尝试获取信号量(调用内部类Sync的方法)
        if (tryAcquireShared(arg) < 0)
            //获取信号量失败时会将当前线程封装成node加入到阻塞队列中
            doAcquireSharedInterruptibly(arg);
    }

2、release()

    //释放信号量
    public void release() {
        sync.releaseShared(1);
    }
    public final boolean releaseShared(int arg) {
        //调用内部类Sync尝试释放信号量
        if (tryReleaseShared(arg)) {
            //释放成功后唤醒阻塞队列的next节点
            doReleaseShared();
            return true;
        }
        return false;
    }

 

posted @ 2024-03-03 16:09  请别耽误我写BUG  阅读(118)  评论(0编辑  收藏  举报