itertools --- 为高效循环创建迭代器的函数


此模块实现了一些 迭代器 构建块,其灵感来自于 APL、Haskell 和 SML 中的构想。每个构建块都经过重构,以使其适用于 Python。

该模块将一组快速、高效利用内存的工具标准化,这些工具本身或组合起来都很有用。它们共同构成了一个“迭代器代数”,使得在纯 Python 中能够简洁高效地构建专用工具。

例如,SML 提供了一个制表工具:tabulate(f),它可以生成序列 f(0), f(1), ...。在 Python 中,可以通过组合 map()count() 形成 map(f, count()) 来实现相同的效果。

无限迭代器

迭代器

参数

结果

示例

count()

[start[, step]]

start, start+step, start+2*step, …

count(10) 10 11 12 13 14 ...

cycle()

p

p0, p1, … plast, p0, p1, …

cycle('ABCD') A B C D A B C D ...

repeat()

elem [,n]

elem, elem, elem, … 无限重复或重复 n 次

repeat(10, 3) 10 10 10

在最短输入序列上终止的迭代器

迭代器

参数

结果

示例

accumulate()

p [,func]

p0, p0+p1, p0+p1+p2, …

accumulate([1,2,3,4,5]) 1 3 6 10 15

batched()

p, n

(p0, p1, …, p_n-1), …

batched('ABCDEFG', n=2) AB CD EF G

chain()

p, q, …

p0, p1, … plast, q0, q1, …

chain('ABC', 'DEF') A B C D E F

chain.from_iterable()

可迭代对象

p0, p1, … plast, q0, q1, …

chain.from_iterable(['ABC', 'DEF']) A B C D E F

compress()

data, selectors

(d[0] if s[0]), (d[1] if s[1]), …

compress('ABCDEF', [1,0,1,0,1,1]) A C E F

dropwhile()

predicate, seq

seq[n], seq[n+1], 从 predicate 为假时开始

dropwhile(lambda x: x<5, [1,4,6,3,8]) 6 3 8

filterfalse()

predicate, seq

seq 中 predicate(elem) 为假的元素

filterfalse(lambda x: x<5, [1,4,6,3,8]) 6 8

groupby()

iterable[, key]

按 key(v) 值分组的子迭代器

groupby(['A','B','DEF'], len) (1, A B) (3, DEF)

islice()

seq, [start,] stop [, step]

来自 seq[start:stop:step] 的元素

islice('ABCDEFG', 2, None) C D E F G

pairwise()

可迭代对象

(p[0], p[1]), (p[1], p[2])

pairwise('ABCDEFG') AB BC CD DE EF FG

starmap()

func, seq

func(*seq[0]), func(*seq[1]), …

starmap(pow, [(2,5), (3,2), (10,3)]) 32 9 1000

takewhile()

predicate, seq

seq[0], seq[1], 直到 predicate 为假

takewhile(lambda x: x<5, [1,4,6,3,8]) 1 4

tee()

it, n

it1, it2, … itn 将一个迭代器拆分为 n 个

tee('ABC', 2) A B C, A B C

zip_longest()

p, q, …

(p[0], q[0]), (p[1], q[1]), …

zip_longest('ABCD', 'xy', fillvalue='-') Ax By C- D-

组合迭代器

迭代器

参数

结果

product()

p, q, … [repeat=1]

笛卡尔积,相当于嵌套的 for 循环

permutations()

p[, r]

长度为 r 的元组,所有可能的排序,无重复元素

combinations()

p, r

长度为 r 的元组,按排序顺序,无重复元素

combinations_with_replacement()

p, r

长度为 r 的元组,按排序顺序,有重复元素

示例:

结果

product('ABCD', repeat=2)

AA AB AC AD BA BB BC BD CA CB CC CD DA DB DC DD

permutations('ABCD', 2)

AB AC AD BA BC BD CA CB CD DA DB DC

combinations('ABCD', 2)

AB AC AD BC BD CD

combinations_with_replacement('ABCD', 2)

AA AB AC AD BB BC BD CC CD DD

Itertool 函数

以下所有函数都用于构建并返回迭代器。有些提供无限长度的数据流,因此它们只应被那些会截断数据流的函数或循环访问。

itertools.accumulate(iterable[, function, *, initial=None])

创建一个迭代器,返回累加的和或来自其他二元函数的累加结果。

function 默认为加法。function 应接受两个参数:一个累加的总值和一个来自 iterable 的值。

