Netty+WebSocket简单实现网页聊天
基于Netty+WebSocket的网页聊天简单实现
一、pom依赖
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>4.1.6.Final</version>
</dependency>
二、文件目录
三、服务端代码
WebSocketService
package com.netty; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.websocketx.WebSocketFrame; public interface WebSocketService { void handleFrame(ChannelHandlerContext ctx,WebSocketFrame frame); }
HttpService
package com.netty; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.FullHttpRequest; public interface HttpService { void handleHttpRequset(ChannelHandlerContext ctx,FullHttpRequest request); }
WebSocketServerHandler
package com.netty; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.websocketx.WebSocketFrame; public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object>{ private WebSocketService webSocketServiceImpl; private HttpService httpServiceImpl; public WebSocketServerHandler(WebSocketService webSocketServiceImpl, HttpService httpServiceImpl) { super(); this.webSocketServiceImpl = webSocketServiceImpl; this.httpServiceImpl = httpServiceImpl; } @Override protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { // TODO Auto-generated method stub if(msg instanceof FullHttpRequest){ httpServiceImpl.handleHttpRequset(ctx, (FullHttpRequest)msg); }else if(msg instanceof WebSocketFrame){ webSocketServiceImpl.handleFrame(ctx, (WebSocketFrame)msg); } } @Override public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { // TODO Auto-generated method stub ctx.flush(); } }
WebSocketServerImpl
package com.netty; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelId; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpServerCodec; import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory; import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.util.AttributeKey; public class WebSocketServerImpl implements WebSocketService, HttpService{ private static final String HN_HTTP_CODEC = "HN_HTTP_CODEC"; private static final String NH_HTTP_AGGREGATOR ="NH_HTTP_AGGREGATOR"; private static final String NH_HTTP_CHUNK = "HN_HTTP_CHUNK"; private static final String NH_SERVER = "NH_LOGIC"; private static final AttributeKey<WebSocketServerHandshaker> ATTR_HANDSHAKER = AttributeKey.newInstance("ATTR_KEY_CHANNELID"); private static final int MAX_CONTENT_LENGTH = 65536; private static final String WEBSOCKET_UPGRADE = "websocket"; private static final String WEBSOCKET_CONNECTION = "Upgrade"; private static final String WEBSOCKET_URI_ROOT_PATTERN = "ws://%s:%d"; //地址 private String host; //端口号 private int port; //存放websocket连接 private Map<ChannelId, Channel> channelMap = new ConcurrentHashMap<ChannelId, Channel>(); private final String WEBSOCKET_URI_ROOT; public WebSocketServerImpl(String host, int port) { super(); this.host = host; this.port = port; WEBSOCKET_URI_ROOT = String.format(WEBSOCKET_URI_ROOT_PATTERN, host, port); } //启动 public void start(){ EventLoopGroup bossGroup = new NioEventLoopGroup(); EventLoopGroup workerGroup = new NioEventLoopGroup(); ServerBootstrap sb = new ServerBootstrap(); sb.group(bossGroup, workerGroup); sb.channel(NioServerSocketChannel.class); sb.childHandler(new ChannelInitializer<Channel>() { @Override protected void initChannel(Channel ch) throws Exception { // TODO Auto-generated method stub ChannelPipeline pl = ch.pipeline(); //保存引用 channelMap.put(ch.id(), ch); ch.closeFuture().addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { // TODO Auto-generated method stub //关闭后抛弃 channelMap.remove(future.channel().id()); } }); pl.addLast(HN_HTTP_CODEC,new HttpServerCodec()); pl.addLast(NH_HTTP_AGGREGATOR,new HttpObjectAggregator(MAX_CONTENT_LENGTH)); pl.addLast(NH_HTTP_CHUNK,new ChunkedWriteHandler()); pl.addLast(NH_SERVER,new WebSocketServerHandler(WebSocketServerImpl.this,WebSocketServerImpl.this)); } }); try { ChannelFuture future = sb.bind(host,port).addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { // TODO Auto-generated method stub if(future.isSuccess()){ System.out.println("websocket started"); } } }).sync(); future.channel().closeFuture().addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { // TODO Auto-generated method stub System.out.println("channel is closed"); } }).sync(); } catch (InterruptedException e) { // TODO Auto-generated catch block e.printStackTrace(); } finally{ bossGroup.shutdownGracefully(); workerGroup.shutdownGracefully(); } System.out.println("websocket shutdown"); } @Override public void handleHttpRequset(ChannelHandlerContext ctx, FullHttpRequest request) { // TODO Auto-generated method stub if(isWebSocketUpgrade(request)){ String subProtocols = request.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL); WebSocketServerHandshakerFactory factory = new WebSocketServerHandshakerFactory(WEBSOCKET_URI_ROOT, subProtocols, false); WebSocketServerHandshaker handshaker = factory.newHandshaker(request); if(handshaker == null){ WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel()); }else{ //响应请求 handshaker.handshake(ctx.channel(), request); //将handshaker绑定给channel ctx.channel().attr(ATTR_HANDSHAKER).set(handshaker); } return; } } @Override public void handleFrame(ChannelHandlerContext ctx, WebSocketFrame frame) { // TODO Auto-generated method stub if(frame instanceof TextWebSocketFrame){ String text = ((TextWebSocketFrame) frame).text(); TextWebSocketFrame rsp = new TextWebSocketFrame(text); for(Channel ch:channelMap.values()){ if(ctx.channel().equals(ch)){ continue; } ch.writeAndFlush(rsp); } return; } //ping 回复 pong if(frame instanceof PingWebSocketFrame){ ctx.channel().writeAndFlush(new PongWebSocketFrame(frame.content().retain())); return; } if(frame instanceof PongWebSocketFrame){ return; } if(frame instanceof CloseWebSocketFrame){ WebSocketServerHandshaker handshaker = ctx.channel().attr(ATTR_HANDSHAKER).get(); if(handshaker == null){ return; } handshaker.close(ctx.channel(), (CloseWebSocketFrame)frame.retain()); return; } } //1、判断是否为get 2、判断Upgrade头 包含websocket字符串 3、Connection头 包换upgrade字符串 private boolean isWebSocketUpgrade(FullHttpRequest request){ HttpHeaders headers = request.headers(); return request.method().equals(HttpMethod.GET) && headers.get(HttpHeaderNames.UPGRADE).contains(WEBSOCKET_UPGRADE) && headers.get(HttpHeaderNames.CONNECTION).contains(WEBSOCKET_CONNECTION); } }
DoMain
package com.netty; public class DoMain { public static void main(String[] args) { WebSocketServerImpl socket = new WebSocketServerImpl("localhost", 9999); socket.start(); } }
四、服务端代码
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"> <html xmlns="http://www.w3.org/1999/xhtml"> <head> <meta http-equiv="Content-Type" content="text/html; charset=utf-8" /> <title></title> </head> </head> <script type="text/javascript"> var socket; if(!window.WebSocket){ window.WebSocket = window.MozWebSocket; } if(window.WebSocket){ socket = new WebSocket("ws://localhost:9999"); socket.onmessage = function(event){ appendln("receive:" + event.data); }; socket.onopen = function(event){ appendln("WebSocket is opened"); }; socket.onclose = function(event){ appendln("WebSocket is closed"); }; }else{ alert("WebSocket is not support"); } function send(message){ if(!window.WebSocket){return;} if(socket.readyState == WebSocket.OPEN){ socket.send(message); appendln("send:" + message); }else{ alert("WebSocket is failed"); } } function appendln(text) { var ta = document.getElementById('responseText'); ta.value += text + "\r\n"; } function clear() { var ta = document.getElementById('responseText'); ta.value = ""; } </script> <body> <form onSubmit="return false;"> <input type = "text" name="message" value="hello"/> <br/><br/> <input type="button" value="send" onClick="send(this.form.message.value)"/> <hr/> <h3>chat</h3> <textarea id="responseText" style="width: 800px;height: 300px;"></textarea> </form> </body> </html>
五、结果
六、实际使用阶段、当出现3个以上客户端时回报错,io.netty.util.IllegalReferenceCountException: refCnt: 0, decrement: 1
查找资料后发现是writeAndFlush方法里面有个计数器,导致异常。解决方法:
WebSocketServerImpl不使用 private Map<ChannelId, Channel> channelMap 存放连接,使用netty提供的ChannelGroup存放连接
创建变量 private ChannelGroup group = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
保存引用由 channelMap.put(ch.id(), ch); 改为 group.add(ch);
关闭由 channelMap.remove(future.channel().id()); 改为group.remove(ch);
发送 handleFrame方法 改为
String text = ((TextWebSocketFrame) frame).text(); TextWebSocketFrame rsp = new TextWebSocketFrame(text); group.writeAndFlush(rsp);
WebSocketServerImpl完整代码:
package com.netty; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelId; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; import io.netty.channel.group.ChannelGroup; import io.netty.channel.group.DefaultChannelGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpServerCodec; import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory; import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.util.AttributeKey; import io.netty.util.concurrent.GlobalEventExecutor; public class WebSocketServerImpl implements WebSocketService, HttpService{ private static final String HN_HTTP_CODEC = "HN_HTTP_CODEC"; private static final String NH_HTTP_AGGREGATOR ="NH_HTTP_AGGREGATOR"; private static final String NH_HTTP_CHUNK = "HN_HTTP_CHUNK"; private static final String NH_SERVER = "NH_LOGIC"; private static final AttributeKey<WebSocketServerHandshaker> ATTR_HANDSHAKER = AttributeKey.newInstance("ATTR_KEY_CHANNELID"); private static final int MAX_CONTENT_LENGTH = 65536; private static final String WEBSOCKET_UPGRADE = "websocket"; private static final String WEBSOCKET_CONNECTION = "Upgrade"; private static final String WEBSOCKET_URI_ROOT_PATTERN = "ws://%s:%d"; //地址 private String host; //端口号 private int port; //存放websocket连接 private Map<ChannelId, Channel> channelMap = new ConcurrentHashMap<ChannelId, Channel>(); private ChannelGroup group = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE); private final String WEBSOCKET_URI_ROOT; public WebSocketServerImpl(String host, int port) { super(); this.host = host; this.port = port; WEBSOCKET_URI_ROOT = String.format(WEBSOCKET_URI_ROOT_PATTERN, host, port); } //启动 public void start(){ EventLoopGroup bossGroup = new NioEventLoopGroup(); EventLoopGroup workerGroup = new NioEventLoopGroup(); ServerBootstrap sb = new ServerBootstrap(); sb.group(bossGroup, workerGroup); sb.channel(NioServerSocketChannel.class); sb.childHandler(new ChannelInitializer<Channel>() { @Override protected void initChannel(Channel ch) throws Exception { // TODO Auto-generated method stub ChannelPipeline pl = ch.pipeline(); //保存引用 channelMap.put(ch.id(), ch); group.add(ch); ch.closeFuture().addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { // TODO Auto-generated method stub //关闭后抛弃 channelMap.remove(future.channel().id()); group.remove(ch); } }); pl.addLast(HN_HTTP_CODEC,new HttpServerCodec()); pl.addLast(NH_HTTP_AGGREGATOR,new HttpObjectAggregator(MAX_CONTENT_LENGTH)); pl.addLast(NH_HTTP_CHUNK,new ChunkedWriteHandler()); pl.addLast(NH_SERVER,new WebSocketServerHandler(WebSocketServerImpl.this,WebSocketServerImpl.this)); } }); try { ChannelFuture future = sb.bind(host,port).addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { // TODO Auto-generated method stub if(future.isSuccess()){ System.out.println("websocket started"); } } }).sync(); future.channel().closeFuture().addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { // TODO Auto-generated method stub System.out.println("channel is closed"); } }).sync(); } catch (InterruptedException e) { // TODO Auto-generated catch block e.printStackTrace(); } finally{ bossGroup.shutdownGracefully(); workerGroup.shutdownGracefully(); } System.out.println("websocket shutdown"); } @Override public void handleHttpRequset(ChannelHandlerContext ctx, FullHttpRequest request) { // TODO Auto-generated method stub if(isWebSocketUpgrade(request)){ String subProtocols = request.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL); WebSocketServerHandshakerFactory factory = new WebSocketServerHandshakerFactory(WEBSOCKET_URI_ROOT, subProtocols, false); WebSocketServerHandshaker handshaker = factory.newHandshaker(request); if(handshaker == null){ WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel()); }else{ //响应请求 handshaker.handshake(ctx.channel(), request); //将handshaker绑定给channel ctx.channel().attr(ATTR_HANDSHAKER).set(handshaker); } return; } } @Override public void handleFrame(ChannelHandlerContext ctx, WebSocketFrame frame) { // TODO Auto-generated method stub if(frame instanceof TextWebSocketFrame){ String text = ((TextWebSocketFrame) frame).text(); TextWebSocketFrame rsp = new TextWebSocketFrame(text); // System.out.println(channelMap.size()); // // for(Channel ch:channelMap.values()){ // // // if (ctx.channel().equals(ch)) { // continue; // } // ch.writeAndFlush(rsp); // } group.writeAndFlush(rsp); // } //ping 回复 pong if(frame instanceof PingWebSocketFrame){ ctx.channel().writeAndFlush(new PongWebSocketFrame(frame.content().retain())); return; } if(frame instanceof PongWebSocketFrame){ return; } if(frame instanceof CloseWebSocketFrame){ WebSocketServerHandshaker handshaker = ctx.channel().attr(ATTR_HANDSHAKER).get(); if(handshaker == null){ return; } handshaker.close(ctx.channel(), (CloseWebSocketFrame)frame.retain()); return; } } //1、判断是否为get 2、判断Upgrade头 包含websocket字符串 3、Connection头 包换upgrade字符串 private boolean isWebSocketUpgrade(FullHttpRequest request){ HttpHeaders headers = request.headers(); return request.method().equals(HttpMethod.GET) && headers.get(HttpHeaderNames.UPGRADE).contains(WEBSOCKET_UPGRADE) && headers.get(HttpHeaderNames.CONNECTION).contains(WEBSOCKET_CONNECTION); } public void sendMessage(String message){ TextWebSocketFrame rsp = new TextWebSocketFrame(message); for(Channel ch:channelMap.values()){ ch.writeAndFlush(rsp); } } }