springboot 集成 websocket
1.首先添加maven依赖
<dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-websocket</artifactId> </dependency>
2.添加拦截器
import cn.hutool.core.util.StrUtil; import cn.hutool.extra.spring.SpringUtil; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.HandshakeInterceptor; import javax.servlet.http.HttpSession; import java.util.Map; public class CustomWebSocketInterceptor implements HandshakeInterceptor { private static TokenProperties tokenProperties; static { //不能注入 动态设置 tokenProperties = SpringUtil.getBean(TokenProperties.class); } @Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception { ServletServerHttpRequest serverHttpRequest = (ServletServerHttpRequest) request; HttpSession session = serverHttpRequest.getServletRequest().getSession(); String token = serverHttpRequest.getServletRequest().getParameter("token"); if (StrUtil.isBlank(token)) { return false; } //解密token return true; } @Override public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception e) { } }
3.工具类
import cn.hutool.core.lang.Console; import cn.hutool.json.JSONUtil; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.Map; public class TelSocketSessionUtil { private static Map<String, WebSocketSession> clients = Collections.synchronizedMap(new HashMap<>());//new ConcurrentHashMap<>(); /** * 保存一个连接 * * @param session */ public static void add(Object o, WebSocketSession session) { clients.put(getKey(o), session); } /** * 获取一个连接 * * @return */ public static WebSocketSession get(Object o) { return clients.get(getKey(o)); } /** * 移除一个连接 */ public static void remove(Object o) throws IOException { clients.remove(getKey(o)); } /** * 组装sessionId * * @return */ public static String getKey(Object o) { return JSONUtil.toJsonStr(o); //return JsonUtils.serialize(o); } /** * 判断是否有效连接 * 判断是否存在 * 判断连接是否开启 * 无效的进行清除 * * @return */ public static boolean hasConnection(Object o) { String key = getKey(o); if (clients.containsKey(key)) { return true; } return false; } /** * 获取连接数的数量 * * @return */ public static int getSize() { return clients.size(); } /** * 发送消息到客户端 * * @throws Exception */ public static void sendMessage(Object key, String message) throws Exception { if (!hasConnection(key)) { throw new NullPointerException(getKey(key) + " connection does not exist"); } WebSocketSession session = get(key); try { session.sendMessage(new TextMessage(message)); } catch (IOException e) { Console.log("WebSocket sendMessage exception: {}", getKey(key)); Console.log(e.getMessage(), e); clients.remove(getKey(key)); } } }
4.实现处理器
import org.springframework.stereotype.Component; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.TextWebSocketHandler; import java.io.IOException; import java.util.concurrent.CopyOnWriteArraySet; @Component public class CustomWebSocketHandler extends TextWebSocketHandler { private static final CopyOnWriteArraySet<WebSocketSession> sessions = new CopyOnWriteArraySet<>(); @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { sessions.add(session); TelSocketSessionUtil.add("", session); // TelSocketSessionUtil.sendMessage("", "我给你发消息了"); System.out.println("New connection established: " + session.getId()); } @Override protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { String payload = message.getPayload(); System.out.println("Received message: " + payload); // Broadcast the received message to all connected clients for (WebSocketSession webSocketSession : sessions) { if (webSocketSession.isOpen()) { webSocketSession.sendMessage(new TextMessage("Server received: " + payload)); } } } @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { sessions.remove(session); System.out.println("Connection closed: " + session.getId()); } @Override public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { System.err.println("Transport error: " + exception.getMessage()); } private void sendMessageToAll(String message) throws IOException { for (WebSocketSession session : sessions) { if (session.isOpen()) { session.sendMessage(new TextMessage(message)); } } } }
5.启用
import cn.boxitec.websocket.CustomWebSocketHandler; import cn.boxitec.websocket.CustomWebSocketInterceptor; import org.springframework.context.annotation.Configuration; import org.springframework.web.socket.config.annotation.EnableWebSocket; import org.springframework.web.socket.config.annotation.WebSocketConfigurer; import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; @Configuration @EnableWebSocket public class WebSocketConfig implements WebSocketConfigurer { private final CustomWebSocketHandler customWebSocketHandler; public WebSocketConfig(CustomWebSocketHandler customWebSocketHandler) { this.customWebSocketHandler = customWebSocketHandler; } @Override public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { registry .addHandler(customWebSocketHandler, "ws") .addInterceptors(new CustomWebSocketInterceptor()) .setAllowedOrigins("*"); } }