Source code for flint.nn.modules.loss

from flint import Tensor
from .module import Module
from .. import functional as F

__all__ = [
    'Loss',
    'NllLoss',
    'CrossEntropyLoss',
    'MSELoss',
    'BCELoss'
]


[docs]class Loss(Module): def __init__(self, reduction: str = 'mean') -> None: super(Loss, self).__init__() self.reduction = reduction
[docs]class NllLoss(Loss): """ Negative Log Likelihood Loss. See :func:`flint.nn.functional.nll_loss` for more details. Args: reduction (str, optional): 'none' / 'mean' / 'sum' """ def __init__(self, reduction: str = 'mean') -> None: super(NllLoss, self).__init__(reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: self.data = F.nll_loss(input, target, reduction=self.reduction) return self.data
[docs]class CrossEntropyLoss(Loss): """ Cross Entropy Loss, combines :func:`~flint.Tensor.softmax` and :func:`~flint.nn.functional.nll_loss`. See :func:`flint.nn.functional.cross_entropy` for more details. Args: reduction (str, optional): 'none' / 'mean' / 'sum' """ def __init__(self, reduction: str = 'mean') -> None: super(CrossEntropyLoss, self).__init__(reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: self.data = F.cross_entropy(input, target, reduction=self.reduction) return self.data
[docs]class MSELoss(Loss): """ Mean Squared Error Loss: :math:`(x - y)^2`. See :func:`flint.nn.functional.mse_loss` for more details. Args: reduction (str, optional): 'none' / 'mean' / 'sum' """ def __init__(self, reduction: str = 'mean') -> None: super(MSELoss, self).__init__(reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: self.data = F.mse_loss(input, target, reduction=self.reduction) return self.data
[docs]class BCELoss(Loss): """ Binary Cross Entropy Loss: .. math:: \\text{loss} = y \log(x) + (1 - y) \log(1 - x) See :func:`flint.nn.functional.binary_cross_entropy` for more details. Args: reduction (str, optional): 'none' / 'mean' / 'sum' """ def __init__(self, reduction: str = 'mean') -> None: super(BCELoss, self).__init__(reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: self.data = F.binary_cross_entropy(input, target, reduction=self.reduction) return self.data