itertools — 用于高效循环创建迭代器的函数


此模块实现了一些 迭代器 构建块,其灵感来自 APL、Haskell 和 SML 中的结构。每个结构都已转换为适合 Python 的形式。

该模块标准化了一组核心的快速、内存高效的工具,这些工具可以单独使用或组合使用。它们共同构成了一个“迭代器代数”,使得在纯 Python 中简洁高效地构建专用工具成为可能。

例如,SML 提供了一个列表工具:tabulate(f),它生成一个序列 f(0), f(1), ...。在 Python 中可以通过组合 map()count() 来达到相同的效果,形成 map(f, count())

这些工具及其内置对应工具也可以与 operator 模块中的高速函数很好地配合使用。例如,可以将乘法运算符映射到两个向量上,以形成高效的点积:sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))

无限迭代器

迭代器

参数

结果

示例

count()

[start[, step]]

start, start+step, start+2*step, …

count(10) 10 11 12 13 14 ...

cycle()

p

p0, p1, … plast, p0, p1, …

cycle('ABCD') A B C D A B C D ...

repeat()

elem [,n]

elem, elem, elem, … 无限循环或最多 n 次

repeat(10, 3) 10 10 10

在最短输入序列处终止的迭代器

迭代器

参数

结果

示例

accumulate()

p [,func]

p0, p0+p1, p0+p1+p2, …

accumulate([1,2,3,4,5]) 1 3 6 10 15

batched()

p, n

(p0, p1, …, p_n-1), …

batched('ABCDEFG', n=3) ABC DEF G

chain()

p, q, …

p0, p1, … plast, q0, q1, …

chain('ABC', 'DEF') A B C D E F

chain.from_iterable()

iterable

p0, p1, … plast, q0, q1, …

chain.from_iterable(['ABC', 'DEF']) A B C D E F

compress()

data, selectors

(d[0] if s[0]), (d[1] if s[1]), …

compress('ABCDEF', [1,0,1,0,1,1]) A C E F

dropwhile()

predicate, seq

seq[n], seq[n+1], 从谓词失败时开始

dropwhile(lambda x: x<5, [1,4,6,3,8]) 6 3 8

filterfalse()

predicate, seq

predicate(elem) 为假的 seq 元素

filterfalse(lambda x: x<5, [1,4,6,3,8]) 6 8

groupby()

iterable[, key]

按 key(v) 的值分组的子迭代器

islice()

seq, [start,] stop [, step]

来自 seq[start:stop:step] 的元素

islice('ABCDEFG', 2, None) C D E F G

pairwise()

iterable

(p[0], p[1]), (p[1], p[2])

pairwise('ABCDEFG') AB BC CD DE EF FG

starmap()

func, seq

func(*seq[0]), func(*seq[1]), …

starmap(pow, [(2,5), (3,2), (10,3)]) 32 9 1000

takewhile()

predicate, seq

seq[0], seq[1], 直到谓词失败

takewhile(lambda x: x<5, [1,4,6,3,8]) 1 4

tee()

it, n

it1, it2, … itn 将一个迭代器拆分为 n 个

zip_longest()

p, q, …

(p[0], q[0]), (p[1], q[1]), …

zip_longest('ABCD', 'xy', fillvalue='-') Ax By C- D-

组合迭代器

迭代器

参数

结果

product()

p, q, … [repeat=1]

笛卡尔积,相当于嵌套的 for 循环

permutations()

p[, r]

长度为 r 的元组,所有可能的排序,没有重复元素

combinations()

p, r

长度为 r 的元组,按排序顺序排列,没有重复元素

combinations_with_replacement()

p, r

长度为 r 的元组,按排序顺序排列,允许重复元素

示例

结果

product('ABCD', repeat=2)

AA AB AC AD BA BB BC BD CA CB CC CD DA DB DC DD

permutations('ABCD', 2)

AB AC AD BA BC BD CA CB CD DA DB DC

combinations('ABCD', 2)

AB AC AD BC BD CD

combinations_with_replacement('ABCD', 2)

AA AB AC AD BB BC BD CC CD DD

迭代器函数

以下模块函数都用于构造和返回迭代器。有些函数提供无限长度的流,因此只能由截断流的函数或循环访问它们。

