# adopted from: https://github.com/teddykoker/tinyloader/blob/main/dataloader.py
import math
from typing import Optional, Any, Callable, List
from .dataset import Dataset
from . import _utils
[docs]class DataLoader:
    """
    DataLoader provides an iterable over the given dataset. It supports
    automatic mini-batching now.
    Parameters
    ----------
    dataset : Dataset
        Dataset from which to load the data.
    batch_size : int, optional, default=1
        How many samples per batch to load.
    collate_fn : callable, optional
        Merge a list of samples to form a mini-batch of Tensor(s).
    """
    def __init__(
        self,
        dataset: Dataset,
        batch_size: Optional[int] = 1,
        collate_fn: Optional[Callable] = None
    ):
        self.index = 0
        self.dataset = dataset
        self.batch_size = batch_size
        if collate_fn is None:
            collate_fn = _utils.default_collate
        self.collate_fn = collate_fn
    def __iter__(self):
        self.index = 0
        return self
    def __next__(self) -> Any:
        if self.index >= len(self.dataset):
            raise StopIteration
        batch_size = min(len(self.dataset) - self.index, self.batch_size)
        return self.collate_fn([self.get() for _ in range(batch_size)])
    def __len__(self) -> int:
        return math.ceil(len(self.dataset) / self.batch_size)
[docs]    def get(self):
        item = self.dataset[self.index]
        self.index += 1
        return item