gomock是Go官方提供的测试框架,可以使用它对代码中的那些接口类型进行mock,方便编写单元测试。
安装mockgen
go install github.com/golang/mock/mockgen@v1.6.0
构建mock
为数据库函数编写单元测试代码,可我们又不能在单元测试过程中连接真实的数据库,这个时候就需要mock DB这个接口来方便进行单元测试。
使用mockgen 工具来为生成相应的mock代码。通过执行下面的命令,我们就能在当前项目下生成一个mock文件夹,里面存放了一个mock_employees.go文件。
mockgen -source=employees.go -destination=mock/mock_employees.go -package=mock
mock_employees.go文件中的内容就是mock相关接口的代码了。
我们通常不需要编辑它,只需要在单元测试中按照规定的方式使用它们就可以了。
-source:包含要mock的接口的文件。
-destination:生成的源代码写入的文件。如果不设置此项,代码将打印到标准输出。
-package:用于生成的模拟类源代码的包名。如果不设置此项包名默认在原包名前添加mock_前缀。
-imports:在生成的源代码中使用的显式导入列表。值为foo=bar/baz形式的逗号分隔的元素列表,其中bar/baz是要导入的包,foo是要在生成的源代码中用于包的标识符。
mock文件 无需编辑
// Code generated by MockGen. DO NOT EDIT.
// Source: employees.go
// Package mock is a generated GoMock package.
package mock
import (
reflect "reflect"
domain "server/domain"
gomock "github.com/golang/mock/gomock"
)
// MockEmployeesRepository is a mock of EmployeesRepository interface.
type MockEmployeesRepository struct {
ctrl *gomock.Controller
recorder *MockEmployeesRepositoryMockRecorder
}
// MockEmployeesRepositoryMockRecorder is the mock recorder for MockEmployeesRepository.
type MockEmployeesRepositoryMockRecorder struct {
mock *MockEmployeesRepository
}
// NewMockEmployeesRepository creates a new mock instance.
func NewMockEmployeesRepository(ctrl *gomock.Controller) *MockEmployeesRepository {
mock := &MockEmployeesRepository{ctrl: ctrl}
mock.recorder = &MockEmployeesRepositoryMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockEmployeesRepository) EXPECT() *MockEmployeesRepositoryMockRecorder {
return m.recorder
}
// CreateEmployee mocks base method.
func (m *MockEmployeesRepository) CreateEmployee(arg0 *domain.Employees) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateEmployee", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CreateEmployee indicates an expected call of CreateEmployee.
func (mr *MockEmployeesRepositoryMockRecorder) CreateEmployee(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEmployee", reflect.TypeOf((*MockEmployeesRepository)(nil).CreateEmployee), arg0)
}
函数
func (repo *EmployeesRepository) CreateEmployee(employee *domain.Employees) (err error) {
err = repo.db.Debug().Create(&employee).Error
if err != nil {
err = fmt.Errorf("[repository.r.CreateEmployee] failed: employee = %+v, error = %w ", employee, err)
return
}
return
}
测试用例
package impl
import (
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/magiconair/properties/assert"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"server/domain"
)
func getDBMock() (*gorm.DB, sqlmock.Sqlmock, error) {
db, mock, err := sqlmock.New()
if err != nil {
return nil, nil, err
}
//defer db.Close()
gdb, err := gorm.Open(postgres.New(postgres.Config{
DriverName: "postgres",
PreferSimpleProtocol: true,
Conn: db,
}), &gorm.Config{})
if err != nil {
return nil, nil, err
}
return gdb, mock, nil
}
func TestEmployeesRepository_CreateEmployees(t *testing.T) {
type fields struct {
db *gorm.DB
}
type args struct {
employees *domain.Employees
}
db, mock, err := getDBMock()
assert.Equal(t, err, nil)
tests := []struct {
name string
fields fields
args args
invoke func(args)
wantErr bool
}{
{
name: "create Employees successful",
fields: fields{
db: db,
},
args: args{employees: &domain.Employees{
ID:0,
Code: "1",
Name:"zhangsan",
DepartmentID:1,
}},
invoke: func(args args) {
mock.ExpectQuery("INSERT INTO (.+)").WithArgs(
args.employees.Code, args.employees.Name,args.employees.DepartmentID,sqlmock.AnyArg(),sqlmock.AnyArg(),sqlmock.AnyArg()).
WillReturnRows(sqlmock.NewRows([]string{"ID"}).AddRow(1))
},
},
{
name: "create Employees failed",
fields: fields{
db: db,
},
args: args{employees: &domain.Employees{
Code: "1",
Name:"zhangsan",
DepartmentID:2,
}},
invoke: func(args args) {
mock.ExpectQuery("INSERT INTO (.+)").WithArgs(
args.employees.Code,args.employees.Name,args.employees.DepartmentID,sqlmock.AnyArg(),sqlmock.AnyArg(),sqlmock.AnyArg()).
WillReturnError(gorm.ErrInvalidData)
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &EmployeesRepository{
db: tt.fields.db,
}
tt.invoke(tt.args)
if err := c.CreateEmployee(tt.args.employees); (err != nil) != tt.wantErr {
t.Errorf("CreateEmployees() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
所有博客均为自己学习的笔记。如有错误敬请理解。