Skip to content

ding.torch_utils.optimizer_helper

ding.torch_utils.optimizer_helper

Adam

Bases: Adam

Overview

Rewrited Adam optimizer to support more features.

Interfaces: __init__, step, _state_init, get_grad

__init__(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False, optim_type='adam', grad_clip_type=None, clip_value=None, clip_coef=5, clip_norm_type=2.0, clip_momentum_timestep=100, grad_norm_type=None, grad_ignore_type=None, ignore_value=None, ignore_coef=5, ignore_norm_type=2.0, ignore_momentum_timestep=100)

Overview

init method of refactored Adam class

Arguments: - params (:obj:iterable): – an iterable of torch.Tensor s or dict s. Specifies what Tensors should be optimized - lr (:obj:float): learning rate, default set to 1e-3 - betas (:obj:Tuple[float, float]): coefficients used for computing running averages of gradient and its square, default set to (0.9, 0.999)) - eps (:obj:float): term added to the denominator to improve numerical stability, default set to 1e-8 - weight_decay (:obj:float): weight decay coefficient, deault set to 0 - amsgrad (:obj:bool): whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond https://arxiv.org/abs/1904.09237 - optim_type (:obj:str): support ["adam", "adamw"] - grad_clip_type (:obj:str): support [None, 'clip_momentum', 'clip_value', 'clip_norm', 'clip_momentum_norm'] - clip_value (:obj:float): the value to start clipping - clip_coef (:obj:float): the cliping coefficient - clip_norm_type (:obj:float): 2.0 means use norm2 to clip - clip_momentum_timestep (:obj:int): after how many step should we start the momentum clipping - grad_ignore_type (:obj:str): support [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', 'ignore_momentum_norm'] - ignore_value (:obj:float): the value to start ignoring - ignore_coef (:obj:float): the ignoreing coefficient - ignore_norm_type (:obj:float): 2.0 means use norm2 to ignore - ignore_momentum_timestep (:obj:int): after how many step should we start the momentum ignoring

step(closure=None)

Overview

Performs a single optimization step

Arguments: - closure (:obj:callable): A closure that reevaluates the model and returns the loss, default set to None

RMSprop

Bases: RMSprop

Overview

Rewrited RMSprop optimizer to support more features.

Interfaces: __init__, step, _state_init, get_grad

__init__(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False, grad_clip_type=None, clip_value=None, clip_coef=5, clip_norm_type=2.0, clip_momentum_timestep=100, grad_norm_type=None, grad_ignore_type=None, ignore_value=None, ignore_coef=5, ignore_norm_type=2.0, ignore_momentum_timestep=100)

Overview

init method of refactored Adam class

