用Golang手写一个RPC,理解RPC原理

代码结构#

Copy
. ├── client.go ├── coder.go ├── coder_test.go ├── rpc_test.go ├── server.go ├── session.go └── session_test.go

代码#

client.go#

Copy
package rpc import ( "net" "reflect" ) // rpc 客户端实现 // 抽象客户端方法 type Client struct { conn net.Conn } // client构造方法 func NewClient(conn net.Conn) *Client { return &Client{conn: conn} } // 客户端调用服务端rpc实现 // client.RpcCall("login", &req) func (c *Client) RpcCall(name string, fpr interface{}) { // 反射获取函数原型 fn := reflect.ValueOf(fpr).Elem() // 客户端逻辑的实现 f := func(args []reflect.Value) (results []reflect.Value) { // 从匿名函数中构建请求参数 inArgs := make([]interface{}, 0, len(args)) for _, v := range args { inArgs = append(inArgs, v.Interface()) } // 组装rpc data请求数据 reqData := RpcData{Name: name, Args: inArgs} // 进行数据编码 reqByteData, err := encode(reqData) if err != nil { return } // 创建session 对象 session := NewSession(c.conn) // 客户端发送数据 err = session.Write(reqByteData) if err != nil { return } // 读取客户端数据 rspByteData, err := session.Read() if err != nil { return } // 数据进行解码 rspData, err := decode(rspByteData) if err != nil { return } // 处理服务端返回的数据结果 outArgs := make([]reflect.Value, 0, len(rspData.Args)) for i, v := range rspData.Args { // 数据特殊情况处理 if v == nil { // reflect.Zero() 返回某类型的零值的value // .Out()返回函数输出的参数类型 // 得到具体第几个位置的参数的零值 outArgs = append(outArgs, reflect.Zero(fn.Type().Out(i))) continue } outArgs = append(outArgs, reflect.ValueOf(v)) } return outArgs } // 函数原型到调用的关键,需要2个参数 // 参数1:函数原型,是Type类型 // 参数2:返回类型是Value类型 // 简单理解:参数1是函数原型,参数2是客户端逻辑 v := reflect.MakeFunc(fn.Type(), f) fn.Set(v) }

coder.go#

Copy
package rpc import ( "bytes" "encoding/gob" "fmt" ) // 对传输的数据进行编解码 // 使用Golang自带的一个数据结构序列化编码/解码工具 gob // 定义rpc数据交互式数据传输格式 type RpcData struct { Name string // 调用方法名 Args []interface{} // 调用和返回的参数列表 } // 编码 func encode(data RpcData) ([]byte, error) { // gob进行编码 var buf bytes.Buffer // 得到字节编码器 encoder := gob.NewEncoder(&buf) // 进行编码 if err := encoder.Encode(data); err != nil { fmt.Printf("gob encode failed, err: %v\n", err) return nil, err } return buf.Bytes(), nil } // 解码 func decode(data []byte) (RpcData, error) { // 得到字节解码器 buf := bytes.NewBuffer(data) decoder := gob.NewDecoder(buf) // 解码数据 var rd RpcData if err := decoder.Decode(&rd); err != nil { fmt.Printf("gob decode failed, err: %v\n", err) return rd, err } return rd, nil }

server.go#

Copy
package rpc import ( "net" "reflect" ) // rpc 服务端实现 // 抽象服务端 type Server struct { add string // 连接地址 funcs map[string]reflect.Value // 存储方法名和方法的对应关系,服务注册 } // server 构造方法 func NewServer(addr string) *Server { return &Server{add: addr, funcs: make(map[string]reflect.Value)} } // 注册接口 func (s *Server) Register(name string, fc interface{}) { if _, ok := s.funcs[name]; ok { return } s.funcs[name] = reflect.ValueOf(fc) } func (s *Server) Run() (err error) { listener, err := net.Listen("tcp", s.add) if err != nil { return } for { // 监听连接 conn, err := listener.Accept() if err != nil { conn.Close() continue } // 创建会话 session := NewSession(conn) // 读取会话请求数据 reqData, err := session.Read() if err != nil { conn.Close() continue } // 数据解码 rpcReqData, err := decode(reqData) // 获取客户端要调用的方法 fc, ok := s.funcs[rpcReqData.Name]; if !ok { conn.Close() continue } // 获取请求的参数列表 args := make([]reflect.Value, 0, len(rpcReqData.Args)) for _, v := range rpcReqData.Args { args = append(args, reflect.ValueOf(v)) } // 调用 callReslut := fc.Call(args) // 处理调用返回的数据结果 rargs := make([]interface{}, 0, len(callReslut)) for _, rv := range callReslut { rargs = append(rargs, rv.Interface()) } // 构建返回的rpc数据 rpcRspData := RpcData{Name: rpcReqData.Name, Args: rargs} // 返回数据进行编码 rspData, err := encode(rpcRspData) if err != nil { conn.Close() continue } err = session.Write(rspData) if err != nil { conn.Close() continue } } return }

