Java项目开发中异步调用场景控制并发数

场景

项目基于SpringBoot搭建,默认使用Tomcat Web容器,对于每个HTTP请求,Tomcat Web容器会分配1个线程来处理请求。

在pom.xml里查看依赖关系: spring-boot-starter-web添加了tomcat-embed-core依赖

Tomcat线程池配置可在application.yml配置:

server:
  tomcat:
    max-threads: 500

此外,在日常开发中常见有两种用到多线程异步执行的场景:

  1. 服务内部启动1个或多个自定义的业务线程池,用于异步并发处理业务
  2. 调用的中间件或者第三方它本身是异步执行的

第1种场景,由于线程池是开发自己定义的,可配置线程池的核心/最大线程数、阻塞队列、拒绝策略等参数,
通过评估实际需求、资源占用等来进行配置参数,保证服务的稳定性、健壮性、可控性,
如:线程数过大导致服务器鸭梨过大影响性能、队列不是无限大避免OOM等。

而第2种场景,由于多线程、异步不是服务内的,开发经常容易忽略稳定性、可控性。

以使用ElasticSearch为例,RestHighLevelClient客户端提供了多个异步执行的方法,如:updateByQueryAsyncbulkAync等,
由于这些异步方法没有阻塞当前线程,能让接口快速返回,调用来确实很方便,但这隐藏了不可控的"风险"。

举例:
搜索服务里有商品索引,索引里包含商品、店铺、库存、价格等信息,
当某个商品的信息发生变动时,索引里商品相关字段需要更新,
一个商品可能在很多个店铺上架,那么需要批量更新索引里多个文档;

假设商品变动是商品服务通过MQ队列来下发商品变动消息,搜索服务收到消息后调用updateByQueryAsync来批量更新,
由于updateByQueryAsync是异步执行的,在消息监听处理里调用起来很快,通常消息消费者也是多线程的,
这样单个消息处理很快,如果短时间内消息很多,会导致大量updateByQueryAsync调用,这些异步调用由ElasticSearch
处理,而ElasticSearch内部也有线程池和队列,可能出现以下问题:

  1. ElasticSearch服务器有大量的批量更新处理,可能出现索引文档写入冲突,同时服务器CPU/内存/IO/GC等压力变大;
  2. 由于ElasticSearch服务器压力变大,可能影响其它非写入业务的正常运行,如商品搜索相关业务响应变慢或者超时;
  3. ElasticSearch内部的线程数、队列占满,会出现任务拒绝执行的异常,导致某些更新失败,影响搜索结果的正确性;

调整ElasticSearch写入线程池的配置是一个方法,这是站在被调用方的优化思路;

另一个思路是:当遇到这种调中间件或第三方接口本身是异步的场景,由调用方主动做一些控制,
如控制调用的频率、调用并发数等,为系统整体的可用性、稳定性考虑,而不是简单发起调用后就不管。

思路

那么如何控制呢?
考虑到调用方发起调用后,这个调用是由中间件、第三方处理,它们可能是多线程并发执行的,对于发起的1次调用,即有1个线程来处理,处理结束后该线程空闲出来;
因此调用方考虑对同时发起的并发调用数进行控制,比如限制并发调用数最大为50,保证不让被调用方同时超过50个调用异步执行;
发起每次异步调用后,需要拿到调用结果,这样才知道这次调用是否处理完成,从而控制是否发起新的调用。

对于RestHighLevelClient#updateByQueryAsync(UpdateByQueryRequest updateByQueryRequest, RequestOptions options, ActionListener<BulkByScrollResponse> listener)方法,
第3个参数是ActionListener接口,其中onResponse方法里可添加完成回调的逻辑。

方法1---用CountDownLatch来控制:

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.apache.commons.lang3.builder.ToStringStyle;
import org.springframework.util.Assert;

import java.io.Serializable;
import java.util.Collection;
import java.util.Iterator;
import java.util.concurrent.CountDownLatch;

/**
 * 限制批处理的异步任务执行器
 * <p>
 * 1. 任务本身需是异步执行的
 * 2. 此执行器控制异步执行并发数, 避免大量并发导致系统负载过大
 *
 * @author cdfive
 */
@Slf4j
public class LimitBatchAsyncTaskExecutor {

    // 每批次处理数
    private int batch;

    // 跟踪id
    private String traceId;

    public LimitBatchAsyncTaskExecutor(int batch) {
        this(batch, CommonUtil.getTraceId());
    }

