Source code for flint.nn.modules.flatten
from flint import Tensor
from .module import Module
from .. import functional as F
[docs]class Flatten(Module):
    """
    Flatten the input. Does not affect the batch size.
    NOTE:
        If inputs are shaped ``(batch,)`` without a feature axis, then flattening
        adds an extra channel dimension and output shape is ``(batch, 1)``.
    """
    def __init__(self) -> None:
        super(Flatten, self).__init__()
    def forward(self, input: Tensor) -> Tensor:
        self.output = F.flatten(input)
        return self.output