Skip to content

ding.policy.fqf

ding.policy.fqf

FQFPolicy

Bases: DQNPolicy

Overview

Policy class of FQF (Fully Parameterized Quantile Function) algorithm, proposed in https://arxiv.org/pdf/1911.02140.pdf.

Config

== ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============== ======================================== ======================= 1 type str fqf | RL policy register name, refer to | this arg is optional, | registry POLICY_REGISTRY | a placeholder 2 cuda bool False | Whether to use cuda for network | this arg can be diff- | erent from modes 3 on_policy bool False | Whether the RL algorithm is on-policy | or off-policy 4 priority bool True | Whether use priority(PER) | priority sample, | update priority 6 | other.eps float 0.05 | Start value for epsilon decay. It's | .start | small because rainbow use noisy net. 7 | other.eps float 0.05 | End value for epsilon decay. | .end 8 | discount_ float 0.97, | Reward's future discount factor, aka. | may be 1 when sparse | factor [0.95, 0.999] | gamma | reward env 9 nstep int 3, | N-step reward discount sum for target [3, 5] | q_value estimation 10 | learn.update int 3 | How many updates(iterations) to train | this args can be vary | per_collect | after collector's one collection. Only | from envs. Bigger val | valid in serial training | means more off-policy 11 learn.kappa float / | Threshold of Huber loss == ==================== ======== ============== ======================================== =======================

default_model()

Overview

Returns the default model configuration used by the FQF algorithm. __init__ method will automatically call this method to get the default model setting and create model.

Returns:

Type Description
Tuple[str, List[str]]
  • model_info (:obj:Tuple[str, List[str]]): Tuple containing the registered model name and model's import_names.

compute_grad_norm(model)

Overview

Compute grad norm of a network's parameters.

Arguments: - model (:obj:nn.Module): The network to compute grad norm. Returns: - grad_norm (:obj:torch.Tensor): The grad norm of the network's parameters.

Full Source Code

../ding/policy/fqf.py

