Skip to content

ding.policy.ibc

ding.policy.ibc

IBCPolicy

Bases: BehaviourCloningPolicy

Overview

Policy class of IBC (Implicit Behavior Cloning), proposed in https://arxiv.org/abs/2109.00137.pdf.

.. note:: The code is adapted from the pytorch version of IBC https://github.com/kevinzakka/ibc, which only supports the \ derivative-free optimization (dfo) variants. This implementation moves a step forward and supports all \ variants of energy-based model mentioned in the paper (dfo, autoregressive dfo, and mcmc).

default_model()

Overview

Returns the default model configuration used by the IBC algorithm. __init__ method will automatically call this method to get the default model setting and create model.

Returns:

Type Description
Tuple[str, List[str]]
  • model_info (:obj:Tuple[str, List[str]]): Tuple containing the registered model name and model's import_names.

set_statistic(statistics)

Overview

Set the statistics of the environment, including the action space and the observation space.

Arguments: - statistics (:obj:EasyDict): The statistics of the environment. For IBC, it contains at least the following keys: ['action_bounds'].

Full Source Code

../ding/policy/ibc.py

1from typing import Dict, Any, List, Tuple 2from collections import namedtuple 3from easydict import EasyDict 4 5import torch 6import torch.nn.functional as F 7 8from ding.model import model_wrap 9from ding.torch_utils import to_device 10from ding.utils.data import default_collate, default_decollate 11from ding.utils import POLICY_REGISTRY 12from .bc import BehaviourCloningPolicy 13from ding.model.template.ebm import create_stochastic_optimizer 14from ding.model.template.ebm import StochasticOptimizer, MCMC, AutoRegressiveDFO 15from ding.torch_utils import unsqueeze_repeat 16from ding.utils import EasyTimer 17 18 19@POLICY_REGISTRY.register('ibc') 20class IBCPolicy(BehaviourCloningPolicy): 21 r""" 22 Overview: 23 Policy class of IBC (Implicit Behavior Cloning), proposed in https://arxiv.org/abs/2109.00137.pdf. 24 25 .. note:: 26 The code is adapted from the pytorch version of IBC https://github.com/kevinzakka/ibc, which only supports the \ 27 derivative-free optimization (dfo) variants. This implementation moves a step forward and supports all \ 28 variants of energy-based model mentioned in the paper (dfo, autoregressive dfo, and mcmc). 29 """ 30 31 config = dict( 32 # (str) The policy type. 'ibc' refers to Implicit Behavior Cloning. 33 type='ibc', 34 # (bool) Whether to use CUDA for training. False means CPU will be used. 35 cuda=False, 36 # (bool) If True, the policy will operate on-policy. Here it's False, indicating off-policy. 37 on_policy=False, 38 # (bool) Whether the action space is continuous. True for continuous action space. 39 continuous=True, 40 # (dict) Configuration for the model, including stochastic optimization settings. 41 model=dict( 42 # (dict) Configuration for the stochastic optimization, specifying the type of optimizer. 43 stochastic_optim=dict( 44 # (str) The type of stochastic optimizer. 'mcmc' refers to Markov Chain Monte Carlo methods. 45 type='mcmc', 46 ), 47 ), 48 # (dict) Configuration for the learning process. 49 learn=dict( 50 # (int) The number of training epochs. 51 train_epoch=30, 52 # (int) The size of batches used during training. 53 batch_size=256, 54 # (dict) Configuration for the optimizer used during training. 55 optim=dict( 56 # (float) The learning rate for the optimizer. 57 learning_rate=1e-5, 58 # (float) The weight decay regularization term for the optimizer. 59 weight_decay=0.0, 60 # (float) The beta1 hyperparameter for the AdamW optimizer. 61 beta1=0.9, 62 # (float) The beta2 hyperparameter for the AdamW optimizer. 63 beta2=0.999, 64 ), 65 ), 66 # (dict) Configuration for the evaluation process. 67 eval=dict( 68 # (dict) Configuration for the evaluator. 69 evaluator=dict( 70 # (int) The frequency of evaluations during training, in terms of number of training steps. 71 eval_freq=10000, 72 ), 73 ), 74 ) 75 76 def default_model(self) -> Tuple[str, List[str]]: 77 """ 78 Overview: 79 Returns the default model configuration used by the IBC algorithm. ``__init__`` method will \ 80 automatically call this method to get the default model setting and create model. 81 82 Returns: 83 - model_info (:obj:`Tuple[str, List[str]]`): \ 84 Tuple containing the registered model name and model's import_names. 85 """ 86 return 'ebm', ['ding.model.template.ebm'] 87 88 def _init_learn(self) -> None: 89 """ 90 Overview: 91 Initialize the learn mode of policy, including related attributes and modules. For IBC, it mainly \ 92 contains optimizer and main model. \ 93 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 94 95 .. note:: 96 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 97 and ``_load_state_dict_learn`` methods. 98 99 .. note:: 100 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 101 102 .. note:: 103 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 104 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 105 """ 106 self._timer = EasyTimer(cuda=self._cfg.cuda) 107 self._sync_timer = EasyTimer(cuda=self._cfg.cuda) 108 optim_cfg = self._cfg.learn.optim 109 self._optimizer = torch.optim.AdamW( 110 self._model.parameters(), 111 lr=optim_cfg.learning_rate, 112 weight_decay=optim_cfg.weight_decay, 113 betas=(optim_cfg.beta1, optim_cfg.beta2), 114 ) 115 self._stochastic_optimizer: StochasticOptimizer = \ 116 create_stochastic_optimizer(self._device, self._cfg.model.stochastic_optim) 117 self._learn_model = model_wrap(self._model, 'base') 118 self._learn_model.reset() 119 120 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 121 """ 122 Overview: 123 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 124 that the policy inputs some training batch data from the replay buffer and then returns the output \ 125 result, including various training information such as policy_loss, value_loss, entropy_loss. 126 Arguments: 127 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 128 training samples. For each element in list, the key of the dict is the name of data items and the \ 129 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 130 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 131 dimension by some utility functions such as ``default_preprocess_learn``. \ 132 For IBC, each element in list is a dict containing at least the following keys: \ 133 ['obs', 'action']. 134 Returns: 135 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 136 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 137 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 138 139 .. note:: 140 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 141 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 142 You can implement your own model rather than use the default model. For more information, please raise an \ 143 issue in GitHub repo and we will continue to follow up. 144 """ 145 with self._timer: 146 data = default_collate(data) 147 if self._cuda: 148 data = to_device(data, self._device) 149 self._learn_model.train() 150 151 loss_dict = dict() 152 153 # obs: (B, O) 154 # action: (B, A) 155 obs, action = data['obs'], data['action'] 156 # When action/observation space is 1, the action/observation dimension will 157 # be squeezed in the first place, therefore unsqueeze there to make the data 158 # compatible with the ibc pipeline. 159 if len(obs.shape) == 1: 160 obs = obs.unsqueeze(-1) 161 if len(action.shape) == 1: 162 action = action.unsqueeze(-1) 163 164 # N refers to the number of negative samples, i.e. self._stochastic_optimizer.inference_samples. 165 # (B, N, O), (B, N, A) 166 obs, negatives = self._stochastic_optimizer.sample(obs, self._learn_model) 167 168 # (B, N+1, A) 169 targets = torch.cat([action.unsqueeze(dim=1), negatives], dim=1) 170 # (B, N+1, O) 171 obs = torch.cat([obs[:, :1], obs], dim=1) 172 173 permutation = torch.rand(targets.shape[0], targets.shape[1]).argsort(dim=1) 174 targets = targets[torch.arange(targets.shape[0]).unsqueeze(-1), permutation] 175 176 # (B, ) 177 ground_truth = (permutation == 0).nonzero()[:, 1].to(self._device) 178 179 # (B, N+1) for ebm 180 # (B, N+1, A) for autoregressive ebm 181 energy = self._learn_model.forward(obs, targets) 182 183 logits = -1.0 * energy 184 if isinstance(self._stochastic_optimizer, AutoRegressiveDFO): 185 # autoregressive case 186 # (B, A) 187 ground_truth = unsqueeze_repeat(ground_truth, logits.shape[-1], -1) 188 loss = F.cross_entropy(logits, ground_truth) 189 loss_dict['ebm_loss'] = loss.item() 190 191 if isinstance(self._stochastic_optimizer, MCMC): 192 grad_penalty = self._stochastic_optimizer.grad_penalty(obs, targets, self._learn_model) 193 loss += grad_penalty 194 loss_dict['grad_penalty'] = grad_penalty.item() 195 loss_dict['total_loss'] = loss.item() 196 197 self._optimizer.zero_grad() 198 loss.backward() 199 with self._sync_timer: 200 if self._cfg.multi_gpu: 201 self.sync_gradients(self._learn_model) 202 sync_time = self._sync_timer.value 203 self._optimizer.step() 204 205 total_time = self._timer.value 206 207 return { 208 'total_time': total_time, 209 'sync_time': sync_time, 210 **loss_dict, 211 } 212 213 def _monitor_vars_learn(self) -> List[str]: 214 """ 215 Overview: 216 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 217 as text logger, tensorboard logger, will use these keys to save the corresponding data. 218 Returns: 219 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 220 """ 221 if isinstance(self._stochastic_optimizer, MCMC): 222 return ['total_loss', 'ebm_loss', 'grad_penalty', 'total_time', 'sync_time'] 223 else: 224 return ['total_loss', 'ebm_loss', 'total_time', 'sync_time'] 225 226 def _init_eval(self) -> None: 227 """ 228 Overview: 229 Initialize the eval mode of policy, including related attributes and modules. 230 This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. 231 232 .. note:: 233 If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ 234 with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. 235 """ 236 self._eval_model = model_wrap(self._model, wrapper_name='base') 237 self._eval_model.reset() 238 239 def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: 240 """ 241 Overview: 242 Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ 243 means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ 244 action to interact with the envs. 245 Arguments: 246 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 247 key of the dict is environment id and the value is the corresponding data of the env. 248 Returns: 249 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ 250 key of the dict is the same as the input data, i.e., environment id. 251 252 .. note:: 253 The input value can be ``torch.Tensor`` or dict/list combinations, current policy supports all of them. \ 254 For the data type that is not supported, the main reason is that the corresponding model does not \ 255 support it. You can implement your own model rather than use the default model. For more information, \ 256 please raise an issue in GitHub repo, and we will continue to follow up. 257 """ 258 tensor_input = isinstance(data, torch.Tensor) 259 if not tensor_input: 260 data_id = list(data.keys()) 261 data = default_collate(list(data.values())) 262 263 if self._cuda: 264 data = to_device(data, self._device) 265 266 self._eval_model.eval() 267 output = self._stochastic_optimizer.infer(data, self._eval_model) 268 output = dict(action=output) 269 270 if self._cuda: 271 output = to_device(output, 'cpu') 272 if tensor_input: 273 return output 274 else: 275 output = default_decollate(output) 276 return {i: d for i, d in zip(data_id, output)} 277 278 def set_statistic(self, statistics: EasyDict) -> None: 279 """ 280 Overview: 281 Set the statistics of the environment, including the action space and the observation space. 282 Arguments: 283 - statistics (:obj:`EasyDict`): The statistics of the environment. For IBC, it contains at least the \ 284 following keys: ['action_bounds']. 285 """ 286 self._stochastic_optimizer.set_action_bounds(statistics.action_bounds) 287 288 # =================================================================== # 289 # Implicit Behavioral Cloning does not need `collect`-related functions 290 # =================================================================== # 291 def _init_collect(self): 292 raise NotImplementedError 293 294 def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: 295 raise NotImplementedError 296 297 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: 298 raise NotImplementedError 299 300 def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 301 raise NotImplementedError