Source code for flint.nn.modules.flatten

from flint import Tensor
from .module import Module
from .. import functional as F

[docs]class Flatten(Module): """ Flatten the input. Does not affect the batch size. NOTE: If inputs are shaped ``(batch,)`` without a feature axis, then flattening adds an extra channel dimension and output shape is ``(batch, 1)``. """ def __init__(self) -> None: super(Flatten, self).__init__() def forward(self, input: Tensor) -> Tensor: self.output = F.flatten(input) return self.output