Skip to content

ding.torch_utils.model_helper

ding.torch_utils.model_helper

get_num_params(model)

Overview

Return the number of parameters in the model.

Arguments: - model (:obj:torch.nn.Module): The model object to calculate the parameter number. Returns: - n_params (:obj:int): The calculated number of parameters. Examples: >>> model = torch.nn.Linear(3, 5) >>> num = get_num_params(model) >>> assert num == 15

Full Source Code

../ding/torch_utils/model_helper.py

1import torch 2 3 4def get_num_params(model: torch.nn.Module) -> int: 5 """ 6 Overview: 7 Return the number of parameters in the model. 8 Arguments: 9 - model (:obj:`torch.nn.Module`): The model object to calculate the parameter number. 10 Returns: 11 - n_params (:obj:`int`): The calculated number of parameters. 12 Examples: 13 >>> model = torch.nn.Linear(3, 5) 14 >>> num = get_num_params(model) 15 >>> assert num == 15 16 """ 17 return sum(p.numel() for p in model.parameters())