itertools.accumulate(iterable[, function, *, initial=None])

创建一个迭代器,该迭代器返回累积总和或其他二元函数的累积结果。

function 默认为加法。 function 应该接受两个参数,一个累积总数和一个来自 iterable 的值。

如果提供了 initial 值,则累积将从该值开始,并且输出将比输入迭代器多一个元素。

大致相当于

def accumulate(iterable, function=operator.add, *, initial=None):
    'Return running totals'
    # accumulate([1,2,3,4,5]) → 1 3 6 10 15
    # accumulate([1,2,3,4,5], initial=100) → 100 101 103 106 110 115
    # accumulate([1,2,3,4,5], operator.mul) → 1 2 6 24 120

    iterator = iter(iterable)
    total = initial
    if initial is None:
        try:
            total = next(iterator)
        except StopIteration:
            return

    yield total
    for element in iterator:
        total = function(total, element)
        yield total

function 参数可以设置为 min() 用于运行最小值,max() 用于运行最大值,或 operator.mul() 用于运行乘积。 摊销表 可以通过累积利息和应用付款来构建

>>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8]
>>> list(accumulate(data, max))              # running maximum
[3, 4, 6, 6, 6, 9, 9, 9, 9, 9]
>>> list(accumulate(data, operator.mul))     # running product
[3, 12, 72, 144, 144, 1296, 0, 0, 0, 0]

# Amortize a 5% loan of 1000 with 10 annual payments of 90
>>> update = lambda balance, payment: round(balance * 1.05) - payment
>>> list(accumulate(repeat(90, 10), update, initial=1_000))
[1000, 960, 918, 874, 828, 779, 728, 674, 618, 559, 497]

有关仅返回最终累积值的类似函数,请参阅 functools.reduce()

3.2 版新增。

3.3 版更改: 添加了可选的 function 参数。

3.8 版更改: 添加了可选的 initial 参数。

itertools.batched(iterable, n)

iterable 中的数据批量处理成长度为 n 的元组。 最后一批可能短于 n

循环遍历输入迭代器并将数据累积到大小不超过 n 的元组中。 输入是延迟消耗的,刚好足以填充一批。 一旦批处理已满或输入迭代器耗尽,就会立即生成结果

>>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet']
>>> unflattened = list(batched(flattened_data, 2))
>>> unflattened
[('roses', 'red'), ('violets', 'blue'), ('sugar', 'sweet')]

大致相当于

def batched(iterable, n):
    # batched('ABCDEFG', 3) → ABC DEF G
    if n < 1:
        raise ValueError('n must be at least one')
    iterator = iter(iterable)
    while batch := tuple(islice(iterator, n)):
        yield batch

3.12 版新增。

itertools.chain(*iterables)

创建一个迭代器,该迭代器从第一个迭代器返回元素,直到它耗尽,然后继续到下一个迭代器,直到所有迭代器都耗尽。 用于将连续序列视为单个序列。 大致相当于

def chain(*iterables):
    # chain('ABC', 'DEF') → A B C D E F
    for iterable in iterables:
        yield from iterable
classmethod chain.from_iterable(iterable)

chain() 的备用构造函数。 从延迟评估的单个可迭代参数获取链接输入。 大致相当于

def from_iterable(iterables):
    # chain.from_iterable(['ABC', 'DEF']) → A B C D E F
    for iterable in iterables:
        yield from iterable
itertools.combinations(iterable, r)

返回来自输入 iterable 的元素的 r 长度子序列。

输出是 product() 的子序列,仅保留作为 iterable 子序列的条目。 输出的长度由 math.comb() 给出,它在 0 r n 时计算 n! / r! / (n - r)!,在 r > n 时计算零。

组合元组根据输入 iterable 的顺序按字典顺序发出。 如果输入 iterable 已排序,则输出元组将按排序顺序生成。

元素的唯一性基于其位置,而不是其值。 如果输入元素是唯一的,则每个组合中将没有重复的值。

大致相当于

def combinations(iterable, r):
    # combinations('ABCD', 2) → AB AC AD BC BD CD
    # combinations(range(4), 3) → 012 013 023 123

    pool = tuple(iterable)
    n = len(pool)
    if r > n:
        return
    indices = list(range(r))

    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != i + n - r:
                break
        else:
            return
        indices[i] += 1
        for j in range(i+1, r):
            indices[j] = indices[j-1] + 1
        yield tuple(pool[i] for i in indices)