session.go#

Copy
package rpc import ( "encoding/binary" "fmt" "io" "net" ) // 处理连接会话 // 会话对象结构体 type Session struct { conn net.Conn } // 传输数据存储方式 // 字节数组, 添加4个字节的头,用来存储数据的长度 // 会话构造函数 func NewSession(conn net.Conn) *Session { return &Session{conn: conn} } // 从连接中读取数据 func (s *Session) Read() (data []byte, err error) { // 读取数据header数据 header := make([]byte, 4) _, err = s.conn.Read(header) if err != nil { fmt.Printf("read conn header data failed, err: %v\n", err) return } // 读取body数据 hlen := binary.BigEndian.Uint32(header) data = make([]byte, hlen) _, err = io.ReadFull(s.conn, data) if err != nil { fmt.Printf("read conn body data failed, err: %v\n", err) return } return } // 向连接中写入数据 func (s *Session) Write(data []byte) (err error) { // 创建数据字节切片 buf := make([]byte, 4+len(data)) // 向header写入数据长度 binary.BigEndian.PutUint32(buf[:4], uint32(len(data))) // 写入body内容 copy(buf[4:], data) // 写入连接数据 _, err = s.conn.Write(buf) if err != nil { fmt.Printf("write conn data failed, err: %v\n", err) return } return }

coder_test.go#

Copy
package rpc import ( "testing" ) func TestCoder(t *testing.T) { rd := RpcData{ Name: "login", Args: []interface{}{"zhangsan", "zs123"}, } eData, err := encode(rd) if err != nil { t.Error(err) return } t.Logf("gob 编码后数据长度: %d\n", len(eData)) dData, err := decode(eData) if err != nil { t.Error(err) return } t.Logf("%#v\n", dData) }

session_test.go#

Copy
package rpc import ( "net" "sync" "testing" ) func TestSession(t *testing.T) { addr := ":8080" test_data := "my is test data" var wg sync.WaitGroup wg.Add(2) // 写数据 go func() { defer wg.Done() listener, err := net.Listen("tcp", addr) if err != nil { t.Fatal(err) return } conn, _ := listener.Accept() s := NewSession(conn) data, err := s.Read() if err != nil { t.Error(err) return } t.Log(string(data)) }() // 读数据 go func() { defer wg.Done() conn, err := net.Dial("tcp", addr) if err != nil { t.Fatal(err) return } s := NewSession(conn) err = s.Write([]byte(test_data)) if err != nil { return } t.Log("写入数据成功") return }() wg.Wait() }

rpc_test.go#

Copy
package rpc import ( "encoding/gob" "fmt" "net" "testing" ) // rpc 客户端和服务端测试 // 定义一个服务端结构体 // 定义一个方法 // 通过调用rpc方法查询用户的信息 type User struct { Name string Age int } // 定义查询用户的方法 // 通过用户id查询用户数据 func queryUser(id int) (User, error) { // 造一些查询user的假数据 users := make(map[int]User) users[0] = User{"user01", 22} users[1] = User{"user02", 23} users[2] = User{"user03", 24} if u, ok := users[id]; ok { return u, nil } return User{}, fmt.Errorf("%d id not found", id) } func TestRpc(t *testing.T) { // 给gob注册类型 gob.Register(User{}) addr := ":8080" // 创建服务端 server := NewServer(addr) // 注册服务 server.Register("queryUser", queryUser) // 启动服务端 go server.Run() // 创建客户端连接 conn, err := net.Dial("tcp", addr) if err != nil { return } // 创客户端 client := NewClient(conn) // 定义函数调用原型 var query func(int) (User, error) // 客户端调用rpc client.RpcCall("queryUser", &query) // 得到返回结果 user, err := query(1) if err != nil { t.Error(err) return } fmt.Printf("%#v\n", user) }
posted @   ZhiChao&  阅读(1644)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
CONTENTS