Skip to content

ding.policy.common_utils

ding.policy.common_utils

set_noise_mode(module, noise_enabled)

Overview

Recursively set the 'enable_noise' attribute for all NoiseLinearLayer modules within the given module. This function is typically used in algorithms such as NoisyNet and Rainbow. During training, 'enable_noise' should be set to True to enable noise for exploration. During inference or evaluation, it should be set to False to disable noise for deterministic behavior.

Parameters:

Name Type Description Default
- module (

obj:nn.Module): The root module to search for NoiseLinearLayer instances.

required
- noise_enabled (

obj:bool): Whether to enable or disable noise.

required

default_preprocess_learn(data, use_priority_IS_weight=False, use_priority=False, use_nstep=False, ignore_done=False)

Overview

Default data pre-processing in policy's _forward_learn method, including stacking batch data, preprocess ignore done, nstep and priority IS weight.

Arguments: - data (:obj:List[Any]): The list of a training batch samples, each sample is a dict of PyTorch Tensor. - use_priority_IS_weight (:obj:bool): Whether to use priority IS weight correction, if True, this function will set the weight of each sample to the priority IS weight. - use_priority (:obj:bool): Whether to use priority, if True, this function will set the priority IS weight. - use_nstep (:obj:bool): Whether to use nstep TD error, if True, this function will reshape the reward. - ignore_done (:obj:bool): Whether to ignore done, if True, this function will set the done to 0. Returns: - data (:obj:Dict[str, torch.Tensor]): The preprocessed dict data whose values can be directly used for the following model forward and loss computation.

single_env_forward_wrapper(forward_fn)

Overview

Wrap policy to support gym-style interaction between policy and single environment.

Arguments: - forward_fn (:obj:Callable): The original forward function of policy. Returns: - wrapped_forward_fn (:obj:Callable): The wrapped forward function of policy. Examples: >>> env = gym.make('CartPole-v0') >>> policy = DQNPolicy(...) >>> forward_fn = single_env_forward_wrapper(policy.eval_mode.forward) >>> obs = env.reset() >>> action = forward_fn(obs) >>> next_obs, rew, done, info = env.step(action)

single_env_forward_wrapper_ttorch(forward_fn, cuda=True)

Overview

Wrap policy to support gym-style interaction between policy and single environment for treetensor (ttorch) data.

Arguments: - forward_fn (:obj:Callable): The original forward function of policy. - cuda (:obj:bool): Whether to use cuda in policy, if True, this function will move the input data to cuda. Returns: - wrapped_forward_fn (:obj:Callable): The wrapped forward function of policy.

Examples:

>>> env = gym.make('CartPole-v0')
>>> policy = PPOFPolicy(...)
>>> forward_fn = single_env_forward_wrapper_ttorch(policy.eval)
>>> obs = env.reset()
>>> action = forward_fn(obs)
>>> next_obs, rew, done, info = env.step(action)

Full Source Code

../ding/policy/common_utils.py