itertools.combinations_with_replacement(iterable, r)

返回来自输入 iterable 的元素的 r 长度子序列,允许单个元素重复多次。

输出是 product() 的子序列,仅保留作为 iterable 的子序列(可能包含重复元素)的条目。 返回的子序列数在 n > 0 时为 (n + r - 1)! / r! / (n - 1)!

组合元组根据输入 iterable 的顺序按字典顺序发出。 如果输入 iterable 已排序,则输出元组将按排序顺序生成。

元素的唯一性基于其位置,而不是其值。 如果输入元素是唯一的,则生成的组合也将是唯一的。

大致相当于

def combinations_with_replacement(iterable, r):
    # combinations_with_replacement('ABC', 2) → AA AB AC BB BC CC

    pool = tuple(iterable)
    n = len(pool)
    if not n and r:
        return
    indices = [0] * r

    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != n - 1:
                break
        else:
            return
        indices[i:] = [indices[i] + 1] * (r - i)
        yield tuple(pool[i] for i in indices)

3.1 版新增。

itertools.compress(data, selectors)

创建一个迭代器,该迭代器从 data 中返回元素,其中 selectors 中的对应元素为 true。 当 dataselectors 迭代器耗尽时停止。 大致相当于

def compress(data, selectors):
    # compress('ABCDEF', [1,0,1,0,1,1]) → A C E F
    return (datum for datum, selector in zip(data, selectors) if selector)

3.1 版新增。

itertools.count(start=0, step=1)

创建一个迭代器,该迭代器返回从 start 开始的均匀间隔的值。 可以与 map() 一起使用以生成连续数据点,或与 zip() 一起使用以添加序列号。 大致相当于

def count(start=0, step=1):
    # count(10) → 10 11 12 13 14 ...
    # count(2.5, 0.5) → 2.5 3.0 3.5 ...
    n = start
    while True:
        yield n
        n += step

使用浮点数计数时,有时可以通过替换乘法代码来获得更好的精度,例如:(start + step * i for i in count())

3.1 版更改: 添加了 step 参数并允许非整数参数。

itertools.cycle(iterable)

创建一个迭代器,该迭代器从 iterable 返回元素并保存每个元素的副本。 当迭代器耗尽时,从保存的副本中返回元素。 无限重复。 大致相当于

def cycle(iterable):
    # cycle('ABCD') → A B C D A B C D A B C D ...
    saved = []
    for element in iterable:
        yield element
        saved.append(element)
    while saved:
        for element in saved:
            yield element

此迭代工具可能需要大量的辅助存储空间(取决于迭代器的长度)。

itertools.dropwhile(predicate, iterable)

创建一个迭代器,该迭代器在 predicate 为 true 时从 iterable 中删除元素,然后返回每个元素。 大致相当于

def dropwhile(predicate, iterable):
    # dropwhile(lambda x: x<5, [1,4,6,3,8]) → 6 3 8

    iterator = iter(iterable)
    for x in iterator:
        if not predicate(x):
            yield x
            break

    for x in iterator:
        yield x

请注意,在谓词第一次变为 false 之前,这不会产生*任何*输出,因此此迭代工具的启动时间可能很长。

itertools.filterfalse(predicate, iterable)

创建一个迭代器,它从 *iterable* 中过滤元素,只返回那些 *predicate* 返回假值的元素。如果 *predicate* 是 None,则返回假值项。大致相当于

def filterfalse(predicate, iterable):
    # filterfalse(lambda x: x<5, [1,4,6,3,8]) → 6 8
    if predicate is None:
        predicate = bool
    for x in iterable:
        if not predicate(x):
            yield x
itertools.groupby(iterable, key=None)

创建一个迭代器,它从 *iterable* 中返回连续的键和组。*key* 是一个计算每个元素键值的函数。如果没有指定或为 None,则 *key* 默认为标识函数,并按原样返回元素。通常,iterable 需要已经按照相同的键函数排序。

groupby() 的操作类似于 Unix 中的 uniq 过滤器。每次键函数的值发生变化时,它都会生成一个中断或新组(这就是为什么通常需要使用相同的键函数对数据进行排序)。这种行为不同于 SQL 的 GROUP BY,后者会聚合公共元素,而不管它们的输入顺序如何。

