代码改变世界

如何在 Python 中使用 generators 和 yield

  abce  阅读(46)  评论(0编辑  收藏  举报

 

是否曾经需要处理一个大到足以耗尽机器内存的数据集?或者有一个复杂的函数,每次调用时都需要维护内部状态,但这个函数太小,不适合创建自己的类。在这些情况以及更多情况下,Generators 和 yield 语句都能帮上忙。

 

使用 generator

generator 函数是一种返回懒惰迭代器的特殊函数。generator 对象可以像列表一样循环使用。然而,与列表不同的是,懒惰迭代器不会将其内容存储在内存中。

 

generator示例1:读取大文件

generator 的一个常见用例是处理数据流或大文件,如 CSV 文件。这些文本文件使用逗号将数据分栏。这种格式是共享数据的常用方式。现在,如果要计算 CSV 文件中的行数怎么办?下面的代码块展示了一种计算行数的方法:

1
2
3
4
5
6
7
csv_gen = csv_reader("some_csv.txt")
row_count = 0
 
for row in csv_gen:
    row_count += 1
 
print(f"Row count is : {row_count}")

这里你可以认为 csv_gen 是一个列表。为了填充这个列表,csv_reader() 会打开一个文件并将其内容载入 csv_gen。然后,程序会遍历该列表,递增 row_count 计数。

这是一个合理的解释,但如果文件非常大,这种设计还能起作用吗?如果文件比可用内存还大呢?为了回答这个问题,我们假设 csv_reader()只是打开文件并将其读入一个数组:

1
2
3
4
def csv_reader(file_name):
    file = open(file_name)
    result = file.read().split("\n") #按行分割
    return result

函数打开一个给定的文件,并使用 file.read() 和 .split() 将每一行作为一个单独的元素添加到一个列表中。如果在前面看到的行计数代码块中使用此版本的 csv_reader(),就会得到以下输出结果:

1
2
3
4
5
6
7
8
Traceback (most recent call last):
  File "ex1_naive.py", line 22, in <module>
    main()
  File "ex1_naive.py", line 13, in main
    csv_gen = csv_reader("file.txt")
  File "ex1_naive.py", line 6, in csv_reader
    result = file.read().split("\n")
MemoryError

在这种情况下,如果 open() 会返回一个 generator 对象,你可以懒洋洋地逐行遍历。但是,file.read().split() 会一次性将所有内容加载到内存中,从而导致内存错误。

 

在此之前,你可能会发现电脑运行缓慢。甚至可能需要使用键盘中断(KeyboardInterrupt)来杀死程序。那么,如何处理这些庞大的数据文件呢?来看看 csv_reader() 的新定义:

1
2
3
def csv_reader(file_name):
    for row in open(file_name, "r"):
        yield row

在此代码版本中,打开文件,遍历文件,然后生成一行。该代码应产生以下输出,且不会出现内存错误

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def csv_reader(file_name):
    for row in open(file_name, "r"):
        yield row
 
 
csv_gen = csv_reader("/tmp/e.log")
 
row_count = 0
 
for row in csv_gen:
    row_count += 1
 
print(f"Row count is : {row_count}")
 
# 运行结果
Row count is : 406242721

这里把 csv_reader() 变成了一个 generator 函数。这个版本打开一个文件,循环查看每一行,然后生成每一行,而不是返回文件。

 

你还可以定义一个 generator 表达式,它的语法与列表表达式非常相似。这样,就可以在不调用函数的情况下使用 generator:

1
csv_gen = (row for row in open(file_name))

 

generator示例 2:生成无限序列

让我们换个角度来看看无穷序列的生成。在 Python 中,要获得有限序列,需要调用 range() 并在列表上下文中对其进行评估:

1
2
3
>>> a = range(5)
>>> list(a)
[0, 1, 2, 3, 4]

然而生成一个无限序列需要使用 generator,因为你的电脑的内存是有限的:

1
2
3
4
5
def infinite_sequence():
    num = 0
    while True:
        yield num
        num += 1

该代码块简短而精炼。首先,初始化变量 num 并启动一个无限循环。然后,立即 yeild num,以便捕捉初始状态。这模仿了 range() 的操作。

 

在 yield 之后,num 会递增 1。如果你用 for 循环尝试一下,就会发现它看起来确实是无限的:

1
2
3
4
5
6
7
8
9
10
11
12
>>> for i in infinite_sequence():
...     print(i, end=" ")
...
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
30 31 32 33 34 35 36 37 38 39 40 41 42
[...]
6157818 6157819 6157820 6157821 6157822 6157823 6157824 6157825 6157826 6157827
6157828 6157829 6157830 6157831 6157832 6157833 6157834 6157835 6157836 6157837
6157838 6157839 6157840 6157841 6157842
Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
KeyboardInterrupt

程序将继续执行,直到手动停止。

 

也可以直接在 generator 对象上调用 next() 来代替 for 循环。这对于在控制台中测试信号发生器特别有用:

1
2
3
4
5
6
7
8
9
>>> gen = infinite_sequence()
>>> next(gen)
0
>>> next(gen)
1
>>> next(gen)
2
>>> next(gen)
3

在这里,你有一个名为 gen 的 generator ,你可以通过重复调用 next() 来手动遍历该 generator 。这可以作为一个很好的正确性检查,以确保 generator 产生了所期望的输出。

 

generator示例 3:检测回文

无限序列可以用在很多方面,但其中一个实际用途是构建回文检测器。回文检测器会找出所有回文字母或数字序列。这些单词或数字的正向和反向读法相同,就像 121。首先,定义数字回文检测器:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def is_palindrome(num):
    # Skip single-digit inputs
    if num // 10 == 0:
        return False
    temp = num
    reversed_num = 0
 
    while temp != 0:
        reversed_num = (reversed_num * 10) + (temp % 10)
        temp = temp // 10
 
    if num == reversed_num:
        return num
    else:
        return False

不用太在意如何理解这段代码中的基本数学知识。只需注意,该函数接收一个输入数字,将其反转,并检查反转后的数字是否与原始数字相同。现在,你可以使用无限序列 generator 来获取所有数字回文的运行列表:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
>>> for i in infinite_sequence():
...     pal = is_palindrome(i)
...     if pal:
...         print(i)
...
11
22
33
[...]
99799
99899
99999
100001
101101
102201
Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
  File "<stdin>", line 5, in is_palindrome
KeyboardInterrupt

注:实际上,你不太可能编写自己的无限序列 generator 。itertools 模块通过 itertools.count() 提供了一个非常高效的无限序列 generator 。

 

现在你已经看到了无限序列生成器的一个简单用例,让我们深入了解生成器是如何工作的。

 

理解generator

到目前为止,已经了解了创建generator的两种主要方法:使用generator函数和generator表达式。

generator函数的外观和行为与普通函数一样,但有一个显著特点。generator函数使用 Python 的 yield 关键字来代替 return。回想一下之前编写的generator函数:

1
2
3
4
5
def infinite_sequence():
    num = 0
    while True:
        yield num
        num += 1

这看起来像一个典型的函数定义,除了 Python 的 yield 语句和它后面的代码。yield 表示将一个值发送回调用者,但与 return 不同的是,你不会在之后退出函数。

 

相反,函数的状态会被记住。这样,当在generator对象上调用 next() 时(无论是显式调用还是在 for 循环中隐式调用),先前yield的变量 num 会递增,然后再次yield。由于generator函数看起来像其他函数,其行为也与其他函数非常相似,因此可以认为generator表达式与 Python 中的其他表达式非常相似。

 

使用generator表达式创建generator

与列表表达式一样,generator表达式允许你用几行代码快速创建一个generator对象。generator表达式很有用,还有一个额外的好处:在迭代之前,创建generator表达式无需在内存中构建和保留整个对象。换句话说,使用generator表达式不会占用内存。请看下面这个数字平方的例子:

1
2
3
4
5
6
7
8
9
nums_squared_lc = [num**2 for num in range(5)]
nums_squared_gc = (num**2 for num in range(5))
 
print(nums_squared_lc)
print(nums_squared_gc)
 
# 输出
[0, 1, 4, 9, 16]
<generator object <genexpr> at 0x000001F43EB64AC0>

nums_squared_lc 和 nums_squared_gc 看起来基本相同,但有一个关键区别。发现了吗?第一个构建了一个列表,第二个通过括号生成了一个generator表达式,输出结果也表明生成了一个generator对象。

剖析generator的性能

前面已经了解到,generator是优化内存的一种好方法。虽然无限序列generator是这种优化的一个极端例子,但让我们把刚才看到的数字平方例子放大,并检查生成对象的大小。你可以调用 sys.getsizeof() 来做到这一点:

1
2
3
4
5
6
7
>>> import sys
>>> nums_squared_lc = [i ** 2 for i in range(10000)]
>>> sys.getsizeof(nums_squared_lc)
85176
>>> nums_squared_gc = (i ** 2 for i in range(10000))
>>> print(sys.getsizeof(nums_squared_gc))
112

在本例中,从列表表达式中得到的列表是 85176 字节,而generator对象只有 112 字节。这意味着列表比generator对象大 760 多倍!

 

不过有一点需要注意。如果列表小于运行机器的可用内存,那么列表表达式的运算速度会比等效的generator表达式更快。为了探讨这个问题,让我们对上面两个综合的结果进行求和。可以使用 cProfile.run() 生成读出结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
>>> import cProfile
>>> cProfile.run('sum([i * 2 for i in range(10000)])')
         5 function calls in 0.001 seconds
 
   Ordered by: standard name
 
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.000    0.000 <string>:1(<listcomp>)
        1    0.000    0.000    0.001    0.001 <string>:1(<module>)
        1    0.000    0.000    0.001    0.001 {built-in method builtins.exec}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.sum}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
 
 