1from typing import List, Any, Dict, Callable 2import torch 3import torch.nn as nn 4import numpy as np 5import treetensor.torch as ttorch 6from ding.utils.data import default_collate 7from ding.torch_utils import to_tensor, to_ndarray, unsqueeze, squeeze 8from ding.torch_utils import NoiseLinearLayer 9 10 11def set_noise_mode(module: nn.Module, noise_enabled: bool): 12 """ 13 Overview: 14 Recursively set the 'enable_noise' attribute for all NoiseLinearLayer modules within the given module. 15 This function is typically used in algorithms such as NoisyNet and Rainbow. 16 During training, 'enable_noise' should be set to True to enable noise for exploration. 17 During inference or evaluation, it should be set to False to disable noise for deterministic behavior. 18 19 Arguments: 20 - module (:obj:`nn.Module`): The root module to search for NoiseLinearLayer instances. 21 - noise_enabled (:obj:`bool`): Whether to enable or disable noise. 22 """ 23 for m in module.modules(): 24 if isinstance(m, NoiseLinearLayer): 25 m.enable_noise = noise_enabled 26 27 28def default_preprocess_learn( 29 data: List[Any], 30 use_priority_IS_weight: bool = False, 31 use_priority: bool = False, 32 use_nstep: bool = False, 33 ignore_done: bool = False, 34) -> Dict[str, torch.Tensor]: 35 """ 36 Overview: 37 Default data pre-processing in policy's ``_forward_learn`` method, including stacking batch data, preprocess \ 38 ignore done, nstep and priority IS weight. 39 Arguments: 40 - data (:obj:`List[Any]`): The list of a training batch samples, each sample is a dict of PyTorch Tensor. 41 - use_priority_IS_weight (:obj:`bool`): Whether to use priority IS weight correction, if True, this function \ 42 will set the weight of each sample to the priority IS weight. 43 - use_priority (:obj:`bool`): Whether to use priority, if True, this function will set the priority IS weight. 44 - use_nstep (:obj:`bool`): Whether to use nstep TD error, if True, this function will reshape the reward. 45 - ignore_done (:obj:`bool`): Whether to ignore done, if True, this function will set the done to 0. 46 Returns: 47 - data (:obj:`Dict[str, torch.Tensor]`): The preprocessed dict data whose values can be directly used for \ 48 the following model forward and loss computation. 49 """ 50 # data preprocess 51 elem = data[0] 52 if isinstance(elem['action'], (np.ndarray, torch.Tensor)) and elem['action'].dtype in [np.int64, torch.int64]: 53 data = default_collate(data, cat_1dim=True) # for discrete action 54 else: 55 data = default_collate(data, cat_1dim=False) # for continuous action 56 if 'value' in data and data['value'].dim() == 2 and data['value'].shape[1] == 1: 57 data['value'] = data['value'].squeeze(-1) 58 if 'adv' in data and data['adv'].dim() == 2 and data['adv'].shape[1] == 1: 59 data['adv'] = data['adv'].squeeze(-1) 60 61 if ignore_done: 62 data['done'] = torch.zeros_like(data['done']).float() 63 else: 64 data['done'] = data['done'].float() 65 66 if data['done'].dim() == 2 and data['done'].shape[1] == 1: 67 data['done'] = data['done'].squeeze(-1) 68 69 if use_priority_IS_weight: 70 assert use_priority, "Use IS Weight correction, but Priority is not used." 71 if use_priority and use_priority_IS_weight: 72 if 'priority_IS' in data: 73 data['weight'] = data['priority_IS'] 74 else: # for compability 75 data['weight'] = data['IS'] 76 else: 77 data['weight'] = data.get('weight', None) 78 if use_nstep: 79 # reward reshaping for n-step 80 reward = data['reward'] 81 if len(reward.shape) == 1: 82 reward = reward.unsqueeze(1) 83 # single agent reward: (batch_size, nstep) -> (nstep, batch_size) 84 # multi-agent reward: (batch_size, agent_dim, nstep) -> (nstep, batch_size, agent_dim) 85 # Assuming 'reward' is a PyTorch tensor with shape (batch_size, nstep) or (batch_size, agent_dim, nstep) 86 if reward.ndim == 2: 87 # For a 2D tensor, simply transpose it to get (nstep, batch_size) 88 data['reward'] = reward.transpose(0, 1).contiguous() 89 elif reward.ndim == 3: 90 # For a 3D tensor, move the last dimension to the front to get (nstep, batch_size, agent_dim) 91 data['reward'] = reward.permute(2, 0, 1).contiguous() 92 else: 93 raise ValueError("The 'reward' tensor must be either 2D or 3D. Got shape: {}".format(reward.shape)) 94 else: 95 if data['reward'].dim() == 2 and data['reward'].shape[1] == 1: 96 data['reward'] = data['reward'].squeeze(-1) 97 98 return data 99 100 101def single_env_forward_wrapper(forward_fn: Callable) -> Callable: 102 """ 103 Overview: 104 Wrap policy to support gym-style interaction between policy and single environment. 105 Arguments: 106 - forward_fn (:obj:`Callable`): The original forward function of policy. 107 Returns: 108 - wrapped_forward_fn (:obj:`Callable`): The wrapped forward function of policy. 109 Examples: 110 >>> env = gym.make('CartPole-v0') 111 >>> policy = DQNPolicy(...) 112 >>> forward_fn = single_env_forward_wrapper(policy.eval_mode.forward) 113 >>> obs = env.reset() 114 >>> action = forward_fn(obs) 115 >>> next_obs, rew, done, info = env.step(action) 116 117 """ 118 119 def _forward(obs): 120 obs = {0: unsqueeze(to_tensor(obs))} 121 action = forward_fn(obs)[0]['action'] 122 action = to_ndarray(squeeze(action)) 123 return action 124 125 return _forward 126 127 128def single_env_forward_wrapper_ttorch(forward_fn: Callable, cuda: bool = True) -> Callable: 129 """ 130 Overview: 131 Wrap policy to support gym-style interaction between policy and single environment for treetensor (ttorch) data. 132 Arguments: 133 - forward_fn (:obj:`Callable`): The original forward function of policy. 134 - cuda (:obj:`bool`): Whether to use cuda in policy, if True, this function will move the input data to cuda. 135 Returns: 136 - wrapped_forward_fn (:obj:`Callable`): The wrapped forward function of policy. 137 138 Examples: 139 >>> env = gym.make('CartPole-v0') 140 >>> policy = PPOFPolicy(...) 141 >>> forward_fn = single_env_forward_wrapper_ttorch(policy.eval) 142 >>> obs = env.reset() 143 >>> action = forward_fn(obs) 144 >>> next_obs, rew, done, info = env.step(action) 145 """ 146 147 def _forward(obs): 148 # unsqueeze means add batch dim, i.e. (O, ) -> (1, O) 149 obs = ttorch.as_tensor(obs).unsqueeze(0) 150 if cuda and torch.cuda.is_available(): 151 obs = obs.cuda() 152 action = forward_fn(obs).action 153 # squeeze means delete batch dim, i.e. (1, A) -> (A, ) 154 action = action.squeeze(0).cpu().numpy() 155 return action 156 157 return _forward