ding.torch_utils.nn_test_helper¶
ding.torch_utils.nn_test_helper
¶
is_differentiable(loss, model, print_instead=False)
¶
Overview
Judge whether the model/models are differentiable. First check whether module's grad is None, then do loss's back propagation, finally check whether module's grad are torch.Tensor.
Arguments:
- loss (:obj:torch.Tensor): loss tensor of the model
- model (:obj:Union[torch.nn.Module, List[torch.nn.Module]]): model or models to be checked
- print_instead (:obj:bool): Whether to print module's final grad result, instead of asserting. Default set to False.
Full Source Code
../ding/torch_utils/nn_test_helper.py
1from typing import Union, List 2import torch 3 4 5def is_differentiable( 6 loss: torch.Tensor, model: Union[torch.nn.Module, List[torch.nn.Module]], print_instead: bool = False 7) -> None: 8 """ 9 Overview: 10 Judge whether the model/models are differentiable. First check whether module's grad is None, 11 then do loss's back propagation, finally check whether module's grad are torch.Tensor. 12 Arguments: 13 - loss (:obj:`torch.Tensor`): loss tensor of the model 14 - model (:obj:`Union[torch.nn.Module, List[torch.nn.Module]]`): model or models to be checked 15 - print_instead (:obj:`bool`): Whether to print module's final grad result, \ 16 instead of asserting. Default set to ``False``. 17 """ 18 assert isinstance(loss, torch.Tensor) 19 if isinstance(model, list): 20 for m in model: 21 assert isinstance(m, torch.nn.Module) 22 for k, p in m.named_parameters(): 23 assert p.grad is None, k 24 elif isinstance(model, torch.nn.Module): 25 for k, p in model.named_parameters(): 26 assert p.grad is None, k 27 else: 28 raise TypeError('model must be list or nn.Module') 29 30 loss.backward() 31 32 if isinstance(model, list): 33 for m in model: 34 for k, p in m.named_parameters(): 35 if print_instead: 36 if not isinstance(p.grad, torch.Tensor): 37 print(k, "grad is:", p.grad) 38 else: 39 assert isinstance(p.grad, torch.Tensor), k 40 elif isinstance(model, torch.nn.Module): 41 for k, p in model.named_parameters(): 42 if print_instead: 43 if not isinstance(p.grad, torch.Tensor): 44 print(k, "grad is:", p.grad) 45 else: 46 assert isinstance(p.grad, torch.Tensor), k