之前的token验证一直觉得很烂,今天做下优化,老项目就不贴出来了。

第一步

首先将过滤器注册到Bean容器,拦截相关请求,本来是想通过实现ApplicationContextAware接口的setApplicationContext去获取spring的上下文,但是容器启动时报错,才发现WebMvcConfigurationSupport中已经实现了,所以这里直接取用,new Fitler时要将applicationContext通过构造方法传入以便处理

package com.mbuyy.config;

import com.mbuyy.common.CommonInterceptor;
import com.mbuyy.filter.LoginFilter;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.StringHttpMessageConverter;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurationSupport;

import java.nio.charset.Charset;

@Configuration
public class WebConfig extends WebMvcConfigurationSupport {

    @Bean
    public HttpMessageConverter<String> responseBodyConverter() {
        StringHttpMessageConverter converter = new StringHttpMessageConverter(
                Charset.forName("UTF-8"));
        return converter;
    }


    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(new CommonInterceptor()).addPathPatterns("/**");
        super.addInterceptors(registry);
    }

    /**
     * 1、注册过滤器
     *
     * @return
     */
    @Bean
    public FilterRegistrationBean filterRegist() {
        FilterRegistrationBean frBean = new FilterRegistrationBean();
        frBean.setFilter(new LoginFilter(this.getApplicationContext()));
        frBean.addUrlPatterns("/payment/*");
        frBean.addUrlPatterns("/merchant/*");
        frBean.addUrlPatterns("/mobile/*");
        System.out.println("filter");
        return frBean;
    }
}

第二步

书写构造器

package com.mbuyy.filter;

import com.mbuyy.annotation.TokenIgnoreUtils;
import com.mbuyy.annotation.TokenVerify;
import com.mbuyy.merchant.controller.*;
import com.mbuyy.mobile.controller.*;
import com.mbuyy.util.JWTUtils;
import com.mbuyy.util.RedisUtils;
import com.mbuyy.util.StringUtils;
import com.mbuyy.util.ValidateUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationContext;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.Enumeration;
import java.util.Map;
import java.util.Objects;
import java.util.regex.Pattern;

import static com.mbuyy.constants.Constants.*;

public class LoginFilter implements Filter {
    private static final Logger logger = LoggerFactory.getLogger(LoginFilter.class);

    private TokenIgnoreUtils tokenIgnoreUtils;
    private ApplicationContext applicationContext;

    Pattern pattern = Pattern.compile("^[-\\+]?[\\d]*$");

    public LoginFilter() {
    }

    /**
     * 构造器注入spring的ApplicationContext上下文对象
     * @param applicationContext
     */
    public LoginFilter(ApplicationContext applicationContext) {
        this.applicationContext = applicationContext;
    }

    @Override
    public void destroy() {

    }

    @Override
    public void init(FilterConfig arg0) {
        tokenIgnoreUtils = new TokenIgnoreUtils();

        // 获取所有添加了TokenVerifyAnnotion注解的类,注册并添加
        Map<String, Object> beansWithAnnotationMap = this.applicationContext.getBeansWithAnnotation(TokenVerify.class);

        for (Object tokenVerifyBean : beansWithAnnotationMap.keySet()){
            logger.info(tokenVerifyBean.toString());

            Object bean = this.applicationContext.getBean(tokenVerifyBean.toString());
            tokenIgnoreUtils.registerController(bean.getClass());
        }
//        tokenIgnoreUtils.registerController(MerchantCouponController.class);
//        tokenIgnoreUtils.registerController(MerchantFlashPurchaseController.class);
//        tokenIgnoreUtils.registerController(MerchantNoticeController.class);
//        tokenIgnoreUtils.registerController(MerchantOrderController.class);
//        tokenIgnoreUtils.registerController(MerchantPayController.class);
//        tokenIgnoreUtils.registerController(MerchantPlatformController.class);
//        tokenIgnoreUtils.registerController(MerchantPopularizeController.class);
//        tokenIgnoreUtils.registerController(MerchantShopController.class);
//        tokenIgnoreUtils.registerController(MerchantTaskController.class);
//        tokenIgnoreUtils.registerController(MerchantUserController.class);
//
//
//        tokenIgnoreUtils.registerController(MobileCouponController.class);
//        tokenIgnoreUtils.registerController(MobileFlashPurchaseController.class);
//        tokenIgnoreUtils.registerController(MobileNoticeController.class);
//        tokenIgnoreUtils.registerController(MobileOrderController.class);
//        tokenIgnoreUtils.registerController(MobilePayController.class);
//        tokenIgnoreUtils.registerController(MobilePlatformController.class);
//        tokenIgnoreUtils.registerController(MobilePopularizeController.class);
//        tokenIgnoreUtils.registerController(MobileTaskController.class);
//        tokenIgnoreUtils.registerController(MobileThirdOrderController.class);
//        tokenIgnoreUtils.registerController(MobileUserController.class);
    }


    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        //拿到对象
        HttpServletRequest req = (HttpServletRequest) request;
        HttpServletResponse rep = (HttpServletResponse) response;

