Lucidrains-系列项目源码解析-八十八-

63 阅读28分钟

Lucidrains 系列项目源码解析(八十八)

.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\basis.py

# 导入必要的库
import os
from math import pi
import torch
from torch import einsum
from einops import rearrange
from itertools import product
from contextlib import contextmanager

# 导入自定义库
from se3_transformer_pytorch.irr_repr import irr_repr, spherical_harmonics
from se3_transformer_pytorch.utils import torch_default_dtype, cache_dir, exists, default, to_order
from se3_transformer_pytorch.spherical_harmonics import clear_spherical_harmonics_cache

# 常量定义

# 设置缓存路径,默认为用户主目录下的.cache.equivariant_attention文件夹
CACHE_PATH = default(os.getenv('CACHE_PATH'), os.path.expanduser('~/.cache.equivariant_attention'))
# 如果环境变量CLEAR_CACHE存在,则将缓存路径设为None
CACHE_PATH = CACHE_PATH if not exists(os.environ.get('CLEAR_CACHE')) else None

# 随机角度列表
# todo (figure ot why this was hard coded in official repo)
RANDOM_ANGLES = [ 
    [4.41301023, 5.56684102, 4.59384642],
    [4.93325116, 6.12697327, 4.14574096],
    [0.53878964, 4.09050444, 5.36539036],
    [2.16017393, 3.48835314, 5.55174441],
    [2.52385107, 0.2908958, 3.90040975]
]

# 辅助函数

# 空上下文管理器
@contextmanager
def null_context():
    yield

# 函数定义

def get_matrix_kernel(A, eps = 1e-10):
    '''
    计算矩阵A的核的正交基(x_1, x_2, ...)
    A x_i = 0
    scalar_product(x_i, x_j) = delta_ij

    :param A: 矩阵
    :return: 每行是A核的基向量的矩阵
    '''
    _u, s, v = torch.svd(A)
    kernel = v.t()[s < eps]
    return kernel

def get_matrices_kernel(As, eps = 1e-10):
    '''
    计算所有矩阵As的公共核
    '''
    matrix = torch.cat(As, dim=0)
    return get_matrix_kernel(matrix, eps)

def get_spherical_from_cartesian(cartesian, divide_radius_by = 1.0):
    """
    将笛卡尔坐标转换为球坐标

    # ON ANGLE CONVENTION
    #
    # sh has following convention for angles:
    # :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)).
    # :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi.
    #
    # the 3D steerable CNN code therefore (probably) has the following convention for alpha and beta:
    # beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1).
    # alpha = phi
    #
    """
    # 初始化返回数组
    spherical = torch.zeros_like(cartesian)

    # 索引
    ind_radius, ind_alpha, ind_beta = 0, 1, 2

    cartesian_x, cartesian_y, cartesian_z = 2, 0, 1

    # 获取在xy平面上的投影半径
    r_xy = cartesian[..., cartesian_x] ** 2 + cartesian[..., cartesian_y] ** 2

    # 获取第二个角度
    # 版本 'elevation angle defined from Z-axis down'
    spherical[..., ind_beta] = torch.atan2(torch.sqrt(r_xy), cartesian[..., cartesian_z])

    # 获取xy平面上的角度
    spherical[...,ind_alpha] = torch.atan2(cartesian[...,cartesian_y], cartesian[...,cartesian_x])

    # 获取整体半径
    radius = torch.sqrt(r_xy + cartesian[...,cartesian_z]**2)

    if divide_radius_by != 1.0:
        radius /= divide_radius_by

    spherical[..., ind_radius] = radius
    return spherical

def kron(a, b):
    """
    计算矩阵a和b的Kronecker积
    """
    res = einsum('... i j, ... k l -> ... i k j l', a, b)
    return rearrange(res, '... i j k l -> ... (i j) (k l)')

def get_R_tensor(order_out, order_in, a, b, c):
    return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c)

def sylvester_submatrix(order_out, order_in, J, a, b, c):
    ''' 生成用于在子空间J中解Sylvester方程的Kronecker积矩阵 '''
    R_tensor = get_R_tensor(order_out, order_in, a, b, c)  # [m_out * m_in, m_out * m_in]
    R_irrep_J = irr_repr(J, a, b, c)  # [m, m]

    R_tensor_identity = torch.eye(R_tensor.shape[0])
    R_irrep_J_identity = torch.eye(R_irrep_J.shape[0]
    # 计算两个张量的 Kronecker 乘积,并返回结果
    return kron(R_tensor, R_irrep_J_identity) - kron(R_tensor_identity, R_irrep_J.t())  # [(m_out * m_in) * m, (m_out * m_in) * m]
# 使用缓存目录装饰器,指定缓存路径为 CACHE_PATH
# 使用默认的 torch 浮点数类型为 float64 装饰器
# 禁用 torch 的梯度计算功能装饰器
def basis_transformation_Q_J(J, order_in, order_out, random_angles = RANDOM_ANGLES):
    """
    :param J: 球谐函数的阶数
    :param order_in: 输入表示的阶数
    :param order_out: 输出表示的阶数
    :return: 文章中 Q^-1 矩阵的一部分
    """
    # 生成 Sylvester 子矩阵列表
    sylvester_submatrices = [sylvester_submatrix(order_out, order_in, J, a, b, c) for a, b, c in random_angles]
    # 获取 Sylvester 子矩阵的零空间
    null_space = get_matrices_kernel(sylvester_submatrices)
    # 断言零空间的大小为 1,即唯一的子空间解
    assert null_space.size(0) == 1, null_space.size()
    # 获取 Q_J 矩阵
    Q_J = null_space[0]  # [(m_out * m_in) * m]
    # 重塑 Q_J 矩阵的形状
    Q_J = Q_J.view(to_order(order_out) * to_order(order_in), to_order(J))  # [m_out * m_in, m]
    # 转换为 float 类型并返回
    return Q_J.float()  # [m_out * m_in, m]

# 预计算球谐函数直到最大阶数 max_J
def precompute_sh(r_ij, max_J):
    """
    预计算球谐函数直到最大阶数 max_J

    :param r_ij: 相对位置
    :param max_J: 整个网络中使用的最大阶数
    :return: 字典,每个条目的形状为 [B,N,K,2J+1]
    """
    i_alpha, i_beta = 1, 2
    # 生成球谐函数字典
    Y_Js = {J: spherical_harmonics(J, r_ij[...,i_alpha], r_ij[...,i_beta]) for J in range(max_J + 1)}
    # 清除球谐函数缓存
    clear_spherical_harmonics_cache()
    return Y_Js

# 获取等变权重基础(基础)函数
def get_basis(r_ij, max_degree, differentiable = False):
    """Return equivariant weight basis (basis)

    Call this function *once* at the start of each forward pass of the model.
    It computes the equivariant weight basis, W_J^lk(x), and internodal 
    distances, needed to compute varphi_J^lk(x), of eqn 8 of
    https://arxiv.org/pdf/2006.10503.pdf. The return values of this function 
    can be shared as input across all SE(3)-Transformer layers in a model.

    Args:
        r_ij: relative positional vectors
        max_degree: non-negative int for degree of highest feature-type
        differentiable: whether r_ij should receive gradients from basis
    Returns:
        dict of equivariant bases, keys are in form '<d_in><d_out>'
    """

    # 相对位置编码(向量)
    context = null_context if not differentiable else torch.no_grad

    device, dtype = r_ij.device, r_ij.dtype

    with context():
        # 将笛卡尔坐标系转换为球坐标系
        r_ij = get_spherical_from_cartesian(r_ij)

        # 预计算球谐函数
        Y = precompute_sh(r_ij, 2 * max_degree)

        # 等变基础(字典['d_in><d_out>'])
        basis = {}
        for d_in, d_out in product(range(max_degree+1), range(max_degree+1)):
            K_Js = []
            for J in range(abs(d_in - d_out), d_in + d_out + 1):
                # 获取球谐函数变换矩阵 Q_J
                Q_J = basis_transformation_Q_J(J, d_in, d_out)
                Q_J = Q_J.type(dtype).to(device)

                # 从球谐函数创建核
                K_J = torch.matmul(Y[J], Q_J.T)
                K_Js.append(K_J)

            # 重塑以便可以使用点积进行线性组合
            K_Js = torch.stack(K_Js, dim = -1)
            size = (*r_ij.shape[:-1], 1, to_order(d_out), 1, to_order(d_in), to_order(min(d_in,d_out)))
            basis[f'{d_in},{d_out}'] = K_Js.view(*size)

    # 额外的 detach 以确保安全
    if not differentiable:
        for k, v in basis.items():
            basis[k] = v.detach()

    return basis

.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\irr_repr.py

# 导入所需的库
import os
import numpy as np
import torch
from torch import sin, cos, atan2, acos
from math import pi
from pathlib import Path
from functools import wraps

# 导入自定义的函数和类
from se3_transformer_pytorch.utils import exists, default, cast_torch_tensor, to_order
from se3_transformer_pytorch.spherical_harmonics import get_spherical_harmonics, clear_spherical_harmonics_cache

# 设置数据路径
DATA_PATH = path = Path(os.path.dirname(__file__)) / 'data'

# 尝试加载预先计算好的 J_dense 数据
try:
    path = DATA_PATH / 'J_dense.pt'
    Jd = torch.load(str(path))
except:
    # 如果加载失败,则加载 numpy 格式的数据并转换为 torch 格式
    path = DATA_PATH / 'J_dense.npy'
    Jd_np = np.load(str(path), allow_pickle = True)
    Jd = list(map(torch.from_numpy, Jd_np))

# 创建 Wigner D 矩阵
def wigner_d_matrix(degree, alpha, beta, gamma, dtype = None, device = None):
    """Create wigner D matrices for batch of ZYZ Euler anglers for degree l."""
    J = Jd[degree].type(dtype).to(device)
    order = to_order(degree)
    x_a = z_rot_mat(alpha, degree)
    x_b = z_rot_mat(beta, degree)
    x_c = z_rot_mat(gamma, degree)
    res = x_a @ J @ x_b @ J @ x_c
    return res.view(order, order)

# 创建绕 Z 轴旋转的矩阵
def z_rot_mat(angle, l):
    device, dtype = angle.device, angle.dtype
    order = to_order(l)
    m = angle.new_zeros((order, order))
    inds = torch.arange(0, order, 1, dtype=torch.long, device=device)
    reversed_inds = torch.arange(2 * l, -1, -1, dtype=torch.long, device=device)
    frequencies = torch.arange(l, -l - 1, -1, dtype=dtype, device=device)[None]

    m[inds, reversed_inds] = sin(frequencies * angle[None])
    m[inds, inds] = cos(frequencies * angle[None])
    return m

# 创建不可约表示
def irr_repr(order, alpha, beta, gamma, dtype = None):
    """
    irreducible representation of SO3
    - compatible with compose and spherical_harmonics
    """
    cast_ = cast_torch_tensor(lambda t: t)
    dtype = default(dtype, torch.get_default_dtype())
    alpha, beta, gamma = map(cast_, (alpha, beta, gamma))
    return wigner_d_matrix(order, alpha, beta, gamma, dtype = dtype)

# 绕 Z 轴旋转
@cast_torch_tensor
def rot_z(gamma):
    '''
    Rotation around Z axis
    '''
    return torch.tensor([
        [cos(gamma), -sin(gamma), 0],
        [sin(gamma), cos(gamma), 0],
        [0, 0, 1]
    ], dtype=gamma.dtype)

# 绕 Y 轴旋转
@cast_torch_tensor
def rot_y(beta):
    '''
    Rotation around Y axis
    '''
    return torch.tensor([
        [cos(beta), 0, sin(beta)],
        [0, 1, 0],
        [-sin(beta), 0, cos(beta)]
    ], dtype=beta.dtype)

# 将球���上的点转换为 alpha 和 beta
@cast_torch_tensor
def x_to_alpha_beta(x):
    '''
    Convert point (x, y, z) on the sphere into (alpha, beta)
    '''
    x = x / torch.norm(x)
    beta = acos(x[2])
    alpha = atan2(x[1], x[0])
    return (alpha, beta)

# ZYZ 欧拉角旋转
def rot(alpha, beta, gamma):
    '''
    ZYZ Euler angles rotation
    '''
    return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)

