Skip to content

ding.policy.td3_bc

ding.policy.td3_bc

TD3BCPolicy

Bases: DDPGPolicy

Overview

Policy class of TD3_BC algorithm.

Since DDPG and TD3 share many common things, we can easily derive this TD3_BC class from DDPG class by changing _actor_update_freq, _twin_critic and noise in model wrapper.

https://arxiv.org/pdf/2106.06860.pdf

Property

learn_mode, collect_mode, eval_mode

Config:

== ==================== ======== ================== ================================= ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ================== ================================= ======================= 1 type str td3_bc | RL policy register name, refer | this arg is optional, | to registry POLICY_REGISTRY | a placeholder 2 cuda bool True | Whether to use cuda for network | 3 | random_ int 25000 | Number of randomly collected | Default to 25000 for | collect_size | training samples in replay | DDPG/TD3, 10000 for | | buffer when training starts. | sac. 4 | model.twin_ bool True | Whether to use two critic | Default True for TD3, | critic | networks or only one. | Clipped Double | | | Q-learning method in | | | TD3 paper. 5 | learn.learning float 1e-3 | Learning rate for actor | | _rate_actor | network(aka. policy). | 6 | learn.learning float 1e-3 | Learning rates for critic | | _rate_critic | network (aka. Q-network). | 7 | learn.actor_ int 2 | When critic network updates | Default 2 for TD3, 1 | update_freq | once, how many times will actor | for DDPG. Delayed | | network update. | Policy Updates method | | | in TD3 paper. 8 | learn.noise bool True | Whether to add noise on target | Default True for TD3, | | network's action. | False for DDPG. | | | Target Policy Smoo- | | | thing Regularization | | | in TD3 paper. 9 | learn.noise_ dict | dict(min=-0.5, | Limit for range of target | | range | max=0.5,) | policy smoothing noise, | | | | aka. noise_clip. | 10 | learn.- bool False | Determine whether to ignore | Use ignore_done only | ignore_done | done flag. | in halfcheetah env. 11 | learn.- float 0.005 | Used for soft update of the | aka. Interpolation | target_theta | target network. | factor in polyak aver | | | aging for target | | | networks. 12 | collect.- float 0.1 | Used for add noise during co- | Sample noise from dis | noise_sigma | llection, through controlling | tribution, Ornstein- | | the sigma of distribution | Uhlenbeck process in | | | DDPG paper, Guassian | | | process in ours. == ==================== ======== ================== ================================= =======================

Full Source Code

../ding/policy/td3_bc.py

