Skip to content

ding.utils.log_helper

ding.utils.log_helper

TBLoggerFactory

Bases: object

Overview

TBLoggerFactory is a factory class for SummaryWriter.

Interfaces: create_logger Properties: - tb_loggers (:obj:Dict[str, SummaryWriter]): A dict that stores SummaryWriter instances.

LoggerFactory

Bases: object

Overview

LoggerFactory is a factory class for logging.Logger.

Interfaces: create_logger, get_tabulate_vars, get_tabulate_vars_hor

create_logger(path, name='default', level=logging.INFO) classmethod

Overview

Create logger using logging

Arguments: - name (:obj:str): Logger's name - path (:obj:str): Logger's save dir - level (:obj:int or :obj:str): Used to set the level. Reference: Logger.setLevel method. Returns: - (:obj:logging.Logger): new logging logger

get_tabulate_vars(variables) staticmethod

Overview

Get the text description in tabular form of all vars

Arguments: - variables (:obj:List[str]): Names of the vars to query. Returns: - string (:obj:str): Text description in tabular form of all vars

get_tabulate_vars_hor(variables) staticmethod

Overview

Get the text description in tabular form of all vars

Arguments: - variables (:obj:List[str]): Names of the vars to query.

build_logger(path, name=None, need_tb=True, need_text=True, text_level=logging.INFO)

Overview

Build text logger and tensorboard logger.

Arguments: - path (:obj:str): Logger(Textlogger & SummaryWriter)'s saved dir - name (:obj:str): The logger file name - need_tb (:obj:bool): Whether SummaryWriter instance would be created and returned - need_text (:obj:bool): Whether loggingLogger instance would be created and returned - text_level (:obj:int`` or :obj:str): Logging level of ``logging.Logger``, default set to ``logging.INFO`` Returns: - logger (:obj:Optional[logging.Logger]): Logger that displays terminal output - tb_logger (:obj:Optional['SummaryWriter']`): Saves output to tfboard, only return when need_tb.

pretty_print(result, direct_print=True)

Overview

Print a dict result in a pretty way

Arguments: - result (:obj:dict): The result to print - direct_print (:obj:bool): Whether to print directly Returns: - string (:obj:str): The pretty-printed result in str format

Full Source Code

../ding/utils/log_helper.py

