手写一个 RPC 通信框架

 

  在使用微服务的过程中,RPC 是永远绕不开的点。之前并没有磕的很深,一直觉着 RPC 是一个黑魔法。比如我们常用的 Dubbo、SpringCloud 等框架,将微服务模块间的方法调用封装的像本地方法调用一样,方便又令人费解。

  今天如愿以偿的仿照 Dubbo 自己手写了一个 rpc “框架”,虽然简陋但终于觉着抓住了 rpc 的精髓。同时也觉着,平时修炼一些 “无用” 的内功确实是很重要的,比如操作系统、计算机原理、计算机网络等等这些硬菜。它们让我真正动手进行这次实现时并不是很吃力(想想刚入行时就妄想手写Dubbo,无疾而终好多次,感觉当时真是浮躁)。

  总结一下实现的关键部分:

  1.调度中心没有使用 zookeeper,只是通过文件进行 服务端/客户端 进程间的数据共享以方便测试,因为本次实现的重点并不在调度中心。

  2. 客户端与服务端的通信需要定义固定的数据结构,将数据结构封装在 Invocation 类中,传输该类序列化对象进行通信。

   3. 支持通过 Http 以及 自定义协议 两种方式进行远程调用。Http 虽然非常通用,但对于过程调用来说太 “重” 了,绝大部分控制位不会被使用到,会浪费传输带宽。Dubbo 便在 TCP 协议的基础上封装了 Dubbo 协议用于更加高效的进行过程调用。这里我们自定义一个“协议”,因为我们的实现更简单或者说不太完善,需要的控制位更少。报文格式如下:

  基于自定义协议通信需要基于 Socket 通信框架,Dubbo 使用的是大佬 netty,这里为了方便我自己封装了一个简单的 TCP 通信框架,详见:手写一个模块化的 TCP 服务端客户端

  4. 服务端通过 Invocation 实例中的数据对目标方法进行反射调用。

  5. 客户端通过动态代理技术提供语法糖。动态代理技术为本地接口生成一个实现类,实现类中的方法都远程调用服务端暴露的服务,让我们在调用远程方式时感觉像在调用本地方法一样。客户端调用效果:

   6. 本次实现重点是进行 RPC 调用的封装,因此许多其它细节并没有打磨。比如:

  注册治理中心简单的使用文件的方式,会并发效率非常低;

  客户端在进行调用时,只提供了一种随机选取服务端的负载均衡策略,这在实际生产场景中是最少使用到的负载均衡策略,但好在实现简单;

  客户端应当维护一个与服务端通信的长连接,因为过程调用可能会非常频繁,频繁的创建和销毁连接会浪费非常多的性能,本次实现都是以短连接的方式实现的;

  可以自定义注解来标识对外提供的服务,服务端启动时动态扫描并将这些服务注册到注册中心。这个后续会做,目前还是采用手工配置的方式;

  通信报文内容可以更加丰富,比如添加标识调用成功或失败的标识位。

  7. 总的来说,调用过程如下:

   下面是实现,首先是注册中心:

/**
 * @Author Nxy
 * @Date 2020/3/19 11:00
 * @Description 服务端通过服务名找到对应实现
 */
public interface LocalRegistry {

    /**
     * @Author Nxy
     * @Date 2020/3/19 11:03
     * @Description 注册到服务中心
     */
    void register(String interfaceName, Class interfaceImplClass);

    /**
     * @Author Nxy
     * @Date 2020/3/19 11:03
     * @Description 根据服务名称获取实现类
     */
    Class get(String interfaceName);

}
/**
 * @Author Nxy
 * @Date 2020/3/19 11:02
 * @Description 服务调用方通过服务名找到调用地址
 */
public interface WebRegistry {

    /**
     * @Author Nxy
     * @Date 2020/3/19 11:02
     * @Description 注册到治理中心
     */
    void register(String interfaceName, URL host);

    /**
     * @Author Nxy
     * @Date 2020/3/19 11:02
     * @Description 根据服务名找到调用地址,当调用地址有多个时,采用随机选取的负载均衡策略
     */
    URL getRandomURL(String interfaceName);

}

  注册中心实现类:

