Skip to content

ding.reward_model.gail_irl_model

ding.reward_model.gail_irl_model

GailRewardModel

Bases: BaseRewardModel

Overview

The Gail reward model class (https://arxiv.org/abs/1606.03476)

Interface: estimate, train, load_expert_data, collect_data, clear_date, __init__, state_dict, load_state_dict, learn Config: == ==================== ======== ============= =================================== ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============= =================================== ======================= 1 type str gail | RL policy register name, refer | this arg is optional, | to registry POLICY_REGISTRY | a placeholder 2 | expert_data_ str expert_data. | Path to the expert dataset | Should be a '.pkl' | path .pkl | | file 3 | learning_rate float 0.001 | The step size of gradient descent | 4 | update_per_ int 100 | Number of updates per collect | | collect | | 5 | batch_size int 64 | Training batch size | 6 | input_size int | Size of the input: | | | obs_dim + act_dim | 7 | target_new_ int 64 | Collect steps per iteration | | data_count | | 8 | hidden_size int 128 | Linear model hidden size | 9 | collect_count int 100000 | Expert dataset size | One entry is a (s,a) | | | tuple 10 | clear_buffer_ int 1 | clear buffer per fixed iters | make sure replay | per_iters | buffer's data count | | isn't too few. | | (code work in entry) == ==================== ======== ============= =================================== =======================

__init__(config, device, tb_logger)

Overview

Initialize self. See help(type(self)) for accurate signature.

Arguments: - cfg (:obj:EasyDict): Training config - device (:obj:str): Device usage, i.e. "cpu" or "cuda" - tb_logger (:obj:SummaryWriter): Logger, defaultly set as 'SummaryWriter' for model summary

load_expert_data()

Overview

Getting the expert data from config.data_path attribute in self

Effects: This is a side effect function which updates the expert data attribute (i.e. self.expert_data) with fn:concat_state_action_pairs

learn(train_data, expert_data)

Overview

Helper function for train which calculates loss for train data and expert data.

Arguments: - train_data (:obj:torch.Tensor): Data used for training - expert_data (:obj:torch.Tensor): Expert data Returns: - Combined loss calculated of reward model from using train_data and expert_data.

train()

Overview

Training the Gail reward model. The training and expert data are randomly sampled with designated batch size abstracted from the batch_size attribute in self.cfg and correspondingly, the expert_data as well as train_data attributes initialized `self

Effects: - This is a side effect function which updates the reward model and increment the train iteration count.

estimate(data)

Overview

Estimate reward by rewriting the reward key in each row of the data.

Arguments: - data (:obj:list): the list of data used for estimation, with at least obs and action keys. Effects: - This is a side effect function which updates the reward values in place.

collect_data(data)

Overview

Collecting training data formatted by fn:concat_state_action_pairs.

Arguments: - data (:obj:Any): Raw training data (e.g. some form of states, actions, obs, etc) Effects: - This is a side effect function which updates the data attribute in self

clear_data()

Overview

Clearing training data. This is a side effect function which clears the data attribute in self

concat_state_action_pairs(iterator)

Overview

Concatenate state and action pairs from input.

Arguments: - iterator (:obj:Iterable): Iterables with at least obs and action tensor keys. Returns: - res (:obj:Torch.tensor): State and action pairs.

concat_state_action_pairs_one_hot(iterator, action_size)

Overview

Concatenate state and action pairs from input. Action values are one-hot encoded

Arguments: - iterator (:obj:Iterable): Iterables with at least obs and action tensor keys. Returns: - res (:obj:Torch.tensor): State and action pairs.

Full Source Code

../ding/reward_model/gail_irl_model.py

1from typing import List, Dict, Any 2import pickle 3import random 4from collections.abc import Iterable 5from easydict import EasyDict 6 7import torch 8import torch.nn as nn 9import torch.optim as optim 10 11from ding.utils import REWARD_MODEL_REGISTRY 12from .base_reward_model import BaseRewardModel 13import torch.nn.functional as F 14from functools import partial 15 16 17def concat_state_action_pairs(iterator): 18 """ 19 Overview: 20 Concatenate state and action pairs from input. 21 Arguments: 22 - iterator (:obj:`Iterable`): Iterables with at least ``obs`` and ``action`` tensor keys. 23 Returns: 24 - res (:obj:`Torch.tensor`): State and action pairs. 25 """ 26 assert isinstance(iterator, Iterable) 27 res = [] 28 for item in iterator: 29 state = item['obs'].flatten() # to allow 3d obs and actions concatenation 30 action = item['action'] 31 s_a = torch.cat([state, action.float()], dim=-1) 32 res.append(s_a) 33 return res 34 35 36def concat_state_action_pairs_one_hot(iterator, action_size: int): 37 """ 38 Overview: 39 Concatenate state and action pairs from input. Action values are one-hot encoded 40 Arguments: 41 - iterator (:obj:`Iterable`): Iterables with at least ``obs`` and ``action`` tensor keys. 42 Returns: 43 - res (:obj:`Torch.tensor`): State and action pairs. 44 """ 45 assert isinstance(iterator, Iterable) 46 res = [] 47 for item in iterator: 48 state = item['obs'].flatten() # to allow 3d obs and actions concatenation 49 action = item['action'] 50 action = torch.Tensor([int(i == action) for i in range(action_size)]) 51 s_a = torch.cat([state, action], dim=-1) 52 res.append(s_a) 53 return res 54 55 56class RewardModelNetwork(nn.Module): 57 58 def __init__(self, input_size: int, hidden_size: int, output_size: int) -> None: 59 super(RewardModelNetwork, self).__init__() 60 self.l1 = nn.Linear(input_size, hidden_size) 61 self.l2 = nn.Linear(hidden_size, output_size) 62 self.a1 = nn.Tanh() 63 self.a2 = nn.Sigmoid() 64 65 def forward(self, x: torch.Tensor) -> torch.Tensor: 66 out = x 67 out = self.l1(out) 68 out = self.a1(out) 69 out = self.l2(out) 70 out = self.a2(out) 71 return out 72 73 74class AtariRewardModelNetwork(nn.Module): 75 76 def __init__(self, input_size: int, action_size: int) -> None: 77 super(AtariRewardModelNetwork, self).__init__() 78 self.input_size = input_size 79 self.action_size = action_size 80 self.conv1 = nn.Conv2d(4, 16, 7, stride=3) 81 self.conv2 = nn.Conv2d(16, 16, 5, stride=2) 82 self.conv3 = nn.Conv2d(16, 16, 3, stride=1) 83 self.conv4 = nn.Conv2d(16, 16, 3, stride=1) 84 self.fc1 = nn.Linear(784, 64) 85 self.fc2 = nn.Linear(64 + self.action_size, 1) # here we add 1 to take consideration of the action concat 86 self.a = nn.Sigmoid() 87 88 def forward(self, x: torch.Tensor) -> torch.Tensor: 89 # input: x = [B, 4 x 84 x 84 + self.action_size], last element is action 90 actions = x[:, -self.action_size:] # [B, self.action_size] 91 # get observations 92 x = x[:, :-self.action_size] 93 x = x.reshape([-1] + self.input_size) # [B, 4, 84, 84] 94 x = F.leaky_relu(self.conv1(x)) 95 x = F.leaky_relu(self.conv2(x)) 96 x = F.leaky_relu(self.conv3(x)) 97 x = F.leaky_relu(self.conv4(x)) 98 x = x.reshape(-1, 784) 99 x = F.leaky_relu(self.fc1(x)) 100 x = torch.cat([x, actions], dim=-1) 101 x = self.fc2(x) 102 r = self.a(x) 103 return r 104 105 106@REWARD_MODEL_REGISTRY.register('gail') 107class GailRewardModel(BaseRewardModel): 108 """ 109 Overview: 110 The Gail reward model class (https://arxiv.org/abs/1606.03476) 111 Interface: 112 ``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \ 113 ``__init__``, ``state_dict``, ``load_state_dict``, ``learn`` 114 Config: 115 == ==================== ======== ============= =================================== ======================= 116 ID Symbol Type Default Value Description Other(Shape) 117 == ==================== ======== ============= =================================== ======================= 118 1 ``type`` str gail | RL policy register name, refer | this arg is optional, 119 | to registry ``POLICY_REGISTRY`` | a placeholder 120 2 | ``expert_data_`` str expert_data. | Path to the expert dataset | Should be a '.pkl' 121 | ``path`` .pkl | | file 122 3 | ``learning_rate`` float 0.001 | The step size of gradient descent | 123 4 | ``update_per_`` int 100 | Number of updates per collect | 124 | ``collect`` | | 125 5 | ``batch_size`` int 64 | Training batch size | 126 6 | ``input_size`` int | Size of the input: | 127 | | obs_dim + act_dim | 128 7 | ``target_new_`` int 64 | Collect steps per iteration | 129 | ``data_count`` | | 130 8 | ``hidden_size`` int 128 | Linear model hidden size | 131 9 | ``collect_count`` int 100000 | Expert dataset size | One entry is a (s,a) 132 | | | tuple 133 10 | ``clear_buffer_`` int 1 | clear buffer per fixed iters | make sure replay 134 | ``per_iters`` | buffer's data count 135 | | isn't too few. 136 | | (code work in entry) 137 == ==================== ======== ============= =================================== ======================= 138 """ 139 config = dict( 140 # (str) RL policy register name, refer to registry ``POLICY_REGISTRY``. 141 type='gail', 142 # (float) The step size of gradient descent. 143 learning_rate=1e-3, 144 # (int) How many updates(iterations) to train after collector's one collection. 145 # Bigger "update_per_collect" means bigger off-policy. 146 # collect data -> update policy-> collect data -> ... 147 update_per_collect=100, 148 # (int) How many samples in a training batch. 149 batch_size=64, 150 # (int) Size of the input: obs_dim + act_dim. 151 input_size=4, 152 # (int) Collect steps per iteration. 153 target_new_data_count=64, 154 # (int) Linear model hidden size. 155 hidden_size=128, 156 # (int) Expert dataset size. 157 collect_count=100000, 158 # (int) Clear buffer per fixed iters. 159 clear_buffer_per_iters=1, 160 ) 161 162 def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa 163 """ 164 Overview: 165 Initialize ``self.`` See ``help(type(self))`` for accurate signature. 166 Arguments: 167 - cfg (:obj:`EasyDict`): Training config 168 - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda" 169 - tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary 170 """ 171 super(GailRewardModel, self).__init__() 172 self.cfg = config 173 assert device in ["cpu", "cuda"] or "cuda" in device 174 self.device = device 175 self.tb_logger = tb_logger 176 obs_shape = config.input_size 177 if isinstance(obs_shape, int) or len(obs_shape) == 1: 178 self.reward_model = RewardModelNetwork(config.input_size, config.hidden_size, 1) 179 self.concat_state_action_pairs = concat_state_action_pairs 180 elif len(obs_shape) == 3: 181 action_shape = self.cfg.action_size 182 self.reward_model = AtariRewardModelNetwork(config.input_size, action_shape) 183 self.concat_state_action_pairs = partial(concat_state_action_pairs_one_hot, action_size=action_shape) 184 self.reward_model.to(self.device) 185 self.expert_data = [] 186 self.train_data = [] 187 self.expert_data_loader = None 188 self.opt = optim.Adam(self.reward_model.parameters(), config.learning_rate) 189 self.train_iter = 0 190 191 self.load_expert_data() 192 193 def load_expert_data(self) -> None: 194 """ 195 Overview: 196 Getting the expert data from ``config.data_path`` attribute in self 197 Effects: 198 This is a side effect function which updates the expert data attribute \ 199 (i.e. ``self.expert_data``) with ``fn:concat_state_action_pairs`` 200 """ 201 with open(self.cfg.data_path + '/expert_data.pkl', 'rb') as f: 202 self.expert_data_loader: list = pickle.load(f) 203 self.expert_data = self.concat_state_action_pairs(self.expert_data_loader) 204 205 def state_dict(self) -> Dict[str, Any]: 206 return { 207 'model': self.reward_model.state_dict(), 208 } 209 210 def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 211 self.reward_model.load_state_dict(state_dict['model']) 212 213 def learn(self, train_data: torch.Tensor, expert_data: torch.Tensor) -> float: 214 """ 215 Overview: 216 Helper function for ``train`` which calculates loss for train data and expert data. 217 Arguments: 218 - train_data (:obj:`torch.Tensor`): Data used for training 219 - expert_data (:obj:`torch.Tensor`): Expert data 220 Returns: 221 - Combined loss calculated of reward model from using ``train_data`` and ``expert_data``. 222 """ 223 # calculate loss, here are some hyper-param 224 out_1: torch.Tensor = self.reward_model(train_data) 225 loss_1: torch.Tensor = torch.log(out_1 + 1e-8).mean() 226 out_2: torch.Tensor = self.reward_model(expert_data) 227 loss_2: torch.Tensor = torch.log(1 - out_2 + 1e-8).mean() 228 # log(x) with 0<x<1 is negative, so to reduce this loss we have to minimize the opposite 229 loss: torch.Tensor = -(loss_1 + loss_2) 230 self.opt.zero_grad() 231 loss.backward() 232 self.opt.step() 233 return loss.item() 234 235 def train(self) -> None: 236 """ 237 Overview: 238 Training the Gail reward model. The training and expert data are randomly sampled with designated\ 239 batch size abstracted from the ``batch_size`` attribute in ``self.cfg`` and \ 240 correspondingly, the ``expert_data`` as well as ``train_data`` attributes initialized ``self` 241 Effects: 242 - This is a side effect function which updates the reward model and increment the train iteration count. 243 """ 244 for _ in range(self.cfg.update_per_collect): 245 sample_expert_data: list = random.sample(self.expert_data, self.cfg.batch_size) 246 sample_train_data: list = random.sample(self.train_data, self.cfg.batch_size) 247 sample_expert_data = torch.stack(sample_expert_data).to(self.device) 248 sample_train_data = torch.stack(sample_train_data).to(self.device) 249 loss = self.learn(sample_train_data, sample_expert_data) 250 self.tb_logger.add_scalar('reward_model/gail_loss', loss, self.train_iter) 251 self.train_iter += 1 252 253 def estimate(self, data: list) -> List[Dict]: 254 """ 255 Overview: 256 Estimate reward by rewriting the reward key in each row of the data. 257 Arguments: 258 - data (:obj:`list`): the list of data used for estimation, with at least \ 259 ``obs`` and ``action`` keys. 260 Effects: 261 - This is a side effect function which updates the reward values in place. 262 """ 263 # NOTE: deepcopy reward part of data is very important, 264 # otherwise the reward of data in the replay buffer will be incorrectly modified. 265 train_data_augmented = self.reward_deepcopy(data) 266 res = self.concat_state_action_pairs(train_data_augmented) 267 res = torch.stack(res).to(self.device) 268 with torch.no_grad(): 269 reward = self.reward_model(res).squeeze(-1).cpu() 270 reward = torch.chunk(reward, reward.shape[0], dim=0) 271 for item, rew in zip(train_data_augmented, reward): 272 item['reward'] = -torch.log(rew + 1e-8) 273 274 return train_data_augmented 275 276 def collect_data(self, data: list) -> None: 277 """ 278 Overview: 279 Collecting training data formatted by ``fn:concat_state_action_pairs``. 280 Arguments: 281 - data (:obj:`Any`): Raw training data (e.g. some form of states, actions, obs, etc) 282 Effects: 283 - This is a side effect function which updates the data attribute in ``self`` 284 """ 285 self.train_data.extend(self.concat_state_action_pairs(data)) 286 287 def clear_data(self) -> None: 288 """ 289 Overview: 290 Clearing training data. \ 291 This is a side effect function which clears the data attribute in ``self`` 292 """ 293 self.train_data.clear()