Lucidrains-系列项目源码解析-三十三-

63 阅读30分钟

Lucidrains 系列项目源码解析(三十三)

.\lucidrains\equiformer-pytorch\equiformer_pytorch\equiformer_pytorch.py

from math import sqrt
from functools import partial
from itertools import product
from collections import namedtuple

from beartype.typing import Optional, Union, Tuple, Dict
from beartype import beartype

import torch
from torch import nn, is_tensor, Tensor
import torch.nn.functional as F

from taylor_series_linear_attention import TaylorSeriesLinearAttn

from opt_einsum import contract as opt_einsum

from equiformer_pytorch.basis import (
    get_basis,
    get_D_to_from_z_axis
)

from equiformer_pytorch.reversible import (
    SequentialSequence,
    ReversibleSequence
)

from equiformer_pytorch.utils import (
    exists,
    default,
    masked_mean,
    to_order,
    cast_tuple,
    safe_cat,
    fast_split,
    slice_for_centering_y_to_x,
    pad_for_centering_y_to_x
)

from einx import get_at

from einops import rearrange, repeat, reduce, einsum, pack, unpack
from einops.layers.torch import Rearrange

# constants

# 定义一个命名元组,用于返回多个类型
Return = namedtuple('Return', ['type0', 'type1'])

# 定义一个命名元组,用于存储边的信息
EdgeInfo = namedtuple('EdgeInfo', ['neighbor_indices', 'neighbor_mask', 'edges'])

# helpers

# 定义一个函数,将一个张量打包成指定模式
def pack_one(t, pattern):
    return pack([t], pattern)

# 定义一个函数,将一个打包的张量解包成指定模式
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# fiber functions

# 定义一个函数,计算两个fiber的笛卡尔积
@beartype
def fiber_product(
    fiber_in: Tuple[int, ...],
    fiber_out: Tuple[int, ...]
):
    fiber_in, fiber_out = tuple(map(lambda t: [(degree, dim) for degree, dim in enumerate(t)], (fiber_in, fiber_out)))
    return product(fiber_in, fiber_out)