/**
 * @Author Nxy
 * @Date 2020/3/23 23:02
 * @Description 简单的远程注册中心
 * 注册内容写入文件,用于进程间共享
 */
public class BasicWebRegistry implements WebRegistry {
    private Map<String, List<URL>> registerMap = new HashMap<String, List<URL>>(1024);
    public static final String path = "C://tmp//rpc";
    private static BasicWebRegistry basicWebRegister;

    public static BasicWebRegistry getInstance() {
        if (basicWebRegister == null) {
            synchronized (BasicWebRegistry.class) {
                if (basicWebRegister == null) {
                    basicWebRegister = new BasicWebRegistry();
                }
            }
        }
        return basicWebRegister;
    }

    @Override
    public void register(String interfaceName, URL host) {
        if (registerMap.containsKey(interfaceName)) {
            List<URL> list = registerMap.get(interfaceName);
            list.add(host);
        } else {
            List<URL> list = new LinkedList<URL>();
            list.add(host);
            registerMap.put(interfaceName, list);
        }
        try {
            saveFile(path, registerMap);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override
    public URL getRandomURL(String interfaceName) {
        int i = 0;
        //尝试 5 次读取注册文件
        try {
            while (i < 5) {
                return getRandomURLOnce(interfaceName);
            }
        } catch (IOException e) {
            e.printStackTrace();
            i++;
        }
        return null;
    }


    public URL getRandomURLOnce(String interfaceName) throws IOException {
        try {
            registerMap = (Map<String, List<URL>>) readFile(path);
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        List<URL> list = registerMap.get(interfaceName);
        Random random = new Random();
        int i = random.nextInt(list.size());
        return list.get(i);
    }

    /**
     * 写入文件
     *
     * @param path
     * @param object
     * @throws IOException
     */
    private void saveFile(String path, Object object) throws IOException {
        FileOutputStream fileOutputStream = new FileOutputStream(new File(path));
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(fileOutputStream);
        objectOutputStream.writeObject(object);
    }

    /**
     * 从文件中读取
     *
     * @param path
     * @return
     * @throws IOException
     * @throws ClassNotFoundException
     */
    private Object readFile(String path) throws IOException, ClassNotFoundException {
        FileInputStream fileInputStream = new FileInputStream(new File(path));
        ObjectInputStream inputStream = new ObjectInputStream(fileInputStream);
        return inputStream.readObject();
    }

    private Object readResolve() {
        return getInstance();
    }
}
/**
 * @Author Nxy
 * @Date 2020/3/23 23:03
 * @Description 本地注册中心,常驻内存,存储接口与实现类的对应关系
 */
public class BasicLocalRegistry implements LocalRegistry {
    private Map<String, Class> registerMap = new HashMap<String, Class>(1024);

    private static BasicLocalRegistry basicLocalRegister;

    @Override
    public void register(String interfaceName, Class interfaceImplClass) {
        registerMap.put(interfaceName, interfaceImplClass);
    }

    @Override
    public Class get(String interfaceName) {
        return registerMap.get(interfaceName);
    }

    public static BasicLocalRegistry getInstance() {
        if (basicLocalRegister == null) {
            synchronized (BasicLocalRegistry.class) {
                if (basicLocalRegister == null) {
                    basicLocalRegister = new BasicLocalRegistry();
                }
            }
        }
        return basicLocalRegister;
    }

    private Object readResolve() {
        return getInstance();
    }
}

  通信数据封装 Invocation:

/**
 * @Author Nxy
 * @Date 2020/3/19 10:57
 * @Description rpc 远程调用参数封装
 */
public class Invocation implements Serializable {
    private static final long serialVersionUID = 75929334234892747L;
    //远程调用接口名称
    private String interfaceName;
    //远程调用方法名称
    private String methodName;
    //方法参数类型列表
    private Class[] paramtypes;
    //方法参数列表
    private Object[] objects;

    /**
     * @param interfaceName 接口名字
     * @param methodName    方法名字
     * @param paramtypes    参数类型列表
     * @param objects       参数列表
     */
    public Invocation(String interfaceName, String methodName, Class[] paramtypes, Object[] objects) {
        this.interfaceName = interfaceName;
        this.methodName = methodName;
        this.paramtypes = paramtypes;
        this.objects = objects;
    }

    public String getInterfaceName() {
        return interfaceName;
    }

    public void setInterfaceName(String interfaceName) {
        this.interfaceName = interfaceName;
    }

    public String getMethodName() {
        return methodName;
    }

    public void setMethodName(String methodName) {
        this.methodName = methodName;
    }

    public Class[] getParamtypes() {
        return paramtypes;
    }

    public void setParamtypes(Class[] paramtypes) {
        this.paramtypes = paramtypes;
    }

    public Object[] getObjects() {
        return objects;
    }

    public void setObjects(Object[] objects) {
        this.objects = objects;
    }
}

  服务端:

public interface RpcServer {
    public void start();
}

  服务端 TCP 实现:

public class TcpServer implements RpcServer {
    private final int port;

    public TcpServer(int port) {
        this.port = port;
    }

    @Override
    public void start() {
        Server server = ServerFactory.getServer(port, SocketType.BIO, DelimitType.LengthFlag, DelimitType.LengthFlag,
                new BasicTcpHandler());
        server.start();
    }
}
public class BasicTcpHandler implements BiFunction {

    @Override
    public Object apply(Object o, Object o2) {
        byte[] bytes = (byte[]) o2;
        Invocation invocation = null;
        String exception = null;
        try {
            ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
            ObjectInputStream ois = new ObjectInputStream(bis);
            invocation = (Invocation) ois.readObject();
            ois.close();
            bis.close();

            String interfaceName = invocation.getInterfaceName();
            String methodName = invocation.getMethodName();
            System.out.println("客户端调用:" + interfaceName + " : " + methodName);
            //从注册中心里面拿到接口的实现类
            Class interfaceImplClass = BasicLocalRegistry.getInstance().get(interfaceName);
            //获取类的方法
            Method method = interfaceImplClass.getMethod(invocation.getMethodName(), invocation.getParamtypes());
            //反射调用方法
            String result = (String) method.invoke(interfaceImplClass.newInstance(), invocation.getObjects());
            return result;
        } catch (IOException ex) {
            ex.printStackTrace();
            exception = ex.getMessage();
        } catch (ClassNotFoundException ex) {
            ex.printStackTrace();
        } catch (InstantiationException ie) {
            ie.printStackTrace();
        } catch (Exception illE) {
            illE.printStackTrace();
        }
        return "500";
    }
}

  服务端Http实现:

public class HttpServer implements RpcServer {
    private final int port;

    public HttpServer(int port) {
        this.port = port;
    }

    @Override
    public void start() {
        Tomcat tomcat = new Tomcat();
        Server server = tomcat.getServer();
        Service service = server.findService("tomcat");
        Connector connector = new Connector();
        connector.setPort(port);
        Engine engine = new StandardEngine();
        engine.setDefaultHost("locahost");

        Host host = new StandardHost();
        host.setName("locahost");
        String contextPath = "";
        Context context = new StandardContext();
        context.setPath(contextPath);
        context.addLifecycleListener(new Tomcat.FixContextListener());

        host.addChild(context);
        engine.addChild(host);

        service.setContainer(engine);
        service.addConnector(connector);

        tomcat.addServlet(contextPath, "dispather", new DispatcherServlet());
        context.addServletMapping("/*", "dispather");
        try {
            //启动tomcat
            tomcat.start();
            tomcat.getServer().await();
        } catch (LifecycleException e) {
            e.printStackTrace();
        }
    }
}
public class BasicHttpHandler {
    public void handler(HttpServletRequest req, HttpServletResponse resp) {
        // 获取对象
        try {
            //从流里面获取数据
            InputStream is = req.getInputStream();
            ObjectInputStream objectInputStream = new ObjectInputStream(is);
            //从流中读取数据反序列话成实体类。
            Invocation invocation = (Invocation) objectInputStream.readObject();
            //拿到服务的名字
            String interfaceName = invocation.getInterfaceName();
            //从注册中心里面拿到接口的实现类
            Class interfaceImplClass = BasicLocalRegistry.getInstance().get(interfaceName);
            //获取类的方法
            Method method = interfaceImplClass.getMethod(invocation.getMethodName(), invocation.getParamtypes());
            //反射调用方法
            String result = (String) method.invoke(interfaceImplClass.newInstance(), invocation.getObjects());
            //把结果返回给调用者
            IOUtils.write(result, resp.getOutputStream());
        } catch (IOException e) {
            e.printStackTrace();
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        } catch (NoSuchMethodException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InstantiationException e) {
            e.printStackTrace();
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        }
    }
}
public class DispatchServlet extends HttpServlet {
    @Override
    protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        //把所有的请求交给HttpHandler接口处理
        new BasicHttpHandler().handler(req, resp);
    }
}

  服务端工厂:

public class ServerFactory {
    public static RpcServer getServer(int port, ServerType serverType) {
        RpcServer server = null;
        switch (serverType) {
            case HTTP:
                server = new HttpServer(port);
            case TCP:
                server = new TcpServer(port);
        }
        return server;
    }

    public enum ServerType {
        TCP, HTTP
    }
}

  客户端:

/**
 * @Author Nxy
 * @Date 2020/3/23 21:10
 * @Description 进行远程方法调用
 */
public interface InvokeRegister<T> {
    public T invoke(URL url, Invocation invocation);
}

  Http实现:

/**
 * @Author Nxy
 * @Date 2020/3/23 21:11
 * @Description 通过 http 协议进行远程调用
 */
public class HttpInvokeRegister<T> implements InvokeRegister<T> {
    @Override
    public T invoke(URL url, Invocation invocation) {
        try {
            T re = (T) HttpUtil.httpPostSerialObject("http://" + url.getHost() + ":" + url.getPort(),
                    1000, 1000, invocation).toString();
            return re;
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }
}

  Tcp实现:

/**
 * @Author Nxy
 * @Date 2020/3/23 21:11
 * @Description 通过 TCP 进行远程调用
 */
public class TcpInvokeRegister<T> implements InvokeRegister<T> {
    @Override
    public T invoke(URL url, Invocation invocation) {
        Client client =
                ClientFactory.getClient(url.getHost(), url.getPort(), SocketType.BIO,
                        DelimitType.LengthFlag, DelimitType.LengthFlag);
        byte[] re = client.send(invocation);
        try {
            ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(re));
            T o = (T) in.readObject();
            return o;
        } catch (IOException e) {
            e.printStackTrace();
        } catch (ClassNotFoundException ce) {
            ce.printStackTrace();
        }
        return null;
    }
}

  客户端工厂:

/**
 * @Author Nxy
 * @Date 2020/3/23 21:09
 * @Description 通过 JDK 动态代理获取远程调用方法的接口实例
 */
public class ProxyFactory<T, R> {
    final InvokeRegister<R> invokeRegister;

    public ProxyFactory(InvokeType invokeType) {
        switch (invokeType) {
            case TCP:
                invokeRegister = new TcpInvokeRegister<R>();
                break;
            case HTTP:
                invokeRegister = new HttpInvokeRegister<R>();
                break;
            default:
                invokeRegister = null;
        }
    }

    /**
     * @Author Nxy
     * @Date 2020/3/23 21:14
     * @Param interfaceClass:调用接口,invokeType:远程调用方式
     * @Return
     * @Exception
     * @Description
     */
    public T getProxy(final Class interfaceClass) {
        return (T) Proxy.newProxyInstance(interfaceClass.getClassLoader(), new Class[]{interfaceClass},
                (Object proxy, Method method, Object[] args) -> {
                    Invocation invocation = new Invocation(interfaceClass.getName(), method.getName(), method.getParameterTypes(), args);
                    WebRegistry remoteRegister = BasicWebRegistry.getInstance();
                    URL randomURL = remoteRegister.getRandomURL(interfaceClass.getName());
                    System.out.println("调用地址host:" + randomURL.getHost() + ",port:" + randomURL.getPort());
                    return invokeRegister.invoke(randomURL, invocation);
                }
        );
    }

    public enum InvokeType {
        HTTP, TCP
    }
}

  工具类:

public class HttpUtil {
    public static Object httpPostSerialObject(String requestUrl, int connTimeoutMills,
                                              int readTimeoutMills, Object serializedObject) throws Exception {
        HttpURLConnection httpUrlConn = null;
        InputStream inputStream = null;
        InputStreamReader inputStreamReader = null;
        BufferedReader bufferedReader = null;
        ObjectOutputStream oos = null;
        StringBuffer buffer = new StringBuffer();
        try {
            URL url = new URL(requestUrl);
            httpUrlConn = (HttpURLConnection) url.openConnection();
            // 设置content_type=SERIALIZED_OBJECT
            // 如果不设此项,在传送序列化对象时,当WEB服务默认的不是这种类型时可能抛java.io.EOFException
            httpUrlConn.setRequestProperty("Content-Type", "application/x-java-serialized-object");
            httpUrlConn.setConnectTimeout(connTimeoutMills);
            httpUrlConn.setReadTimeout(readTimeoutMills);
            // 设置是否向httpUrlConn输出,因为是post请求,参数要放在http正文内,因此需要设为true, 默认情况下是false
            httpUrlConn.setDoOutput(true);
            // 设置是否从httpUrlConn读入,默认情况下是true
            httpUrlConn.setDoInput(true);
            // 不使用缓存
            httpUrlConn.setUseCaches(false);

            // 设置请求方式,默认是GET
            httpUrlConn.setRequestMethod("POST");
            httpUrlConn.connect();

            if (serializedObject != null) {
                // 此处getOutputStream会隐含的进行connect,即:如同调用上面的connect()方法,
                // 所以在开发中不调用上述的connect()也可以,不过建议最好显式调用
                // write object(impl Serializable) using ObjectOutputStream
                oos = new ObjectOutputStream(httpUrlConn.getOutputStream());
                oos.writeObject(serializedObject);
                oos.flush();
                // outputStream不是一个网络流,充其量是个字符串流,往里面写入的东西不会立即发送到网络,
                // 而是存在于内存缓冲区中,待outputStream流关闭时,根据输入的内容生成http正文。所以这里的close是必须的
                oos.close();
            }
            // 将返回的输入流转换成字符串
            // 无论是post还是get,http请求实际上直到HttpURLConnection的getInputStream()这个函数里面才正式发送出去
            inputStream = httpUrlConn.getInputStream();//注意,实际发送请求的代码段就在这里
            inputStreamReader = new InputStreamReader(inputStream, "UTF-8");
            bufferedReader = new BufferedReader(inputStreamReader);
            String str = null;
            while ((str = bufferedReader.readLine()) != null) {
                buffer.append(str);
            }
        } catch (Exception e) {
            throw e;
        } finally {
            try {
                IOUtils.closeQuietly(bufferedReader);
                IOUtils.closeQuietly(inputStreamReader);
                IOUtils.closeQuietly(inputStream);
                IOUtils.closeQuietly(oos);
                if (httpUrlConn != null) {
                    httpUrlConn.disconnect();
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        return buffer.toString();
    }

}

  服务端 Demo :

public class Provider {
    public static void main(String[] args) {
        RpcServer sever = ServerFactory.getServer(80, ServerFactory.ServerType.TCP);
        //获取 远程调度中心 和 本地调度中心 实例
        WebRegistry webRegistry = BasicWebRegistry.getInstance();
        LocalRegistry localRegistry = BasicLocalRegistry.getInstance();
        try {
            //远程注册中心注册本地提供的接口
            webRegistry.register(MyRpcService.class.getName(), UrlUtil.getLocalUrl());
            //本地缓存注册接口对应的实现类
            localRegistry.register(MyRpcService.class.getName(), MyFirstService.class);
        } catch (Exception e) {
            System.out.println("向注册中心注册服务期间发生异常: " + e.getMessage());
            return;
        }
        sever.start();
    }
}

  客户端 Demo :

public class Consumer {
    public static void main(String[] args) {
        MyRpcService helloService =
                new ProxyFactory<MyRpcService, String>(ProxyFactory.InvokeType.TCP).getProxy(MyRpcService.class);
        String result = helloService.sayHellow("liuy");
        System.out.println(result);
    }
}

  调用结果,客户端成功调用获得结果:

posted @ 2020-03-24 11:27  牛有肉  阅读(553)  评论(0编辑  收藏  举报