Source code for nlpstack.data.dataloaders
import math
import random
from typing import Dict, Iterator, List, Sequence
from collatable.collator import Collator
from collatable.instance import Instance
from collatable.typing import DataArray
[docs]class BatchSampler:
"""
A batch sampler is responsible for generating batches of indices from a dataset.
"""
[docs] def get_batch_indices(self, dataset: Sequence[Instance]) -> Iterator[List[int]]:
"""
Returns an iterator over batches of indices from the dataset.
Args:
dataset: The dataset to sample from.
Returns:
An iterator over batches of indices from the dataset.
"""
raise NotImplementedError
[docs] def get_num_batches(self, dataset: Sequence[Instance]) -> int:
"""
Returns the number of batches in the dataset.
Args:
dataset: The dataset to sample from.
Returns:
The number of batches in the dataset.
"""
raise NotImplementedError
[docs] def get_batch_size(self) -> int:
"""
Returns the batch size.
Returns:
The batch size.
"""
raise NotImplementedError
[docs]class BasicBatchSampler(BatchSampler):
"""
A basic batch sampler that generates batches of indices from a dataset.
Args:
batch_size: The batch size.
shuffle: Whether to shuffle the dataset before sampling.
drop_last: Whether to drop the last batch if it is smaller than the batch size.
"""
def __init__(
self,
batch_size: int,
shuffle: bool = False,
drop_last: bool = False,
) -> None:
self._batch_size = batch_size
self._shuffle = shuffle
self._drop_last = drop_last
[docs] def get_batch_indices(self, dataset: Sequence[Instance]) -> Iterator[List[int]]:
indices = list(range(len(dataset)))
if self._shuffle:
random.shuffle(indices)
batch_indices: List[int] = []
for index in indices:
batch_indices.append(index)
if len(batch_indices) == self._batch_size:
yield batch_indices
batch_indices = []
if batch_indices and not self._drop_last:
yield batch_indices
[docs] def get_num_batches(self, dataset: Sequence[Instance]) -> int:
if self._drop_last:
return len(dataset) // self._batch_size
return math.ceil(len(dataset) / self._batch_size)
[docs] def get_batch_size(self) -> int:
return self._batch_size
[docs]class BatchIterator:
def __init__(
self,
dataset: Sequence[Instance],
sampler: BatchSampler,
) -> None:
self._dataset = dataset
self._sampler = sampler
self._collator = Collator()
self._batch_indices = iter(self._sampler.get_batch_indices(self._dataset))
def __len__(self) -> int:
return self._sampler.get_num_batches(self._dataset)
def __next__(self) -> Dict[str, DataArray]:
batch_indices = next(self._batch_indices)
return self._collator([self._dataset[i] for i in batch_indices])
def __iter__(self) -> Iterator[Dict[str, DataArray]]:
return self
[docs]class DataLoader:
"""
A data loader is responsible for iterating over batches of instances from a dataset.
Args:
sampler: The batch sampler to use.
"""
def __init__(
self,
sampler: BatchSampler,
) -> None:
self._sampler = sampler
def __call__(self, dataset: Sequence[Instance]) -> BatchIterator:
"""
Returns an iterator over batches of instances from the dataset.
Args:
dataset: The dataset to iterate over.
Returns:
An iterator over batches of instances from the dataset.
"""
return BatchIterator(dataset, self._sampler)
[docs] def get_batch_size(self) -> int:
"""
Returns the batch size.
Returns:
The batch size.
"""
return self._sampler.get_batch_size()