Skip to content

ding.reward_model.pwil_irl_model

ding.reward_model.pwil_irl_model

PwilRewardModel

Bases: BaseRewardModel

Overview

The Pwil reward model class (https://arxiv.org/pdf/2006.04678.pdf)

Interface: estimate, train, load_expert_data, collect_data, clear_date, __init__, _train, _get_state_distance, _get_action_distance Config: == ================== ===== ============= ======================================= ======================= ID Symbol Type Default Value Description Other(Shape) == ================== ===== ============= ======================================= ======================= 1 type str pwil | Reward model register name, refer | | to registry REWARD_MODEL_REGISTRY | 2 | expert_data_ str expert_data. | Path to the expert dataset | Should be a '.pkl' | path .pkl | | file 3 | sample_size int 1000 | sample data from expert dataset | | with fixed size | 4 | alpha int 5 | factor alpha | 5 | beta int 5 | factor beta | 6 | s_size int 4 | state size | 7 | a_size int 2 | action size | 8 | clear_buffer int 1 | clear buffer per fixed iters | make sure replay _per_iters | buffer's data count | isn't too few. | (code work in entry) == ================== ===== ============= ======================================= ======================= Properties: - reward_table (:obj: Dict): In this algorithm, reward model is a dictionary.

__init__(config, device, tb_logger)

Overview

Initialize self. See help(type(self)) for accurate signature.

Arguments: - cfg (:obj:Dict): Training config - device (:obj:str): Device usage, i.e. "cpu" or "cuda" - tb_logger (:obj:str): Logger, defaultly set as 'SummaryWriter' for model summary

load_expert_data()

Overview

Getting the expert data from config['expert_data_path'] attribute in self

Effects: This is a side effect function which updates the expert data attribute (e.g. self.expert_data); in this algorithm, also the self.expert_s, self.expert_a for states and actions are updated.

collect_data(data)

Overview

Collecting training data formatted by fn:concat_state_action_pairs.

Arguments: - data (:obj:list): Raw training data (e.g. some form of states, actions, obs, etc) Effects: - This is a side effect function which updates the data attribute in self; in this algorithm, also the s_size, a_size for states and actions are updated in the attribute in self.cfg Dict; reward_factor also updated as collect_data called.

train()

Overview

Training the Pwil reward model.

estimate(data)

Overview

Estimate reward by rewriting the reward key in each row of the data.

Arguments: - data (:obj:list): the list of data used for estimation, with at least obs and action keys. Effects: - This is a side effect function which updates the reward_table with (obs,action) tuples from input.

clear_data()

Overview

Clearing training data. This is a side effect function which clears the data attribute in self

collect_state_action_pairs(iterator)

Overview

Concate state and action pairs from input iterator.

Arguments: - iterator (:obj:Iterable): Iterables with at least obs and action tensor keys. Returns: - res (:obj:Torch.tensor): State and action pairs.

Full Source Code

../ding/reward_model/pwil_irl_model.py

1from typing import Dict, List 2import math 3import random 4import pickle 5import torch 6 7from ding.utils import REWARD_MODEL_REGISTRY 8from .base_reward_model import BaseRewardModel 9 10 11def collect_state_action_pairs(iterator): 12 # concat state and action 13 """ 14 Overview: 15 Concate state and action pairs from input iterator. 16 Arguments: 17 - iterator (:obj:`Iterable`): Iterables with at least ``obs`` and ``action`` tensor keys. 18 Returns: 19 - res (:obj:`Torch.tensor`): State and action pairs. 20 """ 21 res = [] 22 for item in iterator: 23 state = item['obs'] 24 action = item['action'] 25 # s_a = torch.cat([state, action.float()], dim=-1) 26 res.append((state, action)) 27 return res 28 29 30@REWARD_MODEL_REGISTRY.register('pwil') 31class PwilRewardModel(BaseRewardModel): 32 """ 33 Overview: 34 The Pwil reward model class (https://arxiv.org/pdf/2006.04678.pdf) 35 Interface: 36 ``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \ 37 ``__init__``, ``_train``, ``_get_state_distance``, ``_get_action_distance`` 38 Config: 39 == ================== ===== ============= ======================================= ======================= 40 ID Symbol Type Default Value Description Other(Shape) 41 == ================== ===== ============= ======================================= ======================= 42 1 ``type`` str pwil | Reward model register name, refer | 43 | to registry ``REWARD_MODEL_REGISTRY`` | 44 2 | ``expert_data_`` str expert_data. | Path to the expert dataset | Should be a '.pkl' 45 | ``path`` .pkl | | file 46 3 | ``sample_size`` int 1000 | sample data from expert dataset | 47 | with fixed size | 48 4 | ``alpha`` int 5 | factor alpha | 49 5 | ``beta`` int 5 | factor beta | 50 6 | ``s_size`` int 4 | state size | 51 7 | ``a_size`` int 2 | action size | 52 8 | ``clear_buffer`` int 1 | clear buffer per fixed iters | make sure replay 53 ``_per_iters`` | buffer's data count 54 | isn't too few. 55 | (code work in entry) 56 == ================== ===== ============= ======================================= ======================= 57 Properties: 58 - reward_table (:obj: `Dict`): In this algorithm, reward model is a dictionary. 59 """ 60 config = dict( 61 # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``. 62 type='pwil', 63 # (str) Path to the expert dataset. 64 # expert_data_path='expert_data.pkl', 65 # (int) Sample data from expert dataset with fixed size. 66 sample_size=1000, 67 # r = alpha * exp((-beta*T/sqrt(|s_size|+ |a_size|))*c_i) 68 # key idea for this reward is to minimize. 69 # the Wasserstein distance between the state-action distribution. 70 # (int) Factor alpha. 71 alpha=5, 72 # (int) Factor beta. 73 beta=5, 74 #(int)State size. 75 # s_size=4, 76 # (int) Action size. 77 # a_size=2, 78 # (int) Clear buffer per fixed iters. 79 clear_buffer_per_iters=1, 80 ) 81 82 def __init__(self, config: Dict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa 83 """ 84 Overview: 85 Initialize ``self.`` See ``help(type(self))`` for accurate signature. 86 Arguments: 87 - cfg (:obj:`Dict`): Training config 88 - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda" 89 - tb_logger (:obj:`str`): Logger, defaultly set as 'SummaryWriter' for model summary 90 """ 91 super(PwilRewardModel, self).__init__() 92 self.cfg: Dict = config 93 assert device in ["cpu", "cuda"] or "cuda" in device 94 self.device = device 95 self.expert_data: List[tuple] = [] 96 self.train_data: List[tuple] = [] 97 # In this algo, model is a dict 98 self.reward_table: Dict = {} 99 self.T: int = 0 100 101 self.load_expert_data() 102 103 def load_expert_data(self) -> None: 104 """ 105 Overview: 106 Getting the expert data from ``config['expert_data_path']`` attribute in self 107 Effects: 108 This is a side effect function which updates the expert data attribute (e.g. ``self.expert_data``); \ 109 in this algorithm, also the ``self.expert_s``, ``self.expert_a`` for states and actions are updated. 110 111 """ 112 with open(self.cfg.expert_data_path, 'rb') as f: 113 self.expert_data = pickle.load(f) 114 print("the data size is:", len(self.expert_data)) 115 sample_size = min(self.cfg.sample_size, len(self.expert_data)) 116 self.expert_data = random.sample(self.expert_data, sample_size) 117 self.expert_data = [(item['obs'], item['action']) for item in self.expert_data] 118 self.expert_s, self.expert_a = list(zip(*self.expert_data)) 119 print('the expert data demonstrations is:', len(self.expert_data)) 120 121 def collect_data(self, data: list) -> None: 122 """ 123 Overview: 124 Collecting training data formatted by ``fn:concat_state_action_pairs``. 125 Arguments: 126 - data (:obj:`list`): Raw training data (e.g. some form of states, actions, obs, etc) 127 Effects: 128 - This is a side effect function which updates the data attribute in ``self``; \ 129 in this algorithm, also the ``s_size``, ``a_size`` for states and actions are updated in the \ 130 attribute in ``self.cfg`` Dict; ``reward_factor`` also updated as ``collect_data`` called. 131 """ 132 self.train_data.extend(collect_state_action_pairs(data)) 133 self.T = len(self.train_data) 134 135 s_size = self.cfg.s_size 136 a_size = self.cfg.a_size 137 beta = self.cfg.beta 138 self.reward_factor = -beta * self.T / math.sqrt(s_size + a_size) 139 140 def train(self) -> None: 141 """ 142 Overview: 143 Training the Pwil reward model. 144 """ 145 self._train(self.train_data) 146 147 def estimate(self, data: list) -> List[Dict]: 148 """ 149 Overview: 150 Estimate reward by rewriting the reward key in each row of the data. 151 Arguments: 152 - data (:obj:`list`): the list of data used for estimation, \ 153 with at least ``obs`` and ``action`` keys. 154 Effects: 155 - This is a side effect function which updates the ``reward_table`` with ``(obs,action)`` \ 156 tuples from input. 157 """ 158 # NOTE: deepcopy reward part of data is very important, 159 # otherwise the reward of data in the replay buffer will be incorrectly modified. 160 train_data_augmented = self.reward_deepcopy(data) 161 for item in train_data_augmented: 162 s = item['obs'] 163 a = item['action'] 164 if (s, a) in self.reward_table: 165 item['reward'] = self.reward_table[(s, a)] 166 else: 167 # when (s, a) pair is not trained, set the reward value to default value(e.g.: 0) 168 item['reward'] = torch.zeros_like(item['reward']) 169 return train_data_augmented 170 171 def _get_state_distance(self, s1: list, s2: list) -> torch.Tensor: 172 """ 173 Overview: 174 Getting distances of states given 2 state lists. One single state \ 175 is of shape ``torch.Size([n])`` (``n`` referred in in-code comments) 176 Arguments: 177 - s1 (:obj:`torch.Tensor list`): the 1st states' list of size M 178 - s2 (:obj:`torch.Tensor list`): the 2nd states' list of size N 179 Returns: 180 - distance (:obj:`torch.Tensor`) Euclidean distance tensor of \ 181 the state tensor lists, of size M x N. 182 """ 183 # Format the values in the tensors to be of float type 184 s1 = torch.stack(s1).float() 185 s2 = torch.stack(s2).float() 186 M, N = s1.shape[0], s2.shape[0] 187 # Automatically fill in length 188 s1 = s1.view(M, -1) 189 s2 = s2.view(N, -1) 190 # Automatically fill in & format the tensor size to be (MxNxn) 191 s1 = s1.unsqueeze(1).repeat(1, N, 1) 192 s2 = s2.unsqueeze(0).repeat(M, 1, 1) 193 # Return the distance tensor of size MxN 194 return ((s1 - s2) ** 2).mean(dim=-1) 195 196 def _get_action_distance(self, a1: list, a2: list) -> torch.Tensor: 197 # TODO the metric of action distance maybe different from envs 198 """ 199 Overview: 200 Getting distances of actions given 2 action lists. One single action \ 201 is of shape ``torch.Size([n])`` (``n`` referred in in-code comments) 202 Arguments: 203 - a1 (:obj:`torch.Tensor list`): the 1st actions' list of size M 204 - a2 (:obj:`torch.Tensor list`): the 2nd actions' list of size N 205 Returns: 206 - distance (:obj:`torch.Tensor`) Euclidean distance tensor of \ 207 the action tensor lists, of size M x N. 208 """ 209 a1 = torch.stack(a1).float() 210 a2 = torch.stack(a2).float() 211 M, N = a1.shape[0], a2.shape[0] 212 a1 = a1.view(M, -1) 213 a2 = a2.view(N, -1) 214 a1 = a1.unsqueeze(1).repeat(1, N, 1) 215 a2 = a2.unsqueeze(0).repeat(M, 1, 1) 216 return ((a1 - a2) ** 2).mean(dim=-1) 217 218 def _train(self, data: list): 219 """ 220 Overview: 221 Helper function for ``train``, find the min disctance ``s_e``, ``a_e``. 222 Arguments: 223 - data (:obj:`list`): Raw training data (e.g. some form of states, actions, obs, etc) 224 Effects: 225 - This is a side effect function which updates the ``reward_table`` attribute in ``self`` . 226 """ 227 batch_s, batch_a = list(zip(*data)) 228 s_distance_matrix = self._get_state_distance(batch_s, self.expert_s) 229 a_distance_matrix = self._get_action_distance(batch_a, self.expert_a) 230 distance_matrix = s_distance_matrix + a_distance_matrix 231 w_e_list = [1 / len(self.expert_data)] * len(self.expert_data) 232 for i, item in enumerate(data): 233 s, a = item 234 w_pi = 1 / self.T 235 c = 0 236 expert_data_idx = torch.arange(len(self.expert_data)).tolist() 237 while w_pi > 0: 238 selected_dist = distance_matrix[i, expert_data_idx] 239 nearest_distance = selected_dist.min().item() 240 nearest_index_selected = selected_dist.argmin().item() 241 nearest_index = expert_data_idx[nearest_index_selected] 242 if w_pi >= w_e_list[nearest_index]: 243 c = c + nearest_distance * w_e_list[nearest_index] 244 w_pi = w_pi - w_e_list[nearest_index] 245 expert_data_idx.pop(nearest_index_selected) 246 else: 247 c = c + w_pi * nearest_distance 248 w_e_list[nearest_index] = w_e_list[nearest_index] - w_pi 249 w_pi = 0 250 reward = self.cfg.alpha * math.exp(self.reward_factor * c) 251 self.reward_table[(s, a)] = torch.FloatTensor([reward]) 252 253 def clear_data(self) -> None: 254 """ 255 Overview: 256 Clearing training data. \ 257 This is a side effect function which clears the data attribute in ``self`` 258 """ 259 self.train_data.clear()