itertools
— 用于高效循环的迭代器创建函数¶
该模块实现了一些受 APL、Haskell 和 SML 构造启发的 迭代器 构建块。每个构建块都以适合 Python 的形式进行了重铸。
该模块标准化了一组核心的快速、内存高效的工具,这些工具本身或组合使用都很有用。它们共同形成了一个“迭代器代数”,使得可以用纯 Python 简洁高效地构建专门的工具。
例如,SML 提供了一个制表工具:tabulate(f)
,它产生一个序列 f(0), f(1), ...
。在 Python 中,可以通过组合 map()
和 count()
来形成 map(f, count())
来实现相同的效果。
无限迭代器
迭代器 |
参数 |
结果 |
示例 |
---|---|---|---|
[start[, step]] |
start, start+step, start+2*step, … |
|
|
p |
p0, p1, … plast, p0, p1, … |
|
|
elem [,n] |
elem, elem, elem, … 无限次或最多 n 次 |
|
在最短输入序列上终止的迭代器
迭代器 |
参数 |
结果 |
示例 |
---|---|---|---|
p [,func] |
p0, p0+p1, p0+p1+p2, … |
|
|
p, n |
(p0, p1, …, p_n-1), … |
|
|
p, q, … |
p0, p1, … plast, q0, q1, … |
|
|
iterable |
p0, p1, … plast, q0, q1, … |
|
|
data, selectors |
(d[0] if s[0]), (d[1] if s[1]), … |
|
|
predicate, seq |
seq[n], seq[n+1], 从谓词失败时开始 |
|
|
predicate, seq |
seq 中 predicate(elem) 失败的元素 |
|
|
iterable[, key] |
按 key(v) 的值分组的子迭代器 |
|
|
seq, [start,] stop [, step] |
来自 seq[start:stop:step] 的元素 |
|
|
iterable |
(p[0], p[1]), (p[1], p[2]) |
|
|
func, seq |
func(*seq[0]), func(*seq[1]), … |
|
|
predicate, seq |
seq[0], seq[1], 直到谓词失败 |
|
|
it, n |
it1, it2, … itn 将一个迭代器分成 n 个 |
||
p, q, … |
(p[0], q[0]), (p[1], q[1]), … |
|
组合迭代器
迭代器 |
参数 |
结果 |
---|---|---|
p, q, … [repeat=1] |
笛卡尔积,等效于嵌套的 for 循环 |
|
p[, r] |
r 长度的元组,所有可能的排序,没有重复元素 |
|
p, r |
r 长度的元组,按排序顺序,没有重复元素 |
|
p, r |
r 长度的元组,按排序顺序,带有重复元素 |
示例 |
结果 |
---|---|
|
|
|
|
|
|
|
|
Itertool 函数¶
以下函数都构造并返回迭代器。有些函数提供无限长度的流,因此它们只能由截断流的函数或循环访问。
- itertools.accumulate(iterable[, function, *, initial=None])¶
创建一个返回累积和或来自其他二元函数的累积结果的迭代器。
function 默认为加法。function 应该接受两个参数,一个累积的总数和一个来自 iterable 的值。
如果提供 initial 值,则累积将从该值开始,并且输出将比输入 iterable 多一个元素。
大致等效于
def accumulate(iterable, function=operator.add, *, initial=None): 'Return running totals' # accumulate([1,2,3,4,5]) → 1 3 6 10 15 # accumulate([1,2,3,4,5], initial=100) → 100 101 103 106 110 115 # accumulate([1,2,3,4,5], operator.mul) → 1 2 6 24 120 iterator = iter(iterable) total = initial if initial is None: try: total = next(iterator) except StopIteration: return yield total for element in iterator: total = function(total, element) yield total
要计算运行最小值,请将 function 设置为
min()
。对于运行最大值,请将 function 设置为max()
。或者对于运行乘积,请将 function 设置为operator.mul()
。要构建摊销表,请累积利息并应用付款>>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8] >>> list(accumulate(data, max)) # running maximum [3, 4, 6, 6, 6, 9, 9, 9, 9, 9] >>> list(accumulate(data, operator.mul)) # running product [3, 12, 72, 144, 144, 1296, 0, 0, 0, 0] # Amortize a 5% loan of 1000 with 10 annual payments of 90 >>> update = lambda balance, payment: round(balance * 1.05) - payment >>> list(accumulate(repeat(90, 10), update, initial=1_000)) [1000, 960, 918, 874, 828, 779, 728, 674, 618, 559, 497]
有关仅返回最终累积值的类似函数,请参阅
functools.reduce()
。在版本 3.2 中添加。
在版本 3.3 中更改: 添加了可选的 function 参数。
在版本 3.8 中更改: 添加了可选的 initial 参数。
- itertools.batched(iterable, n, *, strict=False)¶
将来自 iterable 的数据分批处理成长度为 n 的元组。最后一个批次可能比 n 短。
如果 strict 为 true,当最后一批的长度小于 n 时,会引发
ValueError
异常。循环遍历输入的可迭代对象,并将数据累积到大小为 n 的元组中。输入以惰性方式消耗,只消耗足够填满一批的数据。一旦批次填满或输入的可迭代对象耗尽,就会立即产生结果。
>>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet'] >>> unflattened = list(batched(flattened_data, 2)) >>> unflattened [('roses', 'red'), ('violets', 'blue'), ('sugar', 'sweet')]
大致等效于
def batched(iterable, n, *, strict=False): # batched('ABCDEFG', 3) → ABC DEF G if n < 1: raise ValueError('n must be at least one') iterator = iter(iterable) while batch := tuple(islice(iterator, n)): if strict and len(batch) != n: raise ValueError('batched(): incomplete batch') yield batch
在 3.12 版本中新增。
在 3.13 版本中更改: 增加了 strict 选项。
- itertools.chain(*iterables)¶
创建一个迭代器,该迭代器从第一个可迭代对象返回元素,直到它耗尽,然后继续处理下一个可迭代对象,直到所有可迭代对象都耗尽。这会将多个数据源组合成一个迭代器。大致等效于
def chain(*iterables): # chain('ABC', 'DEF') → A B C D E F for iterable in iterables: yield from iterable
- classmethod chain.from_iterable(iterable)¶
chain()
的替代构造函数。从一个惰性求值的可迭代参数中获取链式输入。大致等效于def from_iterable(iterables): # chain.from_iterable(['ABC', 'DEF']) → A B C D E F for iterable in iterables: yield from iterable
- itertools.combinations(iterable, r)¶
返回来自输入 iterable 的元素长度为 r 的子序列。
输出是
product()
的子序列,仅保留 iterable 的子序列的条目。输出的长度由math.comb()
给出,它计算n! / r! / (n - r)!
,当0 ≤ r ≤ n
时,或当r > n
时为零。组合元组按照输入 iterable 的顺序以字典序发出。如果输入 iterable 已排序,则输出元组将按排序顺序生成。
元素根据它们的位置而不是它们的值被视为唯一。如果输入元素是唯一的,则每个组合中都不会有重复的值。
大致等效于
def combinations(iterable, r): # combinations('ABCD', 2) → AB AC AD BC BD CD # combinations(range(4), 3) → 012 013 023 123 pool = tuple(iterable) n = len(pool) if r > n: return indices = list(range(r)) yield tuple(pool[i] for i in indices) while True: for i in reversed(range(r)): if indices[i] != i + n - r: break else: return indices[i] += 1 for j in range(i+1, r): indices[j] = indices[j-1] + 1 yield tuple(pool[i] for i in indices)
- itertools.combinations_with_replacement(iterable, r)¶
返回来自输入 iterable 的元素长度为 r 的子序列,允许单个元素重复多次。
输出是
product()
的子序列,仅保留 iterable 的子序列(具有可能的重复元素)的条目。返回的子序列的数量为(n + r - 1)! / r! / (n - 1)!
,当n > 0
时。组合元组按照输入 iterable 的顺序以字典序发出。如果输入 iterable 已排序,则输出元组将按排序顺序生成。
元素根据它们的位置而不是它们的值被视为唯一。如果输入元素是唯一的,则生成的组合也将是唯一的。
大致等效于
def combinations_with_replacement(iterable, r): # combinations_with_replacement('ABC', 2) → AA AB AC BB BC CC pool = tuple(iterable) n = len(pool) if not n and r: return indices = [0] * r yield tuple(pool[i] for i in indices) while True: for i in reversed(range(r)): if indices[i] != n - 1: break else: return indices[i:] = [indices[i] + 1] * (r - i) yield tuple(pool[i] for i in indices)
在 3.1 版本中新增。
- itertools.compress(data, selectors)¶
创建一个迭代器,该迭代器从 data 返回元素,其中 selectors 中对应的元素为 true。当 data 或 selectors 可迭代对象耗尽时停止。大致等效于
def compress(data, selectors): # compress('ABCDEF', [1,0,1,0,1,1]) → A C E F return (datum for datum, selector in zip(data, selectors) if selector)
在 3.1 版本中新增。
- itertools.count(start=0, step=1)¶
创建一个迭代器,该迭代器返回以 start 开头的均匀间隔的值。可以与
map()
一起使用以生成连续的数据点,或者与zip()
一起使用以添加序列号。大致等效于def count(start=0, step=1): # count(10) → 10 11 12 13 14 ... # count(2.5, 0.5) → 2.5 3.0 3.5 ... n = start while True: yield n n += step
当使用浮点数计数时,有时可以通过替换乘法代码来获得更好的精度,例如:
(start + step * i for i in count())
。在 3.1 版本中更改: 增加了 step 参数并允许非整数参数。
- itertools.cycle(iterable)¶
创建一个迭代器,该迭代器从 iterable 返回元素,并保存每个元素的副本。当可迭代对象耗尽时,从保存的副本返回元素。无限重复。大致等效于
def cycle(iterable): # cycle('ABCD') → A B C D A B C D A B C D ... saved = [] for element in iterable: yield element saved.append(element) while saved: for element in saved: yield element
此迭代工具可能需要大量辅助存储空间(取决于可迭代对象的长度)。
- itertools.dropwhile(predicate, iterable)¶
创建一个迭代器,当 predicate 为 true 时,该迭代器从 iterable 中删除元素,然后返回每个元素。大致等效于
def dropwhile(predicate, iterable): # dropwhile(lambda x: x<5, [1,4,6,3,8]) → 6 3 8 iterator = iter(iterable) for x in iterator: if not predicate(x): yield x break for x in iterator: yield x
请注意,在谓词首次变为 false 之前,这不会产生 任何 输出,因此此迭代工具的启动时间可能很长。
- itertools.filterfalse(predicate, iterable)¶
创建一个迭代器,该迭代器从 iterable 中筛选元素,仅返回那些 predicate 返回 false 值的元素。如果 predicate 是
None
,则返回那些为 false 的项目。大致等效于def filterfalse(predicate, iterable): # filterfalse(lambda x: x<5, [1,4,6,3,8]) → 6 8 if predicate is None: predicate = bool for x in iterable: if not predicate(x): yield x
- itertools.groupby(iterable, key=None)¶
创建一个迭代器,该迭代器从 iterable 返回连续的键和组。key 是一个函数,用于计算每个元素的键值。如果未指定或为
None
,则 key 默认为恒等函数并返回未更改的元素。通常,可迭代对象需要已按相同的键函数排序。groupby()
的操作类似于 Unix 中的uniq
过滤器。它每次键函数的值更改时生成一个中断或一个新组(这就是为什么通常需要使用相同的键函数对数据进行排序的原因)。这种行为与 SQL 的 GROUP BY 不同,后者会聚合公共元素,而不管它们的输入顺序如何。返回的组本身是一个迭代器,它与
groupby()
共享底层可迭代对象。由于源是共享的,因此当groupby()
对象前进时,之前的组将不再可见。因此,如果以后需要该数据,则应将其存储为列表groups = [] uniquekeys = [] data = sorted(data, key=keyfunc) for k, g in groupby(data, keyfunc): groups.append(list(g)) # Store group iterator as a list uniquekeys.append(k)
groupby()
大致等效于def groupby(iterable, key=None): # [k for k, g in groupby('AAAABBBCCDAABBB')] → A B C D A B # [list(g) for k, g in groupby('AAAABBBCCD')] → AAAA BBB CC D keyfunc = (lambda x: x) if key is None else key iterator = iter(iterable) exhausted = False def _grouper(target_key): nonlocal curr_value, curr_key, exhausted yield curr_value for curr_value in iterator: curr_key = keyfunc(curr_value) if curr_key != target_key: return yield curr_value exhausted = True try: curr_value = next(iterator) except StopIteration: return curr_key = keyfunc(curr_value) while not exhausted: target_key = curr_key curr_group = _grouper(target_key) yield curr_key, curr_group if curr_key == target_key: for _ in curr_group: pass
- itertools.islice(iterable, stop)¶
- itertools.islice(iterable, start, stop[, step])
创建一个迭代器,该迭代器返回来自可迭代对象的选定元素。工作方式类似于序列切片,但不支持 start、stop 或 step 的负值。
如果 start 为零或
None
,则迭代从零开始。否则,将跳过可迭代对象中的元素,直到到达 start。如果 stop 为
None
,则迭代会一直进行到输入耗尽为止(如果会耗尽)。否则,它会在指定位置停止。如果 step 为
None
,则步长默认为 1。除非 step 设置为大于 1 的值,从而跳过一些项,否则会连续返回元素。大致等效于
def islice(iterable, *args): # islice('ABCDEFG', 2) → A B # islice('ABCDEFG', 2, 4) → C D # islice('ABCDEFG', 2, None) → C D E F G # islice('ABCDEFG', 0, None, 2) → A C E G s = slice(*args) start = 0 if s.start is None else s.start stop = s.stop step = 1 if s.step is None else s.step if start < 0 or (stop is not None and stop < 0) or step <= 0: raise ValueError indices = count() if stop is None else range(max(start, stop)) next_i = start for i, element in zip(indices, iterable): if i == next_i: yield element next_i += step
如果输入是一个迭代器,则完全消耗 islice 会使输入迭代器前进
max(start, stop)
步,而与 step 值无关。
- itertools.pairwise(iterable)¶
返回从输入 iterable 中提取的连续重叠的对。
输出迭代器中的 2 元组的数量将比输入项的数量少一个。如果输入可迭代对象的值少于两个,则输出为空。
大致等效于
def pairwise(iterable): # pairwise('ABCDEFG') → AB BC CD DE EF FG iterator = iter(iterable) a = next(iterator, None) for b in iterator: yield a, b a = b
在 3.10 版本中添加。
- itertools.permutations(iterable, r=None)¶
返回来自 iterable 的长度为 r 的连续元素排列。
如果未指定 r 或 r 为
None
,则 r 默认为 iterable 的长度,并生成所有可能的完整长度排列。输出是
product()
的子序列,其中已过滤掉具有重复元素的条目。输出的长度由math.perm()
给出,它计算n! / (n - r)!
(当0 ≤ r ≤ n
时)或零 (当r > n
时)。排列元组根据输入 iterable 的顺序以字典顺序发出。如果输入 iterable 已排序,则输出元组将按排序顺序生成。
元素被视为基于其位置而不是其值的唯一元素。如果输入元素是唯一的,则排列中不会有重复的值。
大致等效于
def permutations(iterable, r=None): # permutations('ABCD', 2) → AB AC AD BA BC BD CA CB CD DA DB DC # permutations(range(3)) → 012 021 102 120 201 210 pool = tuple(iterable) n = len(pool) r = n if r is None else r if r > n: return indices = list(range(n)) cycles = list(range(n, n-r, -1)) yield tuple(pool[i] for i in indices[:r]) while n: for i in reversed(range(r)): cycles[i] -= 1 if cycles[i] == 0: indices[i:] = indices[i+1:] + indices[i:i+1] cycles[i] = n - i else: j = cycles[i] indices[i], indices[-j] = indices[-j], indices[i] yield tuple(pool[i] for i in indices[:r]) break else: return
- itertools.product(*iterables, repeat=1)¶
输入可迭代对象的笛卡尔积。
大致等同于生成器表达式中的嵌套 for 循环。例如,
product(A, B)
的返回值与((x,y) for x in A for y in B)
相同。嵌套循环像里程表一样循环,最右边的元素在每次迭代时前进。此模式创建了字典顺序,因此如果输入的迭代器已排序,则产品元组将按排序顺序发出。
要计算可迭代对象与其自身的乘积,请使用可选的 repeat 关键字参数指定重复次数。例如,
product(A, repeat=4)
与product(A, A, A, A)
的含义相同。此函数大致等效于以下代码,但实际实现不会在内存中构建中间结果
def product(*iterables, repeat=1): # product('ABCD', 'xy') → Ax Ay Bx By Cx Cy Dx Dy # product(range(2), repeat=3) → 000 001 010 011 100 101 110 111 if repeat < 0: raise ValueError('repeat argument cannot be negative') pools = [tuple(pool) for pool in iterables] * repeat result = [[]] for pool in pools: result = [x+[y] for x in result for y in pool] for prod in result: yield tuple(prod)
在
product()
运行之前,它会完全消耗输入的可迭代对象,将值池保存在内存中以生成产品。因此,它仅适用于有限的输入。
- itertools.repeat(object[, times])¶
创建一个迭代器,该迭代器会一次又一次地返回 object。无限期运行,除非指定了 times 参数。
大致等效于
def repeat(object, times=None): # repeat(10, 3) → 10 10 10 if times is None: while True: yield object else: for i in range(times): yield object
repeat 的常见用法是为 map 或 zip 提供常量值流
>>> list(map(pow, range(10), repeat(2))) [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
- itertools.starmap(function, iterable)¶
创建一个迭代器,该迭代器使用从 iterable 获取的参数计算 function。当参数参数已“预先压缩”到元组中时,将使用此方法而不是
map()
。map()
和starmap()
之间的区别类似于function(a,b)
和function(*c)
之间的区别。大致等效于def starmap(function, iterable): # starmap(pow, [(2,5), (3,2), (10,3)]) → 32 9 1000 for args in iterable: yield function(*args)
- itertools.takewhile(predicate, iterable)¶
创建一个迭代器,只要 predicate 为真,该迭代器就会从 iterable 返回元素。大致等效于
def takewhile(predicate, iterable): # takewhile(lambda x: x<5, [1,4,6,3,8]) → 1 4 for x in iterable: if not predicate(x): break yield x
注意,第一个未通过谓词条件的元素将从输入迭代器中消耗掉,并且无法访问该元素。如果应用程序希望在 takewhile 运行到耗尽后进一步消耗输入迭代器,这可能会成为一个问题。要解决此问题,请考虑改用 more-itertools before_and_after()。
- itertools.tee(iterable, n=2)¶
从单个可迭代对象返回 n 个独立的迭代器。
大致等效于
def tee(iterable, n=2): if n < 0: raise ValueError if n == 0: return () iterator = _tee(iterable) result = [iterator] for _ in range(n - 1): result.append(_tee(iterator)) return tuple(result) class _tee: def __init__(self, iterable): it = iter(iterable) if isinstance(it, _tee): self.iterator = it.iterator self.link = it.link else: self.iterator = it self.link = [None, None] def __iter__(self): return self def __next__(self): link = self.link if link[1] is None: link[0] = next(self.iterator) link[1] = [None, None] value, self.link = link return value
当输入 iterable 已经是 tee 迭代器对象时,返回元组的所有成员的构造方式就好像它们是由上游的
tee()
调用产生的。此“展平步骤”允许嵌套的tee()
调用共享相同的基础数据链,并且具有单个更新步骤,而不是一系列调用。展平属性使 tee 迭代器能够有效地进行窥视
def lookahead(tee_iterator): "Return the next value without moving the input forward" [forked_iterator] = tee(tee_iterator, 1) return next(forked_iterator)
>>> iterator = iter('abcdef') >>> [iterator] = tee(iterator, 1) # Make the input peekable >>> next(iterator) # Move the iterator forward 'a' >>> lookahead(iterator) # Check next value 'b' >>> next(iterator) # Continue moving forward 'b'
tee
迭代器不是线程安全的。当同时使用由同一个tee()
调用返回的迭代器时,即使原始的 iterable 是线程安全的,也可能会引发RuntimeError
。此迭代工具可能需要大量的辅助存储(取决于需要存储多少临时数据)。通常,如果一个迭代器在另一个迭代器启动之前使用大部分或所有数据,则使用
list()
比tee()
快。
- itertools.zip_longest(*iterables, fillvalue=None)¶
创建一个迭代器,该迭代器会聚合来自每个 iterables 的元素。
如果可迭代对象的长度不均匀,则缺少的值将用 fillvalue 填充。如果未指定,fillvalue 默认为
None
。迭代将一直持续到最长的可迭代对象耗尽为止。
大致等效于
def zip_longest(*iterables, fillvalue=None): # zip_longest('ABCD', 'xy', fillvalue='-') → Ax By C- D- iterators = list(map(iter, iterables)) num_active = len(iterators) if not num_active: return while True: values = [] for i, iterator in enumerate(iterators): try: value = next(iterator) except StopIteration: num_active -= 1 if not num_active: return iterators[i] = repeat(fillvalue) value = fillvalue values.append(value) yield tuple(values)
如果其中一个可迭代对象可能是无限的,则应使用某些限制调用次数的方法(例如
islice()
或takewhile()
)包装zip_longest()
函数。
Itertools 示例¶
本节显示了使用现有 itertools 作为构建块创建扩展工具集的示例。
itertools 配方的主要目的是教育性的。这些配方展示了思考各个工具的多种方式——例如,chain.from_iterable
与扁平化的概念相关。这些配方还提供了关于如何组合这些工具的想法——例如,starmap()
和 repeat()
如何协同工作。这些配方还展示了如何将 itertools 与 operator
和 collections
模块,以及内置的 itertools(如 map()
、filter()
、reversed()
和 enumerate()
)一起使用。
这些配方的第二个目的是作为孵化器。accumulate()
、compress()
和 pairwise()
这些 itertools 最初都是以配方的形式出现的。目前,sliding_window()
、iter_index()
和 sieve()
配方正在测试中,以检验它们的价值。
基本上所有这些配方以及许多其他配方都可以从 Python 包索引上的 more-itertools 项目中安装。
python -m pip install more-itertools
许多配方都提供了与底层工具集相同的高性能。通过一次处理一个元素而不是一次将整个可迭代对象加载到内存中来保持卓越的内存性能。通过以函数式风格将工具链接在一起,代码量保持较小。通过优先使用“向量化”构建块而不是使用 for 循环和会产生解释器开销的生成器来保持高速度。
from collections import deque
from contextlib import suppress
from functools import reduce
from math import sumprod, isqrt
from operator import itemgetter, getitem, mul, neg
def take(n, iterable):
"Return first n items of the iterable as a list."
return list(islice(iterable, n))
def prepend(value, iterable):
"Prepend a single value in front of an iterable."
# prepend(1, [2, 3, 4]) → 1 2 3 4
return chain([value], iterable)
def tabulate(function, start=0):
"Return function(0), function(1), ..."
return map(function, count(start))
def repeatfunc(function, times=None, *args):
"Repeat calls to a function with specified arguments."
if times is None:
return starmap(function, repeat(args))
return starmap(function, repeat(args, times))
def flatten(list_of_lists):
"Flatten one level of nesting."
return chain.from_iterable(list_of_lists)
def ncycles(iterable, n):
"Returns the sequence elements n times."
return chain.from_iterable(repeat(tuple(iterable), n))
def loops(n):
"Loop n times. Like range(n) but without creating integers."
# for _ in loops(100): ...
return repeat(None, n)
def tail(n, iterable):
"Return an iterator over the last n items."
# tail(3, 'ABCDEFG') → E F G
return iter(deque(iterable, maxlen=n))
def consume(iterator, n=None):
"Advance the iterator n-steps ahead. If n is None, consume entirely."
# Use functions that consume iterators at C speed.
if n is None:
deque(iterator, maxlen=0)
else:
next(islice(iterator, n, n), None)
def nth(iterable, n, default=None):
"Returns the nth item or a default value."
return next(islice(iterable, n, None), default)
def quantify(iterable, predicate=bool):
"Given a predicate that returns True or False, count the True results."
return sum(map(predicate, iterable))
def first_true(iterable, default=False, predicate=None):
"Returns the first true value or the *default* if there is no true value."
# first_true([a,b,c], x) → a or b or c or x
# first_true([a,b], x, f) → a if f(a) else b if f(b) else x
return next(filter(predicate, iterable), default)
def all_equal(iterable, key=None):
"Returns True if all the elements are equal to each other."
# all_equal('4٤௪౪໔', key=int) → True
return len(take(2, groupby(iterable, key))) <= 1
def unique_justseen(iterable, key=None):
"Yield unique elements, preserving order. Remember only the element just seen."
# unique_justseen('AAAABBBCCDAABBB') → A B C D A B
# unique_justseen('ABBcCAD', str.casefold) → A B c A D
if key is None:
return map(itemgetter(0), groupby(iterable))
return map(next, map(itemgetter(1), groupby(iterable, key)))
def unique_everseen(iterable, key=None):
"Yield unique elements, preserving order. Remember all elements ever seen."
# unique_everseen('AAAABBBCCDAABBB') → A B C D
# unique_everseen('ABBcCAD', str.casefold) → A B c D
seen = set()
if key is None:
for element in filterfalse(seen.__contains__, iterable):
seen.add(element)
yield element
else:
for element in iterable:
k = key(element)
if k not in seen:
seen.add(k)
yield element
def unique(iterable, key=None, reverse=False):
"Yield unique elements in sorted order. Supports unhashable inputs."
# unique([[1, 2], [3, 4], [1, 2]]) → [1, 2] [3, 4]
sequenced = sorted(iterable, key=key, reverse=reverse)
return unique_justseen(sequenced, key=key)
def sliding_window(iterable, n):
"Collect data into overlapping fixed-length chunks or blocks."
# sliding_window('ABCDEFG', 4) → ABCD BCDE CDEF DEFG
iterator = iter(iterable)
window = deque(islice(iterator, n - 1), maxlen=n)
for x in iterator:
window.append(x)
yield tuple(window)
def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
"Collect data into non-overlapping fixed-length chunks or blocks."
# grouper('ABCDEFG', 3, fillvalue='x') → ABC DEF Gxx
# grouper('ABCDEFG', 3, incomplete='strict') → ABC DEF ValueError
# grouper('ABCDEFG', 3, incomplete='ignore') → ABC DEF
iterators = [iter(iterable)] * n
match incomplete:
case 'fill':
return zip_longest(*iterators, fillvalue=fillvalue)
case 'strict':
return zip(*iterators, strict=True)
case 'ignore':
return zip(*iterators)
case _:
raise ValueError('Expected fill, strict, or ignore')
def roundrobin(*iterables):
"Visit input iterables in a cycle until each is exhausted."
# roundrobin('ABC', 'D', 'EF') → A D E B F C
# Algorithm credited to George Sakkis
iterators = map(iter, iterables)
for num_active in range(len(iterables), 0, -1):
iterators = cycle(islice(iterators, num_active))
yield from map(next, iterators)
def subslices(seq):
"Return all contiguous non-empty subslices of a sequence."
# subslices('ABCD') → A AB ABC ABCD B BC BCD C CD D
slices = starmap(slice, combinations(range(len(seq) + 1), 2))
return map(getitem, repeat(seq), slices)
def iter_index(iterable, value, start=0, stop=None):
"Return indices where a value occurs in a sequence or iterable."
# iter_index('AABCADEAF', 'A') → 0 1 4 7
seq_index = getattr(iterable, 'index', None)
if seq_index is None:
iterator = islice(iterable, start, stop)
for i, element in enumerate(iterator, start):
if element is value or element == value:
yield i
else:
stop = len(iterable) if stop is None else stop
i = start
with suppress(ValueError):
while True:
yield (i := seq_index(value, i, stop))
i += 1
def iter_except(function, exception, first=None):
"Convert a call-until-exception interface to an iterator interface."
# iter_except(d.popitem, KeyError) → non-blocking dictionary iterator
with suppress(exception):
if first is not None:
yield first()
while True:
yield function()
以下配方具有更强的数学风味
def powerset(iterable):
"Subsequences of the iterable from shortest to longest."
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
def sum_of_squares(iterable):
"Add up the squares of the input values."
# sum_of_squares([10, 20, 30]) → 1400
return sumprod(*tee(iterable))
def reshape(matrix, columns):
"Reshape a 2-D matrix to have a given number of columns."
# reshape([(0, 1), (2, 3), (4, 5)], 3) → (0, 1, 2), (3, 4, 5)
return batched(chain.from_iterable(matrix), columns, strict=True)
def transpose(matrix):
"Swap the rows and columns of a 2-D matrix."
# transpose([(1, 2, 3), (11, 22, 33)]) → (1, 11) (2, 22) (3, 33)
return zip(*matrix, strict=True)
def matmul(m1, m2):
"Multiply two matrices."
# matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) → (49, 80), (41, 60)
n = len(m2[0])
return batched(starmap(sumprod, product(m1, transpose(m2))), n)
def convolve(signal, kernel):
"""Discrete linear convolution of two iterables.
Equivalent to polynomial multiplication.
Convolutions are mathematically commutative; however, the inputs are
evaluated differently. The signal is consumed lazily and can be
infinite. The kernel is fully consumed before the calculations begin.
Article: https://betterexplained.com/articles/intuitive-convolution/
Video: https://www.youtube.com/watch?v=KuXjwB4LzSA
"""
# convolve([1, -1, -20], [1, -3]) → 1 -4 -17 60
# convolve(data, [0.25, 0.25, 0.25, 0.25]) → Moving average (blur)
# convolve(data, [1/2, 0, -1/2]) → 1st derivative estimate
# convolve(data, [1, -2, 1]) → 2nd derivative estimate
kernel = tuple(kernel)[::-1]
n = len(kernel)
padded_signal = chain(repeat(0, n-1), signal, repeat(0, n-1))
windowed_signal = sliding_window(padded_signal, n)
return map(sumprod, repeat(kernel), windowed_signal)
def polynomial_from_roots(roots):
"""Compute a polynomial's coefficients from its roots.
(x - 5) (x + 4) (x - 3) expands to: x³ -4x² -17x + 60
"""
# polynomial_from_roots([5, -4, 3]) → [1, -4, -17, 60]
factors = zip(repeat(1), map(neg, roots))
return list(reduce(convolve, factors, [1]))
def polynomial_eval(coefficients, x):
"""Evaluate a polynomial at a specific value.
Computes with better numeric stability than Horner's method.
"""
# Evaluate x³ -4x² -17x + 60 at x = 5
# polynomial_eval([1, -4, -17, 60], x=5) → 0
n = len(coefficients)
if not n:
return type(x)(0)
powers = map(pow, repeat(x), reversed(range(n)))
return sumprod(coefficients, powers)
def polynomial_derivative(coefficients):
"""Compute the first derivative of a polynomial.
f(x) = x³ -4x² -17x + 60
f'(x) = 3x² -8x -17
"""
# polynomial_derivative([1, -4, -17, 60]) → [3, -8, -17]
n = len(coefficients)
powers = reversed(range(1, n))
return list(map(mul, coefficients, powers))
def sieve(n):
"Primes less than n."
# sieve(30) → 2 3 5 7 11 13 17 19 23 29
if n > 2:
yield 2
data = bytearray((0, 1)) * (n // 2)
for p in iter_index(data, 1, start=3, stop=isqrt(n) + 1):
data[p*p : n : p+p] = bytes(len(range(p*p, n, p+p)))
yield from iter_index(data, 1, start=3)
def factor(n):
"Prime factors of n."
# factor(99) → 3 3 11
# factor(1_000_000_000_000_007) → 47 59 360620266859
# factor(1_000_000_000_000_403) → 1000000000000403
for prime in sieve(isqrt(n) + 1):
while not n % prime:
yield prime
n //= prime
if n == 1:
return
if n > 1:
yield n
def is_prime(n):
"Return True if n is prime."
# is_prime(1_000_000_000_000_403) → True
return n > 1 and next(factor(n)) == n
def totient(n):
"Count of natural numbers up to n that are coprime to n."
# https://mathworld.wolfram.com/TotientFunction.html
# totient(12) → 4 because len([1, 5, 7, 11]) == 4
for prime in set(factor(n)):
n -= n // prime
return n