contextvars:上下文变量管理

楔子

Python 在 3.7 的时候引入了一个模块:contextvars,从名字上很容易看出它指的是 "上下文变量(Context Variables)",所以在介绍 contextvars 之前我们需要先了解一下什么是 "上下文(Context)"。

Context 就是一个包含了相关信息内容的对象,举个栗子:"比如一部 13 集的动漫,你直接点进第八集,看到女主角在男主角面前流泪了"。相信此时你是不知道为什么女主角会流泪的,因为你没有看前面几集的内容,所以缺失了相关的上下文信息。

所以 Context 并不是什么神奇的东西,它的作用就是携带一些指定的上下文。

web 框架中的 request

我们以 fastapi 和 sanic 为例,看看当一个请求过来的时候,它们是如何解析的。

# fastapi
from fastapi import FastAPI, Request
import uvicorn

app = FastAPI()


@app.get("/index")
async def index(request: Request):
    name = request.query_params.get("name")
    return {"name": name}


uvicorn.run("__main__:app", host="127.0.0.1", port=5555)

# -------------------------------------------------------

# sanic
from sanic import Sanic
from sanic.request import Request
from sanic import response

app = Sanic("sanic")


@app.get("/index")
async def index(request: Request):
    name = request.args.get("name")
    return response.json({"name": name})


app.run(host="127.0.0.1", port=6666)

发请求测试一下,看看结果是否正确。

可以看到请求都是成功的,但是我们看到对于 fastapi 和 sanic 而言,其 request 和 视图函数是绑定在一起的。也就是在请求到来的时候,会被封装成一个 Request 对象、然后传递到视图函数中。

但是对于 flask 而言则不是这样子的,我们看一下 flask 是如何接收请求参数的。

from flask import Flask, request

app = Flask("flask")


@app.route("/index")
def index():
    name = request.args.get("name")
    return {"name": name}


app.run(host="127.0.0.1", port=7777)

我们看到对于 flask 而言则是通过 import 的方式,如果不需要的话就不用 import,当然我这里并不是在比较哪种方式好,主要是为引出我们今天的主题。首先对于 flask 而言,如果我再定义一个视图函数的话,那么获取请求参数依旧是相同的方式,但是这样问题就来了,不同的视图函数内部使用同一个 request,难道不会发生冲突吗?

答案是不会的,至于原因,就是 ThreadLocal。

ThreadLocal

ThreadLocal,从名字上看可以看出它肯定是和线程相关的。没错,它是专门用来创建局部变量的,并且创建的局部变量是和线程进行绑定的。

import threading


# 创建一个 local 对象
local = threading.local()


def get():
    name = threading.current_thread().name
    # 获取绑定在 local 上的 value
    value = local.value
    print(f"线程 {name}, value: {value}")


def set_():
    name = threading.current_thread().name
    # 为不同的线程设置不同的值
    if name == "one":
        local.value = "ONE"
    elif name == "two":
        local.value = "TWO"
    # 执行 get 函数
    get()


t1 = threading.Thread(target=set_, name="one")
t2 = threading.Thread(target=set_, name="two")
t1.start()
t2.start()
"""
线程 one, value: ONE
线程 two, value: TWO
"""

可以看到两个线程之间是互不影响的,因为每个线程都有自己唯一的 id,在绑定值的时候会绑定在当前相应的线程中,获取也会从当前相应的线程中获取。可以把 ThreadLocal 想象成一个字典:

{
    "one": {"value": "ONE"},
    "two": {"value": "TWO"}
}

更准确的说 key 应该是线程的 id,但为了直观我们就用线程的 name 代替了,但总之在获取的时候只会获取绑定在该线程上的变量的值。而 flask 内部也是这么设计的,只不过它没有直接用 threading.local,而是自己实现了一个 Local 类,除了支持线程之外还支持 greenlet 的协程。那么它是怎么实现的呢?首先我们知道 flask 内部分为 "请求 context" 和 "应用 context",它们都是通过栈来维护的(两个不同的栈)。

# flask/globals.py
_request_ctx_stack = LocalStack()
_app_ctx_stack = LocalStack()
current_app = LocalProxy(_find_app)
request = LocalProxy(partial(_lookup_req_object, "request"))
session = LocalProxy(partial(_lookup_req_object, "session"))

