Skip to content

ding.model.template.ebm

ding.model.template.ebm

Vanilla DFO and EBM are adapted from https://github.com/kevinzakka/ibc. MCMC is adapted from https://github.com/google-research/ibc.

StochasticOptimizer

Bases: ABC

Overview

Base class for stochastic optimizers.

Interface: __init__, _sample, _get_best_action_sample, set_action_bounds, sample, infer

set_action_bounds(action_bounds)

Overview

Set action bounds calculated from the dataset statistics.

Arguments: - action_bounds (:obj:np.ndarray): Array of shape (2, A), where action_bounds[0] is lower bound and action_bounds[1] is upper bound. Returns: - action_bounds (:obj:torch.Tensor): Action bounds. Shapes: - action_bounds (:obj:np.ndarray): :math:(2, A). - action_bounds (:obj:torch.Tensor): :math:(2, A). Examples: >>> opt = StochasticOptimizer() >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))

sample(obs, ebm) abstractmethod

Overview

Create tiled observations and sample counter-negatives for InfoNCE loss.

Arguments: - obs (:obj:torch.Tensor): Observations. - ebm (:obj:torch.nn.Module): Energy based model. Returns: - tiled_obs (:obj:torch.Tensor): Tiled observations. - action (:obj:torch.Tensor): Actions. Shapes: - obs (:obj:torch.Tensor): :math:(B, O). - ebm (:obj:torch.nn.Module): :math:(B, N, O). - tiled_obs (:obj:torch.Tensor): :math:(B, N, O). - action (:obj:torch.Tensor): :math:(B, N, A).

.. note:: In the case of derivative-free optimization, this function will simply call _sample.

infer(obs, ebm) abstractmethod

Overview

Optimize for the best action conditioned on the current observation.

Arguments: - obs (:obj:torch.Tensor): Observations. - ebm (:obj:torch.nn.Module): Energy based model. Returns: - best_action_samples (:obj:torch.Tensor): Best actions. Shapes: - obs (:obj:torch.Tensor): :math:(B, O). - ebm (:obj:torch.nn.Module): :math:(B, N, O). - best_action_samples (:obj:torch.Tensor): :math:(B, A).

DFO

Bases: StochasticOptimizer

Overview

Derivative-Free Optimizer in paper Implicit Behavioral Cloning. https://arxiv.org/abs/2109.00137

Interface: init, sample, infer

__init__(noise_scale=0.33, noise_shrink=0.5, iters=3, train_samples=8, inference_samples=16384, device='cpu')

Overview

Initialize the Derivative-Free Optimizer

Arguments: - noise_scale (:obj:float): Initial noise scale. - noise_shrink (:obj:float): Noise scale shrink rate. - iters (:obj:int): Number of iterations. - train_samples (:obj:int): Number of samples for training. - inference_samples (:obj:int): Number of samples for inference. - device (:obj:str): Device.

sample(obs, ebm)

Overview

Drawing action samples from the uniform random distribution and tiling observations to the same shape as action samples.

Arguments: - obs (:obj:torch.Tensor): Observations. - ebm (:obj:torch.nn.Module): Energy based model. Returns: - tiled_obs (:obj:torch.Tensor): Tiled observation. - action_samples (:obj:torch.Tensor): Action samples. Shapes: - obs (:obj:torch.Tensor): :math:(B, O). - ebm (:obj:torch.nn.Module): :math:(B, N, O). - tiled_obs (:obj:torch.Tensor): :math:(B, N, O). - action_samples (:obj:torch.Tensor): :math:(B, N, A). Examples: >>> obs = torch.randn(2, 4) >>> ebm = EBM(4, 5) >>> opt = DFO() >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) >>> tiled_obs, action_samples = opt.sample(obs, ebm)

infer(obs, ebm)

Overview

Optimize for the best action conditioned on the current observation.

Arguments: - obs (:obj:torch.Tensor): Observations. - ebm (:obj:torch.nn.Module): Energy based model. Returns: - best_action_samples (:obj:torch.Tensor): Actions. Shapes: - obs (:obj:torch.Tensor): :math:(B, O). - ebm (:obj:torch.nn.Module): :math:(B, N, O). - best_action_samples (:obj:torch.Tensor): :math:(B, A). Examples: >>> obs = torch.randn(2, 4) >>> ebm = EBM(4, 5) >>> opt = DFO() >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) >>> best_action_samples = opt.infer(obs, ebm)

AutoRegressiveDFO

Bases: DFO

Overview

AutoRegressive Derivative-Free Optimizer in paper Implicit Behavioral Cloning. https://arxiv.org/abs/2109.00137

Interface: __init__, infer

__init__(noise_scale=0.33, noise_shrink=0.5, iters=3, train_samples=8, inference_samples=4096, device='cpu')

Overview

Initialize the AutoRegressive Derivative-Free Optimizer

Arguments: - noise_scale (:obj:float): Initial noise scale. - noise_shrink (:obj:float): Noise scale shrink rate. - iters (:obj:int): Number of iterations. - train_samples (:obj:int): Number of samples for training. - inference_samples (:obj:int): Number of samples for inference. - device (:obj:str): Device.

infer(obs, ebm)

Overview

Optimize for the best action conditioned on the current observation.

Arguments: - obs (:obj:torch.Tensor): Observations. - ebm (:obj:torch.nn.Module): Energy based model. Returns: - best_action_samples (:obj:torch.Tensor): Actions. Shapes: - obs (:obj:torch.Tensor): :math:(B, O). - ebm (:obj:torch.nn.Module): :math:(B, N, O). - best_action_samples (:obj:torch.Tensor): :math:(B, A). Examples: >>> obs = torch.randn(2, 4) >>> ebm = EBM(4, 5) >>> opt = AutoRegressiveDFO() >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) >>> best_action_samples = opt.infer(obs, ebm)

