redis发布/订阅解决分布式websocket推送问题

分布式websocket推送

场景

项目中用到websocket推送消息,后台是分布式部署的,需要通过websocket讲预警消息推送给前台。直接添加websocket后出现了一个问题,假设两台服务S1、S2,客户端C和后端服务建立链接的时候经过负载均衡给了S1,如果S1后台收到了预警消息此时可以直接推送给客户端C,但是加入服务端S2后台收到了预警消息也要推送给客户端,但是此时S2并没有和客户端C建立连接,此时该消息就会丢失而无法推送给客户端。

解决方案

使用MQ解耦消息和websocket服务端,假设收到了预警消息不是直接推送到客户端,而是发送到MQ,然后再websocket服务端通过监听/拉去MQ中的消息进行判断和推送。当然消息体的格式需要设计符合你的业务的结构。

实现

既然要使用MQ,我们该如何选型呢,其实市面上常见的MQ都是够用了,比如RocketMQ、ActiveMQ、RabbitMQ等,Kafka(不过有点儿大才小用了)。因为我们这个业务的关系,不希望引入新的组件,项目中刚好用到了Redis,决定用Redis的订阅发布功能解决。

代码

websocket

配置类
EndpointConfig

import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;

import javax.websocket.server.ServerEndpointConfig;

public class EndpointConfig extends ServerEndpointConfig.Configurator implements ApplicationContextAware {

    private static volatile BeanFactory context;

    @Override
    public <T> T getEndpointInstance(Class<T> clazz) throws InstantiationException {
        return context.getBean(clazz);
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        EndpointConfig.context = applicationContext;
    }
}
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;

`WebSocketConfig`类
//@EnableWebSocket // 可以不用该注解
@Configuration
public class WebSocketConfig02 {

    @Bean
    public ServerEndpointExporter serverEndpointConfig() {
        return new ServerEndpointExporter();
    }

    @Bean
    public EndpointConfig newConfig() {
        return new EndpointConfig();
    }
}

