Mamba: Linear-Time Sequence Modeling with Selective State Spaces

502 阅读7分钟

论文链接:[2312.00752] Mamba: Linear-Time Sequence Modeling with Selective State Spaces

code: GitHub - state-spaces/mamba: Mamba SSM architecture

关于mamba的环境配置可以参考另一篇博客

Mamba是具有一次线下复杂度的序列建模方案,被视为是缓解transformer N(O^2)复杂度的潜在解决方案。

关于SSM --> S4模型 (离散化,RNN表达式,卷积表达式)--> S6模型 (selective scan, parallelization)在文章中有详细介绍,同时也推荐去看对应的论文原文。本文将从代码层面详细去介绍Mamba、Vmamba、VisionMamba的相关机制。让我们开始吧!

一、selective_scan_fn() 函数解析

mamba中实现了selective scan,其中核心的算子是通过cuda编程实现的,其源码位于mamba源码中的下列目录,具有cuda编程基础的同学可以自行查阅。

src/selective_scan 

cuda编程基础较弱的同学也没关系,只需要记住selective scan实现了SSM的所有过程,输入data:[B,L,D], 算子从第一个token开始,由ssm的公式依次计算中间状态并生成当前时刻的输出,以此类推知道产生L个输出结果,再堆叠起来得到了最终的输出output: [B,L,D]。

cuda代码被编译为一个Python包,其核心函数是

import selective_scan_cuda

在selective_scan_interface.py文件中对selective_scan_cuda这个函数进行了封装,同时mamba也给出了selective_scan_cuda的Python实现版本selective_scan_ref,其和selective_scan_cuda在功能上都是一样的,只是速度上的区别。接下来我们详细看一下这个文件中的代码。

为了能够理解selective scan,我们先从selective_scan_ref开始理解:

def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,                      return_last_state=False):    """    下列r代表real,即动态变化的,即S6模型,c代表constant,全局共享,即S4模型    u: r(B D L) 输入序列     delta: r(B D L) 控制时间步的增量    A: c(D N) or r(D N)    B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)    C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)    D: r(D)    z: r(B D L)    delta_bias: r(D), fp32    out: r(B D L)    last_state (optional): r(B D dstate) or c(B D dstate)    """    dtype_in = u.dtype    u = u.float() # [B D L]    delta = delta.float() # [B D L]    if delta_bias is not None:        # 给delta的每个通道加上独立的偏置,如果delta太小或者太大,整体的状态更新会失效(爆炸或者消失)        delta = delta + delta_bias[..., None].float()    if delta_softplus:        # softplus(x) = log(1 + exp(x)) 小x的时候 softplus(x) ≈ 0,大x的时候 softplus(x) ≈ x。确保 delta 是正数!!        delta = F.softplus(delta)    batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]    is_variable_B = B.dim() >= 3    is_variable_C = C.dim() >= 3    if A.is_complex():# 判断A是不是复数类型,S4版本是用复数矩阵做状态演化,S6使用实数做状态演化        if is_variable_B:            B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))        if is_variable_C:            C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))    else:        B = B.float()        C = C.float()    x = A.new_zeros((batch, dim, dstate)) # x代表隐藏状态,创建一个全0张量    ys = []    # 对应离散化SSM公式中的A_ba计算    deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))     if not is_variable_B:        deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)    else:        if B.dim() == 3:            # 对应离散化SSM公式中的B_ba*x,可以看到这里对B_ba进行了简化,这个简化是基于delta一般很小,            # 在delta趋向于0的时候省略那一部分趋向于1(洛必达法则可简单证明)            deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)        else:            B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])            deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)    if is_variable_C and C.dim() == 4:        C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])    last_state = None    for i in range(u.shape[2]): # u.shape[2]==L        # 对应于离散化SSM公式中隐藏状态更新        x = deltaA[:, :, i] * x + deltaB_u[:, :, i]        if not is_variable_C:            y = torch.einsum('bdn,dn->bd', x, C) # bd        else:            if C.dim() == 3:                # 对应于输出当前时间步的输出                y = torch.einsum('bdn,bn->bd', x, C[:, :, i])            else:                y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])        if i == u.shape[2] - 1: # 最后一个时间步            last_state = x         if y.is_complex():            y = y.real * 2        ys.append(y) # 将每个时间步的输出放在一个列表中,y的shape是b,d    y = torch.stack(ys, dim=2) # (batch dim L)    out = y if D is None else y + u * rearrange(D, "d -> d 1") # 跳跃连接    if z is not None:        out = out * F.silu(z)    out = out.to(dtype=dtype_in)    # 返回    return out if not return_last_state else (out, last_state)

为了便于理解selective_scan_ref在干什么,这里将离散化的SSM公式贴出来:

上述代码添加了必要的注释,对应着公式仔细阅读可以理解selective_scan_ref以及selective_scan_cuda的工作。

理解了selective_scan_ref以及selective_scan_cuda的工作后,我们会发现seldctiveScanFn和selective_scan_fn都是对selective_scan_cuda的封装,具体代码这里不再赘述。但是我们会发现在selective_scan_interface.py文件中还有一个ManbaInerFn以及mamba_inner_fn函数,这个函数和seldctiveScanFn以及selective_scan_fn又有什么关系呢?

简单说MambaInnerFn负责前面卷积特征分解(得到 delta、B、C) ➔ 归一化SelectiveScan 的完整链路。SelectiveScanFn是:你已经有准备好的 u, delta, A, B, C,我就直接帮你算 selective scan,它是一个纯计算核,更加底层、更加简单。

