Skip to content

ding.policy.il

ding.policy.il

ILPolicy

Bases: Policy

Overview

Policy class of Imitation learning algorithm

Interface: init, set_setting, repr, state_dict_handle Property: learn_mode, collect_mode, eval_mode

Full Source Code

../ding/policy/il.py

1from typing import List, Dict, Any, Tuple, Union 2from collections import namedtuple 3import torch 4import torch.nn as nn 5 6from ding.torch_utils import Adam, to_device 7from ding.model import model_wrap 8from ding.utils import POLICY_REGISTRY 9from ding.utils.data import default_collate, default_decollate 10from .base_policy import Policy 11FootballKaggle5thPlaceModel = None 12 13 14@POLICY_REGISTRY.register('IL') 15class ILPolicy(Policy): 16 r""" 17 Overview: 18 Policy class of Imitation learning algorithm 19 Interface: 20 __init__, set_setting, __repr__, state_dict_handle 21 Property: 22 learn_mode, collect_mode, eval_mode 23 """ 24 config = dict( 25 type='IL', 26 cuda=True, 27 # (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same) 28 on_policy=False, 29 priority=False, 30 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 31 priority_IS_weight=False, 32 learn=dict( 33 34 # (int) collect n_episode data, train model n_iteration time 35 update_per_collect=20, 36 # (int) the number of data for a train iteration 37 batch_size=64, 38 # (float) gradient-descent step size 39 learning_rate=0.0002, 40 ), 41 collect=dict( 42 # (int) collect n_sample data, train model n_iteration time 43 # n_sample=128, 44 # (float) discount factor for future reward, defaults int [0, 1] 45 discount_factor=0.99, 46 ), 47 eval=dict(evaluator=dict(eval_freq=800, ), ), 48 other=dict( 49 replay_buffer=dict( 50 replay_buffer_size=100000, 51 # (int) max use count of data, if count is bigger than this value, 52 # the data will be removed from buffer 53 max_reuse=10, 54 ), 55 command=dict(), 56 ), 57 ) 58 59 # TODO different collect model and learn model 60 def default_model(self) -> Tuple[str, List[str]]: 61 return 'football_iql', ['dizoo.gfootball.model.iql.iql_network'] 62 63 def _init_learn(self) -> None: 64 r""" 65 Overview: 66 Learn mode init method. Called by ``self.__init__``. 67 Init optimizers, algorithm config, main and target models. 68 """ 69 # actor and critic optimizer 70 self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) 71 72 # main and target models 73 self._learn_model = model_wrap(self._model, wrapper_name='base') 74 self._learn_model.train() 75 self._learn_model.reset() 76 77 self._forward_learn_cnt = 0 # count iterations 78 79 def _forward_learn(self, data: dict) -> Dict[str, Any]: 80 r""" 81 Overview: 82 Forward and backward function of learn mode. 83 Arguments: 84 - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] 85 Returns: 86 - info_dict (:obj:`Dict[str, Any]`): Including at least actor and critic lr, different losses. 87 """ 88 data = default_collate(data, cat_1dim=False) 89 data['done'] = None 90 if self._cuda: 91 data = to_device(data, self._device) 92 loss_dict = {} 93 # ==================== 94 # imitation learn forward 95 # ==================== 96 obs = data.get('obs') 97 logit = data.get('logit') 98 assert isinstance(obs['processed_obs'], torch.Tensor), obs['processed_obs'] 99 model_action_logit = self._learn_model.forward(obs['processed_obs'])['logit'] 100 supervised_loss = nn.MSELoss(reduction='none')(model_action_logit, logit).mean() 101 self._optimizer.zero_grad() 102 supervised_loss.backward() 103 self._optimizer.step() 104 loss_dict['supervised_loss'] = supervised_loss 105 return { 106 'cur_lr': self._optimizer.defaults['lr'], 107 **loss_dict, 108 } 109 110 def _state_dict_learn(self) -> Dict[str, Any]: 111 return { 112 'model': self._learn_model.state_dict(), 113 'optimizer': self._optimizer.state_dict(), 114 } 115 116 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 117 self._learn_model.load_state_dict(state_dict['model']) 118 self._optimizer.load_state_dict(state_dict['optimizer']) 119 120 def _init_collect(self) -> None: 121 r""" 122 Overview: 123 Collect mode init method. Called by ``self.__init__``. 124 Init traj and unroll length, collect model. 125 """ 126 self._collect_model = model_wrap(FootballKaggle5thPlaceModel(), wrapper_name='base') 127 self._gamma = self._cfg.collect.discount_factor 128 self._collect_model.eval() 129 self._collect_model.reset() 130 131 def _forward_collect(self, data: dict) -> dict: 132 r""" 133 Overview: 134 Forward function of collect mode. 135 Arguments: 136 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 137 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 138 Returns: 139 - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. 140 ReturnsKeys 141 - necessary: ``action`` 142 - optional: ``logit`` 143 """ 144 data_id = list(data.keys()) 145 data = default_collate(list(data.values())) 146 if self._cuda: 147 data = to_device(data, self._device) 148 with torch.no_grad(): 149 output = self._collect_model.forward(default_decollate(data['obs']['raw_obs'])) 150 if self._cuda: 151 output = to_device(output, 'cpu') 152 output = default_decollate(output) 153 return {i: d for i, d in zip(data_id, output)} 154 155 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> Dict[str, Any]: 156 r""" 157 Overview: 158 Generate dict type transition data from inputs. 159 Arguments: 160 - obs (:obj:`Any`): Env observation 161 - model_output (:obj:`dict`): Output of collect model, including at least ['action'] 162 - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ 163 (here 'obs' indicates obs after env step, i.e. next_obs). 164 Return: 165 - transition (:obj:`Dict[str, Any]`): Dict type transition data. 166 """ 167 transition = { 168 'obs': obs, 169 'action': model_output['action'], 170 'logit': model_output['logit'], 171 'reward': timestep.reward, 172 'done': timestep.done, 173 } 174 return transition 175 176 def _get_train_sample(self, origin_data: list) -> Union[None, List[Any]]: 177 datas = [] 178 pre_rew = 0 179 for i in range(len(origin_data) - 1, -1, -1): 180 data = {} 181 data['obs'] = origin_data[i]['obs'] 182 data['action'] = origin_data[i]['action'] 183 cur_rew = origin_data[i]['reward'] 184 pre_rew = cur_rew + (pre_rew * self._gamma) 185 # sample uniformly 186 data['priority'] = 1 187 data['logit'] = origin_data[i]['logit'] 188 datas.append(data) 189 return datas 190 191 def _init_eval(self) -> None: 192 r""" 193 Overview: 194 Evaluate mode init method. Called by ``self.__init__``. 195 Init eval model. Unlike learn and collect model, eval model does not need noise. 196 """ 197 self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') 198 self._eval_model.reset() 199 200 def _forward_eval(self, data: dict) -> dict: 201 r""" 202 Overview: 203 Forward function of eval mode, similar to ``self._forward_collect``. 204 Arguments: 205 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 206 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 207 Returns: 208 - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. 209 ReturnsKeys 210 - necessary: ``action`` 211 - optional: ``logit`` 212 """ 213 data_id = list(data.keys()) 214 data = default_collate(list(data.values())) 215 if self._cuda: 216 data = to_device(data, self._device) 217 with torch.no_grad(): 218 output = self._eval_model.forward(data['obs']['processed_obs']) 219 if self._cuda: 220 output = to_device(output, 'cpu') 221 output = default_decollate(output) 222 return {i: d for i, d in zip(data_id, output)} 223 224 def _monitor_vars_learn(self) -> List[str]: 225 r""" 226 Overview: 227 Return variables' name if variables are to used in monitor. 228 Returns: 229 - vars (:obj:`List[str]`): Variables' name list. 230 """ 231 return ['cur_lr', 'supervised_loss']