君临-行者无界

导航

netty学习总结(三)

  Netty开发中会遇到沾包、拆包问题,因此我们需要制定协议去读取数据,下面给出一个自定义协议Demo

public class XDecoder extends ByteToMessageDecoder {

    static final int PACKAGER_SIZE = 220;

    private ByteBuf tempMsg = Unpooled.buffer();


    @Override
    protected void decode(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf, List<Object> list) throws Exception {

        System.out.println(Thread.currentThread().getName() + "收到数据包,大小为"+ byteBuf.readableBytes() );
        ByteBuf message = null;
        if(tempMsg.readableBytes()>0){
            message =Unpooled.buffer();
            message.writeBytes(tempMsg);
            message.writeBytes(byteBuf);
            System.out.println("上次剩余长度为" + tempMsg.readableBytes() + ",合并后长度为:" + message.readableBytes());
        }else{
            message = byteBuf;
        }

        int size = message.readableBytes();
        int count = size/PACKAGER_SIZE;

        for (int i = 0; i < count; i++) {
            byte[] request = new byte[PACKAGER_SIZE];
            message.readBytes(request);
            list.add(Unpooled.copiedBuffer(request));
        }

        size = message.readableBytes();
        if(size>0){
            System.out.println("剩余长度为" + size);
            tempMsg.clear();
            tempMsg.writeBytes(message.readBytes(size));
        }


    }
}
public class XHandller extends ChannelInboundHandlerAdapter {

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {

        ByteBuf byteBuf = (ByteBuf) msg;
        System.out.println(msg);

        byte[] content = new byte[byteBuf.readableBytes()];
        byteBuf.readBytes(content);
        System.out.println(new String(content));
        byteBuf.release();
    }

    @Override
    public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
        ctx.flush();
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        cause.printStackTrace();
        ctx.flush();
    }
}
public class XNettyServer {

    public static void main(String[] args) throws InterruptedException {

        EventLoopGroup acceptGroup = new NioEventLoopGroup();

        EventLoopGroup readGroup = new NioEventLoopGroup();

        try {
            ServerBootstrap bootstrap = new ServerBootstrap();
            bootstrap.group(acceptGroup,readGroup);
            bootstrap.channel(NioServerSocketChannel.class);
            bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
                @Override
                protected void initChannel(SocketChannel socketChannel) throws Exception {

                    ChannelPipeline pipeline = socketChannel.pipeline();
                    pipeline.addLast(new XDecoder());
                    pipeline.addLast(new XHandller());

                }
            });
            System.out.println("启动成功");
            bootstrap.bind(9999).sync().channel().closeFuture().sync();
        } finally {

            acceptGroup.shutdownGracefully();
            readGroup.shutdownGracefully();
        }

    }
}
public class MySocketClient {

    public static void main(String[] args) throws IOException, InterruptedException {

        Socket client = new Socket("127.0.0.1",9999);

        OutputStream outputStream = client.getOutputStream();

        byte[] request = new byte[220];
        byte[] userId ="1000000000".getBytes();
        byte[] content ="自定义协议自定义协议自定义协议".getBytes();
        System.arraycopy(userId,0,request,0,10);
        System.arraycopy(content,0,request,10,content.length);
        CountDownLatch countDownLatch = new CountDownLatch(1);
        for (int i = 0; i < 10; i++) {
            new Thread(()->{
                try {
                    countDownLatch.await();
                    outputStream.write(request);
                } catch (IOException e) {
                    e.printStackTrace();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }

            }).start();
        }
        countDownLatch.countDown();
        Thread.sleep(2000);

        client.close();
    }
}

  自研协议比较麻烦,netty为我们封装了常用的协议,大部分情况我们直接使用即可,下面是一个基于netty实现的websocket服务

  server端

package com.example.test.websocket;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.QueryStringDecoder;
import io.netty.util.AttributeKey;

import java.util.List;
import java.util.Map;

