itertools — 为高效循环创建迭代器的函数

该模块实现了许多iterator构建块,其灵感来自APL、Haskell和SML的构造。每一个都被重写成适合Python的形式。

该模块标准化了一组快速、内存高效的核心工具,这些工具本身或在组合。它们一起构成了一个“迭代器代数学”,使得在纯Python中简洁有效地构造专门的工具成为可能。

例如,SML提供了一个制表工具: tabulate(f),它产生一个序列f(0)f(1),在Python中可以通过结合map()count()来形成map(f, count())来达到同样的效果。

这些工具和它们的内置对应物也可以很好地配合运算符模块中的高速函数。例如,乘法运算符可以跨两个向量进行映射,从而形成高效的dot-product: sum(starmap(operator.mul, zip(vecl, vec2, strict=True))

无穷迭代器:

Iterator Function Arguments Results
count() start, [step] start, start+step, start+2*step, ...
cycle() p p0, p1, ..., plast, p0, p1, ...
repeat() elem [,n] elem, elem, elem, ..., endlessly or up to n times

终止于最短输入序列的迭代器:

Function Arguments Results Example
accumulate p [,func] p0, p0+p1, p0+p1+p2, … accumulate([1,2,3,4,5]) --> 1 3 6 10 15
chain p, q, … p0, p1, … plast, q0, q1, … chain('ABC', 'DEF') --> A B C D E F
chain.from_iterable iterable p0, p1, … plast, q0, q1, … chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
compress data, selectors (d[0] if s[0]), (d[1] if s[1]), … compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F
dropwhile pred, seq seq[n], seq[n+1], starting when pred fails dropwhile(lambda x: x<5, [1,4,6,4,1]) --> 6 4 1
filterfalse pred, seq elements of seq where pred(elem) is false filterfalse(lambda x: x%2, range(10)) --> 0 2 4 6 8
groupby iterable[, key] sub-iterators grouped by value of key(v)
islice seq, [start,] stop [, step] elements from seq[start:stop:step] islice('ABCDEFG', 2, None) --> C D E F G
pairwise iterable (p[0], p[1]), (p[1], p[2]) pairwise('ABCDEFG') --> AB BC CD DE EF FG
starmap func, seq func(seq[0]), func(seq[1]), … starmap(pow, [(2,5), (3,2), (10,3)]) --> 32 9 1000
takewhile pred, seq seq[0], seq[1], until pred fails takewhile(lambda x: x<5, [1,4,6,4,1]) --> 1 4
tee it, n it1, it2, … itn splits one iterator into n
zip_longest p, q, … (p[0], q[0]), (p[1], q[1]), … zip_longest('ABCD', 'xy', fillvalue='-') --> Ax By C- D-

组合数的迭代器:

Function Arguments Results
product() p, q, … [repeat=1] cartesian product, equivalent to a nested for-loop
permutations() p[, r] r-length tuples, all possible orderings, no repeated elements
combinations() p, r r-length tuples, in sorted order, no repeated elements
combinations_with_replacement p, r r-length tuples, in sorted order, with repeated elements
Examples Results
product('ABCD', repeat=2) AA AB AC AD BA BB BC BD CA CB CC CD DA DB DC DD
permutations('ABCD', 2) AB AC AD BA BC BD CA CB CD DA DB DC
combinations('ABCD', 2) AB AC AD BC BD CD
combinations_with_replacement('ABCD', 2) AA AB AC AD BB BC BD CC CD DD

It is not difficult to find that permutations_with_replacement is actually equivalent to product, which may be the reason why this API does not exist.

Itertool functions

以下模块函数都构造并返回迭代器。有些提供无限长度的流,因此只能由截断流的函数或循环访问。

itertools.accumulate(iterable[, func, *, initial=None])

创建一个迭代器,返回累计和或其他二进制函数的累计结果(通过可选参数func指定)。

如果提供了func,它应该是一个包含两个参数的函数。输入iterable的元素可以是任何类型,可以接受为func的实参。(例如,在默认的加法操作中,元素可以是任何可加类型,包括DecimalFraction。)

通常,输出的元素数量与输入的可迭代对象相匹配。然而,如果提供了关键字参数initial,则累加从初始值开始,以便输出比输入可迭代对象多一个元素。

大致相当于:

def accumulate(iterable, func=operator.add, *, initial=None):
    'Return running totals'
    # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
    # accumulate([1,2,3,4,5], initial=100) --> 100 101 103 106 110 115
    # accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
    it = iter(iterable)
    total = initial
    if initial is None:
        try:
            total = next(it)
        except StopIteration:
            return
    yield total
    for element in it:
        total = func(total, element)
        yield total

叫chatgpt解释一下:

The function first creates an iterator for the input iterable and assigns the initial value to the total variable (or the first element of the iterable, if initial is None). The yield statement is used to return the initial value as the first output of the generator.

The function then iterates over the remaining elements of the iterable and applies the func operation to each element and the current running total(total), updating the running total each time. The new running total is then yielded as the next output of the generator.

func参数有很多用法。它可以被设置为min()用于运行的最小值,max()用于运行的最大值,或者operator.mul()用于运行的乘积。摊销表可以通过累积利息和应用付款来构建:

data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8]
list(accumulate(data, operator.mul))     # running product

