Skip to content

ding.model.wrapper.model_wrappers

ding.model.wrapper.model_wrappers

IModelWrapper

Bases: ABC

Overview

The basic interface class of model wrappers. Model wrapper is a wrapper class of torch.nn.Module model, which is used to add some extra operations for the wrapped model, such as hidden state maintain for RNN-base model, argmax action selection for discrete action space, etc.

Interfaces: __init__, __getattr__, info, reset, forward.

__init__(model)

Overview

Initialize model and other necessary member variabls in the model wrapper.

__getattr__(key)

Overview

Get original attrbutes of torch.nn.Module model, such as variables and methods defined in model.

Arguments: - key (:obj:str): The string key to query. Returns: - ret (:obj:Any): The queried attribute.

info(attr_name)

Overview

Get some string information of the indicated attr_name, which is used for debug wrappers. This method will recursively search for the indicated attr_name.

Arguments: - attr_name (:obj:str): The string key to query information. Returns: - info_string (:obj:str): The information string of the indicated attr_name.

reset(data_id=None, **kwargs)

Overview Basic interface, reset some stateful varaibles in the model wrapper, such as hidden state of RNN. Here we do nothing and just implement this interface method. Other derived model wrappers can override this method to add some extra operations. Arguments: - data_id (:obj:List[int]): The data id list to reset. If None, reset all data. In practice, model wrappers often needs to maintain some stateful variables for each data trajectory, so we leave this data_id argument to reset the stateful variables of the indicated data.

forward(*args, **kwargs)

Overview

Basic interface, call the wrapped model's forward method. Other derived model wrappers can override this method to add some extra operations.

BaseModelWrapper

Bases: IModelWrapper

Overview

Placeholder class for the model wrapper. This class is used to wrap the model without any extra operations, including a empty reset method and a forward method which directly call the wrapped model's forward. To keep the consistency of the model wrapper interface, we use this class to wrap the model without specific operations in the implementation of DI-engine's policy.

HiddenStateWrapper

Bases: IModelWrapper

Overview

Maintain the hidden state for RNN-base model. Each sample in a batch has its own state.

Interfaces: __init__, reset, forward.

__init__(model, state_num, save_prev_state=False, init_fn=lambda: None)

Overview

Maintain the hidden state for RNN-base model. Each sample in a batch has its own state. Init the maintain state and state function; Then wrap the model.forward method with auto saved data ['prev_state'] input, and create the model.reset method.

Arguments: - model(:obj:Any): Wrapped model class, should contain forward method. - state_num (:obj:int): Number of states to process. - save_prev_state (:obj:bool): Whether to output the prev state in output. - init_fn (:obj:Callable): The function which is used to init every hidden state when init and reset, default return None for hidden states.

.. note:: 1. This helper must deal with an actual batch with some parts of samples, e.g: 6 samples of state_num 8. 2. This helper must deal with the single sample state reset.

TransformerInputWrapper

Bases: IModelWrapper

__init__(model, seq_len, init_fn=lambda: None)

Overview

Given N the length of the sequences received by a Transformer model, maintain the last N-1 input observations. In this way we can provide at each step all the observations needed by Transformer to compute its output. We need this because some methods such as 'collect' and 'evaluate' only provide the model 1 observation per step and don't have memory of past observations, but Transformer needs a sequence of N observations. The wrapper method forward will save the input observation in a FIFO memory of length N and the method reset will reset the memory. The empty memory spaces will be initialized with 'init_fn' or zero by calling the method reset_input. Since different env can terminate at different steps, the method reset_memory_entry only initializes the memory of specific environments in the batch size.

Arguments: - model (:obj:Any): Wrapped model class, should contain forward method. - seq_len (:obj:int): Number of past observations to remember. - init_fn (:obj:Callable): The function which is used to init every memory locations when init and reset.

forward(input_obs, only_last_logit=True, data_id=None, **kwargs)

Parameters:

Name Type Description Default
- input_obs (

obj:torch.Tensor): Input observation without sequence shape: (bs, *obs_shape).

required
- only_last_logit (

obj:bool): if True 'logit' only contains the output corresponding to the current observation (shape: bs, embedding_dim), otherwise logit has shape (seq_len, bs, embedding_dim).

required
- data_id (

obj:List): id of the envs that are currently running. Memory update and logits return has only effect for those environments. If None it is considered that all envs are running.

required

Returns: - Dictionary containing the input_sequence 'input_seq' stored in memory and the transformer output 'logit'.

reset_input(input_obs)

Overview

Initialize the whole memory

reset_memory_entry(state_id=None)

Overview

Reset specific batch of the memory, batch ids are specified in 'state_id'

TransformerSegmentWrapper

Bases: IModelWrapper

__init__(model, seq_len)

Overview

Given T the length of a trajectory and N the length of the sequences received by a Transformer model, split T in sequences of N elements and forward each sequence one by one. If T % N != 0, the last sequence will be zero-padded. Usually used during Transformer training phase.

Arguments: - model (:obj:Any): Wrapped model class, should contain forward method. - seq_len (:obj:int): N, length of a sequence.

forward(obs, **kwargs)

Parameters:

Name Type Description Default
- data (

obj:dict): Dict type data, including at least ['main_obs', 'target_obs', 'action', 'reward', 'done', 'weight']

required

Returns: - List containing a dict of the model output for each sequence.

TransformerMemoryWrapper

Bases: IModelWrapper

__init__(model, batch_size)

Overview

Stores a copy of the Transformer memory in order to be reused across different phases. To make it more clear, suppose the training pipeline is divided into 3 phases: evaluate, collect, learn. The goal of the wrapper is to maintain the content of the memory at the end of each phase and reuse it when the same phase is executed again. In this way, it prevents different phases to interferer each other memory.

Arguments: - model (:obj:Any): Wrapped model class, should contain forward method. - batch_size (:obj:int): Memory batch size.

forward(*args, **kwargs)

Parameters:

Name Type Description Default
- data (

obj:dict): Dict type data, including at least ['main_obs', 'target_obs', 'action', 'reward', 'done', 'weight']

required

Returns: - Output of the forward method.

reset_memory_entry(state_id=None)

Overview

Reset specific batch of the memory, batch ids are specified in 'state_id'

ArgmaxSampleWrapper

Bases: IModelWrapper

Overview

