Skip to content

ding.data.shm_buffer

ding.data.shm_buffer

ShmBuffer

Overview

Shared memory buffer to store numpy array.

__init__(dtype, shape, copy_on_get=True, ctype=None)

Overview

Initialize the buffer.

Arguments: - dtype (:obj:Union[type, np.dtype]): The dtype of the data to limit the size of the buffer. - shape (:obj:Tuple[int]): The shape of the data to limit the size of the buffer. - copy_on_get (:obj:bool): Whether to copy data when calling get method. - ctype (:obj:Optional[type]): Origin class type, e.g. np.ndarray, torch.Tensor.

fill(src_arr)

Overview

Fill the shared memory buffer with a numpy array. (Replace the original one.)

Arguments: - src_arr (:obj:np.ndarray): array to fill the buffer.

get()

Overview

Get the array stored in the buffer.

Return: - data (:obj:np.ndarray): A copy of the data stored in the buffer.

ShmBufferContainer

Bases: object

Overview

Support multiple shared memory buffers. Each key-value is name-buffer.

__init__(dtype, shape, copy_on_get=True)

Overview

Initialize the buffer container.

Arguments: - dtype (:obj:Union[type, np.dtype]): The dtype of the data to limit the size of the buffer. - shape (:obj:Union[Dict[Any, tuple], tuple]): If Dict[Any, tuple], use a dict to manage multiple buffers; If tuple, use single buffer. - copy_on_get (:obj:bool): Whether to copy data when calling get method.

fill(src_arr)

Overview

Fill the one or many shared memory buffer.

Arguments: - src_arr (:obj:Union[Dict[Any, np.ndarray], np.ndarray]): array to fill the buffer.

get()

Overview

Get the one or many arrays stored in the buffer.

Return: - data (:obj:np.ndarray): The array(s) stored in the buffer.

Full Source Code

../ding/data/shm_buffer.py

1from typing import Any, Optional, Union, Tuple, Dict 2from multiprocessing import Array 3import ctypes 4import numpy as np 5import torch 6 7_NTYPE_TO_CTYPE = { 8 np.bool_: ctypes.c_bool, 9 np.uint8: ctypes.c_uint8, 10 np.uint16: ctypes.c_uint16, 11 np.uint32: ctypes.c_uint32, 12 np.uint64: ctypes.c_uint64, 13 np.int8: ctypes.c_int8, 14 np.int16: ctypes.c_int16, 15 np.int32: ctypes.c_int32, 16 np.int64: ctypes.c_int64, 17 np.float32: ctypes.c_float, 18 np.float64: ctypes.c_double, 19} 20 21 22class ShmBuffer(): 23 """ 24 Overview: 25 Shared memory buffer to store numpy array. 26 """ 27 28 def __init__( 29 self, 30 dtype: Union[type, np.dtype], 31 shape: Tuple[int], 32 copy_on_get: bool = True, 33 ctype: Optional[type] = None 34 ) -> None: 35 """ 36 Overview: 37 Initialize the buffer. 38 Arguments: 39 - dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer. 40 - shape (:obj:`Tuple[int]`): The shape of the data to limit the size of the buffer. 41 - copy_on_get (:obj:`bool`): Whether to copy data when calling get method. 42 - ctype (:obj:`Optional[type]`): Origin class type, e.g. np.ndarray, torch.Tensor. 43 """ 44 if isinstance(dtype, np.dtype): # it is type of gym.spaces.dtype 45 dtype = dtype.type 46 self.buffer = Array(_NTYPE_TO_CTYPE[dtype], int(np.prod(shape))) 47 self.dtype = dtype 48 self.shape = shape 49 self.copy_on_get = copy_on_get 50 self.ctype = ctype 51 52 def fill(self, src_arr: np.ndarray) -> None: 53 """ 54 Overview: 55 Fill the shared memory buffer with a numpy array. (Replace the original one.) 56 Arguments: 57 - src_arr (:obj:`np.ndarray`): array to fill the buffer. 58 """ 59 assert isinstance(src_arr, np.ndarray), type(src_arr) 60 # for np.array with shape (4, 84, 84) and float32 dtype, reshape is 15~20x faster than flatten 61 # for np.array with shape (4, 84, 84) and uint8 dtype, reshape is 5~7x faster than flatten 62 # so we reshape dst_arr rather than flatten src_arr 63 dst_arr = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape) 64 np.copyto(dst_arr, src_arr) 65 66 def get(self) -> np.ndarray: 67 """ 68 Overview: 69 Get the array stored in the buffer. 70 Return: 71 - data (:obj:`np.ndarray`): A copy of the data stored in the buffer. 72 """ 73 data = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape) 74 if self.copy_on_get: 75 data = data.copy() # must use np.copy, torch.from_numpy and torch.as_tensor still use the same memory 76 if self.ctype is torch.Tensor: 77 data = torch.from_numpy(data) 78 return data 79 80 81class ShmBufferContainer(object): 82 """ 83 Overview: 84 Support multiple shared memory buffers. Each key-value is name-buffer. 85 """ 86 87 def __init__( 88 self, 89 dtype: Union[Dict[Any, type], type, np.dtype], 90 shape: Union[Dict[Any, tuple], tuple], 91 copy_on_get: bool = True 92 ) -> None: 93 """ 94 Overview: 95 Initialize the buffer container. 96 Arguments: 97 - dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer. 98 - shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \ 99 multiple buffers; If `tuple`, use single buffer. 100 - copy_on_get (:obj:`bool`): Whether to copy data when calling get method. 101 """ 102 if isinstance(shape, dict): 103 self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get) for k, v in shape.items()} 104 elif isinstance(shape, (tuple, list)): 105 self._data = ShmBuffer(dtype, shape, copy_on_get) 106 else: 107 raise RuntimeError("not support shape: {}".format(shape)) 108 self._shape = shape 109 110 def fill(self, src_arr: Union[Dict[Any, np.ndarray], np.ndarray]) -> None: 111 """ 112 Overview: 113 Fill the one or many shared memory buffer. 114 Arguments: 115 - src_arr (:obj:`Union[Dict[Any, np.ndarray], np.ndarray]`): array to fill the buffer. 116 """ 117 if isinstance(self._shape, dict): 118 for k in self._shape.keys(): 119 self._data[k].fill(src_arr[k]) 120 elif isinstance(self._shape, (tuple, list)): 121 self._data.fill(src_arr) 122 123 def get(self) -> Union[Dict[Any, np.ndarray], np.ndarray]: 124 """ 125 Overview: 126 Get the one or many arrays stored in the buffer. 127 Return: 128 - data (:obj:`np.ndarray`): The array(s) stored in the buffer. 129 """ 130 if isinstance(self._shape, dict): 131 return {k: self._data[k].get() for k in self._shape.keys()} 132 elif isinstance(self._shape, (tuple, list)): 133 return self._data.get()