Skip to content

ding.policy.prompt_awr

ding.policy.prompt_awr

PromptAWRPolicy

Bases: Policy

Overview

Policy class of AWR (Advantage Weighted Regression) algorithm, proposed in https://arxiv.org/abs/1910.00177. Especially, this policy is designed for training a language model policy. In this policy, the environment's observation includes the current context, a list of optional actions (strings). The final output of the policy is a set of optional actions with a size of shot_number.

default_model()

Overview

Returns the default model configuration used by the AWR algorithm. __init__ method will automatically call this method to get the default model setting and create model.

Returns:

Type Description
Tuple[str, List[str]]
  • model_info (:obj:Tuple[str, List[str]]): Tuple containing the registered model name and model's import_names.

Full Source Code

../ding/policy/prompt_awr.py

1from collections import namedtuple 2from typing import List, Dict, Any, Tuple, Union 3 4import torch 5 6from ding.model import model_wrap 7from ding.rl_utils import get_train_sample 8from ding.torch_utils import Adam, to_device 9from ding.utils import POLICY_REGISTRY 10from ding.utils.data import default_collate, default_decollate 11from .base_policy import Policy 12 13 14@POLICY_REGISTRY.register('prompt_awr') 15class PromptAWRPolicy(Policy): 16 """ 17 Overview: 18 Policy class of AWR (Advantage Weighted Regression) algorithm, proposed in https://arxiv.org/abs/1910.00177. 19 Especially, this policy is designed for training a language model policy. 20 In this policy, the environment's observation includes the current context, a list of optional actions 21 (strings). The final output of the policy is a set of optional actions with a size of ``shot_number``. 22 """ 23 config = dict( 24 # (str) Name of the registered RL policy (refer to the "register_policy" function). 25 type='prompt_awr', 26 # (bool) Flag to enable CUDA for model computation. 27 cuda=False, 28 # (bool) Flag for using on-policy training (training policy is the same as the behavior policy). 29 on_policy=False, 30 # (bool) Flag for enabling priority experience replay. Must be False when priority_IS_weight is False. 31 priority=False, 32 # (bool) Flag for using Importance Sampling weights to correct updates. Requires `priority` to be True. 33 priority_IS_weight=False, 34 # (str) Type of action space used in the policy, with valid options ['discrete', 'continuous']. 35 action_space='discrete', 36 # (int) The number of actions that can be done simultaneously in one timestep. 37 shot_number=1, 38 # learn_mode configuration 39 learn=dict( 40 # (int) Number of updates per data collection. A2C requires this to be set to 1. 41 update_per_collect=1, 42 # (int) Batch size for learning. 43 batch_size=64, 44 # (float) Learning rate for optimizer. 45 learning_rate=0.001, 46 # (Tuple[float, float]) Coefficients used for computing running averages of gradient and its square. 47 betas=(0.9, 0.999), 48 # (float) Term added to the denominator to improve numerical stability in optimizer. 49 eps=1e-8, 50 # (float) Maximum norm for gradients. 51 grad_norm=0.5, 52 # (float) Scaling factor for value network loss relative to policy network loss. 53 value_weight=0.5, 54 # (float) Coefficient that controls the exp scale in awr algorithm. 55 beta=1.0, 56 # (float) Weight of entropy regularization in the loss function. 57 entropy_weight=0.001, 58 # (Tuple[float, float]) The range of adv. Value that exceeds this range will be clipped. 59 adv_range=(-0.5, 0.5), 60 # (bool) If set to True, the 'done' signals that indicate the end of an episode due to environment time 61 # limits are disregarded. By default, this is set to False. This setting is particularly useful for tasks 62 # that have a predetermined episode length, such as HalfCheetah and various other MuJoCo environments, 63 # where the maximum length is capped at 1000 steps. When enabled, any 'done' signal triggered by reaching 64 # the maximum episode steps will be overridden to 'False'. This ensures the accurate calculation of the 65 # Temporal Difference (TD) error, using the formula `gamma * (1 - done) * next_v + reward`, 66 # even when the episode surpasses the predefined step limit. 67 ignore_done=False, 68 ), 69 # collect_mode configuration 70 collect=dict( 71 # (int) The length of rollout for data collection. 72 unroll_len=1, 73 # (float) Discount factor for calculating future rewards, typically in the range [0, 1]. 74 discount_factor=0.9, 75 # (float) Trade-off parameter for balancing TD-error and Monte Carlo error in GAE. 76 gae_lambda=0.95, 77 ), 78 # eval_mode configuration (kept empty for compatibility purposes) 79 eval=dict(), 80 ) 81 82 def default_model(self) -> Tuple[str, List[str]]: 83 """ 84 Overview: 85 Returns the default model configuration used by the AWR algorithm. ``__init__`` method will \ 86 automatically call this method to get the default model setting and create model. 87 88 Returns: 89 - model_info (:obj:`Tuple[str, List[str]]`): \ 90 Tuple containing the registered model name and model's import_names. 91 """ 92 return 'language_transformer', ['ding.model.template.language_transformer'] 93 94 def _init_learn(self) -> None: 95 """ 96 Overview: 97 Initialize the learn mode of policy, including related attributes and modules. For AWR, it mainly \ 98 contains optimizer, algorithm-specific arguments such as value_weight, entropy_weight, adv_norm 99 and grad_norm, and main model. \ 100 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 101 102 .. note:: 103 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 104 and ``_load_state_dict_learn`` methods. 105 106 .. note:: 107 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 108 109 .. note:: 110 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 111 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 112 """ 113 assert self._cfg.action_space == "discrete" 114 # Optimizer 115 self._optimizer = Adam( 116 self._model.parameters(), 117 lr=self._cfg.learn.learning_rate, 118 betas=self._cfg.learn.betas, 119 eps=self._cfg.learn.eps 120 ) 121 122 # Algorithm config 123 self._priority = self._cfg.priority 124 self._priority_IS_weight = self._cfg.priority_IS_weight 125 self._value_weight = self._cfg.learn.value_weight 126 self._entropy_weight = self._cfg.learn.entropy_weight 127 self._adv_norm = self._cfg.learn.adv_norm 128 self._grad_norm = self._cfg.learn.grad_norm 129 130 # Main and target models 131 self._learn_model = self._model 132 133 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 134 # Data preprocessing operations, such as stack data, cpu to cuda device 135 self._learn_model.train() 136 137 for i in range(0, len(data), self._cfg.learn.batch_size): 138 batch = default_collate(data[i:i + self._cfg.learn.batch_size]) 139 if self._cuda: 140 batch = to_device(batch, self._device) 141 142 # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected) 143 train_samples, cand_samples = batch["obs"]["train_sample"], batch["obs"]["candidate_samples"] 144 for cand_n in range(len(cand_samples)): 145 cand_samples[cand_n] = cand_samples[cand_n][0] 146 output = self._learn_model.forward(train_samples, cand_samples, mode='compute_actor_critic') 147 return_ = batch['return'] 148 149 # Calculate AWR loss 150 real_act = batch['action'] 151 152 # Ensure the shape of real_act is: (B, shot_number) 153 if len(real_act.shape) == 1: 154 real_act = real_act.unsqueeze(-1) 155 156 # Calculate different parts of loss. 157 total_policy_loss, total_entropy_loss, total_value_loss = 0, 0, 0 158 for shot_n in range(self._cfg.shot_number): 159 log_prob = output['dist'].log_prob(real_act[:, shot_n]) 160 # Clamp the adv for better stability. 161 adv = torch.clamp( 162 return_ - batch['value'], min=self._cfg.learn.norm_range[0], max=self._cfg.learn.norm_range[1] 163 ) 164 # The policy loss for AWR algorithm. 165 policy_loss = -(log_prob * torch.exp(adv / self._cfg.learn.beta)).mean() 166 total_policy_loss += policy_loss 167 # The value loss for AWR algorithm. 168 value_loss = ((return_ - output['value']) ** 2).mean() 169 total_value_loss += value_loss 170 # The entropy loss for AWR algorithm. 171 total_entropy_loss += -self._cfg.learn.entropy_weight * output['dist'].entropy().mean() 172 total_loss = total_entropy_loss + total_policy_loss + total_value_loss 173 174 self._optimizer.zero_grad() 175 total_loss.backward() 176 177 grad_norm = torch.nn.utils.clip_grad_norm_( 178 list(self._learn_model.parameters()), 179 max_norm=self._grad_norm, 180 ) 181 self._optimizer.step() 182 183 return { 184 'cur_lr': self._optimizer.param_groups[0]['lr'], 185 'total_loss': total_loss.item(), 186 'policy_loss': total_policy_loss.item(), 187 'entropy_loss': total_entropy_loss.item(), 188 'value_loss': total_value_loss.item(), 189 'return_abs_max': return_.abs().max().item(), 190 'grad_norm': grad_norm, 191 } 192 193 def _init_collect(self) -> None: 194 self._unroll_len = self._cfg.collect.unroll_len 195 self._gamma = self._cfg.collect.discount_factor 196 self._collect_model = model_wrap(self._model, wrapper_name='combination_multinomial_sample') 197 198 def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]: 199 """ 200 Overview: 201 Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ 202 that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ 203 data, such as the action to interact with the envs. 204 Arguments: 205 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 206 key of the dict is environment id and the value is the corresponding data of the env. 207 Returns: 208 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ 209 other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ 210 dict is the same as the input data, i.e. environment id. 211 """ 212 data_id = list(data.keys()) 213 data = default_collate(list(data.values())) 214 self._model.eval() 215 with torch.no_grad(): 216 # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected) 217 for ii in range(len(data['candidate_samples'])): 218 data['candidate_samples'][ii] = data['candidate_samples'][ii][0] 219 output = self._collect_model.forward( 220 self._cfg.shot_number, data['train_sample'], data['candidate_samples'], mode="compute_actor_critic" 221 ) 222 if self._cuda: 223 output = to_device(output, 'cpu') 224 output = default_decollate(output) 225 return {i: d for i, d in zip(data_id, output)} 226 227 def _process_transition(self, obs: Any, policy_output: Dict[str, torch.Tensor], 228 timestep: namedtuple) -> Dict[str, torch.Tensor]: 229 return { 230 'obs': obs, 231 'action': policy_output['action'], 232 'value': policy_output['value'], 233 'reward': timestep.reward, 234 'done': timestep.done, 235 } 236 237 def _get_train_sample(self, data: list) -> Union[None, List[Any]]: 238 r""" 239 Overview: 240 Get the trajectory and the n step return data, then sample from the n_step return data 241 Arguments: 242 - data (:obj:`list`): The trajectory's buffer list 243 Returns: 244 - samples (:obj:`dict`): The training samples generated 245 """ 246 if self._cfg.learn.ignore_done: 247 raise NotImplementedError 248 249 R = 0. 250 for i in reversed(range(len(data))): 251 R = self._gamma * R + data[i]['reward'] 252 data[i]['return'] = R 253 return get_train_sample(data, self._unroll_len) 254 255 def _init_eval(self) -> None: 256 self._eval_model = model_wrap(self._model, wrapper_name='combination_argmax_sample') 257 258 def _forward_eval(self, data: dict) -> dict: 259 data_id = list(data.keys()) 260 data = default_collate(list(data.values())) 261 self._model.eval() 262 with torch.no_grad(): 263 # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected) 264 for ii in range(len(data['candidate_samples'])): 265 data['candidate_samples'][ii] = data['candidate_samples'][ii][0] 266 output = self._eval_model.forward(self._cfg.shot_number, data['train_sample'], data['candidate_samples']) 267 if self._cuda: 268 output = to_device(output, 'cpu') 269 output = default_decollate(output) 270 return {i: d for i, d in zip(data_id, output)} 271 272 def _monitor_vars_learn(self) -> List[str]: 273 return super()._monitor_vars_learn() + \ 274 ['policy_loss', 'entropy_loss', 'return_abs_max', 'grad_norm', 'value_loss']