1from typing import Union, Optional, Dict, Callable, List 2import torch 3import torch.nn as nn 4 5from ding.torch_utils import get_lstm, one_hot, to_tensor, to_ndarray 6from ding.utils import MODEL_REGISTRY, SequenceType, squeeze 7# from ding.torch_utils.data_helper import one_hot_embedding, one_hot_embedding_none 8from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, MultiHead, RainbowHead, \ 9 QuantileHead, QRDQNHead, DistributionHead 10 11 12def parallel_wrapper(forward_fn: Callable) -> Callable: 13 """ 14 Overview: 15 Process timestep T and batch_size B at the same time, in other words, treat different timestep data as \ 16 different trajectories in a batch. 17 Arguments: 18 - forward_fn (:obj:`Callable`): Normal ``nn.Module`` 's forward function. 19 Returns: 20 - wrapper (:obj:`Callable`): Wrapped function. 21 """ 22 23 def wrapper(x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]: 24 T, B = x.shape[:2] 25 26 def reshape(d): 27 if isinstance(d, list): 28 d = [reshape(t) for t in d] 29 elif isinstance(d, dict): 30 d = {k: reshape(v) for k, v in d.items()} 31 else: 32 d = d.reshape(T, B, *d.shape[1:]) 33 return d 34 35 x = x.reshape(T * B, *x.shape[2:]) 36 x = forward_fn(x) 37 x = reshape(x) 38 return x 39 40 return wrapper 41 42 43@MODEL_REGISTRY.register('ngu') 44class NGU(nn.Module): 45 """ 46 Overview: 47 The recurrent Q model for NGU(https://arxiv.org/pdf/2002.06038.pdf) policy, modified from the class DRQN in \ 48 q_leaning.py. The implementation mentioned in the original paper is 'adapt the R2D2 agent that uses the \ 49 dueling network architecture with an LSTM layer after a convolutional neural network'. The NGU network \ 50 includes encoder, LSTM core(rnn) and head. 51 Interface: 52 ``__init__``, ``forward``. 53 """ 54 55 def __init__( 56 self, 57 obs_shape: Union[int, SequenceType], 58 action_shape: Union[int, SequenceType], 59 encoder_hidden_size_list: SequenceType = [128, 128, 64], 60 collector_env_num: Optional[int] = 1, # TODO 61 dueling: bool = True, 62 head_hidden_size: Optional[int] = None, 63 head_layer_num: int = 1, 64 lstm_type: Optional[str] = 'normal', 65 activation: Optional[nn.Module] = nn.ReLU(), 66 norm_type: Optional[str] = None 67 ) -> None: 68 """ 69 Overview: 70 Init the DRQN Model for NGU according to arguments. 71 Arguments: 72 - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space, such as 8 or [4, 84, 84]. 73 - action_shape (:obj:`Union[int, SequenceType]`): Action's space, such as 6 or [2, 3, 3]. 74 - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``. 75 - collector_env_num (:obj:`Optional[int]`): The number of environments used to collect data simultaneously. 76 - dueling (:obj:`bool`): Whether choose ``DuelingHead`` (True) or ``DiscreteHead (False)``, \ 77 default to True. 78 - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``, should match the \ 79 last element of ``encoder_hidden_size_list``. 80 - head_layer_num (:obj:`int`): The number of layers in head network. 81 - lstm_type (:obj:`Optional[str]`): Version of rnn cell, now support ['normal', 'pytorch', 'hpc', 'gru'], \ 82 default is 'normal'. 83 - activation (:obj:`Optional[nn.Module]`): 84 The type of activation function to use in ``MLP`` the after ``layer_fn``, \ 85 if ``None`` then default set to ``nn.ReLU()``. 86 - norm_type (:obj:`Optional[str]`): 87 The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details`. 88 """ 89 super(NGU, self).__init__() 90 # For compatibility: 1, (1, ), [4, H, H] 91 obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape) 92 self.action_shape = action_shape 93 self.collector_env_num = collector_env_num 94 if head_hidden_size is None: 95 head_hidden_size = encoder_hidden_size_list[-1] 96 # FC Encoder 97 if isinstance(obs_shape, int) or len(obs_shape) == 1: 98 self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) 99 # Conv Encoder 100 elif len(obs_shape) == 3: 101 self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) 102 else: 103 raise RuntimeError( 104 "not support obs_shape for pre-defined encoder: {}, please customize your own DRQN".format(obs_shape) 105 ) 106 # NOTE: current obs hidden_state_dim, previous action, previous extrinsic reward, beta 107 # TODO(pu): add prev_reward_intrinsic to network input, reward uses some kind of embedding instead of 1D value 108 input_size = head_hidden_size + action_shape + 1 + self.collector_env_num 109 # LSTM Type 110 self.rnn = get_lstm(lstm_type, input_size=input_size, hidden_size=head_hidden_size) 111 # Head Type 112 if dueling: 113 head_cls = DuelingHead 114 else: 115 head_cls = DiscreteHead 116 multi_head = not isinstance(action_shape, int) 117 if multi_head: 118 self.head = MultiHead( 119 head_cls, 120 head_hidden_size, 121 action_shape, 122 layer_num=head_layer_num, 123 activation=activation, 124 norm_type=norm_type 125 ) 126 else: 127 self.head = head_cls( 128 head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type 129 ) 130 131 def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps: Optional[list] = None) -> Dict: 132 """ 133 Overview: 134 Forward computation graph of NGU R2D2 network. Input observation, prev_action prev_reward_extrinsic \ 135 to predict NGU Q output. Parameter updates with NGU's MLPs forward setup. 136 Arguments: 137 - inputs (:obj:`Dict`): 138 - obs (:obj:`torch.Tensor`): Encoded observation. 139 - prev_state (:obj:`list`): Previous state's tensor of size ``(B, N)``. 140 - inference: (:obj:'bool'): If inference is True, we unroll the one timestep transition, \ 141 if inference is False, we unroll the sequence transitions. 142 - saved_state_timesteps: (:obj:'Optional[list]'): When inference is False, \ 143 we unroll the sequence transitions, then we would save rnn hidden states at timesteps \ 144 that are listed in list saved_state_timesteps. 145 Returns: 146 - outputs (:obj:`Dict`): 147 Run ``MLP`` with ``DRQN`` setups and return the result prediction dictionary. 148 149 ReturnsKeys: 150 - logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``obs``. 151 - next_state (:obj:`list`): Next state's tensor of size ``(B, N)``. 152 Shapes: 153 - obs (:obj:`torch.Tensor`): :math:`(B, N=obs_space)`, where B is batch size. 154 - prev_state(:obj:`torch.FloatTensor list`): :math:`[(B, N)]`. 155 - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`. 156 - next_state(:obj:`torch.FloatTensor list`): :math:`[(B, N)]`. 157 """ 158 x, prev_state = inputs['obs'], inputs['prev_state'] 159 if 'prev_action' in inputs.keys(): 160 # collect, eval mode: pass into one timestep mini-batch data (batchsize=env_num) 161 prev_action = inputs['prev_action'] 162 prev_reward_extrinsic = inputs['prev_reward_extrinsic'] 163 else: 164 # train mode: pass into H timesteps mini-batch data (batchsize=train_batch_size) 165 prev_action = torch.cat( 166 [torch.ones_like(inputs['action'][:, 0].unsqueeze(1)) * (-1), inputs['action'][:, :-1]], dim=1 167 ) # (B, 1) (B, H-1) -> (B, H, self.action_shape) 168 prev_reward_extrinsic = torch.cat( 169 [torch.zeros_like(inputs['reward'][:, 0].unsqueeze(1)), inputs['reward'][:, :-1]], dim=1 170 ) # (B, 1, nstep) (B, H-1, nstep) -> (B, H, nstep) 171 beta = inputs['beta'] # beta_index 172 if inference: 173 # collect, eval mode: pass into one timestep mini-batch data (batchsize=env_num) 174 x = self.encoder(x) 175 x = x.unsqueeze(0) 176 prev_reward_extrinsic = prev_reward_extrinsic.unsqueeze(0).unsqueeze(-1) 177 178 env_num = self.collector_env_num 179 beta_onehot = one_hot(beta, env_num).unsqueeze(0) 180 prev_action_onehot = one_hot(prev_action, self.action_shape).unsqueeze(0) 181 x_a_r_beta = torch.cat( 182 [x, prev_action_onehot, prev_reward_extrinsic, beta_onehot], dim=-1 183 ) # shape (1, H, 1+env_num+action_dim) 184 x, next_state = self.rnn(x_a_r_beta.to(torch.float32), prev_state) 185 # TODO(pu): x, next_state = self.rnn(x, prev_state) 186 x = x.squeeze(0) 187 x = self.head(x) 188 x['next_state'] = next_state 189 return x 190 else: 191 # train mode: pass into H timesteps mini-batch data (batchsize=train_batch_size) 192 assert len(x.shape) in [3, 5], x.shape # (B, H, obs_dim) 193 x = parallel_wrapper(self.encoder)(x) # (B, H, hidden_dim) 194 prev_reward_extrinsic = prev_reward_extrinsic[:, :, 0].unsqueeze(-1) # (B,H,1) 195 env_num = self.collector_env_num 196 beta_onehot = one_hot(beta.view(-1), env_num).view([beta.shape[0], beta.shape[1], -1]) # (B, H, env_num) 197 prev_action_onehot = one_hot(prev_action.view(-1), self.action_shape).view( 198 [prev_action.shape[0], prev_action.shape[1], -1] 199 ) # (B, H, action_dim) 200 x_a_r_beta = torch.cat( 201 [x, prev_action_onehot, prev_reward_extrinsic, beta_onehot], dim=-1 202 ) # (B, H, 1+env_num+action_dim) 203 x = x_a_r_beta 204 lstm_embedding = [] 205 # TODO(nyz) how to deal with hidden_size key-value 206 hidden_state_list = [] 207 if saved_state_timesteps is not None: 208 saved_state = [] 209 for t in range(x.shape[0]): # T timesteps 210 output, prev_state = self.rnn(x[t:t + 1], prev_state) 211 if saved_state_timesteps is not None and t + 1 in saved_state_timesteps: 212 saved_state.append(prev_state) 213 lstm_embedding.append(output) 214 # only take the hidden state h 215 hidden_state_list.append(torch.cat([item['h'] for item in prev_state], dim=1)) 216 217 x = torch.cat(lstm_embedding, 0) # [B, H, 64] 218 x = parallel_wrapper(self.head)(x) 219 # the last timestep state including the hidden state (h) and the cell state (c) 220 x['next_state'] = prev_state 221 x['hidden_state'] = torch.cat(hidden_state_list, dim=-3) 222 if saved_state_timesteps is not None: 223 # the selected saved hidden states, including the hidden state (h) and the cell state (c) 224 x['saved_state'] = saved_state 225 return x