1import copy 2from typing import List, Dict, Any, Tuple 3 4import torch 5 6from ding.model import model_wrap 7from ding.rl_utils import fqf_nstep_td_data, fqf_nstep_td_error, fqf_calculate_fraction_loss 8from ding.torch_utils import Adam, RMSprop, to_device 9from ding.utils import POLICY_REGISTRY 10from .common_utils import default_preprocess_learn 11from .dqn import DQNPolicy 12 13 14def compute_grad_norm(model): 15 """ 16 Overview: 17 Compute grad norm of a network's parameters. 18 Arguments: 19 - model (:obj:`nn.Module`): The network to compute grad norm. 20 Returns: 21 - grad_norm (:obj:`torch.Tensor`): The grad norm of the network's parameters. 22 """ 23 return torch.norm(torch.stack([torch.norm(p.grad.detach(), 2.0) for p in model.parameters()]), 2.0) 24 25 26@POLICY_REGISTRY.register('fqf') 27class FQFPolicy(DQNPolicy): 28 """ 29 Overview: 30 Policy class of FQF (Fully Parameterized Quantile Function) algorithm, proposed in 31 https://arxiv.org/pdf/1911.02140.pdf. 32 33 Config: 34 == ==================== ======== ============== ======================================== ======================= 35 ID Symbol Type Default Value Description Other(Shape) 36 == ==================== ======== ============== ======================================== ======================= 37 1 ``type`` str fqf | RL policy register name, refer to | this arg is optional, 38 | registry ``POLICY_REGISTRY`` | a placeholder 39 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff- 40 | erent from modes 41 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy 42 | or off-policy 43 4 ``priority`` bool True | Whether use priority(PER) | priority sample, 44 | update priority 45 6 | ``other.eps`` float 0.05 | Start value for epsilon decay. It's 46 | ``.start`` | small because rainbow use noisy net. 47 7 | ``other.eps`` float 0.05 | End value for epsilon decay. 48 | ``.end`` 49 8 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | may be 1 when sparse 50 | ``factor`` [0.95, 0.999] | gamma | reward env 51 9 ``nstep`` int 3, | N-step reward discount sum for target 52 [3, 5] | q_value estimation 53 10 | ``learn.update`` int 3 | How many updates(iterations) to train | this args can be vary 54 | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val 55 | valid in serial training | means more off-policy 56 11 ``learn.kappa`` float / | Threshold of Huber loss 57 == ==================== ======== ============== ======================================== ======================= 58 """ 59 60 config = dict( 61 # (str) Name of the RL policy registered in "POLICY_REGISTRY" function. 62 type='fqf', 63 # (bool) Flag to enable/disable CUDA for network computation. 64 cuda=False, 65 # (bool) Indicator of the RL algorithm's policy type (True for on-policy algorithms). 66 on_policy=False, 67 # (bool) Toggle for using prioritized experience replay (priority sampling and updating). 68 priority=False, 69 # (float) Discount factor (gamma) for calculating the future reward. 70 discount_factor=0.97, 71 # (int) Number of steps to consider for calculating n-step returns. 72 nstep=1, 73 learn=dict( 74 # (int) Number of training iterations per data collection from the environment. 75 update_per_collect=3, 76 # (int) Size of minibatch for each update. 77 batch_size=64, 78 # (float) Fractional learning rate for the fraction proposal network. 79 learning_rate_fraction=2.5e-9, 80 # (float) Learning rate for the quantile regression network. 81 learning_rate_quantile=0.00005, 82 # ============================================================== 83 # Algorithm-specific configurations 84 # ============================================================== 85 # (int) Frequency of target network updates. 86 target_update_freq=100, 87 # (float) Huber loss threshold (kappa in the FQF paper). 88 kappa=1.0, 89 # (float) Coefficient for the entropy loss term. 90 ent_coef=0, 91 # (bool) If set to True, the 'done' signals that indicate the end of an episode due to environment time 92 # limits are disregarded. By default, this is set to False. This setting is particularly useful for tasks 93 # that have a predetermined episode length, such as HalfCheetah and various other MuJoCo environments, 94 # where the maximum length is capped at 1000 steps. When enabled, any 'done' signal triggered by reaching 95 # the maximum episode steps will be overridden to 'False'. This ensures the accurate calculation of the 96 # Temporal Difference (TD) error, using the formula `gamma * (1 - done) * next_v + reward`, 97 # even when the episode surpasses the predefined step limit. 98 ignore_done=False, 99 ), 100 collect=dict( 101 # (int) Specify one of [n_sample, n_step, n_episode] for data collection. 102 # n_sample=8, 103 # (int) Length of trajectory segments for processing. 104 unroll_len=1, 105 ), 106 eval=dict(), 107 other=dict( 108 # Epsilon-greedy strategy with a decay mechanism. 109 eps=dict( 110 # (str) Type of decay mechanism ['exp' for exponential, 'linear']. 111 type='exp', 112 # (float) Initial value of epsilon in epsilon-greedy exploration. 113 start=0.95, 114 # (float) Final value of epsilon after decay. 115 end=0.1, 116 # (int) Number of environment steps over which epsilon is decayed. 117 decay=10000, 118 ), 119 replay_buffer=dict( 120 # (int) Size of the replay buffer. 121 replay_buffer_size=10000, 122 ), 123 ), 124 ) 125 126 def default_model(self) -> Tuple[str, List[str]]: 127 """ 128 Overview: 129 Returns the default model configuration used by the FQF algorithm. ``__init__`` method will \ 130 automatically call this method to get the default model setting and create model. 131 132 Returns: 133 - model_info (:obj:`Tuple[str, List[str]]`): \ 134 Tuple containing the registered model name and model's import_names. 135 """ 136 return 'fqf', ['ding.model.template.q_learning'] 137 138 def _init_learn(self) -> None: 139 """ 140 Overview: 141 Initialize the learn mode of policy, including related attributes and modules. For FQF, it mainly \ 142 contains optimizer, algorithm-specific arguments such as gamma, nstep, kappa ent_coef, main and \ 143 target model. This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 144 145 .. note:: 146 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 147 and ``_load_state_dict_learn`` methods. 148 149 .. note:: 150 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 151 152 .. note:: 153 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 154 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 155 """ 156 self._priority = self._cfg.priority 157 # Optimizer 158 self._fraction_loss_optimizer = RMSprop( 159 self._model.head.quantiles_proposal.parameters(), 160 lr=self._cfg.learn.learning_rate_fraction, 161 alpha=0.95, 162 eps=0.00001 163 ) 164 self._quantile_loss_optimizer = Adam( 165 list(self._model.head.Q.parameters()) + list(self._model.head.fqf_fc.parameters()) + 166 list(self._model.encoder.parameters()), 167 lr=self._cfg.learn.learning_rate_quantile, 168 eps=1e-2 / self._cfg.learn.batch_size 169 ) 170 171 self._gamma = self._cfg.discount_factor 172 self._nstep = self._cfg.nstep 173 self._kappa = self._cfg.learn.kappa 174 self._ent_coef = self._cfg.learn.ent_coef 175 176 # use model_wrapper for specialized demands of different modes 177 self._target_model = copy.deepcopy(self._model) 178 self._target_model = model_wrap( 179 self._target_model, 180 wrapper_name='target', 181 update_type='assign', 182 update_kwargs={'freq': self._cfg.learn.target_update_freq} 183 ) 184 self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') 185 self._learn_model.reset() 186 self._target_model.reset() 187 188 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 189 """ 190 Overview: 191 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 192 that the policy inputs some training batch data from the replay buffer and then returns the output \ 193 result, including various training information such as policy_loss, value_loss, entropy_loss. 194 Arguments: 195 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 196 training samples. For each element in list, the key of the dict is the name of data items and the \ 197 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 198 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 199 dimension by some utility functions such as ``default_preprocess_learn``. \ 200 For FQF, each element in list is a dict containing at least the following keys: \ 201 ['obs', 'action', 'reward', 'next_obs']. 202 Returns: 203 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 204 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 205 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 206 207 .. note:: 208 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 209 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 210 You can implement your own model rather than use the default model. For more information, please raise an \ 211 issue in GitHub repo and we will continue to follow up. 212 """ 213 # Data preprocessing operations, such as stack data, cpu to cuda device 214 data = default_preprocess_learn( 215 data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True 216 ) 217 if self._cuda: 218 data = to_device(data, self._device) 219 # ==================== 220 # Q-learning forward 221 # ==================== 222 self._learn_model.train() 223 self._target_model.train() 224 # Current q value (main model) 225 ret = self._learn_model.forward(data['obs']) 226 logit = ret['logit'] # [batch, action_dim(64)] 227 q_value = ret['q'] # [batch, num_quantiles, action_dim(64)] 228 quantiles = ret['quantiles'] # [batch, num_quantiles+1] 229 quantiles_hats = ret['quantiles_hats'] # [batch, num_quantiles], requires_grad = False 230 q_tau_i = ret['q_tau_i'] # [batch_size, num_quantiles-1, action_dim(64)] 231 entropies = ret['entropies'] # [batch, 1] 232 233 # Target q value 234 with torch.no_grad(): 235 target_q_value = self._target_model.forward(data['next_obs'])['q'] 236 # Max q value action (main model) 237 target_q_action = self._learn_model.forward(data['next_obs'])['action'] 238 239 data_n = fqf_nstep_td_data( 240 q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], quantiles_hats, 241 data['weight'] 242 ) 243 value_gamma = data.get('value_gamma') 244 entropy_loss = -self._ent_coef * entropies.mean() 245 fraction_loss = fqf_calculate_fraction_loss(q_tau_i.detach(), q_value, quantiles, data['action']) + entropy_loss 246 quantile_loss, td_error_per_sample = fqf_nstep_td_error( 247 data_n, self._gamma, nstep=self._nstep, kappa=self._kappa, value_gamma=value_gamma 248 ) 249 250 # ==================== 251 # fraction_proposal network update 252 # ==================== 253 self._fraction_loss_optimizer.zero_grad() 254 fraction_loss.backward(retain_graph=True) 255 if self._cfg.multi_gpu: 256 self.sync_gradients(self._learn_model) 257 with torch.no_grad(): 258 total_norm_quantiles_proposal = compute_grad_norm(self._model.head.quantiles_proposal) 259 self._fraction_loss_optimizer.step() 260 261 # ==================== 262 # Q-learning update 263 # ==================== 264 self._quantile_loss_optimizer.zero_grad() 265 quantile_loss.backward() 266 if self._cfg.multi_gpu: 267 self.sync_gradients(self._learn_model) 268 with torch.no_grad(): 269 total_norm_Q = compute_grad_norm(self._model.head.Q) 270 total_norm_fqf_fc = compute_grad_norm(self._model.head.fqf_fc) 271 total_norm_encoder = compute_grad_norm(self._model.encoder) 272 self._quantile_loss_optimizer.step() 273 274 # ============= 275 # after update 276 # ============= 277 self._target_model.update(self._learn_model.state_dict()) 278 return { 279 'cur_lr_fraction_loss': self._fraction_loss_optimizer.defaults['lr'], 280 'cur_lr_quantile_loss': self._quantile_loss_optimizer.defaults['lr'], 281 'logit': logit.mean().item(), 282 'fraction_loss': fraction_loss.item(), 283 'quantile_loss': quantile_loss.item(), 284 'total_norm_quantiles_proposal': total_norm_quantiles_proposal, 285 'total_norm_Q': total_norm_Q, 286 'total_norm_fqf_fc': total_norm_fqf_fc, 287 'total_norm_encoder': total_norm_encoder, 288 'priority': td_error_per_sample.abs().tolist(), 289 # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard. 290 '[histogram]action_distribution': data['action'], 291 '[histogram]quantiles_hats': quantiles_hats[0], # quantiles_hats.requires_grad = False 292 } 293 294 def _monitor_vars_learn(self) -> List[str]: 295 """ 296 Overview: 297 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 298 as text logger, tensorboard logger, will use these keys to save the corresponding data. 299 Returns: 300 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 301 """ 302 return [ 303 'cur_lr_fraction_loss', 'cur_lr_quantile_loss', 'logit', 'fraction_loss', 'quantile_loss', 304 'total_norm_quantiles_proposal', 'total_norm_Q', 'total_norm_fqf_fc', 'total_norm_encoder' 305 ] 306 307 def _state_dict_learn(self) -> Dict[str, Any]: 308 """ 309 Overview: 310 Return the state_dict of learn mode, usually including model and optimizer. 311 Returns: 312 - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. 313 """ 314 return { 315 'model': self._learn_model.state_dict(), 316 'target_model': self._target_model.state_dict(), 317 'optimizer_fraction_loss': self._fraction_loss_optimizer.state_dict(), 318 'optimizer_quantile_loss': self._quantile_loss_optimizer.state_dict(), 319 } 320 321 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 322 """ 323 Overview: 324 Load the state_dict variable into policy learn mode. 325 Arguments: 326 - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. 327 328 .. tip:: 329 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 330 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 331 complicated operation. 332 """ 333 self._learn_model.load_state_dict(state_dict['model']) 334 self._target_model.load_state_dict(state_dict['target_model']) 335 self._fraction_loss_optimizer.load_state_dict(state_dict['optimizer_fraction_loss']) 336 self._quantile_loss_optimizer.load_state_dict(state_dict['optimizer_quantile_loss'])