Skip to content

ding.reward_model.drex_reward_model

ding.reward_model.drex_reward_model

DrexRewardModel

Bases: TrexRewardModel

Overview

The Drex reward model class (https://arxiv.org/pdf/1907.03976.pdf)

Interface: estimate, train, load_expert_data, collect_data, clear_date, __init__, _train, Config: == ==================== ====== ============= ======================================= =============== ID Symbol Type Default Value Description Other(Shape) == ==================== ====== ============= ======================================= =============== 1 type str drex | Reward model register name, refer | | to registry REWARD_MODEL_REGISTRY | 3 | learning_rate float 0.00001 | learning rate for optimizer | 4 | update_per_ int 100 | Number of updates per collect | | collect | | 5 | batch_size int 64 | How many samples in a training batch | 6 | hidden_size int 128 | Linear model hidden size | 7 | num_trajs int 0 | Number of downsampled full | | trajectories | 8 | num_snippets int 6000 | Number of short subtrajectories | | to sample | == ==================== ====== ============= ======================================= ================

__init__(config, device, tb_logger)

Overview

Initialize self. See help(type(self)) for accurate signature.

Arguments: - cfg (:obj:EasyDict): Training config - device (:obj:str): Device usage, i.e. "cpu" or "cuda" - tb_logger (:obj:SummaryWriter): Logger, defaultly set as 'SummaryWriter' for model summary

load_expert_data()

Overview

Getting the expert data from config.expert_data_path attribute in self

Effects: This is a side effect function which updates the expert data attribute (i.e. self.expert_data) with fn:concat_state_action_pairs

Full Source Code

../ding/reward_model/drex_reward_model.py

1import copy 2from easydict import EasyDict 3import pickle 4 5from ding.utils import REWARD_MODEL_REGISTRY 6 7from .trex_reward_model import TrexRewardModel 8 9 10@REWARD_MODEL_REGISTRY.register('drex') 11class DrexRewardModel(TrexRewardModel): 12 """ 13 Overview: 14 The Drex reward model class (https://arxiv.org/pdf/1907.03976.pdf) 15 Interface: 16 ``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \ 17 ``__init__``, ``_train``, 18 Config: 19 == ==================== ====== ============= ======================================= =============== 20 ID Symbol Type Default Value Description Other(Shape) 21 == ==================== ====== ============= ======================================= =============== 22 1 ``type`` str drex | Reward model register name, refer | 23 | to registry ``REWARD_MODEL_REGISTRY`` | 24 3 | ``learning_rate`` float 0.00001 | learning rate for optimizer | 25 4 | ``update_per_`` int 100 | Number of updates per collect | 26 | ``collect`` | | 27 5 | ``batch_size`` int 64 | How many samples in a training batch | 28 6 | ``hidden_size`` int 128 | Linear model hidden size | 29 7 | ``num_trajs`` int 0 | Number of downsampled full | 30 | trajectories | 31 8 | ``num_snippets`` int 6000 | Number of short subtrajectories | 32 | to sample | 33 == ==================== ====== ============= ======================================= ================ 34 """ 35 config = dict( 36 # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``. 37 type='drex', 38 # (float) The step size of gradient descent. 39 learning_rate=1e-5, 40 # (int) How many updates(iterations) to train after collector's one collection. 41 # Bigger "update_per_collect" means bigger off-policy. 42 # collect data -> update policy-> collect data -> ... 43 update_per_collect=100, 44 # (int) How many samples in a training batch. 45 batch_size=64, 46 # (int) Linear model hidden size 47 hidden_size=128, 48 # (int) Number of downsampled full trajectories. 49 num_trajs=0, 50 # (int) Number of short subtrajectories to sample. 51 num_snippets=6000, 52 ) 53 54 bc_cfg = None 55 56 def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa 57 """ 58 Overview: 59 Initialize ``self.`` See ``help(type(self))`` for accurate signature. 60 Arguments: 61 - cfg (:obj:`EasyDict`): Training config 62 - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda" 63 - tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary 64 """ 65 super(DrexRewardModel, self).__init__(copy.deepcopy(config), device, tb_logger) 66 67 self.demo_data = [] 68 self.load_expert_data() 69 70 def load_expert_data(self) -> None: 71 """ 72 Overview: 73 Getting the expert data from ``config.expert_data_path`` attribute in self 74 Effects: 75 This is a side effect function which updates the expert data attribute \ 76 (i.e. ``self.expert_data``) with ``fn:concat_state_action_pairs`` 77 """ 78 super(DrexRewardModel, self).load_expert_data() 79 80 with open(self.cfg.reward_model.offline_data_path + '/suboptimal_data.pkl', 'rb') as f: 81 self.demo_data = pickle.load(f) 82 83 def train(self): 84 self._train() 85 return_dict = self.pred_data(self.demo_data) 86 res, pred_returns = return_dict['real'], return_dict['pred'] 87 self._logger.info("real: " + str(res)) 88 self._logger.info("pred: " + str(pred_returns)) 89 90 info = { 91 "min_snippet_length": self.min_snippet_length, 92 "max_snippet_length": self.max_snippet_length, 93 "len_num_training_obs": len(self.training_obs), 94 "lem_num_labels": len(self.training_labels), 95 "accuracy": self.calc_accuracy(self.reward_model, self.training_obs, self.training_labels), 96 } 97 self._logger.info( 98 "accuracy and comparison:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])) 99 )