Java 实现 WebSocket 集群转发:使用 Redis 发布订阅

视频说明:https://www.bilibili.com/video/BV1Yh4y1F7SV?p=3

场景

后端服务被部署到多个节点上,通过弹性负载均衡对外提供服务。

客户端(浏览器) 客户端1 连接到了服务端 A 的 WebSocket 节点。
客户端通过弹性负载均衡,把请求分配到了服务端 B,比如计算服务会输出一些过程信息,服务端 B 上没有 客户端1 的 WS 连接。

需求

服务端 B 把消息转发到服务端 A 上,找到 客户端1 的连接,发送出去。

画示意图

代码

代码:https://github.com/ioufev/websocket-cluster-forward

备份:蓝奏云

Redis 发布类

import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;

@Component
public class RedisPublisher {

    @Resource
    private RedisTemplate<String, byte[]> redisTemplate;

    public void publishMessage(String channel, byte[] message) {
        redisTemplate.convertAndSend(channel, message);
    }

}

Redis 订阅类

import com.ioufev.wsforward.consts.RedisConst;
import com.ioufev.wsforward.ws.WebSocketServer;
import org.springframework.context.annotation.Bean;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.listener.ChannelTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import java.nio.charset.StandardCharsets;
import java.util.Base64;

@Component
public class RedisMessageListener implements MessageListener {

    @Resource
    private WebSocketServer webSocket;

    public RedisMessageListener(WebSocketServer webSocket) {
        this.webSocket = webSocket;
    }

    @Override
    public void onMessage(Message message, byte[] pattern) {

        // 获取频道名称
        String channel = new String(message.getChannel());

        // 判断是否为需要转发的频道
        if(channel.equals(RedisConst.PUB_SUB_TOPIC)){

            // 获取频道内容
            byte[] body = message.getBody();
            String contentBase64WithQuotes = new String(body, StandardCharsets.UTF_8); // 带引号的Base64
            String contentBase64 = removeQuotes(contentBase64WithQuotes); // base64
            String content = new String(Base64.getDecoder().decode(contentBase64), StandardCharsets.UTF_8); // 原来的字符串

            String key = content.split("::")[0];
            String wsContent  = content.substring((key + "::").length());
            webSocket.sendOneMessageForRedisMessage(key, wsContent);
        }

    }

    @Bean
    public RedisMessageListenerContainer container(RedisConnectionFactory factory,
                                                   RedisMessageListener listener) {
        RedisMessageListenerContainer container = new RedisMessageListenerContainer();
        container.setConnectionFactory(factory);
        container.addMessageListener(listener, new ChannelTopic(RedisConst.PUB_SUB_TOPIC));
        return container;
    }

    /**
     * 移除存在Redis中的值开头和结尾的引号
     * @param input 输入
     * @return 输出
     */
    private String removeQuotes(String input) {
        if (input != null && input.length() >= 2 && input.startsWith("\"") && input.endsWith("\"")) {
            return input.substring(1, input.length() - 1);
        }
        return input;
    }

}

WebSocket 服务端控制类

import com.ioufev.wsforward.consts.RedisConst;
import com.ioufev.wsforward.redis.RedisPublisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;

import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.PathParam;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;


import org.springframework.stereotype.Component;

import javax.websocket.server.ServerEndpoint;

@Component
@ServerEndpoint("/websocket/{key}")
public class WebSocketServer {

    private static final Logger log = LoggerFactory.getLogger(WebSocketServer.class);

    private String sessionId;
    private Session session;

    private static RedisPublisher redisPublisher;

    @Autowired
    public void setApplicationContext(RedisPublisher redisPublisher) {
        WebSocketServer.redisPublisher= redisPublisher;
    }

    private static CopyOnWriteArraySet<WebSocketServer> webSockets = new CopyOnWriteArraySet<>();

    private static Map<String, Session> sessionPool = new ConcurrentHashMap<>();


    @OnOpen
    public void onOpen(Session session, @PathParam(value = "key") String key) {

        this.sessionId = key;
        this.session = session;
        webSockets.add(this);
        sessionPool.put(key, session);
        log.info(key + "【websocket消息】有新的连接,总数为:" + webSockets.size() + ", session count is :" + sessionPool.size());
        for(WebSocketServer webSocket : webSockets) {
            log.info("【webSocket】key is :" + webSocket.sessionId);
        }

    }

    @OnClose
    public void onClose() {
        sessionPool.remove(this.sessionId);
        webSockets.remove(this);
        log.info("【websocket消息】连接断开,总数为:" + webSockets.size());
    }

    @OnMessage
    public void onMessage(@PathParam(value = "key") String key, String message) {
        log.info("【websocket消息】收到消息message:" + message);
        sendOneMessage(key, message);
    }

    /**
     * 广播消息
     */
    public void sendAllMessage(String message) {
        for (WebSocketServer webSocket : webSockets) {
            log.info("【websocket消息】广播消息:" + message);
            try {
                webSocket.session.getAsyncRemote().sendText(message);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * 单点消息
     */
    public void sendOneMessage(String key, String message) {

//		Session session = sessionPool.get(key);
        Session session = getSession(key);
        if (session != null) {
            try {
                session.getBasicRemote().sendText(message);
            } catch (Exception e) {
                e.printStackTrace();
            }
        } else {
            redisPublisher.publishMessage(RedisConst.PUB_SUB_TOPIC, (key + "::" + message).getBytes(StandardCharsets.UTF_8));
        }
    }

    /**
     * 用来Redis订阅后使用
     */
    public void sendOneMessageForRedisMessage(String key, String message) {
        Session session = getSession(key);
        if (session != null) {
            try {
                session.getBasicRemote().sendText(message);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    private static Session getSession(String key){
        for (WebSocketServer webSocket : webSockets) {
            if(webSocket.sessionId.equals(key)){
                return webSocket.session;
            }
        }
        return null;
    }

}

参考文章

💧 WebSocket 集群解决方案
👉 图画的好,理解起来很清楚。

💧 WebSocket 集群解决方案,不用 MQ
👉 在上面的思路基础上,想给服务端添加一个标识,用来记录用户连接和服务端的关联关系,我也有类似的想法,不过关于用户ID和服务端ID关联关系的存储问题,还没处理好。

💧 Spring Cloud 一个配置注解实现 WebSocket 集群方案
👉 这个思路更大胆,既然是集群转发,没什么不能直接使用 WebSocket 本身

💧 分布式 WebSocket 集群解决方案
👉 用户连接和服务端的关联关系,用一致性哈希存储

💧 Spring Boot WebSocket 的 6 种集成方式
👉 喜欢文章的标题,内容看看目录就行了。

💧 构建通用 WebSocket 推送网关的设计与实践
👉 生产环境值得参考,但是用来入门参考显然没说清楚重点和难点

💧 石墨文档是如何通过 WebSocket 实现百万长连接的?
👉 生产环境值得参考,但是用来入门参考显然没说清楚重点和难点,这个比上面文章说更详细,显然具有可操作性。

总结

1、需要有一个统一的地方来保存用户连接和服务端的关联关系,可以是: Redis、MQ、Zookeeper、微服务的服务发现。

2、Redis 发布订阅用来集群转发非常简单,适用于实时发布消息那种,比如一个计算过程的实时步骤输出。

3、如果要确保消息不丢失,尽量送达之类的,那就用 MQ。

4、最佳方式:每个服务端有一个ID,每个用户连接也有一个ID,然后服务端转发的时候,找到需要的服务端,只转发一次就好了。

posted @ 2023-07-17 11:05  ioufev  阅读(1627)  评论(0编辑  收藏  举报