JDK-In-Action-BIO-TcpSample

BIO

一个模拟TCP交互协议的阻塞套接字编程示例

TcpSample

模拟TCP连接状态机

package com.gitee.zhwcong.nio;

/**
 *
 */
public class TcpSample {

    private boolean client;
    private ClientState clientState = ClientState.CLOSED;
    private ServerState serverState = ServerState.CLOSED;

    public boolean isClosed() {
        return clientState == ClientState.CLOSED && serverState == ServerState.CLOSED;
    }

    enum ClientState {
        CLOSED,
        SYN_SEND,
        ESTABLISHED,
        FIN_WAIT_1,
        FIN_WAIT_2,
        TIME_WAIT,
    }

    enum ServerState {
        CLOSED,
        LISTENER,
        SYN_RCVD,
        ESTABLISHED,
        CLOSE_WAIT,
        LAST_ACK,
    }

    public TcpSample(boolean client) {
        this.client = client;
        if (this.client) {
            this.clientState = ClientState.CLOSED;
        } else {
            this.serverState = ServerState.CLOSED;
        }
    }

    public void open() {
        this.serverState = ServerState.LISTENER;
    }

    public String connect() {
        this.clientState = ClientState.SYN_SEND;
        return "[SYN]";
    }

    public String close() {
        if (client) {
            return "[FIN]";
        }
        return null;
    }

    public String reply(String msg) {
        if (client) {
            switch (clientState) {
                case CLOSED:
                    break;
                case SYN_SEND:
                    if ("[SYN,ACK]".equals(msg)) {
                        this.clientState = ClientState.ESTABLISHED;
                        return "[ACK]";
                    }
                    break;
                case ESTABLISHED:
                    if ("[HELLO]".equals(msg)) {
                        return "[HELLO]";
                    } else if ("[BYE]".equals(msg)) {
                        this.clientState = ClientState.FIN_WAIT_1;
                        return "[FIN]";
                    }
                    break;
                case FIN_WAIT_1:
                    if ("[FIN,ACK]".equals(msg)) {
                        this.clientState = ClientState.FIN_WAIT_1;
                        this.clientState = ClientState.TIME_WAIT;
                        this.clientState = ClientState.CLOSED;
                        return "[ACK]";
                    }
                    break;
                case FIN_WAIT_2:
                    break;
                case TIME_WAIT:
                    break;
            }
        } else {
            switch (serverState) {
                case CLOSED:
                    break;
                case LISTENER:
                    if ("[SYN]".equals(msg)) {
                        this.serverState = ServerState.SYN_RCVD;
                        return "[SYN,ACK]";
                    }
                    break;
                case SYN_RCVD:
                    if ("[ACK]".equals(msg)) {
                        this.serverState = ServerState.ESTABLISHED;
                        return "[HELLO]";
                    }
                    break;
                case ESTABLISHED:
                    if ("[HELLO]".equals(msg)) {
                        this.serverState = ServerState.ESTABLISHED;
                        return "[BYE]";
                    } else if ("[FIN]".equals(msg)) {
                        this.serverState = ServerState.LAST_ACK;
                        return "[FIN,ACK]";
                    }
                    break;
                case CLOSE_WAIT:
                    break;
                case LAST_ACK:
                    if ("[ACK]".equals(msg)) {
                        this.serverState = ServerState.CLOSED;
                    }
                    break;
            }
        }
        return null;
    }

    public static void log(String rcvd, String send) {
        if (rcvd != null) {
            System.out.print("<" + rcvd);
        }
        if (send != null) {
            System.out.print(">" + send);
        }
        System.out.println();
    }

}

TcpSampleSocket

套接字客户端

package com.gitee.zhwcong.nio;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.net.Socket;

/**
 *
 */
public class TcpSampleSocket {
    public static void main(String[] args) {
        String host = "localhost";
        int port = 8080;
        int max = 30;
        try {
            TcpSample tcpSample = new TcpSample(true);
            Socket socket = new Socket(host, port);
            PrintWriter writer = new PrintWriter(socket.getOutputStream(), true);
            BufferedReader reader = new BufferedReader(new InputStreamReader(socket.getInputStream()));
            int count = 1;
            String connect = tcpSample.connect();
            writer.println(connect);
            System.out.println(connect);
            while (!tcpSample.isClosed() && count < max) {
                String msg = reader.readLine();
                String replay = tcpSample.reply(msg);
                if (replay != null) {
                    writer.println(replay);
                    System.out.println(msg + ">>" + replay);
                } else {
                    System.out.println(msg);
                }
                count++;
            }
            socket.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
        System.out.println("Client is Closed");
    }
}

TcpSampleMultiThreadServerSocket

套接字服务端

package com.gitee.zhwcong.nio;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
 *
 */
public class TcpSampleMultiThreadServerSocket {
    public static void main(String[] args) {
        start();
    }