1from typing import List, Dict, Any, Tuple, Union 2from easydict import EasyDict 3from collections import namedtuple 4import torch 5import torch.nn.functional as F 6import copy 7 8from ding.torch_utils import Adam, to_device 9from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample 10from ding.model import model_wrap 11from ding.utils import POLICY_REGISTRY 12from ding.utils.data import default_collate, default_decollate 13from .base_policy import Policy 14from .common_utils import default_preprocess_learn 15from .ddpg import DDPGPolicy 16 17 18@POLICY_REGISTRY.register('td3_bc') 19class TD3BCPolicy(DDPGPolicy): 20 r""" 21 Overview: 22 Policy class of TD3_BC algorithm. 23 24 Since DDPG and TD3 share many common things, we can easily derive this TD3_BC 25 class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and noise in model wrapper. 26 27 https://arxiv.org/pdf/2106.06860.pdf 28 29 Property: 30 learn_mode, collect_mode, eval_mode 31 32 Config: 33 34 == ==================== ======== ================== ================================= ======================= 35 ID Symbol Type Default Value Description Other(Shape) 36 == ==================== ======== ================== ================================= ======================= 37 1 ``type`` str td3_bc | RL policy register name, refer | this arg is optional, 38 | to registry ``POLICY_REGISTRY`` | a placeholder 39 2 ``cuda`` bool True | Whether to use cuda for network | 40 3 | ``random_`` int 25000 | Number of randomly collected | Default to 25000 for 41 | ``collect_size`` | training samples in replay | DDPG/TD3, 10000 for 42 | | buffer when training starts. | sac. 43 4 | ``model.twin_`` bool True | Whether to use two critic | Default True for TD3, 44 | ``critic`` | networks or only one. | Clipped Double 45 | | | Q-learning method in 46 | | | TD3 paper. 47 5 | ``learn.learning`` float 1e-3 | Learning rate for actor | 48 | ``_rate_actor`` | network(aka. policy). | 49 6 | ``learn.learning`` float 1e-3 | Learning rates for critic | 50 | ``_rate_critic`` | network (aka. Q-network). | 51 7 | ``learn.actor_`` int 2 | When critic network updates | Default 2 for TD3, 1 52 | ``update_freq`` | once, how many times will actor | for DDPG. Delayed 53 | | network update. | Policy Updates method 54 | | | in TD3 paper. 55 8 | ``learn.noise`` bool True | Whether to add noise on target | Default True for TD3, 56 | | network's action. | False for DDPG. 57 | | | Target Policy Smoo- 58 | | | thing Regularization 59 | | | in TD3 paper. 60 9 | ``learn.noise_`` dict | dict(min=-0.5, | Limit for range of target | 61 | ``range`` | max=0.5,) | policy smoothing noise, | 62 | | | aka. noise_clip. | 63 10 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only 64 | ``ignore_done`` | done flag. | in halfcheetah env. 65 11 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation 66 | ``target_theta`` | target network. | factor in polyak aver 67 | | | aging for target 68 | | | networks. 69 12 | ``collect.-`` float 0.1 | Used for add noise during co- | Sample noise from dis 70 | ``noise_sigma`` | llection, through controlling | tribution, Ornstein- 71 | | the sigma of distribution | Uhlenbeck process in 72 | | | DDPG paper, Guassian 73 | | | process in ours. 74 == ==================== ======== ================== ================================= ======================= 75 """ 76 77 # You can refer to DDPG's default config for more details. 78 config = dict( 79 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 80 type='td3_bc', 81 # (bool) Whether to use cuda for network. 82 cuda=False, 83 # (bool type) on_policy: Determine whether on-policy or off-policy. 84 # on-policy setting influences the behaviour of buffer. 85 # Default False in TD3. 86 on_policy=False, 87 # (bool) Whether use priority(priority sample, IS weight, update priority) 88 # Default False in TD3. 89 priority=False, 90 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 91 priority_IS_weight=False, 92 # (int) Number of training samples(randomly collected) in replay buffer when training starts. 93 # Default 25000 in DDPG/TD3. 94 random_collect_size=25000, 95 # (bool) Whether use batch normalization for reward 96 reward_batch_norm=False, 97 action_space='continuous', 98 model=dict( 99 # (bool) Whether to use two critic networks or only one. 100 # Clipped Double Q-Learning for Actor-Critic in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf). 101 # Default True for TD3, False for DDPG. 102 twin_critic=True, 103 104 # (str type) action_space: Use regression trick for continous action 105 action_space='regression', 106 107 # (int) Hidden size for actor network head. 108 actor_head_hidden_size=256, 109 110 # (int) Hidden size for critic network head. 111 critic_head_hidden_size=256, 112 ), 113 learn=dict( 114 115 # How many updates(iterations) to train after collector's one collection. 116 # Bigger "update_per_collect" means bigger off-policy. 117 # collect data -> update policy-> collect data -> ... 118 update_per_collect=1, 119 # (int) Minibatch size for gradient descent. 120 batch_size=256, 121 # (float) Learning rates for actor network(aka. policy). 122 learning_rate_actor=1e-3, 123 # (float) Learning rates for critic network(aka. Q-network). 124 learning_rate_critic=1e-3, 125 # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) 126 # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. 127 # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. 128 # However, interaction with HalfCheetah always gets done with False, 129 # Since we inplace done==True with done==False to keep 130 # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), 131 # when the episode step is greater than max episode step. 132 ignore_done=False, 133 # (float type) target_theta: Used for soft update of the target network, 134 # aka. Interpolation factor in polyak averaging for target networks. 135 # Default to 0.005. 136 target_theta=0.005, 137 # (float) discount factor for the discounted sum of rewards, aka. gamma. 138 discount_factor=0.99, 139 # (int) When critic network updates once, how many times will actor network update. 140 # Delayed Policy Updates in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf). 141 # Default 1 for DDPG, 2 for TD3. 142 actor_update_freq=2, 143 # (bool) Whether to add noise on target network's action. 144 # Target Policy Smoothing Regularization in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf). 145 # Default True for TD3, False for DDPG. 146 noise=True, 147 # (float) Sigma for smoothing noise added to target policy. 148 noise_sigma=0.2, 149 # (dict) Limit for range of target policy smoothing noise, aka. noise_clip. 150 noise_range=dict( 151 min=-0.5, 152 max=0.5, 153 ), 154 alpha=2.5, 155 ), 156 collect=dict( 157 # (int) Cut trajectories into pieces with length "unroll_len". 158 unroll_len=1, 159 # (float) It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma". 160 noise_sigma=0.1, 161 ), 162 eval=dict( 163 evaluator=dict( 164 # (int) Evaluate every "eval_freq" training iterations. 165 eval_freq=5000, 166 ), 167 ), 168 other=dict( 169 replay_buffer=dict( 170 # (int) Maximum size of replay buffer. 171 replay_buffer_size=1000000, 172 ), 173 ), 174 ) 175 176 def default_model(self) -> Tuple[str, List[str]]: 177 return 'continuous_qac', ['ding.model.template.qac'] 178 179 def _init_learn(self) -> None: 180 """ 181 Overview: 182 Learn mode init method. Called by ``self.__init__``. Init actor and critic optimizers, algorithm config. 183 """ 184 super(TD3BCPolicy, self)._init_learn() 185 self._alpha = self._cfg.learn.alpha 186 # actor and critic optimizer 187 self._optimizer_actor = Adam( 188 self._model.actor.parameters(), 189 lr=self._cfg.learn.learning_rate_actor, 190 grad_clip_type='clip_norm', 191 clip_value=1.0, 192 ) 193 self._optimizer_critic = Adam( 194 self._model.critic.parameters(), 195 lr=self._cfg.learn.learning_rate_critic, 196 grad_clip_type='clip_norm', 197 clip_value=1.0, 198 ) 199 200 self.noise_sigma = self._cfg.learn.noise_sigma 201 self.noise_range = self._cfg.learn.noise_range 202 203 def _forward_learn(self, data: dict) -> Dict[str, Any]: 204 r""" 205 Overview: 206 Forward and backward function of learn mode. 207 Arguments: 208 - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] 209 Returns: 210 - info_dict (:obj:`Dict[str, Any]`): Including at least actor and critic lr, different losses. 211 """ 212 loss_dict = {} 213 data = default_preprocess_learn( 214 data, 215 use_priority=self._cfg.priority, 216 use_priority_IS_weight=self._cfg.priority_IS_weight, 217 ignore_done=self._cfg.learn.ignore_done, 218 use_nstep=False 219 ) 220 if self._cuda: 221 data = to_device(data, self._device) 222 # ==================== 223 # critic learn forward 224 # ==================== 225 self._learn_model.train() 226 self._target_model.train() 227 next_obs = data['next_obs'] 228 reward = data['reward'] 229 if self._reward_batch_norm: 230 reward = (reward - reward.mean()) / (reward.std() + 1e-8) 231 # current q value 232 q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] 233 q_value_dict = {} 234 if self._twin_critic: 235 q_value_dict['q_value'] = q_value[0].mean() 236 q_value_dict['q_value_twin'] = q_value[1].mean() 237 else: 238 q_value_dict['q_value'] = q_value.mean() 239 # target q value. 240 with torch.no_grad(): 241 next_action = self._target_model.forward(next_obs, mode='compute_actor')['action'] 242 noise = (torch.randn_like(next_action) * 243 self.noise_sigma).clamp(self.noise_range['min'], self.noise_range['max']) 244 next_action = (next_action + noise).clamp(-1, 1) 245 next_data = {'obs': next_obs, 'action': next_action} 246 target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value'] 247 if self._twin_critic: 248 # TD3: two critic networks 249 target_q_value = torch.min(target_q_value[0], target_q_value[1]) # find min one as target q value 250 # critic network1 251 td_data = v_1step_td_data(q_value[0], target_q_value, reward, data['done'], data['weight']) 252 critic_loss, td_error_per_sample1 = v_1step_td_error(td_data, self._gamma) 253 loss_dict['critic_loss'] = critic_loss 254 # critic network2(twin network) 255 td_data_twin = v_1step_td_data(q_value[1], target_q_value, reward, data['done'], data['weight']) 256 critic_twin_loss, td_error_per_sample2 = v_1step_td_error(td_data_twin, self._gamma) 257 loss_dict['critic_twin_loss'] = critic_twin_loss 258 td_error_per_sample = (td_error_per_sample1 + td_error_per_sample2) / 2 259 else: 260 # DDPG: single critic network 261 td_data = v_1step_td_data(q_value, target_q_value, reward, data['done'], data['weight']) 262 critic_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma) 263 loss_dict['critic_loss'] = critic_loss 264 # ================ 265 # critic update 266 # ================ 267 self._optimizer_critic.zero_grad() 268 for k in loss_dict: 269 if 'critic' in k: 270 loss_dict[k].backward() 271 self._optimizer_critic.step() 272 # =============================== 273 # actor learn forward and update 274 # =============================== 275 # actor updates every ``self._actor_update_freq`` iters 276 if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0: 277 actor_data = self._learn_model.forward(data['obs'], mode='compute_actor') 278 actor_data['obs'] = data['obs'] 279 if self._twin_critic: 280 q_value = self._learn_model.forward(actor_data, mode='compute_critic')['q_value'][0] 281 actor_loss = -q_value.mean() 282 else: 283 q_value = self._learn_model.forward(actor_data, mode='compute_critic')['q_value'] 284 actor_loss = -q_value.mean() 285 286 # add behavior cloning loss weight(\lambda) 287 lmbda = self._alpha / q_value.abs().mean().detach() 288 # bc_loss = ((actor_data['action'] - data['action'])**2).mean() 289 bc_loss = F.mse_loss(actor_data['action'], data['action']) 290 actor_loss = lmbda * actor_loss + bc_loss 291 loss_dict['actor_loss'] = actor_loss 292 # actor update 293 self._optimizer_actor.zero_grad() 294 actor_loss.backward() 295 self._optimizer_actor.step() 296 # ============= 297 # after update 298 # ============= 299 loss_dict['total_loss'] = sum(loss_dict.values()) 300 self._forward_learn_cnt += 1 301 self._target_model.update(self._learn_model.state_dict()) 302 return { 303 'cur_lr_actor': self._optimizer_actor.defaults['lr'], 304 'cur_lr_critic': self._optimizer_critic.defaults['lr'], 305 # 'q_value': np.array(q_value).mean(), 306 'action': data.get('action').mean(), 307 'priority': td_error_per_sample.abs().tolist(), 308 'td_error': td_error_per_sample.abs().mean(), 309 **loss_dict, 310 **q_value_dict, 311 } 312 313 def _forward_eval(self, data: dict) -> dict: 314 r""" 315 Overview: 316 Forward function of eval mode, similar to ``self._forward_collect``. 317 Arguments: 318 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 319 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 320 Returns: 321 - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. 322 ReturnsKeys 323 - necessary: ``action`` 324 - optional: ``logit`` 325 """ 326 data_id = list(data.keys()) 327 data = default_collate(list(data.values())) 328 if self._cuda: 329 data = to_device(data, self._device) 330 self._eval_model.eval() 331 with torch.no_grad(): 332 output = self._eval_model.forward(data, mode='compute_actor') 333 if self._cuda: 334 output = to_device(output, 'cpu') 335 output = default_decollate(output) 336 return {i: d for i, d in zip(data_id, output)}