基于SpringDataRedis实现分布式定时任务功能组件
适用场景
当我们的服务在集群模式部署的情况下, 假如需要定时执行一段逻辑, 并且不希望多个实例间重复执行, 那么就可以使用此组件
例: 每10秒执行一次, 将数据库中超时未支付的订单进行关闭处理
最终成果展示
有两种使用方式, 如下
package com.idanchuang.example.spring.cloud.consumer.task; import com.idanchuang.component.redis.task.AbstractTimedTask; import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.stereotype.Component; /** * * 使用方式1: 继承AbstractTimedTask, 并添加Component注解 * @author yjy * Created at 2022/1/6 12:51 上午 */ @Component public class MyTimedTask1 extends AbstractTimedTask { public MyTimedTask1(StringRedisTemplate redisTemplate) { super(redisTemplate); } @Override public long getInterval() { // 每3秒执行一次 return 3; } @Override public void runOnce() { System.out.println("run task1"); } }
-----------------
package com.idanchuang.example.spring.cloud.consumer.task; import com.idanchuang.component.redis.annotation.RedisTask; import org.springframework.stereotype.Component; /** * @author yjy * Created at 2022/1/6 12:51 上午 */ @Component public class MyTimedTask3 { /** * 使用方式2: 在任意Bean的公有方法上添加RedisTask注解 */ @RedisTask(interval = 3) public void doJob() { System.out.println("run task3"); } }
源码展示
组件代码十分简单, 首先定义一个接口类 TimedTask, 如下
package com.idanchuang.component.redis.task; import java.util.concurrent.TimeUnit; /** * * 使用方式: 实现TimedTask接口, 并声明为Bean即可 * * @author yjy * Created at 2022/1/5 11:58 下午 */ public interface TimedTask { /** * 开始任务 */ void start(); /** * 任务名称, 每个人任务唯一 * @return 任务名称 */ String getTaskName(); /** * 执行周期的单位 * @return 时间单位 */ TimeUnit getTimeUnit(); /** * 首次执行延时, 默认不延时 * @return 首次执行延时 */ long getInitialDelay(); /** * 执行周期, 单位默认秒, 可以通过重写 getTimeUnit() 方法修改时间单位 * @return 多少时间执行一次 */ long getInterval(); /** * 执行定时任务逻辑 */ void runOnce(); /** * 任务销毁 */ void destroy(); }
接下来, 我们基于Redis实现TimedTask接口, 完成一个抽象类
package com.idanchuang.component.redis.task; import com.idanchuang.component.core.util.InetUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.data.redis.core.script.DefaultRedisScript; import java.lang.management.ManagementFactory; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; /** * 分布式定时任务抽象实现 (基于Redis lua脚本进行执行权抢占与持有) * * @author yjy * Created at 2022/1/20 10:03 上午 */ public abstract class AbstractTimedTask implements TimedTask { private static final Logger logger = LoggerFactory.getLogger(TimedTask.class); /** 本机ip */ protected static final String LOCAL_IP = InetUtil.getLocalIp(); /** jvm进程id */ protected static final String PROCESSOR_ID = ManagementFactory.getRuntimeMXBean().getName().split("@")[0]; /** 任务执行调度线程 */ protected final ScheduledExecutorService scheduler = new ScheduledThreadPoolExecutor(1, (r) -> new Thread(r, "redis-timed-task")); protected final StringRedisTemplate redisTemplate; protected AbstractTimedTask(StringRedisTemplate redisTemplate) { this.redisTemplate = redisTemplate; } /** * 开始任务 */ @Override public void start() { scheduler.scheduleAtFixedRate(() -> { try { if (tryRequire()) { this.runOnce(); } } catch (Exception e) { logger.error("TimedTask error", e); } }, getInitialDelay(), getInterval(), getTimeUnit()); logger.info("RedisTimedTask running >>> {}", this.getTaskName()); } @Override public String getTaskName() { return this.getClass().getName(); } @Override public TimeUnit getTimeUnit() { return TimeUnit.SECONDS; } @Override public long getInitialDelay() { return 0L; } @Override public void destroy() { // 停止调度线程 scheduler.shutdown(); // 释放占有的redis key tryRelease(); } /** * 检查/续命执行权 * @return .. */ private boolean tryRequire() { String key = getKey(); String instanceId = LOCAL_IP + ":" + PROCESSOR_ID; long interval = getTimeUnit().toSeconds(getInterval()); long ttl = (long) (interval * 1.5D); if (interval == ttl) { // 至少多占1s ttl = interval + 1; } String script = "local instanceId,ttl=ARGV[1],ARGV[2] "; script += "if redis.call('EXISTS', KEYS[1])==1 " + "then local curValue = redis.call('GET', KEYS[1]) " + "if curValue==instanceId " + "then redis.call('EXPIRE', KEYS[1], ttl) " + "return true " + "else " + "return false " + "end " + "else " + "redis.call('SET', KEYS[1], instanceId) " + "redis.call('EXPIRE', KEYS[1], ttl) " + "return true " + "end"; DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>(); redisScript.setResultType(Boolean.class); redisScript.setScriptText(script); List<String> keys = new ArrayList<>(); keys.add(key); return this.redisTemplate.execute(redisScript, keys, instanceId, String.valueOf(ttl)); } /** * 释放执行权 */ private void tryRelease() { String key = getKey(); String instanceId = LOCAL_IP + ":" + PROCESSOR_ID; String script = "local instanceId=ARGV[1] "; script += "if redis.call('EXISTS', KEYS[1])==1 " + "then local curValue = redis.call('GET', KEYS[1]) " + "if curValue==instanceId " + "then redis.call('DEL', KEYS[1]) " + "return true " + "else return false " + "end " + "else return false " + "end"; DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>(); redisScript.setResultType(Boolean.class); redisScript.setScriptText(script); List<String> keys = new ArrayList<>(); keys.add(key); this.redisTemplate.execute(redisScript, keys, instanceId); } /** * 任务对应的redis key * @return .. */ protected String getKey() { return "TimedTask:" + getTaskName(); } }
继续定义一个注解
package com.idanchuang.component.redis.annotation; import java.lang.annotation.*; import java.util.concurrent.TimeUnit; /** * @author yjy * @date 2020/9/3 18:05 **/ @Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) @Inherited public @interface RedisTask { /** 任务名称 */ String value() default ""; /** * @return 执行周期 */ long interval(); /** 推迟开始 */ long initialDelay() default 0L; /** * @return 时间单位 */ TimeUnit timeUnit() default TimeUnit.SECONDS; }
最后还有一个manager
package com.idanchuang.component.redis.task; import com.idanchuang.component.redis.annotation.RedisTask; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.DisposableBean; import org.springframework.boot.context.event.ApplicationReadyEvent; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationListener; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.stereotype.Component; import org.springframework.util.StringUtils; import java.lang.reflect.Method; import java.util.Collection; import java.util.List; import java.util.concurrent.TimeUnit; /** * @author yjy * Created at 2022/1/6 12:41 上午 */ @Component public class TimeTaskManager implements ApplicationListener<ApplicationReadyEvent>, DisposableBean { private static final Logger logger = LoggerFactory.getLogger(TimeTaskManager.class); private volatile boolean running = false; private final StringRedisTemplate redisTemplate; private final List<TimedTask> tasks; public TimeTaskManager(StringRedisTemplate redisTemplate, List<TimedTask> tasks) { this.redisTemplate = redisTemplate; this.tasks = tasks; } @Override public void onApplicationEvent(ApplicationReadyEvent event) { synchronized (TimeTaskManager.class) { if (running) { return; } running = true; } // 将注解的任务加入到任务集合 registerAnnotationTask(event.getApplicationContext()); // 启动所有任务 tasks.forEach(TimedTask::start); } @Override public void destroy() throws Exception { tasks.forEach(TimedTask::destroy); } /** * 找到所有带RedisTask注解的方法, 并注册到任务集合 * @param applicationContext .. */ private void registerAnnotationTask(ApplicationContext applicationContext) { Collection<Object> beans = applicationContext.getBeansOfType(Object.class).values(); for (Object bean : beans) { Method[] methods = bean.getClass().getDeclaredMethods(); for (Method method : methods) { RedisTask task = AnnotationUtils.findAnnotation(method, RedisTask.class); if (task != null) { int parameterCount = method.getParameterCount(); if (parameterCount > 0) { throw new IllegalStateException("RedisTask method parameter count: " + parameterCount + ", except: 0"); } TimedTask timedTask = new AnnoTimedTask(task, method, bean, this.redisTemplate); tasks.add(timedTask); logger.info("registerAnnotationTask > taskName: {}", timedTask.getTaskName()); } } } } /** * 注解的任务封装类 */ private static class AnnoTimedTask extends AbstractTimedTask { private final RedisTask redisTask; private final Method method; private final Object bean; public AnnoTimedTask(RedisTask redisTask, Method method, Object bean, StringRedisTemplate redisTemplate) { super(redisTemplate); this.redisTask = redisTask; this.method = method; this.bean = bean; } @Override public String getTaskName() { String taskName = redisTask.value(); if (StringUtils.isEmpty(taskName)) { return bean.getClass().getName() + ":" + method.getName(); } return taskName; } @Override public long getInterval() { return redisTask.interval(); } @Override public long getInitialDelay() { return redisTask.initialDelay(); } @Override public TimeUnit getTimeUnit() { return redisTask.timeUnit(); } @Override public void runOnce() { try { method.invoke(bean); } catch (Exception e) { logger.error("Run Redis TimedTask failed", e); } } } }
------
其中获取本机ip还用了一个工具类 InetUtil, 如下
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
package com.idanchuang.component.core.util; import java.net.Inet4Address; import java.net.InetAddress; import java.net.NetworkInterface; import java.net.SocketException; import java.util.Enumeration; /** * @author yjy * Created at 2021/3/13 6:25 下午 */ public class InetUtil { /** * 获取本机内网ip * @return ip */ public static String getLocalIp() { // 获得本机的所有网络接口 try { Enumeration<NetworkInterface> nifs = NetworkInterface.getNetworkInterfaces(); while (nifs.hasMoreElements()) { NetworkInterface nif = nifs.nextElement(); if (nif.isPointToPoint() || nif.isLoopback()) { continue; } if (nif.getName().startsWith("docker")) { continue; } // 获得与该网络接口绑定的 IP 地址,一般只有一个 Enumeration<InetAddress> addresses = nif.getInetAddresses(); while (addresses.hasMoreElements()) { InetAddress addr = addresses.nextElement(); if (addr instanceof Inet4Address) { return addr.getHostAddress(); } } } return "127.0.0.1"; } catch (SocketException e) { throw new IllegalStateException(e); } } }