        setCrossDomian(req, rep);

        // 获取请求路径
        String requestURI = req.getRequestURI();

        // 0客户端token,1商户端token,2代驾端token
        int tokenType = getTokenType(req);
        String token = getToken(req);

        req.setAttribute(TOKEN_TYPE, tokenType);

        System.out.println("token:" + token);
        System.out.println("tokenType:" + tokenType);
        // TODO: 2019/1/10 这里有万能token,本地测试使用,别忘了去掉!!!!
        if (ValidateUtils.isNotEmpty(token) && pattern.matcher(token).matches()) {
            req.setAttribute(TOKEN_ID, Integer.parseInt(token));
            chain.doFilter(request, response);
            return;
        }

        try {
            verify(req, token, tokenType,requestURI);
        } catch (Exception e) {
            e.printStackTrace();
            failed(rep, e.getMessage());
            return;
        }

        chain.doFilter(request, response);
    }

    private String getToken(HttpServletRequest req) {
        String token = "";
        if (ValidateUtils.isNotEmpty(req.getHeader("token_merchant"))){
            token = req.getHeader("token_merchant");
        } else {
            token = req.getHeader("token");
        }
        return token;
    }

    /**
     * 判断是来自哪个端的请求,哪个token有取哪一个,不同的token之间互斥
     * @param req
     * @return
     */
    private int getTokenType(HttpServletRequest req) {
        int tokenType = 0;
        Enumeration<String> headerNames = req.getHeaderNames();
        String headName = headerNames.nextElement();
        while (ValidateUtils.isNotEmpty(headName)){
            if (headName.equals("token_merchant")){
                tokenType = 1;
                break;
            } else if (headName.equals("token")){
                tokenType = 0;
                break;
            }
            headName = headerNames.nextElement();
        }
        return tokenType;
    }


    /**
     * 验证token
     * @param req
     * @param token
     * @param tokenType
     * @param requestURI
     */
    private void verify(HttpServletRequest req, String token, int tokenType, String requestURI) throws Exception {
        //如果是登录就不需要验证
        boolean isNeedToken = tokenIgnoreUtils.startCheck(requestURI);
//        boolean isNeedToken = false;
        if (!isNeedToken) {
            //如果不需要Token
            if (ValidateUtils.isNotEmpty(token)) {
                // 解析token
                parseToken(req, token, tokenType);
            }
        }else{
            if (ValidateUtils.isNotEmpty(token)) {
                // 解析token
                parseToken(req, token, tokenType);
            }else{
                //token没传 直接失败
                throw new RuntimeException("请输入验证Token");
            }

        }
    }

    /**
     * 设置跨域问题
     * @param req
     * @param rep
     * @throws UnsupportedEncodingException
     */
    private void setCrossDomian(HttpServletRequest req, HttpServletResponse rep) throws UnsupportedEncodingException {
        req.setCharacterEncoding("utf-8");

        // 设置允许跨域访问的域,*表示支持所有的来源
        rep.setHeader("Access-Control-Allow-Origin", "*"); //Access-Control-Allow-Origin
        System.out.println(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>");
        // 设置允许跨域访问的方法
        rep.setHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS, DELETE");
        rep.setHeader("Access-Control-Max-Age", "3600");
        rep.setHeader("Access-Control-Allow-Headers", "*");
    }

    /**
     * 解析token得到tokenId
     *
     * @param req
     * @param token
     * @param tokenType
     * @return
     * @throws Exception
     */
    private void parseToken(HttpServletRequest req, String token, int tokenType) throws Exception {
        //存在token使用JWT鉴权,鉴权失败以无token处理
        Map<String, String> tokenMap;
        try {
            tokenMap = JWTUtils.getInstance().verifyToken(token);
        } catch (JWTUtils.TokenException e) {
            throw new RuntimeException("鉴权异常");
        }

        // 根据token获取token_id
        Integer tokenId = Integer.parseInt(tokenMap.get(TOKEN_ID));
        if (tokenId == null){
            throw new RuntimeException("TokenId获取失败");
        }

        switch (tokenType){
            case 1 :
                parseMerchantToken(tokenId, token);
                break;
            default:
                parseCustomerToken(tokenId, token);
                break;
        }


        //放行
        //添加token里的user_id
        req.setAttribute(TOKEN_ID, tokenId);

        logger.info("token_id = " + tokenId);

    }

    private void parseCustomerToken(Integer tokenId, String token) throws Exception {
        logger.info("解析客户端token:" + token);

        // redis缓存中是否存在token
        if(!RedisUtils.exists(TOKEN + tokenId)){
            throw new RuntimeException("redis match error");
        }

        String redisToken = RedisUtils.get(TOKEN + tokenId);
        if (!Objects.equals(token, redisToken) && !("-1".equals(redisToken))) {
            throw new RuntimeException("Token Error");
        }


//        int result = userPunishService.verifyUser(tokenId);
//        switch (result){
//            case 1: throw new RuntimeException("Account has been deleted");
//            case 2: throw new RuntimeException("The account has been disabled");
////            case 3: throw new RuntimeException("Account has been deleted"); break;
//        }
    }


    private void parseMerchantToken(Integer tokenId, String token) throws Exception {
        logger.info("解析商户端token:" + token);

        // redis缓存中是否存在token
        if(!RedisUtils.exists(TOKEN_SHOP + tokenId)){
            throw new RuntimeException("匹配异常");
        }

        String redisToken = RedisUtils.get(TOKEN_SHOP + tokenId);
        if (!Objects.equals(token, redisToken) && !("-1".equals(redisToken))) {
            throw new RuntimeException("Token错误");
        }

//        int result = userPunishService.verifyMerchant(tokenId);
//        switch (result){
//            case 1: throw new RuntimeException("Account has been deleted");
//            case 2: throw new RuntimeException("The account has been disabled");
////            case 3: throw new RuntimeException("Account has been deleted"); break;
//        }
    }

    private void failed(HttpServletResponse rep, String msg) throws IOException {
        PrintWriter w = rep.getWriter();
        w.write("{\"status\": 401,\"msg\": \"" + msg + "\"}");
        w.flush();
        w.close();
    }
}

 根据注解(TokenIgnore)验证是否需要Token

