Skip to content

ding.rl_utils.gae

ding.rl_utils.gae

shape_fn_gae(args, kwargs)

Overview

Return shape of gae for hpc

Returns: shape: [T, B]

gae(data, gamma=0.99, lambda_=0.97)

Overview

Implementation of Generalized Advantage Estimator (arXiv:1506.02438)

Arguments: - data (:obj:namedtuple): gae input data with fields ['value', 'reward'], which contains some episodes or trajectories data. - gamma (:obj:float): the future discount factor, should be in [0, 1], defaults to 0.99. - lambda (:obj:float): the gae parameter lambda, should be in [0, 1], defaults to 0.97, when lambda -> 0, it induces bias, but when lambda -> 1, it has high variance due to the sum of terms. Returns: - adv (:obj:torch.FloatTensor): the calculated advantage Shapes: - value (:obj:torch.FloatTensor): :math:(T, B), where T is trajectory length and B is batch size - next_value (:obj:torch.FloatTensor): :math:(T, B) - reward (:obj:torch.FloatTensor): :math:(T, B) - adv (:obj:torch.FloatTensor): :math:(T, B) Examples: >>> value = torch.randn(2, 3) >>> next_value = torch.randn(2, 3) >>> reward = torch.randn(2, 3) >>> data = gae_data(value, next_value, reward, None, None) >>> adv = gae(data)

Full Source Code

../ding/rl_utils/gae.py

1from collections import namedtuple 2import torch 3from ding.hpc_rl import hpc_wrapper 4 5gae_data = namedtuple('gae_data', ['value', 'next_value', 'reward', 'done', 'traj_flag']) 6 7 8def shape_fn_gae(args, kwargs): 9 r""" 10 Overview: 11 Return shape of gae for hpc 12 Returns: 13 shape: [T, B] 14 """ 15 if len(args) <= 0: 16 tmp = kwargs['data'].reward.shape 17 else: 18 tmp = args[0].reward.shape 19 return tmp 20 21 22@hpc_wrapper( 23 shape_fn=shape_fn_gae, namedtuple_data=True, include_args=[0, 1, 2], include_kwargs=['data', 'gamma', 'lambda_'] 24) 25def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.FloatTensor: 26 """ 27 Overview: 28 Implementation of Generalized Advantage Estimator (arXiv:1506.02438) 29 Arguments: 30 - data (:obj:`namedtuple`): gae input data with fields ['value', 'reward'], which contains some episodes or \ 31 trajectories data. 32 - gamma (:obj:`float`): the future discount factor, should be in [0, 1], defaults to 0.99. 33 - lambda (:obj:`float`): the gae parameter lambda, should be in [0, 1], defaults to 0.97, when lambda -> 0, \ 34 it induces bias, but when lambda -> 1, it has high variance due to the sum of terms. 35 Returns: 36 - adv (:obj:`torch.FloatTensor`): the calculated advantage 37 Shapes: 38 - value (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is trajectory length and B is batch size 39 - next_value (:obj:`torch.FloatTensor`): :math:`(T, B)` 40 - reward (:obj:`torch.FloatTensor`): :math:`(T, B)` 41 - adv (:obj:`torch.FloatTensor`): :math:`(T, B)` 42 Examples: 43 >>> value = torch.randn(2, 3) 44 >>> next_value = torch.randn(2, 3) 45 >>> reward = torch.randn(2, 3) 46 >>> data = gae_data(value, next_value, reward, None, None) 47 >>> adv = gae(data) 48 """ 49 value, next_value, reward, done, traj_flag = data 50 if done is None: 51 done = torch.zeros_like(reward, device=reward.device) 52 if traj_flag is None: 53 traj_flag = done 54 done = done.float() 55 traj_flag = traj_flag.float() 56 if len(value.shape) == len(reward.shape) + 1: # for some marl case: value(T, B, A), reward(T, B) 57 reward = reward.unsqueeze(-1) 58 done = done.unsqueeze(-1) 59 traj_flag = traj_flag.unsqueeze(-1) 60 61 next_value *= (1 - done) 62 delta = reward + gamma * next_value - value 63 factor = gamma * lambda_ * (1 - traj_flag) 64 adv = torch.zeros_like(value) 65 gae_item = torch.zeros_like(value[0]) 66 67 for t in reversed(range(reward.shape[0])): 68 gae_item = delta[t] + factor[t] * gae_item 69 adv[t] = gae_item 70 return adv