MCMC

Bases: StochasticOptimizer

Overview

MCMC method as stochastic optimizers in paper Implicit Behavioral Cloning. https://arxiv.org/abs/2109.00137

Interface: __init__, sample, infer, grad_penalty

BaseScheduler

Bases: ABC

Overview

Base class for learning rate scheduler.

Interface: get_rate

get_rate(index) abstractmethod
Overview

Abstract method for getting learning rate.

ExponentialScheduler

Overview

Exponential learning rate schedule for Langevin sampler.

Interface: __init__, get_rate

__init__(init, decay)
Overview

Initialize the ExponentialScheduler.

Arguments: - init (:obj:float): Initial learning rate. - decay (:obj:float): Decay rate.

get_rate(index)
Overview

Get learning rate. Assumes calling sequentially.

Arguments: - index (:obj:int): Current iteration.

PolynomialScheduler

Overview

Polynomial learning rate schedule for Langevin sampler.

Interface: __init__, get_rate

__init__(init, final, power, num_steps)
Overview

Initialize the PolynomialScheduler.

Arguments: - init (:obj:float): Initial learning rate. - final (:obj:float): Final learning rate. - power (:obj:float): Power of polynomial. - num_steps (:obj:int): Number of steps.

get_rate(index)
Overview

Get learning rate for index.

Arguments: - index (:obj:int): Current iteration.

__init__(iters=100, use_langevin_negative_samples=True, train_samples=8, inference_samples=512, stepsize_scheduler=dict(init=0.5, final=1e-05, power=2.0), optimize_again=True, again_stepsize_scheduler=dict(init=1e-05, final=1e-05, power=2.0), device='cpu', noise_scale=0.5, grad_clip=None, delta_action_clip=0.5, add_grad_penalty=True, grad_norm_type='inf', grad_margin=1.0, grad_loss_weight=1.0, **kwargs)

Overview

Initialize the MCMC.

Arguments: - iters (:obj:int): Number of iterations. - use_langevin_negative_samples (:obj:bool): Whether to use Langevin sampler. - train_samples (:obj:int): Number of samples for training. - inference_samples (:obj:int): Number of samples for inference. - stepsize_scheduler (:obj:dict): Step size scheduler for Langevin sampler. - optimize_again (:obj:bool): Whether to run a second optimization. - again_stepsize_scheduler (:obj:dict): Step size scheduler for the second optimization. - device (:obj:str): Device. - noise_scale (:obj:float): Initial noise scale. - grad_clip (:obj:float): Gradient clip. - delta_action_clip (:obj:float): Action clip. - add_grad_penalty (:obj:bool): Whether to add gradient penalty. - grad_norm_type (:obj:str): Gradient norm type. - grad_margin (:obj:float): Gradient margin. - grad_loss_weight (:obj:float): Gradient loss weight.

grad_penalty(obs, action, ebm)

Overview

Calculate gradient penalty.

Arguments: - obs (:obj:torch.Tensor): Observations. - action (:obj:torch.Tensor): Actions. - ebm (:obj:torch.nn.Module): Energy based model. Returns: - loss (:obj:torch.Tensor): Gradient penalty. Shapes: - obs (:obj:torch.Tensor): :math:(B, N+1, O). - action (:obj:torch.Tensor): :math:(B, N+1, A). - ebm (:obj:torch.nn.Module): :math:(B, N+1, O). - loss (:obj:torch.Tensor): :math:(B, ).

sample(obs, ebm)

Overview

Create tiled observations and sample counter-negatives for InfoNCE loss.

Arguments: - obs (:obj:torch.Tensor): Observations. - ebm (:obj:torch.nn.Module): Energy based model. Returns: - tiled_obs (:obj:torch.Tensor): Tiled observations. - action_samples (:obj:torch.Tensor): Action samples. Shapes: - obs (:obj:torch.Tensor): :math:(B, O). - ebm (:obj:torch.nn.Module): :math:(B, N, O). - tiled_obs (:obj:torch.Tensor): :math:(B, N, O). - action_samples (:obj:torch.Tensor): :math:(B, N, A). Examples: >>> obs = torch.randn(2, 4) >>> ebm = EBM(4, 5) >>> opt = MCMC() >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) >>> tiled_obs, action_samples = opt.sample(obs, ebm)

infer(obs, ebm)

Overview

Optimize for the best action conditioned on the current observation.

Arguments: - obs (:obj:torch.Tensor): Observations. - ebm (:obj:torch.nn.Module): Energy based model. Returns: - best_action_samples (:obj:torch.Tensor): Actions. Shapes: - obs (:obj:torch.Tensor): :math:(B, O). - ebm (:obj:torch.nn.Module): :math:(B, N, O). - best_action_samples (:obj:torch.Tensor): :math:(B, A). Examples: >>> obs = torch.randn(2, 4) >>> ebm = EBM(4, 5) >>> opt = MCMC() >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) >>> best_action_samples = opt.infer(obs, ebm)

EBM

Bases: Module

Overview

Energy based model.

Interface: __init__, forward

__init__(obs_shape, action_shape, hidden_size=512, hidden_layer_num=4, **kwargs)

Overview

Initialize the EBM.

Arguments: - obs_shape (:obj:int): Observation shape. - action_shape (:obj:int): Action shape. - hidden_size (:obj:int): Hidden size. - hidden_layer_num (:obj:int): Number of hidden layers.

forward(obs, action)

Overview

Forward computation graph of EBM.