如果提供了 initial 值,累加将从该值开始,并且输出将比输入的可迭代对象多一个元素。

大致相当于:

def accumulate(iterable, function=operator.add, *, initial=None):
    'Return running totals'
    # accumulate([1,2,3,4,5]) → 1 3 6 10 15
    # accumulate([1,2,3,4,5], initial=100) → 100 101 103 106 110 115
    # accumulate([1,2,3,4,5], operator.mul) → 1 2 6 24 120

    iterator = iter(iterable)
    total = initial
    if initial is None:
        try:
            total = next(iterator)
        except StopIteration:
            return

    yield total
    for element in iterator:
        total = function(total, element)
        yield total

要计算一个运行中的最小值,请将 function 设置为 min()。对于运行中的最大值,请将 function 设置为 max()。或者对于运行中的乘积,请将 function 设置为 operator.mul()。要构建摊销表,可累加利息并应用付款。

>>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8]
>>> list(accumulate(data, max))              # running maximum
[3, 4, 6, 6, 6, 9, 9, 9, 9, 9]
>>> list(accumulate(data, operator.mul))     # running product
[3, 12, 72, 144, 144, 1296, 0, 0, 0, 0]

# Amortize a 5% loan of 1000 with 10 annual payments of 90
>>> update = lambda balance, payment: round(balance * 1.05) - payment
>>> list(accumulate(repeat(90, 10), update, initial=1_000))
[1000, 960, 918, 874, 828, 779, 728, 674, 618, 559, 497]

另请参阅 functools.reduce(),它是一个类似函数,只返回最终的累加值。

在 3.2 版本加入。

在 3.3 版本发生变更: 添加了可选的 function 形参。

在 3.8 版本发生变更: 添加了可选的 initial 形参。

itertools.batched(iterable, n, *, strict=False)

将来自 iterable 的数据分批成长度为 n 的元组。最后一批的长度可能小于 n

如果 strict 为真,且最后一批的长度小于 n,将引发 ValueError

遍历输入的可迭代对象,并将数据累积成大小至多为 n 的元组。输入是被惰性消耗的,仅消耗足以填充一批的数量。一旦批次满员或输入的可迭代对象耗尽,就会产生结果。

>>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet']
>>> unflattened = list(batched(flattened_data, 2))
>>> unflattened
[('roses', 'red'), ('violets', 'blue'), ('sugar', 'sweet')]

大致相当于:

def batched(iterable, n, *, strict=False):
    # batched('ABCDEFG', 2) → AB CD EF G
    if n < 1:
        raise ValueError('n must be at least one')
    iterator = iter(iterable)
    while batch := tuple(islice(iterator, n)):
        if strict and len(batch) != n:
            raise ValueError('batched(): incomplete batch')
        yield batch

3.12 新版功能.

在 3.13 版本发生变更: 添加了 strict 选项。

itertools.chain(*iterables)

创建一个迭代器,它从第一个可迭代对象中返回元素,直到耗尽,然后继续到下一个可迭代对象,直到所有的可迭代对象都被耗尽。这将多个数据源组合成一个单一的迭代器。大致等同于:

def chain(*iterables):
    # chain('ABC', 'DEF') → A B C D E F
    for iterable in iterables:
        yield from iterable
classmethod chain.from_iterable(iterable)

chain() 的备用构造函数。从一个惰性求值的单一可迭代对象参数中获取链式输入。大致等同于:

def from_iterable(iterables):
    # chain.from_iterable(['ABC', 'DEF']) → A B C D E F
    for iterable in iterables:
        yield from iterable
itertools.combinations(iterable, r)

返回输入 iterable 中元素的长度为 r 的子序列。

输出是 product() 的一个子序列,只保留那些是 iterable 的子序列的条目。输出的长度由 math.comb() 给出,当 0 r n 时计算 n! / r! / (n - r)!,当 r > n 时为零。

组合元组按输入 iterable 的顺序以字典序发出。如果输入 iterable 是排序的,则输出的元组将按排序顺序生成。

元素根据其位置而不是其值被视为唯一的。如果输入元素是唯一的,则每个组合中将不会有重复的值。

大致相当于:

def combinations(iterable, r):
    # combinations('ABCD', 2) → AB AC AD BC BD CD
    # combinations(range(4), 3) → 012 013 023 123

    pool = tuple(iterable)
    n = len(pool)
    if r > n:
        return
    indices = list(range(r))

    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != i + n - r:
                break
        else:
            return
        indices[i] += 1
        for j in range(i+1, r):
            indices[j] = indices[j-1] + 1
        yield tuple(pool[i] for i in indices)
