ding.rl_utils.upgo¶
ding.rl_utils.upgo
¶
tb_cross_entropy(logit, label, mask=None)
¶
Overview
Compute the cross entropy loss for label and logit, with mask support
Arguments:
- logit (:obj:torch.Tensor): the logit tensor, of size [T, B, N] or [T, B, N, N2]
- label (:obj:torch.Tensor): the label tensor, of size [T, B] or [T, B, N2]
- mask (:obj:torch.Tensor or :obj:None): the mask tensor, of size [T, B] or [T, B, N2]
Returns:
- ce (:obj:torch.Tensor): the computed cross entropy, of size [T, B]
Examples:
>>> T, B, N, N2 = 4, 8, 5, 7
>>> logit = torch.randn(T, B, N, N2).softmax(-1).requires_grad_(True)
>>> action = logit.argmax(-1).detach()
>>> ce = tb_cross_entropy(logit, action)
upgo_returns(rewards, bootstrap_values)
¶
Overview
Computing UPGO return targets. Also notice there is no special handling for the terminal state.
Arguments:
- rewards (:obj:torch.Tensor): the returns from time step 0 to T-1, of size [T_traj, batchsize]
- bootstrap_values (:obj:torch.Tensor): estimation of the state value at step 0 to T, of size [T_traj+1, batchsize]
Returns:
- ret (:obj:torch.Tensor): Computed lambda return value for each state from 0 to T-1, of size [T_traj, batchsize]
Examples:
>>> T, B, N, N2 = 4, 8, 5, 7
>>> rewards = torch.randn(T, B)
>>> bootstrap_values = torch.randn(T + 1, B).requires_grad_(True)
>>> returns = upgo_returns(rewards, bootstrap_values)
upgo_loss(target_output, rhos, action, rewards, bootstrap_values, mask=None)
¶
Overview
Computing UPGO loss given constant gamma and lambda. There is no special handling for terminal state value, if the last state in trajectory is the terminal, just pass a 0 as bootstrap_terminal_value.
Arguments:
- target_output (:obj:torch.Tensor): the output computed by the target policy network, of size [T_traj, batchsize, n_output]
- rhos (:obj:torch.Tensor): the importance sampling ratio, of size [T_traj, batchsize]
- action (:obj:torch.Tensor): the action taken, of size [T_traj, batchsize]
- rewards (:obj:torch.Tensor): the returns from time step 0 to T-1, of size [T_traj, batchsize]
- bootstrap_values (:obj:torch.Tensor): estimation of the state value at step 0 to T, of size [T_traj+1, batchsize]
Returns:
- loss (:obj:torch.Tensor): Computed importance sampled UPGO loss, averaged over the samples, of size []
Examples:
>>> T, B, N, N2 = 4, 8, 5, 7
>>> rhos = torch.randn(T, B)
>>> loss = upgo_loss(logit, rhos, action, rewards, bootstrap_values)
Full Source Code
../ding/rl_utils/upgo.py