ding.policy.mbpolicy.dreamer¶
Full Source Code
../ding/policy/mbpolicy/dreamer.py
1from typing import List, Dict, Any, Tuple, Union 2from collections import namedtuple 3import torch 4from torch import nn 5from copy import deepcopy 6from ding.torch_utils import Adam, to_device 7from ding.rl_utils import get_train_sample 8from ding.utils import POLICY_REGISTRY, deep_merge_dicts 9from ding.utils.data import default_collate, default_decollate 10from ding.policy import Policy 11from ding.model import model_wrap 12from ding.policy.common_utils import default_preprocess_learn 13 14from .utils import imagine, compute_target, compute_actor_loss, RewardEMA, tensorstats 15 16 17@POLICY_REGISTRY.register('dreamer') 18class DREAMERPolicy(Policy): 19 config = dict( 20 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 21 type='dreamer', 22 # (bool) Whether to use cuda for network and loss computation. 23 cuda=False, 24 # (int) Number of training samples (randomly collected) in replay buffer when training starts. 25 random_collect_size=5000, 26 # (bool) Whether to need policy-specific data in preprocess transition. 27 transition_with_policy_data=False, 28 # (int) 29 imag_horizon=15, 30 learn=dict( 31 # (float) Lambda for TD-lambda return. 32 lambda_=0.95, 33 # (float) Max norm of gradients. 34 grad_clip=100, 35 learning_rate=3e-5, 36 batch_size=16, 37 batch_length=64, 38 imag_sample=True, 39 slow_value_target=True, 40 slow_target_update=1, 41 slow_target_fraction=0.02, 42 discount=0.997, 43 reward_EMA=True, 44 actor_entropy=3e-4, 45 actor_state_entropy=0.0, 46 value_decay=0.0, 47 ), 48 ) 49 50 def default_model(self) -> Tuple[str, List[str]]: 51 return 'dreamervac', ['ding.model.template.vac'] 52 53 def _init_learn(self) -> None: 54 r""" 55 Overview: 56 Learn mode init method. Called by ``self.__init__``. 57 Init the optimizer, algorithm config, main and target models. 58 """ 59 # Algorithm config 60 self._lambda = self._cfg.learn.lambda_ 61 self._grad_clip = self._cfg.learn.grad_clip 62 63 self._critic = self._model.critic 64 self._actor = self._model.actor 65 66 if self._cfg.learn.slow_value_target: 67 self._slow_value = deepcopy(self._critic) 68 self._updates = 0 69 70 # Optimizer 71 self._optimizer_value = Adam( 72 self._critic.parameters(), 73 lr=self._cfg.learn.learning_rate, 74 ) 75 self._optimizer_actor = Adam( 76 self._actor.parameters(), 77 lr=self._cfg.learn.learning_rate, 78 ) 79 80 self._learn_model = model_wrap(self._model, wrapper_name='base') 81 self._learn_model.reset() 82 83 self._forward_learn_cnt = 0 84 85 if self._cfg.learn.reward_EMA: 86 self.reward_ema = RewardEMA(device=self._device) 87 88 def _forward_learn(self, start: dict, world_model, envstep) -> Dict[str, Any]: 89 # log dict 90 log_vars = {} 91 self._learn_model.train() 92 self._update_slow_target() 93 94 self._actor.requires_grad_(requires_grad=True) 95 # start is dict of {stoch, deter, logit} 96 if self._cuda: 97 start = to_device(start, self._device) 98 99 # train self._actor 100 imag_feat, imag_state, imag_action = imagine( 101 self._cfg.learn, world_model, start, self._actor, self._cfg.imag_horizon 102 ) 103 reward = world_model.heads["reward"](world_model.dynamics.get_feat(imag_state)).mode() 104 actor_ent = self._actor(imag_feat).entropy() 105 state_ent = world_model.dynamics.get_dist(imag_state).entropy() 106 # this target is not scaled 107 # slow is flag to indicate whether slow_target is used for lambda-return 108 target, weights, base = compute_target( 109 self._cfg.learn, world_model, self._critic, imag_feat, imag_state, reward, actor_ent, state_ent 110 ) 111 actor_loss, mets = compute_actor_loss( 112 self._cfg.learn, 113 self._actor, 114 self.reward_ema, 115 imag_feat, 116 imag_state, 117 imag_action, 118 target, 119 actor_ent, 120 state_ent, 121 weights, 122 base, 123 ) 124 log_vars.update(mets) 125 value_input = imag_feat 126 self._actor.requires_grad_(requires_grad=False) 127 128 self._critic.requires_grad_(requires_grad=True) 129 value = self._critic(value_input[:-1].detach()) 130 # to do 131 # target = torch.stack(target, dim=1) 132 # (time, batch, 1), (time, batch, 1) -> (time, batch) 133 value_loss = -value.log_prob(target.detach()) 134 slow_target = self._slow_value(value_input[:-1].detach()) 135 if self._cfg.learn.slow_value_target: 136 value_loss = value_loss - value.log_prob(slow_target.mode().detach()) 137 if self._cfg.learn.value_decay: 138 value_loss += self._cfg.learn.value_decay * value.mode() 139 # (time, batch, 1), (time, batch, 1) -> (1,) 140 value_loss = torch.mean(weights[:-1] * value_loss[:, :, None]) 141 self._critic.requires_grad_(requires_grad=False) 142 143 log_vars.update(tensorstats(value.mode(), "value")) 144 log_vars.update(tensorstats(target, "target")) 145 log_vars.update(tensorstats(reward, "imag_reward")) 146 log_vars.update(tensorstats(imag_action, "imag_action")) 147 log_vars["actor_ent"] = torch.mean(actor_ent).detach().cpu().numpy().item() 148 # ==================== 149 # actor-critic update 150 # ==================== 151 self._model.requires_grad_(requires_grad=True) 152 world_model.requires_grad_(requires_grad=True) 153 154 loss_dict = { 155 'critic_loss': value_loss, 156 'actor_loss': actor_loss, 157 } 158 159 norm_dict = self._update(loss_dict) 160 161 self._model.requires_grad_(requires_grad=False) 162 world_model.requires_grad_(requires_grad=False) 163 # ============= 164 # after update 165 # ============= 166 self._forward_learn_cnt += 1 167 168 return { 169 **log_vars, 170 **norm_dict, 171 **loss_dict, 172 } 173 174 def _update(self, loss_dict): 175 # update actor 176 self._optimizer_actor.zero_grad() 177 loss_dict['actor_loss'].backward() 178 actor_norm = nn.utils.clip_grad_norm_(self._model.actor.parameters(), self._grad_clip) 179 self._optimizer_actor.step() 180 # update critic 181 self._optimizer_value.zero_grad() 182 loss_dict['critic_loss'].backward() 183 critic_norm = nn.utils.clip_grad_norm_(self._model.critic.parameters(), self._grad_clip) 184 self._optimizer_value.step() 185 return {'actor_grad_norm': actor_norm, 'critic_grad_norm': critic_norm} 186 187 def _update_slow_target(self): 188 if self._cfg.learn.slow_value_target: 189 if self._updates % self._cfg.learn.slow_target_update == 0: 190 mix = self._cfg.learn.slow_target_fraction 191 for s, d in zip(self._critic.parameters(), self._slow_value.parameters()): 192 d.data = mix * s.data + (1 - mix) * d.data 193 self._updates += 1 194 195 def _state_dict_learn(self) -> Dict[str, Any]: 196 ret = { 197 'model': self._learn_model.state_dict(), 198 'optimizer_value': self._optimizer_value.state_dict(), 199 'optimizer_actor': self._optimizer_actor.state_dict(), 200 } 201 return ret 202 203 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 204 self._learn_model.load_state_dict(state_dict['model']) 205 self._optimizer_value.load_state_dict(state_dict['optimizer_value']) 206 self._optimizer_actor.load_state_dict(state_dict['optimizer_actor']) 207 208 def _init_collect(self) -> None: 209 self._unroll_len = self._cfg.collect.unroll_len 210 self._collect_model = model_wrap(self._model, wrapper_name='base') 211 self._collect_model.reset() 212 213 def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=None) -> dict: 214 data_id = list(data.keys()) 215 data = default_collate(list(data.values())) 216 if self._cuda: 217 data = to_device(data, self._device) 218 self._collect_model.eval() 219 220 if state is None: 221 batch_size = len(data_id) 222 latent = world_model.dynamics.initial(batch_size) # {logit, stoch, deter} 223 action = torch.zeros((batch_size, self._cfg.collect.action_size)).to(self._device) 224 else: 225 #state = default_collate(list(state.values())) 226 latent = to_device(default_collate(list(zip(*state))[0]), self._device) 227 action = to_device(default_collate(list(zip(*state))[1]), self._device) 228 if len(action.shape) == 1: 229 action = action.unsqueeze(-1) 230 if reset.any(): 231 mask = 1 - reset 232 for key in latent.keys(): 233 for i in range(latent[key].shape[0]): 234 latent[key][i] *= mask[i] 235 for i in range(len(action)): 236 action[i] *= mask[i] 237 assert world_model.obs_type == 'vector' or world_model.obs_type == 'RGB', \ 238 "action type must be vector or RGB" 239 # normalize RGB image input 240 if world_model.obs_type == 'RGB': 241 data = data - 0.5 242 embed = world_model.encoder(data) 243 latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample) 244 feat = world_model.dynamics.get_feat(latent) 245 246 actor = self._actor(feat) 247 action = actor.sample() 248 logprob = actor.log_prob(action) 249 latent = {k: v.detach() for k, v in latent.items()} 250 action = action.detach() 251 252 state = (latent, action) 253 assert world_model.action_type == 'discrete' or world_model.action_type == 'continuous', \ 254 "action type must be continuous or discrete" 255 if world_model.action_type == 'discrete': 256 action = torch.where(action == 1)[1] 257 output = {"action": action, "logprob": logprob, "state": state} 258 259 if self._cuda: 260 output = to_device(output, 'cpu') 261 output = default_decollate(output) 262 if world_model.action_type == 'discrete': 263 for l in range(len(output)): 264 output[l]['action'] = output[l]['action'].squeeze(0) 265 return {i: d for i, d in zip(data_id, output)} 266 267 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: 268 r""" 269 Overview: 270 Generate dict type transition data from inputs. 271 Arguments: 272 - obs (:obj:`Any`): Env observation 273 - model_output (:obj:`dict`): Output of collect model, including at least ['action'] 274 - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ 275 (here 'obs' indicates obs after env step). 276 Returns: 277 - transition (:obj:`dict`): Dict type transition data. 278 """ 279 transition = { 280 'obs': obs, 281 'action': model_output['action'], 282 # TODO(zp) random_collect just have action 283 #'logprob': model_output['logprob'], 284 'reward': timestep.reward, 285 'discount': 1. - timestep.done, # timestep.info['discount'], 286 'done': timestep.done, 287 } 288 return transition 289 290 def _get_train_sample(self, data: list) -> Union[None, List[Any]]: 291 return get_train_sample(data, self._unroll_len) 292 293 def _init_eval(self) -> None: 294 self._eval_model = model_wrap(self._model, wrapper_name='base') 295 self._eval_model.reset() 296 297 def _forward_eval(self, data: dict, world_model, reset=None, state=None) -> dict: 298 data_id = list(data.keys()) 299 data = default_collate(list(data.values())) 300 if self._cuda: 301 data = to_device(data, self._device) 302 self._eval_model.eval() 303 304 if state is None: 305 batch_size = len(data_id) 306 latent = world_model.dynamics.initial(batch_size) # {logit, stoch, deter} 307 action = torch.zeros((batch_size, self._cfg.collect.action_size)).to(self._device) 308 else: 309 #state = default_collate(list(state.values())) 310 latent = to_device(default_collate(list(zip(*state))[0]), self._device) 311 action = to_device(default_collate(list(zip(*state))[1]), self._device) 312 if len(action.shape) == 1: 313 action = action.unsqueeze(-1) 314 if reset.any(): 315 mask = 1 - reset 316 for key in latent.keys(): 317 for i in range(latent[key].shape[0]): 318 latent[key][i] *= mask[i] 319 for i in range(len(action)): 320 action[i] *= mask[i] 321 322 # normalize RGB image input 323 if world_model.obs_type == 'RGB': 324 data = data - 0.5 325 embed = world_model.encoder(data) 326 latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample) 327 feat = world_model.dynamics.get_feat(latent) 328 329 actor = self._actor(feat) 330 action = actor.mode() 331 logprob = actor.log_prob(action) 332 latent = {k: v.detach() for k, v in latent.items()} 333 action = action.detach() 334 335 state = (latent, action) 336 if world_model.action_type == 'discrete': 337 action = torch.where(action == 1)[1] 338 output = {"action": action, "logprob": logprob, "state": state} 339 340 if self._cuda: 341 output = to_device(output, 'cpu') 342 output = default_decollate(output) 343 if world_model.action_type == 'discrete': 344 for l in range(len(output)): 345 output[l]['action'] = output[l]['action'].squeeze(0) 346 return {i: d for i, d in zip(data_id, output)} 347 348 def _monitor_vars_learn(self) -> List[str]: 349 r""" 350 Overview: 351 Return variables' name if variables are to used in monitor. 352 Returns: 353 - vars (:obj:`List[str]`): Variables' name list. 354 """ 355 return [ 356 'normed_target_mean', 'normed_target_std', 'normed_target_min', 'normed_target_max', 'EMA_005', 'EMA_095', 357 'actor_entropy', 'actor_state_entropy', 'value_mean', 'value_std', 'value_min', 'value_max', 'target_mean', 358 'target_std', 'target_min', 'target_max', 'imag_reward_mean', 'imag_reward_std', 'imag_reward_min', 359 'imag_reward_max', 'imag_action_mean', 'imag_action_std', 'imag_action_min', 'imag_action_max', 'actor_ent', 360 'actor_loss', 'critic_loss', 'actor_grad_norm', 'critic_grad_norm' 361 ]