ding.rl_utils.isw¶
ding.rl_utils.isw
¶
compute_importance_weights(target_output, behaviour_output, action, action_space_type='discrete', requires_grad=False)
¶
Overview
Computing importance sampling weight with given output and action
Arguments:
- target_output (:obj:Union[torch.Tensor,dict]): the output taking the action by the current policy network, usually this output is network output logit if action space is discrete, or is a dict containing parameters of action distribution if action space is continuous.
- behaviour_output (:obj:Union[torch.Tensor,dict]): the output taking the action by the behaviour policy network, usually this output is network output logit, if action space is discrete, or is a dict containing parameters of action distribution if action space is continuous.
- action (:obj:torch.Tensor): the chosen action(index for the discrete action space) in trajectory, i.e.: behaviour_action
- action_space_type (:obj:str): action space types in ['discrete', 'continuous']
- requires_grad (:obj:bool): whether requires grad computation
Returns:
- rhos (:obj:torch.Tensor): Importance sampling weight
Shapes:
- target_output (:obj:Union[torch.FloatTensor,dict]): :math:(T, B, N), where T is timestep, B is batch size and N is action dim
- behaviour_output (:obj:Union[torch.FloatTensor,dict]): :math:(T, B, N)
- action (:obj:torch.LongTensor): :math:(T, B)
- rhos (:obj:torch.FloatTensor): :math:(T, B)
Examples:
>>> target_output = torch.randn(2, 3, 4)
>>> behaviour_output = torch.randn(2, 3, 4)
>>> action = torch.randint(0, 4, (2, 3))
>>> rhos = compute_importance_weights(target_output, behaviour_output, action)
Full Source Code
../ding/rl_utils/isw.py