多线程实现生产消费模式
这篇文章是通过多线程的方式实现生产消费模式,但是有几点需要注意:1.只适用于生产和消费方法在同一个类中,2.只适用单一任务的生产和消费。
这里的测试类使用的是xxl分布式定时任务调用平台为例。
代码
生产和消费上下文对象:
import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; /** * 生产、消费线程上下文 * * @param <E> */ public class Context<E> { private static final Logger log = LoggerFactory.getLogger(Context.class); private final LinkedBlockingQueue<E> consumptionQueue = new LinkedBlockingQueue<E>(2500); private volatile ThreadState producersThreadState;// 生产线程状态 private volatile ThreadState consumersThreadState;// 消费线程状态 /** * 将指定元素插入到此队列的尾部 * <p> * 如有必要(队列空间已满且消费线程未停止运行),则等待空间变得可用 * </P> * * @param e * @return true:插入成功,false:插入失败(消费线程已停止运行) * @throws Exception */ public boolean offerDataToConsumptionQueue(E e) throws Exception { setProducersThreadState(ThreadState.RUNNING); if (ThreadState.DEAD == this.getConsumersThreadState()) {// 如果消费线程停止了,不再生产数据 return false; } while (true) { if (consumptionQueue.offer(e, 2, TimeUnit.SECONDS)) { return true; } // 添加元素失败,很有可能是队列已满,再次检查消费线程是否工作中 if (ThreadState.DEAD == this.getConsumersThreadState()) {// 如果消费线程停止了,不再生产数据 return false; } } } /** * 获取并移除此队列的头, * <p> * 如果此队列为空且生产线程已停止,则返回 null * </P> * * @return * @throws Exception */ public E pollDataFromConsumptionQueue() throws Exception { setConsumersThreadState(ThreadState.RUNNING); while (true) { E e = consumptionQueue.poll(20, TimeUnit.MILLISECONDS); if (e != null) { return e; } // 没有从队列里获取到元素,并且生产线程已停止,则返回null if (ThreadState.DEAD == this.getProducersThreadState()) { return null; } log.debug("demand exceeds supply(供不应求,需生产数据)..."); Thread.sleep(50); } } /** * 获取队列的大小 * * @return */ int getConsumptionQueueSize() { return consumptionQueue.size(); } /** * 获取生产者线程的状态 * * @return */ ThreadState getProducersThreadState() { return producersThreadState; } /** * 设置生产者线程的状态 * * @param producersThreadState */ void setProducersThreadState(ThreadState producersThreadState) { this.producersThreadState = producersThreadState; } /** * 获取消费者线程的状态 * * @return */ ThreadState getConsumersThreadState() { return consumersThreadState; } /** * 设置消费者线程的状态 * * @param consumersThreadState */ void setConsumersThreadState(ThreadState consumersThreadState) { this.consumersThreadState = consumersThreadState; } }
生产和消费协调者类:
import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.concurrent.*; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; /** * 生产与消费协调者 */ public class Coordinator { private static final Logger log = LoggerFactory.getLogger(Coordinator.class); private final Lock lock = new ReentrantLock(); private final Condition enabledConsumers = lock.newCondition(); private volatile boolean isEnabledForConsumers; private final Context<?> context; private boolean isWaitingToFinish;// 是否等待生产及消费完成 private int consumersMaxTotal;// 最大消费线程数 public Coordinator(Context<?> context, int consumersMaxTotal) { this(context, consumersMaxTotal, true); } public Coordinator(Context<?> context, int consumersMaxTotal, boolean isWaitingToFinish) { this.context = context; this.consumersMaxTotal = consumersMaxTotal; this.isWaitingToFinish = isWaitingToFinish; } /** * 启动生产、消费 (适用于生产函数、消费函数在一个类里实现且只有一对生产、消费组合,并且方法入参列表简单) * <p> * 这个方法才是生产者和消费者的主启动方法 * </P> * * @param simpleTemplate 生产、消费简易模板 */ public void start(SimpleTemplate<?> simpleTemplate) throws Exception { if (context.getConsumersThreadState() != null || context.getProducersThreadState() != null) { return; } ProducersThreadUnit producersThreadUnit = new ProducersThreadUnit(simpleTemplate, "production", context); ConsumersThreadUnit consumersThreadUnit = new ConsumersThreadUnit(simpleTemplate, "consumption", context); if (context.getConsumersThreadState() != ThreadState.NEW || context.getProducersThreadState() != ThreadState.NEW) { return; } try { long startTime = System.currentTimeMillis(); Thread startProducersThread = this.startProducers(producersThreadUnit); Thread startConsumersThread = this.startConsumers(consumersThreadUnit); if (!this.isWaitingToFinish) { return; } startProducersThread.join(); if (startConsumersThread != null) { startConsumersThread.join(); } log.info(String.format("processing is completed... man-hour(millisecond)=[%s]", System.currentTimeMillis() - startTime)); } catch (Exception e) { log.error("start worker error...", e); throw e; } } /** * 启动生产 * * @param producersThreadUnit * @return * @throws Exception */ private Thread startProducers(ProducersThreadUnit producersThreadUnit) throws Exception { Thread thread = new Thread(producersThreadUnit); thread.start(); return thread; } /** * 启动消费 * * @param consumersThreadUnit * @return * @throws Exception */ private Thread startConsumers(ConsumersThreadUnit consumersThreadUnit) throws Exception { lock.lock(); try { log.info("wating for producers..."); while (!isEnabledForConsumers) { // 等待生产(造成当前线程在接到信号、被中断或到达指定等待时间之前一直处于等待状态),假定可能发生虚假唤醒(这并非是因为等待超时),因此总是在一个循环中等待 enabledConsumers.await(5, TimeUnit.SECONDS);// 间隔检查,防止意外情况下线程没能被成功唤醒(机率小之又小,导致线程无限挂起) } if (context.getConsumptionQueueSize() == 0) { return null; } log.info("start consumers before..."); Thread thread = new Thread(consumersThreadUnit); thread.start(); return thread; } catch (Exception e) { log.error("start consumers error...", e); throw e; } finally { lock.unlock(); } } /** * 生产线程 */ public class ProducersThreadUnit implements Runnable { private Object targetObject; private String targetMethodName; private Object[] targetMethodParameters; private ExecutorService executorService = Executors.newFixedThreadPool(1); /** * 构件函数 * * @param targetObject * @param targetMethodName * @param targetMethodParameters */ public ProducersThreadUnit(Object targetObject, String targetMethodName, Object... targetMethodParameters) { this.targetObject = targetObject; this.targetMethodName = targetMethodName; this.targetMethodParameters = targetMethodParameters; context.setProducersThreadState(ThreadState.NEW); } @Override public void run() { try { executorService.execute(new RunnableThreadUnit(targetObject, targetMethodName, targetMethodParameters)); context.setProducersThreadState(ThreadState.RUNNABLE); executorService.shutdown(); // 阻塞线程,直到生产中(消费队列不为空)或者停止生产 while (!executorService.isTerminated() && context.getConsumptionQueueSize() == 0) { Thread.sleep(20); } log.info("production the end or products have been delivered,ready to inform consumers..."); this.wakeConsumers(); log.info("wait until the production is complete..."); while (!executorService.isTerminated()) { // 等待生产完毕 Thread.sleep(200); } } catch (Exception e) { log.error(String.format("production error... targetObject=[%s],targetMethodName=[%s],targetMethodParameters=[%s]", targetObject, targetMethodName, targetMethodParameters), e); if (!executorService.isShutdown()) { executorService.shutdown(); } } finally { log.info("production the end..."); context.setProducersThreadState(ThreadState.DEAD); isEnabledForConsumers = true;// 无论在何种情况下,必须确保能够结束挂起中的消费者线程 } } /** * 向消费者发送信号 */ private void wakeConsumers() { isEnabledForConsumers = true;// 即使唤醒消费者线程失败,也可以使用该句柄结束挂起中的消费者线程 lock.lock(); try { enabledConsumers.signal(); } catch (Exception e) { log.error("inform to consumers error...", e); } finally { lock.unlock(); } } } /** * 消费线程 */ public class ConsumersThreadUnit implements Runnable { private Object targetObject; private String targetMethodName; private Object[] targetMethodParameters; public ConsumersThreadUnit(Object targetObject, String targetMethodName, Object... targetMethodParameters) { this.targetObject = targetObject; this.targetMethodName = targetMethodName; this.targetMethodParameters = targetMethodParameters; context.setConsumersThreadState(ThreadState.NEW); } @Override public void run() { ThreadPoolExecutor threadPoolExecutor = null; int concurrencyMaxTotal = Coordinator.this.consumersMaxTotal; try { threadPoolExecutor = new ThreadPoolExecutor(0, concurrencyMaxTotal, 60L, TimeUnit.SECONDS, new SynchronousQueue<Runnable>()); while (concurrencyMaxTotal > 0) { if (threadPoolExecutor.getPoolSize() > context.getConsumptionQueueSize()) { if (ThreadState.DEAD == context.getProducersThreadState()) { break;// 无须再提交新任务 } else { Thread.sleep(50); continue;// 再次检查是否有必要提交新任务 } } RunnableThreadUnit consumers = new RunnableThreadUnit(targetObject, targetMethodName, targetMethodParameters); threadPoolExecutor.execute(consumers); context.setConsumersThreadState(ThreadState.RUNNABLE); log.info("submit consumption task..."); concurrencyMaxTotal--; } threadPoolExecutor.shutdown(); while (!threadPoolExecutor.isTerminated()) { // 等待消费完毕 Thread.sleep(100); } } catch (Exception e) { log.error(String.format("consumption error... targetObject=[%s],targetMethodName=[%s],targetMethodParameters=[%s]", targetObject, targetMethodName, targetMethodParameters), e); if (threadPoolExecutor != null && !threadPoolExecutor.isShutdown()) { threadPoolExecutor.shutdown(); } } finally { log.info("consumption the end..."); context.setConsumersThreadState(ThreadState.DEAD); } } } }
线程处理公用类:
import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.lang.reflect.Method; /** * 线程单元 */ public class RunnableThreadUnit implements Runnable { private final static Logger logger = LoggerFactory.getLogger(RunnableThreadUnit.class); private Object object; private String methodName; private Object[] methodParameters; public RunnableThreadUnit(Object object, String methodName, Object... methodParameters) { if (object == null || StringUtils.isBlank(methodName) || methodParameters == null) throw new RuntimeException("init runnable thread unit error..."); this.object = object; this.methodName = methodName; this.methodParameters = methodParameters; } public void run() { try { Class<?>[] classes = new Class[methodParameters.length]; for (int i = 0; i < methodParameters.length; i++) classes[i] = methodParameters[i].getClass(); Method method = object.getClass().getMethod(methodName, classes); method.invoke(object, methodParameters); } catch (Exception e) { logger.error(String.format("execute runnable thread unit error... service=[%s],invokeMethodName=[%s]", object, methodName), e); } } }
模板接口类
/** * 实现生产、消费(适用于生产、消费在一个类里完成且只有一个生产、消费组合,并且方法入参列表简单)简易模板 * * @param <C_E> */ public interface SimpleTemplate<C_E> { /** * 生产数据 * * @param context * @throws Exception */ void production(Context<C_E> context) throws Exception; /** * 消费数据 * * @param context * @throws Exception */ void consumption(Context<C_E> context) throws Exception; }
线程状态枚举类
/** * 线程状态 */ enum ThreadState { NEW, RUNNABLE, RUNNING, DEAD, BLOCKED; }
测试类
import com.alibaba.fastjson.JSONObject; import com.credithc.channel.dao.entity.ChannelCollisionEntity; import com.credithc.channel.xxljob.common.JobConstants; import com.credithc.channel.xxljob.concurrent.Context; import com.credithc.channel.xxljob.concurrent.Coordinator; import com.credithc.channel.xxljob.handlers.PerfectIdentityInfoForcollisionHandler; import com.xxl.job.core.biz.model.ReturnT; import com.xxl.job.core.handler.IJobHandler; import com.xxl.job.core.handler.annotation.JobHandler; import com.xxl.job.core.log.XxlJobLogger; import org.apache.commons.lang.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import java.text.SimpleDateFormat; import java.util.Date; /** * 完善撞库表中身份证信息定时任务 * * @author kuangxiang * @date 2020/8/24 16:15 */ @Component @JobHandler("perfectIdentityInfoForCollisionJob") public class PerfectIdentityInfoForCollisionJob extends IJobHandler { private Logger logger = LoggerFactory.getLogger(this.getClass()); @Autowired private PerfectIdentityInfoForcollisionHandler perfectIdentityInfoForcollisionHandler; @Override public ReturnT<String> execute(String param) throws Exception { long startTimeStamp = System.currentTimeMillis(); XxlJobLogger.log("【完善撞库表中身份证信息】【定时任务】,开始,param:" + param + ",startTimeStamp:" + startTimeStamp); logger.info("【完善撞库表中身份证信息】【定时任务】,开始,param:{},startTimeStamp:{}", param, startTimeStamp); if (!checkParams(param, startTimeStamp)) { return ReturnT.FAIL; } //生产消费模式启动方法 new Coordinator(new Context<ChannelCollisionEntity>(), JobConstants.CONSUMERS_MAX_TOTAL).start(perfectIdentityInfoForcollisionHandler); double executTime = (System.currentTimeMillis() - startTimeStamp) / 1000; XxlJobLogger.log("【完善撞库表中身份证信息】【定时任务】,结束,startTimeStamp:" + startTimeStamp + ",executTime:" + executTime); logger.info("【完善撞库表中身份证信息】【定时任务】,结束,startTimeStamp:{},executTime:{}秒", startTimeStamp, executTime); return ReturnT.SUCCESS; } /** * 参数校验 * * @param param 定时任务参数 * @param startTimeStamp 时间戳 * @return */ private boolean checkParams(String param, long startTimeStamp) { try { JSONObject jsonObject = JSONObject.parseObject(param); if (jsonObject == null || jsonObject.getInteger("step") == null) { XxlJobLogger.log("【完善撞库表中身份证信息】【定时任务】,参数校验失败"); logger.info("【完善撞库表中身份证信息】【定时任务】,参数校验失败,param:{},startTimeStamp:{}", param, startTimeStamp); return false; } SimpleDateFormat simpleDateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); Date beginDate = null; String beginDateStr = jsonObject.getString("beginDateStr"); if (StringUtils.isNotBlank(beginDateStr)) { beginDate = simpleDateFormat.parse(beginDateStr); } Date endDate = null; String endDateStr = jsonObject.getString("endDateStr"); if (StringUtils.isNotBlank(endDateStr)) { endDate = simpleDateFormat.parse(endDateStr); } perfectIdentityInfoForcollisionHandler.setStep(jsonObject.getInteger("step")); perfectIdentityInfoForcollisionHandler.setBeginDate(beginDate); perfectIdentityInfoForcollisionHandler.setEndDate(endDate); } catch (Exception e) { XxlJobLogger.log("【完善撞库表中身份证信息】【定时任务】,参数校验异常"); logger.error("【完善撞库表中身份证信息】【定时任务】,参数校验异常,param:" + param, e); return false; } return true; } }