Skip to content

ding.reward_model.trex_reward_model

ding.reward_model.trex_reward_model

TrexConvEncoder

Bases: Module

Overview

The Convolution Encoder used in models. Used to encoder raw 2-dim observation.

Interfaces: __init__, forward

__init__(obs_shape, hidden_size_list=[16, 16, 16, 16, 64, 1], activation=nn.LeakyReLU())

Overview

Init the Trex Convolution Encoder according to arguments. TrexConvEncoder is different \ from the ConvEncoder in model.common.encoder, their stride and kernel size parameters \ are different

Arguments: - obs_shape (:obj:SequenceType): Sequence of in_channel, some output size - hidden_size_list (:obj:SequenceType): The collection of hidden_size - activation (:obj:nn.Module): The type of activation to use in the conv layers, if None then default set to nn.LeakyReLU()

forward(x)

Overview

Return embedding tensor of the env observation

Arguments: - x (:obj:torch.Tensor): Env raw observation Returns: - outputs (:obj:torch.Tensor): Embedding tensor

TrexModel

Bases: Module

cum_return(traj, mode='sum')

calculate cumulative return of trajectory

forward(traj_i, traj_j)

compute cumulative return for each trajectory and return logits

TrexRewardModel

Bases: BaseRewardModel

Overview