>>> cProfile.run('sum((i * 2 for i in range(10000)))')
         10005 function calls in 0.002 seconds
 
   Ordered by: standard name
 
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    10001    0.001    0.000    0.001    0.000 <string>:1(<genexpr>)
        1    0.000    0.000    0.002    0.002 <string>:1(<module>)
        1    0.000    0.000    0.002    0.002 {built-in method builtins.exec}
        1    0.001    0.001    0.002    0.002 {built-in method builtins.sum}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
 
 
>>>

从这里可以看出,对列表表达式中的所有值求和所需的时间约为对generator求和所需的时间的三分之一。如果速度是个问题,而内存不是,那么列表表达式可能是更好的工具。

 

注意:这些测量结果不仅适用于使用generator表达式生成的对象。它们也同样适用于使用类似generator函数生成的对象,因为generator的结果是等价的。

 

请记住,列表表达式返回完整列表,而generator表达式返回generator。无论使用函数还是表达式,generator的工作原理都是一样的。使用表达式只需在一行中定义简单的generator,并在每次内部迭代结束时设定 yield。

 

Python 的 yield 语句无疑是generator所有功能的关键所在,让我们深入了解一下 Python 中 yield 的工作原理。

 

理解 Python Yield 语句

总的来说,yield 是一个相当简单的语句。它的主要作用是以一种类似于 return 语句的方式控制generator函数的流程。正如上面提到的,Python yield 语句有一些小技巧。

