Skip to content

ding.worker.collector.base_serial_collector

ding.worker.collector.base_serial_collector

ISerialCollector

Bases: ABC

Overview

Abstract baseclass for serial collector.

Interfaces: default_config, reset_env, reset_policy, reset, collect Property: envstep

default_config() classmethod

Overview

Get collector's default config. We merge collector's default config with other default configs and user's config to get the final config.

Return: cfg: (:obj:EasyDict): collector's default config

reset_env(_env=None) abstractmethod

Overview

Reset collector's environment. In some case, we need collector use the same policy to collect data in different environments. We can use reset_env to reset the environment.

reset_policy(_policy=None) abstractmethod

Overview

Reset collector's policy. In some case, we need collector work in this same environment but use different policy to collect data. We can use reset_policy to reset the policy.

reset(_policy=None, _env=None) abstractmethod

Overview

Reset collector's policy and environment. Use new policy and environment to collect data.

collect(per_collect_target) abstractmethod

Overview

Collect the corresponding data according to the specified target and return. There are different definitions in episode and sample mode.

envstep()

Overview

Get the total envstep num.

CachePool

Bases: object

Overview

CachePool is the repository of cache items.

Interfaces: init, update, getitem, reset

__init__(name, env_num, deepcopy=False)

Overview

Initialization method.

Arguments: - name (:obj:str): name of cache - env_num (:obj:int): number of environments - deepcopy (:obj:bool): whether to deepcopy data

update(data)

Overview

Update elements in cache pool.

Arguments: - data (:obj:Dict[int, Any]): A dict containing update index-value pairs. Key is index in cache pool, and value is the new element.

__getitem__(idx)

Overview

Get item in cache pool.

Arguments: - idx (:obj:int): The index of the item we need to get. Return: - item (:obj:Any): The item we get.

reset(idx)

Overview

Reset the cache pool.

Arguments: - idx (:obj:int): The index of the position we need to reset.

TrajBuffer

Bases: list

Overview

TrajBuffer is used to store traj_len pieces of transitions.

Interfaces: init, append

__init__(maxlen, *args, deepcopy=False, **kwargs)

Overview

Initialization trajBuffer.

Arguments: - maxlen (:obj:int): The maximum length of trajectory buffer. - deepcopy (:obj:bool): Whether to deepcopy data when do operation.

append(data)

Overview

Append data to trajBuffer.

create_serial_collector(cfg, **kwargs)

Overview

Create a specific collector instance based on the config.

get_serial_collector_cls(cfg)

Overview

Get the specific collector class according to the config.

to_tensor_transitions(data, shallow_copy_next_obs=True)

Overview

Transform ths original transition return from env to tensor format.

Argument: - data (:obj:List[Dict[str, Any]]): The data that will be transformed to tensor. - shallow_copy_next_obs (:obj:bool): Whether to shallow copy next_obs. Default: True. Return: - data (:obj:List[Dict[str, Any]]): The transformed tensor-like data.

.. tip:: In order to save memory, If there are next_obs in the passed data, we do special treatment on next_obs so that the next_obs of each state in the data fragment is the next state's obs and the next_obs of the last state is its own next_obsself. Besides, we set transform_scalar to False to avoid the extra .item() operation.

Full Source Code

../ding/worker/collector/base_serial_collector.py

