Skip to content

ding.policy.sqn

ding.policy.sqn

SQNPolicy

Bases: Policy

Overview

Policy class of SQN algorithm (arxiv: 1912.10891).

Full Source Code

../ding/policy/sqn.py

1from typing import List, Dict, Any, Tuple, Union 2from collections import namedtuple 3import math 4import itertools 5import numpy as np 6import torch 7import torch.nn.functional as F 8import copy 9 10from ding.torch_utils import Adam, to_device 11from ding.rl_utils import get_train_sample 12from ding.model import model_wrap 13from ding.utils import POLICY_REGISTRY 14from ding.utils.data import default_collate, default_decollate 15from .base_policy import Policy 16from .common_utils import default_preprocess_learn 17 18 19@POLICY_REGISTRY.register('sqn') 20class SQNPolicy(Policy): 21 r""" 22 Overview: 23 Policy class of SQN algorithm (arxiv: 1912.10891). 24 """ 25 26 config = dict( 27 cuda=False, 28 type='sqn', 29 on_policy=False, 30 priority=False, 31 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 32 priority_IS_weight=False, 33 learn=dict( 34 update_per_collect=16, 35 batch_size=64, 36 learning_rate_q=0.001, 37 learning_rate_alpha=0.001, 38 # ============================================================== 39 # The following configs are algorithm-specific 40 # ============================================================== 41 target_theta=0.005, 42 alpha=0.2, 43 discount_factor=0.99, 44 # If env's action shape is int type, we recommend `self._action_shape / 10`; else, we recommend 0.2 45 target_entropy=0.2, 46 # (bool) Whether ignore done(usually for max step termination env) 47 ignore_done=False, 48 ), 49 collect=dict( 50 # n_sample=16, 51 # Cut trajectories into pieces with length "unroll_len". 52 unroll_len=1, 53 ), 54 eval=dict(), 55 other=dict( 56 eps=dict( 57 type='exp', 58 start=1., 59 end=0.8, 60 decay=2000, 61 ), 62 replay_buffer=dict(replay_buffer_size=100000, ), 63 ), 64 ) 65 66 def default_model(self) -> Tuple[str, List[str]]: 67 return 'sqn', ['ding.model.template.sqn'] 68 69 def _init_learn(self) -> None: 70 r""" 71 Overview: 72 Learn mode init method. Called by ``self.__init__``. 73 Init q, value and policy's optimizers, algorithm config, main and target models. 74 """ 75 self._priority = self._cfg.priority 76 self._priority_IS_weight = self._cfg.priority_IS_weight 77 # Optimizers 78 self._optimizer_q = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate_q) 79 80 # Algorithm config 81 self._gamma = self._cfg.learn.discount_factor 82 self._action_shape = self._cfg.model.action_shape 83 self._target_entropy = self._cfg.learn.target_entropy 84 self._log_alpha = torch.FloatTensor([math.log(self._cfg.learn.alpha)]).to(self._device).requires_grad_(True) 85 self._optimizer_alpha = Adam([self._log_alpha], lr=self._cfg.learn.learning_rate_alpha) 86 87 # Main and target models 88 self._target_model = copy.deepcopy(self._model) 89 self._target_model = model_wrap( 90 self._target_model, 91 wrapper_name='target', 92 update_type='momentum', 93 update_kwargs={'theta': self._cfg.learn.target_theta} 94 ) 95 self._learn_model = model_wrap(self._model, wrapper_name='base') 96 self._learn_model.reset() 97 self._target_model.reset() 98 99 self._forward_learn_cnt = 0 100 101 def q_1step_td_loss(self, td_data: dict) -> torch.tensor: 102 q_value = td_data["q_value"] 103 target_q_value = td_data["target_q_value"] 104 action = td_data.get('action') 105 done = td_data.get('done') 106 reward = td_data.get('reward') 107 q0 = q_value[0] 108 q1 = q_value[1] 109 batch_range = torch.arange(action.shape[0]) 110 q0_a = q0[batch_range, action] 111 q1_a = q1[batch_range, action] 112 # Target 113 with torch.no_grad(): 114 q0_targ = target_q_value[0] 115 q1_targ = target_q_value[1] 116 q_targ = torch.min(q0_targ, q1_targ) 117 # discrete policy 118 alpha = torch.exp(self._log_alpha.clone()) 119 # TODO use q_targ or q0 for pi 120 log_pi = F.log_softmax(q_targ / alpha, dim=-1) 121 pi = torch.exp(log_pi) 122 # v = \sum_a \pi(a | s) (Q(s, a) - \alpha \log(\pi(a|s))) 123 target_v_value = (pi * (q_targ - alpha * log_pi)).sum(axis=-1) 124 # q = r + \gamma v 125 q_backup = reward + (1 - done) * self._gamma * target_v_value 126 # alpha_loss 127 entropy = (-pi * log_pi).sum(axis=-1) 128 expect_entropy = (pi * self._target_entropy).sum(axis=-1) 129 130 # Q loss 131 q0_loss = F.mse_loss(q0_a, q_backup) 132 q1_loss = F.mse_loss(q1_a, q_backup) 133 total_q_loss = q0_loss + q1_loss 134 # alpha loss 135 alpha_loss = self._log_alpha * (entropy - expect_entropy).mean() 136 return total_q_loss, alpha_loss, entropy 137 138 def _forward_learn(self, data: dict) -> Dict[str, Any]: 139 """ 140 Overview: 141 Forward and backward function of learn mode. 142 Arguments: 143 - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs', 'done',\ 144 'weight'] 145 Returns: 146 - info_dict (:obj:`Dict[str, Any]`): Learn info, including current lr and loss. 147 """ 148 data = default_preprocess_learn( 149 data, 150 use_priority=self._cfg.priority, 151 use_priority_IS_weight=self._cfg.priority_IS_weight, 152 ignore_done=self._cfg.learn.ignore_done, 153 use_nstep=False 154 ) 155 if self._cuda: 156 data = to_device(data, self._device) 157 158 self._learn_model.train() 159 self._target_model.train() 160 obs = data.get('obs') 161 next_obs = data.get('next_obs') 162 reward = data.get('reward') 163 action = data.get('action') 164 done = data.get('done') 165 # Q-function 166 q_value = self._learn_model.forward(obs)['q_value'] 167 target_q_value = self._target_model.forward(next_obs)['q_value'] 168 169 num_s_env = 1 if isinstance(self._action_shape, int) else len(self._action_shape) # num of separate env 170 171 for s_env_id in range(num_s_env): 172 if isinstance(self._action_shape, int): 173 td_data = { 174 "q_value": q_value, 175 "target_q_value": target_q_value, 176 "obs": obs, 177 "next_obs": next_obs, 178 "reward": reward, 179 "action": action, 180 "done": done 181 } 182 else: 183 td_data = { 184 "q_value": [q_value[0][s_env_id], q_value[1][s_env_id]], 185 "target_q_value": [target_q_value[0][s_env_id], target_q_value[1][s_env_id]], 186 "obs": obs, 187 "next_obs": next_obs, 188 "reward": reward, 189 "action": action[s_env_id], 190 "done": done 191 } 192 total_q_loss, alpha_loss, entropy = self.q_1step_td_loss(td_data) 193 if s_env_id == 0: 194 a_total_q_loss, a_alpha_loss, a_entropy = total_q_loss, alpha_loss, entropy # accumulate 195 else: # running average, accumulate loss 196 a_total_q_loss += total_q_loss / (num_s_env + 1e-6) 197 a_alpha_loss += alpha_loss / (num_s_env + 1e-6) 198 a_entropy += entropy / (num_s_env + 1e-6) 199 200 self._optimizer_q.zero_grad() 201 a_total_q_loss.backward() 202 self._optimizer_q.step() 203 204 self._optimizer_alpha.zero_grad() 205 a_alpha_loss.backward() 206 self._optimizer_alpha.step() 207 208 # target update 209 self._target_model.update(self._learn_model.state_dict()) 210 self._forward_learn_cnt += 1 211 # some useful info 212 return { 213 '[histogram]action_distribution': np.stack([a.cpu().numpy() for a in data['action']]).flatten(), 214 'q_loss': a_total_q_loss.item(), 215 'alpha_loss': a_alpha_loss.item(), 216 'entropy': a_entropy.mean().item(), 217 'alpha': math.exp(self._log_alpha.item()), 218 'q_value': np.mean([x.cpu().detach().numpy() for x in itertools.chain(*q_value)], dtype=float), 219 } 220 221 def _state_dict_learn(self) -> Dict[str, Any]: 222 return { 223 'model': self._learn_model.state_dict(), 224 'optimizer_q': self._optimizer_q.state_dict(), 225 'optimizer_alpha': self._optimizer_alpha.state_dict(), 226 } 227 228 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 229 self._learn_model.load_state_dict(state_dict['model']) 230 self._optimizer_q.load_state_dict(state_dict['optimizer_q']) 231 self._optimizer_alpha.load_state_dict(state_dict['optimizer_alpha']) 232 233 def _init_collect(self) -> None: 234 r""" 235 Overview: 236 Collect mode init method. Called by ``self.__init__``. 237 Init traj and unroll length, collect model. 238 Use action noise for exploration. 239 """ 240 self._unroll_len = self._cfg.collect.unroll_len 241 self._collect_model = model_wrap(self._model, wrapper_name='base') 242 self._collect_model.reset() 243 244 def _forward_collect(self, data: dict) -> dict: 245 r""" 246 Overview: 247 Forward function of collect mode. 248 Arguments: 249 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 250 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 251 Returns: 252 - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. 253 ReturnsKeys 254 - necessary: ``action`` 255 - optional: ``logit`` 256 """ 257 data_id = list(data.keys()) 258 data = default_collate(list(data.values())) 259 if self._cuda: 260 data = to_device(data, self._device) 261 self._collect_model.eval() 262 with torch.no_grad(): 263 # start with random action for better exploration 264 output = self._collect_model.forward(data) 265 _decay = self._cfg.other.eps.decay 266 _act_p = 1 / \ 267 (_decay - self._forward_learn_cnt) if self._forward_learn_cnt < _decay - 1000 else 0.999 268 269 if np.random.random(1) < _act_p: 270 if isinstance(self._action_shape, int): 271 logits = output['logit'] / math.exp(self._log_alpha.item()) 272 prob = torch.softmax(logits - logits.max(axis=-1, keepdim=True).values, dim=-1) 273 pi_action = torch.multinomial(prob, 1) 274 else: 275 logits = [_logit / math.exp(self._log_alpha.item()) for _logit in output['logit']] 276 prob = [ 277 torch.softmax(_logits - _logits.max(axis=-1, keepdim=True).values, dim=-1) for _logits in logits 278 ] 279 pi_action = [torch.multinomial(_prob, 1) for _prob in prob] 280 else: 281 if isinstance(self._action_shape, int): 282 pi_action = torch.randint(0, self._action_shape, (output["logit"].shape[0], )) 283 else: 284 pi_action = [torch.randint(0, d, (output["logit"][0].shape[0], )) for d in self._action_shape] 285 286 output['action'] = pi_action 287 if self._cuda: 288 output = to_device(output, 'cpu') 289 output = default_decollate(output) 290 return {i: d for i, d in zip(data_id, output)} 291 292 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: 293 r""" 294 Overview: 295 Generate dict type transition data from inputs. 296 Arguments: 297 - obs (:obj:`Any`): Env observation 298 - model_output (:obj:`dict`): Output of collect model, including at least ['action'] 299 - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ 300 (here 'obs' indicates obs after env step, i.e. next_obs). 301 Return: 302 - transition (:obj:`Dict[str, Any]`): Dict type transition data. 303 """ 304 transition = { 305 'obs': obs, 306 'next_obs': timestep.obs, 307 'action': model_output['action'], 308 'reward': timestep.reward, 309 'done': timestep.done, 310 } 311 return transition 312 313 def _get_train_sample(self, data: list) -> Union[None, List[Any]]: 314 return get_train_sample(data, self._unroll_len) 315 316 def _init_eval(self) -> None: 317 r""" 318 Overview: 319 Evaluate mode init method. Called by ``self.__init__``. 320 Init eval model, which use argmax for selecting action 321 """ 322 self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') 323 self._eval_model.reset() 324 325 def _forward_eval(self, data: dict) -> dict: 326 r""" 327 Overview: 328 Forward function of eval mode, similar to ``self._forward_collect``. 329 Arguments: 330 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 331 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 332 Returns: 333 - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. 334 ReturnsKeys 335 - necessary: ``action`` 336 - optional: ``logit`` 337 """ 338 data_id = list(data.keys()) 339 data = default_collate(list(data.values())) 340 if self._cuda: 341 data = to_device(data, self._device) 342 self._eval_model.eval() 343 with torch.no_grad(): 344 output = self._eval_model.forward(data) 345 if self._cuda: 346 output = to_device(output, 'cpu') 347 output = default_decollate(output) 348 return {i: d for i, d in zip(data_id, output)} 349 350 def _monitor_vars_learn(self) -> List[str]: 351 r""" 352 Overview: 353 Return variables' name if variables are to used in monitor. 354 Returns: 355 - vars (:obj:`List[str]`): Variables' name list. 356 """ 357 return ['alpha_loss', 'alpha', 'entropy', 'q_loss', 'q_value']