Skip to content

ding.reward_model.guided_cost_reward_model

ding.reward_model.guided_cost_reward_model

GuidedCostRewardModel

Bases: BaseRewardModel

Overview

Policy class of Guided cost algorithm. (https://arxiv.org/pdf/1603.00448.pdf)

Interface: estimate, train, collect_data, clear_date, __init__, state_dict, load_state_dict, learn state_dict_reward_model, load_state_dict_reward_model Config: == ==================== ======== ============= ======================================== ================ ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============= ======================================== ================ 1 type str guided_cost | Reward model register name, refer | | to registry REWARD_MODEL_REGISTRY | 2 | continuous bool True | Whether action is continuous | 3 | learning_rate float 0.001 | learning rate for optimizer | 4 | update_per_ int 100 | Number of updates per collect | | collect | | 5 | batch_size int 64 | Training batch size | 6 | hidden_size int 128 | Linear model hidden size | 7 | action_shape int 1 | Action space shape | 8 | log_every_n int 50 | add loss to log every n iteration | | _train | | 9 | store_model_ int 100 | save model every n iteration | | every_n_train | == ==================== ======== ============= ======================================== ================

collect_data(data)

Overview

Collecting training data, not implemented if reward model (i.e. online_net) is only trained ones, if online_net is trained continuously, there should be some implementations in collect_data method

clear_data()

Overview

Collecting clearing data, not implemented if reward model (i.e. online_net) is only trained ones, if online_net is trained continuously, there should be some implementations in clear_data method

Full Source Code

../ding/reward_model/guided_cost_reward_model.py

1from typing import List, Dict, Any 2from easydict import EasyDict 3 4import torch 5import torch.nn as nn 6import torch.optim as optim 7import torch.nn.functional as F 8from torch.distributions import Independent, Normal 9 10from ding.utils import REWARD_MODEL_REGISTRY 11from ding.utils.data import default_collate 12from .base_reward_model import BaseRewardModel 13 14 15class GuidedCostNN(nn.Module): 16 17 def __init__( 18 self, 19 input_size, 20 hidden_size=128, 21 output_size=1, 22 ): 23 super(GuidedCostNN, self).__init__() 24 self.net = nn.Sequential( 25 nn.Linear(input_size, hidden_size), 26 nn.ReLU(), 27 nn.Linear(hidden_size, hidden_size), 28 nn.ReLU(), 29 nn.Linear(hidden_size, output_size), 30 ) 31 32 def forward(self, x): 33 return self.net(x) 34 35 36@REWARD_MODEL_REGISTRY.register('guided_cost') 37class GuidedCostRewardModel(BaseRewardModel): 38 """ 39 Overview: 40 Policy class of Guided cost algorithm. (https://arxiv.org/pdf/1603.00448.pdf) 41 Interface: 42 ``estimate``, ``train``, ``collect_data``, ``clear_date``, \ 43 ``__init__``, ``state_dict``, ``load_state_dict``, ``learn``\ 44 ``state_dict_reward_model``, ``load_state_dict_reward_model`` 45 Config: 46 == ==================== ======== ============= ======================================== ================ 47 ID Symbol Type Default Value Description Other(Shape) 48 == ==================== ======== ============= ======================================== ================ 49 1 ``type`` str guided_cost | Reward model register name, refer | 50 | to registry ``REWARD_MODEL_REGISTRY`` | 51 2 | ``continuous`` bool True | Whether action is continuous | 52 3 | ``learning_rate`` float 0.001 | learning rate for optimizer | 53 4 | ``update_per_`` int 100 | Number of updates per collect | 54 | ``collect`` | | 55 5 | ``batch_size`` int 64 | Training batch size | 56 6 | ``hidden_size`` int 128 | Linear model hidden size | 57 7 | ``action_shape`` int 1 | Action space shape | 58 8 | ``log_every_n`` int 50 | add loss to log every n iteration | 59 | ``_train`` | | 60 9 | ``store_model_`` int 100 | save model every n iteration | 61 | ``every_n_train`` | 62 == ==================== ======== ============= ======================================== ================ 63 64 """ 65 66 config = dict( 67 # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``. 68 type='guided_cost', 69 # (float) The step size of gradient descent. 70 learning_rate=1e-3, 71 # (int) Action space shape, such as 1. 72 action_shape=1, 73 # (bool) Whether action is continuous. 74 continuous=True, 75 # (int) How many samples in a training batch. 76 batch_size=64, 77 # (int) Linear model hidden size. 78 hidden_size=128, 79 # (int) How many updates(iterations) to train after collector's one collection. 80 # Bigger "update_per_collect" means bigger off-policy. 81 # collect data -> update policy-> collect data -> ... 82 update_per_collect=100, 83 # (int) Add loss to log every n iteration. 84 log_every_n_train=50, 85 # (int) Save model every n iteration. 86 store_model_every_n_train=100, 87 ) 88 89 def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa 90 super(GuidedCostRewardModel, self).__init__() 91 self.cfg = config 92 self.action_shape = self.cfg.action_shape 93 assert device == "cpu" or device.startswith("cuda") 94 self.device = device 95 self.tb_logger = tb_logger 96 self.reward_model = GuidedCostNN(config.input_size, config.hidden_size) 97 self.reward_model.to(self.device) 98 self.opt = optim.Adam(self.reward_model.parameters(), lr=config.learning_rate) 99 100 def train(self, expert_demo: torch.Tensor, samp: torch.Tensor, iter, step): 101 device_0 = expert_demo[0]['obs'].device 102 device_1 = samp[0]['obs'].device 103 for i in range(len(expert_demo)): 104 expert_demo[i]['prob'] = torch.FloatTensor([1]).to(device_0) 105 if self.cfg.continuous: 106 for i in range(len(samp)): 107 (mu, sigma) = samp[i]['logit'] 108 dist = Independent(Normal(mu, sigma), 1) 109 next_action = samp[i]['action'] 110 log_prob = dist.log_prob(next_action) 111 samp[i]['prob'] = torch.exp(log_prob).unsqueeze(0).to(device_1) 112 else: 113 for i in range(len(samp)): 114 probs = F.softmax(samp[i]['logit'], dim=-1) 115 prob = probs[samp[i]['action']] 116 samp[i]['prob'] = prob.to(device_1) 117 # Mix the expert data and sample data to train the reward model. 118 samp.extend(expert_demo) 119 expert_demo = default_collate(expert_demo) 120 samp = default_collate(samp) 121 cost_demo = self.reward_model( 122 torch.cat([expert_demo['obs'], expert_demo['action'].float().reshape(-1, self.action_shape)], dim=-1) 123 ) 124 cost_samp = self.reward_model( 125 torch.cat([samp['obs'], samp['action'].float().reshape(-1, self.action_shape)], dim=-1) 126 ) 127 128 prob = samp['prob'].unsqueeze(-1) 129 loss_IOC = torch.mean(cost_demo) + \ 130 torch.log(torch.mean(torch.exp(-cost_samp)/(prob+1e-7))) 131 # UPDATING THE COST FUNCTION 132 self.opt.zero_grad() 133 loss_IOC.backward() 134 self.opt.step() 135 if iter % self.cfg.log_every_n_train == 0: 136 self.tb_logger.add_scalar('reward_model/loss_iter', loss_IOC, iter) 137 self.tb_logger.add_scalar('reward_model/loss_step', loss_IOC, step) 138 139 def estimate(self, data: list) -> List[Dict]: 140 # NOTE: this estimate method of gcl alg. is a little different from the one in other irl alg., 141 # because its deepcopy is operated before learner train loop. 142 train_data_augmented = data 143 for i in range(len(train_data_augmented)): 144 with torch.no_grad(): 145 reward = self.reward_model( 146 torch.cat([train_data_augmented[i]['obs'], train_data_augmented[i]['action'].float()]).unsqueeze(0) 147 ).squeeze(0) 148 train_data_augmented[i]['reward'] = -reward 149 150 return train_data_augmented 151 152 def collect_data(self, data) -> None: 153 """ 154 Overview: 155 Collecting training data, not implemented if reward model (i.e. online_net) is only trained ones, \ 156 if online_net is trained continuously, there should be some implementations in collect_data method 157 """ 158 # if online_net is trained continuously, there should be some implementations in collect_data method 159 pass 160 161 def clear_data(self): 162 """ 163 Overview: 164 Collecting clearing data, not implemented if reward model (i.e. online_net) is only trained ones, \ 165 if online_net is trained continuously, there should be some implementations in clear_data method 166 """ 167 # if online_net is trained continuously, there should be some implementations in clear_data method 168 pass 169 170 def state_dict_reward_model(self) -> Dict[str, Any]: 171 return { 172 'model': self.reward_model.state_dict(), 173 'optimizer': self.opt.state_dict(), 174 } 175 176 def load_state_dict_reward_model(self, state_dict: Dict[str, Any]) -> None: 177 self.reward_model.load_state_dict(state_dict['model']) 178 self.opt.load_state_dict(state_dict['optimizer'])