大模型高效微调方法原理总结

181 阅读2分钟

Soft Prompts

soft prompt & hard prompt

hard prompt,简单来说,就是我们平时写的中文或者英文,这是一种离散的token(discrete token)
soft prompt,通过在模型中添加可学习的、连续的向量,以影响模型的输出。

Prompt-tuning

prompt-tuning的原理是在输入层前加一段soft prompt,在训练过程中冻结原始模型所有的参数,只更新优化添加的参数。

prefix tuning

在模型的每一层都可以添加可训练的前缀,一般是在K和V前进行添加。训练时同样冻结原始模型中的所有参数,只更新优化添加的参数。

P-tuning

P-tuning有两个版本,分别是P-tuning v1和P-tuning v2
P-tuning v1是先让初始化的soft prompt通过一个简单的LSTM/MLP神经网络,将soft prompt转换成更适合模型使用的表示,然后再添加到输入层之前
但是v1存在部分缺点,如缺乏模型通用性和任务普遍性。所以后面又提出了v2来改善这种情况
P-tuning v2改动其实很简单,就是从只添加到输入层前修改为每一层前都可以添加soft prompt

低秩适配器

Lora

Lora的核心原理是在原始权重矩阵旁边添加两个低秩矩阵AABB,作为可学习的参数参加前向传播,前向传播的权重可以表示为W+ΔWW+\Delta{W},其中ΔW=AB\Delta{W}=AB。在后向传播更新参数的过程中,原始权重矩阵WW被冻结,只更新ΔW\Delta{W}的权重,从而大大减少参数的训练量

低秩矩阵初始化问题

A矩阵会服从随机高斯分布进行初始化,B矩阵则是初始化为全0矩阵。

lora代码实现

import torch
imort math
import torch.nn as nn
import torch.nn.functional as F

class LoraLinear(nn.Module):
    def __init__(self,in_features,out_features,merge,rank=16,lora_alpha=16,dropout=0.5):
        super(LoraLinear).__init__()
        self.in_features=in_features
        self.out_features=out_features
        self.merge=merge
        self.rank=rank
        self.lora_alpha=lora_alpha
        self.dropout_rate=dropout
        self.linear=nn.Linear(in_features,out_features)
        if rank > 0:
            #全零初始化
            self.lora_a=nn.Parameter(torch.zeros(rank,in_features))
            self.lora_b=nn.Parameter(torch.zeros(out_features,rank))
            self.scale=self.lora_alpha//self.rank
            self.linear.weight.require_grad=False
        if dropout > 0:
            self.dropout_rate=nn.Dropout(self.dropout_rate)
        else:
            self.dropout=nn.Identity()
        self.initial_weights()
    def initial_weights():
        nn.init.kaiming_uniform(self.lora_a,a=math.sqrt(5))
        nn.init.zeros(self.lora_b)
    def forward(self,x):
         if self.rank > 0 and self.merge:
             output=F.linear(x,self.linear.weight+self.lora_b@self.lora_a * self.scale,self.linear.bias)
             output=self.dropout(output)
             return output
         else:
             return self.dropout(self.linear(x))