返回的组本身是一个迭代器,它与 groupby() 共享底层迭代器。因为源是共享的,所以当 groupby() 对象前进时,前一个组就不再可见。因此,如果以后需要这些数据,应该将其存储为列表

groups = []
uniquekeys = []
data = sorted(data, key=keyfunc)
for k, g in groupby(data, keyfunc):
    groups.append(list(g))      # Store group iterator as a list
    uniquekeys.append(k)

groupby() 大致相当于

def groupby(iterable, key=None):
    # [k for k, g in groupby('AAAABBBCCDAABBB')] → A B C D A B
    # [list(g) for k, g in groupby('AAAABBBCCD')] → AAAA BBB CC D

    keyfunc = (lambda x: x) if key is None else key
    iterator = iter(iterable)
    exhausted = False

    def _grouper(target_key):
        nonlocal curr_value, curr_key, exhausted
        yield curr_value
        for curr_value in iterator:
            curr_key = keyfunc(curr_value)
            if curr_key != target_key:
                return
            yield curr_value
        exhausted = True

    try:
        curr_value = next(iterator)
    except StopIteration:
        return
    curr_key = keyfunc(curr_value)

    while not exhausted:
        target_key = curr_key
        curr_group = _grouper(target_key)
        yield curr_key, curr_group
        if curr_key == target_key:
            for _ in curr_group:
                pass
itertools.islice(iterable, stop)
itertools.islice(iterable, start, stop[, step])

创建一个迭代器,它从可迭代对象中返回选定的元素。工作原理类似于序列切片,但不支持 *start*、*stop* 或 *step* 的负值。

如果 *start* 为零或 None,则迭代从零开始。否则,将跳过可迭代对象中的元素,直到达到 *start*。

如果 *stop* 为 None,则迭代将继续进行,直到迭代器耗尽为止(如果有的话)。否则,它将在指定位置停止。

如果 *step* 为 None,则步长默认为 1。元素将连续返回,除非 *step* 设置为大于 1,这将导致跳过项目。

大致相当于

def islice(iterable, *args):
    # islice('ABCDEFG', 2) → A B
    # islice('ABCDEFG', 2, 4) → C D
    # islice('ABCDEFG', 2, None) → C D E F G
    # islice('ABCDEFG', 0, None, 2) → A C E G

    s = slice(*args)
    start = 0 if s.start is None else s.start
    stop = s.stop
    step = 1 if s.step is None else s.step
    if start < 0 or (stop is not None and stop < 0) or step <= 0:
        raise ValueError

    indices = count() if stop is None else range(max(start, stop))
    next_i = start
    for i, element in zip(indices, iterable):
        if i == next_i:
            yield element
            next_i += step
itertools.pairwise(iterable)

从输入的 *iterable* 中返回连续的重叠对。

输出迭代器中 2 元组的数量将比输入的数量少 1。如果输入迭代器的值少于两个,它将为空。

大致相当于

def pairwise(iterable):
    # pairwise('ABCDEFG') → AB BC CD DE EF FG
    iterator = iter(iterable)
    a = next(iterator, None)
    for b in iterator:
        yield a, b
        a = b

在 3.10 版本中添加。

itertools.permutations(iterable, r=None)

从 *iterable* 中返回连续的 *r* 长度 元素排列

如果未指定 *r* 或 *r* 为 None,则 *r* 默认为 *iterable* 的长度,并生成所有可能的全长排列。

输出是 product() 的子序列,其中已过滤掉具有重复元素的条目。输出的长度由 math.perm() 给出,它在 0 r n 时计算 n! / (n - r)!,在 r > n 时计算零。

排列元组按照输入 *iterable* 的顺序以字典顺序发出。如果输入 *iterable* 已排序,则输出元组将按排序顺序生成。

元素的唯一性取决于它们的位置,而不是它们的值。如果输入元素是唯一的,则在一个排列中将没有重复的值。

大致相当于