list(accumulate(data, max))              # running maximum



cashflows = [1000, -90, -90, -90, -90]
list(accumulate(cashflows, lambda bal, pmt: bal*1.05 + pmt))

有关只返回最终累积值的类似函数,请参阅functools.reduce()

itertools.chain(*iterables)

创建一个迭代器,从第一个可迭代对象返回元素,直到耗尽它,然后继续到下一个可迭代对象,直到耗尽所有可迭代对象。用于将连续序列作为单个序列处理(注意它的形参是一个可变长参数)。

大致相当于:

def chain(*iterables):
    # chain('ABC', 'DEF') --> A B C D E F
    for it in iterables:
        for element in it:
            yield element

This Python code defines a generator function called chain that takes in any number of iterable arguments using the * operator.

The * operator in the function definition indicates that the function accepts a variable number of arguments, which are then packed into a tuple called iterables. In other words, you can pass any number of arguments to the chain function, and they will be treated as iterables that you want to chain together.

The chain function then iterates over each iterable passed in, and yields each element in the iterable one by one. This allows you to iterate over multiple iterables as if they were a single iterable, without having to concatenate them into a single list or tuple.

For example, if you called chain('ABC', 'DEF'), it would return a generator object that would yield the elements 'A', 'B', 'C', 'D', 'E', and 'F' in sequence when iterated over.

Overall, the chain function provides a way to iterate over multiple iterables in a single loop, without having to pre-process them or concatenate them into a single list or tuple.

itertools.combinations(iterable, r)

返回输入可迭代对象中元素的长度为r的子序列。组合元组根据输入iterable的顺序按字典顺序释放。所以,如果输入iterable是排序的,输出元组就会按照排序的顺序产生。

元素的唯一性是基于它们的位置,而不是它们的值。因此,如果输入元素是唯一的,那么在每个组合中就不会有重复的值。

大致相当于:

def combinations(iterable, r):
    # combinations('ABCD', 2) --> AB AC AD BC BD CD
    # combinations(range(4), 3) --> 012 013 023 123
    pool = tuple(iterable)
    n = len(pool)
    if r > n:
        return
    indices = list(range(r))
    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != i + n - r:
                break
        else:
            return
        indices[i] += 1
        for j in range(i+1, r):
            indices[j] = indices[j-1] + 1
        yield tuple(pool[i] for i in indices)

这是经典的组合数生成算法,称为"字典序法"(Lexicographic order generation algorithm)。
算法的基本思想是从最小的组合开始,按照字典序逐个生成所有组合。
具体地,假设要从长度为n的集合中选取r个元素,首先将所有元素按照某种顺序(如从小到大)排列,生成最小的组合,即前r个元素。
然后,不断寻找下一个字典序的组合,方法是从右往左扫描当前组合,找到第一个位置i,使得该位置右边的元素个数不够填满剩余的空位(即n-i<r-indices[i]),然后将位置i的元素替换为它右边的某个元素,并将i右边的所有元素重新排列生成最小的组合。
重复以上步骤,直到最后一个字典序的组合被生成。

这个 combinations()代码也可以被表示为permutations() 过滤掉元素未排序的条目

