Source code for flint.utils.data.dataset

# adopted from: https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataset.py

import bisect
from typing import List, Iterable

[docs]class Dataset: def __getitem__(self, index): raise NotImplementedError def __add__(self, other: 'Dataset') -> 'ConcatDataset': return ConcatDataset([self, other])
[docs]class ConcatDataset(Dataset): """ Dataset as a concatenation of multiple datasets. This class is useful to assemble different existing datasets. Args: datasets (list): List of datasets to be concatenated """ datasets: List[Dataset] cumulative_sizes: List[int]
[docs] @staticmethod def cumsum(sequence): r, s = [], 0 for e in sequence: l = len(e) r.append(l + s) s += l return r
def __init__(self, datasets: Iterable[Dataset]) -> None: super(ConcatDataset, self).__init__() assert len(datasets) > 0, 'datasets should not be an empty iterable' self.datasets = list(datasets) self.cumulative_sizes = self.cumsum(self.datasets) def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): if idx < 0: if -idx > len(self): raise ValueError("absolute value of index should not exceed dataset length") idx = len(self) + idx dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] return self.datasets[dataset_idx][sample_idx]