Java xss攻击拦截,Java CSRF跨站点伪造请求拦截
Java xss攻击拦截,Java CSRF跨站点伪造请求拦截
================================
©Copyright 蕃薯耀 2021-05-07
https://www.cnblogs.com/fanshuyao/
一、Java xss攻击拦截
XssFilter过滤器
import java.io.IOException; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; public class XssFilter implements Filter{ FilterConfig filterConfig = null; public void init(FilterConfig filterConfig) throws ServletException { this.filterConfig = filterConfig; } public void destroy() { this.filterConfig = null; } public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { chain.doFilter(new XssWrapper((HttpServletRequest) request), response); } }
XssWrapper类:
import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.nio.charset.Charset; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Map; import javax.servlet.ReadListener; import javax.servlet.ServletInputStream; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import org.apache.commons.lang3.StringUtils; import org.apache.commons.text.StringEscapeUtils; import com.szpl.csgx.utils.JsonUtil; public class XssWrapper extends HttpServletRequestWrapper { public static final String JSON_TYPE = "application/json"; public static final String CONTENT_TYPE = "Content-Type"; public static final String CHARSET = "UTF-8"; private String mBody; HttpServletRequest originalRequest = null; public XssWrapper(HttpServletRequest request) throws IOException { super(request); // 将body数据存储起来 originalRequest = request; setRequestBody(request.getInputStream()); } /** * 获取最原始的request。已经被getInputStream()了。 * * @return */ public HttpServletRequest getOrgRequest() { return originalRequest; } /** * 获取最原始的request的静态方法。已经被getInputStream()了。 * * @return */ public static HttpServletRequest getOriginalRequest(HttpServletRequest req) { if (req instanceof XssWrapper) { return ((XssWrapper) req).getOrgRequest(); } return req; } @Override public String getHeader(String name) { String value = super.getHeader(name); if(StringUtils.isBlank(value)) { return value; } return StringEscapeUtils.escapeHtml4(value); } @Override public String getQueryString() { return StringUtils.isBlank(super.getQueryString()) ? "" : StringEscapeUtils.escapeHtml4(super.getQueryString()); } @Override public String getParameter(String name) { String value = super.getParameter(name); if(StringUtils.isBlank(value)) { return value; } return StringEscapeUtils.escapeHtml4(value); } @Override public String[] getParameterValues(String name) { String[] values = super.getParameterValues(name); if (values == null) { return values; } for (int i=0; i < values.length; i++) { values[i] = StringEscapeUtils.escapeHtml4(values[i]); } return values; } @Override public Map<String, String[]> getParameterMap() { Map<String, String[]> map = new LinkedHashMap<String, String[]>(); Map<String, String[]> parameterMap = super.getParameterMap(); if(parameterMap == null) { return super.getParameterMap(); } for (String key : parameterMap.keySet()) { String[] values = parameterMap.get(key); if(values != null && values.length > 0) { for (int i = 0; i < values.length; i++) { values[i] = StringEscapeUtils.escapeHtml4(values[i]); } } map.put(key, values); } return map; } private void setRequestBody(InputStream stream) { String line = ""; StringBuilder body = new StringBuilder(); // 读取POST提交的数据内容 BufferedReader reader = new BufferedReader(new InputStreamReader(stream, Charset.forName(CHARSET))); try { while ((line = reader.readLine()) != null) { body.append(line); } } catch (IOException e) { e.printStackTrace(); } mBody = body.toString(); if(StringUtils.isBlank(mBody)) {//为空时,直接返回 return; } @SuppressWarnings("unchecked") Map<String,Object> map= JsonUtil.string2Obj(mBody, Map.class); Map<String,Object> resultMap=new HashMap<>(map.size()); for(String key : map.keySet()){ Object val = map.get(key); if(map.get(key) instanceof String){ resultMap.put(key, StringEscapeUtils.escapeHtml4(val.toString())); }else{ resultMap.put(key, val); } } mBody = JsonUtil.obj2String(resultMap); } @Override public BufferedReader getReader() throws IOException { return new BufferedReader(new InputStreamReader(getInputStream())); } @Override public ServletInputStream getInputStream() throws IOException { if(!JSON_TYPE.equalsIgnoreCase(super.getHeader(CONTENT_TYPE))) {//非json类型,直接返回 return super.getInputStream(); } if(StringUtils.isBlank(mBody)) {//为空时,直接返回 return super.getInputStream(); } final ByteArrayInputStream bais = new ByteArrayInputStream(mBody.getBytes(CHARSET)); return new ServletInputStream() { @Override public int read() throws IOException { return bais.read(); } @Override public boolean isFinished() { return false; } @Override public boolean isReady() { return false; } @Override public void setReadListener(ReadListener listener) { } }; } }
二、Java CSRF跨站点伪造请求拦截
import java.io.IOException; import java.net.URL; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; import cn.hutool.json.JSONUtil; /** * CSRF跨站点伪造请求拦截 */ @Component public class CsrfFilter implements Filter { //后台日志打印 private Logger log = LoggerFactory.getLogger(CsrfFilter.class); //跨站点请求白名单,通过英文逗号分隔。在application.properties配置 @Value("${csrf.white.paths}") private String[] csrfWhitePaths; //跨站点请求域名白名单,通过英文逗号分隔。在application.properties配置 @Value("${csrf.white.domains}") private String[] csrfWhiteDomains; @Override public void init(FilterConfig filterConfig) throws ServletException { log.info("init……"); } public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException { HttpServletRequest req = (HttpServletRequest) request; HttpServletResponse res = (HttpServletResponse) response; String referer = req.getHeader("Referer"); if (!StringUtils.isBlank(referer)) { //log.info("referer = " + referer); URL refererUrl = new URL(referer); String refererHost = refererUrl.getHost(); int refererPort = refererUrl.getPort(); String refererHostAndPort; if(refererPort == -1) { refererHostAndPort = refererHost; }else { refererHostAndPort = refererHost + ":" + refererPort; } //log.info("refererHostAndPort = " + refererHostAndPort); //log.info("refererHost = " + refererHost); String requestURL = req.getRequestURL().toString(); //log.info("requestURL = " + requestURL); URL urlRequest = new URL(requestURL); String requestHost = urlRequest.getHost(); int requestPort = urlRequest.getPort(); String requestHostAndPort; if(requestPort == -1) { requestHostAndPort = requestHost; }else { requestHostAndPort = requestHost + ":" + requestPort; } //log.info("requestHost = " + requestHost); if(requestHostAndPort.equalsIgnoreCase(refererHostAndPort)) {//同域名和同端口,即同一个域的系统,通过 filterChain.doFilter(request, response); }else { if(isCsrfWhiteDomains(refererHostAndPort)) {//域名白名单 filterChain.doFilter(request, response); return; } String path = urlRequest.getPath(); log.info("path = " + path); String actionPath = path.replaceAll(request.getServletContext().getContextPath(), ""); log.info("actionPath = " + actionPath); if(isCsrfWhitePaths(actionPath)) {//访问路径白名单 filterChain.doFilter(request, response); return; } log.warn("csrf跨站点伪造请求已经被拦截:"); log.warn("requestURL = " + requestURL); log.warn("referer = " + referer); res.sendRedirect(req.getContextPath() + "/illegal"); return; } }else{ filterChain.doFilter(request, response); } } @Override public void destroy() { log.info("destroy……"); } /** * 本系统不拦截的路径白名单 * @param path * @return */ private boolean isCsrfWhitePaths(String path) { if(csrfWhitePaths != null && csrfWhitePaths.length > 0) { for (String csrfWhitePath : csrfWhitePaths) { if(!StringUtils.isBlank(csrfWhitePath)) { if(csrfWhitePath.equals(path)) { log.info("跨站点请求所有路径白名单:csrfWhitePaths = " + JSONUtil.toJsonStr(csrfWhitePaths)); log.info("符合跨站点请求路径白名单:path = " + path); return true; } } } } return false; } /** * 不拦截外部系统的域名(可带端口)白名单 * @param path * @return */ private boolean isCsrfWhiteDomains(String refererHostAndPort) { if(csrfWhiteDomains != null && csrfWhiteDomains.length > 0) { for (String csrfWhiteDomain : csrfWhiteDomains) { if(!StringUtils.isBlank(csrfWhiteDomain)) { if(csrfWhiteDomain.equals(refererHostAndPort)) { log.info("跨站点请求所有【域名】]白名单:csrfWhiteDomains = " + JSONUtil.toJsonStr(csrfWhiteDomains)); log.info("符合跨站点请求【域名】白名单:refererHost = " + refererHostAndPort); return true; } } } log.info("跨站点请求非法【域名】:refererHost = " + refererHostAndPort); } return false; } }
配置文件:
#跨站点请求域名白名单,通过英文逗号分隔。如(abc.com:9010,abc.org:9010) csrf.white.domains=www.abc.com,abc.cn:9011
#跨站点伪造请求 #跨站点请求路径白名单,通过英文逗号分隔。如(/illegal,/illegal2,/illegal3) csrf.white.paths=/illegal
三、SpringBoot注册过滤器
import javax.servlet.Filter; import org.springframework.boot.web.servlet.FilterRegistrationBean; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import com.szpl.csgx.security.CsrfFilter; import com.szpl.csgx.security.XssFilter; /** * 使用配置方式开发Filter,否则其中的自动注入无效 * */ @Configuration public class HttpFilterConfig { /** * xss攻击过滤器 * @return */ @Bean public FilterRegistrationBean<Filter> xssFilter() { FilterRegistrationBean<Filter> XssBean = new FilterRegistrationBean<>(new XssFilter()); XssBean.setName("xssFilter"); XssBean.addUrlPatterns("/*"); XssBean.setOrder(4); return XssBean; } /** * csrf跨站点欺骗过滤器 */ @Bean public FilterRegistrationBean<Filter> csrfFilterRegistrationBean(CsrfFilter csrfFilter) { FilterRegistrationBean<Filter> registration = new FilterRegistrationBean<Filter>(); registration.setFilter(csrfFilter);//这里不能直接使用New,因为直接New出来的东西,CsrfFilter不受spring管理,不能通过@value注入变量 registration.addUrlPatterns("/*"); registration.setName("csrfFilter"); registration.setOrder(0); return registration; } }