def combinations(iterable, r):
    pool = tuple(iterable)
    n = len(pool)
    for indices in permutations(range(n), r):
        if sorted(indices) == list(indices):
            yield tuple(pool[i] for i in indices)

itertools.combinations_with_replacement(iterable, r)

返回输入iterable中元素的长度为r的子序列,允许单个元素重复多次。

组合元组根据输入可迭代对象的顺序按字典顺序释放。所以,如果输入iterable是排序的,输出元组就会按照排序的顺序产生。

元素的唯一性是基于它们的位置,而不是它们的值。因此,如果输入元素是唯一的,生成的组合也将是唯一的。

大致相当于:

def combinations_with_replacement(iterable, r):
    # combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC
    pool = tuple(iterable)
    n = len(pool)
    if not n and r:
        return
    indices = [0] * r
    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != n - 1:
                break
        else:
            return
        indices[i:] = [indices[i] + 1] * (r - i)
        yield tuple(pool[i] for i in indices)

combinations_with_replacement()的代码也可以表示为product()的子序列,在过滤了元素没有排序的条目后:

def combinations_with_replacement(iterable, r):
    pool = tuple(iterable)
    n = len(pool)
    for indices in product(range(n), repeat=r):
        if sorted(indices) == list(indices):
            yield tuple(pool[i] for i in indices)

itertools.compress(data, selectors)

创建一个迭代器,从数据中筛选元素,只返回那些在选择器中有相应元素的元素,其计算值为True。当数据或选择器迭代对象耗尽时停止。大致相当于:

def compress(data, selectors):
    # compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F
    return (d for d, s in zip(data, selectors) if s)

itertools.count(start=0, step=1)

创建一个迭代器,返回以数字start开始的等间距值。通常用作map()的参数。生成连续的数据点。同样,与zip()一起使用来添加序列号。大致相当于:

def count(start=0, step=1):
    # count(10) --> 10 11 12 13 14 ...
    # count(2.5, 0.5) --> 2.5 3.0 3.5 ...
    n = start
    while True:
        yield n
        n += step

当使用浮点数计数时,有时可以通过替换乘法代码来获得更好的精度,例如:(start + step * i for i in count())

itertools.cycle(iterable)

创建一个迭代器,返回可迭代对象中的元素,并保存每个元素的副本。当迭代对象耗尽时,从保存的副本中返回元素。重复下去。大致相当于:

def cycle(iterable):
    # cycle('ABCD') --> A B C D A B C D A B C D ...
    saved = []
    for element in iterable:
        yield element
        saved.append(element)
    while saved:
        for element in saved:
              yield element

注意,工具箱的这个成员可能需要大量的辅助存储(取决于可迭代对象的长度)。

itertools.dropwhile(predicate, iterable)

创建一个迭代器,只要谓词为真,迭代器就会从iterable中删除元素;之后,返回每一个元素。注意,迭代器在谓词第一次变为false之前不会产生任何输出,因此它可能有很长的启动时间。大致相当于:

def dropwhile(predicate, iterable):
    # dropwhile(lambda x: x<5, [1,4,6,4,1]) --> 6 4 1
    iterable = iter(iterable)
    for x in iterable:
        if not predicate(x):
            yield x
            break
    for x in iterable:
        yield x

itertools.filterfalse(predicate, iterable)

创建一个迭代器,从iterable中过滤元素,只返回谓词为的元素假的。如果predicate为None,返回为false的项。大致相当于:

def filterfalse(predicate, iterable):
    # filterfalse(lambda x: x%2, range(10)) --> 0 2 4 6 8
    if predicate is None:
        predicate = bool
    for x in iterable:
        if not predicate(x):
            yield x

itertools.groupby(iterable, key=None)

创建一个迭代器,从迭代对象中返回连续的键和组。key是一个为每个元素计算键值的函数。如果没有指定或为None, key默认为一个恒等函数,并原样返回元素。一般来说,迭代对象需要已经在相同的key函数上排序。

groupby()的操作类似于Unix中的unig过滤器。每次键函数的值发生变化时,它都会生成一个break或新组(这就是为什么通常需要使用相同的键函数对数据进行排序)。这种行为不同于SQL的GROUP BY,后者聚合公共元素,而不管它们的输入顺序。