每个请求都会绑定在当前的 Context 中,等到请求结束之后再销毁,这个过程由框架完成,开发者只需要直接使用 request 即可。所以请求的具体细节流程可以点进源码中查看,这里我们重点关注一个对象:werkzeug.local.Local,也就是上面说的 Local 类,它是变量的设置和获取的关键。直接看部分源码:

# werkzeug/local.py

class Local(object):
    __slots__ = ("__storage__", "__ident_func__")

    def __init__(self):
        # 内部有两个成员:__storage__ 是一个字典,值就存在这里里面
        # __ident_func__ 只需要知道它是用来获取线程 id 的即可
        object.__setattr__(self, "__storage__", {})
        object.__setattr__(self, "__ident_func__", get_ident)

    def __call__(self, proxy):
        """Create a proxy for a name."""
        return LocalProxy(self, proxy)

    def __release_local__(self):
        self.__storage__.pop(self.__ident_func__(), None)

    def __getattr__(self, name):
        try:
            # 根据线程 id 得到 value(一个字典),然后再根据 name 获取对应的值
            # 所以只会获取绑定在当前线程上的值
            return self.__storage__[self.__ident_func__()][name]
        except KeyError:
            raise AttributeError(name)

    def __setattr__(self, name, value):
        ident = self.__ident_func__()
        storage = self.__storage__
        try:
            # 将线程 id 作为 key,然后将值设置在对应的字典中
            # 所以只会将值设置在当前的线程中
            storage[ident][name] = value
        except KeyError:
            storage[ident] = {name: value}

    def __delattr__(self, name):
        # 删除逻辑也很简单
        try:
            del self.__storage__[self.__ident_func__()][name]
        except KeyError:
            raise AttributeError(name)

所以我们看到 flask 内部的逻辑其实很简单,通过 ThreadLocal 实现了线程之间的隔离,每个请求都会绑定在各自的 Context 中,获取值的时候也会从各自的 Context 中获取,因为它就是用来保存相关信息的(重要的是同时也实现了隔离)。

相应此刻你已经理解了上下文,但是问题来了,不管是 threading.local 也好、还是类似于 flask 自己实现的 Local 也罢,它们都是针对线程的。如果是使用 async 定义的协程该怎么办呢?如何实现每个协程的上下文隔离呢?所以终于引出了我们的主角:contextvars。

contextvars

该模块提供了一组接口,可用于在协程中管理、设置、访问局部 Context 的状态。

import asyncio
import contextvars

context_var = contextvars.ContextVar("只是一个标识, 用于调试")


async def get():
    # 获取值
    return context_var.get() + "~~~"


async def set_(val):
    # 设置值
    context_var.set(val)
    print(await get())


async def main():
    coro1 = set_("协程1")
    coro2 = set_("协程2")
    await asyncio.gather(coro1, coro2)


asyncio.run(main())
"""
协程1~~~
协程2~~~
"""

我们看到和 threading.Local 是类似的用法,数据在协程之间是隔离的,不会受到彼此的影响。但是我们再仔细观察一下,我们是在 set_ 中设置的值,然后在 get 中获取值。可 await get() 相当于是开启了一个新的协程,那么意味着设置值和获取值不是在同一个协程当中。但即便如此,我们依旧可以获取到希望的结果,那么是不是意味着可以嵌套多层呢?

import asyncio
import contextvars

context_var = contextvars.ContextVar("只是一个标识, 用于调试")


async def get2():
    # 获取值
    return context_var.get() + "~~~"


async def get1():
    return await get2()


async def set_(index, val):
    # 设置值
    context_var.set(val)
    print(index, await get1())
    print(index, await get2())


async def main():
    coro1 = set_(1, "协程1")
    coro2 = set_(2, "协程2")
    await asyncio.gather(coro1, coro2)


asyncio.run(main())
"""
1 协程1~~~
1 协程1~~~
2 协程2~~~
2 协程2~~~
"""

我们看到不管是 await get1() 还是 await get2(),得到的都是 set_ 中设置的结果,说明它是可以嵌套的。但是问题又来了,如果我们在 get1 中重新设置值的话,会有什么影响呢?

import asyncio
import contextvars

context_var = contextvars.ContextVar("只是一个标识, 用于调试")


