Skip to content

ding.rl_utils.vtrace

ding.rl_utils.vtrace

vtrace_nstep_return(clipped_rhos, clipped_cs, reward, bootstrap_values, gamma=0.99, lambda_=0.95)

Overview

Computation of vtrace return.

Returns: - vtrace_return (:obj:torch.FloatTensor): the vtrace loss item, all of them are differentiable 0-dim tensor Shapes: - clipped_rhos (:obj:torch.FloatTensor): :math:(T, B), where T is timestep, B is batch size - clipped_cs (:obj:torch.FloatTensor): :math:(T, B) - reward (:obj:torch.FloatTensor): :math:(T, B) - bootstrap_values (:obj:torch.FloatTensor): :math:(T+1, B) - vtrace_return (:obj:torch.FloatTensor): :math:(T, B)

vtrace_advantage(clipped_pg_rhos, reward, return_, bootstrap_values, gamma)

Overview

Computation of vtrace advantage.

Returns: - vtrace_advantage (:obj:namedtuple): the vtrace loss item, all of them are the differentiable 0-dim tensor Shapes: - clipped_pg_rhos (:obj:torch.FloatTensor): :math:(T, B), where T is timestep, B is batch size - reward (:obj:torch.FloatTensor): :math:(T, B) - return (:obj:torch.FloatTensor): :math:(T, B) - bootstrap_values (:obj:torch.FloatTensor): :math:(T, B) - vtrace_advantage (:obj:torch.FloatTensor): :math:(T, B)

shape_fn_vtrace_discrete_action(args, kwargs)

Overview

Return shape of vtrace for hpc

Returns: shape: [T, B, N]

vtrace_error_discrete_action(data, gamma=0.99, lambda_=0.95, rho_clip_ratio=1.0, c_clip_ratio=1.0, rho_pg_clip_ratio=1.0)

Overview

Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures), (arXiv:1802.01561)

Arguments: - data (:obj:namedtuple): input data with fields shown in vtrace_data - target_output (:obj:torch.Tensor): the output taking the action by the current policy network, usually this output is network output logit - behaviour_output (:obj:torch.Tensor): the output taking the action by the behaviour policy network, usually this output is network output logit, which is used to produce the trajectory(collector) - action (:obj:torch.Tensor): the chosen action(index for the discrete action space) in trajectory, i.e.: behaviour_action - gamma: (:obj:float): the future discount factor, defaults to 0.95 - lambda: (:obj:float): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0 - rho_clip_ratio (:obj:float): the clipping threshold for importance weights (rho) when calculating the baseline targets (vs) - c_clip_ratio (:obj:float): the clipping threshold for importance weights (c) when calculating the baseline targets (vs) - rho_pg_clip_ratio (:obj:float): the clipping threshold for importance weights (rho) when calculating the policy gradient advantage Returns: - trace_loss (:obj:namedtuple): the vtrace loss item, all of them are the differentiable 0-dim tensor Shapes: - target_output (:obj:torch.FloatTensor): :math:(T, B, N), where T is timestep, B is batch size and N is action dim - behaviour_output (:obj:torch.FloatTensor): :math:(T, B, N) - action (:obj:torch.LongTensor): :math:(T, B) - value (:obj:torch.FloatTensor): :math:(T+1, B) - reward (:obj:torch.LongTensor): :math:(T, B) - weight (:obj:torch.LongTensor): :math:(T, B) Examples: >>> T, B, N = 4, 8, 16 >>> value = torch.randn(T + 1, B).requires_grad_(True) >>> reward = torch.rand(T, B) >>> target_output = torch.randn(T, B, N).requires_grad_(True) >>> behaviour_output = torch.randn(T, B, N) >>> action = torch.randint(0, N, size=(T, B)) >>> data = vtrace_data(target_output, behaviour_output, action, value, reward, None) >>> loss = vtrace_error_discrete_action(data, rho_clip_ratio=1.1)

vtrace_error_continuous_action(data, gamma=0.99, lambda_=0.95, rho_clip_ratio=1.0, c_clip_ratio=1.0, rho_pg_clip_ratio=1.0)

Overview

Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures), (arXiv:1802.01561)