The Trex reward model class (https://arxiv.org/pdf/1904.06387.pdf)

Interface: estimate, train, load_expert_data, collect_data, clear_date, __init__, _train, Config: == ==================== ====== ============= ============================================ ============= ID Symbol Type Default Value Description Other(Shape) == ==================== ====== ============= ============================================ ============= 1 type str trex | Reward model register name, refer | | to registry REWARD_MODEL_REGISTRY | 3 | learning_rate float 0.00001 | learning rate for optimizer | 4 | update_per_ int 100 | Number of updates per collect | | collect | | 5 | num_trajs int 0 | Number of downsampled full trajectories | 6 | num_snippets int 6000 | Number of short subtrajectories to sample | == ==================== ====== ============= ============================================ =============

__init__(config, device, tb_logger)

Overview

Initialize self. See help(type(self)) for accurate signature.

Arguments: - cfg (:obj:EasyDict): Training config - device (:obj:str): Device usage, i.e. "cpu" or "cuda" - tb_logger (:obj:SummaryWriter): Logger, defaultly set as 'SummaryWriter' for model summary

load_expert_data()

Overview

Getting the expert data.

Effects: This is a side effect function which updates the expert data attribute (i.e. self.expert_data) with fn:concat_state_action_pairs

estimate(data)

Overview

Estimate reward by rewriting the reward key in each row of the data.

Arguments: - data (:obj:list): the list of data used for estimation, with at least obs and action keys. Effects: - This is a side effect function which updates the reward values in place.

collect_data(data)

Overview

Collecting training data formatted by fn:concat_state_action_pairs.

Arguments: - data (:obj:Any): Raw training data (e.g. some form of states, actions, obs, etc) Effects: - This is a side effect function which updates the data attribute in self

clear_data()

Overview

Clearing training data. This is a side effect function which clears the data attribute in self

Full Source Code

../ding/reward_model/trex_reward_model.py

1from copy import deepcopy 2from typing import Tuple, Optional, List, Dict 3from easydict import EasyDict 4import pickle 5import os 6import numpy as np 7 8import torch 9import torch.nn as nn 10import torch.optim as optim 11 12from ding.utils import REWARD_MODEL_REGISTRY 13from ding.utils import SequenceType 14from ding.model.common import FCEncoder 15from ding.utils import build_logger 16from ding.utils.data import default_collate 17 18from .base_reward_model import BaseRewardModel 19from .rnd_reward_model import collect_states 20 21 22class TrexConvEncoder(nn.Module): 23 r""" 24 Overview: 25 The ``Convolution Encoder`` used in models. Used to encoder raw 2-dim observation. 26 Interfaces: 27 ``__init__``, ``forward`` 28 """ 29 30 def __init__( 31 self, 32 obs_shape: SequenceType, 33 hidden_size_list: SequenceType = [16, 16, 16, 16, 64, 1], 34 activation: Optional[nn.Module] = nn.LeakyReLU() 35 ) -> None: 36 r""" 37 Overview: 38 Init the Trex Convolution Encoder according to arguments. TrexConvEncoder is different \ 39 from the ConvEncoder in model.common.encoder, their stride and kernel size parameters \ 40 are different 41 Arguments: 42 - obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, some ``output size`` 43 - hidden_size_list (:obj:`SequenceType`): The collection of ``hidden_size`` 44 - activation (:obj:`nn.Module`): 45 The type of activation to use in the conv ``layers``, 46 if ``None`` then default set to ``nn.LeakyReLU()`` 47 """ 48 super(TrexConvEncoder, self).__init__() 49 self.obs_shape = obs_shape 50 self.act = activation 51 self.hidden_size_list = hidden_size_list 52 53 layers = [] 54 kernel_size = [7, 5, 3, 3] 55 stride = [3, 2, 1, 1] 56 input_size = obs_shape[0] # in_channel 57 for i in range(len(kernel_size)): 58 layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i])) 59 layers.append(self.act) 60 input_size = hidden_size_list[i] 61 layers.append(nn.Flatten()) 62 self.main = nn.Sequential(*layers) 63 64 flatten_size = self._get_flatten_size() 65 self.mid = nn.Sequential( 66 nn.Linear(flatten_size, hidden_size_list[-2]), self.act, 67 nn.Linear(hidden_size_list[-2], hidden_size_list[-1]) 68 ) 69 70 def _get_flatten_size(self) -> int: 71 r""" 72 Overview: 73 Get the encoding size after ``self.main`` to get the number of ``in-features`` to feed to ``nn.Linear``. 74 Arguments: 75 - x (:obj:`torch.Tensor`): Encoded Tensor after ``self.main`` 76 Returns: 77 - outputs (:obj:`torch.Tensor`): Size int, also number of in-feature 78 """ 79 test_data = torch.randn(1, *self.obs_shape) 80 with torch.no_grad(): 81 output = self.main(test_data) 82 return output.shape[1] 83 84 def forward(self, x: torch.Tensor) -> torch.Tensor: 85 r""" 86 Overview: 87 Return embedding tensor of the env observation 88 Arguments: 89 - x (:obj:`torch.Tensor`): Env raw observation 90 Returns: 91 - outputs (:obj:`torch.Tensor`): Embedding tensor 92 """ 93 x = self.main(x) 94 x = self.mid(x) 95 return x 96 97 98class TrexModel(nn.Module): 99 100 def __init__(self, obs_shape): 101 super(TrexModel, self).__init__() 102 if isinstance(obs_shape, int) or len(obs_shape) == 1: 103 self.encoder = nn.Sequential(FCEncoder(obs_shape, [512, 64]), nn.Linear(64, 1)) 104 # Conv Encoder 105 elif len(obs_shape) == 3: 106 self.encoder = TrexConvEncoder(obs_shape) 107 else: 108 raise KeyError( 109 "not support obs_shape for pre-defined encoder: {}, please customize your own Trex model". 110 format(obs_shape) 111 ) 112 113 def cum_return(self, traj: torch.Tensor, mode: str = 'sum') -> Tuple[torch.Tensor, torch.Tensor]: 114 '''calculate cumulative return of trajectory''' 115 r = self.encoder(traj) 116 if mode == 'sum': 117 sum_rewards = torch.sum(r) 118 sum_abs_rewards = torch.sum(torch.abs(r)) 119 return sum_rewards, sum_abs_rewards 120 elif mode == 'batch': 121 return r, torch.abs(r) 122 else: 123 raise KeyError("not support mode: {}, please choose mode=sum or mode=batch".format(mode)) 124 125 def forward(self, traj_i: torch.Tensor, traj_j: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 126 '''compute cumulative return for each trajectory and return logits''' 127 cum_r_i, abs_r_i = self.cum_return(traj_i) 128 cum_r_j, abs_r_j = self.cum_return(traj_j) 129 return torch.cat((cum_r_i.unsqueeze(0), cum_r_j.unsqueeze(0)), 0), abs_r_i + abs_r_j 130 131 132@REWARD_MODEL_REGISTRY.register('trex') 133class TrexRewardModel(BaseRewardModel): 134 """ 135 Overview: 136 The Trex reward model class (https://arxiv.org/pdf/1904.06387.pdf) 137 Interface: 138 ``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \ 139 ``__init__``, ``_train``, 140 Config: 141 == ==================== ====== ============= ============================================ ============= 142 ID Symbol Type Default Value Description Other(Shape) 143 == ==================== ====== ============= ============================================ ============= 144 1 ``type`` str trex | Reward model register name, refer | 145 | to registry ``REWARD_MODEL_REGISTRY`` | 146 3 | ``learning_rate`` float 0.00001 | learning rate for optimizer | 147 4 | ``update_per_`` int 100 | Number of updates per collect | 148 | ``collect`` | | 149 5 | ``num_trajs`` int 0 | Number of downsampled full trajectories | 150 6 | ``num_snippets`` int 6000 | Number of short subtrajectories to sample | 151 == ==================== ====== ============= ============================================ ============= 152 """ 153 config = dict( 154 # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``. 155 type='trex', 156 # (float) The step size of gradient descent. 157 learning_rate=1e-5, 158 # (int) How many updates(iterations) to train after collector's one collection. 159 # Bigger "update_per_collect" means bigger off-policy. 160 # collect data -> update policy-> collect data -> ... 161 update_per_collect=100, 162 # (int) Number of downsampled full trajectories. 163 num_trajs=0, 164 # (int) Number of short subtrajectories to sample. 165 num_snippets=6000, 166 ) 167 168 def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa 169 """ 170 Overview: 171 Initialize ``self.`` See ``help(type(self))`` for accurate signature. 172 Arguments: 173 - cfg (:obj:`EasyDict`): Training config 174 - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda" 175 - tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary 176 """ 177 super(TrexRewardModel, self).__init__() 178 self.cfg = config 179 assert device in ["cpu", "cuda"] or "cuda" in device 180 self.device = device 181 self.tb_logger = tb_logger 182 self.reward_model = TrexModel(self.cfg.policy.model.obs_shape) 183 self.reward_model.to(self.device) 184 self.pre_expert_data = [] 185 self.train_data = [] 186 self.expert_data_loader = None 187 self.opt = optim.Adam(self.reward_model.parameters(), config.reward_model.learning_rate) 188 self.train_iter = 0 189 self.learning_returns = [] 190 self.training_obs = [] 191 self.training_labels = [] 192 self.num_trajs = self.cfg.reward_model.num_trajs 193 self.num_snippets = self.cfg.reward_model.num_snippets 194 # minimum number of short subtrajectories to sample 195 self.min_snippet_length = config.reward_model.min_snippet_length 196 # maximum number of short subtrajectories to sample 197 self.max_snippet_length = config.reward_model.max_snippet_length 198 self.l1_reg = 0 199 self.data_for_save = {} 200 self._logger, self._tb_logger = build_logger( 201 path='./{}/log/{}'.format(self.cfg.exp_name, 'trex_reward_model'), name='trex_reward_model' 202 ) 203 self.load_expert_data() 204 205 def load_expert_data(self) -> None: 206 """ 207 Overview: 208 Getting the expert data. 209 Effects: 210 This is a side effect function which updates the expert data attribute \ 211 (i.e. ``self.expert_data``) with ``fn:concat_state_action_pairs`` 212 """ 213 with open(os.path.join(self.cfg.exp_name, 'episodes_data.pkl'), 'rb') as f: 214 self.pre_expert_data = pickle.load(f) 215 with open(os.path.join(self.cfg.exp_name, 'learning_returns.pkl'), 'rb') as f: 216 self.learning_returns = pickle.load(f) 217 218 self.create_training_data() 219 self._logger.info("num_training_obs: {}".format(len(self.training_obs))) 220 self._logger.info("num_labels: {}".format(len(self.training_labels))) 221 222 def create_training_data(self): 223 num_trajs = self.num_trajs 224 num_snippets = self.num_snippets 225 min_snippet_length = self.min_snippet_length 226 max_snippet_length = self.max_snippet_length 227 228 demo_lengths = [] 229 for i in range(len(self.pre_expert_data)): 230 demo_lengths.append([len(d) for d in self.pre_expert_data[i]]) 231 232 self._logger.info("demo_lengths: {}".format(demo_lengths)) 233 max_snippet_length = min(np.min(demo_lengths), max_snippet_length) 234 self._logger.info("min snippet length: {}".format(min_snippet_length)) 235 self._logger.info("max snippet length: {}".format(max_snippet_length)) 236 237 # collect training data 238 max_traj_length = 0 239 num_bins = len(self.pre_expert_data) 240 assert num_bins >= 2 241 242 # add full trajs (for use on Enduro) 243 si = np.random.randint(6, size=num_trajs) 244 sj = np.random.randint(6, size=num_trajs) 245 step = np.random.randint(3, 7, size=num_trajs) 246 for n in range(num_trajs): 247 # pick two random demonstrations 248 bi, bj = np.random.choice(num_bins, size=(2, ), replace=False) 249 ti = np.random.choice(len(self.pre_expert_data[bi])) 250 tj = np.random.choice(len(self.pre_expert_data[bj])) 251 # create random partial trajs by finding random start frame and random skip frame 252 traj_i = self.pre_expert_data[bi][ti][si[n]::step[n]] # slice(start,stop,step) 253 traj_j = self.pre_expert_data[bj][tj][sj[n]::step[n]] 254 255 label = int(bi <= bj) 256 257 self.training_obs.append((traj_i, traj_j)) 258 self.training_labels.append(label) 259 max_traj_length = max(max_traj_length, len(traj_i), len(traj_j)) 260 261 # fixed size snippets with progress prior 262 rand_length = np.random.randint(min_snippet_length, max_snippet_length, size=num_snippets) 263 for n in range(num_snippets): 264 # pick two random demonstrations 265 bi, bj = np.random.choice(num_bins, size=(2, ), replace=False) 266 ti = np.random.choice(len(self.pre_expert_data[bi])) 267 tj = np.random.choice(len(self.pre_expert_data[bj])) 268 # create random snippets 269 # find min length of both demos to ensure we can pick a demo no earlier 270 # than that chosen in worse preferred demo 271 min_length = min(len(self.pre_expert_data[bi][ti]), len(self.pre_expert_data[bj][tj])) 272 if bi < bj: # pick tj snippet to be later than ti 273 ti_start = np.random.randint(min_length - rand_length[n] + 1) 274 # print(ti_start, len(demonstrations[tj])) 275 tj_start = np.random.randint(ti_start, len(self.pre_expert_data[bj][tj]) - rand_length[n] + 1) 276 else: # ti is better so pick later snippet in ti 277 tj_start = np.random.randint(min_length - rand_length[n] + 1) 278 # print(tj_start, len(demonstrations[ti])) 279 ti_start = np.random.randint(tj_start, len(self.pre_expert_data[bi][ti]) - rand_length[n] + 1) 280 # skip everyother framestack to reduce size 281 traj_i = self.pre_expert_data[bi][ti][ti_start:ti_start + rand_length[n]:2] 282 traj_j = self.pre_expert_data[bj][tj][tj_start:tj_start + rand_length[n]:2] 283 284 max_traj_length = max(max_traj_length, len(traj_i), len(traj_j)) 285 label = int(bi <= bj) 286 self.training_obs.append((traj_i, traj_j)) 287 self.training_labels.append(label) 288 self._logger.info(("maximum traj length: {}".format(max_traj_length))) 289 return self.training_obs, self.training_labels 290 291 def _train(self): 292 # check if gpu available 293 device = self.device # torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 294 # Assume that we are on a CUDA machine, then this should print a CUDA device: 295 self._logger.info("device: {}".format(device)) 296 training_inputs, training_outputs = self.training_obs, self.training_labels 297 loss_criterion = nn.CrossEntropyLoss() 298 299 cum_loss = 0.0 300 training_data = list(zip(training_inputs, training_outputs)) 301 for epoch in range(self.cfg.reward_model.update_per_collect): # todo 302 np.random.shuffle(training_data) 303 training_obs, training_labels = zip(*training_data) 304 for i in range(len(training_labels)): 305 306 # traj_i, traj_j has the same length, however, they change as i increases 307 traj_i, traj_j = training_obs[i] # traj_i is a list of array generated by env.step 308 traj_i = np.array(traj_i) 309 traj_j = np.array(traj_j) 310 traj_i = torch.from_numpy(traj_i).float().to(device) 311 traj_j = torch.from_numpy(traj_j).float().to(device) 312 313 # training_labels[i] is a boolean integer: 0 or 1 314 labels = torch.tensor([training_labels[i]]).to(device) 315 316 # forward + backward + zero out gradient + optimize 317 outputs, abs_rewards = self.reward_model.forward(traj_i, traj_j) 318 outputs = outputs.unsqueeze(0) 319 loss = loss_criterion(outputs, labels) + self.l1_reg * abs_rewards 320 self.opt.zero_grad() 321 loss.backward() 322 self.opt.step() 323 324 # print stats to see if learning 325 item_loss = loss.item() 326 cum_loss += item_loss 327 if i % 100 == 99: 328 self._logger.info("[epoch {}:{}] loss {}".format(epoch, i, cum_loss)) 329 self._logger.info("abs_returns: {}".format(abs_rewards)) 330 cum_loss = 0.0 331 self._logger.info("check pointing") 332 if not os.path.exists(os.path.join(self.cfg.exp_name, 'ckpt_reward_model')): 333 os.makedirs(os.path.join(self.cfg.exp_name, 'ckpt_reward_model')) 334 torch.save(self.reward_model.state_dict(), os.path.join(self.cfg.exp_name, 'ckpt_reward_model/latest.pth.tar')) 335 self._logger.info("finished training") 336 337 def train(self): 338 self._train() 339 # print out predicted cumulative returns and actual returns 340 sorted_returns = sorted(self.learning_returns, key=lambda s: s[0]) 341 demonstrations = [ 342 x for _, x in sorted(zip(self.learning_returns, self.pre_expert_data), key=lambda pair: pair[0][0]) 343 ] 344 with torch.no_grad(): 345 pred_returns = [self.predict_traj_return(self.reward_model, traj[0]) for traj in demonstrations] 346 for i, p in enumerate(pred_returns): 347 self._logger.info("{} {} {}".format(i, p, sorted_returns[i][0])) 348 info = { 349 "demo_length": [len(d[0]) for d in self.pre_expert_data], 350 "min_snippet_length": self.min_snippet_length, 351 "max_snippet_length": min(np.min([len(d[0]) for d in self.pre_expert_data]), self.max_snippet_length), 352 "len_num_training_obs": len(self.training_obs), 353 "lem_num_labels": len(self.training_labels), 354 "accuracy": self.calc_accuracy(self.reward_model, self.training_obs, self.training_labels), 355 } 356 self._logger.info( 357 "accuracy and comparison:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])) 358 ) 359 360 def predict_traj_return(self, net, traj): 361 device = self.device 362 # torch.set_printoptions(precision=20) 363 # torch.use_deterministic_algorithms(True) 364 with torch.no_grad(): 365 rewards_from_obs = net.cum_return( 366 torch.from_numpy(np.array(traj)).float().to(device), mode='batch' 367 )[0].squeeze().tolist() 368 # rewards_from_obs1 = net.cum_return(torch.from_numpy(np.array([traj[0]])).float().to(device))[0].item() 369 # different precision 370 return sum(rewards_from_obs) # rewards_from_obs is a list of floats 371 372 def calc_accuracy(self, reward_network, training_inputs, training_outputs): 373 device = self.device 374 loss_criterion = nn.CrossEntropyLoss() 375 num_correct = 0. 376 with torch.no_grad(): 377 for i in range(len(training_inputs)): 378 label = training_outputs[i] 379 traj_i, traj_j = training_inputs[i] 380 traj_i = np.array(traj_i) 381 traj_j = np.array(traj_j) 382 traj_i = torch.from_numpy(traj_i).float().to(device) 383 traj_j = torch.from_numpy(traj_j).float().to(device) 384 385 #forward to get logits 386 outputs, abs_return = reward_network.forward(traj_i, traj_j) 387 _, pred_label = torch.max(outputs, 0) 388 if pred_label.item() == label: 389 num_correct += 1. 390 return num_correct / len(training_inputs) 391 392 def pred_data(self, data): 393 obs = [default_collate(data[i])['obs'] for i in range(len(data))] 394 res = [torch.sum(default_collate(data[i])['reward']).item() for i in range(len(data))] 395 pred_returns = [self.predict_traj_return(self.reward_model, obs[i]) for i in range(len(obs))] 396 return {'real': res, 'pred': pred_returns} 397 398 def estimate(self, data: list) -> List[Dict]: 399 """ 400 Overview: 401 Estimate reward by rewriting the reward key in each row of the data. 402 Arguments: 403 - data (:obj:`list`): the list of data used for estimation, with at least \ 404 ``obs`` and ``action`` keys. 405 Effects: 406 - This is a side effect function which updates the reward values in place. 407 """ 408 # NOTE: deepcopy reward part of data is very important, 409 # otherwise the reward of data in the replay buffer will be incorrectly modified. 410 train_data_augmented = self.reward_deepcopy(data) 411 412 res = collect_states(train_data_augmented) 413 res = torch.stack(res).to(self.device) 414 with torch.no_grad(): 415 sum_rewards, sum_abs_rewards = self.reward_model.cum_return(res, mode='batch') 416 417 for item, rew in zip(train_data_augmented, sum_rewards): # TODO optimise this loop as well ? 418 item['reward'] = rew 419 420 return train_data_augmented 421 422 def collect_data(self, data: list) -> None: 423 """ 424 Overview: 425 Collecting training data formatted by ``fn:concat_state_action_pairs``. 426 Arguments: 427 - data (:obj:`Any`): Raw training data (e.g. some form of states, actions, obs, etc) 428 Effects: 429 - This is a side effect function which updates the data attribute in ``self`` 430 """ 431 pass 432 433 def clear_data(self) -> None: 434 """ 435 Overview: 436 Clearing training data. \ 437 This is a side effect function which clears the data attribute in ``self`` 438 """ 439 self.training_obs.clear() 440 self.training_labels.clear()