# adopted from: https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataset.py
import bisect
from typing import List, Iterable
[docs]class Dataset:
def __getitem__(self, index):
raise NotImplementedError
def __add__(self, other: 'Dataset') -> 'ConcatDataset':
return ConcatDataset([self, other])
[docs]class ConcatDataset(Dataset):
"""
Dataset as a concatenation of multiple datasets. This class is useful
to assemble different existing datasets.
Args:
datasets (list): List of datasets to be concatenated
"""
datasets: List[Dataset]
cumulative_sizes: List[int]
[docs] @staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, 'datasets should not be an empty iterable'
self.datasets = list(datasets)
self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]