Skip to content

ding.utils.default_helper

ding.utils.default_helper

LimitedSpaceContainer

Overview

A space simulator.

Interfaces: __init__, get_residual_space, release_space

__init__(min_val, max_val)

Overview

Set min_val and max_val of the container, also set cur to min_val for initialization.

Arguments: - min_val (:obj:int): Min volume of the container, usually 0. - max_val (:obj:int): Max volume of the container.

get_residual_space()

Overview

Get all residual pieces of space. Set cur to max_val

Arguments: - ret (:obj:int): Residual space, calculated by max_val - cur.

acquire_space()

Overview

Try to get one pice of space. If there is one, return True; Otherwise return False.

Returns: - flag (:obj:bool): Whether there is any piece of residual space.

release_space()

Overview

Release only one piece of space. Decrement cur, but ensure it won't be negative.

increase_space()

Overview

Increase one piece in space. Increment max_val.

decrease_space()

Overview

Decrease one piece in space. Decrement max_val.

RunningMeanStd

Bases: object

Overview

Wrapper to update new variable, new mean, and new count

Interfaces: __init__, update, reset, new_shape Properties: - mean, std, _epsilon, _shape, _mean, _var, _count

mean property

Overview

Property mean gotten from self._mean

std property

Overview

Property std calculated from self._var and the epsilon value of self._epsilon

__init__(epsilon=0.0001, shape=(), device=torch.device('cpu'))

Overview

Initialize self. See help(type(self)) for accurate signature; setup the properties.

Arguments: - env (:obj:gym.Env): the environment to wrap. - epsilon (:obj:Float): the epsilon used for self for the std output - shape (:obj: np.array): the np array shape used for the expression of this wrapper on attibutes of mean and variance

update(x)

Overview

Update mean, variable, and count

Arguments: - x: the batch

reset()

Overview

Resets the state of the environment and reset properties: _mean, _var, _count

new_shape(obs_shape, act_shape, rew_shape) staticmethod

Overview

Get new shape of observation, acton, and reward; in this case unchanged.

Arguments: obs_shape (:obj:Any), act_shape (:obj:Any), rew_shape (:obj:Any) Returns: obs_shape (:obj:Any), act_shape (:obj:Any), rew_shape (:obj:Any)

get_shape0(data)

Overview

Get shape[0] of data's torch tensor or treetensor

Arguments: - data (:obj:Union[List,Dict,torch.Tensor,ttorch.Tensor]): data to be analysed Returns: - shape[0] (:obj:int): first dimension length of data, usually the batchsize.

lists_to_dicts(data, recursive=False)

Overview

Transform a list of dicts to a dict of lists.

Arguments: - data (:obj:Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]]): A dict of lists need to be transformed - recursive (:obj:bool): whether recursively deals with dict element Returns: - newdata (:obj:Union[Mapping[object, object], NamedTuple]): A list of dicts as a result Example: >>> from ding.utils import * >>> lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}])

dicts_to_lists(data)

Overview

Transform a dict of lists to a list of dicts.

Parameters:

Name Type Description Default
- data (

obj:Mapping[object, list]): A list of dicts need to be transformed

required

Returns:

Type Description
List[Mapping[object, object]]
  • newdata (:obj:List[Mapping[object, object]]): A dict of lists as a result
Example

from ding.utils import * dicts_to_lists({1: [1, 2], 10: [3, 4]}) [{1: 1, 10: 3}, {1: 2, 10: 4}]

override(cls)

Overview

Annotation for documenting method overrides.

Parameters:

Name Type Description Default
- cls (

obj:type): The superclass that provides the overridden method. If this cls does not actually have the method, an error is raised.

required

squeeze(data)

Overview

Squeeze data from tuple, list or dict to single object

Arguments: - data (:obj:object): data to be squeezed Example: >>> a = (4, ) >>> a = squeeze(a) >>> print(a) >>> 4

default_get(data, name, default_value=None, default_fn=None, judge_fn=None)

