不知道ZSet(有序集合)的看官们,可以翻阅我的上一篇文章:

小白也能看懂的REDIS教学基础篇——朋友面试被SKIPLIST跳跃表拦住了

 

       书接上回,话说我朋友小A童鞋,终于面世通过加入了一家公司。这个公司待遇比较丰厚,而且离小A住的地方也比较近,最让小A中意还是有个肯带他的大佬。小A对这份工作非常满意。时间一天一天过去,某个周末,小A来找我家吃蹭饭.在饭桌上小A给我分享了他上星期的一次事故经历。

        上个星期,他们公司出了比较严重的一个事故,一个导出报表的后台服务拖垮了报表数据服务,导致很多查询该服务的业务都受到了牵连。主要原因是因为导出的报表数据比较多,导致导出时间比较漫长。前端也未针对导出按钮做防重试限制。多个运营人员多次点击了导出按钮。加上后台服务配置的重试机制把这个流量放大了好几倍,最后拖垮了整个报表数据服务。老大让小A出个集群限流的方案防止下次在出现这类问题。

        这下小A慌了,单机限流好搞,使用 Hystrix 框架注解一加就完事。或者使用 Sentinel 在 Sentinel Dashboard 后台配置一下就完事。集群限流要怎么弄?小A苦思冥想了一个上午也没整出来,最后只能求助大佬帮助。

 

小A:大佬,老大让我出一个集群限流的方案,我以前对这个不熟,网站找的一堆,都是重复相同的感觉不靠谱,能教教我怎么弄吗?

大佬:莫慌莫慌,这件事情其实不难。我先考考你,限流的算法有哪些?

小A:... 我想想

小A:有了,常见的限流算法有以下三种:滑动窗口算法,令牌桶算法,漏桶算法。

大佬:对头,那你觉得单机限流和集群限流有什么区别呢?

小A:emmm...

大佬:你可以从集群和单机程序本身的区别去想想。

小A:我知道了,单机限流限流数据存在单机上,只能一个机器用。而集群是分布式多机器的,要让多个机器共享同一份限流数据,才能保证多机器的限流。

大佬:很好,那你还记得面试时候我问你的Zset(Sorted Sets)吗?用它就能很简单的实现一个滑动时间窗哦。

小A:大佬快教我!

大佬:那话不多说,现在进入正题。

 

如果不知道Zset数据结构的,可以先去看看我的这篇文章

小白也能看懂的REDIS教学基础篇——朋友面试被SKIPLIST跳跃表拦住了

 

首先来看一段实现滑动时间窗的lua代码(这是实现Redis滑动时间窗限流的核心代码)

 

-- 参数:
-- nowTime 当前时间
-- windowTime窗口时间
-- maxCount最大次数
-- expiredWindowTime 已经过期的窗口时间
-- value 请求标记
local nowTime = tonumber(ARGV[1]);
local windowTime = tonumber(ARGV[2]);
local maxCount = tonumber(ARGV[3]);
local expiredWindowTime = tonumber(ARGV[4])
local value = ARGV[5];
local key = KEYS[1];
-- 获取当前窗口的请求标志个数
local count = redis.call('ZCARD', key)
-- 比较当前已经请求的数量是否大于窗口最大请求数
if count >= maxCount then
    -- 如果大于最大请求数
    -- 删除过期的请求标志 释放窗口空间 等同于滑动时间窗口向前滑动
    redis.call('ZREMRANGEBYSCORE', key, 0, expiredWindowTime)
    -- 再次获取当前窗口的请求标志个数
    local count = redis.call('ZCARD', key)
    -- 延长过期时间
    redis.call('PEXPIRE', key, windowTime + 1000)
    -- 比较释放后的大小 是否小于窗口最大请求数
    if count < maxCount then
         -- 插入当前访问的访问标记
        redis.call('ZADD', key, nowTime, value)
        -- 返回200代表成功
        return 200
    else
        -- 返回500代表失败
        return 500
    end
else
     -- 插入当前访问的访问标记
    redis.call('ZADD', key, nowTime, value)
    -- 延长过期时间
    redis.call('PEXPIRE', key, windowTime + 1000)
    -- 返回200代表成功
    return 200
end

 

 

 

Redis接收到请求,开始执行 lua 脚本。根据 Key 找到对应的Zset也就是时间窗。获取当前时间窗内请求标记的数量,如果小于最大窗口允许访问最大次数,直接插入最新的请求标记 设置标记的score = 1642403014820 ms 。延长该窗口的过期时间,并返回成功。如下图所示:

 

 

 

 


