Skip to content

ding.torch_utils.network.popart

ding.torch_utils.network.popart

Implementation of POPART algorithm for reward rescale.

POPART is an adaptive normalization algorithm to normalize the targets used in the learning updates. The two main components in POPART are: ART: to update scale and shift such that the return is appropriately normalized, POP: to preserve the outputs of the unnormalized function when we change the scale and shift.

PopArt

Bases: Module

Overview

A linear layer with popart normalization. This class implements a linear transformation followed by PopArt normalization, which is a method to automatically adapt the contribution of each task to the agent's updates in multi-task learning, as described in the paper https://arxiv.org/abs/1809.04474.

Interfaces

__init__, reset_parameters, forward, update_parameters

__init__(input_features=None, output_features=None, beta=0.5)

Overview

Initialize the class with input features, output features, and the beta parameter.

Arguments: - input_features (:obj:Union[int, None]): The size of each input sample. - output_features (:obj:Union[int, None]): The size of each output sample. - beta (:obj:float): The parameter for moving average.

reset_parameters()

Overview

Reset the parameters including weights and bias using kaiming_uniform_ and uniform_ initialization.

forward(x)

Overview

Implement the forward computation of the linear layer and return both the output and the normalized output of the layer.

Arguments: - x (:obj:torch.Tensor): Input tensor which is to be normalized. Returns: - output (:obj:Dict[str, torch.Tensor]): A dictionary contains 'pred' and 'unnormalized_pred'.

update_parameters(value)

Overview

Update the normalization parameters based on the given value and return the new mean and standard deviation after the update.

Arguments: - value (:obj:torch.Tensor): The tensor to be used for updating parameters. Returns: - update_results (:obj:Dict[str, torch.Tensor]): A dictionary contains 'new_mean' and 'new_std'.

Full Source Code

../ding/torch_utils/network/popart.py

1""" 2Implementation of ``POPART`` algorithm for reward rescale. 3<link https://arxiv.org/abs/1602.07714 link> 4 5POPART is an adaptive normalization algorithm to normalize the targets used in the learning updates. 6The two main components in POPART are: 7**ART**: to update scale and shift such that the return is appropriately normalized, 8**POP**: to preserve the outputs of the unnormalized function when we change the scale and shift. 9 10""" 11from typing import Optional, Union, Dict 12import math 13import torch 14import torch.nn as nn 15 16 17class PopArt(nn.Module): 18 """ 19 Overview: 20 A linear layer with popart normalization. This class implements a linear transformation followed by 21 PopArt normalization, which is a method to automatically adapt the contribution of each task to the agent's 22 updates in multi-task learning, as described in the paper <https://arxiv.org/abs/1809.04474>. 23 24 Interfaces: 25 ``__init__``, ``reset_parameters``, ``forward``, ``update_parameters`` 26 """ 27 28 def __init__( 29 self, 30 input_features: Union[int, None] = None, 31 output_features: Union[int, None] = None, 32 beta: float = 0.5 33 ) -> None: 34 """ 35 Overview: 36 Initialize the class with input features, output features, and the beta parameter. 37 Arguments: 38 - input_features (:obj:`Union[int, None]`): The size of each input sample. 39 - output_features (:obj:`Union[int, None]`): The size of each output sample. 40 - beta (:obj:`float`): The parameter for moving average. 41 """ 42 super(PopArt, self).__init__() 43 44 self.beta = beta 45 self.input_features = input_features 46 self.output_features = output_features 47 # Initialize the linear layer parameters, weight and bias. 48 self.weight = nn.Parameter(torch.Tensor(output_features, input_features)) 49 self.bias = nn.Parameter(torch.Tensor(output_features)) 50 # Register a buffer for normalization parameters which can not be considered as model parameters. 51 # The normalization parameters will be used later to save the target value's scale and shift. 52 self.register_buffer('mu', torch.zeros(output_features, requires_grad=False)) 53 self.register_buffer('sigma', torch.ones(output_features, requires_grad=False)) 54 self.register_buffer('v', torch.ones(output_features, requires_grad=False)) 55 56 self.reset_parameters() 57 58 def reset_parameters(self): 59 """ 60 Overview: 61 Reset the parameters including weights and bias using ``kaiming_uniform_`` and ``uniform_`` initialization. 62 """ 63 nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 64 if self.bias is not None: 65 fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 66 bound = 1 / math.sqrt(fan_in) 67 nn.init.uniform_(self.bias, -bound, bound) 68 69 def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: 70 """ 71 Overview: 72 Implement the forward computation of the linear layer and return both the output and the 73 normalized output of the layer. 74 Arguments: 75 - x (:obj:`torch.Tensor`): Input tensor which is to be normalized. 76 Returns: 77 - output (:obj:`Dict[str, torch.Tensor]`): A dictionary contains 'pred' and 'unnormalized_pred'. 78 """ 79 normalized_output = x.mm(self.weight.t()) 80 normalized_output += self.bias.unsqueeze(0).expand_as(normalized_output) 81 # The unnormalization of output 82 with torch.no_grad(): 83 output = normalized_output * self.sigma + self.mu 84 85 return {'pred': normalized_output.squeeze(1), 'unnormalized_pred': output.squeeze(1)} 86 87 def update_parameters(self, value: torch.Tensor) -> Dict[str, torch.Tensor]: 88 """ 89 Overview: 90 Update the normalization parameters based on the given value and return the new mean and 91 standard deviation after the update. 92 Arguments: 93 - value (:obj:`torch.Tensor`): The tensor to be used for updating parameters. 94 Returns: 95 - update_results (:obj:`Dict[str, torch.Tensor]`): A dictionary contains 'new_mean' and 'new_std'. 96 """ 97 # Tensor device conversion of the normalization parameters. 98 self.mu = self.mu.to(value.device) 99 self.sigma = self.sigma.to(value.device) 100 self.v = self.v.to(value.device) 101 102 old_mu = self.mu 103 old_std = self.sigma 104 105 # Calculate the first and second moments (mean and variance) of the target value: 106 batch_mean = torch.mean(value, 0) 107 batch_v = torch.mean(torch.pow(value, 2), 0) 108 batch_mean[torch.isnan(batch_mean)] = self.mu[torch.isnan(batch_mean)] 109 batch_v[torch.isnan(batch_v)] = self.v[torch.isnan(batch_v)] 110 batch_mean = (1 - self.beta) * self.mu + self.beta * batch_mean 111 batch_v = (1 - self.beta) * self.v + self.beta * batch_v 112 batch_std = torch.sqrt(batch_v - (batch_mean ** 2)) 113 # Clip the standard deviation to reject the outlier data. 114 batch_std = torch.clamp(batch_std, min=1e-4, max=1e+6) 115 # Replace the nan value with old value. 116 batch_std[torch.isnan(batch_std)] = self.sigma[torch.isnan(batch_std)] 117 118 self.mu = batch_mean 119 self.v = batch_v 120 self.sigma = batch_std 121 # Update weight and bias with mean and standard deviation to preserve unnormalised outputs 122 self.weight.data = (self.weight.data.t() * old_std / self.sigma).t() 123 self.bias.data = (old_std * self.bias.data + old_mu - self.mu) / self.sigma 124 125 return {'new_mean': batch_mean, 'new_std': batch_std}