9、手写基于Netty的RPC框架
测试demo:
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import org.junit.Test;
import java.io.*;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.net.InetSocketAddress;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
/*
假设一个需求,写一个RPC
来回通信,连接数量,拆包
*/
public class MyRPCTest {
@Test
public void startServer() {
NioEventLoopGroup boss = new NioEventLoopGroup(1);
NioEventLoopGroup worker = boss;
ServerBootstrap sbs = new ServerBootstrap();
ChannelFuture bind = sbs.group(boss, worker)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel nioSocketChannel) throws Exception {
System.out.println("server accept client port: " + nioSocketChannel.remoteAddress().getPort());
ChannelPipeline pipeline = nioSocketChannel.pipeline();
pipeline.addLast(new ServerRequestHandler());
}
}).bind(new InetSocketAddress("localhost", 8888));
try {
bind.sync().channel().closeFuture().sync();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
// 模拟客户端
@Test
public void get() {
new Thread(() -> startServer()).start();
System.out.println("server started......");
Thread[] threads = new Thread[20];
for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(() -> {
Car car = proxyGet( Car.class); // 动态代理
car.ooxx("hello");
});
}
for (Thread thread : threads) {
thread.start();
}
try {
// 阻塞住
System.in.read();
} catch (IOException e) {
e.printStackTrace();
}
// Car car = proxyGet(Car.class); // 动态代理
// car.ooxx("hello");
//
// Fly fly = proxyGet(Fly.class); // 动态代理
// fly.xxoo("hello");
}
public <T> T proxyGet(Class<T> interfaceInfo) {
// 实现动态代理
ClassLoader classLoader = interfaceInfo.getClassLoader();
Class<?>[] methodInfo = {interfaceInfo};
return (T) Proxy.newProxyInstance(classLoader, methodInfo, new InvocationHandler() {
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
// 客户端对提供者的调用过程
// 1, 调用服务,方法,参数 封装成message
String name = interfaceInfo.getName();
String methodName = method.getName();
Class<?>[] parameterTypes = method.getParameterTypes();
MyContent content = new MyContent();
content.setArgs(args);
content.setName(name);
content.setMethodName(methodName);
content.setParameterTypes(parameterTypes);
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ObjectOutputStream outputStream = new ObjectOutputStream(byteArrayOutputStream);
outputStream.writeObject(content);
byte[] msgBody = byteArrayOutputStream.toByteArray();
// requestID + message, 本地要缓存
// 协议: header<> msgBody
MyHeader header = createHeader(msgBody);
byteArrayOutputStream.reset();
outputStream = new ObjectOutputStream(byteArrayOutputStream);
outputStream.writeObject(header);
byte[] msgHeader = byteArrayOutputStream.toByteArray();
// 连接池中取得连接
ClientFactory factory = ClientFactory.getFactory();
NioSocketChannel clientChannel = factory.getClient(new InetSocketAddress("localhost", 8888));
// 发送走io, out走netty(event驱动)
ByteBuf byteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(msgHeader.length + msgBody.length);
CountDownLatch countDownLatch = new CountDownLatch(1);
long requestID = header.getRequestID();
ResponseHandler.addCallBack(requestID, new Runnable() {
@Override
public void run() {
countDownLatch.countDown();
}
});
byteBuf.writeBytes(msgHeader);
byteBuf.writeBytes(msgBody);
ChannelFuture channelFuture = clientChannel.writeAndFlush(byteBuf);
channelFuture.sync();
countDownLatch.await();
return null;
}
});
}
private MyHeader createHeader(byte[] msgBody) {
MyHeader header = new MyHeader();
int length = msgBody.length;
int f = 0x14141414;
long requestID = Math.abs(UUID.randomUUID().getLeastSignificantBits());
header.setFlag(f);
header.setDataLen(length);
header.setRequestID(requestID);
return header;
}
}
class ServerRequestHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
ByteBuf buf = (ByteBuf) msg;
ByteBuf sendBuf = buf.copy();
if (buf.readableBytes() >= 110) {
byte[] bytes = new byte[110];
buf.readBytes(bytes);
ByteArrayInputStream in = new ByteArrayInputStream(bytes);
ObjectInputStream oin = new ObjectInputStream(in);
MyHeader header = (MyHeader) oin.readObject();
System.out.println("server response @ id: " + header.requestID);
if (buf.readableBytes() >= header.getDataLen()) {
byte[] data = new byte[(int) header.getDataLen()];
buf.readBytes(data);
ByteArrayInputStream din = new ByteArrayInputStream(data);
ObjectInputStream doin = new ObjectInputStream(din);
MyContent content = ( MyContent) doin.readObject();
System.out.println(content.getName());
}
}
ChannelFuture channelFuture = ctx.writeAndFlush(sendBuf);
channelFuture.sync();
}
}
// 源于spark源码
class ClientFactory {
int poolSize = 1;
Random rand = new Random();
NioEventLoopGroup clientWorker;
private static final ClientFactory factory; // 单例
private ClientFactory() {
}
static {
factory = new ClientFactory();
}
public static ClientFactory getFactory() {
return factory;
}
// 一个客户端可以连接很多提供者,每个提供者都有自己的pool
ConcurrentHashMap<InetSocketAddress, ClientPool> outboxs = new ConcurrentHashMap<>();
public synchronized NioSocketChannel getClient(InetSocketAddress address) {
ClientPool clientPool = outboxs.get(address);
if (clientPool == null) {
outboxs.putIfAbsent(address, new ClientPool(poolSize));
clientPool = outboxs.get(address);
}
int i = rand.nextInt(poolSize);
if (clientPool.clients[i] != null && clientPool.clients[i].isActive()) {
return clientPool.clients[i];
}
synchronized (clientPool.lock[i]) {
return clientPool.clients[i] = create(address);
}
}
private NioSocketChannel create(InetSocketAddress address) {
// 基于netty客户端创建
clientWorker = new NioEventLoopGroup(1);
Bootstrap bootstrap = new Bootstrap();
ChannelFuture connect = bootstrap.group(clientWorker)
.channel(NioSocketChannel.class)
.handler(new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel nioSocketChannel) throws Exception {
ChannelPipeline pipeline = nioSocketChannel.pipeline();
pipeline.addLast(new ClientResponses()); // 解决给谁的
}
}).connect(address);
try {
NioSocketChannel client = (NioSocketChannel) connect.sync().channel();
return client;
} catch (InterruptedException e) {
e.printStackTrace();
}
return null;
}
}
class ClientResponses extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
ByteBuf buf = (ByteBuf) msg;
if (buf.readableBytes() >= 110) {
byte[] bytes = new byte[110];
buf.readBytes(bytes);
ByteArrayInputStream in = new ByteArrayInputStream(bytes);
ObjectInputStream oin = new ObjectInputStream(in);
MyHeader header = (MyHeader) oin.readObject();
System.out.println("client response @ id: " + header.requestID);
ResponseHandler.runCallBack(header.requestID);
// if (buf.readableBytes() >= header.getDataLen()) {
// byte[] data = new byte[(int) header.getDataLen()];
// buf.readBytes(data);
// ByteArrayInputStream din = new ByteArrayInputStream(data);
// ObjectInputStream doin = new ObjectInputStream(din);
// MyContent content = (MyContent) doin.readObject();
// System.out.println(content.getName());
// }
}
super.channelRead(ctx, msg);
}
}
class ResponseHandler {
static ConcurrentHashMap<Long, Runnable> mapping = new ConcurrentHashMap<>();
public static void addCallBack(long requestID, Runnable cb) {
mapping.putIfAbsent(requestID, cb);
}
public static void runCallBack(long requestID) {
Runnable runnable = mapping.get(requestID);
runnable.run();
removeCB(requestID);
}
private static void removeCB(long requestID) {
mapping.remove(requestID);
}
}
class ClientPool {
NioSocketChannel[] clients;
Object[] lock;
ClientPool(int size) {
this.clients = new NioSocketChannel[size]; // init 连接是空的
this.lock = new Object[size]; // 锁是初始化的
for (int i = 0; i < size; i++) {
lock[i] = new Object();
}
}
}
class MyHeader implements Serializable {
/*
通信协议
1, ooxx值
2, UUID
3, DATA_LEN
*/
int flag; // 32bit可以设置很多信息
long requestID;
long dataLen;
public int getFlag() {
return flag;
}
public void setFlag(int flag) {
this.flag = flag;
}
public long getRequestID() {
return requestID;
}
public void setRequestID(long requestID) {
this.requestID = requestID;
}
public long getDataLen() {
return dataLen;
}
public void setDataLen(long dataLen) {
this.dataLen = dataLen;
}
}
class MyContent implements Serializable {
String name;
String methodName;
Class<?>[] parameterTypes;
Object[] args;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getMethodName() {
return methodName;
}
public void setMethodName(String methodName) {
this.methodName = methodName;
}
public Class<?>[] getParameterTypes() {
return parameterTypes;
}
public void setParameterTypes(Class<?>[] parameterTypes) {
this.parameterTypes = parameterTypes;
}
public Object[] getArgs() {
return args;
}
public void setArgs(Object[] args) {
this.args = args;
}
}
interface Car {
public void ooxx(String msg);
}
interface Fly {
public void xxoo(String msg);
}
《三体》中有句话——弱小和无知不是生存的障碍,傲慢才是。
所以我们不要做一个小青蛙