SqlFilter

  badregex: (\b(exec|execute|insert|delete|update|drop|chr|mid|master|truncate|char|declare|sitename|xp_cmdshell|create|alter|grant|union)\b)

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.huawei.it.gts.ls.common.domain.LSResponse;
import com.huawei.it.gts.ls.constants.GlobalErrorCode;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.util.MultiValueMap;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
* @since 2020年03月25日
*/
@Slf4j
@Data
@Component
@ConditionalOnProperty(prefix = "security.config.sql.filter", name = "enabled", havingValue = "true")
public class SqlFilter implements GlobalFilter, Ordered {

// 过滤掉的sql关键字,可以手动添加
@Value("${security.config.sql.badregex}")
private String badStr;

@Value("${security.config.sql.whiteUrl}")
private String whiteUrl;//白名单接口资源

private final String prefix = "[/\\w]+/v[0-9]+[\\.0-9]+";//URL验证正则表达式前缀部分,如:/api_gateway/uams_msa/v0.1
private final String replaceRegex = "\\{\\w+\\}";//权限资源URL正则匹配模式,如:将"{id}",替换掉java正则表达式
private final String replacement = "[\\\\w-%.]+";//java正则表达式

/**
* @author wwx798540
* @description 使用spring cloud gateway 重写filter
* @date 2020/4/9
* @param: [exchange, chain]
* @return reactor.core.publisher.Mono<java.lang.Void>
**/
@SuppressWarnings("unchecked")
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse();

HttpHeaders headers = request.getHeaders();

Map<String, Object> objectMap = new HashMap<>();

MultiValueMap<String, String> queryParams = request.getQueryParams();
Set<String> paramsSet = queryParams.keySet();
for (String param : paramsSet) {
objectMap.put(param + "Param", request.getQueryParams().getFirst(param));
}
String contentType = headers.getFirst("Content-Type");
// 只考虑contentType是json格式的
if (contentType!=null && contentType.contains("application/json")) {
Object cachedRequestBodyObject = exchange.getAttribute("cachedRequestBodyObject");
if (null != cachedRequestBodyObject) {
String body = cachedRequestBodyObject.toString();
if (body.startsWith("[")) {
List<Map<String, Object>> mapList = (List<Map<String, Object>>) cachedRequestBodyObject;
String jsonStr = JSON.toJSONString(mapList);
JSONArray jsonArray = JSONArray.parseArray(jsonStr);
objectMap.put("jsonArray", jsonArray);
} else {
Map<String, Object> requestBodylist = (Map<String, Object>) cachedRequestBodyObject;
objectMap.putAll(requestBodylist);
}
}
}

// 获取URI
String uri = request.getPath().toString();
log.info("SqlFilter Request URI:{}", uri);

// 从配置中获取白名单,白名单中接口不校验sql关键字
List<String> whiteUrlList = Arrays.asList(whiteUrl.split(","));
log.info("SqlFilter WhiteUrlList:{}", whiteUrlList.toString());
if (CollectionUtils.isNotEmpty(whiteUrlList)) {
for (String whiteUrlStr : whiteUrlList){
String[] splitUrls = whiteUrlStr.split(":");// 将http
// request请求方法和URL分割;如:GET:/apps/v0.1/action
String newUrl = splitUrls[1].replaceAll(replaceRegex, replacement) + "$";
if (!StringUtils.startsWith(newUrl, "/")) {
newUrl = "/" + newUrl;
}
String regex = prefix + newUrl;// 将正则前缀部分+后续部分拼接成完整的URI校验正则表达式
Matcher matcher = Pattern.compile(regex).matcher(uri);
if (matcher.matches() && StringUtils.equalsIgnoreCase(splitUrls[0], Objects.requireNonNull(request.getMethod()).toString())) {
// 通过交给下一个过滤器
return chain.filter(exchange);
} else {
//有sql关键字,返回400
String sqlKeywords = sqlValidate(objectMap.toString());
if (StringUtils.isNotEmpty(sqlKeywords)) {
log.info("-----请求参数中含有sql关键字:{}", sqlKeywords);
return outResponse(response, sqlKeywords);
} else {
return chain.filter(exchange);
}
}
}
} else {
//有sql关键字,返回400
String sqlKeywords = sqlValidate(objectMap.toString());
if (StringUtils.isNotEmpty(sqlKeywords)) {
log.info("-----请求参数中含有sql关键字:{}", sqlKeywords);
return outResponse(response, sqlKeywords);
} else {
return chain.filter(exchange);
}
}
return chain.filter(exchange);
}

private Mono<Void> outResponse(ServerHttpResponse response, String sqlKeywords) {
try {
LSResponse<Object> lsResponse = new LSResponse<>();
lsResponse.setFaild(GlobalErrorCode.E1001_016);
lsResponse.setStatus(Integer.parseInt(GlobalErrorCode.E1001_016));
lsResponse.setMsg("请求参数中含有sql关键字:" + sqlKeywords);
// 对象转JSON转byte数组
ObjectMapper objectMapper = new ObjectMapper();
byte[] data = objectMapper.writeValueAsBytes(lsResponse);
// 写入响应头
DataBuffer buffer = response.bufferFactory().wrap(data);
response.setStatusCode(HttpStatus.BAD_REQUEST);
response.getHeaders().add(HttpHeaders.CONTENT_TYPE, "application/json;charset=UTF-8");
return response.writeWith(Mono.just(buffer));
} catch (JsonProcessingException e) {
log.error("lsResponse.parse.jsonException");
}
return null;
}

// 放在AuthorityFilter之后
@Override
public int getOrder() {
return 200;
}

// 效验
private String sqlValidate(String str) {
//将正则表达式进行编译
Pattern p = Pattern.compile(badStr, Pattern.CASE_INSENSITIVE);
//要校验的字节序列
Matcher m = p.matcher(str);
int count = 0;
List<String> keywordList = new ArrayList<>();
while (m.find()) {
String keyword = m.group(0);
if (!keywordList.contains(keyword)) {
keywordList.add(keyword);
}
count++;
if (count == 3) {
break;
}
}
// 入参中出现三次及以上的sql关键字则拦截
if (3 == count) {
return keywordList.toString();
}
return null;
}

}
posted @ 2020-12-28 20:31  dahuinihao  阅读(884)  评论(0编辑  收藏  举报