WaitGroup的用法和原理、常见错误

WaitGroup的介绍

WaitGroup就是package sync用来做任务编排的一个并发原语,这个要解决的就是并发-等待的问题:现有一个goroutine A在检查点(chaeckpoint)等待一组goroutine全部完成,如果在执行任务的这些goroutine还没有全部完成,那么goroutine A就会阻塞在检查点,直到所有的goroutine都完成后才能继续执行。

WaitGroup的用法

创建一个WaitGroup对象后,可以使用Add方法向计数器添加值,然后使用Done方法从计数器减去值。最后,Wait方法将阻塞当前goroutine,直到计数器归零。

WaitGroup的实现原理

WaitGroup的实现原理比较简单。它有一个计数器,初始值为零。当调用Add方法时,它会将计数器加上传入的值。每次调用Done方法时,计数器减去一个值。最后,在调用Wait方法时,程序将阻塞,直到计数器归零。

看一下Add方法的逻辑,Add方法主要操作的是state的计数部分。可以为计数值增加一个delta值,内部通过原子操把这个值加到计数值上。但这个delta也可以是负数,相当于为计数值减去了一个值,Done方法内部其实就是通过Add(-1) 实现的。

看看WaitGroup的数据结构,它包括了一个noCopy的辅助字段,一个state1记录WaitGroup状态的数组。

noCopy,表示不可复制,辅助字段,主要就是辅助vet工具检查是否通过copy赋值这个WaitGroup实例。
state1,一个具有复合意义的字段,包含WaitGroup的计数、waiter数和信号量。

注意:如果想要自己定义的数据结构不被复制使用,或不能通过vet工具检查出复制使用的报警,就可以通过嵌入noCopy这个数据类型来实现。

WaitGroup的源代码:

复制代码
type WaitGroup struct {
    counter int64
    mutex   sync.Mutex
    waiters sync.Cond
}

func (wg *WaitGroup) Add(delta int) {
    wg.mutex.Lock()
    defer wg.mutex.Unlock()
    wg.counter += int64(delta)
}

func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

func (wg *WaitGroup) Wait() {
    wg.mutex.Lock()
    defer wg.mutex.Unlock()
    for wg.counter > 0 {
        wg.waiters.Wait()
    }
}
复制代码

在这个实现中,WaitGroup使用了一个互斥锁和一个条件变量,以确保在多个goroutine之间同步计数器的值。Add方法和Done方法都使用互斥锁来保护计数器的访问。Wait方法使用条件变量来等待计数器归零。

WaitGroup的实践

Demo1:使用WaitGroup等待多个goroutine执行完毕后再继续执行主函数

复制代码
package main

import (
    "fmt"
    "sync"
)

func worker(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Worker %d starting\n", id)
    // 模拟工作时间
    for i := 0; i < 100000000; i++ {

    }
    fmt.Printf("Worker %d done\n", id)
}

func main() {
    var wg sync.WaitGroup
    for i := 1; i <= 5; i++ {
        wg.Add(1)
        go worker(i, &wg)
    }
    wg.Wait()
    fmt.Println("All workers done")
}
View Code
复制代码

说明:该程序会启动5个goroutine,每个goroutine都会打印出自己的ID并模拟一段工作时间后结束。主函数会等待所有goroutine都执行完毕后再打印"All workers done"。

Demo2:使用WaitGroup等待多个http请求完成后再继续执行主函数

复制代码
package main

import (
    "fmt"
    "io/ioutil"
    "net/http"
    "sync"
)

func fetch(url string, wg *sync.WaitGroup) {
    defer wg.Done()
    resp, err := http.Get(url)
    if err != nil {
        fmt.Println(err)
        return
    }
    defer resp.Body.Close()
    body, err := ioutil.ReadAll(resp.Body)
    if err != nil {
        fmt.Println(err)
        return
    }
    fmt.Printf("Fetched %s, Body size: %d\n", url, len(body))
}

