Skip to content

ding.reward_model.base_reward_model

ding.reward_model.base_reward_model

BaseRewardModel

Bases: ABC

Overview

the base class of reward model

Interface: default_config, estimate, train, clear_data, collect_data, load_expert_date

estimate(data) abstractmethod

Overview

estimate reward

Arguments: - data (:obj:List): the list of data used for estimation Returns / Effects: - This can be a side effect function which updates the reward value - If this function returns, an example returned object can be reward (:obj:Any): the estimated reward

train(data) abstractmethod

Overview

Training the reward model

Arguments: - data (:obj:Any): Data used for training Effects: - This is mostly a side effect function which updates the reward model

collect_data(data) abstractmethod

Overview

Collecting training data in designated formate or with designated transition.

Arguments: - data (:obj:Any): Raw training data (e.g. some form of states, actions, obs, etc) Returns / Effects: - This can be a side effect function which updates the data attribute in self

clear_data() abstractmethod

Overview

Clearing training data. This can be a side effect function which clears the data attribute in self

load_expert_data(data)

Overview

Getting the expert data, usually used in inverse RL reward model

Arguments: - data (:obj:Any): Expert data Effects: This is mostly a side effect function which updates the expert data attribute (e.g. self.expert_data)

reward_deepcopy(train_data)

Overview

this method deepcopy reward part in train_data, and other parts keep shallow copy to avoid the reward part of train_data in the replay buffer be incorrectly modified.

Arguments: - train_data (:obj:List): the List of train data in which the reward part will be operated by deepcopy.

create_reward_model(cfg, device, tb_logger)

Overview

Reward Estimation Model.

Arguments: - cfg (:obj:Dict): Training config - device (:obj:str): Device usage, i.e. "cpu" or "cuda" - tb_logger (:obj:str): Logger, defaultly set as 'SummaryWriter' for model summary Returns: - reward (:obj:Any): The reward model

Full Source Code

../ding/reward_model/base_reward_model.py

1from abc import ABC, abstractmethod 2from typing import Dict 3from easydict import EasyDict 4from ditk import logging 5import os 6import copy 7from typing import Any 8from ding.utils import REWARD_MODEL_REGISTRY, import_module, save_file 9 10 11class BaseRewardModel(ABC): 12 """ 13 Overview: 14 the base class of reward model 15 Interface: 16 ``default_config``, ``estimate``, ``train``, ``clear_data``, ``collect_data``, ``load_expert_date`` 17 """ 18 19 @classmethod 20 def default_config(cls: type) -> EasyDict: 21 cfg = EasyDict(copy.deepcopy(cls.config)) 22 cfg.cfg_type = cls.__name__ + 'Dict' 23 return cfg 24 25 @abstractmethod 26 def estimate(self, data: list) -> Any: 27 """ 28 Overview: 29 estimate reward 30 Arguments: 31 - data (:obj:`List`): the list of data used for estimation 32 Returns / Effects: 33 - This can be a side effect function which updates the reward value 34 - If this function returns, an example returned object can be reward (:obj:`Any`): the estimated reward 35 """ 36 raise NotImplementedError() 37 38 @abstractmethod 39 def train(self, data) -> None: 40 """ 41 Overview: 42 Training the reward model 43 Arguments: 44 - data (:obj:`Any`): Data used for training 45 Effects: 46 - This is mostly a side effect function which updates the reward model 47 """ 48 raise NotImplementedError() 49 50 @abstractmethod 51 def collect_data(self, data) -> None: 52 """ 53 Overview: 54 Collecting training data in designated formate or with designated transition. 55 Arguments: 56 - data (:obj:`Any`): Raw training data (e.g. some form of states, actions, obs, etc) 57 Returns / Effects: 58 - This can be a side effect function which updates the data attribute in ``self`` 59 """ 60 raise NotImplementedError() 61 62 @abstractmethod 63 def clear_data(self) -> None: 64 """ 65 Overview: 66 Clearing training data. \ 67 This can be a side effect function which clears the data attribute in ``self`` 68 """ 69 raise NotImplementedError() 70 71 def load_expert_data(self, data) -> None: 72 """ 73 Overview: 74 Getting the expert data, usually used in inverse RL reward model 75 Arguments: 76 - data (:obj:`Any`): Expert data 77 Effects: 78 This is mostly a side effect function which updates the expert data attribute (e.g. ``self.expert_data``) 79 """ 80 pass 81 82 def reward_deepcopy(self, train_data) -> Any: 83 """ 84 Overview: 85 this method deepcopy reward part in train_data, and other parts keep shallow copy 86 to avoid the reward part of train_data in the replay buffer be incorrectly modified. 87 Arguments: 88 - train_data (:obj:`List`): the List of train data in which the reward part will be operated by deepcopy. 89 """ 90 train_data_reward_deepcopy = [ 91 {k: copy.deepcopy(v) if k == 'reward' else v 92 for k, v in sample.items()} for sample in train_data 93 ] 94 return train_data_reward_deepcopy 95 96 def state_dict(self) -> Dict: 97 # this method should be overrided by subclass. 98 return {} 99 100 def load_state_dict(self, _state_dict) -> None: 101 # this method should be overrided by subclass. 102 pass 103 104 def save(self, path: str = None, name: str = 'best'): 105 if path is None: 106 path = self.cfg.exp_name 107 path = os.path.join(path, 'reward_model', 'ckpt') 108 if not os.path.exists(path): 109 try: 110 os.makedirs(path) 111 except FileExistsError: 112 pass 113 path = os.path.join(path, 'ckpt_{}.pth.tar'.format(name)) 114 state_dict = self.state_dict() 115 save_file(path, state_dict) 116 logging.info('Saved reward model ckpt in {}'.format(path)) 117 118 119def create_reward_model(cfg: dict, device: str, tb_logger: 'SummaryWriter') -> BaseRewardModel: # noqa 120 """ 121 Overview: 122 Reward Estimation Model. 123 Arguments: 124 - cfg (:obj:`Dict`): Training config 125 - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda" 126 - tb_logger (:obj:`str`): Logger, defaultly set as 'SummaryWriter' for model summary 127 Returns: 128 - reward (:obj:`Any`): The reward model 129 """ 130 cfg = copy.deepcopy(cfg) 131 if 'import_names' in cfg: 132 import_module(cfg.pop('import_names')) 133 if hasattr(cfg, 'reward_model'): 134 reward_model_type = cfg.reward_model.pop('type') 135 else: 136 reward_model_type = cfg.pop('type') 137 return REWARD_MODEL_REGISTRY.build(reward_model_type, cfg, device=device, tb_logger=tb_logger) 138 139 140def get_reward_model_cls(cfg: EasyDict) -> type: 141 import_module(cfg.get('import_names', [])) 142 return REWARD_MODEL_REGISTRY.get(cfg.type)