如果获取到当前时间窗内请求标记的数量,大于或者等于窗口最大允许请求数。如下图所示,获取到当前时间窗内请求标记的数量为6,大于窗口最大允许请求数5。

 

 

 

 

则先根据时间窗大小删除窗口中已经过期的请求。当前请求的 score = 1642403014820 ms 时间窗口大小的 10000ms。那过期时间就是1642403014820 - 10000 = 1642403004820;那就删除 score < 1642403004820 的节点。

删除完成后,再次获取当前窗口中请求标记数量,可以看到当前数量为1小于窗口最大请求数。插入最新的请求标记 score = 1642403014820 ms 。延长该窗口的过期时间,并返回成功。


大佬:现在是否明白了一些?

小A:明白了,简单总结一下就是,利用Redis中现成的数据结构ZSet(有序集合)来做时间窗口。集合中的排序值为请求发生时的时间戳。在请求发生时,

统计时间窗口中总的请求数,如果总请求数小于窗口允许最大请求数,就插入一个请求标记,也就相当于窗口中请求数加一。如果总请求数大于或者等于窗口允许最大请求数,则需要删除过期的统计,以便释放足够的空间。删除的方式就是先计算出窗口的前边界,也就是已经失效的最大时间。根据这个时间戳,然后利用Zset原生的删除方法 ZREMRANGEBYSCORE key min max 删除小于最大失效时间的请求标记,其实这里的删除过期数据也就等同于滑动时间窗口向前滑动。删除完成后再次统计下窗口中剩余的请求数是否大于或者等于窗口最大请求数,如果大于就直接返回失败,告诉客户端拒绝该请求。如果小于就插入当前请求的请求标记 score 为当前请求的请求时间戳。至此完成了一次限流请求。可是我不明白为什么要使用 lua 脚本呢?这不是增加了维护成本吗?

大佬:不错。基本原理就是这样的,至于为什么使用 lua 脚本。那为了原子的执行多个命令和限流的判断逻辑。防止在你执行删除或者获取总数的命令时,其他人也在执行导致数据不准确,从而使限流失败。

小A:嗯嗯。明白了。

大佬:不过这个限流也存在不足。比如需要设置一个10秒内允许访问100万次的请求,它就不合适,因为这样窗口中会有100万个请求几率,会消耗大量的内存空间。切记!不要盲目使用,要根据自己业务量来综合考量。

小A:好的,大佬,记住了!


 

大佬:了解完这个,接下来我们来看一下JAVA代码怎么写。

首先我们是通过 Spring AOP 和 标记注解 @CurrentLimiting 来实现限流方案的。

    @GetMapping("getId")
    @CurrentLimiting(value = "getId",
            // ErrorCallback 是错误回调 callback 是 bean Name callbackClass 是实现类 Class
            errorCallback = @ErrorCallback(callback = "redisCurrentLimitingDegradeCallbackImpl", callbackClass = RedisCurrentLimitingErrorCallbackImpl.class)
,
            // DegradeCallback 是降级回调 callback 是 bean Name callbackClass 是实现类 Class
            degradeCallback = @DegradeCallback(callback = "redisCurrentLimitingDegradeCallbackImpl", callbackClass = RedisCurrentLimitingDegradeCallbackImpl.class))
    public Integer getId(){
        return 1;
    }

 

下面介绍AOP切面类:


package com.raiden.redis.current.limiter.aop;
​
import com.raiden.redis.current.limiter.RedisCurrentLimiter;
import com.raiden.redis.current.limiter.annotation.CurrentLimiting;
import com.raiden.redis.current.limiter.annotation.DegradeCallback;
import com.raiden.redis.current.limiter.annotation.ErrorCallback;
import com.raiden.redis.current.limiter.callbock.RedisCurrentLimitingDegradeCallback;
import com.raiden.redis.current.limiter.chain.ErrorCallbackChain;
import com.raiden.redis.current.limiter.info.RedisCurrentLimiterInfo;
import com.raiden.redis.current.limiter.properties.RedisCurrentLimiterProperties;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.context.ApplicationContext;
import org.springframework.core.annotation.AnnotationUtils;
​
import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
​
​
/**
 * @创建人:Raiden
 * @Descriotion:
 * @Date:Created in 23:51 2020/8/27
 * @Modified By:
 * 限流AOP切面类
 */