def permutations(iterable, r=None):
    # permutations('ABCD', 2) → AB AC AD BA BC BD CA CB CD DA DB DC
    # permutations(range(3)) → 012 021 102 120 201 210

    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    if r > n:
        return

    indices = list(range(n))
    cycles = list(range(n, n-r, -1))
    yield tuple(pool[i] for i in indices[:r])

    while n:
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i
            else:
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
                yield tuple(pool[i] for i in indices[:r])
                break
        else:
            return
itertools.product(*iterables, repeat=1)

输入迭代器的笛卡尔积。

大致相当于生成器表达式中的嵌套 for 循环。例如,product(A, B) 返回的结果与 ((x,y) for x in A for y in B) 相同。

嵌套循环像里程表一样循环,最右边的元素在每次迭代中前进。这种模式创建了一个字典顺序,因此如果输入的迭代器已排序,则产品元组将按排序顺序发出。

要计算一个迭代器与其自身的乘积,请使用可选的 *repeat* 关键字参数指定重复次数。例如,product(A, repeat=4)product(A, A, A, A) 含义相同。

此函数大致等效于以下代码,但实际实现不会在内存中构建中间结果

def product(*iterables, repeat=1):
    # product('ABCD', 'xy') → Ax Ay Bx By Cx Cy Dx Dy
    # product(range(2), repeat=3) → 000 001 010 011 100 101 110 111

    pools = [tuple(pool) for pool in iterables] * repeat

    result = [[]]
    for pool in pools:
        result = [x+[y] for x in result for y in pool]

    for prod in result:
        yield tuple(prod)

product() 运行之前,它会完全消耗输入迭代器,并在内存中保留值池以生成产品。因此,它只对有限的输入有用。

itertools.repeat(object[, times])

创建一个迭代器,它反复返回 *object*。无限期地运行,除非指定了 *times* 参数。

大致相当于

def repeat(object, times=None):
    # repeat(10, 3) → 10 10 10
    if times is None:
        while True:
            yield object
    else:
        for i in range(times):
            yield object

*repeat* 的常见用途是向 *map* 或 *zip* 提供恒定值的流

>>> list(map(pow, range(10), repeat(2)))
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
itertools.starmap(function, iterable)

创建一个迭代器,它使用从 *iterable* 获取的参数来计算 *function*。当参数参数已经被“预压缩”成元组时,使用它来代替 map()

map()starmap() 之间的差异类似于 function(a,b)function(*c) 之间的区别。大致相当于

def starmap(function, iterable):
    # starmap(pow, [(2,5), (3,2), (10,3)]) → 32 9 1000
    for args in iterable:
        yield function(*args)
itertools.takewhile(predicate, iterable)

创建一个迭代器,只要 *predicate* 为真,就从 *iterable* 中返回元素。大致相当于

def takewhile(predicate, iterable):
    # takewhile(lambda x: x<5, [1,4,6,3,8]) → 1 4
    for x in iterable:
        if not predicate(x):
            break
        yield x

注意,第一个不满足谓词条件的元素将从输入迭代器中消耗掉,并且无法访问它。如果应用程序希望在 *takewhile* 运行到耗尽后进一步使用输入迭代器,这可能是一个问题。要解决此问题,请考虑改用 more-iterools before_and_after()

itertools.tee(iterable, n=2)

从单个可迭代对象返回 *n* 个独立的迭代器。

大致相当于

def tee(iterable, n=2):
    iterator = iter(iterable)
    shared_link = [None, None]
    return tuple(_tee(iterator, shared_link) for _ in range(n))

def _tee(iterator, link):
    try:
        while True:
            if link[1] is None:
                link[0] = next(iterator)
                link[1] = [None, None]
            value, link = link
            yield value
    except StopIteration:
        return

创建 tee() 后,不应在其他任何地方使用原始的 *iterable*;否则,*iterable* 可能会在 tee 对象不知情的情况下被提前。

tee 迭代器不是线程安全的。即使原始的 *iterable* 是线程安全的,在同时使用由同一个 tee() 调用返回的迭代器时,也可能会引发 RuntimeError

此 itertool 可能需要大量的辅助存储空间(取决于需要存储多少临时数据)。通常,如果一个迭代器在另一个迭代器启动之前使用大部分或全部数据,则使用 list() 比使用 tee() 更快。

itertools.zip_longest(*iterables, fillvalue=None)

创建一个迭代器,用于聚合来自每个 *iterables* 的元素。