调用generator函数或使用generator表达式时,会返回一个特殊的迭代器,称为generator。可以将generator赋值给一个变量来使用它。当调用generator上的特殊方法(如 next())时,函数内的代码会一直执行到 yield。

当 Python yield 语句被执行时,程序会暂停函数的执行,并将 yield 值返回给调用者。(相反,return 会完全停止函数的执行。)当函数被暂停时,函数的状态会被保存。这包括generator本地的任何变量绑定、指令指针、内部堆栈和任何异常处理。

 

这样,只要调用generator的某个方法,就能恢复函数的执行。这样,所有函数的评估都会在 yield 结束后立即恢复。可以通过使用多个 Python yield 语句来了解这种方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
>>> def multi_yield():
...     yield_str = "This will print the first string"
...     yield yield_str
...     yield_str = "This will print the second string"
...     yield yield_str
...
>>> multi_obj = multi_yield()
>>> print(next(multi_obj))
This will print the first string
>>> print(next(multi_obj))
This will print the second string
>>> print(next(multi_obj))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
StopIteration

仔细看最后一次调用 next()。你可以看到执行过程中出现了Traceback。这是因为generator和所有迭代器一样,可能会被耗尽。除非generator是无限的,否则只能遍历一次。一旦评估完所有值,迭代就会停止,for 循环也会退出。如果使用 next(),则会出现明确的 StopIteration 异常。

 

