Skip to content

ding.utils.time_helper_cuda

ding.utils.time_helper_cuda

get_cuda_time_wrapper()

Overview

Return the TimeWrapperCuda class, this wrapper aims to ensure compatibility in no cuda device

Returns:

Type Description
Callable[[], TimeWrapper]
  • TimeWrapperCuda(:obj:class): See TimeWrapperCuda class

.. note:: Must use torch.cuda.synchronize(), reference: https://blog.csdn.net/u013548568/article/details/81368019

Full Source Code

../ding/utils/time_helper_cuda.py

1from typing import Callable 2import torch 3from .time_helper_base import TimeWrapper 4 5 6def get_cuda_time_wrapper() -> Callable[[], 'TimeWrapper']: 7 """ 8 Overview: 9 Return the ``TimeWrapperCuda`` class, this wrapper aims to ensure compatibility in no cuda device 10 11 Returns: 12 - TimeWrapperCuda(:obj:`class`): See ``TimeWrapperCuda`` class 13 14 .. note:: 15 Must use ``torch.cuda.synchronize()``, reference: <https://blog.csdn.net/u013548568/article/details/81368019> 16 17 """ 18 19 # TODO find a way to autodoc the class within method 20 class TimeWrapperCuda(TimeWrapper): 21 """ 22 Overview: 23 A class method that inherit from ``TimeWrapper`` class 24 25 Notes: 26 Must use torch.cuda.synchronize(), reference: \ 27 <https://blog.csdn.net/u013548568/article/details/81368019> 28 29 Interfaces: 30 ``start_time``, ``end_time`` 31 """ 32 # cls variable is initialized on loading this class 33 start_record = torch.cuda.Event(enable_timing=True) 34 end_record = torch.cuda.Event(enable_timing=True) 35 36 # overwrite 37 @classmethod 38 def start_time(cls): 39 """ 40 Overview: 41 Implement and overide the ``start_time`` method in ``TimeWrapper`` class 42 """ 43 torch.cuda.synchronize() 44 cls.start = cls.start_record.record() 45 46 # overwrite 47 @classmethod 48 def end_time(cls): 49 """ 50 Overview: 51 Implement and overide the end_time method in ``TimeWrapper`` class 52 Returns: 53 - time(:obj:`float`): The time between ``start_time`` and ``end_time`` 54 """ 55 cls.end = cls.end_record.record() 56 torch.cuda.synchronize() 57 return cls.start_record.elapsed_time(cls.end_record) / 1000 58 59 return TimeWrapperCuda