Skip to content

ding.rl_utils.td

ding.rl_utils.td

q_1step_td_error(data, gamma, criterion=nn.MSELoss(reduction='none'))

Overview

1 step td_error, support single agent case and multi agent case.

Arguments: - data (:obj:q_1step_td_data): The input data, q_1step_td_data to calculate loss - gamma (:obj:float): Discount factor - criterion (:obj:torch.nn.modules): Loss function criterion Returns: - loss (:obj:torch.Tensor): 1step td error Shapes: - data (:obj:q_1step_td_data): the q_1step_td_data containing ['q', 'next_q', 'act', 'next_act', 'reward', 'done', 'weight'] - q (:obj:torch.FloatTensor): :math:(B, N) i.e. [batch_size, action_dim] - next_q (:obj:torch.FloatTensor): :math:(B, N) i.e. [batch_size, action_dim] - act (:obj:torch.LongTensor): :math:(B, ) - next_act (:obj:torch.LongTensor): :math:(B, ) - reward (:obj:torch.FloatTensor): :math:( , B) - done (:obj:torch.BoolTensor) :math:(B, ), whether done in last timestep - weight (:obj:torch.FloatTensor or None): :math:(B, ), the training sample weight Examples: >>> action_dim = 4 >>> data = q_1step_td_data( >>> q=torch.randn(3, action_dim), >>> next_q=torch.randn(3, action_dim), >>> act=torch.randint(0, action_dim, (3,)), >>> next_act=torch.randint(0, action_dim, (3,)), >>> reward=torch.randn(3), >>> done=torch.randint(0, 2, (3,)).bool(), >>> weight=torch.ones(3), >>> ) >>> loss = q_1step_td_error(data, 0.99)

m_q_1step_td_error(data, gamma, tau, alpha, criterion=nn.MSELoss(reduction='none'))

Overview

Munchausen td_error for DQN algorithm, support 1 step td error.

Arguments: - data (:obj:m_q_1step_td_data): The input data, m_q_1step_td_data to calculate loss - gamma (:obj:float): Discount factor - tau (:obj:float): Entropy factor for Munchausen DQN - alpha (:obj:float): Discount factor for Munchausen term - criterion (:obj:torch.nn.modules): Loss function criterion Returns: - loss (:obj:torch.Tensor): 1step td error, 0-dim tensor Shapes: - data (:obj:m_q_1step_td_data): the m_q_1step_td_data containing ['q', 'target_q', 'next_q', 'act', 'reward', 'done', 'weight'] - q (:obj:torch.FloatTensor): :math:(B, N) i.e. [batch_size, action_dim] - target_q (:obj:torch.FloatTensor): :math:(B, N) i.e. [batch_size, action_dim] - next_q (:obj:torch.FloatTensor): :math:(B, N) i.e. [batch_size, action_dim] - act (:obj:torch.LongTensor): :math:(B, ) - reward (:obj:torch.FloatTensor): :math:( , B) - done (:obj:torch.BoolTensor) :math:(B, ), whether done in last timestep - weight (:obj:torch.FloatTensor or None): :math:(B, ), the training sample weight Examples: >>> action_dim = 4 >>> data = m_q_1step_td_data( >>> q=torch.randn(3, action_dim), >>> target_q=torch.randn(3, action_dim), >>> next_q=torch.randn(3, action_dim), >>> act=torch.randint(0, action_dim, (3,)), >>> reward=torch.randn(3), >>> done=torch.randint(0, 2, (3,)), >>> weight=torch.ones(3), >>> ) >>> loss = m_q_1step_td_error(data, 0.99, 0.01, 0.01)

q_v_1step_td_error(data, gamma, criterion=nn.MSELoss(reduction='none'))

Overview

td_error between q and v value for SAC algorithm, support 1 step td error.

Arguments: - data (:obj:q_v_1step_td_data): The input data, q_v_1step_td_data to calculate loss - gamma (:obj:float): Discount factor - criterion (:obj:torch.nn.modules): Loss function criterion Returns: - loss (:obj:torch.Tensor): 1step td error, 0-dim tensor Shapes: - data (:obj:q_v_1step_td_data): the q_v_1step_td_data containing ['q', 'v', 'act', 'reward', 'done', 'weight'] - q (:obj:torch.FloatTensor): :math:(B, N) i.e. [batch_size, action_dim] - v (:obj:torch.FloatTensor): :math:(B, ) - act (:obj:torch.LongTensor): :math:(B, ) - reward (:obj:torch.FloatTensor): :math:( , B) - done (:obj:torch.BoolTensor) :math:(B, ), whether done in last timestep - weight (:obj:torch.FloatTensor or None): :math:(B, ), the training sample weight Examples: >>> action_dim = 4 >>> data = q_v_1step_td_data( >>> q=torch.randn(3, action_dim), >>> v=torch.randn(3), >>> act=torch.randint(0, action_dim, (3,)), >>> reward=torch.randn(3), >>> done=torch.randint(0, 2, (3,)), >>> weight=torch.ones(3), >>> ) >>> loss = q_v_1step_td_error(data, 0.99)

nstep_return(data, gamma, nstep, value_gamma=None)

Overview

Calculate nstep return for DQN algorithm, support single agent case and multi agent case.

Arguments: - data (:obj:nstep_return_data): The input data, nstep_return_data to calculate loss - gamma (:obj:float): Discount factor - nstep (:obj:int): nstep num - value_gamma (:obj:torch.Tensor): Discount factor for value Returns: - return (:obj:torch.Tensor): nstep return Shapes: - data (:obj:nstep_return_data): the nstep_return_data containing ['reward', 'next_value', 'done'] - reward (:obj:torch.FloatTensor): :math:(T, B), where T is timestep(nstep) - next_value (:obj:torch.FloatTensor): :math:(, B) - done (:obj:torch.BoolTensor) :math:(B, ), whether done in last timestep Examples: >>> data = nstep_return_data( >>> reward=torch.randn(3, 3), >>> next_value=torch.randn(3), >>> done=torch.randint(0, 2, (3,)), >>> ) >>> loss = nstep_return(data, 0.99, 3)

dist_1step_td_error(data, gamma, v_min, v_max, n_atom)

Overview

1 step td_error for distributed q-learning based algorithm

Arguments: - data (:obj:dist_1step_td_data): The input data, dist_nstep_td_data to calculate loss - gamma (:obj:float): Discount factor - v_min (:obj:float): The min value of support - v_max (:obj:float): The max value of support - n_atom (:obj:int): The num of atom Returns: - loss (:obj:torch.Tensor): nstep td error, 0-dim tensor Shapes: - data (:obj:dist_1step_td_data): the dist_1step_td_data containing ['dist', 'next_n_dist', 'act', 'reward', 'done', 'weight'] - dist (:obj:torch.FloatTensor): :math:(B, N, n_atom) i.e. [batch_size, action_dim, n_atom] - next_dist (:obj:torch.FloatTensor): :math:(B, N, n_atom) - act (:obj:torch.LongTensor): :math:(B, ) - next_act (:obj:torch.LongTensor): :math:(B, ) - reward (:obj:torch.FloatTensor): :math:(, B) - done (:obj:torch.BoolTensor) :math:(B, ), whether done in last timestep - weight (:obj:torch.FloatTensor or None): :math:(B, ), the training sample weight Examples: >>> dist = torch.randn(4, 3, 51).abs().requires_grad_(True) >>> next_dist = torch.randn(4, 3, 51).abs() >>> act = torch.randint(0, 3, (4,)) >>> next_act = torch.randint(0, 3, (4,)) >>> reward = torch.randn(4) >>> done = torch.randint(0, 2, (4,)) >>> data = dist_1step_td_data(dist, next_dist, act, next_act, reward, done, None) >>> loss = dist_1step_td_error(data, 0.99, -10.0, 10.0, 51)

shape_fn_dntd(args, kwargs)

Overview

Return dntd shape for hpc

Returns: shape: [T, B, N, n_atom]

dist_nstep_td_error(data, gamma, v_min, v_max, n_atom, nstep=1, value_gamma=None)

Overview

Multistep (1 step or n step) td_error for distributed q-learning based algorithm, support single agent case and multi agent case.

Arguments: - data (:obj:dist_nstep_td_data): The input data, dist_nstep_td_data to calculate loss - gamma (:obj:float): Discount factor - nstep (:obj:int): nstep num, default set to 1 Returns: - loss (:obj:torch.Tensor): nstep td error, 0-dim tensor Shapes: - data (:obj:dist_nstep_td_data): the dist_nstep_td_data containing ['dist', 'next_n_dist', 'act', 'reward', 'done', 'weight'] - dist (:obj:torch.FloatTensor): :math:(B, N, n_atom) i.e. [batch_size, action_dim, n_atom] - next_n_dist (:obj:torch.FloatTensor): :math:(B, N, n_atom) - act (:obj:torch.LongTensor): :math:(B, ) - next_n_act (:obj:torch.LongTensor): :math:(B, ) - reward (:obj:torch.FloatTensor): :math:(T, B), where T is timestep(nstep) - done (:obj:torch.BoolTensor) :math:(B, ), whether done in last timestep Examples: >>> dist = torch.randn(4, 3, 51).abs().requires_grad_(True) >>> next_n_dist = torch.randn(4, 3, 51).abs() >>> done = torch.randn(4) >>> action = torch.randint(0, 3, size=(4, )) >>> next_action = torch.randint(0, 3, size=(4, )) >>> reward = torch.randn(5, 4) >>> data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None) >>> loss, _ = dist_nstep_td_error(data, 0.95, -10.0, 10.0, 51, 5)

v_1step_td_error(data, gamma, criterion=nn.MSELoss(reduction='none'))

Overview

1 step td_error for distributed value based algorithm

Arguments: - data (:obj:v_1step_td_data): The input data, v_1step_td_data to calculate loss - gamma (:obj:float): Discount factor - criterion (:obj:torch.nn.modules): Loss function criterion Returns: - loss (:obj:torch.Tensor): 1step td error, 0-dim tensor Shapes: - data (:obj:v_1step_td_data): the v_1step_td_data containing ['v', 'next_v', 'reward', 'done', 'weight'] - v (:obj:torch.FloatTensor): :math:(B, ) i.e. [batch_size, ] - next_v (:obj:torch.FloatTensor): :math:(B, ) - reward (:obj:torch.FloatTensor): :math:(, B) - done (:obj:torch.BoolTensor) :math:(B, ), whether done in last timestep - weight (:obj:torch.FloatTensor or None): :math:(B, ), the training sample weight Examples: >>> v = torch.randn(5).requires_grad_(True) >>> next_v = torch.randn(5) >>> reward = torch.rand(5) >>> done = torch.zeros(5) >>> data = v_1step_td_data(v, next_v, reward, done, None) >>> loss, td_error_per_sample = v_1step_td_error(data, 0.99)

v_nstep_td_error(data, gamma, nstep=1, criterion=nn.MSELoss(reduction='none'))

Overview

Multistep (n step) td_error for distributed value based algorithm

Arguments: - data (:obj:dist_nstep_td_data): The input data, v_nstep_td_data to calculate loss - gamma (:obj:float): Discount factor - nstep (:obj:int): nstep num, default set to 1 Returns: - loss (:obj:torch.Tensor): nstep td error, 0-dim tensor Shapes: - data (:obj:dist_nstep_td_data): The v_nstep_td_data containing ['v', 'next_n_v', 'reward', 'done', 'weight', 'value_gamma'] - v (:obj:torch.FloatTensor): :math:(B, ) i.e. [batch_size, ] - next_v (:obj:torch.FloatTensor): :math:(B, ) - reward (:obj:torch.FloatTensor): :math:(T, B), where T is timestep(nstep) - done (:obj:torch.BoolTensor) :math:(B, ), whether done in last timestep - weight (:obj:torch.FloatTensor or None): :math:(B, ), the training sample weight - value_gamma (:obj:torch.Tensor): If the remaining data in the buffer is less than n_step we use value_gamma as the gamma discount value for next_v rather than gamma**n_step Examples: >>> v = torch.randn(5).requires_grad_(True) >>> next_v = torch.randn(5) >>> reward = torch.rand(5, 5) >>> done = torch.zeros(5) >>> data = v_nstep_td_data(v, next_v, reward, done, 0.9, 0.99) >>> loss, td_error_per_sample = v_nstep_td_error(data, 0.99, 5)

shape_fn_qntd(args, kwargs)

Overview

Return qntd shape for hpc

Returns: shape: [T, B, N]

q_nstep_td_error(data, gamma, nstep=1, cum_reward=False, value_gamma=None, criterion=nn.MSELoss(reduction='none'))

Overview

Multistep (1 step or n step) td_error for q-learning based algorithm

Arguments: - data (:obj:q_nstep_td_data): The input data, q_nstep_td_data to calculate loss - gamma (:obj:float): Discount factor - cum_reward (:obj:bool): Whether to use cumulative nstep reward, which is figured out when collecting data - value_gamma (:obj:torch.Tensor): Gamma discount value for target q_value - criterion (:obj:torch.nn.modules): Loss function criterion - nstep (:obj:int): nstep num, default set to 1 Returns: - loss (:obj:torch.Tensor): nstep td error, 0-dim tensor - td_error_per_sample (:obj:torch.Tensor): nstep td error, 1-dim tensor Shapes: - data (:obj:q_nstep_td_data): The q_nstep_td_data containing ['q', 'next_n_q', 'action', 'reward', 'done'] - q (:obj:torch.FloatTensor): :math:(B, N) i.e. [batch_size, action_dim] - next_n_q (:obj:torch.FloatTensor): :math:(B, N) - action (:obj:torch.LongTensor): :math:(B, ) - next_n_action (:obj:torch.LongTensor): :math:(B, ) - reward (:obj:torch.FloatTensor): :math:(T, B), where T is timestep(nstep) - done (:obj:torch.BoolTensor) :math:(B, ), whether done in last timestep - td_error_per_sample (:obj:torch.FloatTensor): :math:(B, ) Examples: >>> next_q = torch.randn(4, 3) >>> done = torch.randn(4) >>> action = torch.randint(0, 3, size=(4, )) >>> next_action = torch.randint(0, 3, size=(4, )) >>> nstep =3 >>> q = torch.randn(4, 3).requires_grad_(True) >>> reward = torch.rand(nstep, 4) >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) >>> loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep)

