使用errgroup并发查询数据库
参考文章
[Golang]并发编程包之errgroup —— 文中最后那个ctx的坑注意一下!
代码演示
package aerrgrouptests import ( "context" "fmt" "sort" "sync" "testing" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" "gorm.io/driver/mysql" "gorm.io/gorm" ) const ( studentTableName = "student" ) type Student struct { ID uint `gorm:"primaryKey;column:id"` Name string `gorm:"column:name"` Age uint8 `gorm:"column:age"` Email string `gorm:"column:email"` Score float32 `gorm:"column:score"` Notes string `gorm:"column:notes"` } func (s Student) TableName() string { return studentTableName } func genDb() (*gorm.DB, error) { dsn := "root:123@tcp(127.0.0.1:3307)/whw2?charset=utf8mb4&parseTime=True&loc=Local" db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) fmt.Println("db: ", db) return db, err } // 往数据库中加数据 func TestEGP0(t *testing.T) { db, errDB := genDb() require.Equal(t, errDB, nil) for i := 0; i < 100; i++ { currName := fmt.Sprintf("name%d", i) currAge := i currEmail := fmt.Sprintf("email%d", i) currScore := i currNotes := fmt.Sprintf("notes%d", i) currStu := Student{ Name: currName, Email: currEmail, Age: uint8(currAge), Score: float32(currScore), Notes: currNotes, } _ = db.Table(studentTableName).Create(&currStu).Error } } // 不推荐的循环查数据库的方法 ———— 会使接口变慢 func TestEGP1(t *testing.T) { db, errDB := genDb() require.Equal(t, errDB, nil) ids := make([]int, 0) // 数据库的数据不足200 for i := 1; i < 200; i++ { ids = append(ids, i) } students := make([]*Student, 0) // for循环中查询数据库不可取 for _, id := range ids { currStu := new(Student) errStu := db.Table(studentTableName).Where("id = ?", id).First(currStu).Error if errStu != nil { fmt.Println("error: ", errStu, "id: ", id) } else { fmt.Println("currStu: ", currStu) students = append(students, currStu) } } fmt.Println("students: ", len(students), students) } func TestEGP2(t *testing.T) { db, errDB := genDb() require.Equal(t, errDB, nil) ids := make([]int, 0) // 数据库的数据不足200 for i := 1; i < 200; i++ { ids = append(ids, i) } students := make([]*Student, 0) var mutex sync.Mutex // 暂时不用ctx group, _ := errgroup.WithContext(context.Background()) // 控制下并发数量 limitChan := make(chan struct{}, 5) for _, id := range ids { limitChan <- struct{}{} // 防止惰性计算的情况:https://golang.org/doc/faq#closures_and_goroutines currId := id group.Go(func() error { defer func() { <-limitChan }() currStu := new(Student) errStu := db.Table(studentTableName).Where("id = ?", currId).First(currStu).Error // NOTE 模拟错误 if currId == 2 || currId == 3 || currId == 12 { return fmt.Errorf(fmt.Sprintf("发生错误了--%d", currId)) } if errStu != nil { fmt.Println("errStu: ", errStu) return errStu // 子协裎中返回错误 } fmt.Println("currStu: ", currStu) // slice不是线程安全的,最好加上mutex mutex.Lock() defer mutex.Unlock() students = append(students, currStu) return nil }) } errGP := group.Wait() fmt.Println("errGP: ", errGP) // 实际上,只会捕获第一个发生错误的协裎的错误 // NOTE 注意,得到的结果是无序的,如果业务中需要对返回的结果按照id排序的话还需要做一下排序: // 根据id从小到大排序 sort.SliceStable(students, func(i, j int) bool { return students[i].ID < students[j].ID }) fmt.Println("students: ", len(students)) for _, item := range students { fmt.Println("id: ", item.ID) } }
~~~