springboot 集成 websocket

1.首先添加maven依赖

<dependency>
	<groupId>org.springframework.boot</groupId>
	<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>

 

2.添加拦截器

import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.spring.SpringUtil;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;

import javax.servlet.http.HttpSession;
import java.util.Map;

public class CustomWebSocketInterceptor implements HandshakeInterceptor {
    private static TokenProperties tokenProperties;

    static {
        //不能注入 动态设置
        tokenProperties = SpringUtil.getBean(TokenProperties.class);
    }

    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
        ServletServerHttpRequest serverHttpRequest = (ServletServerHttpRequest) request;
        HttpSession session = serverHttpRequest.getServletRequest().getSession();
        String token = serverHttpRequest.getServletRequest().getParameter("token");
        if (StrUtil.isBlank(token)) {
            return false;
        }

        //解密token 
        
        return true;
    }

    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception e) {

    }
}

 

3.工具类

import cn.hutool.core.lang.Console;
import cn.hutool.json.JSONUtil;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

public class TelSocketSessionUtil {
    private static Map<String, WebSocketSession> clients = Collections.synchronizedMap(new HashMap<>());//new ConcurrentHashMap<>();

    /**
     * 保存一个连接
     *
     * @param session
     */
    public static void add(Object o, WebSocketSession session) {
        clients.put(getKey(o), session);
    }

    /**
     * 获取一个连接
     *
     * @return
     */
    public static WebSocketSession get(Object o) {
        return clients.get(getKey(o));
    }

    /**
     * 移除一个连接
     */
    public static void remove(Object o) throws IOException {
        clients.remove(getKey(o));
    }

    /**
     * 组装sessionId
     *
     * @return
     */
    public static String getKey(Object o) {
        return JSONUtil.toJsonStr(o);
        //return JsonUtils.serialize(o);
    }

    /**
     * 判断是否有效连接
     * 判断是否存在
     * 判断连接是否开启
     * 无效的进行清除
     *
     * @return
     */
    public static boolean hasConnection(Object o) {
        String key = getKey(o);
        if (clients.containsKey(key)) {
            return true;
        }
        return false;
    }

    /**
     * 获取连接数的数量
     *
     * @return
     */
    public static int getSize() {
        return clients.size();
    }

    /**
     * 发送消息到客户端
     *
     * @throws Exception
     */
    public static void sendMessage(Object key, String message) throws Exception {
        if (!hasConnection(key)) {
            throw new NullPointerException(getKey(key) + " connection does not exist");
        }

        WebSocketSession session = get(key);
        try {
            session.sendMessage(new TextMessage(message));
        } catch (IOException e) {
            Console.log("WebSocket sendMessage exception: {}", getKey(key));
            Console.log(e.getMessage(), e);
            clients.remove(getKey(key));
        }
    }
}

 

4.实现处理器

import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;

import java.io.IOException;
import java.util.concurrent.CopyOnWriteArraySet;

@Component
public class CustomWebSocketHandler extends TextWebSocketHandler {

    private static final CopyOnWriteArraySet<WebSocketSession> sessions = new CopyOnWriteArraySet<>();

    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        sessions.add(session);
        TelSocketSessionUtil.add("", session);
//        TelSocketSessionUtil.sendMessage("", "我给你发消息了");
        System.out.println("New connection established: " + session.getId());
    }

    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
        String payload = message.getPayload();
        System.out.println("Received message: " + payload);
        // Broadcast the received message to all connected clients
        for (WebSocketSession webSocketSession : sessions) {
            if (webSocketSession.isOpen()) {
                webSocketSession.sendMessage(new TextMessage("Server received: " + payload));
            }
        }
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
        sessions.remove(session);
        System.out.println("Connection closed: " + session.getId());
    }

    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
        System.err.println("Transport error: " + exception.getMessage());
    }

    private void sendMessageToAll(String message) throws IOException {
        for (WebSocketSession session : sessions) {
            if (session.isOpen()) {
                session.sendMessage(new TextMessage(message));
            }
        }
    }
}

 

5.启用

import cn.boxitec.websocket.CustomWebSocketHandler;
import cn.boxitec.websocket.CustomWebSocketInterceptor;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;

@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {

    private final CustomWebSocketHandler customWebSocketHandler;

    public WebSocketConfig(CustomWebSocketHandler customWebSocketHandler) {
        this.customWebSocketHandler = customWebSocketHandler;
    }

    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry
                .addHandler(customWebSocketHandler, "ws")
                .addInterceptors(new CustomWebSocketInterceptor())
                .setAllowedOrigins("*");
    }
}

  

posted @ 2024-06-18 09:53  那知归不归  阅读(12)  评论(0编辑  收藏  举报