Arguments: - params (:obj:iterable): – an iterable of torch.Tensor s or dict s. Specifies what Tensors should be optimized - lr (:obj:float): learning rate, default set to 1e-3 - alpha (:obj:float): smoothing constant, default set to 0.99 - eps (:obj:float): term added to the denominator to improve numerical stability, default set to 1e-8 - weight_decay (:obj:float): weight decay coefficient, deault set to 0 - centred (:obj:bool): if True, compute the centered RMSprop, the gradient is normalized by an estimation of its variance - grad_clip_type (:obj:str): support [None, 'clip_momentum', 'clip_value', 'clip_norm', 'clip_momentum_norm'] - clip_value (:obj:float): the value to start clipping - clip_coef (:obj:float): the cliping coefficient - clip_norm_type (:obj:float): 2.0 means use norm2 to clip - clip_momentum_timestep (:obj:int): after how many step should we start the momentum clipping - grad_ignore_type (:obj:str): support [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', 'ignore_momentum_norm'] - ignore_value (:obj:float): the value to start ignoring - ignore_coef (:obj:float): the ignoreing coefficient - ignore_norm_type (:obj:float): 2.0 means use norm2 to ignore - ignore_momentum_timestep (:obj:int): after how many step should we start the momentum ignoring

step(closure=None)

Overview

Performs a single optimization step

Arguments: - closure (:obj:callable): A closure that reevaluates the model and returns the loss, default set to None

get_grad()

Overview

calculate grad norm of the parameters whose grad norms are not None in the model.

PCGrad

Overview

PCGrad optimizer to support multi-task. you can view the paper in the following link https://arxiv.org/pdf/2001.06782.pdf

Interfaces: __init__, zero_grad, step, pc_backward Properties: - optimizer (:obj:torch.optim): the optimizer to be used

optimizer property

Overview

get the optimizer

__init__(optimizer, reduction='mean')

Overview

Initialization of PCGrad optimizer

Arguments: - optimizer (:obj:torch.optim): the optimizer to be used - reduction (:obj:str): the reduction method, support ['mean', 'sum']

zero_grad()

Overview

clear the gradient of the parameters

step()

Overview

update the parameters with the gradient

pc_backward(objectives)

Overview

calculate the gradient of the parameters

Arguments: - objectives: a list of objectives

calculate_grad_norm(model, norm_type=2)

Overview

calculate grad norm of the parameters whose grad norms are not None in the model.

Arguments: - model: torch.nn.Module - norm_type (:obj:int or inf)

calculate_grad_norm_without_bias_two_norm(model)

Overview

calculate grad norm of the parameters whose grad norms are not None in the model.

Arguments: - model: torch.nn.Module

grad_ignore_norm(parameters, max_norm, norm_type=2)

Overview

Clip the gradient norm of an iterable of parameters.

Arguments: - parameters (:obj:Iterable): an iterable of torch.Tensor - max_norm (:obj:float): the max norm of the gradients - norm_type (:obj:float): 2.0 means use norm2 to clip

grad_ignore_value(parameters, clip_value)

Overview

Clip the gradient value of an iterable of parameters.

Arguments: - parameters (:obj:Iterable): an iterable of torch.Tensor - clip_value (:obj:float): the value to start clipping

configure_weight_decay(model, weight_decay)

Overview

Separating out all parameters of the model into two buckets: those that will experience weight decay for regularization and those that won't (biases, and layer-norm or embedding weights).

Arguments: - model (:obj:nn.Module): The given PyTorch model. - weight_decay (:obj:float): Weight decay value for optimizer. Returns: - optim groups (:obj:List): The parameter groups to be set in the latter optimizer.

Full Source Code

../ding/torch_utils/optimizer_helper.py

1import torch 2import math 3from torch.nn.utils import clip_grad_norm_, clip_grad_value_ 4from typing import Union, Iterable, Tuple, Callable, List 5import torch.nn as nn 6import torch.nn.functional as F 7import torch.optim as optim 8import pdb 9import numpy as np 10import copy 11import random 12 13inf = math.inf 14 15 16def calculate_grad_norm(model: torch.nn.Module, norm_type=2) -> float: 17 """ 18 Overview: 19 calculate grad norm of the parameters whose grad norms are not None in the model. 20 Arguments: 21 - model: torch.nn.Module 22 - norm_type (:obj:`int` or `inf`) 23 """ 24 parameters = list(filter(lambda p: p.grad is not None, model.parameters())) 25 if parameters == []: 26 parameters = 0 27 return 0 28 if norm_type == 'inf': 29 total_norm = max(p.grad.data.abs().max() for p in parameters) 30 return float(total_norm) 31 else: 32 total_norm = 0 33 for p in parameters: 34 param_norm = p.grad.data.norm(norm_type) 35 total_norm += param_norm.item() ** norm_type 36 total_norm = total_norm ** (1. / norm_type) 37 return float(total_norm) 38 39 40def calculate_grad_norm_without_bias_two_norm(model: torch.nn.Module) -> float: 41 """ 42 Overview: 43 calculate grad norm of the parameters whose grad norms are not None in the model. 44 Arguments: 45 - model: torch.nn.Module 46 """ 47 _list = [] 48 for name, param in model.named_parameters(): 49 if 'bias' not in name and param.requires_grad: 50 if param.grad is None: 51 return 0 52 _list.append(param.grad.data.norm(2).item() ** 2) 53 return float(sum(_list) ** (1. / 2)) 54 55 56def grad_ignore_norm(parameters, max_norm, norm_type=2): 57 """ 58 Overview: 59 Clip the gradient norm of an iterable of parameters. 60 Arguments: 61 - parameters (:obj:`Iterable`): an iterable of torch.Tensor 62 - max_norm (:obj:`float`): the max norm of the gradients 63 - norm_type (:obj:`float`): 2.0 means use norm2 to clip 64 """ 65 if isinstance(parameters, torch.Tensor): 66 parameters = [parameters] 67 parameters = list(filter(lambda p: p.grad is not None, parameters)) 68 max_norm = float(max_norm) 69 norm_type = float(norm_type) 70 if norm_type == inf: 71 total_norm = max(p.grad.data.abs().max() for p in parameters) 72 else: 73 total_norm = 0 74 for p in parameters: 75 param_norm = p.grad.data.norm(norm_type) 76 total_norm += param_norm.item() ** norm_type 77 total_norm = total_norm ** (1. / norm_type) 78 clip_coef = max_norm / (total_norm + 1e-6) 79 if clip_coef < 1: 80 for p in parameters: 81 p.grad.zero_() 82 return total_norm 83 84 85def grad_ignore_value(parameters, clip_value): 86 """ 87 Overview: 88 Clip the gradient value of an iterable of parameters. 89 Arguments: 90 - parameters (:obj:`Iterable`): an iterable of torch.Tensor 91 - clip_value (:obj:`float`): the value to start clipping 92 """ 93 if isinstance(parameters, torch.Tensor): 94 parameters = [parameters] 95 clip_value = float(clip_value) 96 flag = False 97 for p in filter(lambda p: p.grad is not None, parameters): 98 val = p.grad.data.abs().max() 99 if val >= clip_value: 100 flag = True 101 break 102 if flag: 103 for p in filter(lambda p: p.grad is not None, parameters): 104 p.grad.data.zero_() 105 106 107class Adam(torch.optim.Adam): 108 """ 109 Overview: 110 Rewrited Adam optimizer to support more features. 111 Interfaces: 112 ``__init__``, ``step``, ``_state_init``, ``get_grad`` 113 """ 114 115 def __init__( 116 self, 117 params: Iterable, 118 lr: float = 1e-3, 119 betas: Tuple[float, float] = (0.9, 0.999), 120 eps: float = 1e-8, 121 weight_decay: float = 0, 122 amsgrad: bool = False, 123 optim_type: str = 'adam', 124 grad_clip_type: str = None, 125 clip_value: Union[float, None] = None, 126 clip_coef: float = 5, 127 clip_norm_type: float = 2.0, 128 clip_momentum_timestep: int = 100, 129 grad_norm_type: str = None, 130 grad_ignore_type: str = None, 131 ignore_value: Union[float, None] = None, 132 ignore_coef: float = 5, 133 ignore_norm_type: float = 2.0, 134 ignore_momentum_timestep: int = 100, 135 ): 136 """ 137 Overview: 138 init method of refactored Adam class 139 Arguments: 140 - params (:obj:`iterable`): – an iterable of torch.Tensor s or dict s. \ 141 Specifies what Tensors should be optimized 142 - lr (:obj:`float`): learning rate, default set to 1e-3 143 - betas (:obj:`Tuple[float, float]`): coefficients used for computing running averages of gradient and its\ 144 square, default set to (0.9, 0.999)) 145 - eps (:obj:`float`): term added to the denominator to improve numerical stability, default set to 1e-8 146 - weight_decay (:obj:`float`): weight decay coefficient, deault set to 0 147 - amsgrad (:obj:`bool`): whether to use the AMSGrad variant of this algorithm from the paper\ 148 On the Convergence of Adam and Beyond <https://arxiv.org/abs/1904.09237> 149 - optim_type (:obj:str): support ["adam", "adamw"] 150 - grad_clip_type (:obj:`str`): support [None, 'clip_momentum', 'clip_value', 'clip_norm', \ 151 'clip_momentum_norm'] 152 - clip_value (:obj:`float`): the value to start clipping 153 - clip_coef (:obj:`float`): the cliping coefficient 154 - clip_norm_type (:obj:`float`): 2.0 means use norm2 to clip 155 - clip_momentum_timestep (:obj:`int`): after how many step should we start the momentum clipping 156 - grad_ignore_type (:obj:`str`): support [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', \ 157 'ignore_momentum_norm'] 158 - ignore_value (:obj:`float`): the value to start ignoring 159 - ignore_coef (:obj:`float`): the ignoreing coefficient 160 - ignore_norm_type (:obj:`float`): 2.0 means use norm2 to ignore 161 - ignore_momentum_timestep (:obj:`int`): after how many step should we start the momentum ignoring 162 163 """ 164 165 self._support_type = { 166 'optim': ['adam', 'adamw'], 167 'grad_clip': [None, 'clip_momentum', 'clip_value', 'clip_norm', 'clip_momentum_norm'], 168 'grad_norm': [None], 169 'grad_ignore': [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', 'ignore_momentum_norm'], 170 } 171 172 assert optim_type in self._support_type['optim'] 173 assert grad_clip_type in self._support_type['grad_clip'] 174 assert grad_norm_type in self._support_type['grad_norm'] 175 assert grad_ignore_type in self._support_type['grad_ignore'] 176 if grad_clip_type: 177 assert clip_value is not None 178 if grad_ignore_type: 179 assert ignore_value is not None 180 181 self._optim_type = optim_type 182 self._grad_clip_type = grad_clip_type 183 self._grad_norm_type = grad_norm_type 184 self._grad_ignore_type = grad_ignore_type 185 self._clip_value = clip_value 186 self._clip_norm_type = clip_norm_type 187 self._clip_coef = clip_coef 188 self._ignore_value = ignore_value 189 self._ignore_norm_type = ignore_norm_type 190 self._ignore_coef = ignore_coef 191 self._clip_momentum_timestep = clip_momentum_timestep 192 self._ignore_momentum_timestep = ignore_momentum_timestep 193 194 if self._optim_type == 'adamw': 195 self._weight_decay = weight_decay 196 super(Adam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=0, amsgrad=amsgrad) 197 elif self._optim_type == 'adam': 198 super(Adam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) 199 else: 200 raise NotImplementedError( 201 "optimizer type {} is not implemented, support type is {}".format( 202 self._optim_type, self._support_type['optim'] 203 ) 204 ) 205 206 def _state_init(self, p, amsgrad): 207 """ 208 Overview: 209 Initialize the state of the optimizer 210 Arguments: 211 - p (:obj:`torch.Tensor`): the parameter to be optimized 212 - amsgrad (:obj:`bool`): whether to use the AMSGrad variant of this algorithm from the paper\ 213 On the Convergence of Adam and Beyond <https://arxiv.org/abs/1904.09237> 214 """ 215 state = self.state[p] 216 state['thre_exp_avg_sq'] = torch.zeros_like(p.data, device=p.data.device) 217 # others 218 if torch.__version__ < "1.12.0": 219 state['step'] = 0 220 # TODO 221 # wait torch upgrad to 1.4, 1.3.1 didn't support memory format state['step'] = 0 222 else: 223 state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ 224 if self.defaults['capturable'] else torch.tensor(0.) 225 226 state['exp_avg'] = torch.zeros_like(p.data) 227 # Exponential moving average of squared gradient values 228 state['exp_avg_sq'] = torch.zeros_like(p.data) 229 if amsgrad: 230 # Maintains max of all exp. moving avg. of sq. grad. values 231 state['max_exp_avg_sq'] = torch.zeros_like(p.data) 232 233 def step(self, closure: Union[Callable, None] = None): 234 """ 235 Overview: 236 Performs a single optimization step 237 Arguments: 238 - closure (:obj:`callable`): A closure that reevaluates the model and returns the loss, default set to None 239 """ 240 # clipping 241 new_params = [ 242 t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None 243 ] 244 if self._grad_clip_type == 'clip_value': 245 clip_grad_value_(new_params, self._clip_value) 246 elif self._grad_clip_type == 'clip_norm': 247 clip_grad_norm_(new_params, self._clip_value, self._clip_norm_type) 248 elif self._grad_clip_type == 'clip_momentum': 249 ''' 250 This is the implimentation mimic the clip used in OPENAI, quote: 251 'Gradients are additionally clipped per parameter to be within between ±5√v 252 where v is the running estimate of the second moment of the (unclipped) gradient' 253 ''' 254 for group in self.param_groups: 255 for p in group['params']: 256 if p.grad is None: 257 continue 258 state = self.state[p] 259 if len(state) == 0: 260 self._state_init(p, group['amsgrad']) 261 grad = p.grad.data 262 # should we use same beta group? 263 beta1, beta2 = group['betas'] 264 bias_correction2 = 1 - beta2 ** state['step'] 265 state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad) 266 if state['step'] >= self._clip_momentum_timestep: # initial value is inaccurate 267 flag = grad.abs( 268 ) > (state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) * self._clip_coef 269 grad.mul_(~flag).add_( 270 ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) * 271 self._clip_coef).mul_(flag) 272 ) 273 elif self._grad_clip_type == 'clip_momentum_norm': 274 # might have multi param_group, we should calculate each group differently. 275 for group in self.param_groups: 276 total_norm = 0 277 total_momentum_norm = 0 278 step = inf 279 for p in group['params']: 280 if p.grad is None: 281 continue 282 state = self.state[p] 283 if len(state) == 0: 284 self._state_init(p, group['amsgrad']) 285 grad = p.grad.data 286 # should we use same beta group? 287 beta1, beta2 = group['betas'] 288 bias_correction2 = 1 - beta2 ** state['step'] 289 state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad) 290 # sum total_norm 291 param_norm = grad.norm(self._clip_norm_type) 292 total_norm += param_norm.item() ** self._clip_norm_type 293 294 # sum momentum_norm 295 momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) * 296 self._clip_coef).norm(self._clip_norm_type) 297 total_momentum_norm += momentum.item() ** self._clip_norm_type 298 step = min(step, state['step']) 299 if step > self._clip_momentum_timestep: 300 total_norm = total_norm ** (1. / self._clip_norm_type) 301 total_momentum_norm = total_momentum_norm ** (1. / self._clip_norm_type) 302 clip_coef = total_momentum_norm / (total_norm + 1e-6) 303 if clip_coef < 1: 304 for p in group['params']: 305 p.grad.data.mul_(clip_coef) 306 307 if self._grad_ignore_type == 'ignore_value': 308 grad_ignore_value(new_params, self._ignore_value) 309 elif self._grad_ignore_type == 'ignore_norm': 310 grad_ignore_norm(new_params, self._ignore_value, self._ignore_norm_type) 311 elif self._grad_ignore_type == 'ignore_momentum': 312 flag = False 313 for group in self.param_groups: 314 for p in group['params']: 315 if p.grad is None: 316 continue 317 state = self.state[p] 318 if len(state) == 0: 319 self._state_init(p, group['amsgrad']) 320 grad = p.grad.data 321 # should we use same beta group? 322 beta1, beta2 = group['betas'] 323 bias_correction2 = 1 - beta2 ** state['step'] 324 state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad) 325 if state['step'] >= self._ignore_momentum_timestep: # initial value is inaccurate 326 if grad.abs() > (state['thre_exp_avg_sq'].sqrt() / 327 math.sqrt(bias_correction2)) * self._ignore_coef: 328 flag = True 329 break 330 else: 331 continue 332 break 333 334 if flag: 335 for group in self.param_groups: 336 for p in group['params']: 337 if p.grad is None: 338 continue 339 p.grad.zero_() 340 elif self._grad_ignore_type == 'ignore_momentum_norm': 341 # might have multi param_group, we should calculate each group differently. 342 step = inf 343 for group in self.param_groups: 344 total_norm = 0 345 total_momentum_norm = 0 346 for p in group['params']: 347 if p.grad is None: 348 continue 349 state = self.state[p] 350 if len(state) == 0: 351 self._state_init(p, group['amsgrad']) 352 grad = p.grad.data 353 # should we use same beta group? 354 beta1, beta2 = group['betas'] 355 bias_correction2 = 1 - beta2 ** state['step'] 356 state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad) 357 # sum total_norm 358 param_norm = grad.norm(self._ignore_norm_type) 359 total_norm += param_norm.item() ** self._ignore_norm_type 360 361 # sum momentum_norm 362 momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) * 363 self._ignore_coef).norm(self._ignore_norm_type) 364 total_momentum_norm += momentum.item() ** self._ignore_norm_type 365 step = min(step, state['step']) 366 367 if step > self._ignore_momentum_timestep: 368 total_norm = total_norm ** (1. / self._ignore_norm_type) 369 total_momentum_norm = total_momentum_norm ** (1. / self._ignore_norm_type) 370 ignore_coef = total_momentum_norm / (total_norm + 1e-6) 371 if ignore_coef < 1: 372 for p in group['params']: 373 p.grad.zero_() 374 375 # Adam optim type 376 if self._optim_type == 'adamw': 377 for group in self.param_groups: 378 for p in group['params']: 379 if p.grad is None: 380 continue 381 p.data = p.data.add(-self._weight_decay * group['lr'], p.data) 382 return super().step(closure=closure) 383 elif self._optim_type == 'adam': 384 return super().step(closure=closure) 385 386 def get_grad(self) -> float: 387 total_norm = 0. 388 params = [t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None] 389 for p in params: 390 param_norm = p.grad.data.norm(self._clip_norm_type) 391 total_norm += param_norm.item() ** self._clip_norm_type 392 return total_norm 393 394 395class RMSprop(torch.optim.RMSprop): 396 r""" 397 Overview: 398 Rewrited RMSprop optimizer to support more features. 399 Interfaces: 400 ``__init__``, ``step``, ``_state_init``, ``get_grad`` 401 """ 402 403 def __init__( 404 self, 405 params: Iterable, 406 lr: float = 1e-2, 407 alpha: float = 0.99, 408 eps: float = 1e-8, 409 weight_decay: float = 0, 410 momentum: float = 0, 411 centered: bool = False, 412 grad_clip_type: str = None, 413 clip_value: Union[float, None] = None, 414 clip_coef: float = 5, 415 clip_norm_type: float = 2.0, 416 clip_momentum_timestep: int = 100, 417 grad_norm_type: str = None, 418 grad_ignore_type: str = None, 419 ignore_value: Union[float, None] = None, 420 ignore_coef: float = 5, 421 ignore_norm_type: float = 2.0, 422 ignore_momentum_timestep: int = 100, 423 ): 424 """ 425 Overview: 426 init method of refactored Adam class 427 Arguments: 428 - params (:obj:`iterable`): – an iterable of torch.Tensor s or dict s. \ 429 Specifies what Tensors should be optimized 430 - lr (:obj:`float`): learning rate, default set to 1e-3 431 - alpha (:obj:`float`): smoothing constant, default set to 0.99 432 - eps (:obj:`float`): term added to the denominator to improve numerical stability, default set to 1e-8 433 - weight_decay (:obj:`float`): weight decay coefficient, deault set to 0 434 - centred (:obj:`bool`): if True, compute the centered RMSprop, \ 435 the gradient is normalized by an estimation of its variance 436 - grad_clip_type (:obj:`str`): support [None, 'clip_momentum', 'clip_value', 'clip_norm', \ 437 'clip_momentum_norm'] 438 - clip_value (:obj:`float`): the value to start clipping 439 - clip_coef (:obj:`float`): the cliping coefficient 440 - clip_norm_type (:obj:`float`): 2.0 means use norm2 to clip 441 - clip_momentum_timestep (:obj:`int`): after how many step should we start the momentum clipping 442 - grad_ignore_type (:obj:`str`): support [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', \ 443 'ignore_momentum_norm'] 444 - ignore_value (:obj:`float`): the value to start ignoring 445 - ignore_coef (:obj:`float`): the ignoreing coefficient 446 - ignore_norm_type (:obj:`float`): 2.0 means use norm2 to ignore 447 - ignore_momentum_timestep (:obj:`int`): after how many step should we start the momentum ignoring 448 """ 449 450 self._support_type = { 451 'grad_clip': [None, 'clip_momentum', 'clip_value', 'clip_norm', 'clip_momentum_norm'], 452 'grad_norm': [None], 453 'grad_ignore': [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', 'ignore_momentum_norm'], 454 } 455 456 assert grad_clip_type in self._support_type['grad_clip'] 457 assert grad_norm_type in self._support_type['grad_norm'] 458 assert grad_ignore_type in self._support_type['grad_ignore'] 459 if grad_clip_type: 460 assert clip_value is not None 461 if grad_ignore_type: 462 assert ignore_value is not None 463 464 self._grad_clip_type = grad_clip_type 465 self._grad_norm_type = grad_norm_type 466 self._grad_ignore_type = grad_ignore_type 467 self._clip_value = clip_value 468 self._clip_norm_type = clip_norm_type 469 self._clip_coef = clip_coef 470 self._ignore_value = ignore_value 471 self._ignore_norm_type = ignore_norm_type 472 self._ignore_coef = ignore_coef 473 self._clip_momentum_timestep = clip_momentum_timestep 474 self._ignore_momentum_timestep = ignore_momentum_timestep 475 476 super(RMSprop, self).__init__( 477 params, lr=lr, alpha=alpha, eps=eps, weight_decay=weight_decay, momentum=momentum, centered=centered 478 ) 479 480 def _state_init(self, p, momentum, centered): 481 """ 482 Overview: 483 Initialize the state of the optimizer 484 Arguments: 485 - p (:obj:`torch.Tensor`): the parameter to be optimized 486 - momentum (:obj:`float`): the momentum coefficient 487 - centered (:obj:`bool`): if True, compute the centered RMSprop, \ 488 the gradient is normalized by an estimation of its variance 489 """ 490 491 state = self.state[p] 492 if torch.__version__ < "1.12.0": 493 state['step'] = 0 494 # TODO 495 # wait torch upgrad to 1.4, 1.3.1 didn't support memory format state['step'] = 0 496 else: 497 state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ 498 if ('capturable' in self.defaults and self.defaults['capturable']) else torch.tensor(0.) 499 state['thre_square_avg'] = torch.zeros_like(p.data, device=p.data.device) 500 state['square_avg'] = torch.zeros_like(p.data, device=p.data.device) 501 if momentum: 502 state['momentum_buffer'] = torch.zeros_like(p.data, device=p.data.device) 503 if centered: 504 state['grad_avg'] = torch.zeros_like(p.data, device=p.data.device) 505 506 def step(self, closure: Union[Callable, None] = None): 507 """ 508 Overview: 509 Performs a single optimization step 510 Arguments: 511 - closure (:obj:`callable`): A closure that reevaluates the model and returns the loss, default set to None 512 """ 513 # clipping 514 new_params = [ 515 t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None 516 ] 517 if self._grad_clip_type == 'clip_value': 518 clip_grad_value_(new_params, self._clip_value) 519 elif self._grad_clip_type == 'clip_norm': 520 clip_grad_norm_(new_params, self._clip_value, self._clip_norm_type) 521 elif self._grad_clip_type == 'clip_momentum': 522 ''' 523 This implementation mimics the clip used in OPENAI, quote: 524 'Gradients are additionally clipped per parameter to be within between ±5√v 525 where v is the running estimate of the second moment of the (unclipped) gradient' 526 ''' 527 for group in self.param_groups: 528 for p in group['params']: 529 if p.grad is None: 530 continue 531 state = self.state[p] 532 if len(state) == 0: 533 self._state_init(p, group['momentum'], group['centered']) 534 grad = p.grad.data 535 # beta1, beta2 = group['betas'] 536 alpha = group['alpha'] 537 state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad) 538 if state['step'] >= self._clip_momentum_timestep: # initial value is inaccurate 539 flag = grad.abs() > state['thre_square_avg'].sqrt() * self._clip_coef 540 grad.mul_(~flag).add_((state['thre_square_avg'].sqrt() * self._clip_coef).mul_(flag)) 541 elif self._grad_clip_type == 'clip_momentum_norm': 542 # might have multi param_group, we should calculate each group differently. 543 for group in self.param_groups: 544 total_norm = 0 545 total_momentum_norm = 0 546 step = inf 547 for p in group['params']: 548 if p.grad is None: 549 continue 550 state = self.state[p] 551 if len(state) == 0: 552 self._state_init(p, group['momentum'], group['centered']) 553 grad = p.grad.data 554 alpha = group['alpha'] 555 state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad) 556 # sum total_norm 557 param_norm = grad.norm(self._clip_norm_type) 558 total_norm += param_norm.item() ** self._clip_norm_type 559 560 # sum momentum_norm 561 momentum = (state['thre_square_avg'].sqrt() * self._clip_coef).norm(self._clip_norm_type) 562 total_momentum_norm += momentum.item() ** self._clip_norm_type 563 step = min(step, state['step']) 564 if step > self._clip_momentum_timestep: 565 total_norm = total_norm ** (1. / self._clip_norm_type) 566 total_momentum_norm = total_momentum_norm ** (1. / self._clip_norm_type) 567 clip_coef = total_momentum_norm / (total_norm + 1e-6) 568 if clip_coef < 1: 569 for p in group['params']: 570 p.grad.data.mul_(clip_coef) 571 572 if self._grad_ignore_type == 'ignore_value': 573 grad_ignore_value(new_params, self._ignore_value) 574 elif self._grad_ignore_type == 'ignore_norm': 575 grad_ignore_norm(new_params, self._ignore_value, self._ignore_norm_type) 576 elif self._grad_ignore_type == 'ignore_momentum': 577 flag = False 578 for group in self.param_groups: 579 for p in group['params']: 580 if p.grad is None: 581 continue 582 state = self.state[p] 583 if len(state) == 0: 584 self._state_init(p, group['momentum'], group['centered']) 585 grad = p.grad.data 586 alpha = group['alpha'] 587 state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad) 588 if state['step'] >= self._ignore_momentum_timestep: # initial value is inaccurate 589 if grad.abs() > state['thre_square_avg'].sqrt() * self._ignore_coef: 590 flag = True 591 break 592 else: 593 continue 594 break 595 596 if flag: 597 for group in self.param_groups: 598 for p in group['params']: 599 if p.grad is None: 600 continue 601 p.grad.zero_() 602 elif self._grad_ignore_type == 'ignore_momentum_norm': 603 # might have multi param_group, we should calculate each group differently. 604 step = inf 605 for group in self.param_groups: 606 total_norm = 0 607 total_momentum_norm = 0 608 for p in group['params']: 609 if p.grad is None: 610 continue 611 state = self.state[p] 612 if len(state) == 0: 613 self._state_init(p, group['momentum'], group['centered']) 614 grad = p.grad.data 615 alpha = group['alpha'] 616 state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad) 617 # sum total_norm 618 param_norm = grad.norm(self._ignore_norm_type) 619 total_norm += param_norm.item() ** self._ignore_norm_type 620 621 # sum momentum_norm 622 momentum = (state['thre_square_avg'].sqrt() * self._ignore_coef).norm(self._ignore_norm_type) 623 total_momentum_norm += momentum.item() ** self._ignore_norm_type 624 step = min(step, state['step']) 625 626 if step > self._ignore_momentum_timestep: 627 total_norm = total_norm ** (1. / self._ignore_norm_type) 628 total_momentum_norm = total_momentum_norm ** (1. / self._ignore_norm_type) 629 ignore_coef = total_momentum_norm / (total_norm + 1e-6) 630 if ignore_coef < 1: 631 for p in group['params']: 632 p.grad.zero_() 633 634 return super().step(closure=closure) 635 636 def get_grad(self) -> float: 637 """ 638 Overview: 639 calculate grad norm of the parameters whose grad norms are not None in the model. 640 """ 641 642 total_norm = 0. 643 params = [t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None] 644 for p in params: 645 param_norm = p.grad.data.norm(self._clip_norm_type) 646 total_norm += param_norm.item() ** self._clip_norm_type 647 return total_norm 648 649 650class PCGrad(): 651 """ 652 Overview: 653 PCGrad optimizer to support multi-task. 654 you can view the paper in the following link https://arxiv.org/pdf/2001.06782.pdf 655 Interfaces: 656 ``__init__``, ``zero_grad``, ``step``, ``pc_backward`` 657 Properties: 658 - optimizer (:obj:`torch.optim`): the optimizer to be used 659 """ 660 661 def __init__(self, optimizer, reduction='mean'): 662 """ 663 Overview: 664 Initialization of PCGrad optimizer 665 Arguments: 666 - optimizer (:obj:`torch.optim`): the optimizer to be used 667 - reduction (:obj:`str`): the reduction method, support ['mean', 'sum'] 668 """ 669 670 self._optim, self._reduction = optimizer, reduction 671 672 @property 673 def optimizer(self): 674 """ 675 Overview: 676 get the optimizer 677 """ 678 679 return self._optim 680 681 def zero_grad(self): 682 """ 683 Overview: 684 clear the gradient of the parameters 685 """ 686 687 return self._optim.zero_grad(set_to_none=True) 688 689 def step(self): 690 """ 691 Overview: 692 update the parameters with the gradient 693 """ 694 695 return self._optim.step() 696 697 def pc_backward(self, objectives): 698 """ 699 Overview: 700 calculate the gradient of the parameters 701 Arguments: 702 - objectives: a list of objectives 703 """ 704 705 grads, shapes, has_grads = self._pack_grad(objectives) 706 pc_grad = self._project_conflicting(grads, has_grads) 707 pc_grad = self._unflatten_grad(pc_grad, shapes[0]) 708 self._set_grad(pc_grad) 709 return 710 711 def _project_conflicting(self, grads, has_grads, shapes=None): 712 """ 713 Overview: 714 project the conflicting gradient to the orthogonal space 715 Arguments: 716 - grads (:obj:`list`): a list of the gradient of the parameters 717 - has_grads (:obj:`list`): a list of mask represent whether the parameter has gradient 718 - shapes (:obj:`list`): a list of the shape of the parameters 719 """ 720 721 shared = torch.stack(has_grads).prod(0).bool() 722 pc_grad, num_task = copy.deepcopy(grads), len(grads) 723 for g_i in pc_grad: 724 random.shuffle(grads) 725 for g_j in grads: 726 g_i_g_j = torch.dot(g_i, g_j) 727 if g_i_g_j < 0: 728 g_i -= (g_i_g_j) * g_j / (g_j.norm() ** 2) 729 merged_grad = torch.zeros_like(grads[0]).to(grads[0].device) 730 if self._reduction: 731 merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).mean(dim=0) 732 elif self._reduction == 'sum': 733 merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).sum(dim=0) 734 else: 735 raise KeyError("invalid reduction method") 736 737 merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0) 738 return merged_grad 739 740 def _set_grad(self, grads): 741 """ 742 Overview: 743 set the modified gradients to the network 744 Arguments: 745 - grads (:obj:`list`): a list of the gradient of the parameters 746 """ 747 748 idx = 0 749 for group in self._optim.param_groups: 750 for p in group['params']: 751 # if p.grad is None: continue 752 p.grad = grads[idx] 753 idx += 1 754 return 755 756 def _pack_grad(self, objectives): 757 """ 758 Overview: 759 pack the gradient of the parameters of the network for each objective 760 Arguments: 761 - objectives: a list of objectives 762 Returns: 763 - grad: a list of the gradient of the parameters 764 - shape: a list of the shape of the parameters 765 - has_grad: a list of mask represent whether the parameter has gradient 766 """ 767 768 grads, shapes, has_grads = [], [], [] 769 for obj in objectives: 770 self._optim.zero_grad(set_to_none=True) 771 obj.backward(retain_graph=True) 772 grad, shape, has_grad = self._retrieve_grad() 773 grads.append(self._flatten_grad(grad, shape)) 774 has_grads.append(self._flatten_grad(has_grad, shape)) 775 shapes.append(shape) 776 return grads, shapes, has_grads 777 778 def _unflatten_grad(self, grads, shapes): 779 """ 780 Overview: 781 unflatten the gradient of the parameters of the network 782 Arguments: 783 - grads (:obj:`list`): a list of the gradient of the parameters 784 - shapes (:obj:`list`): a list of the shape of the parameters 785 """ 786 787 unflatten_grad, idx = [], 0 788 for shape in shapes: 789 length = np.prod(shape) 790 unflatten_grad.append(grads[idx:idx + length].view(shape).clone()) 791 idx += length 792 return unflatten_grad 793 794 def _flatten_grad(self, grads, shapes): 795 """ 796 Overview: 797 flatten the gradient of the parameters of the network 798 Arguments: 799 - grads (:obj:`list`): a list of the gradient of the parameters 800 - shapes (:obj:`list`): a list of the shape of the parameters 801 """ 802 803 flatten_grad = torch.cat([g.flatten() for g in grads]) 804 return flatten_grad 805 806 def _retrieve_grad(self): 807 """ 808 Overview: 809 get the gradient of the parameters of the network with specific objective 810 Returns: 811 - grad: a list of the gradient of the parameters 812 - shape: a list of the shape of the parameters 813 - has_grad: a list of mask represent whether the parameter has gradient 814 """ 815 816 grad, shape, has_grad = [], [], [] 817 for group in self._optim.param_groups: 818 for p in group['params']: 819 # if p.grad is None: continue 820 # tackle the multi-head scenario 821 if p.grad is None: 822 shape.append(p.shape) 823 grad.append(torch.zeros_like(p).to(p.device)) 824 has_grad.append(torch.zeros_like(p).to(p.device)) 825 continue 826 shape.append(p.grad.shape) 827 grad.append(p.grad.clone()) 828 has_grad.append(torch.ones_like(p).to(p.device)) 829 return grad, shape, has_grad 830 831 832def configure_weight_decay(model: nn.Module, weight_decay: float) -> List: 833 """ 834 Overview: 835 Separating out all parameters of the model into two buckets: those that will experience 836 weight decay for regularization and those that won't (biases, and layer-norm or embedding weights). 837 Arguments: 838 - model (:obj:`nn.Module`): The given PyTorch model. 839 - weight_decay (:obj:`float`): Weight decay value for optimizer. 840 Returns: 841 - optim groups (:obj:`List`): The parameter groups to be set in the latter optimizer. 842 """ 843 decay = set() 844 no_decay = set() 845 whitelist_weight_modules = (torch.nn.Linear, ) 846 blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 847 for mn, m in model.named_modules(): 848 for pn, p in m.named_parameters(): 849 fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 850 # Because named_modules and named_parameters are recursive 851 # we will see the same tensors p many times. But doing it this way 852 # allows us to know which parent module any tensor p belongs to. 853 if pn.endswith('bias'): 854 # all biases will not be decayed 855 no_decay.add(fpn) 856 elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 857 # weights of whitelist modules will be weight decayed 858 decay.add(fpn) 859 elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 860 # weights of blacklist modules will NOT be weight decayed 861 no_decay.add(fpn) 862 else: 863 decay.add(fpn) 864 865 decay = decay - no_decay 866 # validate that we considered every parameter 867 param_dict = {pn: p for pn, p in model.named_parameters()} 868 union_params = decay | no_decay 869 assert len( 870 param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 871 % (str(param_dict.keys() - union_params),) 872 873 optim_groups = [ 874 { 875 "params": [param_dict[pn] for pn in sorted(list(decay))], 876 "weight_decay": weight_decay 877 }, 878 { 879 "params": [param_dict[pn] for pn in sorted(list(no_decay))], 880 "weight_decay": 0.0 881 }, 882 ] 883 884 return optim_groups