基于SpringDataRedis实现分布式定时任务功能组件

适用场景

当我们的服务在集群模式部署的情况下, 假如需要定时执行一段逻辑, 并且不希望多个实例间重复执行, 那么就可以使用此组件

例: 每10秒执行一次, 将数据库中超时未支付的订单进行关闭处理

 

最终成果展示

有两种使用方式, 如下

package com.idanchuang.example.spring.cloud.consumer.task;

import com.idanchuang.component.redis.task.AbstractTimedTask;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;

/**
 *
 * 使用方式1: 继承AbstractTimedTask, 并添加Component注解
 * @author yjy
 * Created at 2022/1/6 12:51 上午
 */
@Component
public class MyTimedTask1 extends AbstractTimedTask {

    public MyTimedTask1(StringRedisTemplate redisTemplate) {
        super(redisTemplate);
    }

    @Override
    public long getInterval() {
        // 每3秒执行一次
        return 3;
    }

    @Override
    public void runOnce() {
        System.out.println("run task1");
    }

}

-----------------

package com.idanchuang.example.spring.cloud.consumer.task;

import com.idanchuang.component.redis.annotation.RedisTask;
import org.springframework.stereotype.Component;

/**
 * @author yjy
 * Created at 2022/1/6 12:51 上午
 */
@Component
public class MyTimedTask3 {

    /**
     * 使用方式2: 在任意Bean的公有方法上添加RedisTask注解
     */
    @RedisTask(interval = 3)
    public void doJob() {
        System.out.println("run task3");
    }

}

 

源码展示

组件代码十分简单, 首先定义一个接口类 TimedTask, 如下

package com.idanchuang.component.redis.task;

import java.util.concurrent.TimeUnit;

/**
 *
 * 使用方式: 实现TimedTask接口, 并声明为Bean即可
 *
 * @author yjy
 * Created at 2022/1/5 11:58 下午
 */
public interface TimedTask {

    /**
     * 开始任务
     */
    void start();

    /**
     * 任务名称, 每个人任务唯一
     * @return 任务名称
     */
    String getTaskName();

    /**
     * 执行周期的单位
     * @return 时间单位
     */
    TimeUnit getTimeUnit();

    /**
     * 首次执行延时, 默认不延时
     * @return 首次执行延时
     */
    long getInitialDelay();

    /**
     * 执行周期, 单位默认秒, 可以通过重写 getTimeUnit() 方法修改时间单位
     * @return 多少时间执行一次
     */
    long getInterval();

    /**
     * 执行定时任务逻辑
     */
    void runOnce();

    /**
     * 任务销毁
     */
    void destroy();

}

接下来, 我们基于Redis实现TimedTask接口, 完成一个抽象类

package com.idanchuang.component.redis.task;

import com.idanchuang.component.core.util.InetUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;

import java.lang.management.ManagementFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
 * 分布式定时任务抽象实现 (基于Redis lua脚本进行执行权抢占与持有)
 *
 * @author yjy
 * Created at 2022/1/20 10:03 上午
 */
public abstract class AbstractTimedTask implements TimedTask {

    private static final Logger logger = LoggerFactory.getLogger(TimedTask.class);

    /** 本机ip */
    protected static final String LOCAL_IP = InetUtil.getLocalIp();
    /** jvm进程id */
    protected static final String PROCESSOR_ID = ManagementFactory.getRuntimeMXBean().getName().split("@")[0];

    /** 任务执行调度线程 */
    protected final ScheduledExecutorService scheduler = new ScheduledThreadPoolExecutor(1,
            (r) -> new Thread(r, "redis-timed-task"));
    
    protected final StringRedisTemplate redisTemplate;

    protected AbstractTimedTask(StringRedisTemplate redisTemplate) {
        this.redisTemplate = redisTemplate;
    }


    /**
     * 开始任务
     */
    @Override
    public void start() {
        scheduler.scheduleAtFixedRate(() -> {
            try {
                if (tryRequire()) {
                    this.runOnce();
                }
            } catch (Exception e) {
                logger.error("TimedTask error", e);
            }
        }, getInitialDelay(), getInterval(), getTimeUnit());
        logger.info("RedisTimedTask running >>> {}", this.getTaskName());
    }

