ding.policy.bcq¶
ding.policy.bcq
¶
BCQPolicy
¶
Bases: Policy
Overview
Policy class of BCQ (Batch-Constrained deep Q-learning) algorithm, proposed in https://arxiv.org/abs/1812.02900.
default_model()
¶
Overview
Returns the default model configuration used by the BCQ algorithm. __init__ method will automatically call this method to get the default model setting and create model.
Returns:
| Type | Description |
|---|---|
Tuple[str, List[str]]
|
|
Full Source Code
../ding/policy/bcq.py
1import copy 2from collections import namedtuple 3from typing import List, Dict, Any, Tuple 4 5import torch 6import torch.nn.functional as F 7 8from ding.model import model_wrap 9from ding.policy import Policy 10from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, get_nstep_return_data 11from ding.torch_utils import Adam, to_device 12from ding.utils import POLICY_REGISTRY 13from ding.utils.data import default_collate, default_decollate 14from .common_utils import default_preprocess_learn 15 16 17@POLICY_REGISTRY.register('bcq') 18class BCQPolicy(Policy): 19 """ 20 Overview: 21 Policy class of BCQ (Batch-Constrained deep Q-learning) algorithm, proposed in \ 22 https://arxiv.org/abs/1812.02900. 23 """ 24 25 config = dict( 26 # (str) Name of the registered RL policy (refer to the "register_policy" function). 27 type='bcq', 28 # (bool) Indicates if CUDA should be used for network operations. 29 cuda=False, 30 # (bool) Determines whether priority sampling is used in the replay buffer. Default is False. 31 priority=False, 32 # (bool) If True, Importance Sampling Weight is used to correct updates. Requires 'priority' to be True. 33 priority_IS_weight=False, 34 # (int) Number of random samples in replay buffer before training begins. Default is 10000. 35 random_collect_size=10000, 36 # (int) The number of steps for calculating target q_value. 37 nstep=1, 38 model=dict( 39 # (List[int]) Sizes of the hidden layers in the actor network. 40 actor_head_hidden_size=[400, 300], 41 # (List[int]) Sizes of the hidden layers in the critic network. 42 critic_head_hidden_size=[400, 300], 43 # (float) Maximum perturbation for BCQ. Controls exploration in action space. 44 phi=0.05, 45 ), 46 learn=dict( 47 # (int) Number of policy updates per data collection step. Higher values indicate more off-policy training. 48 update_per_collect=1, 49 # (int) Batch size for each gradient descent step. 50 batch_size=100, 51 # (float) Learning rate for the Q-network. Set to 1e-3 if `model.value_network` is True. 52 learning_rate_q=3e-4, 53 # (float) Learning rate for the policy network. Set to 1e-3 if `model.value_network` is True. 54 learning_rate_policy=3e-4, 55 # (float) Learning rate for the VAE network. Initialize if `model.vae_network` is True. 56 learning_rate_vae=3e-4, 57 # (bool) If set to True, the 'done' signals that indicate the end of an episode due to environment time 58 # limits are disregarded. By default, this is set to False. This setting is particularly useful for tasks 59 # that have a predetermined episode length, such as HalfCheetah and various other MuJoCo environments, 60 # where the maximum length is capped at 1000 steps. When enabled, any 'done' signal triggered by reaching 61 # the maximum episode steps will be overridden to 'False'. This ensures the accurate calculation of the 62 # Temporal Difference (TD) error, using the formula `gamma * (1 - done) * next_v + reward`, 63 # even when the episode surpasses the predefined step limit. 64 ignore_done=False, 65 # (float) Polyak averaging coefficient for the target network update. Typically small. 66 target_theta=0.005, 67 # (float) Discount factor for future rewards, often denoted as gamma. 68 discount_factor=0.99, 69 # (float) Lambda for TD(lambda) learning. Weighs the trade-off between bias and variance. 70 lmbda=0.75, 71 # (float) Range for uniform weight initialization in the output layer. 72 init_w=3e-3, 73 ), 74 collect=dict( 75 # (int) Length of trajectory segments for unrolling. Set to higher for longer dependencies. 76 unroll_len=1, 77 ), 78 eval=dict(), 79 other=dict( 80 replay_buffer=dict( 81 # (int) Maximum size of the replay buffer. 82 replay_buffer_size=1000000, 83 ), 84 ), 85 ) 86 87 def default_model(self) -> Tuple[str, List[str]]: 88 """ 89 Overview: 90 Returns the default model configuration used by the BCQ algorithm. ``__init__`` method will \ 91 automatically call this method to get the default model setting and create model. 92 93 Returns: 94 - model_info (:obj:`Tuple[str, List[str]]`): \ 95 Tuple containing the registered model name and model's import_names. 96 """ 97 return 'bcq', ['ding.model.template.bcq'] 98 99 def _init_learn(self) -> None: 100 """ 101 Overview: 102 Initialize the learn mode of policy, including related attributes and modules. For BCQ, it mainly \ 103 contains optimizer, algorithm-specific arguments such as gamma, main and target model. \ 104 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 105 106 .. note:: 107 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 108 and ``_load_state_dict_learn`` methods. 109 110 .. note:: 111 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 112 113 .. note:: 114 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 115 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 116 """ 117 # Init 118 self._priority = self._cfg.priority 119 self._priority_IS_weight = self._cfg.priority_IS_weight 120 self.lmbda = self._cfg.learn.lmbda 121 self.latent_dim = self._cfg.model.action_shape * 2 122 123 # Optimizers 124 self._optimizer_q = Adam( 125 self._model.critic.parameters(), 126 lr=self._cfg.learn.learning_rate_q, 127 ) 128 self._optimizer_policy = Adam( 129 self._model.actor.parameters(), 130 lr=self._cfg.learn.learning_rate_policy, 131 ) 132 self._optimizer_vae = Adam( 133 self._model.vae.parameters(), 134 lr=self._cfg.learn.learning_rate_vae, 135 ) 136 137 # Algorithm config 138 self._gamma = self._cfg.learn.discount_factor 139 140 # Main and target models 141 self._target_model = copy.deepcopy(self._model) 142 self._target_model = model_wrap( 143 self._target_model, 144 wrapper_name='target', 145 update_type='momentum', 146 update_kwargs={'theta': self._cfg.learn.target_theta} 147 ) 148 self._learn_model = model_wrap(self._model, wrapper_name='base') 149 self._learn_model.reset() 150 self._target_model.reset() 151 self._forward_learn_cnt = 0 152 153 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 154 """ 155 Overview: 156 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 157 that the policy inputs some training batch data from the replay buffer and then returns the output \ 158 result, including various training information such as policy_loss, value_loss, entropy_loss. 159 Arguments: 160 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 161 training samples. For each element in list, the key of the dict is the name of data items and the \ 162 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 163 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 164 dimension by some utility functions such as ``default_preprocess_learn``. \ 165 For BCQ, each element in list is a dict containing at least the following keys: \ 166 ['obs', 'action', 'adv', 'value', 'weight']. 167 Returns: 168 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 169 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 170 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 171 172 .. note:: 173 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 174 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 175 You can implement your own model rather than use the default model. For more information, please raise an \ 176 issue in GitHub repo and we will continue to follow up. 177 """ 178 loss_dict = {} 179 # Data preprocessing operations, such as stack data, cpu to cuda device 180 data = default_preprocess_learn( 181 data, 182 use_priority=self._priority, 183 use_priority_IS_weight=self._cfg.priority_IS_weight, 184 ignore_done=self._cfg.learn.ignore_done, 185 use_nstep=False 186 ) 187 if len(data.get('action').shape) == 1: 188 data['action'] = data['action'].reshape(-1, 1) 189 190 if self._cuda: 191 data = to_device(data, self._device) 192 193 self._learn_model.train() 194 self._target_model.train() 195 obs = data['obs'] 196 next_obs = data['next_obs'] 197 reward = data['reward'] 198 done = data['done'] 199 batch_size = obs.shape[0] 200 201 # train_vae 202 vae_out = self._model.forward(data, mode='compute_vae') 203 recon, mean, log_std = vae_out['recons_action'], vae_out['mu'], vae_out['log_var'] 204 recons_loss = F.mse_loss(recon, data['action']) 205 kld_loss = torch.mean(-0.5 * torch.sum(1 + log_std - mean ** 2 - log_std.exp(), dim=1), dim=0) 206 loss_dict['recons_loss'] = recons_loss 207 loss_dict['kld_loss'] = kld_loss 208 vae_loss = recons_loss + 0.5 * kld_loss 209 loss_dict['vae_loss'] = vae_loss 210 self._optimizer_vae.zero_grad() 211 vae_loss.backward() 212 self._optimizer_vae.step() 213 214 # train_critic 215 q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] 216 217 with (torch.no_grad()): 218 next_obs_rep = torch.repeat_interleave(next_obs, 10, 0) 219 z = torch.randn((next_obs_rep.shape[0], self.latent_dim)).to(self._device).clamp(-0.5, 0.5) 220 vae_action = self._model.vae.decode_with_obs(z, next_obs_rep)['reconstruction_action'] 221 next_action = self._target_model.forward({ 222 'obs': next_obs_rep, 223 'action': vae_action 224 }, mode='compute_actor')['action'] 225 226 next_data = {'obs': next_obs_rep, 'action': next_action} 227 target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value'] 228 # the value of a policy according to the maximum entropy objective 229 # find min one as target q value 230 target_q_value = self.lmbda * torch.min(target_q_value[0], target_q_value[1]) \ 231 + (1 - self.lmbda) * torch.max(target_q_value[0], target_q_value[1]) 232 target_q_value = target_q_value.reshape(batch_size, -1).max(1)[0].reshape(-1, 1) 233 234 q_data0 = v_1step_td_data(q_value[0], target_q_value, reward, done, data['weight']) 235 loss_dict['critic_loss'], td_error_per_sample0 = v_1step_td_error(q_data0, self._gamma) 236 q_data1 = v_1step_td_data(q_value[1], target_q_value, reward, done, data['weight']) 237 loss_dict['twin_critic_loss'], td_error_per_sample1 = v_1step_td_error(q_data1, self._gamma) 238 td_error_per_sample = (td_error_per_sample0 + td_error_per_sample1) / 2 239 240 self._optimizer_q.zero_grad() 241 (loss_dict['critic_loss'] + loss_dict['twin_critic_loss']).backward() 242 self._optimizer_q.step() 243 244 # train_policy 245 z = torch.randn((obs.shape[0], self.latent_dim)).to(self._device).clamp(-0.5, 0.5) 246 sample_action = self._model.vae.decode_with_obs(z, obs)['reconstruction_action'] 247 input = {'obs': obs, 'action': sample_action} 248 perturbed_action = self._model.forward(input, mode='compute_actor')['action'] 249 q_input = {'obs': obs, 'action': perturbed_action} 250 q = self._learn_model.forward(q_input, mode='compute_critic')['q_value'][0] 251 loss_dict['actor_loss'] = -q.mean() 252 self._optimizer_policy.zero_grad() 253 loss_dict['actor_loss'].backward() 254 self._optimizer_policy.step() 255 self._forward_learn_cnt += 1 256 self._target_model.update(self._learn_model.state_dict()) 257 return { 258 'td_error': td_error_per_sample.detach().mean().item(), 259 'target_q_value': target_q_value.detach().mean().item(), 260 **loss_dict 261 } 262 263 def _monitor_vars_learn(self) -> List[str]: 264 """ 265 Overview: 266 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 267 as text logger, tensorboard logger, will use these keys to save the corresponding data. 268 Returns: 269 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 270 """ 271 return [ 272 'td_error', 'target_q_value', 'critic_loss', 'twin_critic_loss', 'actor_loss', 'recons_loss', 'kld_loss', 273 'vae_loss' 274 ] 275 276 def _state_dict_learn(self) -> Dict[str, Any]: 277 """ 278 Overview: 279 Return the state_dict of learn mode, usually including model and optimizer. 280 Returns: 281 - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. 282 """ 283 ret = { 284 'model': self._learn_model.state_dict(), 285 'target_model': self._target_model.state_dict(), 286 'optimizer_q': self._optimizer_q.state_dict(), 287 'optimizer_policy': self._optimizer_policy.state_dict(), 288 'optimizer_vae': self._optimizer_vae.state_dict(), 289 } 290 return ret 291 292 def _init_eval(self) -> None: 293 """ 294 Overview: 295 Initialize the eval mode of policy, including related attributes and modules. 296 This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. 297 298 .. note:: 299 If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ 300 with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. 301 """ 302 self._eval_model = model_wrap(self._model, wrapper_name='base') 303 self._eval_model.reset() 304 305 def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: 306 """ 307 Overview: 308 Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ 309 means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ 310 action to interact with the envs. 311 Arguments: 312 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 313 key of the dict is environment id and the value is the corresponding data of the env. 314 Returns: 315 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ 316 key of the dict is the same as the input data, i.e., environment id. 317 318 .. note:: 319 The input value can be ``torch.Tensor`` or dict/list combinations, current policy supports all of them. \ 320 For the data type that is not supported, the main reason is that the corresponding model does not \ 321 support it. You can implement your own model rather than use the default model. For more information, \ 322 please raise an issue in GitHub repo, and we will continue to follow up. 323 """ 324 data_id = list(data.keys()) 325 data = default_collate(list(data.values())) 326 if self._cuda: 327 data = to_device(data, self._device) 328 data = {'obs': data} 329 self._eval_model.eval() 330 with torch.no_grad(): 331 output = self._eval_model.forward(data, mode='compute_eval') 332 if self._cuda: 333 output = to_device(output, 'cpu') 334 output = default_decollate(output) 335 return {i: d for i, d in zip(data_id, output)} 336 337 def _init_collect(self) -> None: 338 """ 339 Overview: 340 Initialize the collect mode of policy, including related attributes and modules. For BCQ, it contains the \ 341 collect_model to balance the exploration and exploitation with ``eps_greedy_sample`` \ 342 mechanism, and other algorithm-specific arguments such as gamma and nstep. 343 This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. 344 345 .. note:: 346 If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ 347 with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. 348 """ 349 self._unroll_len = self._cfg.collect.unroll_len 350 self._gamma = self._cfg.discount_factor 351 self._nstep = self._cfg.nstep 352 self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample') 353 self._collect_model.reset() 354 355 def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 356 pass 357 358 def _forward_collect(self, data: dict, **kwargs) -> dict: 359 pass 360 361 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: 362 pass