返回的组本身是一个迭代器,它与groupby()共享底层可迭代对象。因为源是共享的,所以当groupby()对象进阶时,前一个组就不再可见了。所以,如果以后需要使用这个数据,应该以列表的形式存储:

groups = []
uniquekeys = []
data = sorted(data, key=keyfunc)
for k, g in groupby(data, keyfunc):
    groups.append(list(g))      # Store group iterator as a list
    uniquekeys.append(k)

groupby()大致相当于:

class groupby:
    # [k for k, g in groupby('AAAABBBCCDAABBB')] --> A B C D A B
    # [list(g) for k, g in groupby('AAAABBBCCD')] --> AAAA BBB CC D

    def __init__(self, iterable, key=None):
        if key is None:
            key = lambda x: x
        self.keyfunc = key
        self.it = iter(iterable)
        self.tgtkey = self.currkey = self.currvalue = object()

    def __iter__(self):
        return self

    def __next__(self):
        self.id = object()
        while self.currkey == self.tgtkey:
            self.currvalue = next(self.it)    # Exit on StopIteration
            self.currkey = self.keyfunc(self.currvalue)
        self.tgtkey = self.currkey
        return (self.currkey, self._grouper(self.tgtkey, self.id))

    def _grouper(self, tgtkey, id):
        while self.id is id and self.currkey == tgtkey:
            yield self.currvalue
            try:
                self.currvalue = next(self.it)
            except StopIteration:
                return
            self.currkey = self.keyfunc(self.currvalue)

itertools.islice(iterable, stop)

itertools.islice(iterable, start, stop[, step])

创建一个迭代器,从可迭代对象中返回所选元素。如果start非零,则跳过迭代对象中的元素,直到到达start。之后,元素将连续返回,除非step设置的值高于1,从而导致元素被跳过。如果stop为None,则继续迭代直到迭代器耗尽(如果有的话);否则,则在指定位置停止。

如果start为None,则迭代从0开始。如果step为None,则step默认为1。

与常规切片不同,islice()不支持start、stop或step为负值。可以用于从内部结构被扁平化的数据中提取相关字段(用于例如,一个多行报告可能会在每第三行上列出一个名称字段)。

大致相当于:

def islice(iterable, *args):
    # islice('ABCDEFG', 2) --> A B
    # islice('ABCDEFG', 2, 4) --> C D
    # islice('ABCDEFG', 2, None) --> C D E F G
    # islice('ABCDEFG', 0, None, 2) --> A C E G
    s = slice(*args)
    start, stop, step = s.start or 0, s.stop or sys.maxsize, s.step or 1
    it = iter(range(start, stop, step))
    try:
        nexti = next(it)
    except StopIteration:
        # Consume *iterable* up to the *start* position.
        for i, element in zip(range(start), iterable):
            pass
        return
    try:
        for i, element in enumerate(iterable):
            if i == nexti:
                yield element
                nexti = next(it)
    except StopIteration:
        # Consume to *stop*.
        for i, element in zip(range(i + 1, stop), iterable):
            pass

itertools.pairwise(iterable)

返回从输入可迭代对象中获取的连续重叠对。

输出迭代器中2元组的数量将比输入的数量少1。如果输入迭代器的值少于两个,则为空。大致相当于:

def pairwise(iterable):
    # pairwise('ABCDEFG') --> AB BC CD DE EF FG
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)

3.10+ 才有这个函数

itertools.permutations(iterable, r=None)

返回迭代对象中元素的连续rlength排列。如果r未指定或为None,则r默认为可迭代对象的长度,并生成所有可能的全长度排列。

排列元组根据输入可迭代对象的顺序按字典顺序发出。所以,如果输入的iterable是排序的,输出的元组就会按照排序的顺序产生。

元素的唯一性是基于它们的位置,而不是它们的值。因此,如果输入的元素是唯一的,那么在一个排列中就不会有重复的值。大致相当于:

def permutations(iterable, r=None):
    # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
    # permutations(range(3)) --> 012 021 102 120 201 210
    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    if r > n:
        return
    indices = list(range(n))
    cycles = list(range(n, n-r, -1))
    yield tuple(pool[i] for i in indices[:r])
    while n:
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i
            else:
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
                yield tuple(pool[i] for i in indices[:r])
                break
        else:
            return

permutations()的代码也可以表示为product()的子序列,过滤为排除元素重复的条目(来自输入池中相同位置的条目):