# 合成旋转
def compose(a1, b1, c1, a2, b2, c2):
    """
    (a, b, c) = (a1, b1, c1) composed with (a2, b2, c2)
    """
    comp = rot(a1, b1, c1) @ rot(a2, b2, c2)
    xyz = comp @ torch.tensor([0, 0, 1.])
    a, b = x_to_alpha_beta(xyz)
    rotz = rot(0, -b, -a) @ comp
    c = atan2(rotz[1, 0], rotz[0, 0])
    return a, b, c

# 计算球谐函数
def spherical_harmonics(order, alpha, beta, dtype = None):
    return get_spherical_harmonics(order, theta = (pi - beta), phi = alpha)

.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\reversible.py

import torch  # 导入 PyTorch 库
import torch.nn as nn  # 导入 PyTorch 中的神经网络模块
from torch.autograd.function import Function  # 导入 PyTorch 中的自动微分函数
from torch.utils.checkpoint import get_device_states, set_device_states  # 导入 PyTorch 中的检查点函数

# 辅助函数

def map_values(fn, x):  # 定义一个函数,对字典中的值应用给定函数
    out = {}
    for (k, v) in x.items():
        out[k] = fn(v)
    return out

def dict_chunk(x, chunks, dim):  # 定义一个函数,将字典中的值按给定维度和块数进行分块
    out1 = {}
    out2 = {}
    for (k, v) in x.items():
        c1, c2 = v.chunk(chunks, dim=dim)
        out1[k] = c1
        out2[k] = c2
    return out1, out2

def dict_sum(x, y):  # 定义一个函数,对两个字典中的值进行相加
    out = {}
    for k in x.keys():
        out[k] = x[k] + y[k]
    return out

def dict_subtract(x, y):  # 定义一个函数,对两个字典中的值进行相减
    out = {}
    for k in x.keys():
        out[k] = x[k] - y[k]
    return out

def dict_cat(x, y, dim):  # 定义一个函数,对两个字典中的值按给定维度进行拼接
    out = {}
    for k, v1 in x.items():
        v2 = y[k]
        out[k] = torch.cat((v1, v2), dim=dim)
    return out

def dict_set_(x, key, value):  # 定义一个函数,设置字典中所有值的指定属性为给定值
    for k, v in x.items():
        setattr(v, key, value)

def dict_backwards_(outputs, grad_tensors):  # 定义一个函数,对字典中的值进行反向传播
    for k, v in outputs.items():
        torch.autograd.backward(v, grad_tensors[k], retain_graph=True)

def dict_del_(x):  # 定义一个函数,删除字典中的所有值
    for k, v in x.items():
        del v
    del x

def values(d):  # 定义一个函数,返回字典中所有值的列表
    return [v for _, v in d.items()]

# 参考以下示例保存和设置随机数生成器 https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
class Deterministic(nn.Module):  # 定义一个类,用于确定性计算
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.cpu_state = None
        self.cuda_in_fwd = None
        self.gpu_devices = None
        self.gpu_states = None

    def record_rng(self, *args):  # 记录随机数生成器状态
        self.cpu_state = torch.get_rng_state()
        if torch.cuda._initialized:
            self.cuda_in_fwd = True
            self.gpu_devices, self.gpu_states = get_device_states(*args)

    def forward(self, *args, record_rng=False, set_rng=False, **kwargs):  # 前向传播函数
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)

# 受 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 启发
# 一旦多 GPU 工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module):  # 定义一个可逆块
    def __init__(self, f, g):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)

    def forward(self, x, **kwargs):  # 前向传播函数
        training = self.training
        x1, x2 = dict_chunk(x, 2, dim=-1)
        y1, y2 = None, None

        with torch.no_grad():
            y1 = dict_sum(x1, self.f(x2, record_rng=training, **kwargs))
            y2 = dict_sum(x2, self.g(y1, record_rng=training))

        return dict_cat(y1, y2, dim=-1)
    # 定义反向传播函数,接收输入 y、梯度 dy 和其他参数
    def backward_pass(self, y, dy, **kwargs):
        # 将 y 按照指定维度分成两部分 y1 和 y2
        y1, y2 = dict_chunk(y, 2, dim = -1)
        # 删除原始 y 字典
        dict_del_(y)

        # 将 dy 按照指定维度分成两部分 dy1 和 dy2
        dy1, dy2 = dict_chunk(dy, 2, dim = -1)
        # 删除原始 dy 字典
        dict_del_(dy)

        # 开启梯度追踪
        with torch.enable_grad():
            # 设置 y1 的 requires_grad 为 True
            dict_set_(y1, 'requires_grad', True)
            # 计算 y1 的梯度 gy1
            gy1 = self.g(y1, set_rng = True)
            # 对 gy1 进行反向传播,传入 dy2
            dict_backwards_(gy1, dy2)

        # 关闭梯度追踪
        with torch.no_grad():
            # 计算 x2,即 y2 减去 gy1
            x2 = dict_subtract(y2, gy1)
            # 删除 y2 和 gy1
            dict_del_(y2)
            dict_del_(gy1)

            # 计算 dx1,即 dy1 加上 y1 中各张量的梯度
            dx1 = dict_sum(dy1, map_values(lambda t: t.grad, y1))
            # 删除 dy1,并将 y1 的梯度设为 None
            dict_del_(dy1)
            dict_set_(y1, 'grad', None)

        # 开启梯度追踪
        with torch.enable_grad():
            # 设置 x2 的 requires_grad 为 True
            dict_set_(x2, 'requires_grad', True)
            # 计算 fx2,即对 x2 进行操作并计算梯度
            fx2 = self.f(x2, set_rng = True, **kwargs)
            # 对 fx2 进行反向传播,传入 dx1
            dict_backwards_(fx2, dx1)

        # 关闭梯度追踪
        with torch.no_grad():
            # 计算 x1,即 y1 减去 fx2
            x1 = dict_subtract(y1, fx2)
            # 删除 y1 和 fx2
            dict_del_(y1)
            dict_del_(fx2)

            # 计算 dx2,即 dy2 加上 x2 中各张量的梯度
            dx2 = dict_sum(dy2, map_values(lambda t: t.grad, x2))
            # 删除 dy2,并将 x2 的梯度设为 None
            dict_del_(dy2)
            dict_set_(x2, 'grad', None)

            # 将 x2 中的张量都 detach,即不再追踪梯度
            x2 = map_values(lambda t: t.detach(), x2)

            # 将 x1 和 x2 按照指定维度拼接成 x
            x = dict_cat(x1, x2, dim = -1)
            # 将 dx1 和 dx2 按照指定维度拼接成 dx
            dx = dict_cat(dx1, dx2, dim = -1)

        # 返回拼接后的 x 和 dx
        return x, dx
