go反射实现实体映射
源码地址:https://github.com/marshhu/ma-tools
项目中经常要用到实体映射,以前做.net是用AutoMapper做的实体映射,感觉挺方便的。
然而最近做的golang项目,还是比较原始的手动赋值,弄起来挺痛苦的,实在受不了,动手写了个简单的实体映射工具方法,代码如下:
func MapTo(src interface{}, dst interface{}) error {
srcValue := reflect.ValueOf(src)
dstValue := reflect.ValueOf(dst)
//如果src是指针
if srcValue.Type().Kind() == reflect.Ptr {
srcValue = srcValue.Elem() // 取具体内容
}
if dstValue.Kind() != reflect.Ptr || dstValue.IsNil() {
return errors.New("dst is not a pointer or is nil")
}
dstValue = dstValue.Elem()
item := reflect.New(dstValue.Type())
err := setValue(srcValue, item)
if err != nil {
return err
}
if dstValue.IsValid() && dstValue.CanSet() {
dstValue.Set(item.Elem())
}
return nil
}
func setValue(srcValue reflect.Value, dstValue reflect.Value) error {
if dstValue.Kind() != reflect.Ptr || dstValue.IsNil() {
return errors.New("dst is not a pointer or is nil")
}
dstType, dstValue := dstValue.Type().Elem(), dstValue.Elem()
switch srcValue.Kind() {
case reflect.Struct:
if dstValue.Kind() != reflect.Struct {
return errors.New("dst type should be a struct pointer")
}
for i := 0; i < dstValue.NumField(); i++ {
fieldInfo := dstType.Field(i)
ignore := fieldInfo.Tag.Get("ignore")
if ignore == "true" { //映射忽略
continue
}
value := findValueByName(srcValue, fieldInfo) //根据tag和字段名查找值
if !value.IsValid() {
continue
}
if value.Type().String() == "time.Time" { //处理time.Time时间类型
if dstType.Field(i).Type.String() == "string" { //需要将time.Time转换为字符串
timeFormat := fieldInfo.Tag.Get("timeFormat")
if len(timeFormat) <= 0 {
timeFormat = "2006-01-02 15:04:05" //默认时间格式
}
timeValue := value.Interface().(time.Time)
//fmt.Println(fieldName + ":" + timeValue.Format(timeFormat))
if dstValue.Field(i).IsValid() && dstValue.Field(i).CanSet() {
dstValue.Field(i).Set(reflect.ValueOf(timeValue.Format(timeFormat)))
}
} else { //不需要转换 直接赋值
if dstValue.Field(i).IsValid() && dstValue.Field(i).CanSet() && dstValue.Kind() == srcValue.Kind() {
dstValue.Field(i).Set(value)
}
}
} else {
if dstValue.Field(i).IsValid() && dstValue.Field(i).CanSet() {
item := reflect.New(dstValue.Field(i).Type())
setValue(value, item)
dstValue.Field(i).Set(item.Elem())
}
}
}
case reflect.Slice:
if dstType.Kind() != reflect.Slice {
//fmt.Println(dstType.Kind())
return errors.New("dst type should be a slice")
}
for i := 0; i < srcValue.Len(); i++ {
//fmt.Println(srcValue.Index(i))
item := reflect.New(dstValue.Type().Elem())
setValue(srcValue.Index(i), item)
if dstValue.IsValid() && dstValue.CanSet() {
dstValue.Set(reflect.Append(dstValue, item.Elem()))
}
}
case reflect.Array:
if dstType.Kind() != reflect.Slice && dstType.Kind() != reflect.Array {
//fmt.Println(dstType.Kind())
return errors.New("dst type should be a slice or a array")
}
if dstType.Kind() == reflect.Array {
if dstValue.Len() < srcValue.Len() {
return errors.New("dst array length should grater then src")
}
for i := 0; i < srcValue.Len(); i++ {
//fmt.Println(srcValue.Index(i))
item := reflect.New(dstValue.Type().Elem())
setValue(srcValue.Index(i), item)
if dstValue.Index(i).IsValid() && dstValue.Index(i).CanSet() {
dstValue.Index(i).Set(item.Elem())
}
}
}
if dstType.Kind() == reflect.Slice {
for i := 0; i < srcValue.Len(); i++ {
//fmt.Println(srcValue.Index(i))
item := reflect.New(dstValue.Type().Elem())
setValue(srcValue.Index(i), item)
if dstValue.IsValid() && dstValue.CanSet() {
dstValue.Set(reflect.Append(dstValue, item.Elem()))
}
}
}
case reflect.Map:
if dstType.Kind() != reflect.Map { //源数据为切片,要求目标也为map
return errors.New("dst type should be a map")
}
dstValue.Set(reflect.MakeMap(dstValue.Type()))
for _, key := range srcValue.MapKeys() {
//fmt.Println(srcValue.MapIndex(key))
item := reflect.New(dstValue.Type().Elem())
setValue(srcValue.MapIndex(key), item)
if dstValue.IsValid() && dstValue.CanSet() {
dstValue.SetMapIndex(key, item.Elem())
}
}
default:
if dstValue.IsValid() && dstValue.CanSet() && dstValue.Kind() == srcValue.Kind() {
dstValue.Set(srcValue)
}
}
return nil
}
func findValueByName(srcValue reflect.Value, fieldInfo reflect.StructField) reflect.Value {
fieldName := fieldInfo.Tag.Get("mappingField") //优先根据mappingField设置查找
if len(fieldName) > 0 {
value := srcValue.FieldByNameFunc(func(s string) bool {
return strings.ToUpper(s) == strings.ToUpper(fieldName) //不区分大小写
})
return value
}
fieldName = fieldInfo.Name
value := srcValue.FieldByNameFunc(func(s string) bool {
return strings.ToUpper(s) == strings.ToUpper(fieldName) //不区分大小写
})
return value
}
测试代码:
type Address struct {
Country string
City string
}
type AddressB struct {
Country string
}
type Hobby struct {
Name string
Level int
}
type HobbyB struct {
Name string
}
type Model struct {
ID int
CreatedAt time.Time
UpdatedAt time.Time
}
type User struct {
Model
Name string
Gender int
Tel string
Address Address
Hobbies []Hobby
}
type UserDto struct {
Id int
Name string `ignore:"true"` //映射时,忽略此字段
Avatar string
Phone string `mappingField:"Tel"` //指定映射字段名
Address AddressB
Hobbies []HobbyB
CreatedAt time.Time
UpdatedAt string `timeFormat:"2006-01-02"` //指定映射时,时间格式
}
func GetData() []User {
var users []User
user1 := User{Model{1, time.Now(), time.Now()}, "小明", 1, "18800188001",
Address{"中国", "深圳"},
[]Hobby{{Name: "游泳", Level: 2}, {Name: "篮球", Level: 4}},
}
user2 := User{Model{2, time.Now(), time.Now()}, "小红", 0, "18800188002",
Address{"中国", "深圳"},
[]Hobby{{Name: "游泳", Level: 2}, {Name: "篮球", Level: 4}},
}
user3 := User{Model{3, time.Now(), time.Now()}, "小李", 0, "18800188003",
Address{"中国", "深圳"},
[]Hobby{{Name: "游泳", Level: 2}, {Name: "篮球", Level: 4}},
}
user4 := User{Model{4, time.Now(), time.Now()}, "小张", 1, "18800188004",
Address{"中国", "深圳"},
[]Hobby{{Name: "游泳", Level: 2}, {Name: "篮球", Level: 4}},
}
users = append(users, user1, user2, user3, user4)
return users
}
func CheckResult(user *User, userDto *UserDto) bool {
if userDto.Id != user.ID {
return false
}
//if userDto.Name != user.Name {
// return false
//}
if userDto.Phone != user.Tel {
return false
}
if userDto.Address.Country != user.Address.Country {
return false
}
if userDto.CreatedAt.Format("2006-01-02 15:04:05") != user.CreatedAt.Format("2006-01-02 15:04:05") || userDto.UpdatedAt != user.UpdatedAt.Format("2006-01-02") {
return false
}
if len(userDto.Hobbies) != len(user.Hobbies) {
return false
}
for i, hobby := range userDto.Hobbies {
if hobby.Name != user.Hobbies[i].Name {
return false
}
}
return true
}
func TestMapToStruct(t *testing.T) {
data := GetData()
user := data[0]
userDto := &UserDto{}
err := MapTo(user, userDto)
if err != nil {
t.FailNow()
}
if !CheckResult(&user, userDto) {
t.FailNow()
}
}
func TestMapToSlice(t *testing.T) {
data := GetData()
var userDtos []UserDto
err := MapTo(data, &userDtos)
if err != nil {
t.FailNow()
}
if len(userDtos) != len(data) {
t.FailNow()
}
for i := 0; i < len(data); i++ {
userDto := userDtos[i]
user := data[i]
if !CheckResult(&user, &userDto) {
t.FailNow()
}
}
}
func TestMapToArray(t *testing.T) {
data := GetData()
var users [4]User
users[0] = data[0]
users[1] = data[1]
users[2] = data[2]
users[3] = data[3]
var userDtos [4]UserDto
err := MapTo(users, &userDtos)
if err != nil {
t.FailNow()
}
if len(userDtos) != len(users) {
t.FailNow()
}
for i := 0; i < len(users); i++ {
userDto := userDtos[i]
user := users[i]
if !CheckResult(&user, &userDto) {
t.FailNow()
}
}
}
func TestMapToMap(t *testing.T) {
data := GetData()
users := make(map[string]User)
users["小明"] = data[0]
users["小红"] = data[1]
users["小李"] = data[2]
users["小张"] = data[3]
userDtos := make(map[string]UserDto)
err := MapTo(users, &userDtos)
if err != nil {
t.FailNow()
}
if len(userDtos) != len(users) {
t.FailNow()
}
for key := range users {
userDto := userDtos[key]
user := users[key]
if !CheckResult(&user, &userDto) {
t.FailNow()
}
}
}
测试结果如下: