1from typing import List, Dict, Any, Tuple 2from collections import namedtuple 3import copy 4import torch 5from torch.optim import AdamW 6from ding.torch_utils import Adam, to_device 7from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, get_nstep_return_data, get_train_sample, \ 8 dqfd_nstep_td_error, dqfd_nstep_td_data 9from ding.model import model_wrap 10from ding.utils import POLICY_REGISTRY 11from ding.utils.data import default_collate, default_decollate 12from .dqn import DQNPolicy 13from .common_utils import default_preprocess_learn 14from copy import deepcopy 15 16 17@POLICY_REGISTRY.register('dqfd') 18class DQFDPolicy(DQNPolicy): 19 r""" 20 Overview: 21 Policy class of DQFD algorithm, extended by Double DQN/Dueling DQN/PER/multi-step TD. 22 23 Config: 24 == ==================== ======== ============== ======================================== ======================= 25 ID Symbol Type Default Value Description Other(Shape) 26 == ==================== ======== ============== ======================================== ======================= 27 1 ``type`` str dqn | RL policy register name, refer to | This arg is optional, 28 | registry ``POLICY_REGISTRY`` | a placeholder 29 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff- 30 | erent from modes 31 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy 32 | or off-policy 33 4 ``priority`` bool True | Whether use priority(PER) | Priority sample, 34 | update priority 35 5 | ``priority_IS`` bool True | Whether use Importance Sampling Weight 36 | ``_weight`` | to correct biased update. If True, 37 | priority must be True. 38 6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse 39 | ``factor`` [0.95, 0.999] | gamma | reward env 40 7 ``nstep`` int 10, | N-step reward discount sum for target 41 [3, 5] | q_value estimation 42 8 | ``lambda1`` float 1 | multiplicative factor for n-step 43 9 | ``lambda2`` float 1 | multiplicative factor for the 44 | supervised margin loss 45 10 | ``lambda3`` float 1e-5 | L2 loss 46 11 | ``margin_fn`` float 0.8 | margin function in JE, here we set 47 | this as a constant 48 12 | ``per_train_`` int 10 | number of pertraining iterations 49 | ``iter_k`` 50 13 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary 51 | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val 52 | valid in serial training | means more off-policy 53 14 | ``learn.batch_`` int 64 | The number of samples of an iteration 54 | ``size`` 55 15 | ``learn.learning`` float 0.001 | Gradient step length of an iteration. 56 | ``_rate`` 57 16 | ``learn.target_`` int 100 | Frequency of target network update. | Hard(assign) update 58 | ``update_freq`` 59 17 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some 60 | ``done`` | calculation. | fake termination env 61 18 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from 62 | call of collector. | different envs 63 19 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1 64 | ``_len`` 65 == ==================== ======== ============== ======================================== ======================= 66 """ 67 68 config = dict( 69 type='dqfd', 70 cuda=False, 71 on_policy=False, 72 priority=True, 73 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 74 priority_IS_weight=True, 75 discount_factor=0.99, 76 nstep=10, 77 learn=dict( 78 # multiplicative factor for each loss 79 lambda1=1.0, # n-step return 80 lambda2=1.0, # supervised loss 81 lambda3=1e-5, # L2 82 # margin function in JE, here we implement this as a constant 83 margin_function=0.8, 84 # number of pertraining iterations 85 per_train_iter_k=10, 86 87 # How many updates(iterations) to train after collector's one collection. 88 # Bigger "update_per_collect" means bigger off-policy. 89 # collect data -> update policy-> collect data -> ... 90 update_per_collect=3, 91 batch_size=64, 92 learning_rate=0.001, 93 # ============================================================== 94 # The following configs are algorithm-specific 95 # ============================================================== 96 # (int) Frequence of target network update. 97 target_update_freq=100, 98 # (bool) Whether ignore done(usually for max step termination env) 99 ignore_done=False, 100 ), 101 # collect_mode config 102 collect=dict( 103 # (int) Only one of [n_sample, n_episode] should be set 104 # n_sample=8, 105 # (int) Cut trajectories into pieces with length "unroll_len". 106 unroll_len=1, 107 # The hyperparameter pho, the demo ratio, control the propotion of data\ 108 # coming from expert demonstrations versus from the agent's own experience. 109 pho=0.5, 110 ), 111 eval=dict(), 112 # other config 113 other=dict( 114 # Epsilon greedy with decay. 115 eps=dict( 116 # (str) Decay type. Support ['exp', 'linear']. 117 type='exp', 118 start=0.95, 119 end=0.1, 120 # (int) Decay length(env step) 121 decay=10000, 122 ), 123 replay_buffer=dict(replay_buffer_size=10000, ), 124 ), 125 ) 126 127 def _init_learn(self) -> None: 128 """ 129 Overview: 130 Learn mode init method. Called by ``self.__init__``, initialize the optimizer, algorithm arguments, main \ 131 and target model. 132 """ 133 self.lambda1 = self._cfg.learn.lambda1 # n-step return 134 self.lambda2 = self._cfg.learn.lambda2 # supervised loss 135 self.lambda3 = self._cfg.learn.lambda3 # L2 136 # margin function in JE, here we implement this as a constant 137 self.margin_function = self._cfg.learn.margin_function 138 self._priority = self._cfg.priority 139 self._priority_IS_weight = self._cfg.priority_IS_weight 140 # Optimizer 141 # two optimizers: the performance of adamW is better than adam, so we recommend using the adamW. 142 self._optimizer = AdamW(self._model.parameters(), lr=self._cfg.learn.learning_rate, weight_decay=self.lambda3) 143 # self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate, weight_decay=self.lambda3) 144 145 self._gamma = self._cfg.discount_factor 146 self._nstep = self._cfg.nstep 147 148 # use model_wrapper for specialized demands of different modes 149 self._target_model = copy.deepcopy(self._model) 150 self._target_model = model_wrap( 151 self._target_model, 152 wrapper_name='target', 153 update_type='assign', 154 update_kwargs={'freq': self._cfg.learn.target_update_freq} 155 ) 156 self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') 157 self._learn_model.reset() 158 self._target_model.reset() 159 160 def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: 161 """ 162 Overview: 163 Forward computation graph of learn mode(updating policy). 164 Arguments: 165 - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \ 166 np.ndarray or dict/list combinations. 167 Returns: 168 - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ 169 recorded in text log and tensorboard, values are python scalar or a list of scalars. 170 ArgumentsKeys: 171 - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done`` 172 - optional: ``value_gamma``, ``IS`` 173 ReturnsKeys: 174 - necessary: ``cur_lr``, ``total_loss``, ``priority`` 175 - optional: ``action_distribution`` 176 """ 177 data = default_preprocess_learn( 178 data, 179 use_priority=self._priority, 180 use_priority_IS_weight=self._cfg.priority_IS_weight, 181 ignore_done=self._cfg.learn.ignore_done, 182 use_nstep=True 183 ) 184 data['done_1'] = data['done_1'].float() 185 if self._cuda: 186 data = to_device(data, self._device) 187 # ==================== 188 # Q-learning forward 189 # ==================== 190 self._learn_model.train() 191 self._target_model.train() 192 # Current q value (main model) 193 q_value = self._learn_model.forward(data['obs'])['logit'] 194 # Target q value 195 with torch.no_grad(): 196 target_q_value = self._target_model.forward(data['next_obs'])['logit'] 197 target_q_value_one_step = self._target_model.forward(data['next_obs_1'])['logit'] 198 # Max q value action (main model) 199 target_q_action = self._learn_model.forward(data['next_obs'])['action'] 200 target_q_action_one_step = self._learn_model.forward(data['next_obs_1'])['action'] 201 202 # modify the tensor type to match the JE computation in dqfd_nstep_td_error 203 is_expert = data['is_expert'].float() 204 data_n = dqfd_nstep_td_data( 205 q_value, 206 target_q_value, 207 data['action'], 208 target_q_action, 209 data['reward'], 210 data['done'], 211 data['done_1'], 212 data['weight'], 213 target_q_value_one_step, 214 target_q_action_one_step, 215 is_expert # set is_expert flag(expert 1, agent 0) 216 ) 217 value_gamma = data.get('value_gamma') 218 loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error( 219 data_n, 220 self._gamma, 221 self.lambda1, 222 self.lambda2, 223 self.margin_function, 224 nstep=self._nstep, 225 value_gamma=value_gamma 226 ) 227 228 # ==================== 229 # Q-learning update 230 # ==================== 231 self._optimizer.zero_grad() 232 loss.backward() 233 if self._cfg.multi_gpu: 234 self.sync_gradients(self._learn_model) 235 self._optimizer.step() 236 237 # ============= 238 # after update 239 # ============= 240 self._target_model.update(self._learn_model.state_dict()) 241 return { 242 'cur_lr': self._optimizer.defaults['lr'], 243 'total_loss': loss.item(), 244 'priority': td_error_per_sample.abs().tolist(), 245 # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard. 246 # '[histogram]action_distribution': data['action'], 247 } 248 249 def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 250 """ 251 Overview: 252 For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \ 253 can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \ 254 or some continuous transitions(DRQN). 255 Arguments: 256 - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \ 257 format as the return value of ``self._process_transition`` method. 258 Returns: 259 - samples (:obj:`dict`): The list of training samples. 260 261 .. note:: 262 We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ 263 And the user can customize the this data processing procecure by overriding this two methods and collector \ 264 itself. 265 """ 266 data_1 = deepcopy(get_nstep_return_data(data, 1, gamma=self._gamma)) 267 data = get_nstep_return_data( 268 data, self._nstep, gamma=self._gamma 269 ) # here we want to include one-step next observation 270 for i in range(len(data)): 271 data[i]['next_obs_1'] = data_1[i]['next_obs'] # concat the one-step next observation 272 data[i]['done_1'] = data_1[i]['done'] 273 return get_train_sample(data, self._unroll_len)