class _ReversibleFunction(Function):
    # 定义一个继承自Function的类_ReversibleFunction
    @staticmethod
    def forward(ctx, x, blocks, kwargs):
        # 定义静态方法forward,接受输入x、blocks和kwargs
        input_keys = kwargs.pop('input_keys')
        # 从kwargs中弹出键为'input_keys'的值
        split_dims = kwargs.pop('split_dims')
        # 从kwargs中弹出键为'split_dims'的值
        input_values = x.split(split_dims, dim = -1)
        # 将输入x按照split_dims进行分割,得到输入值列表
        x = dict(zip(input_keys, input_values))
        # 将输入键和值列表组合成字典

        ctx.kwargs = kwargs
        ctx.split_dims = split_dims
        ctx.input_keys = input_keys
        # 将kwargs、split_dims和input_keys保存在上下文对象ctx中

        for block in blocks:
            x = block(x, **kwargs)
        # 遍历blocks中的每个块,对输入x进行处理

        ctx.y = map_values(lambda t: t.detach(), x)
        # 将x中的值进行detach操作,保存在ctx.y中
        ctx.blocks = blocks
        # 将blocks保存在ctx.blocks中

        x = torch.cat(values(x), dim = -1)
        # 将x中的值按照dim = -1进行拼接
        return x
        # 返回处理后的x

    @staticmethod
    def backward(ctx, dy):
        # 定义静态方法backward,接受输入dy
        y = ctx.y
        kwargs = ctx.kwargs
        input_keys = ctx.input_keys
        split_dims = ctx.split_dims
        # 从上下文对象ctx中获取y、kwargs、input_keys和split_dims

        dy = dy.split(split_dims, dim = -1)
        # 将dy按照split_dims进行分割
        dy = dict(zip(input_keys, dy))
        # 将分割后的dy与input_keys组合成字典

        for block in ctx.blocks[::-1]:
            y, dy = block.backward_pass(y, dy, **kwargs)
        # 逆序遍历ctx.blocks中的每个块,对y和dy进行反向传播

        dy = torch.cat(values(dy), dim = -1)
        # 将dy中的值按照dim = -1进行拼接
        return dy, None, None
        # 返回处理后的dy,以及None值



class SequentialSequence(nn.Module):
    # 定义一个继承自nn.Module的类SequentialSequence
    def __init__(self, blocks):
        # 初始化方法,接受blocks作为参数
        super().__init__()
        self.blocks = blocks
        # 调用父类的初始化方法,并将blocks保存在self.blocks中

    def forward(self, x, **kwargs):
        # 前向传播方法,接受输入x和kwargs
        for (attn, ff) in self.blocks:
            x = attn(x, **kwargs)
            x = ff(x)
        # 遍历self.blocks中的每个元素,对输入x进行处理
        return x
        # 返回处理后的x



class ReversibleSequence(nn.Module):
    # 定义一个继承自nn.Module的类ReversibleSequence
    def __init__(self, blocks):
        # 初始化方法,接受blocks作为参数
        super().__init__()
        self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks])
        # 调用父类的初始化方法,并将blocks中的每个元素(f, g)构建成ReversibleBlock对象保存在self.blocks中

    def forward(self, x, **kwargs):
        # 前向传播方法,接受输入x和kwargs
        blocks = self.blocks

        x = map_values(lambda t: torch.cat((t, t), dim = -1), x)
        # 对输入x中的值进行操作,将每个值与自身拼接

        input_keys = x.keys()
        split_dims = tuple(map(lambda t: t.shape[-1], x.values()))
        # 获取输入x的键和每个值的最后一个维度大小,保存在split_dims中
        block_kwargs = {'input_keys': input_keys, 'split_dims': split_dims, **kwargs}
        # 构建块的参数字典,包括input_keys、split_dims和kwargs

        x = torch.cat(values(x), dim = -1)
        # 将输入x中的值按照dim = -1进行拼接

        x = _ReversibleFunction.apply(x, blocks, block_kwargs)
        # 调用_ReversibleFunction的apply方法进行处理

        x = dict(zip(input_keys, x.split(split_dims, dim = -1)))
        # 将处理后的x按照split_dims进行分割,组合成字典
        x = map_values(lambda t: torch.stack(t.chunk(2, dim = -1)).mean(dim = 0), x)
        # 对x中的值进行操作,拆分成两部分,取平均值
        return x
        # 返回处理后的x

.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\rotary.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 定义 SinusoidalEmbeddings 类,继承自 nn.Module
class SinusoidalEmbeddings(nn.Module):
    # 初始化函数,接受维度参数 dim
    def __init__(self, dim):
        # 调用父类的初始化函数
        super().__init__()
        # 计算频率的倒数,用于生成正弦位置编码
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        # 将频率的倒数作为缓冲区注册到模型中
        self.register_buffer('inv_freq', inv_freq)

    # 前向传播函数,接受输入张量 t
    def forward(self, t):
        # 计算频率,用于生成正弦位置编码
        freqs = t[..., None].float() * self.inv_freq[None, :]
        # 将频率重复两次,用于位置编码
        return repeat(freqs, '... d -> ... (d r)', r = 2)

# 定义 rotate_half 函数,用于旋转输入张量的一半
def rotate_half(x):
    # 重新排列输入张量的维度
    x = rearrange(x, '... (d j) m -> ... d j m', j = 2)
    # 将输入张量按照最后一个维度拆分为两部分
    x1, x2 = x.unbind(dim = -2)
    # 将两部分张量进行旋转并拼接在一起
    return torch.cat((-x2, x1), dim = -2)

# 定义 apply_rotary_pos_emb 函数,用于应用旋转位置编码
def apply_rotary_pos_emb(t, freqs):
    # 获取旋转维度的大小
    rot_dim = freqs.shape[-2]
    # 将输入张量 t 拆分为旋转部分和非旋转部分
    t, t_pass = t[..., :rot_dim, :], t[..., rot_dim:, :]
    # 应用旋转位置编码到输入张量 t
    t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
    # 将旋转部分和非旋转部分拼接在一起
    return torch.cat((t, t_pass), dim = -2)

.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\se3_transformer_pytorch.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 从 itertools 模块中导入 product 函数
from itertools import product
# 从 collections 模块中导入 namedtuple 类
from collections import namedtuple

# 导入 torch 库
import torch
# 从 torch.nn.functional 模块中导入 F 别名
import torch.nn.functional as F
# 从 torch 模块中导入 nn 和 einsum 函数
from torch import nn, einsum

# 导入自定义模块
from se3_transformer_pytorch.basis import get_basis
from se3_transformer_pytorch.utils import exists, default, uniq, map_values, batched_index_select, masked_mean, to_order, fourier_encode, cast_tuple, safe_cat, fast_split, rand_uniform, broadcat
from se3_transformer_pytorch.reversible import ReversibleSequence, SequentialSequence
from se3_transformer_pytorch.rotary import SinusoidalEmbeddings, apply_rotary_pos_emb

# 从 einops 模块中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 定义命名元组 FiberEl,包含 degrees 和 dim 两个字段
FiberEl = namedtuple('FiberEl', ['degrees', 'dim'])

# 定义 Fiber 类
class Fiber(nn.Module):
    def __init__(
        self,
        structure
    ):
        super().__init__()
        # 如果 structure 是字典,则转换为列表形式
        if isinstance(structure, dict):
            structure = [FiberEl(degree, dim) for degree, dim in structure.items()]
        self.structure = structure

    # 返回所有维度的列表
    @property
    def dims(self):
        return uniq(map(lambda t: t[1], self.structure))

    # 返回所有度数的生成器
    @property
    def degrees(self):
        return map(lambda t: t[0], self.structure)

    # 创建 Fiber 实例
    @staticmethod
    def create(num_degrees, dim):
        dim_tuple = dim if isinstance(dim, tuple) else ((dim,) * num_degrees)
        return Fiber([FiberEl(degree, dim) for degree, dim in zip(range(num_degrees), dim_tuple)])

    # 获取指定度数的元素
    def __getitem__(self, degree):
        return dict(self.structure)[degree]

    # 迭代器方法
    def __iter__(self):
        return iter(self.structure)

    # 定义乘法操作
    def __mul__(self, fiber):
        return product(self.structure, fiber.structure)

    # 定义与操作
    def __and__(self, fiber):
        out = []
        degrees_out = fiber.degrees
        for degree, dim in self:
            if degree in fiber.degrees:
                dim_out = fiber[degree]
                out.append((degree, dim, dim_out))
        return out

# 获取张量的设备和数据类型
def get_tensor_device_and_dtype(features):
    first_tensor = next(iter(features.items()))[1]
    return first_tensor.device, first_tensor.dtype

# 定义 ResidualSE3 类
class ResidualSE3(nn.Module):
    """ only support instance where both Fibers are identical """
    def forward(self, x, res):
        out = {}
        for degree, tensor in x.items():
            degree = str(degree)
            out[degree] = tensor
            if degree in res:
                out[degree] = out[degree] + res[degree]
        return out

# 定义 LinearSE3 类
class LinearSE3(nn.Module):
    def __init__(
        self,
        fiber_in,
        fiber_out
    ):
        super().__init__()
        self.weights = nn.ParameterDict()

        for (degree, dim_in, dim_out) in (fiber_in & fiber_out):
            key = str(degree)
            self.weights[key]  = nn.Parameter(torch.randn(dim_in, dim_out) / sqrt(dim_in))

    def forward(self, x):
        out = {}
        for degree, weight in self.weights.items():
            out[degree] = einsum('b n d m, d e -> b n e m', x[degree], weight)
        return out

