Go中反射的利用(通用sql封装,字段拼接对应)

在后端接口开发中,往往需要针对某一张表写相对应的增删改查的sql方法,比如我们查询某张表的数据

func (r *DomainDao) GetHostsByModel(model *config.HostsModel, startNum int, pagesize int) ([]*config.HostsModel, error) {

	db := r.Db.GetDB()
	if db == nil {
		return nil, errors.New("db is nil")
	}
	result := make([]*config.HostsModel, 0)
	args := make([]interface{}, 0)
	strSql := "SELECT * FROM table WHERE 1=1 "
	if model != nil {

		if model.Gid != 0 {
			strSql += " AND gid=? "
			args = append(args, model.Gid)
		}
		if len(strings.Trim(model.CdnDomain, " ")) != 0 {
			strSql += " AND cdn_domain=? "
			args = append(args, model.CdnDomain)
		}
		if len(strings.Trim(model.CdnName, " ")) != 0 {
			strSql += " AND cdn_name=? "
			args = append(args, model.CdnName)
		}
		if model.CdnType != 0 {
			strSql += " AND cdn_type=? "
			args = append(args, model.CdnType)
		}
		if model.Master != 0 {
			strSql += " AND master=? "
			args = append(args, model.Master)
		}
		if model.Mode != 0 {
			strSql += " AND mode=? "
			args = append(args, model.Mode)
		}
		if model.AuthOutTime != 0 {
			strSql += " AND auth_out_time=? "
			args = append(args, model.AuthOutTime)
		}
		if model.Enable != 0 {
			strSql += " AND enable=? "
			args = append(args, model.Enable)
		}
		if len(strings.Trim(model.AuthKey, " ")) != 0 {
			strSql += " AND auth_key=? "
			args = append(args, model.AuthKey)
		}
		if len(strings.Trim(model.Reserve, " ")) != 0 {
			strSql += " AND reserve=? "
			args = append(args, model.Reserve)
		}
	}
	if startNum >= 0 && pagesize > 0 {
		strSql += " LIMIT ?,?  "
		args = append(args, startNum)
		args = append(args, pagesize)
	}

	var err error
	if len(args) > 0 {
		_, err = db.ExecSelect(strSql, &result, args...)
	} else {
		_, err = db.ExecSelect(strSql, &result)
	}
	return result, err
}

  出于封装的目的,我们常常将查询条件动态封装拼接以满足我们不同的查询业务,但每张表的字段都是不同的,所有每张表都像这样写就显得很冗余低效了,所以需要对其再封装

通过反射进行封装:

//GetSelectTableSql 获取通用查询sql
/**
 * @name: GetSelectTableSql
 * @Descripttion: 获取通用查询sql
 * @param 	{model} 表结构对应条件结构体
 * @param 	{tableName} 表名称
 * @param 	{startNum} 开始下标位置
 * @param 	{pagesize} 条数
 * @return 	{string} 拼装成的sql {[]interface{}} 动态参数
 */
func (db *MySQL) GetSelectTableSql(model interface{}, tableName string, startNum int, pagesize int) (string, []interface{}) {
	strSql := " SELECT * FROM " + tableName + " WHERE 1=1 "
	args := make([]interface{}, 0)
	if model != nil {
		refValue := reflect.ValueOf(model)
		refType := reflect.TypeOf(model)
		fieldCount := refValue.NumField()

		for i := 0; i < fieldCount; i++ {
			fieldType := refType.Field(i)
			fieldValue := refValue.Field(i)
			gormTag := fieldType.Tag.Get("gorm")
			//对应的列名称(tag中定义)
			gormTag = strings.Replace(gormTag, "column:", "", -1)
			gormTags := strings.Split(gormTag, ";")
			if len(gormTags) > 0 {
				column := gormTags[0]
				isadd := false
				switch fieldType.Type.String() {
				case "string":
					if fieldValue.Len() > 0 {
						isadd = true
					}
				case "int", "int8", "int16", "int32", "int64":
					if fieldValue.Int() != 0 {
						isadd = true
					}
				case "time.Time":
					valTime := fieldValue.Interface().(time.Time)
					if !tool.CheckIsDefaultTime(valTime) {
						isadd = true
					} else {
						if column == "delete_at" {
							strSql += " AND delete_at IS NULL "
						}
					}
				}
				if isadd {
					strSql += " AND " + column + " = ? "
					args = append(args, fieldValue.Interface())
				}
			}
		}
	}
	if startNum >= 0 && pagesize > 0 {
		strSql += " LIMIT ?,?  "
		args = append(args, startNum)
		args = append(args, pagesize)
	}
	return strSql, args
}

