Skip to content

ding.utils.linklink_dist_helper

Overview

A context manager for linklink distribution

Interfaces: __init__, __enter__, __exit__

Overview

Initialize the DistContext

Overview

Initialize linklink distribution

Overview

Finalize linklink distribution

Arugments: - args (:obj:Tuple): The arguments passed to the __exit__ function. - kwargs (:obj:Dict): The keyword arguments passed to the __exit__ function.

Overview

Get the rank of linklink model, return 0 if use FakeLink.

.. note:: Reference import_helper.try_import_link and linklink.get_rank.

Overview

Get the world_size of linklink model, return 0 if use FakeLink.

.. note:: Reference import_helper.try_import_link and linklink.get_world_size.

Overview

Use linklink.broadcast and raise error when using FakeLink

Arguments: - value (:obj:obj): the value to board cast - rank (:obj:int): the rank to broadcast on

Overview

Call linklink.allreduce on the data

Arguments: - data (:obj:obj): the data to reduce - op (:obj:str): the operation to perform on data, support ['sum', 'max']

Overview

Call linklink.allreduce_async on the data

Arguments: - data (:obj:obj): the data to reduce - op (:obj:str): the operation to perform on data, support ['sum', 'max']

Overview

Get the group segmentation of group_size each group

Arguments: - group_size (:obj:int) the group_size

Overview

Wrap the function so that in can init and finalize automatically before each call

Arguments: - func (:obj:Callable): the function to wrap

Overview

Init the distribution