# 定义一个函数,计算两个fiber的交集
@beartype
def fiber_and(
    fiber_in: Tuple[int, ...],
    fiber_out: Tuple[int, ...]
):
    fiber_in = [(degree, dim) for degree, dim in enumerate(fiber_in)]
    fiber_out_degrees = set(range(len(fiber_out))

    out = []
    for degree, dim in fiber_in:
        if degree not in fiber_out_degrees:
            continue

        dim_out = fiber_out[degree]
        out.append((degree, dim, dim_out))

    return out

# helper functions

# 将一个数字分成指定组数的函数
def split_num_into_groups(num, groups):
    num_per_group = (num + groups - 1) // groups
    remainder = num % groups

    if remainder == 0:
        return (num_per_group,) * groups

    return (*((num_per_group,) * remainder), *((((num_per_group - 1),) * (groups - remainder))))

# 获取张量的设备和数据类型函数
def get_tensor_device_and_dtype(features):
    _, first_tensor = next(iter(features.items()))
    return first_tensor.device, first_tensor.dtype

# 计算残差的函数
def residual_fn(x, residual):
    out = {}

    for degree, tensor in x.items():
        out[degree] = tensor

        if degree not in residual:
            continue

        if not any(t.requires_grad for t in (out[degree], residual[degree])):
            out[degree] += residual[degree]
        else:
            out[degree] = out[degree] + residual[degree]

    return out

# 在元组中设置指定索引的值函数
def tuple_set_at_index(tup, index, value):
    l = list(tup)
    l[index] = value
    return tuple(l)

# 获取特征形状的函数
def feature_shapes(feature):
    return tuple(v.shape for v in feature.values())

# 获取特征fiber的函数
def feature_fiber(feature):
    return tuple(v.shape[-2] for v in feature.values())

# 计算两个张量之间的距离函数
def cdist(a, b, dim = -1, eps = 1e-5):
    a = a.expand_as(b)
    a, _ = pack_one(a, '* c')
    b, ps = pack_one(b, '* c')

    dist = F.pairwise_distance(a, b, p = 2)
    dist = unpack_one(dist, ps, '*')
    return dist

# classes

# 定义一个带残差的模块类
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        y = self.fn(x, **kwargs)
        if not y.requires_grad and not x.requires_grad:
            return x.add_(y)
        return x + y

# 定义一个LayerNorm类
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# 定义一个线性层类
class Linear(nn.Module):
    @beartype
    def __init__(
        self,
        fiber_in: Tuple[int, ...],
        fiber_out: Tuple[int, ...
    ):  
        # 调用父类的构造函数
        super().__init__()
        # 初始化权重列表和度数列表
        self.weights = nn.ParameterList([])
        self.degrees = []

        # 遍历输入和输出的纤维组合
        for (degree, dim_in, dim_out) in fiber_and(fiber_in, fiber_out):
            # 将随机初始化的权重添加到权重列表中
            self.weights.append(nn.Parameter(torch.randn(dim_in, dim_out) / sqrt(dim_in)))
            # 将度数添加到度数列表中
            self.degrees.append(degree)

    def init_zero_(self):
        # 将所有权重初始化为零
        for weight in self.weights:
            weight.data.zero_()

    def forward(self, x):
        # 初始化输出字典
        out = {}

        # 遍历度数和权重,进行张量乘法操作
        for degree, weight in zip(self.degrees, self.weights):
            out[degree] = einsum(x[degree], weight, '... d m, d e -> ... e m')

        return out
class Norm(nn.Module):
    @beartype
    def __init__(
        self,
        fiber: Tuple[int, ...],
        eps = 1e-12,
    ):
        """
        deviates from the paper slightly, will use rmsnorm throughout (no mean centering or bias, even for type0 fatures)
        this has been proven at scale for a number of models, including T5 and alphacode
        """

        super().__init__()
        # 设置 eps 参数
        self.eps = eps
        # 初始化 transforms 为一个空的 nn.ParameterList
        self.transforms = nn.ParameterList([])

        # 遍历 fiber 中的每个维度
        for degree, dim in enumerate(fiber):
            # 将每个维度的参数初始化为 1,并添加到 transforms 中
            self.transforms.append(nn.Parameter(torch.ones(dim, 1)))

    def forward(self, features):
        # 初始化输出字典
        output = {}

        # 遍历 transforms 和 features 中的每个元素
        for scale, (degree, t) in zip(self.transforms, features.items()):
            # 获取输入张量的维度
            dim = t.shape[-2]

            # 计算 L2 范数
            l2normed = t.norm(dim = -1, keepdim = True)
            # 计算 RMS 范数
            rms = l2normed.norm(dim = -2, keepdim = True) * (dim ** -0.5)

            # 将处理后的张量存入输出字典
            output[degree] = t / rms.clamp(min = self.eps) * scale

        return output

class Gate(nn.Module):
    @beartype
    def __init__(
        self,
        fiber: Tuple[int, ...]
    ):
        super().__init__()

        # 获取 type0_dim 和 dim_gate
        type0_dim = fiber[0]
        dim_gate = sum(fiber[1:])

        # 确保 type0_dim 大于 dim_gate
        assert type0_dim > dim_gate, 'sum of channels from rest of the degrees must be less than the channels in type 0, as they would be used up for gating and subtracted out'

        # 初始化 Gate 类的属性
        self.fiber = fiber
        self.num_degrees = len(fiber)
        self.type0_dim_split = [*fiber[1:], type0_dim - dim_gate]

    def forward(self, x):
        # 初始化输出字典
        output = {}

        # 获取 type0_tensor
        type0_tensor = x[0]
        # 将 type0_tensor 拆分为 gates 和 type0_tensor
        *gates, type0_tensor = type0_tensor.split(self.type0_dim_split, dim = -2)

        # 对 type 0 使用 silu 激活函数
        output = {0: F.silu(type0_tensor)}

        # 对高阶类型使用 sigmoid gate
        for degree, gate in zip(range(1, self.num_degrees), gates):
            output[degree] = x[degree] * gate.sigmoid()

        return output

class DTP(nn.Module):
    """ 'Tensor Product' - in the equivariant sense """

    @beartype
    def __init__(
        self,
        fiber_in: Tuple[int, ...],
        fiber_out: Tuple[int, ...],
        self_interaction = True,
        project_xi_xj = True,   # whether to project xi and xj and then sum, as in paper
        project_out = True,     # whether to do a project out after the "tensor product"
        pool = True,
        edge_dim = 0,
        radial_hidden_dim = 16
    ):
        super().__init__()
        # 初始化 DTP 类的属性
        self.fiber_in = fiber_in
        self.fiber_out = fiber_out
        self.edge_dim = edge_dim
        self.self_interaction = self_interaction
        self.pool = pool

        self.project_xi_xj = project_xi_xj
        if project_xi_xj:
            # 初始化 Linear 层
            self.to_xi = Linear(fiber_in, fiber_in)
            self.to_xj = Linear(fiber_in, fiber_in)

        self.kernel_unary = nn.ModuleDict()

        # 遍历输出 fiber 中的每个维度和输入 fiber 中的每个维度
        for degree_out, dim_out in enumerate(self.fiber_out):
            num_degrees_in = len(self.fiber_in)
            # 将输出维度拆分为输入维度的组合
            split_dim_out = split_num_into_groups(dim_out, num_degrees_in)

            # 遍历每个输入维度和输出维度的组合
            for degree_in, (dim_in, dim_out_from_degree_in) in enumerate(zip(self.fiber_in, split_dim_out)):
                degree_min = min(degree_out, degree_in)

                # 初始化 Radial 层
                self.kernel_unary[f'({degree_in},{degree_out})'] = Radial(degree_in, dim_in, degree_out, dim_out_from_degree_in, radial_hidden_dim = radial_hidden_dim, edge_dim = edge_dim)

        # 是否进行单个 token 的自交互
        if self_interaction:
            self.self_interact = Linear(fiber_in, fiber_out)

        self.project_out = project_out
        if project_out:
            # 初始化 Linear 层
            self.to_out = Linear(fiber_out, fiber_out)

    @beartype
    def forward(
        self,
        inp,
        basis,
        D,
        edge_info: EdgeInfo,
        rel_dist = None,
        ):
            # 解包边信息
            neighbor_indices, neighbor_masks, edges = edge_info

            # 初始化变量
            kernels = {}
            outputs = {}

            # neighbors

            # 如果需要将输入投影到 xi 和 xj
            if self.project_xi_xj:
                source, target = self.to_xi(inp), self.to_xj(inp)
            else:
                source, target = inp, inp

            # 遍历输入度类型到输出度类型的每种排列
            for degree_out, _ in enumerate(self.fiber_out):
                output = None
                m_out = to_order(degree_out)

                for degree_in, _ in enumerate(self.fiber_in):
                    etype = f'({degree_in},{degree_out})'

                    m_in = to_order(degree_in)
                    m_min = min(m_in, m_out)

                    degree_min = min(degree_in, degree_out)

                    # 获取源和目标(邻居)表示
                    xi, xj = source[degree_in], target[degree_in]

                    x = get_at('b [i] d m, b j k -> b j k d m', xj, neighbor_indices)

                    # 如果需要将 xi 和 xj 投影
                    if self.project_xi_xj:
                        xi = rearrange(xi, 'b i d m -> b i 1 d m')
                        x = x + xi

                    # 乘以 D(R) - 旋转到 z 轴
                    if degree_in > 0:
                        Di = D[degree_in]
                        x = einsum(Di, x, '... mi1 mi2, ... li mi1 -> ... li mi2')

                    # 如果 degree_in != degree_out,则移除一些 0s
                    maybe_input_slice = slice_for_centering_y_to_x(m_in, m_min)
                    maybe_output_pad = pad_for_centering_y_to_x(m_out, m_min)
                    x = x[..., maybe_input_slice]

                    # 在序列维度上按块处理输入、边和基础
                    kernel_fn = self.kernel_unary[etype]
                    edge_features = safe_cat(edges, rel_dist, dim=-1)
                    B = basis.get(etype, None)
                    R = kernel_fn(edge_features)

                    # 如果没有基础
                    if not exists(B):
                        output_chunk = einsum(R, x, '... lo li, ... li mi -> ... lo mi')
                    else:
                        y = x.clone()
                        x = repeat(x, '... mi -> ... mi mf r', mf=(B.shape[-1] + 1) // 2, r=2)
                        x, x_to_flip = x.unbind(dim=-1)
                        x_flipped = torch.flip(x_to_flip, dims=(-2,))
                        x = torch.stack((x, x_flipped), dim=-1)
                        x = rearrange(x, '... mf r -> ... (mf r)', r=2)
                        x = x[..., :-1]
                        output_chunk = opt_einsum('... o i, m f, ... i m f -> ... o m', R, B, x)

                    # 如果 degree_out < degree_in
                    output_chunk = F.pad(output_chunk, (maybe_output_pad, maybe_output_pad), value=0.)
                    output = safe_cat(output, output_chunk, dim=-2)

                # 乘以 D(R^-1) - 从 z 轴旋转回来
                if degree_out > 0:
                    Do = D[degree_out]
                    output = einsum(output, Do, '... lo mo1, ... mo2 mo1 -> ... lo mo2')

                # 沿 j(邻居)维度池化或不池化
                if self.pool:
                    output = masked_mean(output, neighbor_masks, dim=2)

                outputs[degree_out] = output

            # 如果不需要自相互作用且不需要输出投影,则返回输出
            if not self.self_interaction and not self.project_out:
                return outputs

            # 如果需要输出投影
            if self.project_out:
                outputs = self.to_out(outputs)

            self_interact_out = self.self_interact(inp)

            # 如果需要池化
            if self.pool:
                return residual_fn(outputs, self_interact_out)

            self_interact_out = {k: rearrange(v, '... d m -> ... 1 d m') for k, v in self_interact_out.items()}
            outputs = {degree: torch.cat(tensors, dim=-3) for degree, tensors in enumerate(zip(self_interact_out.values(), outputs.values()))}
            return outputs
# 定义一个名为 Radial 的类,继承自 nn.Module
class Radial(nn.Module):
    # 初始化函数,接受输入特征的度数、通道数、输出特征的度数、通道数、边维度和径向隐藏维度等参数
    def __init__(
        self,
        degree_in,
        nc_in,
        degree_out,
        nc_out,
        edge_dim = 0,
        radial_hidden_dim = 64
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 初始化类的属性
        self.degree_in = degree_in
        self.degree_out = degree_out
        self.nc_in = nc_in
        self.nc_out = nc_out

        # 将输出特征的度数转换为对应的顺序
        self.d_out = to_order(degree_out)
        self.edge_dim = edge_dim

        # 设置中间维度为径向隐藏维度
        mid_dim = radial_hidden_dim
        edge_dim = default(edge_dim, 0)

        # 定义径向网络的结构
        self.rp = nn.Sequential(
            nn.Linear(edge_dim + 1, mid_dim),
            nn.SiLU(),
            LayerNorm(mid_dim),
            nn.Linear(mid_dim, mid_dim),
            nn.SiLU(),
            LayerNorm(mid_dim),
            nn.Linear(mid_dim, nc_in * nc_out),
            Rearrange('... (lo li) -> ... lo li', li = nc_in, lo = nc_out)
        )

    # 前向传播函数
    def forward(self, feat):
        # 返回径向网络的前向传播结果
        return self.rp(feat)

# 定义名为 FeedForward 的类,继承自 nn.Module
class FeedForward(nn.Module):
    # 初始化函数,接受输入特征的维度、输出特征的维度、倍数、是否包含类型归一化和是否初始化输出为零等参数
    @beartype
    def __init__(
        self,
        fiber: Tuple[int, ...],
        fiber_out: Optional[Tuple[int, ...]] = None,
        mult = 4,
        include_htype_norms = True,
        init_out_zero = True
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 初始化类的属性
        self.fiber = fiber

        # 计算隐藏层特征的维度
        fiber_hidden = tuple(dim * mult for dim in fiber)

        project_in_fiber = fiber
        project_in_fiber_hidden = tuple_set_at_index(fiber_hidden, 0, sum(fiber_hidden))

        # 根据是否包含类型归一化来调整输入特征的维度
        self.include_htype_norms = include_htype_norms
        if include_htype_norms:
            project_in_fiber = tuple_set_at_index(project_in_fiber, 0, sum(fiber))

        fiber_out = default(fiber_out, fiber)

        # 定义前向传播的结构
        self.prenorm     = Norm(fiber)
        self.project_in  = Linear(project_in_fiber, project_in_fiber_hidden)
        self.gate        = Gate(project_in_fiber_hidden)
        self.project_out = Linear(fiber_hidden, fiber_out)

        # 如果初始化输出为零,则将输出初始化为零
        if init_out_zero:
            self.project_out.init_zero_()

    # 前向传播函数
    def forward(self, features):
        # 对输入特征进行预归一化
        outputs = self.prenorm(features)

        # 如果包含类型归一化,则对类型进行归一化
        if self.include_htype_norms:
            type0, *htypes = [*outputs.values()]
            htypes = map(lambda t: t.norm(dim = -1, keepdim = True), htypes)
            type0 = torch.cat((type0, *htypes), dim = -2)
            outputs[0] = type0

        # 对特征进行投影
        outputs = self.project_in(outputs)
        outputs = self.gate(outputs)
        outputs = self.project_out(outputs)
        return outputs

# 定义全局线性注意力类
class LinearAttention(nn.Module):
    # 初始化函数,接受特征维度、头维度和头数等参数
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 初始化头数
        self.heads = heads
        dim_inner = dim_head * heads
        # 线性变换得到查询、键、值
        self.to_qkv = nn.Linear(dim, dim_inner * 3)

    # 前向传播函数
    def forward(self, x, mask = None):
        # 判断输入是否包含度数维度
        has_degree_m_dim = x.ndim == 4

        if has_degree_m_dim:
            x = rearrange(x, '... 1 -> ...')

        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

        if exists(mask):
            mask = rearrange(mask, 'b n -> b 1 n 1')
            k = k.masked_fill(~mask, -torch.finfo(q.dtype).max)
            v = v.masked_fill(~mask, 0.)

        k = k.softmax(dim = -2)
        q = q.softmax(dim = -1)

        kv = einsum(k, v, 'b h n d, b h n e -> b h d e')
        out = einsum(kv, q, 'b h d e, b h n d -> b h n e')
        out = rearrange(out, 'b h n d -> b n (h d)')

        if has_degree_m_dim:
            out = rearrange(out, '... -> ... 1')

        return out

# 定义 L2 距离注意力类
class L2DistAttention(nn.Module):
    @beartype
    # 初始化函数,定义了模型的各种参数和层
    def __init__(
        self,
        fiber: Tuple[int, ...],  # 输入特征的维度
        dim_head: Union[int, Tuple[int, ...]] = 64,  # 头的维度
        heads: Union[int, Tuple[int, ...]] = 8,  # 头的数量
        attend_self = False,  # 是否自注意力
        edge_dim = None,  # 边的维度
        single_headed_kv = False,  # 是否单头键值对
        radial_hidden_dim = 64,  # 径向隐藏维度
        splits = 4,  # 分割数
        linear_attn_dim_head = 8,  # 线性注意力头的维度
        num_linear_attn_heads = 0,  # 线性注意力头的数量
        init_out_zero = True,  # 输出是否初始化为零
        gate_attn_head_outputs = True  # 是否对注意力头输出进行门控
    ):
        super().__init__()  # 调用父类的初始化函数
        num_degrees = len(fiber)  # 输入特征的维度数

        dim_head = cast_tuple(dim_head, num_degrees)  # 将头的维度转换为元组
        assert len(dim_head) == num_degrees  # 确保头的维度数与输入特征的维度数相同

        heads = cast_tuple(heads, num_degrees)  # 将头的数量转换为元组
        assert len(heads) == num_degrees  # 确保头的数量与输入特征的维度数相同

        hidden_fiber = tuple(dim * head for dim, head in zip(dim_head, heads))  # 计算隐藏层的维度

        self.single_headed_kv = single_headed_kv  # 是否单头键值对
        self.attend_self = attend_self  # 是否自注意力

        kv_hidden_fiber = hidden_fiber if not single_headed_kv else dim_head  # 键值对隐藏层的维度
        kv_hidden_fiber = tuple(dim * 2 for dim in kv_hidden_fiber)  # 键值对隐藏层的维度

        self.scale = tuple(dim ** -0.5 for dim in dim_head)  # 缩放因子
        self.heads = heads  # 头的数量

        self.prenorm = Norm(fiber)  # 规范化层

        self.to_q = Linear(fiber, hidden_fiber)  # 查询层
        self.to_kv = DTP(fiber, kv_hidden_fiber, radial_hidden_dim = radial_hidden_dim, edge_dim = edge_dim, pool = False, self_interaction = attend_self)  # 键值对层

        # 线性注意力头

        self.has_linear_attn = num_linear_attn_heads > 0  # 是否有线性注意力头

        if self.has_linear_attn:
            degree_zero_dim = fiber[0]  # 输入特征的第一个维度
            self.linear_attn = TaylorSeriesLinearAttn(degree_zero_dim, dim_head = linear_attn_dim_head, heads = num_linear_attn_heads, combine_heads = False, gate_value_heads = True)  # 线性注意力层
            hidden_fiber = tuple_set_at_index(hidden_fiber, 0, hidden_fiber[0] + linear_attn_dim_head * num_linear_attn_heads)  # 更新隐藏层的维度

        # 对所有度的输出进行门控,以允许不关注任何内容

        self.attn_head_gates = None  # 注意力头的门控

        if gate_attn_head_outputs:
            self.attn_head_gates = nn.Sequential(
                Rearrange('... d 1 -> ... d'),
                nn.Linear(fiber[0], sum(heads)),
                nn.Sigmoid(),
                Rearrange('... n h -> ... h n 1 1')
            )  # 门控层

        # 合并头

        self.to_out = Linear(hidden_fiber, fiber)  # 输出层

        if init_out_zero:
            self.to_out.init_zero_()  # 初始化输出为零

    @beartype
    def forward(
        self,
        features,  # 特征
        edge_info: EdgeInfo,  # 边信息
        rel_dist,  # 相对距离
        basis,  # 基础
        D,  # D
        mask = None  # 掩码
        ):
            # 获取单头键值对应的标志
            one_head_kv = self.single_headed_kv

            # 获取特征的设备和数据类型
            device, dtype = get_tensor_device_and_dtype(features)
            # 获取邻居索引、邻居掩码和边信息
            neighbor_indices, neighbor_mask, edges = edge_info

            # 如果邻居掩码存在
            if exists(neighbor_mask):
                # 重新排列邻居掩码的维度
                neighbor_mask = rearrange(neighbor_mask, 'b i j -> b 1 i j')

                # 如果需要考虑自身
                if self.attend_self:
                    # 在邻居掩码上进行填充
                    neighbor_mask = F.pad(neighbor_mask, (1, 0), value = True)

            # 对特征进行预处理
            features = self.prenorm(features)

            # 生成查询、键、值
            queries = self.to_q(features)

            keyvalues   = self.to_kv(
                features,
                edge_info = edge_info,
                rel_dist = rel_dist,
                basis = basis,
                D = D
            )

            # 创建门
            gates = (None,) * len(self.heads)

            # 如果存在注意力头门控
            if exists(self.attn_head_gates):
                # 对特征的第一个元素应用注意力头门控,并按头数分割
                gates = self.attn_head_gates(features[0]).split(self.heads, dim = -4)

            # 单头与多头的区别
            kv_einsum_eq = 'b h i j d m' if not one_head_kv else 'b i j d m'

            outputs = {}

            # 遍历特征的键,门,头数和缩放因子
            for degree, gate, h, scale in zip(features.keys(), gates, self.heads, self.scale):
                # 判断是否为零度
                is_degree_zero = degree == 0

                q, kv = map(lambda t: t[degree], (queries, keyvalues))

                q = rearrange(q, 'b i (h d) m -> b h i d m', h = h)

                # 如果不是单头键值
                if not one_head_kv:
                    kv = rearrange(kv, f'b i j (h d) m -> b h i j d m', h = h)

                k, v = kv.chunk(2, dim = -2)

                # 如果是单头键值
                if one_head_kv:
                    k = repeat(k, 'b i j d m -> b h i j d m', h = h)

                q = repeat(q, 'b h i d m -> b h i j d m', j = k.shape[-3])

                # 如果是零度
                if is_degree_zero:
                    q, k = map(lambda t: rearrange(t, '... 1 -> ...'), (q, k))

                sim = -cdist(q, k) * scale

                # 如果不是零度
                if not is_degree_zero:
                    sim = sim.sum(dim = -1)
                    sim = sim.masked_fill(~neighbor_mask, -torch.finfo(sim.dtype).max)

                attn = sim.softmax(dim = -1)
                out = einsum(attn, v, f'b h i j, {kv_einsum_eq} -> b h i d m')

                # 如果门存在
                if exists(gate):
                    out = out * gate

                outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m')

            # 如果具有线性注意力
            if self.has_linear_attn:
                linear_attn_input = rearrange(features[0], '... 1 -> ...')
                lin_attn_out = self.linear_attn(linear_attn_input, mask = mask)
                lin_attn_out = rearrange(lin_attn_out, '... -> ... 1')
                outputs[0] = torch.cat((outputs[0], lin_attn_out), dim = -2)

            # 返回输出
            return self.to_out(outputs)
# 定义一个多层感知机注意力模型的类
class MLPAttention(nn.Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        fiber: Tuple[int, ...],  # 输入特征的维度
        dim_head: Union[int, Tuple[int, ...]] = 64,  # 注意力头的维度
        heads: Union[int, Tuple[int, ...]] = 8,  # 注意力头的数量
        attend_self = False,  # 是否自注意力
        edge_dim = None,  # 边的维度
        splits = 4,  # 分割数
        single_headed_kv = False,  # 是否单头键值对
        attn_leakyrelu_slope = 0.1,  # 注意力LeakyReLU斜率
        attn_hidden_dim_mult = 4,  # 注意力隐藏层维度倍数
        radial_hidden_dim = 16,  # 径向隐藏层维度
        linear_attn_dim_head = 8,  # 线性注意力头维度
        num_linear_attn_heads = 0,  # 线性注意力头数量
        init_out_zero = True,  # 输出初始化为零
        gate_attn_head_outputs = True,  # 是否门控注意力头输出
        **kwargs
    ):
        super().__init__()
        num_degrees = len(fiber)

        dim_head = cast_tuple(dim_head, num_degrees)  # 将dim_head转换为元组
        assert len(dim_head) == num_degrees

        heads = cast_tuple(heads, num_degrees)  # 将heads转换为元组
        assert len(heads) == num_degrees

        hidden_fiber = tuple(dim * head for dim, head in zip(dim_head, heads))  # 计算隐藏层的维度

        self.single_headed_kv = single_headed_kv  # 是否单头键值对
        value_hidden_fiber = hidden_fiber if not single_headed_kv else dim_head  # 值的隐藏层维度

        self.attend_self = attend_self  # 是否自注意力

        self.scale = tuple(dim ** -0.5 for dim in dim_head)  # 缩放因子
        self.heads = heads  # 注意力头数量

        self.prenorm = Norm(fiber)  # 规范化层

        # type 0需要更大的维度,用于
        # (1) 在值分支上对htypes进行门控
        # (2) 注意力logits,初始维度等于头的数量

        type0_dim = value_hidden_fiber[0]  # 类型0的维度
        htype_dims = sum(value_hidden_fiber[1:])  # htype的维度

        value_gate_fiber = tuple_set_at_index(value_hidden_fiber, 0, type0_dim + htype_dims)  # 值门控的维度

        attn_hidden_dims = tuple(head * attn_hidden_dim_mult for head in heads)  # 注意力隐藏层的维度

        intermediate_fiber = tuple_set_at_index(value_hidden_fiber, 0, sum(attn_hidden_dims) + type0_dim + htype_dims)  # 中间层的维度
        self.intermediate_type0_split = [*attn_hidden_dims, type0_dim + htype_dims]  # 类型0的分割

        # 主分支张量乘积

        self.to_attn_and_v = DTP(fiber, intermediate_fiber, radial_hidden_dim = radial_hidden_dim, edge_dim = edge_dim, pool = False, self_interaction = attend_self)  # 注意力和值的张量乘积

        # 注意力分支的非线性投影到注意力logits

        self.to_attn_logits = nn.ModuleList([
            nn.Sequential(
                nn.LeakyReLU(attn_leakyrelu_slope),  # LeakyReLU激活函数
                nn.Linear(attn_hidden_dim, h, bias = False)  # 线性层
            ) for attn_hidden_dim, h in zip(attn_hidden_dims, self.heads)
        ])

        # 值分支的非线性变换
        # todo - 这里需要一个DTP吗?

        self.to_values = nn.Sequential(
            Gate(value_gate_fiber),  # 门控层
            Linear(value_hidden_fiber, value_hidden_fiber)  # 线性层
        )

        # 线性注意力头

        self.has_linear_attn = num_linear_attn_heads > 0  # 是否有线性注意力头

        if self.has_linear_attn:
            degree_zero_dim = fiber[0]
            self.linear_attn = TaylorSeriesLinearAttn(degree_zero_dim, dim_head = linear_attn_dim_head, heads = num_linear_attn_heads, combine_heads = False)  # 线性注意力

            hidden_fiber = tuple_set_at_index(hidden_fiber, 0, hidden_fiber[0] + linear_attn_dim_head * num_linear_attn_heads)  # 更新隐藏层的维度

        # 门控所有度输出的头
        # 允许不关注任何内容

        self.attn_head_gates = None

        if gate_attn_head_outputs:
            self.attn_head_gates = nn.Sequential(
                Rearrange('... d 1 -> ... d'),  # 重新排列维度
                nn.Linear(fiber[0], sum(heads)),  # 线性层
                nn.Sigmoid(),  # Sigmoid激活函数
                Rearrange('... h -> ... h 1 1')  # 重新排列维度
            )

        # 合并头和投影输出

        self.to_out = Linear(hidden_fiber, fiber)  # 输出层

        if init_out_zero:
            self.to_out.init_zero_()  # 初始化输出为零

    @beartype
    def forward(
        self,
        features,
        edge_info: EdgeInfo,
        rel_dist,
        basis,
        D,
        mask = None
        ):
            # 获取单头键值对
            one_headed_kv = self.single_headed_kv

            # 解包边信息
            _, neighbor_mask, _ = edge_info

            # 如果邻居掩码存在
            if exists(neighbor_mask):
                # 如果需要考虑自身,则在左侧填充一个位置
                if self.attend_self:
                    neighbor_mask = F.pad(neighbor_mask, (1, 0), value = True)

                # 重新排列邻居掩码的维度
                neighbor_mask = rearrange(neighbor_mask, '... -> ... 1')

            # 对特征进行预处理
            features = self.prenorm(features)

            # 获取注意力和值的中间结果
            intermediate = self.to_attn_and_v(
                features,
                edge_info = edge_info,
                rel_dist = rel_dist,
                basis = basis,
                D = D
            )

            # 拆分注意力分支和值分支
            *attn_branch_type0, value_branch_type0 = intermediate[0].split(self.intermediate_type0_split, dim = -2)

            # 将值分支替换回中间结果
            intermediate[0] = value_branch_type0

            # 创建门控
            gates = (None,) * len(self.heads)

            # 如果存在注意力头门控
            if exists(self.attn_head_gates):
                gates = self.attn_head_gates(features[0]).split(self.heads, dim = -3)

            # 处理注意力分支
            attentions = []

            for fn, attn_intermediate, scale in zip(self.to_attn_logits, attn_branch_type0, self.scale):
                attn_intermediate = rearrange(attn_intermediate, '... 1 -> ...')
                attn_logits = fn(attn_intermediate)
                attn_logits = attn_logits * scale

                # 如果邻居掩码存在,则进行掩码处理
                if exists(neighbor_mask):
                    attn_logits = attn_logits.masked_fill(~neighbor_mask, -torch.finfo(attn_logits.dtype).max)

                # 计算注意力权重
                attn = attn_logits.softmax(dim = -2) # (batch, source, target, heads)
                attentions.append(attn)

            # 处理值分支
            values = self.to_values(intermediate)

            # 使用注意力矩阵聚合值
            outputs = {}

            value_einsum_eq = 'b i j h d m' if not one_headed_kv else 'b i j d m'

            for degree, (attn, value, gate, h) in enumerate(zip(attentions, values.values(), gates, self.heads)):
                if not one_headed_kv:
                    value = rearrange(value, 'b i j (h d) m -> b i j h d m', h = h)

                out = einsum(attn, value, f'b i j h, {value_einsum_eq} -> b i h d m')

                if exists(gate):
                    out = out * gate

                out = rearrange(out, 'b i h d m -> b i (h d) m')
                outputs[degree] = out

            # 线性注意力
            if self.has_linear_attn:
                linear_attn_input = rearrange(features[0], '... 1 -> ...')
                lin_attn_out = self.linear_attn(linear_attn_input, mask = mask)
                lin_attn_out = rearrange(lin_attn_out, '... -> ... 1')

                outputs[0] = torch.cat((outputs[0], lin_attn_out), dim = -2)

            # 合并头部输出
            return self.to_out(outputs)
# 主类定义

class Equiformer(nn.Module):
    # 初始化函数,使用装饰器进行参数类型检查
    @beartype
    def __init__(
        self,
        *,
        dim: Union[int, Tuple[int, ...]],  # 维度参数,可以是整数或元组
        dim_in: Optional[Union[int, Tuple[int, ...]]] = None,  # 输入维度参数,可选
        num_degrees = 2,  # 角度数量,默认为2
        input_degrees = 1,  # 输入角度数量,默认为1
        heads: Union[int, Tuple[int, ...]] = 8,  # 头数,可以是整数或元组,默认为8
        dim_head: Union[int, Tuple[int, ...]] = 24,  # 头维度,可以是整数或元组,默认为24
        depth = 2,  # 深度,默认为2
        valid_radius = 1e5,  # 有效半径,默认为1e5
        num_neighbors = float('inf'),  # 邻居数量,默认为无穷大
        reduce_dim_out = False,  # 是否减少输出维度,默认为False
        radial_hidden_dim = 64,  # 径向隐藏维度,默认为64
        num_tokens = None,  # 令牌数量,默认为None
        num_positions = None,  # 位置数量,默认为None
        num_edge_tokens = None,  # 边令牌数量,默认为None
        edge_dim = None,  # 边维度,默认为None
        attend_self = True,  # 是否自注意,默认为True
        splits = 4,  # 分割数,默认为4
        linear_out = True,  # 是否线性输出,默认为True
        embedding_grad_frac = 0.5,  # 嵌入梯度比例,默认为0.5
        single_headed_kv = False,  # 是否对点积注意力进行单头键/值操作,以节省内存和计算资源,默认为False
        ff_include_htype_norms = False,  # 是否在类型0投影中还涉及所有更高类型的规范化,在前馈第一次投影中。这允许所有更高类型受其他类型规范化的门控
        l2_dist_attention = True,  # 是否使用L2距离注意力,默认为True。将其设置为False以使用论文中提出的MLP注意力,但是点积注意力与-cdist相似性仍然要好得多,而且我甚至还没有将距离(旋转嵌入)旋转到类型0特征中
        reversible = False,  # 打开可逆网络,以在不增加深度内存成本的情况下扩展深度
        attend_sparse_neighbors = False,  # 能够接受邻接矩阵,默认为False
        gate_attn_head_outputs = True,  # 对每个注意力头输出进行门控,以允许不关注任何内容
        num_adj_degrees_embed = None,  # 邻接度嵌入数量,默认为None
        adj_dim = 0,  # 邻接维度,默认为0
        max_sparse_neighbors = float('inf'),  # 最大稀疏邻居数量,默认为无穷大
        **kwargs  # 其他关键字参数
    # 初始化函数,继承父类的初始化方法
    def __init__(
        self,
        embedding_grad_frac,
        dim,
        num_degrees,
        dim_in,
        input_degrees,
        num_tokens,
        num_positions,
        edge_dim,
        num_edge_tokens,
        attend_sparse_neighbors,
        max_sparse_neighbors,
        num_adj_degrees_embed,
        adj_dim,
        valid_radius,
        num_neighbors,
        radial_hidden_dim,
        depth,
        heads,
        dim_head,
        attend_self,
        single_headed_kv,
        l2_dist_attention,
        gate_attn_head_outputs,
        reversible,
        ff_include_htype_norms,
        linear_out,
        reduce_dim_out,
        **kwargs
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 用于更稳定的训练的嵌入梯度比例
        self.embedding_grad_frac = embedding_grad_frac

        # 为所有类型决定隐藏维度
        self.dim = cast_tuple(dim, num_degrees)
        assert len(self.dim) == num_degrees

        self.num_degrees = len(self.dim)

        # 为所有类型决定输入维度
        dim_in = default(dim_in, (self.dim[0],))
        self.dim_in = cast_tuple(dim_in, input_degrees)
        assert len(self.dim_in) == input_degrees

        self.input_degrees = len(self.dim_in)

        # token 嵌入
        type0_feat_dim = self.dim_in[0]
        self.type0_feat_dim = type0_feat_dim
        self.token_emb = nn.Embedding(num_tokens, type0_feat_dim) if exists(num_tokens) else None

        # 位置嵌入
        self.num_positions = num_positions
        self.pos_emb = nn.Embedding(num_positions, type0_feat_dim) if exists(num_positions) else None

        # 初始化嵌入
        if exists(self.token_emb):
            nn.init.normal_(self.token_emb.weight, std=1e-2)

        if exists(self.pos_emb):
            nn.init.normal_(self.pos_emb.weight, std=1e-2)

        # 边
        assert not (exists(num_edge_tokens) and not exists(edge_dim)), 'edge dimension (edge_dim) must be supplied if equiformer is to have edge tokens'
        self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None
        self.has_edges = exists(edge_dim) and edge_dim > 0

        # 稀疏邻居,从邻接矩阵或传入的边派生
        self.attend_sparse_neighbors = attend_sparse_neighbors
        self.max_sparse_neighbors = max_sparse_neighbors

        # 邻接邻居派生和嵌入
        assert not exists(num_adj_degrees_embed) or num_adj_degrees_embed >= 1, 'number of adjacent degrees to embed must be 1 or greater'
        self.num_adj_degrees_embed = num_adj_degrees_embed
        self.adj_emb = nn.Embedding(num_adj_degrees_embed + 1, adj_dim) if exists(num_adj_degrees_embed) and adj_dim > 0 else None
        edge_dim = (edge_dim if self.has_edges else 0) + (adj_dim if exists(self.adj_emb) else 0)

        # 邻居超参数
        self.valid_radius = valid_radius
        self.num_neighbors = num_neighbors

        # 主网络
        self.tp_in = DTP(
            self.dim_in,
            self.dim,
            edge_dim=edge_dim,
            radial_hidden_dim=radial_hidden_dim
        )

        # 主干
        self.layers = []

        attention_klass = L2DistAttention if l2_dist_attention else MLPAttention

        for ind in range(depth):
            self.layers.append((
                attention_klass(
                    self.dim,
                    heads=heads,
                    dim_head=dim_head,
                    attend_self=attend_self,
                    edge_dim=edge_dim,
                    single_headed_kv=single_headed_kv,
                    radial_hidden_dim=radial_hidden_dim,
                    gate_attn_head_outputs=gate_attn_head_outputs,
                    **kwargs
                ),
                FeedForward(self.dim, include_htype_norms=ff_include_htype_norms)
            ))

        SequenceKlass = ReversibleSequence if reversible else SequentialSequence

        self.layers = SequenceKlass(self.layers)

        # 输出
        self.norm = Norm(self.dim)

        proj_out_klass = Linear if linear_out else FeedForward

        self.ff_out = proj_out_klass(self.dim, (1,) * self.num_degrees) if reduce_dim_out else None

        # 基础现在是常数
        # pytorch 目前没有 BufferDict,用 Python 属性来实现一个解决方案
        self.basis = get_basis(self.num_degrees - 1)

    @property
    def basis(self):
        out = dict()
        for k in self.basis_keys:
            out[k] = getattr(self, f'basis:{k}')
        return out

    @basis.setter
    # 定义一个方法,用于设置基础信息
    def basis(self, basis):
        # 将传入的基础信息的键存储到对象的属性中
        self.basis_keys = basis.keys()

        # 遍历基础信息的键值对
        for k, v in basis.items():
            # 将每个键值对注册为缓冲区
            self.register_buffer(f'basis:{k}', v)

    # 定义一个属性,用于获取模型参数所在的设备
    @property
    def device(self):
        # 返回第一个参数的设备信息
        return next(self.parameters()).device

    # 定义一个前向传播方法,接受输入数据、坐标、掩码、邻接矩阵、边信息和是否返回池化结果等参数
    @beartype
    def forward(
        self,
        inputs: Union[Tensor, Dict[int, Tensor]],
        coors: Tensor,
        mask = None,
        adj_mat = None,
        edges = None,
        return_pooled = False

.\lucidrains\equiformer-pytorch\equiformer_pytorch\irr_repr.py

# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 从 functools 模块中导入 partial 函数
from functools import partial

# 导入 torch 库
import torch
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 torch 库中导入 sin, cos, atan2, acos 函数
from torch import sin, cos, atan2, acos

# 从 einops 库中导入 rearrange, pack, unpack 函数
from einops import rearrange, pack, unpack

# 从 equiformer_pytorch.utils 模块中导入 exists, default, cast_torch_tensor, to_order, identity, l2norm 函数
from equiformer_pytorch.utils import (
    exists,
    default,
    cast_torch_tensor,
    to_order,
    identity,
    l2norm
)

# 定义 DATA_PATH 变量为当前文件路径的父目录下的 'data' 文件夹
DATA_PATH = Path(__file__).parents[0] / 'data'
# 定义 path 变量为 DATA_PATH 下的 'J_dense.pt' 文件
path = DATA_PATH / 'J_dense.pt'
# 从 'J_dense.pt' 文件中加载数据,赋值给 Jd 变量
Jd = torch.load(str(path))

# 定义 pack_one 函数,用于将输入张量 t 按照指定模式 pattern 进行打包
def pack_one(t, pattern):
    return pack([t], pattern)

# 定义 unpack_one 函数,用于将输入张量 t 按照指定模式 pattern 进行解包
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 定义 wigner_d_matrix 函数,用于创建 ZYZ 欧拉角批量的维格纳 D 矩阵
def wigner_d_matrix(degree, alpha, beta, gamma, dtype = None, device = None):
    """Create wigner D matrices for batch of ZYZ Euler angles for degree l."""
    # 获取批量大小
    batch = alpha.shape[0]
    # 从 Jd 中获取 degree 对应的张量 J
    J = Jd[degree].type(dtype).to(device)
    # 根据 degree 创建对应的排序 order
    order = to_order(degree)
    # 计算 ZYZ 欧拉角的旋转矩阵
    x_a = z_rot_mat(alpha, degree)
    x_b = z_rot_mat(beta, degree)
    x_c = z_rot_mat(gamma, degree)
    res = x_a @ J @ x_b @ J @ x_c
    return res.view(batch, order, order)

# 定义 z_rot_mat 函数,用于创建绕 Z 轴旋转的旋转矩阵
def z_rot_mat(angle, l):
    device, dtype = angle.device, angle.dtype

    # 获取批量大小
    batch = angle.shape[0]
    # 创建 arange 函数的部分应用,指定设备为 device
    arange = partial(torch.arange, device = device)

    # 根据 degree 创建对应的排序 order
    order = to_order(l)

    # 初始化旋转矩阵 m
    m = angle.new_zeros((batch, order, order))

    # 创建批量范围
    batch_range = arange(batch, dtype = torch.long)[..., None]
    inds = arange(order, dtype = torch.long)[None, ...]
    reversed_inds = arange(2 * l, -1, -1, dtype = torch.long)[None, ...]
    frequencies = arange(l, -l - 1, -1, dtype = dtype)[None]

    # 计算旋转矩阵的值
    m[batch_range, inds, reversed_inds] = sin(frequencies * angle[..., None])
    m[batch_range, inds, inds] = cos(frequencies * angle[..., None])
    return m

# 定义 irr_repr 函数,用于计算 SO3 的不可约表示
def irr_repr(order, angles):
    """
    irreducible representation of SO3 - accepts multiple angles in tensor
    """
    dtype, device = angles.dtype, angles.device
    angles, ps = pack_one(angles, '* c')

    alpha, beta, gamma = angles.unbind(dim = -1)
    rep = wigner_d_matrix(order, alpha, beta, gamma, dtype = dtype, device = device)

    return unpack_one(rep, ps, '* o1 o2')

# 将 rot_z 函数的输出转换为 torch 张量
@cast_torch_tensor
def rot_z(gamma):
    '''
    Rotation around Z axis
    '''
    c = cos(gamma)
    s = sin(gamma)
    z = torch.zeros_like(gamma)
    o = torch.ones_like(gamma)

    out = torch.stack((
        c, -s, z,
        s, c, z,
        z, z, o
    ), dim = -1)

    return rearrange(out, '... (r1 r2) -> ... r1 r2', r1 = 3)

# 将 rot_y 函数的输出转换为 torch 张量
@cast_torch_tensor
def rot_y(beta):
    '''
    Rotation around Y axis
    '''
    c = cos(beta)
    s = sin(beta)
    z = torch.zeros_like(beta)
    o = torch.ones_like(beta)

    out = torch.stack((
        c, z, s,
        z, o, z,
        -s, z, c
    ), dim = -1)

    return rearrange(out, '... (r1 r2) -> ... r1 r2', r1 = 3)

# 定义 rot 函数,用于计算 ZYZ 欧拉角的旋转矩阵
def rot(alpha, beta, gamma):
    '''
    ZYZ Euler angles rotation
    '''
    return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)

# 定义 rot_to_euler_angles 函数,用于将旋转矩阵转换为 ZYZ 欧拉角
def rot_to_euler_angles(R):
    '''
    Rotation matrix to ZYZ Euler angles
    '''
    device, dtype = R.device, R.dtype
    xyz = R @ torch.tensor([0.0, 1.0, 0.0], device = device, dtype = dtype)
    xyz = l2norm(xyz).clamp(-1., 1.)

    b = acos(xyz[..., 1])
    a = atan2(xyz[..., 0], xyz[..., 2])

    R = rot(a, b, torch.zeros_like(a)).transpose(-1, -2) @ R
    c = atan2(R[..., 0, 2], R[..., 0, 0])
    return torch.stack((a, b, c), dim = -1)

.\lucidrains\equiformer-pytorch\equiformer_pytorch\reversible.py

import torch
from torch.nn import Module
import torch.nn as nn
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states

from beartype import beartype
from beartype.typing import List, Tuple

from einops import rearrange, reduce

from equiformer_pytorch.utils import to_order

# helpers

# 将函数 fn 应用于字典 x 中的每个值,并返回新的字典
def map_values(fn, x):
    out = {}
    for (k, v) in x.items():
        out[k] = fn(v)
    return out

# 将字典 x 中的值按照指定维度 dim 进行分块,返回两个新的字典
def dict_chunk(x, chunks, dim):
    out1 = {}
    out2 = {}
    for (k, v) in x.items():
        c1, c2 = v.chunk(chunks, dim = dim)
        out1[k] = c1
        out2[k] = c2
    return out1, out2

# 将两个字典 x 和 y 中的对应值相加,并返回新的字典
def dict_sum(x, y):
    out = {}
    for k in x.keys():
        out[k] = x[k] + y[k]
    return out

# 将两个字典 x 和 y 中的对应值相减,并返回新的字典
def dict_subtract(x, y):
    out = {}
    for k in x.keys():
        out[k] = x[k] - y[k]
    return out

# 将两个字典 x 和 y 中的对应值在指定维度 dim 上进行拼接,并返回新的字典
def dict_cat(x, y, dim):
    out = {}
    for k, v1 in x.items():
        v2 = y[k]
        out[k] = torch.cat((v1, v2), dim = dim)
    return out

# 设置字典 x 中所有值的指定属性 key 为指定值 value
def dict_set_(x, key, value):
    for k, v in x.items():
        setattr(v, key, value)

# 对字典 outputs 中的值进行反向传播,使用 grad_tensors 中的梯度
def dict_backwards_(outputs, grad_tensors):
    for k, v in outputs.items():
        torch.autograd.backward(v, grad_tensors[k], retain_graph = True)

# 删除字典 x 中的所有值
def dict_del_(x):
    for k, v in x.items():
        del v
    del x

# 返回字典 d 中所有值的列表
def values(d):
    return [v for _, v in d.items()]

# following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html

# 定义一个继承自 Module 的类 Deterministic
class Deterministic(Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.cpu_state = None
        self.cuda_in_fwd = None
        self.gpu_devices = None
        self.gpu_states = None

    # 记录随机数生成器状态
    def record_rng(self, *args):
        self.cpu_state = torch.get_rng_state()
        if torch.cuda._initialized:
            self.cuda_in_fwd = True
            self.gpu_devices, self.gpu_states = get_device_states(*args)

    # 前向传播函数
    def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)

# heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
# once multi-GPU is confirmed working, refactor and send PR back to source

# 定义一个继承自 Module 的类 ReversibleBlock
class ReversibleBlock(Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)

    # 前向传播函数
    def forward(self, x, **kwargs):
        training = self.training
        x1, x2 = dict_chunk(x, 2, dim = -1)
        y1, y2 = None, None

        with torch.no_grad():
            y1 = dict_sum(x1, self.f(x2, record_rng = training, **kwargs))
            y2 = dict_sum(x2, self.g(y1, record_rng = training))

        return dict_cat(y1, y2, dim = -1)
    # 定义反向传播函数,接收输入 y、梯度 dy 和其他参数
    def backward_pass(self, y, dy, **kwargs):
        # 将 y 按照指定维度分成两部分 y1 和 y2
        y1, y2 = dict_chunk(y, 2, dim = -1)
        # 删除原始 y 字典
        dict_del_(y)

        # 将 dy 按照指定维度分成两部分 dy1 和 dy2
        dy1, dy2 = dict_chunk(dy, 2, dim = -1)
        # 删除原始 dy 字典
        dict_del_(dy)

        # 开启梯度追踪
        with torch.enable_grad():
            # 设置 y1 的 requires_grad 为 True
            dict_set_(y1, 'requires_grad', True)
            # 计算 y1 的梯度 gy1
            gy1 = self.g(y1, set_rng = True)
            # 对 gy1 进行反向传播,传入 dy2
            dict_backwards_(gy1, dy2)

        # 关闭梯度追踪
        with torch.no_grad():
            # 计算 x2,即 y2 减去 gy1
            x2 = dict_subtract(y2, gy1)
            # 删除 y2 和 gy1
            dict_del_(y2)
            dict_del_(gy1)

            # 计算 dx1,即 dy1 加上 y1 中各张量的梯度
            dx1 = dict_sum(dy1, map_values(lambda t: t.grad, y1))
            # 删除 dy1,并将 y1 的梯度设为 None
            dict_del_(dy1)
            dict_set_(y1, 'grad', None)

        # 开启梯度追踪
        with torch.enable_grad():
            # 设置 x2 的 requires_grad 为 True
            dict_set_(x2, 'requires_grad', True)
            # 计算 fx2,即对 x2 进行操作并计算梯度
            fx2 = self.f(x2, set_rng = True, **kwargs)
            # 对 fx2 进行反向传播,传入 dx1
            dict_backwards_(fx2, dx1)

        # 关闭梯度追踪
        with torch.no_grad():
            # 计算 x1,即 y1 减去 fx2
            x1 = dict_subtract(y1, fx2)
            # 删除 y1 和 fx2
            dict_del_(y1)
            dict_del_(fx2)

            # 计算 dx2,即 dy2 加上 x2 中各张量的梯度
            dx2 = dict_sum(dy2, map_values(lambda t: t.grad, x2))
            # 删除 dy2,并将 x2 的梯度设为 None
            dict_del_(dy2)
            dict_set_(x2, 'grad', None)

            # 将 x2 中的张量都 detach,即不再追踪梯度
            x2 = map_values(lambda t: t.detach(), x2)

            # 将 x1 和 x2 按照指定维度拼接成 x
            x = dict_cat(x1, x2, dim = -1)
            # 将 dx1 和 dx2 按照指定维度拼接成 dx
            dx = dict_cat(dx1, dx2, dim = -1)

        # 返回拼接后的 x 和 dx
        return x, dx
class _ReversibleFunction(Function):
    # 定义一个继承自Function的类_ReversibleFunction
    @staticmethod
    def forward(ctx, x, blocks, kwargs):
        # 定义静态方法forward,接受输入x、blocks和kwargs
        input_keys = kwargs.pop('input_keys')
        # 从kwargs中弹出键为'input_keys'的值,赋给input_keys
        split_dims = kwargs.pop('split_dims')
        # 从kwargs中弹出键为'split_dims'的值,赋给split_dims
        input_values = x.split(split_dims, dim = -1)
        # 将输入x按照split_dims在最后一个维度上分割,得到input_values

        x = dict(zip(input_keys, input_values))
        # 将input_keys和input_values打包成字典,赋给x

        ctx.kwargs = kwargs
        ctx.split_dims = split_dims
        ctx.input_keys = input_keys
        # 将kwargs、split_dims和input_keys保存在ctx中

        x = {k: rearrange(v, '... (d m) -> ... d m', m = to_order(k) * 2) for k, v in x.items()}
        # 对x中的每个键值对进行重排列操作,重新赋值给x

        for block in blocks:
            x = block(x, **kwargs)
        # 遍历blocks中的每个块,对x进行处理

        ctx.y = map_values(lambda t: t.detach(), x)
        ctx.blocks = blocks
        # 将x中的值进行detach操作后保存在ctx.y中,保存blocks到ctx中

        x = map_values(lambda t: rearrange(t, '... d m -> ... (d m)'), x)
        x = torch.cat(values(x), dim = -1)
        # 对x中的值进行重排列和拼接操作

        return x
        # 返回处理后的x

    @staticmethod
    def backward(ctx, dy):
        # 定义静态方法backward,接受输入dy
        y = ctx.y
        kwargs = ctx.kwargs
        input_keys = ctx.input_keys
        split_dims = ctx.split_dims
        # 从ctx中获取y、kwargs、input_keys和split_dims

        dy = dy.split(split_dims, dim = -1)
        dy = dict(zip(input_keys, dy))
        # 将dy按照split_dims在最后一个维度上分割,打包成字典

        dy = {k: rearrange(v, '... (d m) -> ... d m', m = to_order(k) * 2) for k, v in dy.items()}
        # 对dy中的每个键值对进行重排列操作

        for block in ctx.blocks[::-1]:
            y, dy = block.backward_pass(y, dy, **kwargs)
        # 逆序遍历ctx.blocks中的每个块,对y和dy进行处理

        dy = map_values(lambda t: rearrange(t, '... d m -> ... (d m)'), dy)
        dy = torch.cat(values(dy), dim = -1)
        # 对dy中的值进行重排列和拼接操作

        return dy, None, None
        # 返回处理后的dy和两个None值

# sequential

def residual_fn(x, residual):
    # 定义一个函数residual_fn,接受输入x和residual
    out = {}
    # ��始化一个空字典out

    for degree, tensor in x.items():
        # 遍历x中的每个键值对
        out[degree] = tensor
        # 将键值对中的值赋给out对应的键

        if degree not in residual:
            continue
        # 如果degree不在residual中,则继续下一次循环

        if not any(t.requires_grad for t in (out[degree], residual[degree])):
            out[degree] += residual[degree]
        else:
            out[degree] = out[degree] + residual[degree]
        # 如果out[degree]和residual[degree]中有任意一个张量需要梯度,则相加,否则直接赋值相加

    return out
    # 返回处理后的out字典

class SequentialSequence(Module):
    # 定义一个继承自Module的类SequentialSequence

    @beartype
    def __init__(
        self,
        blocks: List[Tuple[Module, Module]]
    ):
        # 初始化方法,接受blocks参数,类型为包含元组的列表
        super().__init__()
        # 调用父类的初始化方法

        self.blocks = nn.ModuleList([nn.ModuleList([f, g]) for f, g in blocks])
        # 将blocks中的每个元组(f, g)转换为nn.ModuleList,再转换为nn.ModuleList,赋给self.blocks

    def forward(self, x, **kwargs):
        # 定义前向传播方法,接受输入x和kwargs

        for attn, ff in self.blocks:
            # 遍历self.blocks中的每个元组(attn, ff)
            x = residual_fn(attn(x, **kwargs), x)
            # 对attn(x, **kwargs)和x进行残差连接后赋给x
            x = residual_fn(ff(x), x)
            # 对ff(x)和x进行残差连接后赋给x

        return x
        # 返回处理后的x

# reversible

class ReversibleSequence(Module):
    # 定义一个继承自Module的类ReversibleSequence

    @beartype
    def __init__(
        self,
        blocks: List[Tuple[Module, Module]]
    ):
        # 初始化方法,接受blocks参数,类型为包含元组的列表
        super().__init__()
        # 调用父类的初始化方法

        self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks])
        # 将blocks中的每个元组(f, g)转换为ReversibleBlock,再转换为nn.ModuleList,赋给self.blocks

    def forward(self, x, **kwargs):
        # 定义前向传播方法,接受输入x和kwargs
        blocks = self.blocks
        # 将self.blocks赋给blocks

        # merge into single tensor

        x = map_values(lambda t: torch.cat((t, t), dim = -1), x)
        # 对x中的每个值进行拼接操作
        x = map_values(lambda t: rearrange(t, '... d m -> ... (d m)'), x)
        # 对x中的每个值进行重排列操作

        input_keys = x.keys()
        # 获取x的键集合

        split_dims = tuple(map(lambda t: t.shape[-1], x.values()))
        # 获取x中每个值在最后一个维度上的大小,转换为元组赋给split_dims
        block_kwargs = {'input_keys': input_keys, 'split_dims': split_dims, **kwargs}
        # 构建块的参数字典

        x = torch.cat(values(x), dim = -1)
        # 对x中的值进行拼接操作

        # reversible function, tailored for equivariant network

        x = _ReversibleFunction.apply(x, blocks, block_kwargs)
        # 调用_ReversibleFunction的apply方法处理x

        # reconstitute

        x = dict(zip(input_keys, x.split(split_dims, dim = -1)))
        # 将x按照split_dims在最后一个维度上分割,打包成字典

        x = {k: reduce(v, '... (d r m) -> ... d m', 'mean', r = 2, m = to_order(k)) for k, v in x.items()}
        # 对x中的每个键值对进行降维操作

        return x
        # 返回处理后的x

.\lucidrains\equiformer-pytorch\equiformer_pytorch\utils.py

# 导入必要的库
from pathlib import Path

import time
import pickle
import gzip

import torch
import torch.nn.functional as F

import contextlib
from functools import wraps, lru_cache
from filelock import FileLock
from equiformer_pytorch.version import __version__

from einops import rearrange

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 返回输入值
def identity(t):
    return t

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将度数转换为顺序
def to_order(degree):
    return 2 * degree + 1

# 对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# 为了使 y 居中于 x,对 y 进行填充
def pad_for_centering_y_to_x(x, y):
    assert y <= x
    total_pad = x - y
    assert (total_pad % 2) == 0
    return total_pad // 2

# 为了使 y 居中于 x,对 y 进行切片
def slice_for_centering_y_to_x(x, y):
    pad = pad_for_centering_y_to_x(x, y)
    if pad == 0:
        return slice(None)
    return slice(pad, -pad)

# 安全地拼接张量
def safe_cat(arr, el, dim):
    if not exists(arr):
        return el
    return torch.cat((arr, el), dim = dim)

# 将值转换为元组
def cast_tuple(val, depth = 1):
    return val if isinstance(val, tuple) else (val,) * depth

# 快速分割张量
def fast_split(arr, splits, dim=0):
    axis_len = arr.shape[dim]
    splits = min(axis_len, max(splits, 1))
    chunk_size = axis_len // splits
    remainder = axis_len - chunk_size * splits
    s = 0
    for i in range(splits):
        adjust, remainder = 1 if remainder > 0 else 0, remainder - 1
        yield torch.narrow(arr, dim, s, chunk_size + adjust)
        s += chunk_size + adjust

# 带掩码的均值计算
def masked_mean(tensor, mask, dim = -1):
    if not exists(mask):
        return tensor.mean(dim = dim)

    diff_len = len(tensor.shape) - len(mask.shape)
    mask = mask[(..., *((None,) * diff_len))]
    tensor.masked_fill_(~mask, 0.)

    total_el = mask.sum(dim = dim)
    mean = tensor.sum(dim = dim) / total_el.clamp(min = 1.)
    mean.masked_fill_(total_el == 0, 0.)
    return mean

# 生成指定范围内的随机均匀分布张量
def rand_uniform(size, min_val, max_val):
    return torch.empty(size).uniform_(min_val, max_val)

# 默认数据类型上下文管理器

@contextlib.contextmanager
def torch_default_dtype(dtype):
    prev_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(prev_dtype)

# 将输入转换为 torch 张量
def cast_torch_tensor(fn):
    @wraps(fn)
    def inner(t):
        if not torch.is_tensor(t):
            t = torch.tensor(t, dtype = torch.get_default_dtype())
        return fn(t)
    return inner

# 基准测试工具

def benchmark(fn):
    def inner(*args, **kwargs):
        start = time.time()
        res = fn(*args, **kwargs)
        diff = time.time() - start
        return diff, res
    return inner

# 缓存函数

def cache(cache, key_fn):
    def cache_inner(fn):
        @wraps(fn)
        def inner(*args, **kwargs):
            key_name = key_fn(*args, **kwargs)
            if key_name in cache:
                return cache[key_name]
            res = fn(*args, **kwargs)
            cache[key_name] = res
            return res

        return inner
    return cache_inner

# 在目录中缓存

def cache_dir(dirname, maxsize=128):
    '''
    Cache a function with a directory

    :param dirname: the directory path
    :param maxsize: maximum size of the RAM cache (there is no limit for the directory cache)
    '''
    # 定义一个装饰器函数,接受一个函数作为参数
    def decorator(func):

        # 使用 lru_cache 装饰器缓存函数的结果,设置最大缓存大小为 maxsize
        @lru_cache(maxsize=maxsize)
        # 使用 wraps 装饰器保留原始函数的元数据
        @wraps(func)
        # 定义一个包装函数,接受任意参数和关键字参数
        def wrapper(*args, **kwargs):
            # 如果目录不存在,则直接调用原始函数并返回结果
            if not exists(dirname):
                return func(*args, **kwargs)

            # 创建目录路径对象
            dirpath = Path(dirname)
            # 创建目录,如果不存在则创建,存在则忽略
            dirpath.mkdir(parents=True, exist_ok=True)

            # 创建索引文件路径对象
            indexfile = dirpath / 'index.pkl'
            # 创建文件锁对象
            lock = FileLock(str(dirpath / 'mutex'))

            # 使用文件锁
            with lock:
                # 初始化索引字典
                index = {}
                # 如果索引文件存在,则读取索引数据
                if indexfile.exists():
                    with open(indexfile, "rb") as file:
                        index = pickle.load(file)

                # 生成键值,用于索引
                key = (args, frozenset(kwargs), func.__defaults__)

                # 如果键值存在于索引中,则获取对应的文件名
                if key in index:
                    filename = index[key]
                else:
                    # 如果键值不存在于索引中,则生成新的文件名,并更新索引
                    index[key] = filename = f"{len(index)}.pkl.gz"
                    with open(indexfile, "wb") as file:
                        pickle.dump(index, file)

            # 生成文件路径对象
            filepath = dirpath / filename

            # 如果文件存在,则使用文件锁读取文件数据
            if filepath.exists():
                with lock:
                    with gzip.open(filepath, "rb") as file:
                        result = pickle.load(file)
                return result

            # 打印计算过程信息
            print(f"compute {filename}... ", end="", flush=True)
            # 调用原始函数计算结果
            result = func(*args, **kwargs)
            # 打印保存文件信息
            print(f"save {filename}... ", end="", flush=True)

            # 使用文件锁保存计算结果到文件
            with lock:
                with gzip.open(filepath, "wb") as file:
                    pickle.dump(result, file)

            # 打印完成信息
            print("done")

            return result
        # 返回包装函数
        return wrapper
    # 返回装饰器函数
    return decorator

.\lucidrains\equiformer-pytorch\equiformer_pytorch\version.py

# 定义变量 __version__,赋值为字符串 '0.5.3'
__version__ = '0.5.3'

.\lucidrains\equiformer-pytorch\equiformer_pytorch\__init__.py

# 从 equiformer_pytorch.equiformer_pytorch 模块中导入 Equiformer 类
from equiformer_pytorch.equiformer_pytorch import Equiformer

Equiformer - Pytorch (wip)

Implementation of the Equiformer, SE3/E3 equivariant attention network that reaches new SOTA, and adopted for use by EquiFold (Prescient Design) for protein folding

The design of this seems to build off of SE3 Transformers, with the dot product attention replaced with MLP Attention and non-linear message passing from GATv2. It also does a depthwise tensor product for a bit more efficiency. If you think I am mistakened, please feel free to email me.

Update: There has been a new development that makes scaling the number of degrees for SE3 equivariant networks dramatically better! This paper first noted that by aligning the representations along the z-axis (or y-axis by some other convention), the spherical harmonics become sparse. This removes the mf dimension from the equation. A follow up paper from Passaro et al. noted the Clebsch Gordan matrix has also become sparse, leading to removal of mi and lf. They also made the connection that the problem has been reduced from SO(3) to SO(2) after aligning the reps to one axis. Equiformer v2 (Official repository) leverages this in a transformer-like framework to reach new SOTA.

Will definitely be putting more work / exploration into this. For now, I've incorporated the tricks from the first two paper for Equiformer v1, save for complete conversion into SO(2).

Install

$ pip install equiformer-pytorch

Usage

import torch
from equiformer_pytorch import Equiformer

model = Equiformer(
    num_tokens = 24,
    dim = (4, 4, 2),               # dimensions per type, ascending, length must match number of degrees (num_degrees)
    dim_head = (4, 4, 4),          # dimension per attention head
    heads = (2, 2, 2),             # number of attention heads
    num_linear_attn_heads = 0,     # number of global linear attention heads, can see all the neighbors
    num_degrees = 3,               # number of degrees
    depth = 4,                     # depth of equivariant transformer
    attend_self = True,            # attending to self or not
    reduce_dim_out = True,         # whether to reduce out to dimension of 1, say for predicting new coordinates for type 1 features
    l2_dist_attention = False      # set to False to try out MLP attention
).cuda()

feats = torch.randint(0, 24, (1, 128)).cuda()
coors = torch.randn(1, 128, 3).cuda()
mask  = torch.ones(1, 128).bool().cuda()

out = model(feats, coors, mask) # (1, 128)

out.type0 # invariant type 0    - (1, 128)
out.type1 # equivariant type 1  - (1, 128, 3)

This repository also includes a way to decouple memory usage from depth using reversible networks. In other words, if you increase depth, the memory cost will stay constant at the usage of one equiformer transformer block (attention and feedforward).

import torch
from equiformer_pytorch import Equiformer

model = Equiformer(
    num_tokens = 24,
    dim = (4, 4, 2),
    dim_head = (4, 4, 4),
    heads = (2, 2, 2),
    num_degrees = 3,
    depth = 48,          # depth of 48 - just to show that it runs - in reality, seems to be quite unstable at higher depths, so architecture stil needs more work
    reversible = True,   # just set this to True to use https://arxiv.org/abs/1707.04585
).cuda()

feats = torch.randint(0, 24, (1, 128)).cuda()
coors = torch.randn(1, 128, 3).cuda()
mask  = torch.ones(1, 128).bool().cuda()

out = model(feats, coors, mask)

out.type0.sum().backward()

Edges

with edges, ex. atomic bonds

import torch
from equiformer_pytorch import Equiformer

model = Equiformer(
    num_tokens = 28,
    dim = 64,
    num_edge_tokens = 4,       # number of edge type, say 4 bond types
    edge_dim = 16,             # dimension of edge embedding
    depth = 2,
    input_degrees = 1,
    num_degrees = 3,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
bonds = torch.randint(0, 4, (2, 32, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

out = model(atoms, coors, mask, edges = bonds)

out.type0 # (2, 32)
out.type1 # (2, 32, 3)

with adjacency matrix

import torch
from equiformer_pytorch import Equiformer

model = Equiformer(
    dim = 32,
    heads = 8,
    depth = 1,
    dim_head = 64,
    num_degrees = 2,
    valid_radius = 10,
    reduce_dim_out = True,
    attend_sparse_neighbors = True,  # this must be set to true, in which case it will assert that you pass in the adjacency matrix
    num_neighbors = 0,               # if you set this to 0, it will only consider the connected neighbors as defined by the adjacency matrix. but if you set a value greater than 0, it will continue to fetch the closest points up to this many, excluding the ones already specified by the adjacency matrix
    num_adj_degrees_embed = 2,       # this will derive the second degree connections and embed it correctly
    max_sparse_neighbors = 8         # you can cap the number of neighbors, sampled from within your sparse set of neighbors as defined by the adjacency matrix, if specified
)

feats = torch.randn(1, 128, 32)
coors = torch.randn(1, 128, 3)
mask  = torch.ones(1, 128).bool()

# placeholder adjacency matrix
# naively assuming the sequence is one long chain (128, 128)

i = torch.arange(128)
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))

out = model(feats, coors, mask, adj_mat = adj_mat)

out.type0 # (1, 128)
out.type1 # (1, 128, 3)

Appreciation

  • StabilityAI for the generous sponsorship, as well as my other sponsors out there

Testing

Tests for equivariance etc

$ python setup.py test

Example

First install sidechainnet

$ pip install sidechainnet

Then run the protein backbone denoising task

$ python denoise.py

Todo

  • move xi and xj separate project and sum logic into Conv class

  • move self interacting key / value production into Conv, fix no pooling in conv with self interaction

  • go with a naive way to split up contribution from input degrees for DTP

  • for dot product attention in higher types, try euclidean distance

  • consider a all-neighbors attention layer just for type0, using linear attention

  • integrate the new finding from spherical channels paper, followed up by so(3) -> so(2) paper, which reduces the computation from O(L^6) -> O(L^3)!

    • add rotation matrix -> ZYZ euler angles
    • function for deriving rotation matrix for r_ij -> (0, 1, 0)
    • prepare get_basis to return D for rotating representations to (0, 1, 0) to greatly simplify spherical harmonics
    • add tests for batch rotating vectors to align with another - handle edge cases (0, 0, 0)?
    • redo get_basis to only calculate spherical harmonics Y for (0, 1, 0) and cache
    • do the further optimization to remove clebsch gordan (since m_i only depends on m_o), as noted in eSCN paper
    • validate one can train at higher degrees
    • figure out the whole linear bijection argument in appendix of eSCN and why parameterized lf can be removed
    • figure out why training NaNs with float32
    • refactor into full so3 -> so2 linear layer, as proposed in eSCN paper
    • add equiformer v2, and start looking into equivariant protein backbone diffusion again

Citations

@article{Liao2022EquiformerEG,
    title   = {Equiformer: Equivariant Graph Attention Transformer for 3D Atomistic Graphs},
    author  = {Yi Liao and Tess E. Smidt},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2206.11990}
}
@article {Lee2022.10.07.511322,
    author  = {Lee, Jae Hyeon and Yadollahpour, Payman and Watkins, Andrew and Frey, Nathan C. and Leaver-Fay, Andrew and Ra, Stephen and Cho, Kyunghyun and Gligorijevic, Vladimir and Regev, Aviv and Bonneau, Richard},
    title   = {EquiFold: Protein Structure Prediction with a Novel Coarse-Grained Structure Representation},
    elocation-id = {2022.10.07.511322},
    year    = {2022},
    doi     = {10.1101/2022.10.07.511322},
    publisher = {Cold Spring Harbor Laboratory},
    URL     = {https://www.biorxiv.org/content/early/2022/10/08/2022.10.07.511322},
    eprint  = {https://www.biorxiv.org/content/early/2022/10/08/2022.10.07.511322.full.pdf},
    journal = {bioRxiv}
}
@article{Shazeer2019FastTD,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam M. Shazeer},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1911.02150}
}
@misc{ding2021cogview,
    title   = {CogView: Mastering Text-to-Image Generation via Transformers},
    author  = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
    year    = {2021},
    eprint  = {2105.13290},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{Kim2020TheLC,
    title   = {The Lipschitz Constant of Self-Attention},
    author  = {Hyunjik Kim and George Papamakarios and Andriy Mnih},
    booktitle = {International Conference on Machine Learning},
    year    = {2020}
}
@article{Zitnick2022SphericalCF,
    title   = {Spherical Channels for Modeling Atomic Interactions},
    author  = {C. Lawrence Zitnick and Abhishek Das and Adeesh Kolluru and Janice Lan and Muhammed Shuaibi and Anuroop Sriram and Zachary W. Ulissi and Brandon C. Wood},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2206.14331}
}
@article{Passaro2023ReducingSC,
  title     = {Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs},
  author    = {Saro Passaro and C. Lawrence Zitnick},
  journal   = {ArXiv},
  year      = {2023},
  volume    = {abs/2302.03655}
}
@inproceedings{Gomez2017TheRR,
    title   = {The Reversible Residual Network: Backpropagation Without Storing Activations},
    author  = {Aidan N. Gomez and Mengye Ren and Raquel Urtasun and Roger Baker Grosse},
    booktitle = {NIPS},
    year    = {2017}
}
@article{Bondarenko2023QuantizableTR,
    title   = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
    author  = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2306.12929},
    url     = {https://api.semanticscholar.org/CorpusID:259224568}
}
@inproceedings{Arora2023ZoologyMA,
  title   = {Zoology: Measuring and Improving Recall in Efficient Language Models},
  author  = {Simran Arora and Sabri Eyuboglu and Aman Timalsina and Isys Johnson and Michael Poli and James Zou and Atri Rudra and Christopher R'e},
  year    = {2023},
  url     = {https://api.semanticscholar.org/CorpusID:266149332}
}