Overview

Getting the value by input, checks generically on the inputs with at least data and name. If name exists in data, get the value at name; else, add name to default_get_set with value generated by default_fn (or directly as default_value) that is checked by judge_fn to be legal.

Arguments: - data(:obj:dict): Data input dictionary - name(:obj:str): Key name - default_value(:obj:Optional[Any]) = None, - default_fn(:obj:Optional[Callable]) = Value - judge_fn(:obj:Optional[Callable]) = None Returns: - ret(:obj:list): Splitted data - residual(:obj:list): Residule list

list_split(data, step)

Overview

Split list of data by step.

Arguments: - data(:obj:list): List of data for spliting - step(:obj:int): Number of step for spliting Returns: - ret(:obj:list): List of splitted data. - residual(:obj:list): Residule list. This value is None when data divides steps. Example: >>> list_split([1,2,3,4],2) ([[1, 2], [3, 4]], None) >>> list_split([1,2,3,4],3) ([[1, 2, 3]], [4])

error_wrapper(fn, default_ret, warning_msg='')

Overview

wrap the function, so that any Exception in the function will be catched and return the default_ret

Arguments: - fn (:obj:Callable): the function to be wraped - default_ret (:obj:obj): the default return when an Exception occurred in the function Returns: - wrapper (:obj:Callable): the wrapped function Examples: >>> # Used to checkfor Fakelink (Refer to utils.linklink_dist_helper.py) >>> def get_rank(): # Get the rank of linklink model, return 0 if use FakeLink. >>> if is_fake_link: >>> return 0 >>> return error_wrapper(link.get_rank, 0)()

deep_merge_dicts(original, new_dict)

Overview

Merge two dicts by calling deep_update

Arguments: - original (:obj:dict): Dict 1. - new_dict (:obj:dict): Dict 2. Returns: - merged_dict (:obj:dict): A new dict that is d1 and d2 deeply merged.

deep_update(original, new_dict, new_keys_allowed=False, whitelist=None, override_all_if_type_changes=None)

Overview

Update original dict with values from new_dict recursively.

Arguments: - original (:obj:dict): Dictionary with default values. - new_dict (:obj:dict): Dictionary with values to be updated - new_keys_allowed (:obj:bool): Whether new keys are allowed. - whitelist (:obj:Optional[List[str]]): List of keys that correspond to dict values where new subkeys can be introduced. This is only at the top level. - override_all_if_type_changes(:obj:Optional[List[str]]): List of top level keys with value=dict, for which we always simply override the entire value (:obj:dict), if the "type" key in that value dict changes.

.. note::

If new key is introduced in new_dict, then if new_keys_allowed is not
True, an error will be thrown. Further, for sub-dicts, if the key is
in the whitelist, then new subkeys can be introduced.

flatten_dict(data, delimiter='/')

Overview

Flatten the dict, see example

Arguments: - data (:obj:dict): Original nested dict - delimiter (str): Delimiter of the keys of the new dict Returns: - data (:obj:dict): Flattened nested dict Example: >>> a {'a': {'b': 100}} >>> flatten_dict(a)

set_pkg_seed(seed, use_cuda=True)

Overview

Side effect function to set seed for random, numpy random, and torch's manual seed. This is usaually used in entry scipt in the section of setting random seed for all package and instance

Argument: - seed(:obj:int): Set seed - use_cuda(:obj:bool) Whether use cude Examples: >>> # ../entry/xxxenv_xxxpolicy_main.py >>> ... # Set random seed for all package and instance >>> collector_env.seed(seed) >>> evaluator_env.seed(seed, dynamic_seed=False) >>> set_pkg_seed(seed, use_cuda=cfg.policy.cuda) >>> ... # Set up RL Policy, etc. >>> ...

one_time_warning(warning_msg) cached

Overview

Print warning message only once.

Arguments: - warning_msg (:obj:str): Warning message.

split_fn(data, indices, start, end)

Overview