注意:StopIteration 是一种自然异常,用于提示迭代器的结束。你甚至可以使用 while 循环来实现自己的 for 循环:

1
2
3
4
5
6
7
8
9
10
11
12
13
>>> letters = ["a", "b", "c", "y"]
>>> it = iter(letters)
>>> while True:
...     try:
...         letter = next(it)
...     except StopIteration:
...         break
...     print(letter)
...
a
b
c
y

yield 可以通过多种方式来控制generator的执行流。在你的创造力允许范围内,可以使用多个 Python yield 语句。

 

使用高级的generator方法

除了 yield 之外,generator对象还可以使用以下方法:

1
2
3
.send()
.throw()
.close()

如何使用send()

下面将写一个使用上面提到的三种方法的程序。这个程序会像之前一样打印数字回文,但会做一些调整。在遇到一个回文时,新程序会添加一个数字,然后开始搜索下一个数字。还将使用 .throw() 来处理异常,并使用 .close() 在输入一定数量的数字后停止generator。首先回顾一下回文检测器的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def is_palindrome(num):
    # Skip single-digit inputs
    if num // 10 == 0:
        return False
    temp = num
    reversed_num = 0
 
    while temp != 0:
        reversed_num = (reversed_num * 10) + (temp % 10)
        temp = temp // 10
 
    if num == reversed_num:
        return True
    else:
        return False

这个程序和之前的代码差不多,只是将程序的返回设定成了True或False。无限序列generator的代码也要做出调整:

1
2
3
4
5
6
7
8
def infinite_palindromes():
    num = 0
    while True:
        if is_palindrome(num):
            i = (yield num)
            if i is not None:
                num = i
        num += 1

这里有很多改动!第一个变化出现在第 5 行,即 i = (yield num)。虽然在前面学到 yield 是一个语句,但这并不是全部。

 

从 Python 2.5 开始,yield 是一个表达式,而不是语句。当然,仍然可以将它用作语句。但现在,也可以像上面代码块中看到的那样使用它,其中 i 取值为 yield 的值。这样,就可以对产生的值进行操作。更重要的是,它允许你向generator发送一个值。当 yield 结束后重新开始执行时,i 将获取发送的值。

 

还要检查 i 是否为 None,如果在generator对象上调用 next(),可能会出现这种情况。(如果 i 有一个值,就用新值更新 num。但无论 i 是否有值,都会递增 num 并再次开始循环)

 

现在来看一下主函数代码,它将最低的数字与另一个数字一起发送回generator。例如,如果回文是 121,那么它将 .send() 1000:

1
2
3
4
pal_gen = infinite_palindromes()
for i in pal_gen:
    digits = len(str(i))
    pal_gen.send(10 ** (digits))

通过这段代码创建generator对象并遍历它。只有在找到一个回文字符串时,程序才会产生一个值。它使用 len() 来确定该回文的位数。然后,向generator发送 10 ** 位数字。由于 i 现在有了一个值,程序会更新 num、递增并再次检查回文。

 

一旦代码找到并产生另一个回文,就会通过 for 循环进行迭代。这与使用 next() 进行迭代是一样的。在第 5 行,代码生成器也使用 i = (yield num)。不过,现在 i 是 None,因为你没有明确发送一个值。

 

这里创建的是一个 coroutine 或generator函数,可以向其中传递数据。这些函数对于构建数据管道非常有用,但正如你很快就会看到的,它们并不是构建管道所必需的。

 

