Skip to content

ding.policy.sql

ding.policy.sql

SQLPolicy

Bases: Policy

Overview

Policy class of SQL algorithm.

default_model()

Overview

Return this algorithm default model setting for demonstration.

Returns: - model_info (:obj:Tuple[str, List[str]]): model name and mode import_names

.. note:: The user can define and use customized network model but must obey the same inferface definition indicated by import_names path. For DQN, ding.model.template.q_learning.DQN

Full Source Code

../ding/policy/sql.py

1from typing import List, Dict, Any, Tuple, Union, Optional 2from collections import namedtuple, deque 3import copy 4import torch 5from torch.distributions import Categorical 6from ditk import logging 7from easydict import EasyDict 8from ding.torch_utils import Adam, to_device 9from ding.utils.data import default_collate, default_decollate 10from ding.rl_utils import q_nstep_td_data, q_nstep_sql_td_error, get_nstep_return_data, get_train_sample 11from ding.model import model_wrap 12from ding.utils import POLICY_REGISTRY 13from .base_policy import Policy 14from .common_utils import default_preprocess_learn 15 16 17@POLICY_REGISTRY.register('sql') 18class SQLPolicy(Policy): 19 r""" 20 Overview: 21 Policy class of SQL algorithm. 22 """ 23 24 config = dict( 25 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 26 type='sql', 27 # (bool) Whether to use cuda for network. 28 cuda=False, 29 # (bool) Whether the RL algorithm is on-policy or off-policy. 30 on_policy=False, 31 # (bool) Whether use priority(priority sample, IS weight, update priority) 32 priority=False, 33 # (float) Reward's future discount factor, aka. gamma. 34 discount_factor=0.97, 35 # (int) N-step reward for target q_value estimation 36 nstep=1, 37 learn=dict( 38 39 # How many updates(iterations) to train after collector's one collection. 40 # Bigger "update_per_collect" means bigger off-policy. 41 # collect data -> update policy-> collect data -> ... 42 update_per_collect=3, # after the batch data come into the learner, train with the data for 3 times 43 batch_size=64, 44 learning_rate=0.001, 45 # ============================================================== 46 # The following configs are algorithm-specific 47 # ============================================================== 48 # (int) Frequence of target network update. 49 target_update_freq=100, 50 # (bool) Whether ignore done(usually for max step termination env) 51 ignore_done=False, 52 alpha=0.1, 53 ), 54 # collect_mode config 55 collect=dict( 56 # (int) Only one of [n_sample, n_episode] shoule be set 57 # n_sample=8, # collect 8 samples and put them in collector 58 # (int) Cut trajectories into pieces with length "unroll_len". 59 unroll_len=1, 60 ), 61 eval=dict(), 62 # other config 63 other=dict( 64 # Epsilon greedy with decay. 65 eps=dict( 66 # (str) Decay type. Support ['exp', 'linear']. 67 type='exp', 68 start=0.95, 69 end=0.1, 70 # (int) Decay length(env step) 71 decay=10000, 72 ), 73 replay_buffer=dict(replay_buffer_size=10000, ) 74 ), 75 ) 76 77 def default_model(self) -> Tuple[str, List[str]]: 78 """ 79 Overview: 80 Return this algorithm default model setting for demonstration. 81 Returns: 82 - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names 83 84 .. note:: 85 The user can define and use customized network model but must obey the same inferface definition indicated \ 86 by import_names path. For DQN, ``ding.model.template.q_learning.DQN`` 87 """ 88 return 'dqn', ['ding.model.template.q_learning'] 89 90 def _init_learn(self) -> None: 91 r""" 92 Overview: 93 Learn mode init method. Called by ``self.__init__``. 94 Init the optimizer, algorithm config, main and target models. 95 """ 96 self._priority = self._cfg.priority 97 # Optimizer 98 self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) 99 self._gamma = self._cfg.discount_factor 100 self._nstep = self._cfg.nstep 101 self._alpha = self._cfg.learn.alpha 102 # use wrapper instead of plugin 103 self._target_model = copy.deepcopy(self._model) 104 self._target_model = model_wrap( 105 self._target_model, 106 wrapper_name='target', 107 update_type='assign', 108 update_kwargs={'freq': self._cfg.learn.target_update_freq} 109 ) 110 self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') 111 self._learn_model.reset() 112 self._target_model.reset() 113 114 def _forward_learn(self, data: dict) -> Dict[str, Any]: 115 r""" 116 Overview: 117 Forward and backward function of learn mode. 118 Arguments: 119 - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] 120 Returns: 121 - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. 122 """ 123 data = default_preprocess_learn( 124 data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True 125 ) 126 if self._cuda: 127 data = to_device(data, self._device) 128 # ==================== 129 # Q-learning forward 130 # ==================== 131 self._learn_model.train() 132 self._target_model.train() 133 # Current q value (main model) 134 q_value = self._learn_model.forward(data['obs'])['logit'] 135 with torch.no_grad(): 136 # Target q value 137 target_q_value = self._target_model.forward(data['next_obs'])['logit'] 138 # Max q value action (main model) 139 target_q_action = self._learn_model.forward(data['next_obs'])['action'] 140 141 data_n = q_nstep_td_data( 142 q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight'] 143 ) 144 value_gamma = data.get('value_gamma') 145 loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error( 146 data_n, self._gamma, self._cfg.learn.alpha, nstep=self._nstep, value_gamma=value_gamma 147 ) 148 record_target_v = record_target_v.mean() 149 # ==================== 150 # Q-learning update 151 # ==================== 152 self._optimizer.zero_grad() 153 loss.backward() 154 if self._cfg.multi_gpu: 155 self.sync_gradients(self._learn_model) 156 self._optimizer.step() 157 158 # ============= 159 # after update 160 # ============= 161 self._target_model.update(self._learn_model.state_dict()) 162 return { 163 'cur_lr': self._optimizer.defaults['lr'], 164 'total_loss': loss.item(), 165 'priority': td_error_per_sample.abs().tolist(), 166 'record_value_function': record_target_v 167 # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard. 168 # '[histogram]action_distribution': data['action'], 169 } 170 171 def _state_dict_learn(self) -> Dict[str, Any]: 172 return { 173 'model': self._learn_model.state_dict(), 174 'target_model': self._target_model.state_dict(), 175 'optimizer': self._optimizer.state_dict(), 176 } 177 178 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 179 self._learn_model.load_state_dict(state_dict['model']) 180 self._target_model.load_state_dict(state_dict['target_model']) 181 self._optimizer.load_state_dict(state_dict['optimizer']) 182 183 def _init_collect(self) -> None: 184 r""" 185 Overview: 186 Collect mode init method. Called by ``self.__init__``. 187 Init traj and unroll length, collect model. 188 Enable the eps_greedy_sample 189 """ 190 self._unroll_len = self._cfg.collect.unroll_len 191 self._gamma = self._cfg.discount_factor # necessary for parallel 192 self._nstep = self._cfg.nstep # necessary for parallel 193 self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_multinomial_sample') 194 self._collect_model.reset() 195 196 def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: 197 r""" 198 Overview: 199 Forward function for collect mode with eps_greedy 200 Arguments: 201 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 202 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 203 - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step. 204 Returns: 205 - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. 206 ReturnsKeys 207 - necessary: ``action`` 208 """ 209 data_id = list(data.keys()) 210 data = default_collate(list(data.values())) 211 if self._cuda: 212 data = to_device(data, self._device) 213 self._collect_model.eval() 214 with torch.no_grad(): 215 output = self._collect_model.forward(data, eps=eps, alpha=self._cfg.learn.alpha) 216 if self._cuda: 217 output = to_device(output, 'cpu') 218 output = default_decollate(output) 219 return {i: d for i, d in zip(data_id, output)} 220 221 def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 222 """ 223 Overview: 224 For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \ 225 can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \ 226 or some continuous transitions(DRQN). 227 Arguments: 228 - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \ 229 format as the return value of ``self._process_transition`` method. 230 Returns: 231 - samples (:obj:`dict`): The list of training samples. 232 233 .. note:: 234 We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ 235 And the user can customize the this data processing procecure by overriding this two methods and collector \ 236 itself. 237 """ 238 data = get_nstep_return_data(data, self._nstep, gamma=self._gamma) 239 return get_train_sample(data, self._unroll_len) 240 241 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: 242 r""" 243 Overview: 244 Generate dict type transition data from inputs. 245 Arguments: 246 - obs (:obj:`Any`): Env observation 247 - model_output (:obj:`dict`): Output of collect model, including at least ['action'] 248 - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ 249 (here 'obs' indicates obs after env step). 250 Returns: 251 - transition (:obj:`dict`): Dict type transition data. 252 """ 253 transition = { 254 'obs': obs, 255 'next_obs': timestep.obs, 256 'action': model_output['action'], 257 'reward': timestep.reward, 258 'done': timestep.done, 259 } 260 return transition 261 262 def _init_eval(self) -> None: 263 r""" 264 Overview: 265 Evaluate mode init method. Called by ``self.__init__``. 266 Init eval model with argmax strategy. 267 """ 268 self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') 269 self._eval_model.reset() 270 271 def _forward_eval(self, data: dict) -> dict: 272 r""" 273 Overview: 274 Forward function of eval mode, similar to ``self._forward_collect``. 275 Arguments: 276 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 277 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 278 Returns: 279 - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. 280 ReturnsKeys 281 - necessary: ``action`` 282 """ 283 data_id = list(data.keys()) 284 data = default_collate(list(data.values())) 285 if self._cuda: 286 data = to_device(data, self._device) 287 self._eval_model.eval() 288 with torch.no_grad(): 289 output = self._eval_model.forward(data) 290 if self._cuda: 291 output = to_device(output, 'cpu') 292 output = default_decollate(output) 293 return {i: d for i, d in zip(data_id, output)} 294 295 def _monitor_vars_learn(self) -> List[str]: 296 return super()._monitor_vars_learn() + ['record_value_function']