Arguments: - obs (:obj:torch.Tensor): Observation of shape (B, N, O). - action (:obj:torch.Tensor): Action of shape (B, N, A). Returns: - pred (:obj:torch.Tensor): Energy of shape (B, N). Examples: >>> obs = torch.randn(2, 3, 4) >>> action = torch.randn(2, 3, 5) >>> ebm = EBM(4, 5) >>> pred = ebm(obs, action)

AutoregressiveEBM

Bases: Module

Overview

Autoregressive energy based model.

Interface: __init__, forward

__init__(obs_shape, action_shape, hidden_size=512, hidden_layer_num=4)

Overview

Initialize the AutoregressiveEBM.

Arguments: - obs_shape (:obj:int): Observation shape. - action_shape (:obj:int): Action shape. - hidden_size (:obj:int): Hidden size. - hidden_layer_num (:obj:int): Number of hidden layers.

forward(obs, action)

Overview

Forward computation graph of AutoregressiveEBM.

Arguments: - obs (:obj:torch.Tensor): Observation of shape (B, N, O). - action (:obj:torch.Tensor): Action of shape (B, N, A). Returns: - pred (:obj:torch.Tensor): Energy of shape (B, N, A). Examples: >>> obs = torch.randn(2, 3, 4) >>> action = torch.randn(2, 3, 5) >>> arebm = AutoregressiveEBM(4, 5) >>> pred = arebm(obs, action)

create_stochastic_optimizer(device, stochastic_optimizer_config)

Overview

Create stochastic optimizer.

Arguments: - device (:obj:str): Device. - stochastic_optimizer_config (:obj:dict): Stochastic optimizer config.

no_ebm_grad()

Wrapper that disables energy based model gradients

Full Source Code

../ding/model/template/ebm.py

