Go 语言标准库之 context 包

context.Context是一个非常抽象的概念,中文翻译为 ”上下文“,可看做为goroutine的上下文。Context是线程安全的,所以可以在多个goroutine之间传递上下文信息,包括信号、超时时间、K-V键值对等,同时它也可以用作并发控制。

Context 接口

type Context interface {
   // 返回当前 Context 被取消的时间(完成工作的截止时间);如果没有设置时间,ok 返回 false
   Deadline() (deadline time.Time, ok bool)
    
   // 返回一个 Channel,该 Channel 会在当前工作完成或者上下文被取消之后关闭
   Done() <-chan struct{}
    
   // 返回当前 Context 结束的原因,它只会在 Done() 方法对应的 Channel 关闭时返回非 nil 的值
   // 1. 如果当前 Context 被取消,返回 Canneled 错误
   // 2. 如果当前 Context 超时,返回 DeadlineExceeded 错误 
   Err() error
    
   // 从当前 Context 中返回 key 对应的 value
   Value(key interface{}) interface{}
}

Background() 和 TODO() 函数

context包提供Background()TODO()函数,分别返回实现了Context接口的内置的上下文对象backgroundtodo。一般而言,我们代码最开始都是以这两个内置的上下文对象作为最顶层的partent context,衍生出更多的子上下文对象。

  • Background():主要用于main函数、初始化以及测试代码中,作为Context这个树结构的最顶层的Context,也就是根Context
  • TODO():如果目前还不知道具体的使用场景,不知道该使用什么Context的时候,可以使用这个。

Background()TODO()函数的源码如下:

// background 和 todo 本质是 emptyCtx 结构体类型,是一个不可取消,没有设置截止时间,没有携带任何值的Context
var (
    background = new(emptyCtx)
    todo       = new(emptyCtx)
)

func Background() Context {
    return background
}

func TODO() Context {
    return todo
}

type emptyCtx int

func (*emptyCtx) Deadline() (deadline time.Time, ok bool) {
    return
}

func (*emptyCtx) Done() <-chan struct{} {
    return nil
}

func (*emptyCtx) Err() error {
    return nil
}

func (*emptyCtx) Value(key interface{}) interface{} {
    return nil
}

Background()TODO()返回的上下文对象一般作为最顶层的根Context,然后通过调用withCancel()WithDeadline()WithTimeout()WithValue()函数创建其派生的子上下文。当一个上下文被取消时,它派生的所有上下文也会被取消。


WithCancel() 函数

// 返回具有新的 Done 通道的父上下文的副本 ctx
// 当调用返回的 cancel 函数或关闭父上下文的 Done 通道时,返回的上下文 ctx 的 Done 通道也会被关闭
func WithCancel(parent Context) (ctx Context, cancel CancelFunc)

☕️ 示例代码

package main

import (
    "context"
    "fmt"
)

func gen(ctx context.Context) <-chan int {
    dst := make(chan int)
    n := 1
    go func() {
        for {
            select {
            case <-ctx.Done():
                // return 结束该 goroutine,防止泄露
                return
            case dst <- n:
                n++
            }
        }
    }()
    return dst
}

func main() {
    ctx, cancel := context.WithCancel(context.Background())
    // 获取完需要的整数后,调用 cancel
    defer cancel()

    for n := range gen(ctx) {
        fmt.Println(n)
        if n == 5 {
            break
        }
    }
}

// 1
// 2
// 3
// 4
// 5
package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

var wg sync.WaitGroup

func worker(ctx context.Context) {
    go worker2(ctx)
LOOP:
    for {
        fmt.Println("worker")
        time.Sleep(time.Second)
        select {
        case <-ctx.Done(): // 等待上级通知
            break LOOP
        default:
        }
    }
    wg.Done()
}

func worker2(ctx context.Context) {
LOOP:
    for {
        fmt.Println("worker2")
        time.Sleep(time.Second)
        select {
        case <-ctx.Done(): // 等待上级通知
            break LOOP
        default:
        }
    }
}

func main() {
    ctx, cancel := context.WithCancel(context.Background())
    wg.Add(1)
    go worker(ctx)
    time.Sleep(time.Second * 3)
    cancel() // 通知子 goroutine 结束
    wg.Wait()
    fmt.Println("over")
}

// worker2
// worker
// worker
// worker2
// worker2
// worker 
// over

WithDeadline() 函数

// 返回父上下文的副本 ctx,完成工作的截止时间(deadline)调整为不迟于 d
// 如果父上下文的 deadline 早于 d,则 WithDeadline(parent, d) 在语义上等同于父上下文
// 当遇到以下三种情况时,返回的上下文 ctx 的 Done 通道将被关闭,以最先发生的情况为准:
// 1. 设置的截止日期 d 过期;2. 调用返回的 cancal 函数;3. 父上下文的 Done 通道被关闭
func WithDeadline(parent Context, d time.Time) (ctx Context, cancal CancelFunc)

⭐️ 示例代码

package main

import (
    "context"
    "fmt"
    "time"
)

func main() {
    // 设置当前上下文 50ms 后过期
    d := time.Now().Add(50 * time.Millisecond)
    ctx, cancel := context.WithDeadline(context.Background(), d)

    // 尽管 ctx 会过期,但在任何情况下调用它的 cancel 函数都是很好的实践。
    // 如果不这样做,可能会使上下文及其父类存活的时间超过必要的时间。
    defer cancel()

    select {
    case <-time.After(1 * time.Second):
        // 等待 1 秒后打印 overslept 退出
        fmt.Println("overslept")
    case <-ctx.Done():
        // 等待 ctx 过期后,退出
        fmt.Println(ctx.Err())
    }
}

