1,MyRequestWrapper
对Request进行包装,否则拦截器读取了json入参的输入流后,controller就获取不到参数。
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;
public class MyRequestWrapper extends HttpServletRequestWrapper {
private final byte[] body;
public MyRequestWrapper(HttpServletRequest request) throws IOException {
super(request);
body = getBodyString(request).getBytes(Charset.forName("UTF-8"));
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(getInputStream()));
}
@Override
public ServletInputStream getInputStream() throws IOException {
final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(body);
return new ServletInputStream() {
@Override
public int read() throws IOException {
return byteArrayInputStream.read();
}
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
};
}
/**
* 获取请求Body
*
* @param request
* @return
*/
private String getBodyString(ServletRequest request) {
StringBuilder sb = new StringBuilder();
InputStream inputStream = null;
BufferedReader reader = null;
try {
inputStream = request.getInputStream();
reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8")));
String line = "";
while ((line = reader.readLine()) != null) {
sb.append(line);
}
} catch (IOException e) {
e.printStackTrace();
} finally {
if (inputStream != null) {
try {
inputStream.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (reader != null) {
try {
reader.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
return sb.toString();
}
public byte[] getBody() {
return body;
}
}
RequestWrapperFilter
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
@Slf4j
@Component
@WebFilter(urlPatterns = "/*", filterName = "requestWrapperFilter")
public class RequestWrapperFilter implements Filter {
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
log.info("enter requestWrapperFilter doFilter");
HttpServletRequest httpServletRequest = (HttpServletRequest) request;
MyRequestWrapper requestWrapper = new MyRequestWrapper(httpServletRequest);
chain.doFilter(requestWrapper,response);
}
}
2,拦截器
这块需要说明一下:
1,如果是Get请求,验签思路是,取域名后的接口名+参数的QueryString ,md5(appId+appSecret+QueryString) ,顺序可以自己约定。这样就不需要让客户端再对参数进行字典排序了,没必要。排序目的是为了服务端和客户端
2,如果是post请求,若传递的是json类型,拦截器里直接拿出json参数bodyJson,md5(appId+appSecret+bodyJson),若不是传json,而是form表单传参,此时客户端和服务端,都需要对参数进行字典排序,从而保证获取到的参数列表是一致的。
3,为了提升安全性:
- 可以添加timestamp参数,服务端判断传过来的时间和自己的时间相差多久,可以指定生成sign的有效期。
- 还可以添加nonce随机字符串参数,让客户端生成一个uuid字符串传过来,你在redis里以该字符串为key,incr一下,设置过期期略大于sign有效期,拦截器判断如果incr nonce>1,则返回重复调用。这样一个sign就只能用一次,防止重放。
import cn.hutool.core.date.DatePattern;
import cn.hutool.core.date.DateTime;
import cn.hutool.core.date.DateUnit;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.map.MapUtil;
import cn.hutool.json.JSON;
import cn.hutool.json.JSONUtil;
import cn.xxx.data.common.anno.CheckSign;
import cn.xxx.data.common.constants.SignConstant;
import cn.xxx.data.common.enums.ResultEnum;
import cn.xxx.data.common.vo.Result;
import com.google.common.collect.Maps;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.*;
/**
* 验签
*/
@Slf4j
@Component
public class SignVerifyInterceptor extends HandlerInterceptorAdapter {
@Autowired
private ClientAppProperties clientAppProperties;
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
if (!(handler instanceof HandlerMethod)) {
return true;
}
HandlerMethod handlerMethod = (HandlerMethod) handler;
CheckSign checkSign = handlerMethod.getMethodAnnotation(CheckSign.class);
if (Objects.isNull(checkSign)) {
return true;
}
MyRequestWrapper requestWrapper = new MyRequestWrapper(request);
String appId = requestWrapper.getHeader(SignConstant.APP_ID);
String sign = requestWrapper.getHeader(SignConstant.SIGNATURE_STR);
String timestamp = requestWrapper.getHeader(SignConstant.TIME_STAMP);
log.info("signVerify request param,appId:{},sign:{},timestamp:{}", appId, sign,timestamp);
if (StringUtils.isBlank(appId)) {
log.warn("signVerify: appId is blank");
response(response, Result.error(ResultEnum.PARAM_ERROR.getCode(), SignConstant.LOST_PARAM_APPID));
return false;
}
String appSecret = getAppSecret(appId);
if(StringUtils.isBlank(appSecret)){
log.warn("signVerify:appSecret is blank,appId not exist ");
response(response, Result.error(ResultEnum.PARAM_ERROR.getCode(), SignConstant.NO_EXIST_APPID));
return false;
}
if (StringUtils.isBlank(sign)) {
log.warn("signVerify:sign is blank ");
response(response, Result.error(ResultEnum.PARAM_ERROR.getCode(), SignConstant.LOST_PARAM_SIGN));
return false;
}
if (StringUtils.isNotBlank(timestamp)) {
DateTime reqDateTime = null;
try {
reqDateTime = DateUtil.parse(timestamp, DatePattern.NORM_DATETIME_PATTERN);
if (signIsTimeout(reqDateTime)) {
//超时
log.warn("signverify ,sign过期,超时时间(秒数):{}", SignConstant.expireTimeSeconds);
response(response,Result.error(ResultEnum.PARAM_ERROR.getCode(), SignConstant.SIGN_TIMEOUT));
return false;
}
} catch (Exception e) {
log.warn("signverify ,解析timestamp失败,入参 timestamp:{}",timestamp);
response(response, Result.error(ResultEnum.PARAM_ERROR.getCode(), SignConstant.TIMESTAMP_FMT_ERROR));
return false;
}
}
if("POST".equals(request.getMethod())){
String bodyJson = new String(requestWrapper.getBody(), Charset.forName("UTF-8"));
log.info("signVerifyInterceptor bodyJson:{}",bodyJson);
String signStr = appId + appSecret + bodyJson ;
if(StringUtils.isNotBlank(timestamp)){
signStr+=timestamp;
}
String sign2 = DigestUtils.md5Hex(signStr).toUpperCase();
log.info("signverify get sign >>>>>>:{}" + signStr);
if (!StringUtils.equalsIgnoreCase(sign2, sign)) {
response(response, Result.error(ResultEnum.SIGN_AUTH_FAIL));
return false;
}
}
if("GET".equals(request.getMethod())){
String signStr = appId+appSecret+getGetSignString(request);
String sign2 = DigestUtils.md5Hex(signStr).toUpperCase();
log.info("signverify get sign >>>>>>:{}" + signStr);
if (!StringUtils.equalsIgnoreCase(sign2, sign)) {
response(response, Result.error(ResultEnum.SIGN_AUTH_FAIL));
return false;
}
}
return true;
}
private boolean signIsTimeout(DateTime reqDateTime) {
return DateUtil.between(reqDateTime.toJdkDate(), new Date(), DateUnit.SECOND) > SignConstant.expireTimeSeconds;
}
public String getGetSignString(HttpServletRequest request) {
String requestURI = request.getRequestURI();
String queryString = request.getQueryString();
if (StringUtils.isBlank(queryString)) {
return requestURI;
} else {
return StringUtils.join(requestURI,"?",queryString);
}
}
private String getAppSecret(String appId) {
return clientAppProperties.getApps().get(appId);
}
public static void response(HttpServletResponse response, Result<?> result) throws IOException {
response.setContentType("application/json");
response.setCharacterEncoding("utf-8");
response.getWriter().write(JSONUtil.toJsonStr(result));
response.getWriter().flush();
}
}
// public static String getSortedString(HttpServletRequest request) {
// Map<String, String[]> parameterMap = request.getParameterMap();
//
// SortedMap<String, String[]> sortMap = new TreeMap<>();
// StringBuffer sbf = new StringBuffer();
//
// if(MapUtil.isNotEmpty(parameterMap)){
// Iterator<Map.Entry<String, String[]>> iter = parameterMap.entrySet().iterator();
// while(iter.hasNext()){
// Map.Entry<String, String[]> entry = iter.next();
// String key = entry.getKey();
// String[] value = entry.getValue();
// if (StringUtils.equalsAny(key,SignConstant.APP_ID,SignConstant.SIGNATURE_STR)) {
// continue;
// }
// sortMap.put(key, value);
// }
//
// }
// Iterator<Map.Entry<String, String[]>> iter = sortMap.entrySet().iterator();
// while (iter.hasNext()) {
// Map.Entry<String, String[]> entry = iter.next();
// String key = entry.getKey();
// String[] value = entry.getValue();
// sbf.append(key + "=" + value + "&");
// }
// String sbfString = sbf.toString();
// System.out.println("排序后的字符串:" + sbfString.substring(0, sbfString.length() - 1));
// return sbfString.substring(0, sbfString.length() - 1);
// }
拦截器MvcConfig
@Configuration
public class WebMvcConfig implements WebMvcConfigurer {
@Autowired
private SignVerifyInterceptor signVerifyInterceptor;
@Override
public void addInterceptors(InterceptorRegistry registry) {
/*****************/
registry.addInterceptor(signVerifyInterceptor);
}
}
客户端应用配置
client:
apps:
xxx: ooo
配置类
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Configuration;
import java.util.HashMap;
import java.util.Map;
@Data
@Configuration
@ConfigurationProperties(prefix = "client")
public class ClientAppProperties {
private Map<String,String> apps = new HashMap<>();
}
相关常量
/**
* 签名常量
*/
public class SignConstant {
/**
* appId
*/
public static final String APP_ID = "appId";
/**
* 加密串
*/
public static final String SIGNATURE_STR = "sign";
/**
* 时间戳
*/
public static final String TIME_STAMP = "timestamp";
/**
* 接口访问过期时间 秒数
*/
public static final long expireTimeSeconds = 60 * 5;
/**
* 参数补全
*/
public static final String LOST_PARAM_SIGN = "sign不能为空";
/**
* 参数补全
*/
public static final String LOST_PARAM_TIME_STAMP = "timestamp不能为空";
public static final String TIMESTAMP_FMT_ERROR = "timestamp参数格式错误";
/**
* 参数补全
*/
public static final String LOST_PARAM_APPID = "appId不能为空";
/**
* 参数补全
*/
public static final String SIGN_TIMEOUT = "sign失效,有效期5分钟";
/**
* 参数补全
*/
public static final String SIGN_FAIL = "验签失败";
/**
* 参数补全
*/
public static final String NO_EXIST_APPID = "非法appId";
}
注解
import java.lang.annotation.*;
/**
* @Author
* @Date 2023-08-02 16:17
**/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Documented
public @interface CheckSign {
}