如果可迭代对象的长度不一致,则使用 *fillvalue* 填充缺失值。如果未指定,则 *fillvalue* 默认为 None

迭代将持续到最长的可迭代对象耗尽。

大致相当于

def zip_longest(*iterables, fillvalue=None):
    # zip_longest('ABCD', 'xy', fillvalue='-') → Ax By C- D-

    iterators = list(map(iter, iterables))
    num_active = len(iterators)
    if not num_active:
        return

    while True:
        values = []
        for i, iterator in enumerate(iterators):
            try:
                value = next(iterator)
            except StopIteration:
                num_active -= 1
                if not num_active:
                    return
                iterators[i] = repeat(fillvalue)
                value = fillvalue
            values.append(value)
        yield tuple(values)

如果其中一个可迭代对象可能是无限的,则应使用限制调用次数的内容(例如 islice()takewhile())包装 zip_longest() 函数。

Itertools 配方

本节展示了使用现有 itertools 作为构建块创建扩展工具集的配方。

itertools 配方的主要目的是教育意义。这些配方展示了思考单个工具的各种方式,例如,chain.from_iterable 与扁平化的概念有关。这些配方还提供了如何组合工具的想法,例如,starmap()repeat() 如何协同工作。这些配方还展示了将 itertools 与 operatorcollections 模块以及内置 itertools(如 map()filter()reversed()enumerate())一起使用的模式。

配方的第二个目的是充当孵化器。accumulate()compress()pairwise() itertools 最初都是作为配方出现的。目前,sliding_window()iter_index()sieve() 配方正在测试中,以查看它们是否具有价值。

几乎所有这些配方以及许多其他配方都可以从 Python 包索引中找到的 more-itertools 项目中安装

python -m pip install more-itertools

许多配方都提供与底层工具集相同的高性能。通过一次处理一个元素而不是一次将整个可迭代对象加载到内存中,可以保持卓越的内存性能。通过以 函数式风格 将工具链接在一起,代码量保持较小。通过优先选择“向量化”构建块而不是使用 for 循环和 生成器(它们会产生解释器开销),可以保持高速。

import collections
import contextlib
import functools
import math
import operator
import random

def take(n, iterable):
    "Return first n items of the iterable as a list."
    return list(islice(iterable, n))

def prepend(value, iterable):
    "Prepend a single value in front of an iterable."
    # prepend(1, [2, 3, 4]) → 1 2 3 4
    return chain([value], iterable)

def tabulate(function, start=0):
    "Return function(0), function(1), ..."
    return map(function, count(start))

def repeatfunc(func, times=None, *args):
    "Repeat calls to func with specified arguments."
    if times is None:
        return starmap(func, repeat(args))
    return starmap(func, repeat(args, times))

def flatten(list_of_lists):
    "Flatten one level of nesting."
    return chain.from_iterable(list_of_lists)

def ncycles(iterable, n):
    "Returns the sequence elements n times."
    return chain.from_iterable(repeat(tuple(iterable), n))

def tail(n, iterable):
    "Return an iterator over the last n items."
    # tail(3, 'ABCDEFG') → E F G
    return iter(collections.deque(iterable, maxlen=n))

def consume(iterator, n=None):
    "Advance the iterator n-steps ahead. If n is None, consume entirely."
    # Use functions that consume iterators at C speed.
    if n is None:
        collections.deque(iterator, maxlen=0)
    else:
        next(islice(iterator, n, n), None)

def nth(iterable, n, default=None):
    "Returns the nth item or a default value."
    return next(islice(iterable, n, None), default)

def quantify(iterable, predicate=bool):
    "Given a predicate that returns True or False, count the True results."
    return sum(map(predicate, iterable))

def first_true(iterable, default=False, predicate=None):
    "Returns the first true value or the *default* if there is no true value."
    # first_true([a,b,c], x) → a or b or c or x
    # first_true([a,b], x, f) → a if f(a) else b if f(b) else x
    return next(filter(predicate, iterable), default)

def all_equal(iterable, key=None):
    "Returns True if all the elements are equal to each other."
    # all_equal('4٤௪౪໔', key=int) → True
    return len(take(2, groupby(iterable, key))) <= 1

