go gorm select * 优化

很多时候sql查询都不允许select * 出现要求指定列名,如果你是用gormv2 ,恭喜你可以使用QueryFields属性,如果是gormv1版本怎么样,难道要升级gormV2吗,这里提供种反射的实现,可能不是最优解,但只是一个方案。

首先mysq建一个表

CREATE TABLE `test` (
  `id` BIGINT(20) NOT NULL,
  `name` VARCHAR(5) DEFAULT NULL,
  `age` INT(11) DEFAULT NULL,
  PRIMARY KEY (`id`)
) ENGINE=INNODB DEFAULT CHARSET=utf8mb4 

go的代码如下:

package main

import (
    "fmt"
    "gorm.io/driver/mysql"
    "gorm.io/gorm/logger"
    "gorm.io/gorm/schema"
    "reflect"
    "strings"

    _ "github.com/go-sql-driver/mysql"
    "gorm.io/gorm"
)

func main() {
    dns := "root:root@tcp(192.168.100.30:3306)/demo?charset=utf8&parseTime=True&loc=Local"
    config := &gorm.Config{
        NamingStrategy: schema.NamingStrategy{
            SingularTable: true,
        },
        //QueryFields: true,
        Logger: logger.Default.LogMode(logger.Info),
    }

    db, err := gorm.Open(mysql.Open(dns), config)
    if err != nil {
        fmt.Println(fmt.Sprintf("Open err:%v", err))
    }

    var ret []*Test
    //一般查询
    fmt.Println("一般查询")
    err = db.Table("test").Where("id>1").Find(&ret).Debug().Error
    if err != nil {
        fmt.Println(fmt.Sprintf("select err:%v", err))
    }

    //通过反射指定列
    fmt.Println("通过反射指定列")
    err = db.Table("test").Where("id>1").Select(GetAllFields(new(Test))).Find(&ret).Debug().Error
    if err != nil {
        fmt.Println(fmt.Sprintf("select err:%v", err))
    }

    //通过QueryFields 指定列
    fmt.Println("通过QueryFields 指定列")
    tx := db.Table("test")
    tx.QueryFields = true
    err = tx.Where("id>1").Find(&ret).Debug().Error
    if err != nil {
        fmt.Println(fmt.Sprintf("select err:%v", err))
    }
}

type Test struct {
    ID   int64  `gorm:"type:bigint(20);column:id;primary_key"`
    Name string `gorm:"type:varchar(5);column:name"`
    Age  int    `gorm:"type:int(11);column:age"`
}

func GetAllFields(info interface{}) string {
    tagName := strings.ToUpper("column")
    var arr []string
    el := reflect.TypeOf(info).Elem()
    for i := 0; i < el.NumField(); i++ {
        structTag := el.Field(i).Tag
        tags := parseTagSetting(structTag)
        if columnName, ok := tags[tagName]; ok && columnName != "-" {
            arr = append(arr, columnName)
        }
    }
    sqlFields := "`" + strings.Join(arr, "`,`") + "`"
    return sqlFields
}

func parseTagSetting(tags reflect.StructTag) map[string]string {
    setting := map[string]string{}
    for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
        if str == "" {
            continue
        }
        tags := strings.Split(str, ";")
        for _, value := range tags {
            v := strings.Split(value, ":")
            k := strings.TrimSpace(strings.ToUpper(v[0]))
            if len(v) >= 2 {
                setting[k] = strings.Join(v[1:], ":")
            } else {
                setting[k] = k
            }
        }
    }
    return setting
}

运行效果:

 

posted on 2022-09-17 14:02  dz45693  阅读(881)  评论(0编辑  收藏  举报

导航