Java xss攻击拦截,Java CSRF跨站点伪造请求拦截

Java xss攻击拦截,Java CSRF跨站点伪造请求拦截

 

================================

©Copyright 蕃薯耀 2021-05-07

https://www.cnblogs.com/fanshuyao/

 

一、Java xss攻击拦截

XssFilter过滤器

复制代码
import java.io.IOException;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;

public class XssFilter implements Filter{
    FilterConfig filterConfig = null;  
      
    public void init(FilterConfig filterConfig) throws ServletException {  
        this.filterConfig = filterConfig;  
    }  
  
    public void destroy() {  
        this.filterConfig = null;  
    }  
  
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        chain.doFilter(new XssWrapper((HttpServletRequest) request), response);  
    }  
}
复制代码

 

XssWrapper类:

复制代码
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringEscapeUtils;

import com.szpl.csgx.utils.JsonUtil;


public class XssWrapper extends HttpServletRequestWrapper {
    
    public static final String JSON_TYPE = "application/json";
    public static final String CONTENT_TYPE = "Content-Type";
    public static final String CHARSET = "UTF-8";
    private String mBody;
    HttpServletRequest originalRequest = null;

    public XssWrapper(HttpServletRequest request) throws IOException {
        super(request);
        // 将body数据存储起来
        originalRequest = request;
        setRequestBody(request.getInputStream());
    }
    
    
    /**
     * 获取最原始的request。已经被getInputStream()了。
     *
     * @return
     */
    public HttpServletRequest getOrgRequest() {
        return originalRequest;
    }

    
    /**
     * 获取最原始的request的静态方法。已经被getInputStream()了。
     *
     * @return
     */
    public static HttpServletRequest getOriginalRequest(HttpServletRequest req) {
        if (req instanceof XssWrapper) {
            return ((XssWrapper) req).getOrgRequest();
        }
        return req;
    }

    
    @Override
    public String getHeader(String name) {
        String value = super.getHeader(name);
        if(StringUtils.isBlank(value)) {
            return value;
        }
        return StringEscapeUtils.escapeHtml4(value);
    }
    
    
    @Override
    public String getQueryString() {
        return StringUtils.isBlank(super.getQueryString()) ? "" : StringEscapeUtils.escapeHtml4(super.getQueryString());
    }


    @Override
    public String getParameter(String name) {
        String value = super.getParameter(name);
        if(StringUtils.isBlank(value)) {
            return value;
        }
        return StringEscapeUtils.escapeHtml4(value);
    }

    
    @Override
    public String[] getParameterValues(String name) {
        String[] values = super.getParameterValues(name);
        if (values == null) {
            return values;
        }
        
        for (int i=0; i < values.length; i++) {
            values[i] = StringEscapeUtils.escapeHtml4(values[i]);
        }
        return values;
    }
    
    
    @Override
    public Map<String, String[]> getParameterMap() {
        Map<String, String[]> map = new LinkedHashMap<String, String[]>();
        Map<String, String[]> parameterMap = super.getParameterMap();
        
        if(parameterMap == null) {
            return super.getParameterMap();
        }
        
        for (String key : parameterMap.keySet()) {
            String[] values = parameterMap.get(key);
            if(values != null && values.length > 0) {
                for (int i = 0; i < values.length; i++) {
                    values[i] = StringEscapeUtils.escapeHtml4(values[i]);
                }
            }
            map.put(key, values);
        }
        return map;
    }
    

