netty 实现webSocket
public class WebSocketClientCache { private static ConcurrentHashMap<String, WebSocketClientNode> channelsMap = new ConcurrentHashMap<>(); //心跳 1 已发送 private static ConcurrentHashMap<String, Integer> pingPongChannelsMap = new ConcurrentHashMap<>(); private static ConcurrentHashMap<String, Long> registTrackIdMap = new ConcurrentHashMap<String, Long>(); //轮训时间 检测过期连接 private final static int SCHEDULE_SECONDS = 180; private static ScheduledExecutorService scheduleService = Executors.newScheduledThreadPool(1); private static Lock pingPongLock = new ReentrantLock(); /*标记状态*/ private static volatile boolean isSent = true; static { scheduleService.scheduleAtFixedRate(new Runnable() { @Override public void run() { //3分鐘 send ping to all LogUtils.logDebug(WebSocketClientCache.class, "WebSocketClientCache scheduleWithFixedDelay starting ..."); if (isSent) { isSent = false; //定时发送心跳 sendPingMessageToAll(); } else { isSent = true; clearNotPingPongMessage(); } } }, 1L, SCHEDULE_SECONDS, TimeUnit.SECONDS); } public static void putChannelHandlerContext(String channelId, ChannelHandlerContext channelHandlerContext, WebSocketServerHandshaker handshaker) { if (channelsMap.containsKey(channelId)) { return; } try { WebSocketClientNode webSocketClientNode = new WebSocketClientNode(); webSocketClientNode.handshaker = handshaker; webSocketClientNode.ctx = channelHandlerContext; channelsMap.put(channelId, webSocketClientNode); pingPongChannelsMap.remove(channelId); } catch (Exception e) { LogUtils.logError(WebSocketClientCache.class, e); } } public static ChannelHandlerContext getChannelHandlerContext(String channelId) { WebSocketClientNode webSocketClientNode = channelsMap.get(channelId); if (webSocketClientNode != null && webSocketClientNode.ctx != null) { return webSocketClientNode.ctx; } return null; } public static void removeChannelHandlerContext(String channelId) { WebSocketClientNode webSocketClientNode = channelsMap.remove(channelId); if (webSocketClientNode != null) { Channel channel = webSocketClientNode.ctx.channel(); //关闭连接 if (channel.isOpen()) { webSocketClientNode.handshaker.close(channel, new CloseWebSocketFrame()); channel.close(); } } registTrackIdMap.remove(channelId); pingPongChannelsMap.remove(channelId); } public static List<ChannelHandlerContext> getChannelHandlerContextList() { if (channelsMap.isEmpty()) { return null; } Set<String> keySet = channelsMap.keySet(); List<ChannelHandlerContext> list = new LinkedList<>(); for (String key : keySet) { WebSocketClientNode webSocketClientNode = channelsMap.get(key); list.add(webSocketClientNode.ctx); } if (list.size() == 0) { return null; } return list; } /** * 全部发送心跳 */ public static void sendPingMessageToAll() { if (channelsMap.isEmpty()) { return; } Set<String> keySet = channelsMap.keySet(); for (String key : keySet) { WebSocketClientNode webSocketClientNode = channelsMap.get(key); //往客户端发ping 客户端会返回pong 可以用来判断客户端存活 PingWebSocketFrame pingWebSocketFrame = new PingWebSocketFrame(); Channel channel = webSocketClientNode.ctx.channel(); if (channel.isOpen()) { channel.writeAndFlush(pingWebSocketFrame); } //标记为已发送 pingPongChannelsMap.put(key, 1); } } public static void getPongMessage(String channelId) { if (channelId == null) { return; } pingPongChannelsMap.remove(channelId); } public static void clearNotPingPongMessage() { if (pingPongChannelsMap.isEmpty()) { return; } Set<String> keySet = pingPongChannelsMap.keySet(); for (String key : keySet) { Integer status = pingPongChannelsMap.get(key); if (status != null && status.intValue() == 1) { WebSocketClientNode webSocketClientNode = channelsMap.get(key); //关闭websocket // 握手关闭连接 Channel channel = webSocketClientNode.ctx.channel(); if (channel.isOpen()) { webSocketClientNode.handshaker.close(channel, new CloseWebSocketFrame()); } } //删除连接 removeChannelHandlerContext(key); } } private static class WebSocketClientNode { WebSocketServerHandshaker handshaker; ChannelHandlerContext ctx; } /** * 全部发送消息 * 往客户端推送消息 */ public static void sendMessageToAll(String message) { if (channelsMap.isEmpty()) { return; } Set<String> keySet = channelsMap.keySet(); for (String key : keySet) { WebSocketClientNode webSocketClientNode = channelsMap.get(key); TextWebSocketFrame textWebSocketFrame = new TextWebSocketFrame(message); if (webSocketClientNode.ctx.channel().isOpen() && webSocketClientNode.ctx.channel().isWritable()) { webSocketClientNode.ctx.channel().writeAndFlush(textWebSocketFrame); } else { channelsMap.remove(key); registTrackIdMap.remove(key); } } } public static void registTrackTagId(String key, Long id) { registTrackIdMap.put(key, id); } public static void removeTrackTagId(String key) { registTrackIdMap.remove(key); } public static boolean isRegistTrackTagId(Long tagId) { Collection<Long> ids = registTrackIdMap.values(); return ids.contains(tagId); } }
package com.hikvision.energy.websocket.server; import com.hikvision.energy.core.util.log.LogUtils; import com.hikvision.energy.websocket.handler.WebSocketServerHandler; import com.hikvision.energy.websocket.handler.WebSocketServerHandlerFactory; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.*; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpServerCodec; import io.netty.handler.stream.ChunkedWriteHandler; import org.springframework.beans.factory.annotation.Autowired; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; /** * netty websocket 服务器 * <p> * Created by zhuangjiesen on 2017/8/8. */ public class WebSocketNettyServer { /** * 端口号 **/ private int port; @Autowired private WebSocketServerHandlerFactory webSocketServerHandlerFactory; private ExecutorService executorService = Executors.newSingleThreadExecutor(); public WebSocketNettyServer() { super(); } /** * 创建一个单例的线程池, */ public void startNettyServer() { executorService.execute(new Runnable() { @Override public void run() { // TODO Auto-generated method stub EventLoopGroup boss = new NioEventLoopGroup(); EventLoopGroup worker = new NioEventLoopGroup(); try { ServerBootstrap bootstrap = new ServerBootstrap(); bootstrap.group(boss, worker); bootstrap.channel(NioServerSocketChannel.class); bootstrap.option(ChannelOption.SO_BACKLOG, 1024); //连接数 bootstrap.option(ChannelOption.TCP_NODELAY, true); //不延迟,消息立即发送 // bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 2000); //超时时间 bootstrap.childOption(ChannelOption.SO_KEEPALIVE, true); //长连接 bootstrap.childHandler(new ChannelInitializer<SocketChannel>() { @Override protected void initChannel(SocketChannel socketChannel) throws Exception { ChannelPipeline p = socketChannel.pipeline(); p.addLast("http-codec", new HttpServerCodec()); p.addLast("aggregator", new HttpObjectAggregator(65536)); p.addLast("http-chunked", new ChunkedWriteHandler()); p.addLast("handler", webSocketServerHandlerFactory.newWebSocketServerHandler()); p.addLast(" ", webSocketServerHandlerFactory.newWebSocketServerOutboundHandler()); } }); ChannelFuture f = bootstrap.bind(port).sync(); if (f.isSuccess()) { LogUtils.logDebug(WebSocketNettyServer.class, " WebSocketNettyServer start successfully ...."); } f.channel().closeFuture().sync(); LogUtils.logDebug(WebSocketNettyServer.class, " WebSocketNettyServer start successfully ...."); } catch (Exception e) { LogUtils.logError(e.getClass(), e); } finally { LogUtils.logError(WebSocketNettyServer.class, " WebSocketNettyServer shutdownGracefully ...."); boss.shutdownGracefully(); worker.shutdownGracefully(); } } }); } public int getPort() { return port; } public void setPort(int port) { this.port = port; } public WebSocketServerHandlerFactory getWebSocketServerHandlerFactory() { return webSocketServerHandlerFactory; } public void setWebSocketServerHandlerFactory(WebSocketServerHandlerFactory webSocketServerHandlerFactory) { this.webSocketServerHandlerFactory = webSocketServerHandlerFactory; } }
public class WebSocketServerOutboundHandler extends ChannelOutboundHandlerAdapter { @Override public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { String id = ctx.channel().id().asLongText(); WebSocketClientCache.removeChannelHandlerContext(id); super.close(ctx, promise); } }
public class WebSocketServerHandlerFactory { private WebSocketMessageHandler webSocketMessageHandler; public WebSocketServerHandler newWebSocketServerHandler() { WebSocketServerHandler webSocketServerHandler = new WebSocketServerHandler(); webSocketServerHandler.setWebSocketMessageHandler(webSocketMessageHandler); return webSocketServerHandler; } public WebSocketServerOutboundHandler newWebSocketServerOutboundHandler() { WebSocketServerOutboundHandler webSocketServerOutboundHandler = new WebSocketServerOutboundHandler(); return webSocketServerOutboundHandler; } public WebSocketMessageHandler getWebSocketMessageHandler() { return webSocketMessageHandler; } public void setWebSocketMessageHandler(WebSocketMessageHandler webSocketMessageHandler) { this.webSocketMessageHandler = webSocketMessageHandler; } }
package com.hikvision.energy.websocket.handler; import com.alibaba.fastjson.JSON; import com.hikvision.energy.core.util.log.LogUtils; import com.hikvision.energy.util.CfgMgr; import com.hikvision.energy.websocket.cache.WebSocketClientCache; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.*; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http.websocketx.*; import io.netty.util.CharsetUtil; import java.util.concurrent.ConcurrentHashMap; public class WebSocketServerHandler extends ChannelInboundHandlerAdapter { public final static String WEBSOCKET = "websocket"; public final static String Upgrade = "Upgrade"; public final static String WEBSOCKET_ADDRESS_FORMAT = "ws://${host}/web/websocket"; private WebSocketServerHandshaker webSocketServerHandshaker; private WebSocketMessageHandler webSocketMessageHandler; @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { // TODO Auto-generated method stub // LogUtils.logError(WebSocketServerHandler.class , "WebSocketServerHandler channelActive ..." ); } @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { // TODO Auto-generated method stub if (msg instanceof FullHttpRequest) { //处理http请求 handleHttpRequest(ctx, ((FullHttpRequest) msg)); } else if (msg instanceof WebSocketFrame) { //处理websocket请求 handlerWebSocketFrame(ctx, (WebSocketFrame) msg); } else if (msg instanceof CloseWebSocketFrame) { //关闭链路 onWebSocketFrameColsed(ctx, (CloseWebSocketFrame) msg); } else { throw new RuntimeException("无法处理的请求"); } } @Override public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { // TODO Auto-generated method stub // LogUtils.logError(WebSocketServerHandler.class , "WebSocketServerHandler channelReadComplete ..." ); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { // TODO Auto-generated method stub // LogUtils.logError(WebSocketServerHandler.class , "WebSocketServerHandler exceptionCaught ..." , " exception : " , cause ); } /** * 处理http请求 */ public void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) { String headVal = ""; if (req.headers() != null) { headVal = req.headers().get(this.Upgrade); } //不是websocket请求 if (!req.decoderResult().isSuccess() || (headVal != null && !this.WEBSOCKET.equals(headVal.toLowerCase()))) { DefaultFullHttpResponse defaultFullHttpResponse = new DefaultFullHttpResponse( HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST); // 返回应答给客户端 if (defaultFullHttpResponse.status().code() != 200) { ByteBuf buf = Unpooled.copiedBuffer(defaultFullHttpResponse.status().toString(), CharsetUtil.UTF_8); defaultFullHttpResponse.content().writeBytes(buf); buf.release(); } // 如果是非Keep-Alive,关闭连接 ChannelFuture f = ctx.channel().writeAndFlush(defaultFullHttpResponse); //boolean isKeepAlive = false; //if ((!isKeepAlive) || defaultFullHttpResponse.status().code() != 200) { if (defaultFullHttpResponse.status().code() != 200) { f.addListener(ChannelFutureListener.CLOSE); } return; } String host = ""; if (req.headers().get("Host") != null) { host = req.headers().get("Host").toString(); } else { host = CfgMgr.getInstallIp(); } String webAddress = WEBSOCKET_ADDRESS_FORMAT.replace("${host}", host); WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory( webAddress, null, false); //wbSocket Factory 通過請求可以獲取 握手。 WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req); if (handshaker == null) { WebSocketServerHandshakerFactory .sendUnsupportedVersionResponse(ctx.channel()); } else { handshaker.handshake(ctx.channel(), req); webSocketServerHandshaker = handshaker; WebSocketClientCache.putChannelHandlerContext(ctx.channel().id().asLongText(), ctx, handshaker); } } /** * 处理websocket 请求 */ public void handlerWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) { if (frame == null) { return; } // 判断是否关闭链路的指令 if ((frame instanceof CloseWebSocketFrame)) { LogUtils.logDebug(WebSocketServerHandler.class, "关闭连接...."); onWebSocketFrameColsed(ctx, frame); return; } // 判断是否ping消息 if (frame instanceof PingWebSocketFrame) { LogUtils.logDebug(WebSocketServerHandler.class, "i am ping message ...."); ctx.channel().write( new PongWebSocketFrame(frame.content().retain())); return; } if (frame instanceof PongWebSocketFrame) { LogUtils.logDebug(WebSocketServerHandler.class, "i am pong message ...."); ctx.channel().write( new PingWebSocketFrame(frame.content().retain())); //获取pong WebSocketClientCache.getPongMessage(ctx.channel().id().asLongText()); return; } LogUtils.logDebug(WebSocketServerHandler.class, "frame class : " + frame.getClass().getName()); // 本例程仅支持文本消息,不支持二进制消息 if (!(frame instanceof TextWebSocketFrame)) { LogUtils.logDebug(WebSocketServerHandler.class, "not support other frame data : ", JSON.toJSONString(frame)); throw new UnsupportedOperationException(String.format( "%s frame types not supported", frame.getClass().getName())); } String channelId = ctx.channel().id().asLongText(); //每次插入,保证可靠 WebSocketClientCache.putChannelHandlerContext(channelId, ctx, webSocketServerHandshaker); //获取pong WebSocketClientCache.getPongMessage(ctx.channel().id().asLongText()); if (frame != null) { TextWebSocketFrame textWebSocketFrame = (TextWebSocketFrame) frame; // String frameContent = textWebSocketFrame.text(); // Channel channel = ctx.channel(); // String channelId = channel.id().asLongText(); // //获取pong // WebSocketClientCache.getPongMessage(ctx.channel().id().asLongText()); // TextWebSocketFrame tws = new TextWebSocketFrame("我回应了。。。。"); // channel.writeAndFlush(tws); // LogUtils.logDebug(WebSocketServerHandler.class , "frameContent : " + frameContent); webSocketMessageHandler.onMessage(textWebSocketFrame, ctx); } } /* * 处理websocket 关闭请求 * */ public void onWebSocketFrameColsed(ChannelHandlerContext ctx, WebSocketFrame frame) { if (webSocketServerHandshaker != null) { WebSocketClientCache.removeChannelHandlerContext(ctx.channel().id().asLongText()); webSocketServerHandshaker.close(ctx.channel(), (CloseWebSocketFrame) frame .retain()); } } public WebSocketMessageHandler getWebSocketMessageHandler() { return webSocketMessageHandler; } public void setWebSocketMessageHandler(WebSocketMessageHandler webSocketMessageHandler) { this.webSocketMessageHandler = webSocketMessageHandler; } }
package com.hikvision.energy.websocket.handler.impl; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.hikvision.energy.core.util.log.LogUtils; import com.hikvision.energy.websocket.cache.WebSocketClientCache; import com.hikvision.energy.websocket.handler.WebSocketMessageHandler; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrame; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.math.NumberUtils; /** * Created by zhuangjiesen on 2017/8/9. */ public class TextFrameWebSocketMessageHandler implements WebSocketMessageHandler { @Override public void onMessage(WebSocketFrame webSocketFrame, ChannelHandlerContext ctx) { TextWebSocketFrame textWebSocketFrame = (TextWebSocketFrame) webSocketFrame; String message = textWebSocketFrame.text(); if (StringUtils.isNotBlank(message)) { try { Channel channel = ctx.channel(); String channelId = channel.id().asLongText(); String[] splitArr = message.split("="); if (splitArr != null && splitArr.length == 2) { String trackTagIdStr = splitArr[1]; if (NumberUtils.isDigits(trackTagIdStr)) { //在cache 的 map 中存 通道id 和标签的id WebSocketClientCache.registTrackTagId(channelId, NumberUtils.toLong(trackTagIdStr)); } } else { //删除 tag Id WebSocketClientCache.removeTrackTagId(channelId); } } catch (Exception e) { LogUtils.logError(this, e, "analyze message error", message); } } LogUtils.logDebug(TextFrameWebSocketMessageHandler.class , "textWebSocketFrame : " + ((TextWebSocketFrame) webSocketFrame).text()); } }