ding.policy.a2c¶
ding.policy.a2c
¶
A2CPolicy
¶
Bases: Policy
Overview
Policy class of A2C (Advantage Actor-Critic) algorithm, proposed in https://arxiv.org/abs/1602.01783.
default_model()
¶
Overview
Returns the default model configuration used by the A2C algorithm. __init__ method will automatically call this method to get the default model setting and create model.
Returns:
| Type | Description |
|---|---|
Tuple[str, List[str]]
|
|
Full Source Code
../ding/policy/a2c.py
1from collections import namedtuple 2from typing import List, Dict, Any, Tuple 3 4import torch 5 6from ding.model import model_wrap 7from ding.rl_utils import a2c_data, a2c_error, get_gae_with_default_last_value, get_train_sample, \ 8 a2c_error_continuous 9from ding.torch_utils import Adam, to_device 10from ding.utils import POLICY_REGISTRY, split_data_generator 11from ding.utils.data import default_collate, default_decollate 12from .base_policy import Policy 13from .common_utils import default_preprocess_learn 14 15 16@POLICY_REGISTRY.register('a2c') 17class A2CPolicy(Policy): 18 """ 19 Overview: 20 Policy class of A2C (Advantage Actor-Critic) algorithm, proposed in https://arxiv.org/abs/1602.01783. 21 """ 22 config = dict( 23 # (str) Name of the registered RL policy (refer to the "register_policy" function). 24 type='a2c', 25 # (bool) Flag to enable CUDA for model computation. 26 cuda=False, 27 # (bool) Flag for using on-policy training (training policy is the same as the behavior policy). 28 on_policy=True, 29 # (bool) Flag for enabling priority experience replay. Must be False when priority_IS_weight is False. 30 priority=False, 31 # (bool) Flag for using Importance Sampling weights to correct updates. Requires `priority` to be True. 32 priority_IS_weight=False, 33 # (str) Type of action space used in the policy, with valid options ['discrete', 'continuous']. 34 action_space='discrete', 35 # learn_mode configuration 36 learn=dict( 37 # (int) Number of updates per data collection. A2C requires this to be set to 1. 38 update_per_collect=1, 39 # (int) Batch size for learning. 40 batch_size=64, 41 # (float) Learning rate for optimizer. 42 learning_rate=0.001, 43 # (Tuple[float, float]) Coefficients used for computing running averages of gradient and its square. 44 betas=(0.9, 0.999), 45 # (float) Term added to the denominator to improve numerical stability in optimizer. 46 eps=1e-8, 47 # (float) Maximum norm for gradients. 48 grad_norm=0.5, 49 # (float) Scaling factor for value network loss relative to policy network loss. 50 value_weight=0.5, 51 # (float) Weight of entropy regularization in the loss function. 52 entropy_weight=0.01, 53 # (bool) Flag to enable normalization of advantages. 54 adv_norm=False, 55 # (bool) If set to True, the 'done' signals that indicate the end of an episode due to environment time 56 # limits are disregarded. By default, this is set to False. This setting is particularly useful for tasks 57 # that have a predetermined episode length, such as HalfCheetah and various other MuJoCo environments, 58 # where the maximum length is capped at 1000 steps. When enabled, any 'done' signal triggered by reaching 59 # the maximum episode steps will be overridden to 'False'. This ensures the accurate calculation of the 60 # Temporal Difference (TD) error, using the formula `gamma * (1 - done) * next_v + reward`, 61 # even when the episode surpasses the predefined step limit. 62 ignore_done=False, 63 ), 64 # collect_mode configuration 65 collect=dict( 66 # (int) The length of rollout for data collection. 67 unroll_len=1, 68 # (float) Discount factor for calculating future rewards, typically in the range [0, 1]. 69 discount_factor=0.9, 70 # (float) Trade-off parameter for balancing TD-error and Monte Carlo error in GAE. 71 gae_lambda=0.95, 72 ), 73 # eval_mode configuration (kept empty for compatibility purposes) 74 eval=dict(), 75 ) 76 77 def default_model(self) -> Tuple[str, List[str]]: 78 """ 79 Overview: 80 Returns the default model configuration used by the A2C algorithm. ``__init__`` method will \ 81 automatically call this method to get the default model setting and create model. 82 83 Returns: 84 - model_info (:obj:`Tuple[str, List[str]]`): \ 85 Tuple containing the registered model name and model's import_names. 86 """ 87 return 'vac', ['ding.model.template.vac'] 88 89 def _init_learn(self) -> None: 90 """ 91 Overview: 92 Initialize the learn mode of policy, including related attributes and modules. For A2C, it mainly \ 93 contains optimizer, algorithm-specific arguments such as value_weight, entropy_weight, adv_norm 94 and grad_norm, and main model. \ 95 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 96 97 .. note:: 98 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 99 and ``_load_state_dict_learn`` methods. 100 101 .. note:: 102 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 103 104 .. note:: 105 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 106 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 107 """ 108 assert self._cfg.action_space in ["continuous", "discrete"] 109 # Optimizer 110 self._optimizer = Adam( 111 self._model.parameters(), 112 lr=self._cfg.learn.learning_rate, 113 betas=self._cfg.learn.betas, 114 eps=self._cfg.learn.eps 115 ) 116 117 # Algorithm config 118 self._priority = self._cfg.priority 119 self._priority_IS_weight = self._cfg.priority_IS_weight 120 self._value_weight = self._cfg.learn.value_weight 121 self._entropy_weight = self._cfg.learn.entropy_weight 122 self._adv_norm = self._cfg.learn.adv_norm 123 self._grad_norm = self._cfg.learn.grad_norm 124 125 # Main and target models 126 self._learn_model = model_wrap(self._model, wrapper_name='base') 127 self._learn_model.reset() 128 129 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 130 """ 131 Overview: 132 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 133 that the policy inputs some training batch data from the replay buffer and then returns the output \ 134 result, including various training information such as policy_loss, value_loss, entropy_loss. 135 Arguments: 136 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 137 training samples. For each element in the list, the key of the dict is the name of data items and the \ 138 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 139 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 140 dimension by some utility functions such as ``default_preprocess_learn``. \ 141 For A2C, each element in the list is a dict containing at least the following keys: \ 142 ['obs', 'action', 'adv', 'value', 'weight']. 143 Returns: 144 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 145 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 146 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 147 148 .. note:: 149 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 150 For the data type that is not supported, the main reason is that the corresponding model does not support \ 151 it. You can implement your own model rather than use the default model. For more information, please \ 152 raise an issue in GitHub repo, and we will continue to follow up. 153 """ 154 # Data preprocessing operations, such as stack data, cpu to cuda device 155 data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False) 156 if self._cuda: 157 data = to_device(data, self._device) 158 self._learn_model.train() 159 160 for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True): 161 # forward 162 output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') 163 164 adv = batch['adv'] 165 return_ = batch['value'] + adv 166 if self._adv_norm: 167 # norm adv in total train_batch 168 adv = (adv - adv.mean()) / (adv.std() + 1e-8) 169 error_data = a2c_data(output['logit'], batch['action'], output['value'], adv, return_, batch['weight']) 170 171 # Calculate A2C loss 172 if self._action_space == 'continuous': 173 a2c_loss = a2c_error_continuous(error_data) 174 elif self._action_space == 'discrete': 175 a2c_loss = a2c_error(error_data) 176 177 wv, we = self._value_weight, self._entropy_weight 178 total_loss = a2c_loss.policy_loss + wv * a2c_loss.value_loss - we * a2c_loss.entropy_loss 179 180 # ==================== 181 # A2C-learning update 182 # ==================== 183 self._optimizer.zero_grad() 184 total_loss.backward() 185 186 grad_norm = torch.nn.utils.clip_grad_norm_( 187 list(self._learn_model.parameters()), 188 max_norm=self._grad_norm, 189 ) 190 self._optimizer.step() 191 192 # ============= 193 # after update 194 # ============= 195 # only record last updates information in logger 196 return { 197 'cur_lr': self._optimizer.param_groups[0]['lr'], 198 'total_loss': total_loss.item(), 199 'policy_loss': a2c_loss.policy_loss.item(), 200 'value_loss': a2c_loss.value_loss.item(), 201 'entropy_loss': a2c_loss.entropy_loss.item(), 202 'adv_abs_max': adv.abs().max().item(), 203 'grad_norm': grad_norm, 204 } 205 206 def _state_dict_learn(self) -> Dict[str, Any]: 207 """ 208 Overview: 209 Return the state_dict of learn mode, usually including model and optimizer. 210 Returns: 211 - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. 212 """ 213 return { 214 'model': self._learn_model.state_dict(), 215 'optimizer': self._optimizer.state_dict(), 216 } 217 218 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 219 """ 220 Overview: 221 Load the state_dict variable into policy learn mode. 222 Arguments: 223 - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. 224 225 .. tip:: 226 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 227 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 228 complicated operation. 229 """ 230 self._learn_model.load_state_dict(state_dict['model']) 231 self._optimizer.load_state_dict(state_dict['optimizer']) 232 233 def _init_collect(self) -> None: 234 """ 235 Overview: 236 Initialize the collect mode of policy, including related attributes and modules. For A2C, it contains the \ 237 collect_model to balance the exploration and exploitation with ``reparam_sample`` or \ 238 ``multinomial_sample`` mechanism, and other algorithm-specific arguments such as gamma and gae_lambda. \ 239 This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. 240 241 .. note:: 242 If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ 243 with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. 244 """ 245 assert self._cfg.action_space in ["continuous", "discrete"] 246 self._unroll_len = self._cfg.collect.unroll_len 247 248 self._action_space = self._cfg.action_space 249 if self._action_space == 'continuous': 250 self._collect_model = model_wrap(self._model, wrapper_name='reparam_sample') 251 elif self._action_space == 'discrete': 252 self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample') 253 self._collect_model.reset() 254 # Algorithm 255 self._gamma = self._cfg.collect.discount_factor 256 self._gae_lambda = self._cfg.collect.gae_lambda 257 258 def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]: 259 """ 260 Overview: 261 Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ 262 that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ 263 data, such as the action to interact with the envs. 264 Arguments: 265 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 266 key of the dict is environment id and the value is the corresponding data of the env. 267 Returns: 268 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ 269 other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ 270 dict is the same as the input data, i.e. environment id. 271 """ 272 data_id = list(data.keys()) 273 data = default_collate(list(data.values())) 274 if self._cuda: 275 data = to_device(data, self._device) 276 self._collect_model.eval() 277 with torch.no_grad(): 278 output = self._collect_model.forward(data, mode='compute_actor_critic') 279 if self._cuda: 280 output = to_device(output, 'cpu') 281 output = default_decollate(output) 282 return {i: d for i, d in zip(data_id, output)} 283 284 def _process_transition(self, obs: Any, policy_output: Dict[str, torch.Tensor], 285 timestep: namedtuple) -> Dict[str, torch.Tensor]: 286 """ 287 Overview: 288 Process and pack one timestep transition data into a dict, which can be directly used for training and \ 289 saved in replay buffer. For A2C, it contains obs, next_obs, action, value, reward, done. 290 Arguments: 291 - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. 292 - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ 293 as input. For A2C, it contains the action and the value of the state. 294 - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ 295 except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ 296 reward, done, info, etc. 297 Returns: 298 - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. 299 """ 300 transition = { 301 'obs': obs, 302 'next_obs': timestep.obs, 303 'action': policy_output['action'], 304 'value': policy_output['value'], 305 'reward': timestep.reward, 306 'done': timestep.done, 307 } 308 return transition 309 310 def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 311 """ 312 Overview: 313 For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ 314 can be used for training directly. In A2C, a train sample is a processed transition. \ 315 This method is usually used in collectors to execute necessary \ 316 RL data preprocessing before training, which can help the learner amortize relevant time consumption. \ 317 In addition, you can also implement this method as an identity function and do the data processing \ 318 in ``self._forward_learn`` method. 319 Arguments: 320 - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ 321 in the same format as the return value of ``self._process_transition`` method. 322 Returns: 323 - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is similar in format \ 324 to input transitions, but may contain more data for training, such as advantages. 325 """ 326 transitions = get_gae_with_default_last_value( 327 transitions, 328 transitions[-1]['done'], 329 gamma=self._gamma, 330 gae_lambda=self._gae_lambda, 331 cuda=self._cuda, 332 ) 333 return get_train_sample(transitions, self._unroll_len) 334 335 def _init_eval(self) -> None: 336 """ 337 Overview: 338 Initialize the eval mode of policy, including related attributes and modules. For A2C, it contains the \ 339 eval model to greedily select action with ``argmax_sample`` mechanism (For discrete action space) and \ 340 ``deterministic_sample`` mechanism (For continuous action space). \ 341 This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. 342 343 .. note:: 344 If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ 345 with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. 346 """ 347 assert self._cfg.action_space in ["continuous", "discrete"] 348 self._action_space = self._cfg.action_space 349 if self._action_space == 'continuous': 350 self._eval_model = model_wrap(self._model, wrapper_name='deterministic_sample') 351 elif self._action_space == 'discrete': 352 self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') 353 self._eval_model.reset() 354 355 def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: 356 """ 357 Overview: 358 Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ 359 means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ 360 action to interact with the envs. 361 Arguments: 362 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 363 key of the dict is environment id and the value is the corresponding data of the env. 364 Returns: 365 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ 366 key of the dict is the same as the input data, i.e., environment id. 367 368 .. note:: 369 The input value can be ``torch.Tensor`` or dict/list combinations, current policy supports all of them. \ 370 For the data type that is not supported, the main reason is that the corresponding model does not \ 371 support it. You can implement your own model rather than use the default model. For more information, \ 372 please raise an issue in GitHub repo, and we will continue to follow up. 373 """ 374 data_id = list(data.keys()) 375 data = default_collate(list(data.values())) 376 if self._cuda: 377 data = to_device(data, self._device) 378 self._eval_model.eval() 379 with torch.no_grad(): 380 output = self._eval_model.forward(data, mode='compute_actor') 381 if self._cuda: 382 output = to_device(output, 'cpu') 383 output = default_decollate(output) 384 return {i: d for i, d in zip(data_id, output)} 385 386 def _monitor_vars_learn(self) -> List[str]: 387 """ 388 Overview: 389 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 390 as text logger, tensorboard logger, will use these keys to save the corresponding data. 391 Returns: 392 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 393 """ 394 return super()._monitor_vars_learn() + ['policy_loss', 'value_loss', 'entropy_loss', 'adv_abs_max', 'grad_norm']