Skip to content

ding.policy.offppo_collect_traj

ding.policy.offppo_collect_traj

OffPPOCollectTrajPolicy

Bases: Policy

Overview

Policy class of off policy PPO algorithm to collect expert traj for R2D3.

Full Source Code

../ding/policy/offppo_collect_traj.py

1from typing import List, Dict, Any, Tuple, Union 2from collections import namedtuple 3import torch 4import copy 5import numpy as np 6from torch.distributions import Independent, Normal 7 8from ding.torch_utils import Adam, to_device 9from ding.rl_utils import ppo_data, ppo_error, ppo_policy_error, ppo_policy_data, get_gae_with_default_last_value, \ 10 v_nstep_td_data, v_nstep_td_error, get_nstep_return_data, get_train_sample, gae, gae_data, ppo_error_continuous,\ 11 get_gae 12from ding.model import model_wrap 13from ding.utils import POLICY_REGISTRY, split_data_generator, RunningMeanStd 14from ding.utils.data import default_collate, default_decollate 15from .base_policy import Policy 16from .common_utils import default_preprocess_learn 17 18 19@POLICY_REGISTRY.register('offppo_collect_traj') 20class OffPPOCollectTrajPolicy(Policy): 21 r""" 22 Overview: 23 Policy class of off policy PPO algorithm to collect expert traj for R2D3. 24 """ 25 config = dict( 26 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 27 type='ppo', 28 # (bool) Whether to use cuda for network. 29 cuda=False, 30 # (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used) 31 on_policy=True, 32 # (bool) Whether to use priority(priority sample, IS weight, update priority) 33 priority=False, 34 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 35 priority_IS_weight=False, 36 # (bool) Whether to use nstep_return for value loss 37 nstep_return=False, 38 nstep=3, 39 learn=dict( 40 41 # How many updates(iterations) to train after collector's one collection. 42 # Bigger "update_per_collect" means bigger off-policy. 43 # collect data -> update policy-> collect data -> ... 44 update_per_collect=5, 45 batch_size=64, 46 learning_rate=0.001, 47 # ============================================================== 48 # The following configs is algorithm-specific 49 # ============================================================== 50 # (float) The loss weight of value network, policy network weight is set to 1 51 value_weight=0.5, 52 # (float) The loss weight of entropy regularization, policy network weight is set to 1 53 entropy_weight=0.01, 54 # (float) PPO clip ratio, defaults to 0.2 55 clip_ratio=0.2, 56 # (bool) Whether to use advantage norm in a whole training batch 57 adv_norm=False, 58 ignore_done=False, 59 ), 60 collect=dict( 61 # ============================================================== 62 # The following configs is algorithm-specific 63 # ============================================================== 64 # (float) Reward's future discount factor, aka. gamma. 65 discount_factor=0.99, 66 # (float) GAE lambda factor for the balance of bias and variance(1-step td and mc) 67 gae_lambda=0.95, 68 ), 69 eval=dict(), 70 other=dict(replay_buffer=dict(replay_buffer_size=10000, ), ), 71 ) 72 73 def default_model(self) -> Tuple[str, List[str]]: 74 return 'vac', ['ding.model.template.vac'] 75 76 def _init_learn(self) -> None: 77 r""" 78 Overview: 79 Learn mode init method. Called by ``self.__init__``. 80 Init the optimizer, algorithm config and the main model. 81 """ 82 self._priority = self._cfg.priority 83 self._priority_IS_weight = self._cfg.priority_IS_weight 84 assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in PPO" 85 # Orthogonal init 86 for m in self._model.modules(): 87 if isinstance(m, torch.nn.Conv2d): 88 torch.nn.init.orthogonal_(m.weight) 89 if isinstance(m, torch.nn.Linear): 90 torch.nn.init.orthogonal_(m.weight) 91 # Optimizer 92 self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) 93 self._learn_model = model_wrap(self._model, wrapper_name='base') 94 95 # Algorithm config 96 self._value_weight = self._cfg.learn.value_weight 97 self._entropy_weight = self._cfg.learn.entropy_weight 98 self._clip_ratio = self._cfg.learn.clip_ratio 99 self._adv_norm = self._cfg.learn.adv_norm 100 self._nstep = self._cfg.nstep 101 self._nstep_return = self._cfg.nstep_return 102 # Main model 103 self._learn_model.reset() 104 105 def _forward_learn(self, data: dict) -> Dict[str, Any]: 106 r""" 107 Overview: 108 Forward and backward function of learn mode. 109 Arguments: 110 - data (:obj:`dict`): Dict type data 111 Returns: 112 - info_dict (:obj:`Dict[str, Any]`): 113 Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \ 114 adv_abs_max, approx_kl, clipfrac 115 """ 116 data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=self._nstep_return) 117 if self._cuda: 118 data = to_device(data, self._device) 119 # ==================== 120 # PPO forward 121 # ==================== 122 123 self._learn_model.train() 124 # normal ppo 125 if not self._nstep_return: 126 output = self._learn_model.forward(data['obs'], mode='compute_actor_critic') 127 adv = data['adv'] 128 return_ = data['value'] + adv 129 if self._adv_norm: 130 # Normalize advantage in a total train_batch 131 adv = (adv - adv.mean()) / (adv.std() + 1e-8) 132 # Calculate ppo error 133 ppodata = ppo_data( 134 output['logit'], data['logit'], data['action'], output['value'], data['value'], adv, return_, 135 data['weight'] 136 ) 137 ppo_loss, ppo_info = ppo_error(ppodata, self._clip_ratio) 138 wv, we = self._value_weight, self._entropy_weight 139 total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss 140 141 else: 142 output = self._learn_model.forward(data['obs'], mode='compute_actor') 143 adv = data['adv'] 144 if self._adv_norm: 145 # Normalize advantage in a total train_batch 146 adv = (adv - adv.mean()) / (adv.std() + 1e-8) 147 148 # Calculate ppo error 149 ppodata = ppo_policy_data(output['logit'], data['logit'], data['action'], adv, data['weight']) 150 ppo_policy_loss, ppo_info = ppo_policy_error(ppodata, self._clip_ratio) 151 wv, we = self._value_weight, self._entropy_weight 152 next_obs = data.get('next_obs') 153 value_gamma = data.get('value_gamma') 154 reward = data.get('reward') 155 # current value 156 value = self._learn_model.forward(data['obs'], mode='compute_critic') 157 # target value 158 next_data = {'obs': next_obs} 159 target_value = self._learn_model.forward(next_data['obs'], mode='compute_critic') 160 # TODO what should we do here to keep shape 161 assert self._nstep > 1 162 td_data = v_nstep_td_data( 163 value['value'], target_value['value'], reward.t(), data['done'], data['weight'], value_gamma 164 ) 165 # calculate v_nstep_td critic_loss 166 critic_loss, td_error_per_sample = v_nstep_td_error(td_data, self._gamma, self._nstep) 167 ppo_loss_data = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss']) 168 ppo_loss = ppo_loss_data(ppo_policy_loss.policy_loss, critic_loss, ppo_policy_loss.entropy_loss) 169 total_loss = ppo_policy_loss.policy_loss + wv * critic_loss - we * ppo_policy_loss.entropy_loss 170 171 # ==================== 172 # PPO update 173 # ==================== 174 self._optimizer.zero_grad() 175 total_loss.backward() 176 self._optimizer.step() 177 return { 178 'cur_lr': self._optimizer.defaults['lr'], 179 'total_loss': total_loss.item(), 180 'policy_loss': ppo_loss.policy_loss.item(), 181 'value_loss': ppo_loss.value_loss.item(), 182 'entropy_loss': ppo_loss.entropy_loss.item(), 183 'adv_abs_max': adv.abs().max().item(), 184 'approx_kl': ppo_info.approx_kl, 185 'clipfrac': ppo_info.clipfrac, 186 } 187 188 def _state_dict_learn(self) -> Dict[str, Any]: 189 return { 190 'model': self._learn_model.state_dict(), 191 'optimizer': self._optimizer.state_dict(), 192 } 193 194 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 195 self._learn_model.load_state_dict(state_dict['model']) 196 self._optimizer.load_state_dict(state_dict['optimizer']) 197 198 def _init_collect(self) -> None: 199 r""" 200 Overview: 201 Collect mode init method. Called by ``self.__init__``. 202 Init traj and unroll length, collect model. 203 """ 204 self._unroll_len = self._cfg.collect.unroll_len 205 # self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample') 206 # NOTE this policy is to collect expert traj, so we have to use argmax_sample wrapper 207 self._collect_model = model_wrap(self._model, wrapper_name='argmax_sample') 208 self._collect_model.reset() 209 self._gamma = self._cfg.collect.discount_factor 210 self._gae_lambda = self._cfg.collect.gae_lambda 211 self._nstep = self._cfg.nstep 212 self._nstep_return = self._cfg.nstep_return 213 214 def _forward_collect(self, data: dict) -> dict: 215 r""" 216 Overview: 217 Forward function for collect mode 218 Arguments: 219 - data (:obj:`dict`): Dict type data, including at least ['obs']. 220 Returns: 221 - data (:obj:`dict`): The collected data 222 """ 223 data_id = list(data.keys()) 224 data = default_collate(list(data.values())) 225 if self._cuda: 226 data = to_device(data, self._device) 227 self._collect_model.eval() 228 with torch.no_grad(): 229 output = self._collect_model.forward(data, mode='compute_actor_critic') 230 if self._cuda: 231 output = to_device(output, 'cpu') 232 output = default_decollate(output) 233 return {i: d for i, d in zip(data_id, output)} 234 235 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: 236 """ 237 Overview: 238 Generate dict type transition data from inputs. 239 Arguments: 240 - obs (:obj:`Any`): Env observation 241 - model_output (:obj:`dict`): Output of collect model, including at least ['action'] 242 - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\ 243 (here 'obs' indicates obs after env step). 244 Returns: 245 - transition (:obj:`dict`): Dict type transition data. 246 """ 247 transition = { 248 'obs': obs, 249 'action': model_output['action'], 250 # 'prev_state': model_output['prev_state'], 251 'prev_state': None, 252 'reward': timestep.reward, 253 'done': timestep.done, 254 } 255 return transition 256 257 def _get_train_sample(self, data: list) -> Union[None, List[Any]]: 258 r""" 259 Overview: 260 Get the trajectory and calculate GAE, return one data to cache for next time calculation 261 Arguments: 262 - data (:obj:`list`): The trajectory's cache 263 Returns: 264 - samples (:obj:`dict`): The training samples generated 265 """ 266 from copy import deepcopy 267 # data_one_step = deepcopy(get_nstep_return_data(data, 1, gamma=self._gamma)) 268 data_one_step = deepcopy(data) 269 data = get_nstep_return_data(data, self._nstep, gamma=self._gamma) 270 for i in range(len(data)): 271 # here we record the one-step done, we don't need record one-step reward, 272 # because the n-step reward in data already include one-step reward 273 data[i]['done_one_step'] = data_one_step[i]['done'] 274 return get_train_sample(data, self._unroll_len) # self._unroll_len_add_burnin_step 275 276 def _init_eval(self) -> None: 277 r""" 278 Overview: 279 Evaluate mode init method. Called by ``self.__init__``. 280 Init eval model with argmax strategy. 281 """ 282 self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') 283 self._eval_model.reset() 284 285 def _forward_eval(self, data: dict) -> dict: 286 r""" 287 Overview: 288 Forward function for eval mode, similar to ``self._forward_collect``. 289 Arguments: 290 - data (:obj:`dict`): Dict type data, including at least ['obs']. 291 Returns: 292 - output (:obj:`dict`): Dict type data, including at least inferred action according to input obs. 293 """ 294 data_id = list(data.keys()) 295 data = default_collate(list(data.values())) 296 if self._cuda: 297 data = to_device(data, self._device) 298 self._eval_model.eval() 299 with torch.no_grad(): 300 output = self._eval_model.forward(data, mode='compute_actor') 301 if self._cuda: 302 output = to_device(output, 'cpu') 303 output = default_decollate(output) 304 return {i: d for i, d in zip(data_id, output)} 305 306 def _monitor_vars_learn(self) -> List[str]: 307 return super()._monitor_vars_learn() + [ 308 'policy_loss', 'value_loss', 'entropy_loss', 'adv_abs_max', 'approx_kl', 'clipfrac' 309 ]