论文链接:[2312.00752] Mamba: Linear-Time Sequence Modeling with Selective State Spaces
code: GitHub - state-spaces/mamba: Mamba SSM architecture
关于mamba的环境配置可以参考另一篇博客。
Mamba是具有一次线下复杂度的序列建模方案,被视为是缓解transformer N(O^2)复杂度的潜在解决方案。
关于SSM --> S4模型 (离散化,RNN表达式,卷积表达式)--> S6模型 (selective scan, parallelization)在文章中有详细介绍,同时也推荐去看对应的论文原文。本文将从代码层面详细去介绍Mamba、Vmamba、VisionMamba的相关机制。让我们开始吧!
一、selective_scan_fn() 函数解析
mamba中实现了selective scan,其中核心的算子是通过cuda编程实现的,其源码位于mamba源码中的下列目录,具有cuda编程基础的同学可以自行查阅。
src/selective_scan
cuda编程基础较弱的同学也没关系,只需要记住selective scan实现了SSM的所有过程,输入data:[B,L,D], 算子从第一个token开始,由ssm的公式依次计算中间状态并生成当前时刻的输出,以此类推知道产生L个输出结果,再堆叠起来得到了最终的输出output: [B,L,D]。
cuda代码被编译为一个Python包,其核心函数是
import selective_scan_cuda
在selective_scan_interface.py文件中对selective_scan_cuda这个函数进行了封装,同时mamba也给出了selective_scan_cuda的Python实现版本selective_scan_ref,其和selective_scan_cuda在功能上都是一样的,只是速度上的区别。接下来我们详细看一下这个文件中的代码。
为了能够理解selective scan,我们先从selective_scan_ref开始理解:
def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False): """ 下列r代表real,即动态变化的,即S6模型,c代表constant,全局共享,即S4模型 u: r(B D L) 输入序列 delta: r(B D L) 控制时间步的增量 A: c(D N) or r(D N) B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) D: r(D) z: r(B D L) delta_bias: r(D), fp32 out: r(B D L) last_state (optional): r(B D dstate) or c(B D dstate) """ dtype_in = u.dtype u = u.float() # [B D L] delta = delta.float() # [B D L] if delta_bias is not None: # 给delta的每个通道加上独立的偏置,如果delta太小或者太大,整体的状态更新会失效(爆炸或者消失) delta = delta + delta_bias[..., None].float() if delta_softplus: # softplus(x) = log(1 + exp(x)) 小x的时候 softplus(x) ≈ 0,大x的时候 softplus(x) ≈ x。确保 delta 是正数!! delta = F.softplus(delta) batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] is_variable_B = B.dim() >= 3 is_variable_C = C.dim() >= 3 if A.is_complex():# 判断A是不是复数类型,S4版本是用复数矩阵做状态演化,S6使用实数做状态演化 if is_variable_B: B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) if is_variable_C: C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) else: B = B.float() C = C.float() x = A.new_zeros((batch, dim, dstate)) # x代表隐藏状态,创建一个全0张量 ys = [] # 对应离散化SSM公式中的A_ba计算 deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) if not is_variable_B: deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) else: if B.dim() == 3: # 对应离散化SSM公式中的B_ba*x,可以看到这里对B_ba进行了简化,这个简化是基于delta一般很小, # 在delta趋向于0的时候省略那一部分趋向于1(洛必达法则可简单证明) deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) else: B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) if is_variable_C and C.dim() == 4: C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) last_state = None for i in range(u.shape[2]): # u.shape[2]==L # 对应于离散化SSM公式中隐藏状态更新 x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: y = torch.einsum('bdn,dn->bd', x, C) # bd else: if C.dim() == 3: # 对应于输出当前时间步的输出 y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) else: y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) if i == u.shape[2] - 1: # 最后一个时间步 last_state = x if y.is_complex(): y = y.real * 2 ys.append(y) # 将每个时间步的输出放在一个列表中,y的shape是b,d y = torch.stack(ys, dim=2) # (batch dim L) out = y if D is None else y + u * rearrange(D, "d -> d 1") # 跳跃连接 if z is not None: out = out * F.silu(z) out = out.to(dtype=dtype_in) # 返回 return out if not return_last_state else (out, last_state)
为了便于理解selective_scan_ref在干什么,这里将离散化的SSM公式贴出来:
上述代码添加了必要的注释,对应着公式仔细阅读可以理解selective_scan_ref以及selective_scan_cuda的工作。
理解了selective_scan_ref以及selective_scan_cuda的工作后,我们会发现seldctiveScanFn和selective_scan_fn都是对selective_scan_cuda的封装,具体代码这里不再赘述。但是我们会发现在selective_scan_interface.py文件中还有一个ManbaInerFn以及mamba_inner_fn函数,这个函数和seldctiveScanFn以及selective_scan_fn又有什么关系呢?
简单说MambaInnerFn负责前面卷积 ➔ 特征分解(得到 delta、B、C) ➔ 归一化 ➔ SelectiveScan 的完整链路。SelectiveScanFn是:你已经有准备好的 u, delta, A, B, C,我就直接帮你算 selective scan,它是一个纯计算核,更加底层、更加简单。
其实我们会发现这两个类都是对selective_scan_fn的封装,SelectiveScanFn需要提供u, delta, A, B, C,而MambaInnerFn只需要提供输入数据x就行,其他的都在内部帮你完成,如果只是想用mamba的扫描可以使用amba_inner_fn,但是如果想对mamba进行自行修改selective_scan_ref将更加灵活。
二、如何使用selective_scan_fn()
了解了核心算子selective_scan_fn的工作原理我们可以来看到底该如何使用它,我们直接来看mamba_simple.py文件中的Mamba类。这个类代码量较大,我们挑重点来讲:
- 如何组织u,delta,A,B,C这些参数从而调用selective_scan_fn
很多人第一次接触mamba的时候很疑惑为什么要设置dt_rank,为什么要使用conv1d或者causal-conv1d。以下是简化过后的mamba代码:
class Mamba(nn.Module): def __init__( self, d_model,d_state=16,d_conv=4,expand=2,dt_rank="auto",dt_min=0.001,dt_max=0.1,dt_init="random",dt_scale=1.0 ): super().__init__() self.d_inner = int(self.expand * self.d_model) self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank # 这个线性层用于对输入的x进行投影得到xz,其实就是两份x,后续会分成x,z self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) # 这个卷积用于对输入的序列进行一维卷积(有的也用causal-conv1d),目的是提取序列数据的局部特征
# 即使用一维卷积提取局部特征,使用mamba建模全部特征
self.conv1d = nn.Conv1d( in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias, kernel_size=d_conv, groups=self.d_inner, padding=d_conv - 1, **factory_kwargs, ) # x将投影到 self.dt_rank + self.d_state * 2这个维度去构造dt,B,C self.x_proj = nn.Linear( self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs )
# 将dt从dt_rank 投影到d_inner,注意最终的dt的shape就是[B,L,d_inner]
# 相当于dt的构造是先投影到dt_rank再投影到d_inner,部分资料乘dt_rank为dt的秩
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) # 这一段是对dt以及其偏置进行初始化 dt = torch.exp( torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ).clamp(min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): self.dt_proj.bias.copy_(inv_dt) # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit self.dt_proj.bias._no_reinit = True
# S4D real initialization这一段是初始化A,采用的方法是HiPPO矩阵 A = repeat( torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), "n -> d n", d=self.d_inner, ).contiguous() A_log = torch.log(A) # Keep A_log in fp32 self.A_log = nn.Parameter(A_log) self.A_log._no_weight_decay = True # D "skip" parameter self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 self.D._no_weight_decay = True # 将最终的结果投影会d_model self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) def forward(self, hidden_states, inference_params=None): """ hidden_states: (B, L, D) Returns: same shape as hidden_states """ batch, seqlen, dim = hidden_states.shape # *****************************
# 这一段代码是执行自回归任务时使用的代码,一个一个token的处理
# step 函数:它主要在需要逐步解码或逐时刻计算时使用,每次处理一个时间步(即每次输入一个 token)。这在生成任务中比较常见,比如自回归模型。
conv_state, ssm_state = None, None if inference_params is not None: conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) if inference_params.seqlen_offset > 0: # The states are updated inplace out, _, _ = self.step(hidden_states, conv_state, ssm_state) return out
# **************************** # We do matmul and transpose BLH -> HBL at the same time xz = rearrange( self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), "d (b l) -> b d l", l=seqlen, ) if self.in_proj.bias is not None: xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # 这一段使用快速路径,只需要传入A和xz,其他在函数内部处理 if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states out = mamba_inner_fn( xz, self.conv1d.weight, self.conv1d.bias, self.x_proj.weight, self.dt_proj.weight, self.out_proj.weight, self.out_proj.bias, A, None, # input-dependent B None, # input-dependent C self.D.float(), delta_bias=self.dt_proj.bias.float(), delta_softplus=True, ) else: x, z = xz.chunk(2, dim=1) # Compute short convolution if conv_state is not None: # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) if causal_conv1d_fn is None: x = self.act(self.conv1d(x)[..., :seqlen]) else:# 下面是手动创建相关参数的模式 assert self.activation in ["silu", "swish"] x = causal_conv1d_fn( x=x, weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), bias=self.conv1d.bias, activation=self.activation, ) # We're careful here about the layout, to avoid extra transposes. # We want dt to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt = self.dt_proj.weight @ dt.t() dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() assert self.activation in ["silu", "swish"] y = selective_scan_fn( x, dt, A, B, C, self.D.float(), z=z, delta_bias=self.dt_proj.bias.float(), delta_softplus=True, return_last_state=ssm_state is not None, ) if ssm_state is not None: y, last_state = y ssm_state.copy_(last_state) y = rearrange(y, "b d l -> b l d") out = self.out_proj(y) return out
其实上述过程仔细读来就是通过输入x如何去构建selective_scan_fn所需要的参数,然后调用其对数据进行处理。