Sync之WaitGroup模块源码分析

前言

WaitGroup和channel一样,也是Golang应用开发过程中经常使用的并发控制技术,不过它和channel实现的机制不一样,它是使用信号量来控制的。

使用示例

package main

import (
	"fmt"
	"sync"
	"time"
)

func main() {
	wg := sync.WaitGroup{}
	wg.Add(10) //Add的数量必须和for循环的次数保持一致
	for i := 0; i < 10; i++ {
		go func(i int) {
			time.Sleep(1 * time.Second)
			fmt.Println(i)
			wg.Done()
		}(i)
	}
	wg.Wait()
	fmt.Println("done")
}

结构体

这里以go1.19版本为例:

type WaitGroup struct {
	noCopy noCopy
	state1 uint64
	state2 uint32
}

noCopy:golang的一种防拷贝技术,如果有被拷贝则会报错,所以我们在初始化waitGroup实例后就是一个全局对象,不能被拷贝

state1: 用于存放任务计数器和等待者计数器

state2:信号量的地址,和select的结构体一样,也用到了信号量技术来控制线程的阻塞和唤醒

结构体实现了三个方法:Add()、Done()、Wait(),下面我看来分别看看它们的作用。

Add()方法

func (wg *WaitGroup) Add(delta int) {
	// 获取state1和semap的指针地址
	statep, semap := wg.state()
	if race.Enabled {
		_ = *statep // trigger nil deref early
		if delta < 0 {
			// Synchronize decrements with Wait.
			race.ReleaseMerge(unsafe.Pointer(wg))
		}
		race.Disable()
		defer race.Enable()
	}
	// 把delta左移32位累加到state,即累加到counter中
	state := atomic.AddUint64(statep, uint64(delta)<<32)
	// 获取任务counter的值
	v := int32(state >> 32)
	// 获取等待者的个数
	w := uint32(state)
	println("v=",v, "w=", w, "semap=", *semap)
	if race.Enabled && delta > 0 && v == int32(delta) {
		race.Read(unsafe.Pointer(semap))
	}
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// 如果任务counter数大于0或等待者个数等于0,则不阻塞直接返回
	if v > 0 || w == 0 {
		return
	}
	if *statep != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// Reset waiters count to 0.
	*statep = 0
	// 走到这里说明任务都执行完了,并且还有等待者,因此会逐个释放信号量
	for ; w != 0; w-- {
		runtime_Semrelease(semap, false, 0)
	}
}

Add方法会做两件事情:

1、 更新任务counter个数,可增可减;

2、如果任务都执行完且还有等待者,则逐个唤醒等待者。

Done()方法

func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

每执行一次Done芳芳,则将任务数减1,同时判断是否需要唤醒等待者,参考Add()方法。

Wait()方法

func (wg *WaitGroup) Wait() {
	// 获取state1和semap的指针地址
	statep, semap := wg.state()
	if race.Enabled {
		_ = *statep // trigger nil deref early
		race.Disable()
	}
	for {
		state := atomic.LoadUint64(statep)
		// 获取当前的任务counter数
		v := int32(state >> 32)
		// 获取当前的等待者个数
		w := uint32(state)
		if v == 0 {
			// Counter is 0, no need to wait.
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
		// 将当前等待者个数加1
		if atomic.CompareAndSwapUint64(statep, state, state+1) {
			if race.Enabled && w == 0 {
				race.Write(unsafe.Pointer(semap))
			}
			// 一直阻塞直到获取到信号量
			runtime_Semacquire(semap)
			if *statep != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
	}
}

Wait方法会做两件事:

1、不断地尝试将等待者个数加1(CAS操作);

2、一旦增加成功,则一直阻塞直到获取到信号量。

posted @ 2021-09-16 21:03  独揽风月  阅读(85)  评论(0编辑  收藏  举报