Spring Cloud网关使用GlobalFilter过滤器做参数加密传输

package com.tsl.serviceingateway.sso;

import com.alibaba.fastjson.JSONObject;
import com.tsl.serviceingateway.sso.util.AESUtil;
import com.tsl.serviceingateway.sso.util.Sm4Util;
import io.netty.buffer.ByteBufAllocator;
import lombok.SneakyThrows;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.stereotype.Component;
import org.springframework.util.MultiValueMap;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.io.UnsupportedEncodingException;
import java.net.URI;
import java.net.URLDecoder;
import java.nio.CharBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;

/**
 * 参数加密传输
 *
 * @author: lll
 * @date: 2021年08月27日 14:08:20
 */
@Component
public class ParamsEncryptionFilter implements GlobalFilter, Ordered {


    static Logger logger = LoggerFactory.getLogger(ParamsEncryptionFilter.class);

    @SneakyThrows
    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerHttpRequest serverHttpRequest = exchange.getRequest();
        HttpMethod method = serverHttpRequest.getMethod();
        String encryptionType = serverHttpRequest.getHeaders().getFirst("encryptionType");
        URI uri = serverHttpRequest.getURI();
        MediaType mediaType = serverHttpRequest.getHeaders().getContentType();
        if (encryptionType != null) {
            if (method == HttpMethod.POST || method == HttpMethod.PUT) {
                //从请求里获取Post请求体
                String bodyStr = resolveBodyFromRequest(serverHttpRequest);
                logger.info("PUT OR POST bodyStr: " + bodyStr);
                //密文参数
                String cipherParams = "";
                //明文参数
                String plainParams = "";
                if (bodyStr != null) {
                    // 根据请求方法类型获取密文
                    cipherParams = getCiphertextByMediaType(bodyStr, mediaType);
                    //解密
                    logger.info("PUT OR POST ciphertext  parameters: " + cipherParams);
                    plainParams = decodeParamsBytype(encryptionType, cipherParams);
                    //非法非法加密处理,不进行解密操作
                    if ("-1".equals(plainParams)) {
                        return chain.filter(exchange);
                    }
                    logger.info("PUT OR POST plaintext parameters: " + plainParams);

                    //封装request,传给下一级
                    ServerHttpRequest request = serverHttpRequest.mutate().uri(uri).build();
                    DataBuffer bodyDataBuffer = stringBuffer(plainParams);
                    Flux<DataBuffer> bodyFlux = Flux.just(bodyDataBuffer);
                    request = new ServerHttpRequestDecorator(request) {
                        @Override
                        public Flux<DataBuffer> getBody() {
                            return bodyFlux;
                        }
                    };
                    // 构建新的请求头
                    HttpHeaders headers = new HttpHeaders();
                    headers.putAll(exchange.getRequest().getHeaders());
                    // 重新设置CONTENT_LENGTH
                    int length = plainParams.getBytes().length;
                    headers.remove(HttpHeaders.CONTENT_LENGTH);
                    headers.setContentLength(length);
                    request = new ServerHttpRequestDecorator(request) {
                        @Override
                        public HttpHeaders getHeaders() {
                            return headers;
                        }
                    };
                    return chain.filter(exchange.mutate().request(request).build());
                }
            } else if (method == HttpMethod.GET || method == HttpMethod.DELETE) {
                try {
                    MultiValueMap<String, String> requestQueryParams = serverHttpRequest.getQueryParams();
                    logger.info("GET OR DELETE ciphertext  parameters: " + requestQueryParams.get("params").get(0));
                    String params = requestQueryParams.get("params").get(0);
                    //解密
                    params = decodeParamsBytype(encryptionType, params);
                    //非法非法加密处理,不进行解密操作
                    if ("-1".equals(params)) {
                        return chain.filter(exchange);
                    }
                    logger.info("GET OR DELETE plaintext parameters: " + params);
                    // 封装URL
                    URI plaintUrl = new URI(uri.getScheme(), uri.getAuthority(),
                            uri.getPath(), params, uri.getFragment());
                    //封装request,传给下一级
                    ServerHttpRequest request = serverHttpRequest.mutate().uri(plaintUrl).build();
                    logger.info("get OR delete plaintext request.getQueryParams(): " + request.getQueryParams());
                    return chain.filter(exchange.mutate().request(request).build());
                } catch (Exception e) {
                    return chain.filter(exchange);
                }
            }
        }
        return chain.filter(exchange);

    }

    /**
     * 根据请求方法类型获取密文
     */
    private String getCiphertextByMediaType(String bodyStr, MediaType mediaType) throws UnsupportedEncodingException {

        //json请求
        if (mediaType.equals(MediaType.APPLICATION_JSON)) {
            JSONObject bodyJson = JSONObject.parseObject(bodyStr);
            return bodyJson.getString("params");
        }
        //form请求
        else if (mediaType.equals(MediaType.MULTIPART_FORM_DATA) || mediaType.equals(MediaType.APPLICATION_FORM_URLENCODED)) {
            Map<String, String> keyValues = urlSplit(bodyStr);
            return URLDecoder.decode(keyValues.get("params"), "UTF-8");
        } else {
            return "-1";
        }
    }

    /**
     * 解密过滤器必须在所有过滤器之前,否后后续过滤器获取参数会报错
     * 如果有的其他的过滤器添加请调整过滤器顺序
     */
    @Override
    public int getOrder() {
        return -2;
    }

    //根据类型进行解密
    private String decodeParamsBytype(String type, String params) throws Exception {
        if ("BA".equals(type)) {
            //BASE64解密
            return new String(Base64.getDecoder().decode(params));
        } else if ("AE".equals(type)) {
            //AES128解密
            return AESUtil.aesDecrypt(params);
        } else if ("SM".equals(type)) {
            //SM4解密
            return Sm4Util.decryptEcb(params);
        } else {
            //非法解密
            return "-1";
        }
    }

    private String resolveBodyFromRequest(ServerHttpRequest serverHttpRequest) {
        //获取请求体
        Flux<DataBuffer> body = serverHttpRequest.getBody();

        AtomicReference<String> bodyRef = new AtomicReference<>();
        body.subscribe(buffer -> {
            CharBuffer charBuffer = StandardCharsets.UTF_8.decode(buffer.asByteBuffer());
            DataBufferUtils.release(buffer);
            bodyRef.set(charBuffer.toString());
        });
        //获取request body
        return bodyRef.get();
    }

    private DataBuffer stringBuffer(String value) {
        byte[] bytes = value.getBytes(StandardCharsets.UTF_8);
        NettyDataBufferFactory nettyDataBufferFactory = new NettyDataBufferFactory(ByteBufAllocator.DEFAULT);
        DataBuffer buffer = nettyDataBufferFactory.allocateBuffer(bytes.length);
        buffer.write(bytes);
        return buffer;
    }

    /**
     * 解析出url参数中的键值对
     * 如 "Action=del&id=123",解析出Action:del,id:123存入map中
     *
     * @param params url地址
     * @return url请求参数部分
     */
    private static Map<String, String> urlSplit(String params) {
        Map<String, String> mapRequest = new HashMap<>();
        String[] arrSplit = null;
        if (params == null) {
            return mapRequest;
        }
        arrSplit = params.split("[&]");
        for (String strSplit : arrSplit) {
            String[] arrSplitEqual = null;
            arrSplitEqual = strSplit.split("[=]");
            //解析出键值
            if (arrSplitEqual.length > 1) {
                //正确解析
                mapRequest.put(arrSplitEqual[0], arrSplitEqual[1]);
            } else if (arrSplitEqual[0] != "") {
                //只有参数没有值,不加入
                mapRequest.put(arrSplitEqual[0], "");
            }
        }
        return mapRequest;
    }

    public static void main(String[] args) throws Exception {
        String params = "accountName=superadmin&password=123123";
        //base64
        String encodeParams = Base64.getEncoder().encodeToString(params.getBytes());
        logger.info("BASE64 encodeParams: " + encodeParams);
        String decoderParams = new String(Base64.getDecoder().decode(encodeParams));
        logger.info("BASE64 decoderParams: " + decoderParams);

        //aes128
        String aesEncodeParams = AESUtil.aesEncrypt(params);
        logger.info("aesEncodeParams: " + aesEncodeParams);
        String aesdecodeParams = AESUtil.aesDecrypt(aesEncodeParams);
        logger.info("aesdecodeParams: " + aesdecodeParams);

        //国密sm4
        String sm4EncodeParams = Sm4Util.encryptEcb(params);
        logger.info("sm4EncodeParams: " + sm4EncodeParams);
        String sm4DecodeParams = Sm4Util.decryptEcb(sm4EncodeParams);
        logger.info("sm4DecodeParams: " + sm4DecodeParams);
    }

 

posted @ 2021-09-15 15:01  CodingPanda  阅读(1334)  评论(0编辑  收藏  举报