Skip to content

ding.rl_utils.happo

ding.rl_utils.happo

happo_error(data, clip_ratio=0.2, use_value_clip=True, dual_clip=None)

Overview

Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip

Arguments: - data (:obj:namedtuple): the ppo input data with fieids shown in ppo_data - clip_ratio (:obj:float): the ppo clip ratio for the constraint of policy update, defaults to 0.2 - use_value_clip (:obj:bool): whether to use clip in value loss with the same ratio as policy - dual_clip (:obj:float): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), defaults to 5.0, if you don't want to use it, set this parameter to None Returns: - happo_loss (:obj:namedtuple): the ppo loss item, all of them are the differentiable 0-dim tensor - happo_info (:obj:namedtuple): the ppo optim information for monitoring, all of them are Python scalar Shapes: - logit_new (:obj:torch.FloatTensor): :math:(B, N), where B is batch size and N is action dim - logit_old (:obj:torch.FloatTensor): :math:(B, N) - action (:obj:torch.LongTensor): :math:(B, ) - value_new (:obj:torch.FloatTensor): :math:(B, ) - value_old (:obj:torch.FloatTensor): :math:(B, ) - adv (:obj:torch.FloatTensor): :math:(B, ) - return (:obj:torch.FloatTensor): :math:(B, ) - weight (:obj:torch.FloatTensor or :obj:None): :math:(B, ) - policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor - value_loss (:obj:torch.FloatTensor): :math:() - entropy_loss (:obj:torch.FloatTensor): :math:() Examples: >>> action_dim = 4 >>> data = happo_data( >>> logit_new=torch.randn(3, action_dim), >>> logit_old=torch.randn(3, action_dim), >>> action=torch.randint(0, action_dim, (3,)), >>> value_new=torch.randn(3), >>> value_old=torch.randn(3), >>> adv=torch.randn(3), >>> return_=torch.randn(3), >>> weight=torch.ones(3), >>> factor=torch.ones(3, 1), >>> ) >>> loss, info = happo_error(data)

.. note::

adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many
ways to calculate this mean and std, like among data buffer or train batch, so we don't couple
this part into happo_error, you can refer to our examples for different ways.

happo_policy_error(data, clip_ratio=0.2, dual_clip=None)

Overview

Get PPO policy loss

Arguments: - data (:obj:namedtuple): ppo input data with fieids shown in ppo_policy_data - clip_ratio (:obj:float): clip value for ratio - dual_clip (:obj:float): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), defaults to 5.0, if you don't want to use it, set this parameter to None Returns: - happo_policy_loss (:obj:namedtuple): the ppo policy loss item, all of them are the differentiable 0-dim tensor. - happo_info (:obj:namedtuple): the ppo optim information for monitoring, all of them are Python scalar Shapes: - logit_new (:obj:torch.FloatTensor): :math:(B, N), where B is batch size and N is action dim - logit_old (:obj:torch.FloatTensor): :math:(B, N) - action (:obj:torch.LongTensor): :math:(B, ) - adv (:obj:torch.FloatTensor): :math:(B, ) - weight (:obj:torch.FloatTensor or :obj:None): :math:(B, ) - policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor - entropy_loss (:obj:torch.FloatTensor): :math:() Examples: >>> action_dim = 4 >>> data = ppo_policy_data( >>> logit_new=torch.randn(3, action_dim), >>> logit_old=torch.randn(3, action_dim), >>> action=torch.randint(0, action_dim, (3,)), >>> adv=torch.randn(3), >>> weight=torch.ones(3), >>> factor=torch.ones(3, 1), >>> ) >>> loss, info = happo_policy_error(data)

happo_value_error(data, clip_ratio=0.2, use_value_clip=True)

Overview

Get PPO value loss

Arguments: - data (:obj:namedtuple): ppo input data with fieids shown in happo_value_data - clip_ratio (:obj:float): clip value for ratio - use_value_clip (:obj:bool): whether use value clip Returns: - value_loss (:obj:torch.FloatTensor): the ppo value loss item, all of them are the differentiable 0-dim tensor Shapes: - value_new (:obj:torch.FloatTensor): :math:(B, ), where B is batch size - value_old (:obj:torch.FloatTensor): :math:(B, ) - return (:obj:torch.FloatTensor): :math:(B, ) - weight (:obj:torch.FloatTensor or :obj:None): :math:(B, ) - value_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor Examples: >>> action_dim = 4 >>> data = happo_value_data( >>> value_new=torch.randn(3), >>> value_old=torch.randn(3), >>> return_=torch.randn(3), >>> weight=torch.ones(3), >>> ) >>> loss, info = happo_value_error(data)