    public LimitBatchAsyncTaskExecutor(int batch, String traceId) {
        Assert.isTrue(batch > 0, "batch must be greater than 0");
        this.batch = batch;
        this.traceId = traceId;
    }

    public <T> void executeTasks(Collection<T> tasks, AsyncTaskExecutor<T> asyncTaskExecutor) {
        Assert.isTrue(tasks != null && tasks.size() > 0, "tasks can't be empty");
        this.executeTasks(tasks.iterator(), tasks.size(), asyncTaskExecutor);
    }

    public <T> void executeTasks(Iterator<T> tasks, int total, AsyncTaskExecutor<T> asyncTaskExecutor) {
        Assert.notNull(tasks, "tasks can't be null");
        Assert.isTrue(total > 0, "total must be greater than 0");
        Assert.notNull(asyncTaskExecutor, "asyncTaskExecutor can't be null");

        log.info(traceId + ",LimitBatchAsyncTaskExecutor executeTasks start,total={},batch={}", total, batch);
        long totalStart = System.currentTimeMillis();
        long batchStart = System.currentTimeMillis();

        CountDownLatch latch = null;
        Runnable callback = null;
        int index = 0;
        int batchIndex = 0;
        int batchTotal = (total / batch) + (total % batch == 0 ? 0 : 1);
        while (tasks.hasNext()) {
            T task = tasks.next();
            index++;
            if (latch == null) {
                batchStart = System.currentTimeMillis();
                batchIndex++;
                latch = new CountDownLatch((batchIndex < batchTotal) ? batch : (total - index + 1));
                CountDownLatch finalLatch = latch;
                callback = new Runnable() {
                    @Override
                    public void run() {
                        finalLatch.countDown();
                    }
                };
            }

            Context context = new Context(index, total, batchIndex, batchTotal);
            log.info(traceId + ",LimitBatchAsyncTaskExecutor executeTask start,batch={},context={}", batch, context);
            long start = System.currentTimeMillis();
            Runnable finalCallback = callback;
            asyncTaskExecutor.executeTask(task, () -> {
                log.info(traceId + ",LimitBatchAsyncTaskExecutor executeTask done,cost={}ms,batch={},context={}", (System.currentTimeMillis() - start), batch, context);
                finalCallback.run();
            }, context);

            if (index % batch == 0 || index == total) {
                try {
                    latch.await();
                } catch (InterruptedException e) {
                    log.error(traceId + ",LimitBatchAsyncTaskExecutor await error", e);
                }

                latch = null;
                callback = null;
                log.info(traceId + ",LimitBatchAsyncTaskExecutor batch done,cost={}ms,index=({}/{}),batchIndex=({}/{})"
                        , (System.currentTimeMillis() - batchStart), index, total, batchIndex, batchTotal);
            }
        }

        log.info(traceId + ",LimitBatchAsyncTaskExecutor executeTasks success,total cost={}ms,batch={}", (System.currentTimeMillis() - totalStart), batch);
    }

    public int getBatch() {
        return batch;
    }

    public String getTraceId() {
        return traceId;
    }

    /**
     * 异步任务执行器
     */
    public static interface AsyncTaskExecutor<T> {

        /**
         * 执行任务,任务本身需是异步执行的
         */
        void executeTask(T task, Runnable callback, Context context);
    }

    /**
     * 上下文
     */
    @NoArgsConstructor
    @AllArgsConstructor
    @Data
    public static class Context implements Serializable {

        private static final long serialVersionUID = 2916376548269524746L;

        // 下标
        private int index;

        // 总数
        private int total;

        // 批量下标
        private int batchIndex;

        // 批量总数
        private int batchTotal;

        @Override
        public String toString() {
            return ToStringBuilder.reflectionToString(this, ToStringStyle.JSON_STYLE);
        }
    }

    @Override
    public String toString() {
        return "LimitBatchAsyncTaskExecutor{" +
                "batch=" + batch +
                ", traceId='" + traceId + '\'' +
                '}';
    }
}

LimitBatchAsyncTaskExecutor类的batch字段,表示每批次处理的数量,通过CountDownLatch和回调接口,
控制异步任务每批次最多同时执行的任务数量,当每批次任务执行完成后,开始执行下个批次的任务;注意最后1批次执行任务数
可能小于batch

方法2---用Semaphore来控制:

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.apache.commons.lang3.builder.ToStringStyle;
import org.springframework.util.Assert;

import java.io.Serializable;
import java.util.Collection;
import java.util.Iterator;
import java.util.concurrent.Semaphore;

