Skip to content

ding.utils.fast_copy

ding.utils.fast_copy

Full Source Code

../ding/utils/fast_copy.py

1import torch 2import numpy as np 3from typing import Any, List 4 5 6class _FastCopy: 7 """ 8 Overview: 9 The idea of this class comes from this article \ 10 https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list. 11 We use recursive calls to copy each object that needs to be copied, which will be 5x faster \ 12 than copy.deepcopy. 13 Interfaces: 14 ``__init__``, ``_copy_list``, ``_copy_dict``, ``_copy_tensor``, ``_copy_ndarray``, ``copy``. 15 """ 16 17 def __init__(self): 18 """ 19 Overview: 20 Initialize the _FastCopy object. 21 """ 22 23 dispatch = {} 24 dispatch[list] = self._copy_list 25 dispatch[dict] = self._copy_dict 26 dispatch[torch.Tensor] = self._copy_tensor 27 dispatch[np.ndarray] = self._copy_ndarray 28 self.dispatch = dispatch 29 30 def _copy_list(self, l: List) -> dict: 31 """ 32 Overview: 33 Copy the list. 34 Arguments: 35 - l (:obj:`List`): The list to be copied. 36 """ 37 38 ret = l.copy() 39 for idx, item in enumerate(ret): 40 cp = self.dispatch.get(type(item)) 41 if cp is not None: 42 ret[idx] = cp(item) 43 return ret 44 45 def _copy_dict(self, d: dict) -> dict: 46 """ 47 Overview: 48 Copy the dict. 49 Arguments: 50 - d (:obj:`dict`): The dict to be copied. 51 """ 52 53 ret = d.copy() 54 for key, value in ret.items(): 55 cp = self.dispatch.get(type(value)) 56 if cp is not None: 57 ret[key] = cp(value) 58 59 return ret 60 61 def _copy_tensor(self, t: torch.Tensor) -> torch.Tensor: 62 """ 63 Overview: 64 Copy the tensor. 65 Arguments: 66 - t (:obj:`torch.Tensor`): The tensor to be copied. 67 """ 68 69 return t.clone() 70 71 def _copy_ndarray(self, a: np.ndarray) -> np.ndarray: 72 """ 73 Overview: 74 Copy the ndarray. 75 Arguments: 76 - a (:obj:`np.ndarray`): The ndarray to be copied. 77 """ 78 79 return np.copy(a) 80 81 def copy(self, sth: Any) -> Any: 82 """ 83 Overview: 84 Copy the object. 85 Arguments: 86 - sth (:obj:`Any`): The object to be copied. 87 """ 88 89 cp = self.dispatch.get(type(sth)) 90 if cp is None: 91 return sth 92 else: 93 return cp(sth) 94 95 96fastcopy = _FastCopy()