Skip to content

ding.data.level_replay.level_sampler

ding.data.level_replay.level_sampler

LevelSampler

Overview

Policy class of Prioritized Level Replay algorithm. https://arxiv.org/pdf/2010.03934.pdf

PLR is a method for improving generalization and sample-efficiency of deep RL agents on procedurally-generated environments by adaptively updating a sampling distribution over the training levels based on a score of the learning potential of replaying each level.

Full Source Code

../ding/data/level_replay/level_sampler.py

1from typing import Optional, Union, Any, List 2from easydict import EasyDict 3from ding.utils import deep_merge_dicts, SequenceType 4from collections import namedtuple 5import numpy as np 6import torch 7 8 9class LevelSampler(): 10 """ 11 Overview: 12 Policy class of Prioritized Level Replay algorithm. 13 https://arxiv.org/pdf/2010.03934.pdf 14 15 PLR is a method for improving generalization and sample-efficiency of \ 16 deep RL agents on procedurally-generated environments by adaptively updating \ 17 a sampling distribution over the training levels based on a score of the learning \ 18 potential of replaying each level. 19 """ 20 config = dict( 21 strategy='policy_entropy', 22 replay_schedule='fixed', 23 score_transform='rank', 24 temperature=1.0, 25 eps=0.05, 26 rho=0.2, 27 nu=0.5, 28 alpha=1.0, 29 staleness_coef=0, 30 staleness_transform='power', 31 staleness_temperature=1.0, 32 ) 33 34 def __init__( 35 self, 36 seeds: Optional[List[int]], 37 obs_space: Union[int, SequenceType], 38 action_space: int, 39 num_actors: int, 40 cfg: EasyDict, 41 ): 42 self.cfg = EasyDict(deep_merge_dicts(self.config, cfg)) 43 self.cfg.update(cfg) 44 self.obs_space = obs_space 45 self.action_space = action_space 46 self.strategy = self.cfg.strategy 47 self.replay_schedule = self.cfg.replay_schedule 48 self.score_transform = self.cfg.score_transform 49 self.temperature = self.cfg.temperature 50 # Eps means the level replay epsilon for eps-greedy sampling 51 self.eps = self.cfg.eps 52 # Rho means the minimum size of replay set relative to total number of levels before sampling replays 53 self.rho = self.cfg.rho 54 # Nu means the probability of sampling a new level instead of a replay level 55 self.nu = self.cfg.nu 56 # Alpha means the level score EWA smoothing factor 57 self.alpha = self.cfg.alpha 58 self.staleness_coef = self.cfg.staleness_coef 59 self.staleness_transform = self.cfg.staleness_transform 60 self.staleness_temperature = self.cfg.staleness_temperature 61 62 # Track seeds and scores as in np arrays backed by shared memory 63 self.seeds = np.array(seeds, dtype=np.int64) 64 self.seed2index = {seed: i for i, seed in enumerate(seeds)} 65 66 self.unseen_seed_weights = np.ones(len(seeds)) 67 self.seed_scores = np.zeros(len(seeds)) 68 self.partial_seed_scores = np.zeros((num_actors, len(seeds)), dtype=np.float32) 69 self.partial_seed_steps = np.zeros((num_actors, len(seeds)), dtype=np.int64) 70 self.seed_staleness = np.zeros(len(seeds)) 71 72 self.next_seed_index = 0 # Only used for sequential strategy 73 74 def update_with_rollouts(self, train_data: dict, num_actors: int): 75 total_steps = train_data['reward'].shape[0] 76 if self.strategy == 'random': 77 return 78 79 if self.strategy == 'policy_entropy': 80 score_function = self._entropy 81 elif self.strategy == 'least_confidence': 82 score_function = self._least_confidence 83 elif self.strategy == 'min_margin': 84 score_function = self._min_margin 85 elif self.strategy == 'gae': 86 score_function = self._gae 87 elif self.strategy == 'value_l1': 88 score_function = self._value_l1 89 elif self.strategy == 'one_step_td_error': 90 score_function = self._one_step_td_error 91 else: 92 raise ValueError('Not supported strategy: {}'.format(self.strategy)) 93 94 self._update_with_rollouts(train_data, num_actors, total_steps, score_function) 95 96 for actor_index in range(self.partial_seed_scores.shape[0]): 97 for seed_idx in range(self.partial_seed_scores.shape[1]): 98 if self.partial_seed_scores[actor_index][seed_idx] != 0: 99 self.update_seed_score(actor_index, seed_idx, 0, 0) 100 self.partial_seed_scores.fill(0) 101 self.partial_seed_steps.fill(0) 102 103 def update_seed_score(self, actor_index: int, seed_idx: int, score: float, num_steps: int): 104 score = self._partial_update_seed_score(actor_index, seed_idx, score, num_steps, done=True) 105 106 self.unseen_seed_weights[seed_idx] = 0. # No longer unseen 107 108 old_score = self.seed_scores[seed_idx] 109 self.seed_scores[seed_idx] = (1 - self.alpha) * old_score + self.alpha * score 110 111 def _partial_update_seed_score( 112 self, actor_index: int, seed_idx: int, score: float, num_steps: int, done: bool = False 113 ): 114 partial_score = self.partial_seed_scores[actor_index][seed_idx] 115 partial_num_steps = self.partial_seed_steps[actor_index][seed_idx] 116 117 running_num_steps = partial_num_steps + num_steps 118 merged_score = partial_score + (score - partial_score) * num_steps / float(running_num_steps) 119 120 if done: 121 self.partial_seed_scores[actor_index][seed_idx] = 0. # zero partial score, partial num_steps 122 self.partial_seed_steps[actor_index][seed_idx] = 0 123 else: 124 self.partial_seed_scores[actor_index][seed_idx] = merged_score 125 self.partial_seed_steps[actor_index][seed_idx] = running_num_steps 126 127 return merged_score 128 129 def _entropy(self, **kwargs): 130 episode_logits = kwargs['episode_logits'] 131 num_actions = self.action_space 132 max_entropy = -(1. / num_actions) * np.log(1. / num_actions) * num_actions 133 134 return (-torch.exp(episode_logits) * episode_logits).sum(-1).mean().item() / max_entropy 135 136 def _least_confidence(self, **kwargs): 137 episode_logits = kwargs['episode_logits'] 138 return (1 - torch.exp(episode_logits.max(-1, keepdim=True)[0])).mean().item() 139 140 def _min_margin(self, **kwargs): 141 episode_logits = kwargs['episode_logits'] 142 top2_confidence = torch.exp(episode_logits.topk(2, dim=-1)[0]) 143 return 1 - (top2_confidence[:, 0] - top2_confidence[:, 1]).mean().item() 144 145 def _gae(self, **kwargs): 146 147 advantages = kwargs['adv'] 148 149 return advantages.mean().item() 150 151 def _value_l1(self, **kwargs): 152 advantages = kwargs['adv'] 153 # If the absolute value of ADV is large, it means that the level can significantly change 154 # the policy and can be used to learn more 155 156 return advantages.abs().mean().item() 157 158 def _one_step_td_error(self, **kwargs): 159 rewards = kwargs['rewards'] 160 value = kwargs['value'] 161 162 max_t = len(rewards) 163 td_errors = (rewards[:-1] + value[:max_t - 1] - value[1:max_t]).abs() 164 165 return td_errors.abs().mean().item() 166 167 def _update_with_rollouts(self, train_data: dict, num_actors: int, all_total_steps: int, score_function): 168 level_seeds = train_data['seed'].reshape(num_actors, int(all_total_steps / num_actors)).transpose(0, 1) 169 policy_logits = train_data['logit'].reshape(num_actors, int(all_total_steps / num_actors), -1).transpose(0, 1) 170 done = train_data['done'].reshape(num_actors, int(all_total_steps / num_actors)).transpose(0, 1) 171 total_steps, num_actors = policy_logits.shape[:2] 172 num_decisions = len(policy_logits) 173 174 for actor_index in range(num_actors): 175 done_steps = done[:, actor_index].nonzero()[:total_steps, 0] 176 start_t = 0 177 178 for t in done_steps: 179 if not start_t < total_steps: 180 break 181 182 if t == 0: # if t is 0, then this done step caused a full update of previous seed last cycle 183 continue 184 185 seed_t = level_seeds[start_t, actor_index].item() 186 seed_t = int(seed_t) 187 seed_idx_t = self.seed2index[seed_t] 188 189 score_function_kwargs = {} 190 episode_logits = policy_logits[start_t:t, actor_index] 191 score_function_kwargs['episode_logits'] = torch.log_softmax(episode_logits, -1) 192 193 if self.strategy in ['gae', 'value_l1', 'one_step_td_error']: 194 rewards = train_data['reward'].reshape(num_actors, 195 int(all_total_steps / num_actors)).transpose(0, 1) 196 adv = train_data['adv'].reshape(num_actors, int(all_total_steps / num_actors)).transpose(0, 1) 197 value = train_data['value'].reshape(num_actors, int(all_total_steps / num_actors)).transpose(0, 1) 198 score_function_kwargs['adv'] = adv[start_t:t, actor_index] 199 score_function_kwargs['rewards'] = rewards[start_t:t, actor_index] 200 score_function_kwargs['value'] = value[start_t:t, actor_index] 201 202 score = score_function(**score_function_kwargs) 203 num_steps = len(episode_logits) 204 self.update_seed_score(actor_index, seed_idx_t, score, num_steps) 205 206 start_t = t.item() 207 208 if start_t < total_steps: 209 seed_t = level_seeds[start_t, actor_index].item() 210 seed_idx_t = self.seed2index[seed_t] 211 212 score_function_kwargs = {} 213 episode_logits = policy_logits[start_t:, actor_index] 214 score_function_kwargs['episode_logits'] = torch.log_softmax(episode_logits, -1) 215 216 if self.strategy in ['gae', 'value_l1', 'one_step_td_error']: 217 rewards = train_data['reward'].reshape(num_actors, 218 int(all_total_steps / num_actors)).transpose(0, 1) 219 adv = train_data['adv'].reshape(num_actors, int(all_total_steps / num_actors)).transpose(0, 1) 220 value = train_data['value'].reshape(num_actors, int(all_total_steps / num_actors)).transpose(0, 1) 221 score_function_kwargs['adv'] = adv[start_t:, actor_index] 222 score_function_kwargs['rewards'] = rewards[start_t:, actor_index] 223 score_function_kwargs['value'] = value[start_t:, actor_index] 224 225 score = score_function(**score_function_kwargs) 226 num_steps = len(episode_logits) 227 self._partial_update_seed_score(actor_index, seed_idx_t, score, num_steps) 228 229 def _update_staleness(self, selected_idx: int): 230 if self.staleness_coef > 0: 231 self.seed_staleness += 1 232 self.seed_staleness[selected_idx] = 0 233 234 def _sample_replay_level(self): 235 sample_weights = self._sample_weights() 236 237 if np.isclose(np.sum(sample_weights), 0): 238 sample_weights = np.ones_like(sample_weights, dtype=np.float32) / len(sample_weights) 239 240 seed_idx = np.random.choice(range(len(self.seeds)), 1, p=sample_weights)[0] 241 seed = self.seeds[seed_idx] 242 243 self._update_staleness(seed_idx) 244 245 return int(seed) 246 247 def _sample_unseen_level(self): 248 sample_weights = self.unseen_seed_weights / self.unseen_seed_weights.sum() 249 seed_idx = np.random.choice(range(len(self.seeds)), 1, p=sample_weights)[0] 250 seed = self.seeds[seed_idx] 251 252 self._update_staleness(seed_idx) 253 254 return int(seed) 255 256 def sample(self, strategy: Optional[str] = None): 257 if not strategy: 258 strategy = self.strategy 259 260 if strategy == 'random': 261 seed_idx = np.random.choice(range(len(self.seeds))) 262 seed = self.seeds[seed_idx] 263 return int(seed) 264 265 elif strategy == 'sequential': 266 seed_idx = self.next_seed_index 267 self.next_seed_index = (self.next_seed_index + 1) % len(self.seeds) 268 seed = self.seeds[seed_idx] 269 return int(seed) 270 271 num_unseen = (self.unseen_seed_weights > 0).sum() 272 proportion_seen = (len(self.seeds) - num_unseen) / len(self.seeds) 273 274 if self.replay_schedule == 'fixed': 275 if proportion_seen >= self.rho: 276 # Sample replay level with fixed prob = 1 - nu OR if all levels seen 277 if np.random.rand() > self.nu or not proportion_seen < 1.0: 278 return self._sample_replay_level() 279 280 # Otherwise, sample a new level 281 return self._sample_unseen_level() 282 283 else: # Default to proportionate schedule 284 if proportion_seen >= self.rho and np.random.rand() < proportion_seen: 285 return self._sample_replay_level() 286 else: 287 return self._sample_unseen_level() 288 289 def _sample_weights(self): 290 weights = self._score_transform(self.score_transform, self.temperature, self.seed_scores) 291 weights = weights * (1 - self.unseen_seed_weights) # zero out unseen levels 292 293 z = np.sum(weights) 294 if z > 0: 295 weights /= z 296 297 staleness_weights = 0 298 if self.staleness_coef > 0: 299 staleness_weights = self._score_transform( 300 self.staleness_transform, self.staleness_temperature, self.seed_staleness 301 ) 302 staleness_weights = staleness_weights * (1 - self.unseen_seed_weights) 303 z = np.sum(staleness_weights) 304 if z > 0: 305 staleness_weights /= z 306 307 weights = (1 - self.staleness_coef) * weights + self.staleness_coef * staleness_weights 308 309 return weights 310 311 def _score_transform(self, transform: Optional[str], temperature: float, scores: Optional[List[float]]): 312 if transform == 'rank': 313 temp = np.flip(scores.argsort()) 314 ranks = np.empty_like(temp) 315 ranks[temp] = np.arange(len(temp)) + 1 316 weights = 1 / ranks ** (1. / temperature) 317 elif transform == 'power': 318 eps = 0 if self.staleness_coef > 0 else 1e-3 319 weights = (np.array(scores) + eps) ** (1. / temperature) 320 321 return weights