1from ditk import logging 2import signal 3import sys 4import traceback 5from typing import Callable 6import torch 7import torch.utils.data # torch1.1.0 compatibility 8from ding.utils import read_file, save_file 9 10logger = logging.getLogger('default_logger') 11 12 13def build_checkpoint_helper(cfg): 14 """ 15 Overview: 16 Use config to build checkpoint helper. 17 Arguments: 18 - cfg (:obj:`dict`): ckpt_helper config 19 Returns: 20 - (:obj:`CheckpointHelper`): checkpoint_helper created by this function 21 """ 22 return CheckpointHelper() 23 24 25class CheckpointHelper: 26 """ 27 Overview: 28 Help to save or load checkpoint by give args. 29 Interfaces: 30 ``__init__``, ``save``, ``load``, ``_remove_prefix``, ``_add_prefix``, ``_load_matched_model_state_dict`` 31 """ 32 33 def __init__(self): 34 pass 35 36 def _remove_prefix(self, state_dict: dict, prefix: str = 'module.') -> dict: 37 """ 38 Overview: 39 Remove prefix in state_dict 40 Arguments: 41 - state_dict (:obj:`dict`): model's state_dict 42 - prefix (:obj:`str`): this prefix will be removed in keys 43 Returns: 44 - new_state_dict (:obj:`dict`): new state_dict after removing prefix 45 """ 46 new_state_dict = {} 47 for k, v in state_dict.items(): 48 if k.startswith(prefix): 49 new_k = ''.join(k.split(prefix)) 50 else: 51 new_k = k 52 new_state_dict[new_k] = v 53 return new_state_dict 54 55 def _add_prefix(self, state_dict: dict, prefix: str = 'module.') -> dict: 56 """ 57 Overview: 58 Add prefix in state_dict 59 Arguments: 60 - state_dict (:obj:`dict`): model's state_dict 61 - prefix (:obj:`str`): this prefix will be added in keys 62 Returns: 63 - (:obj:`dict`): new state_dict after adding prefix 64 """ 65 return {prefix + k: v for k, v in state_dict.items()} 66 67 def save( 68 self, 69 path: str, 70 model: torch.nn.Module, 71 optimizer: torch.optim.Optimizer = None, 72 last_iter: 'CountVar' = None, # noqa 73 last_epoch: 'CountVar' = None, # noqa 74 last_frame: 'CountVar' = None, # noqa 75 dataset: torch.utils.data.Dataset = None, 76 collector_info: torch.nn.Module = None, 77 prefix_op: str = None, 78 prefix: str = None, 79 ) -> None: 80 """ 81 Overview: 82 Save checkpoint by given args 83 Arguments: 84 - path (:obj:`str`): the path of saving checkpoint 85 - model (:obj:`torch.nn.Module`): model to be saved 86 - optimizer (:obj:`torch.optim.Optimizer`): optimizer obj 87 - last_iter (:obj:`CountVar`): iter num, default None 88 - last_epoch (:obj:`CountVar`): epoch num, default None 89 - last_frame (:obj:`CountVar`): frame num, default None 90 - dataset (:obj:`torch.utils.data.Dataset`): dataset, should be replaydataset 91 - collector_info (:obj:`torch.nn.Module`): attr of checkpoint, save collector info 92 - prefix_op (:obj:`str`): should be ['remove', 'add'], process on state_dict 93 - prefix (:obj:`str`): prefix to be processed on state_dict 94 """ 95 checkpoint = {} 96 model = model.state_dict() 97 if prefix_op is not None: # remove or add prefix to model.keys() 98 prefix_func = {'remove': self._remove_prefix, 'add': self._add_prefix} 99 if prefix_op not in prefix_func.keys(): 100 raise KeyError('invalid prefix_op:{}'.format(prefix_op)) 101 else: 102 model = prefix_func[prefix_op](model, prefix) 103 checkpoint['model'] = model 104 105 if optimizer is not None: # save optimizer 106 assert (last_iter is not None or last_epoch is not None) 107 checkpoint['last_iter'] = last_iter.val 108 if last_epoch is not None: 109 checkpoint['last_epoch'] = last_epoch.val 110 if last_frame is not None: 111 checkpoint['last_frame'] = last_frame.val 112 checkpoint['optimizer'] = optimizer.state_dict() 113 114 if dataset is not None: 115 checkpoint['dataset'] = dataset.state_dict() 116 if collector_info is not None: 117 checkpoint['collector_info'] = collector_info.state_dict() 118 save_file(path, checkpoint) 119 logger.info('save checkpoint in {}'.format(path)) 120 121 def _load_matched_model_state_dict(self, model: torch.nn.Module, ckpt_state_dict: dict) -> None: 122 """ 123 Overview: 124 Load matched model state_dict, and show mismatch keys between model's state_dict and checkpoint's state_dict 125 Arguments: 126 - model (:obj:`torch.nn.Module`): model 127 - ckpt_state_dict (:obj:`dict`): checkpoint's state_dict 128 """ 129 assert isinstance(model, torch.nn.Module) 130 diff = {'miss_keys': [], 'redundant_keys': [], 'mismatch_shape_keys': []} 131 model_state_dict = model.state_dict() 132 model_keys = set(model_state_dict.keys()) 133 ckpt_keys = set(ckpt_state_dict.keys()) 134 diff['miss_keys'] = model_keys - ckpt_keys 135 diff['redundant_keys'] = ckpt_keys - model_keys 136 137 intersection_keys = model_keys.intersection(ckpt_keys) 138 valid_keys = [] 139 for k in intersection_keys: 140 if model_state_dict[k].shape == ckpt_state_dict[k].shape: 141 valid_keys.append(k) 142 else: 143 diff['mismatch_shape_keys'].append( 144 '{}\tmodel_shape: {}\tckpt_shape: {}'.format( 145 k, model_state_dict[k].shape, ckpt_state_dict[k].shape 146 ) 147 ) 148 valid_ckpt_state_dict = {k: v for k, v in ckpt_state_dict.items() if k in valid_keys} 149 model.load_state_dict(valid_ckpt_state_dict, strict=False) 150 151 for n, keys in diff.items(): 152 for k in keys: 153 logger.info('{}: {}'.format(n, k)) 154 155 def load( 156 self, 157 load_path: str, 158 model: torch.nn.Module, 159 optimizer: torch.optim.Optimizer = None, 160 last_iter: 'CountVar' = None, # noqa 161 last_epoch: 'CountVar' = None, # noqa 162 last_frame: 'CountVar' = None, # noqa 163 lr_schduler: 'Scheduler' = None, # noqa 164 dataset: torch.utils.data.Dataset = None, 165 collector_info: torch.nn.Module = None, 166 prefix_op: str = None, 167 prefix: str = None, 168 strict: bool = True, 169 logger_prefix: str = '', 170 state_dict_mask: list = [], 171 ): 172 """ 173 Overview: 174 Load checkpoint by given path 175 Arguments: 176 - load_path (:obj:`str`): checkpoint's path 177 - model (:obj:`torch.nn.Module`): model definition 178 - optimizer (:obj:`torch.optim.Optimizer`): optimizer obj 179 - last_iter (:obj:`CountVar`): iter num, default None 180 - last_epoch (:obj:`CountVar`): epoch num, default None 181 - last_frame (:obj:`CountVar`): frame num, default None 182 - lr_schduler (:obj:`Schduler`): lr_schduler obj 183 - dataset (:obj:`torch.utils.data.Dataset`): dataset, should be replaydataset 184 - collector_info (:obj:`torch.nn.Module`): attr of checkpoint, save collector info 185 - prefix_op (:obj:`str`): should be ['remove', 'add'], process on state_dict 186 - prefix (:obj:`str`): prefix to be processed on state_dict 187 - strict (:obj:`bool`): args of model.load_state_dict 188 - logger_prefix (:obj:`str`): prefix of logger 189 - state_dict_mask (:obj:`list`): A list containing state_dict keys, \ 190 which shouldn't be loaded into model(after prefix op) 191 192 .. note:: 193 194 The checkpoint loaded from load_path is a dict, whose format is like '{'state_dict': OrderedDict(), ...}' 195 """ 196 # TODO save config 197 # Note: for reduce first GPU memory cost and compatible for cpu env 198 checkpoint = read_file(load_path) 199 state_dict = checkpoint['model'] 200 if prefix_op is not None: 201 prefix_func = {'remove': self._remove_prefix, 'add': self._add_prefix} 202 if prefix_op not in prefix_func.keys(): 203 raise KeyError('invalid prefix_op:{}'.format(prefix_op)) 204 else: 205 state_dict = prefix_func[prefix_op](state_dict, prefix) 206 if len(state_dict_mask) > 0: 207 if strict: 208 logger.info( 209 logger_prefix + 210 '[Warning] non-empty state_dict_mask expects strict=False, but finds strict=True in input argument' 211 ) 212 strict = False 213 for m in state_dict_mask: 214 state_dict_keys = list(state_dict.keys()) 215 for k in state_dict_keys: 216 if k.startswith(m): 217 state_dict.pop(k) # ignore return value 218 if strict: 219 model.load_state_dict(state_dict, strict=True) 220 else: 221 self._load_matched_model_state_dict(model, state_dict) 222 logger.info(logger_prefix + 'load model state_dict in {}'.format(load_path)) 223 224 if dataset is not None: 225 if 'dataset' in checkpoint.keys(): 226 dataset.load_state_dict(checkpoint['dataset']) 227 logger.info(logger_prefix + 'load online data in {}'.format(load_path)) 228 else: 229 logger.info(logger_prefix + "dataset not in checkpoint, ignore load procedure") 230 231 if optimizer is not None: 232 if 'optimizer' in checkpoint.keys(): 233 optimizer.load_state_dict(checkpoint['optimizer']) 234 logger.info(logger_prefix + 'load optimizer in {}'.format(load_path)) 235 else: 236 logger.info(logger_prefix + "optimizer not in checkpoint, ignore load procedure") 237 238 if last_iter is not None: 239 if 'last_iter' in checkpoint.keys(): 240 last_iter.update(checkpoint['last_iter']) 241 logger.info( 242 logger_prefix + 'load last_iter in {}, current last_iter is {}'.format(load_path, last_iter.val) 243 ) 244 else: 245 logger.info(logger_prefix + "last_iter not in checkpoint, ignore load procedure") 246 247 if collector_info is not None: 248 collector_info.load_state_dict(checkpoint['collector_info']) 249 logger.info(logger_prefix + 'load collector info in {}'.format(load_path)) 250 251 if lr_schduler is not None: 252 assert (last_iter is not None) 253 raise NotImplementedError 254 255 256class CountVar(object): 257 """ 258 Overview: 259 Number counter 260 Interfaces: 261 ``__init__``, ``update``, ``add`` 262 Properties: 263 - val (:obj:`int`): the value of the counter 264 """ 265 266 def __init__(self, init_val: int) -> None: 267 """ 268 Overview: 269 Init the var counter 270 Arguments: 271 - init_val (:obj:`int`): the init value of the counter 272 """ 273 274 self._val = init_val 275 276 @property 277 def val(self) -> int: 278 """ 279 Overview: 280 Get the var counter 281 """ 282 283 return self._val 284 285 def update(self, val: int) -> None: 286 """ 287 Overview: 288 Update the var counter 289 Arguments: 290 - val (:obj:`int`): the update value of the counter 291 """ 292 self._val = val 293 294 def add(self, add_num: int): 295 """ 296 Overview: 297 Add the number to counter 298 Arguments: 299 - add_num (:obj:`int`): the number added to the counter 300 """ 301 self._val += add_num 302 303 304def auto_checkpoint(func: Callable) -> Callable: 305 """ 306 Overview: 307 Create a wrapper to wrap function, and the wrapper will call the save_checkpoint method 308 whenever an exception happens. 309 Arguments: 310 - func(:obj:`Callable`): the function to be wrapped 311 Returns: 312 - wrapper (:obj:`Callable`): the wrapped function 313 """ 314 dead_signals = ['SIGILL', 'SIGINT', 'SIGKILL', 'SIGQUIT', 'SIGSEGV', 'SIGSTOP', 'SIGTERM', 'SIGBUS'] 315 all_signals = dead_signals + ['SIGUSR1'] 316 317 def register_signal_handler(handler): 318 valid_sig = [] 319 invalid_sig = [] 320 for sig in all_signals: 321 try: 322 sig = getattr(signal, sig) 323 signal.signal(sig, handler) 324 valid_sig.append(sig) 325 except Exception: 326 invalid_sig.append(sig) 327 logger.info('valid sig: ({})\ninvalid sig: ({})'.format(valid_sig, invalid_sig)) 328 329 def wrapper(*args, **kwargs): 330 handle = args[0] 331 assert (hasattr(handle, 'save_checkpoint')) 332 333 def signal_handler(signal_num, frame): 334 sig = signal.Signals(signal_num) 335 logger.info("SIGNAL: {}({})".format(sig.name, sig.value)) 336 handle.save_checkpoint('ckpt_interrupt.pth.tar') 337 sys.exit(1) 338 339 register_signal_handler(signal_handler) 340 try: 341 return func(*args, **kwargs) 342 except Exception as e: 343 handle.save_checkpoint('ckpt_exception.pth.tar') 344 traceback.print_exc() 345 346 return wrapper