itertools.combinations_with_replacement(iterable, r)

返回输入 iterable 中元素的长度为 r 的子序列,允许单个元素重复多次。

输出是 product() 的一个子序列,只保留那些是 iterable 的子序列(可能带有重复元素)的条目。当 n > 0 时,返回的子序列数量为 (n + r - 1)! / r! / (n - 1)!

组合元组按输入 iterable 的顺序以字典序发出。如果输入 iterable 是排序的,则输出的元组将按排序顺序生成。

元素根据其位置而不是其值被视为唯一的。如果输入元素是唯一的,则生成的组合也将是唯一的。

大致相当于:

def combinations_with_replacement(iterable, r):
    # combinations_with_replacement('ABC', 2) → AA AB AC BB BC CC

    pool = tuple(iterable)
    n = len(pool)
    if not n and r:
        return
    indices = [0] * r

    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != n - 1:
                break
        else:
            return
        indices[i:] = [indices[i] + 1] * (r - i)
        yield tuple(pool[i] for i in indices)

在 3.1 版本加入。

itertools.compress(data, selectors)

创建一个迭代器,它返回 data 中那些在 selectors 中相应元素为真的元素。当 dataselectors 的任一可迭代对象被耗尽时停止。大致等同于:

def compress(data, selectors):
    # compress('ABCDEF', [1,0,1,0,1,1]) → A C E F
    return (datum for datum, selector in zip(data, selectors) if selector)

在 3.1 版本加入。

itertools.count(start=0, step=1)

创建一个迭代器,返回从 start 开始的等间距值。可与 map() 一起生成连续的数据点,或与 zip() 一起添加序列号。大致等同于:

def count(start=0, step=1):
    # count(10) → 10 11 12 13 14 ...
    # count(2.5, 0.5) → 2.5 3.0 3.5 ...
    n = start
    while True:
        yield n
        n += step

当使用浮点数计数时,通过使用乘法代码,有时可以获得更好的精度,例如:(start + step * i for i in count())

在 3.1 版本发生变更: 增加了 step 参数并允许非整数参数。

itertools.cycle(iterable)

创建一个迭代器,返回来自 iterable 的元素,并保存每个元素的副本。当可迭代对象被耗尽时,从保存的副本中返回元素。无限重复。大致等同于:

def cycle(iterable):
    # cycle('ABCD') → A B C D A B C D A B C D ...

    saved = []
    for element in iterable:
        yield element
        saved.append(element)

    while saved:
        for element in saved:
            yield element

此迭代工具可能需要大量的辅助存储空间(取决于可迭代对象的长度)。

itertools.dropwhile(predicate, iterable)

创建一个迭代器,只要 predicate 为真,就从 iterable 中丢弃元素,之后返回每个元素。大致等同于:

def dropwhile(predicate, iterable):
    # dropwhile(lambda x: x<5, [1,4,6,3,8]) → 6 3 8

    iterator = iter(iterable)
    for x in iterator:
        if not predicate(x):
            yield x
            break

    for x in iterator:
        yield x

请注意,在断言首次变为假之前,它不会产生*任何*输出,因此这个迭代工具可能会有很长的启动时间。

itertools.filterfalse(predicate, iterable)

创建一个迭代器,它从 iterable 中过滤元素,只返回那些 predicate 返回假值的元素。如果 predicateNone,则返回值为假的项目。大致等同于:

def filterfalse(predicate, iterable):
    # filterfalse(lambda x: x<5, [1,4,6,3,8]) → 6 8

    if predicate is None:
        predicate = bool

    for x in iterable:
        if not predicate(x):
            yield x
itertools.groupby(iterable, key=None)

创建一个迭代器,返回来自 iterable 的连续键和组。key 是一个计算每个元素的键值的函数。如果未指定或为 Nonekey 默认为一个恒等函数并返回元素本身。通常,可迭代对象需要已经按相同的键函数排序。

groupby() 的操作类似于 Unix 中的 uniq 过滤器。每当键函数的值发生变化时,它就会生成一个中断或新组(这就是为什么通常需要使用相同的键函数对数据进行排序的原因)。这种行为不同于 SQL 的 GROUP BY,后者聚合公共元素,而不考虑它们的输入顺序。

返回的组本身是一个迭代器,它与 groupby() 共享底层的可迭代对象。因为源是共享的,当 groupby() 对象前进时,前一个组就不再可见。因此,如果以后需要这些数据,应该将其存储为一个列表。

