Source code for flint.nn.modules.module

from collections import OrderedDict
from typing import Optional, Union, Iterator, Tuple, Set
from flint import Tensor
from .. import Parameter

[docs]class Module(object): """ Base class for all modules. Args: name (str): name of the module """ training: bool def __init__(self) -> None: self.training = True self._parameters = OrderedDict() self._modules = OrderedDict()
[docs] def register_parameter(self, name: str, param: Optional[Parameter]) -> None: """ Add a parameter to the module. Args: name (str): name of the parameter param (Parameter): parameter to be added to the module """ if param is None: self._parameters[name] = None else: self._parameters[name] = param
[docs] def add_module(self, name: str, module: Optional['Module']) -> None: """ Add a child module to the current module. Args: name (str): name of the child module module (Module): child module to be added to the module """ if module is None: self._modules[name] = None else: self._modules[name] = module
[docs] def parameters(self, recurse: bool = True) -> Iterator[Parameter]: """ Returns an iterator over module parameters, only yielding the parameter itself. Args: recurse (bool): If ``True``, yields parameters of this module and all submodules. If ``False``, yields only parameters that are direct members of this module. Yields: Parameter: module parameter """ for name, param in self.named_parameters(recurse = recurse): yield param
[docs] def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: """ Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. Adapted from: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py Args: prefix (str): prefix to prepend to all parameter names. recurse (bool): True: yield parameters of this module and all submodules False: yield only parameters that are direct members of this module Yields: (string, Parameter): Tuple containing the name and parameter """ memo = set() modules = self.named_modules(prefix = prefix) if recurse else [(prefix, self)] for module_prefix, module in modules: params = module._parameters.items() for k, v in params: if v is None or v in memo: continue memo.add(v) name = module_prefix + ('.' if module_prefix else '') + k yield name, v
[docs] def children(self) -> Iterator['Module']: """ Returns an iterator over immediate children modules. Yields: module (Module): A child module """ for name, module in self.named_children(): yield module
[docs] def named_children(self) -> Iterator[Tuple[str, 'Module']]: """ Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (string, Module): Tuple containing a name and child module """ memo = set() for name, module in self._modules.items(): if module is not None and module not in memo: memo.add(module) yield name, module
[docs] def modules(self) -> Iterator['Module']: """ Returns an iterator over all modules in the network, only yielding the module itself. yields: Module: a module in the network NOTE: Duplicate modules are returned only once. """ for name, module in self.named_modules(): yield module
[docs] def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '') -> Iterator[Tuple[str, 'Module']]: """ Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Borrowed from: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py Args: memo (Set): a set for recording visited modules prefix (str): prefix to prepend to all parameter names yields: (string, Module): Tuple of name and module NOTE: Duplicate modules are returned only once. """ if memo is None: memo = set() if self not in memo: memo.add(self) yield prefix, self for name, module in self._modules.items(): if module is None: continue submodule_prefix = prefix + ('.' if prefix else '') + name for m in module.named_modules(memo, submodule_prefix): yield m
[docs] def train(self, mode: bool = True) -> 'Module': """ Sets the module in training mode. This has effect only on the following modules: - :class:`flint.nn.Dropout` See their documentations for details of their behaviors in training / evaluation mode. Parameters ---------- mode : bool, optional, default=True Whether to set training mode (``True``) or evaluation mode (``False``) Returns ------- module : Module """ self.training = mode for module in self.children(): module.train(mode) return self
[docs] def eval(self) -> 'Module': """ Sets the module in evaluation mode. This has effect only on the following modules: - :class:`flint.nn.Dropout` See their documentations for details of their behaviors in training / evaluation mode. Returns ------- module : Module """ return self.train(False)
def __call__(self, *input, **kwargs) -> Tensor: out = self.forward(*input, **kwargs) return out def __setattr__(self, name: str, value): # add a parameter to the module if isinstance(value, Parameter): self.register_parameter(name, value) # add a child module to the module elif isinstance(value, Module): self.add_module(name, value) object.__setattr__(self, name, value) def __delattr__(self, name): # delete a parameter from the module if name in self._parameters: del self._parameters[name] # delete a child module from the module elif name in self.modules: del self._modules[name]