itertools
— 用于高效循环创建迭代器的函数¶
此模块实现了一些 迭代器 构建块,其灵感来自 APL、Haskell 和 SML 中的结构。每个结构都已转换为适合 Python 的形式。
该模块标准化了一组核心的快速、内存高效的工具,这些工具可以单独使用或组合使用。它们共同构成了一个“迭代器代数”,使得在纯 Python 中简洁高效地构建专用工具成为可能。
例如,SML 提供了一个列表工具:tabulate(f)
,它生成一个序列 f(0), f(1), ...
。在 Python 中可以通过组合 map()
和 count()
来达到相同的效果,形成 map(f, count())
。
这些工具及其内置对应工具也可以与 operator
模块中的高速函数很好地配合使用。例如,可以将乘法运算符映射到两个向量上,以形成高效的点积:sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))
。
无限迭代器
迭代器 |
参数 |
结果 |
示例 |
---|---|---|---|
[start[, step]] |
start, start+step, start+2*step, … |
|
|
p |
p0, p1, … plast, p0, p1, … |
|
|
elem [,n] |
elem, elem, elem, … 无限循环或最多 n 次 |
|
在最短输入序列处终止的迭代器
迭代器 |
参数 |
结果 |
示例 |
---|---|---|---|
p [,func] |
p0, p0+p1, p0+p1+p2, … |
|
|
p, n |
(p0, p1, …, p_n-1), … |
|
|
p, q, … |
p0, p1, … plast, q0, q1, … |
|
|
iterable |
p0, p1, … plast, q0, q1, … |
|
|
data, selectors |
(d[0] if s[0]), (d[1] if s[1]), … |
|
|
predicate, seq |
seq[n], seq[n+1], 从谓词失败时开始 |
|
|
predicate, seq |
predicate(elem) 为假的 seq 元素 |
|
|
iterable[, key] |
按 key(v) 的值分组的子迭代器 |
||
seq, [start,] stop [, step] |
来自 seq[start:stop:step] 的元素 |
|
|
iterable |
(p[0], p[1]), (p[1], p[2]) |
|
|
func, seq |
func(*seq[0]), func(*seq[1]), … |
|
|
predicate, seq |
seq[0], seq[1], 直到谓词失败 |
|
|
it, n |
it1, it2, … itn 将一个迭代器拆分为 n 个 |
||
p, q, … |
(p[0], q[0]), (p[1], q[1]), … |
|
组合迭代器
迭代器 |
参数 |
结果 |
---|---|---|
p, q, … [repeat=1] |
笛卡尔积,相当于嵌套的 for 循环 |
|
p[, r] |
长度为 r 的元组,所有可能的排序,没有重复元素 |
|
p, r |
长度为 r 的元组,按排序顺序排列,没有重复元素 |
|
p, r |
长度为 r 的元组,按排序顺序排列,允许重复元素 |
示例 |
结果 |
---|---|
|
|
|
|
|
|
|
|
迭代器函数¶
以下模块函数都用于构造和返回迭代器。有些函数提供无限长度的流,因此只能由截断流的函数或循环访问它们。
- itertools.accumulate(iterable[, function, *, initial=None])¶
创建一个迭代器,该迭代器返回累积总和或其他二元函数的累积结果。
function 默认为加法。 function 应该接受两个参数,一个累积总数和一个来自 iterable 的值。
如果提供了 initial 值,则累积将从该值开始,并且输出将比输入迭代器多一个元素。
大致相当于
def accumulate(iterable, function=operator.add, *, initial=None): 'Return running totals' # accumulate([1,2,3,4,5]) → 1 3 6 10 15 # accumulate([1,2,3,4,5], initial=100) → 100 101 103 106 110 115 # accumulate([1,2,3,4,5], operator.mul) → 1 2 6 24 120 iterator = iter(iterable) total = initial if initial is None: try: total = next(iterator) except StopIteration: return yield total for element in iterator: total = function(total, element) yield total
function 参数可以设置为
min()
用于运行最小值,max()
用于运行最大值,或operator.mul()
用于运行乘积。 摊销表 可以通过累积利息和应用付款来构建>>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8] >>> list(accumulate(data, max)) # running maximum [3, 4, 6, 6, 6, 9, 9, 9, 9, 9] >>> list(accumulate(data, operator.mul)) # running product [3, 12, 72, 144, 144, 1296, 0, 0, 0, 0] # Amortize a 5% loan of 1000 with 10 annual payments of 90 >>> update = lambda balance, payment: round(balance * 1.05) - payment >>> list(accumulate(repeat(90, 10), update, initial=1_000)) [1000, 960, 918, 874, 828, 779, 728, 674, 618, 559, 497]
有关仅返回最终累积值的类似函数,请参阅
functools.reduce()
。3.2 版新增。
3.3 版更改: 添加了可选的 function 参数。
3.8 版更改: 添加了可选的 initial 参数。
- itertools.batched(iterable, n)¶
将 iterable 中的数据批量处理成长度为 n 的元组。 最后一批可能短于 n。
循环遍历输入迭代器并将数据累积到大小不超过 n 的元组中。 输入是延迟消耗的,刚好足以填充一批。 一旦批处理已满或输入迭代器耗尽,就会立即生成结果
>>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet'] >>> unflattened = list(batched(flattened_data, 2)) >>> unflattened [('roses', 'red'), ('violets', 'blue'), ('sugar', 'sweet')]
大致相当于
def batched(iterable, n): # batched('ABCDEFG', 3) → ABC DEF G if n < 1: raise ValueError('n must be at least one') iterator = iter(iterable) while batch := tuple(islice(iterator, n)): yield batch
3.12 版新增。
- itertools.chain(*iterables)¶
创建一个迭代器,该迭代器从第一个迭代器返回元素,直到它耗尽,然后继续到下一个迭代器,直到所有迭代器都耗尽。 用于将连续序列视为单个序列。 大致相当于
def chain(*iterables): # chain('ABC', 'DEF') → A B C D E F for iterable in iterables: yield from iterable
- classmethod chain.from_iterable(iterable)¶
chain()
的备用构造函数。 从延迟评估的单个可迭代参数获取链接输入。 大致相当于def from_iterable(iterables): # chain.from_iterable(['ABC', 'DEF']) → A B C D E F for iterable in iterables: yield from iterable
- itertools.combinations(iterable, r)¶
返回来自输入 iterable 的元素的 r 长度子序列。
输出是
product()
的子序列,仅保留作为 iterable 子序列的条目。 输出的长度由math.comb()
给出,它在0 ≤ r ≤ n
时计算n! / r! / (n - r)!
,在r > n
时计算零。组合元组根据输入 iterable 的顺序按字典顺序发出。 如果输入 iterable 已排序,则输出元组将按排序顺序生成。
元素的唯一性基于其位置,而不是其值。 如果输入元素是唯一的,则每个组合中将没有重复的值。
大致相当于
def combinations(iterable, r): # combinations('ABCD', 2) → AB AC AD BC BD CD # combinations(range(4), 3) → 012 013 023 123 pool = tuple(iterable) n = len(pool) if r > n: return indices = list(range(r)) yield tuple(pool[i] for i in indices) while True: for i in reversed(range(r)): if indices[i] != i + n - r: break else: return indices[i] += 1 for j in range(i+1, r): indices[j] = indices[j-1] + 1 yield tuple(pool[i] for i in indices)
- itertools.combinations_with_replacement(iterable, r)¶
返回来自输入 iterable 的元素的 r 长度子序列,允许单个元素重复多次。
输出是
product()
的子序列,仅保留作为 iterable 的子序列(可能包含重复元素)的条目。 返回的子序列数在n > 0
时为(n + r - 1)! / r! / (n - 1)!
。组合元组根据输入 iterable 的顺序按字典顺序发出。 如果输入 iterable 已排序,则输出元组将按排序顺序生成。
元素的唯一性基于其位置,而不是其值。 如果输入元素是唯一的,则生成的组合也将是唯一的。
大致相当于
def combinations_with_replacement(iterable, r): # combinations_with_replacement('ABC', 2) → AA AB AC BB BC CC pool = tuple(iterable) n = len(pool) if not n and r: return indices = [0] * r yield tuple(pool[i] for i in indices) while True: for i in reversed(range(r)): if indices[i] != n - 1: break else: return indices[i:] = [indices[i] + 1] * (r - i) yield tuple(pool[i] for i in indices)
3.1 版新增。
- itertools.compress(data, selectors)¶
创建一个迭代器,该迭代器从 data 中返回元素,其中 selectors 中的对应元素为 true。 当 data 或 selectors 迭代器耗尽时停止。 大致相当于
def compress(data, selectors): # compress('ABCDEF', [1,0,1,0,1,1]) → A C E F return (datum for datum, selector in zip(data, selectors) if selector)
3.1 版新增。
- itertools.count(start=0, step=1)¶
创建一个迭代器,该迭代器返回从 start 开始的均匀间隔的值。 可以与
map()
一起使用以生成连续数据点,或与zip()
一起使用以添加序列号。 大致相当于def count(start=0, step=1): # count(10) → 10 11 12 13 14 ... # count(2.5, 0.5) → 2.5 3.0 3.5 ... n = start while True: yield n n += step
使用浮点数计数时,有时可以通过替换乘法代码来获得更好的精度,例如:
(start + step * i for i in count())
。3.1 版更改: 添加了 step 参数并允许非整数参数。
- itertools.cycle(iterable)¶
创建一个迭代器,该迭代器从 iterable 返回元素并保存每个元素的副本。 当迭代器耗尽时,从保存的副本中返回元素。 无限重复。 大致相当于
def cycle(iterable): # cycle('ABCD') → A B C D A B C D A B C D ... saved = [] for element in iterable: yield element saved.append(element) while saved: for element in saved: yield element
此迭代工具可能需要大量的辅助存储空间(取决于迭代器的长度)。
- itertools.dropwhile(predicate, iterable)¶
创建一个迭代器,该迭代器在 predicate 为 true 时从 iterable 中删除元素,然后返回每个元素。 大致相当于
def dropwhile(predicate, iterable): # dropwhile(lambda x: x<5, [1,4,6,3,8]) → 6 3 8 iterator = iter(iterable) for x in iterator: if not predicate(x): yield x break for x in iterator: yield x
请注意,在谓词第一次变为 false 之前,这不会产生*任何*输出,因此此迭代工具的启动时间可能很长。
- itertools.filterfalse(predicate, iterable)¶
创建一个迭代器,它从 *iterable* 中过滤元素,只返回那些 *predicate* 返回假值的元素。如果 *predicate* 是
None
,则返回假值项。大致相当于def filterfalse(predicate, iterable): # filterfalse(lambda x: x<5, [1,4,6,3,8]) → 6 8 if predicate is None: predicate = bool for x in iterable: if not predicate(x): yield x
- itertools.groupby(iterable, key=None)¶
创建一个迭代器,它从 *iterable* 中返回连续的键和组。*key* 是一个计算每个元素键值的函数。如果没有指定或为
None
,则 *key* 默认为标识函数,并按原样返回元素。通常,iterable 需要已经按照相同的键函数排序。groupby()
的操作类似于 Unix 中的uniq
过滤器。每次键函数的值发生变化时,它都会生成一个中断或新组(这就是为什么通常需要使用相同的键函数对数据进行排序)。这种行为不同于 SQL 的 GROUP BY,后者会聚合公共元素,而不管它们的输入顺序如何。返回的组本身是一个迭代器,它与
groupby()
共享底层迭代器。因为源是共享的,所以当groupby()
对象前进时,前一个组就不再可见。因此,如果以后需要这些数据,应该将其存储为列表groups = [] uniquekeys = [] data = sorted(data, key=keyfunc) for k, g in groupby(data, keyfunc): groups.append(list(g)) # Store group iterator as a list uniquekeys.append(k)
groupby()
大致相当于def groupby(iterable, key=None): # [k for k, g in groupby('AAAABBBCCDAABBB')] → A B C D A B # [list(g) for k, g in groupby('AAAABBBCCD')] → AAAA BBB CC D keyfunc = (lambda x: x) if key is None else key iterator = iter(iterable) exhausted = False def _grouper(target_key): nonlocal curr_value, curr_key, exhausted yield curr_value for curr_value in iterator: curr_key = keyfunc(curr_value) if curr_key != target_key: return yield curr_value exhausted = True try: curr_value = next(iterator) except StopIteration: return curr_key = keyfunc(curr_value) while not exhausted: target_key = curr_key curr_group = _grouper(target_key) yield curr_key, curr_group if curr_key == target_key: for _ in curr_group: pass
- itertools.islice(iterable, stop)¶
- itertools.islice(iterable, start, stop[, step])
创建一个迭代器,它从可迭代对象中返回选定的元素。工作原理类似于序列切片,但不支持 *start*、*stop* 或 *step* 的负值。
如果 *start* 为零或
None
,则迭代从零开始。否则,将跳过可迭代对象中的元素,直到达到 *start*。如果 *stop* 为
None
,则迭代将继续进行,直到迭代器耗尽为止(如果有的话)。否则,它将在指定位置停止。如果 *step* 为
None
,则步长默认为 1。元素将连续返回,除非 *step* 设置为大于 1,这将导致跳过项目。大致相当于
def islice(iterable, *args): # islice('ABCDEFG', 2) → A B # islice('ABCDEFG', 2, 4) → C D # islice('ABCDEFG', 2, None) → C D E F G # islice('ABCDEFG', 0, None, 2) → A C E G s = slice(*args) start = 0 if s.start is None else s.start stop = s.stop step = 1 if s.step is None else s.step if start < 0 or (stop is not None and stop < 0) or step <= 0: raise ValueError indices = count() if stop is None else range(max(start, stop)) next_i = start for i, element in zip(indices, iterable): if i == next_i: yield element next_i += step
- itertools.pairwise(iterable)¶
从输入的 *iterable* 中返回连续的重叠对。
输出迭代器中 2 元组的数量将比输入的数量少 1。如果输入迭代器的值少于两个,它将为空。
大致相当于
def pairwise(iterable): # pairwise('ABCDEFG') → AB BC CD DE EF FG iterator = iter(iterable) a = next(iterator, None) for b in iterator: yield a, b a = b
在 3.10 版本中添加。
- itertools.permutations(iterable, r=None)¶
从 *iterable* 中返回连续的 *r* 长度 元素排列。
如果未指定 *r* 或 *r* 为
None
,则 *r* 默认为 *iterable* 的长度,并生成所有可能的全长排列。输出是
product()
的子序列,其中已过滤掉具有重复元素的条目。输出的长度由math.perm()
给出,它在0 ≤ r ≤ n
时计算n! / (n - r)!
,在r > n
时计算零。排列元组按照输入 *iterable* 的顺序以字典顺序发出。如果输入 *iterable* 已排序,则输出元组将按排序顺序生成。
元素的唯一性取决于它们的位置,而不是它们的值。如果输入元素是唯一的,则在一个排列中将没有重复的值。
大致相当于
def permutations(iterable, r=None): # permutations('ABCD', 2) → AB AC AD BA BC BD CA CB CD DA DB DC # permutations(range(3)) → 012 021 102 120 201 210 pool = tuple(iterable) n = len(pool) r = n if r is None else r if r > n: return indices = list(range(n)) cycles = list(range(n, n-r, -1)) yield tuple(pool[i] for i in indices[:r]) while n: for i in reversed(range(r)): cycles[i] -= 1 if cycles[i] == 0: indices[i:] = indices[i+1:] + indices[i:i+1] cycles[i] = n - i else: j = cycles[i] indices[i], indices[-j] = indices[-j], indices[i] yield tuple(pool[i] for i in indices[:r]) break else: return
- itertools.product(*iterables, repeat=1)¶
输入迭代器的笛卡尔积。
大致相当于生成器表达式中的嵌套 for 循环。例如,
product(A, B)
返回的结果与((x,y) for x in A for y in B)
相同。嵌套循环像里程表一样循环,最右边的元素在每次迭代中前进。这种模式创建了一个字典顺序,因此如果输入的迭代器已排序,则产品元组将按排序顺序发出。
要计算一个迭代器与其自身的乘积,请使用可选的 *repeat* 关键字参数指定重复次数。例如,
product(A, repeat=4)
与product(A, A, A, A)
含义相同。此函数大致等效于以下代码,但实际实现不会在内存中构建中间结果
def product(*iterables, repeat=1): # product('ABCD', 'xy') → Ax Ay Bx By Cx Cy Dx Dy # product(range(2), repeat=3) → 000 001 010 011 100 101 110 111 pools = [tuple(pool) for pool in iterables] * repeat result = [[]] for pool in pools: result = [x+[y] for x in result for y in pool] for prod in result: yield tuple(prod)
在
product()
运行之前,它会完全消耗输入迭代器,并在内存中保留值池以生成产品。因此,它只对有限的输入有用。
- itertools.repeat(object[, times])¶
创建一个迭代器,它反复返回 *object*。无限期地运行,除非指定了 *times* 参数。
大致相当于
def repeat(object, times=None): # repeat(10, 3) → 10 10 10 if times is None: while True: yield object else: for i in range(times): yield object
*repeat* 的常见用途是向 *map* 或 *zip* 提供恒定值的流
>>> list(map(pow, range(10), repeat(2))) [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
- itertools.starmap(function, iterable)¶
创建一个迭代器,它使用从 *iterable* 获取的参数来计算 *function*。当参数参数已经被“预压缩”成元组时,使用它来代替
map()
。map()
和starmap()
之间的差异类似于function(a,b)
和function(*c)
之间的区别。大致相当于def starmap(function, iterable): # starmap(pow, [(2,5), (3,2), (10,3)]) → 32 9 1000 for args in iterable: yield function(*args)
- itertools.takewhile(predicate, iterable)¶
创建一个迭代器,只要 *predicate* 为真,就从 *iterable* 中返回元素。大致相当于
def takewhile(predicate, iterable): # takewhile(lambda x: x<5, [1,4,6,3,8]) → 1 4 for x in iterable: if not predicate(x): break yield x
注意,第一个不满足谓词条件的元素将从输入迭代器中消耗掉,并且无法访问它。如果应用程序希望在 *takewhile* 运行到耗尽后进一步使用输入迭代器,这可能是一个问题。要解决此问题,请考虑改用 more-iterools before_and_after()。
- itertools.tee(iterable, n=2)¶
从单个可迭代对象返回 *n* 个独立的迭代器。
大致相当于
def tee(iterable, n=2): iterator = iter(iterable) shared_link = [None, None] return tuple(_tee(iterator, shared_link) for _ in range(n)) def _tee(iterator, link): try: while True: if link[1] is None: link[0] = next(iterator) link[1] = [None, None] value, link = link yield value except StopIteration: return
创建
tee()
后,不应在其他任何地方使用原始的 *iterable*;否则,*iterable* 可能会在 tee 对象不知情的情况下被提前。tee
迭代器不是线程安全的。即使原始的 *iterable* 是线程安全的,在同时使用由同一个tee()
调用返回的迭代器时,也可能会引发RuntimeError
。此 itertool 可能需要大量的辅助存储空间(取决于需要存储多少临时数据)。通常,如果一个迭代器在另一个迭代器启动之前使用大部分或全部数据,则使用
list()
比使用tee()
更快。
- itertools.zip_longest(*iterables, fillvalue=None)¶
创建一个迭代器,用于聚合来自每个 *iterables* 的元素。
如果可迭代对象的长度不一致,则使用 *fillvalue* 填充缺失值。如果未指定,则 *fillvalue* 默认为
None
。迭代将持续到最长的可迭代对象耗尽。
大致相当于
def zip_longest(*iterables, fillvalue=None): # zip_longest('ABCD', 'xy', fillvalue='-') → Ax By C- D- iterators = list(map(iter, iterables)) num_active = len(iterators) if not num_active: return while True: values = [] for i, iterator in enumerate(iterators): try: value = next(iterator) except StopIteration: num_active -= 1 if not num_active: return iterators[i] = repeat(fillvalue) value = fillvalue values.append(value) yield tuple(values)
如果其中一个可迭代对象可能是无限的,则应使用限制调用次数的内容(例如
islice()
或takewhile()
)包装zip_longest()
函数。
Itertools 配方¶
本节展示了使用现有 itertools 作为构建块创建扩展工具集的配方。
itertools 配方的主要目的是教育意义。这些配方展示了思考单个工具的各种方式,例如,chain.from_iterable
与扁平化的概念有关。这些配方还提供了如何组合工具的想法,例如,starmap()
和 repeat()
如何协同工作。这些配方还展示了将 itertools 与 operator
和 collections
模块以及内置 itertools(如 map()
、filter()
、reversed()
和 enumerate()
)一起使用的模式。
配方的第二个目的是充当孵化器。accumulate()
、compress()
和 pairwise()
itertools 最初都是作为配方出现的。目前,sliding_window()
、iter_index()
和 sieve()
配方正在测试中,以查看它们是否具有价值。
几乎所有这些配方以及许多其他配方都可以从 Python 包索引中找到的 more-itertools 项目中安装
python -m pip install more-itertools
许多配方都提供与底层工具集相同的高性能。通过一次处理一个元素而不是一次将整个可迭代对象加载到内存中,可以保持卓越的内存性能。通过以 函数式风格 将工具链接在一起,代码量保持较小。通过优先选择“向量化”构建块而不是使用 for 循环和 生成器(它们会产生解释器开销),可以保持高速。
import collections
import contextlib
import functools
import math
import operator
import random
def take(n, iterable):
"Return first n items of the iterable as a list."
return list(islice(iterable, n))
def prepend(value, iterable):
"Prepend a single value in front of an iterable."
# prepend(1, [2, 3, 4]) → 1 2 3 4
return chain([value], iterable)
def tabulate(function, start=0):
"Return function(0), function(1), ..."
return map(function, count(start))
def repeatfunc(func, times=None, *args):
"Repeat calls to func with specified arguments."
if times is None:
return starmap(func, repeat(args))
return starmap(func, repeat(args, times))
def flatten(list_of_lists):
"Flatten one level of nesting."
return chain.from_iterable(list_of_lists)
def ncycles(iterable, n):
"Returns the sequence elements n times."
return chain.from_iterable(repeat(tuple(iterable), n))
def tail(n, iterable):
"Return an iterator over the last n items."
# tail(3, 'ABCDEFG') → E F G
return iter(collections.deque(iterable, maxlen=n))
def consume(iterator, n=None):
"Advance the iterator n-steps ahead. If n is None, consume entirely."
# Use functions that consume iterators at C speed.
if n is None:
collections.deque(iterator, maxlen=0)
else:
next(islice(iterator, n, n), None)
def nth(iterable, n, default=None):
"Returns the nth item or a default value."
return next(islice(iterable, n, None), default)
def quantify(iterable, predicate=bool):
"Given a predicate that returns True or False, count the True results."
return sum(map(predicate, iterable))
def first_true(iterable, default=False, predicate=None):
"Returns the first true value or the *default* if there is no true value."
# first_true([a,b,c], x) → a or b or c or x
# first_true([a,b], x, f) → a if f(a) else b if f(b) else x
return next(filter(predicate, iterable), default)
def all_equal(iterable, key=None):
"Returns True if all the elements are equal to each other."
# all_equal('4٤௪౪໔', key=int) → True
return len(take(2, groupby(iterable, key))) <= 1
def unique_justseen(iterable, key=None):
"Yield unique elements, preserving order. Remember only the element just seen."
# unique_justseen('AAAABBBCCDAABBB') → A B C D A B
# unique_justseen('ABBcCAD', str.casefold) → A B c A D
if key is None:
return map(operator.itemgetter(0), groupby(iterable))
return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
def unique_everseen(iterable, key=None):
"Yield unique elements, preserving order. Remember all elements ever seen."
# unique_everseen('AAAABBBCCDAABBB') → A B C D
# unique_everseen('ABBcCAD', str.casefold) → A B c D
seen = set()
if key is None:
for element in filterfalse(seen.__contains__, iterable):
seen.add(element)
yield element
else:
for element in iterable:
k = key(element)
if k not in seen:
seen.add(k)
yield element
def unique(iterable, key=None, reverse=False):
"Yield unique elements in sorted order. Supports unhashable inputs."
# unique([[1, 2], [3, 4], [1, 2]]) → [1, 2] [3, 4]
return unique_justseen(sorted(iterable, key=key, reverse=reverse), key=key)
def sliding_window(iterable, n):
"Collect data into overlapping fixed-length chunks or blocks."
# sliding_window('ABCDEFG', 4) → ABCD BCDE CDEF DEFG
iterator = iter(iterable)
window = collections.deque(islice(iterator, n - 1), maxlen=n)
for x in iterator:
window.append(x)
yield tuple(window)
def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
"Collect data into non-overlapping fixed-length chunks or blocks."
# grouper('ABCDEFG', 3, fillvalue='x') → ABC DEF Gxx
# grouper('ABCDEFG', 3, incomplete='strict') → ABC DEF ValueError
# grouper('ABCDEFG', 3, incomplete='ignore') → ABC DEF
iterators = [iter(iterable)] * n
match incomplete:
case 'fill':
return zip_longest(*iterators, fillvalue=fillvalue)
case 'strict':
return zip(*iterators, strict=True)
case 'ignore':
return zip(*iterators)
case _:
raise ValueError('Expected fill, strict, or ignore')
def roundrobin(*iterables):
"Visit input iterables in a cycle until each is exhausted."
# roundrobin('ABC', 'D', 'EF') → A D E B F C
# Algorithm credited to George Sakkis
iterators = map(iter, iterables)
for num_active in range(len(iterables), 0, -1):
iterators = cycle(islice(iterators, num_active))
yield from map(next, iterators)
def partition(predicate, iterable):
"""Partition entries into false entries and true entries.
If *predicate* is slow, consider wrapping it with functools.lru_cache().
"""
# partition(is_odd, range(10)) → 0 2 4 6 8 and 1 3 5 7 9
t1, t2 = tee(iterable)
return filterfalse(predicate, t1), filter(predicate, t2)
def subslices(seq):
"Return all contiguous non-empty subslices of a sequence."
# subslices('ABCD') → A AB ABC ABCD B BC BCD C CD D
slices = starmap(slice, combinations(range(len(seq) + 1), 2))
return map(operator.getitem, repeat(seq), slices)
def iter_index(iterable, value, start=0, stop=None):
"Return indices where a value occurs in a sequence or iterable."
# iter_index('AABCADEAF', 'A') → 0 1 4 7
seq_index = getattr(iterable, 'index', None)
if seq_index is None:
iterator = islice(iterable, start, stop)
for i, element in enumerate(iterator, start):
if element is value or element == value:
yield i
else:
stop = len(iterable) if stop is None else stop
i = start
with contextlib.suppress(ValueError):
while True:
yield (i := seq_index(value, i, stop))
i += 1
def iter_except(func, exception, first=None):
"Convert a call-until-exception interface to an iterator interface."
# iter_except(d.popitem, KeyError) → non-blocking dictionary iterator
with contextlib.suppress(exception):
if first is not None:
yield first()
while True:
yield func()
以下配方具有更强的数学风格
def powerset(iterable):
"powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
def sum_of_squares(iterable):
"Add up the squares of the input values."
# sum_of_squares([10, 20, 30]) → 1400
return math.sumprod(*tee(iterable))
def reshape(matrix, cols):
"Reshape a 2-D matrix to have a given number of columns."
# reshape([(0, 1), (2, 3), (4, 5)], 3) → (0, 1, 2), (3, 4, 5)
return batched(chain.from_iterable(matrix), cols)
def transpose(matrix):
"Swap the rows and columns of a 2-D matrix."
# transpose([(1, 2, 3), (11, 22, 33)]) → (1, 11) (2, 22) (3, 33)
return zip(*matrix, strict=True)
def matmul(m1, m2):
"Multiply two matrices."
# matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) → (49, 80), (41, 60)
n = len(m2[0])
return batched(starmap(math.sumprod, product(m1, transpose(m2))), n)
def convolve(signal, kernel):
"""Discrete linear convolution of two iterables.
Equivalent to polynomial multiplication.
Convolutions are mathematically commutative; however, the inputs are
evaluated differently. The signal is consumed lazily and can be
infinite. The kernel is fully consumed before the calculations begin.
Article: https://betterexplained.com/articles/intuitive-convolution/
Video: https://www.youtube.com/watch?v=KuXjwB4LzSA
"""
# convolve([1, -1, -20], [1, -3]) → 1 -4 -17 60
# convolve(data, [0.25, 0.25, 0.25, 0.25]) → Moving average (blur)
# convolve(data, [1/2, 0, -1/2]) → 1st derivative estimate
# convolve(data, [1, -2, 1]) → 2nd derivative estimate
kernel = tuple(kernel)[::-1]
n = len(kernel)
padded_signal = chain(repeat(0, n-1), signal, repeat(0, n-1))
windowed_signal = sliding_window(padded_signal, n)
return map(math.sumprod, repeat(kernel), windowed_signal)
def polynomial_from_roots(roots):
"""Compute a polynomial's coefficients from its roots.
(x - 5) (x + 4) (x - 3) expands to: x³ -4x² -17x + 60
"""
# polynomial_from_roots([5, -4, 3]) → [1, -4, -17, 60]
factors = zip(repeat(1), map(operator.neg, roots))
return list(functools.reduce(convolve, factors, [1]))
def polynomial_eval(coefficients, x):
"""Evaluate a polynomial at a specific value.
Computes with better numeric stability than Horner's method.
"""
# Evaluate x³ -4x² -17x + 60 at x = 5
# polynomial_eval([1, -4, -17, 60], x=5) → 0
n = len(coefficients)
if not n:
return type(x)(0)
powers = map(pow, repeat(x), reversed(range(n)))
return math.sumprod(coefficients, powers)
def polynomial_derivative(coefficients):
"""Compute the first derivative of a polynomial.
f(x) = x³ -4x² -17x + 60
f'(x) = 3x² -8x -17
"""
# polynomial_derivative([1, -4, -17, 60]) → [3, -8, -17]
n = len(coefficients)
powers = reversed(range(1, n))
return list(map(operator.mul, coefficients, powers))
def sieve(n):
"Primes less than n."
# sieve(30) → 2 3 5 7 11 13 17 19 23 29
if n > 2:
yield 2
data = bytearray((0, 1)) * (n // 2)
for p in iter_index(data, 1, start=3, stop=math.isqrt(n) + 1):
data[p*p : n : p+p] = bytes(len(range(p*p, n, p+p)))
yield from iter_index(data, 1, start=3)
def factor(n):
"Prime factors of n."
# factor(99) → 3 3 11
# factor(1_000_000_000_000_007) → 47 59 360620266859
# factor(1_000_000_000_000_403) → 1000000000000403
for prime in sieve(math.isqrt(n) + 1):
while not n % prime:
yield prime
n //= prime
if n == 1:
return
if n > 1:
yield n
def totient(n):
"Count of natural numbers up to n that are coprime to n."
# https://mathworld.wolfram.com/TotientFunction.html
# totient(12) → 4 because len([1, 5, 7, 11]) == 4
for prime in set(factor(n)):
n -= n // prime
return n