Go源码解读-rpc包
服务端代码
一个简单的rpc server示例如下:
package main
import (
"log"
"net"
"net/http"
"net/rpc"
"github.com/monoxy/rpc/common"
)
func main() {
server := rpc.NewServer()
server.Register(new(common.Embed))
lis, err := net.Listen("tcp", ":1234")
if err != nil {
log.Fatalf("list err: %v", err)
}
http.Serve(lis, server)
}
我们通过rpc.NewServer
拿到Server结构体,表示rpc server,其定义为:
type Server struct {
serviceMap sync.Map // map[string]*service
reqLock sync.Mutex // 对freeReq提供锁保护
freeReq *Request
respLock sync.Mutex // 对freeResp提供锁保护
freeResp *Response
}
提供服务的实体抽象为service(一般与一个接收器对应,接收器可能是结构体或指针):
type service struct {
name string // 服务名称,没有指定便为结构体的导出名称
rcvr reflect.Value // receiver of methods for the service
typ reflect.Type // type of the receiver
method map[string]*methodType // registered methods
}
serviceMap是一个同步sync.Map,以serviceName为key,Serice为value。通过server.Register
将相应的服务,也就是service注册到当前的server中去。Register的代码和详细注释如下:
func (server *Server) register(rcvr interface{}, name string, useName bool) error {
s := new(service)
s.typ = reflect.TypeOf(rcvr)
s.rcvr = reflect.ValueOf(rcvr)
sname := reflect.Indirect(s.rcvr).Type().Name()
if useName { // 如果指定service name,则使用name参数
sname = name
}
if sname == "" {
s := "rpc.Register: no service name for type " + s.typ.String()
log.Print(s)
return errors.New(s)
}
if !token.IsExported(sname) && !useName { // service name必须是可导出的,即首字母大写
s := "rpc.Register: type " + sname + " is not exported"
log.Print(s)
return errors.New(s)
}
s.name = sname
// 调用suitableMethods,将接收器中的导出方法注册到s.method变量中
s.method = suitableMethods(s.typ, true)
if len(s.method) == 0 {
str := ""
// 这里兼容reciever是指针,但注册的对象是结构体的情况,通过reflect.PtrTo去相应的结构体指针中查找方法并注册
method := suitableMethods(reflect.PtrTo(s.typ), false)
if len(method) != 0 {
str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
} else {
str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
}
log.Print(str)
return errors.New(str)
}
// 当service的method注册完成后,将service整体注册到server的serviceMap变量中,注意这里使用了sync.Map
if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
return errors.New("rpc: service already defined: " + sname)
}
return nil
}
再来看suitabaleMethod
,这个函数的作用是遍历接收器中的方法,这里就不列举源码,简单来说,对每个方法判断以下条件是否满足:
- 方法名可导出,即首字母大写
- 方法有三个入参,接收器本身、请求参数(*args)、响应参数(*reply)
- 请求参数和响应参数均可导出
- 响应参数必须为指针类型
- 方法只有一个出参error
如果满足,则构造methodType对象,将其注册到service.method
中
注册后方法后,需要提供服务了,rpc包实现了tcp和http两种方式。开关的示例代码即是采用了http server。显然我们的rpc server需要提供ServerHTTP,以实现http.Handler接口。
func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// 客户端需要发起CONNECT请求,才能进行后续的rpc调用
if req.Method != "CONNECT" {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusMethodNotAllowed)
io.WriteString(w, "405 must CONNECT\n")
return
}
// 注意这里通过Hijack劫持了http协议,拿到conn连接
conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
return
}
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
server.ServeConn(conn)
}
server.ServeConn
里表示对每一条tcp连接的处理过程,首先构建gob编解码器,再将流程转到server.ServeCodec
中处理
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
buf := bufio.NewWriter(conn)
srv := &gobServerCodec{
rwc: conn,
dec: gob.NewDecoder(conn),
enc: gob.NewEncoder(buf),
encBuf: buf,
}
server.ServeCodec(srv)
}
func (server *Server) ServeCodec(codec ServerCodec) {
sending := new(sync.Mutex)
wg := new(sync.WaitGroup)
for {
// 从codec中读取每一次rpc请求的service,方法,请求参数和响应参数等信息
service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
if err != nil {
if debugLog && err != io.EOF {
log.Println("rpc:", err)
}
if !keepReading {
break
}
// send a response if we actually managed to read a header.
if req != nil {
server.sendResponse(sending, req, invalidRequest, codec, err.Error())
server.freeRequest(req)
}
continue
}
wg.Add(1)
// 将相关的参数传入到call中执行,函数内部利用反射调用接收体的对应方法,将处理结果写回到response中
go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
}
// 在关闭codec之前,等全部的service.call协程退出
wg.Wait()
codec.Close()
}
server.readRequest
从codec中解码每一次的rpc请求,一定是按先解码请求头,再解码请求体的顺序,分别对应codec.ReadRequestHeader
和codec.ReadRequestBody
两个函数。
func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
service, mtype, req, keepReading, err = server.readRequestHeader(codec)
if err != nil {
if !keepReading {
return
}
// 读取request header出错,则丢弃剩余的body数据
codec.ReadRequestBody(nil)
return
}
// 解码请求参数
argIsValue := false // if true, need to indirect before calling.
if mtype.ArgType.Kind() == reflect.Ptr {
argv = reflect.New(mtype.ArgType.Elem()) // 通过relfect.New构造对应类型的指针
} else {
argv = reflect.New(mtype.ArgType) // 通过relfect.New构造对应类型的指针
argIsValue = true
}
// 此时argv的内部value为指针,通过argv。Interface()拿到指针变量
if err = codec.ReadRequestBody(argv.Interface()); err != nil {
return
}
// 这里,如果请求参数原本是值类型,也要通过argv。Elem()将指针还原为值类型
if argIsValue {
argv = argv.Elem()
}
// 返回参数一定是指针类型,所以直接通过reflect.New构造
replyv = reflect.New(mtype.ReplyType.Elem())
// 针对map和slice两类参数,提前make
switch mtype.ReplyType.Elem().Kind() {
case reflect.Map:
replyv.Elem().Set(reflect.MakeMap(mtype.ReplyType.Elem()))
case reflect.Slice:
replyv.Elem().Set(reflect.MakeSlice(mtype.ReplyType.Elem(), 0, 0))
}
return
}
func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) {
// 从server中取一个request指针
req = server.getRequest()
err = codec.ReadRequestHeader(req) // 将gob字节流解码到req上
if err != nil {
req = nil
if err == io.EOF || err == io.ErrUnexpectedEOF {
return
}
err = errors.New("rpc: server cannot decode request: " + err.Error())
return
}
// We read the header successfully. If we see an error now,
// we can still recover and move on to the next request.
keepReading = true
// 请求头中的ServiceName一定是"Service.Method"这种格式,注意service中可能也包含.,我们需要取最后一个.作为分隔符。
dot := strings.LastIndex(req.ServiceMethod, ".")
if dot < 0 {
err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
return
}
serviceName := req.ServiceMethod[:dot]
methodName := req.ServiceMethod[dot+1:]
// 凭借serviceName去serviceMap中查找对应的service
svci, ok := server.serviceMap.Load(serviceName)
if !ok {
err = errors.New("rpc: can't find service " + req.ServiceMethod)
return
}
svc = svci.(*service)
mtype = svc.method[methodName]
if mtype == nil {
err = errors.New("rpc: can't find method " + req.ServiceMethod)
}
return
}
客户端代码
简单的客户端示例如下:
package main
import (
"fmt"
"log"
"net/rpc"
"github.com/monoxy/rpc/common"
)
func main() {
client, err := rpc.DialHTTP("tcp", "127.0.0.1:1234")
if err != nil {
log.Fatal(err)
}
args := common.Args{A: 1, B: 2}
reply := new(common.Reply)
err = client.Call("Embed.Add", args, reply)
fmt.Println(reply)
fmt.Println(err)
}
针对获取rpc.Client,rpc包提供了两类方法,DialHTTP和Dial,前者以http协议的形式进行初始化,后者为tcp协议。
DialHTTP的使用前提是服务端需要提供http服务。DialHTTP中,以默认rpc路由为参数调用DialHTTPPath。可以看到DialHTTPPath内部首先对服务端发送了一次CONNET请求,如果成功,则将conn构造rpc.Client对象并返回。
我们可以发现,DialHTTP相比于Dial,实际上就多发起了一次CONNECT请求。
func DialHTTP(network, address string) (*Client, error) {
return DialHTTPPath(network, address, DefaultRPCPath)
}
// DialHTTPPath connects to an HTTP RPC server
// at the specified network address and path.
func DialHTTPPath(network, address, path string) (*Client, error) {
var err error
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n")
// 需要成功的HTTP响应,才能切换到rpc协议
resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
if err == nil && resp.Status == connected {
return NewClient(conn), nil
}
if err == nil {
err = errors.New("unexpected HTTP response: " + resp.Status)
}
conn.Close()
return nil, &net.OpError{
Op: "dial-http",
Net: network + " " + address,
Addr: nil,
Err: err,
}
}
NewClient的参数只有一个conn,先基于conn构造客户端gob解码器,再将其作为参数传入NewClientWithCodec。构造完成后,开启一个读协程,读取服务端返回的数据。
func NewClient(conn io.ReadWriteCloser) *Client {
encBuf := bufio.NewWriter(conn)
client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
return NewClientWithCodec(client)
}
// NewClientWithCodec is like NewClient but uses the specified
// codec to encode requests and decode responses.
func NewClientWithCodec(codec ClientCodec) *Client {
client := &Client{
codec: codec,
pending: make(map[uint64]*Call),
}
// 注意这里开启了读协程
go client.input()
return client
}
// 读协程
func (client *Client) input() {
var err error
var response Response
for err == nil {
response = Response{}
err = client.codec.ReadResponseHeader(&response)
if err != nil {
break
}
seq := response.Seq
client.mutex.Lock()
call := client.pending[seq]
delete(client.pending, seq)
client.mutex.Unlock()
switch {
case call == nil:
// call为空,表明在执行client.Call或者client.Go发送请求时已经出现失败,将对应的call删除,那么接下的处理就是丢弃对应的body,只获取错误信息
err = client.codec.ReadResponseBody(nil)
if err != nil {
err = errors.New("reading error body: " + err.Error())
}
case response.Error != "":
// 服务端返回了错误,将错误信息返回给call,并终止循环。那么所有后续的处于pending的call将收到同样的错误信息
call.Error = ServerError(response.Error)
err = client.codec.ReadResponseBody(nil)
if err != nil {
err = errors.New("reading error body: " + err.Error())
}
call.done()
default:
err = client.codec.ReadResponseBody(call.Reply)
if err != nil {
call.Error = errors.New("reading body " + err.Error())
}
call.done()
}
}
// 终止pending状态的call
client.reqMutex.Lock()
client.mutex.Lock()
client.shutdown = true
closing := client.closing
if err == io.EOF {
if closing {
err = ErrShutdown
} else {
err = io.ErrUnexpectedEOF
}
}
for _, call := range client.pending {
call.Error = err
call.done()
}
client.mutex.Unlock()
client.reqMutex.Unlock()
if debugLog && err != io.EOF && !closing {
log.Println("rpc: client protocol error:", err)
}
}
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 一个奇形怪状的面试题:Bean中的CHM要不要加volatile?
· [.NET]调用本地 Deepseek 模型
· 一个费力不讨好的项目,让我损失了近一半的绩效!
· PowerShell开发游戏 · 打蜜蜂
· 在鹅厂做java开发是什么体验
· 百万级群聊的设计实践
· WPF到Web的无缝过渡:英雄联盟客户端的OpenSilver迁移实战
· 永远不要相信用户的输入:从 SQL 注入攻防看输入验证的重要性