Source code for flint.nn.modules.container

# adopted from: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/container.py

from collections import OrderedDict
from typing import overload, Iterator

from flint import Tensor
from .module import Module

[docs]class Sequential(Module): """ A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in. """ @overload def __init__(self, *args: Module) -> None: ... @overload def __init__(self, arg: 'OrderedDict[str, Module]') -> None: ... def __init__(self, *args): super(Sequential, self).__init__() if len(args) == 1 and isinstance(args[0], OrderedDict): for key, module in args[0].items(): self.add_module(key, module) else: for idx, module in enumerate(args): self.add_module(str(idx), module) def __len__(self) -> int: return len(self._modules) def __iter__(self) -> Iterator[Module]: return iter(self._modules.values()) def forward(self, input: Tensor) -> Tensor: for module in self: input = module(input) return input