bdq_nstep_td_error(data, gamma, nstep=1, cum_reward=False, value_gamma=None, criterion=nn.MSELoss(reduction='none'))

Overview

Multistep (1 step or n step) td_error for BDQ algorithm, referenced paper "Action Branching Architectures for Deep Reinforcement Learning", link: https://arxiv.org/pdf/1711.08946. In fact, the original paper only provides the 1-step TD-error calculation method, and here we extend the calculation method of n-step, i.e., TD-error:

Arguments: - data (:obj:q_nstep_td_data): The input data, q_nstep_td_data to calculate loss - gamma (:obj:float): Discount factor - cum_reward (:obj:bool): Whether to use cumulative nstep reward, which is figured out when collecting data - value_gamma (:obj:torch.Tensor): Gamma discount value for target q_value - criterion (:obj:torch.nn.modules): Loss function criterion - nstep (:obj:int): nstep num, default set to 1 Returns: - loss (:obj:torch.Tensor): nstep td error, 0-dim tensor - td_error_per_sample (:obj:torch.Tensor): nstep td error, 1-dim tensor Shapes: - data (:obj:q_nstep_td_data): The q_nstep_td_data containing ['q', 'next_n_q', 'action', 'reward', 'done'] - q (:obj:torch.FloatTensor): :math:(B, D, N) i.e. [batch_size, branch_num, action_bins_per_branch] - next_n_q (:obj:torch.FloatTensor): :math:(B, D, N) - action (:obj:torch.LongTensor): :math:(B, D) - next_n_action (:obj:torch.LongTensor): :math:(B, D) - reward (:obj:torch.FloatTensor): :math:(T, B), where T is timestep(nstep) - done (:obj:torch.BoolTensor) :math:(B, ), whether done in last timestep - td_error_per_sample (:obj:torch.FloatTensor): :math:(B, ) Examples: >>> action_per_branch = 3 >>> next_q = torch.randn(8, 6, action_per_branch) >>> done = torch.randn(8) >>> action = torch.randint(0, action_per_branch, size=(8, 6)) >>> next_action = torch.randint(0, action_per_branch, size=(8, 6)) >>> nstep =3 >>> q = torch.randn(8, 6, action_per_branch).requires_grad_(True) >>> reward = torch.rand(nstep, 8) >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) >>> loss, td_error_per_sample = bdq_nstep_td_error(data, 0.95, nstep=nstep)

shape_fn_qntd_rescale(args, kwargs)

Overview

Return qntd_rescale shape for hpc

Returns: shape: [T, B, N]

q_nstep_td_error_with_rescale(data, gamma, nstep=1, value_gamma=None, criterion=nn.MSELoss(reduction='none'), trans_fn=value_transform, inv_trans_fn=value_inv_transform)

Overview

Multistep (1 step or n step) td_error with value rescaling

Arguments: - data (:obj:q_nstep_td_data): The input data, q_nstep_td_data to calculate loss - gamma (:obj:float): Discount factor - nstep (:obj:int): nstep num, default set to 1 - criterion (:obj:torch.nn.modules): Loss function criterion - trans_fn (:obj:Callable): Value transfrom function, default to value_transform (refer to rl_utils/value_rescale.py) - inv_trans_fn (:obj:Callable): Value inverse transfrom function, default to value_inv_transform (refer to rl_utils/value_rescale.py) Returns: - loss (:obj:torch.Tensor): nstep td error, 0-dim tensor Shapes: - data (:obj:q_nstep_td_data): The q_nstep_td_data containing ['q', 'next_n_q', 'action', 'reward', 'done'] - q (:obj:torch.FloatTensor): :math:(B, N) i.e. [batch_size, action_dim] - next_n_q (:obj:torch.FloatTensor): :math:(B, N) - action (:obj:torch.LongTensor): :math:(B, ) - next_n_action (:obj:torch.LongTensor): :math:(B, ) - reward (:obj:torch.FloatTensor): :math:(T, B), where T is timestep(nstep) - done (:obj:torch.BoolTensor) :math:(B, ), whether done in last timestep Examples: >>> next_q = torch.randn(4, 3) >>> done = torch.randn(4) >>> action = torch.randint(0, 3, size=(4, )) >>> next_action = torch.randint(0, 3, size=(4, )) >>> nstep =3 >>> q = torch.randn(4, 3).requires_grad_(True) >>> reward = torch.rand(nstep, 4) >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) >>> loss, _ = q_nstep_td_error_with_rescale(data, 0.95, nstep=nstep)

dqfd_nstep_td_error(data, gamma, lambda_n_step_td, lambda_supervised_loss, margin_function, lambda_one_step_td=1.0, nstep=1, cum_reward=False, value_gamma=None, criterion=nn.MSELoss(reduction='none'))

Overview

Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd

Arguments: - data (:obj:dqfd_nstep_td_data): The input data, dqfd_nstep_td_data to calculate loss - gamma (:obj:float): discount factor - cum_reward (:obj:bool): Whether to use cumulative nstep reward, which is figured out when collecting data - value_gamma (:obj:torch.Tensor): Gamma discount value for target q_value - criterion (:obj:torch.nn.modules): Loss function criterion - nstep (:obj:int): nstep num, default set to 10 Returns: - loss (:obj:torch.Tensor): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor - td_error_per_sample (:obj:torch.Tensor): Multistep n step td_error + 1 step td_error + supervised margin loss, 1-dim tensor Shapes: - data (:obj:q_nstep_td_data): the q_nstep_td_data containing ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight' , 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert'] - q (:obj:torch.FloatTensor): :math:(B, N) i.e. [batch_size, action_dim] - next_n_q (:obj:torch.FloatTensor): :math:(B, N) - action (:obj:torch.LongTensor): :math:(B, ) - next_n_action (:obj:torch.LongTensor): :math:(B, ) - reward (:obj:torch.FloatTensor): :math:(T, B), where T is timestep(nstep) - done (:obj:torch.BoolTensor) :math:(B, ), whether done in last timestep - td_error_per_sample (:obj:torch.FloatTensor): :math:(B, ) - new_n_q_one_step (:obj:torch.FloatTensor): :math:(B, N) - next_n_action_one_step (:obj:torch.LongTensor): :math:(B, ) - is_expert (:obj:int) : 0 or 1 Examples: >>> next_q = torch.randn(4, 3) >>> done = torch.randn(4) >>> done_1 = torch.randn(4) >>> next_q_one_step = torch.randn(4, 3) >>> action = torch.randint(0, 3, size=(4, )) >>> next_action = torch.randint(0, 3, size=(4, )) >>> next_action_one_step = torch.randint(0, 3, size=(4, )) >>> is_expert = torch.ones((4)) >>> nstep = 3 >>> q = torch.randn(4, 3).requires_grad_(True) >>> reward = torch.rand(nstep, 4) >>> data = dqfd_nstep_td_data( >>> q, next_q, action, next_action, reward, done, done_1, None, >>> next_q_one_step, next_action_one_step, is_expert >>> ) >>> loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error( >>> data, 0.95, lambda_n_step_td=1, lambda_supervised_loss=1, >>> margin_function=0.8, nstep=nstep >>> )

dqfd_nstep_td_error_with_rescale(data, gamma, lambda_n_step_td, lambda_supervised_loss, lambda_one_step_td, margin_function, nstep=1, cum_reward=False, value_gamma=None, criterion=nn.MSELoss(reduction='none'), trans_fn=value_transform, inv_trans_fn=value_inv_transform)

Overview

Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd

Arguments: - data (:obj:dqfd_nstep_td_data): The input data, dqfd_nstep_td_data to calculate loss - gamma (:obj:float): Discount factor - cum_reward (:obj:bool): Whether to use cumulative nstep reward, which is figured out when collecting data - value_gamma (:obj:torch.Tensor): Gamma discount value for target q_value - criterion (:obj:torch.nn.modules): Loss function criterion - nstep (:obj:int): nstep num, default set to 10 Returns: - loss (:obj:torch.Tensor): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor - td_error_per_sample (:obj:torch.Tensor): Multistep n step td_error + 1 step td_error + supervised margin loss, 1-dim tensor Shapes: - data (:obj:q_nstep_td_data): The q_nstep_td_data containing ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight' , 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert'] - q (:obj:torch.FloatTensor): :math:(B, N) i.e. [batch_size, action_dim] - next_n_q (:obj:torch.FloatTensor): :math:(B, N) - action (:obj:torch.LongTensor): :math:(B, ) - next_n_action (:obj:torch.LongTensor): :math:(B, ) - reward (:obj:torch.FloatTensor): :math:(T, B), where T is timestep(nstep) - done (:obj:torch.BoolTensor) :math:(B, ), whether done in last timestep - td_error_per_sample (:obj:torch.FloatTensor): :math:(B, ) - new_n_q_one_step (:obj:torch.FloatTensor): :math:(B, N) - next_n_action_one_step (:obj:torch.LongTensor): :math:(B, ) - is_expert (:obj:int) : 0 or 1

qrdqn_nstep_td_error(data, gamma, nstep=1, value_gamma=None)

Overview

Multistep (1 step or n step) td_error with in QRDQN

