Skip to content

ding.rl_utils.upgo

ding.rl_utils.upgo

tb_cross_entropy(logit, label, mask=None)

Overview

Compute the cross entropy loss for label and logit, with mask support

Arguments: - logit (:obj:torch.Tensor): the logit tensor, of size [T, B, N] or [T, B, N, N2] - label (:obj:torch.Tensor): the label tensor, of size [T, B] or [T, B, N2] - mask (:obj:torch.Tensor or :obj:None): the mask tensor, of size [T, B] or [T, B, N2] Returns: - ce (:obj:torch.Tensor): the computed cross entropy, of size [T, B] Examples: >>> T, B, N, N2 = 4, 8, 5, 7 >>> logit = torch.randn(T, B, N, N2).softmax(-1).requires_grad_(True) >>> action = logit.argmax(-1).detach() >>> ce = tb_cross_entropy(logit, action)

upgo_returns(rewards, bootstrap_values)

Overview

Computing UPGO return targets. Also notice there is no special handling for the terminal state.

Arguments: - rewards (:obj:torch.Tensor): the returns from time step 0 to T-1, of size [T_traj, batchsize] - bootstrap_values (:obj:torch.Tensor): estimation of the state value at step 0 to T, of size [T_traj+1, batchsize] Returns: - ret (:obj:torch.Tensor): Computed lambda return value for each state from 0 to T-1, of size [T_traj, batchsize] Examples: >>> T, B, N, N2 = 4, 8, 5, 7 >>> rewards = torch.randn(T, B) >>> bootstrap_values = torch.randn(T + 1, B).requires_grad_(True) >>> returns = upgo_returns(rewards, bootstrap_values)

upgo_loss(target_output, rhos, action, rewards, bootstrap_values, mask=None)

Overview

Computing UPGO loss given constant gamma and lambda. There is no special handling for terminal state value, if the last state in trajectory is the terminal, just pass a 0 as bootstrap_terminal_value.

Arguments: - target_output (:obj:torch.Tensor): the output computed by the target policy network, of size [T_traj, batchsize, n_output] - rhos (:obj:torch.Tensor): the importance sampling ratio, of size [T_traj, batchsize] - action (:obj:torch.Tensor): the action taken, of size [T_traj, batchsize] - rewards (:obj:torch.Tensor): the returns from time step 0 to T-1, of size [T_traj, batchsize] - bootstrap_values (:obj:torch.Tensor): estimation of the state value at step 0 to T, of size [T_traj+1, batchsize] Returns: - loss (:obj:torch.Tensor): Computed importance sampled UPGO loss, averaged over the samples, of size [] Examples: >>> T, B, N, N2 = 4, 8, 5, 7 >>> rhos = torch.randn(T, B) >>> loss = upgo_loss(logit, rhos, action, rewards, bootstrap_values)

Full Source Code

../ding/rl_utils/upgo.py

