ding.model.template.language_transformer¶
ding.model.template.language_transformer
¶
LanguageTransformer
¶
Bases: Module
Overview
The LanguageTransformer network. Download a pre-trained language model and add head on it. In the default case, we use BERT model as the text encoder, whose bi-directional character is good for obtaining the embedding of the whole sentence.
Interfaces:
__init__, forward
__init__(model_name='bert-base-uncased', add_linear=False, embedding_size=128, freeze_encoder=True, hidden_dim=768, norm_embedding=False)
¶
Overview
Init the LanguageTransformer Model according to input arguments.
Arguments:
- model_name (:obj:str): The base language model name in huggingface, such as "bert-base-uncased".
- add_linear (:obj:bool): Whether to add a linear layer on the top of language model, defaults to be False.
- embedding_size (:obj:int): The embedding size of the added linear layer, such as 128.
- freeze_encoder (:obj:bool): Whether to freeze the encoder language model while training, defaults to be True.
- hidden_dim (:obj:int): The embedding dimension of the encoding model (e.g. BERT). This value should correspond to the model you use. For bert-base-uncased, this value is 768.
- norm_embedding (:obj:bool): Whether to normalize the embedding vectors. Default to be False.
forward(train_samples, candidate_samples=None, mode='compute_actor')
¶
Overview
LanguageTransformer forward computation graph, input two lists of strings and predict their matching scores.
Different mode will forward with different network modules to get different outputs.
Arguments:
- train_samples (:obj:List[str]): One list of strings.
- candidate_samples (:obj:Optional[List[str]]): The other list of strings to calculate matching scores.
- - mode (:obj:str): The forward mode, all the modes are defined in the beginning of this class.
Returns:
- output (:obj:Dict): Output dict data, including the logit of matching scores and the corresponding torch.distributions.Categorical object.
Examples:
>>> test_pids = [1]
>>> cand_pids = [0, 2, 4]
>>> problems = [ "This is problem 0", "This is the first question", "Second problem is here", "Another problem", "This is the last problem" ]
>>> ctxt_list = [problems[pid] for pid in test_pids]
>>> cands_list = [problems[pid] for pid in cand_pids]
>>> model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256)
>>> scores = model(ctxt_list, cands_list)
>>> assert scores.shape == (1, 3)
Full Source Code
../ding/model/template/language_transformer.py