分布式限流 - 令牌桶

redis+lua方案

单机版使用guava

lua脚本,参照spring cloud gateway

src/main/resources/request_rate_limiter.lua

-- 令牌桶在redis中的key值
local tokens_key = KEYS[1]
-- 该令牌桶上一次刷新的时间对应的key的值
local timestamp_key = KEYS[2]
-- 令牌单位时间填充速率
local rate = tonumber(ARGV[1])
-- 令牌桶容量
local capacity = tonumber(ARGV[2])
-- 当前时间
local now = tonumber(ARGV[3])
-- 请求需要的令牌数
local requested = tonumber(ARGV[4])
-- 令牌桶容量/令牌填充速率=令牌桶填满所需的时间
local fill_time = capacity/rate
-- 令牌过期时间 填充时间*2
local ttl = math.floor(fill_time*2)
-- 获取上一次令牌桶剩余的令牌数
local last_tokens = tonumber(redis.call("get", tokens_key))
-- 如果没有获取到,可能是令牌桶是新的,之前不存在该令牌桶,或者该令牌桶已经好久没有使用
-- 过期了,这里需要对令牌桶进行初始化,初始情况,令牌桶是满的
if last_tokens == nil then
  last_tokens = capacity
end
-- 获取上一次刷新的时间,如果没有,或者已经过期,那么初始化为0
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
  last_refreshed = 0
end
-- 计算上一次刷新时间和本次刷新时间的时间差
local delta = math.max(0, now-last_refreshed)
-- delta*rate = 这个时间差可以填充的令牌数,
-- 令牌桶中先存在的令牌数 = 填充令牌数+令牌桶中原有的令牌数
-- 因为令牌桶有容量,所以如果计算的值大于令牌桶容量,那么以令牌桶容量为准
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
-- 判断令牌桶中的令牌数是都满本次请求需要的令牌数,如果不满足,说明被限流了
local allowed = filled_tokens >= requested

-- 这里声明了两个变量,一个是新的令牌数,一个是是否被限流,0代表限流,1代表没有限流
local new_tokens = filled_tokens
local allowed_num = 0
-- 如果没有被限流,即,filled_tokens >= requested,
-- 新的令牌数=刚刚计算好的令牌桶中存在的令牌数减掉本次需要使用的令牌数
-- 并设置限流结果为未限流
if allowed then
  new_tokens = filled_tokens - requested
  allowed_num = 1
end
-- 存储本次操作后,令牌桶中的令牌数以及本次刷新时间
if ttl > 0 then
  redis.call("setex", tokens_key, ttl, new_tokens)
  redis.call("setex", timestamp_key, ttl, now)
end
-- 返回是否被限流标志以及令牌桶剩余令牌数
return { allowed_num, new_tokens }

加载到java环境

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.scripting.support.ResourceScriptSource;

import java.util.List;

/**
 * @author YaLong
 * @date 2022/9/15 18:07
 */
@Configuration
public class RedisLuaScript {

    @Bean
    @SuppressWarnings("all")
    public RedisScript redisRateLimiterLua() {
        DefaultRedisScript redisRequestRateLimiterScript = new DefaultRedisScript<>();
        redisRequestRateLimiterScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("request_rate_limiter.lua")));
        //返回结果类型要和lua脚本的对应
        redisRequestRateLimiterScript.setResultType(List.class);
        return redisRequestRateLimiterScript;
    }

}

构造一个易用的实体类


import lombok.Data;

import java.time.Instant;

/**
 * @author YaLong
 */
@Data
public class RedisRateLimiter {

    public RedisRateLimiter(String key, double replenishRate, int burstCapacity) {
        this.key = key;
        this.replenishRate = replenishRate;
        this.burstCapacity = burstCapacity;
    }

    /**
     * 限流的键
     */
    String key;

    /**
     * 每秒生产令牌数量
     */
    double replenishRate;

    /**
     * 令牌桶容量
     */
    int burstCapacity;

    /**
     * 每次请求消耗令牌数量
     */
    double requestedTokens = 1;

    /**
     * 注意每次都要调用,不可使用同一个返回值
     */
    public String[] toArgs() {
        return new String[]{
                String.valueOf(this.replenishRate),
                String.valueOf(this.burstCapacity),
                String.valueOf(Instant.now().getEpochSecond()),
                String.valueOf(this.requestedTokens)
        };
    }
}

创建一个工具类

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;

import java.util.Arrays;
import java.util.List;

/**
 * @author YaLong
 * @date 2022/9/15 18:12
 */
@Component
public class RedisRateLimiterUtil {
    private static RedisScript redisLuaScript;
    private static StringRedisTemplate redisTemplate;

    @Autowired
    public void setRedisLuaScript(RedisScript redisLuaScript) {
        RedisRateLimiterUtil.redisLuaScript = redisLuaScript;
    }

    @Autowired
    public void setRedisTemplate(StringRedisTemplate redisTemplate) {
        RedisRateLimiterUtil.redisTemplate = redisTemplate;
    }

    private static List<String> getKeys(String id) {
        // use `{}` around keys to use Redis Key hash tags
        // this allows for using redis cluster

        // Make a unique key per user.
        String prefix = "request_rate_limiter.{" + id;

        // You need two Redis keys for Token Bucket.
        String tokenKey = prefix + "}.tokens";
        String timestampKey = prefix + "}.timestamp";
        return Arrays.asList(tokenKey, timestampKey);
    }

    /**
     * 获取频率
     *
     * @param limiter 参数
     * @return 第一个参数:0->超过频率 1-> 未超过频率
     * 第二个参数:剩余令牌数量
     */
    @SuppressWarnings("all")
    public static List<Long> getRate(RedisRateLimiter limiter) {
        String[] args = limiter.toArgs();
        List<String> keys = getKeys(limiter.getKey());
        return (List<Long>) RedisRateLimiterUtil.redisTemplate.execute(RedisRateLimiterUtil.redisLuaScript, keys, args);
    }


}

使用

/**
 * @author YaLong
 * @date 2022/9/15 16:56
 */
@RunWith(SpringRunner.class)
@SpringBootTest(classes = MaiKeSpApplication.class)
public class Test {

    private void aa() {
        RedisRateLimiter redisRateLimiter = new RedisRateLimiter("fafsa", 0.5, 5);
        List<Long> results = RedisRateLimiterUtil.getRate(redisRateLimiter);
        if (results.get(0) == 0L) {
            System.out.println("Too Many Requests!");
        }else {
            Long tokensLeft = results.get(1);
            System.out.println("tokens left ->" + tokensLeft);
        }

    }

    @org.junit.Test
    public void tt() throws InterruptedException {
        while (true) {
            aa();
            System.out.println("-------------");
            TimeUnit.MILLISECONDS.sleep(1000);
        }

    }

}

posted @ 2022-09-15 19:38  rm-rf*  阅读(286)  评论(0编辑  收藏  举报