Source code for nlpstack.common.iterutil
import itertools
import math
from collections import abc
from typing import Any, Callable, Generic, Iterable, Iterator, List, TypeVar
T = TypeVar("T")
[docs]class SizedIterator(Generic[T]):
"""
A wrapper for an iterator that knows its size.
Args:
iterator: The iterator.
size: The size of the iterator.
"""
def __init__(self, iterator: Iterator[T], size: int):
self.iterator = iterator
self.size = size
def __iter__(self) -> Iterator[T]:
return self.iterator
def __next__(self) -> T:
return next(self.iterator)
def __len__(self) -> int:
return self.size
[docs]def batched(iterable: Iterable[T], batch_size: int, drop_last: bool = False) -> Iterator[List[T]]:
"""
Batch an iterable into lists of the given size.
Args:
iterable: The iterable.
batch_size: The size of each batch.
drop_last: Whether to drop the last batch if it is smaller than the given size.
Returns:
An iterator over batches.
"""
def iterator() -> Iterator[List[T]]:
batch = []
for item in iterable:
batch.append(item)
if len(batch) == batch_size:
yield batch
batch = []
if batch and not drop_last:
yield batch
if isinstance(iterable, abc.Sized):
num_batches = math.ceil(len(iterable) / batch_size)
return SizedIterator(iterator(), num_batches)
return iterator()
[docs]def batched_iterator(iterable: Iterable[T], batch_size: int) -> Iterator[Iterator[T]]:
"""
Batch an iterable into iterators of the given size.
Args:
iterable: The iterable.
batch_size: The size of each batch.
Returns:
An iterator over batches.
"""
def iterator() -> Iterator[Iterator[T]]:
stop = False
batch_progress = 0
def iterator_wrapper() -> Iterator[T]:
nonlocal batch_progress
for item in iterable:
yield item
batch_progress += 1
iterator = iterator_wrapper()
def consume(n: int) -> Iterator[T]:
for _ in range(n):
try:
yield next(iterator)
except StopIteration:
nonlocal stop
stop = True
break
while not stop:
try:
batch_progress = 0
yield itertools.chain([next(iterator)], consume(batch_size - 1))
for _ in range(batch_size - batch_progress):
next(iterator)
except StopIteration:
break
if isinstance(iterable, abc.Sized):
num_batches = math.ceil(len(iterable) / batch_size)
return SizedIterator(iterator(), num_batches)
return iterator()
[docs]def iter_with_callback(
iterable: Iterable[T],
callback: Callable[[T], Any],
) -> Iterator[T]:
"""
Iterate over an iterable and call a callback for each item.
Args:
iterable: The iterable.
callback: The callback to call for each item.
Returns:
An iterator over the iterable.
"""
def iterator() -> Iterator[T]:
for item in iterable:
yield item
callback(item)
if isinstance(iterable, abc.Sized):
return SizedIterator(iterator(), len(iterable))
return iterator()
[docs]def wrap_iterator(wrapper: Callable[[Iterable[T]], Iterator[T]], iterable: Iterable[T]) -> Iterator:
"""
Wrap an iterator with a function.
Note:
This function assume that the wrapped iterator is of the same size as the input iterator.
Args:
wrapper: The function to wrap the iterator.
Returns:
An iterator wrapped with the function.
"""
def wrapped() -> Iterator[T]:
return wrapper(iterable)
if isinstance(iterable, abc.Sized):
return SizedIterator(wrapped(), len(iterable))
return wrapped()