def permutations(iterable, r=None):
    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    for indices in product(range(n), repeat=r):
        if len(set(indices)) == r:
            yield tuple(pool[i] for i in indices)

itertools.product(*iterables, repeat=1)

输入可迭代对象的笛卡尔积。

大致相当于生成器表达式中的嵌套for循环。例如,product (A, B)返回的结果与((×,y) For x in A For y in B)相同。

嵌套循环像里程表一样循环,最右边的元素在每次迭代中都在前进。这个模式创建了一个字典顺序,这样如果输入的可迭代对象被排序,那么乘积元组就会按排序顺序发出。

要计算可迭代对象与其自身的乘积,请使用可选的repeat关键字参数指定重复次数。例如,product(A, repeat=4)的意思与product(A, A, A, A)相同。

这个函数大致相当于下面的代码,不同的是实际实现并不会在内存中建立中间结果:

def product(*args, repeat=1):
    # product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
    # product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
    pools = [tuple(pool) for pool in args] * repeat
    result = [[]]
    for pool in pools:
        result = [x+[y] for x in result for y in pool]
    for prod in result:
        yield tuple(prod)

product()运行之前,它完全消耗输入可迭代对象,在内存中保留值池以生成产品。因此,它只对有限的输入有用。

itertools.repeat(object[, times])

创建一个反复返回object的迭代器。无限期地运行,除非指定了times参数。大致相当于:

def repeat(object, times=None):
    # repeat(10, 3) --> 10 10 10
    if times is None:
        while True:
            yield object
    else:
        for i in range(times):
            yield object

repeat的一个常见用途是为mapzip提供一个常量值流:

list(map(pow, range(10), repeat(2)))

itertools.starmap(function, iterable)

创建一个迭代器,使用从iterable中获得的参数计算函数。当参数形参已经从单个可迭代对象中分组为元组(当数据已经“预压缩”)时,用于代替map()

map()starmap()之间的区别类似于function(a,b)function(*c)之间的区别。大致相当于:

def starmap(function, iterable):
    # starmap(pow, [(2,5), (3,2), (10,3)]) --> 32 9 1000
    for args in iterable:
        yield function(*args)

itertools.takewhile(predicate, iterable)

创建一个迭代器,只要谓词为真,该迭代器就从可迭代对象中返回元素。大致相当于:

def takewhile(predicate, iterable):
    # takewhile(lambda x: x<5, [1,4,6,4,1]) --> 1 4
    for x in iterable:
        if predicate(x):
            yield x
        else:
            break

itertools.tee(iterable, n=2)

从单个iterable返回n个独立的迭代器。下面的Python代码有助于解释tee的功能(尽管实际实现更复杂,只使用了单个底层FIFO队列):

def tee(iterable, n=2):
    it = iter(iterable)
    deques = [collections.deque() for i in range(n)]
    def gen(mydeque):
        while True:
            if not mydeque:             # when the local deque is empty
                try:
                    newval = next(it)   # fetch a new value and
                except StopIteration:
                    return
                for d in deques:        # load it to all the deques
                    d.append(newval)
            yield mydeque.popleft()
    return tuple(gen(d) for d in deques)

一旦创建了tee(),原始的可迭代对象就不应该在其他任何地方使用;否则,迭代对象可以在不通知tee对象的情况下被提升。

tee迭代器不是线程安全的。当同时使用同一tee()调用返回的迭代器时,可能会引发RuntimeError,即使原始迭代器是线程安全的。

这个itertool可能需要大量的辅助存储(取决于需要存储多少临时数据)。一般来说,如果一个迭代器在另一个迭代器启动之前使用了大部分或全部数据,那么使用list()tee()更快。

itertools.zip_longest(*iterables, fillvalue=None)

创建一个迭代器,聚合来自每个可迭代对象的元素。如果可迭代对象的长度不均匀,则用fillvalue填充缺失的值。迭代会继续,直到用完最长的可迭代对象。大致相当于:

def zip_longest(*args, fillvalue=None):
    # zip_longest('ABCD', 'xy', fillvalue='-') --> Ax By C- D-
    iterators = [iter(it) for it in args]
    num_active = len(iterators)
    if not num_active:
        return
    while True:
        values = []
        for i, it in enumerate(iterators):
            try:
                value = next(it)
            except StopIteration:
                num_active -= 1
                if not num_active:
                    return
                iterators[i] = repeat(fillvalue)
                value = fillvalue
            values.append(value)
        yield tuple(values)

