Skip to content

ding.model.template.language_transformer

ding.model.template.language_transformer

LanguageTransformer

Bases: Module

Overview

The LanguageTransformer network. Download a pre-trained language model and add head on it. In the default case, we use BERT model as the text encoder, whose bi-directional character is good for obtaining the embedding of the whole sentence.

Interfaces: __init__, forward

__init__(model_name='bert-base-uncased', add_linear=False, embedding_size=128, freeze_encoder=True, hidden_dim=768, norm_embedding=False)

Overview

Init the LanguageTransformer Model according to input arguments.

Arguments: - model_name (:obj:str): The base language model name in huggingface, such as "bert-base-uncased". - add_linear (:obj:bool): Whether to add a linear layer on the top of language model, defaults to be False. - embedding_size (:obj:int): The embedding size of the added linear layer, such as 128. - freeze_encoder (:obj:bool): Whether to freeze the encoder language model while training, defaults to be True. - hidden_dim (:obj:int): The embedding dimension of the encoding model (e.g. BERT). This value should correspond to the model you use. For bert-base-uncased, this value is 768. - norm_embedding (:obj:bool): Whether to normalize the embedding vectors. Default to be False.

forward(train_samples, candidate_samples=None, mode='compute_actor')

Overview

LanguageTransformer forward computation graph, input two lists of strings and predict their matching scores. Different mode will forward with different network modules to get different outputs.

Arguments: - train_samples (:obj:List[str]): One list of strings. - candidate_samples (:obj:Optional[List[str]]): The other list of strings to calculate matching scores. - - mode (:obj:str): The forward mode, all the modes are defined in the beginning of this class. Returns: - output (:obj:Dict): Output dict data, including the logit of matching scores and the corresponding torch.distributions.Categorical object.

Examples:

>>> test_pids = [1]
>>> cand_pids = [0, 2, 4]
>>> problems = [                 "This is problem 0", "This is the first question", "Second problem is here", "Another problem",                 "This is the last problem"             ]
>>> ctxt_list = [problems[pid] for pid in test_pids]
>>> cands_list = [problems[pid] for pid in cand_pids]
>>> model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256)
>>> scores = model(ctxt_list, cands_list)
>>> assert scores.shape == (1, 3)

Full Source Code

../ding/model/template/language_transformer.py

1from typing import List, Dict, Optional 2import torch 3from torch import nn 4 5try: 6 from transformers import AutoTokenizer, AutoModelForTokenClassification 7except ImportError: 8 from ditk import logging 9 logging.warning("not found transformer, please install it using: pip install transformers") 10from ding.utils import MODEL_REGISTRY 11 12 13@MODEL_REGISTRY.register('language_transformer') 14class LanguageTransformer(nn.Module): 15 """ 16 Overview: 17 The LanguageTransformer network. Download a pre-trained language model and add head on it. 18 In the default case, we use BERT model as the text encoder, whose bi-directional character is good 19 for obtaining the embedding of the whole sentence. 20 Interfaces: 21 ``__init__``, ``forward`` 22 """ 23 mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] 24 25 def __init__( 26 self, 27 model_name: str = "bert-base-uncased", 28 add_linear: bool = False, 29 embedding_size: int = 128, 30 freeze_encoder: bool = True, 31 hidden_dim: int = 768, 32 norm_embedding: bool = False 33 ) -> None: 34 """ 35 Overview: 36 Init the LanguageTransformer Model according to input arguments. 37 Arguments: 38 - model_name (:obj:`str`): The base language model name in huggingface, such as "bert-base-uncased". 39 - add_linear (:obj:`bool`): Whether to add a linear layer on the top of language model, defaults to be \ 40 ``False``. 41 - embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128. 42 - freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \ 43 defaults to be ``True``. 44 - hidden_dim (:obj:`int`): The embedding dimension of the encoding model (e.g. BERT). This value should \ 45 correspond to the model you use. For bert-base-uncased, this value is 768. 46 - norm_embedding (:obj:`bool`): Whether to normalize the embedding vectors. Default to be ``False``. 47 """ 48 super().__init__() 49 self.tokenizer = AutoTokenizer.from_pretrained(model_name) 50 self.model = AutoModelForTokenClassification.from_pretrained(model_name) 51 in_channel = hidden_dim if not add_linear else embedding_size 52 self.value_head = nn.Linear(in_channel, 1) 53 self.norm = nn.Identity() if not norm_embedding else nn.LayerNorm( 54 normalized_shape=in_channel, elementwise_affine=False 55 ) 56 57 # Freeze transformer encoder and only train the linear layer 58 if freeze_encoder: 59 for param in self.model.parameters(): 60 param.requires_grad = False 61 62 if add_linear: 63 # Add a small, adjustable linear layer on top of language model tuned through RL 64 self.embedding_size = embedding_size 65 self.linear = nn.Linear(self.model.config.hidden_size, embedding_size) 66 else: 67 self.linear = None 68 69 def _calc_embedding(self, x: list) -> torch.Tensor: 70 # ``truncation=True`` means that if the length of the prompt exceed the ``max_length`` of the tokenizer, 71 # the exceeded part will be truncated. ``padding=True`` means that if the length of the prompt does not reach 72 # the ``max_length``, the latter part will be padded. These settings ensure the length of encoded tokens is 73 # exactly ``max_length``, which can enable batch-wise computing. 74 input = self.tokenizer(x, truncation=True, padding=True, return_tensors="pt").to(self.model.device) 75 output = self.model(**input, output_hidden_states=True) 76 # Get last layer hidden states 77 last_hidden_states = output.hidden_states[-1] 78 # Get [CLS] hidden states 79 sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size 80 sentence_embedding = self.norm(sentence_embedding) 81 82 if self.linear: 83 sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size 84 85 return sentence_embedding 86 87 def forward( 88 self, 89 train_samples: List[str], 90 candidate_samples: Optional[List[str]] = None, 91 mode: str = 'compute_actor' 92 ) -> Dict: 93 """ 94 Overview: 95 LanguageTransformer forward computation graph, input two lists of strings and predict their matching scores. 96 Different ``mode`` will forward with different network modules to get different outputs. 97 Arguments: 98 - train_samples (:obj:`List[str]`): One list of strings. 99 - candidate_samples (:obj:`Optional[List[str]]`): The other list of strings to calculate matching scores. 100 - - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. 101 Returns: 102 - output (:obj:`Dict`): Output dict data, including the logit of matching scores and the \ 103 corresponding ``torch.distributions.Categorical`` object. 104 105 Examples: 106 >>> test_pids = [1] 107 >>> cand_pids = [0, 2, 4] 108 >>> problems = [ \ 109 "This is problem 0", "This is the first question", "Second problem is here", "Another problem", \ 110 "This is the last problem" \ 111 ] 112 >>> ctxt_list = [problems[pid] for pid in test_pids] 113 >>> cands_list = [problems[pid] for pid in cand_pids] 114 >>> model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256) 115 >>> scores = model(ctxt_list, cands_list) 116 >>> assert scores.shape == (1, 3) 117 """ 118 assert mode in self.mode 119 prompt_embedding = self._calc_embedding(train_samples) 120 121 res_dict = {} 122 if mode in ['compute_actor', 'compute_actor_critic']: 123 cands_embedding = self._calc_embedding(candidate_samples) 124 scores = torch.mm(prompt_embedding, cands_embedding.t()) 125 res_dict.update({'dist': torch.distributions.Categorical(logits=scores), 'logit': scores}) 126 if mode in ['compute_critic', 'compute_actor_critic']: 127 value = self.value_head(prompt_embedding) 128 res_dict.update({'value': value}) 129 return res_dict