Skip to content

ding.torch_utils.dataparallel

ding.torch_utils.dataparallel

DataParallel

Bases: DataParallel

Overview

A wrapper class for nn.DataParallel.

Interfaces: __init__, parameters

__init__(module, device_ids=None, output_device=None, dim=0)

Overview

Initialize the DataParallel object.

Arguments: - module (:obj:nn.Module): The module to be parallelized. - device_ids (:obj:list): The list of GPU ids. - output_device (:obj:int): The output GPU id. - dim (:obj:int): The dimension to be parallelized.

parameters(recurse=True)

Overview

Return the parameters of the module.

Arguments: - recurse (:obj:bool): Whether to return the parameters of the submodules. Returns: - params (:obj:generator): The generator of the parameters.

Full Source Code

../ding/torch_utils/dataparallel.py

1import torch 2import torch.nn as nn 3 4 5class DataParallel(nn.DataParallel): 6 """ 7 Overview: 8 A wrapper class for nn.DataParallel. 9 Interfaces: 10 ``__init__``, ``parameters`` 11 """ 12 13 def __init__(self, module, device_ids=None, output_device=None, dim=0): 14 """ 15 Overview: 16 Initialize the DataParallel object. 17 Arguments: 18 - module (:obj:`nn.Module`): The module to be parallelized. 19 - device_ids (:obj:`list`): The list of GPU ids. 20 - output_device (:obj:`int`): The output GPU id. 21 - dim (:obj:`int`): The dimension to be parallelized. 22 """ 23 super().__init__(module, device_ids=None, output_device=None, dim=0) 24 self.module = module 25 26 def parameters(self, recurse: bool = True): 27 """ 28 Overview: 29 Return the parameters of the module. 30 Arguments: 31 - recurse (:obj:`bool`): Whether to return the parameters of the submodules. 32 Returns: 33 - params (:obj:`generator`): The generator of the parameters. 34 """ 35 return self.module.parameters(recurse=True)