记一次文件上传进度的开发经验
一、概览
1,浏览器发送文件给nginx,nginx针对特定接口(比如文件上传接口)关闭缓存,收到数据立即反向代理给下游,方便下游监听进度;
2,springboot服务器DispatcherServlet处理这个请求,checkMultipart检查是否是文件上传请求,调用multipartResolver解析这个请求;
3,排除springboot默认指定的multipartResolver对象StandardServletMultipartResolver,注入定制的multipartResolver,并重写parseRequest方法;
4,parseRequest方法中加入自定义进度监听器,使用redis实时更新上传进度,为此处上传进度(网络传输、写入临时文件)分配比重;
5,DispatcherServlet分发请求到controller,controller将临时文件真实写入,并分配一定进度比重;
6,webSocket使用定时器每隔(比如0.2秒)向客户端(浏览器)推送进度数据;
二、具体实现
1,nginx配置
location 上传接口url { proxy_pass 网关地址/服务器地址; proxy_request_buffering off; #关闭代理请求缓存,收到请求数据直接反向代理给下游;默认是开启的,只有request sent结束了,才会整个反向代理给下游; client_max_body_size 200M; #文件大小限制 keepalive_timeout 300; #超时限制 }
2,DispatcherServlet解析请求
DispatcherServlet.doService —> doDispatch —> checkMultipart —>this.multipartResolver.resolveMultipart
3,指定DispatcherServlet的multipartResolver对象
3.1,排除springboot默认的multipartResolver
// 可以在springboot的启动类上排除这个bean
@SpringBootApplication(exclude = {MultipartAutoConfiguration.class})
3.2,指定注入定制的multipartResolver,重写相关方法,监听进度
// 可以在springboot的启动类里注入这个bean @Bean(name = "multipartResolver") public MultipartResolver multipartResolver() { return new CustomMultipartResolver(); }
public class CustomMultipartResolver extends CommonsMultipartResolver { /** * 可以注入其他的业务接口 */ @Autowired private 业务接口; @Override protected MultipartParsingResult parseRequest(HttpServletRequest request) throws MultipartException { String encoding = determineEncoding(request); FileUpload fileUpload = prepareFileUpload(encoding); // 针对某些接口进行监听,判断的时候可以不用contains,直接用equals String url = request.getServletPath(); if (StringUtils.isNotBlank(url) && url.contains(文件上传接口url)) { // 文件上传进度监听器 FileUploadProgressListener listener = new FileUploadProgressListener(); // 从request对象中获取请求体中的参数比较麻烦和困难,所以参数放在header中;也许可以参考springboot的StandardServletMultipartResolver是如何获取请求体参数的 listener.setUserId(request.getHeader("userId")); listener.setUuid(request.getHeader("uuid")); String fileName = request.getHeader("fileName"); if (StringUtils.isNotBlank(fileName)) { try { // header参数不支持中文,使用url编解码 fileName = URLDecoder.decode(fileName, "UTF-8"); } catch (UnsupportedEncodingException e) { throw new BusinessException(ErrorCode.Failure.getCode(), "上传失败,文件名解析失败!"); } fileName = fileName.substring(0, fileName.lastIndexOf(".")); } listener.setFileName(fileName); if (StringUtils.isNotBlank(listener.getUserId()) && StringUtils.isNotBlank(listener.getUuid()) && StringUtils.isNotBlank(listener.getFileName())) { listener.set业务接口(业务接口); fileUpload.setProgressListener(listener); } } try { List<FileItem> fileItems = ((ServletFileUpload) fileUpload).parseRequest(request); return parseFileItems(fileItems, encoding); } catch (FileUploadException ex) { throw new MultipartException("Failed to parse multipart servlet request", ex); } } }
import org.apache.commons.fileupload.ProgressListener;
/** * 文件上传进度监听 * 参数都是业务需要用到的参数,你可能根据业务确定相关参数
*/ @Slf4j public class FileUploadProgressListener implements ProgressListener { /** * 用户id */ private String userId; /** * 文件唯一标识 */ private String uuid; /** * 文件名 */ private String fileName; private 业务接口; /** * 上次更新进度时,当时读取的字节数 */ private long lastBytesRead; /** * 标识是否用户执行了“停止上传” */ private boolean isStopUpload = false; /** * 监听读取文件的进度 * * @param pBytesRead 已读取的字节数 * @param pContentLength 文件的总字节数 * @param pItems The number of the field, which is currently being read. * (0 = no item so far, 1 = first item is being read, ...) */ @Override public void update(long pBytesRead, long pContentLength, int pItems) { // 每读取1%的字节数,记录一次进度 if ((pBytesRead - lastBytesRead) < (pContentLength / 100)) { return; } // 如果用户执行了“停止上传”,则不再更新进度 if (!isStopUpload && 业务接口.isStopUpload(uuid)) { isStopUpload = true; // 防止用户执行了“停止上传”,删除了对应的进度后,某个场景又添加了这个进度 进度对象 = new 进度对象(); 进度对象.setUserId(userId); 进度对象.setStoreFileName(uuid); 业务接口.deleteRedisProgress(进度对象); } if (isStopUpload) { return; } // 进度(单位%) long percent = pBytesRead * 分配的比重 / pContentLength; 进度对象 = new 进度对象(); 进度对象.setUploadProgress(percent); 进度对象.setFileName(fileName); 进度对象.setStoreFileName(uuid); 进度对象.setUserId(userId); 业务接口.updateRedisProgress(进度对象); // 记录此次更新进度时,已完成读取的字节数 lastBytesRead = pBytesRead; } public String getUserId() { return userId; } public void setUserId(String userId) { this.userId = userId; } public String getUuid() { return uuid; } public void setUuid(String uuid) { this.uuid = uuid; } public String getFileName() { return fileName; } public void setFileName(String fileName) { this.fileName = fileName; } public FileNodeService getFileNodeService() { return fileNodeService; } public void setFileNodeService(FileNodeService fileNodeService) { this.fileNodeService = fileNodeService; } }
4,controller将临时文件写入
/** * 将上传的文件写入本地文件中 * * @param in 上传的文件 * @param out 目标路径文件 * @param fileUploadRequest 档案管理,文件上传请求参数 * @throws IOException io异常 */ private void copy(InputStream in, OutputStream out, 业务对象, 业务接口) throws IOException { Assert.notNull(in, "No InputStream specified"); Assert.notNull(out, "No OutputStream specified"); // 是否用户执行了“停止上传” boolean isStopUpload = false; // 已写入的字节数 int bytesWrite = 0; // 记录上次已写入的字节数 int lastBytesWrite = 0; // 文件总大小 long totalSize = 业务对象.getFileBytes(); long frequency = totalSize / 5; // 每次读取和写入的字节数组 byte[] buffer = new byte[4KB(4096)或者随便]; // 每次读取的字节数量 int bytesRead; try { while ((bytesRead = in.read(buffer)) != -1) { // 写入数据 out.write(buffer, 0, bytesRead); bytesWrite += bytesRead; // 每写入20%的字节数,记录一次进度 if ((bytesWrite - lastBytesWrite) < frequency) { continue; } // 记录进度 业务对象.setUploadProgress(前面网络传输写临时文件分配的比重 + (long) bytesWrite * 此处写入真实文件分配的比重 / totalSize); 业务接口.updateRedisProgress(业务对象); lastBytesWrite = bytesWrite; // 如果用户执行了“停止上传”,则不再写入文件和更新进度 if (业务接口.isStopUpload(业务对象.get文件标识())) { isStopUpload = true; break; } } out.flush(); } catch (Exception e) { // 出现异常时,删除已写好的文件 } finally { IOUtils.closeQuietly(in); IOUtils.closeQuietly(out); } // 用户执行了“停止上传”,删除文件和进度记录 if (isStopUpload) { 业务接口.deleteRedisProgress(业务对象); 删除已经写好的文件; } }
5,websocket定频推送进度信息
import com.corundumstudio.socketio.AckRequest; import com.corundumstudio.socketio.HandshakeData; import com.corundumstudio.socketio.SocketIOClient; import com.corundumstudio.socketio.SocketIOServer; import com.corundumstudio.socketio.annotation.OnConnect; import com.corundumstudio.socketio.annotation.OnDisconnect; import com.corundumstudio.socketio.annotation.OnEvent; import 业务对象; import 业务接口; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang.StringUtils; import org.springframework.stereotype.Component; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; /** * websocket - 事件处理 * 1,同时支持的客户端连接数是有限的 * 2,使用线程池处理单个客户端连接的定频推送任务*/ @Slf4j @Component @AllArgsConstructor public class WebSocketMessageEventHandler { /** * websocket服务端 */ private final SocketIOServer server; private final 业务接口; /** * websocket定频推送上传进度的线程池 */ private final ScheduledExecutorService sendProcessPool; /** * 客户端连接事件 * * @param client 客户端 */ @OnConnect public void onConnect(SocketIOClient client) { log.info("[websocket]onConnect method start"); if (null == client) { log.error("[websocket]onConnect failed, client is null"); return; } // 加入房间号为:userId;你也可以传递和使用其他参数 HandshakeData handshakeData = client.getHandshakeData(); String userId = handshakeData.getSingleUrlParam("userId"); if (StringUtils.isBlank(userId)) { log.error("[websocket]onConnect failed, param-userId is null"); return; } client.joinRoom(userId); log.info("[websocket]onConnect method end"); } /** * 客户端断开连接事件 * * @param client 客户端 */ @OnDisconnect public void onDisconnect(SocketIOClient client) { log.info("[websocket]on'Dis'connect method start"); if (null == client) { log.error("[websocket]on'Dis'connect failed, client is null"); return; } // 离开房间号为:userId;你也可以传递和使用其他参数 HandshakeData handshakeData = client.getHandshakeData(); String userId = handshakeData.getSingleUrlParam("userId"); if (StringUtils.isBlank(userId)) { log.error("[websocket]on'Dis'connect failed, userId is null"); return; } client.leaveRoom(userId); log.info("[websocket]on'Dis'connect method end"); } /** * 请求文件上传进度 * * @param client 客户端 * @param ackRequest websocket ack * @param 业务对象 */ @OnEvent(value = "getFileUploadProgress") public void onEvent(SocketIOClient client, AckRequest ackRequest, 业务对象) throws InterruptedException { log.info("[websocket]OnEvent-getFileUploadProgress start"); if (null == client || null == 业务对象|| StringUtils.isBlank(业务对象.getUserId())) { log.error("[websocket]OnEvent-getFileUploadProgress failed, client is null or userId is null"); return; } if (StringUtils.isNotBlank(业务对象.getMsg())) { // 如果客户端发来的是心跳信息,则原值返回 client.sendEvent("fileUploadProgressResult", 业务对象.getMsg()); } else { Runnable task = new SendProgressTimerTask(server, client, 业务接口, 业务对象, log); // 等上一个消息推送结束后,再过0.2秒,继续推送消息 sendProcessPool.scheduleWithFixedDelay(task, 0, 200, TimeUnit.MILLISECONDS); } } }
public class SendProgressTimerTask implements Runnable { /** * websocket server */ private final SocketIOServer server; /** * websocket client */ private final SocketIOClient client; private final 业务接口; private final 业务对象; /** * 主线程的日志记录器 */ private final Logger log; /** * 标记是否是向客户端发送的第一条消息 */ private boolean isFirstCycle = true; /** * 上次推送的消息 */ private List<业务对象> previousProgressList; public SendProgressTimerTask(SocketIOServer server, SocketIOClient client, 业务接口, 业务对象, Logger log) { this.server = server; this.client = client; this.业务接口 = 业务接口; this.业务对象 = 业务对象; this.log = log; } @Override public void run() { // 如果客户端离开房间了(断连),则结束任务 if (null == server.getClient(client.getSessionId())) { log.info("[websocket]OnEvent-getFileUploadProgress end, client left the room");
// 此处抛出异常,可以结束线程池中的定时任务,下同 throw new RuntimeException(); } // 获取用户的文件上传进度列表 List<业务对象> progressList = 业务接口.listFileUploadProgress(业务对象); // 如果连续两次返回的消息都是空,则不再继续推送消息 if (!isFirstCycle && CollectionUtils.isEmpty(progressList) && CollectionUtils.isEmpty(previousProgressList)) { log.info("[websocket]OnEvent-getFileUploadProgress end, empty result twice"); throw new RuntimeException(); } // 标记:第一条消息已经发送出去了 isFirstCycle = false; // 保存本次消息 previousProgressList = progressList; // 向客户端发送用户的文件上传进度 if (CollectionUtils.isNotEmpty(progressList)) { Collections.sort(progressList); } client.sendEvent("fileUploadProgressResult", progressList); } }
@Bean public ScheduledExecutorService sendProcessPool() {
// 可以一直扩充线程 return new ScheduledThreadPoolExecutor(10, new BasicThreadFactory.Builder().namingPattern("sendFileUploadProgress-%d").daemon(true).build()); }
// 下面是从网上copy的
@Configuration @AllArgsConstructor public class NettySocketConfig { private final SocketProperties socketProperties; @Bean public SocketIOServer socketIOServer() { com.corundumstudio.socketio.Configuration config = new com.corundumstudio.socketio.Configuration(); // 设置主机名,默认是0.0.0.0 config.setHostname(socketProperties.getServerIp()); // 设置监听端口 config.setPort(socketProperties.getSocketPort()); // 协议升级超时时间(毫秒),默认10000。HTTP握手升级为ws协议超时时间 config.setUpgradeTimeout(10000); // Ping消息间隔(毫秒),默认25000。客户端向服务器发送一条心跳消息间隔 config.setPingInterval(socketProperties.getPingInterval()); // Ping消息超时时间(毫秒),默认60000,这个时间间隔内没有接收到心跳消息就会发送超时事件 config.setPingTimeout(socketProperties.getPingTimeout()); // 握手协议参数使用JWT的Token认证方案 config.setAuthorizationListener(data -> { return true; }); return new SocketIOServer(config); } @Bean public SpringAnnotationScanner springAnnotationScanner(SocketIOServer socketServer) { return new SpringAnnotationScanner(socketServer); } }
@Data @Component @ConfigurationProperties(prefix = "socketio") public class SocketProperties { /** * socket端口 */ private String serverIp; /** * socket端口 */ private Integer socketPort; /** * Ping消息间隔(毫秒) */ private Integer pingInterval; /** * Ping消息超时时间(毫秒) */ private Integer pingTimeout; }
socketio:
serverIp: #设置socket所在ip地址,填入localhost可能会失败,需要对应处理;填入死的ip则没问题
socketPort: #socket端口
pingInterval: 25000 #Ping消息间隔(毫秒)
pingTimeout: 60000 #Ping消息超时时间(毫秒)
// 下面是从网上copy的
@Slf4j @Component @Order(1) public class ServerRunner implements CommandLineRunner { @Autowired private SocketIOServer server; @Override public void run(String... args) { log.info("SocketIOServerRunner start..."); try { server.start(); } catch (Exception e) { log.error("socket.io start failed!", e); String hostName = server.getConfiguration().getHostname(); log.error("hostName=" + hostName, e); } } }
#nginx配置这个,可以解决websocket跨域的问题
location /socket.io { proxy_pass http://socketio的ip:socketio的端口; }
三、问题补充
1,websocket版本问题
java端用nettysocketio,前端用socket.io-client;java端的版本一般都是2,前端的版本容易超出2导致问题,所以服务端和客户端版本要一致;
2,注意浏览器(前端js)、nginx、网关超时配置;
3,注意websocket跨域问题,可通过nginx反向代理解决;
4,注意nginx默认开启请求缓存,会等请求体整个传输到nginx,nginx才会反向代理到下游,导致下游监听的进度几乎是瞬间完成的;可以配置nginx缓存的大小,或者直接关闭缓存;
5,注意nginx配置的请求体的大小限制;
6,while(true){Thread.sleep()},如果不放心的话,可以使用ScheduledExecutorService代替;Timer过时了;
定时任务不要让他一直执行,如客户端离开,任务就不需要执行了;定时任务注意try catch,遇到异常,任务会停止,下面的任务也不会执行了;
7,使用redis hash(map)保存该用户的进度记录,就像购物车,上传一个文件就放入一个记录,查询、修改、删除都很方便,我这边userId是购物车的名字,文件的标识是商品名字;但是只能对整个map设置超时,无法对map中的某个记录单独设置超时;可以通过一些其他方法解决,如存时间字段,判断当前时间减去存入时间是否超时,删除记录;或者map保存其他redis对象,其他redis对象设置超时时间等;map是没有顺序的,可以查出后根据某个字段排序;
8,一个50MB的文件,网络传输可能花个15秒,但是本地写入时可能只要0.2秒,所以也许只要监听网络传输的部分就行了;使用阿里巴巴的oss,有进度监听回调接口;
9,URLEncode/URLDecode注意加号空格问题:https://www.cnblogs.com/seeall/p/16561328.html;
10,nginx、网关的异常,或者其他异常,后台服务器可能无法获取,可以设定上传进度30秒无变化,则改状态为失败;
11,multipartresolver中获取请求体参数,暂定认为是有困难的;不管是springboot默认的multipartresolver获取文件名的方式,还是本例获取文件名的方式,都存在进度和文件名无法兼得的问题;进度需要在文件解析的过程中获得,文件名要在文件解析结束后获得;