//GetSelectTableCountSql 获取通用查询的数量sql
/**
 * @name: GetSelectTableCountSql
 * @Descripttion: 获取通用查询的数量sql
 * @param 	{model} 表结构对应条件结构体
 * @param 	{tableName} 表名称
 * @return 	{string} 拼装成的sql {[]interface{}} 动态参数
 */
func (db *MySQL) GetSelectTableCountSql(model interface{}, tableName string) (string, []interface{}) {
	strSql := " SELECT COUNT(1) FROM " + tableName + " WHERE 1=1 "
	args := make([]interface{}, 0)
	if model != nil {
		refValue := reflect.ValueOf(model)
		refType := reflect.TypeOf(model)
		fieldCount := refValue.NumField()

		for i := 0; i < fieldCount; i++ {
			fieldType := refType.Field(i)
			fieldValue := refValue.Field(i)
			gormTag := fieldType.Tag.Get("gorm")
			//对应的列名称(tag中定义)
			gormTag = strings.Replace(gormTag, "column:", "", -1)
			gormTags := strings.Split(gormTag, ";")
			if len(gormTags) > 0 {
				column := gormTags[0]
				isadd := false
				switch fieldType.Type.String() {
				case "string":
					if fieldValue.Len() > 0 {
						isadd = true
					}
				case "int", "int8", "int16", "int32", "int64":
					if fieldValue.Int() != 0 {
						isadd = true
					}
				case "time.Time":
					valTime := fieldValue.Interface().(time.Time)
					if !tool.CheckIsDefaultTime(valTime) {
						isadd = true
					} else {
						if column == "delete_at" {
							strSql += " AND delete_at IS NULL "
						}
					}
				}
				if isadd {
					strSql += " AND " + column + " = ? "
					args = append(args, fieldValue.Interface())
				}
			}
		}
	}
	return strSql, args
}

//InsertTable 通用新增表
/**
 * @name: InsertTable
 * @Descripttion: 通用新增表
 * @param 	{model} 表结构对应结构体数据
 * @param 	{tableName} 表名称
 * @return 	{int64} 新增gid {error} 错误
 */
func (db *MySQL) InsertTable(model interface{}, tableName string) (int64, error) {
	if model == nil {
		return -1, errors.New("model is nil")
	}
	strSql := "insert " + tableName
	args := make([]interface{}, 0)
	strSql += " ("
	refValue := reflect.ValueOf(model)
	refType := reflect.TypeOf(model)
	fieldCount := refValue.NumField()
	for i := 0; i < fieldCount; i++ {
		fieldType := refType.Field(i)
		fieldValue := refValue.Field(i)
		gormTag := fieldType.Tag.Get("gorm")
		//对应的列名称(tag中定义)
		gormTag = strings.Replace(gormTag, "column:", "", -1)
		gormTags := strings.Split(gormTag, ";")
		if len(gormTags) > 0 {
			column := gormTags[0]
			isadd := false
			switch fieldType.Type.String() {
			case "string":
				if fieldValue.Len() > 0 {
					isadd = true
				}
			case "int", "int8", "int16", "int32", "int64":
				if fieldValue.Int() != 0 {
					isadd = true
				}
			case "time.Time":
				valTime := fieldValue.Interface().(time.Time)
				if !tool.CheckIsDefaultTime(valTime) {
					isadd = true
				}
			}
			if isadd {
				strSql += column + ","
				args = append(args, fieldValue.Interface())
			}
		}
	}

	if len(args) < 1 {
		return -1, errors.New("args is nil")
	}
	insertKeyStr := strSql[0:len(strSql)-1] + ") "
	insertValueStr := " values ("
	for i := 0; i < len(args); i++ {
		insertValueStr += "?"
		if i != len(args)-1 {
			insertValueStr += ","
		}
	}
	insertValueStr += ")"
	insertSql := insertKeyStr + insertValueStr
	result, err := db.Exec(insertSql, args...)
	return result, err
}

//UpdateTableByColumn 通用修改
/**
 * @name: UpdateTableByColumn
 * @Descripttion: 通用修改通过表结构某一字段
 * @param 	{model} 表结构对应结构体数据
 * @param 	{tableName} 表名称
 * @param 	{mapcolumn} 根据表字段修改 默认通过gid
 * @return 	{int64} 新增gid {error} 错误
 */