1""" 2Vanilla DFO and EBM are adapted from https://github.com/kevinzakka/ibc. 3MCMC is adapted from https://github.com/google-research/ibc. 4""" 5from typing import Callable, Tuple 6from functools import wraps 7 8import numpy as np 9import torch 10import torch.nn as nn 11import torch.nn.functional as F 12 13from abc import ABC, abstractmethod 14 15from ding.utils import MODEL_REGISTRY, STOCHASTIC_OPTIMIZER_REGISTRY 16from ding.torch_utils import unsqueeze_repeat 17from ding.model.wrapper import IModelWrapper 18from ding.model.common import RegressionHead 19 20 21def create_stochastic_optimizer(device: str, stochastic_optimizer_config: dict): 22 """ 23 Overview: 24 Create stochastic optimizer. 25 Arguments: 26 - device (:obj:`str`): Device. 27 - stochastic_optimizer_config (:obj:`dict`): Stochastic optimizer config. 28 """ 29 return STOCHASTIC_OPTIMIZER_REGISTRY.build( 30 stochastic_optimizer_config.pop("type"), device=device, **stochastic_optimizer_config 31 ) 32 33 34def no_ebm_grad(): 35 """Wrapper that disables energy based model gradients""" 36 37 def ebm_disable_grad_wrapper(func: Callable): 38 39 @wraps(func) 40 def wrapper(*args, **kwargs): 41 ebm = args[-1] 42 assert isinstance(ebm, (IModelWrapper, nn.Module)),\ 43 'Make sure ebm is the last positional arguments.' 44 ebm.requires_grad_(False) 45 result = func(*args, **kwargs) 46 ebm.requires_grad_(True) 47 return result 48 49 return wrapper 50 51 return ebm_disable_grad_wrapper 52 53 54class StochasticOptimizer(ABC): 55 """ 56 Overview: 57 Base class for stochastic optimizers. 58 Interface: 59 ``__init__``, ``_sample``, ``_get_best_action_sample``, ``set_action_bounds``, ``sample``, ``infer`` 60 """ 61 62 def _sample(self, obs: torch.Tensor, num_samples: int) -> Tuple[torch.Tensor, torch.Tensor]: 63 """ 64 Overview: 65 Drawing action samples from the uniform random distribution \ 66 and tiling observations to the same shape as action samples. 67 Arguments: 68 - obs (:obj:`torch.Tensor`): Observation. 69 - num_samples (:obj:`int`): The number of negative samples. 70 Returns: 71 - tiled_obs (:obj:`torch.Tensor`): Observations tiled. 72 - action (:obj:`torch.Tensor`): Action sampled. 73 Shapes: 74 - obs (:obj:`torch.Tensor`): :math:`(B, O)`. 75 - num_samples (:obj:`int`): :math:`N`. 76 - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. 77 - action (:obj:`torch.Tensor`): :math:`(B, N, A)`. 78 Examples: 79 >>> obs = torch.randn(2, 4) 80 >>> opt = StochasticOptimizer() 81 >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) 82 >>> tiled_obs, action = opt._sample(obs, 8) 83 """ 84 size = (obs.shape[0], num_samples, self.action_bounds.shape[1]) 85 low, high = self.action_bounds[0, :], self.action_bounds[1, :] 86 action_samples = low + (high - low) * torch.rand(size).to(self.device) 87 tiled_obs = unsqueeze_repeat(obs, num_samples, 1) 88 return tiled_obs, action_samples 89 90 @staticmethod 91 @torch.no_grad() 92 def _get_best_action_sample(obs: torch.Tensor, action_samples: torch.Tensor, ebm: nn.Module): 93 """ 94 Overview: 95 Return one action for each batch with highest probability (lowest energy). 96 Arguments: 97 - obs (:obj:`torch.Tensor`): Observation. 98 - action_samples (:obj:`torch.Tensor`): Action from uniform distributions. 99 Returns: 100 - best_action_samples (:obj:`torch.Tensor`): Best action. 101 Shapes: 102 - obs (:obj:`torch.Tensor`): :math:`(B, O)`. 103 - action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`. 104 - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. 105 Examples: 106 >>> obs = torch.randn(2, 4) 107 >>> action_samples = torch.randn(2, 8, 5) 108 >>> ebm = EBM(4, 5) 109 >>> opt = StochasticOptimizer() 110 >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) 111 >>> best_action_samples = opt._get_best_action_sample(obs, action_samples, ebm) 112 """ 113 # (B, N) 114 energies = ebm.forward(obs, action_samples) 115 probs = F.softmax(-1.0 * energies, dim=-1) 116 # (B, ) 117 best_idxs = probs.argmax(dim=-1) 118 return action_samples[torch.arange(action_samples.size(0)), best_idxs] 119 120 def set_action_bounds(self, action_bounds: np.ndarray): 121 """ 122 Overview: 123 Set action bounds calculated from the dataset statistics. 124 Arguments: 125 - action_bounds (:obj:`np.ndarray`): Array of shape (2, A), \ 126 where action_bounds[0] is lower bound and action_bounds[1] is upper bound. 127 Returns: 128 - action_bounds (:obj:`torch.Tensor`): Action bounds. 129 Shapes: 130 - action_bounds (:obj:`np.ndarray`): :math:`(2, A)`. 131 - action_bounds (:obj:`torch.Tensor`): :math:`(2, A)`. 132 Examples: 133 >>> opt = StochasticOptimizer() 134 >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) 135 """ 136 self.action_bounds = torch.as_tensor(action_bounds, dtype=torch.float32).to(self.device) 137 138 @abstractmethod 139 def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]: 140 """ 141 Overview: 142 Create tiled observations and sample counter-negatives for InfoNCE loss. 143 Arguments: 144 - obs (:obj:`torch.Tensor`): Observations. 145 - ebm (:obj:`torch.nn.Module`): Energy based model. 146 Returns: 147 - tiled_obs (:obj:`torch.Tensor`): Tiled observations. 148 - action (:obj:`torch.Tensor`): Actions. 149 Shapes: 150 - obs (:obj:`torch.Tensor`): :math:`(B, O)`. 151 - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. 152 - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. 153 - action (:obj:`torch.Tensor`): :math:`(B, N, A)`. 154 155 .. note:: In the case of derivative-free optimization, this function will simply call _sample. 156 """ 157 raise NotImplementedError 158 159 @abstractmethod 160 def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: 161 """ 162 Overview: 163 Optimize for the best action conditioned on the current observation. 164 Arguments: 165 - obs (:obj:`torch.Tensor`): Observations. 166 - ebm (:obj:`torch.nn.Module`): Energy based model. 167 Returns: 168 - best_action_samples (:obj:`torch.Tensor`): Best actions. 169 Shapes: 170 - obs (:obj:`torch.Tensor`): :math:`(B, O)`. 171 - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. 172 - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. 173 """ 174 raise NotImplementedError 175 176 177@STOCHASTIC_OPTIMIZER_REGISTRY.register('dfo') 178class DFO(StochasticOptimizer): 179 """ 180 Overview: 181 Derivative-Free Optimizer in paper Implicit Behavioral Cloning. 182 https://arxiv.org/abs/2109.00137 183 Interface: 184 ``init``, ``sample``, ``infer`` 185 """ 186 187 def __init__( 188 self, 189 noise_scale: float = 0.33, 190 noise_shrink: float = 0.5, 191 iters: int = 3, 192 train_samples: int = 8, 193 inference_samples: int = 16384, 194 device: str = 'cpu', 195 ): 196 """ 197 Overview: 198 Initialize the Derivative-Free Optimizer 199 Arguments: 200 - noise_scale (:obj:`float`): Initial noise scale. 201 - noise_shrink (:obj:`float`): Noise scale shrink rate. 202 - iters (:obj:`int`): Number of iterations. 203 - train_samples (:obj:`int`): Number of samples for training. 204 - inference_samples (:obj:`int`): Number of samples for inference. 205 - device (:obj:`str`): Device. 206 """ 207 self.action_bounds = None 208 self.noise_scale = noise_scale 209 self.noise_shrink = noise_shrink 210 self.iters = iters 211 self.train_samples = train_samples 212 self.inference_samples = inference_samples 213 self.device = device 214 215 def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]: 216 """ 217 Overview: 218 Drawing action samples from the uniform random distribution \ 219 and tiling observations to the same shape as action samples. 220 Arguments: 221 - obs (:obj:`torch.Tensor`): Observations. 222 - ebm (:obj:`torch.nn.Module`): Energy based model. 223 Returns: 224 - tiled_obs (:obj:`torch.Tensor`): Tiled observation. 225 - action_samples (:obj:`torch.Tensor`): Action samples. 226 Shapes: 227 - obs (:obj:`torch.Tensor`): :math:`(B, O)`. 228 - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. 229 - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. 230 - action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`. 231 Examples: 232 >>> obs = torch.randn(2, 4) 233 >>> ebm = EBM(4, 5) 234 >>> opt = DFO() 235 >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) 236 >>> tiled_obs, action_samples = opt.sample(obs, ebm) 237 """ 238 return self._sample(obs, self.train_samples) 239 240 @torch.no_grad() 241 def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: 242 """ 243 Overview: 244 Optimize for the best action conditioned on the current observation. 245 Arguments: 246 - obs (:obj:`torch.Tensor`): Observations. 247 - ebm (:obj:`torch.nn.Module`): Energy based model. 248 Returns: 249 - best_action_samples (:obj:`torch.Tensor`): Actions. 250 Shapes: 251 - obs (:obj:`torch.Tensor`): :math:`(B, O)`. 252 - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. 253 - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. 254 Examples: 255 >>> obs = torch.randn(2, 4) 256 >>> ebm = EBM(4, 5) 257 >>> opt = DFO() 258 >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) 259 >>> best_action_samples = opt.infer(obs, ebm) 260 """ 261 noise_scale = self.noise_scale 262 263 # (B, N, O), (B, N, A) 264 obs, action_samples = self._sample(obs, self.inference_samples) 265 266 for i in range(self.iters): 267 # (B, N) 268 energies = ebm.forward(obs, action_samples) 269 probs = F.softmax(-1.0 * energies, dim=-1) 270 271 # Resample with replacement. 272 idxs = torch.multinomial(probs, self.inference_samples, replacement=True) 273 action_samples = action_samples[torch.arange(action_samples.size(0)).unsqueeze(-1), idxs] 274 275 # Add noise and clip to target bounds. 276 action_samples = action_samples + torch.randn_like(action_samples) * noise_scale 277 action_samples = action_samples.clamp(min=self.action_bounds[0, :], max=self.action_bounds[1, :]) 278 279 noise_scale *= self.noise_shrink 280 281 # Return target with highest probability. 282 return self._get_best_action_sample(obs, action_samples, ebm) 283 284 285@STOCHASTIC_OPTIMIZER_REGISTRY.register('ardfo') 286class AutoRegressiveDFO(DFO): 287 """ 288 Overview: 289 AutoRegressive Derivative-Free Optimizer in paper Implicit Behavioral Cloning. 290 https://arxiv.org/abs/2109.00137 291 Interface: 292 ``__init__``, ``infer`` 293 """ 294 295 def __init__( 296 self, 297 noise_scale: float = 0.33, 298 noise_shrink: float = 0.5, 299 iters: int = 3, 300 train_samples: int = 8, 301 inference_samples: int = 4096, 302 device: str = 'cpu', 303 ): 304 """ 305 Overview: 306 Initialize the AutoRegressive Derivative-Free Optimizer 307 Arguments: 308 - noise_scale (:obj:`float`): Initial noise scale. 309 - noise_shrink (:obj:`float`): Noise scale shrink rate. 310 - iters (:obj:`int`): Number of iterations. 311 - train_samples (:obj:`int`): Number of samples for training. 312 - inference_samples (:obj:`int`): Number of samples for inference. 313 - device (:obj:`str`): Device. 314 """ 315 super().__init__(noise_scale, noise_shrink, iters, train_samples, inference_samples, device) 316 317 @torch.no_grad() 318 def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: 319 """ 320 Overview: 321 Optimize for the best action conditioned on the current observation. 322 Arguments: 323 - obs (:obj:`torch.Tensor`): Observations. 324 - ebm (:obj:`torch.nn.Module`): Energy based model. 325 Returns: 326 - best_action_samples (:obj:`torch.Tensor`): Actions. 327 Shapes: 328 - obs (:obj:`torch.Tensor`): :math:`(B, O)`. 329 - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. 330 - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. 331 Examples: 332 >>> obs = torch.randn(2, 4) 333 >>> ebm = EBM(4, 5) 334 >>> opt = AutoRegressiveDFO() 335 >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) 336 >>> best_action_samples = opt.infer(obs, ebm) 337 """ 338 noise_scale = self.noise_scale 339 340 # (B, N, O), (B, N, A) 341 obs, action_samples = self._sample(obs, self.inference_samples) 342 343 for i in range(self.iters): 344 # j: action_dim index 345 for j in range(action_samples.shape[-1]): 346 # (B, N) 347 energies = ebm.forward(obs, action_samples)[..., j] 348 probs = F.softmax(-1.0 * energies, dim=-1) 349 350 # Resample with replacement. 351 idxs = torch.multinomial(probs, self.inference_samples, replacement=True) 352 action_samples = action_samples[torch.arange(action_samples.size(0)).unsqueeze(-1), idxs] 353 354 # Add noise and clip to target bounds. 355 action_samples[..., j] = action_samples[..., j] + torch.randn_like(action_samples[..., j]) * noise_scale 356 357 action_samples[..., j] = action_samples[..., j].clamp( 358 min=self.action_bounds[0, j], max=self.action_bounds[1, j] 359 ) 360 361 noise_scale *= self.noise_shrink 362 363 # (B, N) 364 energies = ebm.forward(obs, action_samples)[..., -1] 365 probs = F.softmax(-1.0 * energies, dim=-1) 366 # (B, ) 367 best_idxs = probs.argmax(dim=-1) 368 return action_samples[torch.arange(action_samples.size(0)), best_idxs] 369 370 371@STOCHASTIC_OPTIMIZER_REGISTRY.register('mcmc') 372class MCMC(StochasticOptimizer): 373 """ 374 Overview: 375 MCMC method as stochastic optimizers in paper Implicit Behavioral Cloning. 376 https://arxiv.org/abs/2109.00137 377 Interface: 378 ``__init__``, ``sample``, ``infer``, ``grad_penalty`` 379 """ 380 381 class BaseScheduler(ABC): 382 """ 383 Overview: 384 Base class for learning rate scheduler. 385 Interface: 386 ``get_rate`` 387 """ 388 389 @abstractmethod 390 def get_rate(self, index): 391 """ 392 Overview: 393 Abstract method for getting learning rate. 394 """ 395 raise NotImplementedError 396 397 class ExponentialScheduler: 398 """ 399 Overview: 400 Exponential learning rate schedule for Langevin sampler. 401 Interface: 402 ``__init__``, ``get_rate`` 403 """ 404 405 def __init__(self, init, decay): 406 """ 407 Overview: 408 Initialize the ExponentialScheduler. 409 Arguments: 410 - init (:obj:`float`): Initial learning rate. 411 - decay (:obj:`float`): Decay rate. 412 """ 413 self._decay = decay 414 self._latest_lr = init 415 416 def get_rate(self, index): 417 """ 418 Overview: 419 Get learning rate. Assumes calling sequentially. 420 Arguments: 421 - index (:obj:`int`): Current iteration. 422 """ 423 del index 424 lr = self._latest_lr 425 self._latest_lr *= self._decay 426 return lr 427 428 class PolynomialScheduler: 429 """ 430 Overview: 431 Polynomial learning rate schedule for Langevin sampler. 432 Interface: 433 ``__init__``, ``get_rate`` 434 """ 435 436 def __init__(self, init, final, power, num_steps): 437 """ 438 Overview: 439 Initialize the PolynomialScheduler. 440 Arguments: 441 - init (:obj:`float`): Initial learning rate. 442 - final (:obj:`float`): Final learning rate. 443 - power (:obj:`float`): Power of polynomial. 444 - num_steps (:obj:`int`): Number of steps. 445 """ 446 self._init = init 447 self._final = final 448 self._power = power 449 self._num_steps = num_steps 450 451 def get_rate(self, index): 452 """ 453 Overview: 454 Get learning rate for index. 455 Arguments: 456 - index (:obj:`int`): Current iteration. 457 """ 458 if index == -1: 459 return self._init 460 return ( 461 (self._init - self._final) * ((1 - (float(index) / float(self._num_steps - 1))) ** (self._power)) 462 ) + self._final 463 464 def __init__( 465 self, 466 iters: int = 100, 467 use_langevin_negative_samples: bool = True, 468 train_samples: int = 8, 469 inference_samples: int = 512, 470 stepsize_scheduler: dict = dict( 471 init=0.5, 472 final=1e-5, 473 power=2.0, 474 # num_steps, 475 ), 476 optimize_again: bool = True, 477 again_stepsize_scheduler: dict = dict( 478 init=1e-5, 479 final=1e-5, 480 power=2.0, 481 # num_steps, 482 ), 483 device: str = 'cpu', 484 # langevin_step 485 noise_scale: float = 0.5, 486 grad_clip=None, 487 delta_action_clip: float = 0.5, 488 add_grad_penalty: bool = True, 489 grad_norm_type: str = 'inf', 490 grad_margin: float = 1.0, 491 grad_loss_weight: float = 1.0, 492 **kwargs, 493 ): 494 """ 495 Overview: 496 Initialize the MCMC. 497 Arguments: 498 - iters (:obj:`int`): Number of iterations. 499 - use_langevin_negative_samples (:obj:`bool`): Whether to use Langevin sampler. 500 - train_samples (:obj:`int`): Number of samples for training. 501 - inference_samples (:obj:`int`): Number of samples for inference. 502 - stepsize_scheduler (:obj:`dict`): Step size scheduler for Langevin sampler. 503 - optimize_again (:obj:`bool`): Whether to run a second optimization. 504 - again_stepsize_scheduler (:obj:`dict`): Step size scheduler for the second optimization. 505 - device (:obj:`str`): Device. 506 - noise_scale (:obj:`float`): Initial noise scale. 507 - grad_clip (:obj:`float`): Gradient clip. 508 - delta_action_clip (:obj:`float`): Action clip. 509 - add_grad_penalty (:obj:`bool`): Whether to add gradient penalty. 510 - grad_norm_type (:obj:`str`): Gradient norm type. 511 - grad_margin (:obj:`float`): Gradient margin. 512 - grad_loss_weight (:obj:`float`): Gradient loss weight. 513 """ 514 self.iters = iters 515 self.use_langevin_negative_samples = use_langevin_negative_samples 516 self.train_samples = train_samples 517 self.inference_samples = inference_samples 518 self.stepsize_scheduler = stepsize_scheduler 519 self.optimize_again = optimize_again 520 self.again_stepsize_scheduler = again_stepsize_scheduler 521 self.device = device 522 523 self.noise_scale = noise_scale 524 self.grad_clip = grad_clip 525 self.delta_action_clip = delta_action_clip 526 self.add_grad_penalty = add_grad_penalty 527 self.grad_norm_type = grad_norm_type 528 self.grad_margin = grad_margin 529 self.grad_loss_weight = grad_loss_weight 530 531 @staticmethod 532 def _gradient_wrt_act( 533 obs: torch.Tensor, 534 action: torch.Tensor, 535 ebm: nn.Module, 536 create_graph: bool = False, 537 ) -> torch.Tensor: 538 """ 539 Overview: 540 Calculate gradient w.r.t action. 541 Arguments: 542 - obs (:obj:`torch.Tensor`): Observations. 543 - action (:obj:`torch.Tensor`): Actions. 544 - ebm (:obj:`torch.nn.Module`): Energy based model. 545 - create_graph (:obj:`bool`): Whether to create graph. 546 Returns: 547 - grad (:obj:`torch.Tensor`): Gradient w.r.t action. 548 Shapes: 549 - obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. 550 - action (:obj:`torch.Tensor`): :math:`(B, N, A)`. 551 - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. 552 - grad (:obj:`torch.Tensor`): :math:`(B, N, A)`. 553 """ 554 action.requires_grad_(True) 555 energy = ebm.forward(obs, action).sum() 556 # `create_graph` set to `True` when second order derivative 557 # is needed i.e, d(de/da)/d_param 558 grad = torch.autograd.grad(energy, action, create_graph=create_graph)[0] 559 action.requires_grad_(False) 560 return grad 561 562 def grad_penalty(self, obs: torch.Tensor, action: torch.Tensor, ebm: nn.Module) -> torch.Tensor: 563 """ 564 Overview: 565 Calculate gradient penalty. 566 Arguments: 567 - obs (:obj:`torch.Tensor`): Observations. 568 - action (:obj:`torch.Tensor`): Actions. 569 - ebm (:obj:`torch.nn.Module`): Energy based model. 570 Returns: 571 - loss (:obj:`torch.Tensor`): Gradient penalty. 572 Shapes: 573 - obs (:obj:`torch.Tensor`): :math:`(B, N+1, O)`. 574 - action (:obj:`torch.Tensor`): :math:`(B, N+1, A)`. 575 - ebm (:obj:`torch.nn.Module`): :math:`(B, N+1, O)`. 576 - loss (:obj:`torch.Tensor`): :math:`(B, )`. 577 """ 578 if not self.add_grad_penalty: 579 return 0. 580 # (B, N+1, A), this gradient is differentiable w.r.t model parameters 581 de_dact = MCMC._gradient_wrt_act(obs, action, ebm, create_graph=True) 582 583 def compute_grad_norm(grad_norm_type, de_dact) -> torch.Tensor: 584 # de_deact: B, N+1, A 585 # return: B, N+1 586 grad_norm_type_to_ord = { 587 '1': 1, 588 '2': 2, 589 'inf': float('inf'), 590 } 591 ord = grad_norm_type_to_ord[grad_norm_type] 592 return torch.linalg.norm(de_dact, ord, dim=-1) 593 594 # (B, N+1) 595 grad_norms = compute_grad_norm(self.grad_norm_type, de_dact) 596 grad_norms = grad_norms - self.grad_margin 597 grad_norms = grad_norms.clamp(min=0., max=1e10) 598 grad_norms = grad_norms.pow(2) 599 600 grad_loss = grad_norms.mean() 601 return grad_loss * self.grad_loss_weight 602 603 # can not use @torch.no_grad() during the inference 604 # because we need to calculate gradient w.r.t inputs as MCMC updates. 605 @no_ebm_grad() 606 def _langevin_step(self, obs: torch.Tensor, action: torch.Tensor, stepsize: float, ebm: nn.Module) -> torch.Tensor: 607 """ 608 Overview: 609 Run one langevin MCMC step. 610 Arguments: 611 - obs (:obj:`torch.Tensor`): Observations. 612 - action (:obj:`torch.Tensor`): Actions. 613 - stepsize (:obj:`float`): Step size. 614 - ebm (:obj:`torch.nn.Module`): Energy based model. 615 Returns: 616 - action (:obj:`torch.Tensor`): Actions. 617 Shapes: 618 - obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. 619 - action (:obj:`torch.Tensor`): :math:`(B, N, A)`. 620 - stepsize (:obj:`float`): :math:`(B, )`. 621 - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. 622 """ 623 l_lambda = 1.0 624 de_dact = MCMC._gradient_wrt_act(obs, action, ebm) 625 626 if self.grad_clip: 627 de_dact = de_dact.clamp(min=-self.grad_clip, max=self.grad_clip) 628 629 gradient_scale = 0.5 630 de_dact = (gradient_scale * l_lambda * de_dact + torch.randn_like(de_dact) * l_lambda * self.noise_scale) 631 632 delta_action = stepsize * de_dact 633 delta_action_clip = self.delta_action_clip * 0.5 * (self.action_bounds[1] - self.action_bounds[0]) 634 delta_action = delta_action.clamp(min=-delta_action_clip, max=delta_action_clip) 635 636 action = action - delta_action 637 action = action.clamp(min=self.action_bounds[0], max=self.action_bounds[1]) 638 639 return action 640 641 @no_ebm_grad() 642 def _langevin_action_given_obs( 643 self, 644 obs: torch.Tensor, 645 action: torch.Tensor, 646 ebm: nn.Module, 647 scheduler: BaseScheduler = None 648 ) -> torch.Tensor: 649 """ 650 Overview: 651 Run langevin MCMC for `self.iters` steps. 652 Arguments: 653 - obs (:obj:`torch.Tensor`): Observations. 654 - action (:obj:`torch.Tensor`): Actions. 655 - ebm (:obj:`torch.nn.Module`): Energy based model. 656 - scheduler (:obj:`BaseScheduler`): Learning rate scheduler. 657 Returns: 658 - action (:obj:`torch.Tensor`): Actions. 659 Shapes: 660 - obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. 661 - action (:obj:`torch.Tensor`): :math:`(B, N, A)`. 662 - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. 663 """ 664 if not scheduler: 665 self.stepsize_scheduler['num_steps'] = self.iters 666 scheduler = MCMC.PolynomialScheduler(**self.stepsize_scheduler) 667 stepsize = scheduler.get_rate(-1) 668 for i in range(self.iters): 669 action = self._langevin_step(obs, action, stepsize, ebm) 670 stepsize = scheduler.get_rate(i) 671 return action 672 673 @no_ebm_grad() 674 def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]: 675 """ 676 Overview: 677 Create tiled observations and sample counter-negatives for InfoNCE loss. 678 Arguments: 679 - obs (:obj:`torch.Tensor`): Observations. 680 - ebm (:obj:`torch.nn.Module`): Energy based model. 681 Returns: 682 - tiled_obs (:obj:`torch.Tensor`): Tiled observations. 683 - action_samples (:obj:`torch.Tensor`): Action samples. 684 Shapes: 685 - obs (:obj:`torch.Tensor`): :math:`(B, O)`. 686 - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. 687 - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. 688 - action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`. 689 Examples: 690 >>> obs = torch.randn(2, 4) 691 >>> ebm = EBM(4, 5) 692 >>> opt = MCMC() 693 >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) 694 >>> tiled_obs, action_samples = opt.sample(obs, ebm) 695 """ 696 obs, uniform_action_samples = self._sample(obs, self.train_samples) 697 if not self.use_langevin_negative_samples: 698 return obs, uniform_action_samples 699 langevin_action_samples = self._langevin_action_given_obs(obs, uniform_action_samples, ebm) 700 return obs, langevin_action_samples 701 702 @no_ebm_grad() 703 def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: 704 """ 705 Overview: 706 Optimize for the best action conditioned on the current observation. 707 Arguments: 708 - obs (:obj:`torch.Tensor`): Observations. 709 - ebm (:obj:`torch.nn.Module`): Energy based model. 710 Returns: 711 - best_action_samples (:obj:`torch.Tensor`): Actions. 712 Shapes: 713 - obs (:obj:`torch.Tensor`): :math:`(B, O)`. 714 - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. 715 - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. 716 Examples: 717 >>> obs = torch.randn(2, 4) 718 >>> ebm = EBM(4, 5) 719 >>> opt = MCMC() 720 >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) 721 >>> best_action_samples = opt.infer(obs, ebm) 722 """ 723 # (B, N, O), (B, N, A) 724 obs, uniform_action_samples = self._sample(obs, self.inference_samples) 725 action_samples = self._langevin_action_given_obs( 726 obs, 727 uniform_action_samples, 728 ebm, 729 ) 730 731 # Run a second optimization, a trick for more precise inference 732 if self.optimize_again: 733 self.again_stepsize_scheduler['num_steps'] = self.iters 734 action_samples = self._langevin_action_given_obs( 735 obs, 736 action_samples, 737 ebm, 738 scheduler=MCMC.PolynomialScheduler(**self.again_stepsize_scheduler), 739 ) 740 741 # action_samples: B, N, A 742 return self._get_best_action_sample(obs, action_samples, ebm) 743 744 745@MODEL_REGISTRY.register('ebm') 746class EBM(nn.Module): 747 """ 748 Overview: 749 Energy based model. 750 Interface: 751 ``__init__``, ``forward`` 752 """ 753 754 def __init__( 755 self, 756 obs_shape: int, 757 action_shape: int, 758 hidden_size: int = 512, 759 hidden_layer_num: int = 4, 760 **kwargs, 761 ): 762 """ 763 Overview: 764 Initialize the EBM. 765 Arguments: 766 - obs_shape (:obj:`int`): Observation shape. 767 - action_shape (:obj:`int`): Action shape. 768 - hidden_size (:obj:`int`): Hidden size. 769 - hidden_layer_num (:obj:`int`): Number of hidden layers. 770 """ 771 super().__init__() 772 input_size = obs_shape + action_shape 773 self.net = nn.Sequential( 774 nn.Linear(input_size, hidden_size), nn.ReLU(), 775 RegressionHead( 776 hidden_size, 777 1, 778 hidden_layer_num, 779 final_tanh=False, 780 ) 781 ) 782 783 def forward(self, obs, action): 784 """ 785 Overview: 786 Forward computation graph of EBM. 787 Arguments: 788 - obs (:obj:`torch.Tensor`): Observation of shape (B, N, O). 789 - action (:obj:`torch.Tensor`): Action of shape (B, N, A). 790 Returns: 791 - pred (:obj:`torch.Tensor`): Energy of shape (B, N). 792 Examples: 793 >>> obs = torch.randn(2, 3, 4) 794 >>> action = torch.randn(2, 3, 5) 795 >>> ebm = EBM(4, 5) 796 >>> pred = ebm(obs, action) 797 """ 798 x = torch.cat([obs, action], -1) 799 x = self.net(x) 800 return x['pred'] 801 802 803@MODEL_REGISTRY.register('arebm') 804class AutoregressiveEBM(nn.Module): 805 """ 806 Overview: 807 Autoregressive energy based model. 808 Interface: 809 ``__init__``, ``forward`` 810 """ 811 812 def __init__( 813 self, 814 obs_shape: int, 815 action_shape: int, 816 hidden_size: int = 512, 817 hidden_layer_num: int = 4, 818 ): 819 """ 820 Overview: 821 Initialize the AutoregressiveEBM. 822 Arguments: 823 - obs_shape (:obj:`int`): Observation shape. 824 - action_shape (:obj:`int`): Action shape. 825 - hidden_size (:obj:`int`): Hidden size. 826 - hidden_layer_num (:obj:`int`): Number of hidden layers. 827 """ 828 super().__init__() 829 self.ebm_list = nn.ModuleList() 830 for i in range(action_shape): 831 self.ebm_list.append(EBM(obs_shape, i + 1, hidden_size, hidden_layer_num)) 832 833 def forward(self, obs, action): 834 """ 835 Overview: 836 Forward computation graph of AutoregressiveEBM. 837 Arguments: 838 - obs (:obj:`torch.Tensor`): Observation of shape (B, N, O). 839 - action (:obj:`torch.Tensor`): Action of shape (B, N, A). 840 Returns: 841 - pred (:obj:`torch.Tensor`): Energy of shape (B, N, A). 842 Examples: 843 >>> obs = torch.randn(2, 3, 4) 844 >>> action = torch.randn(2, 3, 5) 845 >>> arebm = AutoregressiveEBM(4, 5) 846 >>> pred = arebm(obs, action) 847 """ 848 output_list = [] 849 for i, ebm in enumerate(self.ebm_list): 850 output_list.append(ebm(obs, action[..., :i + 1])) 851 return torch.stack(output_list, axis=-1)