Used to help the model to sample argmax action.

Interfaces: forward.

forward(*args, **kwargs)

Overview

Employ model forward computation graph, and use the output logit to greedily select max action (argmax).

CombinationArgmaxSampleWrapper

Bases: IModelWrapper

Overview

Used to help the model to sample combination argmax action.

Interfaces: forward.

CombinationMultinomialSampleWrapper

Bases: IModelWrapper

Overview

Used to help the model to sample combination multinomial action.

Interfaces: forward.

HybridArgmaxSampleWrapper

Bases: IModelWrapper

Overview

Used to help the model to sample argmax action in hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous}

Interfaces: forward.

MultinomialSampleWrapper

Bases: IModelWrapper

Overview

Used to help the model get the corresponding action from the output['logits']self.

Interfaces: forward.

EpsGreedySampleWrapper

Bases: IModelWrapper

Overview

Epsilon greedy sampler used in collector_model to help balance exploratin and exploitation. The type of eps can vary from different algorithms, such as: - float (i.e. python native scalar): for almost normal case - Dict[str, float]: for algorithm NGU

Interfaces: forward.

EpsGreedyMultinomialSampleWrapper

Bases: IModelWrapper

Overview

Epsilon greedy sampler coupled with multinomial sample used in collector_model to help balance exploration and exploitation.

Interfaces: forward.

HybridEpsGreedySampleWrapper

Bases: IModelWrapper

Overview

Epsilon greedy sampler used in collector_model to help balance exploration and exploitation. In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous}

Interfaces: forward.

HybridEpsGreedyMultinomialSampleWrapper

Bases: IModelWrapper

Overview

Epsilon greedy sampler coupled with multinomial sample used in collector_model to help balance exploration and exploitation. In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous}

Interfaces: forward.

HybridReparamMultinomialSampleWrapper

Bases: IModelWrapper

Overview

Reparameterization sampler coupled with multinomial sample used in collector_model to help balance exploration and exploitation. In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous}

Interfaces: forward

HybridDeterministicArgmaxSampleWrapper

Bases: IModelWrapper

Overview

Deterministic sampler coupled with argmax sample used in eval_model. In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous}

Interfaces: forward

DeterministicSampleWrapper

Bases: IModelWrapper

Overview

Deterministic sampler (just use mu directly) used in eval_model.

Interfaces: forward

ReparamSampleWrapper

Bases: IModelWrapper

Overview

Reparameterization gaussian sampler used in collector_model.

Interfaces: forward

ActionNoiseWrapper

Bases: IModelWrapper

Overview

Add noise to collector's action output; Do clips on both generated noise and action after adding noise.

Interfaces: __init__, forward. Arguments: - model (:obj:Any): Wrapped model class. Should contain forward method. - noise_type (:obj:str): The type of noise that should be generated, support ['gauss', 'ou']. - noise_kwargs (:obj:dict): Keyword args that should be used in noise init. Depends on noise_type. - noise_range (:obj:Optional[dict]): Range of noise, used for clipping. - action_range (:obj:Optional[dict]): Range of action + noise, used for clip, default clip to [-1, 1].

add_noise(action)

Overview

Generate noise and clip noise if needed. Add noise to action and clip action if needed.

Arguments: - action (:obj:torch.Tensor): Model's action output. Returns: - noised_action (:obj:torch.Tensor): Action processed after adding noise and clipping.

TargetNetworkWrapper

Bases: IModelWrapper

Overview

Maintain and update the target network

Interfaces: update, reset

update(state_dict, direct=False)

Overview

Update the target network state dict

Parameters:

Name Type Description Default
- state_dict (

obj:dict): the state_dict from learner model

required
- direct (

obj:bool): whether to update the target network directly, \ if true then will simply call the load_state_dict method of the model

required

reset_state(target_update_count=None)

Overview

Reset the update_count

Arguments: target_update_count (:obj:int): reset target update count value.

TeacherNetworkWrapper

Bases: IModelWrapper

Overview

Set the teacher Network. Set the model's model.teacher_cfg to the input teacher_cfg

model_wrap(model, wrapper_name=None, **kwargs)

Overview

Wrap the model with the specified wrapper and return the wrappered model.

Arguments: - model (:obj:Any): The model to be wrapped. - wrapper_name (:obj:str): The name of the wrapper to be used.

.. note:: The arguments of the wrapper should be passed in as kwargs.

register_wrapper(name, wrapper_type)

Overview

Register new wrapper to wrapper_name_map. When user implements a new wrapper, they must call this function to complete the registration. Then the wrapper can be called by model_wrap.

Arguments: - name (:obj:str): The name of the new wrapper to be registered. - wrapper_type (:obj:type): The wrapper class needs to be added in wrapper_name_map. This argument should be the subclass of IModelWrapper.

Full Source Code

../ding/model/wrapper/model_wrappers.py