groups = []
uniquekeys = []
data = sorted(data, key=keyfunc)
for k, g in groupby(data, keyfunc):
    groups.append(list(g))      # Store group iterator as a list
    uniquekeys.append(k)

groupby() 大致等同于:

def groupby(iterable, key=None):
    # [k for k, g in groupby('AAAABBBCCDAABBB')] → A B C D A B
    # [list(g) for k, g in groupby('AAAABBBCCD')] → AAAA BBB CC D

    keyfunc = (lambda x: x) if key is None else key
    iterator = iter(iterable)
    exhausted = False

    def _grouper(target_key):
        nonlocal curr_value, curr_key, exhausted
        yield curr_value
        for curr_value in iterator:
            curr_key = keyfunc(curr_value)
            if curr_key != target_key:
                return
            yield curr_value
        exhausted = True

    try:
        curr_value = next(iterator)
    except StopIteration:
        return
    curr_key = keyfunc(curr_value)

    while not exhausted:
        target_key = curr_key
        curr_group = _grouper(target_key)
        yield curr_key, curr_group
        if curr_key == target_key:
            for _ in curr_group:
                pass
itertools.islice(iterable, stop)
itertools.islice(iterable, start, stop[, step])

创建一个迭代器,返回可迭代对象中的选定元素。工作方式类似于序列切片,但不支持 startstopstep 的负值。

如果 start 为零或 None,迭代从零开始。否则,跳过可迭代对象中的元素,直到达到 start

如果 stopNone,迭代将持续进行直到输入耗尽(如果会耗尽的话)。 否则,它会在指定位置停止。

如果 stepNone,则步长默认为 1。元素会连续地被返回,除非 step 设置为大于 1 的值,这会导致一些条目被跳过。

大致相当于:

def islice(iterable, *args):
    # islice('ABCDEFG', 2) → A B
    # islice('ABCDEFG', 2, 4) → C D
    # islice('ABCDEFG', 2, None) → C D E F G
    # islice('ABCDEFG', 0, None, 2) → A C E G

    s = slice(*args)
    start = 0 if s.start is None else s.start
    stop = s.stop
    step = 1 if s.step is None else s.step
    if start < 0 or (stop is not None and stop < 0) or step <= 0:
        raise ValueError

    indices = count() if stop is None else range(max(start, stop))
    next_i = start
    for i, element in zip(indices, iterable):
        if i == next_i:
            yield element
            next_i += step

如果输入是一个迭代器,那么完全消耗 islice 会使输入迭代器前进 max(start, stop) 步,而与 step 的值无关。

itertools.pairwise(iterable)

返回从输入 iterable 中获取的连续重叠对。

输出迭代器中的二元组数量将比输入数量少一。如果输入的可迭代对象的值少于两个,它将为空。

大致相当于:

def pairwise(iterable):
    # pairwise('ABCDEFG') → AB BC CD DE EF FG

    iterator = iter(iterable)
    a = next(iterator, None)

    for b in iterator:
        yield a, b
        a = b

在 3.10 版本加入。

itertools.permutations(iterable, r=None)

返回来自 iterable 的元素的连续 r 长度排列

如果 r 未指定或为 None,则 r 默认为 iterable 的长度,并生成所有可能的全长排列。

输出是 product() 的一个子序列,其中已过滤掉有重复元素的条目。输出的长度由 math.perm() 给出,当 0 r n 时计算 n! / (n - r)!,当 r > n 时为零。

排列元组按输入 iterable 的顺序以字典序发出。如果输入 iterable 是排序的,则输出的元组将按排序顺序生成。

元素根据其位置而不是其值被视为唯一的。如果输入元素是唯一的,则排列中不会有重复的值。

大致相当于:

def permutations(iterable, r=None):
    # permutations('ABCD', 2) → AB AC AD BA BC BD CA CB CD DA DB DC
    # permutations(range(3)) → 012 021 102 120 201 210

    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    if r > n:
        return

    indices = list(range(n))
    cycles = list(range(n, n-r, -1))
    yield tuple(pool[i] for i in indices[:r])

    while n:
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i
            else:
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
                yield tuple(pool[i] for i in indices[:r])
                break
        else:
            return
itertools.product(*iterables, repeat=1)

输入可迭代对象的笛卡尔积

大致相当于生成器表达式中的嵌套 for 循环。例如,product(A, B) 返回的结果与 ((x,y) for x in A for y in B) 相同。

