Skip to content

ding.worker.replay_buffer.base_buffer

ding.worker.replay_buffer.base_buffer

IBuffer

Bases: ABC

Overview

Buffer interface

Interfaces: default_config, push, update, sample, clear, count, state_dict, load_state_dict

default_config() classmethod

Overview

Default config of this buffer class.

Returns: - default_config (:obj:EasyDict)

push(data, cur_collector_envstep) abstractmethod

Overview

Push a data into buffer.

Arguments: - data (:obj:Union[List[Any], Any]): The data which will be pushed into buffer. Can be one \ (in Any type), or many(int List[Any] type). - cur_collector_envstep (:obj:int): Collector's current env step.

update(info) abstractmethod

Overview

Update data info, e.g. priority.

Arguments: - info (:obj:Dict[str, list]): Info dict. Keys depends on the specific buffer type.

sample(batch_size, cur_learner_iter) abstractmethod

Overview

Sample data with length batch_size.

Arguments: - size (:obj:int): The number of the data that will be sampled. - cur_learner_iter (:obj:int): Learner's current iteration. Returns: - sampled_data (:obj:list): A list of data with length batch_size.

clear() abstractmethod

Overview

Clear all the data and reset the related variables.

count() abstractmethod

Overview

Count how many valid datas there are in the buffer.

Returns: - count (:obj:int): Number of valid data.

save_data(file_name) abstractmethod

Overview

Save buffer data into a file.

Arguments: - file_name (:obj:str): file name of buffer data

load_data(file_name) abstractmethod

Overview

Load buffer data from a file.

Arguments: - file_name (:obj:str): file name of buffer data

state_dict() abstractmethod

Overview

Provide a state dict to keep a record of current buffer.

Returns: - state_dict (:obj:Dict[str, Any]): A dict containing all important values in the buffer. With the dict, one can easily reproduce the buffer.

load_state_dict(_state_dict) abstractmethod

Overview

Load state dict to reproduce the buffer.

Returns: - state_dict (:obj:Dict[str, Any]): A dict containing all important values in the buffer.

create_buffer(cfg, *args, **kwargs)

Overview

Create a buffer according to cfg and other arguments.

Arguments: - cfg (:obj:EasyDict): Buffer config. ArgumentsKeys: - necessary: type

get_buffer_cls(cfg)

Overview

Get a buffer class according to cfg.

Arguments: - cfg (:obj:EasyDict): Buffer config. ArgumentsKeys: - necessary: type

Full Source Code

../ding/worker/replay_buffer/base_buffer.py

1from typing import Union, Dict, Any, List 2from abc import ABC, abstractmethod 3import copy 4from easydict import EasyDict 5 6from ding.utils import import_module, BUFFER_REGISTRY 7 8 9class IBuffer(ABC): 10 r""" 11 Overview: 12 Buffer interface 13 Interfaces: 14 default_config, push, update, sample, clear, count, state_dict, load_state_dict 15 """ 16 17 @classmethod 18 def default_config(cls) -> EasyDict: 19 r""" 20 Overview: 21 Default config of this buffer class. 22 Returns: 23 - default_config (:obj:`EasyDict`) 24 """ 25 cfg = EasyDict(copy.deepcopy(cls.config)) 26 cfg.cfg_type = cls.__name__ + 'Dict' 27 return cfg 28 29 @abstractmethod 30 def push(self, data: Union[List[Any], Any], cur_collector_envstep: int) -> None: 31 r""" 32 Overview: 33 Push a data into buffer. 34 Arguments: 35 - data (:obj:`Union[List[Any], Any]`): The data which will be pushed into buffer. Can be one \ 36 (in `Any` type), or many(int `List[Any]` type). 37 - cur_collector_envstep (:obj:`int`): Collector's current env step. 38 """ 39 raise NotImplementedError 40 41 @abstractmethod 42 def update(self, info: Dict[str, list]) -> None: 43 r""" 44 Overview: 45 Update data info, e.g. priority. 46 Arguments: 47 - info (:obj:`Dict[str, list]`): Info dict. Keys depends on the specific buffer type. 48 """ 49 raise NotImplementedError 50 51 @abstractmethod 52 def sample(self, batch_size: int, cur_learner_iter: int) -> list: 53 r""" 54 Overview: 55 Sample data with length ``batch_size``. 56 Arguments: 57 - size (:obj:`int`): The number of the data that will be sampled. 58 - cur_learner_iter (:obj:`int`): Learner's current iteration. 59 Returns: 60 - sampled_data (:obj:`list`): A list of data with length `batch_size`. 61 """ 62 raise NotImplementedError 63 64 @abstractmethod 65 def clear(self) -> None: 66 """ 67 Overview: 68 Clear all the data and reset the related variables. 69 """ 70 raise NotImplementedError 71 72 @abstractmethod 73 def count(self) -> int: 74 """ 75 Overview: 76 Count how many valid datas there are in the buffer. 77 Returns: 78 - count (:obj:`int`): Number of valid data. 79 """ 80 raise NotImplementedError 81 82 @abstractmethod 83 def save_data(self, file_name: str): 84 """ 85 Overview: 86 Save buffer data into a file. 87 Arguments: 88 - file_name (:obj:`str`): file name of buffer data 89 """ 90 raise NotImplementedError 91 92 @abstractmethod 93 def load_data(self, file_name: str): 94 """ 95 Overview: 96 Load buffer data from a file. 97 Arguments: 98 - file_name (:obj:`str`): file name of buffer data 99 """ 100 raise NotImplementedError 101 102 @abstractmethod 103 def state_dict(self) -> Dict[str, Any]: 104 """ 105 Overview: 106 Provide a state dict to keep a record of current buffer. 107 Returns: 108 - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. \ 109 With the dict, one can easily reproduce the buffer. 110 """ 111 raise NotImplementedError 112 113 @abstractmethod 114 def load_state_dict(self, _state_dict: Dict[str, Any]) -> None: 115 """ 116 Overview: 117 Load state dict to reproduce the buffer. 118 Returns: 119 - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. 120 """ 121 raise NotImplementedError 122 123 124def create_buffer(cfg: EasyDict, *args, **kwargs) -> IBuffer: 125 r""" 126 Overview: 127 Create a buffer according to cfg and other arguments. 128 Arguments: 129 - cfg (:obj:`EasyDict`): Buffer config. 130 ArgumentsKeys: 131 - necessary: `type` 132 """ 133 import_module(cfg.get('import_names', [])) 134 if cfg.type == 'naive': 135 kwargs.pop('tb_logger', None) 136 return BUFFER_REGISTRY.build(cfg.type, cfg, *args, **kwargs) 137 138 139def get_buffer_cls(cfg: EasyDict) -> type: 140 r""" 141 Overview: 142 Get a buffer class according to cfg. 143 Arguments: 144 - cfg (:obj:`EasyDict`): Buffer config. 145 ArgumentsKeys: 146 - necessary: `type` 147 """ 148 import_module(cfg.get('import_names', [])) 149 return BUFFER_REGISTRY.get(cfg.type)