Arguments: - method (:obj:str): Support ['slurm', 'single_node`] - device_id (:obj:int): Default device when using single_node method

Overview

Finalize linklink, see linklink.finalize()

Overview

Split the group according to worldsize, rank and num_groups

Arguments: - world_size (:obj:int): The world size - rank (:obj:int): The rank - num_groups (:obj:int): The number of groups

.. note:: With faulty input, raise array split does not result in an equal division

Overview

Synchronize the process

Full Source Code

../ding/utils/linklink_dist_helper.py

1from functools import lru_cache 2from typing import Callable, Tuple, List, Any 3 4import numpy as np 5import torch 6 7from .default_helper import error_wrapper 8from .fake_linklink import FakeLink 9from .import_helper import try_import_link 10 11 12@lru_cache() 13def get_link(): 14 return try_import_link() 15 16 17@lru_cache() 18def is_fake_link(): 19 return isinstance(get_link(), FakeLink) 20 21 22def get_rank() -> int: 23 """ 24 Overview: 25 Get the rank of ``linklink`` model, return 0 if use ``FakeLink``. 26 27 .. note:: 28 Reference ``import_helper.try_import_link`` and ``linklink.get_rank``. 29 """ 30 if is_fake_link(): 31 return 0 32 return error_wrapper(get_link().get_rank, 0, "[WARNING]: call linklink error, return default_ret.")() 33 34 35def get_world_size() -> int: 36 """ 37 Overview: 38 Get the ``world_size`` of ``linklink model``, return 0 if use ``FakeLink``. 39 40 .. note:: 41 Reference ``import_helper.try_import_link`` and ``linklink.get_world_size``. 42 """ 43 if is_fake_link(): 44 return 1 45 return error_wrapper(get_link().get_world_size, 1, "[WARNING]: call linklink error, return default_ret.")() 46 47 48def broadcast(value: torch.Tensor, rank: int) -> None: 49 """ 50 Overview: 51 Use ``linklink.broadcast`` and raise error when using ``FakeLink`` 52 Arguments: 53 - value (:obj:`obj`): the value to board cast 54 - rank (:obj:`int`): the rank to broadcast on 55 """ 56 if is_fake_link(): 57 raise NotImplementedError 58 get_link().broadcast(value, rank) 59 60 61def allreduce(data: torch.Tensor, op: str = 'sum') -> None: 62 """ 63 Overview: 64 Call ``linklink.allreduce`` on the data 65 Arguments: 66 - data (:obj:`obj`): the data to reduce 67 - op (:obj:`str`): the operation to perform on data, support ``['sum', 'max']`` 68 """ 69 link_op_map = {'sum': get_link().allreduceOp_t.Sum, 'max': get_link().allreduceOp_t.Max} 70 if op not in link_op_map.keys(): 71 raise KeyError("not support allreduce op type: {}".format(op)) 72 else: 73 link_op = link_op_map[op] 74 if is_fake_link(): 75 return data 76 get_link().allreduce(data, reduce_op=link_op) 77 if op == 'sum': 78 data.div_(get_world_size()) 79 80 81def allreduce_async(data: torch.Tensor, op: str = 'sum') -> None: 82 """ 83 Overview: 84 Call ``linklink.allreduce_async`` on the data 85 Arguments: 86 - data (:obj:`obj`): the data to reduce 87 - op (:obj:`str`): the operation to perform on data, support ``['sum', 'max']`` 88 """ 89 link_op_map = {'sum': get_link().allreduceOp_t.Sum, 'max': get_link().allreduceOp_t.Max} 90 if op not in link_op_map.keys(): 91 raise KeyError("not support allreduce op type: {}".format(op)) 92 else: 93 link_op = link_op_map[op] 94 if is_fake_link(): 95 return data 96 if op == 'sum': 97 data.div_(get_world_size()) 98 get_link().allreduce_async(data, reduce_op=link_op) 99 100 101def get_group(group_size: int) -> List: 102 """ 103 Overview: 104 Get the group segmentation of ``group_size`` each group 105 Arguments: 106 - group_size (:obj:`int`) the ``group_size`` 107 """ 108 rank = get_rank() 109 world_size = get_world_size() 110 if group_size is None: 111 group_size = world_size 112 assert (world_size % group_size == 0) 113 return simple_group_split(world_size, rank, world_size // group_size) 114 115 116def dist_mode(func: Callable) -> Callable: 117 """ 118 Overview: 119 Wrap the function so that in can init and finalize automatically before each call 120 Arguments: 121 - func (:obj:`Callable`): the function to wrap 122 """ 123 124 def wrapper(*args, **kwargs): 125 dist_init() 126 func(*args, **kwargs) 127 dist_finalize() 128 129 return wrapper 130 131 132def dist_init(method: str = 'slurm', device_id: int = 0) -> Tuple[int, int]: 133 """ 134 Overview: 135 Init the distribution 136 Arguments: 137 - method (:obj:`str`): Support ``['slurm', 'single_node`]`` 138 - device_id (:obj:`int`): Default device when using ``single_node`` method 139 """ 140 get_link().initialize() 141 world_size = get_link().get_world_size() 142 rank = get_link().get_rank() 143 144 if method == 'slurm': 145 # proc_id = int(os.environ['SLURM_PROCID']) 146 # ntasks = int(os.environ['SLURM_NTASKS']) 147 # node_list = os.environ['SLURM_NODELIST'] 148 num_gpus = torch.cuda.device_count() 149 torch.cuda.set_device(rank % num_gpus) 150 elif method == 'single_node': 151 torch.cuda.set_device(device_id) 152 153 return rank, world_size 154 155 156def dist_finalize() -> None: 157 """ 158 Overview: 159 Finalize ``linklink``, see ``linklink.finalize()`` 160 """ 161 get_link().finalize() 162 163 164class DistContext: 165 """ 166 Overview: 167 A context manager for ``linklink`` distribution 168 Interfaces: 169 ``__init__``, ``__enter__``, ``__exit__`` 170 """ 171 172 def __init__(self) -> None: 173 """ 174 Overview: 175 Initialize the ``DistContext`` 176 """ 177 178 pass 179 180 def __enter__(self) -> None: 181 """ 182 Overview: 183 Initialize ``linklink`` distribution 184 """ 185 186 dist_init() 187 188 def __exit__(self, *args, **kwargs) -> Any: 189 """ 190 Overview: 191 Finalize ``linklink`` distribution 192 Arugments: 193 - args (:obj:`Tuple`): The arguments passed to the ``__exit__`` function. 194 - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__exit__`` function. 195 """ 196 197 dist_finalize() 198 199 200def simple_group_split(world_size: int, rank: int, num_groups: int) -> List: 201 """ 202 Overview: 203 Split the group according to ``worldsize``, ``rank`` and ``num_groups`` 204 Arguments: 205 - world_size (:obj:`int`): The world size 206 - rank (:obj:`int`): The rank 207 - num_groups (:obj:`int`): The number of groups 208 209 .. note:: 210 With faulty input, raise ``array split does not result in an equal division`` 211 """ 212 213 groups = [] 214 rank_list = np.split(np.arange(world_size), num_groups) 215 rank_list = [list(map(int, x)) for x in rank_list] 216 for i in range(num_groups): 217 groups.append(get_link().new_group(rank_list[i])) 218 group_size = world_size // num_groups 219 return groups[rank // group_size] 220 221 222def synchronize(): 223 """ 224 Overview: 225 Synchronize the process 226 """ 227 228 get_link().synchronize()