/**
 * 限制并发度的异步任务执行器
 * <p>
 * 1. 任务本身需是异步执行的
 * 2. 此执行器控制异步执行并发数, 避免大量并发导致系统负载过大
 *
 * @author cdfive
 */
@Slf4j
public class LimitConcurrencyAsyncTaskExecutor {

    // 并发数
    private int concurrency;

    // 跟踪id
    private String traceId;

    public LimitConcurrencyAsyncTaskExecutor(int concurrency) {
        this(concurrency, CommonUtil.getTraceId());
    }

    public LimitConcurrencyAsyncTaskExecutor(int concurrency, String traceId) {
        Assert.isTrue(concurrency > 0, "concurrency must be greater than 0");
        this.concurrency = concurrency;
        this.traceId = traceId;
    }

    public <T> void executeTasks(Collection<T> tasks, AsyncTaskExecutor<T> asyncTaskExecutor) {
        Assert.isTrue(tasks != null && tasks.size() > 0, "tasks can't be empty");
        this.executeTasks(tasks.iterator(), tasks.size(), asyncTaskExecutor);
    }

    public <T> void executeTasks(Iterator<T> tasks, int total, AsyncTaskExecutor<T> asyncTaskExecutor) {
        Assert.notNull(tasks, "tasks can't be null");
        Assert.isTrue(total > 0, "total must be greater than 0");
        Assert.notNull(asyncTaskExecutor, "asyncTaskExecutor can't be null");

        log.info(traceId + ",LimitConcurrencyAsyncTaskExecutor executeTasks start,total={},concurrency={}", total, concurrency);
        long totalStart = System.currentTimeMillis();
        long batchStart = System.currentTimeMillis();

        Semaphore semaphore = new Semaphore(concurrency);
        Runnable callback = new Runnable() {
            @Override
            public void run() {
                semaphore.release();
            }
        };
        int index = 0;
        while (tasks.hasNext()) {
            T task = tasks.next();
            index++;

            semaphore.acquireUninterruptibly();
            Context context = new Context(index, total);
            log.info(traceId + ",LimitConcurrencyAsyncTaskExecutor executeTask start,concurrency={},context={}", concurrency, context);
            long start = System.currentTimeMillis();

            asyncTaskExecutor.executeTask(task, () -> {
                log.info(traceId + ",LimitConcurrencyAsyncTaskExecutor executeTask done,cost={}ms,concurrency={},context={}", (System.currentTimeMillis() - start), concurrency, context);
                callback.run();
            }, context);
        }

        log.info(traceId + ",LimitConcurrencyAsyncTaskExecutor executeTasks success,total cost={}ms,concurrency={}", (System.currentTimeMillis() - totalStart), concurrency);
    }

    public int getConcurrency() {
        return concurrency;
    }

    public String getTraceId() {
        return traceId;
    }

    /**
     * 异步任务执行器
     */
    public static interface AsyncTaskExecutor<T> {

        /**
         * 执行任务,任务本身需是异步执行的
         */
        void executeTask(T task, Runnable callback, Context context);
    }

    /**
     * 上下文
     */
    @NoArgsConstructor
    @AllArgsConstructor
    @Data
    public static class Context implements Serializable {

        private static final long serialVersionUID = 2916376548269524746L;

        // 下标
        private int index;

        // 总数
        private int total;

        @Override
        public String toString() {
            return ToStringBuilder.reflectionToString(this, ToStringStyle.JSON_STYLE);
        }
    }

    @Override
    public String toString() {
        return "LimitConcurrencyAsyncTaskExecutor{" +
                "concurrency=" + concurrency +
                ", traceId='" + traceId + '\'' +
                '}';
    }
}

LimitConcurrencyAsyncTaskExecutor类的concurrency字段,表示异步同时执行任务的并发数,
通过Semaphore和回调接口,限制了异步任务同时调用执行的并发数量,concurrency个任务同时异步执行,
当其中某个任务执行完成后资源空闲出来,立即调用执行下个任务。

2个方法比较和测试

LimitBatchAsyncTaskExecutorLimitConcurrencyAsyncTaskExecutor都限制了执行异步任务的并发数,
不同的是LimitBatchAsyncTaskExecutor是每批次执行n个任务,需要等这个批次所有任务执行完成后,开始下个批次的执行,
LimitConcurrencyAsyncTaskExecutor是限制最多n个任务同时执行,当其中1个任务执行完成立即执行下个任务,相比较
而言它没有时间和资源上的浪费,性能更好。

编写测试类验证:

/**
 * @author cdfive
 */
public class LimitAsyncTaskExecutorTest {

