ding.rl_utils.happo¶
ding.rl_utils.happo
¶
happo_error(data, clip_ratio=0.2, use_value_clip=True, dual_clip=None)
¶
Overview
Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip
Arguments:
- data (:obj:namedtuple): the ppo input data with fieids shown in ppo_data
- clip_ratio (:obj:float): the ppo clip ratio for the constraint of policy update, defaults to 0.2
- use_value_clip (:obj:bool): whether to use clip in value loss with the same ratio as policy
- dual_clip (:obj:float): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), defaults to 5.0, if you don't want to use it, set this parameter to None
Returns:
- happo_loss (:obj:namedtuple): the ppo loss item, all of them are the differentiable 0-dim tensor
- happo_info (:obj:namedtuple): the ppo optim information for monitoring, all of them are Python scalar
Shapes:
- logit_new (:obj:torch.FloatTensor): :math:(B, N), where B is batch size and N is action dim
- logit_old (:obj:torch.FloatTensor): :math:(B, N)
- action (:obj:torch.LongTensor): :math:(B, )
- value_new (:obj:torch.FloatTensor): :math:(B, )
- value_old (:obj:torch.FloatTensor): :math:(B, )
- adv (:obj:torch.FloatTensor): :math:(B, )
- return (:obj:torch.FloatTensor): :math:(B, )
- weight (:obj:torch.FloatTensor or :obj:None): :math:(B, )
- policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor
- value_loss (:obj:torch.FloatTensor): :math:()
- entropy_loss (:obj:torch.FloatTensor): :math:()
Examples:
>>> action_dim = 4
>>> data = happo_data(
>>> logit_new=torch.randn(3, action_dim),
>>> logit_old=torch.randn(3, action_dim),
>>> action=torch.randint(0, action_dim, (3,)),
>>> value_new=torch.randn(3),
>>> value_old=torch.randn(3),
>>> adv=torch.randn(3),
>>> return_=torch.randn(3),
>>> weight=torch.ones(3),
>>> factor=torch.ones(3, 1),
>>> )
>>> loss, info = happo_error(data)
.. note::
adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many
ways to calculate this mean and std, like among data buffer or train batch, so we don't couple
this part into happo_error, you can refer to our examples for different ways.
happo_policy_error(data, clip_ratio=0.2, dual_clip=None)
¶
Overview
Get PPO policy loss
Arguments:
- data (:obj:namedtuple): ppo input data with fieids shown in ppo_policy_data
- clip_ratio (:obj:float): clip value for ratio
- dual_clip (:obj:float): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), defaults to 5.0, if you don't want to use it, set this parameter to None
Returns:
- happo_policy_loss (:obj:namedtuple): the ppo policy loss item, all of them are the differentiable 0-dim tensor.
- happo_info (:obj:namedtuple): the ppo optim information for monitoring, all of them are Python scalar
Shapes:
- logit_new (:obj:torch.FloatTensor): :math:(B, N), where B is batch size and N is action dim
- logit_old (:obj:torch.FloatTensor): :math:(B, N)
- action (:obj:torch.LongTensor): :math:(B, )
- adv (:obj:torch.FloatTensor): :math:(B, )
- weight (:obj:torch.FloatTensor or :obj:None): :math:(B, )
- policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor
- entropy_loss (:obj:torch.FloatTensor): :math:()
Examples:
>>> action_dim = 4
>>> data = ppo_policy_data(
>>> logit_new=torch.randn(3, action_dim),
>>> logit_old=torch.randn(3, action_dim),
>>> action=torch.randint(0, action_dim, (3,)),
>>> adv=torch.randn(3),
>>> weight=torch.ones(3),
>>> factor=torch.ones(3, 1),
>>> )
>>> loss, info = happo_policy_error(data)
happo_value_error(data, clip_ratio=0.2, use_value_clip=True)
¶
Overview
Get PPO value loss
Arguments:
- data (:obj:namedtuple): ppo input data with fieids shown in happo_value_data
- clip_ratio (:obj:float): clip value for ratio
- use_value_clip (:obj:bool): whether use value clip
Returns:
- value_loss (:obj:torch.FloatTensor): the ppo value loss item, all of them are the differentiable 0-dim tensor
Shapes:
- value_new (:obj:torch.FloatTensor): :math:(B, ), where B is batch size
- value_old (:obj:torch.FloatTensor): :math:(B, )
- return (:obj:torch.FloatTensor): :math:(B, )
- weight (:obj:torch.FloatTensor or :obj:None): :math:(B, )
- value_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor
Examples:
>>> action_dim = 4
>>> data = happo_value_data(
>>> value_new=torch.randn(3),
>>> value_old=torch.randn(3),
>>> return_=torch.randn(3),
>>> weight=torch.ones(3),
>>> )
>>> loss, info = happo_value_error(data)
happo_error_continuous(data, clip_ratio=0.2, use_value_clip=True, dual_clip=None)
¶
Overview
Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip
Arguments:
- data (:obj:namedtuple): the ppo input data with fieids shown in ppo_data
- clip_ratio (:obj:float): the ppo clip ratio for the constraint of policy update, defaults to 0.2
- use_value_clip (:obj:bool): whether to use clip in value loss with the same ratio as policy
- dual_clip (:obj:float): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), defaults to 5.0, if you don't want to use it, set this parameter to None
Returns:
- happo_loss (:obj:namedtuple): the ppo loss item, all of them are the differentiable 0-dim tensor
- happo_info (:obj:namedtuple): the ppo optim information for monitoring, all of them are Python scalar
Shapes:
- mu_sigma_new (:obj:tuple): :math:((B, N), (B, N)), where B is batch size and N is action dim
- mu_sigma_old (:obj:tuple): :math:((B, N), (B, N)), where B is batch size and N is action dim
- action (:obj:torch.LongTensor): :math:(B, )
- value_new (:obj:torch.FloatTensor): :math:(B, )
- value_old (:obj:torch.FloatTensor): :math:(B, )
- adv (:obj:torch.FloatTensor): :math:(B, )
- return (:obj:torch.FloatTensor): :math:(B, )
- weight (:obj:torch.FloatTensor or :obj:None): :math:(B, )
- policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor
- value_loss (:obj:torch.FloatTensor): :math:()
- entropy_loss (:obj:torch.FloatTensor): :math:()
Examples:
>>> action_dim = 4
>>> data = ppo_data_continuous(
>>> mu_sigma_new= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)2),
>>> mu_sigma_old= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)2),
>>> action=torch.randn(3, action_dim),
>>> value_new=torch.randn(3),
>>> value_old=torch.randn(3),
>>> adv=torch.randn(3),
>>> return_=torch.randn(3),
>>> weight=torch.ones(3),
>>> )
>>> loss, info = happo_error(data)
.. note::
adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many
ways to calculate this mean and std, like among data buffer or train batch, so we don't couple
this part into happo_error, you can refer to our examples for different ways.
happo_policy_error_continuous(data, clip_ratio=0.2, dual_clip=None)
¶
Overview
Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip
Arguments:
- data (:obj:namedtuple): the ppo input data with fieids shown in ppo_data
- clip_ratio (:obj:float): the ppo clip ratio for the constraint of policy update, defaults to 0.2
- dual_clip (:obj:float): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), defaults to 5.0, if you don't want to use it, set this parameter to None
Returns:
- happo_loss (:obj:namedtuple): the ppo loss item, all of them are the differentiable 0-dim tensor
- happo_info (:obj:namedtuple): the ppo optim information for monitoring, all of them are Python scalar
Shapes:
- mu_sigma_new (:obj:tuple): :math:((B, N), (B, N)), where B is batch size and N is action dim
- mu_sigma_old (:obj:tuple): :math:((B, N), (B, N)), where B is batch size and N is action dim
- action (:obj:torch.LongTensor): :math:(B, )
- adv (:obj:torch.FloatTensor): :math:(B, )
- weight (:obj:torch.FloatTensor or :obj:None): :math:(B, )
- policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor
- entropy_loss (:obj:torch.FloatTensor): :math:()
Examples:
>>> action_dim = 4
>>> data = ppo_policy_data_continuous(
>>> mu_sigma_new=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)2),
>>> mu_sigma_old=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)2),
>>> action=torch.randn(3, action_dim),
>>> adv=torch.randn(3),
>>> weight=torch.ones(3),
>>> )
>>> loss, info = happo_policy_error_continuous(data)
Full Source Code
../ding/rl_utils/happo.py