websocket请求类

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.bart.websocket.configuration.EndpointConfig;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * 1) value = "/ws/{userId}"
 * onOpen(@PathParam("userId") String userId, Session session){ // ... }
 * 这种方式必须在前端在/后面拼接参数 ws://localhost:7889/productWebSocket/123 ,否则404 
 * 
 * 2) value = "/ws"
 * onOpen(Session session){ // ... }
 * Map<String, List<String>> requestParameterMap = session.getRequestParameterMap();
 * // 获得 ?userId=123 这样的参数
 * @author bart
 */
@Component
@ServerEndpoint(
        value = "/ws/{userId}",
        configurator = EndpointConfig.class
        ,encoders = { ProductWebSocket.MessageEncoder.class } // 添加消息编码器
)
public class ProductWebSocket {

	final static Logger log = LoggerFactory.getLogger(ProductWebSocket.class);
	
    //当前在线用户
    private static final AtomicInteger onlineCount = new AtomicInteger(0);


    // 当前登录用户的id和websocket session的map
    private static ConcurrentHashMap<Session, String> userIdSessionMap = new ConcurrentHashMap<>();

    private Session session;

    private String userId;

    /**
     * 连接开启时调用
     * 
     * @param userId
     * @param session
     */
    @OnOpen
    public void onOpen(@PathParam("userId") String userId, Session session) {
        if (userId != null) {
            log.info("websocket 新客户端连入,用户id:" + userId);
            userIdSessionMap.put(session, userId);
            addOnlineCount();
            // 发送消息返回当前用户
            JSONObject jsonObject = new JSONObject();
            jsonObject.put("code", 200);
            jsonObject.put("message", "OK");
            send(userId, JSON.toJSONString(jsonObject));
        } else {
            log.error("websocket连接 缺少参数 id");
            throw new IllegalArgumentException("websocket连接 缺少参数 id");
        }
    }

    /**
     * 连接关闭时调用
     */
    @OnClose
    public void onClose(Session session) {
        log.info("一个客户端关闭连接");
        subOnlineCount();
        userIdSessionMap.remove(session);
    }

    /**
     * 服务端接收到信息后调用
     *
     * @param message
     * @param session
     */
    @OnMessage
    public void onMessage(String message, Session session) {
        log.info("用户发送过来的消息为:" + message);
    }

    /**
     * 服务端websocket出错时调用
     *
     * @param session
     * @param error
     */
    @OnError
    public void onError(Session session, Throwable error) {
        log.error("websocket出现错误");
        error.printStackTrace();
    }

    /**
     * 服务端发送信息给客户端
     * @param id 用户ID
     * @param message 发送的消息
     */
    public void send(String id, String message) {
        log.info("#### 点对点消息,userId={}", id);
        if(userIdSessionMap.size() > 0) {
            List<Session> sessionList = new ArrayList<>();
            for (Map.Entry<Session, String> entry : userIdSessionMap.entrySet()) {
                if(id.equalsIgnoreCase(entry.getValue())) {
                    sessionList.add(entry.getKey());
                }
            }
            if(sessionList.size() > 0) {
                for (Session session : sessionList) {
                    try {
                        session.getBasicRemote().sendText(message);//发送string
                        log.info("推送用户【{}】消息成功,消息为:【{}】", id , message);
                    } catch (Exception e) {
                        log.info("推送用户【{}】消息失败,消息为:【{}】,原因是:【{}】", id , message, e.getMessage());
                    }
                }
            } else {
                log.error("未找到当前id对应的session, id = {}", id);
            }
        } else {
            log.warn("当前无websocket连接");
        }
    }

    /**
     * 广播消息
     * @param message
     */
    public void broadcast(String message) {
        log.info("#### 广播消息");
        if(userIdSessionMap.size() > 0) {
            for (Map.Entry<Session, String> entry : userIdSessionMap.entrySet()) {
                try {
                    entry.getKey().getBasicRemote().sendText(message);//发送string
                } catch (Exception e) {
                    log.error("websocket 发送【{}】消息出错:{}",entry.getKey(), e.getMessage());
                }
            }
        } else {
            log.warn("当前无websocket连接");
        }
    }

    public static synchronized int getOnlineCount() {
        return onlineCount.get();
    }

    public static synchronized void addOnlineCount() {
        onlineCount.incrementAndGet();
    }

    public static synchronized void subOnlineCount() {
        onlineCount.decrementAndGet();
    }

    /**
     * 自定义消息编码器
     */
    public static class MessageEncoder implements Encoder.Text<JSONObject> {
        @Override
        public void init(javax.websocket.EndpointConfig endpointConfig) {

        }
        @Override
        public void destroy () {
        }

        @Override
        public String encode(JSONObject object) throws EncodeException {
            return object == null ? "" : object.toJSONString();
        }
    }
}

redis

常量类

public class RedisKeyConstants {

    /**
     * redis topic
     */
    public final static String REDIS_TOPIC_MSG = "redis_topic_msg";
}

配置类

import java.util.Arrays;

import com.bart.websocket.common.RedisKeyConstants;
import com.bart.websocket.configuration.redis.listener.RedisTopicListener;
import com.bart.websocket.service.WarnMsgService;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.listener.ChannelTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.data.redis.serializer.StringRedisSerializer;

/**
 * @author bart
 */
@Configuration
public class RedisConfig {

    /**
     * 添加spring提供的RedisMessageListenerContainer到容器
     * @param connectionFactory
     * @return
     */
    @Bean
    RedisMessageListenerContainer container(RedisConnectionFactory connectionFactory) {
        RedisMessageListenerContainer container = new RedisMessageListenerContainer();
        container.setConnectionFactory(connectionFactory);
        return container;
    }


    /**
     * 添加自己的监听器到容器中(监听指定topic)
     * @param container
     * @param stringRedisTemplate
     * @return
     */
    @Bean
    RedisTopicListener redisTopicListener(
            RedisMessageListenerContainer container,
            StringRedisTemplate stringRedisTemplate,
            WarnMsgService warnMsgService) {
        // 指定监听的 topic
        RedisTopicListener redisTopicListener = new RedisTopicListener(container,
                Arrays.asList(new ChannelTopic(RedisKeyConstants.REDIS_TOPIC_MSG)),
                warnMsgService);
        redisTopicListener.setStringRedisSerializer(new StringRedisSerializer());
        redisTopicListener.setStringRedisTemplate(stringRedisTemplate);
        return redisTopicListener;
    }


}

redis消息体

import com.bart.websocket.entity.WarnMsg;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

/**
 * redis发送消息的封装
 */
@Data
@AllArgsConstructor
@NoArgsConstructor
public class TopicMsg {

    private String userId;

    private WarnMsg msg;
}

监听器

import java.util.List;

import com.alibaba.fastjson.JSON;
import com.bart.websocket.common.RedisKeyConstants;
import com.bart.websocket.entity.WarnMsg;
import com.bart.websocket.service.WarnMsgService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.data.redis.listener.Topic;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.util.StringUtils;

/**
 * 自定义的topic的监听器
 * @author bart
 *
 */
public class RedisTopicListener implements MessageListener {

	private final static Logger log = LoggerFactory.getLogger(RedisTopicListener.class);

    private StringRedisSerializer stringRedisSerializer;
    private StringRedisTemplate stringRedisTemplate;

    private WarnMsgService warnMsgService;

	public RedisTopicListener(RedisMessageListenerContainer listenerContainer, List< ? extends Topic> topics, WarnMsgService warnMsgService) {
		this(listenerContainer, topics);
		this.warnMsgService = warnMsgService;
	}

    public RedisTopicListener(RedisMessageListenerContainer listenerContainer, List< ? extends Topic> topics) {
    	listenerContainer.addMessageListener(this, topics);
    }
	
	@Override
	public void onMessage(Message message, byte[] pattern) {
		String patternStr = stringRedisSerializer.deserialize(pattern);
		String channel = stringRedisSerializer.deserialize(message.getChannel());
		String body = stringRedisSerializer.deserialize(message.getBody());
		log.info("event = {}, message.channel = {},  message.body = {}", patternStr, channel, body);
		if(RedisKeyConstants.REDIS_TOPIC_MSG.equals(channel)) {
			TopicMsg topicMsg = JSON.parseObject(body, TopicMsg.class);
			String userId = topicMsg.getUserId();
			WarnMsg msg = topicMsg.getMsg();
//			log.debug("receive from topic=[{}] , userId=[{}], msg=[{}]", RedisKeyConstants.REDIS_TOPIC_MSG, userId, msg);
			// 发送消息 id 为空就是群发消息
			if(StringUtils.isEmpty(userId)) {
				warnMsgService.push(msg);
			} else {
				warnMsgService.push(userId, msg);
			}
		}
	}

	public StringRedisSerializer getStringRedisSerializer() {
		return stringRedisSerializer;
	}

	public void setStringRedisSerializer(StringRedisSerializer stringRedisSerializer) {
		this.stringRedisSerializer = stringRedisSerializer;
	}

	public StringRedisTemplate getStringRedisTemplate() {
		return stringRedisTemplate;
	}

	public void setStringRedisTemplate(StringRedisTemplate stringRedisTemplate) {
		this.stringRedisTemplate = stringRedisTemplate;
	}	
}

重点方法在这里:

com.bart.websocket.configuration.redis.listener.RedisTopicListener#onMessage

测试接口

import com.bart.websocket.entity.WarnMsg;
import com.bart.websocket.service.WarnMsgService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.HashMap;

@RestController
public class IndexController {

    @Autowired
    WarnMsgService warnMsgService;
    /**
     * 推送消息测试
     */
    @GetMapping("/push")
    public void initMsg(String id) {
        WarnMsg warnMsg = new WarnMsg();
        String format = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"));
        warnMsg.setTitle(format);
        warnMsg.setBody("吃了没?");
        warnMsgService.push(id, warnMsg);
    }
}

消息处理器类

WarnMsgService接口

public interface WarnMsgService {

    /**
     * 推送消息
     * @param msg
     */
    void push(WarnMsg msg);

    /**
     * 推送消息
     * @param userId 用户id
     * @param msg
     */
    void push(String userId, WarnMsg msg);

    /**
     * 通过 redis topic 发送消息(群发)
     * @param msg
     */
    void pushWithTopic(WarnMsg msg);

    /**
     * 通过 redis topic 发送消息
     * @param userId
     * @param msg
     */
    void pushWithTopic(String userId, WarnMsg msg);

}

WarnMsgServiceImpl实现类

package com.bart.websocket.service.impl;

import com.alibaba.fastjson.JSON;
import com.bart.websocket.common.RedisKeyConstants;
import com.bart.websocket.configuration.redis.listener.TopicMsg;
import com.bart.websocket.controller._02_spring_annotation.ProductWebSocket;
import com.bart.websocket.entity.WarnMsg;
import com.bart.websocket.service.WarnMsgService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketHandler;

import java.util.Collections;

@Service
public class WarnMsgServiceImpl implements WarnMsgService, ApplicationContextAware {

    private final static Logger log = LoggerFactory.getLogger(WarnMsgServiceImpl.class);

    ProductWebSocket webSocketHandler;

    @Autowired
    StringRedisTemplate stringRedisTemplate;

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        webSocketHandler = (ProductWebSocket)applicationContext.getBean("webSocketHandler", WebSocketHandler.class);
        Assert.notNull(webSocketHandler, "初始化webSocketHandler成功!");
    }

    @Override
    public void push(WarnMsg msg) {
//        RyGzry user = CommonUtils.getUser();
//        push(String.valueOf(user.getId()), msg);
        push("", msg);
    }

    @Override
    public void push(String userId, WarnMsg msg) {
        Assert.notNull(msg, "消息对象不能为空!");
        if(msg.getBody() == null) {
            msg.setBody(Collections.emptyMap());
        }
        if(StringUtils.isEmpty(userId)) {
            webSocketHandler.broadcast(JSON.toJSONString(msg));
        } else {
            webSocketHandler.send(userId, JSON.toJSONString(msg));
        }
    }

    /*
     *  向 redis 的 topic 发消息
     * 测试指定的topic的监听器(命令行)
     * 发布订阅
     * SUBSCRIBE redisChat // 订阅主题
     * PSUBSCRIBE it* big* //订阅给定模式的主题
     *
     * PUBLISH redisChat "Redis is a great caching technique" // 发布消息主题
     *
     * PUNSUBSCRIBE it* big* // 取消订阅通配符的频道
     * UNSUBSCRIBE channel it_info big_data // 取消订阅具体的频道
     */
    @Override
    public void pushWithTopic(String userId, WarnMsg msg) {
        if(null == userId) {
            userId = "";
        }
        if(msg == null) {
            log.debug("send to userId = [{}] msg is empty, just ignore!", userId);
            return;
        }
        String body = JSON.toJSONString(new TopicMsg(userId, msg));
        log.debug("send topic=[], msg=[]", RedisKeyConstants.REDIS_TOPIC_MSG, body);
        stringRedisTemplate.convertAndSend(RedisKeyConstants.REDIS_TOPIC_MSG, body);
    }

    @Override
    public void pushWithTopic(WarnMsg msg) {
        pushWithTopic("", msg);
    }
}

