Source code for flint.nn.modules.fold

from flint import Tensor
from .. import functional as F
from ..types import _size_2_t
from .module import Module


[docs]class Unfold(Module): """ Extracts sliding local blocks from a batched input tensor. See :func:`flint.nn.functional.unfold` for more details. - input shape: :math:`(N, C, H, W)` - output shape: :math:`(N, C \\times \prod(\\text{kernel\_size}), L)` where: .. math:: L = \prod_d \\frac{\\text{spatial\_size[d] + 2 * padding[d] - dilation[d] * (kernel\_size[d] - 1) - 1}}{\\text{stride}[d]} + 1 where :math:`\\text{spatial\_size}` is formed by the spatial dimensions of ``input`` (H and W above), and :math:`d` is over all spatial dimensions. Parameters ---------- input : Tensor Input tensor kernel_size : int or tuple Size of the sliding blocks. stride : int or tuple, optional, default=1 Stride of the sliding blocks in the input spatial dimensions. padding : int or tuple, optional, default=0 Implicit zero padding to be added on both sides of input. dilation : int or tuple, optional, default=1 A parameter that controls the stride of elements within the neighborhood. """ def __init__( self, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: _size_2_t = 0, dilation: _size_2_t = 1 ) -> None: super(Unfold, self).__init__() self.kernel_size = kernel_size self.dilation = dilation self.padding = padding self.stride = stride def forward(self, input: Tensor) -> Tensor: out, _, _ = F.unfold( input, self.kernel_size, self.stride, self.padding, self.dilation ) return out