Transformers源码分析-BERT(Pytorch)

1,098 阅读8分钟
Transofrmers版本4.21
Pytorch版本的Bert源码路径位于:

transformers/src/transformers/models/bert/modeling_bert.py



modeling_bert.py中的所有类

这里只考虑预训练的代码,毕竟预训练才是最重要的,一通百通。即 BertForPretraining这个类

获取类调用的关系图UML图

图中左侧箭头对应NSP任务的处理流程,右侧部分对应BERT的常规处理,以得到每一个token对应的向量(其中包括了[CLS]和[SEP])。

按照上面的关系图-正文开始

BertForPretraining

BertForPretraining 作为预训练的入口类别调用了两个类,一个用于正常的Bert计算,另一个用于取CLS向量计算NSP(Next sentence predict)二分类任务。

参数意义
bert计算每个TOKEN的向量
cls拿BERT的[CLS]向量,去做NSP任务

这里前向传播很简单,不贴代码了

BertModel

BertModel 首先将输入计算得到了Embdings,之所以带s是因为包含了不止一个Embdding,之后输入了Encoder中进行计算。

参数意义
Embdings调用BertEmbeddings对输入进行编码
encoderBert中的Encoder部分
config加载的配置
BertPooler-
Forward:

源码中的BERT封装了其可以作为Decoder使用的方法,即is_decoder=True.本文不涉及将Bert用作Decoder

def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.config.is_decoder: # 涉及Decoder
            use_cache = use_cache if use_cache is not None else self.config.use_cache
        else:
            use_cache = False

        if input_ids is not None and inputs_embeds is not None: # input_ids不能为None
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None: 
            input_shape = input_ids.size()
        elif inputs_embeds is not None: # 使用预训练的embeding
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        batch_size, seq_length = input_shape
        device = input_ids.device if input_ids is not None else inputs_embeds.device

        # past_key_values_length 也是用于Decoder的情况下,在自回归的条件下记录历史的K,V以减少计算量。
        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
        if attention_mask is None:
            attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)

        if token_type_ids is None: # 若没有输入token_type_ids 默认输入的是一句话。即句子中只有一个[SEP]
            if hasattr(self.embeddings, "token_type_ids"): # 如果编码层做了处理这里取过来扩展一下维度
                buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] # 取句子长度的向量[0,0,0,0,0....]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) 
                token_type_ids = buffered_token_type_ids_expanded
            else:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
       """  
       这部分均为作为Decoder应用的情况下的使用。
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None
        """
        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
       
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
     
        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask, # 这里为None Encoder不需要MASK
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = encoder_outputs[0]
        # 单独映射一下CLS向量。
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            past_key_values=encoder_outputs.past_key_values,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
            cross_attentions=encoder_outputs.cross_attentions,
        )


BertEmbeddings

Bert的句子编码层,分为三层来编码分别是,word_embeddings,token_type_id_embedding,position_embeddins.三者累加构成了最终输入encoder的编码。其中:

  • word_embeddings也就是one-hot形式的token编码。
  • token_type_id_embedding 是用来区分两句话的编码,句子比如句子A长度为3,B长度为4。则对应编码为[1,1,1,0,0,0,0]
  • position_embeddins: 不同于Transformer的正余弦位置编码,Bert使用了绝对的位置即[0,1,2,3,...,seq_length] 然后经过一个Embeding层来学习其中的相对位置编码,再进行累加。
class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        # max_position_embeddings 最大句子长度
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 
        # 
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) # register_buffer pytorch nn.Moudel中的方法 注册一个变量。这里注册了一个1*[0-seqs_length]的矩阵。即绝对位置编码
        if version.parse(torch.__version__) > version.parse("1.6.0"):
            self.register_buffer(
                "token_type_ids",
                torch.zeros(self.position_ids.size(), dtype=torch.long),
                persistent=False,
            )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        past_key_values_length: int = 0,
    ) -> torch.Tensor:
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]
        seq_length = input_shape[1]
        if position_ids is None:
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
        if token_type_ids is None: # 这里和上一节BertModel中的处理一样
            if hasattr(self, "token_type_ids"):
                buffered_token_type_ids = self.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        embeddings = inputs_embeds + token_type_embeddings # 累加-1
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings # 累加-2 
        embeddings = self.LayerNorm(embeddings) # Transformer最早的框架中这里没有进行LayerNorm
        embeddings = self.dropout(embeddings)
        return embeddings

