Redis+lua脚本配合AOP限流

限流

Redis脚本限流

脚本配合切面注解

定义注解:

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter
{
    /**
     * 限流key
     */
    public String key() default CacheConstants.RATE_LIMIT_KEY;

    /**
     * 限流时间,单位秒
     */
    public int time() default 60;

    /**
     * 限流次数
     */
    public int count() default 100;

    /**
     * 限流类型
     */
    public LimitType limitType() default LimitType.DEFAULT;
}

定义切面:

@Aspect
@Component
public class RateLimiterAspect
{
    private static final Logger log = LoggerFactory.getLogger(RateLimiterAspect.class);

    private RedisTemplate<Object, Object> redisTemplate;

    private RedisScript<Long> limitScript;

    @Autowired
    public void setRedisTemplate1(RedisTemplate<Object, Object> redisTemplate)
    {
        this.redisTemplate = redisTemplate;
    }

    @Autowired
    public void setLimitScript(RedisScript<Long> limitScript)
    {
        this.limitScript = limitScript;
    }

    @Before("@annotation(rateLimiter)")
    public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable
    {
        int time = rateLimiter.time();
        int count = rateLimiter.count();

        String combineKey = getCombineKey(rateLimiter, point);
        List<Object> keys = Collections.singletonList(combineKey);
        try
        {
            Long number = redisTemplate.execute(limitScript, keys, count, time);
            if (StringUtils.isNull(number) || number.intValue() > count)
            {
                throw new ServiceException("访问过于频繁,请稍候再试");
            }
            log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.intValue(), combineKey);
        }
        catch (ServiceException e)
        {
            throw e;
        }
        catch (Exception e)
        {
            throw new RuntimeException("服务器限流异常,请稍候再试");
        }
    }

    public String getCombineKey(RateLimiter rateLimiter, JoinPoint point)
    {
        StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
        if (rateLimiter.limitType() == LimitType.IP)
        {
            stringBuffer.append(IpUtils.getIpAddr()).append("-");
        }
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        Class<?> targetClass = method.getDeclaringClass();
        stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
        return stringBuffer.toString();
    }
}

在Redis配置类注入Redis脚本:

@Configuration
@EnableCaching
public class RedisConfig extends CachingConfigurerSupport
{
    @Bean
    @SuppressWarnings(value = { "unchecked", "rawtypes" })
    public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory connectionFactory)
    {
        RedisTemplate<Object, Object> template = new RedisTemplate<>();
        template.setConnectionFactory(connectionFactory);

        FastJson2JsonRedisSerializer serializer = new FastJson2JsonRedisSerializer(Object.class);

        // 使用StringRedisSerializer来序列化和反序列化redis的key值
        template.setKeySerializer(new StringRedisSerializer());
        template.setValueSerializer(serializer);

        // Hash的key也采用StringRedisSerializer的序列化方式
        template.setHashKeySerializer(new StringRedisSerializer());
        template.setHashValueSerializer(serializer);

        template.afterPropertiesSet();
        return template;
    }

    @Bean
    public DefaultRedisScript<Long> limitScript()
    {
        DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
        redisScript.setScriptText(limitScriptText());
        redisScript.setResultType(Long.class);
        return redisScript;
    }

    /**
     * 限流脚本
     */
    private String limitScriptText()
    {
        return "local key = KEYS[1]\n" +
                "local count = tonumber(ARGV[1])\n" +
                "local time = tonumber(ARGV[2])\n" +
                "local current = redis.call('get', key);\n" +
                "if current and tonumber(current) > count then\n" +
                "    return tonumber(current);\n" +
                "end\n" +
                "current = redis.call('incr', key)\n" +
                "if tonumber(current) == 1 then\n" +
                "    redis.call('expire', key, time)\n" +
                "end\n" +
                "return tonumber(current);";
    }
}

脚本解释:

在Redis中的execute方法最终还是通过调用eval方法执行的,可以通过源码查看:

image

在Redis客户端使用eval命令:

image

将换行符去掉之后的Redis脚本是这样的:

		local key = KEYS[1] 
        local count = tonumber(ARGV[1]) 
        local time = tonumber(ARGV[2]) 
        local current = redis.call('get', key); 
        if current and tonumber(current) > count then 
            return tonumber(current); 
        end 
        current = redis.call('incr', key) 
        if tonumber(current) == 1 then 
            redis.call('expire', key, time) 
        end 
        return tonumber(current);

根据传递的参数去理解lua脚本: Long number = redisTemplate.execute(limitScript, keys, count, time);

首先local key = KEYS[1] 获取的是我们存在Redis中的key键;local count = tonumber(ARGV[1])拿到第一个参数;local current = redis.call('get', key); 调用Redis的get方法获取值,当前已经存在多少请求存在Redis内存数据库中;

判断如果当前大于指定的count,返回一个Number,根据数据去决定抛出异常;否则说明没有达到限流阈值,调用Redis的incr方法将键值自增1;如果值为1,将键设置过期时间


我感觉下面这个脚本更好,使用zremrangebyscore命令移除滑动窗口之外的数据

-- 获取zset的key
local key = KEYS[1]

-- 脚本传入的限流大小
local limit = tonumber(ARGV[1])

-- 脚本传入的限流起始时间戳
local start = tonumber(ARGV[2])

-- 脚本传入的限流当前时间戳
local now = tonumber(ARGV[3])

-- 脚本传入的限流当前时间戳
local uuid = ARGV[4]

-- 获取当前流量总数
local count = tonumber(redis.call('zcount',key, start, now))

--是否超出限流值
if count + 1 >limit then
    return false
