ding.torch_utils.dataparallel¶
ding.torch_utils.dataparallel
¶
DataParallel
¶
Bases: DataParallel
Overview
A wrapper class for nn.DataParallel.
Interfaces:
__init__, parameters
__init__(module, device_ids=None, output_device=None, dim=0)
¶
Overview
Initialize the DataParallel object.
Arguments:
- module (:obj:nn.Module): The module to be parallelized.
- device_ids (:obj:list): The list of GPU ids.
- output_device (:obj:int): The output GPU id.
- dim (:obj:int): The dimension to be parallelized.
parameters(recurse=True)
¶
Overview
Return the parameters of the module.
Arguments:
- recurse (:obj:bool): Whether to return the parameters of the submodules.
Returns:
- params (:obj:generator): The generator of the parameters.
Full Source Code
../ding/torch_utils/dataparallel.py
1import torch 2import torch.nn as nn 3 4 5class DataParallel(nn.DataParallel): 6 """ 7 Overview: 8 A wrapper class for nn.DataParallel. 9 Interfaces: 10 ``__init__``, ``parameters`` 11 """ 12 13 def __init__(self, module, device_ids=None, output_device=None, dim=0): 14 """ 15 Overview: 16 Initialize the DataParallel object. 17 Arguments: 18 - module (:obj:`nn.Module`): The module to be parallelized. 19 - device_ids (:obj:`list`): The list of GPU ids. 20 - output_device (:obj:`int`): The output GPU id. 21 - dim (:obj:`int`): The dimension to be parallelized. 22 """ 23 super().__init__(module, device_ids=None, output_device=None, dim=0) 24 self.module = module 25 26 def parameters(self, recurse: bool = True): 27 """ 28 Overview: 29 Return the parameters of the module. 30 Arguments: 31 - recurse (:obj:`bool`): Whether to return the parameters of the submodules. 32 Returns: 33 - params (:obj:`generator`): The generator of the parameters. 34 """ 35 return self.module.parameters(recurse=True)