ding.policy.prompt_pg¶
ding.policy.prompt_pg
¶
Full Source Code
../ding/policy/prompt_pg.py
1from typing import List, Dict, Any, Tuple, Union 2from collections import namedtuple 3import torch 4 5from ding.rl_utils import get_train_sample 6from ding.torch_utils import Adam, to_device 7from ding.utils import POLICY_REGISTRY, split_data_generator 8from ding.utils.data import default_collate, default_decollate 9from .base_policy import Policy 10from ..model import model_wrap 11 12 13@POLICY_REGISTRY.register('prompt_pg') 14class PromptPGPolicy(Policy): 15 r""" 16 Overview: 17 Policy class of Prompt Policy Gradient (PromptPG) algorithm. 18 Link of the original paper: https://arxiv.org/abs/2209.14610 19 """ 20 config = dict( 21 # (string) RL policy register name (refer to function "register_policy"). 22 type='prompt_pg', 23 # (bool) whether to use cuda for network. 24 cuda=True, 25 # (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same) 26 on_policy=True, # for pg strictly on policy algorithm, this line should not be modified by users 27 # (bool) whether to use deterministic action for evaluation. 28 deterministic_eval=True, 29 # (int) The number of actions that can be done simultaneously in one timestep. 30 shot_number=1, 31 learn=dict( 32 # (int) the number of samples for one update. 33 batch_size=64, 34 # (float) the step size of one gradient descend. 35 learning_rate=0.001, 36 # ============================================================== 37 # The following configs is algorithm-specific 38 # ============================================================== 39 # (float) loss weight of the entropy regularization, the weight of policy network is set to 1 40 entropy_weight=0.01, 41 # (float) max grad norm value. 42 grad_norm=5, 43 # (bool) whether to ignore done signal for non-termination env. 44 ignore_done=False, 45 ), 46 collect=dict( 47 # (int) collect n_sample data, train model n_iteration times 48 # n_episode=8, 49 # (int) trajectory unroll length 50 unroll_len=1, 51 # ============================================================== 52 # The following configs is algorithm-specific 53 # ============================================================== 54 # (float) discount factor for future reward, defaults int [0, 1] 55 discount_factor=0, 56 collector=dict(get_train_sample=True), 57 ), 58 eval=dict(), 59 ) 60 61 def default_model(self) -> Tuple[str, List[str]]: 62 return 'language_transformer', ['ding.model.template.language_transformer'] 63 64 def _init_learn(self) -> None: 65 r""" 66 Overview: 67 Learn mode init method. Called by ``self.__init__``. 68 Init the optimizer, algorithm config, main and target models. 69 """ 70 # Optimizer 71 self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) 72 73 self._entropy_weight = self._cfg.learn.entropy_weight 74 self._grad_norm = self._cfg.learn.grad_norm 75 self._learn_model = self._model # for compatibility 76 77 def _forward_learn(self, data: dict) -> Dict[str, Any]: 78 r""" 79 Overview: 80 Forward and backward function of learn mode. 81 Arguments: 82 - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward'] 83 Returns: 84 - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. 85 """ 86 self._model.train() 87 88 return_infos = [] 89 for i in range(0, len(data), self._cfg.learn.batch_size): 90 batch = default_collate(data[i:i + self._cfg.learn.batch_size]) 91 if self._cuda: 92 batch = to_device(batch, self._device) 93 94 # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected) 95 train_samples, cand_samples = batch["obs"]["train_sample"], batch["obs"]["candidate_samples"] 96 for ii in range(len(cand_samples)): 97 cand_samples[ii] = cand_samples[ii][0] 98 output = self._learn_model.forward(train_samples, cand_samples) 99 return_ = batch['return'] 100 101 # calculate PG loss 102 real_act = batch['action'] # shape: (B, shot_number) 103 if len(real_act.shape) == 1: 104 real_act = real_act.unsqueeze(-1) 105 # Calculate loss. 106 total_policy_loss, total_entropy_loss = 0, 0 107 for ii in range(self._cfg.shot_number): 108 log_prob = output['dist'].log_prob(real_act[:, ii]) 109 policy_loss = -(log_prob * return_).mean() 110 total_policy_loss += policy_loss 111 total_entropy_loss += -self._cfg.learn.entropy_weight * output['dist'].entropy().mean() 112 total_loss = total_entropy_loss + total_policy_loss 113 114 # update 115 self._optimizer.zero_grad() 116 total_loss.backward() 117 118 grad_norm = torch.nn.utils.clip_grad_norm_( 119 list(self._learn_model.parameters()), 120 max_norm=self._grad_norm, 121 ) 122 self._optimizer.step() 123 124 # only record last updates information in logger 125 return_info = { 126 'cur_lr': self._optimizer.param_groups[0]['lr'], 127 'total_loss': total_loss.item(), 128 'policy_loss': total_policy_loss.item(), 129 'entropy_loss': total_entropy_loss.item(), 130 'return_abs_max': return_.abs().max().item(), 131 'grad_norm': grad_norm, 132 } 133 return_infos.append(return_info) 134 return return_infos 135 136 def _init_collect(self) -> None: 137 self._unroll_len = self._cfg.collect.unroll_len 138 self._gamma = self._cfg.collect.discount_factor 139 self._collect_model = model_wrap(self._model, wrapper_name='combination_multinomial_sample') 140 141 def _forward_collect(self, data: dict) -> dict: 142 data_id = list(data.keys()) 143 data = default_collate(list(data.values())) 144 self._model.eval() 145 with torch.no_grad(): 146 # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected) 147 for ii in range(len(data['candidate_samples'])): 148 data['candidate_samples'][ii] = data['candidate_samples'][ii][0] 149 output = self._collect_model.forward(self._cfg.shot_number, data['train_sample'], data['candidate_samples']) 150 if self._cuda: 151 output = to_device(output, 'cpu') 152 output = default_decollate(output) 153 return {i: d for i, d in zip(data_id, output)} 154 155 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: 156 r""" 157 Overview: 158 Generate dict type transition data from inputs. 159 Arguments: 160 - obs (:obj:`Any`): Env observation 161 - model_output (:obj:`dict`): Output of collect model, including at least ['action'] 162 - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ 163 (here 'obs' indicates obs after env step). 164 Returns: 165 - transition (:obj:`dict`): Dict type transition data. 166 """ 167 return { 168 'obs': obs, 169 'action': model_output['action'], 170 'reward': timestep.reward, 171 'done': timestep.done, 172 } 173 174 def _get_train_sample(self, data: list) -> Union[None, List[Any]]: 175 r""" 176 Overview: 177 Get the trajectory and the n step return data, then sample from the n_step return data 178 Arguments: 179 - data (:obj:`list`): The trajectory's buffer list 180 Returns: 181 - samples (:obj:`dict`): The training samples generated 182 """ 183 if self._cfg.learn.ignore_done: 184 raise NotImplementedError 185 186 R = 0. 187 for i in reversed(range(len(data))): 188 R = self._gamma * R + data[i]['reward'] 189 data[i]['return'] = R 190 return get_train_sample(data, self._unroll_len) 191 192 def _init_eval(self) -> None: 193 self._eval_model = model_wrap(self._model, wrapper_name='combination_argmax_sample') 194 195 def _forward_eval(self, data: dict) -> dict: 196 data_id = list(data.keys()) 197 data = default_collate(list(data.values())) 198 self._model.eval() 199 with torch.no_grad(): 200 # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected) 201 for ii in range(len(data['candidate_samples'])): 202 data['candidate_samples'][ii] = data['candidate_samples'][ii][0] 203 output = self._eval_model.forward(self._cfg.shot_number, data['train_sample'], data['candidate_samples']) 204 if self._cuda: 205 output = to_device(output, 'cpu') 206 output = default_decollate(output) 207 return {i: d for i, d in zip(data_id, output)} 208 209 def _monitor_vars_learn(self) -> List[str]: 210 return super()._monitor_vars_learn() + ['policy_loss', 'entropy_loss', 'return_abs_max', 'grad_norm']