1from typing import List, Dict 2from ditk import logging 3import numpy as np 4import torch 5import pickle 6try: 7 from sklearn.svm import SVC 8except ImportError: 9 SVC = None 10from ding.torch_utils import cov 11from ding.utils import REWARD_MODEL_REGISTRY, one_time_warning 12from .base_reward_model import BaseRewardModel 13 14 15@REWARD_MODEL_REGISTRY.register('pdeil') 16class PdeilRewardModel(BaseRewardModel): 17 """ 18 Overview: 19 The Pdeil reward model class (https://arxiv.org/abs/2112.06746) 20 Interface: 21 ``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \ 22 ``__init__``, ``_train``, ``_batch_mn_pdf`` 23 Config: 24 == ==================== ===== ============= ======================================= ======================= 25 ID Symbol Type Default Value Description Other(Shape) 26 == ==================== ===== ============= ======================================= ======================= 27 1 ``type`` str pdeil | Reward model register name, refer | 28 | to registry ``REWARD_MODEL_REGISTRY`` | 29 2 | ``expert_data_`` str expert_data. | Path to the expert dataset | Should be a '.pkl' 30 | ``path`` .pkl | | file 31 3 | ``discrete_`` bool False | Whether the action is discrete | 32 | ``action`` | | 33 4 | ``alpha`` float 0.5 | coefficient for Probability | 34 | | Density Estimator | 35 5 | ``clear_buffer`` int 1 | clear buffer per fixed iters | make sure replay 36 ``_per_iters`` | buffer's data count 37 | isn't too few. 38 | (code work in entry) 39 == ==================== ===== ============= ======================================= ======================= 40 """ 41 config = dict( 42 # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``. 43 type='pdeil', 44 # (str) Path to the expert dataset. 45 # expert_data_path='expert_data.pkl', 46 # (bool) Whether the action is discrete. 47 discrete_action=False, 48 # (float) Coefficient for Probability Density Estimator. 49 # alpha + beta = 1, alpha is in [0,1] 50 # when alpha is close to 0, the estimator has high variance and low bias; 51 # when alpha is close to 1, the estimator has high bias and low variance. 52 alpha=0.5, 53 # (int) Clear buffer per fixed iters. 54 clear_buffer_per_iters=1, 55 ) 56 57 def __init__(self, cfg: dict, device, tb_logger: 'SummaryWriter') -> None: # noqa 58 """ 59 Overview: 60 Initialize ``self.`` See ``help(type(self))`` for accurate signature. 61 Some rules in naming the attributes of ``self.``: 62 63 - ``e_`` : expert values 64 - ``_sigma_`` : standard division values 65 - ``p_`` : current policy values 66 - ``_s_`` : states 67 - ``_a_`` : actions 68 Arguments: 69 - cfg (:obj:`Dict`): Training config 70 - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda" 71 - tb_logger (:obj:`str`): Logger, defaultly set as 'SummaryWriter' for model summary 72 """ 73 super(PdeilRewardModel, self).__init__() 74 try: 75 import scipy.stats as stats 76 self.stats = stats 77 except ImportError: 78 import sys 79 logging.warning("Please install scipy first, such as `pip3 install scipy`.") 80 sys.exit(1) 81 self.cfg: dict = cfg 82 self.e_u_s = None 83 self.e_sigma_s = None 84 if cfg.discrete_action: 85 self.svm = None 86 else: 87 self.e_u_s_a = None 88 self.e_sigma_s_a = None 89 self.p_u_s = None 90 self.p_sigma_s = None 91 self.expert_data = None 92 self.train_data: list = [] 93 assert device in ["cpu", "cuda"] or "cuda" in device 94 # pedil default use cpu device 95 self.device = 'cpu' 96 97 self.load_expert_data() 98 states: list = [] 99 actions: list = [] 100 for item in self.expert_data: 101 states.append(item['obs']) 102 actions.append(item['action']) 103 states: torch.Tensor = torch.stack(states, dim=0) 104 actions: torch.Tensor = torch.stack(actions, dim=0) 105 self.e_u_s: torch.Tensor = torch.mean(states, axis=0) 106 self.e_sigma_s: torch.Tensor = cov(states, rowvar=False) 107 if self.cfg.discrete_action and SVC is None: 108 one_time_warning("You are using discrete action while the SVC is not installed!") 109 if self.cfg.discrete_action and SVC is not None: 110 self.svm: SVC = SVC(probability=True) 111 self.svm.fit(states.cpu().numpy(), actions.cpu().numpy()) 112 else: 113 # states action conjuct 114 state_actions = torch.cat((states, actions.float()), dim=-1) 115 self.e_u_s_a = torch.mean(state_actions, axis=0) 116 self.e_sigma_s_a = cov(state_actions, rowvar=False) 117 118 def load_expert_data(self) -> None: 119 """ 120 Overview: 121 Getting the expert data from ``config['expert_data_path']`` attribute in self. 122 Effects: 123 This is a side effect function which updates the expert data attribute (e.g. ``self.expert_data``) 124 """ 125 expert_data_path: str = self.cfg.expert_data_path 126 with open(expert_data_path, 'rb') as f: 127 self.expert_data: list = pickle.load(f) 128 129 def _train(self, states: torch.Tensor) -> None: 130 """ 131 Overview: 132 Helper function for ``train`` which caclulates loss for train data and expert data. 133 Arguments: 134 - states (:obj:`torch.Tensor`): current policy states 135 Effects: 136 - Update attributes of ``p_u_s`` and ``p_sigma_s`` 137 """ 138 # we only need to collect the current policy state 139 self.p_u_s = torch.mean(states, axis=0) 140 self.p_sigma_s = cov(states, rowvar=False) 141 142 def train(self): 143 """ 144 Overview: 145 Training the Pdeil reward model. 146 """ 147 states = torch.stack([item['obs'] for item in self.train_data], dim=0) 148 self._train(states) 149 150 def _batch_mn_pdf(self, x: np.ndarray, mean: np.ndarray, cov: np.ndarray) -> np.ndarray: 151 """ 152 Overview: 153 Get multivariate normal pdf of given np array. 154 """ 155 return np.asarray( 156 self.stats.multivariate_normal.pdf(x, mean=mean, cov=cov, allow_singular=False), dtype=np.float32 157 ) 158 159 def estimate(self, data: list) -> List[Dict]: 160 """ 161 Overview: 162 Estimate reward by rewriting the reward keys. 163 Arguments: 164 - data (:obj:`list`): the list of data used for estimation,\ 165 with at least ``obs`` and ``action`` keys. 166 Effects: 167 - This is a side effect function which updates the reward values in place. 168 """ 169 # NOTE: deepcopy reward part of data is very important, 170 # otherwise the reward of data in the replay buffer will be incorrectly modified. 171 train_data_augmented = self.reward_deepcopy(data) 172 s = torch.stack([item['obs'] for item in train_data_augmented], dim=0) 173 a = torch.stack([item['action'] for item in train_data_augmented], dim=0) 174 if self.p_u_s is None: 175 print("you need to train you reward model first") 176 for item in train_data_augmented: 177 item['reward'].zero_() 178 else: 179 rho_1 = self._batch_mn_pdf(s.cpu().numpy(), self.e_u_s.cpu().numpy(), self.e_sigma_s.cpu().numpy()) 180 rho_1 = torch.from_numpy(rho_1) 181 rho_2 = self._batch_mn_pdf(s.cpu().numpy(), self.p_u_s.cpu().numpy(), self.p_sigma_s.cpu().numpy()) 182 rho_2 = torch.from_numpy(rho_2) 183 if self.cfg.discrete_action: 184 rho_3 = self.svm.predict_proba(s.cpu().numpy())[a.cpu().numpy()] 185 rho_3 = torch.from_numpy(rho_3) 186 else: 187 s_a = torch.cat([s, a.float()], dim=-1) 188 rho_3 = self._batch_mn_pdf( 189 s_a.cpu().numpy(), 190 self.e_u_s_a.cpu().numpy(), 191 self.e_sigma_s_a.cpu().numpy() 192 ) 193 rho_3 = torch.from_numpy(rho_3) 194 rho_3 = rho_3 / rho_1 195 alpha = self.cfg.alpha 196 beta = 1 - alpha 197 den = rho_1 * rho_3 198 frac = alpha * rho_1 + beta * rho_2 199 if frac.abs().max() < 1e-4: 200 for item in train_data_augmented: 201 item['reward'].zero_() 202 else: 203 reward = den / frac 204 reward = torch.chunk(reward, reward.shape[0], dim=0) 205 for item, rew in zip(train_data_augmented, reward): 206 item['reward'] = rew 207 return train_data_augmented 208 209 def collect_data(self, item: list): 210 """ 211 Overview: 212 Collecting training data by iterating data items in the input list 213 Arguments: 214 - data (:obj:`list`): Raw training data (e.g. some form of states, actions, obs, etc) 215 Effects: 216 - This is a side effect function which updates the data attribute in ``self`` by \ 217 iterating data items in the input data items' list 218 """ 219 self.train_data.extend(item) 220 221 def clear_data(self): 222 """ 223 Overview: 224 Clearing training data. \ 225 This is a side effect function which clears the data attribute in ``self`` 226 """ 227 self.train_data.clear()