# 定义 NormSE3 类
class NormSE3(nn.Module):
    """Norm-based SE(3)-equivariant nonlinearity.
    
    Nonlinearities are important in SE(3) equivariant GCNs. They are also quite 
    expensive to compute, so it is convenient for them to share resources with
    other layers, such as normalization. The general workflow is as follows:

    > for feature type in features:
    >    norm, phase <- feature
    >    output = fnc(norm) * phase
    
    where fnc: {R+}^m -> R^m is a learnable map from m norms to m scalars.
    """
    def __init__(
        self,
        fiber,
        nonlin = nn.GELU(),
        gated_scale = False,
        eps = 1e-12,
    # 初始化函数,设置初始参数
    def __init__(
        self,
        fiber,
        nonlin = nn.ReLU(),
        eps = 1e-12,
        gated_scale = False
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 将参数赋值给对象属性
        self.fiber = fiber
        self.nonlin = nonlin
        self.eps = eps

        # Norm mappings: 1 per feature type
        # 创建一个 ModuleDict 对象,用于存储每种特征类型的规范化映射
        self.transform = nn.ModuleDict()
        # 遍历 fiber 中的每个元素
        for degree, chan in fiber:
            # 为每种特征类型创建一个参数字典
            self.transform[str(degree)] = nn.ParameterDict({
                'scale': nn.Parameter(torch.ones(1, 1, chan)) if not gated_scale else None,
                'w_gate': nn.Parameter(rand_uniform((chan, chan), -1e-3, 1e-3)) if gated_scale else None
            })

    # 前向传播函数
    def forward(self, features):
        # 初始化输出字典
        output = {}
        # 遍历输入的特征字典
        for degree, t in features.items():
            # 计算规范化和归一化特征
            norm = t.norm(dim = -1, keepdim = True).clamp(min = self.eps)
            phase = t / norm

            # Transform on norms
            # 获取当前特征类型对应的参数
            parameters = self.transform[degree]
            gate_weights, scale = parameters['w_gate'], parameters['scale']

            # 重排特征
            transformed = rearrange(norm, '... () -> ...')

            # 如果缺少 scale 参数,则使用 gate_weights 进行计算
            if not exists(scale):
                scale = einsum('b n d, d e -> b n e', transformed, gate_weights)

            # 对特征进行非线性变换
            transformed = self.nonlin(transformed * scale)
            transformed = rearrange(transformed, '... -> ... ()')

            # 对规范化特征进行非线性变换
            output[degree] = (transformed * phase).view(*t.shape)

        # 返回输出字典
        return output
class ConvSE3(nn.Module):
    """定义一个张量场网络层
    
    ConvSE3代表一个SE(3)-等变卷积层。它相当于MLP中的线性层,CNN中的卷积层,或者GCN中的图卷积层。

    在每个节点上,激活被分成不同的“特征类型”,由SE(3)表示类型索引:非负整数0, 1, 2, ..
    """
    def __init__(
        self,
        fiber_in,
        fiber_out,
        self_interaction = True,
        pool = True,
        edge_dim = 0,
        fourier_encode_dist = False,
        num_fourier_features = 4,
        splits = 4
    ):
        super().__init__()
        self.fiber_in = fiber_in
        self.fiber_out = fiber_out
        self.edge_dim = edge_dim
        self.self_interaction = self_interaction

        self.num_fourier_features = num_fourier_features
        self.fourier_encode_dist = fourier_encode_dist

        # radial function will assume a dimension of at minimum 1, for the relative distance - extra fourier features must be added to the edge dimension
        edge_dim += (0 if not fourier_encode_dist else (num_fourier_features * 2))

        # Neighbor -> center weights
        self.kernel_unary = nn.ModuleDict()

        self.splits = splits # for splitting the computation of kernel and basis, to reduce peak memory usage

        for (di, mi), (do, mo) in (self.fiber_in * self.fiber_out):
            self.kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim = edge_dim, splits = splits)

        self.pool = pool

        # Center -> center weights
        if self_interaction:
            assert self.pool, 'must pool edges if followed with self interaction'
            self.self_interact = LinearSE3(fiber_in, fiber_out)
            self.self_interact_sum = ResidualSE3()

    def forward(
        self,
        inp,
        edge_info,
        rel_dist = None,
        basis = None
        ):
            # 获取拆分信息
            splits = self.splits
            neighbor_indices, neighbor_masks, edges = edge_info
            # 重新排列相对距离的维度
            rel_dist = rearrange(rel_dist, 'b m n -> b m n ()')

            kernels = {}
            outputs = {}

            if self.fourier_encode_dist:
                # 对相对距离进行傅立叶编码
                rel_dist = fourier_encode(rel_dist[..., None], num_encodings = self.num_fourier_features)

            # 拆分基础

            basis_keys = basis.keys()
            split_basis_values = list(zip(*list(map(lambda t: fast_split(t, splits, dim = 1), basis.values())))
            split_basis = list(map(lambda v: dict(zip(basis_keys, v)), split_basis_values))

            # 遍历每种输入度类型到输出度类型的排列组合

            for degree_out in self.fiber_out.degrees:
                output = 0
                degree_out_key = str(degree_out)

                for degree_in, m_in in self.fiber_in:
                    etype = f'({degree_in},{degree_out})'

                    x = inp[str(degree_in)]

                    x = batched_index_select(x, neighbor_indices, dim = 1)
                    x = x.view(*x.shape[:3], to_order(degree_in) * m_in, 1)

                    kernel_fn = self.kernel_unary[etype]
                    edge_features = torch.cat((rel_dist, edges), dim = -1) if exists(edges) else rel_dist

                    output_chunk = None
                    split_x = fast_split(x, splits, dim = 1)
                    split_edge_features = fast_split(edge_features, splits, dim = 1)

                    # 沿着序列维度对输入、边缘和基础进行分块处理

                    for x_chunk, edge_features, basis in zip(split_x, split_edge_features, split_basis):
                        kernel = kernel_fn(edge_features, basis = basis)
                        chunk = einsum('... o i, ... i c -> ... o c', kernel, x_chunk)
                        output_chunk = safe_cat(output_chunk, chunk, dim = 1)

                    output = output + output_chunk

                if self.pool:
                    output = masked_mean(output, neighbor_masks, dim = 2) if exists(neighbor_masks) else output.mean(dim = 2)

                leading_shape = x.shape[:2] if self.pool else x.shape[:3]
                output = output.view(*leading_shape, -1, to_order(degree_out))

                outputs[degree_out_key] = output

            if self.self_interaction:
                self_interact_out = self.self_interact(inp)
                outputs = self.self_interact_sum(outputs, self_interact_out)

            return outputs
class RadialFunc(nn.Module):
    """定义一个神经网络参数化的径向函数。"""
    def __init__(
        self,
        num_freq,
        in_dim,
        out_dim,
        edge_dim = None,
        mid_dim = 128
    ):
        super().__init__()
        self.num_freq = num_freq
        self.in_dim = in_dim
        self.mid_dim = mid_dim
        self.out_dim = out_dim
        self.edge_dim = default(edge_dim, 0)

        self.net = nn.Sequential(
            nn.Linear(self.edge_dim + 1, mid_dim),
            nn.LayerNorm(mid_dim),
            nn.GELU(),
            nn.Linear(mid_dim, mid_dim),
            nn.LayerNorm(mid_dim),
            nn.GELU(),
            nn.Linear(mid_dim, num_freq * in_dim * out_dim)
        )

    def forward(self, x):
        y = self.net(x)
        return rearrange(y, '... (o i f) -> ... o () i () f', i = self.in_dim, o = self.out_dim)

class PairwiseConv(nn.Module):
    """两种单一类型特征之间的SE(3)-等变卷积。"""
    def __init__(
        self,
        degree_in,
        nc_in,
        degree_out,
        nc_out,
        edge_dim = 0,
        splits = 4
    ):
        super().__init__()
        self.degree_in = degree_in
        self.degree_out = degree_out
        self.nc_in = nc_in
        self.nc_out = nc_out

        self.num_freq = to_order(min(degree_in, degree_out))
        self.d_out = to_order(degree_out)
        self.edge_dim = edge_dim

        self.rp = RadialFunc(self.num_freq, nc_in, nc_out, edge_dim)

        self.splits = splits

    def forward(self, feat, basis):
        splits = self.splits
        R = self.rp(feat)
        B = basis[f'{self.degree_in},{self.degree_out}']

        out_shape = (*R.shape[:3], self.d_out * self.nc_out, -1)

        # torch.sum(R * B, dim = -1) is too memory intensive
        # needs to be chunked to reduce peak memory usage

        out = 0
        for i in range(R.shape[-1]):
            out += R[..., i] * B[..., i]

        out = rearrange(out, 'b n h s ... -> (b n h s) ...')

        # reshape and out
        return out.view(*out_shape)

# feed forwards

class FeedForwardSE3(nn.Module):
    def __init__(
        self,
        fiber,
        mult = 4
    ):
        super().__init__()
        self.fiber = fiber
        fiber_hidden = Fiber(list(map(lambda t: (t[0], t[1] * mult), fiber)))

        self.project_in  = LinearSE3(fiber, fiber_hidden)
        self.nonlin      = NormSE3(fiber_hidden)
        self.project_out = LinearSE3(fiber_hidden, fiber)

    def forward(self, features):
        outputs = self.project_in(features)
        outputs = self.nonlin(outputs)
        outputs = self.project_out(outputs)
        return outputs

class FeedForwardBlockSE3(nn.Module):
    def __init__(
        self,
        fiber,
        norm_gated_scale = False
    ):
        super().__init__()
        self.fiber = fiber
        self.prenorm = NormSE3(fiber, gated_scale = norm_gated_scale)
        self.feedforward = FeedForwardSE3(fiber)
        self.residual = ResidualSE3()

    def forward(self, features):
        res = features
        out = self.prenorm(features)
        out = self.feedforward(out)
        return self.residual(out, res)

# attention

class AttentionSE3(nn.Module):
    def __init__(
        self,
        fiber,
        dim_head = 64,
        heads = 8,
        attend_self = False,
        edge_dim = None,
        fourier_encode_dist = False,
        rel_dist_num_fourier_features = 4,
        use_null_kv = False,
        splits = 4,
        global_feats_dim = None,
        linear_proj_keys = False,
        tie_key_values = False
        ):
        # 调用父类的构造函数
        super().__init__()
        # 计算隐藏层维度
        hidden_dim = dim_head * heads
        # 创建隐藏层的 Fiber 对象
        hidden_fiber = Fiber(list(map(lambda t: (t[0], hidden_dim), fiber)))
        # 判断是否需要进行输出投影
        project_out = not (heads == 1 and len(fiber.dims) == 1 and dim_head == fiber.dims[0])

        # 设置缩放因子
        self.scale = dim_head ** -0.5
        self.heads = heads

        # 是否对特征进行线性投影以获得 keys
        self.linear_proj_keys = linear_proj_keys
        # 创建 LinearSE3 对象用于处理 queries
        self.to_q = LinearSE3(fiber, hidden_fiber)
        # 创建 ConvSE3 对象用于处理 values
        self.to_v = ConvSE3(fiber, hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)

        # 检查是否同时进行线性投影 keys 和共享 key / values
        assert not (linear_proj_keys and tie_key_values), 'you cannot do linear projection of keys and have shared key / values turned on at the same time'

        # 根据不同情况创建 keys 处理对象
        if linear_proj_keys:
            self.to_k = LinearSE3(fiber, hidden_fiber)
        elif not tie_key_values:
            self.to_k = ConvSE3(fiber, hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
        else:
            self.to_k = None

        # 创建输出处理对象
        self.to_out = LinearSE3(hidden_fiber, fiber) if project_out else nn.Identity()

        # 是否使用空的 keys 和 values
        self.use_null_kv = use_null_kv
        if use_null_kv:
            self.null_keys = nn.ParameterDict()
            self.null_values = nn.ParameterDict()

            # 初始化空的 keys 和 values
            for degree in fiber.degrees:
                m = to_order(degree)
                degree_key = str(degree)
                self.null_keys[degree_key] = nn.Parameter(torch.zeros(heads, dim_head, m))
                self.null_values[degree_key] = nn.Parameter(torch.zeros(heads, dim_head, m))

        # 是否自我关注
        self.attend_self = attend_self
        if attend_self:
            # 创建自我关注的 keys 处理对象
            self.to_self_k = LinearSE3(fiber, hidden_fiber)
            # 创建自我关注的 values 处理对象
            self.to_self_v = LinearSE3(fiber, hidden_fiber)

        # 是否接受全局特征
        self.accept_global_feats = exists(global_feats_dim)
        if self.accept_global_feats:
            # 创建全局特征的 keys 处理对象
            global_input_fiber = Fiber.create(1, global_feats_dim)
            global_output_fiber = Fiber.create(1, hidden_fiber[0])
            self.to_global_k = LinearSE3(global_input_fiber, global_output_fiber)
            # 创建全局特征的 values 处理对象
            self.to_global_v = LinearSE3(global_input_fiber, global_output_fiber)
    # 定义前向传播函数,接收特征、边信息、相对距离、基础信息、全局特征、位置嵌入和掩码作为输入
    def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None):
        # 获取头数和是否自我关注的标志
        h, attend_self = self.heads, self.attend_self
        # 获取设备和数据类型
        device, dtype = get_tensor_device_and_dtype(features)
        # 解包边信息
        neighbor_indices, neighbor_mask, edges = edge_info

        # 如果邻居掩码存在,则重排维度
        if exists(neighbor_mask):
            neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j')

        # 将特征转换为查询、值和键
        queries = self.to_q(features)
        values  = self.to_v(features, edge_info, rel_dist, basis)

        # 如果使用线性投影的键,则将键映射到邻居索引
        if self.linear_proj_keys:
            keys = self.to_k(features)
            keys = map_values(lambda val: batched_index_select(val, neighbor_indices, dim = 1), keys)
        # 如果没有定义键转换函数,则将键设置为值
        elif not exists(self.to_k):
            keys = values
        else:
            keys = self.to_k(features, edge_info, rel_dist, basis)

        # 如果允许自我关注,则获取自我键和自我值
        if attend_self:
            self_keys, self_values = self.to_self_k(features), self.to_self_v(features)

        # 如果全局特征存在,则获取全局键和全局值
        if exists(global_feats):
            global_keys, global_values = self.to_global_k(global_feats), self.to_global_v(global_feats)

        # 初始化输出字典
        outputs = {}
        # 遍历特征的度
        for degree in features.keys():
            # 获取当前度的查询、键和值
            q, k, v = map(lambda t: t[degree], (queries, keys, values))

            # 重排查询、键和值的维度
            q = rearrange(q, 'b i (h d) m -> b h i d m', h = h)
            k, v = map(lambda t: rearrange(t, 'b i j (h d) m -> b h i j d m', h = h), (k, v))

            # 如果允许自我关注,则处理自我键和自我值
            if attend_self:
                self_k, self_v = map(lambda t: t[degree], (self_keys, self_values))
                self_k, self_v = map(lambda t: rearrange(t, 'b n (h d) m -> b h n () d m', h = h), (self_k, self_v))
                k = torch.cat((self_k, k), dim = 3)
                v = torch.cat((self_v, v), dim = 3)

            # 如果位置嵌入存在且度为'0',则应用旋转位置嵌入
            if exists(pos_emb) and degree == '0':
                query_pos_emb, key_pos_emb = pos_emb
                query_pos_emb = rearrange(query_pos_emb, 'b i d -> b () i d ()')
                key_pos_emb = rearrange(key_pos_emb, 'b i j d -> b () i j d ()')
                q = apply_rotary_pos_emb(q, query_pos_emb)
                k = apply_rotary_pos_emb(k, key_pos_emb)
                v = apply_rotary_pos_emb(v, key_pos_emb)

            # 如果使用空键值对,则处理空键和空值
            if self.use_null_kv:
                null_k, null_v = map(lambda t: t[degree], (self.null_keys, self.null_values))
                null_k, null_v = map(lambda t: repeat(t, 'h d m -> b h i () d m', b = q.shape[0], i = q.shape[2]), (null_k, null_v))
                k = torch.cat((null_k, k), dim = 3)
                v = torch.cat((null_v, v), dim = 3)

            # 如果全局特征存在且度为'0',则处理全局键和全局值
            if exists(global_feats) and degree == '0':
                global_k, global_v = map(lambda t: t[degree], (global_keys, global_values))
                global_k, global_v = map(lambda t: repeat(t, 'b j (h d) m -> b h i j d m', h = h, i = k.shape[2]), (global_k, global_v))
                k = torch.cat((global_k, k), dim = 3)
                v = torch.cat((global_v, v), dim = 3)

            # 计算注意力权重
            sim = einsum('b h i d m, b h i j d m -> b h i j', q, k) * self.scale

            # 如果邻居掩码存在,则进行掩码处理
            if exists(neighbor_mask):
                num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1]
                mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True)
                sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

            # 计算注意力输出
            attn = sim.softmax(dim = -1)
            out = einsum('b h i j, b h i j d m -> b h i d m', attn, v)
            outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m')

        # 返回输出结果
        return self.to_out(outputs)