嵌套的循环像里程表一样循环,最右边的元素在每次迭代时都会前进。这种模式创建了一个字典序的排序,因此如果输入的迭代器是排序的,那么乘积元组将以排序的顺序发出。

要计算一个可迭代对象与自身的乘积,请使用可选的 repeat 关键字参数指定重复次数。例如,product(A, repeat=4)product(A, A, A, A) 的意思相同。

此函数大致等同于以下代码,但实际实现不会在内存中构建中间结果:

def product(*iterables, repeat=1):
    # product('ABCD', 'xy') → Ax Ay Bx By Cx Cy Dx Dy
    # product(range(2), repeat=3) → 000 001 010 011 100 101 110 111

    if repeat < 0:
        raise ValueError('repeat argument cannot be negative')
    pools = [tuple(pool) for pool in iterables] * repeat

    result = [[]]
    for pool in pools:
        result = [x+[y] for x in result for y in pool]

    for prod in result:
        yield tuple(prod)

product() 运行之前,它会完全消耗输入的迭代器,将值池保存在内存中以生成乘积。因此,它只适用于有限的输入。

itertools.repeat(object[, times])

创建一个迭代器,它一次又一次地返回 object。除非指定了 times 参数,否则它会无限期运行。

大致相当于:

def repeat(object, times=None):
    # repeat(10, 3) → 10 10 10
    if times is None:
        while True:
            yield object
    else:
        for i in range(times):
            yield object

repeat 的一个常见用途是向 mapzip 提供一个常量值流:

>>> list(map(pow, range(10), repeat(2)))
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
itertools.starmap(function, iterable)

创建一个迭代器,使用从 iterable 获得的参数来计算 function。当参数已经“预先打包”成元组时,使用它来代替 map()

map()starmap() 之间的区别与 function(a,b)function(*c) 之间的区别相类似。大致等同于:

def starmap(function, iterable):
    # starmap(pow, [(2,5), (3,2), (10,3)]) → 32 9 1000
    for args in iterable:
        yield function(*args)
itertools.takewhile(predicate, iterable)

创建一个迭代器,只要 predicate 为真,就从 iterable 中返回元素。大致等同于:

def takewhile(predicate, iterable):
    # takewhile(lambda x: x<5, [1,4,6,3,8]) → 1 4
    for x in iterable:
        if not predicate(x):
            break
        yield x

注意,第一个不满足谓词条件的元素会从输入迭代器中消耗掉,并且无法访问它。如果应用程序希望在 takewhile 运行完毕后继续消耗输入迭代器,这可能会成为一个问题。为了解决这个问题,可以考虑使用 more-itertools before_and_after()

itertools.tee(iterable, n=2)

从单个可迭代对象返回 n 个独立的迭代器。

大致相当于:

def tee(iterable, n=2):
    if n < 0:
        raise ValueError
    if n == 0:
        return ()
    iterator = _tee(iterable)
    result = [iterator]
    for _ in range(n - 1):
        result.append(_tee(iterator))
    return tuple(result)

class _tee:

    def __init__(self, iterable):
        it = iter(iterable)
        if isinstance(it, _tee):
            self.iterator = it.iterator
            self.link = it.link
        else:
            self.iterator = it
            self.link = [None, None]

    def __iter__(self):
        return self

    def __next__(self):
        link = self.link
        if link[1] is None:
            link[0] = next(self.iterator)
            link[1] = [None, None]
        value, self.link = link
        return value

当输入 iterable 已经是一个 tee 迭代器对象时,返回的元组的所有成员的构造方式就好像它们是由上游的 tee() 调用产生的一样。这种“扁平化步骤”允许嵌套的 tee() 调用共享相同的底层数据链,并且只有一个更新步骤而不是一连串的调用。

扁平化特性使得 tee 迭代器可以高效地进行窥视:

def lookahead(tee_iterator):
     "Return the next value without moving the input forward"
     [forked_iterator] = tee(tee_iterator, 1)
     return next(forked_iterator)
>>> iterator = iter('abcdef')
>>> [iterator] = tee(iterator, 1)   # Make the input peekable
>>> next(iterator)                  # Move the iterator forward
'a'
>>> lookahead(iterator)             # Check next value
'b'
>>> next(iterator)                  # Continue moving forward
'b'

tee 迭代器不是线程安全的。当同时使用由同一个 tee() 调用返回的迭代器时,可能会引发 RuntimeError,即使原始的 iterable 是线程安全的。