    private void setRequestBody(InputStream stream) {
        String line = "";
        StringBuilder body = new StringBuilder();
        
        // 读取POST提交的数据内容
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream, Charset.forName(CHARSET)));
        try {
            while ((line = reader.readLine()) != null) {
                body.append(line);
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        mBody = body.toString();
        
        if(StringUtils.isBlank(mBody)) {//为空时,直接返回
            return;
        }
        @SuppressWarnings("unchecked")
        Map<String,Object> map= JsonUtil.string2Obj(mBody, Map.class);
        
        Map<String,Object> resultMap=new HashMap<>(map.size());
        for(String key : map.keySet()){
            Object val = map.get(key);
            
            if(map.get(key) instanceof String){
                resultMap.put(key, StringEscapeUtils.escapeHtml4(val.toString()));
            }else{
                resultMap.put(key, val);
            }
        }
        mBody = JsonUtil.obj2String(resultMap);
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        
        if(!JSON_TYPE.equalsIgnoreCase(super.getHeader(CONTENT_TYPE))) {//非json类型,直接返回
            return super.getInputStream();
        }
        
        if(StringUtils.isBlank(mBody)) {//为空时,直接返回
            return super.getInputStream();
        }
        
        final ByteArrayInputStream bais = new ByteArrayInputStream(mBody.getBytes(CHARSET));
        return new ServletInputStream() {
            @Override
            public int read() throws IOException {
                return bais.read();
            }
            @Override
            public boolean isFinished() {
                return false;
            }
            @Override
            public boolean isReady() {
                return false;
            }
            @Override
            public void setReadListener(ReadListener listener) {
            }
        };
    }

}
复制代码

 

 

二、Java CSRF跨站点伪造请求拦截

复制代码
import java.io.IOException;
import java.net.URL;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import cn.hutool.json.JSONUtil;


/** 
 * CSRF跨站点伪造请求拦截
 */
@Component
public class CsrfFilter implements Filter {  
    
    //后台日志打印
    private Logger log = LoggerFactory.getLogger(CsrfFilter.class);
    
    
    //跨站点请求白名单,通过英文逗号分隔。在application.properties配置
    @Value("${csrf.white.paths}")
    private String[] csrfWhitePaths;
    
    //跨站点请求域名白名单,通过英文逗号分隔。在application.properties配置
    @Value("${csrf.white.domains}")
    private String[] csrfWhiteDomains;
    
    
    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        log.info("init……");
    }
    
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain)
            throws IOException, ServletException {
        HttpServletRequest req = (HttpServletRequest) request;
        HttpServletResponse res = (HttpServletResponse) response;
        String referer = req.getHeader("Referer");
        
        if (!StringUtils.isBlank(referer)) {
            
            //log.info("referer = " + referer);
            
            URL refererUrl = new URL(referer);
            String refererHost = refererUrl.getHost();
            int refererPort = refererUrl.getPort();
            String refererHostAndPort;
            if(refererPort == -1) {
                refererHostAndPort = refererHost;
            }else {
                refererHostAndPort = refererHost + ":" + refererPort;
            }
            
            //log.info("refererHostAndPort = " + refererHostAndPort);
            //log.info("refererHost = " + refererHost);
            
            String requestURL = req.getRequestURL().toString();
            //log.info("requestURL = " + requestURL);
            
            URL urlRequest = new URL(requestURL);
            String requestHost = urlRequest.getHost();
            int requestPort = urlRequest.getPort();
            String requestHostAndPort;
            if(requestPort == -1) {
                requestHostAndPort = requestHost;
            }else {
                requestHostAndPort = requestHost + ":" + requestPort;
            }
            
            //log.info("requestHost = " + requestHost);
            if(requestHostAndPort.equalsIgnoreCase(refererHostAndPort)) {//同域名和同端口,即同一个域的系统,通过
                filterChain.doFilter(request, response);
                
            }else {
                
                if(isCsrfWhiteDomains(refererHostAndPort)) {//域名白名单
                    filterChain.doFilter(request, response);
                    return;
                }
                
                String path = urlRequest.getPath();
                log.info("path = " + path);
                String actionPath = path.replaceAll(request.getServletContext().getContextPath(), "");
                log.info("actionPath = " + actionPath);
                    
                if(isCsrfWhitePaths(actionPath)) {//访问路径白名单
                    filterChain.doFilter(request, response);
                    return;
                }
                
                log.warn("csrf跨站点伪造请求已经被拦截:");
                log.warn("requestURL = " + requestURL);
                log.warn("referer = " + referer);
                res.sendRedirect(req.getContextPath() + "/illegal");
                return;
            }
            
            
        }else{
            filterChain.doFilter(request, response);
        }
    }
    
    @Override
    public void destroy() {
        
        log.info("destroy……");
    }  
    
    
    /**
     * 本系统不拦截的路径白名单
     * @param path
     * @return
     */
    private boolean isCsrfWhitePaths(String path) {
        
        if(csrfWhitePaths != null && csrfWhitePaths.length > 0) {
            for (String csrfWhitePath : csrfWhitePaths) {
                if(!StringUtils.isBlank(csrfWhitePath)) {
                    if(csrfWhitePath.equals(path)) {
                        log.info("跨站点请求所有路径白名单:csrfWhitePaths = " + JSONUtil.toJsonStr(csrfWhitePaths));
                        log.info("符合跨站点请求路径白名单:path = " + path);
                        return true;
                    }
                }
            }
        }
        return false;
    }
    
    
    /**
     * 不拦截外部系统的域名(可带端口)白名单
     * @param path
     * @return
     */
    private boolean isCsrfWhiteDomains(String refererHostAndPort) {
        
        if(csrfWhiteDomains != null && csrfWhiteDomains.length > 0) {
            for (String csrfWhiteDomain : csrfWhiteDomains) {
                if(!StringUtils.isBlank(csrfWhiteDomain)) {
                    if(csrfWhiteDomain.equals(refererHostAndPort)) {
                        log.info("跨站点请求所有【域名】]白名单:csrfWhiteDomains = " + JSONUtil.toJsonStr(csrfWhiteDomains));
                        log.info("符合跨站点请求【域名】白名单:refererHost = " + refererHostAndPort);
                        return true;
                    }
                }
            }
            log.info("跨站点请求非法【域名】:refererHost = " + refererHostAndPort);
        }
        return false;
    }
    
    
}
复制代码

 