func main() {
    var wg sync.WaitGroup
    urls := []string{
        "https://www.baidu.com",
        "https://www.google.com",
        "https://www.bing.com",
    }
    for _, url := range urls {
        wg.Add(1)
        go fetch(url, &wg)
    }
    wg.Wait()
    fmt.Println("All fetches done")
}
View Code
复制代码

说明:该程序会启动3个goroutine,每个goroutine都会向一个URL发送http请求并打印出返回的body大小。主函数会等待所有goroutine都执行完毕后再打印"All fetches done"。

Demo3:使用WaitGroup实现协程池

复制代码
package main

import (
    "fmt"
    "sync"
)

const (
    workerCount = 5
    taskCount   = 20
)

func worker(id int, tasks <-chan int, wg *sync.WaitGroup) {
    defer wg.Done()
    for task := range tasks {
        fmt.Printf("Worker %d processing task %d\n", id, task)
        // 模拟工作时间
        for i := 0; i < 100000000; i++ {

        }
    }
}

func main() {
    var wg sync.WaitGroup
    tasks := make(chan int, taskCount)
    for i := 1; i <= workerCount; i++ {
        wg.Add(1)
        go worker(i, tasks, &wg)
    }
    for i := 1; i <= taskCount; i++ {
        tasks <- i
    }
    close(tasks)
    wg.Wait()
    fmt.Println("All tasks done")
}
View Code
复制代码

说明:该程序会创建一个由5个goroutine构成的协程池,并向任务通道中发送20个任务。每个goroutine会从任务通道中获取一个任务并处理,直到任务通道关闭。主函数会等待所有协程都执行完毕后再打印"All tasks done"。

WaitGroup的常见错误

常见问题一:计数器设置为负值

两种情况会导致计数器设置为负数

1)调用Add的时候传递一个负数。

复制代码
func main() {
    var wg sync.WaitGroup
    wg.Add(10)
    
    wg.Add(-10)// 将-10作为参数调用Add,计数值被设置为0
    
    wg.Add(-1)// 将-1作为参数调用Add,如果加上-1计数值就会变为负数。这是不对的,所以会触发panic
}
复制代码

如果能保证当前的计数器加上这个负数后还是大于等于0,没有问题,否则就会对导致panic。

2)调用Done方法的次数过多,超过了WaitGroup的计数值

使用WaitGroup的正确姿势是,预先确定好WaitGroup的计数值,然后调用相同次数的Done完成相应的任务。例如,在WaitGroup变量声明之后,就立即设置它的计数值,或者在goroutine启动之前增加1,然后在goroutine中调用Done。

如果没有遵循这些规则,就很可能会导致Done方法调用的次数和计数值不一致,进而造成死锁(Done调用次数比计数值少)或者panic(Done调用次数比计数值多)。

比如像以下情况,多调用了一次Donef方法后,会导致计数值为负,所以程序运行到这一行出现panic:

复制代码
func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    
    wg.Done()
    
    wg.Done()
}
复制代码

常见问题二:不期望的Add时机

使用WaitGroup一定要遵守的原则就是,等所有的Add方法调用之后再调用Wait,否则就可能导致panic或者不期望的结果。

构造一个场景:只有部分的Add/Done执行完后,Wait就返回。这里我们启动四个goroutine,每个goroutine内部调用Add(1)然后调用Done(),主goroutine调用Wait等待任务完成。

复制代码
func main() {
    var wg sync.WaitGroup
    go dosomething(100, &wg) // 启动第一个goroutine
    go dosomething(110, &wg) // 启动第二个goroutine
    go dosomething(120, &wg) // 启动第三个goroutine
    go dosomething(130, &wg) // 启动第四个goroutine
    
    wg.Wait() // 主goroutine等待完成
    fmt.Println("Done")
}

func dosomething(millisecs time.Duration, wg *sync.WaitGroup) {
    duration := millisecs * time.Millisecond
    time.Sleep(duration) // 故意sleep一段时间
    
    wg.Add(1)
    fmt.Println("后台执行,duration:",duration)
    wg.Done()
}
复制代码

