Skip to content

ding.framework.middleware.ckpt_handler

ding.framework.middleware.ckpt_handler

CkptSaver

Overview

The class used to save checkpoint data.

__init__(policy, save_dir, train_freq=None, save_finish=True)

Overview

Initialize the CkptSaver.

Arguments: - policy (:obj:Policy): Policy used to save the checkpoint. - save_dir (:obj:str): The directory path to save ckpt. - train_freq (:obj:int): Number of training iterations between each saving checkpoint data. - save_finish (:obj:bool): Whether save final ckpt when task.finish = True.

__call__(ctx)

Overview

The method used to save checkpoint data. The checkpoint data will be saved in a file in following 3 cases: - When a multiple of self.train_freq iterations have elapsed since the beginning of training; - When the evaluation episode return is the best so far; - When task.finish is True.

Input of ctx: - train_iter (:obj:int): Number of training iteration, i.e. the number of updating policy related network. - eval_value (:obj:float): The episode return of current iteration.

Full Source Code

../ding/framework/middleware/ckpt_handler.py

1from typing import TYPE_CHECKING, Optional, Union 2from easydict import EasyDict 3import os 4import numpy as np 5 6from ding.utils import save_file 7from ding.policy import Policy 8from ding.framework import task 9 10if TYPE_CHECKING: 11 from ding.framework import OnlineRLContext, OfflineRLContext 12 13 14class CkptSaver: 15 """ 16 Overview: 17 The class used to save checkpoint data. 18 """ 19 20 def __new__(cls, *args, **kwargs): 21 if task.router.is_active and not (task.has_role(task.role.LEARNER) or task.has_role(task.role.EVALUATOR)): 22 return task.void() 23 return super(CkptSaver, cls).__new__(cls) 24 25 def __init__(self, policy: Policy, save_dir: str, train_freq: Optional[int] = None, save_finish: bool = True): 26 """ 27 Overview: 28 Initialize the `CkptSaver`. 29 Arguments: 30 - policy (:obj:`Policy`): Policy used to save the checkpoint. 31 - save_dir (:obj:`str`): The directory path to save ckpt. 32 - train_freq (:obj:`int`): Number of training iterations between each saving checkpoint data. 33 - save_finish (:obj:`bool`): Whether save final ckpt when ``task.finish = True``. 34 """ 35 self.policy = policy 36 self.train_freq = train_freq 37 if str(os.path.basename(os.path.normpath(save_dir))) != "ckpt": 38 self.prefix = '{}/ckpt'.format(os.path.normpath(save_dir)) 39 else: 40 self.prefix = '{}/'.format(os.path.normpath(save_dir)) 41 if not os.path.exists(self.prefix): 42 os.makedirs(self.prefix) 43 self.last_save_iter = 0 44 self.max_eval_value = -np.inf 45 self.save_finish = save_finish 46 47 def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: 48 """ 49 Overview: 50 The method used to save checkpoint data. \ 51 The checkpoint data will be saved in a file in following 3 cases: \ 52 - When a multiple of `self.train_freq` iterations have elapsed since the beginning of training; \ 53 - When the evaluation episode return is the best so far; \ 54 - When `task.finish` is True. 55 Input of ctx: 56 - train_iter (:obj:`int`): Number of training iteration, i.e. the number of updating policy related network. 57 - eval_value (:obj:`float`): The episode return of current iteration. 58 """ 59 # train enough iteration 60 if self.train_freq: 61 if ctx.train_iter == 0 or ctx.train_iter - self.last_save_iter >= self.train_freq: 62 save_file( 63 "{}/iteration_{}.pth.tar".format(self.prefix, ctx.train_iter), self.policy.learn_mode.state_dict() 64 ) 65 self.last_save_iter = ctx.train_iter 66 67 # best episode return so far 68 if ctx.eval_value is not None and ctx.eval_value > self.max_eval_value: 69 save_file("{}/eval.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict()) 70 self.max_eval_value = ctx.eval_value 71 72 # finish 73 if task.finish and self.save_finish: 74 save_file("{}/final.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict())