Skip to content

ding.rl_utils.ppg

ding.rl_utils.ppg

ppg_joint_error(data, clip_ratio=0.2, use_value_clip=True)

Overview

Get PPG joint loss

Arguments: - data (:obj:namedtuple): ppg input data with fieids shown in ppg_data - clip_ratio (:obj:float): clip value for ratio - use_value_clip (:obj:bool): whether use value clip Returns: - ppg_joint_loss (:obj:namedtuple): the ppg loss item, all of them are the differentiable 0-dim tensor Shapes: - logit_new (:obj:torch.FloatTensor): :math:(B, N), where B is batch size and N is action dim - logit_old (:obj:torch.FloatTensor): :math:(B, N) - action (:obj:torch.LongTensor): :math:(B,) - value_new (:obj:torch.FloatTensor): :math:(B, 1) - value_old (:obj:torch.FloatTensor): :math:(B, 1) - return (:obj:torch.FloatTensor): :math:(B, 1) - weight (:obj:torch.FloatTensor): :math:(B,) - auxiliary_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor - behavioral_cloning_loss (:obj:torch.FloatTensor): :math:() Examples: >>> action_dim = 4 >>> data = ppg_data( >>> logit_new=torch.randn(3, action_dim), >>> logit_old=torch.randn(3, action_dim), >>> action=torch.randint(0, action_dim, (3,)), >>> value_new=torch.randn(3, 1), >>> value_old=torch.randn(3, 1), >>> return_=torch.randn(3, 1), >>> weight=torch.ones(3), >>> ) >>> loss = ppg_joint_error(data, 0.99, 0.99)

Full Source Code

../ding/rl_utils/ppg.py

1from typing import Tuple 2from collections import namedtuple 3import torch 4import torch.nn.functional as F 5 6ppg_data = namedtuple('ppg_data', ['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'return_', 'weight']) 7ppg_joint_loss = namedtuple('ppg_joint_loss', ['auxiliary_loss', 'behavioral_cloning_loss']) 8 9 10def ppg_joint_error( 11 data: namedtuple, 12 clip_ratio: float = 0.2, 13 use_value_clip: bool = True, 14) -> Tuple[namedtuple, namedtuple]: 15 ''' 16 Overview: 17 Get PPG joint loss 18 Arguments: 19 - data (:obj:`namedtuple`): ppg input data with fieids shown in ``ppg_data`` 20 - clip_ratio (:obj:`float`): clip value for ratio 21 - use_value_clip (:obj:`bool`): whether use value clip 22 Returns: 23 - ppg_joint_loss (:obj:`namedtuple`): the ppg loss item, all of them are the differentiable 0-dim tensor 24 Shapes: 25 - logit_new (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim 26 - logit_old (:obj:`torch.FloatTensor`): :math:`(B, N)` 27 - action (:obj:`torch.LongTensor`): :math:`(B,)` 28 - value_new (:obj:`torch.FloatTensor`): :math:`(B, 1)` 29 - value_old (:obj:`torch.FloatTensor`): :math:`(B, 1)` 30 - return (:obj:`torch.FloatTensor`): :math:`(B, 1)` 31 - weight (:obj:`torch.FloatTensor`): :math:`(B,)` 32 - auxiliary_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor 33 - behavioral_cloning_loss (:obj:`torch.FloatTensor`): :math:`()` 34 Examples: 35 >>> action_dim = 4 36 >>> data = ppg_data( 37 >>> logit_new=torch.randn(3, action_dim), 38 >>> logit_old=torch.randn(3, action_dim), 39 >>> action=torch.randint(0, action_dim, (3,)), 40 >>> value_new=torch.randn(3, 1), 41 >>> value_old=torch.randn(3, 1), 42 >>> return_=torch.randn(3, 1), 43 >>> weight=torch.ones(3), 44 >>> ) 45 >>> loss = ppg_joint_error(data, 0.99, 0.99) 46 ''' 47 logit_new, logit_old, action, value_new, value_old, return_, weight = data 48 49 if weight is None: 50 weight = torch.ones_like(return_) 51 52 # auxiliary_loss 53 if use_value_clip: 54 value_clip = value_old + (value_new - value_old).clamp(-clip_ratio, clip_ratio) 55 v1 = (return_ - value_new).pow(2) 56 v2 = (return_ - value_clip).pow(2) 57 auxiliary_loss = 0.5 * (torch.max(v1, v2) * weight).mean() 58 else: 59 auxiliary_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean() 60 61 dist_new = torch.distributions.categorical.Categorical(logits=logit_new) 62 dist_old = torch.distributions.categorical.Categorical(logits=logit_old) 63 logp_new = dist_new.log_prob(action) 64 logp_old = dist_old.log_prob(action) 65 66 # behavioral cloning loss 67 behavioral_cloning_loss = F.kl_div(logp_new, logp_old, reduction='batchmean') 68 69 return ppg_joint_loss(auxiliary_loss, behavioral_cloning_loss)