在学习了 .send() 之后,让我们来看看 .throw()。

如何使用throw()

throw() 支持使用generator抛出异常。在下面的示例中,将在第 6 行引发异常。一旦digits达到 5,这段代码就会抛出 ValueError:

1
2
3
4
5
6
7
pal_gen = infinite_palindromes()
for i in pal_gen:
    print(i)
    digits = len(str(i))
    if digits == 5:
        pal_gen.throw(ValueError("We don't like large palindromes"))
    pal_gen.send(10 ** (digits))

这与之前的代码相同,但现在要检查 digits 是否等于 5,如果是,则 .throw() 一个 ValueError。要确认代码是否按预期运行,请查看代码的输出结果:

1
2
3
4
5
6
7
8
9
10
11
12
11
111
1111
10101
Traceback (most recent call last):
  File "advanced_gen.py", line 47, in <module>
    main()
  File "advanced_gen.py", line 41, in main
    pal_gen.throw(ValueError("We don't like large palindromes"))
  File "advanced_gen.py", line 26, in infinite_palindromes
    i = (yield num)
ValueError: We don't like large palindromes

.throw() 在需要捕获异常时非常有用。在本例中,使用 .throw() 来控制何时停止迭代generator。使用 .close() 可以更优雅地实现这一目的。

如何使用 close()

顾名思义,close() 允许停止generator。这在控制无限序列generator时尤其方便。让我们更新上面的代码,将 .throw() 改为 .close() 以停止迭代:

1
2
3
4
5
6
7
pal_gen = infinite_palindromes()
for i in pal_gen:
    print(i)
    digits = len(str(i))
    if digits == 5:
        pal_gen.close()
    pal_gen.send(10 ** (digits))

在第 6 行中使用 .close() 代替调用 .throw() 。使用 .close() 的好处是,它会引发 StopIteration 异常,这是一个用于提示有限迭代器结束的异常:

1
2
3
4
5
6
7
8
9
10
11
111
1111
10101
Traceback (most recent call last):
  File "advanced_gen.py", line 46, in <module>
    main()
  File "advanced_gen.py", line 42, in main
    pal_gen.send(10 ** (digits))
StopIteration

下面我们就来谈谈如何使用generator来构建数据管道。

使用generator创建数据管道

数据管道支持将代码串联起来处理大型数据集或数据流,而不会耗尽机器的内存。想象一下,你有一个大型 CSV 文件,其第一行如下:

1
2
3
4
5
6
permalink,company,numEmps,category,city,state,fundedDate,raisedAmt,raisedCurrenc y,round
digg,Digg,60,web,San Francisco,CA,1-Dec-06,8500000,USD,b
digg,Digg,60,web,San Francisco,CA,1-Oct-05,2800000,USD,a
facebook,Facebook,450,web,Palo Alto,CA,1-Sep-04,500000,USD,angel
facebook,Facebook,450,web,Palo Alto,CA,1-May-05,12700000,USD,a
photobucket,Photobucket,60,web,Palo Alto,CA,1-Mar-05,3000000,USD,a

本示例取自 TechCrunch 美国大陆数据集,该数据集描述了美国各种初创企业的融资轮次和金额。

 

是时候用 Python 进行一些处理了!为了演示如何使用generator构建流水线,分析该文件,以获得数据集中所有 A 系列回合的总数和平均值。

 

让我们想想策略:

1.读取文件的每一行。

2.将每一行拆分成一个值列表。

3.提取列名。

4.使用列名和列表创建字典。

5.过滤掉你不感兴趣的轮次。

6.计算你感兴趣的轮次的总值和平均值。

通常情况下,你可以使用像 pandas 这样的软件包来实现这一功能,但你也可以使用一些generator来实现这一功能。首先,我们将使用generator表达式读取文件中的每一行:

1
2
file_name = "techcrunch.csv"
lines = (line for line in open(file_name))

然后,使用另一个generator表达式与前一个generator表达式配合使用,将每一行分割成一个列表:

1
list_line = (s.rstrip().split(",") for s in lines)

