如何在 Python 中使用 generators 和 yield
2025-01-03 20:02 abce 阅读(46) 评论(0) 编辑 收藏 举报
是否曾经需要处理一个大到足以耗尽机器内存的数据集?或者有一个复杂的函数,每次调用时都需要维护内部状态,但这个函数太小,不适合创建自己的类。在这些情况以及更多情况下,Generators 和 yield 语句都能帮上忙。
使用 generator
generator 函数是一种返回懒惰迭代器的特殊函数。generator 对象可以像列表一样循环使用。然而,与列表不同的是,懒惰迭代器不会将其内容存储在内存中。
generator示例1:读取大文件
generator 的一个常见用例是处理数据流或大文件,如 CSV 文件。这些文本文件使用逗号将数据分栏。这种格式是共享数据的常用方式。现在,如果要计算 CSV 文件中的行数怎么办?下面的代码块展示了一种计算行数的方法:
1 2 3 4 5 6 7 | csv_gen = csv_reader( "some_csv.txt" ) row_count = 0 for row in csv_gen: row_count += 1 print(f "Row count is : {row_count}" ) |
这里你可以认为 csv_gen 是一个列表。为了填充这个列表,csv_reader() 会打开一个文件并将其内容载入 csv_gen。然后,程序会遍历该列表,递增 row_count 计数。
这是一个合理的解释,但如果文件非常大,这种设计还能起作用吗?如果文件比可用内存还大呢?为了回答这个问题,我们假设 csv_reader()只是打开文件并将其读入一个数组:
1 2 3 4 | def csv_reader(file_name): file = open (file_name) result = file. read ().split( "\n" ) #按行分割 return result |
函数打开一个给定的文件,并使用 file.read() 和 .split() 将每一行作为一个单独的元素添加到一个列表中。如果在前面看到的行计数代码块中使用此版本的 csv_reader(),就会得到以下输出结果:
1 2 3 4 5 6 7 8 | Traceback (most recent call last ): File "ex1_naive.py" , line 22, in <module> main() File "ex1_naive.py" , line 13, in main csv_gen = csv_reader( "file.txt" ) File "ex1_naive.py" , line 6, in csv_reader result = file. read ().split( "\n" ) MemoryError |
在这种情况下,如果 open() 会返回一个 generator 对象,你可以懒洋洋地逐行遍历。但是,file.read().split() 会一次性将所有内容加载到内存中,从而导致内存错误。
在此之前,你可能会发现电脑运行缓慢。甚至可能需要使用键盘中断(KeyboardInterrupt)来杀死程序。那么,如何处理这些庞大的数据文件呢?来看看 csv_reader() 的新定义:
1 2 3 | def csv_reader(file_name): for row in open (file_name, "r" ): yield row |
在此代码版本中,打开文件,遍历文件,然后生成一行。该代码应产生以下输出,且不会出现内存错误
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | def csv_reader(file_name): for row in open (file_name, "r" ): yield row csv_gen = csv_reader( "/tmp/e.log" ) row_count = 0 for row in csv_gen: row_count += 1 print(f "Row count is : {row_count}" ) # 运行结果 Row count is : 406242721 |
这里把 csv_reader() 变成了一个 generator 函数。这个版本打开一个文件,循环查看每一行,然后生成每一行,而不是返回文件。
你还可以定义一个 generator 表达式,它的语法与列表表达式非常相似。这样,就可以在不调用函数的情况下使用 generator:
1 | csv_gen = (row for row in open (file_name)) |
generator示例 2:生成无限序列
让我们换个角度来看看无穷序列的生成。在 Python 中,要获得有限序列,需要调用 range() 并在列表上下文中对其进行评估:
1 2 3 | >>> a = range(5) >>> list(a) [0, 1, 2, 3, 4] |
然而生成一个无限序列需要使用 generator,因为你的电脑的内存是有限的:
1 2 3 4 5 | def infinite_sequence(): num = 0 while True : yield num num += 1 |
该代码块简短而精炼。首先,初始化变量 num 并启动一个无限循环。然后,立即 yeild num,以便捕捉初始状态。这模仿了 range() 的操作。
在 yield 之后,num 会递增 1。如果你用 for 循环尝试一下,就会发现它看起来确实是无限的:
1 2 3 4 5 6 7 8 9 10 11 12 | >>> for i in infinite_sequence(): ... print(i, end = " " ) ... 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 [...] 6157818 6157819 6157820 6157821 6157822 6157823 6157824 6157825 6157826 6157827 6157828 6157829 6157830 6157831 6157832 6157833 6157834 6157835 6157836 6157837 6157838 6157839 6157840 6157841 6157842 Traceback (most recent call last ): File "<stdin>" , line 2, in <module> KeyboardInterrupt |
程序将继续执行,直到手动停止。
也可以直接在 generator 对象上调用 next() 来代替 for 循环。这对于在控制台中测试信号发生器特别有用:
1 2 3 4 5 6 7 8 9 | >>> gen = infinite_sequence() >>> next (gen) 0 >>> next (gen) 1 >>> next (gen) 2 >>> next (gen) 3 |
在这里,你有一个名为 gen 的 generator ,你可以通过重复调用 next() 来手动遍历该 generator 。这可以作为一个很好的正确性检查,以确保 generator 产生了所期望的输出。
generator示例 3:检测回文
无限序列可以用在很多方面,但其中一个实际用途是构建回文检测器。回文检测器会找出所有回文字母或数字序列。这些单词或数字的正向和反向读法相同,就像 121。首先,定义数字回文检测器:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | def is_palindrome(num): # Skip single-digit inputs if num // 10 == 0: return False temp = num reversed_num = 0 while temp != 0: reversed_num = (reversed_num * 10) + ( temp % 10) temp = temp // 10 if num == reversed_num: return num else : return False |
不用太在意如何理解这段代码中的基本数学知识。只需注意,该函数接收一个输入数字,将其反转,并检查反转后的数字是否与原始数字相同。现在,你可以使用无限序列 generator 来获取所有数字回文的运行列表:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | >>> for i in infinite_sequence(): ... pal = is_palindrome(i) ... if pal: ... print(i) ... 11 22 33 [...] 99799 99899 99999 100001 101101 102201 Traceback (most recent call last ): File "<stdin>" , line 2, in <module> File "<stdin>" , line 5, in is_palindrome KeyboardInterrupt |
注:实际上,你不太可能编写自己的无限序列 generator 。itertools 模块通过 itertools.count() 提供了一个非常高效的无限序列 generator 。
现在你已经看到了无限序列生成器的一个简单用例,让我们深入了解生成器是如何工作的。
理解generator
到目前为止,已经了解了创建generator的两种主要方法:使用generator函数和generator表达式。
generator函数的外观和行为与普通函数一样,但有一个显著特点。generator函数使用 Python 的 yield 关键字来代替 return。回想一下之前编写的generator函数:
1 2 3 4 5 | def infinite_sequence(): num = 0 while True : yield num num += 1 |
这看起来像一个典型的函数定义,除了 Python 的 yield 语句和它后面的代码。yield 表示将一个值发送回调用者,但与 return 不同的是,你不会在之后退出函数。
相反,函数的状态会被记住。这样,当在generator对象上调用 next() 时(无论是显式调用还是在 for 循环中隐式调用),先前yield的变量 num 会递增,然后再次yield。由于generator函数看起来像其他函数,其行为也与其他函数非常相似,因此可以认为generator表达式与 Python 中的其他表达式非常相似。
使用generator表达式创建generator
与列表表达式一样,generator表达式允许你用几行代码快速创建一个generator对象。generator表达式很有用,还有一个额外的好处:在迭代之前,创建generator表达式无需在内存中构建和保留整个对象。换句话说,使用generator表达式不会占用内存。请看下面这个数字平方的例子:
1 2 3 4 5 6 7 8 9 | nums_squared_lc = [num**2 for num in range(5)] nums_squared_gc = (num**2 for num in range(5)) print(nums_squared_lc) print(nums_squared_gc) # 输出 [0, 1, 4, 9, 16] <generator object <genexpr> at 0x000001F43EB64AC0> |
nums_squared_lc 和 nums_squared_gc 看起来基本相同,但有一个关键区别。发现了吗?第一个构建了一个列表,第二个通过括号生成了一个generator表达式,输出结果也表明生成了一个generator对象。
剖析generator的性能
前面已经了解到,generator是优化内存的一种好方法。虽然无限序列generator是这种优化的一个极端例子,但让我们把刚才看到的数字平方例子放大,并检查生成对象的大小。你可以调用 sys.getsizeof() 来做到这一点:
1 2 3 4 5 6 7 | >>> import sys >>> nums_squared_lc = [i ** 2 for i in range(10000)] >>> sys.getsizeof(nums_squared_lc) 85176 >>> nums_squared_gc = (i ** 2 for i in range(10000)) >>> print(sys.getsizeof(nums_squared_gc)) 112 |
在本例中,从列表表达式中得到的列表是 85176 字节,而generator对象只有 112 字节。这意味着列表比generator对象大 760 多倍!
不过有一点需要注意。如果列表小于运行机器的可用内存,那么列表表达式的运算速度会比等效的generator表达式更快。为了探讨这个问题,让我们对上面两个综合的结果进行求和。可以使用 cProfile.run() 生成读出结果:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | >>> import cProfile >>> cProfile.run( 'sum([i * 2 for i in range(10000)])' ) 5 function calls in 0.001 seconds Ordered by : standard name ncalls tottime percall cumtime percall filename:lineno( function ) 1 0.000 0.000 0.000 0.000 <string>:1(<listcomp>) 1 0.000 0.000 0.001 0.001 <string>:1(<module>) 1 0.000 0.000 0.001 0.001 {built- in method builtins. exec } 1 0.000 0.000 0.000 0.000 {built- in method builtins. sum } 1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects} >>> cProfile.run( 'sum((i * 2 for i in range(10000)))' ) 10005 function calls in 0.002 seconds Ordered by : standard name ncalls tottime percall cumtime percall filename:lineno( function ) 10001 0.001 0.000 0.001 0.000 <string>:1(<genexpr>) 1 0.000 0.000 0.002 0.002 <string>:1(<module>) 1 0.000 0.000 0.002 0.002 {built- in method builtins. exec } 1 0.001 0.001 0.002 0.002 {built- in method builtins. sum } 1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects} >>> |
从这里可以看出,对列表表达式中的所有值求和所需的时间约为对generator求和所需的时间的三分之一。如果速度是个问题,而内存不是,那么列表表达式可能是更好的工具。
注意:这些测量结果不仅适用于使用generator表达式生成的对象。它们也同样适用于使用类似generator函数生成的对象,因为generator的结果是等价的。
请记住,列表表达式返回完整列表,而generator表达式返回generator。无论使用函数还是表达式,generator的工作原理都是一样的。使用表达式只需在一行中定义简单的generator,并在每次内部迭代结束时设定 yield。
Python 的 yield 语句无疑是generator所有功能的关键所在,让我们深入了解一下 Python 中 yield 的工作原理。
理解 Python Yield 语句
总的来说,yield 是一个相当简单的语句。它的主要作用是以一种类似于 return 语句的方式控制generator函数的流程。正如上面提到的,Python yield 语句有一些小技巧。
调用generator函数或使用generator表达式时,会返回一个特殊的迭代器,称为generator。可以将generator赋值给一个变量来使用它。当调用generator上的特殊方法(如 next())时,函数内的代码会一直执行到 yield。
当 Python yield 语句被执行时,程序会暂停函数的执行,并将 yield 值返回给调用者。(相反,return 会完全停止函数的执行。)当函数被暂停时,函数的状态会被保存。这包括generator本地的任何变量绑定、指令指针、内部堆栈和任何异常处理。
这样,只要调用generator的某个方法,就能恢复函数的执行。这样,所有函数的评估都会在 yield 结束后立即恢复。可以通过使用多个 Python yield 语句来了解这种方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | >>> def multi_yield(): ... yield_str = "This will print the first string" ... yield yield_str ... yield_str = "This will print the second string" ... yield yield_str ... >>> multi_obj = multi_yield() >>> print( next (multi_obj)) This will print the first string >>> print( next (multi_obj)) This will print the second string >>> print( next (multi_obj)) Traceback (most recent call last ): File "<stdin>" , line 1, in <module> StopIteration |
仔细看最后一次调用 next()。你可以看到执行过程中出现了Traceback。这是因为generator和所有迭代器一样,可能会被耗尽。除非generator是无限的,否则只能遍历一次。一旦评估完所有值,迭代就会停止,for 循环也会退出。如果使用 next(),则会出现明确的 StopIteration 异常。
注意:StopIteration 是一种自然异常,用于提示迭代器的结束。你甚至可以使用 while 循环来实现自己的 for 循环:
1 2 3 4 5 6 7 8 9 10 11 12 13 | >>> letters = [ "a" , "b" , "c" , "y" ] >>> it = iter(letters) >>> while True : ... try: ... letter = next (it) ... except StopIteration: ... break ... print(letter) ... a b c y |
yield 可以通过多种方式来控制generator的执行流。在你的创造力允许范围内,可以使用多个 Python yield 语句。
使用高级的generator方法
除了 yield 之外,generator对象还可以使用以下方法:
1 2 3 | .send() .throw() . close () |
如何使用send()
下面将写一个使用上面提到的三种方法的程序。这个程序会像之前一样打印数字回文,但会做一些调整。在遇到一个回文时,新程序会添加一个数字,然后开始搜索下一个数字。还将使用 .throw() 来处理异常,并使用 .close() 在输入一定数量的数字后停止generator。首先回顾一下回文检测器的代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | def is_palindrome(num): # Skip single-digit inputs if num // 10 == 0: return False temp = num reversed_num = 0 while temp != 0: reversed_num = (reversed_num * 10) + ( temp % 10) temp = temp // 10 if num == reversed_num: return True else : return False |
这个程序和之前的代码差不多,只是将程序的返回设定成了True或False。无限序列generator的代码也要做出调整:
1 2 3 4 5 6 7 8 | def infinite_palindromes(): num = 0 while True : if is_palindrome(num): i = (yield num) if i is not None: num = i num += 1 |
这里有很多改动!第一个变化出现在第 5 行,即 i = (yield num)。虽然在前面学到 yield 是一个语句,但这并不是全部。
从 Python 2.5 开始,yield 是一个表达式,而不是语句。当然,仍然可以将它用作语句。但现在,也可以像上面代码块中看到的那样使用它,其中 i 取值为 yield 的值。这样,就可以对产生的值进行操作。更重要的是,它允许你向generator发送一个值。当 yield 结束后重新开始执行时,i 将获取发送的值。
还要检查 i 是否为 None,如果在generator对象上调用 next(),可能会出现这种情况。(如果 i 有一个值,就用新值更新 num。但无论 i 是否有值,都会递增 num 并再次开始循环)
现在来看一下主函数代码,它将最低的数字与另一个数字一起发送回generator。例如,如果回文是 121,那么它将 .send() 1000:
1 2 3 4 | pal_gen = infinite_palindromes() for i in pal_gen: digits = len(str(i)) pal_gen.send(10 ** (digits)) |
通过这段代码创建generator对象并遍历它。只有在找到一个回文字符串时,程序才会产生一个值。它使用 len() 来确定该回文的位数。然后,向generator发送 10 ** 位数字。由于 i 现在有了一个值,程序会更新 num、递增并再次检查回文。
一旦代码找到并产生另一个回文,就会通过 for 循环进行迭代。这与使用 next() 进行迭代是一样的。在第 5 行,代码生成器也使用 i = (yield num)。不过,现在 i 是 None,因为你没有明确发送一个值。
这里创建的是一个 coroutine 或generator函数,可以向其中传递数据。这些函数对于构建数据管道非常有用,但正如你很快就会看到的,它们并不是构建管道所必需的。
在学习了 .send() 之后,让我们来看看 .throw()。
如何使用throw()
throw() 支持使用generator抛出异常。在下面的示例中,将在第 6 行引发异常。一旦digits达到 5,这段代码就会抛出 ValueError:
1 2 3 4 5 6 7 | pal_gen = infinite_palindromes() for i in pal_gen: print(i) digits = len(str(i)) if digits == 5: pal_gen.throw(ValueError( "We don't like large palindromes" )) pal_gen.send(10 ** (digits)) |
这与之前的代码相同,但现在要检查 digits 是否等于 5,如果是,则 .throw() 一个 ValueError。要确认代码是否按预期运行,请查看代码的输出结果:
1 2 3 4 5 6 7 8 9 10 11 12 | 11 111 1111 10101 Traceback (most recent call last ): File "advanced_gen.py" , line 47, in <module> main() File "advanced_gen.py" , line 41, in main pal_gen.throw(ValueError( "We don't like large palindromes" )) File "advanced_gen.py" , line 26, in infinite_palindromes i = (yield num) ValueError: We don't like large palindromes |
.throw() 在需要捕获异常时非常有用。在本例中,使用 .throw() 来控制何时停止迭代generator。使用 .close() 可以更优雅地实现这一目的。
如何使用 close()
顾名思义,close() 允许停止generator。这在控制无限序列generator时尤其方便。让我们更新上面的代码,将 .throw() 改为 .close() 以停止迭代:
1 2 3 4 5 6 7 | pal_gen = infinite_palindromes() for i in pal_gen: print(i) digits = len(str(i)) if digits == 5: pal_gen. close () pal_gen.send(10 ** (digits)) |
在第 6 行中使用 .close() 代替调用 .throw() 。使用 .close() 的好处是,它会引发 StopIteration 异常,这是一个用于提示有限迭代器结束的异常:
1 2 3 4 5 6 7 8 9 10 | 11 111 1111 10101 Traceback (most recent call last ): File "advanced_gen.py" , line 46, in <module> main() File "advanced_gen.py" , line 42, in main pal_gen.send(10 ** (digits)) StopIteration |
下面我们就来谈谈如何使用generator来构建数据管道。
使用generator创建数据管道
数据管道支持将代码串联起来处理大型数据集或数据流,而不会耗尽机器的内存。想象一下,你有一个大型 CSV 文件,其第一行如下:
1 2 3 4 5 6 | permalink,company,numEmps,category,city,state,fundedDate,raisedAmt,raisedCurrenc y,round digg,Digg,60,web,San Francisco,CA,1- Dec -06,8500000,USD,b digg,Digg,60,web,San Francisco,CA,1-Oct-05,2800000,USD,a facebook,Facebook,450,web,Palo Alto,CA,1-Sep-04,500000,USD,angel facebook,Facebook,450,web,Palo Alto,CA,1-May-05,12700000,USD,a photobucket,Photobucket,60,web,Palo Alto,CA,1-Mar-05,3000000,USD,a |
本示例取自 TechCrunch 美国大陆数据集,该数据集描述了美国各种初创企业的融资轮次和金额。
是时候用 Python 进行一些处理了!为了演示如何使用generator构建流水线,分析该文件,以获得数据集中所有 A 系列回合的总数和平均值。
让我们想想策略:
1.读取文件的每一行。
2.将每一行拆分成一个值列表。
3.提取列名。
4.使用列名和列表创建字典。
5.过滤掉你不感兴趣的轮次。
6.计算你感兴趣的轮次的总值和平均值。
通常情况下,你可以使用像 pandas 这样的软件包来实现这一功能,但你也可以使用一些generator来实现这一功能。首先,我们将使用generator表达式读取文件中的每一行:
1 2 | file_name = "techcrunch.csv" lines = (line for line in open (file_name)) |
然后,使用另一个generator表达式与前一个generator表达式配合使用,将每一行分割成一个列表:
1 | list_line = (s.rstrip().split( "," ) for s in lines) |
在这里,创建了generator list_line,它会generator 的第一行遍历。这是设计generator管道时常用的模式。接下来,将从 techcrunch.csv 中提取列名。由于列名往往是 CSV 文件的第一行,因此可以通过调用简短的 next() 来获取:
1 | cols = next (list_line) |
调用 next() 会使迭代器在 generator list_line上前进一次。将所有代码组合在一起,代码应该是这样的:
1 2 3 4 | file_name = "techcrunch.csv" lines = (line for line in open (file_name)) list_line = (s.rstrip().split( "," ) for s in lines) cols = next (list_line) |
首先创建一个generator表达式 lines,生成文件中的每一行。然后,在另一个名为 list_line 的generator表达式的定义中迭代该generator,将每一行转化为一个值列表。然后,使用 next() 将 list_line 的迭代向前推进一次,以获取 CSV 文件中的列名列表。
注:注意尾部换行!这段代码利用了 list_line 生成器表达式中的 .rstrip() 来确保没有尾部换行符,因为 CSV 文件中可能存在尾部换行符。
为了帮助过滤数据并对其执行操作,将创建一个字典,其中的键是 CSV 中的列名:
1 | company_dicts = (dict(zip(cols, data)) for data in list_line) |
这个generator表达式会遍历 list_line 生成的列表。然后,它使用 zip() 和 dict() 创建上述指定的字典。现在,使用第四个generator来过滤你想要的融资轮次,并同时提取 raisedAmt:
1 2 3 4 5 | funding = ( int (company_dict[ "raisedAmt" ]) for company_dict in company_dicts if company_dict[ "round" ] == "a" ) |
在这段代码中,generator表达式会遍历 company_dicts 的结果,并获取 round 键为 “a ”的任何 company_dict 的 raisedAmt。
请记住,在generator表达式中并不是一次遍历所有这些结果。事实上,在使用 for 循环或对可迭代表起作用的函数(如 sum())之前,你什么都没有迭代。下面调用 sum() 来遍历generator:
1 | total_series_a = sum (funding) |
将这一切组合在一起,就能生成下面的代码:
1 2 3 4 5 6 7 8 9 10 11 12 | file_name = "techcrunch.csv" lines = (line for line in open (file_name)) list_line = (s.rstrip().split( "," ) for s in lines) cols = next (list_line) company_dicts = (dict(zip(cols, data)) for data in list_line) funding = ( int (company_dict[ "raisedAmt" ]) for company_dict in company_dicts if company_dict[ "round" ] == "a" ) total_series_a = sum (funding) print(f "Total series A fundraising: ${total_series_a}" ) |
该脚本将创建的每个generator都整合在一起,它们就像一个大数据管道。下面是一行一行的分解:
·第 2 行读入文件的每一行。
·第 3 行将每行分割成不同的值,并将这些值放入一个列表中。
·第 4 行使用 next() 将列名存储到列表中。
·第 5 行创建字典,并调用 zip() 将其合并:
--键是第 4 行中的列名 cols。
--值是第 3 行创建的列表形式的行。
·第 6 行获取每家公司的 A 轮融资额。它还会过滤掉任何其他融资金额。
·第 11 行开始迭代过程,调用 sum() 获得 CSV 中 A 轮融资的总金额。
在 techcrunch.csv 上运行此代码后,你会发现 A 轮融资总额为 438,015,000 美元。
注意:本教程中开发的处理 CSV 文件的方法对于理解如何使用generator和 Python yield 语句非常重要。不过,在 Python 中处理 CSV 文件时,应使用 Python 标准库中的 csv 模块。该模块可以更稳健地解析逗号分隔文件,并有优化的方法来高效处理它们。
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· DeepSeek “源神”启动!「GitHub 热点速览」
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· 我与微信审核的“相爱相杀”看个人小程序副业
· C# 集成 DeepSeek 模型实现 AI 私有化(本地部署与 API 调用教程)
· spring官宣接入deepseek,真的太香了~
2023-01-03 fn_dblog()和fn_full_dblog()的使用