"""
Some of the code is borrowed from: https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
"""
import math
import numpy as np
from typing import Optional, Union
from flint import Tensor
[docs]def calculate_gain(nonlinearity: str, param: Optional[Union[int, float]] = None):
"""
Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\\frac{2}{1 + \\text{negative\_slope}^2}}`
SELU :math:`\\frac{3}{4}`
================= ====================================================
Parameters
----------
nonlinearity : str
Name of the non-linear function
param : Union[int, float], optional
Optional parameter for the non-linear function
"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
elif nonlinearity == 'tanh':
return 5.0 / 3
elif nonlinearity == 'relu':
return math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
elif nonlinearity == 'selu':
return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
[docs]def zeros_(tensor: Tensor) -> None:
"""
Fill the tensor with the scalar value ``0``.
Args:
tensor (Tensor): A Tensor
"""
tensor.zero_()
[docs]def ones_(tensor: Tensor) -> None:
"""
Fill the tensor with the scalar value ``1``.
Args:
tensor (Tensor): A Tensor
"""
tensor.one_()
[docs]def constant_(tensor: Tensor, val: float) -> None:
"""
Fill the tensor with the given scalar value ``val``.
Args:
tensor (Tensor): A Tensor
val (float): The value to fill the tensor with
"""
tensor.fill_(val)
[docs]def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> None:
"""
Fills the tensor with values drawn from the normal distribution.
Args:
tensor (Tensor): A Tensor
mean (float): The mean of the normal distribution
std (float): The standard deviation of the normal distribution
"""
tensor.normal_(mean=mean, std=std)
def _calculate_fan_in_and_fan_out(tensor: Tensor):
"""
Compute number of input and output nodes for a tensor.
Parameters
----------
tensor : Tensor
A Tensor
Returns
-------
fan_in : int
Number of input nodes
fan_out : int
Number of output nodes
"""
dimensions = tensor.ndim
if dimensions < 2:
raise ValueError('Fan in and fan out can not be computed for tensor with fewer than 2 dimensions')
num_input_fmaps = tensor.shape[1]
num_output_fmaps = tensor.shape[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = np.prod(tensor.shape[2:])
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
[docs]def xavier_normal_(tensor: Tensor, gain: float = 1.) -> None:
"""
Implementation of Xavier initialization proposed in [1]. Also known
as Glorot initialization, using a normal distribution.
The resulting tensor will have values sampled from :math:`N(0, \\text{std}^2)`,
where ``std = gain * sqrt(2 / (fan_in + fan_out))``
Parameters
----------
tensor : Tensor
A Tensor
gain : float, optional, default=1.
An optional scaling factor
References
----------
1. "`Understanding the Difficulty of Training Deep Feedforward Neural Networks. <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_" Xavier Glorot and Yoshua Bengio. AISTATS 2010.
"""
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
std = gain * math.sqrt(2.0 / (fan_in + fan_out))
tensor.normal_(mean=0, std=std)
def _calculate_correct_fan(tensor: Tensor, mode: str):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out
[docs]def kaiming_normal_(
tensor: Tensor,
a: float = 0.,
mode: str = 'fan_in',
nonlinearity: str = 'leaky_relu'
) -> None:
"""
Implementation of Kaiming initialization proposed in [1]. Also known
as He initialization, using a normal distribution.
The resulting tensor will have values sampled from :math:`N(0, \\text{std}^2)`,
where ``std = gain / sqrt(fan_mode)``.
Parameters
----------
tensor : Tensor
A Tensor
a : float, optional, default=0.
The negative slope of the rectifier used after this layer (only used
with 'leaky_relu')
mode : str, optional, default='fan_in'
Either ``'fan_in'`` or ``'fan_out'``. ``'fan_in'`` for preserving the
magnitude of the variance of the weights in the forward pass. ``'fan_out'``
for preserving the magnitudes in the backwards pass.
nonlinearity : str, optional, default='leaky_relu'
Name of the non-linear function, recommended to use only with 'relu'
or 'leaky_relu'
References
----------
1. "`Delving Deep into Rectifiers: Surpassing Human-level Performance on ImageNet Classification. \
<https://arxiv.org/pdf/1502.01852.pdf>`_" Kaiming He, et al. ICCV 2015.
"""
fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
tensor.normal_(mean=0, std=std)
[docs]def lecun_normal_(tensor: Tensor) -> None:
"""
Implementation of LeCun initialization, using a normal distribution.
The resulting tensor will have values sampled from :math:`N(0, \\text{std}^2)`,
where ``std = sqrt(1 / fan_in)``.
Args:
tensor (Tensor): A Tensor
References
----------
1. "`Efficient Backprop. <http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf>`_" Yann LeCun, et al. 1998.
"""
fan_in, _ = _calculate_fan_in_and_fan_out(tensor)
std = math.sqrt(1.0 / fan_in)
tensor.normal_(mean=0, std=std)