// 处理新连接
public class NewConnectHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) throws Exception {
        // 解析请求,判断token,拿到用户ID。
        Map<String, List<String>> parameters = new QueryStringDecoder(req.uri()).parameters();
        // String token = parameters.get("token").get(0);  不是所有人都能连接,比如需要登录之后,发放一个推送的token
        String userId = parameters.get("userId").get(0);
        ctx.channel().attr(AttributeKey.valueOf("userId")).getAndSet(userId); // channel中保存userId
        MessageCenter.saveConnection(userId, ctx.channel()); // 保存连接
    }
}
package com.example.test.websocket;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.websocketx.*;
import io.netty.util.AttributeKey;
import io.netty.util.CharsetUtil;

import java.util.concurrent.atomic.LongAdder;

import static io.netty.handler.codec.http.HttpMethod.GET;
import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;

/**
 * Handles handshakes and messages
 */
public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object> {

    private static final String WEBSOCKET_PATH = "/websocket";

    private WebSocketServerHandshaker handshaker;

    public static final LongAdder counter = new LongAdder();

    @Override
    public void channelRead0(ChannelHandlerContext ctx, Object msg) {
        counter.add(1);
        if (msg instanceof FullHttpRequest) {
            // 处理websocket握手
            handleHttpRequest(ctx, (FullHttpRequest) msg);
        } else if (msg instanceof WebSocketFrame) {
            // 处理websocket后续的消息
            handleWebSocketFrame(ctx, (WebSocketFrame) msg);
        }
    }

    @Override
    public void channelReadComplete(ChannelHandlerContext ctx) {
        ctx.flush();
    }

    private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) {
        // Handle a bad request. //如果http解码失败 则返回http异常 并且判断消息头有没有包含Upgrade字段(协议升级)
        if (!req.decoderResult().isSuccess() || req.method() != GET || (!"websocket".equals(req.headers().get("Upgrade")))) {
            sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST));
            return;
        }

        // 构造握手响应返回
        WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
                getWebSocketLocation(req), null, true, 5 * 1024 * 1024);
        handshaker = wsFactory.newHandshaker(req);
        if (handshaker == null) {
            // 版本不支持
            WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
        } else {
            handshaker.handshake(ctx.channel(), req);
            ctx.fireChannelRead(req.retain()); // 继续传播
        }
    }

    private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) {
        // Check for closing frame 关闭
        if (frame instanceof CloseWebSocketFrame) {
            Object userId = ctx.channel().attr(AttributeKey.valueOf("userId")).get();
            MessageCenter.removeConnection(userId);
            handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame.retain());
            return;
        }
        if (frame instanceof PingWebSocketFrame) { // ping/pong作为心跳
            System.out.println("ping: " + frame);
            ctx.write(new PongWebSocketFrame(frame.content().retain()));
            return;
        }
        if (frame instanceof TextWebSocketFrame) {
            // Echo the frame
            //发送到客户端websocket
            ctx.channel().write(new TextWebSocketFrame(((TextWebSocketFrame) frame).text()
                    + ", 欢迎使用Netty WebSocket服务, 现在时刻:"
                    + new java.util.Date().toString()));

            return;
        }
        // 不处理二进制消息
        if (frame instanceof BinaryWebSocketFrame) {
            // Echo the frame
            ctx.write(frame.retain());
        }
    }

    private static void sendHttpResponse(
            ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res) {
        // Generate an error page if response getStatus code is not OK (200).
        if (res.status().code() != 200) {
            ByteBuf buf = Unpooled.copiedBuffer(res.status().toString(), CharsetUtil.UTF_8);
            res.content().writeBytes(buf);
            buf.release();
            HttpUtil.setContentLength(res, res.content().readableBytes());
        }

        // Send the response and close the connection if necessary.
        ChannelFuture f = ctx.channel().writeAndFlush(res);
        if (!HttpUtil.isKeepAlive(req) || res.status().code() != 200) {
            f.addListener(ChannelFutureListener.CLOSE);
        }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        cause.printStackTrace();
        ctx.close();
    }

    private static String getWebSocketLocation(FullHttpRequest req) {
        String location = req.headers().get(HttpHeaderNames.HOST) + WEBSOCKET_PATH;
        return "ws://" + location;
    }
}
package com.example.test.websocket;

import io.netty.channel.Channel;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;

