使用errgroup并发查询数据库

参考文章

[Golang]并发编程包之errgroup —— 文中最后那个ctx的坑注意一下!

errgroup使用方法及适用场景

代码演示

package aerrgrouptests

import (
    "context"
    "fmt"
    "sort"
    "sync"
    "testing"

    "github.com/stretchr/testify/require"
    "golang.org/x/sync/errgroup"
    "gorm.io/driver/mysql"
    "gorm.io/gorm"
)

const (
    studentTableName = "student"
)

type Student struct {
    ID    uint    `gorm:"primaryKey;column:id"`
    Name  string  `gorm:"column:name"`
    Age   uint8   `gorm:"column:age"`
    Email string  `gorm:"column:email"`
    Score float32 `gorm:"column:score"`
    Notes string  `gorm:"column:notes"`
}

func (s Student) TableName() string {
    return studentTableName
}

func genDb() (*gorm.DB, error) {
    dsn := "root:123@tcp(127.0.0.1:3307)/whw2?charset=utf8mb4&parseTime=True&loc=Local"
    db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
    fmt.Println("db: ", db)
    return db, err
}

// 往数据库中加数据
func TestEGP0(t *testing.T) {
    db, errDB := genDb()
    require.Equal(t, errDB, nil)

    for i := 0; i < 100; i++ {
        currName := fmt.Sprintf("name%d", i)
        currAge := i
        currEmail := fmt.Sprintf("email%d", i)
        currScore := i
        currNotes := fmt.Sprintf("notes%d", i)
        currStu := Student{
            Name:  currName,
            Email: currEmail,
            Age:   uint8(currAge),
            Score: float32(currScore),
            Notes: currNotes,
        }
        _ = db.Table(studentTableName).Create(&currStu).Error
    }

}

// 不推荐的循环查数据库的方法 ———— 会使接口变慢
func TestEGP1(t *testing.T) {
    db, errDB := genDb()
    require.Equal(t, errDB, nil)

    ids := make([]int, 0)

    // 数据库的数据不足200
    for i := 1; i < 200; i++ {
        ids = append(ids, i)
    }

    students := make([]*Student, 0)
    // for循环中查询数据库不可取
    for _, id := range ids {
        currStu := new(Student)
        errStu := db.Table(studentTableName).Where("id = ?", id).First(currStu).Error
        if errStu != nil {
            fmt.Println("error: ", errStu, "id: ", id)
        } else {
            fmt.Println("currStu: ", currStu)
            students = append(students, currStu)
        }
    }

    fmt.Println("students: ", len(students), students)
}

func TestEGP2(t *testing.T) {
    db, errDB := genDb()
    require.Equal(t, errDB, nil)

    ids := make([]int, 0)

    // 数据库的数据不足200
    for i := 1; i < 200; i++ {
        ids = append(ids, i)
    }

    students := make([]*Student, 0)

    var mutex sync.Mutex
    // 暂时不用ctx
    group, _ := errgroup.WithContext(context.Background())
    // 控制下并发数量
    limitChan := make(chan struct{}, 5)
    for _, id := range ids {
        limitChan <- struct{}{}
        // 防止惰性计算的情况:https://golang.org/doc/faq#closures_and_goroutines
        currId := id
        group.Go(func() error {
            defer func() {
                <-limitChan
            }()
            currStu := new(Student)
            errStu := db.Table(studentTableName).Where("id = ?", currId).First(currStu).Error
            // NOTE 模拟错误
            if currId == 2 || currId == 3 || currId == 12 {
                return fmt.Errorf(fmt.Sprintf("发生错误了--%d", currId))
            }

            if errStu != nil {
                fmt.Println("errStu: ", errStu)
                return errStu // 子协裎中返回错误
            }
            fmt.Println("currStu: ", currStu)
            // slice不是线程安全的,最好加上mutex
            mutex.Lock()
            defer mutex.Unlock()
            students = append(students, currStu)
            return nil
        })
    }
    errGP := group.Wait()
    fmt.Println("errGP: ", errGP) // 实际上,只会捕获第一个发生错误的协裎的错误

// NOTE 注意,得到的结果是无序的,如果业务中需要对返回的结果按照id排序的话还需要做一下排序:
    // 根据id从小到大排序
    sort.SliceStable(students, func(i, j int) bool {
        return students[i].ID < students[j].ID
    })

    fmt.Println("students: ", len(students))
    for _, item := range students {
        fmt.Println("id: ", item.ID)
    }
}

~~~

posted on 2022-12-13 20:34  江湖乄夜雨  阅读(155)  评论(0编辑  收藏  举报