1from functools import partial, lru_cache 2from typing import Callable, Optional 3 4import numpy as np 5 6import ding 7from .default_helper import one_time_warning 8 9 10@lru_cache() 11def njit(): 12 """ 13 Overview: 14 Decorator to compile a function using numba. 15 """ 16 17 try: 18 if ding.enable_numba: 19 import numba 20 from numba import njit as _njit 21 version = numba.__version__ 22 middle_version = version.split(".")[1] 23 if int(middle_version) < 53: 24 _njit = partial # noqa 25 one_time_warning( 26 "Due to your numba version <= 0.53.0, DI-engine disables it. And you can install \ 27 numba==0.53.0 if you want to speed up something" 28 ) 29 else: 30 _njit = partial 31 except ImportError: 32 one_time_warning("If you want to use numba to speed up segment tree, please install numba first") 33 _njit = partial 34 return _njit 35 36 37class SegmentTree: 38 """ 39 Overview: 40 Segment tree data structure, implemented by the tree-like array. Only the leaf nodes are real value, 41 non-leaf nodes are to do some operations on its left and right child. 42 Interfaces: 43 ``__init__``, ``reduce``, ``__setitem__``, ``__getitem__`` 44 """ 45 46 def __init__(self, capacity: int, operation: Callable, neutral_element: Optional[float] = None) -> None: 47 """ 48 Overview: 49 Initialize the segment tree. Tree's root node is at index 1. 50 Arguments: 51 - capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes), should be the power of 2. 52 - operation (:obj:`function`): The operation function to construct the tree, e.g. sum, max, min, etc. 53 - neutral_element (:obj:`float` or :obj:`None`): The value of the neutral element, which is used to init \ 54 all nodes value in the tree. 55 """ 56 assert capacity > 0 and capacity & (capacity - 1) == 0 57 self.capacity = capacity 58 self.operation = operation 59 # Set neutral value(initial value) for all elements. 60 if neutral_element is None: 61 if operation == 'sum': 62 neutral_element = 0. 63 elif operation == 'min': 64 neutral_element = np.inf 65 elif operation == 'max': 66 neutral_element = -np.inf 67 else: 68 raise ValueError("operation argument should be in min, max, sum (built in python functions).") 69 self.neutral_element = neutral_element 70 # Index 1 is the root; Index ranging in [capacity, 2 * capacity - 1] are the leaf nodes. 71 # For each parent node with index i, left child is value[2*i] and right child is value[2*i+1]. 72 self.value = np.full([capacity * 2], neutral_element) 73 self._compile() 74 75 def reduce(self, start: int = 0, end: Optional[int] = None) -> float: 76 """ 77 Overview: 78 Reduce the tree in range ``[start, end)`` 79 Arguments: 80 - start (:obj:`int`): Start index(relative index, the first leaf node is 0), default set to 0 81 - end (:obj:`int` or :obj:`None`): End index(relative index), default set to ``self.capacity`` 82 Returns: 83 - reduce_result (:obj:`float`): The reduce result value, which is dependent on data type and operation 84 """ 85 # TODO(nyz) check if directly reduce from the array(value) can be faster 86 if end is None: 87 end = self.capacity 88 assert (start < end) 89 # Change to absolute leaf index by adding capacity. 90 start += self.capacity 91 end += self.capacity 92 return _reduce(self.value, start, end, self.neutral_element, self.operation) 93 94 def __setitem__(self, idx: int, val: float) -> None: 95 """ 96 Overview: 97 Set ``leaf[idx] = val``; Then update the related nodes. 98 Arguments: 99 - idx (:obj:`int`): Leaf node index(relative index), should add ``capacity`` to change to absolute index. 100 - val (:obj:`float`): The value that will be assigned to ``leaf[idx]``. 101 """ 102 assert (0 <= idx < self.capacity), idx 103 # ``idx`` should add ``capacity`` to change to absolute index. 104 _setitem(self.value, idx + self.capacity, val, self.operation) 105 106 def __getitem__(self, idx: int) -> float: 107 """ 108 Overview: 109 Get ``leaf[idx]`` 110 Arguments: 111 - idx (:obj:`int`): Leaf node ``index(relative index)``, add ``capacity`` to change to absolute index. 112 Returns: 113 - val (:obj:`float`): The value of ``leaf[idx]`` 114 """ 115 assert (0 <= idx < self.capacity) 116 return self.value[idx + self.capacity] 117 118 def _compile(self) -> None: 119 """ 120 Overview: 121 Compile the functions using numba. 122 """ 123 124 f64 = np.array([0, 1], dtype=np.float64) 125 f32 = np.array([0, 1], dtype=np.float32) 126 i64 = np.array([0, 1], dtype=np.int64) 127 for d in [f64, f32, i64]: 128 _setitem(d, 0, 3.0, 'sum') 129 _reduce(d, 0, 1, 0.0, 'min') 130 _find_prefixsum_idx(d, 1, 0.5, 0.0) 131 132 133class SumSegmentTree(SegmentTree): 134 """ 135 Overview: 136 Sum segment tree, which is inherited from ``SegmentTree``. Init by passing ``operation='sum'``. 137 Interfaces: 138 ``__init__``, ``find_prefixsum_idx`` 139 """ 140 141 def __init__(self, capacity: int) -> None: 142 """ 143 Overview: 144 Init sum segment tree by passing ``operation='sum'`` 145 Arguments: 146 - capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes). 147 """ 148 super(SumSegmentTree, self).__init__(capacity, operation='sum') 149 150 def find_prefixsum_idx(self, prefixsum: float, trust_caller: bool = True) -> int: 151 """ 152 Overview: 153 Find the highest non-zero index i, sum_{j}leaf[j] <= ``prefixsum`` (where 0 <= j < i) 154 and sum_{j}leaf[j] > ``prefixsum`` (where 0 <= j < i+1) 155 Arguments: 156 - prefixsum (:obj:`float`): The target prefixsum. 157 - trust_caller (:obj:`bool`): Whether to trust caller, which means whether to check whether \ 158 this tree's sum is greater than the input ``prefixsum`` by calling ``reduce`` function. 159 Default set to True. 160 Returns: 161 - idx (:obj:`int`): Eligible index. 162 """ 163 if not trust_caller: 164 assert 0 <= prefixsum <= self.reduce() + 1e-5, prefixsum 165 return _find_prefixsum_idx(self.value, self.capacity, prefixsum, self.neutral_element) 166 167 168class MinSegmentTree(SegmentTree): 169 """ 170 Overview: 171 Min segment tree, which is inherited from ``SegmentTree``. Init by passing ``operation='min'``. 172 Interfaces: 173 ``__init__`` 174 """ 175 176 def __init__(self, capacity: int) -> None: 177 """ 178 Overview: 179 Initialize sum segment tree by passing ``operation='min'`` 180 Arguments: 181 - capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes). 182 """ 183 super(MinSegmentTree, self).__init__(capacity, operation='min') 184 185 186@njit() 187def _setitem(tree: np.ndarray, idx: int, val: float, operation: str) -> None: 188 """ 189 Overview: 190 Set ``tree[idx] = val``; Then update the related nodes. 191 Arguments: 192 - tree (:obj:`np.ndarray`): The tree array. 193 - idx (:obj:`int`): The index of the leaf node. 194 - val (:obj:`float`): The value that will be assigned to ``leaf[idx]``. 195 - operation (:obj:`str`): The operation function to construct the tree, e.g. sum, max, min, etc. 196 """ 197 198 tree[idx] = val 199 # Update from specified node to the root node 200 while idx > 1: 201 idx = idx >> 1 # To parent node idx 202 left, right = tree[2 * idx], tree[2 * idx + 1] 203 if operation == 'sum': 204 tree[idx] = left + right 205 elif operation == 'min': 206 tree[idx] = min([left, right]) 207 208 209@njit() 210def _reduce(tree: np.ndarray, start: int, end: int, neutral_element: float, operation: str) -> float: 211 """ 212 Overview: 213 Reduce the tree in range ``[start, end)`` 214 Arguments: 215 - tree (:obj:`np.ndarray`): The tree array. 216 - start (:obj:`int`): Start index(relative index, the first leaf node is 0). 217 - end (:obj:`int`): End index(relative index). 218 - neutral_element (:obj:`float`): The value of the neutral element, which is used to init \ 219 all nodes value in the tree. 220 - operation (:obj:`str`): The operation function to construct the tree, e.g. sum, max, min, etc. 221 """ 222 223 # Nodes in 【start, end) will be aggregated 224 result = neutral_element 225 while start < end: 226 if start & 1: 227 # If current start node (tree[start]) is a right child node, operate on start node and increase start by 1 228 if operation == 'sum': 229 result = result + tree[start] 230 elif operation == 'min': 231 result = min([result, tree[start]]) 232 start += 1 233 if end & 1: 234 # If current end node (tree[end - 1]) is right child node, decrease end by 1 and operate on end node 235 end -= 1 236 if operation == 'sum': 237 result = result + tree[end] 238 elif operation == 'min': 239 result = min([result, tree[end]]) 240 # Both start and end transform to respective parent node 241 start = start >> 1 242 end = end >> 1 243 return result 244 245 246@njit() 247def _find_prefixsum_idx(tree: np.ndarray, capacity: int, prefixsum: float, neutral_element: float) -> int: 248 """ 249 Overview: 250 Find the highest non-zero index i, sum_{j}leaf[j] <= ``prefixsum`` (where 0 <= j < i) 251 and sum_{j}leaf[j] > ``prefixsum`` (where 0 <= j < i+1) 252 Arguments: 253 - tree (:obj:`np.ndarray`): The tree array. 254 - capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes). 255 - prefixsum (:obj:`float`): The target prefixsum. 256 - neutral_element (:obj:`float`): The value of the neutral element, which is used to init \ 257 all nodes value in the tree. 258 """ 259 260 # The function is to find a non-leaf node's index which satisfies: 261 # self.value[idx] > input prefixsum and self.value[idx + 1] <= input prefixsum 262 # In other words, we can assume that there are intervals: [num_0, num_1), [num_1, num_2), ... [num_k, num_k+1), 263 # the function is to find input prefixsum falls in which interval and return the interval's index. 264 idx = 1 # start from root node 265 while idx < capacity: 266 child_base = 2 * idx 267 if tree[child_base] > prefixsum: 268 idx = child_base 269 else: 270 prefixsum -= tree[child_base] 271 idx = child_base + 1 272 # Special case: The last element of ``self.value`` is neutral_element(0), 273 # and caller wants to ``find_prefixsum_idx(root_value)``. 274 # However, input prefixsum should be smaller than root_value. 275 if idx == 2 * capacity - 1 and tree[idx] == neutral_element: 276 tmp = idx 277 while tmp >= capacity and tree[tmp] == neutral_element: 278 tmp -= 1 279 if tmp != capacity: 280 idx = tmp 281 else: 282 raise ValueError("All elements in tree are the neutral_element(0), can't find non-zero element") 283 assert (tree[idx] != neutral_element) 284 return idx - capacity