1import json 2from ditk import logging 3import os 4from typing import Optional, Tuple, Union, Dict, Any 5 6import ditk.logging 7import numpy as np 8import yaml 9from hbutils.system import touch 10from tabulate import tabulate 11 12from .log_writer_helper import DistributedWriter 13 14 15def build_logger( 16 path: str, 17 name: Optional[str] = None, 18 need_tb: bool = True, 19 need_text: bool = True, 20 text_level: Union[int, str] = logging.INFO 21) -> Tuple[Optional[logging.Logger], Optional['SummaryWriter']]: # noqa 22 """ 23 Overview: 24 Build text logger and tensorboard logger. 25 Arguments: 26 - path (:obj:`str`): Logger(``Textlogger`` & ``SummaryWriter``)'s saved dir 27 - name (:obj:`str`): The logger file name 28 - need_tb (:obj:`bool`): Whether ``SummaryWriter`` instance would be created and returned 29 - need_text (:obj:`bool`): Whether ``loggingLogger`` instance would be created and returned 30 - text_level (:obj:`int`` or :obj:`str`): Logging level of ``logging.Logger``, default set to ``logging.INFO`` 31 Returns: 32 - logger (:obj:`Optional[logging.Logger]`): Logger that displays terminal output 33 - tb_logger (:obj:`Optional['SummaryWriter']`): Saves output to tfboard, only return when ``need_tb``. 34 """ 35 if name is None: 36 name = 'default' 37 logger = LoggerFactory.create_logger(path, name=name, level=text_level) if need_text else None 38 tb_name = name + '_tb_logger' 39 tb_logger = TBLoggerFactory.create_logger(os.path.join(path, tb_name)) if need_tb else None 40 return logger, tb_logger 41 42 43class TBLoggerFactory(object): 44 """ 45 Overview: 46 TBLoggerFactory is a factory class for ``SummaryWriter``. 47 Interfaces: 48 ``create_logger`` 49 Properties: 50 - ``tb_loggers`` (:obj:`Dict[str, SummaryWriter]`): A dict that stores ``SummaryWriter`` instances. 51 """ 52 53 tb_loggers = {} 54 55 @classmethod 56 def create_logger(cls: type, logdir: str) -> DistributedWriter: 57 if logdir in cls.tb_loggers: 58 return cls.tb_loggers[logdir] 59 tb_logger = DistributedWriter(logdir) 60 cls.tb_loggers[logdir] = tb_logger 61 return tb_logger 62 63 64class LoggerFactory(object): 65 """ 66 Overview: 67 LoggerFactory is a factory class for ``logging.Logger``. 68 Interfaces: 69 ``create_logger``, ``get_tabulate_vars``, ``get_tabulate_vars_hor`` 70 """ 71 72 @classmethod 73 def create_logger(cls, path: str, name: str = 'default', level: Union[int, str] = logging.INFO) -> logging.Logger: 74 """ 75 Overview: 76 Create logger using logging 77 Arguments: 78 - name (:obj:`str`): Logger's name 79 - path (:obj:`str`): Logger's save dir 80 - level (:obj:`int` or :obj:`str`): Used to set the level. Reference: ``Logger.setLevel`` method. 81 Returns: 82 - (:obj:`logging.Logger`): new logging logger 83 """ 84 ditk.logging.try_init_root(level) 85 86 logger_name = f'{name}_logger' 87 logger_file_path = os.path.join(path, f'{logger_name}.txt') 88 touch(logger_file_path) 89 90 logger = ditk.logging.getLogger(logger_name, level, [logger_file_path]) 91 logger.get_tabulate_vars = LoggerFactory.get_tabulate_vars 92 logger.get_tabulate_vars_hor = LoggerFactory.get_tabulate_vars_hor 93 94 return logger 95 96 @staticmethod 97 def get_tabulate_vars(variables: Dict[str, Any]) -> str: 98 """ 99 Overview: 100 Get the text description in tabular form of all vars 101 Arguments: 102 - variables (:obj:`List[str]`): Names of the vars to query. 103 Returns: 104 - string (:obj:`str`): Text description in tabular form of all vars 105 """ 106 headers = ["Name", "Value"] 107 data = [] 108 for k, v in variables.items(): 109 data.append([k, "{:.6f}".format(v)]) 110 s = "\n" + tabulate(data, headers=headers, tablefmt='grid') 111 return s 112 113 @staticmethod 114 def get_tabulate_vars_hor(variables: Dict[str, Any]) -> str: 115 """ 116 Overview: 117 Get the text description in tabular form of all vars 118 Arguments: 119 - variables (:obj:`List[str]`): Names of the vars to query. 120 """ 121 122 column_to_divide = 5 # which includes the header "Name & Value" 123 124 datak = [] 125 datav = [] 126 127 divide_count = 0 128 for k, v in variables.items(): 129 if divide_count == 0 or divide_count >= (column_to_divide - 1): 130 datak.append("Name") 131 datav.append("Value") 132 if divide_count >= (column_to_divide - 1): 133 divide_count = 0 134 divide_count += 1 135 136 datak.append(k) 137 if not isinstance(v, str) and np.isscalar(v): 138 datav.append("{:.6f}".format(v)) 139 else: 140 datav.append(v) 141 142 s = "\n" 143 row_number = len(datak) // column_to_divide + 1 144 for row_id in range(row_number): 145 item_start = row_id * column_to_divide 146 item_end = (row_id + 1) * column_to_divide 147 if (row_id + 1) * column_to_divide > len(datak): 148 item_end = len(datak) 149 data = [datak[item_start:item_end], datav[item_start:item_end]] 150 s = s + tabulate(data, tablefmt='grid') + "\n" 151 152 return s 153 154 155def pretty_print(result: dict, direct_print: bool = True) -> str: 156 """ 157 Overview: 158 Print a dict ``result`` in a pretty way 159 Arguments: 160 - result (:obj:`dict`): The result to print 161 - direct_print (:obj:`bool`): Whether to print directly 162 Returns: 163 - string (:obj:`str`): The pretty-printed result in str format 164 """ 165 result = result.copy() 166 out = {} 167 for k, v in result.items(): 168 if v is not None: 169 out[k] = v 170 cleaned = json.dumps(out) 171 string = yaml.safe_dump(json.loads(cleaned), default_flow_style=False) 172 if direct_print: 173 print(string) 174 return string