ding.policy.pg¶
ding.policy.pg
¶
PGPolicy
¶
Bases: Policy
Overview
Policy class of Policy Gradient (REINFORCE) algorithm. Paper link: https://proceedings.neurips.cc/paper_files/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf
default_model()
¶
Overview
Return this algorithm default neural network model setting for demonstration. __init__ method will automatically call this method to get the default model setting and create model.
Returns:
- model_info (:obj:Tuple[str, List[str]]): The registered model name and model's import_names.
Full Source Code
../ding/policy/pg.py
1from typing import List, Dict, Any, Tuple, Union 2from collections import namedtuple 3import torch 4import treetensor as ttorch 5 6from ding.rl_utils import get_gae_with_default_last_value, get_train_sample 7from ding.torch_utils import Adam, to_device 8from ding.utils import POLICY_REGISTRY, split_data_generator 9from ding.utils.data import default_collate, default_decollate 10from .base_policy import Policy 11from .common_utils import default_preprocess_learn 12 13 14@POLICY_REGISTRY.register('pg') 15class PGPolicy(Policy): 16 """ 17 Overview: 18 Policy class of Policy Gradient (REINFORCE) algorithm. Paper link: \ 19 https://proceedings.neurips.cc/paper_files/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf 20 """ 21 config = dict( 22 # (string) RL policy register name (refer to function "register_policy"). 23 type='pg', 24 # (bool) whether to use cuda for network. 25 cuda=False, 26 # (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same) 27 on_policy=True, # for pg strictly on policy algorithm, this line should not be modified by users 28 # (str) action space type: ['discrete', 'continuous'] 29 action_space='discrete', 30 # (bool) whether to use deterministic action for evaluation. 31 deterministic_eval=True, 32 learn=dict( 33 34 # (int) the number of samples for one update. 35 batch_size=64, 36 # (float) the step size of one gradient descend. 37 learning_rate=0.001, 38 # ============================================================== 39 # The following configs is algorithm-specific 40 # ============================================================== 41 # (float) loss weight of the entropy regularization, the weight of policy network is set to 1 42 entropy_weight=0.01, 43 # (float) max grad norm value. 44 grad_norm=5, 45 # (bool) whether to ignore done signal for non-termination env. 46 ignore_done=False, 47 ), 48 collect=dict( 49 # (int) collect n_sample data, train model n_iteration times 50 # n_episode=8, 51 # (int) trajectory unroll length 52 unroll_len=1, 53 # ============================================================== 54 # The following configs is algorithm-specific 55 # ============================================================== 56 # (float) discount factor for future reward, defaults int [0, 1] 57 discount_factor=0.99, 58 collector=dict(get_train_sample=True), 59 ), 60 eval=dict(), 61 ) 62 63 def default_model(self) -> Tuple[str, List[str]]: 64 """ 65 Overview: 66 Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ 67 automatically call this method to get the default model setting and create model. 68 Returns: 69 - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. 70 """ 71 return 'pg', ['ding.model.template.pg'] 72 73 def _init_learn(self) -> None: 74 """ 75 Overview: 76 Initialize the learn mode of policy, including related attributes and modules. For PG, it mainly \ 77 contains optimizer, algorithm-specific arguments such as entropy weight and grad norm. This method \ 78 also executes some special network initializations. 79 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 80 81 .. note:: 82 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 83 and ``_load_state_dict_learn`` methods. 84 85 .. note:: 86 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 87 88 .. note:: 89 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 90 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 91 """ 92 # Optimizer 93 self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) 94 95 self._entropy_weight = self._cfg.learn.entropy_weight 96 self._grad_norm = self._cfg.learn.grad_norm 97 self._learn_model = self._model # for compatibility 98 99 def _forward_learn(self, data: List[Dict[int, Any]]) -> Dict[str, Any]: 100 """ 101 Overview: 102 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 103 that the policy inputs some training batch data from the replay buffer and then returns the output \ 104 result, including various training information such as loss, clipfrac, approx_kl. 105 Arguments: 106 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including the latest \ 107 collected training samples for on-policy algorithms like PG. For each element in list, the key of the \ 108 dict is the name of data items and the value is the corresponding data. Usually, the value is \ 109 torch.Tensor or np.ndarray or there dict/list combinations. In the ``_forward_learn`` method, data \ 110 often need to first be stacked in the batch dimension by some utility functions such as \ 111 ``default_preprocess_learn``. \ 112 For PG, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ 113 ``return``. 114 Returns: 115 - return_infos (:obj:`List[Dict[str, Any]]`): The information list that indicated training result, each \ 116 training iteration contains append a information dict into the final list. The list will be precessed \ 117 and recorded in text log and tensorboard. The value of the dict must be python scalar or a list of \ 118 scalars. For the detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 119 120 .. note:: 121 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 122 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 123 You can implement you own model rather than use the default model. For more information, please raise an \ 124 issue in GitHub repo and we will continue to follow up. 125 """ 126 data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False) 127 if self._cuda: 128 data = to_device(data, self._device) 129 self._model.train() 130 131 return_infos = [] 132 for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True): 133 # forward 134 output = self._learn_model.forward(batch['obs']) 135 return_ = batch['return'] 136 dist = output['dist'] 137 # calculate PG loss 138 log_prob = dist.log_prob(batch['action']) 139 policy_loss = -(log_prob * return_).mean() 140 entropy_loss = -self._cfg.learn.entropy_weight * dist.entropy().mean() 141 total_loss = policy_loss + entropy_loss 142 143 # update 144 self._optimizer.zero_grad() 145 total_loss.backward() 146 147 grad_norm = torch.nn.utils.clip_grad_norm_( 148 list(self._learn_model.parameters()), 149 max_norm=self._grad_norm, 150 ) 151 self._optimizer.step() 152 153 # only record last updates information in logger 154 return_info = { 155 'cur_lr': self._optimizer.param_groups[0]['lr'], 156 'total_loss': total_loss.item(), 157 'policy_loss': policy_loss.item(), 158 'entropy_loss': entropy_loss.item(), 159 'return_abs_max': return_.abs().max().item(), 160 'grad_norm': grad_norm, 161 } 162 return_infos.append(return_info) 163 return return_infos 164 165 def _init_collect(self) -> None: 166 """ 167 Overview: 168 Initialize the collect mode of policy, including related attributes and modules. For PPG, it contains \ 169 algorithm-specific arguments such as unroll_len and gamma. 170 This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. 171 172 .. note:: 173 If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ 174 with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. 175 """ 176 self._unroll_len = self._cfg.collect.unroll_len 177 self._gamma = self._cfg.collect.discount_factor 178 179 def _forward_collect(self, data: Dict[int, Any]) -> dict: 180 """ 181 Overview: 182 Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ 183 that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ 184 data, such as the action to interact with the envs. 185 Arguments: 186 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 187 key of the dict is environment id and the value is the corresponding data of the env. 188 Returns: 189 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ 190 other necessary data (action logit) for learn mode defined in ``self._process_transition`` \ 191 method. The key of the dict is the same as the input data, i.e. environment id. 192 193 .. tip:: 194 If you want to add more tricks on this policy, like temperature factor in multinomial sample, you can pass \ 195 related data as extra keyword arguments of this method. 196 197 .. note:: 198 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 199 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 200 You can implement you own model rather than use the default model. For more information, please raise an \ 201 issue in GitHub repo and we will continue to follow up. 202 """ 203 data_id = list(data.keys()) 204 data = default_collate(list(data.values())) 205 if self._cuda: 206 data = to_device(data, self._device) 207 self._model.eval() 208 with torch.no_grad(): 209 output = self._model.forward(data) 210 output['action'] = output['dist'].sample() 211 if self._cuda: 212 output = to_device(output, 'cpu') 213 output = default_decollate(output) 214 return {i: d for i, d in zip(data_id, output)} 215 216 def _process_transition(self, obs: Any, model_output: Dict[str, torch.Tensor], timestep: namedtuple) -> dict: 217 """ 218 Overview: 219 Process and pack one timestep transition data into a dict, which can be directly used for training and \ 220 saved in replay buffer. For PG, it contains obs, action, reward, done. 221 Arguments: 222 - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. 223 - model_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ 224 as input. For PG, it contains the action. 225 - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ 226 except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ 227 reward, done, info, etc. 228 Returns: 229 - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. 230 """ 231 return { 232 'obs': obs, 233 'action': model_output['action'], 234 'reward': timestep.reward, 235 'done': timestep.done, 236 } 237 238 def _get_train_sample(self, data: List[Dict[str, Any]]) -> Union[None, List[Any]]: 239 """ 240 Overview: 241 For a given entire episode data (a list of transition), process it into a list of sample that \ 242 can be used for training directly. In PG, a train sample is a processed transition with new computed \ 243 ``return`` field. This method is usually used in collectors to execute necessary \ 244 RL data preprocessing before training, which can help learner amortize revelant time consumption. \ 245 In addition, you can also implement this method as an identity function and do the data processing \ 246 in ``self._forward_learn`` method. 247 Arguments: 248 - data (:obj:`List[Dict[str, Any]`): The episode data (a list of transition), each element is \ 249 the same format as the return value of ``self._process_transition`` method. Note that PG needs \ 250 a complete epsiode 251 Returns: 252 - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ 253 as input transitions, but may contain more data for training, such as discounted episode return. 254 """ 255 assert data[-1]['done'], "PG needs a complete epsiode" 256 257 if self._cfg.learn.ignore_done: 258 raise NotImplementedError 259 260 R = 0. 261 if isinstance(data, list): 262 for i in reversed(range(len(data))): 263 R = self._gamma * R + data[i]['reward'] 264 data[i]['return'] = R 265 return get_train_sample(data, self._unroll_len) 266 elif isinstance(data, ttorch.Tensor): 267 data_size = data['done'].shape[0] 268 data['return'] = ttorch.torch.zeros(data_size) 269 for i in reversed(range(data_size)): 270 R = self._gamma * R + data['reward'][i] 271 data['return'][i] = R 272 return get_train_sample(data, self._unroll_len) 273 else: 274 raise ValueError 275 276 def _init_eval(self) -> None: 277 pass 278 279 def _forward_eval(self, data: Dict[int, Any]) -> dict: 280 """ 281 Overview: 282 Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ 283 means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ 284 action to interact with the envs. ``_forward_eval`` in PG often uses deterministic sample method to get \ 285 actions while ``_forward_collect`` usually uses stochastic sample method for balance exploration and \ 286 exploitation. 287 Arguments: 288 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 289 key of the dict is environment id and the value is the corresponding data of the env. 290 Returns: 291 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ 292 key of the dict is the same as the input data, i.e. environment id. 293 294 .. note:: 295 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 296 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 297 You can implement you own model rather than use the default model. For more information, please raise an \ 298 issue in GitHub repo and we will continue to follow up. 299 300 .. note:: 301 For more detailed examples, please refer to our unittest for PGPGPolicy: ``ding.policy.tests.test_pg``. 302 """ 303 data_id = list(data.keys()) 304 data = default_collate(list(data.values())) 305 if self._cuda: 306 data = to_device(data, self._device) 307 self._model.eval() 308 with torch.no_grad(): 309 output = self._model.forward(data) 310 if self._cfg.deterministic_eval: 311 if self._cfg.action_space == 'discrete': 312 output['action'] = output['logit'].argmax(dim=-1) 313 elif self._cfg.action_space == 'continuous': 314 output['action'] = output['logit']['mu'] 315 else: 316 raise KeyError("invalid action_space: {}".format(self._cfg.action_space)) 317 else: 318 output['action'] = output['dist'].sample() 319 if self._cuda: 320 output = to_device(output, 'cpu') 321 output = default_decollate(output) 322 return {i: d for i, d in zip(data_id, output)} 323 324 def _monitor_vars_learn(self) -> List[str]: 325 """ 326 Overview: 327 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 328 as text logger, tensorboard logger, will use these keys to save the corresponding data. 329 Returns: 330 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 331 """ 332 return super()._monitor_vars_learn() + ['policy_loss', 'entropy_loss', 'return_abs_max', 'grad_norm']