1from typing import TYPE_CHECKING 2 3from tensorboardX import SummaryWriter 4 5if TYPE_CHECKING: 6 # TYPE_CHECKING is always False at runtime, but mypy will evaluate the contents of this block. 7 # So if you import this module within TYPE_CHECKING, you will get code hints and other benefits. 8 # Here is a good answer on stackoverflow: 9 # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports 10 from ding.framework import Parallel 11 12 13class DistributedWriter(SummaryWriter): 14 """ 15 Overview: 16 A simple subclass of SummaryWriter that supports writing to one process in multi-process mode. 17 The best way is to use it in conjunction with the ``router`` to take advantage of the message \ 18 and event components of the router (see ``writer.plugin``). 19 Interfaces: 20 ``get_instance``, ``plugin``, ``initialize``, ``__del__`` 21 """ 22 root = None 23 24 def __init__(self, *args, **kwargs): 25 """ 26 Overview: 27 Initialize the DistributedWriter object. 28 Arguments: 29 - args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \ 30 SummaryWriter. 31 - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \ 32 SummaryWriter. 33 """ 34 35 self._default_writer_to_disk = kwargs.get("write_to_disk") if "write_to_disk" in kwargs else True 36 # We need to write data to files lazily, so we should not use file writer in __init__, 37 # On the contrary, we will initialize the file writer when the user calls the 38 # add_* function for the first time 39 kwargs["write_to_disk"] = False 40 super().__init__(*args, **kwargs) 41 self._in_parallel = False 42 self._router = None 43 self._is_writer = False 44 self._lazy_initialized = False 45 46 @classmethod 47 def get_instance(cls, *args, **kwargs) -> "DistributedWriter": 48 """ 49 Overview: 50 Get instance and set the root level instance on the first called. If args and kwargs is none, 51 this method will return root instance. 52 Arguments: 53 - args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \ 54 SummaryWriter. 55 - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \ 56 SummaryWriter. 57 """ 58 if args or kwargs: 59 ins = cls(*args, **kwargs) 60 if cls.root is None: 61 cls.root = ins 62 return ins 63 else: 64 return cls.root 65 66 def plugin(self, router: "Parallel", is_writer: bool = False) -> "DistributedWriter": 67 """ 68 Overview: 69 Plugin ``router``, so when using this writer with active router, it will automatically send requests\ 70 to the main writer instead of writing it to the disk. So we can collect data from multiple processes\ 71 and write them into one file. 72 Arguments: 73 - router (:obj:`Parallel`): The router to be plugged in. 74 - is_writer (:obj:`bool`): Whether this writer is the main writer. 75 Examples: 76 >>> DistributedWriter().plugin(router, is_writer=True) 77 """ 78 if router.is_active: 79 self._in_parallel = True 80 self._router = router 81 self._is_writer = is_writer 82 if is_writer: 83 self.initialize() 84 self._lazy_initialized = True 85 router.on("distributed_writer", self._on_distributed_writer) 86 return self 87 88 def _on_distributed_writer(self, fn_name: str, *args, **kwargs): 89 """ 90 Overview: 91 This method is called when the router receives a request to write data. 92 Arguments: 93 - fn_name (:obj:`str`): The name of the function to be called. 94 - args (:obj:`Tuple`): The arguments passed to the function to be called. 95 - kwargs (:obj:`Dict`): The keyword arguments passed to the function to be called. 96 """ 97 98 if self._is_writer: 99 getattr(self, fn_name)(*args, **kwargs) 100 101 def initialize(self): 102 """ 103 Overview: 104 Initialize the file writer. 105 """ 106 self.close() 107 self._write_to_disk = self._default_writer_to_disk 108 self._get_file_writer() 109 self._lazy_initialized = True 110 111 def __del__(self): 112 """ 113 Overview: 114 Close the file writer. 115 """ 116 self.close() 117 118 119def enable_parallel(fn_name, fn): 120 """ 121 Overview: 122 Decorator to enable parallel writing. 123 Arguments: 124 - fn_name (:obj:`str`): The name of the function to be called. 125 - fn (:obj:`Callable`): The function to be called. 126 """ 127 128 def _parallel_fn(self: DistributedWriter, *args, **kwargs): 129 if not self._lazy_initialized: 130 self.initialize() 131 if self._in_parallel and not self._is_writer: 132 self._router.emit("distributed_writer", fn_name, *args, **kwargs) 133 else: 134 fn(self, *args, **kwargs) 135 136 return _parallel_fn 137 138 139ready_to_parallel_fns = [ 140 'add_audio', 141 'add_custom_scalars', 142 'add_custom_scalars_marginchart', 143 'add_custom_scalars_multilinechart', 144 'add_embedding', 145 'add_figure', 146 'add_graph', 147 'add_graph_deprecated', 148 'add_histogram', 149 'add_histogram_raw', 150 'add_hparams', 151 'add_image', 152 'add_image_with_boxes', 153 'add_images', 154 'add_mesh', 155 'add_onnx_graph', 156 'add_openvino_graph', 157 'add_pr_curve', 158 'add_pr_curve_raw', 159 'add_scalar', 160 'add_scalars', 161 'add_text', 162 'add_video', 163] 164for fn_name in ready_to_parallel_fns: 165 if hasattr(DistributedWriter, fn_name): 166 setattr(DistributedWriter, fn_name, enable_parallel(fn_name, getattr(DistributedWriter, fn_name))) 167 168# Examples: 169# In main, `distributed_writer.plugin(task.router, is_writer=True)`, 170# In middleware, `distributed_writer.record()` 171distributed_writer = DistributedWriter()