from collections import OrderedDict
from typing import Optional, Union, Iterator, Tuple, Set
from flint import Tensor
from .. import Parameter
[docs]class Module(object):
"""
Base class for all modules.
Args:
name (str): name of the module
"""
training: bool
def __init__(self) -> None:
self.training = True
self._parameters = OrderedDict()
self._modules = OrderedDict()
[docs] def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
"""
Add a parameter to the module.
Args:
name (str): name of the parameter
param (Parameter): parameter to be added to the module
"""
if param is None:
self._parameters[name] = None
else:
self._parameters[name] = param
[docs] def add_module(self, name: str, module: Optional['Module']) -> None:
"""
Add a child module to the current module.
Args:
name (str): name of the child module
module (Module): child module to be added to the module
"""
if module is None:
self._modules[name] = None
else:
self._modules[name] = module
[docs] def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
"""
Returns an iterator over module parameters, only yielding the parameter itself.
Args:
recurse (bool): If ``True``, yields parameters of this module and all
submodules. If ``False``, yields only parameters that are direct
members of this module.
Yields:
Parameter: module parameter
"""
for name, param in self.named_parameters(recurse = recurse):
yield param
[docs] def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
"""
Returns an iterator over module parameters, yielding both the name of the
parameter as well as the parameter itself.
Adapted from: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
Args:
prefix (str): prefix to prepend to all parameter names.
recurse (bool):
True: yield parameters of this module and all submodules
False: yield only parameters that are direct members of this module
Yields:
(string, Parameter): Tuple containing the name and parameter
"""
memo = set()
modules = self.named_modules(prefix = prefix) if recurse else [(prefix, self)]
for module_prefix, module in modules:
params = module._parameters.items()
for k, v in params:
if v is None or v in memo:
continue
memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k
yield name, v
[docs] def children(self) -> Iterator['Module']:
"""
Returns an iterator over immediate children modules.
Yields:
module (Module): A child module
"""
for name, module in self.named_children():
yield module
[docs] def named_children(self) -> Iterator[Tuple[str, 'Module']]:
"""
Returns an iterator over immediate children modules, yielding both the
name of the module as well as the module itself.
Yields:
(string, Module): Tuple containing a name and child module
"""
memo = set()
for name, module in self._modules.items():
if module is not None and module not in memo:
memo.add(module)
yield name, module
[docs] def modules(self) -> Iterator['Module']:
"""
Returns an iterator over all modules in the network, only yielding the module itself.
yields:
Module: a module in the network
NOTE:
Duplicate modules are returned only once.
"""
for name, module in self.named_modules():
yield module
[docs] def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '') -> Iterator[Tuple[str, 'Module']]:
"""
Returns an iterator over all modules in the network, yielding
both the name of the module as well as the module itself.
Borrowed from: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
Args:
memo (Set): a set for recording visited modules
prefix (str): prefix to prepend to all parameter names
yields:
(string, Module): Tuple of name and module
NOTE:
Duplicate modules are returned only once.
"""
if memo is None:
memo = set()
if self not in memo:
memo.add(self)
yield prefix, self
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
for m in module.named_modules(memo, submodule_prefix):
yield m
[docs] def train(self, mode: bool = True) -> 'Module':
"""
Sets the module in training mode.
This has effect only on the following modules:
- :class:`flint.nn.Dropout`
See their documentations for details of their behaviors in training /
evaluation mode.
Parameters
----------
mode : bool, optional, default=True
Whether to set training mode (``True``) or evaluation mode (``False``)
Returns
-------
module : Module
"""
self.training = mode
for module in self.children():
module.train(mode)
return self
[docs] def eval(self) -> 'Module':
"""
Sets the module in evaluation mode.
This has effect only on the following modules:
- :class:`flint.nn.Dropout`
See their documentations for details of their behaviors in training /
evaluation mode.
Returns
-------
module : Module
"""
return self.train(False)
def __call__(self, *input, **kwargs) -> Tensor:
out = self.forward(*input, **kwargs)
return out
def __setattr__(self, name: str, value):
# add a parameter to the module
if isinstance(value, Parameter):
self.register_parameter(name, value)
# add a child module to the module
elif isinstance(value, Module):
self.add_module(name, value)
object.__setattr__(self, name, value)
def __delattr__(self, name):
# delete a parameter from the module
if name in self._parameters:
del self._parameters[name]
# delete a child module from the module
elif name in self.modules:
del self._modules[name]