# 定义一个带有一个键/值投影的注意力机制类,该投影在所有查询头之间共享
class OneHeadedKVAttentionSE3(nn.Module):
    def __init__(
        self,
        fiber,
        dim_head = 64,
        heads = 8,
        attend_self = False,
        edge_dim = None,
        fourier_encode_dist = False,
        rel_dist_num_fourier_features = 4,
        use_null_kv = False,
        splits = 4,
        global_feats_dim = None,
        linear_proj_keys = False,
        tie_key_values = False
    ):
        super().__init__()
        hidden_dim = dim_head * heads
        hidden_fiber = Fiber(list(map(lambda t: (t[0], hidden_dim), fiber)))
        kv_hidden_fiber = Fiber(list(map(lambda t: (t[0], dim_head), fiber)))
        project_out = not (heads == 1 and len(fiber.dims) == 1 and dim_head == fiber.dims[0])

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.linear_proj_keys = linear_proj_keys # 是否对键进行线性投影,而不是与基卷积

        # 创建查询线性层
        self.to_q = LinearSE3(fiber, hidden_fiber)
        # 创建值卷积层
        self.to_v = ConvSE3(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)

        assert not (linear_proj_keys and tie_key_values), 'you cannot do linear projection of keys and have shared key / values turned on at the same time'

        if linear_proj_keys:
            # 如果进行线性投影,则创建键的线性层
            self.to_k = LinearSE3(fiber, kv_hidden_fiber)
        elif not tie_key_values:
            # 如果不共享键/值,则创建键的卷积层
            self.to_k = ConvSE3(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
        else:
            self.to_k = None

        # 创建输出线性层
        self.to_out = LinearSE3(hidden_fiber, fiber) if project_out else nn.Identity()

        self.use_null_kv = use_null_kv
        if use_null_kv:
            # 如果使用空键/值,则创建空键和值的参数字典
            self.null_keys = nn.ParameterDict()
            self.null_values = nn.ParameterDict()

            for degree in fiber.degrees:
                m = to_order(degree)
                degree_key = str(degree)
                self.null_keys[degree_key] = nn.Parameter(torch.zeros(dim_head, m))
                self.null_values[degree_key] = nn.Parameter(torch.zeros(dim_head, m))

        self.attend_self = attend_self
        if attend_self:
            # 如果自我关注,则创建自我键和值的线性层
            self.to_self_k = LinearSE3(fiber, kv_hidden_fiber)
            self.to_self_v = LinearSE3(fiber, kv_hidden_fiber)

        self.accept_global_feats = exists(global_feats_dim)
        if self.accept_global_feats:
            # 如果接受全局特征,则创建全局键和值的线性层
            global_input_fiber = Fiber.create(1, global_feats_dim)
            global_output_fiber = Fiber.create(1, kv_hidden_fiber[0])
            self.to_global_k = LinearSE3(global_input_fiber, global_output_fiber)
            self.to_global_v = LinearSE3(global_input_fiber, global_output_fiber)
    # 定义前向传播函数,接收特征、边信息、相对距离、基础信息、全局特征、位置嵌入和掩码作为输入
    def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None):
        # 获取头数和是否自我关注的标志
        h, attend_self = self.heads, self.attend_self
        # 获取设备和数据类型
        device, dtype = get_tensor_device_and_dtype(features)
        # 解包边信息
        neighbor_indices, neighbor_mask, edges = edge_info

        # 如果存在邻居掩码,则重排维度
        if exists(neighbor_mask):
            neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j')

        # 将特征转换为查询、值和键
        queries = self.to_q(features)
        values  = self.to_v(features, edge_info, rel_dist, basis)

        # 如果使用线性投影的键,则将键映射到相应的位置
        if self.linear_proj_keys:
            keys = self.to_k(features)
            keys = map_values(lambda val: batched_index_select(val, neighbor_indices, dim = 1), keys)
        # 如果没有定义键转换函数,则将键设置为值
        elif not exists(self.to_k):
            keys = values
        else:
            keys = self.to_k(features, edge_info, rel_dist, basis)

        # 如果允许自我关注,则获取自我关注的键和值
        if attend_self:
            self_keys, self_values = self.to_self_k(features), self.to_self_v(features)

        # 如果存在全局特征,则获取全局键和值
        if exists(global_feats):
            global_keys, global_values = self.to_global_k(global_feats), self.to_global_v(global_feats)

        # 初始化输出字典
        outputs = {}
        # 遍历特征的度
        for degree in features.keys():
            # 获取当前度的查询、键和值
            q, k, v = map(lambda t: t[degree], (queries, keys, values))

            # 重排查询的维度
            q = rearrange(q, 'b i (h d) m -> b h i d m', h = h)

            # 如果允许自我关注,则处理自我关注的键和值
            if attend_self:
                self_k, self_v = map(lambda t: t[degree], (self_keys, self_values))
                self_k, self_v = map(lambda t: rearrange(t, 'b n d m -> b n () d m'), (self_k, self_v))
                k = torch.cat((self_k, k), dim = 2)
                v = torch.cat((self_v, v), dim = 2)

            # 如果存在位置嵌入并且度为 '0',则应用旋转位置嵌入
            if exists(pos_emb) and degree == '0':
                query_pos_emb, key_pos_emb = pos_emb
                query_pos_emb = rearrange(query_pos_emb, 'b i d -> b () i d ()')
                key_pos_emb = rearrange(key_pos_emb, 'b i j d -> b i j d ()')
                q = apply_rotary_pos_emb(q, query_pos_emb)
                k = apply_rotary_pos_emb(k, key_pos_emb)
                v = apply_rotary_pos_emb(v, key_pos_emb)

            # 如果使用空键值对,则将空键值对与当前键值对拼接
            if self.use_null_kv:
                null_k, null_v = map(lambda t: t[degree], (self.null_keys, self.null_values))
                null_k, null_v = map(lambda t: repeat(t, 'd m -> b i () d m', b = q.shape[0], i = q.shape[2]), (null_k, null_v))
                k = torch.cat((null_k, k), dim = 2)
                v = torch.cat((null_v, v), dim = 2)

            # 如果存在全局特征并且度为 '0',则将全局键值对与当前键值对拼接
            if exists(global_feats) and degree == '0':
                global_k, global_v = map(lambda t: t[degree], (global_keys, global_values))
                global_k, global_v = map(lambda t: repeat(t, 'b j d m -> b i j d m', i = k.shape[1]), (global_k, global_v))
                k = torch.cat((global_k, k), dim = 2)
                v = torch.cat((global_v, v), dim = 2)

            # 计算注意力权重
            sim = einsum('b h i d m, b i j d m -> b h i j', q, k) * self.scale

            # 如果存在邻居掩码,则进行掩码操作
            if exists(neighbor_mask):
                num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1]
                mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True)
                sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

            # 计算注意力分布并进行加权求和
            attn = sim.softmax(dim = -1)
            out = einsum('b h i j, b i j d m -> b h i d m', attn, v)
            outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m')

        # 将输出转换为最终输出
        return self.to_out(outputs)