public class TokenIgnoreUtils {
    private List<Class<?>> mControllerClass;

    public TokenIgnoreUtils() {
        mControllerClass = new ArrayList<>();
    }

    public void registerController(Class<?> clazz) {
        mControllerClass.add(clazz);
    }

    /**
     * 该方法有个缺点,就是所有controller下的requestMapping路径长度必须一致/ ,
     * 例:必须是/user/list,不能是/user/collect/list
     * @param reqUrl
     * @return
     */
    public boolean startCheck(String reqUrl) {
        for (Class clazz : mControllerClass) {
            Method[] methods = clazz.getMethods();
            for (Method method : methods) {
                RequestMapping requestMapping = method.getAnnotation(RequestMapping.class);
//                PostMapping postMapping = method.getAnnotation(PostMapping.class);
//                GetMapping getMapping = method.getAnnotation(GetMapping.class);
//                DeleteMapping deleteMapping = method.getAnnotation(DeleteMapping.class);
//                PutMapping putMapping = method.getAnnotation(PutMapping.class);
                if (requestMapping == null) {
                    continue;
                }
                if (ConfigConstants.TOKEN_UTILS_TEST) {
                    if (reqUrl.contains(requestMapping.value()[0]) &&
                            clazz.getSimpleName().toLowerCase().contains(reqUrl.split("/")[1])) {
                        TokenIgnore tokenIgnore = method.getAnnotation(TokenIgnore.class);
                        return tokenIgnore == null;
                    }
                } else {

                    String requestMappingUrl = "";

                    // 类上的映射
                    RequestMapping cRequestMapping = (RequestMapping) clazz.getAnnotation(RequestMapping.class);
                    if (cRequestMapping != null &&
                            cRequestMapping.value() != null && cRequestMapping.value().length > 0){
                        requestMappingUrl += cRequestMapping.value()[0];
                    }

                    // 方法上的映射
                    if (requestMapping != null &&
                            requestMapping.value() != null && requestMapping.value().length > 0){
                        requestMappingUrl += requestMapping.value()[0];
                    }


                    // 1、判断映射路径是否包含在请求路径中,2、进一步判断映射路径,是否为请求路径最后的全路径
                    if (reqUrl.contains(requestMappingUrl)
                            && reqUrl.lastIndexOf(requestMappingUrl) + requestMappingUrl.length() == reqUrl.length()) {
                        System.out.println("requestMappingUrl = " + requestMappingUrl);
                        TokenIgnore tokenIgnore = method.getAnnotation(TokenIgnore.class);
                        System.out.println(tokenIgnore == null);

                        return tokenIgnore == null;
                    }
                }

            }
        }
        return true;
    }
}

 

第三步

获取当前登录用户信息

Integer currentUserId = (Integer) request.getAttribute(TOKEN_ID);

笔者还很菜,要走的路还很长,如有不妥之处,还请联系小编不吝赐教

 

posted on 2020-12-16 11:51  浅灰色的记忆  阅读(1338)  评论(0编辑  收藏  举报