自定义注解 校验入参请求体

目前javax.validation.constraints自带校验注解,校验不通过会统一抛出ValidationException。一般系统的框架会通过@RestControllerAdvice做拦截器,统一管理异常的返回码,拦截ValidationException时,对象里只有错误描述,而没有具体的错误码(e.g throw new ValidationException( "Unable to get available provider resolvers.", re );)。因此想要自定义错误返回码或者自身需要对返回体做一些定制化修改,可以自己实现一个注解+AOP拦截器实现对参数体变量的参数校验。

1、实现一个注解,作为校验参数的入口

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.PARAMETER)
@Retention(RetentionPolicy.RUNTIME)
public @interface ParamCheck {
}

2、实现一个AOP切面,用于拦截校验注解的逻辑

import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.*;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Component;

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;

@Component
@Aspect
public class ParamValidAop {


    /**
     * 定义有一个切入点,范围为web包下的类
     */
    @Pointcut("within(@org.springframework.web.bind.annotation.RestController *)")
    public void checkParam() {
    }

    @Before("checkParam()")
    public void doBefore(JoinPoint joinPoint) {
    }

    /**
     * 检查参数是否为空
     */
    @Around("checkParam()")
    public Object doAround(ProceedingJoinPoint pjp) throws Throwable {


        MethodSignature signature = ((MethodSignature) pjp.getSignature());
        //得到拦截的方法
        Method method = signature.getMethod();
        //获取方法参数注解,返回二维数组是因为某些参数可能存在多个注解
        Annotation[][] parameterAnnotations = method.getParameterAnnotations();
        if (parameterAnnotations == null || parameterAnnotations.length == 0) {
            return pjp.proceed();
        }
        //获取方法参数名
        String[] paramNames = signature.getParameterNames();
        //获取参数值
        Object[] paranValues = pjp.getArgs();

        //获取方法参数类型
        Class<?>[] parameterTypes = method.getParameterTypes();
        Parameter[] parameters = method.getParameters();
        for (int i = 0; i < parameterAnnotations.length; i++) {
            for (int j = 0; j < parameterAnnotations[i].length; j++) {
                //如果该参数前面的注解不为空并且是ParamCheck的实例,并且notNull()=true,并且默认值为空,则进行非空校验
                if (parameterAnnotations[i][j] != null && parameterAnnotations[i][j] instanceof ParamCheck) {
                    Object obj = pjp.getArgs()[i];
                    ParameterCheckUtils.paramCheckParent(obj);
                    break;
                }
            }
        }
        return pjp.proceed(paranValues);

    }

    /**
     * 在切入点return内容之后切入内容(可以用来对处理返回值做一些加工处理)
     *
     * @param joinPoint
     */
    @AfterReturning("checkParam()")
    public void doAfterReturning(JoinPoint joinPoint) {
    }
}

其中ParameterCheckUtils.paramCheckParent(obj);用于对入参的具体校验

3、实现一个AOP切面,用于拦截校验注解的逻辑

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonSetter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

@Slf4j
public class ParameterCheckUtils {

    // 只校验当前类包含的对象,用于openFeign中face管控接口,不需要传orgId
    public static void paramCheck(Object obj) {
        if (obj == null) {
            log.error("obj can not be null!");
            throw new RuleEngineException(RuleEngineCode.PARAM_CANNOT_NULL);
        }
        Class<?> clazz = obj.getClass();
        List<Field> fieldList = new ArrayList<Field>(Arrays.asList(clazz.getDeclaredFields()));
        commonCheckMethod(obj, fieldList);
    }

    // 会校验当前类和他的所有父类包含的对象,用于OpenApi接口
    public static void paramCheckParent(Object obj) {
        if (obj == null) {
            log.error("obj can not be null!");
            throw new RuleEngineException(RuleEngineCode.PARAM_CANNOT_NULL);
        }
        Class<?> clazz = obj.getClass();
        List<Field> fieldList = new ArrayList<>();

        //子类的变量只扫描一次,否则获取超类会重复扫描
        boolean isScanFirstSonLayer = false;
        while (!Object.class.equals(clazz)) {
            Field[] firstSonLayerClazzFields = clazz.getDeclaredFields();
            if(!isScanFirstSonLayer){
                for(Field firstSonLayerClazzField : firstSonLayerClazzFields){
                    firstSonLayerClazzField.setAccessible(true);
                    List<Field> firstSonLayerFieldList = new ArrayList<>();
                    Object paramObj;
                    try {
                        Class<?> firstSonLayerClazz = Class.forName(firstSonLayerClazzField.getType().getName());
                        firstSonLayerFieldList.addAll(Arrays.asList(firstSonLayerClazz.getDeclaredFields()));
                        paramObj = firstSonLayerClazzField.get(obj);
                    } catch (IllegalAccessException e) {
                        log.error("", e);
                        return;
                    } catch (ClassNotFoundException e) {
                        log.error("", e);
                        return;
                    }
                    commonCheckMethod(paramObj, firstSonLayerFieldList);
                }
                isScanFirstSonLayer = true;
            }
            fieldList.addAll(Arrays.asList(clazz.getDeclaredFields()));
            clazz = clazz.getSuperclass();
        }
        commonCheckMethod(obj, fieldList);
    }