    @Test
    public void testLimitBatchAsyncTaskExecutor() {
        long start = System.currentTimeMillis();
        int batch = 5;
        int total = 1000;
        CountDownLatch latch = new CountDownLatch(total);

        LimitBatchAsyncTaskExecutor limitBatchAsyncTaskExecutor = new LimitBatchAsyncTaskExecutor(batch);

        List<String> codes = IntStream.range(1, 1 + total).mapToObj(i -> String.valueOf(i)).collect(Collectors.toList());

        LimitBatchAsyncTaskExecutor.AsyncTaskExecutor<String> asyncTaskExecutor = new LimitBatchAsyncTaskExecutor.AsyncTaskExecutor<String>() {
            @Override
            public void executeTask(String code, Runnable callback, LimitBatchAsyncTaskExecutor.Context context) {
                new Thread(new Runnable() {
                    @Override
                    public void run() {
                        try {
                            TimeUnit.MILLISECONDS.sleep(ThreadLocalRandom.current().nextInt(180, 200));
                        } catch (InterruptedException e) {
                            e.printStackTrace();
                        }

                        System.err.println(limitBatchAsyncTaskExecutor.getTraceId() + "," + Thread.currentThread().getName() + "=>code=" + code + ",context=" + context);

                        callback.run();

                        latch.countDown();
                    }
                }).start();
            }
        };

        limitBatchAsyncTaskExecutor.executeTasks(codes, asyncTaskExecutor);

        try {
            latch.await();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        // more than 40s
        System.out.println("total done,cost=" + (System.currentTimeMillis() - start) + "ms");
    }
    
    @Test
    public void testLimitConcurrencyAsyncTaskExecutor() {
        long start = System.currentTimeMillis();
        int concurrency = 5;
        int total = 1000;
        CountDownLatch latch = new CountDownLatch(total);

        LimitConcurrencyAsyncTaskExecutor limitConcurrencyAsyncTaskExecutor = new LimitConcurrencyAsyncTaskExecutor(concurrency);

        List<String> codes = IntStream.range(1, 1 + total).mapToObj(i -> String.valueOf(i)).collect(Collectors.toList());

        LimitConcurrencyAsyncTaskExecutor.AsyncTaskExecutor<String> asyncTaskExecutor = new LimitConcurrencyAsyncTaskExecutor.AsyncTaskExecutor<String>() {
            @Override
            public void executeTask(String code, Runnable callback, LimitConcurrencyAsyncTaskExecutor.Context context) {
                new Thread(new Runnable() {
                    @Override
                    public void run() {
                        try {
                            TimeUnit.MILLISECONDS.sleep(ThreadLocalRandom.current().nextInt(180, 200));
                        } catch (InterruptedException e) {
                            e.printStackTrace();
                        }

                        System.err.println(limitConcurrencyAsyncTaskExecutor.getTraceId() + "," + Thread.currentThread().getName() + "=>code=" + code + ",context=" + context);

                        callback.run();

                        latch.countDown();
                    }
                }).start();
            }
        };

        limitConcurrencyAsyncTaskExecutor.executeTasks(codes, asyncTaskExecutor);

        try {
            latch.await();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        // less than 40s
        System.out.println("total done,cost=" + (System.currentTimeMillis() - start) + "ms");
    }
}

1000个任务,每个任务运行时间为180-200毫秒之间,限制参数均设置为5;
估算:1000 * 0.2s / 5 = 40s,5个异步并发,大概40s执行完成;
测试结果:使用LimitBatchAsyncTaskExecutor运行时间略大于40秒,使用LimitConcurrencyAsyncTaskExecutor略小于40秒。
由此可见,方法2的LimitConcurrencyAsyncTaskExecutor性能更好,因为它没有资源闲置;
也可根据需求选择使用。

总结

  • 多线程和异步是开发中常见和重要的场景,在合适的时候使用能提高程序处理效率

  • 对于自定义的线程池根据需求、硬件、性能等评估设置相关参数,可考虑参数被动态配置或者线程池动态创建

  • 对于调中间件或第三方接口本身是异步执行的情况,要考虑调用方的资源和承受能力,可从调用方角度进行限制和控制

  • 被调用方是异步,从调用方控制并发数可使用java.util.concurrent包下的并发工具类,配合异步回调来控制

  • 在程序设计和开发过程中要多考虑风险可控,调用方/被调用方、中间件、第三方等,关注系统的稳定性和和健壮性

posted @ 2024-06-19 21:36  cdfive  阅读(87)  评论(0编辑  收藏  举报