1from typing import Union, Tuple, List, Dict 2from easydict import EasyDict 3 4import random 5import torch 6import torch.nn as nn 7import torch.optim as optim 8 9from ding.utils import SequenceType, REWARD_MODEL_REGISTRY 10from ding.model import FCEncoder, ConvEncoder 11from ding.torch_utils import one_hot 12from .base_reward_model import BaseRewardModel 13 14 15def collect_states(iterator: list) -> Tuple[list, list, list]: 16 states = [] 17 next_states = [] 18 actions = [] 19 for item in iterator: 20 state = item['obs'] 21 next_state = item['next_obs'] 22 action = item['action'] 23 states.append(state) 24 next_states.append(next_state) 25 actions.append(action) 26 return states, next_states, actions 27 28 29class ICMNetwork(nn.Module): 30 """ 31 Intrinsic Curiosity Model (ICM Module) 32 Implementation of: 33 [1] Curiosity-driven Exploration by Self-supervised Prediction 34 Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017. 35 https://arxiv.org/pdf/1705.05363.pdf 36 [2] Code implementation reference: 37 https://github.com/pathak22/noreward-rl 38 https://github.com/jcwleo/curiosity-driven-exploration-pytorch 39 40 1) Embedding observations into a latent space 41 2) Predicting the action logit given two consecutive embedded observations 42 3) Predicting the next embedded obs, given the embeded former observation and action 43 """ 44 45 def __init__(self, obs_shape: Union[int, SequenceType], hidden_size_list: SequenceType, action_shape: int) -> None: 46 super(ICMNetwork, self).__init__() 47 if isinstance(obs_shape, int) or len(obs_shape) == 1: 48 self.feature = FCEncoder(obs_shape, hidden_size_list) 49 elif len(obs_shape) == 3: 50 self.feature = ConvEncoder(obs_shape, hidden_size_list) 51 else: 52 raise KeyError( 53 "not support obs_shape for pre-defined encoder: {}, please customize your own ICM model". 54 format(obs_shape) 55 ) 56 self.action_shape = action_shape 57 feature_output = hidden_size_list[-1] 58 self.inverse_net = nn.Sequential(nn.Linear(feature_output * 2, 512), nn.ReLU(), nn.Linear(512, action_shape)) 59 self.residual = nn.ModuleList( 60 [ 61 nn.Sequential( 62 nn.Linear(action_shape + 512, 512), 63 nn.LeakyReLU(), 64 nn.Linear(512, 512), 65 ) for _ in range(8) 66 ] 67 ) 68 self.forward_net_1 = nn.Sequential(nn.Linear(action_shape + feature_output, 512), nn.LeakyReLU()) 69 self.forward_net_2 = nn.Linear(action_shape + 512, feature_output) 70 71 def forward(self, state: torch.Tensor, next_state: torch.Tensor, 72 action_long: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 73 r""" 74 Overview: 75 Use observation, next_observation and action to genearte ICM module 76 Parameter updates with ICMNetwork forward setup. 77 Arguments: 78 - state (:obj:`torch.Tensor`): 79 The current state batch 80 - next_state (:obj:`torch.Tensor`): 81 The next state batch 82 - action_long (:obj:`torch.Tensor`): 83 The action batch 84 Returns: 85 - real_next_state_feature (:obj:`torch.Tensor`): 86 Run with the encoder. Return the real next_state's embedded feature. 87 - pred_next_state_feature (:obj:`torch.Tensor`): 88 Run with the encoder and residual network. Return the predicted next_state's embedded feature. 89 - pred_action_logit (:obj:`torch.Tensor`): 90 Run with the encoder. Return the predicted action logit. 91 Shapes: 92 - state (:obj:`torch.Tensor`): :math:`(B, N)`, where B is the batch size and N is ''obs_shape'' 93 - next_state (:obj:`torch.Tensor`): :math:`(B, N)`, where B is the batch size and N is ''obs_shape'' 94 - action_long (:obj:`torch.Tensor`): :math:`(B)`, where B is the batch size'' 95 - real_next_state_feature (:obj:`torch.Tensor`): :math:`(B, M)`, where B is the batch size 96 and M is embedded feature size 97 - pred_next_state_feature (:obj:`torch.Tensor`): :math:`(B, M)`, where B is the batch size 98 and M is embedded feature size 99 - pred_action_logit (:obj:`torch.Tensor`): :math:`(B, A)`, where B is the batch size 100 and A is the ''action_shape'' 101 """ 102 action = one_hot(action_long, num=self.action_shape) 103 encode_state = self.feature(state) 104 encode_next_state = self.feature(next_state) 105 # get pred action logit 106 concat_state = torch.cat((encode_state, encode_next_state), 1) 107 pred_action_logit = self.inverse_net(concat_state) 108 # --------------------- 109 110 # get pred next state 111 pred_next_state_feature_orig = torch.cat((encode_state, action), 1) 112 pred_next_state_feature_orig = self.forward_net_1(pred_next_state_feature_orig) 113 114 # residual 115 for i in range(4): 116 pred_next_state_feature = self.residual[i * 2](torch.cat((pred_next_state_feature_orig, action), 1)) 117 pred_next_state_feature_orig = self.residual[i * 2 + 1]( 118 torch.cat((pred_next_state_feature, action), 1) 119 ) + pred_next_state_feature_orig 120 pred_next_state_feature = self.forward_net_2(torch.cat((pred_next_state_feature_orig, action), 1)) 121 real_next_state_feature = encode_next_state 122 return real_next_state_feature, pred_next_state_feature, pred_action_logit 123 124 125@REWARD_MODEL_REGISTRY.register('icm') 126class ICMRewardModel(BaseRewardModel): 127 """ 128 Overview: 129 The ICM reward model class (https://arxiv.org/pdf/1705.05363.pdf) 130 Interface: 131 ``estimate``, ``train``, ``collect_data``, ``clear_data``, \ 132 ``__init__``, ``_train``, ``load_state_dict``, ``state_dict`` 133 Config: 134 == ==================== ======== ============= ==================================== ======================= 135 ID Symbol Type Default Value Description Other(Shape) 136 == ==================== ======== ============= ==================================== ======================= 137 1 ``type`` str icm | Reward model register name, | 138 | refer to registry | 139 | ``REWARD_MODEL_REGISTRY`` | 140 2 | ``intrinsic_`` str add | the intrinsic reward type | including add, new 141 | ``reward_type`` | | , or assign 142 3 | ``learning_rate`` float 0.001 | The step size of gradient descent | 143 4 | ``obs_shape`` Tuple( 6 | the observation shape | 144 [int, 145 list]) 146 5 | ``action_shape`` int 7 | the action space shape | 147 6 | ``batch_size`` int 64 | Training batch size | 148 7 | ``hidden`` list [64, 64, | the MLP layer shape | 149 | ``_size_list`` (int) 128] | | 150 8 | ``update_per_`` int 100 | Number of updates per collect | 151 | ``collect`` | | 152 9 | ``reverse_scale`` float 1 | the importance weight of the | 153 | forward and reverse loss | 154 10 | ``intrinsic_`` float 0.003 | the weight of intrinsic reward | r = w*r_i + r_e 155 ``reward_weight`` 156 11 | ``extrinsic_`` bool True | Whether to normlize 157 ``reward_norm`` | extrinsic reward 158 12 | ``extrinsic_`` int 1 | the upper bound of the reward 159 ``reward_norm_max`` | normalization 160 13 | ``clear_buffer`` int 1 | clear buffer per fixed iters | make sure replay 161 ``_per_iters`` | buffer's data count 162 | isn't too few. 163 | (code work in entry) 164 == ==================== ======== ============= ==================================== ======================= 165 """ 166 config = dict( 167 # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``. 168 type='icm', 169 # (str) The intrinsic reward type, including add, new, or assign. 170 intrinsic_reward_type='add', 171 # (float) The step size of gradient descent. 172 learning_rate=1e-3, 173 # (Tuple[int, list]), The observation shape. 174 obs_shape=6, 175 # (int) The action shape, support discrete action only in this version. 176 action_shape=7, 177 # (float) Batch size. 178 batch_size=64, 179 # (list) The MLP layer shape. 180 hidden_size_list=[64, 64, 128], 181 # (int) How many updates(iterations) to train after collector's one collection. 182 # Bigger "update_per_collect" means bigger off-policy. 183 # collect data -> update policy-> collect data -> ... 184 update_per_collect=100, 185 # (float) The importance weight of the forward and reverse loss. 186 reverse_scale=1, 187 # (float) The weight of intrinsic reward. 188 # r = intrinsic_reward_weight * r_i + r_e. 189 intrinsic_reward_weight=0.003, # 1/300 190 # (bool) Whether to normlize extrinsic reward. 191 # Normalize the reward to [0, extrinsic_reward_norm_max]. 192 extrinsic_reward_norm=True, 193 # (int) The upper bound of the reward normalization. 194 extrinsic_reward_norm_max=1, 195 # (int) Clear buffer per fixed iters. 196 clear_buffer_per_iters=100, 197 ) 198 199 def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa 200 super(ICMRewardModel, self).__init__() 201 self.cfg = config 202 assert device == "cpu" or device.startswith("cuda") 203 self.device = device 204 self.tb_logger = tb_logger 205 self.reward_model = ICMNetwork(config.obs_shape, config.hidden_size_list, config.action_shape) 206 self.reward_model.to(self.device) 207 self.intrinsic_reward_type = config.intrinsic_reward_type 208 assert self.intrinsic_reward_type in ['add', 'new', 'assign'] 209 self.train_data = [] 210 self.train_states = [] 211 self.train_next_states = [] 212 self.train_actions = [] 213 self.opt = optim.Adam(self.reward_model.parameters(), config.learning_rate) 214 self.ce = nn.CrossEntropyLoss(reduction="mean") 215 self.forward_mse = nn.MSELoss(reduction='none') 216 self.reverse_scale = config.reverse_scale 217 self.res = nn.Softmax(dim=-1) 218 self.estimate_cnt_icm = 0 219 self.train_cnt_icm = 0 220 221 def _train(self) -> None: 222 self.train_cnt_icm += 1 223 train_data_list = [i for i in range(0, len(self.train_states))] 224 train_data_index = random.sample(train_data_list, self.cfg.batch_size) 225 data_states: list = [self.train_states[i] for i in train_data_index] 226 data_states: torch.Tensor = torch.stack(data_states).to(self.device) 227 data_next_states: list = [self.train_next_states[i] for i in train_data_index] 228 data_next_states: torch.Tensor = torch.stack(data_next_states).to(self.device) 229 data_actions: list = [self.train_actions[i] for i in train_data_index] 230 data_actions: torch.Tensor = torch.cat(data_actions).to(self.device) 231 232 real_next_state_feature, pred_next_state_feature, pred_action_logit = self.reward_model( 233 data_states, data_next_states, data_actions 234 ) 235 inverse_loss = self.ce(pred_action_logit, data_actions.long()) 236 forward_loss = self.forward_mse(pred_next_state_feature, real_next_state_feature.detach()).mean() 237 self.tb_logger.add_scalar('icm_reward/forward_loss', forward_loss, self.train_cnt_icm) 238 self.tb_logger.add_scalar('icm_reward/inverse_loss', inverse_loss, self.train_cnt_icm) 239 action = torch.argmax(self.res(pred_action_logit), -1) 240 accuracy = torch.sum(action == data_actions.squeeze(-1)).item() / data_actions.shape[0] 241 self.tb_logger.add_scalar('icm_reward/action_accuracy', accuracy, self.train_cnt_icm) 242 loss = self.reverse_scale * inverse_loss + forward_loss 243 self.tb_logger.add_scalar('icm_reward/total_loss', loss, self.train_cnt_icm) 244 loss = self.reverse_scale * inverse_loss + forward_loss 245 self.opt.zero_grad() 246 loss.backward() 247 self.opt.step() 248 249 def train(self) -> None: 250 for _ in range(self.cfg.update_per_collect): 251 self._train() 252 253 def estimate(self, data: list) -> List[Dict]: 254 # NOTE: deepcopy reward part of data is very important, 255 # otherwise the reward of data in the replay buffer will be incorrectly modified. 256 train_data_augmented = self.reward_deepcopy(data) 257 states, next_states, actions = collect_states(train_data_augmented) 258 states = torch.stack(states).to(self.device) 259 next_states = torch.stack(next_states).to(self.device) 260 actions = torch.cat(actions).to(self.device) 261 with torch.no_grad(): 262 real_next_state_feature, pred_next_state_feature, _ = self.reward_model(states, next_states, actions) 263 raw_icm_reward = self.forward_mse(real_next_state_feature, pred_next_state_feature).mean(dim=1) 264 self.estimate_cnt_icm += 1 265 self.tb_logger.add_scalar('icm_reward/raw_icm_reward_max', raw_icm_reward.max(), self.estimate_cnt_icm) 266 self.tb_logger.add_scalar('icm_reward/raw_icm_reward_mean', raw_icm_reward.mean(), self.estimate_cnt_icm) 267 self.tb_logger.add_scalar('icm_reward/raw_icm_reward_min', raw_icm_reward.min(), self.estimate_cnt_icm) 268 self.tb_logger.add_scalar('icm_reward/raw_icm_reward_std', raw_icm_reward.std(), self.estimate_cnt_icm) 269 icm_reward = (raw_icm_reward - raw_icm_reward.min()) / (raw_icm_reward.max() - raw_icm_reward.min() + 1e-8) 270 self.tb_logger.add_scalar('icm_reward/icm_reward_max', icm_reward.max(), self.estimate_cnt_icm) 271 self.tb_logger.add_scalar('icm_reward/icm_reward_mean', icm_reward.mean(), self.estimate_cnt_icm) 272 self.tb_logger.add_scalar('icm_reward/icm_reward_min', icm_reward.min(), self.estimate_cnt_icm) 273 self.tb_logger.add_scalar('icm_reward/icm_reward_std', icm_reward.std(), self.estimate_cnt_icm) 274 icm_reward = (raw_icm_reward - raw_icm_reward.min()) / (raw_icm_reward.max() - raw_icm_reward.min() + 1e-8) 275 icm_reward = icm_reward.to(self.device) 276 for item, icm_rew in zip(train_data_augmented, icm_reward): 277 if self.intrinsic_reward_type == 'add': 278 if self.cfg.extrinsic_reward_norm: 279 item['reward'] = item[ 280 'reward'] / self.cfg.extrinsic_reward_norm_max + icm_rew * self.cfg.intrinsic_reward_weight 281 else: 282 item['reward'] = item['reward'] + icm_rew * self.cfg.intrinsic_reward_weight 283 elif self.intrinsic_reward_type == 'new': 284 item['intrinsic_reward'] = icm_rew 285 if self.cfg.extrinsic_reward_norm: 286 item['reward'] = item['reward'] / self.cfg.extrinsic_reward_norm_max 287 elif self.intrinsic_reward_type == 'assign': 288 item['reward'] = icm_rew 289 290 return train_data_augmented 291 292 def collect_data(self, data: list) -> None: 293 self.train_data.extend(collect_states(data)) 294 states, next_states, actions = collect_states(data) 295 self.train_states.extend(states) 296 self.train_next_states.extend(next_states) 297 self.train_actions.extend(actions) 298 299 def clear_data(self) -> None: 300 self.train_data.clear() 301 self.train_states.clear() 302 self.train_next_states.clear() 303 self.train_actions.clear() 304 305 def state_dict(self) -> Dict: 306 return self.reward_model.state_dict() 307 308 def load_state_dict(self, _state_dict: Dict) -> None: 309 self.reward_model.load_state_dict(_state_dict)