1import inspect 2import os 3from collections import OrderedDict 4from typing import Optional, Iterable, Callable 5 6_innest_error = True 7 8_DI_ENGINE_REG_TRACE_IS_ON = os.environ.get('DIENGINEREGTRACE', 'OFF').upper() == 'ON' 9 10 11class Registry(dict): 12 """ 13 Overview: 14 A helper class for managing registering modules, it extends a dictionary 15 and provides a register functions. 16 Interfaces: 17 ``__init__``, ``register``, ``get``, ``build``, ``query``, ``query_details`` 18 Examples (creating): 19 >>> some_registry = Registry({"default": default_module}) 20 21 Examples (registering: normal way): 22 >>> def foo(): 23 >>> ... 24 >>> some_registry.register("foo_module", foo) 25 26 Examples (registering: decorator way): 27 >>> @some_registry.register("foo_module") 28 >>> @some_registry.register("foo_modeul_nickname") 29 >>> def foo(): 30 >>> ... 31 32 Examples (accessing): 33 >>> f = some_registry["foo_module"] 34 """ 35 36 def __init__(self, *args, **kwargs) -> None: 37 """ 38 Overview: 39 Initialize the Registry object. 40 Arguments: 41 - args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \ 42 dict. 43 - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \ 44 dict. 45 """ 46 47 super(Registry, self).__init__(*args, **kwargs) 48 self.__trace__ = dict() 49 50 def register( 51 self, 52 module_name: Optional[str] = None, 53 module: Optional[Callable] = None, 54 force_overwrite: bool = False 55 ) -> Callable: 56 """ 57 Overview: 58 Register the module. 59 Arguments: 60 - module_name (:obj:`Optional[str]`): The name of the module. 61 - module (:obj:`Optional[Callable]`): The module to be registered. 62 - force_overwrite (:obj:`bool`): Whether to overwrite the module with the same name. 63 """ 64 65 if _DI_ENGINE_REG_TRACE_IS_ON: 66 frame = inspect.stack()[1][0] 67 info = inspect.getframeinfo(frame) 68 filename = info.filename 69 lineno = info.lineno 70 # used as function call 71 if module is not None: 72 assert module_name is not None 73 Registry._register_generic(self, module_name, module, force_overwrite) 74 if _DI_ENGINE_REG_TRACE_IS_ON: 75 self.__trace__[module_name] = (filename, lineno) 76 return 77 78 # used as decorator 79 def register_fn(fn: Callable) -> Callable: 80 if module_name is None: 81 name = fn.__name__ 82 else: 83 name = module_name 84 Registry._register_generic(self, name, fn, force_overwrite) 85 if _DI_ENGINE_REG_TRACE_IS_ON: 86 self.__trace__[name] = (filename, lineno) 87 return fn 88 89 return register_fn 90 91 @staticmethod 92 def _register_generic(module_dict: dict, module_name: str, module: Callable, force_overwrite: bool = False) -> None: 93 """ 94 Overview: 95 Register the module. 96 Arguments: 97 - module_dict (:obj:`dict`): The dict to store the module. 98 - module_name (:obj:`str`): The name of the module. 99 - module (:obj:`Callable`): The module to be registered. 100 - force_overwrite (:obj:`bool`): Whether to overwrite the module with the same name. 101 """ 102 103 if not force_overwrite: 104 assert module_name not in module_dict, module_name 105 module_dict[module_name] = module 106 107 def get(self, module_name: str) -> Callable: 108 """ 109 Overview: 110 Get the module. 111 Arguments: 112 - module_name (:obj:`str`): The name of the module. 113 """ 114 115 return self[module_name] 116 117 def build(self, obj_type: str, *obj_args, **obj_kwargs) -> object: 118 """ 119 Overview: 120 Build the object. 121 Arguments: 122 - obj_type (:obj:`str`): The type of the object. 123 - obj_args (:obj:`Tuple`): The arguments passed to the object. 124 - obj_kwargs (:obj:`Dict`): The keyword arguments passed to the object. 125 """ 126 127 try: 128 build_fn = self[obj_type] 129 return build_fn(*obj_args, **obj_kwargs) 130 except Exception as e: 131 # get build_fn fail 132 if isinstance(e, KeyError): 133 raise KeyError("not support buildable-object type: {}".format(obj_type)) 134 # build_fn execution fail 135 global _innest_error 136 if _innest_error: 137 argspec = inspect.getfullargspec(build_fn) 138 message = 'Hint: for {}(alias={})'.format(build_fn, obj_type) 139 message += '\n\nExpected args are:\n {}\nGiven arguments keys are:\n{}\n'.format( 140 argspec, obj_kwargs.keys() 141 ) 142 print(message) 143 _innest_error = False 144 raise e 145 146 def query(self) -> Iterable: 147 """ 148 Overview: 149 all registered module names. 150 """ 151 152 return self.keys() 153 154 def query_details(self, aliases: Optional[Iterable] = None) -> OrderedDict: 155 """ 156 Overview: 157 Get the details of the registered modules. 158 Arguments: 159 - aliases (:obj:`Optional[Iterable]`): The aliases of the modules. 160 """ 161 162 assert _DI_ENGINE_REG_TRACE_IS_ON, "please exec 'export DIENGINEREGTRACE=ON' first" 163 if aliases is None: 164 aliases = self.keys() 165 return OrderedDict((alias, self.__trace__[alias]) for alias in aliases)