Skip to content

ding.data.buffer.middleware.priority

ding.data.buffer.middleware.priority

PriorityExperienceReplay

Overview

The middleware that implements priority experience replay (PER).

__init__(buffer, IS_weight=True, priority_power_factor=0.6, IS_weight_power_factor=0.4, IS_weight_anneal_train_iter=int(100000.0))

Parameters:

Name Type Description Default
- buffer (

obj:Buffer): The buffer to use PER.

required
- IS_weight (

obj:bool): Whether use importance sampling or not.

required
- priority_power_factor (

obj:float): The factor that adjust the sensitivity between the sampling probability and the priority level.

required
- IS_weight_power_factor (

obj:float): The factor that adjust the sensitivity between the sample rarity and sampling probability in importance sampling.

required
- IS_weight_anneal_train_iter (

obj:float): The factor that controls the increasing of IS_weight_power_factor during training.

required

Full Source Code

../ding/data/buffer/middleware/priority.py

1from typing import Callable, Any, List, Dict, Optional, Union, TYPE_CHECKING 2import copy 3import numpy as np 4import torch 5from ding.utils import SumSegmentTree, MinSegmentTree 6from ding.data.buffer.buffer import BufferedData 7if TYPE_CHECKING: 8 from ding.data.buffer.buffer import Buffer 9 10 11class PriorityExperienceReplay: 12 """ 13 Overview: 14 The middleware that implements priority experience replay (PER). 15 """ 16 17 def __init__( 18 self, 19 buffer: 'Buffer', 20 IS_weight: bool = True, 21 priority_power_factor: float = 0.6, 22 IS_weight_power_factor: float = 0.4, 23 IS_weight_anneal_train_iter: int = int(1e5), 24 ) -> None: 25 """ 26 Arguments: 27 - buffer (:obj:`Buffer`): The buffer to use PER. 28 - IS_weight (:obj:`bool`): Whether use importance sampling or not. 29 - priority_power_factor (:obj:`float`): The factor that adjust the sensitivity between\ 30 the sampling probability and the priority level. 31 - IS_weight_power_factor (:obj:`float`): The factor that adjust the sensitivity between\ 32 the sample rarity and sampling probability in importance sampling. 33 - IS_weight_anneal_train_iter (:obj:`float`): The factor that controls the increasing of\ 34 ``IS_weight_power_factor`` during training. 35 """ 36 37 self.buffer = buffer 38 self.buffer_idx = {} 39 self.buffer_size = buffer.size 40 self.IS_weight = IS_weight 41 self.priority_power_factor = priority_power_factor 42 self.IS_weight_power_factor = IS_weight_power_factor 43 self.IS_weight_anneal_train_iter = IS_weight_anneal_train_iter 44 45 # Max priority till now, it's used to initizalize data's priority if "priority" is not passed in with the data. 46 self.max_priority = 1.0 47 # Capacity needs to be the power of 2. 48 capacity = int(np.power(2, np.ceil(np.log2(self.buffer_size)))) 49 self.sum_tree = SumSegmentTree(capacity) 50 if self.IS_weight: 51 self.min_tree = MinSegmentTree(capacity) 52 self.delta_anneal = (1 - self.IS_weight_power_factor) / self.IS_weight_anneal_train_iter 53 self.pivot = 0 54 55 def push(self, chain: Callable, data: Any, meta: Optional[dict] = None, *args, **kwargs) -> BufferedData: 56 if meta is None: 57 if 'priority' in data: 58 meta = {'priority': data.pop('priority')} 59 else: 60 meta = {'priority': self.max_priority} 61 else: 62 if 'priority' not in meta: 63 meta['priority'] = self.max_priority 64 meta['priority_idx'] = self.pivot 65 self._update_tree(meta['priority'], self.pivot) 66 buffered = chain(data, meta=meta, *args, **kwargs) 67 index = buffered.index 68 self.buffer_idx[self.pivot] = index 69 self.pivot = (self.pivot + 1) % self.buffer_size 70 return buffered 71 72 def sample(self, chain: Callable, size: int, *args, 73 **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]: 74 # Divide [0, 1) into size intervals on average 75 intervals = np.array([i * 1.0 / size for i in range(size)]) 76 # Uniformly sample within each interval 77 mass = intervals + np.random.uniform(size=(size, )) * 1. / size 78 # Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree) 79 mass *= self.sum_tree.reduce() 80 indices = [self.sum_tree.find_prefixsum_idx(m) for m in mass] 81 indices = [self.buffer_idx[i] for i in indices] 82 # Sample with indices 83 data = chain(indices=indices, *args, **kwargs) 84 if self.IS_weight: 85 # Calculate max weight for normalizing IS 86 sum_tree_root = self.sum_tree.reduce() 87 p_min = self.min_tree.reduce() / sum_tree_root 88 buffer_count = self.buffer.count() 89 max_weight = (buffer_count * p_min) ** (-self.IS_weight_power_factor) 90 for i in range(len(data)): 91 meta = data[i].meta 92 priority_idx = meta['priority_idx'] 93 p_sample = self.sum_tree[priority_idx] / sum_tree_root 94 weight = (buffer_count * p_sample) ** (-self.IS_weight_power_factor) 95 meta['priority_IS'] = weight / max_weight 96 data[i].data['priority_IS'] = torch.as_tensor([meta['priority_IS']]).float() # for compability 97 self.IS_weight_power_factor = min(1.0, self.IS_weight_power_factor + self.delta_anneal) 98 return data 99 100 def update(self, chain: Callable, index: str, data: Any, meta: Any, *args, **kwargs) -> None: 101 update_flag = chain(index, data, meta, *args, **kwargs) 102 if update_flag: # when update succeed 103 assert meta is not None, "Please indicate dict-type meta in priority update" 104 new_priority, idx = meta['priority'], meta['priority_idx'] 105 assert new_priority >= 0, "new_priority should greater than 0, but found {}".format(new_priority) 106 new_priority += 1e-5 # Add epsilon to avoid priority == 0 107 self._update_tree(new_priority, idx) 108 self.max_priority = max(self.max_priority, new_priority) 109 110 def delete(self, chain: Callable, index: str, *args, **kwargs) -> None: 111 item = self.buffer.get_by_index(index) 112 meta = item.meta 113 priority_idx = meta['priority_idx'] 114 self.sum_tree[priority_idx] = self.sum_tree.neutral_element 115 self.min_tree[priority_idx] = self.min_tree.neutral_element 116 self.buffer_idx.pop(priority_idx) 117 return chain(index, *args, **kwargs) 118 119 def clear(self, chain: Callable) -> None: 120 self.max_priority = 1.0 121 capacity = int(np.power(2, np.ceil(np.log2(self.buffer_size)))) 122 self.sum_tree = SumSegmentTree(capacity) 123 if self.IS_weight: 124 self.min_tree = MinSegmentTree(capacity) 125 self.buffer_idx = {} 126 self.pivot = 0 127 chain() 128 129 def _update_tree(self, priority: float, idx: int) -> None: 130 weight = priority ** self.priority_power_factor 131 self.sum_tree[idx] = weight 132 if self.IS_weight: 133 self.min_tree[idx] = weight 134 135 def state_dict(self) -> Dict: 136 return { 137 'max_priority': self.max_priority, 138 'IS_weight_power_factor': self.IS_weight_power_factor, 139 'sumtree': self.sumtree, 140 'mintree': self.mintree, 141 'buffer_idx': self.buffer_idx, 142 } 143 144 def load_state_dict(self, _state_dict: Dict, deepcopy: bool = False) -> None: 145 for k, v in _state_dict.items(): 146 if deepcopy: 147 setattr(self, '{}'.format(k), copy.deepcopy(v)) 148 else: 149 setattr(self, '{}'.format(k), v) 150 151 def __call__(self, action: str, chain: Callable, *args, **kwargs) -> Any: 152 if action in ["push", "sample", "update", "delete", "clear"]: 153 return getattr(self, action)(chain, *args, **kwargs) 154 return chain(*args, **kwargs)