async def get2():
    # 获取值
    return context_var.get() + "~~~"


async def get1():
    # 重新设置值
    context_var.set("重新设置")
    return await get2()


async def set_(index, val):
    # 设置值
    context_var.set(val)
    print(index, await get2())
    print(index, await get1())
    print(index, await get2())


async def main():
    coro1 = set_(1, "协程1")
    coro2 = set_(2, "协程2")
    await asyncio.gather(coro1, coro2)


asyncio.run(main())
"""
1 协程1~~~
1 重新设置~~~
1 重新设置~~~
2 协程2~~~
2 重新设置~~~
2 重新设置~~~
"""

我们看到先 await get2() 得到的就是 set_ 中设置的值,这是符合预期的;但是我们在 get1 中将值重新设置了,那么之后不管是 await get1() 还是直接 await get2(),得到的都是新设置的值。这也说明了,一个协程内部 await 另一个协程,在另一个协程内部在 await 另另一个协程,不管套娃(await)多少次,它们获取的值都是一样的。并且在任意一个协程内部都可以进行设置,然后获取会得到最后一次设置的值。再举个栗子:

import asyncio
import contextvars

context_var = contextvars.ContextVar("只是一个标识, 用于调试")


async def get2():
    # 获取值
    val = context_var.get() + "~~~"
    # 重新设置
    context_var.set("重新设置啦")
    return val


async def get1():
    return await get2()


async def set_(val):
    # 设置值
    context_var.set(val)
    print(await get1())
    # 获取值
    print(context_var.get())


async def main():
    coro1 = set_("夏色祭")
    await coro1


asyncio.run(main())
"""
夏色祭~~~
重新设置啦
"""

完全符合预期,首先 set_ 中设置值,然后在 get2 中获取,最终打印 "夏色祭~~~"。但是在 get2 中又把值重新设置了,于是在 set_ 中又获取到了 get2 中设置的值。

如果 context_var 在 get 之前没有先 set,那么会抛出一个 LookupError,所以 contextvar.ContextVar 支持默认值:

import asyncio
import contextvars

context_var = contextvars.ContextVar("只是一个标识, 用于调试", default="哼哼")


async def get():
    # 获取值
    return context_var.get() + "~~~"


async def set_(val):
    #context_var.set(val)
    print(await get())


async def main():
    coro1 = set_("协程1")
    coro2 = set_("协程2")
    await asyncio.gather(coro1, coro2)


asyncio.run(main())
"""
哼哼~~~
哼哼~~~
"""

除了在 ContextVar 中指定默认值之外,也可以在 get 中指定:

import asyncio
import contextvars

context_var = contextvars.ContextVar("只是一个标识, 用于调试", default="哼哼")


async def get():
    # 获取值
    return context_var.get("哈哈") + "~~~"


async def set_(val):
    #context_var.set(val)
    print(await get())


async def main():
    coro1 = set_("协程1")
    coro2 = set_("协程2")
    await asyncio.gather(coro1, coro2)


asyncio.run(main())
"""
哈哈~~~
哈哈~~~
"""

所以结论如下,如果在 context_var.set 之前使用 context_var.get:

  • 当 ContextVar 和 get 中都没有指定默认值,会抛出 LookupError
  • 只要有一方设置了,那么会得到默认值
  • 如果都设置了,那么以 get 为准

如果 context_var.get 之前执行了 context_var.set,那么无论 ContextVar 和 get 有没有指定默认值,获取到的都是 context_var.set 设置的值。

所以总的来说还是比较好理解的,并且 context.ContextVar 除了可以作用在协程上面,它也可以用在线程上面。没错,它可以替代 threading.local,我们来试一下:

import threading
import contextvars

context_var = contextvars.ContextVar("context_var")


def get():
    name = threading.current_thread().name
    value = context_var.get()
    print(f"线程 {name}, value: {value}")


def set_():
    name = threading.current_thread().name
    if name == "one":
        context_var.set("ONE")
    elif name == "two":
        context_var.set("TWO")
    get()


t1 = threading.Thread(target=set_, name="one")
t2 = threading.Thread(target=set_, name="two")
t1.start()
t2.start()
"""
线程 one, value: ONE
线程 two, value: TWO
"""

和 threading.local 的表现是一样的,并且在高并发环境中更建议使用 contextvars.ContextVars。

