ding.bonus.model¶
ding.bonus.model
¶
Full Source Code
../ding/bonus/model.py
1from typing import Union, Optional 2from easydict import EasyDict 3import torch 4import torch.nn as nn 5import treetensor.torch as ttorch 6from copy import deepcopy 7from ding.utils import SequenceType, squeeze 8from ding.model.common import ReparameterizationHead, RegressionHead, MultiHead, \ 9 FCEncoder, ConvEncoder, IMPALAConvEncoder, PopArtVHead 10from ding.torch_utils import MLP, fc_block 11 12 13class DiscretePolicyHead(nn.Module): 14 15 def __init__( 16 self, 17 hidden_size: int, 18 output_size: int, 19 layer_num: int = 1, 20 activation: Optional[nn.Module] = nn.ReLU(), 21 norm_type: Optional[str] = None, 22 ) -> None: 23 super(DiscretePolicyHead, self).__init__() 24 self.main = nn.Sequential( 25 MLP( 26 hidden_size, 27 hidden_size, 28 hidden_size, 29 layer_num, 30 layer_fn=nn.Linear, 31 activation=activation, 32 norm_type=norm_type 33 ), fc_block(hidden_size, output_size) 34 ) 35 36 def forward(self, x: torch.Tensor) -> torch.Tensor: 37 return self.main(x) 38 39 40class PPOFModel(nn.Module): 41 mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] 42 43 def __init__( 44 self, 45 obs_shape: Union[int, SequenceType], 46 action_shape: Union[int, SequenceType, EasyDict], 47 action_space: str = 'discrete', 48 share_encoder: bool = True, 49 encoder_hidden_size_list: SequenceType = [128, 128, 64], 50 actor_head_hidden_size: int = 64, 51 actor_head_layer_num: int = 1, 52 critic_head_hidden_size: int = 64, 53 critic_head_layer_num: int = 1, 54 activation: Optional[nn.Module] = nn.ReLU(), 55 norm_type: Optional[str] = None, 56 sigma_type: Optional[str] = 'independent', 57 fixed_sigma_value: Optional[int] = 0.3, 58 bound_type: Optional[str] = None, 59 encoder: Optional[torch.nn.Module] = None, 60 popart_head=False, 61 ) -> None: 62 super(PPOFModel, self).__init__() 63 obs_shape = squeeze(obs_shape) 64 action_shape = squeeze(action_shape) 65 self.obs_shape, self.action_shape = obs_shape, action_shape 66 self.share_encoder = share_encoder 67 68 # Encoder Type 69 def new_encoder(outsize): 70 if isinstance(obs_shape, int) or len(obs_shape) == 1: 71 return FCEncoder( 72 obs_shape=obs_shape, 73 hidden_size_list=encoder_hidden_size_list, 74 activation=activation, 75 norm_type=norm_type 76 ) 77 elif len(obs_shape) == 3: 78 return ConvEncoder( 79 obs_shape=obs_shape, 80 hidden_size_list=encoder_hidden_size_list, 81 activation=activation, 82 norm_type=norm_type 83 ) 84 else: 85 raise RuntimeError( 86 "not support obs_shape for pre-defined encoder: {}, please customize your own encoder". 87 format(obs_shape) 88 ) 89 90 if self.share_encoder: 91 assert actor_head_hidden_size == critic_head_hidden_size, \ 92 "actor and critic network head should have same size." 93 if encoder: 94 if isinstance(encoder, torch.nn.Module): 95 self.encoder = encoder 96 else: 97 raise ValueError("illegal encoder instance.") 98 else: 99 self.encoder = new_encoder(actor_head_hidden_size) 100 else: 101 if encoder: 102 if isinstance(encoder, torch.nn.Module): 103 self.actor_encoder = encoder 104 self.critic_encoder = deepcopy(encoder) 105 else: 106 raise ValueError("illegal encoder instance.") 107 else: 108 self.actor_encoder = new_encoder(actor_head_hidden_size) 109 self.critic_encoder = new_encoder(critic_head_hidden_size) 110 111 # Head Type 112 if not popart_head: 113 self.critic_head = RegressionHead( 114 critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type 115 ) 116 else: 117 self.critic_head = PopArtVHead( 118 critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type 119 ) 120 121 self.action_space = action_space 122 assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space 123 if self.action_space == 'continuous': 124 self.multi_head = False 125 self.actor_head = ReparameterizationHead( 126 actor_head_hidden_size, 127 action_shape, 128 actor_head_layer_num, 129 sigma_type=sigma_type, 130 activation=activation, 131 norm_type=norm_type, 132 bound_type=bound_type 133 ) 134 elif self.action_space == 'discrete': 135 actor_head_cls = DiscretePolicyHead 136 multi_head = not isinstance(action_shape, int) 137 self.multi_head = multi_head 138 if multi_head: 139 self.actor_head = MultiHead( 140 actor_head_cls, 141 actor_head_hidden_size, 142 action_shape, 143 layer_num=actor_head_layer_num, 144 activation=activation, 145 norm_type=norm_type 146 ) 147 else: 148 self.actor_head = actor_head_cls( 149 actor_head_hidden_size, 150 action_shape, 151 actor_head_layer_num, 152 activation=activation, 153 norm_type=norm_type 154 ) 155 elif self.action_space == 'hybrid': # HPPO 156 # hybrid action space: action_type(discrete) + action_args(continuous), 157 # such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])} 158 action_shape.action_args_shape = squeeze(action_shape.action_args_shape) 159 action_shape.action_type_shape = squeeze(action_shape.action_type_shape) 160 actor_action_args = ReparameterizationHead( 161 actor_head_hidden_size, 162 action_shape.action_args_shape, 163 actor_head_layer_num, 164 sigma_type=sigma_type, 165 fixed_sigma_value=fixed_sigma_value, 166 activation=activation, 167 norm_type=norm_type, 168 bound_type=bound_type, 169 ) 170 actor_action_type = DiscretePolicyHead( 171 actor_head_hidden_size, 172 action_shape.action_type_shape, 173 actor_head_layer_num, 174 activation=activation, 175 norm_type=norm_type, 176 ) 177 self.actor_head = nn.ModuleList([actor_action_type, actor_action_args]) 178 179 # must use list, not nn.ModuleList 180 if self.share_encoder: 181 self.actor = [self.encoder, self.actor_head] 182 self.critic = [self.encoder, self.critic_head] 183 else: 184 self.actor = [self.actor_encoder, self.actor_head] 185 self.critic = [self.critic_encoder, self.critic_head] 186 # Convenient for calling some apis (e.g. self.critic.parameters()), 187 # but may cause misunderstanding when `print(self)` 188 self.actor = nn.ModuleList(self.actor) 189 self.critic = nn.ModuleList(self.critic) 190 191 def forward(self, inputs: ttorch.Tensor, mode: str) -> ttorch.Tensor: 192 assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) 193 return getattr(self, mode)(inputs) 194 195 def compute_actor(self, x: ttorch.Tensor) -> ttorch.Tensor: 196 if self.share_encoder: 197 x = self.encoder(x) 198 else: 199 x = self.actor_encoder(x) 200 201 if self.action_space == 'discrete': 202 return self.actor_head(x) 203 elif self.action_space == 'continuous': 204 x = self.actor_head(x) # mu, sigma 205 return ttorch.as_tensor(x) 206 elif self.action_space == 'hybrid': 207 action_type = self.actor_head[0](x) 208 action_args = self.actor_head[1](x) 209 return ttorch.as_tensor({'action_type': action_type, 'action_args': action_args}) 210 211 def compute_critic(self, x: ttorch.Tensor) -> ttorch.Tensor: 212 if self.share_encoder: 213 x = self.encoder(x) 214 else: 215 x = self.critic_encoder(x) 216 x = self.critic_head(x) 217 return x 218 219 def compute_actor_critic(self, x: ttorch.Tensor) -> ttorch.Tensor: 220 if self.share_encoder: 221 actor_embedding = critic_embedding = self.encoder(x) 222 else: 223 actor_embedding = self.actor_encoder(x) 224 critic_embedding = self.critic_encoder(x) 225 226 value = self.critic_head(critic_embedding) 227 228 if self.action_space == 'discrete': 229 logit = self.actor_head(actor_embedding) 230 return ttorch.as_tensor({'logit': logit, 'value': value['pred']}) 231 elif self.action_space == 'continuous': 232 x = self.actor_head(actor_embedding) 233 return ttorch.as_tensor({'logit': x, 'value': value['pred']}) 234 elif self.action_space == 'hybrid': 235 action_type = self.actor_head[0](actor_embedding) 236 action_args = self.actor_head[1](actor_embedding) 237 return ttorch.as_tensor( 238 { 239 'logit': { 240 'action_type': action_type, 241 'action_args': action_args 242 }, 243 'value': value['pred'] 244 } 245 )