# 定义一个注意力块类,继承自 nn.Module
class AttentionBlockSE3(nn.Module):
    def __init__(
        self,
        fiber,
        dim_head = 24,
        heads = 8,
        attend_self = False,
        edge_dim = None,
        use_null_kv = False,
        fourier_encode_dist = False,
        rel_dist_num_fourier_features = 4,
        splits = 4,
        global_feats_dim = False,
        linear_proj_keys = False,
        tie_key_values = False,
        attention_klass = AttentionSE3,
        norm_gated_scale = False
    ):
        super().__init__()
        # 初始化注意力机制
        self.attn = attention_klass(fiber, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, use_null_kv = use_null_kv, rel_dist_num_fourier_features = rel_dist_num_fourier_features, fourier_encode_dist =fourier_encode_dist, splits = splits, global_feats_dim = global_feats_dim, linear_proj_keys = linear_proj_keys, tie_key_values = tie_key_values)
        # 初始化预处理层
        self.prenorm = NormSE3(fiber, gated_scale = norm_gated_scale)
        # 初始化残差连接
        self.residual = ResidualSE3()

    def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None):
        res = features
        # 对输入特征进行预处理
        outputs = self.prenorm(features)
        # 使用注意力机制处理特征
        outputs = self.attn(outputs, edge_info, rel_dist, basis, global_feats, pos_emb, mask)
        # 返回残差连接结果
        return self.residual(outputs, res)

# 定义 Swish_ 类
class Swish_(nn.Module):
    def forward(self, x):
        return x * x.sigmoid()

# 如果 nn 模块中有 SiLU 函数,则使用 nn.SiLU,否则使用自定义的 Swish_ 类
SiLU = nn.SiLU if hasattr(nn, 'SiLU') else Swish_

# 定义 HtypesNorm 类
class HtypesNorm(nn.Module):
    def __init__(self, dim, eps = 1e-8, scale_init = 1e-2, bias_init = 1e-2):
        super().__init__()
        self.eps = eps
        # 初始化缩放参数和偏置参数
        scale = torch.empty(1, 1, 1, dim, 1).fill_(scale_init)
        bias = torch.empty(1, 1, 1, dim, 1).fill_(bias_init)
        self.scale = nn.Parameter(scale)
        self.bias = nn.Parameter(bias)

    def forward(self, coors):
        # 计算输入张量的范数
        norm = coors.norm(dim = -1, keepdim = True)
        # 对输入张量进行归一化处理
        normed_coors = coors / norm.clamp(min = self.eps)
        return normed_coors * (norm * self.scale + self.bias)