    public static void start() {
        ExecutorService executor = Executors.newCachedThreadPool();
        int port = 8080;
        try (
                ServerSocket serverSocket = new ServerSocket(port)
        ) {
            System.out.println("Server is Running Port:" + port);
            while (true) {
                Socket socket = serverSocket.accept();
                System.out.println("Accept New One:" + socket.getInetAddress() + ":" + socket.getPort());
                executor.submit(() -> {
                    accept(socket);
                });
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private static void accept(Socket socket) {
        try {
            int max = 30;
            TcpSample tcpSample = new TcpSample(false);
            tcpSample.open();
            PrintWriter writer = new PrintWriter(socket.getOutputStream(), true);
            BufferedReader reader = new BufferedReader(new InputStreamReader(socket.getInputStream()));
            int count = 0;
            do {
                String msg = reader.readLine();
                String reply = tcpSample.reply(msg);
                if (reply != null) {
                    writer.println(reply);
                    System.out.println(msg + ">>" + reply);
                } else {
                    System.out.println(msg);
                }
                count++;
            } while (!tcpSample.isClosed() && count < max);
            socket.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}

运行效果

  • server
Server is Running Port:8080
Accept New One:/127.0.0.1:64802
[SYN]>>[SYN,ACK]
[ACK]>>[HELLO]
[HELLO]>>[BYE]
[FIN]>>[FIN,ACK]
[ACK]
  • client
[SYN]
[SYN,ACK]>>[ACK]
[HELLO]>>[HELLO]
[BYE]>>[FIN]
[FIN,ACK]>>[ACK]
Client is Closed

改进使用NIO Client 的示例

  • TcpSampleSocketChannel.java
package com.gitee.zhwcong.nio;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

/**
 * TCP Sample
 * About NIO Server See : https://github.com/jjenkov/java-nio-server
 */
public class TcpSampleSocketChannel {

    public static final String LINE_SPE = "\r\n";

    public static void main(String[] args) throws Exception {
        new Thread(() -> {
            try {
                TcpSampleMultiThreadServerSocket.start();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }).start();
        new TcpSampleSocketChannel().start();
    }

    public void start() throws IOException {
        selector = Selector.open();
        //1. open
        SocketChannel socketChannel = SocketChannel.open();
        socketChannel.configureBlocking(false);
        //2. connect
        socketChannel.connect(new InetSocketAddress("localhost", 8080));
        socketChannel.register(selector, SelectionKey.OP_CONNECT);

        while (selector.isOpen()) {
            selector.select(3);
            final Set<SelectionKey> selectionKeys = selector.selectedKeys();
            final Iterator<SelectionKey> iterator = selectionKeys.iterator();
            while (iterator.hasNext()) {
                SelectionKey key = iterator.next();
                if (key.isConnectable()) {
                    fireChannelConnect(key);
                } else if (key.isReadable()) {
                    fireChannelRead(key);
                } else if (key.isWritable()) {
                    fireChannelWrite(key);
                }
                iterator.remove();
            }
        }
    }


    private Selector selector;
    private ByteBuffer sendBuffer = ByteBuffer.allocateDirect(1024);
    private ByteBuffer rcvdBuffer = ByteBuffer.allocateDirect(1024);
    private ByteBuffer msgBuffer = ByteBuffer.allocate(16);

    /**
     * 处理半包读
     *
     * @return
     */
    private List<String> byteToMessage() {
        rcvdBuffer.flip();
        List<String> message = new ArrayList<>();
        while (rcvdBuffer.hasRemaining()) {
            byte b = rcvdBuffer.get();
            if (b == '\r') {
                //read \n
                rcvdBuffer.get();
                //读取完整行数据
                msgBuffer.flip();
                byte[] d = new byte[msgBuffer.limit()];
                msgBuffer.get(d);
                msgBuffer.clear();

                message.add(new String(d));
            } else {
                //半包数据缓存
                msgBuffer.put(b);
            }
        }
        rcvdBuffer.clear();
        return message;
    }

    private void fireChannelRead(SelectionKey key) throws IOException {
        final TcpSample tcpSample = (TcpSample) key.attachment();
        final SocketChannel socketChannel = (SocketChannel) key.channel();
        socketChannel.read(rcvdBuffer);
        final List<String> messages = byteToMessage();

        for (String msg : messages) {
            if (!tcpSample.isClosed()) {
                String replay = tcpSample.reply(msg);
                if (replay != null) {
                    //大数据量情况,可能无法一次写完
                    final ByteBuffer buffer = ByteBuffer.wrap((replay + LINE_SPE).getBytes());
                    do {
                        socketChannel.write(buffer);
                    } while (buffer.hasRemaining());
                    TcpSample.log(msg, replay);
                } else {
                    System.out.println(msg);
                }
            }
        }
        socketChannel.register(selector, SelectionKey.OP_READ, tcpSample);
        if (tcpSample.isClosed()) {
            socketChannel.close();
            closeSelector(key);
        }
    }

    private void closeSelector(SelectionKey key) throws IOException {
        key.selector().close();
    }

    private void fireChannelWrite(SelectionKey key) throws IOException {
    }

    private void fireChannelConnect(SelectionKey key) throws IOException {
        final SocketChannel socketChannel = (SocketChannel) key.channel();
        if (socketChannel.isConnectionPending()) {
            socketChannel.finishConnect();
        }
        socketChannel.configureBlocking(false);

        final TcpSample tcpSample = new TcpSample(true);
        final String msg = tcpSample.connect();
        TcpSample.log(null, msg);
        final ByteBuffer data = ByteBuffer.wrap((msg + LINE_SPE).getBytes());
        do {
            socketChannel.write(data);
        } while (data.hasRemaining());
        socketChannel.register(selector, SelectionKey.OP_READ, tcpSample);
    }
}

引用

posted @ 2020-05-06 18:06  onion94  阅读(250)  评论(0编辑  收藏  举报