-- 不需要限流
else
    -- 添加当前访问时间戳到zset
    redis.call('zadd', key, now, uuid)
    -- 移除时间区间以外不用的数据,不然会导致zset过大
    redis.call('zremrangebyscore',key, 0, start)
    return true
end

根据文件路径注入脚本

@Configuration
public class rateConfig {

    @Bean
    public RedisScript<Boolean> loadRedisScript(){
        DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
        //lua脚本路径
        redisScript.setLocation(new ClassPathResource("luaScript/rateLimit.lua"));
        //lua脚本返回值
        redisScript.setResultType(java.lang.Boolean.class);
        return redisScript;
    }

}

切面:

@Aspect
@Component
public class ratelimitAspect {

    @Autowired
    private StringRedisTemplate stringRedisTemplate;

    @Autowired
    private RedisScript<Boolean> rateLimitScript;


    @Pointcut("@annotation(com.dk.aspect.rateLimit)")
    public void pointCut(){}

    @Before("pointCut() && @annotation(rateLimit)")
    public void before(rateLimit rateLimit) throws Throwable {
        //注解上的参数信息
        int limit = rateLimit.limit();
        String name = rateLimit.rateName();
        //当前时间戳
        long now = System.currentTimeMillis();
        //调用lua脚本获取限流结果
        Boolean isAccess = stringRedisTemplate.execute(
                //lua限流脚本
                rateLimitScript,
                //限流资源名称
                Collections.singletonList(name),
                //限流大小
                String.valueOf(limit),
                //限流窗口的左区间
                String.valueOf(now - 1000),
                //限流窗口的左区间
                String.valueOf(now),
                //id值,保证zset集合里面不重复,不然会覆盖
                UUID.randomUUID().toString()
        );

        if (!isAccess){
            throw new rateLimitException();
        }
    }
}

滑动窗口算法

package test.slidewindowlimit;

import redis.clients.jedis.Jedis;
import redis.clients.jedis.Pipeline;
import redis.clients.jedis.Response;

import java.io.IOException;

/**
 * <p>
 * 通过zset实现滑动窗口算法限流
 * </p>
 */
public class SimpleSlidingWindowByZSet {

    private Jedis jedis;

    public SimpleSlidingWindowByZSet(Jedis jedis) {
        this.jedis = jedis;
    }

    /**
     * 判断行为是否被允许
     * 针对用户的某个行为,多少秒内限访问多少次
     * 例如:用户user1在60秒内只能访问某个接口的10次
     * 实现思路:使用zset,将时间戳作为score,请求次数作为value,然后根据时间戳进行排序
     *
     * @param userId    用户id
     * @param actionKey 行为key
     * @param period    限流周期
     * @param maxCount  最大请求次数(滑动窗口大小)
     * @return
     */
    public boolean isActionAllowed(String userId, String actionKey, int period, int maxCount) throws IOException {
        String key = this.key(userId, actionKey);
        long ts = System.currentTimeMillis();
        Pipeline pipe = jedis.pipelined();
        pipe.multi();
        // zadd(key, score, value)
        pipe.zadd(key, ts, String.valueOf(ts));
        // 移除滑动窗口之外的数据
        pipe.zremrangeByScore(key, 0, ts - (period * 1000));
        Response<Long> count = pipe.zcard(key);// zcard用于计算集合中的数量
        // 设置行为的过期时间,如果数据为冷数据,zset将会删除以此节省内存空间
        pipe.expire(key, period);
        pipe.exec();
        pipe.close();
        return count.get() <= maxCount;
    }


    /**
     * 限流key
     *
     * @param userId
     * @param actionKey
     * @return
     */
    public String key(String userId, String actionKey) {
        return String.format("limit:%s:%s", userId, actionKey);
    }

}

测试:

public class TestSimpleSlidingWindowByZSet {

    public static void main(String[] args) throws IOException {
        Jedis jedis = new Jedis("127.0.0.1", 6379);
        SimpleSlidingWindowByZSet slidingWindow = new SimpleSlidingWindowByZSet(jedis);
        for (int i = 1; i <= 15; i++) {
            boolean actionAllowed = slidingWindow.isActionAllowed("liziba", "view", 60, 100);

            System.out.println("第" + i + "次操作" + (actionAllowed ? "成功" : "失败"));
        }

        jedis.close();
    }

}

扩展

pipeline 只是把多个redis指令一起发出去,redis并没有保证这些指定的执行是原子的;multi相当于一个redis的transaction的,保证整个操作的原子性,避免由于中途出错而导致最后产生的数据不一致。通过测试得知,pipeline方式执行效率要比其他方式高10倍左右的速度,启用multi写入要比没有开启慢一点

标记一个事务块的开始。 事务块内的多条命令会按照先后顺序被放进一个队列当中,最后由 EXEC 命令原子性(atomic)地执行。

而pipeline:客户端将执行的命令写入到缓冲中,最后由exec命令一次性发送给redis执行返回。

1.muti 2.执行命令 3.exec

需要注意的是 pipeline中的命令并不是原子性执行的也就是说管道中的命令到达 Redis服务器的时候可能会被其他的命令穿插

image

image

Redis事务失败的原因是加入了一条错误的命令incr hash 3;而不是incr noexistkey;因为Redis中不存在该key时,会默认从0开始,加1就得到1,而hash不是一个可自增的字段,所以是错误命令,导致整个事务失败进行回滚,之后使用get name操作得到的值仍然是之前的13

  1. multi不允许嵌套
  2. multi事务执行报错了不会继续执行,回滚
    使用discard可以丢弃事务,在事务开启前添加watch命令观察变量,如果变量在事务中被改变了事务回滚

Redis分布式锁—SETNX+Lua脚本实现篇 - niceyoo - 博客园 (cnblogs.com)

posted @   好滴都  阅读(68)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
点击右上角即可分享
微信分享提示