HttpServletRequest重复读取InputStream

背景

HttpServletRequest的InputStream流只能读取一次,在很多情况下都是需要在多个重复读取,比如日志打印时需要打印body,body需要预读取作其他处理,这个时候就需要HttpServletRequest能够缓存依稀输入流,达到重复读取的效果。

org.springframework.web.util.ContentCachingRequestWrapper

为什么不用这个类?

仔细看源码会发现,该类缓存的处理不是针对InputStream流的,不能达到getInputStream()重复读取的效果。第一次调用getInputStream()方法时会同步写到 ByteArrayOutputStream 中,如果需要用到ByteArrayOutputStream中的bytes,只有在getInputStream()拿到流并执行了read、readLine等操作读取完毕后,通过getContentAsByteArray方法拿到完整的body字节数据。

 

application/x-www-form-urlencoded

form表单请求的特殊之处(下面的分析都是基于org.eclipse.jetty.server.Request)

 public String getParameter(String name) {
     return (String)this.getParameters().getValue(name, 0);
 }
 
 public Map<String, String[]> getParameterMap() {
     return Collections.unmodifiableMap(this.getParameters().toStringArrayMap());
 }
 
 public Enumeration<String> getParameterNames() {
     return Collections.enumeration(this.getParameters().keySet());
 }
 
 public String[] getParameterValues(String name) {
     List<String> vals = this.getParameters().getValues(name);
     return vals == null ? null : (String[])vals.toArray(new String[vals.size()]);

该类中的这四个方法都调用了this.getParameters(),这个方法底层初始化ParameterMap时,如果是form表单请求也会同时去读取InputStream,拿到body里面的数据,这里会导致流失效,不能再次读取

 

遇到的场景

目前所负责的项目是Java/Clojure的混合项目,存在Springmvc和Ring两个web框架,请求由spring启动的jetty容器接管的,springmvc没有实现的请求需要通过一系列的适配逻辑转发到clojure的ring框架中,form请求在中转的时候就是通过getInputStream中转的。但是在springmvc中有前置逻辑调用了getParameter(“key”),所以导致代理到ring框架的请求找不到form参数。

实现

解决方案

getInputStream不能重复读取的问题

通过在第一次请求getInputStream时,去完全读取当前请求最初的InputStream,然后缓存到ByteArrayOutputStream中,再封装一个InputStream的自定义子类(读取ByteArrayOutputStream的bytes)返回。

解决form请求在jetty Request下不能getParameterMap() 和getInputStream()同时调用的问题

这里直接借鉴了ContentCachingRequestWrapper中的writeRequestParametersToCachedContent方法,在初始化的时候直接判断是否form请求,如果是直接先调用getParameters()去读取流,然后通过getParameterMap()根据form body的规则重写成bytes

遗留问题

目前的form请求的解决方案对file类参数不太友好,后续再优化这块解析重写逻辑

代码

 import lombok.SneakyThrows;
 import org.apache.commons.io.IOUtils;
 import org.springframework.core.Ordered;
 import org.springframework.core.annotation.Order;
 import org.springframework.http.HttpMethod;
 import org.springframework.web.filter.OncePerRequestFilter;
 
 import javax.servlet.FilterChain;
 import javax.servlet.ReadListener;
 import javax.servlet.ServletException;
 import javax.servlet.ServletInputStream;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletRequestWrapper;
 import javax.servlet.http.HttpServletResponse;
 import java.io.*;
 import java.net.URLEncoder;
 import java.nio.charset.StandardCharsets;
 import java.util.*;
 
 /**
  * @class-name: RepeatReadHttpServletRequestWrapFilter
  * @description:
  * @author: Mr.Zeng
  * @date: 2022-08-11 23:25
  */
 @Order(Ordered.HIGHEST_PRECEDENCE)
 public class RepeatReadHttpServletRequestWrapFilter extends OncePerRequestFilter {
 
     @Override
     protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
         RepeatReadHttpServletRequest repeatReadRequest = new RepeatReadHttpServletRequest(request);
         filterChain.doFilter(repeatReadRequest, response);
    }
 
     public static class RepeatReadServletInputStream extends ServletInputStream {
 
         private final ByteArrayInputStream copiedInputStream;
 
         public RepeatReadServletInputStream(ByteArrayOutputStream stream) {
             this.copiedInputStream = new ByteArrayInputStream(stream.toByteArray());
        }
 
         public RepeatReadServletInputStream(byte[] buf) {
             this.copiedInputStream = new ByteArrayInputStream(buf);
        }
 
         public boolean isFinished() {
             return false;
        }
 
         public boolean isReady() {
             return true;
        }
 
         public void setReadListener(ReadListener readListener) {
 
        }
 
         public int read() {
             return this.copiedInputStream.read();
        }
 
    }
 
     public static class RepeatReadHttpServletRequest extends HttpServletRequestWrapper {
 
         private static final String FORM_CONTENT_TYPE = "application/x-www-form-urlencoded";
 
         private final ByteArrayOutputStream cachedStream;
 
         public RepeatReadHttpServletRequest(HttpServletRequest request) {
             super(request);
             this.cachedStream = initialCachedStream(request);
             prepareFormRequestIfNeed();
        }
 
         public String getCharacterEncoding() {
             return Optional.ofNullable(super.getCharacterEncoding()).orElse(StandardCharsets.UTF_8.toString());
        }
 
         public ServletInputStream getInputStream() {
             if (isEmptyCachedStream()) {
                 this.cacheInputStream();
            }
             return new RepeatReadServletInputStream(this.cachedStream);
        }
 
         public BufferedReader getReader() {
             return new BufferedReader(new InputStreamReader(this.getInputStream()));
        }
 
         private ByteArrayOutputStream initialCachedStream(HttpServletRequest request) {
             int contentLength = request.getContentLength();
             return new ByteArrayOutputStream(contentLength >= 0 ? contentLength : 1024);
        }
 
         private void prepareFormRequestIfNeed() {
             if (isFormPost()) {
                 writeRequestParametersToCachedContent();
            }
        }
 
         private void writeRequestParametersToCachedContent() {
             try {
                 String requestEncoding = getCharacterEncoding();
                 Map<String, String[]> form = super.getParameterMap();
                 for (Iterator<String> nameIterator = form.keySet().iterator(); nameIterator.hasNext(); ) {
                     String name = nameIterator.next();
                     List<String> values = Arrays.asList(form.get(name));
                     for (Iterator<String> valueIterator = values.iterator(); valueIterator.hasNext(); ) {
                         String value = valueIterator.next();
                         this.cachedStream.write(URLEncoder.encode(name, requestEncoding).getBytes());
                         if (value != null) {
                             this.cachedStream.write('=');
                             this.cachedStream.write(URLEncoder.encode(value, requestEncoding).getBytes());
                             if (valueIterator.hasNext()) {
                                 this.cachedStream.write('&');
                            }
                        }
                    }
                     if (nameIterator.hasNext()) {
                         this.cachedStream.write('&');
                    }
                }
            } catch (IOException ex) {
                 throw new IllegalStateException("Failed to write request parameters to cached content", ex); 
          } 
      } 
​ 
       private boolean isFormPost() { 
           String contentType = getContentType(); 
           return contentType != null && contentType.contains(FORM_CONTENT_TYPE) && HttpMethod.POST.matches(getMethod()); 
      } 
​ 
       @SneakyThrows 
       private void cacheInputStream() { 
           if (isEmptyCachedStream()) { 
               synchronized (this) { 
                   if (isEmptyCachedStream()) { 
                       IOUtils.copy(super.getInputStream(), this.cachedStream); 
                  } 
              } 
          } 
      } 
​ 
       private boolean isEmptyCachedStream() { 
           return this.cachedStream.size() == 0; 
      } 
​ 
  } 
​ 
}
 
posted @ 2022-08-13 10:24  原则  阅读(497)  评论(0编辑  收藏  举报