前端

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>websocket</title>
    <script src="js/sockjs.js"></script>
    <script src="js/jquery.min.js"></script>
</head>
<body>

<fieldset>
    <legend>User01</legend>
    <button onclick="online('bart')">上线</button>
	session:<input type="text" id="session-bart"/>
	host:<input type="text" id="host-bart" value="localhost"/>
	port:<input type="text" id="port-bart" value="8089"/>
    <div>发送消息:</div>
    <input type="text" id="msgContent-bart"/>
    <input type="button" value="点我发送" onclick="chat('bart')"/>
    <div>接受消息:</div>
    <div id="receiveMsg-bart" style="background-color: gainsboro;"></div>
</fieldset>
<script>
    var map = {};
    function online(name) {
		var host = $("#host-"+name).val();
		var port = $("#port-"+name).val();
		var session = $("#session-"+name).val();
        var chat = new CHAT(name, "ws://"+host+":"+port+"/ws/"+session);
        chat.init();
        map[name] = chat
    }
    function chat(name) {
	    console.log(name)
        return false;
    }
    function CHAT(name, url) {
        this.name = name;
        this.socket = null,
        this.init = function() {
            if ('WebSocket' in window) {
                console.log("WebSocket -> "+ url);
                //this.socket = new WebSocket("ws://localhost:8088/ws/"+ this.name);
                this.socket = new WebSocket(url);
            } else {
                console.log("your broswer not support websocket!");
                alert("your broswer not support websocket!")
                return;
            }
            if(this.socket === null) {
               return
            }
            this.socket.onopen = function() {
                console.log("连接建立成功...");
            },
            this.socket.onclose = function() {
                console.log("连接关闭...");
            },
            this.socket.onerror = function() {
                console.log("发生错误...");
            },
            this.socket.onmessage = function(e) {
                var id = "receiveMsg-"+ name;
                var res = JSON.parse(e.data);
                console.log(name , res);
                // 业务逻辑
            }
        },
        this.chat = function() {
            var id = "msgContent-"+ name;
            var value = document.getElementById(id).value;
            console.log("发送消息",  id, value)
            var msg = {
                "type": 1, // 1 就是发给所有人
                "msg": value
            }
            this.socket.send(JSON.stringify(msg));
        }
    };
</script>
</body>
</html>

测试

启动两个后端项目,端口分别为80808081

1、浏览器中链接8080端口的websocket

2、然后访问8081的接口http://localhost:8081/push,发现链接8080的客户端也受到了消息;

posted @ 2021-02-18 17:49  bartggg  阅读(5002)  评论(5编辑  收藏  举报