@Aspect
public class RedisCurrentLimitingAspect {
​
    private Map<String, RedisCurrentLimiterInfo> config;
    private ApplicationContext context;
​
    private ConcurrentHashMap<Method, ErrorCallbackChain> errorCallbackChainCache;
​
    private ConcurrentHashMap<Method, RedisCurrentLimitingDegradeCallback> degradeCallbackCache;
​
    public RedisCurrentLimitingAspect(ApplicationContext context, RedisCurrentLimiterProperties properties){
        this.context = context;
        this.config = properties.getConfig();
        this.errorCallbackChainCache = new ConcurrentHashMap<>();
        this.degradeCallbackCache = new ConcurrentHashMap<>();
    }
​
    @Pointcut("@annotation(com.raiden.redis.current.limiter.annotation.CurrentLimiting) || @within(com.raiden.redis.current.limiter.annotation.CurrentLimiting)")
    public void intercept(){}
​
    @Around("intercept()")
    public Object currentLimitingHandle(ProceedingJoinPoint joinPoint) throws Throwable{
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        CurrentLimiting annotation = AnnotationUtils.findAnnotation(method, CurrentLimiting.class);
        if (annotation == null){
            annotation = method.getDeclaringClass().getAnnotation(CurrentLimiting.class);
        }
        String path = annotation.value();
        //如果没有配置 资源 直接 放过
        //如果没有找到限流配置 也放过
        RedisCurrentLimiterInfo info;
        if (path != null && !path.isEmpty() && (info = config.get(path)) != null){
            try {
                //查看是否需要限流
                boolean allowAccess = RedisCurrentLimiter.isAllowAccess(path, info.getWindowTime(), info.getWindowTimeUnit(), info.getMaxCount());
                if (allowAccess){
                    return joinPoint.proceed();
                }else {
                    //获取降级处理器
                    RedisCurrentLimitingDegradeCallback currentLimitingDegradeCallback = degradeCallbackCache.get(method);
                    if (currentLimitingDegradeCallback == null){
                        degradeCallbackCache.putIfAbsent(method, getRedisCurrentLimitingDegradeCallback(annotation));
                    }
                    currentLimitingDegradeCallback = degradeCallbackCache.get(method);
                    //调用降级回调
                    return currentLimitingDegradeCallback.callback();
                }
            }catch (Throwable e){
                //如果报错 交给 错误回调
                ErrorCallbackChain errorCallbackChain = errorCallbackChainCache.get(method);
                if (errorCallbackChain == null){
                    ErrorCallback[] errorCallbacks = annotation.errorCallback();
                    if (errorCallbacks.length == 0){
                        throw e;
                    }
                    //放入错误回调缓存
                    errorCallbackChainCache.putIfAbsent(method, new ErrorCallbackChain(errorCallbacks, context));
                }
                errorCallbackChain = errorCallbackChainCache.get(method);
                return errorCallbackChain.execute(e);
            }
        }
        return joinPoint.proceed();
    }
​
    private RedisCurrentLimitingDegradeCallback getRedisCurrentLimitingDegradeCallback(CurrentLimiting annotation) throws IllegalAccessException, InstantiationException {
        DegradeCallback degradeCallback = annotation.degradeCallback();
        String callback = degradeCallback.callback();
        if (callback == null || callback.isEmpty()){
            return degradeCallback.callbackClass().newInstance();
        }else {
            return context.getBean(degradeCallback.callback(), degradeCallback.callbackClass());
        }
    }
}
RedisCurrentLimiter Redis时间窗限流执行类:

package com.raiden.redis.current.limiter;
​
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.scripting.support.ResourceScriptSource;
​
import java.net.Inet4Address;
import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.concurrent.TimeUnit;
​
​
/**
 * @创建人:Raiden
 * @Descriotion:
 * @Date:Created in 23:51 2020/8/27
 * @Modified By:
 */
