Springboot基于Guava+自定义注解实现IP或自定义key限流 升级版

Springboot基于Guava+自定义注解实现IP或自定义key限流 升级版

2020年5月17日 凌晨 有人恶意刷接口,刚喝完酒回来 大晚上的给我搞事情。。。。

之前版本Springboot基于Guava+自定义注解实现限流功能是对访问这个接口所有人总的QPS限制,如果我们想对某一个用户或Ip地址访问接口的QPS限制,限制恶意请求接口的人而不影响正常的用户请求访问。

实现步骤

1、添加POM依赖
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-web</artifactId>
</dependency>

<!--AOP相关-->
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-aop</artifactId>
</dependency>

<dependency>
    <groupId>org.projectlombok</groupId>
    <artifactId>lombok</artifactId>
    <optional>true</optional>
</dependency>

<!-- guava -->
<dependency>
    <groupId>com.google.guava</groupId>
    <artifactId>guava</artifactId>
    <version>29.0-jre</version>
</dependency>
2、定义注解
package com.example.guavalimit.limit;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @Description 自定义限流注解
 * @Author jie.zhao
 * @Date 2020/5/17 11:49
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface LxRateLimit {

    //资源名称
    String name() default "默认资源";

    //限制每秒访问次数,默认为3次
    double perSecond() default 3;

    /**
     * 限流Key类型
     * 自定义根据业务唯一码来限制需要在请求参数中添加 String limitKeyValue
     */
    LimitKeyTypeEnum limitKeyType() default LimitKeyTypeEnum.IPADDR;

}
3、定义切面
package com.example.guavalimit.limit;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.util.concurrent.RateLimiter;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.util.concurrent.TimeUnit;

/**
 * @Description 基于Guava cache缓存存储实现限流切面
 * @Author jie.zhao
 * @Date 2020/5/17 11:51
 */
@Slf4j
@Aspect
@Component
public class LxRateLimitAspect {

    /**
     * 缓存
     * maximumSize 设置缓存个数
     * expireAfterWrite 写入后过期时间
     */
    private static LoadingCache<String, RateLimiter> limitCaches = CacheBuilder.newBuilder()
            .maximumSize(1000)
            .expireAfterWrite(1, TimeUnit.DAYS)
            .build(new CacheLoader<String, RateLimiter>() {
                @Override
                public RateLimiter load(String key) throws Exception {
                    double perSecond = LxRateLimitUtil.getCacheKeyPerSecond(key);
                    return RateLimiter.create(perSecond);
                }
            });

    /**
     * 切点
     * 通过扫包切入 @Pointcut("execution(public * com.ycn.springcloud.*.*(..))")
     * 带有指定注解切入 @Pointcut("@annotation(com.ycn.springcloud.annotation.LxRateLimit)")
     */
    @Pointcut("@annotation(com.example.guavalimit.limit.LxRateLimit)")
    public void pointcut() {
    }

    @Around("pointcut()")
    public Object around(ProceedingJoinPoint point) throws Throwable {
        log.info("限流拦截到了{}方法...", point.getSignature().getName());
        HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        if (method.isAnnotationPresent(LxRateLimit.class)) {
            String cacheKey = LxRateLimitUtil.generateCacheKey(method, request);
            RateLimiter limiter = limitCaches.get(cacheKey);
            if (!limiter.tryAcquire()) {
                throw new LimitAccessException("【限流】这位小同志的手速太快了");
            }
        }
        return point.proceed();
    }
}

4、枚举
package com.example.guavalimit.limit;

/**
 * @Description 限流key类型枚举
 * @Author jie.zhao
 * @Date 2020/5/17 14:28
 */
public enum LimitKeyTypeEnum {

    IPADDR("IPADDR", "根据Ip地址来限制"),
    CUSTOM("CUSTOM", "自定义根据业务唯一码来限制,需要在请求参数中添加 String limitKeyValue");