context_var.Token

当我们调用 context_var.set 的时候,其实会返回一个 contextvars.Token 对象:

import contextvars

context_var = contextvars.ContextVar("context_var")
token = context_var.set("val")
print(token)  # <Token var=<ContextVar name='context_var' at 0x000002038EA6E4F0> at 0x000002038EA97FC0>

Token 对象有一个 var 属性,它是只读的,会返回指向此 token 的 ContextVar 对象。

import contextvars


context_var = contextvars.ContextVar("context_var")
token = context_var.set("val")

print(token.var)  # <ContextVar name='context_var' at 0x000002038EA6E4F0>
print(token.var is context_var)  # True
print(token.var.get())  # val

print(token.var.set("val2").var.set("val3").var is context_var)  # True
print(context_var.get())  # val3

Token 对象还有一个 old_value 属性,它会返回上一次 set 设置的值,如果是第一次 set,那么会返回一个 <Token.MISSING>。

import contextvars


context_var = contextvars.ContextVar("context_var")
token = context_var.set("val")

# 该 token 是 context_var 第一次 set 所返回的,在此之前没有 set,所以 old_value 是 <Token.MISSING>
print(token.old_value)  # <Token.MISSING>

token = context_var.set("val1")
# 返回上一次 set 的值
print(token.old_value)  # val

那么这个 contextvars.Token 有什么作用呢?从目前来看貌似没太大用处,其实它最大的用处就是和 reset 搭配使用,可以对状态进行重置。

import contextvars

context_var = contextvars.ContextVar("context_var")
token = context_var.set("val")
# 显然是可以获取的
print(context_var.get())  # val

# 将其重置为 token 之前的状态,token 之前意味着是没有进行 set 的
# 因为这个 token 是第一次 set 返回的,那么之前就相当于没有 set 了
context_var.reset(token)
try:
    context_var.get()  # 此时就会报错
except LookupError:
    print("报错啦")  # 报错啦

# 但是我们可以指定默认值
print(context_var.get("默认值"))  # 默认值

contextvars.Context

它负责保存 ContextVars 对象和设置的值之间的映射,但是我们不会直接通过 contextvars.Context 来创建,而是通过 contentvars.copy_context 函数来创建。

import contextvars

context_var1 = contextvars.ContextVar("context_var1")
context_var1.set("val1")
context_var2 = contextvars.ContextVar("context_var2")
context_var2.set("val2")

# 此时得到的是所有 ContextVar 对象和设置的值之间的映射,它实现了 collections.abc.Mapping 接口
# 因此我们可以像操作字典一样操作它
context = contextvars.copy_context()  
# key 就是对应的 ContextVar 对象,value 就是设置的值
print(context[context_var1])  # val1
for ctx, value in context.items():
    print(ctx.get(), ctx.name, value)
    """
    val1 context_var1 val1
    val2 context_var2 val2
    """

print(len(context))  # 2

除此之外,context 还有一个 run 方法:

import contextvars

context_var1 = contextvars.ContextVar("context_var1")
context_var1.set("val1")
context_var2 = contextvars.ContextVar("context_var2")
context_var2.set("val2")

context = contextvars.copy_context()


def change(val1, val2):
    context_var1.set(val1)
    context_var2.set(val2)
    print(context_var1.get(), context[context_var1])
    print(context_var2.get(), context[context_var2])

context.run(change, "VAL1", "VAL2")
"""
VAL1 VAL1
VAL2 VAL2
"""
print(context_var1.get(), context[context_var1])  # val1 VAL1
print(context_var2.get(), context[context_var2])  # val2 VAL2

我们看到 run 方法接收一个 callable,如果在里面修改了 ContextVars 设置的值,那么对于 ContextVars 而言只会在函数内部生效,一旦出了函数,那么还是原来的值。但是对于 Context 而言,它是会受到影响的,即便出了函数,也是新设置的值。

小结

总的来说 contextvars 还是非常有用的,尤其是在多个协程之间进行变量传递的时候。这让我想起了 Go 在 1.7 版本时引入的 context 包,同样是对多个 goroutine 进行级联管理提供了非常清蒸的解决方案。

posted @ 2019-07-27 01:43  古明地盆  阅读(2070)  评论(0编辑  收藏  举报