Skip to content

ding.utils.pytorch_ddp_dist_helper

ding.utils.pytorch_ddp_dist_helper

DDPContext

Overview

A context manager for linklink distribution

Interfaces: __init__, __enter__, __exit__

__init__()

Overview

Initialize the DDPContext

__enter__()

Overview

Initialize linklink distribution

__exit__(*args, **kwargs)

Overview

Finalize linklink distribution

get_rank()

Overview

Get the rank of current process in total world_size

get_world_size()

Overview

Get the world_size(total process number in data parallel training)

allreduce(x)

Overview

All reduce the tensor x in the world

Arguments: - x (:obj:torch.Tensor): the tensor to be reduced

allreduce_with_indicator(grad, indicator)

Overview

Custom allreduce: Sum both the gradient and indicator tensors across all processes. Then, if at least one process contributed (i.e., the summation of indicator > 0), divide the gradient by the summed indicator. This ensures that if only a subset of GPUs contributed a gradient, the averaging is performed based on the actual number of contributors rather than the total number of GPUs.

Arguments: - grad (torch.Tensor): Local gradient tensor to be reduced. - indicator (torch.Tensor): A tensor flag (1 if the gradient is computed, 0 otherwise).

allreduce_async(name, x)

Overview

All reduce the tensor x in the world asynchronously

Arguments: - name (:obj:str): the name of the tensor - x (:obj:torch.Tensor): the tensor to be reduced

reduce_data(x, dst)

Overview

Reduce the tensor x to the destination process dst

Arguments: - x (:obj:Union[int, float, torch.Tensor]): the tensor to be reduced - dst (:obj:int): the destination process

allreduce_data(x, op)

Overview

All reduce the tensor x in the world

Arguments: - x (:obj:Union[int, float, torch.Tensor]): the tensor to be reduced - op (:obj:str): the operation to perform on data, support ['sum', 'avg']

get_group(group_size)

Overview

Get the group segmentation of group_size each group

Arguments: - group_size (:obj:int) the group_size

dist_mode(func)

Overview

Wrap the function so that in can init and finalize automatically before each call

Arguments: - func (:obj:Callable): the function to be wrapped

dist_init(backend='nccl', addr=None, port=None, rank=None, world_size=None, timeout=datetime.timedelta(seconds=60000))

Overview

Initialize the distributed training setting.

Arguments: - backend (:obj:str): The backend of the distributed training, supports ['nccl', 'gloo']. - addr (:obj:str): The address of the master node. - port (:obj:str): The port of the master node. - rank (:obj:int): The rank of the current process. - world_size (:obj:int): The total number of processes. - timeout (:obj:datetime.timedelta): The timeout for operations executed against the process group. Default is 60000 seconds.

dist_finalize()

Overview

Finalize distributed training resources

simple_group_split(world_size, rank, num_groups)

Overview

Split the group according to worldsize, rank and num_groups

Arguments: - world_size (:obj:int): The world size - rank (:obj:int): The rank - num_groups (:obj:int): The number of groups

.. note:: With faulty input, raise array split does not result in an equal division

to_ddp_config(cfg)

Overview

Convert the config to ddp config

Arguments: - cfg (:obj:EasyDict): The config to be converted

Full Source Code

../ding/utils/pytorch_ddp_dist_helper.py