happo_error_continuous(data, clip_ratio=0.2, use_value_clip=True, dual_clip=None)

Overview

Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip

Arguments: - data (:obj:namedtuple): the ppo input data with fieids shown in ppo_data - clip_ratio (:obj:float): the ppo clip ratio for the constraint of policy update, defaults to 0.2 - use_value_clip (:obj:bool): whether to use clip in value loss with the same ratio as policy - dual_clip (:obj:float): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), defaults to 5.0, if you don't want to use it, set this parameter to None Returns: - happo_loss (:obj:namedtuple): the ppo loss item, all of them are the differentiable 0-dim tensor - happo_info (:obj:namedtuple): the ppo optim information for monitoring, all of them are Python scalar Shapes: - mu_sigma_new (:obj:tuple): :math:((B, N), (B, N)), where B is batch size and N is action dim - mu_sigma_old (:obj:tuple): :math:((B, N), (B, N)), where B is batch size and N is action dim - action (:obj:torch.LongTensor): :math:(B, ) - value_new (:obj:torch.FloatTensor): :math:(B, ) - value_old (:obj:torch.FloatTensor): :math:(B, ) - adv (:obj:torch.FloatTensor): :math:(B, ) - return (:obj:torch.FloatTensor): :math:(B, ) - weight (:obj:torch.FloatTensor or :obj:None): :math:(B, ) - policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor - value_loss (:obj:torch.FloatTensor): :math:() - entropy_loss (:obj:torch.FloatTensor): :math:() Examples: >>> action_dim = 4 >>> data = ppo_data_continuous( >>> mu_sigma_new= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)2), >>> mu_sigma_old= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)2), >>> action=torch.randn(3, action_dim), >>> value_new=torch.randn(3), >>> value_old=torch.randn(3), >>> adv=torch.randn(3), >>> return_=torch.randn(3), >>> weight=torch.ones(3), >>> ) >>> loss, info = happo_error(data)

.. note::

adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many
ways to calculate this mean and std, like among data buffer or train batch, so we don't couple
this part into happo_error, you can refer to our examples for different ways.

happo_policy_error_continuous(data, clip_ratio=0.2, dual_clip=None)

Overview

Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip

Arguments: - data (:obj:namedtuple): the ppo input data with fieids shown in ppo_data - clip_ratio (:obj:float): the ppo clip ratio for the constraint of policy update, defaults to 0.2 - dual_clip (:obj:float): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), defaults to 5.0, if you don't want to use it, set this parameter to None Returns: - happo_loss (:obj:namedtuple): the ppo loss item, all of them are the differentiable 0-dim tensor - happo_info (:obj:namedtuple): the ppo optim information for monitoring, all of them are Python scalar Shapes: - mu_sigma_new (:obj:tuple): :math:((B, N), (B, N)), where B is batch size and N is action dim - mu_sigma_old (:obj:tuple): :math:((B, N), (B, N)), where B is batch size and N is action dim - action (:obj:torch.LongTensor): :math:(B, ) - adv (:obj:torch.FloatTensor): :math:(B, ) - weight (:obj:torch.FloatTensor or :obj:None): :math:(B, ) - policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor - entropy_loss (:obj:torch.FloatTensor): :math:() Examples: >>> action_dim = 4 >>> data = ppo_policy_data_continuous( >>> mu_sigma_new=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)2), >>> mu_sigma_old=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)2), >>> action=torch.randn(3, action_dim), >>> adv=torch.randn(3), >>> weight=torch.ones(3), >>> ) >>> loss, info = happo_policy_error_continuous(data)

Full Source Code

../ding/rl_utils/happo.py

