Skip to content

ding.policy.d4pg

ding.policy.d4pg

D4PGPolicy

Bases: DDPGPolicy

Overview

Policy class of D4PG algorithm. D4PG is a variant of DDPG, which uses distributional critic. The distributional critic is implemented by using quantile regression. Paper link: https://arxiv.org/abs/1804.08617.

Property

learn_mode, collect_mode, eval_mode

Config: == ==================== ======== ============= ================================= ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============= ================================= ======================= 1 type str d4pg | RL policy register name, refer | this arg is optional, | to registry POLICY_REGISTRY | a placeholder 2 cuda bool True | Whether to use cuda for network | 3 | random_ int 25000 | Number of randomly collected | Default to 25000 for | collect_size | training samples in replay | DDPG/TD3, 10000 for | | buffer when training starts. | sac. 5 | learn.learning float 1e-3 | Learning rate for actor | | _rate_actor | network(aka. policy). | 6 | learn.learning float 1e-3 | Learning rates for critic | | _rate_critic | network (aka. Q-network). | 7 | learn.actor_ int 1 | When critic network updates | Default 1 | update_freq | once, how many times will actor | | | network update. | 8 | learn.noise bool False | Whether to add noise on target | Default False for | | network's action. | D4PG. | | | Target Policy Smoo- | | | thing Regularization | | | in TD3 paper. 9 | learn.- bool False | Determine whether to ignore | Use ignore_done only | ignore_done | done flag. | in halfcheetah env. 10 | learn.- float 0.005 | Used for soft update of the | aka. Interpolation | target_theta | target network. | factor in polyak aver | | | aging for target | | | networks. 11 | collect.- float 0.1 | Used for add noise during co- | Sample noise from dis | noise_sigma | llection, through controlling | tribution, Gaussian | | the sigma of distribution | process. 12 | model.v_min float -10 | Value of the smallest atom | | | in the support set. | 13 | model.v_max float 10 | Value of the largest atom | | | in the support set. | 14 | model.n_atom int 51 | Number of atoms in the support | | | set of the value distribution. | 15 | nstep int 3, [1, 5] | N-step reward discount sum for | | | target q_value estimation | 16 | priority bool True | Whether use priority(PER) | priority sample, | update priority == ==================== ======== ============= ================================= =======================

default_model()

Overview

Return the default neural network model class for D4PGPolicy. __init__ method will automatically call this method to get the default model setting and create model.

Returns:

Type Description
Tuple[str, List[str]]
  • model_info (:obj:Tuple[str, List[str]]): The registered model name and model's import_names.

Full Source Code

../ding/policy/d4pg.py