func (db *MySQL) UpdateTableByColumn(model interface{}, tableName string, mapcolumn map[string]interface{}) (int64, error) {
	if model == nil {
		return -1, errors.New("model is nil")
	}
	strSql := "update " + tableName + " SET "

	args := make([]interface{}, 0)

	refValue := reflect.ValueOf(model)
	refType := reflect.TypeOf(model)
	fieldCount := refValue.NumField()

	var gid int64

	for i := 0; i < fieldCount; i++ {
		fieldType := refType.Field(i)
		fieldValue := refValue.Field(i)
		gormTag := fieldType.Tag.Get("gorm")
		//对应的列名称(tag中定义)
		gormTag = strings.Replace(gormTag, "column:", "", -1)
		gormTags := strings.Split(gormTag, ";")
		if len(gormTags) > 0 {
			column := gormTags[0]
			if column == "gid" {
				gid = fieldValue.Interface().(int64)
			}
			isadd := false
			switch fieldType.Type.String() {
			case "string":
				if fieldValue.Len() > 0 {
					isadd = true
				}
			case "int", "int8", "int16", "int32", "int64":
				if fieldValue.Int() != 0 {
					isadd = true
				}
			case "time.Time":
				valTime := fieldValue.Interface().(time.Time)
				if !tool.CheckIsDefaultTime(valTime) {
					isadd = true
				} else {
					if column == "delete_at" {
						strSql += "delete_at=NULL,"
					}
				}
			}
			if isadd {
				strSql += column + "=?,"
				args = append(args, fieldValue.Interface())
			}
		}
	}

	if len(args) < 1 {
		return -1, errors.New("args is nil")
	}

	//默认通过gid修改
	if mapcolumn == nil {
		if gid == 0 {
			return -1, errors.New("update where is nil")
		} else {
			mapcolumn = make(map[string]interface{})
			mapcolumn["gid"] = gid
		}
	}

	updateStr := strSql[0:len(strSql)-1] + " where 1=1"

	for k, v := range mapcolumn {
		updateStr += " AND " + k + "=? "
		args = append(args, v)
	}

	result, err := db.Exec(updateStr, args...)
	return result, err
}

  

 

调用:

//新增
func (r *RatetemplateDao) InsertRatetemplate(model *config.RatetemplateModel) (int64, error) {
	db := r.Db.GetDB()
	if db == nil {
		return -1, errors.New(dao.DbErrMsg)
	}
	if model == nil {
		return -1, errors.New(daoErrMsg)
	}

	var intoModel interface{}
	if model == nil {
		intoModel = nil
	} else {
		intoModel = *model
	}
	result, err := db.InsertTable(intoModel, ratetemplateTableName)
	return result, err
}

//修改
func (r *RatetemplateDao) UpdateRatetemplateById(model *config.RatetemplateModel) (int64, error) {
	db := r.Db.GetDB()
	if db == nil {
		return -1, errors.New(dao.DbErrMsg)
	}
	if model == nil {
		return -1, errors.New(daoErrMsg)
	}
	var intoModel interface{}
	if model == nil {
		intoModel = nil
	} else {
		intoModel = *model
	}
	result, err := db.UpdateTableByColumn(intoModel, ratetemplateTableName, nil)
	return result, err
}

//查询
func (r *RatetemplateDao) GetRatetemplateByModel(model *config.RatetemplateModel, startNum int, pagesize int) ([]*config.RatetemplateModel, error) {

	db := r.Db.GetDB()
	if db == nil {
		return nil, errors.New(dao.DbErrMsg)
	}
	result := make([]*config.RatetemplateModel, 0)
	var intoModel interface{}
	if model == nil {
		intoModel = nil
	} else {
		intoModel = *model
	}

	strSql, args := db.GetSelectTableSql(intoModel, ratetemplateTableName, startNum, pagesize)
	var err error
	if len(args) > 0 {
		_, err = db.ExecSelect(strSql, &result, args...)
	} else {
		_, err = db.ExecSelect(strSql, &result)
	}
	return result, err
}

  传入的参数对象结构体需要在tag里面定义相应的解析值column:

type RatetemplateModel struct {
	Gid          int64     `sql:"Gid" gorm:"column:gid;primary_key;auto_increment;comment:'唯一标识';type:bigint(20)" json:"gid"`
	CdnGid       int64     `sql:"CdnGid" gorm:"column:cdn_gid;not null;comment:'拉流域名gid';type:bigint(20)" json:"cdn_gid"`
	AppName      string    `sql:"AppName" gorm:"column:app_name;not null;comment:'业务线名称(live)';type:varchar(32)" json:"app_name"`
}

  这样我们就不需要再写每张表的常规增删改查的sql语句了,而且当有大量的单一业务时,可以写一个代码生成工具根据数据库来生成这些代码

ps:反射的时候相应的传入对象结构体不能是指针类型的,如外层业务传入的是指针类型,需要转换为值类型,关键代码:
intoModel = *model,以上内容只是方便记录理解反射逻辑,某些具体sql执行方法未贴出

 

posted @ 2021-09-07 20:12  MrZhaoLin  阅读(396)  评论(0编辑  收藏  举报