    @Override
    public String getTaskName() {
        return this.getClass().getName();
    }

    @Override
    public TimeUnit getTimeUnit() {
        return TimeUnit.SECONDS;
    }

    @Override
    public long getInitialDelay() {
        return 0L;
    }

    @Override
    public void destroy() {
        // 停止调度线程
        scheduler.shutdown();
        // 释放占有的redis key
        tryRelease();
    }

    /**
     * 检查/续命执行权
     * @return ..
     */
    private boolean tryRequire() {
        String key = getKey();
        String instanceId = LOCAL_IP + ":" + PROCESSOR_ID;
        long interval = getTimeUnit().toSeconds(getInterval());
        long ttl = (long) (interval * 1.5D);
        if (interval == ttl) {
            // 至少多占1s
            ttl = interval + 1;
        }
        String script = "local instanceId,ttl=ARGV[1],ARGV[2] ";
        script += "if redis.call('EXISTS', KEYS[1])==1 " +
                "then local curValue = redis.call('GET', KEYS[1]) " +
                "if curValue==instanceId " +
                "then redis.call('EXPIRE', KEYS[1], ttl) " +
                "return true " +
                "else " +
                "return false " +
                "end " +
                "else " +
                "redis.call('SET', KEYS[1], instanceId) " +
                "redis.call('EXPIRE', KEYS[1], ttl) " +
                "return true " +
                "end";
        DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
        redisScript.setResultType(Boolean.class);
        redisScript.setScriptText(script);
        List<String> keys = new ArrayList<>();
        keys.add(key);
        return this.redisTemplate.execute(redisScript, keys, instanceId, String.valueOf(ttl));
    }

    /**
     * 释放执行权
     */
    private void tryRelease() {
        String key = getKey();
        String instanceId = LOCAL_IP + ":" + PROCESSOR_ID;
        String script = "local instanceId=ARGV[1] ";
        script += "if redis.call('EXISTS', KEYS[1])==1 " +
                "then local curValue = redis.call('GET', KEYS[1]) " +
                "if curValue==instanceId " +
                "then redis.call('DEL', KEYS[1]) " +
                "return true " +
                "else return false " +
                "end " +
                "else return false " +
                "end";
        DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
        redisScript.setResultType(Boolean.class);
        redisScript.setScriptText(script);
        List<String> keys = new ArrayList<>();
        keys.add(key);
        this.redisTemplate.execute(redisScript, keys, instanceId);
    }

    /**
     * 任务对应的redis key
     * @return ..
     */
    protected String getKey() {
        return "TimedTask:" + getTaskName();
    }

}

继续定义一个注解

package com.idanchuang.component.redis.annotation;

import java.lang.annotation.*;
import java.util.concurrent.TimeUnit;

/**
 * @author yjy
 * @date 2020/9/3 18:05
 **/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Inherited
public @interface RedisTask {

    /** 任务名称 */
    String value() default "";

    /**
     * @return 执行周期
     */
    long interval();

    /** 推迟开始 */
    long initialDelay() default  0L;

    /**
     * @return 时间单位
     */
    TimeUnit timeUnit() default TimeUnit.SECONDS;

}

最后还有一个manager

package com.idanchuang.component.redis.task;

import com.idanchuang.component.redis.annotation.RedisTask;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.boot.context.event.ApplicationReadyEvent;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationListener;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import java.lang.reflect.Method;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.TimeUnit;

/**
 * @author yjy
 * Created at 2022/1/6 12:41 上午
 */
@Component
public class TimeTaskManager implements ApplicationListener<ApplicationReadyEvent>, DisposableBean {

    private static final Logger logger = LoggerFactory.getLogger(TimeTaskManager.class);
    private volatile boolean running = false;
    private final StringRedisTemplate redisTemplate;
    private final List<TimedTask> tasks;