1from typing import List, Dict, Any, Tuple, Union 2import torch 3import copy 4 5from ding.torch_utils import Adam, to_device 6from ding.rl_utils import get_train_sample 7from ding.rl_utils import dist_nstep_td_data, dist_nstep_td_error, get_nstep_return_data 8from ding.model import model_wrap 9from ding.utils import POLICY_REGISTRY 10from .ddpg import DDPGPolicy 11from .common_utils import default_preprocess_learn 12import numpy as np 13 14 15@POLICY_REGISTRY.register('d4pg') 16class D4PGPolicy(DDPGPolicy): 17 """ 18 Overview: 19 Policy class of D4PG algorithm. D4PG is a variant of DDPG, which uses distributional critic. \ 20 The distributional critic is implemented by using quantile regression. \ 21 Paper link: https://arxiv.org/abs/1804.08617. 22 23 Property: 24 learn_mode, collect_mode, eval_mode 25 Config: 26 == ==================== ======== ============= ================================= ======================= 27 ID Symbol Type Default Value Description Other(Shape) 28 == ==================== ======== ============= ================================= ======================= 29 1 ``type`` str d4pg | RL policy register name, refer | this arg is optional, 30 | to registry ``POLICY_REGISTRY`` | a placeholder 31 2 ``cuda`` bool True | Whether to use cuda for network | 32 3 | ``random_`` int 25000 | Number of randomly collected | Default to 25000 for 33 | ``collect_size`` | training samples in replay | DDPG/TD3, 10000 for 34 | | buffer when training starts. | sac. 35 5 | ``learn.learning`` float 1e-3 | Learning rate for actor | 36 | ``_rate_actor`` | network(aka. policy). | 37 6 | ``learn.learning`` float 1e-3 | Learning rates for critic | 38 | ``_rate_critic`` | network (aka. Q-network). | 39 7 | ``learn.actor_`` int 1 | When critic network updates | Default 1 40 | ``update_freq`` | once, how many times will actor | 41 | | network update. | 42 8 | ``learn.noise`` bool False | Whether to add noise on target | Default False for 43 | | network's action. | D4PG. 44 | | | Target Policy Smoo- 45 | | | thing Regularization 46 | | | in TD3 paper. 47 9 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only 48 | ``ignore_done`` | done flag. | in halfcheetah env. 49 10 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation 50 | ``target_theta`` | target network. | factor in polyak aver 51 | | | aging for target 52 | | | networks. 53 11 | ``collect.-`` float 0.1 | Used for add noise during co- | Sample noise from dis 54 | ``noise_sigma`` | llection, through controlling | tribution, Gaussian 55 | | the sigma of distribution | process. 56 12 | ``model.v_min`` float -10 | Value of the smallest atom | 57 | | in the support set. | 58 13 | ``model.v_max`` float 10 | Value of the largest atom | 59 | | in the support set. | 60 14 | ``model.n_atom`` int 51 | Number of atoms in the support | 61 | | set of the value distribution. | 62 15 | ``nstep`` int 3, [1, 5] | N-step reward discount sum for | 63 | | target q_value estimation | 64 16 | ``priority`` bool True | Whether use priority(PER) | priority sample, 65 | update priority 66 == ==================== ======== ============= ================================= ======================= 67 """ 68 69 config = dict( 70 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 71 type='d4pg', 72 # (bool) Whether to use cuda for network. 73 cuda=False, 74 # (bool type) on_policy: Determine whether on-policy or off-policy. 75 # on-policy setting influences the behaviour of buffer. 76 # Default False in D4PG. 77 on_policy=False, 78 # (bool) Whether use priority(priority sample, IS weight, update priority) 79 # Default True in D4PG. 80 priority=True, 81 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 82 priority_IS_weight=True, 83 # (int) Number of training samples(randomly collected) in replay buffer when training starts. 84 # Default 25000 in D4PG. 85 random_collect_size=25000, 86 # (int) N-step reward for target q_value estimation 87 nstep=3, 88 # (str) Action space type 89 action_space='continuous', # ['continuous', 'hybrid'] 90 # (bool) Whether use batch normalization for reward 91 reward_batch_norm=False, 92 # (bool) Whether to need policy data in process transition 93 transition_with_policy_data=False, 94 model=dict( 95 # (float) Value of the smallest atom in the support set. 96 # Default to -10.0. 97 v_min=-10, 98 # (float) Value of the smallest atom in the support set. 99 # Default to 10.0. 100 v_max=10, 101 # (int) Number of atoms in the support set of the 102 # value distribution. Default to 51. 103 n_atom=51 104 ), 105 learn=dict( 106 107 # How many updates(iterations) to train after collector's one collection. 108 # Bigger "update_per_collect" means bigger off-policy. 109 # collect data -> update policy-> collect data -> ... 110 update_per_collect=1, 111 # (int) Minibatch size for gradient descent. 112 batch_size=256, 113 # Learning rates for actor network(aka. policy). 114 learning_rate_actor=1e-3, 115 # Learning rates for critic network(aka. Q-network). 116 learning_rate_critic=1e-3, 117 # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) 118 # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. 119 # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. 120 # However, interaction with HalfCheetah always gets done with done is False, 121 # Since we inplace done==True with done==False to keep 122 # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), 123 # when the episode step is greater than max episode step. 124 ignore_done=False, 125 # (float type) target_theta: Used for soft update of the target network, 126 # aka. Interpolation factor in polyak averaging for target networks. 127 # Default to 0.005. 128 target_theta=0.005, 129 # (float) discount factor for the discounted sum of rewards, aka. gamma. 130 discount_factor=0.99, 131 # (int) When critic network updates once, how many times will actor network update. 132 actor_update_freq=1, 133 # (bool) Whether to add noise on target network's action. 134 # Target Policy Smoothing Regularization in original TD3 paper. 135 noise=False, 136 ), 137 collect=dict( 138 # (int) Only one of [n_sample, n_episode] should be set 139 # n_sample=1, 140 # (int) Cut trajectories into pieces with length "unroll_len". 141 unroll_len=1, 142 # It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma". 143 noise_sigma=0.1, 144 ), 145 eval=dict(evaluator=dict(eval_freq=1000, ), ), 146 other=dict( 147 replay_buffer=dict( 148 # (int) Maximum size of replay buffer. 149 replay_buffer_size=1000000, 150 ), 151 ), 152 ) 153 154 def default_model(self) -> Tuple[str, List[str]]: 155 """ 156 Overview: 157 Return the default neural network model class for D4PGPolicy. ``__init__`` method will \ 158 automatically call this method to get the default model setting and create model. 159 160 Returns: 161 - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. 162 """ 163 return 'qac_dist', ['ding.model.template.qac_dist'] 164 165 def _init_learn(self) -> None: 166 """ 167 Overview: 168 Initialize the D4PG policy's learning mode, which involves setting up key components \ 169 specific to the D4PG algorithm. This includes creating separate optimizers for the actor \ 170 and critic networks, a distinctive trait of D4PG's actor-critic approach, and configuring \ 171 algorithm-specific parameters such as v_min, v_max, and n_atom for the distributional aspect \ 172 of the critic. Additionally, the method sets up the target model with momentum-based updates, \ 173 crucial for stabilizing learning, and optionally integrates noise into the target model for \ 174 effective exploration. This method is invoked during the '__init__' if 'learn' is specified \ 175 in 'enable_field'. 176 177 .. note:: 178 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 179 and ``_load_state_dict_learn`` methods. 180 181 .. note:: 182 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 183 184 .. note:: 185 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 186 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 187 """ 188 self._priority = self._cfg.priority 189 self._priority_IS_weight = self._cfg.priority_IS_weight 190 # actor and critic optimizer 191 self._optimizer_actor = Adam( 192 self._model.actor.parameters(), 193 lr=self._cfg.learn.learning_rate_actor, 194 ) 195 self._optimizer_critic = Adam( 196 self._model.critic.parameters(), 197 lr=self._cfg.learn.learning_rate_critic, 198 ) 199 self._reward_batch_norm = self._cfg.reward_batch_norm 200 201 self._gamma = self._cfg.learn.discount_factor 202 self._nstep = self._cfg.nstep 203 self._actor_update_freq = self._cfg.learn.actor_update_freq 204 205 # main and target models 206 self._target_model = copy.deepcopy(self._model) 207 self._target_model = model_wrap( 208 self._target_model, 209 wrapper_name='target', 210 update_type='momentum', 211 update_kwargs={'theta': self._cfg.learn.target_theta} 212 ) 213 if self._cfg.learn.noise: 214 self._target_model = model_wrap( 215 self._target_model, 216 wrapper_name='action_noise', 217 noise_type='gauss', 218 noise_kwargs={ 219 'mu': 0.0, 220 'sigma': self._cfg.learn.noise_sigma 221 }, 222 noise_range=self._cfg.learn.noise_range 223 ) 224 self._learn_model = model_wrap(self._model, wrapper_name='base') 225 self._learn_model.reset() 226 self._target_model.reset() 227 228 self._v_max = self._cfg.model.v_max 229 self._v_min = self._cfg.model.v_min 230 self._n_atom = self._cfg.model.n_atom 231 232 self._forward_learn_cnt = 0 # count iterations 233 234 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 235 """ 236 Overview: 237 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 238 that the policy inputs some training batch data from the replay buffer and then returns the output \ 239 result, including various training information such as different loss, actor and critic lr. 240 Arguments: 241 - data (:obj:`dict`): Input data used for policy forward, including the \ 242 collected training samples from replay buffer. For each element in dict, the key of the \ 243 dict is the name of data items and the value is the corresponding data. Usually, the value is \ 244 torch.Tensor or np.ndarray or there dict/list combinations. In the ``_forward_learn`` method, data \ 245 often need to first be stacked in the batch dimension by some utility functions such as \ 246 ``default_preprocess_learn``. \ 247 For D4PG, each element in list is a dict containing at least the following keys: ``obs``, \ 248 ``action``, ``reward``, ``next_obs``. Sometimes, it also contains other keys such as ``weight``. 249 250 Returns: 251 - info_dict (:obj:`Dict[str, Any]`): The output result dict of forward learn, containing at \ 252 least the "cur_lr_actor", "cur_lr_critic", "different losses", "q_value", "action", "priority", \ 253 keys. Additionally, loss_dict also contains other keys, which are mainly used for monitoring and \ 254 debugging. "q_value_dict" is used to record the q_value statistics. 255 256 .. note:: 257 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 258 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 259 You can implement you own model rather than use the default model. For more information, please raise an \ 260 issue in GitHub repo and we will continue to follow up. 261 262 .. note:: 263 For more detailed examples, please refer to our unittest for D4PGPolicy: ``ding.policy.tests.test_d4pg``. 264 """ 265 loss_dict = {} 266 data = default_preprocess_learn( 267 data, 268 use_priority=self._cfg.priority, 269 use_priority_IS_weight=self._cfg.priority_IS_weight, 270 ignore_done=self._cfg.learn.ignore_done, 271 use_nstep=True 272 ) 273 if self._cuda: 274 data = to_device(data, self._device) 275 # ==================== 276 # critic learn forward 277 # ==================== 278 self._learn_model.train() 279 self._target_model.train() 280 next_obs = data.get('next_obs') 281 reward = data.get('reward') 282 if self._reward_batch_norm: 283 reward = (reward - reward.mean()) / (reward.std() + 1e-8) 284 # current q value 285 q_value = self._learn_model.forward(data, mode='compute_critic') 286 q_value_dict = {} 287 q_dist = q_value['distribution'] 288 q_value_dict['q_value'] = q_value['q_value'].mean() 289 # target q value. 290 with torch.no_grad(): 291 next_action = self._target_model.forward(next_obs, mode='compute_actor')['action'] 292 next_data = {'obs': next_obs, 'action': next_action} 293 target_q_dist = self._target_model.forward(next_data, mode='compute_critic')['distribution'] 294 295 value_gamma = data.get('value_gamma') 296 action_index = np.zeros(next_action.shape[0]) 297 # since the action is a scalar value, action index is set to 0 which is the only possible choice 298 td_data = dist_nstep_td_data( 299 q_dist, target_q_dist, action_index, action_index, reward, data['done'], data['weight'] 300 ) 301 critic_loss, td_error_per_sample = dist_nstep_td_error( 302 td_data, self._gamma, self._v_min, self._v_max, self._n_atom, nstep=self._nstep, value_gamma=value_gamma 303 ) 304 loss_dict['critic_loss'] = critic_loss 305 # ================ 306 # critic update 307 # ================ 308 self._optimizer_critic.zero_grad() 309 for k in loss_dict: 310 if 'critic' in k: 311 loss_dict[k].backward() 312 self._optimizer_critic.step() 313 # =============================== 314 # actor learn forward and update 315 # =============================== 316 # actor updates every ``self._actor_update_freq`` iters 317 if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0: 318 actor_data = self._learn_model.forward(data['obs'], mode='compute_actor') 319 actor_data['obs'] = data['obs'] 320 actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'].mean() 321 322 loss_dict['actor_loss'] = actor_loss 323 # actor update 324 self._optimizer_actor.zero_grad() 325 actor_loss.backward() 326 self._optimizer_actor.step() 327 # ============= 328 # after update 329 # ============= 330 loss_dict['total_loss'] = sum(loss_dict.values()) 331 self._forward_learn_cnt += 1 332 self._target_model.update(self._learn_model.state_dict()) 333 return { 334 'cur_lr_actor': self._optimizer_actor.defaults['lr'], 335 'cur_lr_critic': self._optimizer_critic.defaults['lr'], 336 'q_value': q_value['q_value'].mean().item(), 337 'action': data['action'].mean().item(), 338 'priority': td_error_per_sample.abs().tolist(), 339 **loss_dict, 340 **q_value_dict, 341 } 342 343 def _get_train_sample(self, traj: list) -> Union[None, List[Any]]: 344 """ 345 Overview: 346 Process the data of a given trajectory (transitions, a list of transition) into a list of sample that \ 347 can be used for training directly. The sample is generated by the following steps: \ 348 1. Calculate the nstep return data. \ 349 2. Sample the data from the nstep return data. \ 350 3. Stack the data in the batch dimension. \ 351 4. Return the sample data. \ 352 For D4PG, the nstep return data is generated by ``get_nstep_return_data`` and the sample data is \ 353 generated by ``get_train_sample``. 354 355 Arguments: 356 - traj (:obj:`list`): The trajectory data (a list of transition), each element is \ 357 the same format as the return value of ``self._process_transition`` method. 358 359 Returns: 360 - samples (:obj:`dict`): The training samples generated, including at least the following keys: \ 361 ``'obs'``, ``'next_obs'``, ``'action'``, ``'reward'``, ``'done'``, ``'weight'``, ``'value_gamma'``. \ 362 For more information, please refer to the ``get_train_sample`` method. 363 """ 364 data = get_nstep_return_data(traj, self._nstep, gamma=self._gamma) 365 return get_train_sample(data, self._unroll_len) 366 367 def _monitor_vars_learn(self) -> List[str]: 368 """ 369 Overview: 370 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 371 as text logger, tensorboard logger, will use these keys to save the corresponding data. 372 Returns: 373 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 374 """ 375 ret = ['cur_lr_actor', 'cur_lr_critic', 'critic_loss', 'actor_loss', 'total_loss', 'q_value', 'action'] 376 return ret