Split data by indices

Arguments: - data (:obj:Union[List, Dict, torch.Tensor, ttorch.Tensor]): data to be analysed - indices (:obj:np.ndarray): indices to split - start (:obj:int): start index - end (:obj:int): end index

split_data_generator(data, split_size, shuffle=True)

Overview

Split data into batches

Arguments: - data (:obj:dict): data to be analysed - split_size (:obj:int): split size - shuffle (:obj:bool): whether shuffle

make_key_as_identifier(data)

Overview

Make the key of dict into legal python identifier string so that it is compatible with some python magic method such as __getattr.

Arguments: - data (:obj:Dict[str, Any]): The original dict data. Return: - new_data (:obj:Dict[str, Any]): The new dict data with legal identifier keys.

remove_illegal_item(data)

Overview

Remove illegal item in dict info, like str, which is not compatible with Tensor.

Arguments: - data (:obj:Dict[str, Any]): The original dict data. Return: - new_data (:obj:Dict[str, Any]): The new dict data without legal items.

Full Source Code

../ding/utils/default_helper.py

1from typing import Union, Mapping, List, NamedTuple, Tuple, Callable, Optional, Any, Dict 2import copy 3from ditk import logging 4import random 5from functools import lru_cache # in python3.9, we can change to cache 6import numpy as np 7import torch 8import treetensor.torch as ttorch 9 10 11def get_shape0(data: Union[List, Dict, torch.Tensor, ttorch.Tensor]) -> int: 12 """ 13 Overview: 14 Get shape[0] of data's torch tensor or treetensor 15 Arguments: 16 - data (:obj:`Union[List,Dict,torch.Tensor,ttorch.Tensor]`): data to be analysed 17 Returns: 18 - shape[0] (:obj:`int`): first dimension length of data, usually the batchsize. 19 """ 20 if isinstance(data, list) or isinstance(data, tuple): 21 return get_shape0(data[0]) 22 elif isinstance(data, dict): 23 for k, v in data.items(): 24 return get_shape0(v) 25 elif isinstance(data, torch.Tensor): 26 return data.shape[0] 27 elif isinstance(data, ttorch.Tensor): 28 29 def fn(t): 30 item = list(t.values())[0] 31 if np.isscalar(item[0]): 32 return item[0] 33 else: 34 return fn(item) 35 36 return fn(data.shape) 37 else: 38 raise TypeError("Error in getting shape0, not support type: {}".format(data)) 39 40 41def lists_to_dicts( 42 data: Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]], 43 recursive: bool = False, 44) -> Union[Mapping[object, object], NamedTuple]: 45 """ 46 Overview: 47 Transform a list of dicts to a dict of lists. 48 Arguments: 49 - data (:obj:`Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]]`): 50 A dict of lists need to be transformed 51 - recursive (:obj:`bool`): whether recursively deals with dict element 52 Returns: 53 - newdata (:obj:`Union[Mapping[object, object], NamedTuple]`): A list of dicts as a result 54 Example: 55 >>> from ding.utils import * 56 >>> lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}]) 57 {1: [1, 2], 10: [3, 4]} 58 """ 59 if len(data) == 0: 60 raise ValueError("empty data") 61 if isinstance(data[0], dict): 62 if recursive: 63 new_data = {} 64 for k in data[0].keys(): 65 if isinstance(data[0][k], dict) and k != 'prev_state': 66 tmp = [data[b][k] for b in range(len(data))] 67 new_data[k] = lists_to_dicts(tmp) 68 else: 69 new_data[k] = [data[b][k] for b in range(len(data))] 70 else: 71 new_data = {k: [data[b][k] for b in range(len(data))] for k in data[0].keys()} 72 elif isinstance(data[0], tuple) and hasattr(data[0], '_fields'): # namedtuple 73 new_data = type(data[0])(*list(zip(*data))) 74 else: 75 raise TypeError("not support element type: {}".format(type(data[0]))) 76 return new_data 77 78 79def dicts_to_lists(data: Mapping[object, List[object]]) -> List[Mapping[object, object]]: 80 """ 81 Overview: 82 Transform a dict of lists to a list of dicts. 83 84 Arguments: 85 - data (:obj:`Mapping[object, list]`): A list of dicts need to be transformed 86 87 Returns: 88 - newdata (:obj:`List[Mapping[object, object]]`): A dict of lists as a result 89 90 Example: 91 >>> from ding.utils import * 92 >>> dicts_to_lists({1: [1, 2], 10: [3, 4]}) 93 [{1: 1, 10: 3}, {1: 2, 10: 4}] 94 """ 95 new_data = [v for v in data.values()] 96 new_data = [{k: v for k, v in zip(data.keys(), t)} for t in list(zip(*new_data))] 97 return new_data 98 99 100def override(cls: type) -> Callable[[ 101 Callable, 102], Callable]: 103 """ 104 Overview: 105 Annotation for documenting method overrides. 106 107 Arguments: 108 - cls (:obj:`type`): The superclass that provides the overridden method. If this 109 cls does not actually have the method, an error is raised. 110 """ 111 112 def check_override(method: Callable) -> Callable: 113 if method.__name__ not in dir(cls): 114 raise NameError("{} does not override any method of {}".format(method, cls)) 115 return method 116 117 return check_override 118 119 120def squeeze(data: object) -> object: 121 """ 122 Overview: 123 Squeeze data from tuple, list or dict to single object 124 Arguments: 125 - data (:obj:`object`): data to be squeezed 126 Example: 127 >>> a = (4, ) 128 >>> a = squeeze(a) 129 >>> print(a) 130 >>> 4 131 """ 132 if isinstance(data, tuple) or isinstance(data, list): 133 if len(data) == 1: 134 return data[0] 135 else: 136 return tuple(data) 137 elif isinstance(data, dict): 138 if len(data) == 1: 139 return list(data.values())[0] 140 return data 141 142 143default_get_set = set() 144 145 146def default_get( 147 data: dict, 148 name: str, 149 default_value: Optional[Any] = None, 150 default_fn: Optional[Callable] = None, 151 judge_fn: Optional[Callable] = None 152) -> Any: 153 """ 154 Overview: 155 Getting the value by input, checks generically on the inputs with \ 156 at least ``data`` and ``name``. If ``name`` exists in ``data``, \ 157 get the value at ``name``; else, add ``name`` to ``default_get_set``\ 158 with value generated by \ 159 ``default_fn`` (or directly as ``default_value``) that \ 160 is checked by `` judge_fn`` to be legal. 161 Arguments: 162 - data(:obj:`dict`): Data input dictionary 163 - name(:obj:`str`): Key name 164 - default_value(:obj:`Optional[Any]`) = None, 165 - default_fn(:obj:`Optional[Callable]`) = Value 166 - judge_fn(:obj:`Optional[Callable]`) = None 167 Returns: 168 - ret(:obj:`list`): Splitted data 169 - residual(:obj:`list`): Residule list 170 """ 171 if name in data: 172 return data[name] 173 else: 174 assert default_value is not None or default_fn is not None 175 value = default_fn() if default_fn is not None else default_value 176 if judge_fn: 177 assert judge_fn(value), "defalut value({}) is not accepted by judge_fn".format(type(value)) 178 if name not in default_get_set: 179 logging.warning("{} use default value {}".format(name, value)) 180 default_get_set.add(name) 181 return value 182 183 184def list_split(data: list, step: int) -> List[list]: 185 """ 186 Overview: 187 Split list of data by step. 188 Arguments: 189 - data(:obj:`list`): List of data for spliting 190 - step(:obj:`int`): Number of step for spliting 191 Returns: 192 - ret(:obj:`list`): List of splitted data. 193 - residual(:obj:`list`): Residule list. This value is ``None`` when ``data`` divides ``steps``. 194 Example: 195 >>> list_split([1,2,3,4],2) 196 ([[1, 2], [3, 4]], None) 197 >>> list_split([1,2,3,4],3) 198 ([[1, 2, 3]], [4]) 199 """ 200 if len(data) < step: 201 return [], data 202 ret = [] 203 divide_num = len(data) // step 204 for i in range(divide_num): 205 start, end = i * step, (i + 1) * step 206 ret.append(data[start:end]) 207 if divide_num * step < len(data): 208 residual = data[divide_num * step:] 209 else: 210 residual = None 211 return ret, residual 212 213 214def error_wrapper(fn, default_ret, warning_msg=""): 215 """ 216 Overview: 217 wrap the function, so that any Exception in the function will be catched and return the default_ret 218 Arguments: 219 - fn (:obj:`Callable`): the function to be wraped 220 - default_ret (:obj:`obj`): the default return when an Exception occurred in the function 221 Returns: 222 - wrapper (:obj:`Callable`): the wrapped function 223 Examples: 224 >>> # Used to checkfor Fakelink (Refer to utils.linklink_dist_helper.py) 225 >>> def get_rank(): # Get the rank of linklink model, return 0 if use FakeLink. 226 >>> if is_fake_link: 227 >>> return 0 228 >>> return error_wrapper(link.get_rank, 0)() 229 """ 230 231 def wrapper(*args, **kwargs): 232 try: 233 ret = fn(*args, **kwargs) 234 except Exception as e: 235 ret = default_ret 236 if warning_msg != "": 237 one_time_warning(warning_msg, "\ndefault_ret = {}\terror = {}".format(default_ret, e)) 238 return ret 239 240 return wrapper 241 242 243class LimitedSpaceContainer: 244 """ 245 Overview: 246 A space simulator. 247 Interfaces: 248 ``__init__``, ``get_residual_space``, ``release_space`` 249 """ 250 251 def __init__(self, min_val: int, max_val: int) -> None: 252 """ 253 Overview: 254 Set ``min_val`` and ``max_val`` of the container, also set ``cur`` to ``min_val`` for initialization. 255 Arguments: 256 - min_val (:obj:`int`): Min volume of the container, usually 0. 257 - max_val (:obj:`int`): Max volume of the container. 258 """ 259 self.min_val = min_val 260 self.max_val = max_val 261 assert (max_val >= min_val) 262 self.cur = self.min_val 263 264 def get_residual_space(self) -> int: 265 """ 266 Overview: 267 Get all residual pieces of space. Set ``cur`` to ``max_val`` 268 Arguments: 269 - ret (:obj:`int`): Residual space, calculated by ``max_val`` - ``cur``. 270 """ 271 ret = self.max_val - self.cur 272 self.cur = self.max_val 273 return ret 274 275 def acquire_space(self) -> bool: 276 """ 277 Overview: 278 Try to get one pice of space. If there is one, return True; Otherwise return False. 279 Returns: 280 - flag (:obj:`bool`): Whether there is any piece of residual space. 281 """ 282 if self.cur < self.max_val: 283 self.cur += 1 284 return True 285 else: 286 return False 287 288 def release_space(self) -> None: 289 """ 290 Overview: 291 Release only one piece of space. Decrement ``cur``, but ensure it won't be negative. 292 """ 293 self.cur = max(self.min_val, self.cur - 1) 294 295 def increase_space(self) -> None: 296 """ 297 Overview: 298 Increase one piece in space. Increment ``max_val``. 299 """ 300 self.max_val += 1 301 302 def decrease_space(self) -> None: 303 """ 304 Overview: 305 Decrease one piece in space. Decrement ``max_val``. 306 """ 307 self.max_val -= 1 308 309 310def deep_merge_dicts(original: dict, new_dict: dict) -> dict: 311 """ 312 Overview: 313 Merge two dicts by calling ``deep_update`` 314 Arguments: 315 - original (:obj:`dict`): Dict 1. 316 - new_dict (:obj:`dict`): Dict 2. 317 Returns: 318 - merged_dict (:obj:`dict`): A new dict that is d1 and d2 deeply merged. 319 """ 320 original = original or {} 321 new_dict = new_dict or {} 322 merged = copy.deepcopy(original) 323 if new_dict: # if new_dict is neither empty dict nor None 324 deep_update(merged, new_dict, True, []) 325 return merged 326 327 328def deep_update( 329 original: dict, 330 new_dict: dict, 331 new_keys_allowed: bool = False, 332 whitelist: Optional[List[str]] = None, 333 override_all_if_type_changes: Optional[List[str]] = None 334): 335 """ 336 Overview: 337 Update original dict with values from new_dict recursively. 338 Arguments: 339 - original (:obj:`dict`): Dictionary with default values. 340 - new_dict (:obj:`dict`): Dictionary with values to be updated 341 - new_keys_allowed (:obj:`bool`): Whether new keys are allowed. 342 - whitelist (:obj:`Optional[List[str]]`): 343 List of keys that correspond to dict 344 values where new subkeys can be introduced. This is only at the top 345 level. 346 - override_all_if_type_changes(:obj:`Optional[List[str]]`): 347 List of top level 348 keys with value=dict, for which we always simply override the 349 entire value (:obj:`dict`), if the "type" key in that value dict changes. 350 351 .. note:: 352 353 If new key is introduced in new_dict, then if new_keys_allowed is not 354 True, an error will be thrown. Further, for sub-dicts, if the key is 355 in the whitelist, then new subkeys can be introduced. 356 """ 357 whitelist = whitelist or [] 358 override_all_if_type_changes = override_all_if_type_changes or [] 359 360 for k, value in new_dict.items(): 361 if k not in original and not new_keys_allowed: 362 raise RuntimeError("Unknown config parameter `{}`. Base config have: {}.".format(k, original.keys())) 363 364 # Both original value and new one are dicts. 365 if isinstance(original.get(k), dict) and isinstance(value, dict): 366 # Check old type vs old one. If different, override entire value. 367 if k in override_all_if_type_changes and \ 368 "type" in value and "type" in original[k] and \ 369 value["type"] != original[k]["type"]: 370 original[k] = value 371 # Whitelisted key -> ok to add new subkeys. 372 elif k in whitelist: 373 deep_update(original[k], value, True) 374 # Non-whitelisted key. 375 else: 376 deep_update(original[k], value, new_keys_allowed) 377 # Original value not a dict OR new value not a dict: 378 # Override entire value. 379 else: 380 original[k] = value 381 return original 382 383 384def flatten_dict(data: dict, delimiter: str = "/") -> dict: 385 """ 386 Overview: 387 Flatten the dict, see example 388 Arguments: 389 - data (:obj:`dict`): Original nested dict 390 - delimiter (str): Delimiter of the keys of the new dict 391 Returns: 392 - data (:obj:`dict`): Flattened nested dict 393 Example: 394 >>> a 395 {'a': {'b': 100}} 396 >>> flatten_dict(a) 397 {'a/b': 100} 398 """ 399 data = copy.deepcopy(data) 400 while any(isinstance(v, dict) for v in data.values()): 401 remove = [] 402 add = {} 403 for key, value in data.items(): 404 if isinstance(value, dict): 405 for subkey, v in value.items(): 406 add[delimiter.join([key, subkey])] = v 407 remove.append(key) 408 data.update(add) 409 for k in remove: 410 del data[k] 411 return data 412 413 414def set_pkg_seed(seed: int, use_cuda: bool = True) -> None: 415 """ 416 Overview: 417 Side effect function to set seed for ``random``, ``numpy random``, and ``torch's manual seed``.\ 418 This is usaually used in entry scipt in the section of setting random seed for all package and instance 419 Argument: 420 - seed(:obj:`int`): Set seed 421 - use_cuda(:obj:`bool`) Whether use cude 422 Examples: 423 >>> # ../entry/xxxenv_xxxpolicy_main.py 424 >>> ... 425 # Set random seed for all package and instance 426 >>> collector_env.seed(seed) 427 >>> evaluator_env.seed(seed, dynamic_seed=False) 428 >>> set_pkg_seed(seed, use_cuda=cfg.policy.cuda) 429 >>> ... 430 # Set up RL Policy, etc. 431 >>> ... 432 433 """ 434 random.seed(seed) 435 np.random.seed(seed) 436 torch.manual_seed(seed) 437 if use_cuda and torch.cuda.is_available(): 438 torch.cuda.manual_seed(seed) 439 440 441@lru_cache() 442def one_time_warning(warning_msg: str) -> None: 443 """ 444 Overview: 445 Print warning message only once. 446 Arguments: 447 - warning_msg (:obj:`str`): Warning message. 448 """ 449 450 logging.warning(warning_msg) 451 452 453def split_fn(data, indices, start, end): 454 """ 455 Overview: 456 Split data by indices 457 Arguments: 458 - data (:obj:`Union[List, Dict, torch.Tensor, ttorch.Tensor]`): data to be analysed 459 - indices (:obj:`np.ndarray`): indices to split 460 - start (:obj:`int`): start index 461 - end (:obj:`int`): end index 462 """ 463 464 if data is None: 465 return None 466 elif isinstance(data, list): 467 return [split_fn(d, indices, start, end) for d in data] 468 elif isinstance(data, dict): 469 return {k1: split_fn(v1, indices, start, end) for k1, v1 in data.items()} 470 elif isinstance(data, str): 471 return data 472 else: 473 return data[indices[start:end]] 474 475 476def split_data_generator(data: dict, split_size: int, shuffle: bool = True) -> dict: 477 """ 478 Overview: 479 Split data into batches 480 Arguments: 481 - data (:obj:`dict`): data to be analysed 482 - split_size (:obj:`int`): split size 483 - shuffle (:obj:`bool`): whether shuffle 484 """ 485 486 assert isinstance(data, dict), type(data) 487 length = [] 488 for k, v in data.items(): 489 if v is None: 490 continue 491 elif k in ['prev_state', 'prev_actor_state', 'prev_critic_state']: 492 length.append(len(v)) 493 elif isinstance(v, list) or isinstance(v, tuple): 494 if isinstance(v[0], str): 495 # some buffer data contains useless string infos, such as 'buffer_id', 496 # which should not be split, so we just skip it 497 continue 498 else: 499 length.append(get_shape0(v[0])) 500 elif isinstance(v, dict): 501 length.append(len(v[list(v.keys())[0]])) 502 else: 503 length.append(len(v)) 504 assert len(length) > 0 505 # assert len(set(length)) == 1, "data values must have the same length: {}".format(length) 506 # if continuous action, data['logit'] is list of length 2 507 length = length[0] 508 assert split_size >= 1 509 if shuffle: 510 indices = np.random.permutation(length) 511 else: 512 indices = np.arange(length) 513 for i in range(0, length, split_size): 514 if i + split_size > length: 515 i = length - split_size 516 batch = split_fn(data, indices, i, i + split_size) 517 yield batch 518 519 520class RunningMeanStd(object): 521 """ 522 Overview: 523 Wrapper to update new variable, new mean, and new count 524 Interfaces: 525 ``__init__``, ``update``, ``reset``, ``new_shape`` 526 Properties: 527 - ``mean``, ``std``, ``_epsilon``, ``_shape``, ``_mean``, ``_var``, ``_count`` 528 """ 529 530 def __init__(self, epsilon=1e-4, shape=(), device=torch.device('cpu')): 531 """ 532 Overview: 533 Initialize ``self.`` See ``help(type(self))`` for accurate \ 534 signature; setup the properties. 535 Arguments: 536 - env (:obj:`gym.Env`): the environment to wrap. 537 - epsilon (:obj:`Float`): the epsilon used for self for the std output 538 - shape (:obj: `np.array`): the np array shape used for the expression \ 539 of this wrapper on attibutes of mean and variance 540 """ 541 self._epsilon = epsilon 542 self._shape = shape 543 self._device = device 544 self.reset() 545 546 def update(self, x): 547 """ 548 Overview: 549 Update mean, variable, and count 550 Arguments: 551 - ``x``: the batch 552 """ 553 batch_mean = np.mean(x, axis=0) 554 batch_var = np.var(x, axis=0) 555 batch_count = x.shape[0] 556 557 new_count = batch_count + self._count 558 mean_delta = batch_mean - self._mean 559 new_mean = self._mean + mean_delta * batch_count / new_count 560 # this method for calculating new variable might be numerically unstable 561 m_a = self._var * self._count 562 m_b = batch_var * batch_count 563 m2 = m_a + m_b + np.square(mean_delta) * self._count * batch_count / new_count 564 new_var = m2 / new_count 565 self._mean = new_mean 566 self._var = new_var 567 self._count = new_count 568 569 def reset(self): 570 """ 571 Overview: 572 Resets the state of the environment and reset properties: ``_mean``, ``_var``, ``_count`` 573 """ 574 if len(self._shape) > 0: 575 self._mean = np.zeros(self._shape, 'float32') 576 self._var = np.ones(self._shape, 'float32') 577 else: 578 self._mean, self._var = 0., 1. 579 self._count = self._epsilon 580 581 @property 582 def mean(self) -> np.ndarray: 583 """ 584 Overview: 585 Property ``mean`` gotten from ``self._mean`` 586 """ 587 if np.isscalar(self._mean): 588 return self._mean 589 else: 590 return torch.FloatTensor(self._mean).to(self._device) 591 592 @property 593 def std(self) -> np.ndarray: 594 """ 595 Overview: 596 Property ``std`` calculated from ``self._var`` and the epsilon value of ``self._epsilon`` 597 """ 598 std = np.sqrt(self._var + 1e-8) 599 if np.isscalar(std): 600 return std 601 else: 602 return torch.FloatTensor(std).to(self._device) 603 604 @staticmethod 605 def new_shape(obs_shape, act_shape, rew_shape): 606 """ 607 Overview: 608 Get new shape of observation, acton, and reward; in this case unchanged. 609 Arguments: 610 obs_shape (:obj:`Any`), act_shape (:obj:`Any`), rew_shape (:obj:`Any`) 611 Returns: 612 obs_shape (:obj:`Any`), act_shape (:obj:`Any`), rew_shape (:obj:`Any`) 613 """ 614 return obs_shape, act_shape, rew_shape 615 616 617def make_key_as_identifier(data: Dict[str, Any]) -> Dict[str, Any]: 618 """ 619 Overview: 620 Make the key of dict into legal python identifier string so that it is 621 compatible with some python magic method such as ``__getattr``. 622 Arguments: 623 - data (:obj:`Dict[str, Any]`): The original dict data. 624 Return: 625 - new_data (:obj:`Dict[str, Any]`): The new dict data with legal identifier keys. 626 """ 627 628 def legalization(s: str) -> str: 629 if s[0].isdigit(): 630 s = '_' + s 631 return s.replace('.', '_') 632 633 new_data = {} 634 for k in data: 635 new_k = legalization(k) 636 new_data[new_k] = data[k] 637 return new_data 638 639 640def remove_illegal_item(data: Dict[str, Any]) -> Dict[str, Any]: 641 """ 642 Overview: 643 Remove illegal item in dict info, like str, which is not compatible with Tensor. 644 Arguments: 645 - data (:obj:`Dict[str, Any]`): The original dict data. 646 Return: 647 - new_data (:obj:`Dict[str, Any]`): The new dict data without legal items. 648 """ 649 new_data = {} 650 for k, v in data.items(): 651 if isinstance(v, str): 652 continue 653 new_data[k] = data[k] 654 return new_data