Skip to content

ding.framework.wrapper.step_timer

ding.framework.wrapper.step_timer

StepTimer

__init__(print_per_step=1, smooth_window=10)

Overview

Print time cost of each step (execute one middleware).

Arguments: - print_per_step (:obj:int): Print each N step. - smooth_window (:obj:int): The window size to smooth the mean.

Full Source Code

../ding/framework/wrapper/step_timer.py

1from collections import deque, defaultdict 2from functools import wraps 3from types import GeneratorType 4from typing import Callable 5import numpy as np 6import time 7from ditk import logging 8from ding.framework import task 9 10 11class StepTimer: 12 13 def __init__(self, print_per_step: int = 1, smooth_window: int = 10) -> None: 14 """ 15 Overview: 16 Print time cost of each step (execute one middleware). 17 Arguments: 18 - print_per_step (:obj:`int`): Print each N step. 19 - smooth_window (:obj:`int`): The window size to smooth the mean. 20 """ 21 22 self.print_per_step = print_per_step 23 self.records = defaultdict(lambda: deque(maxlen=print_per_step * smooth_window)) 24 25 def __call__(self, fn: Callable) -> Callable: 26 step_name = getattr(fn, "__name__", type(fn).__name__) 27 28 @wraps(fn) 29 def executor(ctx): 30 start_time = time.time() 31 time_cost = 0 32 g = fn(ctx) 33 if isinstance(g, GeneratorType): 34 try: 35 next(g) 36 except StopIteration: 37 pass 38 time_cost = time.time() - start_time 39 yield 40 start_time = time.time() 41 try: 42 next(g) 43 except StopIteration: 44 pass 45 time_cost += time.time() - start_time 46 else: 47 time_cost = time.time() - start_time 48 self.records[step_name].append(time_cost) 49 if ctx.total_step % self.print_per_step == 0: 50 logging.info( 51 "[Step Timer][Node:{:>2}] {}: Cost: {:.2f}ms, Mean: {:.2f}ms".format( 52 task.router.node_id or 0, step_name, time_cost * 1000, 53 np.mean(self.records[step_name]) * 1000 54 ) 55 ) 56 57 return executor