// context deadline exceeded

WithTimeout() 函数

// 等同于 WithDeadline(parent, time.Now().Add(timeout))
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc)

✏️ 示例代码

package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

var wg sync.WaitGroup

func worker(ctx context.Context) {
LOOP:
    for {
        fmt.Println("db connecting...")
        time.Sleep(time.Millisecond * 10) // 假设正常连接数据库耗时 10 毫秒
        select {
        case <-ctx.Done(): // 50ms 后调用
            break LOOP
        default:
        }
    }
    fmt.Println("worker done!")
    wg.Done()
}

func main() {
    // 设置当前上下文 50ms 后过期
    ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
    wg.Add(1)
    go worker(ctx)
    time.Sleep(time.Second * 5)
    cancel() // 通知子 goroutine 结束
    wg.Wait()
    fmt.Println("over")
}

// db connecting...
// db connecting...
// db connecting...
// worker done!
// over

withValue() 函数

// 返回父节点的副本,可传递一个与 key 关联的 val 值
func WithValue(parent Context, key, val interface{}) Context

📚 示例代码

package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

type TraceCode string

var wg sync.WaitGroup

func worker(ctx context.Context) {
    key := TraceCode("TRACE_CODE")
    traceCode, ok := ctx.Value(key).(string) // 在子 goroutine 中获取 trace code
    if !ok {
        fmt.Println("invalid trace code")
    }
LOOP:
    for {
        fmt.Printf("worker, trace code:%s\n", traceCode)
        time.Sleep(time.Millisecond * 10) // 假设正常连接数据库耗时 10 毫秒
        select {
        case <-ctx.Done(): // 50毫秒后自动调用
            break LOOP
        default:
        }
    }
    fmt.Println("worker done!")
    wg.Done()
}

func main() {
    // 创建一个过期时间为 50ms 的上下文
    ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
    // 在系统的入口中设置 trace code 传递给后续启动的 goroutine 实现日志数据聚合
    ctx = context.WithValue(ctx, TraceCode("TRACE_CODE"), "12512312234")
    wg.Add(1)
    go worker(ctx)
    time.Sleep(time.Second * 5)
    cancel() // 通知子goroutine结束
    wg.Wait()
    fmt.Println("over")
}

// worker, trace code:12512312234
// worker, trace code:12512312234
// worker, trace code:12512312234
// worker, trace code:12512312234
// worker, trace code:12512312234
// worker done!
// over

客户端超时取消示例

✌ 服务端代码

package main

import (
    "fmt"
    "math/rand"
    "net/http"
    "time"
)

func main() {
    http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
        random := rand.New(rand.NewSource(time.Now().UnixNano()))
        number := random.Intn(2)
        if number == 0 {
            time.Sleep(time.Second * 10) // 耗时 10 秒的慢响应
            fmt.Fprintf(w, "slow response")
            return
        }
        fmt.Fprintf(w, "quick response") // 快速响应
    })

    err := http.ListenAndServe(":8080", nil)
    if err != nil {
        panic(err)
    }
}

✍ 客户端代码

package main

import (
    "context"
    "fmt"
    "io/ioutil"
    "net/http"
    "sync"
    "time"
)

type respData struct {
    resp *http.Response
    err  error
}

func doCall(ctx context.Context) {
    client := http.Client{
        Transport: &http.Transport{
            // 设置为长连接
            DisableKeepAlives: true,
        },
    }

    respChan := make(chan *respData, 1)
    req, err := http.NewRequest("GET", "http://127.0.0.1:8080/", nil)
    if err != nil {
        fmt.Println("new request failed, err:%v\n", err)
        return
    }
    // 使用带超时的 ctx 创建一个新的 client request
    req = req.WithContext(ctx)

    var wg sync.WaitGroup
    wg.Add(1)
    defer wg.Wait()
    go func() {
        resp, err := client.Do(req)
        fmt.Printf("client.do resp:%v, err:%v\n", resp, err)
        respChan <- &respData{
            resp: resp,
            err:  err,
        }
        wg.Done()
    }()

    select {
    case <-ctx.Done():
        fmt.Println("call api timeout")
    case result := <-respChan:
        fmt.Println("call server api success")
        if result.err != nil {
            fmt.Printf("call server api failed, err:%v\n", result.err)
            return
        }
        defer result.resp.Body.Close()
        data, _ := ioutil.ReadAll(result.resp.Body)
        fmt.Printf("resp:%v\n", string(data))
    }
}

func main() {
    // 将当前 ctx 的超时时间设置为 100ms
    ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
    // 调用cancel释放子goroutine资源
    defer cancel()

    doCall(ctx)
}

// 超时的输出:
// call api timeout
// client.do resp:<nil>, err:Get "http://127.0.0.1:8080/": context deadline exceeded

// 不超时的输出:
// client.do resp:&{200 OK 200 HTTP/1.1 1 1 map[Content-Length:[14] Content-Type:[text/plain; charset=utf-8] Date:[Sun, 22 May 2022 07:50:29 GMT]] 0xc000206080 14 [] true false map[] 0xc00013a100 <nil>}, err:<nil> call server api success
// resp:quick response

参考

  1. Go标准库Context
posted @ 2022-05-22 15:54  呵呵233  阅读(143)  评论(0编辑  收藏  举报