Arguments: - data (:obj:namedtuple): input data with fields shown in vtrace_data - target_output (:obj:dict{key:torch.Tensor}): the output taking the action by the current policy network, usually this output is network output, which represents the distribution by reparameterization trick. - behaviour_output (:obj:dict{key:torch.Tensor}): the output taking the action by the behaviour policy network, usually this output is network output logit, which represents the distribution by reparameterization trick. - action (:obj:torch.Tensor): the chosen action(index for the discrete action space) in trajectory, i.e.: behaviour_action - gamma: (:obj:float): the future discount factor, defaults to 0.95 - lambda: (:obj:float): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0 - rho_clip_ratio (:obj:float): the clipping threshold for importance weights (rho) when calculating the baseline targets (vs) - c_clip_ratio (:obj:float): the clipping threshold for importance weights (c) when calculating the baseline targets (vs) - rho_pg_clip_ratio (:obj:float): the clipping threshold for importance weights (rho) when calculating the policy gradient advantage Returns: - trace_loss (:obj:namedtuple): the vtrace loss item, all of them are the differentiable 0-dim tensor Shapes: - target_output (:obj:dict{key:torch.FloatTensor}): :math:(T, B, N), where T is timestep, B is batch size and N is action dim. The keys are usually parameters of reparameterization trick. - behaviour_output (:obj:dict{key:torch.FloatTensor}): :math:(T, B, N) - action (:obj:torch.LongTensor): :math:(T, B) - value (:obj:torch.FloatTensor): :math:(T+1, B) - reward (:obj:torch.LongTensor): :math:(T, B) - weight (:obj:torch.LongTensor): :math:(T, B) Examples: >>> T, B, N = 4, 8, 16 >>> value = torch.randn(T + 1, B).requires_grad_(True) >>> reward = torch.rand(T, B) >>> target_output = dict( >>> 'mu': torch.randn(T, B, N).requires_grad_(True), >>> 'sigma': torch.exp(torch.randn(T, B, N).requires_grad_(True)), >>> ) >>> behaviour_output = dict( >>> 'mu': torch.randn(T, B, N), >>> 'sigma': torch.exp(torch.randn(T, B, N)), >>> ) >>> action = torch.randn((T, B, N)) >>> data = vtrace_data(target_output, behaviour_output, action, value, reward, None) >>> loss = vtrace_error_continuous_action(data, rho_clip_ratio=1.1)

Full Source Code

../ding/rl_utils/vtrace.py