def unique_justseen(iterable, key=None):
    "Yield unique elements, preserving order. Remember only the element just seen."
    # unique_justseen('AAAABBBCCDAABBB') → A B C D A B
    # unique_justseen('ABBcCAD', str.casefold) → A B c A D
    if key is None:
        return map(operator.itemgetter(0), groupby(iterable))
    return map(next, map(operator.itemgetter(1), groupby(iterable, key)))

def unique_everseen(iterable, key=None):
    "Yield unique elements, preserving order. Remember all elements ever seen."
    # unique_everseen('AAAABBBCCDAABBB') → A B C D
    # unique_everseen('ABBcCAD', str.casefold) → A B c D
    seen = set()
    if key is None:
        for element in filterfalse(seen.__contains__, iterable):
            seen.add(element)
            yield element
    else:
        for element in iterable:
            k = key(element)
            if k not in seen:
                seen.add(k)
                yield element

def unique(iterable, key=None, reverse=False):
   "Yield unique elements in sorted order. Supports unhashable inputs."
   # unique([[1, 2], [3, 4], [1, 2]]) → [1, 2] [3, 4]
   return unique_justseen(sorted(iterable, key=key, reverse=reverse), key=key)

def sliding_window(iterable, n):
    "Collect data into overlapping fixed-length chunks or blocks."
    # sliding_window('ABCDEFG', 4) → ABCD BCDE CDEF DEFG
    iterator = iter(iterable)
    window = collections.deque(islice(iterator, n - 1), maxlen=n)
    for x in iterator:
        window.append(x)
        yield tuple(window)

def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
    "Collect data into non-overlapping fixed-length chunks or blocks."
    # grouper('ABCDEFG', 3, fillvalue='x') → ABC DEF Gxx
    # grouper('ABCDEFG', 3, incomplete='strict') → ABC DEF ValueError
    # grouper('ABCDEFG', 3, incomplete='ignore') → ABC DEF
    iterators = [iter(iterable)] * n
    match incomplete:
        case 'fill':
            return zip_longest(*iterators, fillvalue=fillvalue)
        case 'strict':
            return zip(*iterators, strict=True)
        case 'ignore':
            return zip(*iterators)
        case _:
            raise ValueError('Expected fill, strict, or ignore')

def roundrobin(*iterables):
    "Visit input iterables in a cycle until each is exhausted."
    # roundrobin('ABC', 'D', 'EF') → A D E B F C
    # Algorithm credited to George Sakkis
    iterators = map(iter, iterables)
    for num_active in range(len(iterables), 0, -1):
        iterators = cycle(islice(iterators, num_active))
        yield from map(next, iterators)

def partition(predicate, iterable):
    """Partition entries into false entries and true entries.

    If *predicate* is slow, consider wrapping it with functools.lru_cache().
    """
    # partition(is_odd, range(10)) → 0 2 4 6 8   and  1 3 5 7 9
    t1, t2 = tee(iterable)
    return filterfalse(predicate, t1), filter(predicate, t2)

def subslices(seq):
    "Return all contiguous non-empty subslices of a sequence."
    # subslices('ABCD') → A AB ABC ABCD B BC BCD C CD D
    slices = starmap(slice, combinations(range(len(seq) + 1), 2))
    return map(operator.getitem, repeat(seq), slices)

def iter_index(iterable, value, start=0, stop=None):
    "Return indices where a value occurs in a sequence or iterable."
    # iter_index('AABCADEAF', 'A') → 0 1 4 7
    seq_index = getattr(iterable, 'index', None)
    if seq_index is None:
        iterator = islice(iterable, start, stop)
        for i, element in enumerate(iterator, start):
            if element is value or element == value:
                yield i
    else:
        stop = len(iterable) if stop is None else stop
        i = start
        with contextlib.suppress(ValueError):
            while True:
                yield (i := seq_index(value, i, stop))
                i += 1

def iter_except(func, exception, first=None):
    "Convert a call-until-exception interface to an iterator interface."
    # iter_except(d.popitem, KeyError) → non-blocking dictionary iterator
    with contextlib.suppress(exception):
        if first is not None:
            yield first()
        while True:
            yield func()

以下配方具有更强的数学风格

def powerset(iterable):
    "powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

def sum_of_squares(iterable):
    "Add up the squares of the input values."
    # sum_of_squares([10, 20, 30]) → 1400
    return math.sumprod(*tee(iterable))

