itertools
--- 为高效循环创建迭代器的函数¶
此模块实现了一些 迭代器 构建块,其灵感来自于 APL、Haskell 和 SML 中的构想。每个构建块都经过重构,以使其适用于 Python。
该模块将一组快速、高效利用内存的工具标准化,这些工具本身或组合起来都很有用。它们共同构成了一个“迭代器代数”,使得在纯 Python 中能够简洁高效地构建专用工具。
例如,SML 提供了一个制表工具:tabulate(f)
,它可以生成序列 f(0), f(1), ...
。在 Python 中,可以通过组合 map()
和 count()
形成 map(f, count())
来实现相同的效果。
无限迭代器
迭代器 |
参数 |
结果 |
示例 |
---|---|---|---|
[start[, step]] |
start, start+step, start+2*step, … |
|
|
p |
p0, p1, … plast, p0, p1, … |
|
|
elem [,n] |
elem, elem, elem, … 无限重复或重复 n 次 |
|
在最短输入序列上终止的迭代器
迭代器 |
参数 |
结果 |
示例 |
---|---|---|---|
p [,func] |
p0, p0+p1, p0+p1+p2, … |
|
|
p, n |
(p0, p1, …, p_n-1), … |
|
|
p, q, … |
p0, p1, … plast, q0, q1, … |
|
|
可迭代对象 |
p0, p1, … plast, q0, q1, … |
|
|
data, selectors |
(d[0] if s[0]), (d[1] if s[1]), … |
|
|
predicate, seq |
seq[n], seq[n+1], 从 predicate 为假时开始 |
|
|
predicate, seq |
seq 中 predicate(elem) 为假的元素 |
|
|
iterable[, key] |
按 key(v) 值分组的子迭代器 |
|
|
seq, [start,] stop [, step] |
来自 seq[start:stop:step] 的元素 |
|
|
可迭代对象 |
(p[0], p[1]), (p[1], p[2]) |
|
|
func, seq |
func(*seq[0]), func(*seq[1]), … |
|
|
predicate, seq |
seq[0], seq[1], 直到 predicate 为假 |
|
|
it, n |
it1, it2, … itn 将一个迭代器拆分为 n 个 |
|
|
p, q, … |
(p[0], q[0]), (p[1], q[1]), … |
|
组合迭代器
迭代器 |
参数 |
结果 |
---|---|---|
p, q, … [repeat=1] |
笛卡尔积,相当于嵌套的 for 循环 |
|
p[, r] |
长度为 r 的元组,所有可能的排序,无重复元素 |
|
p, r |
长度为 r 的元组,按排序顺序,无重复元素 |
|
p, r |
长度为 r 的元组,按排序顺序,有重复元素 |
示例: |
结果 |
---|---|
|
|
|
|
|
|
|
|
Itertool 函数¶
以下所有函数都用于构建并返回迭代器。有些提供无限长度的数据流,因此它们只应被那些会截断数据流的函数或循环访问。
- itertools.accumulate(iterable[, function, *, initial=None])¶
创建一个迭代器,返回累加的和或来自其他二元函数的累加结果。
function 默认为加法。function 应接受两个参数:一个累加的总值和一个来自 iterable 的值。
如果提供了 initial 值,累加将从该值开始,并且输出将比输入的可迭代对象多一个元素。
大致相当于:
def accumulate(iterable, function=operator.add, *, initial=None): 'Return running totals' # accumulate([1,2,3,4,5]) → 1 3 6 10 15 # accumulate([1,2,3,4,5], initial=100) → 100 101 103 106 110 115 # accumulate([1,2,3,4,5], operator.mul) → 1 2 6 24 120 iterator = iter(iterable) total = initial if initial is None: try: total = next(iterator) except StopIteration: return yield total for element in iterator: total = function(total, element) yield total
要计算一个运行中的最小值,请将 function 设置为
min()
。对于运行中的最大值,请将 function 设置为max()
。或者对于运行中的乘积,请将 function 设置为operator.mul()
。要构建摊销表,可累加利息并应用付款。>>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8] >>> list(accumulate(data, max)) # running maximum [3, 4, 6, 6, 6, 9, 9, 9, 9, 9] >>> list(accumulate(data, operator.mul)) # running product [3, 12, 72, 144, 144, 1296, 0, 0, 0, 0] # Amortize a 5% loan of 1000 with 10 annual payments of 90 >>> update = lambda balance, payment: round(balance * 1.05) - payment >>> list(accumulate(repeat(90, 10), update, initial=1_000)) [1000, 960, 918, 874, 828, 779, 728, 674, 618, 559, 497]
另请参阅
functools.reduce()
,它是一个类似函数,只返回最终的累加值。在 3.2 版本加入。
在 3.3 版本发生变更: 添加了可选的 function 形参。
在 3.8 版本发生变更: 添加了可选的 initial 形参。
- itertools.batched(iterable, n, *, strict=False)¶
将来自 iterable 的数据分批成长度为 n 的元组。最后一批的长度可能小于 n。
如果 strict 为真,且最后一批的长度小于 n,将引发
ValueError
。遍历输入的可迭代对象,并将数据累积成大小至多为 n 的元组。输入是被惰性消耗的,仅消耗足以填充一批的数量。一旦批次满员或输入的可迭代对象耗尽,就会产生结果。
>>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet'] >>> unflattened = list(batched(flattened_data, 2)) >>> unflattened [('roses', 'red'), ('violets', 'blue'), ('sugar', 'sweet')]
大致相当于:
def batched(iterable, n, *, strict=False): # batched('ABCDEFG', 2) → AB CD EF G if n < 1: raise ValueError('n must be at least one') iterator = iter(iterable) while batch := tuple(islice(iterator, n)): if strict and len(batch) != n: raise ValueError('batched(): incomplete batch') yield batch
3.12 新版功能.
在 3.13 版本发生变更: 添加了 strict 选项。
- itertools.chain(*iterables)¶
创建一个迭代器,它从第一个可迭代对象中返回元素,直到耗尽,然后继续到下一个可迭代对象,直到所有的可迭代对象都被耗尽。这将多个数据源组合成一个单一的迭代器。大致等同于:
def chain(*iterables): # chain('ABC', 'DEF') → A B C D E F for iterable in iterables: yield from iterable
- classmethod chain.from_iterable(iterable)¶
chain()
的备用构造函数。从一个惰性求值的单一可迭代对象参数中获取链式输入。大致等同于:def from_iterable(iterables): # chain.from_iterable(['ABC', 'DEF']) → A B C D E F for iterable in iterables: yield from iterable
- itertools.combinations(iterable, r)¶
返回输入 iterable 中元素的长度为 r 的子序列。
输出是
product()
的一个子序列,只保留那些是 iterable 的子序列的条目。输出的长度由math.comb()
给出,当0 ≤ r ≤ n
时计算n! / r! / (n - r)!
,当r > n
时为零。组合元组按输入 iterable 的顺序以字典序发出。如果输入 iterable 是排序的,则输出的元组将按排序顺序生成。
元素根据其位置而不是其值被视为唯一的。如果输入元素是唯一的,则每个组合中将不会有重复的值。
大致相当于:
def combinations(iterable, r): # combinations('ABCD', 2) → AB AC AD BC BD CD # combinations(range(4), 3) → 012 013 023 123 pool = tuple(iterable) n = len(pool) if r > n: return indices = list(range(r)) yield tuple(pool[i] for i in indices) while True: for i in reversed(range(r)): if indices[i] != i + n - r: break else: return indices[i] += 1 for j in range(i+1, r): indices[j] = indices[j-1] + 1 yield tuple(pool[i] for i in indices)
- itertools.combinations_with_replacement(iterable, r)¶
返回输入 iterable 中元素的长度为 r 的子序列,允许单个元素重复多次。
输出是
product()
的一个子序列,只保留那些是 iterable 的子序列(可能带有重复元素)的条目。当n > 0
时,返回的子序列数量为(n + r - 1)! / r! / (n - 1)!
。组合元组按输入 iterable 的顺序以字典序发出。如果输入 iterable 是排序的,则输出的元组将按排序顺序生成。
元素根据其位置而不是其值被视为唯一的。如果输入元素是唯一的,则生成的组合也将是唯一的。
大致相当于:
def combinations_with_replacement(iterable, r): # combinations_with_replacement('ABC', 2) → AA AB AC BB BC CC pool = tuple(iterable) n = len(pool) if not n and r: return indices = [0] * r yield tuple(pool[i] for i in indices) while True: for i in reversed(range(r)): if indices[i] != n - 1: break else: return indices[i:] = [indices[i] + 1] * (r - i) yield tuple(pool[i] for i in indices)
在 3.1 版本加入。
- itertools.compress(data, selectors)¶
创建一个迭代器,它返回 data 中那些在 selectors 中相应元素为真的元素。当 data 或 selectors 的任一可迭代对象被耗尽时停止。大致等同于:
def compress(data, selectors): # compress('ABCDEF', [1,0,1,0,1,1]) → A C E F return (datum for datum, selector in zip(data, selectors) if selector)
在 3.1 版本加入。
- itertools.count(start=0, step=1)¶
创建一个迭代器,返回从 start 开始的等间距值。可与
map()
一起生成连续的数据点,或与zip()
一起添加序列号。大致等同于:def count(start=0, step=1): # count(10) → 10 11 12 13 14 ... # count(2.5, 0.5) → 2.5 3.0 3.5 ... n = start while True: yield n n += step
当使用浮点数计数时,通过使用乘法代码,有时可以获得更好的精度,例如:
(start + step * i for i in count())
。在 3.1 版本发生变更: 增加了 step 参数并允许非整数参数。
- itertools.cycle(iterable)¶
创建一个迭代器,返回来自 iterable 的元素,并保存每个元素的副本。当可迭代对象被耗尽时,从保存的副本中返回元素。无限重复。大致等同于:
def cycle(iterable): # cycle('ABCD') → A B C D A B C D A B C D ... saved = [] for element in iterable: yield element saved.append(element) while saved: for element in saved: yield element
此迭代工具可能需要大量的辅助存储空间(取决于可迭代对象的长度)。
- itertools.dropwhile(predicate, iterable)¶
创建一个迭代器,只要 predicate 为真,就从 iterable 中丢弃元素,之后返回每个元素。大致等同于:
def dropwhile(predicate, iterable): # dropwhile(lambda x: x<5, [1,4,6,3,8]) → 6 3 8 iterator = iter(iterable) for x in iterator: if not predicate(x): yield x break for x in iterator: yield x
请注意,在断言首次变为假之前,它不会产生*任何*输出,因此这个迭代工具可能会有很长的启动时间。
- itertools.filterfalse(predicate, iterable)¶
创建一个迭代器,它从 iterable 中过滤元素,只返回那些 predicate 返回假值的元素。如果 predicate 为
None
,则返回值为假的项目。大致等同于:def filterfalse(predicate, iterable): # filterfalse(lambda x: x<5, [1,4,6,3,8]) → 6 8 if predicate is None: predicate = bool for x in iterable: if not predicate(x): yield x
- itertools.groupby(iterable, key=None)¶
创建一个迭代器,返回来自 iterable 的连续键和组。key 是一个计算每个元素的键值的函数。如果未指定或为
None
,key 默认为一个恒等函数并返回元素本身。通常,可迭代对象需要已经按相同的键函数排序。groupby()
的操作类似于 Unix 中的uniq
过滤器。每当键函数的值发生变化时,它就会生成一个中断或新组(这就是为什么通常需要使用相同的键函数对数据进行排序的原因)。这种行为不同于 SQL 的 GROUP BY,后者聚合公共元素,而不考虑它们的输入顺序。返回的组本身是一个迭代器,它与
groupby()
共享底层的可迭代对象。因为源是共享的,当groupby()
对象前进时,前一个组就不再可见。因此,如果以后需要这些数据,应该将其存储为一个列表。groups = [] uniquekeys = [] data = sorted(data, key=keyfunc) for k, g in groupby(data, keyfunc): groups.append(list(g)) # Store group iterator as a list uniquekeys.append(k)
groupby()
大致等同于:def groupby(iterable, key=None): # [k for k, g in groupby('AAAABBBCCDAABBB')] → A B C D A B # [list(g) for k, g in groupby('AAAABBBCCD')] → AAAA BBB CC D keyfunc = (lambda x: x) if key is None else key iterator = iter(iterable) exhausted = False def _grouper(target_key): nonlocal curr_value, curr_key, exhausted yield curr_value for curr_value in iterator: curr_key = keyfunc(curr_value) if curr_key != target_key: return yield curr_value exhausted = True try: curr_value = next(iterator) except StopIteration: return curr_key = keyfunc(curr_value) while not exhausted: target_key = curr_key curr_group = _grouper(target_key) yield curr_key, curr_group if curr_key == target_key: for _ in curr_group: pass
- itertools.islice(iterable, stop)¶
- itertools.islice(iterable, start, stop[, step])
创建一个迭代器,返回可迭代对象中的选定元素。工作方式类似于序列切片,但不支持 start、stop 或 step 的负值。
如果 start 为零或
None
,迭代从零开始。否则,跳过可迭代对象中的元素,直到达到 start。如果 stop 是
None
,迭代将持续进行直到输入耗尽(如果会耗尽的话)。 否则,它会在指定位置停止。如果 step 是
None
,则步长默认为 1。元素会连续地被返回,除非 step 设置为大于 1 的值,这会导致一些条目被跳过。大致相当于:
def islice(iterable, *args): # islice('ABCDEFG', 2) → A B # islice('ABCDEFG', 2, 4) → C D # islice('ABCDEFG', 2, None) → C D E F G # islice('ABCDEFG', 0, None, 2) → A C E G s = slice(*args) start = 0 if s.start is None else s.start stop = s.stop step = 1 if s.step is None else s.step if start < 0 or (stop is not None and stop < 0) or step <= 0: raise ValueError indices = count() if stop is None else range(max(start, stop)) next_i = start for i, element in zip(indices, iterable): if i == next_i: yield element next_i += step
如果输入是一个迭代器,那么完全消耗 islice 会使输入迭代器前进
max(start, stop)
步,而与 step 的值无关。
- itertools.pairwise(iterable)¶
返回从输入 iterable 中获取的连续重叠对。
输出迭代器中的二元组数量将比输入数量少一。如果输入的可迭代对象的值少于两个,它将为空。
大致相当于:
def pairwise(iterable): # pairwise('ABCDEFG') → AB BC CD DE EF FG iterator = iter(iterable) a = next(iterator, None) for b in iterator: yield a, b a = b
在 3.10 版本加入。
- itertools.permutations(iterable, r=None)¶
返回来自 iterable 的元素的连续 r 长度排列。
如果 r 未指定或为
None
,则 r 默认为 iterable 的长度,并生成所有可能的全长排列。输出是
product()
的一个子序列,其中已过滤掉有重复元素的条目。输出的长度由math.perm()
给出,当0 ≤ r ≤ n
时计算n! / (n - r)!
,当r > n
时为零。排列元组按输入 iterable 的顺序以字典序发出。如果输入 iterable 是排序的,则输出的元组将按排序顺序生成。
元素根据其位置而不是其值被视为唯一的。如果输入元素是唯一的,则排列中不会有重复的值。
大致相当于:
def permutations(iterable, r=None): # permutations('ABCD', 2) → AB AC AD BA BC BD CA CB CD DA DB DC # permutations(range(3)) → 012 021 102 120 201 210 pool = tuple(iterable) n = len(pool) r = n if r is None else r if r > n: return indices = list(range(n)) cycles = list(range(n, n-r, -1)) yield tuple(pool[i] for i in indices[:r]) while n: for i in reversed(range(r)): cycles[i] -= 1 if cycles[i] == 0: indices[i:] = indices[i+1:] + indices[i:i+1] cycles[i] = n - i else: j = cycles[i] indices[i], indices[-j] = indices[-j], indices[i] yield tuple(pool[i] for i in indices[:r]) break else: return
- itertools.product(*iterables, repeat=1)¶
输入可迭代对象的笛卡尔积。
大致相当于生成器表达式中的嵌套 for 循环。例如,
product(A, B)
返回的结果与((x,y) for x in A for y in B)
相同。嵌套的循环像里程表一样循环,最右边的元素在每次迭代时都会前进。这种模式创建了一个字典序的排序,因此如果输入的迭代器是排序的,那么乘积元组将以排序的顺序发出。
要计算一个可迭代对象与自身的乘积,请使用可选的 repeat 关键字参数指定重复次数。例如,
product(A, repeat=4)
与product(A, A, A, A)
的意思相同。此函数大致等同于以下代码,但实际实现不会在内存中构建中间结果:
def product(*iterables, repeat=1): # product('ABCD', 'xy') → Ax Ay Bx By Cx Cy Dx Dy # product(range(2), repeat=3) → 000 001 010 011 100 101 110 111 if repeat < 0: raise ValueError('repeat argument cannot be negative') pools = [tuple(pool) for pool in iterables] * repeat result = [[]] for pool in pools: result = [x+[y] for x in result for y in pool] for prod in result: yield tuple(prod)
在
product()
运行之前,它会完全消耗输入的迭代器,将值池保存在内存中以生成乘积。因此,它只适用于有限的输入。
- itertools.repeat(object[, times])¶
创建一个迭代器,它一次又一次地返回 object。除非指定了 times 参数,否则它会无限期运行。
大致相当于:
def repeat(object, times=None): # repeat(10, 3) → 10 10 10 if times is None: while True: yield object else: for i in range(times): yield object
repeat 的一个常见用途是向 map 或 zip 提供一个常量值流:
>>> list(map(pow, range(10), repeat(2))) [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
- itertools.starmap(function, iterable)¶
创建一个迭代器,使用从 iterable 获得的参数来计算 function。当参数已经“预先打包”成元组时,使用它来代替
map()
。map()
和starmap()
之间的区别与function(a,b)
和function(*c)
之间的区别相类似。大致等同于:def starmap(function, iterable): # starmap(pow, [(2,5), (3,2), (10,3)]) → 32 9 1000 for args in iterable: yield function(*args)
- itertools.takewhile(predicate, iterable)¶
创建一个迭代器,只要 predicate 为真,就从 iterable 中返回元素。大致等同于:
def takewhile(predicate, iterable): # takewhile(lambda x: x<5, [1,4,6,3,8]) → 1 4 for x in iterable: if not predicate(x): break yield x
注意,第一个不满足谓词条件的元素会从输入迭代器中消耗掉,并且无法访问它。如果应用程序希望在 takewhile 运行完毕后继续消耗输入迭代器,这可能会成为一个问题。为了解决这个问题,可以考虑使用 more-itertools before_and_after()。
- itertools.tee(iterable, n=2)¶
从单个可迭代对象返回 n 个独立的迭代器。
大致相当于:
def tee(iterable, n=2): if n < 0: raise ValueError if n == 0: return () iterator = _tee(iterable) result = [iterator] for _ in range(n - 1): result.append(_tee(iterator)) return tuple(result) class _tee: def __init__(self, iterable): it = iter(iterable) if isinstance(it, _tee): self.iterator = it.iterator self.link = it.link else: self.iterator = it self.link = [None, None] def __iter__(self): return self def __next__(self): link = self.link if link[1] is None: link[0] = next(self.iterator) link[1] = [None, None] value, self.link = link return value
当输入 iterable 已经是一个 tee 迭代器对象时,返回的元组的所有成员的构造方式就好像它们是由上游的
tee()
调用产生的一样。这种“扁平化步骤”允许嵌套的tee()
调用共享相同的底层数据链,并且只有一个更新步骤而不是一连串的调用。扁平化特性使得 tee 迭代器可以高效地进行窥视:
def lookahead(tee_iterator): "Return the next value without moving the input forward" [forked_iterator] = tee(tee_iterator, 1) return next(forked_iterator)
>>> iterator = iter('abcdef') >>> [iterator] = tee(iterator, 1) # Make the input peekable >>> next(iterator) # Move the iterator forward 'a' >>> lookahead(iterator) # Check next value 'b' >>> next(iterator) # Continue moving forward 'b'
tee
迭代器不是线程安全的。当同时使用由同一个tee()
调用返回的迭代器时,可能会引发RuntimeError
,即使原始的 iterable 是线程安全的。这个迭代工具可能需要大量的辅助存储空间(取决于需要存储多少临时数据)。一般来说,如果一个迭代器在另一个迭代器开始之前使用了大部分或全部数据,使用
list()
会比使用tee()
更快。
- itertools.zip_longest(*iterables, fillvalue=None)¶
创建一个迭代器,聚合来自每个 iterables 的元素。
如果可迭代对象的长度不均匀,缺失的值将用 fillvalue 填充。如果未指定,fillvalue 默认为
None
。迭代将持续到最长的可迭代对象被耗尽。
大致相当于:
def zip_longest(*iterables, fillvalue=None): # zip_longest('ABCD', 'xy', fillvalue='-') → Ax By C- D- iterators = list(map(iter, iterables)) num_active = len(iterators) if not num_active: return while True: values = [] for i, iterator in enumerate(iterators): try: value = next(iterator) except StopIteration: num_active -= 1 if not num_active: return iterators[i] = repeat(fillvalue) value = fillvalue values.append(value) yield tuple(values)
如果其中一个可迭代对象可能是无限的,那么
zip_longest()
函数应该被包装在限制调用次数的东西中(例如islice()
或takewhile()
)。
Itertools 配方¶
本节展示了使用现有的 itertools 作为构建块来创建扩展工具集的配方。
itertools 配方的主要目的是教育。这些配方展示了思考单个工具的各种方式——例如,chain.from_iterable
与扁平化的概念有关。这些配方也提供了关于工具如何组合的想法——例如,starmap()
和 repeat()
如何协同工作。这些配方还展示了将 itertools 与 operator
和 collections
模块以及内置的 itertools(如 map()
、filter()
、reversed()
和 enumerate()
)一起使用的模式。
配方的次要目的是作为孵化器。 accumulate()
, compress()
, and pairwise()
等 itertools 工具最初都是以配方的形式出现的。 当前,sliding_window()
, iter_index()
和 sieve()
等配方正在测试中,以检验它们是否能证明其价值。
基本上所有这些配方以及许多许多其他的配方都可以从 Python 包索引上的 more-itertools 项目安装。
python -m pip install more-itertools
许多配方提供了与底层工具集相同的高性能。通过一次处理一个元素而不是将整个可迭代对象一次性加载到内存中,保持了卓越的内存性能。通过以函数式风格链接工具,代码量得以保持较小。通过优先选择“矢量化”构建块,而不是使用会产生解释器开销的 for 循环和生成器,从而保持了高速。
from collections import Counter, deque
from contextlib import suppress
from functools import reduce
from math import comb, prod, sumprod, isqrt
from operator import itemgetter, getitem, mul, neg
def take(n, iterable):
"Return first n items of the iterable as a list."
return list(islice(iterable, n))
def prepend(value, iterable):
"Prepend a single value in front of an iterable."
# prepend(1, [2, 3, 4]) → 1 2 3 4
return chain([value], iterable)
def tabulate(function, start=0):
"Return function(0), function(1), ..."
return map(function, count(start))
def repeatfunc(function, times=None, *args):
"Repeat calls to a function with specified arguments."
if times is None:
return starmap(function, repeat(args))
return starmap(function, repeat(args, times))
def flatten(list_of_lists):
"Flatten one level of nesting."
return chain.from_iterable(list_of_lists)
def ncycles(iterable, n):
"Returns the sequence elements n times."
return chain.from_iterable(repeat(tuple(iterable), n))
def loops(n):
"Loop n times. Like range(n) but without creating integers."
# for _ in loops(100): ...
return repeat(None, n)
def tail(n, iterable):
"Return an iterator over the last n items."
# tail(3, 'ABCDEFG') → E F G
return iter(deque(iterable, maxlen=n))
def consume(iterator, n=None):
"Advance the iterator n-steps ahead. If n is None, consume entirely."
# Use functions that consume iterators at C speed.
if n is None:
deque(iterator, maxlen=0)
else:
next(islice(iterator, n, n), None)
def nth(iterable, n, default=None):
"Returns the nth item or a default value."
return next(islice(iterable, n, None), default)
def quantify(iterable, predicate=bool):
"Given a predicate that returns True or False, count the True results."
return sum(map(predicate, iterable))
def first_true(iterable, default=False, predicate=None):
"Returns the first true value or the *default* if there is no true value."
# first_true([a,b,c], x) → a or b or c or x
# first_true([a,b], x, f) → a if f(a) else b if f(b) else x
return next(filter(predicate, iterable), default)
def all_equal(iterable, key=None):
"Returns True if all the elements are equal to each other."
# all_equal('4٤௪౪໔', key=int) → True
return len(take(2, groupby(iterable, key))) <= 1
def unique_justseen(iterable, key=None):
"Yield unique elements, preserving order. Remember only the element just seen."
# unique_justseen('AAAABBBCCDAABBB') → A B C D A B
# unique_justseen('ABBcCAD', str.casefold) → A B c A D
if key is None:
return map(itemgetter(0), groupby(iterable))
return map(next, map(itemgetter(1), groupby(iterable, key)))
def unique_everseen(iterable, key=None):
"Yield unique elements, preserving order. Remember all elements ever seen."
# unique_everseen('AAAABBBCCDAABBB') → A B C D
# unique_everseen('ABBcCAD', str.casefold) → A B c D
seen = set()
if key is None:
for element in filterfalse(seen.__contains__, iterable):
seen.add(element)
yield element
else:
for element in iterable:
k = key(element)
if k not in seen:
seen.add(k)
yield element
def unique(iterable, key=None, reverse=False):
"Yield unique elements in sorted order. Supports unhashable inputs."
# unique([[1, 2], [3, 4], [1, 2]]) → [1, 2] [3, 4]
sequenced = sorted(iterable, key=key, reverse=reverse)
return unique_justseen(sequenced, key=key)
def sliding_window(iterable, n):
"Collect data into overlapping fixed-length chunks or blocks."
# sliding_window('ABCDEFG', 4) → ABCD BCDE CDEF DEFG
iterator = iter(iterable)
window = deque(islice(iterator, n - 1), maxlen=n)
for x in iterator:
window.append(x)
yield tuple(window)
def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
"Collect data into non-overlapping fixed-length chunks or blocks."
# grouper('ABCDEFG', 3, fillvalue='x') → ABC DEF Gxx
# grouper('ABCDEFG', 3, incomplete='strict') → ABC DEF ValueError
# grouper('ABCDEFG', 3, incomplete='ignore') → ABC DEF
iterators = [iter(iterable)] * n
match incomplete:
case 'fill':
return zip_longest(*iterators, fillvalue=fillvalue)
case 'strict':
return zip(*iterators, strict=True)
case 'ignore':
return zip(*iterators)
case _:
raise ValueError('Expected fill, strict, or ignore')
def roundrobin(*iterables):
"Visit input iterables in a cycle until each is exhausted."
# roundrobin('ABC', 'D', 'EF') → A D E B F C
# Algorithm credited to George Sakkis
iterators = map(iter, iterables)
for num_active in range(len(iterables), 0, -1):
iterators = cycle(islice(iterators, num_active))
yield from map(next, iterators)
def subslices(seq):
"Return all contiguous non-empty subslices of a sequence."
# subslices('ABCD') → A AB ABC ABCD B BC BCD C CD D
slices = starmap(slice, combinations(range(len(seq) + 1), 2))
return map(getitem, repeat(seq), slices)
def iter_index(iterable, value, start=0, stop=None):
"Return indices where a value occurs in a sequence or iterable."
# iter_index('AABCADEAF', 'A') → 0 1 4 7
seq_index = getattr(iterable, 'index', None)
if seq_index is None:
iterator = islice(iterable, start, stop)
for i, element in enumerate(iterator, start):
if element is value or element == value:
yield i
else:
stop = len(iterable) if stop is None else stop
i = start
with suppress(ValueError):
while True:
yield (i := seq_index(value, i, stop))
i += 1
def iter_except(function, exception, first=None):
"Convert a call-until-exception interface to an iterator interface."
# iter_except(d.popitem, KeyError) → non-blocking dictionary iterator
with suppress(exception):
if first is not None:
yield first()
while True:
yield function()
以下配方更具数学风格:
def multinomial(*counts):
"Number of distinct arrangements of a multiset."
# Counter('abracadabra').values() → 5 2 2 1 1
# multinomial(5, 2, 2, 1, 1) → 83160
return prod(map(comb, accumulate(counts), counts))
def powerset(iterable):
"Subsequences of the iterable from shortest to longest."
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
def sum_of_squares(iterable):
"Add up the squares of the input values."
# sum_of_squares([10, 20, 30]) → 1400
return sumprod(*tee(iterable))
def reshape(matrix, columns):
"Reshape a 2-D matrix to have a given number of columns."
# reshape([(0, 1), (2, 3), (4, 5)], 3) → (0, 1, 2), (3, 4, 5)
return batched(chain.from_iterable(matrix), columns, strict=True)
def transpose(matrix):
"Swap the rows and columns of a 2-D matrix."
# transpose([(1, 2, 3), (11, 22, 33)]) → (1, 11) (2, 22) (3, 33)
return zip(*matrix, strict=True)
def matmul(m1, m2):
"Multiply two matrices."
# matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) → (49, 80), (41, 60)
n = len(m2[0])
return batched(starmap(sumprod, product(m1, transpose(m2))), n)
def convolve(signal, kernel):
"""Discrete linear convolution of two iterables.
Equivalent to polynomial multiplication.
Convolutions are mathematically commutative; however, the inputs are
evaluated differently. The signal is consumed lazily and can be
infinite. The kernel is fully consumed before the calculations begin.
Article: https://betterexplained.com/articles/intuitive-convolution/
Video: https://www.youtube.com/watch?v=KuXjwB4LzSA
"""
# convolve([1, -1, -20], [1, -3]) → 1 -4 -17 60
# convolve(data, [0.25, 0.25, 0.25, 0.25]) → Moving average (blur)
# convolve(data, [1/2, 0, -1/2]) → 1st derivative estimate
# convolve(data, [1, -2, 1]) → 2nd derivative estimate
kernel = tuple(kernel)[::-1]
n = len(kernel)
padded_signal = chain(repeat(0, n-1), signal, repeat(0, n-1))
windowed_signal = sliding_window(padded_signal, n)
return map(sumprod, repeat(kernel), windowed_signal)
def polynomial_from_roots(roots):
"""Compute a polynomial's coefficients from its roots.
(x - 5) (x + 4) (x - 3) expands to: x³ -4x² -17x + 60
"""
# polynomial_from_roots([5, -4, 3]) → [1, -4, -17, 60]
factors = zip(repeat(1), map(neg, roots))
return list(reduce(convolve, factors, [1]))
def polynomial_eval(coefficients, x):
"""Evaluate a polynomial at a specific value.
Computes with better numeric stability than Horner's method.
"""
# Evaluate x³ -4x² -17x + 60 at x = 5
# polynomial_eval([1, -4, -17, 60], x=5) → 0
n = len(coefficients)
if not n:
return type(x)(0)
powers = map(pow, repeat(x), reversed(range(n)))
return sumprod(coefficients, powers)
def polynomial_derivative(coefficients):
"""Compute the first derivative of a polynomial.
f(x) = x³ -4x² -17x + 60
f'(x) = 3x² -8x -17
"""
# polynomial_derivative([1, -4, -17, 60]) → [3, -8, -17]
n = len(coefficients)
powers = reversed(range(1, n))
return list(map(mul, coefficients, powers))
def sieve(n):
"Primes less than n."
# sieve(30) → 2 3 5 7 11 13 17 19 23 29
if n > 2:
yield 2
data = bytearray((0, 1)) * (n // 2)
for p in iter_index(data, 1, start=3, stop=isqrt(n) + 1):
data[p*p : n : p+p] = bytes(len(range(p*p, n, p+p)))
yield from iter_index(data, 1, start=3)
def factor(n):
"Prime factors of n."
# factor(99) → 3 3 11
# factor(1_000_000_000_000_007) → 47 59 360620266859
# factor(1_000_000_000_000_403) → 1000000000000403
for prime in sieve(isqrt(n) + 1):
while not n % prime:
yield prime
n //= prime
if n == 1:
return
if n > 1:
yield n
def is_prime(n):
"Return True if n is prime."
# is_prime(1_000_000_000_000_403) → True
return n > 1 and next(factor(n)) == n
def totient(n):
"Count of natural numbers up to n that are coprime to n."
# https://mathworld.net.cn/TotientFunction.html
# totient(12) → 4 because len([1, 5, 7, 11]) == 4
for prime in set(factor(n)):
n -= n // prime
return n