BertEncoder

得到完整的Embeding作为输入后进入BertEncoder,BertEncoder主要处理的EncoderLayer的堆叠过程。此外还提供了一种用算力换显存的操作,即:一般训练模式下,pytorch 每次运算后会保留一些中间变量用于求导,而使用 checkpoint 的函数,则不会保留中间变量,中间变量会在求导时再计算一次,因此减少了显存占用。 官方文档给了更加详细的说明:torch.utils.checkpoint.checkpoint

BertEncoder参数
config配置文件
layernn.ModulList(BertLayer)
gradient_checkpointing是否使用torch.utils.checkpoint.checkpoint
Forward

遍历layer进行forward.

def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
        next_decoder_cache = () if use_cache else None
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
            layer_head_mask = head_mask[i] if head_mask is not None else None
            past_key_value = past_key_values[i] if past_key_values is not None else None
            if self.gradient_checkpointing and self.training:# 这里和torch.util.checkpoint.checkpoint有关,更详细的可以看文档。
            # 主要处理:若选择了获取缓存,则优先使用缓存。
            # 若gradient_checkpointing为True,使用torch.utils.checkpoint.checkpoint来取消保存的梯度。
                if use_cache:
                    logger.warning(
                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                    )
                    use_cache = False

                def create_custom_forward(module): 
                    def custom_forward(*inputs):
                        return module(*inputs, past_key_value, output_attentions)

                    return custom_forward
                # 
                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )

            hidden_states = layer_outputs[0]
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )

BertLayer

下面开始进入Bert的核心部分,可以堆叠的BertLayer.

image.png

BertLayer参数
chunk_size_feed_forward把slefAttention的输出在经过下个线性层之前进行拆分,依次经过线性层以实现缩小显存占用
attentionSlefAttention
is_decoder作为Encoder输入时无效
crossattention同上
seq_len_dim叫chunk_dim更合适,指定chunk的维度,这里指定为1,也就是seqs_length这一维
outputBertOutput封装输出

这个UML图基本可以看出来这一部分的实现流程,即

  • BertLayer实现了encoder的堆叠。
  • BertAttention类中进行了多头注意力的切分,然后调用SelfAttention实现自注意力,再之后调用Selfoutput对Q,K,V操作后的矩阵映射后进行残差和LN操作。
  • BertAttention最后通过Bertoutput和BertIntermediate来封装了FeedForward层。
Forward

贴一下主要步骤

        # 计算attention  self_attention_outputs输出为元组,attention_outputs,attention_map (Q*K^T) 
        
        # 这里的attention_outputs实际上是已经做过残差和layerNorm之后的结果。具体的操作在BertSelfOutput这个类中进行。
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        attention_output = self_attention_outputs[0] # 
                # if decoder, the last output is tuple of self-attn cache
        if self.is_decoder:
            outputs = self_attention_outputs[1:-1]
            present_key_value = self_attention_outputs[-1]
        else:
            outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights 这时Outputs仅为attention map
         # 同样的使用Chunk的方法进行
        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )
        outputs = (layer_output,) + outputs # 合并layeroutputs 和 outputs
        return outputs

    def feed_forward_chunk(self, attention_output): # Feedforward层计算
        intermediate_output = self.intermediate(attention_output) # 线性映射并进行激活
         # 映射回原来维度并进行dropout,dorpout之后,进行残差链接,在进行LayerNorm。 post-LayerNorm(原来的Transformer是Pre-LayerNorm(先LayerNorm,再进行残差链接)
        layer_output = self.output(intermediate_output, attention_output) 
        return layer_output

BertAttention

当BertLayer拿到embdding之后首先进入BertAttention层,也就是hidden_states。在这里hidden_states将经历头的拆分,selfattention的计算、头的合并、再经过一个Linear层,最后和初始的hidden_states进行残差链接进行LayerNorm,便得到了Attention部分的操作。啊,这实际上是完整的Attention的操作。

