对 NIO 底层的一些认识及报文定界实现
在 NIO 中,我们读取 channel 中的数据,会通过 channel 的 read 尽最大努力将 Buffer 填满,填满后做一些其它处理。
对于 TCP 协议来说,这种做法无可厚非,因为 TCP 协议本身就不提供定界策略,只负责提供可靠的连接,也就是数据可靠的收发( 以 ack 应答机制为核心)。
但是报文总是需要被分割,才能被正确的解析。没有经过定界的报文会造成半包/粘包问题,因此我们需要在 TCP 协议的上层自己实现定界策略。
对于 BIO 的定界在前面的博客中已经给出过实现:手写一个模块化的 TCP 服务端客户端 。这次我们实现一下 NIO 的报文定界。
NIO 与 BIO 本质的不同在于:BIO 由我们编写代码监视链接的输入输出缓冲区,也就是监视数据的到达和发送,由 OS 完成数据从内核空间到网卡的传输。
而对于 NIO 来说,OS 更进一步,帮我们把监视输入输出缓冲区的活儿也干了。我们只需要编写数据处理的代码,而不需要主动的等待数据的就绪,比如:
socket.getInputStream.read(bytes);
这种阻塞等待的操作,在 NIO 中就没有用到的必要。
NIO 底层对应的是 OS 的 select/poll/epoll 等函数。我们向 OS 注册感兴趣的事件,OS 在事件发生时通知我们处理事件,也就是所谓的事件驱动方式。
从某种意义上说,OS 本身便是事件驱动的,因为其不得不处理许多意料外的情况。比如键盘/鼠标/网卡等外设的输入,OS 不知道这些事件会在什么时间发生,但事件一旦发生,OS 必须做出响应。
为了达到这一目的,处理器提供了中断机制。外设向处理器的中断引脚发送中断信号,OS 立即切换上下文并调用 OS 注册的中断处理函数来处理这些信号。
打个比方就是,OS 在启动时把一大堆中断处理函数注册到中断向量表,这就像一个反应堆。一旦有中断信号到达处理器,立即“点燃”反应堆中对应的函数,发生反应。
OS 的事件驱动是借助硬件层提供的“中断”机制实现的, NIO 的事件驱动也一样。网卡数据到达发送中断信号,OS 通过中断处理函数将其封装为事件通知用户程序(边缘触发的一般实现)。或者采用更加低效的方式,OS 主动循环检查连接在内核空间中对应的缓冲区,发现有事件便通知用户程序(水平触发的一般实现)。
对于边缘触发,数据缓冲区状态改变时,也就是有数据到达时(中断发生时,个人理解),才会触发事件。这就要求每次在事件发生时,我们必须处理所有触发该事件的数据。
对于水平触发,只要数据缓冲区中的数据没有被处理完成,OS 就会不断的触发事件。这样我们可以相对灵活的处理以就绪的数据,比如客户端发送了 512K 数据,我们可以在一次事件处理中只处理 100K 数据。因为 这 512K 数据只要没有被处理完,OS 就会不断的通知我们有事件发生,我们在接下来的事件处理中处理剩余数据即可。
select 与 poll 均为水平触发模式,epoll 支持边缘触发与水平触发两种模式,JAVA 的 NIO 仅支持水平触发模式。为何这样设计暂不考虑,猜测是基于平台一致性或代码编写难度方面的考虑。
水平触发的情况下,我们对数据的处理相对灵活。我们可以准备一个缓冲区存放定长数据,缓冲区填满便为接收到一条完整的报文,把数据交给数据处理函数。当 channel 中数据出现粘包情况时,不管 channel 还可以读取到多少数据(内核空间该连接的接收缓冲区还存在多少数据),我们将缓冲区填满后便不再读取,剩下的数据放到下个事件处理周期中去处理。
下面以通过固定长度的包头标识报文内容长度的定界策略为例,进行实现。
接收策略:
public interface ReciveRegister { public void doRecive(SocketChannel socketChannel) throws Exception; }
/** * @Author Niuxy * @Date 2020/5/28 8:36 下午 * @Description 报文头标示报文长度的定界策略 */ public class HLRegisterImpl implements ReciveRegister { //报文头长度 private int headLength = 0; //报文内容长度 private int messageLength = 0; private MessageHandler messageHandler; boolean isCache = false; ByteBuffer headCacheBuffer; ByteBuffer messageCacheBuffer; public HLRegisterImpl(int headLength, MessageHandler messageHandler) { this.messageHandler = messageHandler; this.headLength = headLength; headCacheBuffer = ByteBuffer.allocate(headLength); } @Override public void doRecive(SocketChannel socketChannel) throws Exception { //判断是否已读取报文头 if (messageLength == 0) { int readLen = socketChannel.read(headCacheBuffer); if (Util.isFullBuffer(headCacheBuffer)) { headCacheBuffer.flip(); messageLength = headCacheBuffer.getInt(); messageCacheBuffer = ByteBuffer.allocate(messageLength); headCacheBuffer.clear(); } } else { int readLen = socketChannel.read(messageCacheBuffer); if (Util.isFullBuffer(messageCacheBuffer)) { messageHandler.doHandler(socketChannel, messageCacheBuffer); messageLength = 0; headLength = 0; messageCacheBuffer = null; System.gc(); } } } }
数据处理器:
public interface MessageHandler { public void doHandler(SocketChannel socketChannel,ByteBuffer messageBuffer) throws Exception; }
public class PrintMessageHandlerImpl implements MessageHandler { String target = "hellow server!hellow server!hellow server!hellow server!hellow server!hellow server!hellow server!hellowhellow server!hellow server!hellow server!hellow server!hellow server!hellow server!hellow"; @Override public void doHandler(SocketChannel socketChannel, ByteBuffer messageBuffer) throws IOException { String message = new String(messageBuffer.array()); // String message=Util.bufferToString(messageBuffer); if (!target.equals(message)) { System.out.println("error!: " + message); } else { System.out.println("success!"); } messageBuffer = null; } }
server端实现:
public class NioSever { private int port; private ReciveRegisterType reciveRegisterType; public NioSever(int port, ReciveRegisterType reciveRegisterType) { this.port = port; this.reciveRegisterType = reciveRegisterType; } public void start() throws IOException { if (port == 0) { throw new NullPointerException("缺少启动参数!"); } Selector selector = Selector.open(); ServerSocketChannel severChannel = ServerSocketChannel.open(); severChannel.configureBlocking(false); severChannel.bind(new InetSocketAddress(port)); System.out.println("Server start!"); severChannel.register(selector, SelectionKey.OP_ACCEPT); //select会阻塞,知道有就绪连接写入selectionKeys while (!Thread.currentThread().isInterrupted()) { if (selector.select(100) == 0) { continue; } Iterator<SelectionKey> keys = selector.selectedKeys().iterator(); while (keys.hasNext()) { //SelectionKey为select中记录的就绪请求的数据结构,其中包括了连接所属的socket及就绪的类型 SelectionKey key = keys.next(); //处理事件,不管是否可以处理完成,都删除 key。因为 soketChannel 为水平触发的, // 未处理完成的事件删除后会被再次通知 keys.remove(); if (key.isAcceptable()) { SocketChannel socketChannel = severChannel.accept(); System.out.println("与 client:" + socketChannel.getRemoteAddress() + " 建立连接!"); socketChannel.configureBlocking(false); SelectionKey readKey = socketChannel.register(selector, SelectionKey.OP_READ); readKey.attach(getRegister()); } else if (key.isReadable()) { SocketChannel socketChannel = (SocketChannel) key.channel(); try { ReciveRegister rec = (ReciveRegister) key.attachment(); rec.doRecive(socketChannel); } catch (Exception e) { e.printStackTrace(); } } } } } //构建报文接收策略 private ReciveRegister getRegister() { if (this.reciveRegisterType == ReciveRegisterType.HL) { return new HLRegisterImpl(4, new PrintMessageHandlerImpl()); } //to-do FL and other type return null; } /** * @Author Niuxy * @Date 2020/5/29 3:47 下午 * @Description 内部枚举,报文接受方式 */ public enum ReciveRegisterType { //报文头标识长度 HL, //固定长度 FL } }
启动 server:
public class ServerDemo { public static void main(String[] args) throws IOException { NioSever server = new NioSever(8000, NioSever.ReciveRegisterType.HL); server.start(); } }
写一个客户端进行测试,测试单个连接多个数据包以及多个连接并发的情况下,是否可以正确的定界:
public class TestClient { public static void main(String[] args) throws Exception { final String msg = "hellow server!hellow server!hellow server!hellow server!hellow server!hellow server!hellow server!hellowhellow server!hellow server!hellow server!hellow server!hellow server!hellow server!hellow"; Thread thread0 = new Thread(() -> { sendMsg(msg); }); Thread thread1 = new Thread(() -> { sendMsg(msg); }); Thread thread2 = new Thread(() -> { sendMsg(msg); }); Thread thread3 = new Thread(() -> { sendMsg(msg); }); thread0.start(); thread1.start(); thread2.start(); thread3.start(); } private static void sendMsg(String msg) { try {
for(int i=0;i<10;i++){ send(msg);
} } catch (Exception e) { e.printStackTrace(); } } public static void send(String message) throws Exception { Socket socket = new Socket("127.0.0.1", 8000); byte[] messageBytes = message.getBytes(); Integer length = messageBytes.length; System.out.println(length); OutputStream outputStream = socket.getOutputStream(); outputStream.write(ByteBuffer.allocate(4).putInt(length).array()); Thread.sleep(100); outputStream.write(messageBytes); outputStream.flush(); outputStream.close(); } }
测试结果,全部正确定界: