Skip to content

ding.rl_utils.isw

ding.rl_utils.isw

compute_importance_weights(target_output, behaviour_output, action, action_space_type='discrete', requires_grad=False)

Overview

Computing importance sampling weight with given output and action

Arguments: - target_output (:obj:Union[torch.Tensor,dict]): the output taking the action by the current policy network, usually this output is network output logit if action space is discrete, or is a dict containing parameters of action distribution if action space is continuous. - behaviour_output (:obj:Union[torch.Tensor,dict]): the output taking the action by the behaviour policy network, usually this output is network output logit, if action space is discrete, or is a dict containing parameters of action distribution if action space is continuous. - action (:obj:torch.Tensor): the chosen action(index for the discrete action space) in trajectory, i.e.: behaviour_action - action_space_type (:obj:str): action space types in ['discrete', 'continuous'] - requires_grad (:obj:bool): whether requires grad computation Returns: - rhos (:obj:torch.Tensor): Importance sampling weight Shapes: - target_output (:obj:Union[torch.FloatTensor,dict]): :math:(T, B, N), where T is timestep, B is batch size and N is action dim - behaviour_output (:obj:Union[torch.FloatTensor,dict]): :math:(T, B, N) - action (:obj:torch.LongTensor): :math:(T, B) - rhos (:obj:torch.FloatTensor): :math:(T, B) Examples: >>> target_output = torch.randn(2, 3, 4) >>> behaviour_output = torch.randn(2, 3, 4) >>> action = torch.randint(0, 4, (2, 3)) >>> rhos = compute_importance_weights(target_output, behaviour_output, action)

Full Source Code

../ding/rl_utils/isw.py

1from typing import Union 2import torch 3from torch.distributions import Categorical, Independent, Normal 4 5 6def compute_importance_weights( 7 target_output: Union[torch.Tensor, dict], 8 behaviour_output: Union[torch.Tensor, dict], 9 action: torch.Tensor, 10 action_space_type: str = 'discrete', 11 requires_grad: bool = False 12): 13 """ 14 Overview: 15 Computing importance sampling weight with given output and action 16 Arguments: 17 - target_output (:obj:`Union[torch.Tensor,dict]`): the output taking the action \ 18 by the current policy network, \ 19 usually this output is network output logit if action space is discrete, \ 20 or is a dict containing parameters of action distribution if action space is continuous. 21 - behaviour_output (:obj:`Union[torch.Tensor,dict]`): the output taking the action \ 22 by the behaviour policy network,\ 23 usually this output is network output logit, if action space is discrete, \ 24 or is a dict containing parameters of action distribution if action space is continuous. 25 - action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory,\ 26 i.e.: behaviour_action 27 - action_space_type (:obj:`str`): action space types in ['discrete', 'continuous'] 28 - requires_grad (:obj:`bool`): whether requires grad computation 29 Returns: 30 - rhos (:obj:`torch.Tensor`): Importance sampling weight 31 Shapes: 32 - target_output (:obj:`Union[torch.FloatTensor,dict]`): :math:`(T, B, N)`, \ 33 where T is timestep, B is batch size and N is action dim 34 - behaviour_output (:obj:`Union[torch.FloatTensor,dict]`): :math:`(T, B, N)` 35 - action (:obj:`torch.LongTensor`): :math:`(T, B)` 36 - rhos (:obj:`torch.FloatTensor`): :math:`(T, B)` 37 Examples: 38 >>> target_output = torch.randn(2, 3, 4) 39 >>> behaviour_output = torch.randn(2, 3, 4) 40 >>> action = torch.randint(0, 4, (2, 3)) 41 >>> rhos = compute_importance_weights(target_output, behaviour_output, action) 42 """ 43 grad_context = torch.enable_grad() if requires_grad else torch.no_grad() 44 assert isinstance(action, torch.Tensor) 45 assert action_space_type in ['discrete', 'continuous'] 46 47 with grad_context: 48 if action_space_type == 'continuous': 49 dist_target = Independent(Normal(loc=target_output['mu'], scale=target_output['sigma']), 1) 50 dist_behaviour = Independent(Normal(loc=behaviour_output['mu'], scale=behaviour_output['sigma']), 1) 51 rhos = dist_target.log_prob(action) - dist_behaviour.log_prob(action) 52 rhos = torch.exp(rhos) 53 return rhos 54 elif action_space_type == 'discrete': 55 dist_target = Categorical(logits=target_output) 56 dist_behaviour = Categorical(logits=behaviour_output) 57 rhos = dist_target.log_prob(action) - dist_behaviour.log_prob(action) 58 rhos = torch.exp(rhos) 59 return rhos