# 定义 EGNN 类
class EGNN(nn.Module):
    def __init__(
        self,
        fiber,
        hidden_dim = 32,
        edge_dim = 0,
        init_eps = 1e-3,
        coor_weights_clamp_value = None
    ):
        super().__init__()
        self.fiber = fiber
        node_dim = fiber[0]

        htypes = list(filter(lambda t: t.degrees != 0, fiber))
        num_htypes = len(htypes)
        htype_dims = sum([fiberel.dim for fiberel in htypes])

        edge_input_dim = node_dim * 2 + htype_dims + edge_dim + 1

        # 初始化节点归一化层
        self.node_norm = nn.LayerNorm(node_dim)

        # 初始化边 MLP
        self.edge_mlp = nn.Sequential(
            nn.Linear(edge_input_dim, edge_input_dim * 2),
            SiLU(),
            nn.Linear(edge_input_dim * 2, hidden_dim),
            SiLU()
        )

        self.htype_norms = nn.ModuleDict({})
        self.htype_gating = nn.ModuleDict({})

        for degree, dim in fiber:
            if degree == 0:
                continue
            # 初始化 HtypesNorm 和线性层
            self.htype_norms[str(degree)] = HtypesNorm(dim)
            self.htype_gating[str(degree)] = nn.Linear(node_dim, dim)

        # 初始化 Htypes MLP
        self.htypes_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            SiLU(),
            nn.Linear(hidden_dim * 4, htype_dims)
        )

        # 初始化节点 MLP
        self.node_mlp = nn.Sequential(
            nn.Linear(node_dim + hidden_dim, node_dim * 2),
            SiLU(),
            nn.Linear(node_dim * 2, node_dim)
        )

        self.coor_weights_clamp_value = coor_weights_clamp_value
        self.init_eps = init_eps
        self.apply(self.init_)

    def init_(self, module):
        if type(module) in {nn.Linear}:
            # 初始化线性层的权重
            nn.init.normal_(module.weight, std = self.init_eps)

    def forward(
        self,
        features,
        edge_info,
        rel_dist,
        mask = None,
        **kwargs
        ):
            # 解包边信息
            neighbor_indices, neighbor_masks, edges = edge_info

            # 使用邻居掩码
            mask = neighbor_masks

            # 类型 0 特征

            # 获取节点特征
            nodes = features['0']
            # 重新排列节点特征
            nodes = rearrange(nodes, '... () -> ...')

            # 更高级别类型(htype)

            # 过滤出非 '0' 类型的特征
            htypes = list(filter(lambda t: t[0] != '0', features.items()))
            # 获取每个类型的度数
            htype_degrees = list(map(lambda t: t[0], htypes))
            # 获取每个类型的维度
            htype_dims = list(map(lambda t: t[1].shape[-2], htypes))

            # 准备更高级别类型

            rel_htypes = []
            rel_htypes_dists = []

            for degree, htype in htypes:
                # 计算相对类型
                rel_htype = rearrange(htype, 'b i d m -> b i () d m') - rearrange(htype, 'b j d m -> b () j d m')
                rel_htype_dist = rel_htype.norm(dim = -1)

                rel_htypes.append(rel_htype)
                rel_htypes_dists.append(rel_htype_dist)

            # 为边 MLP 准备边

            nodes_i = rearrange(nodes, 'b i d -> b i () d')
            nodes_j = batched_index_select(nodes, neighbor_indices, dim = 1)
            neighbor_higher_type_dists = map(lambda t: batched_index_select(t, neighbor_indices, dim = 2), rel_htypes_dists)
            coor_rel_dist = rearrange(rel_dist, 'b i j -> b i j ()')

            edge_mlp_inputs = broadcat((nodes_i, nodes_j, *neighbor_higher_type_dists, coor_rel_dist), dim = -1)

            if exists(edges):
                edge_mlp_inputs = torch.cat((edge_mlp_inputs, edges), dim = -1)

            # 获取中间表示

            m_ij = self.edge_mlp(edge_mlp_inputs)

            # 转换为坐标

            htype_weights = self.htypes_mlp(m_ij)

            if exists(self.coor_weights_clamp_value):
                clamp_value = self.coor_weights_clamp_value
                htype_weights.clamp_(min = -clamp_value, max = clamp_value)

            split_htype_weights = htype_weights.split(htype_dims, dim = -1)

            htype_updates = []

            if exists(mask):
                htype_mask = rearrange(mask, 'b i j -> b i j ()')
                htype_weights = htype_weights.masked_fill(~htype_mask, 0.)

            for degree, rel_htype, htype_weight in zip(htype_degrees, rel_htypes, split_htype_weights):
                normed_rel_htype = self.htype_norms[str(degree)](rel_htype)
                normed_rel_htype = batched_index_select(normed_rel_htype, neighbor_indices, dim = 2)

                htype_update = einsum('b i j d m, b i j d -> b i d m', normed_rel_htype, htype_weight)
                htype_updates.append(htype_update)

            # 转换为节点

            if exists(mask):
                m_ij_mask = rearrange(mask, '... -> ... ()')
                m_ij = m_ij.masked_fill(~m_ij_mask, 0.)

            m_i = m_ij.sum(dim = -2)

            normed_nodes = self.node_norm(nodes)
            node_mlp_input = torch.cat((normed_nodes, m_i), dim = -1)
            node_out = self.node_mlp(node_mlp_input) + nodes

            # 更新节点

            features['0'] = rearrange(node_out, '... -> ... ()')

            # 更新更高级别类型

            update_htype_dicts = dict(zip(htype_degrees, htype_updates))

            for degree, update_htype in update_htype_dicts.items():
                features[degree] = features[degree] + update_htype

            for degree in htype_degrees:
                gating = self.htype_gating[str(degree)](node_out).sigmoid()
                features[degree] = features[degree] * rearrange(gating, '... -> ... ()')

            return features
# 定义一个 EGnnNetwork 类,继承自 nn.Module 类
class EGnnNetwork(nn.Module):
    # 初始化函数,接收多个参数
    def __init__(
        self,
        *,
        fiber,
        depth,
        edge_dim = 0,
        hidden_dim = 32,
        coor_weights_clamp_value = None,
        feedforward = False
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 将参数赋值给对象属性
        self.fiber = fiber
        self.layers = nn.ModuleList([])
        # 循环创建指定数量的 EGNN 和 FeedForwardBlockSE3 对象,并添加到 layers 中
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                EGNN(fiber = fiber, edge_dim = edge_dim, hidden_dim = hidden_dim, coor_weights_clamp_value = coor_weights_clamp_value),
                FeedForwardBlockSE3(fiber) if feedforward else None
            ]))

    # 前向传播函数,接收多个参数
    def forward(
        self,
        features,
        edge_info,
        rel_dist,
        basis,
        global_feats = None,
        pos_emb = None,
        mask = None,
        **kwargs
    ):
        # 解包 edge_info 参数
        neighbor_indices, neighbor_masks, edges = edge_info
        # 获取设备信息
        device = neighbor_indices.device

        # 修改邻居信息以包含自身(因为 SE3 变换器依赖于去除对自身的注意力,但这不适用于 EGNN)

        # 创建包含自身索引的张量
        self_indices = torch.arange(neighbor_indices.shape[1], device = device)
        self_indices = rearrange(self_indices, 'i -> () i ()')
        neighbor_indices = broadcat((self_indices, neighbor_indices), dim = -1)

        # 对邻居掩码进行填充
        neighbor_masks = F.pad(neighbor_masks, (1, 0), value = True)
        rel_dist = F.pad(rel_dist, (1, 0), value = 0.)

        # 如果存在边信息,则对边信息进行填充
        if exists(edges):
            edges = F.pad(edges, (0, 0, 1, 0), value = 0.)  # 暂时将令牌到自身的边设置为 0

        edge_info = (neighbor_indices, neighbor_masks, edges)

        # 遍历每一层
        for egnn, ff in self.layers:
            # 调用 EGNN 对象进行特征变换
            features = egnn(
                features,
                edge_info = edge_info,
                rel_dist = rel_dist,
                basis = basis,
                global_feats = global_feats,
                pos_emb = pos_emb,
                mask = mask,
                **kwargs
            )

            # 如果存在 FeedForwardBlockSE3 对象,则调用进行特征变换
            if exists(ff):
                features = ff(features)

        return features

