from __future__ import annotations
from collections.abc import Iterable
from itertools import islice
from sys import maxsize
def batched[T](
iterable: Iterable[T],
n: int,
*,
strict: bool = False,
) -> Iterable[tuple[T, ...]]:
"""Split an iterable into batches of size n."""
if not 1 <= n <= maxsize:
msg = "Batch size n must be at least one and at most sys.maxsize."
raise ValueError(msg)
iterator = iter(iterable)
while batch := tuple(islice(iterator, n)):
if strict and len(batch) != n:
msg = "Incomplete batch for strict batching."
raise ValueError(msg)
yield batch
Working with iterables is a cornerstone of Python programming. While the built-in itertools
module provides a powerful toolkit, sometimes you need a specific recipe for a common problem. This post collects a few handy and (mostly!) type-safe utility functions for working with iterators and iterables, complete with explanations and examples.
batched
This function splits an iterable into non-overlapping batches of a given size. It’s a common requirement when processing large datasets or making batch API calls. A similar function was added to itertools
in Python 3.12, but this implementation is useful for older Python versions or as a standalone utility.
Example
unbox
Sometimes you have an iterable that is expected to contain exactly one item. This function ‘unboxes’ it, returning the single element. It raises a ValueError
if the iterable is empty or contains more than one element, ensuring that your assumptions about the data are met.
Example
unzip
The unzip
function is the reverse of the built-in zip
. Given an iterable of tuples, it separates the elements into a tuple of iterables, where each iterable contains the elements from the corresponding position in the input tuples.
Example
flatten_iterable
Dealing with nested lists or other iterables is a common task. flatten_iterable
recursively traverses a nested structure of iterables and yields a single, flat list of all the non-iterable elements. It’s careful to not flatten strings, which are often treated as atomic values rather than iterables of characters.
from collections.abc import Generator
from typing import cast
type NestedIterable[T] = Iterable[NestedIterable[T] | T]
def flatten_iterable[T](iterable: NestedIterable[T]) -> list[T]:
"""Flatten an iterable of iterables."""
def _helper(lst: NestedIterable[T]) -> Generator[T]:
for item in lst:
if isinstance(item, Iterable) and not isinstance(item, (str, bytes)):
yield from _helper(cast(NestedIterable[T], item))
else:
yield cast(T, item)
return list(_helper(iterable))
Example
flatten_string_key_dict
This function flattens a nested dictionary into a single-level dictionary by concatenating the keys with a dot (.
). This is particularly useful for handling configuration files or nested JSON objects where you want a simple key-value representation.
from collections.abc import Mapping
type NestedDict[T] = Mapping[str, NestedDict[T] | T]
def flatten_string_key_dict[T](dictionary: NestedDict[T]) -> dict[str, T]:
"""Flatten a nested string-keyed dictionary."""
def _generate_items(
d: NestedDict[T] | T, prefix: str = ""
) -> Generator[tuple[str, T]]:
if isinstance(d, dict):
for key, value in d.items():
yield from _generate_items(value, prefix + key + ".")
else:
yield prefix[:-1], cast(T, d)
return dict(_generate_items(dictionary))
Example
group_by_non_consecutive
The standard itertools.groupby
requires the input iterable to be sorted by the grouping key. This utility simplifies the process by first sorting the entire iterable before grouping. This is less memory-efficient for very large iterables but is very convenient for smaller datasets where the input is not pre-sorted.
import itertools
from collections.abc import Callable, Iterator
from typing import Any, Protocol, Self
class _SupportsLessThan(Protocol):
def __lt__(self, other: Self) -> bool: ...
def group_by_non_consecutive[T, SLT: _SupportsLessThan](
iterable: Iterable[T],
*,
key: Callable[[T], SLT] | None = None,
reverse: bool = False,
) -> Iterator[tuple[SLT, Iterator[T]]]:
"""Sort and group an iterable by a key function.
Unlike itertools.groupby, this function does not require the iterable
to be pre-sorted. It first sorts the entire iterable and then applies
the grouping.
"""
if key is None:
def default_key(x: T) -> SLT:
return cast(SLT, x)
key = default_key
sorted_iterable = sorted(iterable, key=key, reverse=reverse) # type: ignore[arg-type]
return itertools.groupby(sorted_iterable, key=key) # type: ignore[arg-type]
Example
Show the code
# With key function
data = ["apple", "banana", "ant", "bear", "apricot"]
print(f"group_by_non_consecutive({data}, key=lambda x: x[0])")
for key, group in group_by_non_consecutive(data, key=lambda x: x[0]):
print(f" {key}: {list(group)}")
# Without key function
data_nums = [1, 2, 1, 3, 2, 1]
print(f"group_by_non_consecutive({data_nums})")
for key, group in group_by_non_consecutive(data_nums):
print(f" {key}: {list(group)}")
group_by_non_consecutive(['apple', 'banana', 'ant', 'bear', 'apricot'], key=lambda x: x[0])
a: ['apple', 'ant', 'apricot']
b: ['banana', 'bear']
group_by_non_consecutive([1, 2, 1, 3, 2, 1])
1: [1, 1, 1]
2: [2, 2]
3: [3]
These recipes provide robust, type-safe solutions to common problems when working with iterables in Python. Feel free to use them in your own projects. You can download the complete script with all the utilities and examples below.
Download the whole code here.