1import torch 2import torch.nn.functional as F 3from torch.distributions import Categorical, Independent, Normal 4from collections import namedtuple 5from .isw import compute_importance_weights 6from ding.hpc_rl import hpc_wrapper 7 8 9def vtrace_nstep_return(clipped_rhos, clipped_cs, reward, bootstrap_values, gamma=0.99, lambda_=0.95): 10 """ 11 Overview: 12 Computation of vtrace return. 13 Returns: 14 - vtrace_return (:obj:`torch.FloatTensor`): the vtrace loss item, all of them are differentiable 0-dim tensor 15 Shapes: 16 - clipped_rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep, B is batch size 17 - clipped_cs (:obj:`torch.FloatTensor`): :math:`(T, B)` 18 - reward (:obj:`torch.FloatTensor`): :math:`(T, B)` 19 - bootstrap_values (:obj:`torch.FloatTensor`): :math:`(T+1, B)` 20 - vtrace_return (:obj:`torch.FloatTensor`): :math:`(T, B)` 21 """ 22 deltas = clipped_rhos * (reward + gamma * bootstrap_values[1:] - bootstrap_values[:-1]) 23 factor = gamma * lambda_ 24 result = bootstrap_values[:-1].clone() 25 vtrace_item = 0. 26 for t in reversed(range(reward.size()[0])): 27 vtrace_item = deltas[t] + factor * clipped_cs[t] * vtrace_item 28 result[t] += vtrace_item 29 return result 30 31 32def vtrace_advantage(clipped_pg_rhos, reward, return_, bootstrap_values, gamma): 33 """ 34 Overview: 35 Computation of vtrace advantage. 36 Returns: 37 - vtrace_advantage (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor 38 Shapes: 39 - clipped_pg_rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep, B is batch size 40 - reward (:obj:`torch.FloatTensor`): :math:`(T, B)` 41 - return (:obj:`torch.FloatTensor`): :math:`(T, B)` 42 - bootstrap_values (:obj:`torch.FloatTensor`): :math:`(T, B)` 43 - vtrace_advantage (:obj:`torch.FloatTensor`): :math:`(T, B)` 44 """ 45 return clipped_pg_rhos * (reward + gamma * return_ - bootstrap_values) 46 47 48vtrace_data = namedtuple('vtrace_data', ['target_output', 'behaviour_output', 'action', 'value', 'reward', 'weight']) 49vtrace_loss = namedtuple('vtrace_loss', ['policy_loss', 'value_loss', 'entropy_loss']) 50 51 52def shape_fn_vtrace_discrete_action(args, kwargs): 53 r""" 54 Overview: 55 Return shape of vtrace for hpc 56 Returns: 57 shape: [T, B, N] 58 """ 59 if len(args) <= 0: 60 tmp = kwargs['data'].target_output.shape 61 else: 62 tmp = args[0].target_output.shape 63 return tmp 64 65 66@hpc_wrapper( 67 shape_fn=shape_fn_vtrace_discrete_action, 68 namedtuple_data=True, 69 include_args=[0, 1, 2, 3, 4, 5], 70 include_kwargs=['data', 'gamma', 'lambda_', 'rho_clip_ratio', 'c_clip_ratio', 'rho_pg_clip_ratio'] 71) 72def vtrace_error_discrete_action( 73 data: namedtuple, 74 gamma: float = 0.99, 75 lambda_: float = 0.95, 76 rho_clip_ratio: float = 1.0, 77 c_clip_ratio: float = 1.0, 78 rho_pg_clip_ratio: float = 1.0 79): 80 """ 81 Overview: 82 Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner\ 83 Architectures), (arXiv:1802.01561) 84 Arguments: 85 - data (:obj:`namedtuple`): input data with fields shown in ``vtrace_data`` 86 - target_output (:obj:`torch.Tensor`): the output taking the action by the current policy network,\ 87 usually this output is network output logit 88 - behaviour_output (:obj:`torch.Tensor`): the output taking the action by the behaviour policy network,\ 89 usually this output is network output logit, which is used to produce the trajectory(collector) 90 - action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory,\ 91 i.e.: behaviour_action 92 - gamma: (:obj:`float`): the future discount factor, defaults to 0.95 93 - lambda: (:obj:`float`): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0 94 - rho_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ 95 the baseline targets (vs) 96 - c_clip_ratio (:obj:`float`): the clipping threshold for importance weights (c) when calculating\ 97 the baseline targets (vs) 98 - rho_pg_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ 99 the policy gradient advantage 100 Returns: 101 - trace_loss (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor 102 Shapes: 103 - target_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where T is timestep, B is batch size and\ 104 N is action dim 105 - behaviour_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)` 106 - action (:obj:`torch.LongTensor`): :math:`(T, B)` 107 - value (:obj:`torch.FloatTensor`): :math:`(T+1, B)` 108 - reward (:obj:`torch.LongTensor`): :math:`(T, B)` 109 - weight (:obj:`torch.LongTensor`): :math:`(T, B)` 110 Examples: 111 >>> T, B, N = 4, 8, 16 112 >>> value = torch.randn(T + 1, B).requires_grad_(True) 113 >>> reward = torch.rand(T, B) 114 >>> target_output = torch.randn(T, B, N).requires_grad_(True) 115 >>> behaviour_output = torch.randn(T, B, N) 116 >>> action = torch.randint(0, N, size=(T, B)) 117 >>> data = vtrace_data(target_output, behaviour_output, action, value, reward, None) 118 >>> loss = vtrace_error_discrete_action(data, rho_clip_ratio=1.1) 119 """ 120 target_output, behaviour_output, action, value, reward, weight = data 121 with torch.no_grad(): 122 IS = compute_importance_weights(target_output, behaviour_output, action, 'discrete') 123 rhos = torch.clamp(IS, max=rho_clip_ratio) 124 cs = torch.clamp(IS, max=c_clip_ratio) 125 return_ = vtrace_nstep_return(rhos, cs, reward, value, gamma, lambda_) 126 pg_rhos = torch.clamp(IS, max=rho_pg_clip_ratio) 127 return_t_plus_1 = torch.cat([return_[1:], value[-1:]], 0) 128 adv = vtrace_advantage(pg_rhos, reward, return_t_plus_1, value[:-1], gamma) 129 130 if weight is None: 131 weight = torch.ones_like(reward) 132 dist_target = Categorical(logits=target_output) 133 pg_loss = -(dist_target.log_prob(action) * adv * weight).mean() 134 value_loss = (F.mse_loss(value[:-1], return_, reduction='none') * weight).mean() 135 entropy_loss = (dist_target.entropy() * weight).mean() 136 return vtrace_loss(pg_loss, value_loss, entropy_loss) 137 138 139def vtrace_error_continuous_action( 140 data: namedtuple, 141 gamma: float = 0.99, 142 lambda_: float = 0.95, 143 rho_clip_ratio: float = 1.0, 144 c_clip_ratio: float = 1.0, 145 rho_pg_clip_ratio: float = 1.0 146): 147 """ 148 Overview: 149 Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner\ 150 Architectures), (arXiv:1802.01561) 151 Arguments: 152 - data (:obj:`namedtuple`): input data with fields shown in ``vtrace_data`` 153 - target_output (:obj:`dict{key:torch.Tensor}`): the output taking the action \ 154 by the current policy network, usually this output is network output, \ 155 which represents the distribution by reparameterization trick. 156 - behaviour_output (:obj:`dict{key:torch.Tensor}`): the output taking the action \ 157 by the behaviour policy network, usually this output is network output logit, \ 158 which represents the distribution by reparameterization trick. 159 - action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory, \ 160 i.e.: behaviour_action 161 - gamma: (:obj:`float`): the future discount factor, defaults to 0.95 162 - lambda: (:obj:`float`): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0 163 - rho_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ 164 the baseline targets (vs) 165 - c_clip_ratio (:obj:`float`): the clipping threshold for importance weights (c) when calculating\ 166 the baseline targets (vs) 167 - rho_pg_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ 168 the policy gradient advantage 169 Returns: 170 - trace_loss (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor 171 Shapes: 172 - target_output (:obj:`dict{key:torch.FloatTensor}`): :math:`(T, B, N)`, \ 173 where T is timestep, B is batch size and \ 174 N is action dim. The keys are usually parameters of reparameterization trick. 175 - behaviour_output (:obj:`dict{key:torch.FloatTensor}`): :math:`(T, B, N)` 176 - action (:obj:`torch.LongTensor`): :math:`(T, B)` 177 - value (:obj:`torch.FloatTensor`): :math:`(T+1, B)` 178 - reward (:obj:`torch.LongTensor`): :math:`(T, B)` 179 - weight (:obj:`torch.LongTensor`): :math:`(T, B)` 180 Examples: 181 >>> T, B, N = 4, 8, 16 182 >>> value = torch.randn(T + 1, B).requires_grad_(True) 183 >>> reward = torch.rand(T, B) 184 >>> target_output = dict( 185 >>> 'mu': torch.randn(T, B, N).requires_grad_(True), 186 >>> 'sigma': torch.exp(torch.randn(T, B, N).requires_grad_(True)), 187 >>> ) 188 >>> behaviour_output = dict( 189 >>> 'mu': torch.randn(T, B, N), 190 >>> 'sigma': torch.exp(torch.randn(T, B, N)), 191 >>> ) 192 >>> action = torch.randn((T, B, N)) 193 >>> data = vtrace_data(target_output, behaviour_output, action, value, reward, None) 194 >>> loss = vtrace_error_continuous_action(data, rho_clip_ratio=1.1) 195 """ 196 target_output, behaviour_output, action, value, reward, weight = data 197 with torch.no_grad(): 198 IS = compute_importance_weights(target_output, behaviour_output, action, 'continuous') 199 rhos = torch.clamp(IS, max=rho_clip_ratio) 200 cs = torch.clamp(IS, max=c_clip_ratio) 201 return_ = vtrace_nstep_return(rhos, cs, reward, value, gamma, lambda_) 202 pg_rhos = torch.clamp(IS, max=rho_pg_clip_ratio) 203 return_t_plus_1 = torch.cat([return_[1:], value[-1:]], 0) 204 adv = vtrace_advantage(pg_rhos, reward, return_t_plus_1, value[:-1], gamma) 205 206 if weight is None: 207 weight = torch.ones_like(reward) 208 dist_target = Independent(Normal(loc=target_output['mu'], scale=target_output['sigma']), 1) 209 pg_loss = -(dist_target.log_prob(action) * adv * weight).mean() 210 value_loss = (F.mse_loss(value[:-1], return_, reduction='none') * weight).mean() 211 entropy_loss = (dist_target.entropy() * weight).mean() 212 return vtrace_loss(pg_loss, value_loss, entropy_loss)