Skip to content

ding.policy.bc

ding.policy.bc

BehaviourCloningPolicy

Bases: Policy

Overview

Behaviour Cloning (BC) policy class, which supports both discrete and continuous action space. The policy is trained by supervised learning, and the data is a offline dataset collected by expert.

default_model()

Overview

Return this algorithm default neural network model setting for demonstration. __init__ method will automatically call this method to get the default model setting and create model.

Returns: - model_info (:obj:Tuple[str, List[str]]): The registered model name and model's import_names.

.. note:: The user can define and use customized network model but must obey the same inferface definition indicated by import_names path. For example about discrete BC, its registered name is discrete_bc and the import_names is ding.model.template.bc.

Full Source Code

../ding/policy/bc.py

1import math 2import torch 3import torch.nn as nn 4import copy 5from torch.optim import Adam, SGD, AdamW 6from torch.optim.lr_scheduler import LambdaLR 7import logging 8from typing import List, Dict, Any, Tuple, Union, Optional 9from collections import namedtuple 10from easydict import EasyDict 11from ding.policy import Policy 12from ding.model import model_wrap 13from ding.torch_utils import to_device, to_list 14from ding.utils import EasyTimer 15from ding.utils.data import default_collate, default_decollate 16from ding.rl_utils import get_nstep_return_data, get_train_sample 17from ding.utils import POLICY_REGISTRY 18from ding.torch_utils.loss.cross_entropy_loss import LabelSmoothCELoss 19 20 21@POLICY_REGISTRY.register('bc') 22class BehaviourCloningPolicy(Policy): 23 """ 24 Overview: 25 Behaviour Cloning (BC) policy class, which supports both discrete and continuous action space. \ 26 The policy is trained by supervised learning, and the data is a offline dataset collected by expert. 27 """ 28 29 config = dict( 30 type='bc', 31 cuda=False, 32 on_policy=False, 33 continuous=False, 34 action_shape=19, 35 learn=dict( 36 update_per_collect=1, 37 batch_size=32, 38 learning_rate=1e-5, 39 lr_decay=False, 40 decay_epoch=30, 41 decay_rate=0.1, 42 warmup_lr=1e-4, 43 warmup_epoch=3, 44 optimizer='SGD', 45 momentum=0.9, 46 weight_decay=1e-4, 47 ce_label_smooth=False, 48 show_accuracy=False, 49 tanh_mask=False, # if actions always converge to 1 or -1, use this. 50 ), 51 collect=dict( 52 unroll_len=1, 53 noise=False, 54 noise_sigma=0.2, 55 noise_range=dict( 56 min=-0.5, 57 max=0.5, 58 ), 59 ), 60 eval=dict(), # for compatibility 61 ) 62 63 def default_model(self) -> Tuple[str, List[str]]: 64 """ 65 Overview: 66 Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ 67 automatically call this method to get the default model setting and create model. 68 Returns: 69 - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. 70 71 .. note:: 72 The user can define and use customized network model but must obey the same inferface definition indicated \ 73 by import_names path. For example about discrete BC, its registered name is ``discrete_bc`` and the \ 74 import_names is ``ding.model.template.bc``. 75 """ 76 if self._cfg.continuous: 77 return 'continuous_bc', ['ding.model.template.bc'] 78 else: 79 return 'discrete_bc', ['ding.model.template.bc'] 80 81 def _init_learn(self) -> None: 82 """ 83 Overview: 84 Initialize the learn mode of policy, including related attributes and modules. For BC, it mainly contains \ 85 optimizer, algorithm-specific arguments such as lr_scheduler, loss, etc. \ 86 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 87 88 .. note:: 89 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 90 and ``_load_state_dict_learn`` methods. 91 92 .. note:: 93 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 94 95 .. note:: 96 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 97 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 98 """ 99 assert self._cfg.learn.optimizer in ['SGD', 'Adam'], self._cfg.learn.optimizer 100 if self._cfg.learn.optimizer == 'SGD': 101 self._optimizer = SGD( 102 self._model.parameters(), 103 lr=self._cfg.learn.learning_rate, 104 weight_decay=self._cfg.learn.weight_decay, 105 momentum=self._cfg.learn.momentum 106 ) 107 elif self._cfg.learn.optimizer == 'Adam': 108 if self._cfg.learn.weight_decay is None: 109 self._optimizer = Adam( 110 self._model.parameters(), 111 lr=self._cfg.learn.learning_rate, 112 ) 113 else: 114 self._optimizer = AdamW( 115 self._model.parameters(), 116 lr=self._cfg.learn.learning_rate, 117 weight_decay=self._cfg.learn.weight_decay 118 ) 119 if self._cfg.learn.lr_decay: 120 121 def lr_scheduler_fn(epoch): 122 if epoch <= self._cfg.learn.warmup_epoch: 123 return self._cfg.learn.warmup_lr / self._cfg.learn.learning_rate 124 else: 125 ratio = (epoch - self._cfg.learn.warmup_epoch) // self._cfg.learn.decay_epoch 126 return math.pow(self._cfg.learn.decay_rate, ratio) 127 128 self._lr_scheduler = LambdaLR(self._optimizer, lr_scheduler_fn) 129 self._timer = EasyTimer(cuda=True) 130 self._learn_model = model_wrap(self._model, 'base') 131 self._learn_model.reset() 132 133 if self._cfg.continuous: 134 if self._cfg.loss_type == 'l1_loss': 135 self._loss = nn.L1Loss() 136 elif self._cfg.loss_type == 'mse_loss': 137 self._loss = nn.MSELoss() 138 else: 139 raise KeyError("not support loss type: {}".format(self._cfg.loss_type)) 140 else: 141 if not self._cfg.learn.ce_label_smooth: 142 self._loss = nn.CrossEntropyLoss() 143 else: 144 self._loss = LabelSmoothCELoss(0.1) 145 146 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 147 """ 148 Overview: 149 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 150 that the policy inputs some training batch data from the replay buffer and then returns the output \ 151 result, including various training information such as loss and time. 152 Arguments: 153 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 154 training samples. For each element in list, the key of the dict is the name of data items and the \ 155 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 156 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 157 dimension by some utility functions such as ``default_preprocess_learn``. \ 158 For BC, each element in list is a dict containing at least the following keys: ``obs``, ``action``. 159 Returns: 160 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 161 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 162 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 163 164 .. note:: 165 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 166 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 167 You can implement you own model rather than use the default model. For more information, please raise an \ 168 issue in GitHub repo and we will continue to follow up. 169 """ 170 if isinstance(data, list): 171 data = default_collate(data) 172 if self._cuda: 173 data = to_device(data, self._device) 174 self._learn_model.train() 175 with self._timer: 176 obs, action = data['obs'], data['action'].squeeze() 177 if self._cfg.continuous: 178 if self._cfg.learn.tanh_mask: 179 """tanh_mask 180 We mask the action out of range of [tanh(-1),tanh(1)], model will learn information 181 and produce action in [-1,1]. So the action won't always converge to -1 or 1. 182 """ 183 mu = self._eval_model.forward(data['obs'])['action'] 184 bound = 1 - 2 / (math.exp(2) + 1) # tanh(1): (e-e**(-1))/(e+e**(-1)) 185 mask = mu.ge(-bound) & mu.le(bound) 186 mask_percent = 1 - mask.sum().item() / mu.numel() 187 if mask_percent > 0.8: # if there is too little data to learn(<80%). So we use all data. 188 loss = self._loss(mu, action.detach()) 189 else: 190 loss = self._loss(mu.masked_select(mask), action.masked_select(mask).detach()) 191 else: 192 mu = self._learn_model.forward(data['obs'])['action'] 193 # When we use bco, action is predicted by idm, gradient is not expected. 194 loss = self._loss(mu, action.detach()) 195 else: 196 a_logit = self._learn_model.forward(obs) 197 # When we use bco, action is predicted by idm, gradient is not expected. 198 loss = self._loss(a_logit['logit'], action.detach()) 199 200 if self._cfg.learn.show_accuracy: 201 # Calculate the overall accuracy and the accuracy of each class 202 total_accuracy = (a_logit['action'] == action.view(-1)).float().mean() 203 self.total_accuracy_in_dataset.append(total_accuracy) 204 logging.info(f'the total accuracy in current train mini-batch is: {total_accuracy.item()}') 205 for action_unique in to_list(torch.unique(action)): 206 action_index = (action == action_unique).nonzero(as_tuple=True)[0] 207 action_accuracy = (a_logit['action'][action_index] == action.view(-1)[action_index] 208 ).float().mean() 209 if math.isnan(action_accuracy): 210 action_accuracy = 0.0 211 self.action_accuracy_in_dataset[action_unique].append(action_accuracy) 212 logging.info( 213 f'the accuracy of action {action_unique} in current train mini-batch is: ' 214 f'{action_accuracy.item()}, ' 215 f'(nan means the action does not appear in the mini-batch)' 216 ) 217 forward_time = self._timer.value 218 with self._timer: 219 self._optimizer.zero_grad() 220 loss.backward() 221 backward_time = self._timer.value 222 with self._timer: 223 if self._cfg.multi_gpu: 224 self.sync_gradients(self._learn_model) 225 sync_time = self._timer.value 226 self._optimizer.step() 227 cur_lr = [param_group['lr'] for param_group in self._optimizer.param_groups] 228 cur_lr = sum(cur_lr) / len(cur_lr) 229 return { 230 'cur_lr': cur_lr, 231 'total_loss': loss.item(), 232 'forward_time': forward_time, 233 'backward_time': backward_time, 234 'sync_time': sync_time, 235 } 236 237 def _monitor_vars_learn(self) -> List[str]: 238 """ 239 Overview: 240 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 241 as text logger, tensorboard logger, will use these keys to save the corresponding data. 242 Returns: 243 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 244 """ 245 return ['cur_lr', 'total_loss', 'forward_time', 'backward_time', 'sync_time'] 246 247 def _init_eval(self): 248 """ 249 Overview: 250 Initialize the eval mode of policy, including related attributes and modules. For BC, it contains the \ 251 eval model to greedily select action with argmax q_value mechanism for discrete action space. 252 This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. 253 254 .. note:: 255 If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ 256 with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. 257 """ 258 if self._cfg.continuous: 259 self._eval_model = model_wrap(self._model, wrapper_name='base') 260 else: 261 self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') 262 self._eval_model.reset() 263 264 def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: 265 """ 266 Overview: 267 Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ 268 means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ 269 action to interact with the envs. 270 Arguments: 271 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 272 key of the dict is environment id and the value is the corresponding data of the env. 273 Returns: 274 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ 275 key of the dict is the same as the input data, i.e. environment id. 276 277 .. note:: 278 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 279 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 280 You can implement you own model rather than use the default model. For more information, please raise an \ 281 issue in GitHub repo and we will continue to follow up. 282 """ 283 tensor_input = isinstance(data, torch.Tensor) 284 if tensor_input: 285 data = default_collate(list(data)) 286 else: 287 data_id = list(data.keys()) 288 data = default_collate(list(data.values())) 289 if self._cuda: 290 data = to_device(data, self._device) 291 self._eval_model.eval() 292 with torch.no_grad(): 293 output = self._eval_model.forward(data) 294 if self._cuda: 295 output = to_device(output, 'cpu') 296 if tensor_input: 297 return output 298 else: 299 output = default_decollate(output) 300 return {i: d for i, d in zip(data_id, output)} 301 302 def _init_collect(self) -> None: 303 """ 304 Overview: 305 BC policy uses offline dataset so it does not need to collect data. However, sometimes we need to use the \ 306 trained BC policy to collect data for other purposes. 307 """ 308 self._unroll_len = self._cfg.collect.unroll_len 309 if self._cfg.continuous: 310 self._collect_model = model_wrap( 311 self._model, 312 wrapper_name='action_noise', 313 noise_type='gauss', 314 noise_kwargs={ 315 'mu': 0.0, 316 'sigma': self._cfg.collect.noise_sigma.start 317 }, 318 noise_range=self._cfg.collect.noise_range 319 ) 320 else: 321 self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample') 322 self._collect_model.reset() 323 324 def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: 325 data_id = list(data.keys()) 326 data = default_collate(list(data.values())) 327 if self._cuda: 328 data = to_device(data, self._device) 329 self._collect_model.eval() 330 with torch.no_grad(): 331 if self._cfg.continuous: 332 # output = self._collect_model.forward(data) 333 output = self._collect_model.forward(data, **kwargs) 334 else: 335 output = self._collect_model.forward(data, **kwargs) 336 if self._cuda: 337 output = to_device(output, 'cpu') 338 output = default_decollate(output) 339 return {i: d for i, d in zip(data_id, output)} 340 341 def _process_transition(self, obs: Any, policy_output: dict, timestep: namedtuple) -> dict: 342 transition = { 343 'obs': obs, 344 'next_obs': timestep.obs, 345 'action': policy_output['action'], 346 'reward': timestep.reward, 347 'done': timestep.done, 348 } 349 return EasyDict(transition) 350 351 def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 352 data = get_nstep_return_data(data, 1, 1) 353 return get_train_sample(data, self._unroll_len)