Go 语言标准库之 context 包
context.Context
是一个非常抽象的概念,中文翻译为 ”上下文“,可看做为goroutine
的上下文。Context
是线程安全的,所以可以在多个goroutine
之间传递上下文信息,包括信号、超时时间、K-V
键值对等,同时它也可以用作并发控制。
Context 接口
type Context interface {
// 返回当前 Context 被取消的时间(完成工作的截止时间);如果没有设置时间,ok 返回 false
Deadline() (deadline time.Time, ok bool)
// 返回一个 Channel,该 Channel 会在当前工作完成或者上下文被取消之后关闭
Done() <-chan struct{}
// 返回当前 Context 结束的原因,它只会在 Done() 方法对应的 Channel 关闭时返回非 nil 的值
// 1. 如果当前 Context 被取消,返回 Canneled 错误
// 2. 如果当前 Context 超时,返回 DeadlineExceeded 错误
Err() error
// 从当前 Context 中返回 key 对应的 value
Value(key interface{}) interface{}
}
Background() 和 TODO() 函数
context
包提供Background()
和TODO()
函数,分别返回实现了Context
接口的内置的上下文对象background
和todo
。一般而言,我们代码最开始都是以这两个内置的上下文对象作为最顶层的partent context
,衍生出更多的子上下文对象。
Background()
:主要用于main
函数、初始化以及测试代码中,作为Context
这个树结构的最顶层的Context
,也就是根Context
。TODO()
:如果目前还不知道具体的使用场景,不知道该使用什么Context
的时候,可以使用这个。
Background()
和TODO()
函数的源码如下:
// background 和 todo 本质是 emptyCtx 结构体类型,是一个不可取消,没有设置截止时间,没有携带任何值的Context
var (
background = new(emptyCtx)
todo = new(emptyCtx)
)
func Background() Context {
return background
}
func TODO() Context {
return todo
}
type emptyCtx int
func (*emptyCtx) Deadline() (deadline time.Time, ok bool) {
return
}
func (*emptyCtx) Done() <-chan struct{} {
return nil
}
func (*emptyCtx) Err() error {
return nil
}
func (*emptyCtx) Value(key interface{}) interface{} {
return nil
}
Background()
和TODO()
返回的上下文对象一般作为最顶层的根Context
,然后通过调用withCancel()
、WithDeadline()
、WithTimeout()
或WithValue()
函数创建其派生的子上下文。当一个上下文被取消时,它派生的所有上下文也会被取消。
WithCancel() 函数
// 返回具有新的 Done 通道的父上下文的副本 ctx
// 当调用返回的 cancel 函数或关闭父上下文的 Done 通道时,返回的上下文 ctx 的 Done 通道也会被关闭
func WithCancel(parent Context) (ctx Context, cancel CancelFunc)
☕️ 示例代码
package main
import (
"context"
"fmt"
)
func gen(ctx context.Context) <-chan int {
dst := make(chan int)
n := 1
go func() {
for {
select {
case <-ctx.Done():
// return 结束该 goroutine,防止泄露
return
case dst <- n:
n++
}
}
}()
return dst
}
func main() {
ctx, cancel := context.WithCancel(context.Background())
// 获取完需要的整数后,调用 cancel
defer cancel()
for n := range gen(ctx) {
fmt.Println(n)
if n == 5 {
break
}
}
}
// 1
// 2
// 3
// 4
// 5
package main
import (
"context"
"fmt"
"sync"
"time"
)
var wg sync.WaitGroup
func worker(ctx context.Context) {
go worker2(ctx)
LOOP:
for {
fmt.Println("worker")
time.Sleep(time.Second)
select {
case <-ctx.Done(): // 等待上级通知
break LOOP
default:
}
}
wg.Done()
}
func worker2(ctx context.Context) {
LOOP:
for {
fmt.Println("worker2")
time.Sleep(time.Second)
select {
case <-ctx.Done(): // 等待上级通知
break LOOP
default:
}
}
}
func main() {
ctx, cancel := context.WithCancel(context.Background())
wg.Add(1)
go worker(ctx)
time.Sleep(time.Second * 3)
cancel() // 通知子 goroutine 结束
wg.Wait()
fmt.Println("over")
}
// worker2
// worker
// worker
// worker2
// worker2
// worker
// over
WithDeadline() 函数
// 返回父上下文的副本 ctx,完成工作的截止时间(deadline)调整为不迟于 d
// 如果父上下文的 deadline 早于 d,则 WithDeadline(parent, d) 在语义上等同于父上下文
// 当遇到以下三种情况时,返回的上下文 ctx 的 Done 通道将被关闭,以最先发生的情况为准:
// 1. 设置的截止日期 d 过期;2. 调用返回的 cancal 函数;3. 父上下文的 Done 通道被关闭
func WithDeadline(parent Context, d time.Time) (ctx Context, cancal CancelFunc)
⭐️ 示例代码
package main
import (
"context"
"fmt"
"time"
)
func main() {
// 设置当前上下文 50ms 后过期
d := time.Now().Add(50 * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
// 尽管 ctx 会过期,但在任何情况下调用它的 cancel 函数都是很好的实践。
// 如果不这样做,可能会使上下文及其父类存活的时间超过必要的时间。
defer cancel()
select {
case <-time.After(1 * time.Second):
// 等待 1 秒后打印 overslept 退出
fmt.Println("overslept")
case <-ctx.Done():
// 等待 ctx 过期后,退出
fmt.Println(ctx.Err())
}
}
// context deadline exceeded
WithTimeout() 函数
// 等同于 WithDeadline(parent, time.Now().Add(timeout))
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc)
✏️ 示例代码
package main
import (
"context"
"fmt"
"sync"
"time"
)
var wg sync.WaitGroup
func worker(ctx context.Context) {
LOOP:
for {
fmt.Println("db connecting...")
time.Sleep(time.Millisecond * 10) // 假设正常连接数据库耗时 10 毫秒
select {
case <-ctx.Done(): // 50ms 后调用
break LOOP
default:
}
}
fmt.Println("worker done!")
wg.Done()
}
func main() {
// 设置当前上下文 50ms 后过期
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
wg.Add(1)
go worker(ctx)
time.Sleep(time.Second * 5)
cancel() // 通知子 goroutine 结束
wg.Wait()
fmt.Println("over")
}
// db connecting...
// db connecting...
// db connecting...
// worker done!
// over
withValue() 函数
// 返回父节点的副本,可传递一个与 key 关联的 val 值
func WithValue(parent Context, key, val interface{}) Context
📚 示例代码
package main
import (
"context"
"fmt"
"sync"
"time"
)
type TraceCode string
var wg sync.WaitGroup
func worker(ctx context.Context) {
key := TraceCode("TRACE_CODE")
traceCode, ok := ctx.Value(key).(string) // 在子 goroutine 中获取 trace code
if !ok {
fmt.Println("invalid trace code")
}
LOOP:
for {
fmt.Printf("worker, trace code:%s\n", traceCode)
time.Sleep(time.Millisecond * 10) // 假设正常连接数据库耗时 10 毫秒
select {
case <-ctx.Done(): // 50毫秒后自动调用
break LOOP
default:
}
}
fmt.Println("worker done!")
wg.Done()
}
func main() {
// 创建一个过期时间为 50ms 的上下文
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
// 在系统的入口中设置 trace code 传递给后续启动的 goroutine 实现日志数据聚合
ctx = context.WithValue(ctx, TraceCode("TRACE_CODE"), "12512312234")
wg.Add(1)
go worker(ctx)
time.Sleep(time.Second * 5)
cancel() // 通知子goroutine结束
wg.Wait()
fmt.Println("over")
}
// worker, trace code:12512312234
// worker, trace code:12512312234
// worker, trace code:12512312234
// worker, trace code:12512312234
// worker, trace code:12512312234
// worker done!
// over
客户端超时取消示例
✌ 服务端代码
package main
import (
"fmt"
"math/rand"
"net/http"
"time"
)
func main() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
random := rand.New(rand.NewSource(time.Now().UnixNano()))
number := random.Intn(2)
if number == 0 {
time.Sleep(time.Second * 10) // 耗时 10 秒的慢响应
fmt.Fprintf(w, "slow response")
return
}
fmt.Fprintf(w, "quick response") // 快速响应
})
err := http.ListenAndServe(":8080", nil)
if err != nil {
panic(err)
}
}
✍ 客户端代码
package main
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"sync"
"time"
)
type respData struct {
resp *http.Response
err error
}
func doCall(ctx context.Context) {
client := http.Client{
Transport: &http.Transport{
// 设置为长连接
DisableKeepAlives: true,
},
}
respChan := make(chan *respData, 1)
req, err := http.NewRequest("GET", "http://127.0.0.1:8080/", nil)
if err != nil {
fmt.Println("new request failed, err:%v\n", err)
return
}
// 使用带超时的 ctx 创建一个新的 client request
req = req.WithContext(ctx)
var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()
go func() {
resp, err := client.Do(req)
fmt.Printf("client.do resp:%v, err:%v\n", resp, err)
respChan <- &respData{
resp: resp,
err: err,
}
wg.Done()
}()
select {
case <-ctx.Done():
fmt.Println("call api timeout")
case result := <-respChan:
fmt.Println("call server api success")
if result.err != nil {
fmt.Printf("call server api failed, err:%v\n", result.err)
return
}
defer result.resp.Body.Close()
data, _ := ioutil.ReadAll(result.resp.Body)
fmt.Printf("resp:%v\n", string(data))
}
}
func main() {
// 将当前 ctx 的超时时间设置为 100ms
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
// 调用cancel释放子goroutine资源
defer cancel()
doCall(ctx)
}
// 超时的输出:
// call api timeout
// client.do resp:<nil>, err:Get "http://127.0.0.1:8080/": context deadline exceeded
// 不超时的输出:
// client.do resp:&{200 OK 200 HTTP/1.1 1 1 map[Content-Length:[14] Content-Type:[text/plain; charset=utf-8] Date:[Sun, 22 May 2022 07:50:29 GMT]] 0xc000206080 14 [] true false map[] 0xc00013a100 <nil>}, err:<nil> call server api success
// resp:quick response