实际上在这里,进行了进一步的封装,SelfAttention的计算使用了BertSelfAttention来计算以得到attention的输出。BertSelfOuputs来进行残差链接、映射、dropout、layerNorm这些操作。

实际上这里整个脉络已经十分清晰了,仍然值得关注的是对于Transformer,Bert在这一部分做了什么改动以及为什么这样改动?

BertAttention
output封装attention的后处理过程,残差链接、映射、dropout、layerNorm这些操作。
pruned_heads切分多头,计算注意力
selfBertSelfAttention
prune_heads注意力头的剪枝操作
Forward

很简单

        self_outputs = self.self(  # 计算注意力
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        attention_output = self.output(self_outputs[0], hidden_states) # 进行attention后处理
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them 记录attention map

BertSelfAttention&BertSelfOuputs

两个合在一起分析,因为两个合在一起组成了Attention的计算部分。

BertSelfAttention
  • BertSelfAttention: 计算Attentionscore=QKTAttention score=QK^T
    • 为了使得多头注意力具有想图像一样的多通道特性,一定是先进行映射再进行头的拆分,若先拆分在映射则得到的每个头的矩阵都是一样的。
    • 代码考虑了相对位置编码,至于为什么在这里考虑,从T5模型来看,这里的相对位置编码同样是通过Embedding来实现,但是不同在于这里的relation_embdding相对编码矩阵形状是 [batch_size, seqs_len,seqs_len],对应QKTQK^T的形状,而且若使用相对位置编码则 Attentionscore=relationembdding+QKTAttentionscore = relation_{embdding}+QK^T。 这里先不考虑相对位置编码,在T5模型中再拿来看看具体实现。
BertSelfAttention
attention_head_size根据设定的n_head计算得到每个头的维度
queryQuery映射
key同上
value同上
dropout-
position_embedding_type位置编码类型,论文里是绝对,代码中考虑了相对位置编码
is_decoder-
transpose_for_scores拆分多头
    mixed_query_layer = self.query(hidden_states) # 映射Q
    query_layer = self.transpose_for_scores(mixed_query_layer) # 拆分为多头的Q
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
    attention_scores = attention_scores / math.sqrt(self.attention_head_size) 
    attention_probs = nn.functional.softmax(attention_scores, dim=-1) 
    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = self.dropout(attention_probs) # 这里的dropout位置很有意思,直接在token级别进行dorpout, 直观上看这里的dropout可以很有效的避免过拟合。
    

    # Mask heads if we want to
    if head_mask is not None:  # 作为encoder 不考虑attention mask
        attention_probs = attention_probs * head_mask

    context_layer = torch.matmul(attention_probs, value_layer) # 

    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
    context_layer = context_layer.view(new_context_layer_shape)
    outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

    if self.is_decoder:
        outputs = outputs + (past_key_value,)
    return outputs
BertSelfOuputs
  • BertSelfOuputs:封装attention的后处理过程,残差链接、映射、dropout、layerNorm这些操作。
    • 不同于Transformer,这里的所有LayrNorm全是Post-LayerNorm.
BertSelfOuputs参数
LayerNorm带有平滑系数的layerNorm, eps=config.layer_norm_eps
DenseLinear(hidden_size,hidden_size)
Dropout-
    # 清晰明了
    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

BertIntermediate&Bertoutput

经过上面的Attention计算之后我们取得了Attenion的最终结果,然后回到了BertLayer中,合并上FeedForward层就组成了一个可以堆叠的BertLayer。而BertIntermediate&Bertoutput两者共同代表了FeedForwar层,替换了一下激活函数从ReLu变为了Gelu

BertIntermediate
BertIntermediate参数
denseLinear 用于升维
intermediate_act_fn激活函数,使用Gleu
Bertoutput
Bertoutput参数
LayerNorm-
denseLinear,将BertIntermediate升的维度降回原来的。
dropout-
# Gelu源码:
def gelu(input_tensor):
	cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0)))
	return input_tesnsor*cdf


总结:

  • Bert同一代替了transformer中的Pre-LaryNorm,这样降低梯度。
  • Bert使用了绝对位置编码,即0,seqs_length的位置编码,造成了Bert的文本长度收到限制
  • Bert在Feedforward层替换了Relu为Gelu