xorm+postgreSQL实现事物与in查询

model包

package model

import (
    "github.com/xormplus/xorm"
    "whw_go_scripts/a_xorm_tests/utils"
)

const ClassTableName = "class"

// 新建一个班级表,用于连表查询的测试
type Class struct {
    CId  string `json:"cid" xorm:"varchar(255) pk 'cid' comment('班级cid')"`
    Name string `json:"name" xorm:"varchar(25) notnull 'name' comment('班级名称')"`
}

func (c *Class) TableName() string {
    return ClassTableName
}

// 直接使用引擎 insert
func (c *Class) InsertClass() (int64, error) {
    count, err := utils.DBEngine.Table(ClassTableName).InsertOne(c)
    return count, err
}

// 使用session insert
func (c *Class) InsertClassInSession(session *xorm.Session) (int64, error) {
    count, err := session.Table(ClassTableName).InsertOne(c)
    return count, err
}

// 使用session get
func GetClassByCIdInSession(session *xorm.Session, cid string) (*Class, bool, error) {
    c := new(Class)
    has, err := session.Where("cid = ?", cid).Get(c)
    return c, has, err
}

// 使用session update
func (c *Class) UpdateClassByCIdInSession(session *xorm.Session) (int64, error) {
    count, err := session.Table(ClassTableName).AllCols().Where("cid = ?", c.CId).Update(c)
    return count, err
}
class.go
package model

import (
    "fmt"
    "github.com/xormplus/xorm"
    "strings"
    "time"
    "whw_go_scripts/a_xorm_tests/utils"
)

const StudentTableName = "student"

type Student struct {
    // 主键 字段名设置为 sid
    SId     string  `json:"sid" xorm:"varchar(255) pk 'sid' comment('学生ID')"`
    Name    string  `json:"name" xorm:"varchar(25) notnull 'name' comment('学生姓名')"`
    Age     int     `json:"age" xorm:"notnull 'age' comment('学生年龄')"`
    Score   float64 `json:"score" xorm:"notnull 'score' comment('学生成绩')"`
    ClassId string  `json:"class_id" xorm:"notnull 'class_id' comment('学生所在班级')"`
    // 创建时间与修改时间
    Created time.Time `json:"created" xorm:"created notnull"`
    Updated time.Time `json:"updated" xorm:"updated notnull"`
}

func (s *Student) TableName() string {
    return StudentTableName
}

// 插入遇到主键冲突时不做任何操作
func InsertStudentOnConflictDoNothing(sid, name string, age int, score float64, classID string) error {
    currentTime := utils.GetCurrentTime()
    insertSQL := fmt.Sprintf(`insert into %s (sid,name,age,score,class_id,created,updated) 
        values (?,?,?,?,?,?,?) on conflict(sid) do nothing`, StudentTableName)
    _, err := utils.DBEngine.SQL(insertSQL, sid, name, age, score, classID, currentTime, currentTime).Execute()
    if err != nil {
        return err
    }
    return nil
}

// 插入遇到主键冲突时修改指定的字段: age、score、classID、updated
func InsertStudentOnConflictDoUpdate(sid, name string, age int, score float64, classID string) error {
    currentTime := utils.GetCurrentTime()
    insertSQL := fmt.Sprintf(`insert into %s (sid,name,age,score,class_id,created,updated)
        values(?,?,?,?,?,?,?) on conflict(sid) do update set age=excluded.age,score=excluded.score,
        class_id=excluded.class_id,updated=excluded.updated`, StudentTableName)
    _, err := utils.DBEngine.SQL(insertSQL, sid, name, age, score, classID, currentTime, currentTime).Execute()
    if err != nil {
        return err
    }
    return nil
}

// 直接使用引擎 insert
func (s *Student) InsertStudent() (int64, error) {
    count, err := utils.DBEngine.Table(StudentTableName).InsertOne(s)
    return count, err
}

// 使用session insert
func (s *Student) InsertStudentInSession(session *xorm.Session) (int64, error) {
    count, err := session.Table(StudentTableName).InsertOne(s)
    return count, err
}

// 使用session get
func GetStudentBySIdInSession(session *xorm.Session, sid string) (*Student, bool, error) {
    s := new(Student)
    has, err := session.Table(StudentTableName).Where("sid = ?", sid).Get(s)
    return s, has, err
}