import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

// 正常情况是,后台系统通过接口请求,把数据丢到对应的MQ队列,再由推送服务器读取
public class MessageCenter {
    // 此处假设一个用户一台设备,否则用户的通道应该是多个。
    // TODO 还应该有一个定时任务,用于检测失效的连接(类似缓存中的LRU算法,长时间不使用,就拿出来检测一下是否断开了);
    static ConcurrentHashMap<String, Channel> userInfos = new ConcurrentHashMap<String, Channel>();

    // 保存信息
    public static void saveConnection(String userId, Channel channel) {
        userInfos.put(userId, channel);
    }

    // 退出的时候移除掉
    public static void removeConnection(Object userId) {
        if (userId != null) {
            userInfos.remove(userId.toString());
        }
    }

    final static byte[] JUST_TEST = new byte[1024];

    public static void startTest() {
        System.arraycopy("123456".getBytes(), 0, JUST_TEST, 0, 4);
        Executors.newScheduledThreadPool(1).scheduleAtFixedRate(() -> {
            try {
                // 压力测试,在用户中随机抽取1/10进行发送
                if (userInfos.isEmpty()) {
                    return;
                }
                int size = userInfos.size();
                ConcurrentHashMap.KeySetView<String, Channel> keySetView = userInfos.keySet();
                String[] keys = keySetView.toArray(new String[]{});
                System.out.println(WebSocketServerHandler.counter.sum() + " : 当前用户数量" + keys.length);
                if (Boolean.valueOf(true)) { // 是否开启发送
                    for (int i = 0; i < (size > 10 ? size / 10 : size); i++) {
                        // 提交任务给它执行
                        String key = keys[new Random().nextInt(size)];
                        Channel channel = userInfos.get(key);
                        if (channel == null) {
                            continue;
                        }
                        if (!channel.isActive()) {
                            userInfos.remove(key);
                            continue;
                        }
                        channel.eventLoop().execute(() -> {
                            channel.writeAndFlush(new TextWebSocketFrame(new String(JUST_TEST))); // 推送1024字节
                        });

                    }
                }
            } catch (Exception ex) {
                ex.printStackTrace();
            }
        }, 1000L, 2000L, TimeUnit.MILLISECONDS);
    }
}
package com.example.test.websocket;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;

public final class WebSocketServer {

    static int PORT = 9000;

    public static void main(String[] args) throws Exception {

        EventLoopGroup bossGroup = new NioEventLoopGroup(1);
        EventLoopGroup workerGroup = new NioEventLoopGroup();
        try {
            ServerBootstrap b = new ServerBootstrap();
            b.group(bossGroup, workerGroup)
                    .channel(NioServerSocketChannel.class)
                    .option(ChannelOption.SO_REUSEADDR, true)
                    .childHandler(new WebSocketServerInitializer())
                    .childOption(ChannelOption.SO_REUSEADDR, true);
            b.bind(PORT).addListener(new ChannelFutureListener() {
                    public void operationComplete(ChannelFuture future) throws Exception {
                        System.out.println("端口绑定完成:" + future.channel().localAddress());
                    }
                });


            // 端口绑定完成,启动消息随机推送(测试)
            MessageCenter.startTest();

            System.in.read();
        } finally {
            bossGroup.shutdownGracefully();
            workerGroup.shutdownGracefully();
        }
    }
}

  客户端

package com.example.test.websocket;

import io.netty.channel.*;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.websocketx.*;
import io.netty.util.CharsetUtil;

import java.net.InetSocketAddress;
import java.net.URI;
import java.util.concurrent.atomic.AtomicInteger;

// handler 处理多个~ tcp连接建立之后的事件
// open websocket
public class WebSocketClientHandler extends SimpleChannelInboundHandler<Object> {

    private WebSocketClientHandshaker handshaker;
    private ChannelPromise handshakeFuture;

    public ChannelFuture handshakeFuture() {
        return handshakeFuture;
    }

    @Override
    public void handlerAdded(ChannelHandlerContext ctx) {
        handshakeFuture = ctx.newPromise();
    }

    static AtomicInteger counter = new AtomicInteger(0);