这个迭代工具可能需要大量的辅助存储空间(取决于需要存储多少临时数据)。一般来说,如果一个迭代器在另一个迭代器开始之前使用了大部分或全部数据,使用 list() 会比使用 tee() 更快。

itertools.zip_longest(*iterables, fillvalue=None)

创建一个迭代器,聚合来自每个 iterables 的元素。

如果可迭代对象的长度不均匀,缺失的值将用 fillvalue 填充。如果未指定,fillvalue 默认为 None

迭代将持续到最长的可迭代对象被耗尽。

大致相当于:

def zip_longest(*iterables, fillvalue=None):
    # zip_longest('ABCD', 'xy', fillvalue='-') → Ax By C- D-

    iterators = list(map(iter, iterables))
    num_active = len(iterators)
    if not num_active:
        return

    while True:
        values = []
        for i, iterator in enumerate(iterators):
            try:
                value = next(iterator)
            except StopIteration:
                num_active -= 1
                if not num_active:
                    return
                iterators[i] = repeat(fillvalue)
                value = fillvalue
            values.append(value)
        yield tuple(values)

如果其中一个可迭代对象可能是无限的,那么 zip_longest() 函数应该被包装在限制调用次数的东西中(例如 islice()takewhile())。

Itertools 配方

本节展示了使用现有的 itertools 作为构建块来创建扩展工具集的配方。

itertools 配方的主要目的是教育。这些配方展示了思考单个工具的各种方式——例如,chain.from_iterable 与扁平化的概念有关。这些配方也提供了关于工具如何组合的想法——例如,starmap()repeat() 如何协同工作。这些配方还展示了将 itertools 与 operatorcollections 模块以及内置的 itertools(如 map()filter()reversed()enumerate())一起使用的模式。

配方的次要目的是作为孵化器。 accumulate(), compress(), and pairwise() 等 itertools 工具最初都是以配方的形式出现的。 当前,sliding_window(), iter_index()sieve() 等配方正在测试中,以检验它们是否能证明其价值。

基本上所有这些配方以及许多许多其他的配方都可以从 Python 包索引上的 more-itertools 项目安装。

python -m pip install more-itertools

许多配方提供了与底层工具集相同的高性能。通过一次处理一个元素而不是将整个可迭代对象一次性加载到内存中,保持了卓越的内存性能。通过以函数式风格链接工具,代码量得以保持较小。通过优先选择“矢量化”构建块,而不是使用会产生解释器开销的 for 循环和生成器,从而保持了高速。

from collections import Counter, deque
from contextlib import suppress
from functools import reduce
from math import comb, prod, sumprod, isqrt
from operator import itemgetter, getitem, mul, neg

def take(n, iterable):
    "Return first n items of the iterable as a list."
    return list(islice(iterable, n))

def prepend(value, iterable):
    "Prepend a single value in front of an iterable."
    # prepend(1, [2, 3, 4]) → 1 2 3 4
    return chain([value], iterable)

def tabulate(function, start=0):
    "Return function(0), function(1), ..."
    return map(function, count(start))

def repeatfunc(function, times=None, *args):
    "Repeat calls to a function with specified arguments."
    if times is None:
        return starmap(function, repeat(args))
    return starmap(function, repeat(args, times))

def flatten(list_of_lists):
    "Flatten one level of nesting."
    return chain.from_iterable(list_of_lists)

def ncycles(iterable, n):
    "Returns the sequence elements n times."
    return chain.from_iterable(repeat(tuple(iterable), n))

def loops(n):
    "Loop n times. Like range(n) but without creating integers."
    # for _ in loops(100): ...
    return repeat(None, n)

def tail(n, iterable):
    "Return an iterator over the last n items."
    # tail(3, 'ABCDEFG') → E F G
    return iter(deque(iterable, maxlen=n))

def consume(iterator, n=None):
    "Advance the iterator n-steps ahead. If n is None, consume entirely."
    # Use functions that consume iterators at C speed.
    if n is None:
        deque(iterator, maxlen=0)
    else:
        next(islice(iterator, n, n), None)

def nth(iterable, n, default=None):
    "Returns the nth item or a default value."
    return next(islice(iterable, n, None), default)

def quantify(iterable, predicate=bool):
    "Given a predicate that returns True or False, count the True results."
    return sum(map(predicate, iterable))

def first_true(iterable, default=False, predicate=None):
    "Returns the first true value or the *default* if there is no true value."
    # first_true([a,b,c], x) → a or b or c or x
    # first_true([a,b], x, f) → a if f(a) else b if f(b) else x
    return next(filter(predicate, iterable), default)