// 使用session update
func (s *Student) UpdateStudentBySIdInSession(session *xorm.Session) (int64, error) {
    count, err := session.Table(StudentTableName).AllCols().Where("sid = ?", s.SId).Update(s)
    return count, err
}

// 查询满足条件的学生的数量
func GetCountInSession(session *xorm.Session, query interface{}, args ...interface{}) (int64, error) {
    return session.Where(query, args...).Count(&Student{})
}

// 查询满足条件的学生的信息
func GetStudentRecordsByInOperationInSession(session *xorm.Session, query interface{}, args ...interface{}) ([]Student, error) {
    records := make([]Student, 0)
    // TODO 根据created字段排序并只取前100条记录
    statement := session.Table(StudentTableName).Desc("created")
    statement = statement.Limit(100)
    // TODO 注意这里要传 records的指针!!!
    _, err := statement.Where(query, args...).FindAndCount(&records)
    return records, err
}

// in查询:1 (不灵活)
// TODO 测试文件中,在GetCountInSession方法中加上in查询即可

// in查询:2 —— 返回 map[string]string 的列表
func GetStudentMapListBySIDListInSession(session *xorm.Session, sidList []string) ([]map[string]string, error) {
    // 重新构建一下sidList使其契合于 postgreSQL的查询
    newSIDList := make([]string, 0)
    for _, sid := range sidList {
        newSID := fmt.Sprintf("'" + sid + "'")
        newSIDList = append(newSIDList, newSID)
    }
    // 构建查询语句
    queryString := fmt.Sprintf(`SELECT sid,name,age,score from student WHERE sid in (%v)`, strings.Join(newSIDList, ","))
    retList, err := session.QueryString(queryString)
    if err != nil {
        return nil, err
    }
    return retList, err
}
student.go
package model

import (
    "fmt"
    "github.com/stretchr/testify/require"
    "testing"
    "whw_go_scripts/a_xorm_tests/utils"
)

// 创建班级表
func TestCreateClassTable(t *testing.T) {
    engine := utils.DBEngine
    err := engine.Sync2(Class{})
    require.Equal(t, err, nil)
}

// 创建学生表
func TestCreateStudentTable(t *testing.T) {
    engine := utils.DBEngine
    err := engine.Sync2(Student{})
    require.Equal(t, err, nil)
}

// insert class
func TestInsertClass(t *testing.T) {
    // 使用engine的方法
    //classModel1 := Class{
    //    CId: "cid1",
    //    Name: "class1",
    //}
    //count1, err1 := classModel1.InsertClass()
    //fmt.Println("count1: ",count1, "err1: ", err1)

    // 使用session的方法
    session := utils.DBEngine.NewSession()
    defer session.Close()
    classModel2 := Class{
        CId:  "cid2",
        Name: "class2",
    }
    count2, err2 := classModel2.InsertClassInSession(session)
    fmt.Println("count2: ", count2, "err2: ", err2)

}

// get class
func TestGetClass(t *testing.T) {
    session := utils.DBEngine.NewSession()
    defer session.Close()
    classModel, has, err := GetClassByCIdInSession(session, "cid1")
    fmt.Println("classModel: ", classModel, "has: ", has, "err: ", err)
}

// update class
func TestUpdateClass(t *testing.T) {
    session := utils.DBEngine.NewSession()
    defer session.Close()
    classModelNew := Class{
        CId:  "cid1",
        Name: "cidnannnnzxxxxx",
    }
    count, err := classModelNew.UpdateClassByCIdInSession(session)
    fmt.Println("count: ", count, "err: ", err)

}

// insert student
func TestInsertStudent(t *testing.T) {
    session := utils.DBEngine.NewSession()
    defer session.Close()
    StudentModel := Student{
        SId:     "s4",
        Name:    "naruto4",
        Age:     22,
        Score:   99,
        ClassId: "cid1",
    }
    count1, err1 := StudentModel.InsertStudentInSession(session)
    fmt.Println("count1: ", count1, "err1: ", err1)
}