    private String keyType;
    private String desc;

    LimitKeyTypeEnum(String keyType, String desc) {
        this.keyType = keyType;
        this.desc = desc;
    }

    public String getKeyType() {
        return keyType;
    }

    public String getDesc() {
        return desc;
    }
}

5、工具类
package com.example.guavalimit.limit;

import org.springframework.util.StringUtils;

import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.net.InetAddress;
import java.net.UnknownHostException;

/**
 * @Description 限流工具类
 * @Author jie.zhao
 * @Date 2020/5/17 15:37
 */
public class LxRateLimitUtil {


    /**
     * 获取唯一key根据注解类型
     * <p>
     * 规则 资源名:业务key:perSecond
     *
     * @param method
     * @param request
     * @return
     */
    public static String generateCacheKey(Method method, HttpServletRequest request) {
        //获取方法上的注解
        LxRateLimit lxRateLimit = method.getAnnotation(LxRateLimit.class);
        StringBuffer cacheKey = new StringBuffer(lxRateLimit.name() + ":");
        switch (lxRateLimit.limitKeyType()) {
            case IPADDR:
                cacheKey.append(getIpAddr(request) + ":");
                break;
            case CUSTOM:
                String limitKeyValue = request.getParameter("limitKeyValue");
                if (StringUtils.isEmpty(limitKeyValue)) {
                    throw new LimitAccessException("【" + method.getName() + "】自定义业务Key缺少参数String limitKeyValue,或者参数为空");
                }
                cacheKey.append(limitKeyValue + ":");
                break;
        }
        cacheKey.append(lxRateLimit.perSecond());
        return cacheKey.toString();
    }

    /**
     * 获取缓存key的限制每秒访问次数
     * <p>
     * 规则 资源名:业务key:perSecond
     *
     * @param cacheKey
     * @return
     */
    public static double getCacheKeyPerSecond(String cacheKey) {
        String perSecond = cacheKey.split(":")[2];
        return Double.parseDouble(perSecond);
    }

    /**
     * 获取客户端IP地址
     *
     * @param request 请求
     * @return
     */
    public static String getIpAddr(HttpServletRequest request) {
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
            if ("127.0.0.1".equals(ip)) {
                //根据网卡取本机配置的IP
                InetAddress inet = null;
                try {
                    inet = InetAddress.getLocalHost();
                } catch (UnknownHostException e) {
                    e.printStackTrace();
                }
                ip = inet.getHostAddress();
            }
        }
        // 对于通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照','分割
        if (ip != null && ip.length() > 15) {
            if (ip.indexOf(",") > 0) {
                ip = ip.substring(0, ip.indexOf(","));
            }
        }
        if ("0:0:0:0:0:0:0:1".equals(ip)) {
            ip = "127.0.0.1";
        }
        return ip;
    }
}

6、自定义异常
package com.example.guavalimit.limit;

/**
 * @Description 限流自定义异常
 * @Author jie.zhao
 * @Date 2019/8/7 16:01
 */
public class LimitAccessException extends RuntimeException {

    private static final long serialVersionUID = -3608667856397125671L;

    public LimitAccessException(String message) {
        super(message);
    }
}

7、测试controller
package com.example.guavalimit.controller;

import com.example.guavalimit.limit.LimitKeyTypeEnum;
import com.example.guavalimit.limit.LxRateLimit;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

@RestController
public class TestController {

    @GetMapping("/test1")
    @LxRateLimit(perSecond = 1, limitKeyType = LimitKeyTypeEnum.IPADDR)
    public String test1() {
        return "SUCCESS";
    }

    @GetMapping("/test2")
    @LxRateLimit(perSecond = 1, limitKeyType = LimitKeyTypeEnum.CUSTOM)
    public String test2(String limitKeyValue) {
        return "SUCCESS";
    }
}

posted @ 2020-05-17 15:56  趙小傑  阅读(1846)  评论(0编辑  收藏  举报