    @Override
    public void channelActive(ChannelHandlerContext ctx) {
        if (handshaker == null) {
            InetSocketAddress address = (InetSocketAddress) ctx.channel().remoteAddress();
            URI uri = null;
            try {
                uri = new URI("ws://" + address.getHostString() + ":" + address.getPort() + "/websocket?userId=" + counter.incrementAndGet());
            } catch (Exception e) {
                e.printStackTrace();
            }
            handshaker = WebSocketClientHandshakerFactory.newHandshaker(
                    uri, WebSocketVersion.V13, null, true, new DefaultHttpHeaders());
        }
        handshaker.handshake(ctx.channel());
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) {
        if ("true".equals(System.getProperty("netease.debug"))) System.out.println("WebSocket Client disconnected!");
    }

    @Override
    public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
        Channel ch = ctx.channel();
        if (!handshaker.isHandshakeComplete()) {
            try {
                handshaker.finishHandshake(ch, (FullHttpResponse) msg);
                System.out.println("WebSocket Client connected!");
                handshakeFuture.setSuccess();
            } catch (WebSocketHandshakeException e) {
                System.out.println("WebSocket Client failed to connect");
                handshakeFuture.setFailure(e);
            }
            return;
        }

        if (msg instanceof FullHttpResponse) {
            FullHttpResponse response = (FullHttpResponse) msg;
            throw new IllegalStateException(
                    "Unexpected FullHttpResponse (getStatus=" + response.status() +
                            ", content=" + response.content().toString(CharsetUtil.UTF_8) + ')');
        }

        WebSocketFrame frame = (WebSocketFrame) msg;
        if (frame instanceof TextWebSocketFrame) {
            TextWebSocketFrame textFrame = (TextWebSocketFrame) frame;
            System.out.println("WebSocket Client received message: " + textFrame.text());
        } else if (frame instanceof PongWebSocketFrame) {
            System.out.println("WebSocket Client received pong");
        } else if (frame instanceof CloseWebSocketFrame) {
            System.out.println("WebSocket Client received closing");
            ch.close();
        }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        cause.printStackTrace();
        if (!handshakeFuture.isDone()) {
            handshakeFuture.setFailure(cause);
        }
        ctx.close();
    }
}
package com.example.test.websocket;

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientCompressionHandler;

public final class WebSocketClient {


    public static void main(String[] args) throws Exception {
        String host = "127.0.0.1";
        int port = 9000;

        EventLoopGroup group = new NioEventLoopGroup();
        try {

            Bootstrap b = new Bootstrap();
            b.group(group).channel(NioSocketChannel.class);
            b.option(ChannelOption.SO_REUSEADDR, true);
            b.handler(new ChannelInitializer<SocketChannel>() {
                @Override
                protected void initChannel(SocketChannel ch) {
                    ChannelPipeline p = ch.pipeline();
                    p.addLast(new HttpClientCodec());
                    p.addLast(new HttpObjectAggregator(8192));
                    p.addLast(WebSocketClientCompressionHandler.INSTANCE);
                    p.addLast("webSocketClientHandler", new WebSocketClientHandler());
                }
            });
            // tcp 建立连接
            b.connect(host, port).sync().get();

            System.in.read();
        } finally

        {
            group.shutdownGracefully();
        }
    }
}

  使用Netty开发时的注意事项

  1、对于无状态的handler,可以复用,没必要为每一个channel创建独立的handler,节省资源.在对应的handler加ChannelHandler.Sharable 注解即可

  2、耗时逻辑引入业务线程池(为xhandler指定专用的线程池,NioEventLoopGroup本质就是一组线程),方式如下

EventLoopGroup xhandlerGroup = new NioEventLoopGroup();
pipeline.addLast(xhandlerGroup, xHandller);

  3、大数据量分批推送,防止io线程被一个连接长时间占用

  4、 ByteBuf对象复用是有条件的,使用时注意检查ByteBuf是否释放(使用完后调用release方法),降低gc压力

 

posted on 2020-02-14 18:19  请叫我西毒  阅读(540)  评论(0编辑  收藏  举报