ding.rl_utils.vtrace¶
ding.rl_utils.vtrace
¶
vtrace_nstep_return(clipped_rhos, clipped_cs, reward, bootstrap_values, gamma=0.99, lambda_=0.95)
¶
Overview
Computation of vtrace return.
Returns:
- vtrace_return (:obj:torch.FloatTensor): the vtrace loss item, all of them are differentiable 0-dim tensor
Shapes:
- clipped_rhos (:obj:torch.FloatTensor): :math:(T, B), where T is timestep, B is batch size
- clipped_cs (:obj:torch.FloatTensor): :math:(T, B)
- reward (:obj:torch.FloatTensor): :math:(T, B)
- bootstrap_values (:obj:torch.FloatTensor): :math:(T+1, B)
- vtrace_return (:obj:torch.FloatTensor): :math:(T, B)
vtrace_advantage(clipped_pg_rhos, reward, return_, bootstrap_values, gamma)
¶
Overview
Computation of vtrace advantage.
Returns:
- vtrace_advantage (:obj:namedtuple): the vtrace loss item, all of them are the differentiable 0-dim tensor
Shapes:
- clipped_pg_rhos (:obj:torch.FloatTensor): :math:(T, B), where T is timestep, B is batch size
- reward (:obj:torch.FloatTensor): :math:(T, B)
- return (:obj:torch.FloatTensor): :math:(T, B)
- bootstrap_values (:obj:torch.FloatTensor): :math:(T, B)
- vtrace_advantage (:obj:torch.FloatTensor): :math:(T, B)
shape_fn_vtrace_discrete_action(args, kwargs)
¶
Overview
Return shape of vtrace for hpc
Returns: shape: [T, B, N]
vtrace_error_discrete_action(data, gamma=0.99, lambda_=0.95, rho_clip_ratio=1.0, c_clip_ratio=1.0, rho_pg_clip_ratio=1.0)
¶
Overview
Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures), (arXiv:1802.01561)
Arguments:
- data (:obj:namedtuple): input data with fields shown in vtrace_data
- target_output (:obj:torch.Tensor): the output taking the action by the current policy network, usually this output is network output logit
- behaviour_output (:obj:torch.Tensor): the output taking the action by the behaviour policy network, usually this output is network output logit, which is used to produce the trajectory(collector)
- action (:obj:torch.Tensor): the chosen action(index for the discrete action space) in trajectory, i.e.: behaviour_action
- gamma: (:obj:float): the future discount factor, defaults to 0.95
- lambda: (:obj:float): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0
- rho_clip_ratio (:obj:float): the clipping threshold for importance weights (rho) when calculating the baseline targets (vs)
- c_clip_ratio (:obj:float): the clipping threshold for importance weights (c) when calculating the baseline targets (vs)
- rho_pg_clip_ratio (:obj:float): the clipping threshold for importance weights (rho) when calculating the policy gradient advantage
Returns:
- trace_loss (:obj:namedtuple): the vtrace loss item, all of them are the differentiable 0-dim tensor
Shapes:
- target_output (:obj:torch.FloatTensor): :math:(T, B, N), where T is timestep, B is batch size and N is action dim
- behaviour_output (:obj:torch.FloatTensor): :math:(T, B, N)
- action (:obj:torch.LongTensor): :math:(T, B)
- value (:obj:torch.FloatTensor): :math:(T+1, B)
- reward (:obj:torch.LongTensor): :math:(T, B)
- weight (:obj:torch.LongTensor): :math:(T, B)
Examples:
>>> T, B, N = 4, 8, 16
>>> value = torch.randn(T + 1, B).requires_grad_(True)
>>> reward = torch.rand(T, B)
>>> target_output = torch.randn(T, B, N).requires_grad_(True)
>>> behaviour_output = torch.randn(T, B, N)
>>> action = torch.randint(0, N, size=(T, B))
>>> data = vtrace_data(target_output, behaviour_output, action, value, reward, None)
>>> loss = vtrace_error_discrete_action(data, rho_clip_ratio=1.1)
vtrace_error_continuous_action(data, gamma=0.99, lambda_=0.95, rho_clip_ratio=1.0, c_clip_ratio=1.0, rho_pg_clip_ratio=1.0)
¶
Overview
Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures), (arXiv:1802.01561)
Arguments:
- data (:obj:namedtuple): input data with fields shown in vtrace_data
- target_output (:obj:dict{key:torch.Tensor}): the output taking the action by the current policy network, usually this output is network output, which represents the distribution by reparameterization trick.
- behaviour_output (:obj:dict{key:torch.Tensor}): the output taking the action by the behaviour policy network, usually this output is network output logit, which represents the distribution by reparameterization trick.
- action (:obj:torch.Tensor): the chosen action(index for the discrete action space) in trajectory, i.e.: behaviour_action
- gamma: (:obj:float): the future discount factor, defaults to 0.95
- lambda: (:obj:float): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0
- rho_clip_ratio (:obj:float): the clipping threshold for importance weights (rho) when calculating the baseline targets (vs)
- c_clip_ratio (:obj:float): the clipping threshold for importance weights (c) when calculating the baseline targets (vs)
- rho_pg_clip_ratio (:obj:float): the clipping threshold for importance weights (rho) when calculating the policy gradient advantage
Returns:
- trace_loss (:obj:namedtuple): the vtrace loss item, all of them are the differentiable 0-dim tensor
Shapes:
- target_output (:obj:dict{key:torch.FloatTensor}): :math:(T, B, N), where T is timestep, B is batch size and N is action dim. The keys are usually parameters of reparameterization trick.
- behaviour_output (:obj:dict{key:torch.FloatTensor}): :math:(T, B, N)
- action (:obj:torch.LongTensor): :math:(T, B)
- value (:obj:torch.FloatTensor): :math:(T+1, B)
- reward (:obj:torch.LongTensor): :math:(T, B)
- weight (:obj:torch.LongTensor): :math:(T, B)
Examples:
>>> T, B, N = 4, 8, 16
>>> value = torch.randn(T + 1, B).requires_grad_(True)
>>> reward = torch.rand(T, B)
>>> target_output = dict(
>>> 'mu': torch.randn(T, B, N).requires_grad_(True),
>>> 'sigma': torch.exp(torch.randn(T, B, N).requires_grad_(True)),
>>> )
>>> behaviour_output = dict(
>>> 'mu': torch.randn(T, B, N),
>>> 'sigma': torch.exp(torch.randn(T, B, N)),
>>> )
>>> action = torch.randn((T, B, N))
>>> data = vtrace_data(target_output, behaviour_output, action, value, reward, None)
>>> loss = vtrace_error_continuous_action(data, rho_clip_ratio=1.1)
Full Source Code
../ding/rl_utils/vtrace.py