1import torch 2import torch.nn.functional as F 3from ding.hpc_rl import hpc_wrapper 4from .td import generalized_lambda_returns 5 6 7def tb_cross_entropy(logit, label, mask=None): 8 """ 9 Overview: 10 Compute the cross entropy loss for label and logit, with mask support 11 Arguments: 12 - logit (:obj:`torch.Tensor`): the logit tensor, of size [T, B, N] or [T, B, N, N2] 13 - label (:obj:`torch.Tensor`): the label tensor, of size [T, B] or [T, B, N2] 14 - mask (:obj:`torch.Tensor` or :obj:`None`): the mask tensor, of size [T, B] or [T, B, N2] 15 Returns: 16 - ce (:obj:`torch.Tensor`): the computed cross entropy, of size [T, B] 17 Examples: 18 >>> T, B, N, N2 = 4, 8, 5, 7 19 >>> logit = torch.randn(T, B, N, N2).softmax(-1).requires_grad_(True) 20 >>> action = logit.argmax(-1).detach() 21 >>> ce = tb_cross_entropy(logit, action) 22 """ 23 assert (len(label.shape) >= 2) 24 T, B = label.shape[:2] 25 # Special 2D case 26 if len(label.shape) > 2: 27 assert len(label.shape) == 3 28 s, n = logit.shape[-2:] 29 logit = logit.reshape(-1, n) 30 label = label.reshape(-1) 31 ce = -F.cross_entropy(logit, label, reduction='none') 32 ce = ce.view(T * B, -1) 33 if mask is not None: 34 ce *= mask.reshape(-1, s) 35 ce = ce.sum(dim=1) 36 ce = ce.reshape(T, B) 37 else: 38 label = label.reshape(-1) 39 logit = logit.reshape(-1, logit.shape[-1]) 40 ce = -F.cross_entropy(logit, label, reduction='none') 41 ce = ce.reshape(T, B, -1) 42 ce = ce.mean(dim=2) 43 return ce 44 45 46def upgo_returns(rewards: torch.Tensor, bootstrap_values: torch.Tensor) -> torch.Tensor: 47 """ 48 Overview: 49 Computing UPGO return targets. Also notice there is no special handling for the terminal state. 50 Arguments: 51 - rewards (:obj:`torch.Tensor`): the returns from time step 0 to T-1, \ 52 of size [T_traj, batchsize] 53 - bootstrap_values (:obj:`torch.Tensor`): estimation of the state value at step 0 to T, \ 54 of size [T_traj+1, batchsize] 55 Returns: 56 - ret (:obj:`torch.Tensor`): Computed lambda return value for each state from 0 to T-1, \ 57 of size [T_traj, batchsize] 58 Examples: 59 >>> T, B, N, N2 = 4, 8, 5, 7 60 >>> rewards = torch.randn(T, B) 61 >>> bootstrap_values = torch.randn(T + 1, B).requires_grad_(True) 62 >>> returns = upgo_returns(rewards, bootstrap_values) 63 """ 64 # UPGO can be viewed as a lambda return! The trace continues for V_t (i.e. lambda = 1.0) if r_tp1 + V_tp2 > V_tp1. 65 # as the lambdas[-1, :] is ignored in generalized_lambda_returns, we don't care about bootstrap_values_tp2[-1] 66 lambdas = (rewards + bootstrap_values[1:]) >= bootstrap_values[:-1] 67 lambdas = torch.cat([lambdas[1:], torch.ones_like(lambdas[-1:])], dim=0) 68 return generalized_lambda_returns(bootstrap_values, rewards, 1.0, lambdas) 69 70 71@hpc_wrapper( 72 shape_fn=lambda args: args[0].shape, 73 namedtuple_data=True, 74 include_args=5, 75 include_kwargs=['target_output', 'rhos', 'action', 'rewards', 'bootstrap_values'] 76) 77def upgo_loss( 78 target_output: torch.Tensor, 79 rhos: torch.Tensor, 80 action: torch.Tensor, 81 rewards: torch.Tensor, 82 bootstrap_values: torch.Tensor, 83 mask=None 84) -> torch.Tensor: 85 """ 86 Overview: 87 Computing UPGO loss given constant gamma and lambda. There is no special handling for terminal state value, 88 if the last state in trajectory is the terminal, just pass a 0 as bootstrap_terminal_value. 89 Arguments: 90 - target_output (:obj:`torch.Tensor`): the output computed by the target policy network, \ 91 of size [T_traj, batchsize, n_output] 92 - rhos (:obj:`torch.Tensor`): the importance sampling ratio, of size [T_traj, batchsize] 93 - action (:obj:`torch.Tensor`): the action taken, of size [T_traj, batchsize] 94 - rewards (:obj:`torch.Tensor`): the returns from time step 0 to T-1, of size [T_traj, batchsize] 95 - bootstrap_values (:obj:`torch.Tensor`): estimation of the state value at step 0 to T, \ 96 of size [T_traj+1, batchsize] 97 Returns: 98 - loss (:obj:`torch.Tensor`): Computed importance sampled UPGO loss, averaged over the samples, of size [] 99 Examples: 100 >>> T, B, N, N2 = 4, 8, 5, 7 101 >>> rhos = torch.randn(T, B) 102 >>> loss = upgo_loss(logit, rhos, action, rewards, bootstrap_values) 103 """ 104 # discard the value at T as it should be considered in the next slice 105 with torch.no_grad(): 106 returns = upgo_returns(rewards, bootstrap_values) 107 advantages = rhos * (returns - bootstrap_values[:-1]) 108 metric = tb_cross_entropy(target_output, action, mask) 109 assert (metric.shape == action.shape[:2]) 110 losses = advantages * metric 111 return -losses.mean()