ding.policy.pc¶
ding.policy.pc
¶
Full Source Code
../ding/policy/pc.py
1import math 2from typing import List, Dict, Any, Tuple 3from collections import namedtuple 4 5import torch 6import torch.nn as nn 7from torch.optim import Adam, SGD, AdamW 8from torch.optim.lr_scheduler import LambdaLR 9 10from ding.policy import Policy 11from ding.model import model_wrap 12from ding.torch_utils import to_device 13from ding.utils import EasyTimer 14from ding.utils import POLICY_REGISTRY 15 16 17@POLICY_REGISTRY.register('pc_bfs') 18class ProcedureCloningBFSPolicy(Policy): 19 20 def default_model(self) -> Tuple[str, List[str]]: 21 return 'pc_bfs', ['ding.model.template.procedure_cloning'] 22 23 config = dict( 24 type='pc', 25 cuda=False, 26 on_policy=False, 27 continuous=False, 28 max_bfs_steps=100, 29 learn=dict( 30 update_per_collect=1, 31 batch_size=32, 32 learning_rate=1e-5, 33 lr_decay=False, 34 decay_epoch=30, 35 decay_rate=0.1, 36 warmup_lr=1e-4, 37 warmup_epoch=3, 38 optimizer='SGD', 39 momentum=0.9, 40 weight_decay=1e-4, 41 ), 42 collect=dict( 43 unroll_len=1, 44 noise=False, 45 noise_sigma=0.2, 46 noise_range=dict( 47 min=-0.5, 48 max=0.5, 49 ), 50 ), 51 eval=dict(), 52 other=dict(replay_buffer=dict(replay_buffer_size=10000)), 53 ) 54 55 def _init_learn(self): 56 assert self._cfg.learn.optimizer in ['SGD', 'Adam'] 57 if self._cfg.learn.optimizer == 'SGD': 58 self._optimizer = SGD( 59 self._model.parameters(), 60 lr=self._cfg.learn.learning_rate, 61 weight_decay=self._cfg.learn.weight_decay, 62 momentum=self._cfg.learn.momentum 63 ) 64 elif self._cfg.learn.optimizer == 'Adam': 65 if self._cfg.learn.weight_decay is None: 66 self._optimizer = Adam( 67 self._model.parameters(), 68 lr=self._cfg.learn.learning_rate, 69 ) 70 else: 71 self._optimizer = AdamW( 72 self._model.parameters(), 73 lr=self._cfg.learn.learning_rate, 74 weight_decay=self._cfg.learn.weight_decay 75 ) 76 if self._cfg.learn.lr_decay: 77 78 def lr_scheduler_fn(epoch): 79 if epoch <= self._cfg.learn.warmup_epoch: 80 return self._cfg.learn.warmup_lr / self._cfg.learn.learning_rate 81 else: 82 ratio = (epoch - self._cfg.learn.warmup_epoch) // self._cfg.learn.decay_epoch 83 return math.pow(self._cfg.learn.decay_rate, ratio) 84 85 self._lr_scheduler = LambdaLR(self._optimizer, lr_scheduler_fn) 86 self._timer = EasyTimer(cuda=True) 87 self._learn_model = model_wrap(self._model, 'base') 88 self._learn_model.reset() 89 self._max_bfs_steps = self._cfg.max_bfs_steps 90 self._maze_size = self._cfg.maze_size 91 self._num_actions = self._cfg.num_actions 92 93 self._loss = nn.CrossEntropyLoss() 94 95 def process_states(self, observations, maze_maps): 96 """Returns [B, W, W, 3] binary values. Channels are (wall; goal; obs)""" 97 loc = torch.nn.functional.one_hot( 98 (observations[:, 0] * self._maze_size + observations[:, 1]).long(), 99 self._maze_size * self._maze_size, 100 ).long() 101 loc = torch.reshape(loc, [observations.shape[0], self._maze_size, self._maze_size]) 102 states = torch.cat([maze_maps, loc], dim=-1).long() 103 return states 104 105 def _forward_learn(self, data): 106 if self._cuda: 107 collated_data = to_device(data, self._device) 108 else: 109 collated_data = data 110 observations = collated_data['obs'], 111 bfs_input_maps, bfs_output_maps = collated_data['bfs_in'].long(), collated_data['bfs_out'].long() 112 states = observations 113 bfs_input_onehot = torch.nn.functional.one_hot(bfs_input_maps, self._num_actions + 1).float() 114 115 bfs_states = torch.cat([ 116 states, 117 bfs_input_onehot, 118 ], dim=-1) 119 logits = self._model(bfs_states)['logit'] 120 logits = logits.flatten(0, -2) 121 labels = bfs_output_maps.flatten(0, -1) 122 123 loss = self._loss(logits, labels) 124 preds = torch.argmax(logits, dim=-1) 125 acc = torch.sum((preds == labels)) / preds.shape[0] 126 127 self._optimizer.zero_grad() 128 loss.backward() 129 self._optimizer.step() 130 pred_loss = loss.item() 131 132 cur_lr = [param_group['lr'] for param_group in self._optimizer.param_groups] 133 cur_lr = sum(cur_lr) / len(cur_lr) 134 return {'cur_lr': cur_lr, 'total_loss': pred_loss, 'acc': acc} 135 136 def _monitor_vars_learn(self): 137 return ['cur_lr', 'total_loss', 'acc'] 138 139 def _init_eval(self): 140 self._eval_model = model_wrap(self._model, wrapper_name='base') 141 self._eval_model.reset() 142 143 def _forward_eval(self, data): 144 if self._cuda: 145 data = to_device(data, self._device) 146 max_len = self._max_bfs_steps 147 data_id = list(data.keys()) 148 output = {} 149 150 for ii in data_id: 151 states = data[ii].unsqueeze(0) 152 bfs_input_maps = self._num_actions * torch.ones([1, self._maze_size, self._maze_size]).long() 153 if self._cuda: 154 bfs_input_maps = to_device(bfs_input_maps, self._device) 155 xy = torch.where(states[:, :, :, -1] == 1) 156 observation = (xy[1][0].item(), xy[2][0].item()) 157 158 i = 0 159 while bfs_input_maps[0, observation[0], observation[1]].item() == self._num_actions and i < max_len: 160 bfs_input_onehot = torch.nn.functional.one_hot(bfs_input_maps, self._num_actions + 1).long() 161 162 bfs_states = torch.cat([ 163 states, 164 bfs_input_onehot, 165 ], dim=-1) 166 logits = self._model(bfs_states)['logit'] 167 bfs_input_maps = torch.argmax(logits, dim=-1) 168 i += 1 169 output[ii] = bfs_input_maps[0, observation[0], observation[1]] 170 if self._cuda: 171 output[ii] = {'action': to_device(output[ii], 'cpu'), 'info': {}} 172 if output[ii]['action'].item() == self._num_actions: 173 output[ii]['action'] = torch.randint(low=0, high=self._num_actions, size=[1])[0] 174 return output 175 176 def _init_collect(self) -> None: 177 raise NotImplementedError 178 179 def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: 180 raise NotImplementedError 181 182 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: 183 raise NotImplementedError 184 185 def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 186 raise NotImplementedError