1from abc import ABC, abstractmethod, abstractproperty 2from typing import List, Dict, Any, Optional, Union 3from collections import namedtuple 4from easydict import EasyDict 5import copy 6 7from ding.envs import BaseEnvManager 8from ding.utils import SERIAL_COLLECTOR_REGISTRY, import_module 9from ding.torch_utils import to_tensor 10 11INF = float("inf") 12 13 14class ISerialCollector(ABC): 15 """ 16 Overview: 17 Abstract baseclass for serial collector. 18 Interfaces: 19 default_config, reset_env, reset_policy, reset, collect 20 Property: 21 envstep 22 """ 23 24 @classmethod 25 def default_config(cls: type) -> EasyDict: 26 """ 27 Overview: 28 Get collector's default config. We merge collector's default config with other default configs\ 29 and user's config to get the final config. 30 Return: 31 cfg: (:obj:`EasyDict`): collector's default config 32 """ 33 cfg = EasyDict(copy.deepcopy(cls.config)) 34 cfg.cfg_type = cls.__name__ + 'Dict' 35 return cfg 36 37 @abstractmethod 38 def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: 39 """ 40 Overview: 41 Reset collector's environment. In some case, we need collector use the same policy to collect \ 42 data in different environments. We can use reset_env to reset the environment. 43 """ 44 raise NotImplementedError 45 46 @abstractmethod 47 def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: 48 """ 49 Overview: 50 Reset collector's policy. In some case, we need collector work in this same environment but use\ 51 different policy to collect data. We can use reset_policy to reset the policy. 52 """ 53 raise NotImplementedError 54 55 @abstractmethod 56 def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: 57 """ 58 Overview: 59 Reset collector's policy and environment. Use new policy and environment to collect data. 60 """ 61 raise NotImplementedError 62 63 @abstractmethod 64 def collect(self, per_collect_target: Any) -> List[Any]: 65 """ 66 Overview: 67 Collect the corresponding data according to the specified target and return. \ 68 There are different definitions in episode and sample mode. 69 """ 70 raise NotImplementedError 71 72 @abstractproperty 73 def envstep(self) -> int: 74 """ 75 Overview: 76 Get the total envstep num. 77 """ 78 raise NotImplementedError 79 80 81def create_serial_collector(cfg: EasyDict, **kwargs) -> ISerialCollector: 82 """ 83 Overview: 84 Create a specific collector instance based on the config. 85 """ 86 import_module(cfg.get('import_names', [])) 87 return SERIAL_COLLECTOR_REGISTRY.build(cfg.type, cfg=cfg, **kwargs) 88 89 90def get_serial_collector_cls(cfg: EasyDict) -> type: 91 """ 92 Overview: 93 Get the specific collector class according to the config. 94 """ 95 assert hasattr(cfg, 'type'), "{}-{}-{}".format(type(cfg), cfg.keys(), cfg['type']) 96 import_module(cfg.get('import_names', [])) 97 return SERIAL_COLLECTOR_REGISTRY.get(cfg.type) 98 99 100class CachePool(object): 101 """ 102 Overview: 103 CachePool is the repository of cache items. 104 Interfaces: 105 __init__, update, __getitem__, reset 106 """ 107 108 def __init__(self, name: str, env_num: int, deepcopy: bool = False) -> None: 109 """ 110 Overview: 111 Initialization method. 112 Arguments: 113 - name (:obj:`str`): name of cache 114 - env_num (:obj:`int`): number of environments 115 - deepcopy (:obj:`bool`): whether to deepcopy data 116 """ 117 self._pool = [None for _ in range(env_num)] 118 # TODO(nyz) whether must use deepcopy 119 self._deepcopy = deepcopy 120 121 def update(self, data: Union[Dict[int, Any], list]) -> None: 122 """ 123 Overview: 124 Update elements in cache pool. 125 Arguments: 126 - data (:obj:`Dict[int, Any]`): A dict containing update index-value pairs. Key is index in cache pool, \ 127 and value is the new element. 128 """ 129 if isinstance(data, dict): 130 data = [data] 131 for index in range(len(data)): 132 for i in data[index].keys(): 133 d = data[index][i] 134 if self._deepcopy: 135 copy_d = copy.deepcopy(d) 136 else: 137 copy_d = d 138 if index == 0: 139 self._pool[i] = [copy_d] 140 else: 141 self._pool[i].append(copy_d) 142 143 def __getitem__(self, idx: int) -> Any: 144 """ 145 Overview: 146 Get item in cache pool. 147 Arguments: 148 - idx (:obj:`int`): The index of the item we need to get. 149 Return: 150 - item (:obj:`Any`): The item we get. 151 """ 152 data = self._pool[idx] 153 if data is not None and len(data) == 1: 154 data = data[0] 155 return data 156 157 def reset(self, idx: int) -> None: 158 """ 159 Overview: 160 Reset the cache pool. 161 Arguments: 162 - idx (:obj:`int`): The index of the position we need to reset. 163 """ 164 self._pool[idx] = None 165 166 167class TrajBuffer(list): 168 """ 169 Overview: 170 TrajBuffer is used to store traj_len pieces of transitions. 171 Interfaces: 172 __init__, append 173 """ 174 175 def __init__(self, maxlen: int, *args, deepcopy: bool = False, **kwargs) -> None: 176 """ 177 Overview: 178 Initialization trajBuffer. 179 Arguments: 180 - maxlen (:obj:`int`): The maximum length of trajectory buffer. 181 - deepcopy (:obj:`bool`): Whether to deepcopy data when do operation. 182 """ 183 self._maxlen = maxlen 184 self._deepcopy = deepcopy 185 super().__init__(*args, **kwargs) 186 187 def append(self, data: Any) -> None: 188 """ 189 Overview: 190 Append data to trajBuffer. 191 """ 192 if self._maxlen is not None: 193 while len(self) >= self._maxlen: 194 del self[0] 195 if self._deepcopy: 196 data = copy.deepcopy(data) 197 super().append(data) 198 199 200def to_tensor_transitions(data: List[Dict[str, Any]], shallow_copy_next_obs: bool = True) -> List[Dict[str, Any]]: 201 """ 202 Overview: 203 Transform ths original transition return from env to tensor format. 204 Argument: 205 - data (:obj:`List[Dict[str, Any]]`): The data that will be transformed to tensor. 206 - shallow_copy_next_obs (:obj:`bool`): Whether to shallow copy next_obs. Default: True. 207 Return: 208 - data (:obj:`List[Dict[str, Any]]`): The transformed tensor-like data. 209 210 .. tip:: 211 In order to save memory, If there are next_obs in the passed data, we do special \ 212 treatment on next_obs so that the next_obs of each state in the data fragment is \ 213 the next state's obs and the next_obs of the last state is its own next_obsself. \ 214 Besides, we set transform_scalar to False to avoid the extra ``.item()`` operation. 215 """ 216 if 'next_obs' not in data[0]: 217 return to_tensor(data, transform_scalar=False) 218 else: 219 # to_tensor will assign the separate memory to next_obs, if shallow_copy_next_obs is True, 220 # we can add ignore_keys to avoid this data copy for saving memory of next_obs. 221 if shallow_copy_next_obs: 222 data = to_tensor(data, ignore_keys=['next_obs'], transform_scalar=False) 223 for i in range(len(data) - 1): 224 data[i]['next_obs'] = data[i + 1]['obs'] 225 data[-1]['next_obs'] = to_tensor(data[-1]['next_obs'], transform_scalar=False) 226 return data 227 else: 228 data = to_tensor(data, transform_scalar=False) 229 return data