ding.rl_utils.retrace¶
ding.rl_utils.retrace
¶
compute_q_retraces(q_values, v_pred, rewards, actions, weights, ratio, gamma=0.9)
¶
Shapes
- q_values (:obj:
torch.Tensor): :math:(T + 1, B, N), where T is unroll_len, B is batch size, N is discrete action dim. - v_pred (:obj:
torch.Tensor): :math:(T + 1, B, 1) - rewards (:obj:
torch.Tensor): :math:(T, B) - actions (:obj:
torch.Tensor): :math:(T, B) - weights (:obj:
torch.Tensor): :math:(T, B) - ratio (:obj:
torch.Tensor): :math:(T, B, N) - q_retraces (:obj:
torch.Tensor): :math:(T + 1, B, 1)
Examples: >>> T=2 >>> B=3 >>> N=4 >>> q_values=torch.randn(T+1, B, N) >>> v_pred=torch.randn(T+1, B, 1) >>> rewards=torch.randn(T, B) >>> actions=torch.randint(0, N, (T, B)) >>> weights=torch.ones(T, B) >>> ratio=torch.randn(T, B, N) >>> q_retraces = compute_q_retraces(q_values, v_pred, rewards, actions, weights, ratio)
.. note:: q_retrace operation doesn't need to compute gradient, just executes forward computation.
Full Source Code
../ding/rl_utils/retrace.py
1import torch 2import torch.nn.functional as F 3from collections import namedtuple 4from ding.rl_utils.isw import compute_importance_weights 5 6 7def compute_q_retraces( 8 q_values: torch.Tensor, 9 v_pred: torch.Tensor, 10 rewards: torch.Tensor, 11 actions: torch.Tensor, 12 weights: torch.Tensor, 13 ratio: torch.Tensor, 14 gamma: float = 0.9 15) -> torch.Tensor: 16 """ 17 Shapes: 18 - q_values (:obj:`torch.Tensor`): :math:`(T + 1, B, N)`, where T is unroll_len, B is batch size, N is discrete \ 19 action dim. 20 - v_pred (:obj:`torch.Tensor`): :math:`(T + 1, B, 1)` 21 - rewards (:obj:`torch.Tensor`): :math:`(T, B)` 22 - actions (:obj:`torch.Tensor`): :math:`(T, B)` 23 - weights (:obj:`torch.Tensor`): :math:`(T, B)` 24 - ratio (:obj:`torch.Tensor`): :math:`(T, B, N)` 25 - q_retraces (:obj:`torch.Tensor`): :math:`(T + 1, B, 1)` 26 Examples: 27 >>> T=2 28 >>> B=3 29 >>> N=4 30 >>> q_values=torch.randn(T+1, B, N) 31 >>> v_pred=torch.randn(T+1, B, 1) 32 >>> rewards=torch.randn(T, B) 33 >>> actions=torch.randint(0, N, (T, B)) 34 >>> weights=torch.ones(T, B) 35 >>> ratio=torch.randn(T, B, N) 36 >>> q_retraces = compute_q_retraces(q_values, v_pred, rewards, actions, weights, ratio) 37 38 .. note:: 39 q_retrace operation doesn't need to compute gradient, just executes forward computation. 40 """ 41 T = q_values.size()[0] - 1 42 rewards = rewards.unsqueeze(-1) 43 actions = actions.unsqueeze(-1) 44 weights = weights.unsqueeze(-1) 45 q_retraces = torch.zeros_like(v_pred) # shape (T+1),B,1 46 tmp_retraces = v_pred[-1] # shape B,1 47 q_retraces[-1] = v_pred[-1] 48 49 q_gather = torch.zeros_like(v_pred) 50 q_gather[0:-1] = q_values[0:-1].gather(-1, actions) # shape (T+1),B,1 51 ratio_gather = ratio.gather(-1, actions) # shape T,B,1 52 53 for idx in reversed(range(T)): 54 q_retraces[idx] = rewards[idx] + gamma * weights[idx] * tmp_retraces 55 tmp_retraces = ratio_gather[idx].clamp(max=1.0) * (q_retraces[idx] - q_gather[idx]) + v_pred[idx] 56 return q_retraces # shape (T+1),B,1