def reshape(matrix, cols):
    "Reshape a 2-D matrix to have a given number of columns."
    # reshape([(0, 1), (2, 3), (4, 5)], 3) →  (0, 1, 2), (3, 4, 5)
    return batched(chain.from_iterable(matrix), cols)

def transpose(matrix):
    "Swap the rows and columns of a 2-D matrix."
    # transpose([(1, 2, 3), (11, 22, 33)]) → (1, 11) (2, 22) (3, 33)
    return zip(*matrix, strict=True)

def matmul(m1, m2):
    "Multiply two matrices."
    # matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) → (49, 80), (41, 60)
    n = len(m2[0])
    return batched(starmap(math.sumprod, product(m1, transpose(m2))), n)

def convolve(signal, kernel):
    """Discrete linear convolution of two iterables.
    Equivalent to polynomial multiplication.

    Convolutions are mathematically commutative; however, the inputs are
    evaluated differently.  The signal is consumed lazily and can be
    infinite. The kernel is fully consumed before the calculations begin.

    Article:  https://betterexplained.com/articles/intuitive-convolution/
    Video:    https://www.youtube.com/watch?v=KuXjwB4LzSA
    """
    # convolve([1, -1, -20], [1, -3]) → 1 -4 -17 60
    # convolve(data, [0.25, 0.25, 0.25, 0.25]) → Moving average (blur)
    # convolve(data, [1/2, 0, -1/2]) → 1st derivative estimate
    # convolve(data, [1, -2, 1]) → 2nd derivative estimate
    kernel = tuple(kernel)[::-1]
    n = len(kernel)
    padded_signal = chain(repeat(0, n-1), signal, repeat(0, n-1))
    windowed_signal = sliding_window(padded_signal, n)
    return map(math.sumprod, repeat(kernel), windowed_signal)

def polynomial_from_roots(roots):
    """Compute a polynomial's coefficients from its roots.

       (x - 5) (x + 4) (x - 3)  expands to:   x³ -4x² -17x + 60
    """
    # polynomial_from_roots([5, -4, 3]) → [1, -4, -17, 60]
    factors = zip(repeat(1), map(operator.neg, roots))
    return list(functools.reduce(convolve, factors, [1]))

def polynomial_eval(coefficients, x):
    """Evaluate a polynomial at a specific value.

    Computes with better numeric stability than Horner's method.
    """
    # Evaluate x³ -4x² -17x + 60 at x = 5
    # polynomial_eval([1, -4, -17, 60], x=5) → 0
    n = len(coefficients)
    if not n:
        return type(x)(0)
    powers = map(pow, repeat(x), reversed(range(n)))
    return math.sumprod(coefficients, powers)

def polynomial_derivative(coefficients):
    """Compute the first derivative of a polynomial.

       f(x)  =  x³ -4x² -17x + 60
       f'(x) = 3x² -8x  -17
    """
    # polynomial_derivative([1, -4, -17, 60]) → [3, -8, -17]
    n = len(coefficients)
    powers = reversed(range(1, n))
    return list(map(operator.mul, coefficients, powers))

def sieve(n):
    "Primes less than n."
    # sieve(30) → 2 3 5 7 11 13 17 19 23 29
    if n > 2:
        yield 2
    data = bytearray((0, 1)) * (n // 2)
    for p in iter_index(data, 1, start=3, stop=math.isqrt(n) + 1):
        data[p*p : n : p+p] = bytes(len(range(p*p, n, p+p)))
    yield from iter_index(data, 1, start=3)

def factor(n):
    "Prime factors of n."
    # factor(99) → 3 3 11
    # factor(1_000_000_000_000_007) → 47 59 360620266859
    # factor(1_000_000_000_000_403) → 1000000000000403
    for prime in sieve(math.isqrt(n) + 1):
        while not n % prime:
            yield prime
            n //= prime
            if n == 1:
                return
    if n > 1:
        yield n

def totient(n):
    "Count of natural numbers up to n that are coprime to n."
    # https://mathworld.wolfram.com/TotientFunction.html
    # totient(12) → 4 because len([1, 5, 7, 11]) == 4
    for prime in set(factor(n)):
        n -= n // prime
    return n