    public TimeTaskManager(StringRedisTemplate redisTemplate, List<TimedTask> tasks) {
        this.redisTemplate = redisTemplate;
        this.tasks = tasks;
    }

    @Override
    public void onApplicationEvent(ApplicationReadyEvent event) {
        synchronized (TimeTaskManager.class) {
            if (running) {
                return;
            }
            running = true;
        }
        // 将注解的任务加入到任务集合
        registerAnnotationTask(event.getApplicationContext());
        // 启动所有任务
        tasks.forEach(TimedTask::start);
    }

    @Override
    public void destroy() throws Exception {
        tasks.forEach(TimedTask::destroy);
    }

    /**
     * 找到所有带RedisTask注解的方法, 并注册到任务集合
     * @param applicationContext ..
     */
    private void registerAnnotationTask(ApplicationContext applicationContext) {
        Collection<Object> beans = applicationContext.getBeansOfType(Object.class).values();
        for (Object bean : beans) {
            Method[] methods = bean.getClass().getDeclaredMethods();
            for (Method method : methods) {
                RedisTask task = AnnotationUtils.findAnnotation(method, RedisTask.class);
                if (task != null) {
                    int parameterCount = method.getParameterCount();
                    if (parameterCount > 0) {
                        throw new IllegalStateException("RedisTask method parameter count: "
                                + parameterCount + ", except: 0");
                    }
                    TimedTask timedTask = new AnnoTimedTask(task, method, bean, this.redisTemplate);
                    tasks.add(timedTask);
                    logger.info("registerAnnotationTask > taskName: {}", timedTask.getTaskName());
                }
            }
        }
    }

    /**
     * 注解的任务封装类
     */
    private static class AnnoTimedTask extends AbstractTimedTask {

        private final RedisTask redisTask;
        private final Method method;
        private final Object bean;

        public AnnoTimedTask(RedisTask redisTask, Method method, Object bean, StringRedisTemplate redisTemplate) {
            super(redisTemplate);
            this.redisTask = redisTask;
            this.method = method;
            this.bean = bean;
        }

        @Override
        public String getTaskName() {
            String taskName = redisTask.value();
            if (StringUtils.isEmpty(taskName)) {
                return bean.getClass().getName() + ":" + method.getName();
            }
            return taskName;
        }

        @Override
        public long getInterval() {
            return redisTask.interval();
        }

        @Override
        public long getInitialDelay() {
            return redisTask.initialDelay();
        }

        @Override
        public TimeUnit getTimeUnit() {
            return redisTask.timeUnit();
        }

        @Override
        public void runOnce() {
            try {
                method.invoke(bean);
            } catch (Exception e) {
                logger.error("Run Redis TimedTask failed", e);
            }
        }
    }

}

------

其中获取本机ip还用了一个工具类 InetUtil, 如下

package com.idanchuang.component.core.util;

import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.NetworkInterface;
import java.net.SocketException;
import java.util.Enumeration;

/**
 * @author yjy
 * Created at 2021/3/13 6:25 下午
 */
public class InetUtil {

    /**
     * 获取本机内网ip
     * @return ip
     */
    public static String getLocalIp() {
        // 获得本机的所有网络接口
        try {
            Enumeration<NetworkInterface> nifs = NetworkInterface.getNetworkInterfaces();
            while (nifs.hasMoreElements()) {
                NetworkInterface nif = nifs.nextElement();
                if (nif.isPointToPoint() || nif.isLoopback()) {
                    continue;
                }
                if (nif.getName().startsWith("docker")) {
                    continue;
                }
                // 获得与该网络接口绑定的 IP 地址,一般只有一个
                Enumeration<InetAddress> addresses = nif.getInetAddresses();

                while (addresses.hasMoreElements()) {
                    InetAddress addr = addresses.nextElement();
                    if (addr instanceof Inet4Address) {
                        return addr.getHostAddress();
                    }
                }
            }

            return "127.0.0.1";
        } catch (SocketException e) {
            throw new IllegalStateException(e);
        }
    }

}
InetUtil

 

posted @ 2022-01-23 21:35  EEEEET  阅读(185)  评论(0编辑  收藏  举报