原本期望的是,等四个goroutine都执行接受后输出Done的信息,但是它的错误之处在于,将WaitGroup.Add方法的调用放在子gorotuine中。等主goroutine调用Wait的时候,因为四个任务goroutine一开始都休眠,所以可能WaitGroup的Add方法还没有被调用,WaitGroup的计数还是0,所以它并没有等待四个子goroutine执行完毕才继续执行,而是立刻执行了下一步。

导致这个错误的原因就是,没有遵循先完成所有的Add 之后才Wait。

要解决这个问题,第一个方法是,预先设置计数值:

复制代码
func main() {
    var wg sync.WaitGroup
    wg.Add(4) // 预先设定WaitGroup的计数值
    
    go dosomething(100, &wg) // 启动第一个goroutine
    go dosomething(110, &wg) // 启动第二个goroutine
    go dosomething(120, &wg) // 启动第三个goroutine
    go dosomething(130, &wg) // 启动第四个goroutine
    
    wg.Wait() // 主goroutine等待完成
    fmt.Println("Done")
}

func dosomething(millisecs time.Duration, wg *sync.WaitGroup) {
    duration := millisecs * time.Millisecond
    time.Sleep(duration) // 故意sleep一段时间
    
    fmt.Println("后台执行,duration:",duration)
    wg.Done()
}
复制代码

第二个方法是在启动子goroutine之前才调用Add:

复制代码
func main() {
    var wg sync.WaitGroup
    
    go dosomething(100, &wg) // 调用方法,把计数值加1,并启动任务goroutine
    go dosomething(110, &wg) // 调用方法,把计数值加1,并启动任务goroutine
    go dosomething(120, &wg) // 调用方法,把计数值加1,并启动任务goroutine
    go dosomething(130, &wg) // 调用方法,把计数值加1,并启动任务goroutine
    
    wg.Wait() // 主goroutine等待,代码逻辑保证四次Add(1)都已经执行完了
    fmt.Println("Done")
}

func dosomething(millisecs time.Duration, wg *sync.WaitGroup) {
    wg.Add(1) // 计数值加1,再启动goroutine
    
    go func() {
        duration := millisecs * time.Millisecond
        time.Sleep(duration) // 故意sleep一段时间
        fmt.Println("后台执行,duration:",duration)
        wg.Done()
    }()
}
复制代码

结论:无论哪种解决方案,都要保证所有的Add方法是在Wait方法之前被调用的。

常见问题三:前一个Wait还没有结束就重用WaitGroup

举例:在田径比赛的百米小组赛中,需要把选手分为计组,一组选手比赛之后,就可以进行下一组了。为了确保两组比赛上没有冲突,在模型化这个场景的时候,可以使用WaitGroup。

WaitGroup 等一组比赛的所有选手都跑完后 5 分钟,才开始下一组比赛。下一组比赛还可以使用这个 WaitGroup 来控制,因为 WaitGroup 是可以重用的。只要 WaitGroup 的计数值恢复到零值的状态,那么它就可以被看作是新创建的 WaitGroup,被重复使用。如果在WaitGroup的计数值还没有恢复到零值的时候就重用,就会导致程序panic。假如初始设置WaitGroup的计数值为1,启动一个goroutine先调用Done方法,接着就调用Add方法,Add方法有可能和主goroutine并发执行。

复制代码
func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go func() {
        time.Sleep(time.Millisecond)
        wg.Done() // 计数器减1
        wg.Add(1) // 计数值加1
    }()
    wg.Wait() // 主goroutine等待,有可能和第7行并发执行
}
复制代码

在第6行虽然让WaitGroup的计数恢复到0,但是因为第9行有个waiter在等待,如果等待Wait的goroutine,刚被唤醒就和Add调用(第7行)由并发执行的冲突,所以就会出现panic。

结论:WaitGroup虽然可以重用,前提是必须等到上一轮的Wait完成之后,才能重用WaitGroup执行下一轮的Add/Wait,如果在Wait还没执行完的时候就调用下一轮Add方法,就有可能出现panic。

posted @   李若盛开  阅读(689)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
点击右上角即可分享
微信分享提示