Skip to content

ding.reward_model.rnd_reward_model

ding.reward_model.rnd_reward_model

RndRewardModel

Bases: BaseRewardModel

Overview

The RND reward model class (https://arxiv.org/abs/1810.12894v1)

Interface: estimate, train, collect_data, clear_data, __init__, _train, load_state_dict, state_dict Config: == ==================== ===== ============= ======================================= ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ===== ============= ======================================= ======================= 1 type str rnd | Reward model register name, refer | | to registry REWARD_MODEL_REGISTRY | 2 | intrinsic_ str add | the intrinsic reward type | including add, new | reward_type | | , or assign 3 | learning_rate float 0.001 | The step size of gradient descent | 4 | batch_size int 64 | Training batch size | 5 | hidden list [64, 64, | the MLP layer shape | | _size_list (int) 128] | | 6 | update_per_ int 100 | Number of updates per collect | | collect | | 7 | obs_norm bool True | Observation normalization | 8 | obs_norm_ int 0 | min clip value for obs normalization | | clamp_min 9 | obs_norm_ int 1 | max clip value for obs normalization | | clamp_max 10 | intrinsic_ float 0.01 | the weight of intrinsic reward | r = w*r_i + r_e reward_weight 11 | extrinsic_ bool True | Whether to normlize extrinsic reward reward_norm 12 | extrinsic_ int 1 | the upper bound of the reward reward_norm_max | normalization == ==================== ===== ============= ======================================= =======================

estimate(data)

Rewrite the reward key in each row of the data.

Full Source Code

../ding/reward_model/rnd_reward_model.py

1from typing import Union, Tuple, List, Dict 2from easydict import EasyDict 3 4import random 5import torch 6import torch.nn as nn 7import torch.optim as optim 8import torch.nn.functional as F 9 10from ding.utils import SequenceType, REWARD_MODEL_REGISTRY 11from ding.model import FCEncoder, ConvEncoder 12from .base_reward_model import BaseRewardModel 13from ding.utils import RunningMeanStd 14from ding.torch_utils.data_helper import to_tensor 15import numpy as np 16 17 18def collect_states(iterator): 19 res = [] 20 for item in iterator: 21 state = item['obs'] 22 res.append(state) 23 return res 24 25 26class RndNetwork(nn.Module): 27 28 def __init__(self, obs_shape: Union[int, SequenceType], hidden_size_list: SequenceType) -> None: 29 super(RndNetwork, self).__init__() 30 if isinstance(obs_shape, int) or len(obs_shape) == 1: 31 self.target = FCEncoder(obs_shape, hidden_size_list) 32 self.predictor = FCEncoder(obs_shape, hidden_size_list) 33 elif len(obs_shape) == 3: 34 self.target = ConvEncoder(obs_shape, hidden_size_list) 35 self.predictor = ConvEncoder(obs_shape, hidden_size_list) 36 else: 37 raise KeyError( 38 "not support obs_shape for pre-defined encoder: {}, please customize your own RND model". 39 format(obs_shape) 40 ) 41 for param in self.target.parameters(): 42 param.requires_grad = False 43 44 def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 45 predict_feature = self.predictor(obs) 46 with torch.no_grad(): 47 target_feature = self.target(obs) 48 return predict_feature, target_feature 49 50 51@REWARD_MODEL_REGISTRY.register('rnd') 52class RndRewardModel(BaseRewardModel): 53 """ 54 Overview: 55 The RND reward model class (https://arxiv.org/abs/1810.12894v1) 56 Interface: 57 ``estimate``, ``train``, ``collect_data``, ``clear_data``, \ 58 ``__init__``, ``_train``, ``load_state_dict``, ``state_dict`` 59 Config: 60 == ==================== ===== ============= ======================================= ======================= 61 ID Symbol Type Default Value Description Other(Shape) 62 == ==================== ===== ============= ======================================= ======================= 63 1 ``type`` str rnd | Reward model register name, refer | 64 | to registry ``REWARD_MODEL_REGISTRY`` | 65 2 | ``intrinsic_`` str add | the intrinsic reward type | including add, new 66 | ``reward_type`` | | , or assign 67 3 | ``learning_rate`` float 0.001 | The step size of gradient descent | 68 4 | ``batch_size`` int 64 | Training batch size | 69 5 | ``hidden`` list [64, 64, | the MLP layer shape | 70 | ``_size_list`` (int) 128] | | 71 6 | ``update_per_`` int 100 | Number of updates per collect | 72 | ``collect`` | | 73 7 | ``obs_norm`` bool True | Observation normalization | 74 8 | ``obs_norm_`` int 0 | min clip value for obs normalization | 75 | ``clamp_min`` 76 9 | ``obs_norm_`` int 1 | max clip value for obs normalization | 77 | ``clamp_max`` 78 10 | ``intrinsic_`` float 0.01 | the weight of intrinsic reward | r = w*r_i + r_e 79 ``reward_weight`` 80 11 | ``extrinsic_`` bool True | Whether to normlize extrinsic reward 81 ``reward_norm`` 82 12 | ``extrinsic_`` int 1 | the upper bound of the reward 83 ``reward_norm_max`` | normalization 84 == ==================== ===== ============= ======================================= ======================= 85 """ 86 config = dict( 87 # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``. 88 type='rnd', 89 # (str) The intrinsic reward type, including add, new, or assign. 90 intrinsic_reward_type='add', 91 # (float) The step size of gradient descent. 92 learning_rate=1e-3, 93 # (float) Batch size. 94 batch_size=64, 95 # (list(int)) Sequence of ``hidden_size`` of reward network. 96 # If obs.shape == 1, use MLP layers. 97 # If obs.shape == 3, use conv layer and final dense layer. 98 hidden_size_list=[64, 64, 128], 99 # (int) How many updates(iterations) to train after collector's one collection. 100 # Bigger "update_per_collect" means bigger off-policy. 101 # collect data -> update policy-> collect data -> ... 102 update_per_collect=100, 103 # (bool) Observation normalization: transform obs to mean 0, std 1. 104 obs_norm=True, 105 # (int) Min clip value for observation normalization. 106 obs_norm_clamp_min=-1, 107 # (int) Max clip value for observation normalization. 108 obs_norm_clamp_max=1, 109 # Means the relative weight of RND intrinsic_reward. 110 # (float) The weight of intrinsic reward 111 # r = intrinsic_reward_weight * r_i + r_e. 112 intrinsic_reward_weight=0.01, 113 # (bool) Whether to normlize extrinsic reward. 114 # Normalize the reward to [0, extrinsic_reward_norm_max]. 115 extrinsic_reward_norm=True, 116 # (int) The upper bound of the reward normalization. 117 extrinsic_reward_norm_max=1, 118 ) 119 120 def __init__(self, config: EasyDict, device: str = 'cpu', tb_logger: 'SummaryWriter' = None) -> None: # noqa 121 super(RndRewardModel, self).__init__() 122 self.cfg = config 123 assert device == "cpu" or device.startswith("cuda") 124 self.device = device 125 if tb_logger is None: # TODO 126 from tensorboardX import SummaryWriter 127 tb_logger = SummaryWriter('rnd_reward_model') 128 self.tb_logger = tb_logger 129 self.reward_model = RndNetwork(config.obs_shape, config.hidden_size_list) 130 self.reward_model.to(self.device) 131 self.intrinsic_reward_type = config.intrinsic_reward_type 132 assert self.intrinsic_reward_type in ['add', 'new', 'assign'] 133 self.train_obs = [] 134 self.opt = optim.Adam(self.reward_model.predictor.parameters(), config.learning_rate) 135 self._running_mean_std_rnd_reward = RunningMeanStd(epsilon=1e-4) 136 self.estimate_cnt_rnd = 0 137 self.train_cnt_icm = 0 138 self._running_mean_std_rnd_obs = RunningMeanStd(epsilon=1e-4) 139 140 def _train(self) -> None: 141 train_data: list = random.sample(self.train_obs, self.cfg.batch_size) 142 train_data: torch.Tensor = torch.stack(train_data).to(self.device) 143 if self.cfg.obs_norm: 144 # Note: observation normalization: transform obs to mean 0, std 1 145 self._running_mean_std_rnd_obs.update(train_data.cpu().numpy()) 146 train_data = (train_data - to_tensor(self._running_mean_std_rnd_obs.mean).to(self.device)) / to_tensor( 147 self._running_mean_std_rnd_obs.std 148 ).to(self.device) 149 train_data = torch.clamp(train_data, min=self.cfg.obs_norm_clamp_min, max=self.cfg.obs_norm_clamp_max) 150 151 predict_feature, target_feature = self.reward_model(train_data) 152 loss = F.mse_loss(predict_feature, target_feature.detach()) 153 self.tb_logger.add_scalar('rnd_reward/loss', loss, self.train_cnt_icm) 154 self.opt.zero_grad() 155 loss.backward() 156 self.opt.step() 157 158 def train(self) -> None: 159 for _ in range(self.cfg.update_per_collect): 160 self._train() 161 self.train_cnt_icm += 1 162 163 def estimate(self, data: list) -> List[Dict]: 164 """ 165 Rewrite the reward key in each row of the data. 166 """ 167 # NOTE: deepcopy reward part of data is very important, 168 # otherwise the reward of data in the replay buffer will be incorrectly modified. 169 train_data_augmented = self.reward_deepcopy(data) 170 171 obs = collect_states(train_data_augmented) 172 obs = torch.stack(obs).to(self.device) 173 if self.cfg.obs_norm: 174 # Note: observation normalization: transform obs to mean 0, std 1 175 obs = (obs - to_tensor(self._running_mean_std_rnd_obs.mean 176 ).to(self.device)) / to_tensor(self._running_mean_std_rnd_obs.std).to(self.device) 177 obs = torch.clamp(obs, min=self.cfg.obs_norm_clamp_min, max=self.cfg.obs_norm_clamp_max) 178 179 with torch.no_grad(): 180 predict_feature, target_feature = self.reward_model(obs) 181 mse = F.mse_loss(predict_feature, target_feature, reduction='none').mean(dim=1) 182 self._running_mean_std_rnd_reward.update(mse.cpu().numpy()) 183 184 # Note: according to the min-max normalization, transform rnd reward to [0,1] 185 rnd_reward = (mse - mse.min()) / (mse.max() - mse.min() + 1e-8) 186 187 # save the rnd_reward statistics into tb_logger 188 self.estimate_cnt_rnd += 1 189 self.tb_logger.add_scalar('rnd_reward/rnd_reward_max', rnd_reward.max(), self.estimate_cnt_rnd) 190 self.tb_logger.add_scalar('rnd_reward/rnd_reward_mean', rnd_reward.mean(), self.estimate_cnt_rnd) 191 self.tb_logger.add_scalar('rnd_reward/rnd_reward_min', rnd_reward.min(), self.estimate_cnt_rnd) 192 self.tb_logger.add_scalar('rnd_reward/rnd_reward_std', rnd_reward.std(), self.estimate_cnt_rnd) 193 194 rnd_reward = rnd_reward.to(self.device) 195 rnd_reward = torch.chunk(rnd_reward, rnd_reward.shape[0], dim=0) 196 """ 197 NOTE: Following normalization approach to extrinsic reward seems be not reasonable, 198 because this approach compresses the extrinsic reward magnitude, resulting in less informative reward signals. 199 """ 200 # rewards = torch.stack([data[i]['reward'] for i in range(len(data))]) 201 # rewards = (rewards - torch.min(rewards)) / (torch.max(rewards) - torch.min(rewards)) 202 203 for item, rnd_rew in zip(train_data_augmented, rnd_reward): 204 if self.intrinsic_reward_type == 'add': 205 if self.cfg.extrinsic_reward_norm: 206 item['reward'] = item[ 207 'reward'] / self.cfg.extrinsic_reward_norm_max + rnd_rew * self.cfg.intrinsic_reward_weight 208 else: 209 item['reward'] = item['reward'] + rnd_rew * self.cfg.intrinsic_reward_weight 210 elif self.intrinsic_reward_type == 'new': 211 item['intrinsic_reward'] = rnd_rew 212 if self.cfg.extrinsic_reward_norm: 213 item['reward'] = item['reward'] / self.cfg.extrinsic_reward_norm_max 214 elif self.intrinsic_reward_type == 'assign': 215 item['reward'] = rnd_rew 216 217 # save the augmented_reward statistics into tb_logger 218 rew = [item['reward'].cpu().numpy() for item in train_data_augmented] 219 self.tb_logger.add_scalar('augmented_reward/reward_max', np.max(rew), self.estimate_cnt_rnd) 220 self.tb_logger.add_scalar('augmented_reward/reward_mean', np.mean(rew), self.estimate_cnt_rnd) 221 self.tb_logger.add_scalar('augmented_reward/reward_min', np.min(rew), self.estimate_cnt_rnd) 222 self.tb_logger.add_scalar('augmented_reward/reward_std', np.std(rew), self.estimate_cnt_rnd) 223 return train_data_augmented 224 225 def collect_data(self, data: list) -> None: 226 self.train_obs.extend(collect_states(data)) 227 228 def clear_data(self) -> None: 229 self.train_obs.clear() 230 231 def state_dict(self) -> Dict: 232 return self.reward_model.state_dict() 233 234 def load_state_dict(self, _state_dict: Dict) -> None: 235 self.reward_model.load_state_dict(_state_dict)