typing模块中Protocol协议的使用

说明

在 Python 的 typing 模块中,Protocol 是一个用于定义协议(Protocol)的类。
协议是一种形式化的接口,定义了一组方法或属性的规范,而不关心具体的实现。Protocol 类提供了一种方式来定义这些协议。

使用 Protocol 类时,可以定义一个类,继承自 Protocol 并在类中定义需要的方法或属性。
这样,通过继承 Protocol,可以告诉静态类型检查器,该类遵循了特定的协议。

有点类似go语言中的interface,但又有所不同,感觉Protocol只是为了解决静态类型检查的问题

示例

from typing import Protocol

# 理解为定义接口及接口中的方法
class Animal(Protocol):
    def speak(self) -> str:
        pass

    def eat(self) -> str:
        pass

# 实现类,dog实现了接口中的全部方法
class Dog:
    def speak(self) -> str:
        return "Woof!"

    def eat(self) -> str:
        return "Dog is eating hotdog"

# 实现类,但是cat只实现了接口中的一个方法
class Cat:
    def speak(self) -> str:
        return "Meow!"

# 参数为接口类型
def make_sound(animal: Animal) -> str:
    return animal.speak()


dog = Dog()
cat = Cat()

# 如果单独运行,是没有问题的,所以你需要用mypy检查工具运行该代码
print(make_sound(dog))  # Output: Woof!
print(make_sound(cat))  # Output: Meow!

# mypy类型检查会提示如下报错,表示Cat类没有实现接口中的eat方法
part3.py:33: error: Argument 1 to "make_sound" has incompatible type "Cat"; expected "Animal"  [arg-type]
part3.py:33: note: "Cat" is missing following "Animal" protocol member:
part3.py:33: note:     eat

示例2

from typing import Protocol, Any, TypeVar, TYPE_CHECKING
from collections.abc import Iterable

from typing_extensions import reveal_type

# 定义接口,需要实现可以比较的__lt__方法
class SupportsLessThan(Protocol):
    def __lt__(self, other: Any) -> bool: ...


LT = TypeVar('LT', bound=SupportsLessThan) # 表示泛型上限为SupportsLessThan


def top(series: Iterable[LT], length: int) -> list[LT]: # 返回值也可以用LT,因为list也实现了__lt__方法
    ordered = sorted(series, reverse=True)
    return ordered[:length]


if __name__ == '__main__':
    fruit = 'mango pear apple kiwi banana'.split()
    # tuple实现了__lt__方法
    series: Iterable[tuple[int, str]] = (
        (len(s), s) for s in fruit
    )
    length = 3
    expected = [(6, 'banana'), (5, 'mango'), (5, 'apple')]
    result = top(series, length) # 所以可以将series传递到top中
    TYPE_CHECKING = True
    if TYPE_CHECKING:
        reveal_type(series)
        reveal_type(expected)
        reveal_type(result)
    print(result == expected)

posted @ 2023-06-25 03:50  我在路上回头看  阅读(752)  评论(0编辑  收藏  举报