1from typing import Callable, Tuple, List, Any, Union 2from easydict import EasyDict 3 4import os 5import numpy as np 6import torch 7import torch.distributed as dist 8import datetime 9 10from .default_helper import error_wrapper 11 12# from .slurm_helper import get_master_addr 13 14 15def get_rank() -> int: 16 """ 17 Overview: 18 Get the rank of current process in total world_size 19 """ 20 # return int(os.environ.get('SLURM_PROCID', 0)) 21 return error_wrapper(dist.get_rank, 0)() 22 23 24def get_world_size() -> int: 25 """ 26 Overview: 27 Get the world_size(total process number in data parallel training) 28 """ 29 # return int(os.environ.get('SLURM_NTASKS', 1)) 30 return error_wrapper(dist.get_world_size, 1)() 31 32 33broadcast = dist.broadcast 34allgather = dist.all_gather 35broadcast_object_list = dist.broadcast_object_list 36 37 38def allreduce(x: torch.Tensor) -> None: 39 """ 40 Overview: 41 All reduce the tensor ``x`` in the world 42 Arguments: 43 - x (:obj:`torch.Tensor`): the tensor to be reduced 44 """ 45 46 dist.all_reduce(x) 47 x.div_(get_world_size()) 48 49 50def allreduce_with_indicator(grad: torch.Tensor, indicator: torch.Tensor) -> None: 51 """ 52 Overview: 53 Custom allreduce: Sum both the gradient and indicator tensors across all processes. 54 Then, if at least one process contributed (i.e., the summation of indicator > 0), 55 divide the gradient by the summed indicator. This ensures that if only a subset of 56 GPUs contributed a gradient, the averaging is performed based on the actual number 57 of contributors rather than the total number of GPUs. 58 Arguments: 59 - grad (torch.Tensor): Local gradient tensor to be reduced. 60 - indicator (torch.Tensor): A tensor flag (1 if the gradient is computed, 0 otherwise). 61 """ 62 # Allreduce (sum) the gradient and indicator 63 dist.all_reduce(grad) 64 dist.all_reduce(indicator) 65 66 # Avoid division by zero. If indicator is close to 0 (extreme case), grad remains zeros. 67 if not torch.isclose(indicator, torch.tensor(0.0)): 68 grad.div_(indicator.item()) 69 70 71def allreduce_async(name: str, x: torch.Tensor) -> None: 72 """ 73 Overview: 74 All reduce the tensor ``x`` in the world asynchronously 75 Arguments: 76 - name (:obj:`str`): the name of the tensor 77 - x (:obj:`torch.Tensor`): the tensor to be reduced 78 """ 79 80 x.div_(get_world_size()) 81 dist.all_reduce(x, async_op=True) 82 83 84def reduce_data(x: Union[int, float, torch.Tensor], dst: int) -> Union[int, float, torch.Tensor]: 85 """ 86 Overview: 87 Reduce the tensor ``x`` to the destination process ``dst`` 88 Arguments: 89 - x (:obj:`Union[int, float, torch.Tensor]`): the tensor to be reduced 90 - dst (:obj:`int`): the destination process 91 """ 92 93 if np.isscalar(x): 94 x_tensor = torch.as_tensor([x]).cuda() 95 dist.reduce(x_tensor, dst) 96 return x_tensor.item() 97 elif isinstance(x, torch.Tensor): 98 dist.reduce(x, dst) 99 return x 100 else: 101 raise TypeError("not supported type: {}".format(type(x))) 102 103 104def allreduce_data(x: Union[int, float, torch.Tensor], op: str) -> Union[int, float, torch.Tensor]: 105 """ 106 Overview: 107 All reduce the tensor ``x`` in the world 108 Arguments: 109 - x (:obj:`Union[int, float, torch.Tensor]`): the tensor to be reduced 110 - op (:obj:`str`): the operation to perform on data, support ``['sum', 'avg']`` 111 """ 112 113 assert op in ['sum', 'avg'], op 114 if np.isscalar(x): 115 x_tensor = torch.as_tensor([x]).cuda() 116 dist.all_reduce(x_tensor) 117 if op == 'avg': 118 x_tensor.div_(get_world_size()) 119 return x_tensor.item() 120 elif isinstance(x, torch.Tensor): 121 dist.all_reduce(x) 122 if op == 'avg': 123 x.div_(get_world_size()) 124 return x 125 else: 126 raise TypeError("not supported type: {}".format(type(x))) 127 128 129synchronize = torch.cuda.synchronize 130 131 132def get_group(group_size: int) -> List: 133 """ 134 Overview: 135 Get the group segmentation of ``group_size`` each group 136 Arguments: 137 - group_size (:obj:`int`) the ``group_size`` 138 """ 139 rank = get_rank() 140 world_size = get_world_size() 141 if group_size is None: 142 group_size = world_size 143 assert (world_size % group_size == 0) 144 return simple_group_split(world_size, rank, world_size // group_size) 145 146 147def dist_mode(func: Callable) -> Callable: 148 """ 149 Overview: 150 Wrap the function so that in can init and finalize automatically before each call 151 Arguments: 152 - func (:obj:`Callable`): the function to be wrapped 153 """ 154 155 def wrapper(*args, **kwargs): 156 dist_init() 157 func(*args, **kwargs) 158 dist_finalize() 159 160 return wrapper 161 162 163def dist_init( 164 backend: str = 'nccl', 165 addr: str = None, 166 port: str = None, 167 rank: int = None, 168 world_size: int = None, 169 timeout: datetime.timedelta = datetime.timedelta(seconds=60000) 170) -> Tuple[int, int]: 171 """ 172 Overview: 173 Initialize the distributed training setting. 174 Arguments: 175 - backend (:obj:`str`): The backend of the distributed training, supports ``['nccl', 'gloo']``. 176 - addr (:obj:`str`): The address of the master node. 177 - port (:obj:`str`): The port of the master node. 178 - rank (:obj:`int`): The rank of the current process. 179 - world_size (:obj:`int`): The total number of processes. 180 - timeout (:obj:`datetime.timedelta`): The timeout for operations executed against the process group. \ 181 Default is 60000 seconds. 182 """ 183 184 assert backend in ['nccl', 'gloo'], backend 185 os.environ['MASTER_ADDR'] = addr or os.environ.get('MASTER_ADDR', "localhost") 186 os.environ['MASTER_PORT'] = port or os.environ.get('MASTER_PORT', "10314") # hard-code 187 188 if rank is None: 189 local_id = os.environ.get('SLURM_LOCALID', os.environ.get('RANK', None)) 190 if local_id is None: 191 raise RuntimeError("please indicate rank explicitly in dist_init method") 192 else: 193 rank = int(local_id) 194 if world_size is None: 195 ntasks = os.environ.get('SLURM_NTASKS', os.environ.get('WORLD_SIZE', None)) 196 if ntasks is None: 197 raise RuntimeError("please indicate world_size explicitly in dist_init method") 198 else: 199 world_size = int(ntasks) 200 201 dist.init_process_group(backend=backend, rank=rank, world_size=world_size, timeout=timeout) 202 203 num_gpus = torch.cuda.device_count() 204 torch.cuda.set_device(rank % num_gpus) 205 world_size = get_world_size() 206 rank = get_rank() 207 return rank, world_size 208 209 210def dist_finalize() -> None: 211 """ 212 Overview: 213 Finalize distributed training resources 214 """ 215 # This operation usually hangs out so we ignore it temporally. 216 # dist.destroy_process_group() 217 pass 218 219 220class DDPContext: 221 """ 222 Overview: 223 A context manager for ``linklink`` distribution 224 Interfaces: 225 ``__init__``, ``__enter__``, ``__exit__`` 226 """ 227 228 def __init__(self) -> None: 229 """ 230 Overview: 231 Initialize the ``DDPContext`` 232 """ 233 234 pass 235 236 def __enter__(self) -> None: 237 """ 238 Overview: 239 Initialize ``linklink`` distribution 240 """ 241 242 dist_init() 243 244 def __exit__(self, *args, **kwargs) -> Any: 245 """ 246 Overview: 247 Finalize ``linklink`` distribution 248 """ 249 250 dist_finalize() 251 252 253def simple_group_split(world_size: int, rank: int, num_groups: int) -> List: 254 """ 255 Overview: 256 Split the group according to ``worldsize``, ``rank`` and ``num_groups`` 257 Arguments: 258 - world_size (:obj:`int`): The world size 259 - rank (:obj:`int`): The rank 260 - num_groups (:obj:`int`): The number of groups 261 262 .. note:: 263 With faulty input, raise ``array split does not result in an equal division`` 264 """ 265 groups = [] 266 rank_list = np.split(np.arange(world_size), num_groups) 267 rank_list = [list(map(int, x)) for x in rank_list] 268 for i in range(num_groups): 269 groups.append(dist.new_group(rank_list[i])) 270 group_size = world_size // num_groups 271 return groups[rank // group_size] 272 273 274def to_ddp_config(cfg: EasyDict) -> EasyDict: 275 """ 276 Overview: 277 Convert the config to ddp config 278 Arguments: 279 - cfg (:obj:`EasyDict`): The config to be converted 280 """ 281 282 w = get_world_size() 283 if 'batch_size' in cfg.policy: 284 cfg.policy.batch_size = int(np.ceil(cfg.policy.batch_size / w)) 285 if 'batch_size' in cfg.policy.learn: 286 cfg.policy.learn.batch_size = int(np.ceil(cfg.policy.learn.batch_size / w)) 287 if 'n_sample' in cfg.policy.collect: 288 cfg.policy.collect.n_sample = int(np.ceil(cfg.policy.collect.n_sample / w)) 289 if 'n_episode' in cfg.policy.collect: 290 cfg.policy.collect.n_episode = int(np.ceil(cfg.policy.collect.n_episode / w)) 291 return cfg