Arguments: - data (:obj:qrdqn_nstep_td_data): The input data, qrdqn_nstep_td_data to calculate loss - gamma (:obj:float): Discount factor - nstep (:obj:int): nstep num, default set to 1 Returns: - loss (:obj:torch.Tensor): nstep td error, 0-dim tensor Shapes: - data (:obj:q_nstep_td_data): The q_nstep_td_data containing ['q', 'next_n_q', 'action', 'reward', 'done'] - q (:obj:torch.FloatTensor): :math:(tau, B, N) i.e. [tau x batch_size, action_dim] - next_n_q (:obj:torch.FloatTensor): :math:(tau', B, N) - action (:obj:torch.LongTensor): :math:(B, ) - next_n_action (:obj:torch.LongTensor): :math:(B, ) - reward (:obj:torch.FloatTensor): :math:(T, B), where T is timestep(nstep) - done (:obj:torch.BoolTensor) :math:(B, ), whether done in last timestep Examples: >>> next_q = torch.randn(4, 3, 3) >>> done = torch.randn(4) >>> action = torch.randint(0, 3, size=(4, )) >>> next_action = torch.randint(0, 3, size=(4, )) >>> nstep = 3 >>> q = torch.randn(4, 3, 3).requires_grad_(True) >>> reward = torch.rand(nstep, 4) >>> data = qrdqn_nstep_td_data(q, next_q, action, next_action, reward, done, 3, None) >>> loss, td_error_per_sample = qrdqn_nstep_td_error(data, 0.95, nstep=nstep)

q_nstep_sql_td_error(data, gamma, alpha, nstep=1, cum_reward=False, value_gamma=None, criterion=nn.MSELoss(reduction='none'))

Overview

Multistep (1 step or n step) td_error for q-learning based algorithm

Arguments: - data (:obj:q_nstep_td_data): The input data, q_nstep_sql_td_data to calculate loss - gamma (:obj:float): Discount factor - Alpha (:obj:`float): A parameter to weight entropy term in a policy equation - cum_reward (:obj:bool): Whether to use cumulative nstep reward, which is figured out when collecting data - value_gamma (:obj:torch.Tensor): Gamma discount value for target soft_q_value - criterion (:obj:torch.nn.modules): Loss function criterion - nstep (:obj:int): nstep num, default set to 1 Returns: - loss (:obj:torch.Tensor): nstep td error, 0-dim tensor - td_error_per_sample (:obj:torch.Tensor): nstep td error, 1-dim tensor Shapes: - data (:obj:q_nstep_td_data): The q_nstep_td_data containing ['q', 'next_n_q', 'action', 'reward', 'done'] - q (:obj:torch.FloatTensor): :math:(B, N)i.e. [batch_size, action_dim] - next_n_q (:obj:torch.FloatTensor): :math:(B, N)- action (:obj:torch.LongTensor): :math:(B, )- next_n_action (:obj:torch.LongTensor): :math:(B, )- reward (:obj:torch.FloatTensor): :math:(T, B), where T is timestep(nstep) - done (:obj:torch.BoolTensor) :math:(B, ), whether done in last timestep - td_error_per_sample (:obj:torch.FloatTensor): :math:(B, )` Examples: >>> next_q = torch.randn(4, 3) >>> done = torch.randn(4) >>> action = torch.randint(0, 3, size=(4, )) >>> next_action = torch.randint(0, 3, size=(4, )) >>> nstep = 3 >>> q = torch.randn(4, 3).requires_grad_(True) >>> reward = torch.rand(nstep, 4) >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) >>> loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(data, 0.95, 1.0, nstep=nstep)

iqn_nstep_td_error(data, gamma, nstep=1, kappa=1.0, value_gamma=None)

Overview

Multistep (1 step or n step) td_error with in IQN, referenced paper Implicit Quantile Networks for Distributional Reinforcement Learning https://arxiv.org/pdf/1806.06923.pdf

Arguments: - data (:obj:iqn_nstep_td_data): The input data, iqn_nstep_td_data to calculate loss - gamma (:obj:float): Discount factor - nstep (:obj:int): nstep num, default set to 1 - criterion (:obj:torch.nn.modules): Loss function criterion - beta_function (:obj:Callable): The risk function Returns: - loss (:obj:torch.Tensor): nstep td error, 0-dim tensor Shapes: - data (:obj:q_nstep_td_data): The q_nstep_td_data containing ['q', 'next_n_q', 'action', 'reward', 'done'] - q (:obj:torch.FloatTensor): :math:(tau, B, N) i.e. [tau x batch_size, action_dim] - next_n_q (:obj:torch.FloatTensor): :math:(tau', B, N) - action (:obj:torch.LongTensor): :math:(B, ) - next_n_action (:obj:torch.LongTensor): :math:(B, ) - reward (:obj:torch.FloatTensor): :math:(T, B), where T is timestep(nstep) - done (:obj:torch.BoolTensor) :math:(B, ), whether done in last timestep Examples: >>> next_q = torch.randn(3, 4, 3) >>> done = torch.randn(4) >>> action = torch.randint(0, 3, size=(4, )) >>> next_action = torch.randint(0, 3, size=(4, )) >>> nstep = 3 >>> q = torch.randn(3, 4, 3).requires_grad_(True) >>> replay_quantile = torch.randn([3, 4, 1]) >>> reward = torch.rand(nstep, 4) >>> data = iqn_nstep_td_data(q, next_q, action, next_action, reward, done, replay_quantile, None) >>> loss, td_error_per_sample = iqn_nstep_td_error(data, 0.95, nstep=nstep)

fqf_nstep_td_error(data, gamma, nstep=1, kappa=1.0, value_gamma=None)

Overview

Multistep (1 step or n step) td_error with in FQF, referenced paper Fully Parameterized Quantile Function for Distributional Reinforcement Learning https://arxiv.org/pdf/1911.02140.pdf

Arguments: - data (:obj:fqf_nstep_td_data): The input data, fqf_nstep_td_data to calculate loss - gamma (:obj:float): Discount factor - nstep (:obj:int): nstep num, default set to 1 - criterion (:obj:torch.nn.modules): Loss function criterion - beta_function (:obj:Callable): The risk function Returns: - loss (:obj:torch.Tensor): nstep td error, 0-dim tensor Shapes: - data (:obj:q_nstep_td_data): The q_nstep_td_data containing ['q', 'next_n_q', 'action', 'reward', 'done'] - q (:obj:torch.FloatTensor): :math:(B, tau, N) i.e. [batch_size, tau, action_dim] - next_n_q (:obj:torch.FloatTensor): :math:(B, tau', N) - action (:obj:torch.LongTensor): :math:(B, ) - next_n_action (:obj:torch.LongTensor): :math:(B, ) - reward (:obj:torch.FloatTensor): :math:(T, B), where T is timestep(nstep) - done (:obj:torch.BoolTensor) :math:(B, ), whether done in last timestep - quantiles_hats (:obj:torch.FloatTensor): :math:(B, tau) Examples: >>> next_q = torch.randn(4, 3, 3) >>> done = torch.randn(4) >>> action = torch.randint(0, 3, size=(4, )) >>> next_action = torch.randint(0, 3, size=(4, )) >>> nstep = 3 >>> q = torch.randn(4, 3, 3).requires_grad_(True) >>> quantiles_hats = torch.randn([4, 3]) >>> reward = torch.rand(nstep, 4) >>> data = fqf_nstep_td_data(q, next_q, action, next_action, reward, done, quantiles_hats, None) >>> loss, td_error_per_sample = fqf_nstep_td_error(data, 0.95, nstep=nstep)

fqf_calculate_fraction_loss(q_tau_i, q_value, quantiles, actions)

Overview

Calculate the fraction loss in FQF, referenced paper Fully Parameterized Quantile Function for Distributional Reinforcement Learning https://arxiv.org/pdf/1911.02140.pdf

Arguments: - q_tau_i (:obj:torch.FloatTensor): :math:(batch_size, num_quantiles-1, action_dim) - q_value (:obj:torch.FloatTensor): :math:(batch_size, num_quantiles, action_dim) - quantiles (:obj:torch.FloatTensor): :math:(batch_size, num_quantiles+1) - actions (:obj:torch.LongTensor): :math:(batch_size, ) Returns: - fraction_loss (:obj:torch.Tensor): fraction loss, 0-dim tensor

shape_fn_td_lambda(args, kwargs)

Overview

Return td_lambda shape for hpc

Returns: shape: [T, B]

td_lambda_error(data, gamma=0.9, lambda_=0.8)

Overview

Computing TD(lambda) loss given constant gamma and lambda. There is no special handling for terminal state value, if some state has reached the terminal, just fill in zeros for values and rewards beyond terminal (including the terminal state, values[terminal] should also be 0)

Arguments: - data (:obj:namedtuple): td_lambda input data with fields ['value', 'reward', 'weight'] - gamma (:obj:float): Constant discount factor gamma, should be in [0, 1], defaults to 0.9 - lambda (:obj:float): Constant lambda, should be in [0, 1], defaults to 0.8 Returns: - loss (:obj:torch.Tensor): Computed MSE loss, averaged over the batch Shapes: - value (:obj:torch.FloatTensor): :math:(T+1, B), where T is trajectory length and B is batch, which is the estimation of the state value at step 0 to T - reward (:obj:torch.FloatTensor): :math:(T, B), the returns from time step 0 to T-1 - weight (:obj:torch.FloatTensor or None): :math:(B, ), the training sample weight - loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor Examples: >>> T, B = 8, 4 >>> value = torch.randn(T + 1, B).requires_grad_(True) >>> reward = torch.rand(T, B) >>> loss = td_lambda_error(td_lambda_data(value, reward, None))

generalized_lambda_returns(bootstrap_values, rewards, gammas, lambda_, done=None)

Overview

Functional equivalent to trfl.value_ops.generalized_lambda_returns https://github.com/deepmind/trfl/blob/2c07ac22512a16715cc759f0072be43a5d12ae45/trfl/value_ops.py#L74 Passing in a number instead of tensor to make the value constant for all samples in batch

Arguments: - bootstrap_values (:obj:torch.Tensor or :obj:float): estimation of the value at step 0 to T, of size [T_traj+1, batchsize] - rewards (:obj:torch.Tensor): The returns from 0 to T-1, of size [T_traj, batchsize] - gammas (:obj:torch.Tensor or :obj:float): Discount factor for each step (from 0 to T-1), of size [T_traj, batchsize] - lambda (:obj:torch.Tensor or :obj:float): Determining the mix of bootstrapping vs further accumulation of multistep returns at each timestep, of size [T_traj, batchsize] - done (:obj:torch.Tensor or :obj:float): Whether the episode done at current step (from 0 to T-1), of size [T_traj, batchsize] Returns: - return (:obj:torch.Tensor): Computed lambda return value for each state from 0 to T-1, of size [T_traj, batchsize]

multistep_forward_view(bootstrap_values, rewards, gammas, lambda_, done=None)

Overview

Same as trfl.sequence_ops.multistep_forward_view, which implements (12.18) in Sutton & Barto. Assuming the first dim of input tensors correspond to the index in batch.

.. note:: result[T-1] = rewards[T-1] + gammas[T-1] * bootstrap_values[T] for t in 0...T-2 : result[t] = rewards[t] + gammas[t](lambdas[t]result[t+1] + (1-lambdas[t])*bootstrap_values[t+1])

Parameters:

Name Type Description Default
- bootstrap_values (

obj:torch.Tensor): Estimation of the value at step 1 to T, of size [T_traj, batchsize]

required
- rewards (

obj:torch.Tensor): The returns from 0 to T-1, of size [T_traj, batchsize]

required
- gammas (

obj:torch.Tensor): Discount factor for each step (from 0 to T-1), of size [T_traj, batchsize]

required
- lambda (

obj:torch.Tensor): Determining the mix of bootstrapping vs further accumulation of multistep returns at each timestep of size [T_traj, batchsize], the element for T-1 is ignored and effectively set to 0, as there is no information about future rewards.

required
- done (

obj:torch.Tensor or :obj:float): Whether the episode done at current step (from 0 to T-1), of size [T_traj, batchsize]

required

Returns: - ret (:obj:torch.Tensor): Computed lambda return value for each state from 0 to T-1, of size [T_traj, batchsize]

Full Source Code

../ding/rl_utils/td.py

1import copy 2import numpy as np 3from collections import namedtuple 4from typing import Union, Optional, Callable 5 6import torch 7import torch.nn as nn 8import torch.nn.functional as F 9 10from ding.hpc_rl import hpc_wrapper 11from ding.rl_utils.value_rescale import value_transform, value_inv_transform 12from ding.torch_utils import to_tensor 13 14q_1step_td_data = namedtuple('q_1step_td_data', ['q', 'next_q', 'act', 'next_act', 'reward', 'done', 'weight']) 15 16 17def discount_cumsum(x, gamma: float = 1.0) -> np.ndarray: 18 assert abs(gamma - 1.) < 1e-5, "gamma equals to 1.0 in original decision transformer paper" 19 disc_cumsum = np.zeros_like(x) 20 disc_cumsum[-1] = x[-1] 21 for t in reversed(range(x.shape[0] - 1)): 22 disc_cumsum[t] = x[t] + gamma * disc_cumsum[t + 1] 23 return disc_cumsum 24 25 26def q_1step_td_error( 27 data: namedtuple, 28 gamma: float, 29 criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa 30) -> torch.Tensor: 31 """ 32 Overview: 33 1 step td_error, support single agent case and multi agent case. 34 Arguments: 35 - data (:obj:`q_1step_td_data`): The input data, q_1step_td_data to calculate loss 36 - gamma (:obj:`float`): Discount factor 37 - criterion (:obj:`torch.nn.modules`): Loss function criterion 38 Returns: 39 - loss (:obj:`torch.Tensor`): 1step td error 40 Shapes: 41 - data (:obj:`q_1step_td_data`): the q_1step_td_data containing\ 42 ['q', 'next_q', 'act', 'next_act', 'reward', 'done', 'weight'] 43 - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] 44 - next_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] 45 - act (:obj:`torch.LongTensor`): :math:`(B, )` 46 - next_act (:obj:`torch.LongTensor`): :math:`(B, )` 47 - reward (:obj:`torch.FloatTensor`): :math:`( , B)` 48 - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep 49 - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight 50 Examples: 51 >>> action_dim = 4 52 >>> data = q_1step_td_data( 53 >>> q=torch.randn(3, action_dim), 54 >>> next_q=torch.randn(3, action_dim), 55 >>> act=torch.randint(0, action_dim, (3,)), 56 >>> next_act=torch.randint(0, action_dim, (3,)), 57 >>> reward=torch.randn(3), 58 >>> done=torch.randint(0, 2, (3,)).bool(), 59 >>> weight=torch.ones(3), 60 >>> ) 61 >>> loss = q_1step_td_error(data, 0.99) 62 """ 63 q, next_q, act, next_act, reward, done, weight = data 64 assert len(act.shape) == 1, act.shape 65 assert len(reward.shape) == 1, reward.shape 66 batch_range = torch.arange(act.shape[0]) 67 if weight is None: 68 weight = torch.ones_like(reward) 69 q_s_a = q[batch_range, act] 70 target_q_s_a = next_q[batch_range, next_act] 71 target_q_s_a = gamma * (1 - done) * target_q_s_a + reward 72 return (criterion(q_s_a, target_q_s_a.detach()) * weight).mean() 73 74 75m_q_1step_td_data = namedtuple('m_q_1step_td_data', ['q', 'target_q', 'next_q', 'act', 'reward', 'done', 'weight']) 76 77 78def m_q_1step_td_error( 79 data: namedtuple, 80 gamma: float, 81 tau: float, 82 alpha: float, 83 criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa 84) -> torch.Tensor: 85 """ 86 Overview: 87 Munchausen td_error for DQN algorithm, support 1 step td error. 88 Arguments: 89 - data (:obj:`m_q_1step_td_data`): The input data, m_q_1step_td_data to calculate loss 90 - gamma (:obj:`float`): Discount factor 91 - tau (:obj:`float`): Entropy factor for Munchausen DQN 92 - alpha (:obj:`float`): Discount factor for Munchausen term 93 - criterion (:obj:`torch.nn.modules`): Loss function criterion 94 Returns: 95 - loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor 96 Shapes: 97 - data (:obj:`m_q_1step_td_data`): the m_q_1step_td_data containing\ 98 ['q', 'target_q', 'next_q', 'act', 'reward', 'done', 'weight'] 99 - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] 100 - target_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] 101 - next_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] 102 - act (:obj:`torch.LongTensor`): :math:`(B, )` 103 - reward (:obj:`torch.FloatTensor`): :math:`( , B)` 104 - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep 105 - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight 106 Examples: 107 >>> action_dim = 4 108 >>> data = m_q_1step_td_data( 109 >>> q=torch.randn(3, action_dim), 110 >>> target_q=torch.randn(3, action_dim), 111 >>> next_q=torch.randn(3, action_dim), 112 >>> act=torch.randint(0, action_dim, (3,)), 113 >>> reward=torch.randn(3), 114 >>> done=torch.randint(0, 2, (3,)), 115 >>> weight=torch.ones(3), 116 >>> ) 117 >>> loss = m_q_1step_td_error(data, 0.99, 0.01, 0.01) 118 """ 119 q, target_q, next_q, act, reward, done, weight = data 120 lower_bound = -1 121 assert len(act.shape) == 1, act.shape 122 assert len(reward.shape) == 1, reward.shape 123 batch_range = torch.arange(act.shape[0]) 124 if weight is None: 125 weight = torch.ones_like(reward) 126 q_s_a = q[batch_range, act] 127 # calculate muchausen addon 128 # replay_log_policy 129 target_v_s = target_q[batch_range].max(1)[0].unsqueeze(-1) 130 131 logsum = torch.logsumexp((target_q - target_v_s) / tau, 1).unsqueeze(-1) 132 log_pi = target_q - target_v_s - tau * logsum 133 act_get = act.unsqueeze(-1) 134 # same to the last second tau_log_pi_a 135 munchausen_addon = log_pi.gather(1, act_get) 136 137 muchausen_term = alpha * torch.clamp(munchausen_addon, min=lower_bound, max=1) 138 139 # replay_next_log_policy 140 target_v_s_next = next_q[batch_range].max(1)[0].unsqueeze(-1) 141 logsum_next = torch.logsumexp((next_q - target_v_s_next) / tau, 1).unsqueeze(-1) 142 tau_log_pi_next = next_q - target_v_s_next - tau * logsum_next 143 # do stable softmax == replay_next_policy 144 pi_target = F.softmax((next_q - target_v_s_next) / tau) 145 target_q_s_a = (gamma * (pi_target * (next_q - tau_log_pi_next) * (1 - done.unsqueeze(-1))).sum(1)).unsqueeze(-1) 146 147 target_q_s_a = reward.unsqueeze(-1) + muchausen_term + target_q_s_a 148 td_error_per_sample = criterion(q_s_a.unsqueeze(-1), target_q_s_a.detach()).squeeze(-1) 149 150 # calculate action_gap and clipfrac 151 with torch.no_grad(): 152 top2_q_s = target_q[batch_range].topk(2, dim=1, largest=True, sorted=True)[0] 153 action_gap = (top2_q_s[:, 0] - top2_q_s[:, 1]).mean() 154 155 clipped = munchausen_addon.gt(1) | munchausen_addon.lt(lower_bound) 156 clipfrac = torch.as_tensor(clipped).float() 157 158 return (td_error_per_sample * weight).mean(), td_error_per_sample, action_gap, clipfrac 159 160 161q_v_1step_td_data = namedtuple('q_v_1step_td_data', ['q', 'v', 'act', 'reward', 'done', 'weight']) 162 163 164def q_v_1step_td_error( 165 data: namedtuple, gamma: float, criterion: torch.nn.modules = nn.MSELoss(reduction='none') 166) -> torch.Tensor: 167 # we will use this function in discrete sac algorithm to calculate td error between q and v value. 168 """ 169 Overview: 170 td_error between q and v value for SAC algorithm, support 1 step td error. 171 Arguments: 172 - data (:obj:`q_v_1step_td_data`): The input data, q_v_1step_td_data to calculate loss 173 - gamma (:obj:`float`): Discount factor 174 - criterion (:obj:`torch.nn.modules`): Loss function criterion 175 Returns: 176 - loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor 177 Shapes: 178 - data (:obj:`q_v_1step_td_data`): the q_v_1step_td_data containing\ 179 ['q', 'v', 'act', 'reward', 'done', 'weight'] 180 - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] 181 - v (:obj:`torch.FloatTensor`): :math:`(B, )` 182 - act (:obj:`torch.LongTensor`): :math:`(B, )` 183 - reward (:obj:`torch.FloatTensor`): :math:`( , B)` 184 - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep 185 - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight 186 Examples: 187 >>> action_dim = 4 188 >>> data = q_v_1step_td_data( 189 >>> q=torch.randn(3, action_dim), 190 >>> v=torch.randn(3), 191 >>> act=torch.randint(0, action_dim, (3,)), 192 >>> reward=torch.randn(3), 193 >>> done=torch.randint(0, 2, (3,)), 194 >>> weight=torch.ones(3), 195 >>> ) 196 >>> loss = q_v_1step_td_error(data, 0.99) 197 """ 198 q, v, act, reward, done, weight = data 199 if len(act.shape) == 1: 200 assert len(reward.shape) == 1, reward.shape 201 batch_range = torch.arange(act.shape[0]) 202 if weight is None: 203 weight = torch.ones_like(reward) 204 q_s_a = q[batch_range, act] 205 target_q_s_a = gamma * (1 - done) * v + reward 206 else: 207 assert len(reward.shape) == 1, reward.shape 208 batch_range = torch.arange(act.shape[0]) 209 actor_range = torch.arange(act.shape[1]) 210 batch_actor_range = torch.arange(act.shape[0] * act.shape[1]) 211 if weight is None: 212 weight = torch.ones_like(act) 213 temp_q = q.reshape(act.shape[0] * act.shape[1], -1) 214 temp_act = act.reshape(act.shape[0] * act.shape[1]) 215 q_s_a = temp_q[batch_actor_range, temp_act] 216 q_s_a = q_s_a.reshape(act.shape[0], act.shape[1]) 217 target_q_s_a = gamma * (1 - done).unsqueeze(1) * v + reward.unsqueeze(1) 218 td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) 219 return (td_error_per_sample * weight).mean(), td_error_per_sample 220 221 222def view_similar(x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 223 size = list(x.shape) + [1 for _ in range(len(target.shape) - len(x.shape))] 224 return x.view(*size) 225 226 227nstep_return_data = namedtuple('nstep_return_data', ['reward', 'next_value', 'done']) 228 229 230def nstep_return(data: namedtuple, gamma: Union[float, list], nstep: int, value_gamma: Optional[torch.Tensor] = None): 231 ''' 232 Overview: 233 Calculate nstep return for DQN algorithm, support single agent case and multi agent case. 234 Arguments: 235 - data (:obj:`nstep_return_data`): The input data, nstep_return_data to calculate loss 236 - gamma (:obj:`float`): Discount factor 237 - nstep (:obj:`int`): nstep num 238 - value_gamma (:obj:`torch.Tensor`): Discount factor for value 239 Returns: 240 - return (:obj:`torch.Tensor`): nstep return 241 Shapes: 242 - data (:obj:`nstep_return_data`): the nstep_return_data containing\ 243 ['reward', 'next_value', 'done'] 244 - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) 245 - next_value (:obj:`torch.FloatTensor`): :math:`(, B)` 246 - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep 247 Examples: 248 >>> data = nstep_return_data( 249 >>> reward=torch.randn(3, 3), 250 >>> next_value=torch.randn(3), 251 >>> done=torch.randint(0, 2, (3,)), 252 >>> ) 253 >>> loss = nstep_return(data, 0.99, 3) 254 ''' 255 256 reward, next_value, done = data 257 assert reward.shape[0] == nstep 258 device = reward.device 259 260 if isinstance(gamma, float): 261 reward_factor = torch.ones(nstep).to(device) 262 for i in range(1, nstep): 263 reward_factor[i] = gamma * reward_factor[i - 1] 264 reward_factor = view_similar(reward_factor, reward) 265 return_tmp = reward.mul(reward_factor).sum(0) 266 if value_gamma is None: 267 return_ = return_tmp + (gamma ** nstep) * next_value * (1 - done) 268 else: 269 if np.isscalar(value_gamma): 270 value_gamma = torch.full_like(next_value, value_gamma) 271 value_gamma = view_similar(value_gamma, next_value) 272 done = view_similar(done, next_value) 273 return_ = return_tmp + value_gamma * next_value * (1 - done) 274 275 elif isinstance(gamma, list): 276 # if gamma is list, for NGU policy case 277 reward_factor = torch.ones([nstep + 1, done.shape[0]]).to(device) 278 for i in range(1, nstep + 1): 279 reward_factor[i] = torch.stack(gamma, dim=0).to(device) * reward_factor[i - 1] 280 reward_factor = view_similar(reward_factor, reward) 281 return_tmp = reward.mul(reward_factor[:nstep]).sum(0) 282 return_ = return_tmp + reward_factor[nstep] * next_value * (1 - done) 283 else: 284 raise TypeError("The type of gamma should be float or list") 285 286 return return_ 287 288 289dist_1step_td_data = namedtuple( 290 'dist_1step_td_data', ['dist', 'next_dist', 'act', 'next_act', 'reward', 'done', 'weight'] 291) 292 293 294def dist_1step_td_error( 295 data: namedtuple, 296 gamma: float, 297 v_min: float, 298 v_max: float, 299 n_atom: int, 300) -> torch.Tensor: 301 """ 302 Overview: 303 1 step td_error for distributed q-learning based algorithm 304 Arguments: 305 - data (:obj:`dist_1step_td_data`): The input data, dist_nstep_td_data to calculate loss 306 - gamma (:obj:`float`): Discount factor 307 - v_min (:obj:`float`): The min value of support 308 - v_max (:obj:`float`): The max value of support 309 - n_atom (:obj:`int`): The num of atom 310 Returns: 311 - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor 312 Shapes: 313 - data (:obj:`dist_1step_td_data`): the dist_1step_td_data containing\ 314 ['dist', 'next_n_dist', 'act', 'reward', 'done', 'weight'] 315 - dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` i.e. [batch_size, action_dim, n_atom] 316 - next_dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` 317 - act (:obj:`torch.LongTensor`): :math:`(B, )` 318 - next_act (:obj:`torch.LongTensor`): :math:`(B, )` 319 - reward (:obj:`torch.FloatTensor`): :math:`(, B)` 320 - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep 321 - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight 322 Examples: 323 >>> dist = torch.randn(4, 3, 51).abs().requires_grad_(True) 324 >>> next_dist = torch.randn(4, 3, 51).abs() 325 >>> act = torch.randint(0, 3, (4,)) 326 >>> next_act = torch.randint(0, 3, (4,)) 327 >>> reward = torch.randn(4) 328 >>> done = torch.randint(0, 2, (4,)) 329 >>> data = dist_1step_td_data(dist, next_dist, act, next_act, reward, done, None) 330 >>> loss = dist_1step_td_error(data, 0.99, -10.0, 10.0, 51) 331 """ 332 dist, next_dist, act, next_act, reward, done, weight = data 333 device = reward.device 334 assert len(reward.shape) == 1, reward.shape 335 support = torch.linspace(v_min, v_max, n_atom).to(device) 336 delta_z = (v_max - v_min) / (n_atom - 1) 337 338 if len(act.shape) == 1: 339 reward = reward.unsqueeze(-1) 340 done = done.unsqueeze(-1) 341 batch_size = act.shape[0] 342 batch_range = torch.arange(batch_size) 343 if weight is None: 344 weight = torch.ones_like(reward) 345 next_dist = next_dist[batch_range, next_act].detach() 346 else: 347 reward = reward.unsqueeze(-1).repeat(1, act.shape[1]) 348 done = done.unsqueeze(-1).repeat(1, act.shape[1]) 349 350 batch_size = act.shape[0] * act.shape[1] 351 batch_range = torch.arange(act.shape[0] * act.shape[1]) 352 action_dim = dist.shape[2] 353 dist = dist.reshape(act.shape[0] * act.shape[1], action_dim, -1) 354 reward = reward.reshape(act.shape[0] * act.shape[1], -1) 355 done = done.reshape(act.shape[0] * act.shape[1], -1) 356 next_dist = next_dist.reshape(act.shape[0] * act.shape[1], action_dim, -1) 357 358 next_act = next_act.reshape(act.shape[0] * act.shape[1]) 359 next_dist = next_dist[batch_range, next_act].detach() 360 next_dist = next_dist.reshape(act.shape[0] * act.shape[1], -1) 361 act = act.reshape(act.shape[0] * act.shape[1]) 362 if weight is None: 363 weight = torch.ones_like(reward) 364 target_z = reward + (1 - done) * gamma * support 365 target_z = target_z.clamp(min=v_min, max=v_max) 366 b = (target_z - v_min) / delta_z 367 l = b.floor().long() 368 u = b.ceil().long() 369 # Fix disappearing probability mass when l = b = u (b is int) 370 l[(u > 0) * (l == u)] -= 1 371 u[(l < (n_atom - 1)) * (l == u)] += 1 372 373 proj_dist = torch.zeros_like(next_dist) 374 offset = torch.linspace(0, (batch_size - 1) * n_atom, batch_size).unsqueeze(1).expand(batch_size, 375 n_atom).long().to(device) 376 proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)) 377 proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)) 378 379 log_p = torch.log(dist[batch_range, act]) 380 381 loss = -(log_p * proj_dist * weight).sum(-1).mean() 382 383 return loss 384 385 386dist_nstep_td_data = namedtuple( 387 'dist_1step_td_data', ['dist', 'next_n_dist', 'act', 'next_n_act', 'reward', 'done', 'weight'] 388) 389 390 391def shape_fn_dntd(args, kwargs): 392 r""" 393 Overview: 394 Return dntd shape for hpc 395 Returns: 396 shape: [T, B, N, n_atom] 397 """ 398 if len(args) <= 0: 399 tmp = [kwargs['data'].reward.shape[0]] 400 tmp.extend(list(kwargs['data'].dist.shape)) 401 else: 402 tmp = [args[0].reward.shape[0]] 403 tmp.extend(list(args[0].dist.shape)) 404 return tmp 405 406 407@hpc_wrapper( 408 shape_fn=shape_fn_dntd, 409 namedtuple_data=True, 410 include_args=[0, 1, 2, 3], 411 include_kwargs=['data', 'gamma', 'v_min', 'v_max'] 412) 413def dist_nstep_td_error( 414 data: namedtuple, 415 gamma: float, 416 v_min: float, 417 v_max: float, 418 n_atom: int, 419 nstep: int = 1, 420 value_gamma: Optional[torch.Tensor] = None, 421) -> torch.Tensor: 422 """ 423 Overview: 424 Multistep (1 step or n step) td_error for distributed q-learning based algorithm, support single\ 425 agent case and multi agent case. 426 Arguments: 427 - data (:obj:`dist_nstep_td_data`): The input data, dist_nstep_td_data to calculate loss 428 - gamma (:obj:`float`): Discount factor 429 - nstep (:obj:`int`): nstep num, default set to 1 430 Returns: 431 - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor 432 Shapes: 433 - data (:obj:`dist_nstep_td_data`): the dist_nstep_td_data containing\ 434 ['dist', 'next_n_dist', 'act', 'reward', 'done', 'weight'] 435 - dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` i.e. [batch_size, action_dim, n_atom] 436 - next_n_dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` 437 - act (:obj:`torch.LongTensor`): :math:`(B, )` 438 - next_n_act (:obj:`torch.LongTensor`): :math:`(B, )` 439 - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) 440 - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep 441 Examples: 442 >>> dist = torch.randn(4, 3, 51).abs().requires_grad_(True) 443 >>> next_n_dist = torch.randn(4, 3, 51).abs() 444 >>> done = torch.randn(4) 445 >>> action = torch.randint(0, 3, size=(4, )) 446 >>> next_action = torch.randint(0, 3, size=(4, )) 447 >>> reward = torch.randn(5, 4) 448 >>> data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None) 449 >>> loss, _ = dist_nstep_td_error(data, 0.95, -10.0, 10.0, 51, 5) 450 """ 451 dist, next_n_dist, act, next_n_act, reward, done, weight = data 452 device = reward.device 453 reward_factor = torch.ones(nstep).to(device) 454 for i in range(1, nstep): 455 reward_factor[i] = gamma * reward_factor[i - 1] 456 reward = torch.matmul(reward_factor, reward) 457 support = torch.linspace(v_min, v_max, n_atom).to(device) 458 delta_z = (v_max - v_min) / (n_atom - 1) 459 if len(act.shape) == 1: 460 reward = reward.unsqueeze(-1) 461 done = done.unsqueeze(-1) 462 batch_size = act.shape[0] 463 batch_range = torch.arange(batch_size) 464 if weight is None: 465 weight = torch.ones_like(reward) 466 elif isinstance(weight, float): 467 weight = torch.tensor(weight) 468 469 next_n_dist = next_n_dist[batch_range, next_n_act].detach() 470 else: 471 reward = reward.unsqueeze(-1).repeat(1, act.shape[1]) 472 done = done.unsqueeze(-1).repeat(1, act.shape[1]) 473 474 batch_size = act.shape[0] * act.shape[1] 475 batch_range = torch.arange(act.shape[0] * act.shape[1]) 476 action_dim = dist.shape[2] 477 dist = dist.reshape(act.shape[0] * act.shape[1], action_dim, -1) 478 reward = reward.reshape(act.shape[0] * act.shape[1], -1) 479 done = done.reshape(act.shape[0] * act.shape[1], -1) 480 next_n_dist = next_n_dist.reshape(act.shape[0] * act.shape[1], action_dim, -1) 481 482 next_n_act = next_n_act.reshape(act.shape[0] * act.shape[1]) 483 next_n_dist = next_n_dist[batch_range, next_n_act].detach() 484 next_n_dist = next_n_dist.reshape(act.shape[0] * act.shape[1], -1) 485 act = act.reshape(act.shape[0] * act.shape[1]) 486 if weight is None: 487 weight = torch.ones_like(reward) 488 elif isinstance(weight, float): 489 weight = torch.tensor(weight) 490 491 if value_gamma is None: 492 target_z = reward + (1 - done) * (gamma ** nstep) * support 493 elif isinstance(value_gamma, float): 494 value_gamma = torch.tensor(value_gamma).unsqueeze(-1) 495 target_z = reward + (1 - done) * value_gamma * support 496 else: 497 value_gamma = value_gamma.unsqueeze(-1) 498 target_z = reward + (1 - done) * value_gamma * support 499 target_z = target_z.clamp(min=v_min, max=v_max) 500 b = (target_z - v_min) / delta_z 501 l = b.floor().long() 502 u = b.ceil().long() 503 # Fix disappearing probability mass when l = b = u (b is int) 504 l[(u > 0) * (l == u)] -= 1 505 u[(l < (n_atom - 1)) * (l == u)] += 1 506 507 proj_dist = torch.zeros_like(next_n_dist) 508 offset = torch.linspace(0, (batch_size - 1) * n_atom, batch_size).unsqueeze(1).expand(batch_size, 509 n_atom).long().to(device) 510 proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_n_dist * (u.float() - b)).view(-1)) 511 proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_n_dist * (b - l.float())).view(-1)) 512 513 assert (dist[batch_range, act] > 0.0).all(), ("dist act", dist[batch_range, act], "dist:", dist) 514 log_p = torch.log(dist[batch_range, act]) 515 516 if len(weight.shape) == 1: 517 weight = weight.unsqueeze(-1) 518 519 td_error_per_sample = -(log_p * proj_dist).sum(-1) 520 521 loss = -(log_p * proj_dist * weight).sum(-1).mean() 522 523 return loss, td_error_per_sample 524 525 526v_1step_td_data = namedtuple('v_1step_td_data', ['v', 'next_v', 'reward', 'done', 'weight']) 527 528 529def v_1step_td_error( 530 data: namedtuple, 531 gamma: float, 532 criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa 533) -> torch.Tensor: 534 ''' 535 Overview: 536 1 step td_error for distributed value based algorithm 537 Arguments: 538 - data (:obj:`v_1step_td_data`): The input data, v_1step_td_data to calculate loss 539 - gamma (:obj:`float`): Discount factor 540 - criterion (:obj:`torch.nn.modules`): Loss function criterion 541 Returns: 542 - loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor 543 Shapes: 544 - data (:obj:`v_1step_td_data`): the v_1step_td_data containing\ 545 ['v', 'next_v', 'reward', 'done', 'weight'] 546 - v (:obj:`torch.FloatTensor`): :math:`(B, )` i.e. [batch_size, ] 547 - next_v (:obj:`torch.FloatTensor`): :math:`(B, )` 548 - reward (:obj:`torch.FloatTensor`): :math:`(, B)` 549 - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep 550 - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight 551 Examples: 552 >>> v = torch.randn(5).requires_grad_(True) 553 >>> next_v = torch.randn(5) 554 >>> reward = torch.rand(5) 555 >>> done = torch.zeros(5) 556 >>> data = v_1step_td_data(v, next_v, reward, done, None) 557 >>> loss, td_error_per_sample = v_1step_td_error(data, 0.99) 558 ''' 559 v, next_v, reward, done, weight = data 560 if weight is None: 561 weight = torch.ones_like(v) 562 if len(v.shape) == len(reward.shape): 563 if done is not None: 564 target_v = gamma * (1 - done) * next_v + reward 565 else: 566 target_v = gamma * next_v + reward 567 else: 568 if done is not None: 569 target_v = gamma * (1 - done).unsqueeze(1) * next_v + reward.unsqueeze(1) 570 else: 571 target_v = gamma * next_v + reward.unsqueeze(1) 572 td_error_per_sample = criterion(v, target_v.detach()) 573 return (td_error_per_sample * weight).mean(), td_error_per_sample 574 575 576v_nstep_td_data = namedtuple('v_nstep_td_data', ['v', 'next_n_v', 'reward', 'done', 'weight', 'value_gamma']) 577 578 579def v_nstep_td_error( 580 data: namedtuple, 581 gamma: float, 582 nstep: int = 1, 583 criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa 584) -> torch.Tensor: 585 """ 586 Overview: 587 Multistep (n step) td_error for distributed value based algorithm 588 Arguments: 589 - data (:obj:`dist_nstep_td_data`): The input data, v_nstep_td_data to calculate loss 590 - gamma (:obj:`float`): Discount factor 591 - nstep (:obj:`int`): nstep num, default set to 1 592 Returns: 593 - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor 594 Shapes: 595 - data (:obj:`dist_nstep_td_data`): The v_nstep_td_data containing \ 596 ['v', 'next_n_v', 'reward', 'done', 'weight', 'value_gamma'] 597 - v (:obj:`torch.FloatTensor`): :math:`(B, )` i.e. [batch_size, ] 598 - next_v (:obj:`torch.FloatTensor`): :math:`(B, )` 599 - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) 600 - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep 601 - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight 602 - value_gamma (:obj:`torch.Tensor`): If the remaining data in the buffer is less than n_step \ 603 we use value_gamma as the gamma discount value for next_v rather than gamma**n_step 604 Examples: 605 >>> v = torch.randn(5).requires_grad_(True) 606 >>> next_v = torch.randn(5) 607 >>> reward = torch.rand(5, 5) 608 >>> done = torch.zeros(5) 609 >>> data = v_nstep_td_data(v, next_v, reward, done, 0.9, 0.99) 610 >>> loss, td_error_per_sample = v_nstep_td_error(data, 0.99, 5) 611 """ 612 v, next_n_v, reward, done, weight, value_gamma = data 613 if weight is None: 614 weight = torch.ones_like(v) 615 target_v = nstep_return(nstep_return_data(reward, next_n_v, done), gamma, nstep, value_gamma) 616 td_error_per_sample = criterion(v, target_v.detach()) 617 return (td_error_per_sample * weight).mean(), td_error_per_sample 618 619 620q_nstep_td_data = namedtuple( 621 'q_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'] 622) 623 624dqfd_nstep_td_data = namedtuple( 625 'dqfd_nstep_td_data', [ 626 'q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'done_one_step', 'weight', 'new_n_q_one_step', 627 'next_n_action_one_step', 'is_expert' 628 ] 629) 630 631 632def shape_fn_qntd(args, kwargs): 633 r""" 634 Overview: 635 Return qntd shape for hpc 636 Returns: 637 shape: [T, B, N] 638 """ 639 if len(args) <= 0: 640 tmp = [kwargs['data'].reward.shape[0]] 641 tmp.extend(list(kwargs['data'].q.shape)) 642 else: 643 tmp = [args[0].reward.shape[0]] 644 tmp.extend(list(args[0].q.shape)) 645 return tmp 646 647 648@hpc_wrapper(shape_fn=shape_fn_qntd, namedtuple_data=True, include_args=[0, 1], include_kwargs=['data', 'gamma']) 649def q_nstep_td_error( 650 data: namedtuple, 651 gamma: Union[float, list], 652 nstep: int = 1, 653 cum_reward: bool = False, 654 value_gamma: Optional[torch.Tensor] = None, 655 criterion: torch.nn.modules = nn.MSELoss(reduction='none'), 656) -> torch.Tensor: 657 """ 658 Overview: 659 Multistep (1 step or n step) td_error for q-learning based algorithm 660 Arguments: 661 - data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss 662 - gamma (:obj:`float`): Discount factor 663 - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data 664 - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value 665 - criterion (:obj:`torch.nn.modules`): Loss function criterion 666 - nstep (:obj:`int`): nstep num, default set to 1 667 Returns: 668 - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor 669 - td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor 670 Shapes: 671 - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ 672 ['q', 'next_n_q', 'action', 'reward', 'done'] 673 - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] 674 - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` 675 - action (:obj:`torch.LongTensor`): :math:`(B, )` 676 - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` 677 - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) 678 - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep 679 - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` 680 Examples: 681 >>> next_q = torch.randn(4, 3) 682 >>> done = torch.randn(4) 683 >>> action = torch.randint(0, 3, size=(4, )) 684 >>> next_action = torch.randint(0, 3, size=(4, )) 685 >>> nstep =3 686 >>> q = torch.randn(4, 3).requires_grad_(True) 687 >>> reward = torch.rand(nstep, 4) 688 >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) 689 >>> loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep) 690 """ 691 q, next_n_q, action, next_n_action, reward, done, weight = data 692 if weight is None: 693 weight = torch.ones_like(reward) 694 695 if len(action.shape) == 1 or len(action.shape) < len(q.shape): 696 # we need to unsqueeze action and q to make them have the same shape 697 # e.g. single agent case: action is [B, ] and q is [B, ] 698 # e.g. multi agent case: action is [B, agent_num] and q is [B, agent_num, action_shape] 699 action = action.unsqueeze(-1) 700 elif len(action.shape) > 1: # MARL case 701 reward = reward.unsqueeze(-1) 702 weight = weight.unsqueeze(-1) 703 done = done.unsqueeze(-1) 704 if value_gamma is not None: 705 value_gamma = value_gamma.unsqueeze(-1) 706 707 q_s_a = q.gather(-1, action).squeeze(-1) 708 709 target_q_s_a = next_n_q.gather(-1, next_n_action.unsqueeze(-1)).squeeze(-1) 710 711 if cum_reward: 712 if value_gamma is None: 713 target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done) 714 else: 715 target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done) 716 else: 717 target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) 718 td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) 719 return (td_error_per_sample * weight).mean(), td_error_per_sample 720 721 722def bdq_nstep_td_error( 723 data: namedtuple, 724 gamma: Union[float, list], 725 nstep: int = 1, 726 cum_reward: bool = False, 727 value_gamma: Optional[torch.Tensor] = None, 728 criterion: torch.nn.modules = nn.MSELoss(reduction='none'), 729) -> torch.Tensor: 730 """ 731 Overview: 732 Multistep (1 step or n step) td_error for BDQ algorithm, referenced paper "Action Branching Architectures for \ 733 Deep Reinforcement Learning", link: https://arxiv.org/pdf/1711.08946. 734 In fact, the original paper only provides the 1-step TD-error calculation method, and here we extend the \ 735 calculation method of n-step, i.e., TD-error: 736 Arguments: 737 - data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss 738 - gamma (:obj:`float`): Discount factor 739 - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data 740 - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value 741 - criterion (:obj:`torch.nn.modules`): Loss function criterion 742 - nstep (:obj:`int`): nstep num, default set to 1 743 Returns: 744 - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor 745 - td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor 746 Shapes: 747 - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing \ 748 ['q', 'next_n_q', 'action', 'reward', 'done'] 749 - q (:obj:`torch.FloatTensor`): :math:`(B, D, N)` i.e. [batch_size, branch_num, action_bins_per_branch] 750 - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, D, N)` 751 - action (:obj:`torch.LongTensor`): :math:`(B, D)` 752 - next_n_action (:obj:`torch.LongTensor`): :math:`(B, D)` 753 - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) 754 - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep 755 - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` 756 Examples: 757 >>> action_per_branch = 3 758 >>> next_q = torch.randn(8, 6, action_per_branch) 759 >>> done = torch.randn(8) 760 >>> action = torch.randint(0, action_per_branch, size=(8, 6)) 761 >>> next_action = torch.randint(0, action_per_branch, size=(8, 6)) 762 >>> nstep =3 763 >>> q = torch.randn(8, 6, action_per_branch).requires_grad_(True) 764 >>> reward = torch.rand(nstep, 8) 765 >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) 766 >>> loss, td_error_per_sample = bdq_nstep_td_error(data, 0.95, nstep=nstep) 767 """ 768 q, next_n_q, action, next_n_action, reward, done, weight = data 769 if weight is None: 770 weight = torch.ones_like(reward) 771 reward = reward.unsqueeze(-1) 772 done = done.unsqueeze(-1) 773 if value_gamma is not None: 774 value_gamma = value_gamma.unsqueeze(-1) 775 776 q_s_a = q.gather(-1, action.unsqueeze(-1)).squeeze(-1) 777 target_q_s_a = next_n_q.gather(-1, next_n_action.unsqueeze(-1)).squeeze(-1) 778 779 if cum_reward: 780 if value_gamma is None: 781 target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done) 782 else: 783 target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done) 784 else: 785 target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) 786 td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) 787 td_error_per_sample = td_error_per_sample.mean(-1) 788 return (td_error_per_sample * weight).mean(), td_error_per_sample 789 790 791def shape_fn_qntd_rescale(args, kwargs): 792 r""" 793 Overview: 794 Return qntd_rescale shape for hpc 795 Returns: 796 shape: [T, B, N] 797 """ 798 if len(args) <= 0: 799 tmp = [kwargs['data'].reward.shape[0]] 800 tmp.extend(list(kwargs['data'].q.shape)) 801 else: 802 tmp = [args[0].reward.shape[0]] 803 tmp.extend(list(args[0].q.shape)) 804 return tmp 805 806 807@hpc_wrapper( 808 shape_fn=shape_fn_qntd_rescale, namedtuple_data=True, include_args=[0, 1], include_kwargs=['data', 'gamma'] 809) 810def q_nstep_td_error_with_rescale( 811 data: namedtuple, 812 gamma: Union[float, list], 813 nstep: int = 1, 814 value_gamma: Optional[torch.Tensor] = None, 815 criterion: torch.nn.modules = nn.MSELoss(reduction='none'), 816 trans_fn: Callable = value_transform, 817 inv_trans_fn: Callable = value_inv_transform, 818) -> torch.Tensor: 819 """ 820 Overview: 821 Multistep (1 step or n step) td_error with value rescaling 822 Arguments: 823 - data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss 824 - gamma (:obj:`float`): Discount factor 825 - nstep (:obj:`int`): nstep num, default set to 1 826 - criterion (:obj:`torch.nn.modules`): Loss function criterion 827 - trans_fn (:obj:`Callable`): Value transfrom function, default to value_transform\ 828 (refer to rl_utils/value_rescale.py) 829 - inv_trans_fn (:obj:`Callable`): Value inverse transfrom function, default to value_inv_transform\ 830 (refer to rl_utils/value_rescale.py) 831 Returns: 832 - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor 833 Shapes: 834 - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ 835 ['q', 'next_n_q', 'action', 'reward', 'done'] 836 - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] 837 - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` 838 - action (:obj:`torch.LongTensor`): :math:`(B, )` 839 - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` 840 - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) 841 - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep 842 Examples: 843 >>> next_q = torch.randn(4, 3) 844 >>> done = torch.randn(4) 845 >>> action = torch.randint(0, 3, size=(4, )) 846 >>> next_action = torch.randint(0, 3, size=(4, )) 847 >>> nstep =3 848 >>> q = torch.randn(4, 3).requires_grad_(True) 849 >>> reward = torch.rand(nstep, 4) 850 >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) 851 >>> loss, _ = q_nstep_td_error_with_rescale(data, 0.95, nstep=nstep) 852 """ 853 q, next_n_q, action, next_n_action, reward, done, weight = data 854 assert len(action.shape) == 1, action.shape 855 if weight is None: 856 weight = torch.ones_like(action) 857 858 batch_range = torch.arange(action.shape[0]) 859 q_s_a = q[batch_range, action] 860 target_q_s_a = next_n_q[batch_range, next_n_action] 861 862 target_q_s_a = inv_trans_fn(target_q_s_a) 863 target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) 864 target_q_s_a = trans_fn(target_q_s_a) 865 866 td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) 867 return (td_error_per_sample * weight).mean(), td_error_per_sample 868 869 870def dqfd_nstep_td_error( 871 data: namedtuple, 872 gamma: float, 873 lambda_n_step_td: float, 874 lambda_supervised_loss: float, 875 margin_function: float, 876 lambda_one_step_td: float = 1., 877 nstep: int = 1, 878 cum_reward: bool = False, 879 value_gamma: Optional[torch.Tensor] = None, 880 criterion: torch.nn.modules = nn.MSELoss(reduction='none'), 881) -> torch.Tensor: 882 """ 883 Overview: 884 Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd 885 Arguments: 886 - data (:obj:`dqfd_nstep_td_data`): The input data, dqfd_nstep_td_data to calculate loss 887 - gamma (:obj:`float`): discount factor 888 - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data 889 - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value 890 - criterion (:obj:`torch.nn.modules`): Loss function criterion 891 - nstep (:obj:`int`): nstep num, default set to 10 892 Returns: 893 - loss (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor 894 - td_error_per_sample (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error\ 895 + supervised margin loss, 1-dim tensor 896 Shapes: 897 - data (:obj:`q_nstep_td_data`): the q_nstep_td_data containing\ 898 ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'\ 899 , 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert'] 900 - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] 901 - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` 902 - action (:obj:`torch.LongTensor`): :math:`(B, )` 903 - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` 904 - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) 905 - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep 906 - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` 907 - new_n_q_one_step (:obj:`torch.FloatTensor`): :math:`(B, N)` 908 - next_n_action_one_step (:obj:`torch.LongTensor`): :math:`(B, )` 909 - is_expert (:obj:`int`) : 0 or 1 910 Examples: 911 >>> next_q = torch.randn(4, 3) 912 >>> done = torch.randn(4) 913 >>> done_1 = torch.randn(4) 914 >>> next_q_one_step = torch.randn(4, 3) 915 >>> action = torch.randint(0, 3, size=(4, )) 916 >>> next_action = torch.randint(0, 3, size=(4, )) 917 >>> next_action_one_step = torch.randint(0, 3, size=(4, )) 918 >>> is_expert = torch.ones((4)) 919 >>> nstep = 3 920 >>> q = torch.randn(4, 3).requires_grad_(True) 921 >>> reward = torch.rand(nstep, 4) 922 >>> data = dqfd_nstep_td_data( 923 >>> q, next_q, action, next_action, reward, done, done_1, None, 924 >>> next_q_one_step, next_action_one_step, is_expert 925 >>> ) 926 >>> loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error( 927 >>> data, 0.95, lambda_n_step_td=1, lambda_supervised_loss=1, 928 >>> margin_function=0.8, nstep=nstep 929 >>> ) 930 """ 931 q, next_n_q, action, next_n_action, reward, done, done_one_step, weight, new_n_q_one_step, next_n_action_one_step, \ 932 is_expert = data # set is_expert flag(expert 1, agent 0) 933 assert len(action.shape) == 1, action.shape 934 if weight is None: 935 weight = torch.ones_like(action) 936 937 batch_range = torch.arange(action.shape[0]) 938 q_s_a = q[batch_range, action] 939 target_q_s_a = next_n_q[batch_range, next_n_action] 940 target_q_s_a_one_step = new_n_q_one_step[batch_range, next_n_action_one_step] 941 942 # calculate n-step TD-loss 943 if cum_reward: 944 if value_gamma is None: 945 target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done) 946 else: 947 target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done) 948 else: 949 target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) 950 td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) 951 952 # calculate 1-step TD-loss 953 nstep = 1 954 reward = reward[0].unsqueeze(0) # get the one-step reward 955 value_gamma = None 956 if cum_reward: 957 if value_gamma is None: 958 target_q_s_a_one_step = reward + (gamma ** nstep) * target_q_s_a_one_step * (1 - done_one_step) 959 else: 960 target_q_s_a_one_step = reward + value_gamma * target_q_s_a_one_step * (1 - done_one_step) 961 else: 962 target_q_s_a_one_step = nstep_return( 963 nstep_return_data(reward, target_q_s_a_one_step, done_one_step), gamma, nstep, value_gamma 964 ) 965 td_error_one_step_per_sample = criterion(q_s_a, target_q_s_a_one_step.detach()) 966 device = q_s_a.device 967 device_cpu = torch.device('cpu') 968 # calculate the supervised loss 969 l = margin_function * torch.ones_like(q).to(device_cpu) # q shape (B, A), action shape (B, ) 970 l.scatter_(1, torch.LongTensor(action.unsqueeze(1).to(device_cpu)), torch.zeros_like(q, device=device_cpu)) 971 # along the first dimension. for the index of the action, fill the corresponding position in l with 0 972 JE = is_expert * (torch.max(q + l.to(device), dim=1)[0] - q_s_a) 973 974 return ( 975 ( 976 ( 977 lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample + 978 lambda_supervised_loss * JE 979 ) * weight 980 ).mean(), lambda_n_step_td * td_error_per_sample.abs() + 981 lambda_one_step_td * td_error_one_step_per_sample.abs() + lambda_supervised_loss * JE.abs(), 982 (td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean()) 983 ) 984 985 986def dqfd_nstep_td_error_with_rescale( 987 data: namedtuple, 988 gamma: float, 989 lambda_n_step_td: float, 990 lambda_supervised_loss: float, 991 lambda_one_step_td: float, 992 margin_function: float, 993 nstep: int = 1, 994 cum_reward: bool = False, 995 value_gamma: Optional[torch.Tensor] = None, 996 criterion: torch.nn.modules = nn.MSELoss(reduction='none'), 997 trans_fn: Callable = value_transform, 998 inv_trans_fn: Callable = value_inv_transform, 999) -> torch.Tensor:1000 """1001 Overview:1002 Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd1003 Arguments:1004 - data (:obj:`dqfd_nstep_td_data`): The input data, dqfd_nstep_td_data to calculate loss1005 - gamma (:obj:`float`): Discount factor1006 - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data1007 - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value1008 - criterion (:obj:`torch.nn.modules`): Loss function criterion1009 - nstep (:obj:`int`): nstep num, default set to 101010 Returns:1011 - loss (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor1012 - td_error_per_sample (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error\1013 + supervised margin loss, 1-dim tensor1014 Shapes:1015 - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\1016 ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'\1017 , 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert']1018 - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]1019 - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`1020 - action (:obj:`torch.LongTensor`): :math:`(B, )`1021 - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`1022 - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)1023 - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep1024 - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`1025 - new_n_q_one_step (:obj:`torch.FloatTensor`): :math:`(B, N)`1026 - next_n_action_one_step (:obj:`torch.LongTensor`): :math:`(B, )`1027 - is_expert (:obj:`int`) : 0 or 11028 """1029 q, next_n_q, action, next_n_action, reward, done, done_one_step, weight, new_n_q_one_step, next_n_action_one_step, \1030 is_expert = data # set is_expert flag(expert 1, agent 0)1031 assert len(action.shape) == 1, action.shape1032 if weight is None:1033 weight = torch.ones_like(action)10341035 batch_range = torch.arange(action.shape[0])1036 q_s_a = q[batch_range, action]10371038 target_q_s_a = next_n_q[batch_range, next_n_action]1039 target_q_s_a = inv_trans_fn(target_q_s_a) # rescale10401041 target_q_s_a_one_step = new_n_q_one_step[batch_range, next_n_action_one_step]1042 target_q_s_a_one_step = inv_trans_fn(target_q_s_a_one_step) # rescale10431044 # calculate n-step TD-loss1045 if cum_reward:1046 if value_gamma is None:1047 target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done)1048 else:1049 target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done)1050 else:1051 # to use value_gamma in n-step TD-loss1052 target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma)10531054 target_q_s_a = trans_fn(target_q_s_a) # rescale1055 td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())10561057 # calculate 1-step TD-loss1058 nstep = 11059 reward = reward[0].unsqueeze(0) # get the one-step reward1060 value_gamma = None # This is very important, to use gamma in 1-step TD-loss1061 if cum_reward:1062 if value_gamma is None:1063 target_q_s_a_one_step = reward + (gamma ** nstep) * target_q_s_a_one_step * (1 - done_one_step)1064 else:1065 target_q_s_a_one_step = reward + value_gamma * target_q_s_a_one_step * (1 - done_one_step)1066 else:1067 target_q_s_a_one_step = nstep_return(1068 nstep_return_data(reward, target_q_s_a_one_step, done_one_step), gamma, nstep, value_gamma1069 )10701071 target_q_s_a_one_step = trans_fn(target_q_s_a_one_step) # rescale1072 td_error_one_step_per_sample = criterion(q_s_a, target_q_s_a_one_step.detach())1073 device = q_s_a.device1074 device_cpu = torch.device('cpu')1075 # calculate the supervised loss1076 l = margin_function * torch.ones_like(q).to(device_cpu) # q shape (B, A), action shape (B, )1077 l.scatter_(1, torch.LongTensor(action.unsqueeze(1).to(device_cpu)), torch.zeros_like(q, device=device_cpu))1078 # along the first dimension. for the index of the action, fill the corresponding position in l with 01079 JE = is_expert * (torch.max(q + l.to(device), dim=1)[0] - q_s_a)10801081 return (1082 (1083 (1084 lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample +1085 lambda_supervised_loss * JE1086 ) * weight1087 ).mean(), lambda_n_step_td * td_error_per_sample.abs() +1088 lambda_one_step_td * td_error_one_step_per_sample.abs() + lambda_supervised_loss * JE.abs(),1089 (td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean())1090 )109110921093qrdqn_nstep_td_data = namedtuple(1094 'qrdqn_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'tau', 'weight']1095)109610971098def qrdqn_nstep_td_error(1099 data: namedtuple,1100 gamma: float,1101 nstep: int = 1,1102 value_gamma: Optional[torch.Tensor] = None,1103) -> torch.Tensor:1104 """1105 Overview:1106 Multistep (1 step or n step) td_error with in QRDQN1107 Arguments:1108 - data (:obj:`qrdqn_nstep_td_data`): The input data, qrdqn_nstep_td_data to calculate loss1109 - gamma (:obj:`float`): Discount factor1110 - nstep (:obj:`int`): nstep num, default set to 11111 Returns:1112 - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor1113 Shapes:1114 - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\1115 ['q', 'next_n_q', 'action', 'reward', 'done']1116 - q (:obj:`torch.FloatTensor`): :math:`(tau, B, N)` i.e. [tau x batch_size, action_dim]1117 - next_n_q (:obj:`torch.FloatTensor`): :math:`(tau', B, N)`1118 - action (:obj:`torch.LongTensor`): :math:`(B, )`1119 - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`1120 - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)1121 - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep1122 Examples:1123 >>> next_q = torch.randn(4, 3, 3)1124 >>> done = torch.randn(4)1125 >>> action = torch.randint(0, 3, size=(4, ))1126 >>> next_action = torch.randint(0, 3, size=(4, ))1127 >>> nstep = 31128 >>> q = torch.randn(4, 3, 3).requires_grad_(True)1129 >>> reward = torch.rand(nstep, 4)1130 >>> data = qrdqn_nstep_td_data(q, next_q, action, next_action, reward, done, 3, None)1131 >>> loss, td_error_per_sample = qrdqn_nstep_td_error(data, 0.95, nstep=nstep)1132 """1133 q, next_n_q, action, next_n_action, reward, done, tau, weight = data11341135 assert len(action.shape) == 1, action.shape1136 assert len(next_n_action.shape) == 1, next_n_action.shape1137 assert len(done.shape) == 1, done.shape1138 assert len(q.shape) == 3, q.shape1139 assert len(next_n_q.shape) == 3, next_n_q.shape1140 assert len(reward.shape) == 2, reward.shape11411142 if weight is None:1143 weight = torch.ones_like(action)11441145 batch_range = torch.arange(action.shape[0])11461147 # shape: batch_size x num x 11148 q_s_a = q[batch_range, action, :].unsqueeze(2)1149 # shape: batch_size x 1 x num1150 target_q_s_a = next_n_q[batch_range, next_n_action, :].unsqueeze(1)11511152 assert reward.shape[0] == nstep1153 reward_factor = torch.ones(nstep).to(reward)1154 for i in range(1, nstep):1155 reward_factor[i] = gamma * reward_factor[i - 1]1156 # shape: batch_size1157 reward = torch.matmul(reward_factor, reward)1158 # shape: batch_size x 1 x num1159 if value_gamma is None:1160 target_q_s_a = reward.unsqueeze(-1).unsqueeze(-1) + (gamma ** nstep1161 ) * target_q_s_a * (1 - done).unsqueeze(-1).unsqueeze(-1)1162 else:1163 target_q_s_a = reward.unsqueeze(-1).unsqueeze(1164 -11165 ) + value_gamma.unsqueeze(-1).unsqueeze(-1) * target_q_s_a * (1 - done).unsqueeze(-1).unsqueeze(-1)11661167 # shape: batch_size x num x num1168 u = F.smooth_l1_loss(target_q_s_a, q_s_a, reduction="none")1169 # shape: batch_size1170 loss = (u * (tau - (target_q_s_a - q_s_a).detach().le(0.).float()).abs()).sum(-1).mean(1)11711172 return (loss * weight).mean(), loss117311741175def q_nstep_sql_td_error(1176 data: namedtuple,1177 gamma: float,1178 alpha: float,1179 nstep: int = 1,1180 cum_reward: bool = False,1181 value_gamma: Optional[torch.Tensor] = None,1182 criterion: torch.nn.modules = nn.MSELoss(reduction='none'),1183) -> torch.Tensor:1184 """1185 Overview:1186 Multistep (1 step or n step) td_error for q-learning based algorithm1187 Arguments:1188 - data (:obj:`q_nstep_td_data`): The input data, q_nstep_sql_td_data to calculate loss1189 - gamma (:obj:`float`): Discount factor1190 - Alpha (:obj:`float`): A parameter to weight entropy term in a policy equation1191 - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data1192 - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target soft_q_value1193 - criterion (:obj:`torch.nn.modules`): Loss function criterion1194 - nstep (:obj:`int`): nstep num, default set to 11195 Returns:1196 - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor1197 - td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor1198 Shapes:1199 - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\1200 ['q', 'next_n_q', 'action', 'reward', 'done']1201 - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]1202 - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`1203 - action (:obj:`torch.LongTensor`): :math:`(B, )`1204 - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`1205 - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)1206 - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep1207 - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`1208 Examples:1209 >>> next_q = torch.randn(4, 3)1210 >>> done = torch.randn(4)1211 >>> action = torch.randint(0, 3, size=(4, ))1212 >>> next_action = torch.randint(0, 3, size=(4, ))1213 >>> nstep = 31214 >>> q = torch.randn(4, 3).requires_grad_(True)1215 >>> reward = torch.rand(nstep, 4)1216 >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)1217 >>> loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(data, 0.95, 1.0, nstep=nstep)1218 """1219 q, next_n_q, action, next_n_action, reward, done, weight = data1220 assert len(action.shape) == 1, action.shape1221 if weight is None:1222 weight = torch.ones_like(action)12231224 batch_range = torch.arange(action.shape[0])1225 q_s_a = q[batch_range, action]1226 # target_q_s_a = next_n_q[batch_range, next_n_action]1227 target_v = alpha * torch.logsumexp(1228 next_n_q / alpha, 11229 ) # target_v = alpha * torch.log(torch.sum(torch.exp(next_n_q / alpha), 1))1230 target_v[target_v == float("Inf")] = 201231 target_v[target_v == float("-Inf")] = -201232 # For an appropriate hyper-parameter alpha, these hardcodes can be removed.1233 # However, algorithms may face the danger of explosion for other alphas.1234 # The hardcodes above are to prevent this situation from happening1235 record_target_v = copy.deepcopy(target_v)1236 # print(target_v)1237 if cum_reward:1238 if value_gamma is None:1239 target_v = reward + (gamma ** nstep) * target_v * (1 - done)1240 else:1241 target_v = reward + value_gamma * target_v * (1 - done)1242 else:1243 target_v = nstep_return(nstep_return_data(reward, target_v, done), gamma, nstep, value_gamma)1244 td_error_per_sample = criterion(q_s_a, target_v.detach())1245 return (td_error_per_sample * weight).mean(), td_error_per_sample, record_target_v124612471248iqn_nstep_td_data = namedtuple(1249 'iqn_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'replay_quantiles', 'weight']1250)125112521253def iqn_nstep_td_error(1254 data: namedtuple,1255 gamma: float,1256 nstep: int = 1,1257 kappa: float = 1.0,1258 value_gamma: Optional[torch.Tensor] = None,1259) -> torch.Tensor:1260 """1261 Overview:1262 Multistep (1 step or n step) td_error with in IQN, \1263 referenced paper Implicit Quantile Networks for Distributional Reinforcement Learning \1264 <https://arxiv.org/pdf/1806.06923.pdf>1265 Arguments:1266 - data (:obj:`iqn_nstep_td_data`): The input data, iqn_nstep_td_data to calculate loss1267 - gamma (:obj:`float`): Discount factor1268 - nstep (:obj:`int`): nstep num, default set to 11269 - criterion (:obj:`torch.nn.modules`): Loss function criterion1270 - beta_function (:obj:`Callable`): The risk function1271 Returns:1272 - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor1273 Shapes:1274 - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\1275 ['q', 'next_n_q', 'action', 'reward', 'done']1276 - q (:obj:`torch.FloatTensor`): :math:`(tau, B, N)` i.e. [tau x batch_size, action_dim]1277 - next_n_q (:obj:`torch.FloatTensor`): :math:`(tau', B, N)`1278 - action (:obj:`torch.LongTensor`): :math:`(B, )`1279 - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`1280 - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)1281 - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep1282 Examples:1283 >>> next_q = torch.randn(3, 4, 3)1284 >>> done = torch.randn(4)1285 >>> action = torch.randint(0, 3, size=(4, ))1286 >>> next_action = torch.randint(0, 3, size=(4, ))1287 >>> nstep = 31288 >>> q = torch.randn(3, 4, 3).requires_grad_(True)1289 >>> replay_quantile = torch.randn([3, 4, 1])1290 >>> reward = torch.rand(nstep, 4)1291 >>> data = iqn_nstep_td_data(q, next_q, action, next_action, reward, done, replay_quantile, None)1292 >>> loss, td_error_per_sample = iqn_nstep_td_error(data, 0.95, nstep=nstep)1293 """1294 q, next_n_q, action, next_n_action, reward, done, replay_quantiles, weight = data12951296 assert len(action.shape) == 1, action.shape1297 assert len(next_n_action.shape) == 1, next_n_action.shape1298 assert len(done.shape) == 1, done.shape1299 assert len(q.shape) == 3, q.shape1300 assert len(next_n_q.shape) == 3, next_n_q.shape1301 assert len(reward.shape) == 2, reward.shape13021303 if weight is None:1304 weight = torch.ones_like(action)13051306 batch_size = done.shape[0]1307 tau = q.shape[0]1308 tau_prime = next_n_q.shape[0]13091310 action = action.repeat([tau, 1]).unsqueeze(-1)1311 next_n_action = next_n_action.repeat([tau_prime, 1]).unsqueeze(-1)13121313 # shape: batch_size x tau x a1314 q_s_a = torch.gather(q, -1, action).permute([1, 0, 2])1315 # shape: batch_size x tau_prim x 11316 target_q_s_a = torch.gather(next_n_q, -1, next_n_action).permute([1, 0, 2])13171318 assert reward.shape[0] == nstep1319 device = torch.device("cuda" if reward.is_cuda else "cpu")1320 reward_factor = torch.ones(nstep).to(device)1321 for i in range(1, nstep):1322 reward_factor[i] = gamma * reward_factor[i - 1]1323 reward = torch.matmul(reward_factor, reward)1324 if value_gamma is None:1325 target_q_s_a = reward.unsqueeze(-1) + (gamma ** nstep) * target_q_s_a.squeeze(-1) * (1 - done).unsqueeze(-1)1326 else:1327 target_q_s_a = reward.unsqueeze(-1) + value_gamma.unsqueeze(-1) * target_q_s_a.squeeze(-1) * (1 - done1328 ).unsqueeze(-1)1329 target_q_s_a = target_q_s_a.unsqueeze(-1)13301331 # shape: batch_size x tau' x tau x 1.1332 bellman_errors = (target_q_s_a[:, :, None, :] - q_s_a[:, None, :, :])13331334 # The huber loss (see Section 2.3 of the paper) is defined via two cases:1335 huber_loss = torch.where(1336 bellman_errors.abs() <= kappa, 0.5 * bellman_errors ** 2, kappa * (bellman_errors.abs() - 0.5 * kappa)1337 )13381339 # Reshape replay_quantiles to batch_size x num_tau_samples x 11340 replay_quantiles = replay_quantiles.reshape([tau, batch_size, 1]).permute([1, 0, 2])13411342 # shape: batch_size x num_tau_prime_samples x num_tau_samples x 1.1343 replay_quantiles = replay_quantiles[:, None, :, :].repeat([1, tau_prime, 1, 1])13441345 # shape: batch_size x tau_prime x tau x 1.1346 quantile_huber_loss = (torch.abs(replay_quantiles - ((bellman_errors < 0).float()).detach()) * huber_loss) / kappa13471348 # shape: batch_size1349 loss = quantile_huber_loss.sum(dim=2).mean(dim=1)[:, 0]13501351 return (loss * weight).mean(), loss135213531354fqf_nstep_td_data = namedtuple(1355 'fqf_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'quantiles_hats', 'weight']1356)135713581359def fqf_nstep_td_error(1360 data: namedtuple,1361 gamma: float,1362 nstep: int = 1,1363 kappa: float = 1.0,1364 value_gamma: Optional[torch.Tensor] = None,1365) -> torch.Tensor:1366 """1367 Overview:1368 Multistep (1 step or n step) td_error with in FQF, \1369 referenced paper Fully Parameterized Quantile Function for Distributional Reinforcement Learning \1370 <https://arxiv.org/pdf/1911.02140.pdf>1371 Arguments:1372 - data (:obj:`fqf_nstep_td_data`): The input data, fqf_nstep_td_data to calculate loss1373 - gamma (:obj:`float`): Discount factor1374 - nstep (:obj:`int`): nstep num, default set to 11375 - criterion (:obj:`torch.nn.modules`): Loss function criterion1376 - beta_function (:obj:`Callable`): The risk function1377 Returns:1378 - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor1379 Shapes:1380 - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\1381 ['q', 'next_n_q', 'action', 'reward', 'done']1382 - q (:obj:`torch.FloatTensor`): :math:`(B, tau, N)` i.e. [batch_size, tau, action_dim]1383 - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, tau', N)`1384 - action (:obj:`torch.LongTensor`): :math:`(B, )`1385 - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`1386 - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)1387 - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep1388 - quantiles_hats (:obj:`torch.FloatTensor`): :math:`(B, tau)`1389 Examples:1390 >>> next_q = torch.randn(4, 3, 3)1391 >>> done = torch.randn(4)1392 >>> action = torch.randint(0, 3, size=(4, ))1393 >>> next_action = torch.randint(0, 3, size=(4, ))1394 >>> nstep = 31395 >>> q = torch.randn(4, 3, 3).requires_grad_(True)1396 >>> quantiles_hats = torch.randn([4, 3])1397 >>> reward = torch.rand(nstep, 4)1398 >>> data = fqf_nstep_td_data(q, next_q, action, next_action, reward, done, quantiles_hats, None)1399 >>> loss, td_error_per_sample = fqf_nstep_td_error(data, 0.95, nstep=nstep)1400 """1401 q, next_n_q, action, next_n_action, reward, done, quantiles_hats, weight = data14021403 assert len(action.shape) == 1, action.shape1404 assert len(next_n_action.shape) == 1, next_n_action.shape1405 assert len(done.shape) == 1, done.shape1406 assert len(q.shape) == 3, q.shape1407 assert len(next_n_q.shape) == 3, next_n_q.shape1408 assert len(reward.shape) == 2, reward.shape14091410 if weight is None:1411 weight = torch.ones_like(action)14121413 batch_size = done.shape[0]1414 tau = q.shape[1]1415 tau_prime = next_n_q.shape[1]14161417 # shape: batch_size x tau x 11418 q_s_a = evaluate_quantile_at_action(q, action)1419 # shape: batch_size x tau_prime x 11420 target_q_s_a = evaluate_quantile_at_action(next_n_q, next_n_action)14211422 assert reward.shape[0] == nstep1423 reward_factor = torch.ones(nstep).to(reward.device)1424 for i in range(1, nstep):1425 reward_factor[i] = gamma * reward_factor[i - 1]1426 reward = torch.matmul(reward_factor, reward) # [batch_size]1427 if value_gamma is None:1428 target_q_s_a = reward.unsqueeze(-1) + (gamma ** nstep) * target_q_s_a.squeeze(-1) * (1 - done).unsqueeze(-1)1429 else:1430 target_q_s_a = reward.unsqueeze(-1) + value_gamma.unsqueeze(-1) * target_q_s_a.squeeze(-1) * (1 - done1431 ).unsqueeze(-1)1432 target_q_s_a = target_q_s_a.unsqueeze(-1)14331434 # shape: batch_size x tau' x tau x 1.1435 bellman_errors = (target_q_s_a.unsqueeze(2) - q_s_a.unsqueeze(1))14361437 # shape: batch_size x tau' x tau x 11438 huber_loss = F.smooth_l1_loss(target_q_s_a.unsqueeze(2), q_s_a.unsqueeze(1), reduction="none")14391440 # shape: batch_size x num_tau_prime_samples x num_tau_samples x 1.1441 quantiles_hats = quantiles_hats[:, None, :, None].repeat([1, tau_prime, 1, 1])14421443 # shape: batch_size x tau_prime x tau x 1.1444 quantile_huber_loss = (torch.abs(quantiles_hats - ((bellman_errors < 0).float()).detach()) * huber_loss) / kappa14451446 # shape: batch_size1447 loss = quantile_huber_loss.sum(dim=2).mean(dim=1)[:, 0]14481449 return (loss * weight).mean(), loss145014511452def evaluate_quantile_at_action(q_s, actions):1453 assert q_s.shape[0] == actions.shape[0]14541455 batch_size, num_quantiles = q_s.shape[:2]14561457 # Expand actions into (batch_size, num_quantiles, 1).1458 action_index = actions[:, None, None].expand(batch_size, num_quantiles, 1)14591460 # Calculate quantile values at specified actions.1461 q_s_a = q_s.gather(dim=2, index=action_index)14621463 return q_s_a146414651466def fqf_calculate_fraction_loss(q_tau_i, q_value, quantiles, actions):1467 """1468 Overview:1469 Calculate the fraction loss in FQF, \1470 referenced paper Fully Parameterized Quantile Function for Distributional Reinforcement Learning \1471 <https://arxiv.org/pdf/1911.02140.pdf>1472 Arguments:1473 - q_tau_i (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles-1, action_dim)`1474 - q_value (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles, action_dim)`1475 - quantiles (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles+1)`1476 - actions (:obj:`torch.LongTensor`): :math:`(batch_size, )`1477 Returns:1478 - fraction_loss (:obj:`torch.Tensor`): fraction loss, 0-dim tensor1479 """1480 assert q_value.requires_grad14811482 batch_size = q_value.shape[0]1483 num_quantiles = q_value.shape[1]14841485 with torch.no_grad():1486 sa_quantiles = evaluate_quantile_at_action(q_tau_i, actions)1487 assert sa_quantiles.shape == (batch_size, num_quantiles - 1, 1)1488 q_s_a_hats = evaluate_quantile_at_action(q_value, actions) # [batch_size, num_quantiles, 1]1489 assert q_s_a_hats.shape == (batch_size, num_quantiles, 1)1490 assert not q_s_a_hats.requires_grad14911492 # NOTE: Proposition 1 in the paper requires F^{-1} is non-decreasing.1493 # I relax this requirements and calculate gradients of quantiles even when1494 # F^{-1} is not non-decreasing.14951496 values_1 = sa_quantiles - q_s_a_hats[:, :-1]1497 signs_1 = sa_quantiles > torch.cat([q_s_a_hats[:, :1], sa_quantiles[:, :-1]], dim=1)1498 assert values_1.shape == signs_1.shape14991500 values_2 = sa_quantiles - q_s_a_hats[:, 1:]1501 signs_2 = sa_quantiles < torch.cat([sa_quantiles[:, 1:], q_s_a_hats[:, -1:]], dim=1)1502 assert values_2.shape == signs_2.shape15031504 gradient_of_taus = (torch.where(signs_1, values_1, -values_1) +1505 torch.where(signs_2, values_2, -values_2)).view(batch_size, num_quantiles - 1)1506 assert not gradient_of_taus.requires_grad1507 assert gradient_of_taus.shape == quantiles[:, 1:-1].shape15081509 # Gradients of the network parameters and corresponding loss1510 # are calculated using chain rule.1511 fraction_loss = (gradient_of_taus * quantiles[:, 1:-1]).sum(dim=1).mean()15121513 return fraction_loss151415151516td_lambda_data = namedtuple('td_lambda_data', ['value', 'reward', 'weight'])151715181519def shape_fn_td_lambda(args, kwargs):1520 r"""1521 Overview:1522 Return td_lambda shape for hpc1523 Returns:1524 shape: [T, B]1525 """1526 if len(args) <= 0:1527 tmp = kwargs['data'].reward.shape[0]1528 else:1529 tmp = args[0].reward.shape1530 return tmp153115321533@hpc_wrapper(1534 shape_fn=shape_fn_td_lambda,1535 namedtuple_data=True,1536 include_args=[0, 1, 2],1537 include_kwargs=['data', 'gamma', 'lambda_']1538)1539def td_lambda_error(data: namedtuple, gamma: float = 0.9, lambda_: float = 0.8) -> torch.Tensor:1540 """1541 Overview:1542 Computing TD(lambda) loss given constant gamma and lambda.1543 There is no special handling for terminal state value,1544 if some state has reached the terminal, just fill in zeros for values and rewards beyond terminal1545 (*including the terminal state*, values[terminal] should also be 0)1546 Arguments:1547 - data (:obj:`namedtuple`): td_lambda input data with fields ['value', 'reward', 'weight']1548 - gamma (:obj:`float`): Constant discount factor gamma, should be in [0, 1], defaults to 0.91549 - lambda (:obj:`float`): Constant lambda, should be in [0, 1], defaults to 0.81550 Returns:1551 - loss (:obj:`torch.Tensor`): Computed MSE loss, averaged over the batch1552 Shapes:1553 - value (:obj:`torch.FloatTensor`): :math:`(T+1, B)`, where T is trajectory length and B is batch,\1554 which is the estimation of the state value at step 0 to T1555 - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, the returns from time step 0 to T-11556 - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight1557 - loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor1558 Examples:1559 >>> T, B = 8, 41560 >>> value = torch.randn(T + 1, B).requires_grad_(True)1561 >>> reward = torch.rand(T, B)1562 >>> loss = td_lambda_error(td_lambda_data(value, reward, None))1563 """1564 value, reward, weight = data1565 if weight is None:1566 weight = torch.ones_like(reward)1567 with torch.no_grad():1568 return_ = generalized_lambda_returns(value, reward, gamma, lambda_)1569 # discard the value at T as it should be considered in the next slice1570 loss = 0.5 * (F.mse_loss(return_, value[:-1], reduction='none') * weight).mean()1571 return loss157215731574def generalized_lambda_returns(1575 bootstrap_values: torch.Tensor,1576 rewards: torch.Tensor,1577 gammas: float,1578 lambda_: float,1579 done: Optional[torch.Tensor] = None1580) -> torch.Tensor:1581 r"""1582 Overview:1583 Functional equivalent to trfl.value_ops.generalized_lambda_returns1584 https://github.com/deepmind/trfl/blob/2c07ac22512a16715cc759f0072be43a5d12ae45/trfl/value_ops.py#L741585 Passing in a number instead of tensor to make the value constant for all samples in batch1586 Arguments:1587 - bootstrap_values (:obj:`torch.Tensor` or :obj:`float`):1588 estimation of the value at step 0 to *T*, of size [T_traj+1, batchsize]1589 - rewards (:obj:`torch.Tensor`): The returns from 0 to T-1, of size [T_traj, batchsize]1590 - gammas (:obj:`torch.Tensor` or :obj:`float`):1591 Discount factor for each step (from 0 to T-1), of size [T_traj, batchsize]1592 - lambda (:obj:`torch.Tensor` or :obj:`float`): Determining the mix of bootstrapping1593 vs further accumulation of multistep returns at each timestep, of size [T_traj, batchsize]1594 - done (:obj:`torch.Tensor` or :obj:`float`):1595 Whether the episode done at current step (from 0 to T-1), of size [T_traj, batchsize]1596 Returns:1597 - return (:obj:`torch.Tensor`): Computed lambda return value1598 for each state from 0 to T-1, of size [T_traj, batchsize]1599 """1600 if not isinstance(gammas, torch.Tensor):1601 gammas = gammas * torch.ones_like(rewards)1602 if not isinstance(lambda_, torch.Tensor):1603 lambda_ = lambda_ * torch.ones_like(rewards)1604 bootstrap_values_tp1 = bootstrap_values[1:, :]1605 return multistep_forward_view(bootstrap_values_tp1, rewards, gammas, lambda_, done)160616071608def multistep_forward_view(1609 bootstrap_values: torch.Tensor,1610 rewards: torch.Tensor,1611 gammas: float,1612 lambda_: float,1613 done: Optional[torch.Tensor] = None1614) -> torch.Tensor:1615 """1616 Overview:1617 Same as trfl.sequence_ops.multistep_forward_view, which implements (12.18) in Sutton & Barto.1618 Assuming the first dim of input tensors correspond to the index in batch.16191620 .. note::1621 result[T-1] = rewards[T-1] + gammas[T-1] * bootstrap_values[T]1622 for t in 0...T-2 :1623 result[t] = rewards[t] + gammas[t]*(lambdas[t]*result[t+1] + (1-lambdas[t])*bootstrap_values[t+1])16241625 Arguments:1626 - bootstrap_values (:obj:`torch.Tensor`): Estimation of the value at *step 1 to T*, of size [T_traj, batchsize]1627 - rewards (:obj:`torch.Tensor`): The returns from 0 to T-1, of size [T_traj, batchsize]1628 - gammas (:obj:`torch.Tensor`): Discount factor for each step (from 0 to T-1), of size [T_traj, batchsize]1629 - lambda (:obj:`torch.Tensor`): Determining the mix of bootstrapping vs further accumulation of \1630 multistep returns at each timestep of size [T_traj, batchsize], the element for T-1 is ignored \1631 and effectively set to 0, as there is no information about future rewards.1632 - done (:obj:`torch.Tensor` or :obj:`float`):1633 Whether the episode done at current step (from 0 to T-1), of size [T_traj, batchsize]1634 Returns:1635 - ret (:obj:`torch.Tensor`): Computed lambda return value \1636 for each state from 0 to T-1, of size [T_traj, batchsize]1637 """1638 result = torch.empty_like(rewards)1639 if done is None:1640 done = torch.zeros_like(rewards)1641 # Forced cutoff at the last one1642 result[-1, :] = rewards[-1, :] + (1 - done[-1, :]) * gammas[-1, :] * bootstrap_values[-1, :]1643 discounts = gammas * lambda_1644 for t in reversed(range(rewards.size()[0] - 1)):1645 result[t, :] = rewards[t, :] + (1 - done[t, :]) * \1646 (1647 discounts[t, :] * result[t + 1, :] +1648 (gammas[t, :] - discounts[t, :]) * bootstrap_values[t, :]1649 )16501651 return result