如果其中一个可迭代对象可能是无限的,那么zip_longest()函数应该用限制调用次数的东西包装(例如islice()takewhile())。如果没有指定,fillvalue默认为None

Itertools Recipes

本节将展示使用现有itertools作为构建块创建扩展工具集的方法。

itertools配方的主要目的是教育。这些配方展示了思考单个工具的各种方式——例如,与扁平化的概念有关的chain.from_iterable。这些配方也给出了关于工具组合方式的想法——例如,compress()range()如何一起工作。还展示了与operatorcollections模块一起使用itertools的模式,以及与内置的itertools一起使用的模式,如map()filter()reversed()enumerate()

配方的第二个目的是作为一个孵化器。accumulate()compress()pairwise() itertools最初都是作为recipe使用的。目前,iter_index()配方正在测试中,看它是否证明了它的价值。基本上所有这些食谱和许多其他食谱都可以从Python包索引中的more-itertools项目中安装:

python -m pip install more-itertools

许多方法提供了与底层工具集相同的高性能。优越的内存性能是通过一次处理一个元素来保持的,而不是一次将整个可迭代对象都带入内存。通过以函数式的方式将工具链接在一起,有助于消除临时变量,从而保持较小的代码量。通过选择“向量化”的构建块而不是使用for循环和生成器来保持高速度,因为它们会引起解释器的开销。

import collections
import math
import operator
import random

def take(n, iterable):
    "Return first n items of the iterable as a list"
    return list(islice(iterable, n))

def prepend(value, iterator):
    "Prepend a single value in front of an iterator"
    # prepend(1, [2, 3, 4]) --> 1 2 3 4
    return chain([value], iterator)

def tabulate(function, start=0):
    "Return function(0), function(1), ..."
    return map(function, count(start))

def tail(n, iterable):
    "Return an iterator over the last n items"
    # tail(3, 'ABCDEFG') --> E F G
    return iter(collections.deque(iterable, maxlen=n))

def consume(iterator, n=None):
    "Advance the iterator n-steps ahead. If n is None, consume entirely."
    # Use functions that consume iterators at C speed.
    if n is None:
        # feed the entire iterator into a zero-length deque
        collections.deque(iterator, maxlen=0)
    else:
        # advance to the empty slice starting at position n
        next(islice(iterator, n, n), None)

def nth(iterable, n, default=None):
    "Returns the nth item or a default value"
    return next(islice(iterable, n, None), default)

def all_equal(iterable):
    "Returns True if all the elements are equal to each other"
    g = groupby(iterable)
    return next(g, True) and not next(g, False)

def quantify(iterable, pred=bool):
    "Count how many times the predicate is True"
    return sum(map(pred, iterable))

def ncycles(iterable, n):
    "Returns the sequence elements n times"
    return chain.from_iterable(repeat(tuple(iterable), n))

def batched(iterable, n):
    "Batch data into tuples of length n. The last batch may be shorter."
    # batched('ABCDEFG', 3) --> ABC DEF G
    if n < 1:
        raise ValueError('n must be at least one')
    it = iter(iterable)
    while (batch := tuple(islice(it, n))):
        yield batch

def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
    "Collect data into non-overlapping fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx
    # grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError
    # grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF
    args = [iter(iterable)] * n
    if incomplete == 'fill':
        return zip_longest(*args, fillvalue=fillvalue)
    if incomplete == 'strict':
        return zip(*args, strict=True)
    if incomplete == 'ignore':
        return zip(*args)
    else:
        raise ValueError('Expected fill, strict, or ignore')

def sumprod(vec1, vec2):
    "Compute a sum of products."
    return sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))

def sum_of_squares(it):
    "Add up the squares of the input values."
    # sum_of_squares([10, 20, 30]) -> 1400
    return sumprod(*tee(it))

def transpose(it):
    "Swap the rows and columns of the input."
    # transpose([(1, 2, 3), (11, 22, 33)]) --> (1, 11) (2, 22) (3, 33)
    return zip(*it, strict=True)

