Skip to content

ding.torch_utils.parameter

ding.torch_utils.parameter

NonegativeParameter

Bases: Module

Overview

This module will output a non-negative parameter during the forward process.

Interfaces: __init__, forward, set_data.

__init__(data=None, requires_grad=True, delta=1e-08)

Overview

Initialize the NonegativeParameter object using the given arguments.

Arguments: - data (:obj:Optional[torch.Tensor]): The initial value of generated parameter. If set to None, the default value is 0. - requires_grad (:obj:bool): Whether this parameter requires grad. - delta (:obj:Any): The delta of log function.

forward()

Overview

Output the non-negative parameter during the forward process.

Returns: parameter (:obj:torch.Tensor): The generated parameter.

set_data(data)

Overview: Set the value of the non-negative parameter. Arguments: data (:obj:torch.Tensor): The new value of the non-negative parameter.

TanhParameter

Bases: Module

Overview

This module will output a tanh parameter during the forward process.

Interfaces: __init__, forward, set_data.

__init__(data=None, requires_grad=True)

Overview

Initialize the TanhParameter object using the given arguments.

Arguments: - data (:obj:Optional[torch.Tensor]): The initial value of generated parameter. If set to None, the default value is 1. - requires_grad (:obj:bool): Whether this parameter requires grad.

forward()

Overview

Output the tanh parameter during the forward process.

Returns: parameter (:obj:torch.Tensor): The generated parameter.

set_data(data)

Overview

Set the value of the tanh parameter.

Arguments: data (:obj:torch.Tensor): The new value of the tanh parameter.

Full Source Code

../ding/torch_utils/parameter.py

1from typing import Optional 2import torch 3from torch import nn 4from torch.distributions.transforms import TanhTransform 5 6 7class NonegativeParameter(nn.Module): 8 """ 9 Overview: 10 This module will output a non-negative parameter during the forward process. 11 Interfaces: 12 ``__init__``, ``forward``, ``set_data``. 13 """ 14 15 def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = True, delta: float = 1e-8): 16 """ 17 Overview: 18 Initialize the NonegativeParameter object using the given arguments. 19 Arguments: 20 - data (:obj:`Optional[torch.Tensor]`): The initial value of generated parameter. If set to ``None``, the \ 21 default value is 0. 22 - requires_grad (:obj:`bool`): Whether this parameter requires grad. 23 - delta (:obj:`Any`): The delta of log function. 24 """ 25 super().__init__() 26 if data is None: 27 data = torch.zeros(1) 28 self.log_data = nn.Parameter(torch.log(data + delta), requires_grad=requires_grad) 29 30 def forward(self) -> torch.Tensor: 31 """ 32 Overview: 33 Output the non-negative parameter during the forward process. 34 Returns: 35 parameter (:obj:`torch.Tensor`): The generated parameter. 36 """ 37 return torch.exp(self.log_data) 38 39 def set_data(self, data: torch.Tensor) -> None: 40 """ 41 Overview: 42 Set the value of the non-negative parameter. 43 Arguments: 44 data (:obj:`torch.Tensor`): The new value of the non-negative parameter. 45 """ 46 self.log_data = nn.Parameter(torch.log(data + 1e-8), requires_grad=self.log_data.requires_grad) 47 48 49class TanhParameter(nn.Module): 50 """ 51 Overview: 52 This module will output a tanh parameter during the forward process. 53 Interfaces: 54 ``__init__``, ``forward``, ``set_data``. 55 """ 56 57 def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = True): 58 """ 59 Overview: 60 Initialize the TanhParameter object using the given arguments. 61 Arguments: 62 - data (:obj:`Optional[torch.Tensor]`): The initial value of generated parameter. If set to ``None``, the \ 63 default value is 1. 64 - requires_grad (:obj:`bool`): Whether this parameter requires grad. 65 """ 66 super().__init__() 67 if data is None: 68 data = torch.zeros(1) 69 self.transform = TanhTransform(cache_size=1) 70 71 self.data_inv = nn.Parameter(self.transform.inv(data), requires_grad=requires_grad) 72 73 def forward(self) -> torch.Tensor: 74 """ 75 Overview: 76 Output the tanh parameter during the forward process. 77 Returns: 78 parameter (:obj:`torch.Tensor`): The generated parameter. 79 """ 80 return self.transform(self.data_inv) 81 82 def set_data(self, data: torch.Tensor) -> None: 83 """ 84 Overview: 85 Set the value of the tanh parameter. 86 Arguments: 87 data (:obj:`torch.Tensor`): The new value of the tanh parameter. 88 """ 89 self.data_inv = nn.Parameter(self.transform.inv(data), requires_grad=self.data_inv.requires_grad)