// in操作1 —— 不灵活
func TestIn1(t *testing.T) {
    session := utils.DBEngine.NewSession()
    defer session.Close()
    count, err := GetCountInSession(session, `class_id=? and sid in (?,?,?)`, "cid1", "s1", "s2", "s3")
    if err != nil {
        panic(err)
    }
    fmt.Println("count>>>> ", count)

    records, err2 := GetStudentRecordsByInOperationInSession(session, `class_id=? and sid in (?,?,?)`, "cid1", "s1", "s2", "s3")
    if err2 != nil {
        panic(err2)
    }
    fmt.Println("records: ", records)
}

// in操作2 —— 返回 map[string]string
func TestIn2(t *testing.T) {
    utils.DBEngine.ShowSQL()
    session := utils.DBEngine.NewSession()
    defer session.Close()
    sidList := []string{"s1", "s2", "s3"}
    ret, err := GetStudentMapListBySIDListInSession(session, sidList)
    if err != nil {
        panic(err)
    }
    fmt.Println("ret: ", ret)
    for _, item := range ret {
        fmt.Println("item: ", item)
    }
}
model_test.go

utils包

package utils

import (
    _ "github.com/lib/pq"
    "github.com/xormplus/xorm"
    "time"
)

const LocalConn = "postgresql://postgres:password@127.0.0.1:5433/whw_test_db?sslmode=disable;"

var DBEngine *xorm.Engine
var errNewEngine error
var TimeLocation *time.Location

func init() {
    // 初始化ORM引擎
    DBEngine, errNewEngine = xorm.NewEngine("postgres", LocalConn)
    if errNewEngine != nil {
        panic("初始化数据库错误" + errNewEngine.Error())
    }
    // 设置数据库连接池参数
    DBEngine.SetMaxOpenConns(5)
    DBEngine.SetMaxIdleConns(2)
    // 最长链接时间 1.16版本以后才有的
    DBEngine.SetConnMaxLifetime(10 * time.Minute)

    // 初始化时区
    var errTimeLocation error
    TimeLocation, errTimeLocation = time.LoadLocation("Asia/Shanghai")
    if errTimeLocation != nil {
        panic(errTimeLocation)
    }
}

// 获取当前时区的currentTime
func GetCurrentTime() time.Time {
    return time.Now().In(TimeLocation)
}
utils.go

main函数

package main

import (
    "errors"
    "fmt"
    "github.com/xormplus/xorm"
    "whw_go_scripts/a_xorm_tests/model"
    "whw_go_scripts/a_xorm_tests/utils"
)

func main() {

    // 1、不在事务中执行
    //session := utils.DBEngine.NewSession()
    //defer session.Close()
    //resp, err := handleInTransaction(session)
    //if err != nil {
    //    panic("err: " + err.Error())
    //}
    //fmt.Println("resp: ", resp)

    // 2、事务中执行
    resp, err := utils.DBEngine.Transaction(handleInTransaction)
    if err != nil {
        panic("err: " + err.Error())
    }
    fmt.Println("resp: ", resp)
}

func handleInTransaction(session *xorm.Session) (interface{}, error) {
    resp := make(map[string]interface{})
    // 1、获取学生
    stuModel, _, errGetStu := model.GetStudentBySIdInSession(session, "sid1")
    if errGetStu != nil {
        return nil, errGetStu
    }
    fmt.Println("stuModel: ", stuModel)
    // 2、获取班级
    classModel, _, errGetClass := model.GetClassByCIdInSession(session, "cid1")
    if errGetClass != nil {
        return nil, errGetStu
    }
    fmt.Println("classModel: ", classModel)
    // 3、更新班级
    classModel.Name = "class1_test1xxxxx123412"
    _, errUpdateClass := classModel.UpdateClassByCIdInSession(session)
    if errUpdateClass != nil {
        return nil, errUpdateClass
    }
    fmt.Println("classModelUpdate: ", classModel)
    // 4、更新学生
    stuModel.Score = 89
    stuModel.Age = 32
    // TODO 模拟中途返回错误~~~~~
    return nil, errors.New("测试返回错误!!!!!")
    _, errUpdateStu := stuModel.UpdateStudentBySIdInSession(session)
    if errUpdateStu != nil {
        return nil, errUpdateStu
    }
    resp["data"] = "success"
    return resp, nil
}
main.go

~~~

posted on 2022-04-04 20:28  江湖乄夜雨  阅读(262)  评论(0编辑  收藏  举报