def matmul(m1, m2):
    "Multiply two matrices."
    # matmul([(7, 5), (3, 5)], [[2, 5], [7, 9]]) --> (49, 80), (41, 60)
    n = len(m2[0])
    return batched(starmap(sumprod, product(m1, transpose(m2))), n)

def convolve(signal, kernel):
    # See:  https://betterexplained.com/articles/intuitive-convolution/
    # convolve(data, [0.25, 0.25, 0.25, 0.25]) --> Moving average (blur)
    # convolve(data, [1, -1]) --> 1st finite difference (1st derivative)
    # convolve(data, [1, -2, 1]) --> 2nd finite difference (2nd derivative)
    kernel = tuple(kernel)[::-1]
    n = len(kernel)
    window = collections.deque([0], maxlen=n) * n
    for x in chain(signal, repeat(0, n-1)):
        window.append(x)
        yield sumprod(kernel, window)

def polynomial_from_roots(roots):
    """Compute a polynomial's coefficients from its roots.

       (x - 5) (x + 4) (x - 3)  expands to:   x³ -4x² -17x + 60
    """
    # polynomial_from_roots([5, -4, 3]) --> [1, -4, -17, 60]
    roots = list(map(operator.neg, roots))
    return [
        sum(map(math.prod, combinations(roots, k)))
        for k in range(len(roots) + 1)
    ]

def iter_index(iterable, value, start=0):
    "Return indices where a value occurs in a sequence or iterable."
    # iter_index('AABCADEAF', 'A') --> 0 1 4 7
    try:
        seq_index = iterable.index
    except AttributeError:
        # Slow path for general iterables
        it = islice(iterable, start, None)
        i = start - 1
        try:
            while True:
                yield (i := i + operator.indexOf(it, value) + 1)
        except ValueError:
            pass
    else:
        # Fast path for sequences
        i = start - 1
        try:
            while True:
                yield (i := seq_index(value, i+1))
        except ValueError:
            pass

