使用AOP注解方式在controller接口上实现的请求接口限流
一:核心逻辑
package com.simple.common.aop;
import com.simple.common.model.ErrorCode;
import com.simple.common.model.OperationRateLimit;
import com.simple.common.model.ServiceException;
import com.simple.common.model.OperationRateLimit.VoidOperatorKeyGetter;
import com.simple.common.util.JsonUtil;
import com.google.common.base.Joiner;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.springframework.aop.support.AopUtils;
public class OperationRateLimitAspect { private static final int CONCURRENCY_LEVEL = 8; private static final Joiner joiner = Joiner.on("|").useForNull(""); private final Logger errorLogger; private final OperationRateLimitAspect.OperatorKeyGetterRouter router; private final OperationRateLimitAspect.KeyValueStore keyValueStore; private final OperationRateLimitAspect.CacheManager cacheManager = new OperationRateLimitAspect.CacheManager(); private final AtomicBoolean enableLocalCache = new AtomicBoolean(true); public OperationRateLimitAspect(Logger errorLogger, OperationRateLimitAspect.OperatorKeyGetterRouter router, OperationRateLimitAspect.KeyValueStore keyValueStore) { Objects.requireNonNull(errorLogger, "error logger can not be null"); Objects.requireNonNull(router, "operator key getter router can not be null"); Objects.requireNonNull(keyValueStore, "key-value store can not be null"); this.errorLogger = errorLogger; this.router = router; this.keyValueStore = keyValueStore; } public Object operationRateLimit(ProceedingJoinPoint joinPoint) throws Throwable { MethodSignature signature = (MethodSignature)joinPoint.getSignature(); Class<?> targetClass = AopUtils.getTargetClass(joinPoint.getTarget()); Method method = AopUtils.getMostSpecificMethod(signature.getMethod(), joinPoint.getTarget().getClass()); String className = targetClass.getName(); String methodName = method.getName(); OperationRateLimit annotation = (OperationRateLimit)method.getAnnotation(OperationRateLimit.class); long coolDownMillis; if (annotation != null && (coolDownMillis = annotation.coolDownMillis()) > 0L) { String operatorKey = this.getOperatorKey(joinPoint, annotation, className, methodName); String operationKey = joiner.join(operatorKey, className, new Object[]{methodName}); boolean enableLocalCache = this.enableLocalCache.get(); Cache<String, Long> localCache = null; if (enableLocalCache) { localCache = this.cacheManager.getCache(coolDownMillis); Long lastOPTime = (Long)localCache.getIfPresent(operationKey); if (lastOPTime != null) { return this.reject(className, methodName, operatorKey); } } if (this.keyValueStore.exists(operationKey)) { return this.reject(className, methodName, operatorKey); } else { long now = System.currentTimeMillis(); if (enableLocalCache) { localCache.put(operationKey, now); } boolean writeSucceed = this.keyValueStore.setIfAbsent(operationKey, String.valueOf(now), coolDownMillis); if (!writeSucceed) { return this.reject(className, methodName, operatorKey); } else { Object result = this.invoke(joinPoint); if (annotation.resetCoolDownAfterInvoked()) { if (enableLocalCache) { localCache.invalidate(operationKey); } try { this.keyValueStore.delete(operationKey); } catch (Exception var19) { this.errorLogger.error(joiner.join("deleteOperationKeyFailedButBizSucceed", "bizResult=" + (result == null ? "null" : JsonUtil.serialize(result)), new Object[0])); throw var19; } } return result; } } } else { return this.invoke(joinPoint); } } private String getOperatorKey(ProceedingJoinPoint joinPoint, OperationRateLimit annotation, String className, String methodName) { Class<? extends OperationRateLimitAspect.OperatorKeyGetter> getterClass = annotation.operatorKeyGetter(); OperationRateLimitAspect.OperatorKeyGetter getter = this.router.route(joinPoint, getterClass == VoidOperatorKeyGetter.class ? null : getterClass); String operatorKey = getter.getOperatorKey(joinPoint); if (StringUtils.isBlank(operatorKey)) { this.printErrorLog("blankOperatorKey", className, methodName, operatorKey); throw new ServiceException(ErrorCode.COMMON_ERROR, "get operator key failed, blank operator key"); } else { return operatorKey; } } private Object invoke(ProceedingJoinPoint joinPoint) throws Throwable { return joinPoint.proceed(); } public void enableLocalCache() { this.enableLocalCache.set(true); } public void disableLocalCache() { this.enableLocalCache.set(false); } private void printErrorLog(String desc, String className, String methodName, String operatorKey) { this.errorLogger.error(joiner.join(desc, "method=" + className + "#" + methodName, new Object[]{"operatorKey=" + operatorKey})); } private Object reject(String className, String methodName, String operatorKey) { this.printErrorLog("operationLimit", className, methodName, operatorKey); throw new ServiceException(ErrorCode.COMMON_OPERATION_LIMIT, "操作过于频繁,请稍后重试"); } public interface KeyValueStore { boolean exists(String var1); boolean setIfAbsent(String var1, String var2, long var3); boolean delete(String var1); } public interface OperatorKeyGetter { String getOperatorKey(ProceedingJoinPoint var1); } public interface OperatorKeyGetterRouter { OperationRateLimitAspect.OperatorKeyGetter route(ProceedingJoinPoint var1, Class<? extends OperationRateLimitAspect.OperatorKeyGetter> var2); } private class CacheManager { private final Map<Long, Cache<String, Long>> cacheMap; private CacheManager() { this.cacheMap = new ConcurrentHashMap(64, 0.75F, 8); } public Cache<String, Long> getCache(long coolDownMillis) { return (Cache)this.cacheMap.computeIfAbsent(coolDownMillis, (cd) -> { int size; if ((size = this.cacheMap.size()) > 128) { OperationRateLimitAspect.this.errorLogger.warn(OperationRateLimitAspect.joiner.join("tooManyOperationRateLimitCache", "currentCacheCount=" + size, new Object[0])); } return CacheBuilder.newBuilder().concurrencyLevel(8).initialCapacity(64).expireAfterWrite(cd, TimeUnit.MILLISECONDS).build(); }); } } }
二、注解对象
package com.simple.common.model; import com.simple.common.aop.OperationRateLimitAspect.OperatorKeyGetter; import java.lang.annotation.Documented; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import org.aspectj.lang.ProceedingJoinPoint; @Documented @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.METHOD}) public @interface OperationRateLimit { long coolDownMillis() default 1000L; boolean resetCoolDownAfterInvoked() default false; Class<? extends OperatorKeyGetter> operatorKeyGetter() default OperationRateLimit.VoidOperatorKeyGetter.class; public static final class VoidOperatorKeyGetter implements OperatorKeyGetter { public VoidOperatorKeyGetter() { } public String getOperatorKey(ProceedingJoinPoint joinPoint) { throw new UnsupportedOperationException(); } } }
三、拦截器
package com.simple.work.web.aop.aspect; import com.simple.common.aop.OperationRateLimitAspect; import com.simple.work.biz.middleware.redis.RedisService; import com.simple.work.common.constant.Loggers; import com.simple.work.web.context.WebContextService; import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.annotation.Around; import org.aspectj.lang.annotation.Aspect; import org.slf4j.Logger; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.core.annotation.Order; import org.springframework.stereotype.Component; import javax.annotation.PostConstruct; import javax.validation.constraints.NotBlank; import java.util.concurrent.TimeUnit; /** * zz * 2021/4/13 */ @Aspect @Order(9) @Component public class OperationRateAspect { private static final Logger errorLogger = Loggers.ERROR_LOGGER; private OperationRateLimitAspect aspect; @Autowired private ApplicationContext applicationContext; @Autowired private WebContextService webContextService; @Autowired private RedisService redisService; @PostConstruct public void init() { final OperationRateLimitAspect.OperatorKeyGetter defaultOperatorKeyGetter = joinPoint -> webContextService.getAuthUserKey(); final OperationRateLimitAspect.OperatorKeyGetterRouter router = (joinPoint, getterClass) -> getterClass == null ? defaultOperatorKeyGetter : applicationContext.getBean(getterClass); final OperationRateLimitAspect.KeyValueStore keyValueStore = new OperationRateLimitAspect.KeyValueStore() { @Override public boolean exists(@NotBlank String key) { return redisService.exists(key); } @Override public boolean setIfAbsent(@NotBlank String key, @NotBlank String value, long expireMillis) { Boolean result = redisService.setIfAbsent(key, value, expireMillis, TimeUnit.MILLISECONDS); return Boolean.TRUE.equals(result); } @Override public boolean delete(@NotBlank String key) { Boolean result = redisService.delete(key); return Boolean.TRUE.equals(result); } }; aspect = new OperationRateLimitAspect(errorLogger, router, keyValueStore); aspect.enableLocalCache(); } @Around("@annotation(com.flyfish.common.model.OperationRateLimit)") public Object operationRateLimit(ProceedingJoinPoint joinPoint) throws Throwable { return aspect.operationRateLimit(joinPoint); } }
四、使用
@ApiOperation("列表查询接口") @OperationRateLimit(coolDownMillis = 30000, resetCoolDownAfterInvoked = true) @PostMapping("/list") public WebResult<Void> list(@RequestBody @Validated @ApiParam("查询参数") BookingExportSearchRequest request) { return null; }