springboot Redis + HandlerInerceptor实现接口防刷
说明
代码剥离自其他老哥的项目,项目地址——https://github.com/wangzaiplus/springboot/tree/wxw
再此记录下来以备不时之需
创建注解
/**
* 在需要保证 接口防刷限流 的Controller的方法上使用此注解
*/
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface AccessLimit {
int maxCount();// 最大访问次数
int seconds();// 固定时间, 单位: s
}
配置拦截器
实现 HandlerInterceptor 类,该类包含三个方法:
-
preHanle(HttpServletRequest request, HttpServletResponse response, Object handler)
执行 handler 前调用 -
postHandle(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, Object o, ModelAndView modelAndView)
在处理完业务,返回视图前调用 -
afterCompletion(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, Object o, Exception e)
所有东西完成后调用
此处只需要实现 preHandle
/**
* 接口防刷限流拦截器
*/
public class AccessLimitInterceptor implements HandlerInterceptor {
@Autowired
private JedisUtil jedisUtil;
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) {
//判断是否Handler方法
if (!(handler instanceof HandlerMethod)) {
return true;
}
HandlerMethod handlerMethod = (HandlerMethod) handler;
Method method = handlerMethod.getMethod();
//获取该方法上的AccessLimit注解
AccessLimit annotation = method.getAnnotation(AccessLimit.class);
if (annotation != null) {
check(annotation, request);
}
return true;
}
//检查是否可访问
private void check(AccessLimit annotation, HttpServletRequest request) {
int maxCount = annotation.maxCount();
int seconds = annotation.seconds();
StringBuilder sb = new StringBuilder();
//按规则拼接字符串,查看redis是否包含该key
sb.append(Constant.Redis.ACCESS_LIMIT_PREFIX).append(IpUtil.getIpAddress(request)).append(request.getRequestURI());
String key = sb.toString();
Boolean exists = jedisUtil.exists(key);
if (!exists) {
jedisUtil.set(key, String.valueOf(1), seconds);
} else {
int count = Integer.valueOf(jedisUtil.get(key));
if (count < maxCount) {
Long ttl = jedisUtil.ttl(key);
if (ttl <= 0) {
jedisUtil.set(key, String.valueOf(1), seconds);
} else {
jedisUtil.set(key, String.valueOf(++count), ttl.intValue());
}
} else {
throw new ServiceException(ResponseCode.ACCESS_LIMIT.getMsg());
}
}
}
工具类
JedisUtil:
@Component
@Slf4j
public class JedisUtil {
@Autowired(required = false)
private JedisPool jedisPool;
private Jedis getJedis() {
return jedisPool.getResource();
}
/**
* 设值
*
* @param key
* @param value
* @return
*/
public String set(String key, String value) {
Jedis jedis = null;
try {
jedis = getJedis();
return jedis.set(key, value);
} catch (Exception e) {
log.error("set key: {} value: {} error", key, value, e);
return null;
} finally {
close(jedis);
}
}
/**
* 设值
*
* @param key
* @param value
* @param expireTime 过期时间, 单位: s
* @return
*/
public String set(String key, String value, int expireTime) {
Jedis jedis = null;
try {
jedis = getJedis();
return jedis.setex(key, expireTime, value);
} catch (Exception e) {
log.error("set key:{} value:{} expireTime:{} error", key, value, expireTime, e);
return null;
} finally {
close(jedis);
}
}
/**
* 设值
*
* @param key
* @param value
* @return
*/
public Long setnx(String key, String value) {
Jedis jedis = null;
try {
jedis = getJedis();
return jedis.setnx(key, value);
} catch (Exception e) {
log.error("set key:{} value:{} error", key, value, e);
return null;
} finally {
close(jedis);
}
}
/**
* 取值
*
* @param key
* @return
*/
public String get(String key) {
Jedis jedis = null;
try {
jedis = getJedis();
return jedis.get(key);
} catch (Exception e) {
log.error("get key:{} error", key, e);
return null;
} finally {
close(jedis);
}
}
/**
* 删除key
*
* @param key
* @return
*/
public Long del(String key) {
Jedis jedis = null;
try {
jedis = getJedis();
return jedis.del(key.getBytes());
} catch (Exception e) {
log.error("del key:{} error", key, e);
return null;
} finally {
close(jedis);
}
}
/**
* 判断key是否存在
*
* @param key
* @return
*/
public Boolean exists(String key) {
Jedis jedis = null;
try {
jedis = getJedis();
return jedis.exists(key.getBytes());
} catch (Exception e) {
log.error("exists key:{} error", key, e);
return null;
} finally {
close(jedis);
}
}
/**
* 设值key过期时间
*
* @param key
* @param expireTime 过期时间, 单位: s
* @return
*/
public Long expire(String key, int expireTime) {
Jedis jedis = null;
try {
jedis = getJedis();
return jedis.expire(key.getBytes(), expireTime);
} catch (Exception e) {
log.error("expire key:{} error", key, e);
return null;
} finally {
close(jedis);
}
}
/**
* 获取剩余时间
*
* @param key
* @return
*/
public Long ttl(String key) {
Jedis jedis = null;
try {
jedis = getJedis();
return jedis.ttl(key);
} catch (Exception e) {
log.error("ttl key:{} error", key, e);
return null;
} finally {
close(jedis);
}
}
private void close(Jedis jedis) {
if (null != jedis) {
jedis.close();
}
}
}
IpUtil:
public class IpUtil {
/**
* 获取客户端真实ip地址
*
* @param request
* @return
*/
public static String getIpAddress(HttpServletRequest request) {
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_CLIENT_IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
return ip;
}
}