Golang 中 mock 库的实现
mock 库的地址: https://github.com/golang/mock
mock 库算是 go 项目中编写单元测试时的必备库了,它分为两个模块
- mockgen: 可以根据接口来生成单元测试代码
- gomock: 利用 mockgen 生成测试代码来实现打桩 (stub) 功能
其实之前对这个库就有一些好奇,这次趁着五一在家隔离,所以看了看 mock 库的实现。
好奇点
首先列举一下我好奇的点,之后围绕着这些点在代码中寻找答案。
- mockgen 是如何根据接口生成代码的?
- gomock 是怎样将 Expect() 中指定的参数 (ArgsIn) 与执行时接收到参数进行匹配的?
- gomock 如何做到在执行接口时的返回在 Expect() 中指定的返回值?
- gomock 是怎么判断某个方法的执行次数与预期不符的?
mockgen 是如何根据接口生成代码的?
首先把 mock 库 clone 到本地,寻找 mockgen 的入口函数 (main) ,在 mockgen/mockgen.go 文件中。
我常用的利用 mockgen 生成代码的命令为为
mockgen -destination ../examplemock/example_mock.go -package examplemock -source example.go IExample
-destination
指定了 mock 代码的目标路径 (destination)
-package
指定了 mock 代码所在的包名 (packageOut)
-source
指定了源文件名 (source)
IExample
指定了接口名 (srcInterfaces)。
main() 函数的逻辑为如下几步
- parse flag
- parse 源文件得到 model.Package 对象
- 创建目标文件,以及文件句柄
- 创建代码生成器 generator 对象,并给对象内的字段赋值
- 利用generator 对象生成代码
- 将生成的代码输出到目标文件中
func main() {
// parse flag
flag.Parse()
var pkg *model.Package
var err error
if *source != "" {
// parse 源文件得到 pkg 对象
pkg, err = sourceMode(*source)
} else {
// ...
}
// ...
// 创建目标文件,以及文件句柄
dst := os.Stdout
if len(*destination) > 0 {
if err := os.MkdirAll(filepath.Dir(*destination), os.ModePerm); err != nil {
log.Fatalf("Unable to create directory: %v", err)
}
// ...
}
outputPackageName := *packageOut
// ...
// 创建代码生成器 generator 对象,并给对象内的字段赋值
g := new(generator)
if *source != "" {
g.filename = *source
} else {
// ...
}
g.destination = *destination
// ...
// 利用 g (generator 对象) 生成代码
if err := g.Generate(pkg, outputPackageName, outputPackagePath); err != nil {
log.Fatalf("Failed generating mock: %v", err)
}
// 将生成的代码输出到文件中
if _, err := dst.Write(g.Output()); err != nil {
log.Fatalf("Failed writing to destination: %v", err)
}
}
model.Package 是如何生成的?
先看看 Package struct 的定义:
// mockgen/model/model.go
// Package is a Go package. It may be a subset.
type Package struct {
Name string
PkgPath string
Interfaces []*Interface
DotImports []string
}
// Interface is a Go interface.
type Interface struct {
Name string
Methods []*Method
}
// Method is a single method of an interface.
type Method struct {
Name string
In, Out []*Parameter
Variadic *Parameter // may be nil
}
// Parameter is an argument or return parameter of a method.
type Parameter struct {
Name string // may be empty
Type Type
}
Package 对象是通过 sourceMode() 函数完成的,它内部用到了 go 标准库 parse 包的方法,parse 包我就不去深究了,估计是一些 AST 语法树相关的知识。这里我获取到的启发就是,如果以后碰到需要分析 go 文件的需求,可以使用标准库提供的 parse 包。
import (
// ...
"go/parser"
"go/token"
// ...
)
// sourceMode generates mocks via source file.
func sourceMode(source string) (*model.Package, error) {
// ...
fs := token.NewFileSet()
file, err := parser.ParseFile(fs, source, nil, 0)
// ...
}
generator 是如何生成代码的?
generator 对象通过 Generate 这个方法来生成代码
func (g *generator) p(format string, args ...interface{}) {
fmt.Fprintf(&g.buf, g.indent+format+"\n", args...)
}
func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPackagePath string) error {
// ...
// 生成包名和 import 语句
g.p("package %v", outputPkgName)
g.p("")
g.p("import (")
g.in()
for pkgPath, pkgName := range g.packageMap {
if pkgPath == outputPackagePath {
continue
}
g.p("%v %q", pkgName, pkgPath)
}
for _, pkgPath := range pkg.DotImports {
g.p(". %q", pkgPath)
}
g.out()
g.p(")")
// 生成接口
for _, intf := range pkg.Interfaces {
if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil {
return err
}
}
return nil
}
GenerateMockInterface 生成接口的 mock 代码,主要包括 mock struct 的定义以及 struct 的一些方法,基本都是一些 for range + Fprintf() 操作了。
gomock 的主流程
下面 gomock 相关的代码我们通过一个简单的单元测试来剖析:
func TestCallExample(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
e := NewMockExample(ctrl)
e.EXPECT().someMethod(gomock.Any()).Return("it works!")
fmt.Println(e.someMethod("test"))
}
Example 接口的定义为:
// Example is an interface with a non exported method
type Example interface {
someMethod(string) string
}
MockExample 与 MockExampleMockRecorder 的定义
MockExample 与 MockExampleMockRecorder 两者互相套娃
// MockExample is a mock of Example interface.
type MockExample struct {
ctrl *gomock.Controller
recorder *MockExampleMockRecorder
}
// MockExampleMockRecorder is the mock recorder for MockExample.
type MockExampleMockRecorder struct {
mock *MockExample
}
理论上可以省掉 MockExampleMockRecorder 这个 struct 的定义,直接这样定义:
type MockExample struct {
ctrl *gomock.Controller
this *MockExample
}
对这么做的原因是因为我们需要两个同名的方法,这两个方法的作用不一样,一个为打桩,一个为真正执行。
MockExample 有一个 someMethod 方法,用于真正执行。
MockExampleMockRecorder 也得有一个 someMethod 方法,用于打桩。
显然无法为 MockExample 这个 struct 定义两个同名的方法。
e.EXPECT().someMethod(gomock.Any()).Return("it works!")
这句代码是三个方法的组合,分别为 EXPECT(), someMethod() 和 Return(),咱们一个一个分析
EXPECT()
// MockExample is a mock of Example interface.
type MockExample struct {
ctrl *gomock.Controller
recorder *MockExampleMockRecorder
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockExample) EXPECT() *MockExampleMockRecorder {
return m.recorder
}
Expect() 仅仅就是返回了内部的 MockExampleMockRecorder 成员。
someMethod()
// someMethod indicates an expected call of someMethod.
func (mr *MockExampleMockRecorder) someMethod(arg0 interface{}) *gomock.Call {
// ...
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "someMethod", reflect.TypeOf((*MockExample)(nil).someMethod), arg0)
}
// RecordCallWithMethodType is called by a mock. It should not be called by user code.
func (ctrl *Controller) RecordCallWithMethodType(receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call {
// ...
call := newCall(ctrl.T, receiver, method, methodType, args...)
// ...
ctrl.expectedCalls.Add(call)
return call
}
someMethod() 调用了 Controller.RecordCallWithMethodType(),RecordCallWithMethodType() 内构造了一个 Call 对象,然后将 Call 对象放入了 expectedCalls 内
看看 Controller 对象的定义
type Controller struct {
// ...
expectedCalls *callSet
finished bool
}
type callSet struct {
// Calls that are still expected.
expected map[callSetKey][]*Call
// Calls that have been exhausted.
exhausted map[callSetKey][]*Call
}
// callSetKey is the key in the maps in callSet
type callSetKey struct {
receiver interface{}
fname string
}
Controller.expectedCalls 类型为 callSet,内部有两个 map,expected 和 exhausted,map 的 key 是一个 struct(其实我们很少用 struct 作为 map 的 key),callSetKey 的作用就是唯一识别出一个方法。
newCall() 的实现可以先放在一旁,这里我们需要记住的就是 someMethod 方法生成了一个 Call 对象,并放入了 Controller.callSet 内,同时把 Call 对象的指针给返回了。
Return()
// Call represents an expected call to a mock.
type Call struct {
// ...
// actions are called when this Call is called. Each action gets the args and
// can set the return values by returning a non-nil slice. Actions run in the
// order they are created.
actions []func([]interface{}) []interface{}
}
func (c *Call) Return(rets ...interface{}) *Call {
// ...
c.addAction(func([]interface{}) []interface{} {
return rets
})
return c
}
func (c *Call) addAction(action func([]interface{}) []interface{}) {
c.actions = append(c.actions, action)
}
所以 Return 是将 rets 封装成了一个函数,并把这个函数放入了 Call.actions 里面。
e.someMethod("test")
接下来看看 e.someMethod() 是如何执行的。
func (m *MockExample) someMethod(arg0 string) string {
// ...
ret := m.ctrl.Call(m, "someMethod", arg0)
ret0, _ := ret[0].(string)
return ret0
}
e.someMethod() 调用了 ctrl.Call(),并将 receiver,方法名和参数传递了进去。
Controller.Call()
// Call is called by a mock. It should not be called by user code.
func (ctrl *Controller) Call(receiver interface{}, method string, args ...interface{}) []interface{} {
// ...
actions := func() []func([]interface{}) []interface{} {
// ...
// 从 expectedCalls 找到之前注册的方法
expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args)
// ...
actions := expected.call()
if expected.exhausted() {
// 如果方法的次数用尽 (exhausted) 了,就 Remove 掉
ctrl.expectedCalls.Remove(expected)
}
// 返回了 actions, 与上文中 Return() 里面的 action 联动
return actions
}()
var rets []interface{}
// 执行 actions 得到返回值 rets
for _, action := range actions {
if r := action(args); r != nil {
rets = r
}
}
// 返回
return rets
}
Controller.Call() 主要做了如下几步操作:
-
调用 expectedCalls.FindMatch() 方法找到之前注册的方法,与上文提到的 someMethod() 里面对方法的注册联动。FindMatch() 的实现后文会说。
-
执行 expected.call(),它的作用很简单,就是增加一次调用次数
func (c *Call) call() []func([]interface{}) []interface{} { c.numCalls++ return c.actions }
-
通过 expected.exhausted() 检查调用次数是否用尽 (感觉 exhausted 这个方法名取的挺好的,exhausted 是精疲力竭的意思,精疲力竭 ≈ 用尽)
func (c *Call) exhausted() bool { return c.numCalls >= c.maxCalls }
-
如果次数用尽了,就通过 Remove() 将 Call 对象从 callSet.expected 移动到 callSet.exhausted 中
func (cs callSet) Remove(call *Call) { key := callSetKey{call.receiver, call.method} calls := cs.expected[key] for i, c := range calls { if c == call { // maintain order for remaining calls cs.expected[key] = append(calls[:i], calls[i+1:]...) cs.exhausted[key] = append(cs.exhausted[key], call) break } } }
-
执行 action() ,得到返回值
ctrl.Finish()
ctrl.Finish() 会调用 ctrl.finish(),ctrl.finish() 会找到执行次数没有达标的方法,并返回错误。
func (ctrl *Controller) Finish() {
// ...
ctrl.finish(false, err)
}
func (ctrl *Controller) finish(cleanup bool, panicErr interface{}) {
// ...
// Check that all remaining expected calls are satisfied.
failures := ctrl.expectedCalls.Failures()
for _, call := range failures {
ctrl.T.Errorf("missing call(s) to %v", call)
}
if len(failures) != 0 {
if !cleanup {
ctrl.T.Fatalf("aborting test due to missing call(s)")
return
}
ctrl.T.Errorf("aborting test due to missing call(s)")
}
}
// Failures returns the calls that are not satisfied.
func (cs callSet) Failures() []*Call {
failures := make([]*Call, 0, len(cs.expected))
for _, calls := range cs.expected {
for _, call := range calls {
if !call.satisfied() {
failures = append(failures, call)
}
}
}
return failures
}
// Returns true if the minimum number of calls have been made.
func (c *Call) satisfied() bool {
return c.numCalls >= c.minCalls
}
此外,在调用 NewController() 函数时,也会将 ctrl.finish() 注册到 Test.Cleanup() 中。
func NewController(t TestReporter) *Controller {
// ...
ctrl := &Controller{
T: h,
expectedCalls: newCallSet(),
}
if c, ok := isCleanuper(ctrl.T); ok {
c.Cleanup(func() {
ctrl.T.Helper()
ctrl.finish(true, nil) // 这里
})
}
return ctrl
}
到此为止,我们已经知道 gomock 的主流程了,可以解答最初的疑问了。
gomock 是怎样将 Expect() 中指定的参数 (ArgsIn) 与执行时接收到参数进行匹配的?
这里就需要涉及到 FindMatch() 了。
callSet.FindMatch()
func (cs callSet) FindMatch(receiver interface{}, method string, args []interface{}) (*Call, error) {
key := callSetKey{receiver, method}
// Search through the expected calls.
expected := cs.expected[key]
var callsErrors bytes.Buffer
for _, call := range expected {
err := call.matches(args)
if err != nil {
_, _ = fmt.Fprintf(&callsErrors, "\n%v", err)
} else {
return call, nil
}
}
// If we haven't found a match then search through the exhausted calls so we
// get useful error messages.
exhausted := cs.exhausted[key]
for _, call := range exhausted {
if err := call.matches(args); err != nil {
_, _ = fmt.Fprintf(&callsErrors, "\n%v", err)
continue
}
_, _ = fmt.Fprintf(
&callsErrors, "all expected calls for method %q have been exhausted", method,
)
}
if len(expected)+len(exhausted) == 0 {
_, _ = fmt.Fprintf(&callsErrors, "there are no expected calls of the method %q for that receiver", method)
}
return nil, errors.New(callsErrors.String())
}
注释说的很明白了,FindMatch() 会先在 expected 内寻找,调用 call.matches(args) 判断参数是否匹配,如果找不到,就会返回 error,之所以额外在 exhausted 内也寻找一次,是为了得到一些有用的 error 信息。
Call.matches()
type Call struct {
// ...
args []Matcher // the args
}
type Matcher interface {
// Matches returns whether x is a match.
Matches(x interface{}) bool
// String describes what the matcher matches.
String() string
}
func (c *Call) matches(args []interface{}) error {
// IsVariadic() 用户判断方法的参数是否是可变的 (最后一个参数是 ... 的形式)
// 这里我们仅讨论非可变参数的情况
if !c.methodType.IsVariadic() {
if len(args) != len(c.args) {
return fmt.Errorf("expected call at %s has the wrong number of arguments. Got: %d, want: %d",
c.origin, len(args), len(c.args))
}
// 利用 Matches() 方法逐个比较注册的参数 (c.args) 与实际到达的参数是否一致
for i, m := range c.args {
if !m.Matches(args[i]) {
return fmt.Errorf(
"expected call at %s doesn't match the argument at index %d.\nGot: %v\nWant: %v",
c.origin, i, formatGottenArg(m, args[i]), m,
)
}
}
} else {
// ...
}
// ...
// Check that the call is not exhausted.
if c.exhausted() {
return fmt.Errorf("expected call at %s has already been called the max number of times", c.origin)
}
return nil
}
入参的注册
参数的匹配是利用 Matcher.Matches() 完成的,但我我们却很少为 Args 实现 Matcher() 接口,我们最终用的是传一个普通的类型,例如 int,string,struct 等。
其实 gomock 在 newCall() 为我们做了从普通类型到 Matcher 接口的转换。
// newCall creates a *Call. It requires the method type in order to support
// unexported methods.
func newCall(t TestHelper, receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call {
// ...
mArgs := make([]Matcher, len(args))
for i, arg := range args {
if m, ok := arg.(Matcher); ok {
// 如果参数实现了 Matcher 接口,那就不用做额外工作了
mArgs[i] = m
} else if arg == nil {
// Handle nil specially so that passing a nil interface value
// will match the typed nils of concrete args.
// Nil() 的定义见下面
mArgs[i] = Nil()
} else {
// Eq() 的定义见下面
mArgs[i] = Eq(arg)
}
}
// ...
}
一些 Matcher 接口的 Implement
nilMatcher
func Nil() Matcher { return nilMatcher{} }
type nilMatcher struct{}
func (nilMatcher) Matches(x interface{}) bool {
if x == nil {
return true
}
v := reflect.ValueOf(x)
switch v.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map,
reflect.Ptr, reflect.Slice:
return v.IsNil()
}
return false
}
eqMatcher
func Eq(x interface{}) Matcher { return eqMatcher{x} }
type eqMatcher struct {
x interface{}
}
func (e eqMatcher) Matches(x interface{}) bool {
// In case, some value is nil
if e.x == nil || x == nil {
return reflect.DeepEqual(e.x, x)
}
// Check if types assignable and convert them to common type
x1Val := reflect.ValueOf(e.x)
x2Val := reflect.ValueOf(x)
if x1Val.Type().AssignableTo(x2Val.Type()) {
x1ValConverted := x1Val.Convert(x2Val.Type())
return reflect.DeepEqual(x1ValConverted.Interface(), x2Val.Interface())
}
return false
}
anyMatcher
func Any() Matcher { return anyMatcher{} }
type anyMatcher struct{}
func (anyMatcher) Matches(interface{}) bool {
return true
}
gomock 如何做到在执行接口时的返回在 Expect() 中指定的返回值?
上面提到了,通过在 return 里面注册 actions,在执行接口时 apply actions。
gomock 是怎么判断某个方法的执行次数与预期不符的?
每执行一次,执行次数就 +1,通过 exhaust() 方法来判断调用次数不会超过预期次数,通过 satisfied() 来判断调用次数不会少于预期次数。
exhaust() 会在每次执行方法时被调用。
satisfied() 会在 Finish() 内调用。
意外的收货
在阅读 gomock 的代码时,发现它会在很多地方调用 T.Helper() 这个方法。
在利用 T.log() 打印信息时,如果当前的函数栈调用了 T.Helper(),就不再打印当前函数栈的文件行信息,而是打印该函数栈的上一层信息,例子如下,下面这几个 snippet 在利用 t.Log() (t.Log() 会调用 t.log())打印信息时,输出的行号时不一样的。
package XXXTest
import "testing"
func TestXXX(t *testing.T) {
// t.Helper()
Say1(t)
}
func Say1(t *testing.T) {
// t.Helper()
Say2(t)
}
func Say2(t *testing.T) {
// t.Helper()
t.Log("hello world") // xxx_test.go:16: hello world
}
package XXXTest
import "testing"
func TestXXX(t *testing.T) {
// t.Helper()
Say1(t)
}
func Say1(t *testing.T) {
// t.Helper()
Say2(t)
}
func Say2(t *testing.T) {
t.Helper()
t.Log("hello world") // xxx_test.go:11: hello world
}
package XXXTest
import "testing"
func TestXXX(t *testing.T) {
// t.Helper()
Say1(t)
}
func Say1(t *testing.T) {
t.Helper()
Say2(t)
}
func Say2(t *testing.T) {
t.Helper()
t.Log("hello world") // xxx_test.go:7: hello world
}
package XXXTest
import "testing"
func TestXXX(t *testing.T) {
t.Helper()
Say1(t)
}
func Say1(t *testing.T) {
t.Helper()
Say2(t)
}
func Say2(t *testing.T) {
t.Helper()
t.Log("hello world") // xxx_test.go:7: hello world
}