200行代码实现RPC框架

之前因为项目需要,基于zookeeper和thrift协议实现了一个简单易用的RPC框架,核心代码不超过200行。

zookeeper主要作用是服务发现,thrift协议作为通信传输协议, 基于commons pool2构建连接池。

大家感兴趣的话可以参考,具体代码如下:

/**
 * @author zhangkai
 * 抽象的thrift client,内置socket连接池以及线程池,提供同步阻塞式调用和超时调用
 * 具体thrift client需要继承该类并实现其中的抽象方法并按照需要重写相关方法
 */
public abstract class AbstractThriftClient {
    private final static int MAX_FRAME_SIZE = 1024 * 1024 * 1024;
    private final static int MIN_FRAME_SIZE = 1024;
    
    protected ThreadPoolExecutor executor;
    protected AbstractThriftClient client = this;
    protected ClientConfig clientConfig;
    protected CuratorFramework zkClient;
    protected List<TConnectionPool> shardInfos = Lists.newArrayList();
    
    /**
     * AbstractThriftClient的构造函数
     * 初始化线程池、连接池以及服务发现机制
     */
    protected AbstractThriftClient(ClientConfig clientConfig) {        
        int processors = Runtime.getRuntime().availableProcessors(); 
        this.executor = new ThreadPoolExecutor(processors * 5, processors * 10, 60L, TimeUnit.SECONDS, 
                new ArrayBlockingQueue<Runnable>(processors * 100),
                Executors.defaultThreadFactory(), new ThreadPoolExecutor.CallerRunsPolicy());
        this.clientConfig = clientConfig;
        this.zkClient = CuratorFrameworkFactory.builder()
                .connectString(clientConfig.getZkAddrs())
                .retryPolicy(new ExponentialBackoffRetry(500, 4)).build();
        this.zkClient.start();
        buildConnPool();
    }
    
    /**
     * 唯一需要上层实现的抽象类
     * 该方法接收封装好的RPCRequest
     * 调用真实的RPC请求
     * 将RPC服务返回的结果打包成RPCResponse
     * 上层的具体thrift client实例需要实现该方法
     */
    protected abstract RPCResponse doService(RPCRequest rpcRequest, TProtocol protocol) throws Exception;
    
    /**
     * 从连接池中选择连接的方法,
     * 上层可以重写该方法,实现自己的hash规则
     */
    protected int hashRule(RPCRequest request){
        Random rand = new Random();
        return rand.nextInt(shardInfos.size());
    }
    
    /**
     * processRequest方法处理流程:
     * 1、从连接池中获取连接
     * 2、创建相应的Transport协议结构
     * 3、调用doService方法获取RPC的返回结果
     * @param rpcRequest
     * @return
     */
    protected RPCResponse processRequest(RPCRequest rpcRequest){
        String serviceName = rpcRequest.getServiceName();
        RPCResponse response = new RPCResponse();
        if(serviceName == null){
            LogUtils.warn("serviceName can not be null");
            response.setCode(RPCResponse.FAILED);
            return response;
        }
        TConnectionPool connPool = getConnPool(rpcRequest);
        if(connPool == null){
            response.setCode(RPCResponse.FAILED);
            return response;
        }
        TSocket socket = connPool.getSocket();
        try {
            TTransport transport = new TFastFramedTransport(socket, MIN_FRAME_SIZE, MAX_FRAME_SIZE);
            if (!transport.isOpen()) {
                transport.open();
            }
            TProtocol protocol = new TBinaryProtocol(transport);
            return this.doService(rpcRequest, protocol);
        } catch (Exception e) {
            LogUtils.error("", e);
            connPool.removeSocket(socket);
            response.setCode(RPCResponse.FAILED);
            return response;
        } finally {
            if (socket.isOpen()) {
                connPool.returnSocket(socket);
            }
        }
    }
    
    protected RPCResponse sendRequest(RPCRequest request){
        if(clientConfig.getRequestTimeout() <= 0){
            return this.processRequest(request);
        }else{
            return this.processRequestTimeout(request, clientConfig.getRequestTimeout());
        }
    }
    
    private TConnectionPool getConnPool(RPCRequest request){
        if(shardInfos.size() <= 0){
            LogUtils.warn("no valid node available");
            return null;
        }
        int index = hashRule(request);
        return shardInfos.get(index % shardInfos.size());
    }
    
    private RPCResponse processRequestTimeout(RPCRequest request, int timeout){
        RPCRequestTask rpcRequestTask = new RPCRequestTask(request);
        Future<RPCResponse> future = executor.submit(rpcRequestTask);
        
        try {
            RPCResponse response = future.get(clientConfig.getRequestTimeout(), TimeUnit.MILLISECONDS);
            return response;
        } catch (InterruptedException e) {
            LogUtils.warn("[ExecutorService]The current thread was interrupted while waiting: ", e);
            RPCResponse response = new RPCResponse();
            response.setCode(RPCResponse.FAILED);
            return response;
        } catch (ExecutionException e) {
            LogUtils.warn("[ExecutorService]The computation threw an exception: ", e);
            RPCResponse response = new RPCResponse();
            response.setCode(RPCResponse.FAILED);
            return response;
        } catch (TimeoutException e) {
            LogUtils.warn("[ExecutorService]The wait " + this.clientConfig.getRequestTimeout() + " timed out: ", e);
            RPCResponse response = new RPCResponse();
            response.setCode(RPCResponse.FAILED);
            return response;
        } catch(Exception e){
            LogUtils.warn("[ExecutorService] failed", e);
            RPCResponse response = new RPCResponse();
            response.setCode(RPCResponse.FAILED);
            return response;
        }
    }
        
    private class RPCRequestTask implements Callable<RPCResponse> {
        private RPCRequest rpcRequest;
        
        public RPCRequestTask(RPCRequest request) {
            this.rpcRequest = request;
        }
        
        @Override
        public RPCResponse call() {
            return client.processRequest(rpcRequest);
        }
    };
    
    private void buildConnPool(){
        try{
            List<String> nodes = zkClient
                    .getChildren()
                    .usingWatcher(new Watcher(){
                        @Override
                        public void process(WatchedEvent event) {
                            if(event.getType() == EventType.NodeChildrenChanged){
                                buildConnPool();
                            }
                        }})
                    .forPath(clientConfig.getZkNamespace());
            List<TConnectionPool> currShardInfos = Lists.newArrayList();
            for(String node : nodes){
                String path = clientConfig.getZkNamespace() + "/" + node;
                byte[] dataArray = zkClient.getData().forPath(path);
                String dataStr = new String(dataArray);
                RegistryInfo info = JsonUtils.fromJson(dataStr, RegistryInfo.class);
                TServerInfo server = new TServerInfo(info.getIp(), info.getPort());
                currShardInfos.add(new TConnectionPool(server));
            }
            this.shardInfos = currShardInfos;
        }catch(Exception e){
            LogUtils.error("build conn pool failed", e);
        }
    }
}

完整的代码和demo可以参考:https://github.com/zhangkai253/simpleRPC

posted @ 2018-08-20 19:43  一只笨鼠  阅读(693)  评论(0编辑  收藏  举报