1from typing import Any, Tuple, Callable, Optional, List, Dict, Union 2from abc import ABC 3import numpy as np 4import torch 5import torch.nn as nn 6import torch.nn.functional as F 7from torch.distributions import Categorical, Independent, Normal 8from ding.torch_utils import get_tensor_data, zeros_like 9from ding.rl_utils import create_noise_generator 10from ding.utils.data import default_collate 11 12 13class IModelWrapper(ABC): 14 """ 15 Overview: 16 The basic interface class of model wrappers. Model wrapper is a wrapper class of torch.nn.Module model, which \ 17 is used to add some extra operations for the wrapped model, such as hidden state maintain for RNN-base model, \ 18 argmax action selection for discrete action space, etc. 19 Interfaces: 20 ``__init__``, ``__getattr__``, ``info``, ``reset``, ``forward``. 21 """ 22 23 def __init__(self, model: nn.Module) -> None: 24 """ 25 Overview: 26 Initialize model and other necessary member variabls in the model wrapper. 27 """ 28 self._model = model 29 30 def __getattr__(self, key: str) -> Any: 31 """ 32 Overview: 33 Get original attrbutes of torch.nn.Module model, such as variables and methods defined in model. 34 Arguments: 35 - key (:obj:`str`): The string key to query. 36 Returns: 37 - ret (:obj:`Any`): The queried attribute. 38 """ 39 return getattr(self._model, key) 40 41 def info(self, attr_name: str) -> str: 42 """ 43 Overview: 44 Get some string information of the indicated ``attr_name``, which is used for debug wrappers. 45 This method will recursively search for the indicated ``attr_name``. 46 Arguments: 47 - attr_name (:obj:`str`): The string key to query information. 48 Returns: 49 - info_string (:obj:`str`): The information string of the indicated ``attr_name``. 50 """ 51 if attr_name in dir(self): 52 if isinstance(self._model, IModelWrapper): 53 return '{} {}'.format(self.__class__.__name__, self._model.info(attr_name)) 54 else: 55 if attr_name in dir(self._model): 56 return '{} {}'.format(self.__class__.__name__, self._model.__class__.__name__) 57 else: 58 return '{}'.format(self.__class__.__name__) 59 else: 60 if isinstance(self._model, IModelWrapper): 61 return '{}'.format(self._model.info(attr_name)) 62 else: 63 return '{}'.format(self._model.__class__.__name__) 64 65 def reset(self, data_id: List[int] = None, **kwargs) -> None: 66 """ 67 Overview 68 Basic interface, reset some stateful varaibles in the model wrapper, such as hidden state of RNN. 69 Here we do nothing and just implement this interface method. 70 Other derived model wrappers can override this method to add some extra operations. 71 Arguments: 72 - data_id (:obj:`List[int]`): The data id list to reset. If None, reset all data. In practice, \ 73 model wrappers often needs to maintain some stateful variables for each data trajectory, \ 74 so we leave this ``data_id`` argument to reset the stateful variables of the indicated data. 75 """ 76 # This is necessary when multiple model wrappers. 77 if hasattr(self._model, 'reset'): 78 return self._model.reset(data_id=data_id, **kwargs) 79 80 def forward(self, *args, **kwargs) -> Any: 81 """ 82 Overview: 83 Basic interface, call the wrapped model's forward method. Other derived model wrappers can override this \ 84 method to add some extra operations. 85 """ 86 return self._model.forward(*args, **kwargs) 87 88 89class BaseModelWrapper(IModelWrapper): 90 """ 91 Overview: 92 Placeholder class for the model wrapper. This class is used to wrap the model without any extra operations, \ 93 including a empty ``reset`` method and a ``forward`` method which directly call the wrapped model's forward. 94 To keep the consistency of the model wrapper interface, we use this class to wrap the model without specific \ 95 operations in the implementation of DI-engine's policy. 96 """ 97 pass 98 99 100class HiddenStateWrapper(IModelWrapper): 101 """ 102 Overview: 103 Maintain the hidden state for RNN-base model. Each sample in a batch has its own state. 104 Interfaces: 105 ``__init__``, ``reset``, ``forward``. 106 """ 107 108 def __init__( 109 self, 110 model: Any, 111 state_num: int, 112 save_prev_state: bool = False, 113 init_fn: Callable = lambda: None, 114 ) -> None: 115 """ 116 Overview: 117 Maintain the hidden state for RNN-base model. Each sample in a batch has its own state. \ 118 Init the maintain state and state function; Then wrap the ``model.forward`` method with auto \ 119 saved data ['prev_state'] input, and create the ``model.reset`` method. 120 Arguments: 121 - model(:obj:`Any`): Wrapped model class, should contain forward method. 122 - state_num (:obj:`int`): Number of states to process. 123 - save_prev_state (:obj:`bool`): Whether to output the prev state in output. 124 - init_fn (:obj:`Callable`): The function which is used to init every hidden state when init and reset, \ 125 default return None for hidden states. 126 127 .. note:: 128 1. This helper must deal with an actual batch with some parts of samples, e.g: 6 samples of state_num 8. 129 2. This helper must deal with the single sample state reset. 130 """ 131 super().__init__(model) 132 self._state_num = state_num 133 # This is to maintain hidden states (when it comes to this wrapper, \ 134 # map self._state into data['prev_value] and update next_state, store in self._state) 135 self._state = {i: init_fn() for i in range(state_num)} 136 self._save_prev_state = save_prev_state 137 self._init_fn = init_fn 138 139 def forward(self, data, **kwargs): 140 state_id = kwargs.pop('data_id', None) 141 valid_id = kwargs.pop('valid_id', None) # None, not used in any code in DI-engine 142 data, state_info = self.before_forward(data, state_id) # update data['prev_state'] with self._state 143 output = self._model.forward(data, **kwargs) 144 h = output.pop('next_state', None) 145 if h is not None: 146 self.after_forward(h, state_info, valid_id) # this is to store the 'next hidden state' for each time step 147 if self._save_prev_state: 148 prev_state = get_tensor_data(data['prev_state']) 149 # for compatibility, because of the incompatibility between None and torch.Tensor 150 for i in range(len(prev_state)): 151 if prev_state[i] is None: 152 prev_state[i] = zeros_like(h[0]) 153 output['prev_state'] = prev_state 154 return output 155 156 def reset(self, *args, **kwargs): 157 state = kwargs.pop('state', None) 158 state_id = kwargs.get('data_id', None) 159 self.reset_state(state, state_id) 160 if hasattr(self._model, 'reset'): 161 return self._model.reset(*args, **kwargs) 162 163 def reset_state(self, state: Optional[list] = None, state_id: Optional[list] = None) -> None: 164 if state_id is None: # train: init all states 165 state_id = [i for i in range(self._state_num)] 166 if state is None: # collect: init state that are done 167 state = [self._init_fn() for i in range(len(state_id))] 168 assert len(state) == len(state_id), '{}/{}'.format(len(state), len(state_id)) 169 for idx, s in zip(state_id, state): 170 self._state[idx] = s 171 172 def before_forward(self, data: dict, state_id: Optional[list]) -> Tuple[dict, dict]: 173 if state_id is None: 174 state_id = [i for i in range(self._state_num)] 175 176 state_info = {idx: self._state[idx] for idx in state_id} 177 data['prev_state'] = list(state_info.values()) 178 return data, state_info 179 180 def after_forward(self, h: Any, state_info: dict, valid_id: Optional[list] = None) -> None: 181 assert len(h) == len(state_info), '{}/{}'.format(len(h), len(state_info)) 182 for i, idx in enumerate(state_info.keys()): 183 if valid_id is None: 184 self._state[idx] = h[i] 185 else: 186 if idx in valid_id: 187 self._state[idx] = h[i] 188 189 190class TransformerInputWrapper(IModelWrapper): 191 192 def __init__(self, model: Any, seq_len: int, init_fn: Callable = lambda: None) -> None: 193 """ 194 Overview: 195 Given N the length of the sequences received by a Transformer model, maintain the last N-1 input 196 observations. In this way we can provide at each step all the observations needed by Transformer to 197 compute its output. We need this because some methods such as 'collect' and 'evaluate' only provide the 198 model 1 observation per step and don't have memory of past observations, but Transformer needs a sequence 199 of N observations. The wrapper method ``forward`` will save the input observation in a FIFO memory of 200 length N and the method ``reset`` will reset the memory. The empty memory spaces will be initialized 201 with 'init_fn' or zero by calling the method ``reset_input``. Since different env can terminate at 202 different steps, the method ``reset_memory_entry`` only initializes the memory of specific environments in 203 the batch size. 204 Arguments: 205 - model (:obj:`Any`): Wrapped model class, should contain forward method. 206 - seq_len (:obj:`int`): Number of past observations to remember. 207 - init_fn (:obj:`Callable`): The function which is used to init every memory locations when init and reset. 208 """ 209 super().__init__(model) 210 self.seq_len = seq_len 211 self._init_fn = init_fn 212 self.obs_memory = None # shape (N, bs, *obs_shape) 213 self.init_obs = None # sample of observation used to initialize the memory 214 self.bs = None 215 self.memory_idx = [] # len bs, index of where to put the next element in the sequence for each batch 216 217 def forward(self, 218 input_obs: torch.Tensor, 219 only_last_logit: bool = True, 220 data_id: List = None, 221 **kwargs) -> Dict[str, torch.Tensor]: 222 """ 223 Arguments: 224 - input_obs (:obj:`torch.Tensor`): Input observation without sequence shape: ``(bs, *obs_shape)``. 225 - only_last_logit (:obj:`bool`): if True 'logit' only contains the output corresponding to the current \ 226 observation (shape: bs, embedding_dim), otherwise logit has shape (seq_len, bs, embedding_dim). 227 - data_id (:obj:`List`): id of the envs that are currently running. Memory update and logits return has \ 228 only effect for those environments. If `None` it is considered that all envs are running. 229 Returns: 230 - Dictionary containing the input_sequence 'input_seq' stored in memory and the transformer output 'logit'. 231 """ 232 if self.obs_memory is None: 233 self.reset_input(torch.zeros_like(input_obs)) # init the memory with the size of the input observation 234 if data_id is None: 235 data_id = list(range(self.bs)) 236 assert self.obs_memory.shape[0] == self.seq_len 237 # implements a fifo queue, self.memory_idx is index where to put the last element 238 for i, b in enumerate(data_id): 239 if self.memory_idx[b] == self.seq_len: 240 # roll back of 1 position along dim 1 (sequence dim) 241 self.obs_memory[:, b] = torch.roll(self.obs_memory[:, b], -1, 0) 242 self.obs_memory[self.memory_idx[b] - 1, b] = input_obs[i] 243 if self.memory_idx[b] < self.seq_len: 244 self.obs_memory[self.memory_idx[b], b] = input_obs[i] 245 if self.memory_idx != self.seq_len: 246 self.memory_idx[b] += 1 247 out = self._model.forward(self.obs_memory, **kwargs) 248 out['input_seq'] = self.obs_memory 249 if only_last_logit: 250 # return only the logits for running environments 251 out['logit'] = [out['logit'][self.memory_idx[b] - 1][b] for b in range(self.bs) if b in data_id] 252 out['logit'] = default_collate(out['logit']) 253 return out 254 255 def reset_input(self, input_obs: torch.Tensor): 256 """ 257 Overview: 258 Initialize the whole memory 259 """ 260 init_obs = torch.zeros_like(input_obs) 261 self.init_obs = init_obs 262 self.obs_memory = [] # List(bs, *obs_shape) 263 for i in range(self.seq_len): 264 self.obs_memory.append(init_obs.clone() if init_obs is not None else self._init_fn()) 265 self.obs_memory = default_collate(self.obs_memory) # shape (N, bs, *obs_shape) 266 self.bs = self.init_obs.shape[0] 267 self.memory_idx = [0 for _ in range(self.bs)] 268 269 # called before evaluation 270 # called after each evaluation iteration for each done env 271 # called after each collect iteration for each done env 272 def reset(self, *args, **kwargs): 273 state_id = kwargs.get('data_id', None) 274 input_obs = kwargs.get('input_obs', None) 275 if input_obs is not None: 276 self.reset_input(input_obs) 277 if state_id is not None: 278 self.reset_memory_entry(state_id) 279 if input_obs is None and state_id is None: 280 self.obs_memory = None 281 if hasattr(self._model, 'reset'): 282 return self._model.reset(*args, **kwargs) 283 284 def reset_memory_entry(self, state_id: Optional[list] = None) -> None: 285 """ 286 Overview: 287 Reset specific batch of the memory, batch ids are specified in 'state_id' 288 """ 289 assert self.init_obs is not None, 'Call method "reset_memory" first' 290 for _id in state_id: 291 self.memory_idx[_id] = 0 292 self.obs_memory[:, _id] = self.init_obs[_id] # init the corresponding sequence with broadcasting 293 294 295class TransformerSegmentWrapper(IModelWrapper): 296 297 def __init__(self, model: Any, seq_len: int) -> None: 298 """ 299 Overview: 300 Given T the length of a trajectory and N the length of the sequences received by a Transformer model, 301 split T in sequences of N elements and forward each sequence one by one. If T % N != 0, the last sequence 302 will be zero-padded. Usually used during Transformer training phase. 303 Arguments: 304 - model (:obj:`Any`): Wrapped model class, should contain forward method. 305 - seq_len (:obj:`int`): N, length of a sequence. 306 """ 307 super().__init__(model) 308 self.seq_len = seq_len 309 310 def forward(self, obs: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: 311 """ 312 Arguments: 313 - data (:obj:`dict`): Dict type data, including at least \ 314 ['main_obs', 'target_obs', 'action', 'reward', 'done', 'weight'] 315 Returns: 316 - List containing a dict of the model output for each sequence. 317 """ 318 sequences = list(torch.split(obs, self.seq_len, dim=0)) 319 if sequences[-1].shape[0] < self.seq_len: 320 last = sequences[-1].clone() 321 diff = self.seq_len - last.shape[0] 322 sequences[-1] = F.pad(input=last, pad=(0, 0, 0, 0, 0, diff), mode='constant', value=0) 323 outputs = [] 324 for i, seq in enumerate(sequences): 325 out = self._model.forward(seq, **kwargs) 326 outputs.append(out) 327 out = {} 328 for k in outputs[0].keys(): 329 out_k = [o[k] for o in outputs] 330 out_k = torch.cat(out_k, dim=0) 331 out[k] = out_k 332 return out 333 334 335class TransformerMemoryWrapper(IModelWrapper): 336 337 def __init__( 338 self, 339 model: Any, 340 batch_size: int, 341 ) -> None: 342 """ 343 Overview: 344 Stores a copy of the Transformer memory in order to be reused across different phases. To make it more 345 clear, suppose the training pipeline is divided into 3 phases: evaluate, collect, learn. The goal of the 346 wrapper is to maintain the content of the memory at the end of each phase and reuse it when the same phase 347 is executed again. In this way, it prevents different phases to interferer each other memory. 348 Arguments: 349 - model (:obj:`Any`): Wrapped model class, should contain forward method. 350 - batch_size (:obj:`int`): Memory batch size. 351 """ 352 super().__init__(model) 353 # shape (layer_num, memory_len, bs, embedding_dim) 354 self._model.reset_memory(batch_size=batch_size) 355 self.memory = self._model.get_memory() 356 self.mem_shape = self.memory.shape 357 358 def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: 359 """ 360 Arguments: 361 - data (:obj:`dict`): Dict type data, including at least \ 362 ['main_obs', 'target_obs', 'action', 'reward', 'done', 'weight'] 363 Returns: 364 - Output of the forward method. 365 """ 366 self._model.reset_memory(state=self.memory) 367 out = self._model.forward(*args, **kwargs) 368 self.memory = self._model.get_memory() 369 return out 370 371 def reset(self, *args, **kwargs): 372 state_id = kwargs.get('data_id', None) 373 if state_id is None: 374 self.memory = torch.zeros(self.mem_shape) 375 else: 376 self.reset_memory_entry(state_id) 377 if hasattr(self._model, 'reset'): 378 return self._model.reset(*args, **kwargs) 379 380 def reset_memory_entry(self, state_id: Optional[list] = None) -> None: 381 """ 382 Overview: 383 Reset specific batch of the memory, batch ids are specified in 'state_id' 384 """ 385 for _id in state_id: 386 self.memory[:, :, _id] = torch.zeros((self.mem_shape[-1])) 387 388 def show_memory_occupancy(self, layer=0) -> None: 389 memory = self.memory 390 memory_shape = memory.shape 391 print('Layer {}-------------------------------------------'.format(layer)) 392 for b in range(memory_shape[-2]): 393 print('b{}: '.format(b), end='') 394 for m in range(memory_shape[1]): 395 if sum(abs(memory[layer][m][b].flatten())) != 0: 396 print(1, end='') 397 else: 398 print(0, end='') 399 print() 400 401 402def sample_action(logit=None, prob=None): 403 if prob is None: 404 prob = torch.softmax(logit, dim=-1) 405 shape = prob.shape 406 prob += 1e-8 407 prob = prob.view(-1, shape[-1]) 408 # prob can also be treated as weight in multinomial sample 409 action = torch.multinomial(prob, 1).squeeze(-1) 410 action = action.view(*shape[:-1]) 411 return action 412 413 414class ArgmaxSampleWrapper(IModelWrapper): 415 """ 416 Overview: 417 Used to help the model to sample argmax action. 418 Interfaces: 419 ``forward``. 420 """ 421 422 def forward(self, *args, **kwargs): 423 """ 424 Overview: 425 Employ model forward computation graph, and use the output logit to greedily select max action (argmax). 426 """ 427 output = self._model.forward(*args, **kwargs) 428 assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) 429 logit = output['logit'] 430 assert isinstance(logit, torch.Tensor) or isinstance(logit, list) 431 if isinstance(logit, torch.Tensor): 432 logit = [logit] 433 if 'action_mask' in output: 434 mask = output['action_mask'] 435 if isinstance(mask, torch.Tensor): 436 mask = [mask] 437 logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)] 438 action = [l.argmax(dim=-1) for l in logit] 439 if len(action) == 1: 440 action, logit = action[0], logit[0] 441 output['action'] = action 442 return output 443 444 445class CombinationArgmaxSampleWrapper(IModelWrapper): 446 r""" 447 Overview: 448 Used to help the model to sample combination argmax action. 449 Interfaces: 450 ``forward``. 451 """ 452 453 def forward(self, shot_number, *args, **kwargs): 454 output = self._model.forward(*args, **kwargs) 455 # Generate actions. 456 act = [] 457 mask = torch.zeros_like(output['logit']) 458 for ii in range(shot_number): 459 masked_logit = output['logit'] + mask 460 actions = masked_logit.argmax(dim=-1) 461 act.append(actions) 462 for jj in range(actions.shape[0]): 463 mask[jj][actions[jj]] = -1e8 464 # `act` is shaped: (B, shot_number) 465 act = torch.stack(act, dim=1) 466 output['action'] = act 467 return output 468 469 470class CombinationMultinomialSampleWrapper(IModelWrapper): 471 r""" 472 Overview: 473 Used to help the model to sample combination multinomial action. 474 Interfaces: 475 ``forward``. 476 """ 477 478 def forward(self, shot_number, *args, **kwargs): 479 output = self._model.forward(*args, **kwargs) 480 # Generate actions. 481 act = [] 482 mask = torch.zeros_like(output['logit']) 483 for ii in range(shot_number): 484 dist = torch.distributions.Categorical(logits=output['logit'] + mask) 485 actions = dist.sample() 486 act.append(actions) 487 for jj in range(actions.shape[0]): 488 mask[jj][actions[jj]] = -1e8 489 490 # `act` is shaped: (B, shot_number) 491 act = torch.stack(act, dim=1) 492 output['action'] = act 493 return output 494 495 496class HybridArgmaxSampleWrapper(IModelWrapper): 497 r""" 498 Overview: 499 Used to help the model to sample argmax action in hybrid action space, 500 i.e.{'action_type': discrete, 'action_args', continuous} 501 Interfaces: 502 ``forward``. 503 """ 504 505 def forward(self, *args, **kwargs): 506 output = self._model.forward(*args, **kwargs) 507 assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) 508 if 'logit' not in output: 509 return output 510 logit = output['logit'] 511 assert isinstance(logit, torch.Tensor) or isinstance(logit, list) 512 if isinstance(logit, torch.Tensor): 513 logit = [logit] 514 if 'action_mask' in output: 515 mask = output['action_mask'] 516 if isinstance(mask, torch.Tensor): 517 mask = [mask] 518 logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)] 519 action = [l.argmax(dim=-1) for l in logit] 520 if len(action) == 1: 521 action, logit = action[0], logit[0] 522 output = {'action': {'action_type': action, 'action_args': output['action_args']}, 'logit': logit} 523 return output 524 525 526class MultinomialSampleWrapper(IModelWrapper): 527 """ 528 Overview: 529 Used to help the model get the corresponding action from the output['logits']self. 530 Interfaces: 531 ``forward``. 532 """ 533 534 def forward(self, *args, **kwargs): 535 if 'alpha' in kwargs.keys(): 536 alpha = kwargs.pop('alpha') 537 else: 538 alpha = None 539 output = self._model.forward(*args, **kwargs) 540 assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) 541 logit = output['logit'] 542 assert isinstance(logit, torch.Tensor) or isinstance(logit, list) 543 if isinstance(logit, torch.Tensor): 544 logit = [logit] 545 if 'action_mask' in output: 546 mask = output['action_mask'] 547 if isinstance(mask, torch.Tensor): 548 mask = [mask] 549 logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)] 550 if alpha is None: 551 action = [sample_action(logit=l) for l in logit] 552 else: 553 # Note that if alpha is passed in here, we will divide logit by alpha. 554 action = [sample_action(logit=l / alpha) for l in logit] 555 if len(action) == 1: 556 action, logit = action[0], logit[0] 557 output['action'] = action 558 return output 559 560 561class EpsGreedySampleWrapper(IModelWrapper): 562 r""" 563 Overview: 564 Epsilon greedy sampler used in collector_model to help balance exploratin and exploitation. 565 The type of eps can vary from different algorithms, such as: 566 - float (i.e. python native scalar): for almost normal case 567 - Dict[str, float]: for algorithm NGU 568 Interfaces: 569 ``forward``. 570 """ 571 572 def forward(self, *args, **kwargs): 573 eps = kwargs.pop('eps') 574 output = self._model.forward(*args, **kwargs) 575 assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) 576 logit = output['logit'] 577 assert isinstance(logit, torch.Tensor) or isinstance(logit, list) 578 if isinstance(logit, torch.Tensor): 579 logit = [logit] 580 if 'action_mask' in output: 581 mask = output['action_mask'] 582 if isinstance(mask, torch.Tensor): 583 mask = [mask] 584 logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)] 585 else: 586 mask = None 587 action = [] 588 if isinstance(eps, dict): 589 # for NGU policy, eps is a dict, each collect env has a different eps 590 for i, l in enumerate(logit[0]): 591 eps_tmp = eps[i] 592 if np.random.random() > eps_tmp: 593 action.append(l.argmax(dim=-1)) 594 else: 595 if mask is not None: 596 action.append( 597 sample_action(prob=mask[0][i].float().unsqueeze(0)).to(logit[0].device).squeeze(0) 598 ) 599 else: 600 action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1]).to(logit[0].device)) 601 action = torch.stack(action, dim=-1) # shape torch.size([env_num]) 602 else: 603 for i, l in enumerate(logit): 604 if np.random.random() > eps: 605 action.append(l.argmax(dim=-1)) 606 else: 607 if mask is not None: 608 action.append(sample_action(prob=mask[i].float())) 609 else: 610 action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1])) 611 if len(action) == 1: 612 action, logit = action[0], logit[0] 613 output['action'] = action 614 return output 615 616 617class EpsGreedyMultinomialSampleWrapper(IModelWrapper): 618 r""" 619 Overview: 620 Epsilon greedy sampler coupled with multinomial sample used in collector_model 621 to help balance exploration and exploitation. 622 Interfaces: 623 ``forward``. 624 """ 625 626 def forward(self, *args, **kwargs): 627 eps = kwargs.pop('eps') 628 if 'alpha' in kwargs.keys(): 629 alpha = kwargs.pop('alpha') 630 else: 631 alpha = None 632 output = self._model.forward(*args, **kwargs) 633 assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) 634 logit = output['logit'] 635 assert isinstance(logit, torch.Tensor) or isinstance(logit, list) 636 if isinstance(logit, torch.Tensor): 637 logit = [logit] 638 if 'action_mask' in output: 639 mask = output['action_mask'] 640 if isinstance(mask, torch.Tensor): 641 mask = [mask] 642 logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)] 643 else: 644 mask = None 645 action = [] 646 for i, l in enumerate(logit): 647 if np.random.random() > eps: 648 if alpha is None: 649 action = [sample_action(logit=l) for l in logit] 650 else: 651 # Note that if alpha is passed in here, we will divide logit by alpha. 652 action = [sample_action(logit=l / alpha) for l in logit] 653 else: 654 if mask: 655 action.append(sample_action(prob=mask[i].float())) 656 else: 657 action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1])) 658 if len(action) == 1: 659 action, logit = action[0], logit[0] 660 output['action'] = action 661 return output 662 663 664class HybridEpsGreedySampleWrapper(IModelWrapper): 665 r""" 666 Overview: 667 Epsilon greedy sampler used in collector_model to help balance exploration and exploitation. 668 In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous} 669 Interfaces: 670 ``forward``. 671 """ 672 673 def forward(self, *args, **kwargs): 674 eps = kwargs.pop('eps') 675 output = self._model.forward(*args, **kwargs) 676 assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) 677 logit = output['logit'] 678 assert isinstance(logit, torch.Tensor) or isinstance(logit, list) 679 if isinstance(logit, torch.Tensor): 680 logit = [logit] 681 if 'action_mask' in output: 682 mask = output['action_mask'] 683 if isinstance(mask, torch.Tensor): 684 mask = [mask] 685 logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)] 686 else: 687 mask = None 688 action = [] 689 for i, l in enumerate(logit): 690 if np.random.random() > eps: 691 action.append(l.argmax(dim=-1)) 692 else: 693 if mask: 694 action.append(sample_action(prob=mask[i].float())) 695 else: 696 action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1])) 697 if len(action) == 1: 698 action, logit = action[0], logit[0] 699 output = {'action': {'action_type': action, 'action_args': output['action_args']}, 'logit': logit} 700 return output 701 702 703class HybridEpsGreedyMultinomialSampleWrapper(IModelWrapper): 704 """ 705 Overview: 706 Epsilon greedy sampler coupled with multinomial sample used in collector_model 707 to help balance exploration and exploitation. 708 In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous} 709 Interfaces: 710 ``forward``. 711 """ 712 713 def forward(self, *args, **kwargs): 714 eps = kwargs.pop('eps') 715 output = self._model.forward(*args, **kwargs) 716 assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) 717 if 'logit' not in output: 718 return output 719 720 logit = output['logit'] 721 assert isinstance(logit, torch.Tensor) or isinstance(logit, list) 722 if isinstance(logit, torch.Tensor): 723 logit = [logit] 724 if 'action_mask' in output: 725 mask = output['action_mask'] 726 if isinstance(mask, torch.Tensor): 727 mask = [mask] 728 logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)] 729 else: 730 mask = None 731 action = [] 732 for i, l in enumerate(logit): 733 if np.random.random() > eps: 734 action = [sample_action(logit=l) for l in logit] 735 else: 736 if mask: 737 action.append(sample_action(prob=mask[i].float())) 738 else: 739 action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1])) 740 if len(action) == 1: 741 action, logit = action[0], logit[0] 742 output = {'action': {'action_type': action, 'action_args': output['action_args']}, 'logit': logit} 743 return output 744 745 746class HybridReparamMultinomialSampleWrapper(IModelWrapper): 747 """ 748 Overview: 749 Reparameterization sampler coupled with multinomial sample used in collector_model 750 to help balance exploration and exploitation. 751 In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous} 752 Interfaces: 753 forward 754 """ 755 756 def forward(self, *args, **kwargs): 757 output = self._model.forward(*args, **kwargs) 758 assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) 759 760 logit = output['logit'] # logit: {'action_type': action_type_logit, 'action_args': action_args_logit} 761 # discrete part 762 action_type_logit = logit['action_type'] 763 prob = torch.softmax(action_type_logit, dim=-1) 764 pi_action = Categorical(prob) 765 action_type = pi_action.sample() 766 # continuous part 767 mu, sigma = logit['action_args']['mu'], logit['action_args']['sigma'] 768 dist = Independent(Normal(mu, sigma), 1) 769 action_args = dist.sample() 770 action = {'action_type': action_type, 'action_args': action_args} 771 output['action'] = action 772 return output 773 774 775class HybridDeterministicArgmaxSampleWrapper(IModelWrapper): 776 """ 777 Overview: 778 Deterministic sampler coupled with argmax sample used in eval_model. 779 In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous} 780 Interfaces: 781 forward 782 """ 783 784 def forward(self, *args, **kwargs): 785 output = self._model.forward(*args, **kwargs) 786 assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) 787 logit = output['logit'] # logit: {'action_type': action_type_logit, 'action_args': action_args_logit} 788 # discrete part 789 action_type_logit = logit['action_type'] 790 action_type = action_type_logit.argmax(dim=-1) 791 # continuous part 792 mu = logit['action_args']['mu'] 793 action_args = mu 794 action = {'action_type': action_type, 'action_args': action_args} 795 output['action'] = action 796 return output 797 798 799class DeterministicSampleWrapper(IModelWrapper): 800 """ 801 Overview: 802 Deterministic sampler (just use mu directly) used in eval_model. 803 Interfaces: 804 forward 805 """ 806 807 def forward(self, *args, **kwargs): 808 output = self._model.forward(*args, **kwargs) 809 assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) 810 output['action'] = output['logit']['mu'] 811 return output 812 813 814class ReparamSampleWrapper(IModelWrapper): 815 """ 816 Overview: 817 Reparameterization gaussian sampler used in collector_model. 818 Interfaces: 819 forward 820 """ 821 822 def forward(self, *args, **kwargs): 823 output = self._model.forward(*args, **kwargs) 824 assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) 825 mu, sigma = output['logit']['mu'], output['logit']['sigma'] 826 dist = Independent(Normal(mu, sigma), 1) 827 output['action'] = dist.sample() 828 return output 829 830 831class ActionNoiseWrapper(IModelWrapper): 832 r""" 833 Overview: 834 Add noise to collector's action output; Do clips on both generated noise and action after adding noise. 835 Interfaces: 836 ``__init__``, ``forward``. 837 Arguments: 838 - model (:obj:`Any`): Wrapped model class. Should contain ``forward`` method. 839 - noise_type (:obj:`str`): The type of noise that should be generated, support ['gauss', 'ou']. 840 - noise_kwargs (:obj:`dict`): Keyword args that should be used in noise init. Depends on ``noise_type``. 841 - noise_range (:obj:`Optional[dict]`): Range of noise, used for clipping. 842 - action_range (:obj:`Optional[dict]`): Range of action + noise, used for clip, default clip to [-1, 1]. 843 """ 844 845 def __init__( 846 self, 847 model: Any, 848 noise_type: str = 'gauss', 849 noise_kwargs: dict = {}, 850 noise_range: Optional[dict] = None, 851 action_range: Optional[dict] = { 852 'min': -1, 853 'max': 1 854 } 855 ) -> None: 856 super().__init__(model) 857 self.noise_generator = create_noise_generator(noise_type, noise_kwargs) 858 self.noise_range = noise_range 859 self.action_range = action_range 860 861 def forward(self, *args, **kwargs): 862 # if noise sigma need decay, update noise kwargs. 863 if 'sigma' in kwargs: 864 sigma = kwargs.pop('sigma') 865 if sigma is not None: 866 self.noise_generator.sigma = sigma 867 output = self._model.forward(*args, **kwargs) 868 assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) 869 if 'action' in output or 'action_args' in output: 870 key = 'action' if 'action' in output else 'action_args' 871 # handle hybrid action space by adding noise to continuous part of model output 872 action = output[key]['action_args'] if isinstance(output[key], dict) else output[key] 873 assert isinstance(action, torch.Tensor) 874 action = self.add_noise(action) 875 if isinstance(output[key], dict): 876 output[key]['action_args'] = action 877 else: 878 output[key] = action 879 return output 880 881 def add_noise(self, action: torch.Tensor) -> torch.Tensor: 882 r""" 883 Overview: 884 Generate noise and clip noise if needed. Add noise to action and clip action if needed. 885 Arguments: 886 - action (:obj:`torch.Tensor`): Model's action output. 887 Returns: 888 - noised_action (:obj:`torch.Tensor`): Action processed after adding noise and clipping. 889 """ 890 noise = self.noise_generator(action.shape, action.device) 891 if self.noise_range is not None: 892 noise = noise.clamp(self.noise_range['min'], self.noise_range['max']) 893 action += noise 894 if self.action_range is not None: 895 action = action.clamp(self.action_range['min'], self.action_range['max']) 896 return action 897 898 899class TargetNetworkWrapper(IModelWrapper): 900 r""" 901 Overview: 902 Maintain and update the target network 903 Interfaces: 904 update, reset 905 """ 906 907 def __init__(self, model: Any, update_type: str, update_kwargs: dict): 908 super().__init__(model) 909 assert update_type in ['momentum', 'assign'] 910 self._update_type = update_type 911 self._update_kwargs = update_kwargs 912 self._update_count = 0 913 914 def reset(self, *args, **kwargs): 915 target_update_count = kwargs.pop('target_update_count', None) 916 self.reset_state(target_update_count) 917 if hasattr(self._model, 'reset'): 918 return self._model.reset(*args, **kwargs) 919 920 def update(self, state_dict: dict, direct: bool = False) -> None: 921 r""" 922 Overview: 923 Update the target network state dict 924 925 Arguments: 926 - state_dict (:obj:`dict`): the state_dict from learner model 927 - direct (:obj:`bool`): whether to update the target network directly, \ 928 if true then will simply call the load_state_dict method of the model 929 """ 930 if direct: 931 self._model.load_state_dict(state_dict, strict=True) 932 self._update_count = 0 933 else: 934 if self._update_type == 'assign': 935 if (self._update_count + 1) % self._update_kwargs['freq'] == 0: 936 self._model.load_state_dict(state_dict, strict=True) 937 self._update_count += 1 938 elif self._update_type == 'momentum': 939 theta = self._update_kwargs['theta'] 940 for name, p in self._model.named_parameters(): 941 # default theta = 0.001 942 p.data = (1 - theta) * p.data + theta * state_dict[name] 943 944 def reset_state(self, target_update_count: int = None) -> None: 945 r""" 946 Overview: 947 Reset the update_count 948 Arguments: 949 target_update_count (:obj:`int`): reset target update count value. 950 """ 951 if target_update_count is not None: 952 self._update_count = target_update_count 953 954 955class TeacherNetworkWrapper(IModelWrapper): 956 """ 957 Overview: 958 Set the teacher Network. Set the model's model.teacher_cfg to the input teacher_cfg 959 """ 960 961 def __init__(self, model, teacher_cfg): 962 super().__init__(model) 963 self._model._teacher_cfg = teacher_cfg 964 raise NotImplementedError 965 966 967wrapper_name_map = { 968 'base': BaseModelWrapper, 969 'hidden_state': HiddenStateWrapper, 970 'argmax_sample': ArgmaxSampleWrapper, 971 'hybrid_argmax_sample': HybridArgmaxSampleWrapper, 972 'eps_greedy_sample': EpsGreedySampleWrapper, 973 'eps_greedy_multinomial_sample': EpsGreedyMultinomialSampleWrapper, 974 'deterministic_sample': DeterministicSampleWrapper, 975 'reparam_sample': ReparamSampleWrapper, 976 'hybrid_eps_greedy_sample': HybridEpsGreedySampleWrapper, 977 'hybrid_eps_greedy_multinomial_sample': HybridEpsGreedyMultinomialSampleWrapper, 978 'hybrid_reparam_multinomial_sample': HybridReparamMultinomialSampleWrapper, 979 'hybrid_deterministic_argmax_sample': HybridDeterministicArgmaxSampleWrapper, 980 'multinomial_sample': MultinomialSampleWrapper, 981 'action_noise': ActionNoiseWrapper, 982 'transformer_input': TransformerInputWrapper, 983 'transformer_segment': TransformerSegmentWrapper, 984 'transformer_memory': TransformerMemoryWrapper, 985 # model wrapper 986 'target': TargetNetworkWrapper, 987 'teacher': TeacherNetworkWrapper, 988 'combination_argmax_sample': CombinationArgmaxSampleWrapper, 989 'combination_multinomial_sample': CombinationMultinomialSampleWrapper, 990} 991 992 993def model_wrap(model: Union[nn.Module, IModelWrapper], wrapper_name: str = None, **kwargs): 994 """ 995 Overview: 996 Wrap the model with the specified wrapper and return the wrappered model. 997 Arguments: 998 - model (:obj:`Any`): The model to be wrapped. 999 - wrapper_name (:obj:`str`): The name of the wrapper to be used.10001001 .. note::1002 The arguments of the wrapper should be passed in as kwargs.1003 """1004 if wrapper_name in wrapper_name_map:1005 # TODO test whether to remove this if branch1006 if not isinstance(model, IModelWrapper):1007 model = wrapper_name_map['base'](model)1008 model = wrapper_name_map[wrapper_name](model, **kwargs)1009 else:1010 raise TypeError("not support model_wrapper type: {}".format(wrapper_name))1011 return model101210131014def register_wrapper(name: str, wrapper_type: type) -> None:1015 """1016 Overview:1017 Register new wrapper to ``wrapper_name_map``. When user implements a new wrapper, they must call this function \1018 to complete the registration. Then the wrapper can be called by ``model_wrap``.1019 Arguments:1020 - name (:obj:`str`): The name of the new wrapper to be registered.1021 - wrapper_type (:obj:`type`): The wrapper class needs to be added in ``wrapper_name_map``. This argument \1022 should be the subclass of ``IModelWrapper``.1023 """1024 assert isinstance(name, str)1025 assert issubclass(wrapper_type, IModelWrapper)1026 wrapper_name_map[name] = wrapper_type