public final class RedisCurrentLimiter {
​
    private static final String CURRENT_LIMITER = "CurrentLimiter:";;
​
    private static String ip;;
​
    private static RedisTemplate redis;
​
    private static ResourceScriptSource resourceScriptSource;
​
    protected static void init(RedisTemplate redisTemplate){
        if (redisTemplate == null){
            throw new NullPointerException("The parameter cannot be null");
        }
        try {
            ip = Inet4Address.getLocalHost().getHostAddress().replaceAll("\\.", "");
        } catch (UnknownHostException e) {
            throw new RuntimeException(e);
        }
        redis = redisTemplate;
        //lua文件存放在resources目录下的redis文件夹内
        resourceScriptSource = new ResourceScriptSource(new ClassPathResource("redis/redis-current-limiter.lua"));
    }
​
    public static boolean isAllowAccess(String path, int windowTime, TimeUnit windowTimeUnit, int maxCount){
        if (redis == null){
            throw new NullPointerException("Redis is not initialized !");
        }
        if (path == null || path.isEmpty()){
            throw new IllegalArgumentException("The path parameter cannot be empty !");
        }
        //获取 key
        final String key = new StringBuffer(CURRENT_LIMITER).append(path).toString();
        //获取当前时间戳
        long now = System.currentTimeMillis();
        //获取窗口时间 并转换为 毫秒
        long windowTimeMillis = windowTimeUnit.toMillis(windowTime);
        //调用lua脚本并执行
        DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
        //设置返回类型是Long
        redisScript.setResultType(Long.class);
        //设置 lua 脚本源代码
        redisScript.setScriptSource(resourceScriptSource);
        //执行 lua 脚本
        Long result = (Long) redis.execute(redisScript, Arrays.asList(key), now, windowTimeMillis, maxCount, now - windowTimeMillis, createValue(now));
        //获取到返回值
        return result.intValue() == 200;
    }
​
    private static String createValue(long now){
        return new StringBuilder(ip).append(now).append(Math.random() * 100).toString();
    }
}
RedisCurrentLimiterConfiguration Redis滑动时间窗限流配置类:

package com.raiden.redis.current.limiter.config;
​
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.fasterxml.jackson.module.paramnames.ParameterNamesModule;
import com.raiden.redis.current.limiter.RedisCurrentLimiterInit;
import com.raiden.redis.current.limiter.aop.RedisCurrentLimitingAspect;
import com.raiden.redis.current.limiter.common.PublicString;
import com.raiden.redis.current.limiter.properties.RedisCurrentLimiterProperties;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Primary;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.RedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
​
/**
 * @创建人:Raiden
 * @Descriotion:
 * @Date:Created in 23:51 2020/8/27
 * @Modified By:
 */
@EnableConfigurationProperties(RedisCurrentLimiterProperties.class)
// 判断 配置中是否 有 redis.current-limiter.enabled = true 才加载
@ConditionalOnProperty(
        name = {"redis.current-limiter.enabled"}
)
public class RedisCurrentLimiterConfiguration {
​
    /**
     * AOP 切面配置
     * @param properties
     * @param context
     * @return
     */
    @Bean
    public RedisCurrentLimitingAspect redisCurrentLimitingAspect(RedisCurrentLimiterProperties properties, ApplicationContext context){
        return new RedisCurrentLimitingAspect(context, properties);
    }
​
    /**
     * RedisTemplate配置
     * @param redisConnectionFactory
     * @return
     */
    @Bean(PublicString.REDIS_CURRENT_LIMITER_REDIS_TEMPLATE)
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
        // 设置序列化
        Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<Object>(
                Object.class);
        ObjectMapper om = new ObjectMapper()
                .registerModule(new ParameterNamesModule())
                .registerModule(new Jdk8Module())
                .registerModule(new JavaTimeModule());
        om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
        jackson2JsonRedisSerializer.setObjectMapper(om);
        // 配置redisTemplate
        RedisTemplate<String, Object> redisTemplate = new RedisTemplate<String, Object>();
        redisTemplate.setConnectionFactory(redisConnectionFactory);
        RedisSerializer stringSerializer = new StringRedisSerializer();
        // key序列化
        redisTemplate.setKeySerializer(stringSerializer);
        // value序列化
        redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);
        // Hash key序列化
        redisTemplate.setHashKeySerializer(stringSerializer);
        // Hash value序列化
        redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
        redisTemplate.afterPropertiesSet();
        return redisTemplate;
    }
​
    /**
     * 生成 Redis 滑动时间窗限流器 初始化类
     * @param redisTemplate
     * @return
     */
    @Bean
    @ConditionalOnBean(name = PublicString.REDIS_CURRENT_LIMITER_REDIS_TEMPLATE)
    public RedisCurrentLimiterInit redisCurrentLimiterInit(RedisTemplate<String, Object> redisTemplate){
        return new RedisCurrentLimiterInit(redisTemplate);
    }
}

大佬:好了,Java代码实战也带你看过了,现在你学废了吗?如果觉得好的话请点个赞哟。

 

代码github地址:

https://github.com/RaidenXin/redis-current-limiter.git

有需要的同学可以拉下来看看。