# 主类
class SE3Transformer(nn.Module):
    # 初始化函数,接收多个参数
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 24,
        depth = 2,
        input_degrees = 1,
        num_degrees = None,
        output_degrees = 1,
        valid_radius = 1e5,
        reduce_dim_out = False,
        num_tokens = None,
        num_positions = None,
        num_edge_tokens = None,
        edge_dim = None,
        reversible = False,
        attend_self = True,
        use_null_kv = False,
        differentiable_coors = False,
        fourier_encode_dist = False,
        rel_dist_num_fourier_features = 4,
        num_neighbors = float('inf'),
        attend_sparse_neighbors = False,
        num_adj_degrees = None,
        adj_dim = 0,
        max_sparse_neighbors = float('inf'),
        dim_in = None,
        dim_out = None,
        norm_out = False,
        num_conv_layers = 0,
        causal = False,
        splits = 4,
        global_feats_dim = None,
        linear_proj_keys = False,
        one_headed_key_values = False,
        tie_key_values = False,
        rotary_position = False,
        rotary_rel_dist = False,
        norm_gated_scale = False,
        use_egnn = False,
        egnn_hidden_dim = 32,
        egnn_weights_clamp_value = None,
        egnn_feedforward = False,
        hidden_fiber_dict = None,
        out_fiber_dict = None
    # 前向传播函数,接收多个参数
    def forward(
        self,
        feats,
        coors,
        mask = None,
        adj_mat = None,
        edges = None,
        return_type = None,
        return_pooled = False,
        neighbor_mask = None,
        global_feats = None

.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\spherical_harmonics.py

# 从 math 模块中导入 pi 和 sqrt 函数
# 从 functools 模块中导入 reduce 函数
# 从 operator 模块中导入 mul 函数
# 导入 torch 模块
from math import pi, sqrt
from functools import reduce
from operator import mul
import torch

# 从 functools 模块中导入 lru_cache 装饰器
# 从 se3_transformer_pytorch.utils 模块中导入 cache 函数
from functools import lru_cache
from se3_transformer_pytorch.utils import cache

# 定义常量 CACHE,初始化为空字典
CACHE = {}

# 清空球谐函数缓存
def clear_spherical_harmonics_cache():
    CACHE.clear()

# 定义函数 lpmv_cache_key_fn,用于生成缓存键
def lpmv_cache_key_fn(l, m, x):
    return (l, m)

# 定义函数 semifactorial,使用 lru_cache 装饰器缓存结果
@lru_cache(maxsize = 1000)
def semifactorial(x):
    return reduce(mul, range(x, 1, -2), 1.)

# 定义函数 pochhammer,使用 lru_cache 装饰器缓存结果
@lru_cache(maxsize = 1000)
def pochhammer(x, k):
    return reduce(mul, range(x + 1, x + k), float(x))

# 定义函数 negative_lpmv,计算负的球谐函数
def negative_lpmv(l, m, y):
    if m < 0:
        y *= ((-1) ** m / pochhammer(l + m + 1, -2 * m))
    return y

# 定义函数 lpmv,使用 cache 装饰器缓存结果
@cache(cache=CACHE, key_fn=lpmv_cache_key_fn)
def lpmv(l, m, x):
    """Associated Legendre function including Condon-Shortley phase.

    Args:
        m: int order 
        l: int degree
        x: float argument tensor
    Returns:
        tensor of x-shape
    """
    # 检查是否有缓存版本
    m_abs = abs(m)

    if m_abs > l:
        return None

    if l == 0:
        return torch.ones_like(x)
    
    if m_abs == l:
        y = (-1)**m_abs * semifactorial(2*m_abs-1)
        y *= torch.pow(1-x*x, m_abs/2)
        return negative_lpmv(l, m, y)

    lpmv(l-1, m, x)

    y = ((2*l-1) / (l-m_abs)) * x * lpmv(l-1, m_abs, x)

    if l - m_abs > 1:
        y -= ((l+m_abs-1)/(l-m_abs)) * CACHE[(l-2, m_abs)]
    
    if m < 0:
        y = negative_lpmv(l, m, y)
    return y

# 定义函数 get_spherical_harmonics_element,计算球谐函数元素
def get_spherical_harmonics_element(l, m, theta, phi):
    """Tesseral spherical harmonic with Condon-Shortley phase.

    The Tesseral spherical harmonics are also known as the real spherical
    harmonics.

    Args:
        l: int for degree
        m: int for order, where -l <= m < l
        theta: collatitude or polar angle
        phi: longitude or azimuth
    Returns:
        tensor of shape theta
    """
    m_abs = abs(m)
    assert m_abs <= l, "absolute value of order m must be <= degree l"

    N = sqrt((2*l + 1) / (4 * pi))
    leg = lpmv(l, m_abs, torch.cos(theta))

    if m == 0:
        return N * leg

    if m > 0:
        Y = torch.cos(m * phi)
    else:
        Y = torch.sin(m_abs * phi)

    Y *= leg
    N *= sqrt(2. / pochhammer(l - m_abs + 1, 2 * m_abs))
    Y *= N
    return Y

# 定义函数 get_spherical_harmonics,计算球谐函数
def get_spherical_harmonics(l, theta, phi):
    """ Tesseral harmonic with Condon-Shortley phase.

    The Tesseral spherical harmonics are also known as the real spherical
    harmonics.

    Args:
        l: int for degree
        theta: collatitude or polar angle
        phi: longitude or azimuth
    Returns:
        tensor of shape [*theta.shape, 2*l+1]
    """
    return torch.stack([get_spherical_harmonics_element(l, m, theta, phi) \
                        for m in range(-l, l+1)],
                        dim=-1)

.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\utils.py

# 导入必要的库
import os
import sys
import time
import pickle
import gzip
import torch
import contextlib
from functools import wraps, lru_cache
from filelock import FileLock
from einops import rearrange

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 返回唯一值列表
def uniq(arr):
    return list({el: True for el in arr}.keys())

# 返回指定阶数
def to_order(degree):
    return 2 * degree + 1

# 对字典的值应用函数
def map_values(fn, d):
    return {k: fn(v) for k, v in d.items()}

# 安全地拼接张量
def safe_cat(arr, el, dim):
    if not exists(arr):
        return el
    return torch.cat((arr, el), dim=dim)

# 将值转换为元组
def cast_tuple(val, depth):
    return val if isinstance(val, tuple) else (val,) * depth

# 广播张量
def broadcat(tensors, dim=-1):
    num_tensors = len(tensors)
    shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
    assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
    shape_len = list(shape_lens)[0]

    dim = (dim + shape_len) if dim < 0 else dim
    dims = list(zip(*map(lambda t: list(t.shape), tensors)))

    expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
    assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
    max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
    expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
    expanded_dims.insert(dim, (dim, dims[dim]))
    expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
    tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
    return torch.cat(tensors, dim=dim)

# 批量索引选择
def batched_index_select(values, indices, dim=1):
    value_dims = values.shape[(dim + 1):]
    values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
    indices = indices[(..., *((None,) * len(value_dims))]
    indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
    value_expand_len = len(indices_shape) - (dim + 1)
    values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]

    value_expand_shape = [-1] * len(values.shape)
    expand_slice = slice(dim, (dim + value_expand_len))
    value_expand_shape[expand_slice] = indices.shape[expand_slice]
    values = values.expand(*value_expand_shape)

    dim += value_expand_len
    return values.gather(dim, indices)

# 掩码均值
def masked_mean(tensor, mask, dim=-1):
    diff_len = len(tensor.shape) - len(mask.shape)
    mask = mask[(..., *((None,) * diff_len))]
    tensor.masked_fill_(~mask, 0.)

    total_el = mask.sum(dim=dim)
    mean = tensor.sum(dim=dim) / total_el.clamp(min=1.)
    mean.masked_fill_(total_el == 0, 0.)
    return mean

# 生成均匀分布的随机张量
def rand_uniform(size, min_val, max_val):
    return torch.empty(size).uniform_(min_val, max_val)

# 快速分割张量
def fast_split(arr, splits, dim=0):
    axis_len = arr.shape[dim]
    splits = min(axis_len, max(splits, 1))
    chunk_size = axis_len // splits
    remainder = axis_len - chunk_size * splits
    s = 0
    for i in range(splits):
        adjust, remainder = 1 if remainder > 0 else 0, remainder - 1
        yield torch.narrow(arr, dim, s, chunk_size + adjust)
        s += chunk_size + adjust

# 傅立叶编码
def fourier_encode(x, num_encodings=4, include_self=True, flatten=True):
    x = x.unsqueeze(-1)
    device, dtype, orig_x = x.device, x.dtype, x
    scales = 2 ** torch.arange(num_encodings, device=device, dtype=dtype)
    x = x / scales
    x = torch.cat([x.sin(), x.cos()], dim=-1)
    x = torch.cat((x, orig_x), dim=-1) if include_self else x
    x = rearrange(x, 'b m n ... -> b m n (...)') if flatten else x
    return x

# 默认数据类型上下文管理器
@contextlib.contextmanager
def torch_default_dtype(dtype):
    prev_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(prev_dtype)

# 转换为 torch 张量的装饰器
def cast_torch_tensor(fn):
    @wraps(fn)
    # 定义一个内部函数 inner,接受一个参数 t
    def inner(t):
        # 如果 t 不是 torch 的张量,则将 t 转换为 torch 的张量,数据类型为默认数据类型
        if not torch.is_tensor(t):
            t = torch.tensor(t, dtype=torch.get_default_dtype())
        # 调用外部函数 fn,并传入处理后的张量 t
        return fn(t)
    # 返回内部函数 inner
    return inner
# benchmark 工具函数,用于计算函数执行时间
def benchmark(fn):
    # 内部函数,记录函数执行时间并返回结果
    def inner(*args, **kwargs):
        # 记录开始时间
        start = time.time()
        # 执行函数
        res = fn(*args, **kwargs)
        # 计算时间差
        diff = time.time() - start
        # 返回时间差和结果
        return diff, res
    return inner

# 缓存函数装饰器
def cache(cache, key_fn):
    # 内部函数,实现缓存功能
    def cache_inner(fn):
        @wraps(fn)
        # 内部函数,检查缓存并返回结果
        def inner(*args, **kwargs):
            # 生成缓存键名
            key_name = key_fn(*args, **kwargs)
            # 如果缓存中存在键名,则直接返回结果
            if key_name in cache:
                return cache[key_name]
            # 否则执行函数并将结果存入缓存
            res = fn(*args, **kwargs)
            cache[key_name] = res
            return res

        return inner
    return cache_inner

# 在目录中进行缓存
def cache_dir(dirname, maxsize=128):
    '''
    Cache a function with a directory

    :param dirname: the directory path
    :param maxsize: maximum size of the RAM cache (there is no limit for the directory cache)
    '''
    def decorator(func):

        @lru_cache(maxsize=maxsize)
        @wraps(func)
        # 内部函数,实现目录缓存功能
        def wrapper(*args, **kwargs):
            # 如果目录不存在,则直接执行函数
            if not exists(dirname):
                return func(*args, **kwargs)

            # 创建目录
            os.makedirs(dirname, exist_ok=True)

            indexfile = os.path.join(dirname, "index.pkl")
            lock = FileLock(os.path.join(dirname, "mutex"))

            with lock:
                index = {}
                # 如果索引文件存在,则加载索引
                if os.path.exists(indexfile):
                    with open(indexfile, "rb") as file:
                        index = pickle.load(file)

                key = (args, frozenset(kwargs), func.__defaults__)

                # 如果键存在于索引中,则获取文件名
                if key in index:
                    filename = index[key]
                else:
                    index[key] = filename = f"{len(index)}.pkl.gz"
                    with open(indexfile, "wb") as file:
                        pickle.dump(index, file)

            filepath = os.path.join(dirname, filename)

            # 如果文件存在,则加载结果
            if os.path.exists(filepath):
                with lock:
                    with gzip.open(filepath, "rb") as file:
                        result = pickle.load(file)
                return result

            print(f"compute {filename}... ", end="", flush=True)
            result = func(*args, **kwargs)
            print(f"save {filename}... ", end="", flush=True)

            with lock:
                with gzip.open(filepath, "wb") as file:
                    pickle.dump(result, file)

            print("done")

            return result
        return wrapper
    return decorator