其实我们会发现这两个类都是对selective_scan_fn的封装,SelectiveScanFn需要提供u, delta, A, B, C,而MambaInnerFn只需要提供输入数据x就行,其他的都在内部帮你完成,如果只是想用mamba的扫描可以使用amba_inner_fn,但是如果想对mamba进行自行修改selective_scan_ref将更加灵活。

二、如何使用selective_scan_fn()

了解了核心算子selective_scan_fn的工作原理我们可以来看到底该如何使用它,我们直接来看mamba_simple.py文件中的Mamba类。这个类代码量较大,我们挑重点来讲:

  1. 如何组织u,delta,A,B,C这些参数从而调用selective_scan_fn

很多人第一次接触mamba的时候很疑惑为什么要设置dt_rank,为什么要使用conv1d或者causal-conv1d。以下是简化过后的mamba代码:

class Mamba(nn.Module):    def __init__(        self,        d_model,d_state=16,d_conv=4,expand=2,dt_rank="auto",dt_min=0.001,dt_max=0.1,dt_init="random",dt_scale=1.0    ):        super().__init__()        self.d_inner = int(self.expand * self.d_model)        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank        # 这个线性层用于对输入的x进行投影得到xz,其实就是两份x,后续会分成x,z        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)        # 这个卷积用于对输入的序列进行一维卷积(有的也用causal-conv1d),目的是提取序列数据的局部特征
        # 即使用一维卷积提取局部特征,使用mamba建模全部特征
        self.conv1d = nn.Conv1d(            in_channels=self.d_inner,            out_channels=self.d_inner,            bias=conv_bias,            kernel_size=d_conv,            groups=self.d_inner,            padding=d_conv - 1,            **factory_kwargs,        )        # x将投影到 self.dt_rank + self.d_state * 2这个维度去构造dt,B,C        self.x_proj = nn.Linear(            self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs        )

       # 将dt从dt_rank 投影到d_inner,注意最终的dt的shape就是[B,L,d_inner]
       # 相当于dt的构造是先投影到dt_rank再投影到d_inner,部分资料乘dt_rank为dt的秩
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)        # 这一段是对dt以及其偏置进行初始化        dt = torch.exp(            torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))            + math.log(dt_min)        ).clamp(min=dt_init_floor)        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759        inv_dt = dt + torch.log(-torch.expm1(-dt))        with torch.no_grad():            self.dt_proj.bias.copy_(inv_dt)        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit        self.dt_proj.bias._no_reinit = True        

        # S4D real initialization这一段是初始化A,采用的方法是HiPPO矩阵        A = repeat(            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),            "n -> d n",            d=self.d_inner,        ).contiguous()        A_log = torch.log(A)  # Keep A_log in fp32        self.A_log = nn.Parameter(A_log)        self.A_log._no_weight_decay = True        # D "skip" parameter        self.D = nn.Parameter(torch.ones(self.d_inner, device=device))  # Keep in fp32        self.D._no_weight_decay = True        # 将最终的结果投影会d_model        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)    def forward(self, hidden_states, inference_params=None):        """        hidden_states: (B, L, D)        Returns: same shape as hidden_states        """        batch, seqlen, dim = hidden_states.shape        # *****************************
        # 这一段代码是执行自回归任务时使用的代码,一个一个token的处理
        # step 函数:它主要在需要逐步解码或逐时刻计算时使用,每次处理一个时间步(即每次输入一个 token)。这在生成任务中比较常见,比如自回归模型。
        conv_state, ssm_state = None, None        if inference_params is not None:            conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)            if inference_params.seqlen_offset > 0:                # The states are updated inplace                out, _, _ = self.step(hidden_states, conv_state, ssm_state)                return out
        # ****************************        # We do matmul and transpose BLH -> HBL at the same time        xz = rearrange(            self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),            "d (b l) -> b d l",            l=seqlen,        )        if self.in_proj.bias is not None:            xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)        # 这一段使用快速路径,只需要传入A和xz,其他在函数内部处理        if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None:  # Doesn't support outputting the states            out = mamba_inner_fn(                xz,                self.conv1d.weight,                self.conv1d.bias,                self.x_proj.weight,                self.dt_proj.weight,                self.out_proj.weight,                self.out_proj.bias,                A,                None,  # input-dependent B                None,  # input-dependent C                self.D.float(),                delta_bias=self.dt_proj.bias.float(),                delta_softplus=True,            )        else:            x, z = xz.chunk(2, dim=1)            # Compute short convolution            if conv_state is not None:                # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv                # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.                conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))  # Update state (B D W)            if causal_conv1d_fn is None:                x = self.act(self.conv1d(x)[..., :seqlen])            else:# 下面是手动创建相关参数的模式                assert self.activation in ["silu", "swish"]                x = causal_conv1d_fn(                    x=x,                    weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),                    bias=self.conv1d.bias,                    activation=self.activation,                )            # We're careful here about the layout, to avoid extra transposes.            # We want dt to have d as the slowest moving dimension            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.            x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (bl d)            dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)            dt = self.dt_proj.weight @ dt.t()            dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)            B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()            C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()            assert self.activation in ["silu", "swish"]            y = selective_scan_fn(                x,                dt,                A,                B,                C,                self.D.float(),                z=z,                delta_bias=self.dt_proj.bias.float(),                delta_softplus=True,                return_last_state=ssm_state is not None,            )            if ssm_state is not None:                y, last_state = y                ssm_state.copy_(last_state)            y = rearrange(y, "b d l -> b l d")            out = self.out_proj(y)        return out

其实上述过程仔细读来就是通过输入x如何去构建selective_scan_fn所需要的参数,然后调用其对数据进行处理。