手写简单Rpc服务

NioServer.java
package top.icss.rpc.nio;

import top.icss.rpc.nio.entity.RpcRequest;
import top.icss.rpc.nio.entity.RpcResponse;
import top.icss.utils.IOUtil;

import java.io.*;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.Iterator;
import java.util.Set;

public class NioServer implements Runnable{

    private Selector selector;
    private ServerSocketChannel server;

    private Object service;

    public NioServer(Object service) throws IOException {
        this(service, 5891);
        this.service = service;
    }

    public NioServer(final Object service, int port) throws IOException {
        this.service = service;
        server = ServerSocketChannel.open();
        server.bind(new InetSocketAddress(port));
        server.configureBlocking(false);//非阻塞

        selector = Selector.open();
        //注册到selector 监听OP_ACCEPT事件
        SelectionKey key = server.register(selector, SelectionKey.OP_ACCEPT);
        //将新连接处理器作为附件,绑定到sk选择器
        key.attach(new Acceptor());

    }

    public void start(){
        new Thread(this).start();
    }

    @Override
    public void run() {
        try {
            System.out.println("rpc 服务启动成功..");
            // class Reactor continued
            //无限循环等待网络请求的到来
            //其中selector.select();会阻塞直到有绑定到selector的请求类型对应的请求到来,一旦收到事件,处理分发到对应的handler,并将这个事件移除
            while(!Thread.interrupted()){
                selector.select();
                Set<SelectionKey> selectionKeys = selector.selectedKeys();
                Iterator<SelectionKey> iterator = selectionKeys.iterator();
                while (iterator.hasNext()){
                    //反应器负责dispatch收到的事件
                    dispatch(iterator.next());
                    iterator.remove();
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * 反应器的分发方法
     * @param key
     */
    void dispatch(SelectionKey key){
        Runnable runnable = (Runnable) key.attachment();
        if(runnable != null)
            runnable.run();
    }

    /**
     * 新连接处理器
     */
    class Acceptor implements Runnable{

        @Override
        public void run() {
            try {
                //接收新连接
                //需要为新连接,创建一个输入输出的handler处理器
                SocketChannel c = server.accept();
                System.out.println("新的连接:"+ c.getRemoteAddress());
                if(c != null){
                    new Handler(selector, c);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * Handler就是负责socket的数据输入、业务处理,结果输出。
     */
    final class Handler implements Runnable {
        final SelectionKey key;
        final SocketChannel socketChannel;

        Handler(Selector selector, SocketChannel socketChannel) throws IOException {
            this.socketChannel = socketChannel;
            socketChannel.configureBlocking(false);
            // Optionally try first read now
            key = socketChannel.register(selector, 0);
            //设置附件
            key.attach(this);
            //注册读写就绪事件
            key.interestOps(SelectionKey.OP_READ);
            selector.wakeup();
        }


        @Override
        public void run() {
            // class Handler continued
            //具体的请求处理,可能是读事件、写事件等
            ObjectInputStream inputStream = null;
            ObjectOutputStream outputStream = null;
            SocketChannel socketChannel = null;
            try {
                if(key.isReadable()){
                    socketChannel = (SocketChannel) key.channel();
                    ByteBuffer buffer = ByteBuffer.allocate(1024);
                    buffer.clear();
                    socketChannel.read(buffer);

                    ByteArrayInputStream bin = new ByteArrayInputStream(buffer.array());

                    inputStream = new ObjectInputStream(bin);
                    RpcRequest request = (RpcRequest) inputStream.readObject();

                    RpcResponse response = new RpcResponse();
                    response.setId(request.getId());
                    try {
                        // 处理并设置返回结果
                        Object result = invoke(request);
                        response.setResult(result);
                    } catch (Throwable t) {
                        response.setError(t);
                    }
                    write(socketChannel, outputStream, response);
                }
            } catch (Exception e) {
                e.printStackTrace();
            }finally {
                IOUtil.closeQuietly(inputStream);
                IOUtil.closeQuietly(outputStream);
                IOUtil.closeQuietly(socketChannel);
            }
        }
    }

    /**
     * 发送
     * @param socketChannel
     * @param outputStream
     * @param response
     * @throws IOException
     */
    private void write(SocketChannel socketChannel, ObjectOutputStream outputStream, RpcResponse response) throws IOException {
        ByteArrayOutputStream bout = new ByteArrayOutputStream();
        outputStream = new ObjectOutputStream(bout);
        outputStream.writeObject(response);
        outputStream.flush();
        byte[] arr = bout.toByteArray();
        ByteBuffer wrap = ByteBuffer.wrap(arr);
        socketChannel.write(wrap);
    }

    /**
     * 反射调用
     * @param request
     * @return
     * @throws Exception
     */
    private Object invoke(RpcRequest request) throws Exception {
        if (service == null){
            throw new IllegalArgumentException("service instance == null");
        }

        String className = request.getClassName();
        String methodName = request.getMethodName();
        Object[] parameters = request.getParameters();

        Class[] parameterTypes = new Class[parameters.length];
        for (int i = 0, length = parameters.length; i < length; i++) {
            parameterTypes[i] = parameters[i].getClass();
        }

        System.out.println("服务端开始调用--"+ request + ", "+ className + " ," + service.getClass().getName());

        Method method = service.getClass().getMethod(methodName, parameterTypes);

        Object result = method.invoke(service, parameters);
        return result;
    }

}

 

NioClient.java
import top.icss.rpc.nio.entity.RpcRequest;
import top.icss.rpc.nio.entity.RpcResponse;
import top.icss.utils.IOUtil;

import java.io.*;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.UUID;


public class NioClient {
    private int port = 5891;
    private String host = "127.0.0.1";

    /**
     * 远程调用
     * @param request
     * @return
     */
    public RpcResponse send(RpcRequest request){
        ObjectInputStream inputStream = null;
        ObjectOutputStream outputStream = null;
        SocketChannel client = null;
        try {
            client = SocketChannel.open(new InetSocketAddress(host, port));
            write(client, outputStream, request);

            //接收
            ByteBuffer buffer = ByteBuffer.allocate(1024);
            buffer.clear();
            client.read(buffer);
            ByteArrayInputStream bin = new ByteArrayInputStream(buffer.array());

            inputStream = new ObjectInputStream(bin);
            RpcResponse response = (RpcResponse) inputStream.readObject();
            return response;
        } catch (Exception e) {
            throw new RuntimeException("发起远程调用异常!",e);
        }finally {
            IOUtil.closeQuietly(inputStream);
            IOUtil.closeQuietly(outputStream);
            IOUtil.closeQuietly(client);
        }
    }

    /**
     * 发送
     * @param socketChannel
     * @param outputStream
     * @param request
     * @throws IOException
     */
    private void write(SocketChannel socketChannel, ObjectOutputStream outputStream, RpcRequest request) throws IOException {
        ByteArrayOutputStream bout = new ByteArrayOutputStream();
        outputStream = new ObjectOutputStream(bout);
        outputStream.writeObject(request);
        outputStream.flush();
        byte[] arr = bout.toByteArray();
        ByteBuffer wrap = ByteBuffer.wrap(arr);
        socketChannel.write(wrap);
    }

    /**
     * 动态代理
     * @param interfaceClass
     * @param <T>
     * @return
     * @throws Exception
     */
    public <T> T proxy(Class<T> interfaceClass){
        return (T) Proxy.newProxyInstance(interfaceClass.getClassLoader(), new Class[]{interfaceClass}, new InvocationHandler() {
            @Override
            public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
                RpcRequest request = new RpcRequest();
                request.setId(UUID.randomUUID().toString());
                request.setClassName(interfaceClass.getName());
                request.setMethodName(method.getName());
                request.setParameters(args);
                RpcResponse response = send(request);
                System.out.println("调用完成--" + response);
                // 检测是否有异常
                if (response.isError()) {
                    throw response.getError();
                } else {
                    return response.getResult();
                }
            }
        });
    }

}

 

RpcRequest.java
public class RpcRequest implements java.io.Serializable{
    private String id;
    private String className;
    private String methodName;
    private Object[] parameters;

    public String getId() {
        return id;
    }

    public void setId(String id) {
        this.id = id;
    }

    public String getClassName() {
        return className;
    }

    public void setClassName(String className) {
        this.className = className;
    }

    public String getMethodName() {
        return methodName;
    }

    public void setMethodName(String methodName) {
        this.methodName = methodName;
    }

    public Object[] getParameters() {
        return parameters;
    }

    public void setParameters(Object[] parameters) {
        this.parameters = parameters;
    }

    @Override
    public String toString() {
        return "RpcRequest{" +
                "id='" + id + '\'' +
                ", className='" + className + '\'' +
                ", methodName='" + methodName + '\'' +
                ", parameters=" + Arrays.toString(parameters) +
                '}';
    }

 

RpcResponse.java
public class RpcResponse implements java.io.Serializable{
    private String id;
    private Throwable error;
    private Object result;

    public boolean isError() {
        return null != this.error;
    }

    public String getId() {
        return id;
    }

    public void setId(String id) {
        this.id = id;
    }

    public Throwable getError() {
        return error;
    }

    public void setError(Throwable error) {
        this.error = error;
    }

    public Object getResult() {
        return result;
    }

    public void setResult(Object result) {
        this.result = result;
    }

    @Override
    public String toString() {
        return "RpcResponse{" +
                "id='" + id + '\'' +
                ", error=" + error +
                ", result=" + result +
                '}';
    }
}

 

IOUtil.java
public class IOUtil {

    public static void closeQuietly(java.io.Closeable o)
    {
        if (null == o) return;
        try
        {
            o.close();
        } catch (IOException e)
        {
            e.printStackTrace();
        }
    }

    /**
     * 格式化文件大小
     *
     * @param length
     * @return
     */
    public static String getFormatFileSize(long length)
    {
        double size = ((double) length) / (1 << 30);
        if (size >= 1)
        {
            return fileSizeFormater.format(size) + "GB";
        }
        size = ((double) length) / (1 << 20);
        if (size >= 1)
        {
            return fileSizeFormater.format(size) + "MB";
        }
        size = ((double) length) / (1 << 10);
        if (size >= 1)
        {
            return fileSizeFormater.format(size) + "KB";
        }
        return length + "B";
    }

    private static DecimalFormat fileSizeFormater = FormatUtil.decimalFormat(1);

/**
     * 格式化
     */
    static class FormatUtil{
        /**
         * 设置数字格式,保留有效小数位数为fractions
         *
         * @param fractions 保留有效小数位数
         * @return 数字格式
         */
        public static DecimalFormat decimalFormat(int fractions)
        {

            DecimalFormat df = new DecimalFormat("#0.0");
            df.setRoundingMode(RoundingMode.HALF_UP);
            df.setMinimumFractionDigits(fractions);
            df.setMaximumFractionDigits(fractions);
            return df;
        }
    }
}

 

测试

public interface HelloService {
    String hello(String msg);
}

public class HelloServiceImpl implements HelloService {
    @Override
    public String hello(String msg) {
        return "Hello "+ msg;
    }
}

 

 

RpcServer.java
package top.icss.rpc.nio.test;

import top.icss.rpc.nio.NioServer;
import top.icss.rpc.nio.test.service.HelloService;
import top.icss.rpc.nio.test.service.HelloServiceImpl;

import java.io.IOException;

/**
 * @author cd
 * @desc 暴露服务
 * @create 2020/3/16 16:59
 * @since 1.0.0
 */
public class RpcServer {

    public static void main(String[] args) throws IOException {
        HelloService helloService = new HelloServiceImpl();
        NioServer server = new NioServer(helloService);
        server.start();
    }
}

 

RpcClient.java
package top.icss.rpc.nio.test;

import top.icss.rpc.nio.NioClient;
import top.icss.rpc.nio.test.service.HelloService;

/**
 * @author cd
 * @desc 引用服务
 * @create 2020/3/16 17:02
 * @since 1.0.0
 */
public class RpcClient {
    public static void main(String[] args) throws InterruptedException {
        NioClient client = new NioClient();
        HelloService helloService = client.proxy(HelloService.class);
        for (int i = 0; i < Integer.MAX_VALUE; i++){
            String hello = helloService.hello("World " + i);
            System.out.println(hello);
            Thread.sleep(500);
        }
    }
}

 

posted @ 2020-04-02 16:36  不朽丶  阅读(584)  评论(0编辑  收藏  举报
页脚