# adopted from: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/container.py
from collections import OrderedDict
from typing import overload, Iterator
from flint import Tensor
from .module import Module
[docs]class Sequential(Module):
    """
    A sequential container. Modules will be added to it in the order they
    are passed in the constructor. Alternatively, an ordered dict of modules
    can also be passed in.
    """
    @overload
    def __init__(self, *args: Module) -> None:
        ...
    @overload
    def __init__(self, arg: 'OrderedDict[str, Module]') -> None:
        ...
    def __init__(self, *args):
        super(Sequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)
    def __len__(self) -> int:
        return len(self._modules)
    def __iter__(self) -> Iterator[Module]:
        return iter(self._modules.values())
    def forward(self, input: Tensor) -> Tensor:
        for module in self:
            input = module(input)
        return input