def sieve(n):
    "Primes less than n"
    # sieve(30) --> 2 3 5 7 11 13 17 19 23 29
    data = bytearray((0, 1)) * (n // 2)
    data[:3] = 0, 0, 0
    limit = math.isqrt(n) + 1
    for p in compress(range(limit), data):
        data[p*p : n : p+p] = bytes(len(range(p*p, n, p+p)))
    data[2] = 1
    return iter_index(data, 1) if n > 2 else iter([])

def factor(n):
    "Prime factors of n."
    # factor(99) --> 3 3 11
    for prime in sieve(math.isqrt(n) + 1):
        while True:
            quotient, remainder = divmod(n, prime)
            if remainder:
                break
            yield prime
            n = quotient
            if n == 1:
                return
    if n >= 2:
        yield n

def flatten(list_of_lists):
    "Flatten one level of nesting"
    return chain.from_iterable(list_of_lists)

def repeatfunc(func, times=None, *args):
    """Repeat calls to func with specified arguments.

    Example:  repeatfunc(random.random)
    """
    if times is None:
        return starmap(func, repeat(args))
    return starmap(func, repeat(args, times))

def triplewise(iterable):
    "Return overlapping triplets from an iterable"
    # triplewise('ABCDEFG') --> ABC BCD CDE DEF EFG
    for (a, _), (b, c) in pairwise(pairwise(iterable)):
        yield a, b, c

def sliding_window(iterable, n):
    # sliding_window('ABCDEFG', 4) --> ABCD BCDE CDEF DEFG
    it = iter(iterable)
    window = collections.deque(islice(it, n), maxlen=n)
    if len(window) == n:
        yield tuple(window)
    for x in it:
        window.append(x)
        yield tuple(window)

def roundrobin(*iterables):
    "roundrobin('ABC', 'D', 'EF') --> A D E B F C"
    # Recipe credited to George Sakkis
    num_active = len(iterables)
    nexts = cycle(iter(it).__next__ for it in iterables)
    while num_active:
        try:
            for next in nexts:
                yield next()
        except StopIteration:
            # Remove the iterator we just exhausted from the cycle.
            num_active -= 1
            nexts = cycle(islice(nexts, num_active))

def partition(pred, iterable):
    "Use a predicate to partition entries into false entries and true entries"
    # partition(is_odd, range(10)) --> 0 2 4 6 8   and  1 3 5 7 9
    t1, t2 = tee(iterable)
    return filterfalse(pred, t1), filter(pred, t2)

def before_and_after(predicate, it):
    """ Variant of takewhile() that allows complete
        access to the remainder of the iterator.

        >>> it = iter('ABCdEfGhI')
        >>> all_upper, remainder = before_and_after(str.isupper, it)
        >>> ''.join(all_upper)
        'ABC'
        >>> ''.join(remainder)     # takewhile() would lose the 'd'
        'dEfGhI'

        Note that the first iterator must be fully
        consumed before the second iterator can
        generate valid results.
    """
    it = iter(it)
    transition = []
    def true_iterator():
        for elem in it:
            if predicate(elem):
                yield elem
            else:
                transition.append(elem)
                return
    def remainder_iterator():
        yield from transition
        yield from it
    return true_iterator(), remainder_iterator()

def subslices(seq):
    "Return all contiguous non-empty subslices of a sequence"
    # subslices('ABCD') --> A AB ABC ABCD B BC BCD C CD D
    slices = starmap(slice, combinations(range(len(seq) + 1), 2))
    return map(operator.getitem, repeat(seq), slices)

def powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

def unique_everseen(iterable, key=None):
    "List unique elements, preserving order. Remember all elements ever seen."
    # unique_everseen('AAAABBBCCDAABBB') --> A B C D
    # unique_everseen('ABBcCAD', str.lower) --> A B c D
    seen = set()
    if key is None:
        for element in filterfalse(seen.__contains__, iterable):
            seen.add(element)
            yield element
        # For order preserving deduplication,
        # a faster but non-lazy solution is:
        #     yield from dict.fromkeys(iterable)
    else:
        for element in iterable:
            k = key(element)
            if k not in seen:
                seen.add(k)
                yield element
        # For use cases that allow the last matching element to be returned,
        # a faster but non-lazy solution is:
        #      t1, t2 = tee(iterable)
        #      yield from dict(zip(map(key, t1), t2)).values()

def unique_justseen(iterable, key=None):
    "List unique elements, preserving order. Remember only the element just seen."
    # unique_justseen('AAAABBBCCDAABBB') --> A B C D A B
    # unique_justseen('ABBcCAD', str.lower) --> A B c A D
    return map(next, map(operator.itemgetter(1), groupby(iterable, key)))

def iter_except(func, exception, first=None):
    """ Call a function repeatedly until an exception is raised.

    Converts a call-until-exception interface to an iterator interface.
    Like builtins.iter(func, sentinel) but uses an exception instead
    of a sentinel to end the loop.

    Examples:
        iter_except(functools.partial(heappop, h), IndexError)   # priority queue iterator
        iter_except(d.popitem, KeyError)                         # non-blocking dict iterator
        iter_except(d.popleft, IndexError)                       # non-blocking deque iterator
        iter_except(q.get_nowait, Queue.Empty)                   # loop over a producer Queue
        iter_except(s.pop, KeyError)                             # non-blocking set iterator

    """
    try:
        if first is not None:
            yield first()            # For database APIs needing an initial cast to db.first()
        while True:
            yield func()
    except exception:
        pass

def first_true(iterable, default=False, pred=None):
    """Returns the first true value in the iterable.

    If no true value is found, returns *default*

    If *pred* is not None, returns the first item
    for which pred(item) is true.

    """
    # first_true([a,b,c], x) --> a or b or c or x
    # first_true([a,b], x, f) --> a if f(a) else b if f(b) else x
    return next(filter(pred, iterable), default)

def nth_combination(iterable, r, index):
    "Equivalent to list(combinations(iterable, r))[index]"
    pool = tuple(iterable)
    n = len(pool)
    c = math.comb(n, r)
    if index < 0:
        index += c
    if index < 0 or index >= c:
        raise IndexError
    result = []
    while r:
        c, n, r = c*r//n, n-1, r-1
        while index >= c:
            index -= c
            c, n = c*(n-r)//n, n-1
        result.append(pool[-1-n])
    return tuple(result)

翻译自官方文档itertools — Functions creating iterators for efficient looping

posted @ 2023-03-14 15:00  Rogn  阅读(100)  评论(0编辑  收藏  举报