def all_equal(iterable, key=None):
    "Returns True if all the elements are equal to each other."
    # all_equal('4٤௪౪໔', key=int) → True
    return len(take(2, groupby(iterable, key))) <= 1

def unique_justseen(iterable, key=None):
    "Yield unique elements, preserving order. Remember only the element just seen."
    # unique_justseen('AAAABBBCCDAABBB') → A B C D A B
    # unique_justseen('ABBcCAD', str.casefold) → A B c A D
    if key is None:
        return map(itemgetter(0), groupby(iterable))
    return map(next, map(itemgetter(1), groupby(iterable, key)))

def unique_everseen(iterable, key=None):
    "Yield unique elements, preserving order. Remember all elements ever seen."
    # unique_everseen('AAAABBBCCDAABBB') → A B C D
    # unique_everseen('ABBcCAD', str.casefold) → A B c D
    seen = set()
    if key is None:
        for element in filterfalse(seen.__contains__, iterable):
            seen.add(element)
            yield element
    else:
        for element in iterable:
            k = key(element)
            if k not in seen:
                seen.add(k)
                yield element

def unique(iterable, key=None, reverse=False):
   "Yield unique elements in sorted order. Supports unhashable inputs."
   # unique([[1, 2], [3, 4], [1, 2]]) → [1, 2] [3, 4]
   sequenced = sorted(iterable, key=key, reverse=reverse)
   return unique_justseen(sequenced, key=key)

def sliding_window(iterable, n):
    "Collect data into overlapping fixed-length chunks or blocks."
    # sliding_window('ABCDEFG', 4) → ABCD BCDE CDEF DEFG
    iterator = iter(iterable)
    window = deque(islice(iterator, n - 1), maxlen=n)
    for x in iterator:
        window.append(x)
        yield tuple(window)

def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
    "Collect data into non-overlapping fixed-length chunks or blocks."
    # grouper('ABCDEFG', 3, fillvalue='x') → ABC DEF Gxx
    # grouper('ABCDEFG', 3, incomplete='strict') → ABC DEF ValueError
    # grouper('ABCDEFG', 3, incomplete='ignore') → ABC DEF
    iterators = [iter(iterable)] * n
    match incomplete:
        case 'fill':
            return zip_longest(*iterators, fillvalue=fillvalue)
        case 'strict':
            return zip(*iterators, strict=True)
        case 'ignore':
            return zip(*iterators)
        case _:
            raise ValueError('Expected fill, strict, or ignore')

def roundrobin(*iterables):
    "Visit input iterables in a cycle until each is exhausted."
    # roundrobin('ABC', 'D', 'EF') → A D E B F C
    # Algorithm credited to George Sakkis
    iterators = map(iter, iterables)
    for num_active in range(len(iterables), 0, -1):
        iterators = cycle(islice(iterators, num_active))
        yield from map(next, iterators)

def subslices(seq):
    "Return all contiguous non-empty subslices of a sequence."
    # subslices('ABCD') → A AB ABC ABCD B BC BCD C CD D
    slices = starmap(slice, combinations(range(len(seq) + 1), 2))
    return map(getitem, repeat(seq), slices)

def iter_index(iterable, value, start=0, stop=None):
    "Return indices where a value occurs in a sequence or iterable."
    # iter_index('AABCADEAF', 'A') → 0 1 4 7
    seq_index = getattr(iterable, 'index', None)
    if seq_index is None:
        iterator = islice(iterable, start, stop)
        for i, element in enumerate(iterator, start):
            if element is value or element == value:
                yield i
    else:
        stop = len(iterable) if stop is None else stop
        i = start
        with suppress(ValueError):
            while True:
                yield (i := seq_index(value, i, stop))
                i += 1

def iter_except(function, exception, first=None):
    "Convert a call-until-exception interface to an iterator interface."
    # iter_except(d.popitem, KeyError) → non-blocking dictionary iterator
    with suppress(exception):
        if first is not None:
            yield first()
        while True:
            yield function()

以下配方更具数学风格:

def multinomial(*counts):
    "Number of distinct arrangements of a multiset."
    # Counter('abracadabra').values() → 5 2 2 1 1
    # multinomial(5, 2, 2, 1, 1) → 83160
    return prod(map(comb, accumulate(counts), counts))

def powerset(iterable):
    "Subsequences of the iterable from shortest to longest."
    # powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