1from collections import namedtuple 2from typing import Optional, Tuple 3import torch 4import torch.nn as nn 5from torch.distributions import Independent, Normal 6from ding.hpc_rl import hpc_wrapper 7 8happo_value_data = namedtuple('happo_value_data', ['value_new', 'value_old', 'return_', 'weight']) 9happo_loss = namedtuple('happo_loss', ['policy_loss', 'value_loss', 'entropy_loss']) 10happo_policy_loss = namedtuple('happo_policy_loss', ['policy_loss', 'entropy_loss']) 11happo_info = namedtuple('happo_info', ['approx_kl', 'clipfrac']) 12happo_data = namedtuple( 13 'happo_data', ['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight', 'factor'] 14) 15happo_policy_data = namedtuple('happo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight', 'factor']) 16 17 18def happo_error( 19 data: namedtuple, 20 clip_ratio: float = 0.2, 21 use_value_clip: bool = True, 22 dual_clip: Optional[float] = None, 23) -> Tuple[namedtuple, namedtuple]: 24 """ 25 Overview: 26 Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip 27 Arguments: 28 - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data`` 29 - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2 30 - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy 31 - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ 32 defaults to 5.0, if you don't want to use it, set this parameter to None 33 Returns: 34 - happo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor 35 - happo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar 36 Shapes: 37 - logit_new (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim 38 - logit_old (:obj:`torch.FloatTensor`): :math:`(B, N)` 39 - action (:obj:`torch.LongTensor`): :math:`(B, )` 40 - value_new (:obj:`torch.FloatTensor`): :math:`(B, )` 41 - value_old (:obj:`torch.FloatTensor`): :math:`(B, )` 42 - adv (:obj:`torch.FloatTensor`): :math:`(B, )` 43 - return (:obj:`torch.FloatTensor`): :math:`(B, )` 44 - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` 45 - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor 46 - value_loss (:obj:`torch.FloatTensor`): :math:`()` 47 - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` 48 Examples: 49 >>> action_dim = 4 50 >>> data = happo_data( 51 >>> logit_new=torch.randn(3, action_dim), 52 >>> logit_old=torch.randn(3, action_dim), 53 >>> action=torch.randint(0, action_dim, (3,)), 54 >>> value_new=torch.randn(3), 55 >>> value_old=torch.randn(3), 56 >>> adv=torch.randn(3), 57 >>> return_=torch.randn(3), 58 >>> weight=torch.ones(3), 59 >>> factor=torch.ones(3, 1), 60 >>> ) 61 >>> loss, info = happo_error(data) 62 63 .. note:: 64 65 adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many 66 ways to calculate this mean and std, like among data buffer or train batch, so we don't couple 67 this part into happo_error, you can refer to our examples for different ways. 68 """ 69 assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( 70 dual_clip 71 ) 72 logit_new, logit_old, action, value_new, value_old, adv, return_, weight, factor = data 73 policy_data = happo_policy_data(logit_new, logit_old, action, adv, weight, factor) 74 policy_output, policy_info = happo_policy_error(policy_data, clip_ratio, dual_clip) 75 value_data = happo_value_data(value_new, value_old, return_, weight) 76 value_loss = happo_value_error(value_data, clip_ratio, use_value_clip) 77 78 return happo_loss(policy_output.policy_loss, value_loss, policy_output.entropy_loss), policy_info 79 80 81def happo_policy_error( 82 data: namedtuple, 83 clip_ratio: float = 0.2, 84 dual_clip: Optional[float] = None, 85) -> Tuple[namedtuple, namedtuple]: 86 ''' 87 Overview: 88 Get PPO policy loss 89 Arguments: 90 - data (:obj:`namedtuple`): ppo input data with fieids shown in ``ppo_policy_data`` 91 - clip_ratio (:obj:`float`): clip value for ratio 92 - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ 93 defaults to 5.0, if you don't want to use it, set this parameter to None 94 Returns: 95 - happo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable \ 96 0-dim tensor. 97 - happo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar 98 Shapes: 99 - logit_new (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim 100 - logit_old (:obj:`torch.FloatTensor`): :math:`(B, N)` 101 - action (:obj:`torch.LongTensor`): :math:`(B, )` 102 - adv (:obj:`torch.FloatTensor`): :math:`(B, )` 103 - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` 104 - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor 105 - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` 106 Examples: 107 >>> action_dim = 4 108 >>> data = ppo_policy_data( 109 >>> logit_new=torch.randn(3, action_dim), 110 >>> logit_old=torch.randn(3, action_dim), 111 >>> action=torch.randint(0, action_dim, (3,)), 112 >>> adv=torch.randn(3), 113 >>> weight=torch.ones(3), 114 >>> factor=torch.ones(3, 1), 115 >>> ) 116 >>> loss, info = happo_policy_error(data) 117 ''' 118 logit_new, logit_old, action, adv, weight, factor = data 119 if weight is None: 120 weight = torch.ones_like(adv) 121 dist_new = torch.distributions.categorical.Categorical(logits=logit_new) 122 dist_old = torch.distributions.categorical.Categorical(logits=logit_old) 123 logp_new = dist_new.log_prob(action) 124 logp_old = dist_old.log_prob(action) 125 dist_new_entropy = dist_new.entropy() 126 if dist_new_entropy.shape != weight.shape: 127 dist_new_entropy = dist_new.entropy().mean(dim=1) 128 entropy_loss = (dist_new_entropy * weight).mean() 129 # policy_loss 130 ratio = torch.exp(logp_new - logp_old) 131 if ratio.shape != adv.shape: 132 ratio = ratio.mean(dim=1) 133 surr1 = ratio * adv 134 surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv 135 # shape factor: (B,1) surr1: (B,) 136 clip1 = torch.min(surr1, surr2) * factor.squeeze(1) 137 if dual_clip is not None: 138 clip2 = torch.max(clip1, dual_clip * adv) 139 # only use dual_clip when adv < 0 140 policy_loss = -(torch.where(adv < 0, clip2, clip1) * weight).mean() 141 else: 142 policy_loss = (-clip1 * weight).mean() 143 with torch.no_grad(): 144 approx_kl = (logp_old - logp_new).mean().item() 145 clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) 146 clipfrac = torch.as_tensor(clipped).float().mean().item() 147 return happo_policy_loss(policy_loss, entropy_loss), happo_info(approx_kl, clipfrac) 148 149 150def happo_value_error( 151 data: namedtuple, 152 clip_ratio: float = 0.2, 153 use_value_clip: bool = True, 154) -> torch.Tensor: 155 ''' 156 Overview: 157 Get PPO value loss 158 Arguments: 159 - data (:obj:`namedtuple`): ppo input data with fieids shown in ``happo_value_data`` 160 - clip_ratio (:obj:`float`): clip value for ratio 161 - use_value_clip (:obj:`bool`): whether use value clip 162 Returns: 163 - value_loss (:obj:`torch.FloatTensor`): the ppo value loss item, \ 164 all of them are the differentiable 0-dim tensor 165 Shapes: 166 - value_new (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size 167 - value_old (:obj:`torch.FloatTensor`): :math:`(B, )` 168 - return (:obj:`torch.FloatTensor`): :math:`(B, )` 169 - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` 170 - value_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor 171 Examples: 172 >>> action_dim = 4 173 >>> data = happo_value_data( 174 >>> value_new=torch.randn(3), 175 >>> value_old=torch.randn(3), 176 >>> return_=torch.randn(3), 177 >>> weight=torch.ones(3), 178 >>> ) 179 >>> loss, info = happo_value_error(data) 180 ''' 181 value_new, value_old, return_, weight = data 182 if weight is None: 183 weight = torch.ones_like(value_old) 184 # value_loss 185 if use_value_clip: 186 value_clip = value_old + (value_new - value_old).clamp(-clip_ratio, clip_ratio) 187 v1 = (return_ - value_new).pow(2) 188 v2 = (return_ - value_clip).pow(2) 189 value_loss = 0.5 * (torch.max(v1, v2) * weight).mean() 190 else: 191 value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean() 192 return value_loss 193 194 195def happo_error_continuous( 196 data: namedtuple, 197 clip_ratio: float = 0.2, 198 use_value_clip: bool = True, 199 dual_clip: Optional[float] = None, 200) -> Tuple[namedtuple, namedtuple]: 201 """ 202 Overview: 203 Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip 204 Arguments: 205 - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data`` 206 - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2 207 - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy 208 - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ 209 defaults to 5.0, if you don't want to use it, set this parameter to None 210 Returns: 211 - happo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor 212 - happo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar 213 Shapes: 214 - mu_sigma_new (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim 215 - mu_sigma_old (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim 216 - action (:obj:`torch.LongTensor`): :math:`(B, )` 217 - value_new (:obj:`torch.FloatTensor`): :math:`(B, )` 218 - value_old (:obj:`torch.FloatTensor`): :math:`(B, )` 219 - adv (:obj:`torch.FloatTensor`): :math:`(B, )` 220 - return (:obj:`torch.FloatTensor`): :math:`(B, )` 221 - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` 222 - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor 223 - value_loss (:obj:`torch.FloatTensor`): :math:`()` 224 - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` 225 Examples: 226 >>> action_dim = 4 227 >>> data = ppo_data_continuous( 228 >>> mu_sigma_new= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2), 229 >>> mu_sigma_old= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2), 230 >>> action=torch.randn(3, action_dim), 231 >>> value_new=torch.randn(3), 232 >>> value_old=torch.randn(3), 233 >>> adv=torch.randn(3), 234 >>> return_=torch.randn(3), 235 >>> weight=torch.ones(3), 236 >>> ) 237 >>> loss, info = happo_error(data) 238 239 .. note:: 240 241 adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many 242 ways to calculate this mean and std, like among data buffer or train batch, so we don't couple 243 this part into happo_error, you can refer to our examples for different ways. 244 """ 245 assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( 246 dual_clip 247 ) 248 mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight, factor_batch = data 249 if weight is None: 250 weight = torch.ones_like(adv) 251 252 dist_new = Normal(mu_sigma_new['mu'], mu_sigma_new['sigma']) 253 if len(mu_sigma_old['mu'].shape) == 1: 254 dist_old = Normal(mu_sigma_old['mu'].unsqueeze(-1), mu_sigma_old['sigma'].unsqueeze(-1)) 255 else: 256 dist_old = Normal(mu_sigma_old['mu'], mu_sigma_old['sigma']) 257 logp_new = dist_new.log_prob(action) 258 logp_old = dist_old.log_prob(action) 259 entropy_loss = (dist_new.entropy() * weight.unsqueeze(1)).mean() 260 261 # policy_loss 262 ratio = torch.exp(logp_new - logp_old) 263 ratio = torch.prod(ratio, dim=-1) 264 surr1 = ratio * adv 265 surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv 266 if dual_clip is not None: 267 # shape factor: (B,1) surr1: (B,) 268 policy_loss = (-torch.max(factor_batch.squeeze(1) * torch.min(surr1, surr2), dual_clip * adv) * weight).mean() 269 else: 270 policy_loss = (-factor_batch.squeeze(1) * torch.min(surr1, surr2) * weight).mean() 271 with torch.no_grad(): 272 approx_kl = (logp_old - logp_new).mean().item() 273 clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) 274 clipfrac = torch.as_tensor(clipped).float().mean().item() 275 # value_loss 276 if use_value_clip: 277 value_clip = value_old + (value_new - value_old).clamp(-clip_ratio, clip_ratio) 278 v1 = (return_ - value_new).pow(2) 279 v2 = (return_ - value_clip).pow(2) 280 value_loss = 0.5 * (torch.max(v1, v2) * weight).mean() 281 else: 282 value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean() 283 284 return happo_loss(policy_loss, value_loss, entropy_loss), happo_info(approx_kl, clipfrac) 285 286 287def happo_policy_error_continuous(data: namedtuple, 288 clip_ratio: float = 0.2, 289 dual_clip: Optional[float] = None) -> Tuple[namedtuple, namedtuple]: 290 """ 291 Overview: 292 Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip 293 Arguments: 294 - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data`` 295 - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2 296 - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ 297 defaults to 5.0, if you don't want to use it, set this parameter to None 298 Returns: 299 - happo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor 300 - happo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar 301 Shapes: 302 - mu_sigma_new (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim 303 - mu_sigma_old (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim 304 - action (:obj:`torch.LongTensor`): :math:`(B, )` 305 - adv (:obj:`torch.FloatTensor`): :math:`(B, )` 306 - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` 307 - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor 308 - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` 309 Examples: 310 >>> action_dim = 4 311 >>> data = ppo_policy_data_continuous( 312 >>> mu_sigma_new=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2), 313 >>> mu_sigma_old=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2), 314 >>> action=torch.randn(3, action_dim), 315 >>> adv=torch.randn(3), 316 >>> weight=torch.ones(3), 317 >>> ) 318 >>> loss, info = happo_policy_error_continuous(data) 319 """ 320 assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( 321 dual_clip 322 ) 323 mu_sigma_new, mu_sigma_old, action, adv, weight = data 324 if weight is None: 325 weight = torch.ones_like(adv) 326 327 dist_new = Independent(Normal(mu_sigma_new['mu'], mu_sigma_new['sigma']), 1) 328 if len(mu_sigma_old['mu'].shape) == 1: 329 dist_old = Independent(Normal(mu_sigma_old['mu'].unsqueeze(-1), mu_sigma_old['sigma'].unsqueeze(-1)), 1) 330 else: 331 dist_old = Independent(Normal(mu_sigma_old['mu'], mu_sigma_old['sigma']), 1) 332 logp_new = dist_new.log_prob(action) 333 logp_old = dist_old.log_prob(action) 334 entropy_loss = (dist_new.entropy() * weight).mean() 335 # policy_loss 336 ratio = torch.exp(logp_new - logp_old) 337 surr1 = ratio * adv 338 surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv 339 if dual_clip is not None: 340 policy_loss = (-torch.max(torch.min(surr1, surr2), dual_clip * adv) * weight).mean() 341 else: 342 policy_loss = (-torch.min(surr1, surr2) * weight).mean() 343 with torch.no_grad(): 344 approx_kl = (logp_old - logp_new).mean().item() 345 clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) 346 clipfrac = torch.as_tensor(clipped).float().mean().item() 347 return happo_policy_loss(policy_loss, entropy_loss), happo_info(approx_kl, clipfrac)