Skip to content

ding.hpc_rl.wrapper

ding.hpc_rl.wrapper

Full Source Code

../ding/hpc_rl/wrapper.py

1import importlib 2from ditk import logging 3from collections import OrderedDict 4from functools import wraps 5import ding 6''' 7Overview: 8 `hpc_wrapper` is the wrapper for functions which are supported by hpc. If a function is wrapped by it, we will 9 search for its hpc type and return the function implemented by hpc. 10 We will use the following code as a sample to introduce `hpc_wrapper`: 11 ``` 12 @hpc_wrapper(shape_fn=shape_fn_dntd, namedtuple_data=True, include_args=[0,1,2,3], 13 include_kwargs=['data', 'gamma', 'v_min', 'v_max'], is_cls_method=False) 14 def dist_nstep_td_error( 15 data: namedtuple, 16 gamma: float, 17 v_min: float, 18 v_max: float, 19 n_atom: int, 20 nstep: int = 1, 21 ) -> torch.Tensor: 22 ... 23 ``` 24Parameters: 25 - shape_fn (:obj:`function`): a function which return the shape needed by hpc function. In fact, it returns 26 all args that the hpc function needs. 27 - nametuple_data (:obj:`bool`): If True, when hpc function is called, it will be called as hpc_function(*nametuple). 28 If False, nametuple data will remain its `nametuple` type. 29 - include_args (:obj:`list`): a list of index of the args need to be set in hpc function. As shown in the sample, 30 include_args=[0,1,2,3], which means `data`, `gamma`, `v_min` and `v_max` will be set in hpc function. 31 - include_kwargs (:obj:`list`): a list of key of the kwargs need to be set in hpc function. As shown in the sample, 32 include_kwargs=['data', 'gamma', 'v_min', 'v_max'], which means `data`, `gamma`, `v_min` and `v_max` will be 33 set in hpc function. 34 - is_cls_method (:obj:`bool`): If True, it means the function we wrap is a method of a class. `self` will be put 35 into args. We will get rid of `self` in args. Besides, we will use its classname as its fn_name. 36 If False, it means the function is a simple method. 37Q&A: 38 - Q: Is `include_args` and `include_kwargs` need to be set at the same time? 39 - A: Yes. `include_args` and `include_kwargs` can deal with all type of input, such as (data, gamma, v_min=v_min, 40 v_max=v_max) and (data, gamma, v_min, v_max). 41 - Q: What is `hpc_fns`? 42 - A: Here we show a normal `hpc_fns`: 43 ``` 44 hpc_fns = { 45 'fn_name1': { 46 'runtime_name1': hpc_fn1, 47 'runtime_name2': hpc_fn2, 48 ... 49 }, 50 ... 51 } 52 ``` 53 Besides, `per_fn_limit` means the max length of `hpc_fns[fn_name]`. When new function comes, the oldest 54 function will be popped from `hpc_fns[fn_name]`. 55''' 56 57hpc_fns = {} 58per_fn_limit = 3 59 60 61def register_runtime_fn(fn_name, runtime_name, shape): 62 fn_name_mapping = { 63 'gae': ['hpc_rll.rl_utils.gae', 'GAE'], 64 'dist_nstep_td_error': ['hpc_rll.rl_utils.td', 'DistNStepTD'], 65 'LSTM': ['hpc_rll.torch_utils.network.rnn', 'LSTM'], 66 'ppo_error': ['hpc_rll.rl_utils.ppo', 'PPO'], 67 'q_nstep_td_error': ['hpc_rll.rl_utils.td', 'QNStepTD'], 68 'q_nstep_td_error_with_rescale': ['hpc_rll.rl_utils.td', 'QNStepTDRescale'], 69 'ScatterConnection': ['hpc_rll.torch_utils.network.scatter_connection', 'ScatterConnection'], 70 'td_lambda_error': ['hpc_rll.rl_utils.td', 'TDLambda'], 71 'upgo_loss': ['hpc_rll.rl_utils.upgo', 'UPGO'], 72 'vtrace_error_discrete_action': ['hpc_rll.rl_utils.vtrace', 'VTrace'], 73 } 74 fn_str = fn_name_mapping[fn_name] 75 cls = getattr(importlib.import_module(fn_str[0]), fn_str[1]) 76 hpc_fn = cls(*shape).cuda() 77 if fn_name not in hpc_fns: 78 hpc_fns[fn_name] = OrderedDict() 79 hpc_fns[fn_name][runtime_name] = hpc_fn 80 while len(hpc_fns[fn_name]) > per_fn_limit: 81 hpc_fns[fn_name].popitem(last=False) 82 # print(hpc_fns) 83 return hpc_fn 84 85 86def hpc_wrapper(shape_fn=None, namedtuple_data=False, include_args=[], include_kwargs=[], is_cls_method=False): 87 88 def decorate(fn): 89 90 @wraps(fn) 91 def wrapper(*args, **kwargs): 92 if ding.enable_hpc_rl: 93 shape = shape_fn(args, kwargs) 94 if is_cls_method: 95 fn_name = args[0].__class__.__name__ 96 else: 97 fn_name = fn.__name__ 98 runtime_name = '_'.join([fn_name] + [str(s) for s in shape]) 99 if fn_name not in hpc_fns or runtime_name not in hpc_fns[fn_name]: 100 hpc_fn = register_runtime_fn(fn_name, runtime_name, shape) 101 else: 102 hpc_fn = hpc_fns[fn_name][runtime_name] 103 if is_cls_method: 104 args = args[1:] 105 clean_args = [] 106 for i in include_args: 107 if i < len(args): 108 clean_args.append(args[i]) 109 nouse_args = list(set(list(range(len(args)))).difference(set(include_args))) 110 clean_kwargs = {} 111 for k, v in kwargs.items(): 112 if k in include_kwargs: 113 if k == 'lambda_': 114 k = 'lambda' 115 clean_kwargs[k] = v 116 nouse_kwargs = list(set(kwargs.keys()).difference(set(include_kwargs))) 117 if len(nouse_args) > 0 or len(nouse_kwargs) > 0: 118 logging.warn( 119 'in {}, index {} of args are dropped, and keys {} of kwargs are dropped.'.format( 120 runtime_name, nouse_args, nouse_kwargs 121 ) 122 ) 123 if namedtuple_data: 124 data = args[0] # args[0] is a namedtuple 125 return hpc_fn(*data, *clean_args[1:], **clean_kwargs) 126 else: 127 return hpc_fn(*clean_args, **clean_kwargs) 128 else: 129 return fn(*args, **kwargs) 130 131 return wrapper 132 133 return decorate