Source code for flint.utils.data.dataloader

# adopted from: https://github.com/teddykoker/tinyloader/blob/main/dataloader.py

import math
from typing import Optional, Any, Callable, List
from .dataset import Dataset
from . import _utils

[docs]class DataLoader: """ DataLoader provides an iterable over the given dataset. It supports automatic mini-batching now. Parameters ---------- dataset : Dataset Dataset from which to load the data. batch_size : int, optional, default=1 How many samples per batch to load. collate_fn : callable, optional Merge a list of samples to form a mini-batch of Tensor(s). """ def __init__( self, dataset: Dataset, batch_size: Optional[int] = 1, collate_fn: Optional[Callable] = None ): self.index = 0 self.dataset = dataset self.batch_size = batch_size if collate_fn is None: collate_fn = _utils.default_collate self.collate_fn = collate_fn def __iter__(self): self.index = 0 return self def __next__(self) -> Any: if self.index >= len(self.dataset): raise StopIteration batch_size = min(len(self.dataset) - self.index, self.batch_size) return self.collate_fn([self.get() for _ in range(batch_size)]) def __len__(self) -> int: return math.ceil(len(self.dataset) / self.batch_size)
[docs] def get(self): item = self.dataset[self.index] self.index += 1 return item