ding.framework.middleware.functional.termination_checker¶
ding.framework.middleware.functional.termination_checker
¶
Full Source Code
../ding/framework/middleware/functional/termination_checker.py
1from typing import TYPE_CHECKING, Union, Callable, Optional 2from ditk import logging 3import numpy as np 4import torch 5from ding.utils import broadcast 6from ding.framework import task 7 8if TYPE_CHECKING: 9 from ding.framework import OnlineRLContext, OfflineRLContext 10 11 12def termination_checker(max_env_step: Optional[int] = None, max_train_iter: Optional[int] = None) -> Callable: 13 if max_env_step is None: 14 max_env_step = np.inf 15 if max_train_iter is None: 16 max_train_iter = np.inf 17 18 def _check(ctx: Union["OnlineRLContext", "OfflineRLContext"]): 19 # ">" is better than ">=" when taking logger result into consideration 20 assert hasattr(ctx, "env_step") or hasattr(ctx, "train_iter"), "Context must have env_step or train_iter" 21 if hasattr(ctx, "env_step") and ctx.env_step > max_env_step: 22 task.finish = True 23 logging.info('Exceeded maximum number of env_step({}), program is terminated'.format(ctx.env_step)) 24 elif hasattr(ctx, "train_iter") and ctx.train_iter > max_train_iter: 25 task.finish = True 26 logging.info('Exceeded maximum number of train_iter({}), program is terminated'.format(ctx.train_iter)) 27 28 return _check 29 30 31def ddp_termination_checker(max_env_step=None, max_train_iter=None, rank=0): 32 if rank == 0: 33 if max_env_step is None: 34 max_env_step = np.inf 35 if max_train_iter is None: 36 max_train_iter = np.inf 37 38 def _check(ctx): 39 if rank == 0: 40 if ctx.env_step > max_env_step: 41 finish = torch.ones(1).long().cuda() 42 logging.info('Exceeded maximum number of env_step({}), program is terminated'.format(ctx.env_step)) 43 elif ctx.train_iter > max_train_iter: 44 finish = torch.ones(1).long().cuda() 45 logging.info('Exceeded maximum number of train_iter({}), program is terminated'.format(ctx.train_iter)) 46 else: 47 finish = torch.LongTensor([task.finish]).cuda() 48 else: 49 finish = torch.zeros(1).long().cuda() 50 # broadcast finish result to other DDP workers 51 broadcast(finish, 0) 52 task.finish = finish.cpu().bool().item() 53 54 return _check