def sum_of_squares(iterable):
    "Add up the squares of the input values."
    # sum_of_squares([10, 20, 30]) → 1400
    return sumprod(*tee(iterable))

def reshape(matrix, columns):
    "Reshape a 2-D matrix to have a given number of columns."
    # reshape([(0, 1), (2, 3), (4, 5)], 3) →  (0, 1, 2), (3, 4, 5)
    return batched(chain.from_iterable(matrix), columns, strict=True)

def transpose(matrix):
    "Swap the rows and columns of a 2-D matrix."
    # transpose([(1, 2, 3), (11, 22, 33)]) → (1, 11) (2, 22) (3, 33)
    return zip(*matrix, strict=True)

def matmul(m1, m2):
    "Multiply two matrices."
    # matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) → (49, 80), (41, 60)
    n = len(m2[0])
    return batched(starmap(sumprod, product(m1, transpose(m2))), n)

def convolve(signal, kernel):
    """Discrete linear convolution of two iterables.
    Equivalent to polynomial multiplication.

    Convolutions are mathematically commutative; however, the inputs are
    evaluated differently.  The signal is consumed lazily and can be
    infinite. The kernel is fully consumed before the calculations begin.

    Article:  https://betterexplained.com/articles/intuitive-convolution/
    Video:    https://www.youtube.com/watch?v=KuXjwB4LzSA
    """
    # convolve([1, -1, -20], [1, -3]) → 1 -4 -17 60
    # convolve(data, [0.25, 0.25, 0.25, 0.25]) → Moving average (blur)
    # convolve(data, [1/2, 0, -1/2]) → 1st derivative estimate
    # convolve(data, [1, -2, 1]) → 2nd derivative estimate
    kernel = tuple(kernel)[::-1]
    n = len(kernel)
    padded_signal = chain(repeat(0, n-1), signal, repeat(0, n-1))
    windowed_signal = sliding_window(padded_signal, n)
    return map(sumprod, repeat(kernel), windowed_signal)

def polynomial_from_roots(roots):
    """Compute a polynomial's coefficients from its roots.

       (x - 5) (x + 4) (x - 3)  expands to:   x³ -4x² -17x + 60
    """
    # polynomial_from_roots([5, -4, 3]) → [1, -4, -17, 60]
    factors = zip(repeat(1), map(neg, roots))
    return list(reduce(convolve, factors, [1]))

def polynomial_eval(coefficients, x):
    """Evaluate a polynomial at a specific value.

    Computes with better numeric stability than Horner's method.
    """
    # Evaluate x³ -4x² -17x + 60 at x = 5
    # polynomial_eval([1, -4, -17, 60], x=5) → 0
    n = len(coefficients)
    if not n:
        return type(x)(0)
    powers = map(pow, repeat(x), reversed(range(n)))
    return sumprod(coefficients, powers)

def polynomial_derivative(coefficients):
    """Compute the first derivative of a polynomial.

       f(x)  =  x³ -4x² -17x + 60
       f'(x) = 3x² -8x  -17
    """
    # polynomial_derivative([1, -4, -17, 60]) → [3, -8, -17]
    n = len(coefficients)
    powers = reversed(range(1, n))
    return list(map(mul, coefficients, powers))

def sieve(n):
    "Primes less than n."
    # sieve(30) → 2 3 5 7 11 13 17 19 23 29
    if n > 2:
        yield 2
    data = bytearray((0, 1)) * (n // 2)
    for p in iter_index(data, 1, start=3, stop=isqrt(n) + 1):
        data[p*p : n : p+p] = bytes(len(range(p*p, n, p+p)))
    yield from iter_index(data, 1, start=3)

def factor(n):
    "Prime factors of n."
    # factor(99) → 3 3 11
    # factor(1_000_000_000_000_007) → 47 59 360620266859
    # factor(1_000_000_000_000_403) → 1000000000000403
    for prime in sieve(isqrt(n) + 1):
        while not n % prime:
            yield prime
            n //= prime
            if n == 1:
                return
    if n > 1:
        yield n

def is_prime(n):
    "Return True if n is prime."
    # is_prime(1_000_000_000_000_403) → True
    return n > 1 and next(factor(n)) == n

def totient(n):
    "Count of natural numbers up to n that are coprime to n."
    # https://mathworld.net.cn/TotientFunction.html
    # totient(12) → 4 because len([1, 5, 7, 11]) == 4
    for prime in set(factor(n)):
        n -= n // prime
    return n