在这里,创建了generator list_line,它会generator 的第一行遍历。这是设计generator管道时常用的模式。接下来,将从 techcrunch.csv 中提取列名。由于列名往往是 CSV 文件的第一行,因此可以通过调用简短的 next() 来获取:

1
cols = next(list_line)

调用 next() 会使迭代器在 generator list_line上前进一次。将所有代码组合在一起,代码应该是这样的:

1
2
3
4
file_name = "techcrunch.csv"
lines = (line for line in open(file_name))
list_line = (s.rstrip().split(",") for s in lines)
cols = next(list_line)

首先创建一个generator表达式 lines,生成文件中的每一行。然后,在另一个名为 list_line 的generator表达式的定义中迭代该generator,将每一行转化为一个值列表。然后,使用 next() 将 list_line 的迭代向前推进一次,以获取 CSV 文件中的列名列表。

 

注:注意尾部换行!这段代码利用了 list_line 生成器表达式中的 .rstrip() 来确保没有尾部换行符,因为 CSV 文件中可能存在尾部换行符。

 

为了帮助过滤数据并对其执行操作,将创建一个字典,其中的键是 CSV 中的列名:

1
company_dicts = (dict(zip(cols, data)) for data in list_line)

这个generator表达式会遍历 list_line 生成的列表。然后,它使用 zip() 和 dict() 创建上述指定的字典。现在,使用第四个generator来过滤你想要的融资轮次,并同时提取 raisedAmt:

1
2
3
4
5
funding = (
    int(company_dict["raisedAmt"])
    for company_dict in company_dicts
    if company_dict["round"] == "a"
)

在这段代码中,generator表达式会遍历 company_dicts 的结果,并获取 round 键为 “a ”的任何 company_dict 的 raisedAmt。

 

请记住,在generator表达式中并不是一次遍历所有这些结果。事实上,在使用 for 循环或对可迭代表起作用的函数(如 sum())之前,你什么都没有迭代。下面调用 sum() 来遍历generator:

1
total_series_a = sum(funding)

将这一切组合在一起,就能生成下面的代码:

1
2
3
4
5
6
7
8
9
10
11
12
file_name = "techcrunch.csv"
lines = (line for line in open(file_name))
list_line = (s.rstrip().split(",") for s in lines)
cols = next(list_line)
company_dicts = (dict(zip(cols, data)) for data in list_line)
funding = (
    int(company_dict["raisedAmt"])
    for company_dict in company_dicts
    if company_dict["round"] == "a"
)
total_series_a = sum(funding)
print(f"Total series A fundraising: ${total_series_a}")

该脚本将创建的每个generator都整合在一起,它们就像一个大数据管道。下面是一行一行的分解:

·第 2 行读入文件的每一行。

·第 3 行将每行分割成不同的值,并将这些值放入一个列表中。

·第 4 行使用 next() 将列名存储到列表中。

·第 5 行创建字典,并调用 zip() 将其合并:

--键是第 4 行中的列名 cols。

--值是第 3 行创建的列表形式的行。

·第 6 行获取每家公司的 A 轮融资额。它还会过滤掉任何其他融资金额。

·第 11 行开始迭代过程,调用 sum() 获得 CSV 中 A 轮融资的总金额。

在 techcrunch.csv 上运行此代码后,你会发现 A 轮融资总额为 438,015,000 美元。

注意:本教程中开发的处理 CSV 文件的方法对于理解如何使用generator和 Python yield 语句非常重要。不过,在 Python 中处理 CSV 文件时,应使用 Python 标准库中的 csv 模块。该模块可以更稳健地解析逗号分隔文件,并有优化的方法来高效处理它们。

相关博文:
阅读排行:
· DeepSeek “源神”启动!「GitHub 热点速览」
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· 我与微信审核的“相爱相杀”看个人小程序副业
· C# 集成 DeepSeek 模型实现 AI 私有化(本地部署与 API 调用教程)
· spring官宣接入deepseek,真的太香了~
历史上的今天:
2023-01-03 fn_dblog()和fn_full_dblog()的使用
点击右上角即可分享
微信分享提示