手写一个 RPC 通信框架
在使用微服务的过程中,RPC 是永远绕不开的点。之前并没有磕的很深,一直觉着 RPC 是一个黑魔法。比如我们常用的 Dubbo、SpringCloud 等框架,将微服务模块间的方法调用封装的像本地方法调用一样,方便又令人费解。
今天如愿以偿的仿照 Dubbo 自己手写了一个 rpc “框架”,虽然简陋但终于觉着抓住了 rpc 的精髓。同时也觉着,平时修炼一些 “无用” 的内功确实是很重要的,比如操作系统、计算机原理、计算机网络等等这些硬菜。它们让我真正动手进行这次实现时并不是很吃力(想想刚入行时就妄想手写Dubbo,无疾而终好多次,感觉当时真是浮躁)。
总结一下实现的关键部分:
1.调度中心没有使用 zookeeper,只是通过文件进行 服务端/客户端 进程间的数据共享以方便测试,因为本次实现的重点并不在调度中心。
2. 客户端与服务端的通信需要定义固定的数据结构,将数据结构封装在 Invocation 类中,传输该类序列化对象进行通信。
3. 支持通过 Http 以及 自定义协议 两种方式进行远程调用。Http 虽然非常通用,但对于过程调用来说太 “重” 了,绝大部分控制位不会被使用到,会浪费传输带宽。Dubbo 便在 TCP 协议的基础上封装了 Dubbo 协议用于更加高效的进行过程调用。这里我们自定义一个“协议”,因为我们的实现更简单或者说不太完善,需要的控制位更少。报文格式如下:
基于自定义协议通信需要基于 Socket 通信框架,Dubbo 使用的是大佬 netty,这里为了方便我自己封装了一个简单的 TCP 通信框架,详见:手写一个模块化的 TCP 服务端客户端。
4. 服务端通过 Invocation 实例中的数据对目标方法进行反射调用。
5. 客户端通过动态代理技术提供语法糖。动态代理技术为本地接口生成一个实现类,实现类中的方法都远程调用服务端暴露的服务,让我们在调用远程方式时感觉像在调用本地方法一样。客户端调用效果:
6. 本次实现重点是进行 RPC 调用的封装,因此许多其它细节并没有打磨。比如:
注册治理中心简单的使用文件的方式,会并发效率非常低;
客户端在进行调用时,只提供了一种随机选取服务端的负载均衡策略,这在实际生产场景中是最少使用到的负载均衡策略,但好在实现简单;
客户端应当维护一个与服务端通信的长连接,因为过程调用可能会非常频繁,频繁的创建和销毁连接会浪费非常多的性能,本次实现都是以短连接的方式实现的;
可以自定义注解来标识对外提供的服务,服务端启动时动态扫描并将这些服务注册到注册中心。这个后续会做,目前还是采用手工配置的方式;
通信报文内容可以更加丰富,比如添加标识调用成功或失败的标识位。
7. 总的来说,调用过程如下:
下面是实现,首先是注册中心:
/** * @Author Nxy * @Date 2020/3/19 11:00 * @Description 服务端通过服务名找到对应实现 */ public interface LocalRegistry { /** * @Author Nxy * @Date 2020/3/19 11:03 * @Description 注册到服务中心 */ void register(String interfaceName, Class interfaceImplClass); /** * @Author Nxy * @Date 2020/3/19 11:03 * @Description 根据服务名称获取实现类 */ Class get(String interfaceName); }
/** * @Author Nxy * @Date 2020/3/19 11:02 * @Description 服务调用方通过服务名找到调用地址 */ public interface WebRegistry { /** * @Author Nxy * @Date 2020/3/19 11:02 * @Description 注册到治理中心 */ void register(String interfaceName, URL host); /** * @Author Nxy * @Date 2020/3/19 11:02 * @Description 根据服务名找到调用地址,当调用地址有多个时,采用随机选取的负载均衡策略 */ URL getRandomURL(String interfaceName); }
注册中心实现类:
/** * @Author Nxy * @Date 2020/3/23 23:02 * @Description 简单的远程注册中心 * 注册内容写入文件,用于进程间共享 */ public class BasicWebRegistry implements WebRegistry { private Map<String, List<URL>> registerMap = new HashMap<String, List<URL>>(1024); public static final String path = "C://tmp//rpc"; private static BasicWebRegistry basicWebRegister; public static BasicWebRegistry getInstance() { if (basicWebRegister == null) { synchronized (BasicWebRegistry.class) { if (basicWebRegister == null) { basicWebRegister = new BasicWebRegistry(); } } } return basicWebRegister; } @Override public void register(String interfaceName, URL host) { if (registerMap.containsKey(interfaceName)) { List<URL> list = registerMap.get(interfaceName); list.add(host); } else { List<URL> list = new LinkedList<URL>(); list.add(host); registerMap.put(interfaceName, list); } try { saveFile(path, registerMap); } catch (IOException e) { e.printStackTrace(); } } @Override public URL getRandomURL(String interfaceName) { int i = 0; //尝试 5 次读取注册文件 try { while (i < 5) { return getRandomURLOnce(interfaceName); } } catch (IOException e) { e.printStackTrace(); i++; } return null; } public URL getRandomURLOnce(String interfaceName) throws IOException { try { registerMap = (Map<String, List<URL>>) readFile(path); } catch (ClassNotFoundException e) { e.printStackTrace(); } List<URL> list = registerMap.get(interfaceName); Random random = new Random(); int i = random.nextInt(list.size()); return list.get(i); } /** * 写入文件 * * @param path * @param object * @throws IOException */ private void saveFile(String path, Object object) throws IOException { FileOutputStream fileOutputStream = new FileOutputStream(new File(path)); ObjectOutputStream objectOutputStream = new ObjectOutputStream(fileOutputStream); objectOutputStream.writeObject(object); } /** * 从文件中读取 * * @param path * @return * @throws IOException * @throws ClassNotFoundException */ private Object readFile(String path) throws IOException, ClassNotFoundException { FileInputStream fileInputStream = new FileInputStream(new File(path)); ObjectInputStream inputStream = new ObjectInputStream(fileInputStream); return inputStream.readObject(); } private Object readResolve() { return getInstance(); } }
/** * @Author Nxy * @Date 2020/3/23 23:03 * @Description 本地注册中心,常驻内存,存储接口与实现类的对应关系 */ public class BasicLocalRegistry implements LocalRegistry { private Map<String, Class> registerMap = new HashMap<String, Class>(1024); private static BasicLocalRegistry basicLocalRegister; @Override public void register(String interfaceName, Class interfaceImplClass) { registerMap.put(interfaceName, interfaceImplClass); } @Override public Class get(String interfaceName) { return registerMap.get(interfaceName); } public static BasicLocalRegistry getInstance() { if (basicLocalRegister == null) { synchronized (BasicLocalRegistry.class) { if (basicLocalRegister == null) { basicLocalRegister = new BasicLocalRegistry(); } } } return basicLocalRegister; } private Object readResolve() { return getInstance(); } }
通信数据封装 Invocation:
/** * @Author Nxy * @Date 2020/3/19 10:57 * @Description rpc 远程调用参数封装 */ public class Invocation implements Serializable { private static final long serialVersionUID = 75929334234892747L; //远程调用接口名称 private String interfaceName; //远程调用方法名称 private String methodName; //方法参数类型列表 private Class[] paramtypes; //方法参数列表 private Object[] objects; /** * @param interfaceName 接口名字 * @param methodName 方法名字 * @param paramtypes 参数类型列表 * @param objects 参数列表 */ public Invocation(String interfaceName, String methodName, Class[] paramtypes, Object[] objects) { this.interfaceName = interfaceName; this.methodName = methodName; this.paramtypes = paramtypes; this.objects = objects; } public String getInterfaceName() { return interfaceName; } public void setInterfaceName(String interfaceName) { this.interfaceName = interfaceName; } public String getMethodName() { return methodName; } public void setMethodName(String methodName) { this.methodName = methodName; } public Class[] getParamtypes() { return paramtypes; } public void setParamtypes(Class[] paramtypes) { this.paramtypes = paramtypes; } public Object[] getObjects() { return objects; } public void setObjects(Object[] objects) { this.objects = objects; } }
服务端:
public interface RpcServer { public void start(); }
服务端 TCP 实现:
public class TcpServer implements RpcServer { private final int port; public TcpServer(int port) { this.port = port; } @Override public void start() { Server server = ServerFactory.getServer(port, SocketType.BIO, DelimitType.LengthFlag, DelimitType.LengthFlag, new BasicTcpHandler()); server.start(); } }
public class BasicTcpHandler implements BiFunction { @Override public Object apply(Object o, Object o2) { byte[] bytes = (byte[]) o2; Invocation invocation = null; String exception = null; try { ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ObjectInputStream ois = new ObjectInputStream(bis); invocation = (Invocation) ois.readObject(); ois.close(); bis.close(); String interfaceName = invocation.getInterfaceName(); String methodName = invocation.getMethodName(); System.out.println("客户端调用:" + interfaceName + " : " + methodName); //从注册中心里面拿到接口的实现类 Class interfaceImplClass = BasicLocalRegistry.getInstance().get(interfaceName); //获取类的方法 Method method = interfaceImplClass.getMethod(invocation.getMethodName(), invocation.getParamtypes()); //反射调用方法 String result = (String) method.invoke(interfaceImplClass.newInstance(), invocation.getObjects()); return result; } catch (IOException ex) { ex.printStackTrace(); exception = ex.getMessage(); } catch (ClassNotFoundException ex) { ex.printStackTrace(); } catch (InstantiationException ie) { ie.printStackTrace(); } catch (Exception illE) { illE.printStackTrace(); } return "500"; } }
服务端Http实现:
public class HttpServer implements RpcServer { private final int port; public HttpServer(int port) { this.port = port; } @Override public void start() { Tomcat tomcat = new Tomcat(); Server server = tomcat.getServer(); Service service = server.findService("tomcat"); Connector connector = new Connector(); connector.setPort(port); Engine engine = new StandardEngine(); engine.setDefaultHost("locahost"); Host host = new StandardHost(); host.setName("locahost"); String contextPath = ""; Context context = new StandardContext(); context.setPath(contextPath); context.addLifecycleListener(new Tomcat.FixContextListener()); host.addChild(context); engine.addChild(host); service.setContainer(engine); service.addConnector(connector); tomcat.addServlet(contextPath, "dispather", new DispatcherServlet()); context.addServletMapping("/*", "dispather"); try { //启动tomcat tomcat.start(); tomcat.getServer().await(); } catch (LifecycleException e) { e.printStackTrace(); } } }
public class BasicHttpHandler { public void handler(HttpServletRequest req, HttpServletResponse resp) { // 获取对象 try { //从流里面获取数据 InputStream is = req.getInputStream(); ObjectInputStream objectInputStream = new ObjectInputStream(is); //从流中读取数据反序列话成实体类。 Invocation invocation = (Invocation) objectInputStream.readObject(); //拿到服务的名字 String interfaceName = invocation.getInterfaceName(); //从注册中心里面拿到接口的实现类 Class interfaceImplClass = BasicLocalRegistry.getInstance().get(interfaceName); //获取类的方法 Method method = interfaceImplClass.getMethod(invocation.getMethodName(), invocation.getParamtypes()); //反射调用方法 String result = (String) method.invoke(interfaceImplClass.newInstance(), invocation.getObjects()); //把结果返回给调用者 IOUtils.write(result, resp.getOutputStream()); } catch (IOException e) { e.printStackTrace(); } catch (ClassNotFoundException e) { e.printStackTrace(); } catch (NoSuchMethodException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } catch (InstantiationException e) { e.printStackTrace(); } catch (InvocationTargetException e) { e.printStackTrace(); } } }
public class DispatchServlet extends HttpServlet { @Override protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { //把所有的请求交给HttpHandler接口处理 new BasicHttpHandler().handler(req, resp); } }
服务端工厂:
public class ServerFactory { public static RpcServer getServer(int port, ServerType serverType) { RpcServer server = null; switch (serverType) { case HTTP: server = new HttpServer(port); case TCP: server = new TcpServer(port); } return server; } public enum ServerType { TCP, HTTP } }
客户端:
/** * @Author Nxy * @Date 2020/3/23 21:10 * @Description 进行远程方法调用 */ public interface InvokeRegister<T> { public T invoke(URL url, Invocation invocation); }
Http实现:
/** * @Author Nxy * @Date 2020/3/23 21:11 * @Description 通过 http 协议进行远程调用 */ public class HttpInvokeRegister<T> implements InvokeRegister<T> { @Override public T invoke(URL url, Invocation invocation) { try { T re = (T) HttpUtil.httpPostSerialObject("http://" + url.getHost() + ":" + url.getPort(), 1000, 1000, invocation).toString(); return re; } catch (Exception e) { e.printStackTrace(); } return null; } }
Tcp实现:
/** * @Author Nxy * @Date 2020/3/23 21:11 * @Description 通过 TCP 进行远程调用 */ public class TcpInvokeRegister<T> implements InvokeRegister<T> { @Override public T invoke(URL url, Invocation invocation) { Client client = ClientFactory.getClient(url.getHost(), url.getPort(), SocketType.BIO, DelimitType.LengthFlag, DelimitType.LengthFlag); byte[] re = client.send(invocation); try { ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(re)); T o = (T) in.readObject(); return o; } catch (IOException e) { e.printStackTrace(); } catch (ClassNotFoundException ce) { ce.printStackTrace(); } return null; } }
客户端工厂:
/** * @Author Nxy * @Date 2020/3/23 21:09 * @Description 通过 JDK 动态代理获取远程调用方法的接口实例 */ public class ProxyFactory<T, R> { final InvokeRegister<R> invokeRegister; public ProxyFactory(InvokeType invokeType) { switch (invokeType) { case TCP: invokeRegister = new TcpInvokeRegister<R>(); break; case HTTP: invokeRegister = new HttpInvokeRegister<R>(); break; default: invokeRegister = null; } } /** * @Author Nxy * @Date 2020/3/23 21:14 * @Param interfaceClass:调用接口,invokeType:远程调用方式 * @Return * @Exception * @Description */ public T getProxy(final Class interfaceClass) { return (T) Proxy.newProxyInstance(interfaceClass.getClassLoader(), new Class[]{interfaceClass}, (Object proxy, Method method, Object[] args) -> { Invocation invocation = new Invocation(interfaceClass.getName(), method.getName(), method.getParameterTypes(), args); WebRegistry remoteRegister = BasicWebRegistry.getInstance(); URL randomURL = remoteRegister.getRandomURL(interfaceClass.getName()); System.out.println("调用地址host:" + randomURL.getHost() + ",port:" + randomURL.getPort()); return invokeRegister.invoke(randomURL, invocation); } ); } public enum InvokeType { HTTP, TCP } }
工具类:
public class HttpUtil { public static Object httpPostSerialObject(String requestUrl, int connTimeoutMills, int readTimeoutMills, Object serializedObject) throws Exception { HttpURLConnection httpUrlConn = null; InputStream inputStream = null; InputStreamReader inputStreamReader = null; BufferedReader bufferedReader = null; ObjectOutputStream oos = null; StringBuffer buffer = new StringBuffer(); try { URL url = new URL(requestUrl); httpUrlConn = (HttpURLConnection) url.openConnection(); // 设置content_type=SERIALIZED_OBJECT // 如果不设此项,在传送序列化对象时,当WEB服务默认的不是这种类型时可能抛java.io.EOFException httpUrlConn.setRequestProperty("Content-Type", "application/x-java-serialized-object"); httpUrlConn.setConnectTimeout(connTimeoutMills); httpUrlConn.setReadTimeout(readTimeoutMills); // 设置是否向httpUrlConn输出,因为是post请求,参数要放在http正文内,因此需要设为true, 默认情况下是false httpUrlConn.setDoOutput(true); // 设置是否从httpUrlConn读入,默认情况下是true httpUrlConn.setDoInput(true); // 不使用缓存 httpUrlConn.setUseCaches(false); // 设置请求方式,默认是GET httpUrlConn.setRequestMethod("POST"); httpUrlConn.connect(); if (serializedObject != null) { // 此处getOutputStream会隐含的进行connect,即:如同调用上面的connect()方法, // 所以在开发中不调用上述的connect()也可以,不过建议最好显式调用 // write object(impl Serializable) using ObjectOutputStream oos = new ObjectOutputStream(httpUrlConn.getOutputStream()); oos.writeObject(serializedObject); oos.flush(); // outputStream不是一个网络流,充其量是个字符串流,往里面写入的东西不会立即发送到网络, // 而是存在于内存缓冲区中,待outputStream流关闭时,根据输入的内容生成http正文。所以这里的close是必须的 oos.close(); } // 将返回的输入流转换成字符串 // 无论是post还是get,http请求实际上直到HttpURLConnection的getInputStream()这个函数里面才正式发送出去 inputStream = httpUrlConn.getInputStream();//注意,实际发送请求的代码段就在这里 inputStreamReader = new InputStreamReader(inputStream, "UTF-8"); bufferedReader = new BufferedReader(inputStreamReader); String str = null; while ((str = bufferedReader.readLine()) != null) { buffer.append(str); } } catch (Exception e) { throw e; } finally { try { IOUtils.closeQuietly(bufferedReader); IOUtils.closeQuietly(inputStreamReader); IOUtils.closeQuietly(inputStream); IOUtils.closeQuietly(oos); if (httpUrlConn != null) { httpUrlConn.disconnect(); } } catch (Exception e) { e.printStackTrace(); } } return buffer.toString(); } }
服务端 Demo :
public class Provider { public static void main(String[] args) { RpcServer sever = ServerFactory.getServer(80, ServerFactory.ServerType.TCP); //获取 远程调度中心 和 本地调度中心 实例 WebRegistry webRegistry = BasicWebRegistry.getInstance(); LocalRegistry localRegistry = BasicLocalRegistry.getInstance(); try { //远程注册中心注册本地提供的接口 webRegistry.register(MyRpcService.class.getName(), UrlUtil.getLocalUrl()); //本地缓存注册接口对应的实现类 localRegistry.register(MyRpcService.class.getName(), MyFirstService.class); } catch (Exception e) { System.out.println("向注册中心注册服务期间发生异常: " + e.getMessage()); return; } sever.start(); } }
客户端 Demo :
public class Consumer { public static void main(String[] args) { MyRpcService helloService = new ProxyFactory<MyRpcService, String>(ProxyFactory.InvokeType.TCP).getProxy(MyRpcService.class); String result = helloService.sayHellow("liuy"); System.out.println(result); } }
调用结果,客户端成功调用获得结果: