ding.torch_utils.model_helper¶
ding.torch_utils.model_helper
¶
get_num_params(model)
¶
Overview
Return the number of parameters in the model.
Arguments:
- model (:obj:torch.nn.Module): The model object to calculate the parameter number.
Returns:
- n_params (:obj:int): The calculated number of parameters.
Examples:
>>> model = torch.nn.Linear(3, 5)
>>> num = get_num_params(model)
>>> assert num == 15
Full Source Code
../ding/torch_utils/model_helper.py
1import torch 2 3 4def get_num_params(model: torch.nn.Module) -> int: 5 """ 6 Overview: 7 Return the number of parameters in the model. 8 Arguments: 9 - model (:obj:`torch.nn.Module`): The model object to calculate the parameter number. 10 Returns: 11 - n_params (:obj:`int`): The calculated number of parameters. 12 Examples: 13 >>> model = torch.nn.Linear(3, 5) 14 >>> num = get_num_params(model) 15 >>> assert num == 15 16 """ 17 return sum(p.numel() for p in model.parameters())