配置文件:

#跨站点请求域名白名单,通过英文逗号分隔。如(abc.com:9010,abc.org:9010)
csrf.white.domains=www.abc.com,abc.cn:9011

 

#跨站点伪造请求
#跨站点请求路径白名单,通过英文逗号分隔。如(/illegal,/illegal2,/illegal3)
csrf.white.paths=/illegal

 

 

三、SpringBoot注册过滤器

复制代码
import javax.servlet.Filter;

import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import com.szpl.csgx.security.CsrfFilter;
import com.szpl.csgx.security.XssFilter;

/**
 * 使用配置方式开发Filter,否则其中的自动注入无效
 *
 */
@Configuration
public class HttpFilterConfig {

    /**
     * xss攻击过滤器
     * @return
     */
    @Bean
    public FilterRegistrationBean<Filter> xssFilter() {
        FilterRegistrationBean<Filter> XssBean = new FilterRegistrationBean<>(new XssFilter());
        XssBean.setName("xssFilter");
        XssBean.addUrlPatterns("/*");
        XssBean.setOrder(4);
        return XssBean;
    }
    
    
    /**
     * csrf跨站点欺骗过滤器
     */
    @Bean
    public FilterRegistrationBean<Filter> csrfFilterRegistrationBean(CsrfFilter csrfFilter) {
        FilterRegistrationBean<Filter> registration = new FilterRegistrationBean<Filter>();
        registration.setFilter(csrfFilter);//这里不能直接使用New,因为直接New出来的东西,CsrfFilter不受spring管理,不能通过@value注入变量
        registration.addUrlPatterns("/*");

        registration.setName("csrfFilter");
        registration.setOrder(0);
        return registration;
    }


    
}
复制代码

posted on 2021-07-27 11:39  Hi,王松柏  阅读(511)  评论(0编辑  收藏  举报

导航