    private static void commonCheckMethod(Object obj, List<Field> fieldList) {
        int atLeastOneCount = 0;
        boolean isAtLeastOne = false;
        for (Field field : fieldList) {
            field.setAccessible(true);
            ParameterAttr attr = field.getAnnotation(ParameterAttr.class);
            if (attr == null) {
                continue;
            }
            Object paramObj = null;
            try {
                paramObj = field.get(obj);
            } catch (IllegalAccessException e) {
                log.error("", e);
                return;
            }

            if (attr.isNecessary()) {
                checkNotNull(obj, field, paramObj);
            }
            if (attr.lengthMaxLimit() > 0 && paramObj != null && StringUtils.isNotBlank(paramObj.toString())) {
                if (paramObj.toString().length() > attr.lengthMaxLimit()) {
                    log.error("class={} field={} length is too long! limit={} but length={}", obj.getClass().getName(), field.getName(), attr.lengthMaxLimit(), paramObj.toString().length());
                    String name = getName(field);
                    throw new RuleEngineException(RuleEngineCode.ILLEGAL_REQUEST_PARAMETER.getCode(), String.format(RuleEngineCode.ILLEGAL_REQUEST_PARAMETER.getMessage(), name));
                }
            }
            if (attr.lengthMinLimit() > 0 && paramObj != null && StringUtils.isNotBlank(paramObj.toString())) {
                if (paramObj.toString().length() < attr.lengthMinLimit()) {
                    log.error("class={} field={} length is too short! limit={} but length={}", obj.getClass().getName(), field.getName(), attr.lengthMinLimit(), paramObj.toString().length());
                    String name = getName(field);
                    throw new RuleEngineException(RuleEngineCode.ILLEGAL_REQUEST_PARAMETER.getCode(), String.format(RuleEngineCode.ILLEGAL_REQUEST_PARAMETER.getMessage(), name));
                }
            }
            if(attr.isAtLeastOne()){
                isAtLeastOne = true;
                if (paramObj != null && !StringUtils.isBlank(paramObj.toString())) {
                    atLeastOneCount++;
                }
            }
        }
        if(isAtLeastOne && atLeastOneCount < 1){
            log.error("parameter is at least one!");
            throw new RuleEngineException(RuleEngineCode.ILLEGAL_REQUEST_PARAMETER);
        }
        log.info("commonCheck success!");
    }

    private static void checkNotNull(Object obj, Field field, Object paramObj) {
        if (paramObj == null || StringUtils.isBlank(paramObj.toString())) {
            log.error("class={} field={} can not be null!", obj.getClass().getName(), field.getName());
            String name = getName(field);
            throw new RuleEngineException(RuleEngineCode.PARAMETER_IS_MISSING.getCode(), String.format(RuleEngineCode.PARAMETER_IS_MISSING.getMessage(), name));
        }
    }

    private static String getName(Field field) {
        String name = field.getName();
        JsonSetter jsonName = field.getAnnotation(JsonSetter.class);
        if (jsonName != null) {
            return jsonName.value();
        }
        JsonProperty jsonPropertyName = field.getAnnotation(JsonProperty.class);
        if (jsonPropertyName != null) {
            return jsonPropertyName.value();
        }
        return name;
    }
}

4、ParameterAttr是添加在变量上的注解,可以自定义校验方式

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface ParameterAttr {
    /**
     * 是否必须,默认不必须
     *
     * @return
     */
    boolean isNecessary() default false;

    /**
     * 属性toSring之后,最大长度
     *
     * @return
     */
    int lengthMaxLimit() default 0;

    /**
     * 属性toSring之后,最小长度
     *
     * @return
     */
    int lengthMinLimit() default 0;

    /**
     * 至少一个必填
     *
     * @return
     */
    boolean isAtLeastOne() default false;

}

5、在controller需要校验的参数体前添加校验注解

 @ApiOperation(value = "查询测试接口")
 @PostMapping(value = "/test/v1/get", produces = "application/json;charset=UTF-8")
 public GetResDto getRule(@RequestBody @ParamCheck TestDto testDto) {
            //todo
 }

 

6、在需要校验的参数体的具体变量前添加校验注解

@Data
public class TestDto {

    @ParameterAttr(isAtLeastOne = true)
    private String ruleId;

    @ParameterAttr(isAtLeastOne = true, lengthMinLimit = 4, lengthMaxLimit = 32)
    private String ruleKey;

}

 

posted @ 2021-07-05 20:26  cosmocosmo  阅读(671)  评论(0编辑  收藏  举报