Lucidrains-系列项目源码解析-四十-

75 阅读11分钟

Lucidrains 系列项目源码解析(四十)

.\lucidrains\geometric-vector-perceptron\examples\data_utils.py

# 作者:Eric Alcaide

# 导入必要的库
import os 
import sys
# 科学计算库
import torch
import torch_sparse
import numpy as np 
from einops import repeat, rearrange
# 导入自定义工具 - 来自 https://github.com/EleutherAI/mp_nerf
from data_handler import *

# 新数据构建函数
def get_atom_ids_dict():
    """ 获取将每个原子映射到一个标记的字典 """
    # 初始化原子集合
    ids = set(["N", "CA", "C", "O"])

    # 遍历 SC_BUILD_INFO 中的键值对
    for k,v in SC_BUILD_INFO.items():
        # 遍历原子名称列表
        for name in v["atom-names"]:
            ids.add(name)
            
    # 返回原子到标记的映射字典
    return {k: i for i,k in enumerate(sorted(ids))}

#################################
##### 原始项目数据 #####
#################################

# 定义氨基酸序列和对应的数字
AAS = "ARNDCQEGHILKMFPSTWYV_"
AAS2NUM = {k: AAS.index(k) for k in AAS}
# 获取原子标记字典
ATOM_IDS = get_atom_ids_dict()
# 定义氨基酸的键值对应的键为氨基酸,值为键值对应的原子键值对
GVP_DATA = { 
    'A': {
        'bonds': [[0,1], [1,2], [2,3], [1,4]] 
         },
    'R': {
        'bonds': [[0,1], [1,2], [2,3], [2,4], [4,5], [5,6],
                  [6,7], [7,8], [8,9], [8,10]] 
         },
    # 其他氨基酸的键值对应的键值对
    # ...
    '_': {
        'bonds': []
        }
    }

#################################
##### 原始项目数据 #####
#################################

def graph_laplacian_embedds(edges, eigen_k, center_idx=1, norm=False):
    """ 返回图拉普拉斯的前 K 个特征向量中点的嵌入。
        输入:
        * edges: (2, N). 长整型张量或列表。足够表示无向边。
        * eigen_k: 整数。要返回嵌入的前 N 个特征向量。
        * center_idx: 整数。用作嵌入中心的索引。
        * norm: 布尔值。是否使用归一化拉普拉斯。不建议使用。
        输出:(n_points, eigen_k)
    """
    # 如果 edges 是列表,则转换为长整型张量
    if isinstance(edges, list):
        edges = torch.tensor(edges).long()
        # 纠正维度
        if edges.shape[0] != 2:
            edges = edges.t()
        # 如果为空,则返回零张量
        if edges.shape[0] == 0:
            return torch.zeros(1, eigen_k)
    # 获取参数
    # 计算边的最大值并加1,作为邻接矩阵的大小
    size = torch.max(edges)+1
    # 获取边所在设备信息
    device = edges.device
    # 创建邻接矩阵
    adj_mat = torch.eye(size, device=device) 
    # 遍历边的起始点和终点,将邻接矩阵对应位置设为1
    for i,j in edges.t():
        adj_mat[i,j] = adj_mat[j,i] = 1.
        
    # 计算度矩阵
    deg_mat = torch.eye(size) * adj_mat.sum(dim=-1, keepdim=True)
    # 计算拉普拉斯矩阵
    laplace = deg_mat - adj_mat
    # 如果传入了norm参数,则使用规范化的拉普拉斯矩阵
    if norm:
        # 遍历边的起始点和终点,更新拉普拉斯矩阵的值
        for i,j in edges.t():
            laplace[i,j] = laplace[j,i] = -1 / (deg_mat[i,i] * deg_mat[j,j])**0.5
    # 对拉普拉斯矩阵进行特征值分解,获取特征值和特征向量
    e, v = torch.symeig(laplace, eigenvectors=True)
    # 根据特征值的绝对值降序排列,获取排序后的索引
    idxs = torch.sort( e.abs(), descending=True)[1]
    # 获取前eigen_k个特征向量作为嵌入向量
    embedds = v[:, idxs[:eigen_k]]
    # 将嵌入向量减去中心点的嵌入向量
    embedds = embedds - embedds[center_idx].unsqueeze(-2)
    # 返回处理后的嵌入向量
    return embedds
# 返回每个氨基酸中每个原子的标记
def make_atom_id_embedds(k):
    # 创建一个长度为14的零张量
    mask = torch.zeros(14).long()
    # 定义氨基酸中的原子列表
    atom_list = ["N", "CA", "C", "O"] + SC_BUILD_INFO[k]["atom-names"]
    # 遍历原子列表,将每个原子的标记存储在mask中
    for i,atom in enumerate(atom_list):
        mask[i] = ATOM_IDS[atom]
    return mask


#################################
########## SAVE INFO ############
#################################

# 创建包含各种信息的字典
SUPREME_INFO = {k: {"cloud_mask": make_cloud_mask(k),
                    "bond_mask": make_bond_mask(k),
                    "theta_mask": make_theta_mask(k),
                    "torsion_mask": make_torsion_mask(k),
                    "idx_mask": make_idx_mask(k),
                    #
                    "eigen_embedd": graph_laplacian_embedds(GVP_DATA[k]["bonds"], eigen_k = 3),
                    "atom_id_embedd": make_atom_id_embedds(k)
                    } 
                for k in "ARNDCQEGHILKMFPSTWYV_"}

#################################
######### RANDOM UTILS ##########
#################################


# 使用正弦和余弦编码距离
def encode_dist(x, scales=[1,2,4,8], include_self = True):
    """ Encodes a distance with sines and cosines. 
        Inputs:
        * x: (batch, N) or (N,). data to encode.
              Infer devic and type (f16, f32, f64) from here.
        * scales: (s,) or list. lower or higher depending on distances.
        Output: (..., num_scales*2 + 1) if include_self or (..., num_scales*2) 
    """
    x = x.unsqueeze(-1)
    # 推断设备
    device, precise = x.device, x.type()
    # 转换为张量
    if isinstance(scales, list):
        scales = torch.tensor([scales], device=device).type(precise)
    # 获取正弦编码
    sines   = torch.sin(x / scales)
    cosines = torch.cos(x / scales)
    # 连接并返回
    enc_x = torch.cat([sines, cosines], dim=-1)
    return torch.cat([enc_x, x], dim=-1) if include_self else enc_x

# 解码距离
def decode_dist(x, scales=[1,2,4,8], include_self = False):
    """ Encodes a distance with sines and cosines. 
        Inputs:
        * x: (batch, N, 2*fourier_feats (+1) ) or (N,). data to encode.
              Infer devic and type (f16, f32, f64) from here.
        * scales: (s,) or list. lower or higher depending on distances.
        * include_self: whether to average with raw prediction or not.
        Output: (batch, N)
    """
    device, precise = x.device, x.type()
    # 转换为张量
    if isinstance(scales, list):
        scales = torch.tensor([scales], device=device).type(precise)
    # 通过 atan2 解码并校正负角度
    half = x.shape[-1]//2
    decodes = torch.atan2(x[..., :half], x[..., half:2*half])
    decodes += (decodes<0).type(precise) * 2*np.pi 
    # 调整偏移量
    offsets = torch.zeros_like(decodes)
    for i in range(decodes.shape[-1]-1, 0, -1):
        offsets[:, i-1] = 2 * ( offsets[:, i] + (decodes[:, i]>np.pi).type(precise) * np.pi )
    decodes += offsets
    avg_dec = (decodes * scales).mean(dim=-1, keepdim=True)
    if include_self:
        return 0.5*(avg_dec + x[..., -1:])
    return avg_dec

# 计算第n次邻接矩阵
def nth_deg_adjacency(adj_mat, n=1, sparse=False):
    """ Calculates the n-th degree adjacency matrix.
        Performs mm of adj_mat and adds the newly added.
        Default is dense. Mods for sparse version are done when needed.
        Inputs: 
        * adj_mat: (N, N) adjacency tensor
        * n: int. degree of the output adjacency
        * sparse: bool. whether to use torch-sparse module
        Outputs: 
        * edge_idxs: the ij positions of the adjacency matrix
        * edge_attrs: the degree of connectivity (1 for neighs, 2 for neighs^2 )
    """
    adj_mat = adj_mat.float()
    attr_mat = torch.zeros_like(adj_mat)
    # 遍历范围为n的循环
    for i in range(n):
        # 如果i为0,则将属性矩阵与邻接矩阵相加
        if i == 0:
            attr_mat += adj_mat
            continue

        # 如果i为1且sparse为True
        if i == 1 and sparse: 
            # 创建稀疏邻接张量
            adj_mat = torch.sparse.FloatTensor(adj_mat.nonzero().t(),
                                                adj_mat[adj_mat != 0]).to(adj_mat.device).coalesce()
            idxs, vals = adj_mat.indices(), adj_mat.values()
            m, k, n = 3 * [adj_mat.shape[0]]  # (m, n) * (n, k) , 但adj_mats是方阵:m=n=k

        # 如果sparse为True
        if sparse:
            # 使用torch_sparse库中的spspmm函数进行稀疏矩阵乘法
            idxs, vals = torch_sparse.spspmm(idxs, vals, idxs, vals, m=m, k=k, n=n)
            adj_mat = torch.zeros_like(attr_mat)
            adj_mat[idxs[0], idxs[1]] = vals.bool().float()
        else:
            # 如果sparse为False,则将邻接矩阵平方,转换为布尔型矩阵
            adj_mat = (adj_mat @ adj_mat).bool().float() 

        # 更新属性矩阵
        attr_mat[(adj_mat - attr_mat.bool().float()).bool()] += i + 1

    # 返回更新后的邻接矩阵和属性矩阵
    return adj_mat, attr_mat
# 返回蛋白质的共价键的索引
def prot_covalent_bond(seq, adj_degree=1, cloud_mask=None):
    """ 返回蛋白质的共价键的索引。
        输入
        * seq: str. 用1字母氨基酸代码表示的蛋白质序列。
        * cloud_mask: 选择存在原子的掩码。
        输出: edge_idxs
    """
    # 创建或推断 cloud_mask
    if cloud_mask is None: 
        cloud_mask = scn_cloud_mask(seq).bool()
    device, precise = cloud_mask.device, cloud_mask.type()
    # 获取每个氨基酸的起始位置
    scaff = torch.zeros_like(cloud_mask)
    scaff[:, 0] = 1
    idxs = scaff[cloud_mask].nonzero().view(-1)
    # 从包含 GVP_DATA 的字典中获取姿势 + 索引 - 返回所有边
    adj_mat = torch.zeros(idxs.amax()+14, idxs.amax()+14)
    for i,idx in enumerate(idxs):
        # 与下一个氨基酸的键
        extra = []
        if i < idxs.shape[0]-1:
            extra = [[2, (idxs[i+1]-idx).item()]]

        bonds = idx + torch.tensor( GVP_DATA[seq[i]]['bonds'] + extra ).long().t() 
        adj_mat[bonds[0], bonds[1]] = 1.
    # 转换为无向图
    adj_mat = adj_mat + adj_mat.t()
    # 进行 N 次邻接
    adj_mat, attr_mat = nth_deg_adjacency(adj_mat, n=adj_degree, sparse=True)

    edge_idxs = attr_mat.nonzero().t().long()
    edge_attrs = attr_mat[edge_idxs[0], edge_idxs[1]]
    return edge_idxs, edge_attrs


def dist2ca(x, mask=None, eps=1e-7):
    """ 计算每个点到 C-alfa 的距离。
        输入:
        * x: (L, 14, D)
        * mask: (L, 14) 的布尔掩码
        返回单位向量和范数。
    """
    x = x - x[:, 1].unsqueeze(1)
    norm = torch.norm(x, dim=-1, keepdim=True)
    x_norm = x / (norm+eps)
    if mask:
        return x_norm[mask], norm[mask]
    return x_norm, norm


def orient_aa(x, mask=None, eps=1e-7):
    """ 计算主链特征的单位向量和范数。
        输入:
        * x: (L, 14, D). Sidechainnet 格式的坐标。
        返回单位向量 (5) 和范数 (3)。
    """
    # 获取张量信息
    device, precise = x.device, x.type()

    vec_wrap  = torch.zeros(5, x.shape[0], 3, device=device) # (feats, L, dims+1)
    norm_wrap = torch.zeros(3, x.shape[0], device=device)
    # 第一个特征是 CB-CA
    vec_wrap[0]  = x[:, 4] - x[:, 1]
    norm_wrap[0] = torch.norm(vec_wrap[0], dim=-1)
    vec_wrap[0] /= norm_wrap[0].unsqueeze(dim=-1) + eps
    # 第二个是 CA+ - CA :
    vec_wrap[1, :-1]  = x[:-1, 1] - x[1:, 1]
    norm_wrap[1, :-1] = torch.norm(vec_wrap[1, :-1], dim=-1)
    vec_wrap[1, :-1] /= norm_wrap[1, :-1].unsqueeze(dim=-1) + eps
    # 同样但是反向向量
    vec_wrap[2] = (-1)*vec_wrap[1]
    # 第三个是 CA - CA-
    vec_wrap[3, 1:]  = x[:-1, 1] - x[1:, 1]
    norm_wrap[2, 1:] = torch.norm(vec_wrap[3, 1:], dim=-1)
    vec_wrap[3, 1:] /= norm_wrap[2, 1:].unsqueeze(dim=-1) + eps
    # 现在反向顺序的向量
    vec_wrap[4] = (-1)*vec_wrap[3]

    return vec_wrap, norm_wrap


def chain2atoms(x, mask=None):
    """ 从 (L, other) 扩展到 (L, C, other)。"""
    device, precise = x.device, x.type()
    # 获取掩码
    wrap = torch.ones(x.shape[0], 14, *x.shape[1:]).type(precise).to(device)
    # 分配
    wrap = wrap * x.unsqueeze(1)
    if mask is not None:
        return wrap[mask]
    return wrap


def from_encode_to_pred(whole_point_enc, use_fourier=False, embedd_info=None, needed_info=None, vec_dim=3):
    """ 将上述函数的编码转换为标签/预测格式。
        仅包含位置恢复所需的基本信息 (径向单位向量 + 范数)
        输入: 包含以下内容的输入元组:
        * whole_point_enc: (atoms, vector_dims+scalar_dims)
                           与上述函数相同的形状。
                           径向单位向量必须是第一个向量维度
        * embedd_info: 字典。包含标量和向量特征的数量。
    """
    vec_dims = vec_dim * embedd_info["point_n_vectors"]
    start_pos = 2*len(needed_info["atom_pos_scales"])+vec_dims
    # 如果使用傅立叶变换
    if use_fourier:
        # 解码整个点编码中的部分向量维度,不包括自身
        decoded_dist = decode_dist(whole_point_enc[:, vec_dims:start_pos+1],
                                    scales=needed_info["atom_pos_scales"],
                                    include_self=False)
    else:
        # 如果不使用傅立叶变换,直接取整个点编码中的特定维度
        decoded_dist = whole_point_enc[:, start_pos:start_pos+1]
    # 返回连接后的张量,包括单位径向向量和向量范数
    return torch.cat([whole_point_enc[:, :3], decoded_dist], dim=-1)
def encode_whole_bonds(x, x_format="coords", embedd_info={},
                       needed_info = {"cutoffs": [2,5,10],
                                      "bond_scales": [.5, 1, 2],
                                      "adj_degree": 1},
                       free_mem=False, eps=1e-7):
    """ Given some coordinates, and the needed info,
        encodes the bonds from point information.
        * x: (N, 3) or prediction format
        * x_format: one of ["coords" or "prediction"]
        * embedd_info: dict. contains the needed embedding info
        * needed_info: dict. contains additional needed info
            { cutoffs: list. cutoff distances for bonds.
                       can be a string for the k closest (ex: "30_closest"),
                       empty list for just covalent.
              bond_scales: list. fourier encodings
              adj_degree: int. degree of adj (2 means adj of adj is my adj)
                               0 for no adjacency
            }
        * free_mem: whether to delete variables
        * eps: constant for numerical stability
    """ 
    device, precise = x.device, x.type()
    # convert to 3d coords if passed as preds
    if x_format == "encode":
        pred_x = from_encode_to_pred(x, embedd_info=embedd_info, needed_info=needed_info)
        x = pred_x[:, :3] * pred_x[:, 3:4]

    # encode bonds

    # 1. BONDS: find the covalent bond_indices - allow arg -> DRY
    native = None
    if "prot_covalent_bond" in needed_info.keys():
        native = True
        native_bonds = needed_info["covalent_bond"]
    elif needed_info["adj_degree"]:
        native = True
        native_bonds  = prot_covalent_bond(needed_info["seq"], needed_info["adj_degree"])
        
    if native: 
        native_idxs, native_attrs = native_bonds[0].to(device), native_bonds[1].to(device)

    # determine kind of cutoff (hard distance threshold or closest points)
    closest = None
    if len(needed_info["cutoffs"]) > 0: 
        cutoffs = needed_info["cutoffs"].copy() 
        if sum( isinstance(ci, str) for ci in cutoffs ) > 0:
            cutoffs = [-1e-3] # negative so no bond is taken  
            closest = int( needed_info["cutoffs"][0].split("_")[0] ) 

        # points under cutoff = d(i - j) < X 
        cutoffs = torch.tensor(cutoffs, device=device).type(precise)
        dist_mat = torch.cdist(x, x, p=2)

    # normal buckets
    bond_buckets = torch.zeros(*x.shape[:-1], x.shape[-2], device=device).type(precise)
    if len(needed_info["cutoffs"]) > 0 and not closest:
        # count from latest degree of adjacency given
        bond_buckets = torch.bucketize(dist_mat, cutoffs)
        bond_buckets[native_idxs[0], native_idxs[1]] = cutoffs.shape[0]
        # find the indexes - symmetric and we dont want the diag
        bond_buckets   += cutoffs.shape[0] * torch.eye(bond_buckets.shape[0], device=device).long()
        close_bond_idxs = ( bond_buckets < cutoffs.shape[0] ).nonzero().t()
        # move away from poses reserved for native
        bond_buckets[close_bond_idxs[0], close_bond_idxs[1]] += needed_info["adj_degree"]+1

    # the K closest (covalent bonds excluded) are considered bonds 
    # 如果存在最近的键,执行以下操作
    elif closest:
        # 将距离矩阵复制一份,并将共价键屏蔽掉
        masked_dist_mat = dist_mat.clone()
        masked_dist_mat += torch.eye(masked_dist_mat.shape[0], device=device) * torch.amax(masked_dist_mat)
        masked_dist_mat[native_idxs[0], native_idxs[1]] = masked_dist_mat[0,0].clone()
        # 根据距离排序,*(-1)使得最小值在前
        _, sorted_col_idxs = torch.topk(-masked_dist_mat, k=k, dim=-1)
        # 连接索引并重复行索引以匹配列索引的数量
        sorted_col_idxs = rearrange(sorted_col_idxs[:, :k], '... n k -> ... (n k)')
        sorted_row_idxs = torch.repeat_interleave( torch.arange(dist_mat.shape[0]).long(), repeats=k ).to(device)
        close_bond_idxs = torch.stack([ sorted_row_idxs, sorted_col_idxs ], dim=0)
        # 将远离保留给原生的姿势
        bond_buckets = torch.ones_like(dist_mat) * (needed_info["adj_degree"]+1)

    # 合并所有键
    if len(needed_info["cutoffs"]) > 0:
        if close_bond_idxs.shape[0] > 0:
            whole_bond_idxs = torch.cat([native_idxs, close_bond_idxs], dim=-1)
    else:
        whole_bond_idxs = native_idxs

    # 2. ATTRS: 将键编码为属性
    bond_vecs  = x[ whole_bond_idxs[0] ] - x[ whole_bond_idxs[1] ]
    bond_norms = torch.norm(bond_vecs, dim=-1)
    bond_vecs /= (bond_norms + eps).unsqueeze(-1)
    bond_norms_enc = encode_dist(bond_norms, scales=needed_info["bond_scales"]).squeeze()

    if native:
        bond_buckets[native_idxs[0], native_idxs[1]] = native_attrs
    bond_attrs = bond_buckets[whole_bond_idxs[0] , whole_bond_idxs[1]]
    # 打包标量和向量 - 额外的令牌用于共价键
    bond_n_vectors = 1
    bond_n_scalars = (2 * len(needed_info["bond_scales"]) + 1) + 1 # 最后一个是大小为1+len(cutoffs)的嵌入
    whole_bond_enc = torch.cat([bond_vecs, # 1个向量 - 不需要反转 - 我们做2倍的键(对称性)
                                # 标量
                                bond_norms_enc, # 2 * len(scales)
                                (bond_attrs-1).unsqueeze(-1) # 1 
                               ], dim=-1) 
    # 释放 GPU 内存
    if free_mem:
        del bond_buckets, bond_norms_enc, bond_vecs, dist_mat,\
            close_bond_idxs, native_bond_idxs
        if closest: 
            del masked_dist_mat, sorted_col_idxs, sorted_row_idxs

    embedd_info = {"bond_n_vectors": bond_n_vectors, 
                   "bond_n_scalars": bond_n_scalars, 
                   "bond_embedding_nums": [ len(needed_info["cutoffs"]) + needed_info["adj_degree"] ]} # 额外一个用于共价键(默认)

    return whole_bond_idxs, whole_bond_enc, embedd_info
def encode_whole_protein(seq, true_coords, angles, padding_seq,
                         needed_info = { "cutoffs": [2, 5, 10],
                                          "bond_scales": [0.5, 1, 2]}, free_mem=False):
    """ Encodes a whole protein. In points + vectors. """
    # 获取设备和数据类型
    device, precise = true_coords.device, true_coords.type()
    #################
    # encode points #
    #################
    # 创建云掩码
    cloud_mask = torch.tensor(scn_cloud_mask(seq[:-padding_seq or None])).bool().to(device)
    flat_mask = rearrange(cloud_mask, 'l c -> (l c)')
    # 嵌入所有内容

    # 一般位置嵌入
    center_coords = true_coords - true_coords.mean(dim=0)
    pos_unit_norms = torch.norm(center_coords, dim=-1, keepdim=True)
    pos_unit_vecs  = center_coords / pos_unit_norms
    pos_unit_norms_enc = encode_dist(pos_unit_norms, scales=needed_info["atom_pos_scales"]).squeeze()
    # 重新格式化坐标到scn (L, 14, 3) - 待解决如果填充=0
    coords_wrap = rearrange(center_coords, '(l c) d -> l c d', c=14)[:-padding_seq or None] 

    # 蛋白质中的位置嵌入
    aa_pos = encode_dist( torch.arange(len(seq[:-padding_seq or None]), device=device).float(), scales=needed_info["aa_pos_scales"])
    atom_pos = chain2atoms(aa_pos)[cloud_mask]

    # 原子标识嵌入
    atom_id_embedds = torch.stack([SUPREME_INFO[k]["atom_id_embedd"] for k in seq[:-padding_seq or None]], 
                                  dim=0)[cloud_mask].to(device)
    # 氨基酸嵌入
    seq_int = torch.tensor([AAS2NUM[aa] for aa in seq[:-padding_seq or None]], device=device).long()
    aa_id_embedds   = chain2atoms(seq_int, mask=cloud_mask)

    # CA - SC 距离
    dist2ca_vec, dist2ca_norm = dist2ca(coords_wrap) 
    dist2ca_norm_enc = encode_dist(dist2ca_norm, scales=needed_info["dist2ca_norm_scales"]).squeeze()

    # 主链特征
    vecs, norms    = orient_aa(coords_wrap)
    bb_vecs_atoms  = chain2atoms(torch.transpose(vecs, 0, 1), mask=cloud_mask)
    bb_norms_atoms = chain2atoms(torch.transpose(norms, 0, 1), mask=cloud_mask)
    bb_norms_atoms_enc = encode_dist(bb_norms_atoms, scales=[0.5])

    ################
    # encode bonds #
    ################
    bond_info = encode_whole_bonds(x = coords_wrap[cloud_mask],
                                   x_format = "coords",
                                   embedd_info = {},
                                   needed_info = needed_info )
    whole_bond_idxs, whole_bond_enc, bond_embedd_info = bond_info
    #########
    # merge #
    #########

    # 连接以使最终为[矢量维度,标量维度]
    point_n_vectors = 1 + 1 + 5
    point_n_scalars = 2*len(needed_info["atom_pos_scales"]) + 1 +\
                      2*len(needed_info["aa_pos_scales"]) + 1 +\
                      2*len(needed_info["dist2ca_norm_scales"]) + 1+\
                      rearrange(bb_norms_atoms_enc, 'atoms feats encs -> atoms (feats encs)').shape[1] +\
                      2 # 最后2个尚未嵌入

    whole_point_enc = torch.cat([ pos_unit_vecs[ :-padding_seq*14 or None ][ flat_mask ], # 1
                                  dist2ca_vec[cloud_mask], # 1
                                  rearrange(bb_vecs_atoms, 'atoms n d -> atoms (n d)'), # 5
                                  # 标量
                                  pos_unit_norms_enc[ :-padding_seq*14 or None ][ flat_mask ], # 2n+1
                                  atom_pos, # 2n+1
                                  dist2ca_norm_enc[cloud_mask], # 2n+1
                                  rearrange(bb_norms_atoms_enc, 'atoms feats encs -> atoms (feats encs)'), # 2n+1
                                  atom_id_embedds.unsqueeze(-1),
                                  aa_id_embedds.unsqueeze(-1) ], dim=-1) # 最后2个尚未嵌入
    if free_mem:
        del pos_unit_vecs, dist2ca_vec, bb_vecs_atoms, pos_unit_norms_enc, cloud_mask,\
            atom_pos, dist2ca_norm_enc, bb_norms_atoms_enc, atom_id_embedds, aa_id_embedds
    # 记录嵌入维度信息,包括点向量数量和标量数量
    point_embedd_info = {"point_n_vectors": point_n_vectors,
                         "point_n_scalars": point_n_scalars,}

    # 合并点和键的嵌入信息
    embedd_info = {**point_embedd_info, **bond_embedd_info}

    # 返回整体点编码、整体键索引、整体键编码和嵌入信息
    return whole_point_enc, whole_bond_idxs, whole_bond_enc, embedd_info
def get_prot(dataloader_=None, vocab_=None, min_len=80, max_len=150, verbose=True):
    """ Gets a protein from sidechainnet and returns
        the right attrs for training. 
        Inputs: 
        * dataloader_: sidechainnet iterator over dataset
        * vocab_: sidechainnet VOCAB class
        * min_len: int. minimum sequence length
        * max_len: int. maximum sequence length
        * verbose: bool. verbosity level
    """
    # 遍历数据加载器中的训练数据批次
    for batch in dataloader_['train']:
        # 尝试在两个循环中同时中断
        try:
            # 遍历当前批次中的序列
            for i in range(batch.int_seqs.shape[0]):
                # 获取变量
                seq     = ''.join([vocab_.int2char(aa) for aa in batch.int_seqs[i].numpy()])
                int_seq = batch.int_seqs[i]
                angles  = batch.angs[i]
                mask    = batch.msks[i]
                # 获取填充
                padding_angles = (torch.abs(angles).sum(dim=-1) == 0).long().sum()
                padding_seq    = (batch.int_seqs[i] == 20).sum()
                # 仅接受具有正确维度且没有缺失坐标的序列
                # 大于0以避免后续负索引错误
                if batch.crds[i].shape[0]//14 == int_seq.shape[0]:
                    if ( max_len > len(seq) and len(seq) > min_len ) and padding_seq == padding_angles: 
                        if verbose:
                            print("stopping at sequence of length", len(seq))
                            # print(len(seq), angles.shape, "paddings: ", padding_seq, padding_angles)
                        # 触发 StopIteration 异常
                        raise StopIteration
                    else:
                        # print("found a seq of length:", len(seq),
                        #        "but oustide the threshold:", min_len, max_len)
                        pass
        except StopIteration:
            # 中断外部循环
            break
            
    # 返回序列、坐标、角度、填充序列、掩码和蛋白质ID
    return seq, batch.crds[i], angles, padding_seq, batch.msks[i], batch.pids[i]

GVP - Point Cloud

Geometric Vector Perceptron applied to Point Clouds

To install:

  1. git clone ${repo_url}
  2. install packages:
    • sidechainnet: github.com/jonathankin…
    • joblib, tqdm, numpy, einops, ...
    • torch (was developed using 1.7.1)
    • torch geometric: pytorch-geometric.readthedocs.io/en/latest/n…
    • cd this_repo_folder + pip install . OR pip install geometric-vector-perceptron (but installing from PyPi is not recommended for now - not updated)
    • any other just run: pip install package_name
  3. Try to run the notebooks (they should run, report errors if encountered)
    • proto_dev_model.ipynb: shows how to gather the data and train a simple model on it, then reconstruct original struct and calculate improvement.

Descritpion:

  1. encode a protein (3d) into some features (scalars and position vectors)
    • we encode both point features and edge features
  2. train the model to predict the right point features back
  3. reconstruct the 3d case to see the improvement

TO DO LIST:

See the issues tab?

Contribute

PRs and ideas are welcome. Describe a list of the changes you've made and provide tests/examples if possible (they're not requiered, but surely helps understanding).

.\lucidrains\geometric-vector-perceptron\examples\scn_data_module.py

# 导入必要的模块
from argparse import ArgumentParser
from typing import List, Optional
from typing import Union

import numpy as np
import pytorch_lightning as pl
import sidechainnet
from sidechainnet.dataloaders.collate import get_collate_fn
from sidechainnet.utils.sequence import ProteinVocabulary
from torch.utils.data import DataLoader, Dataset

# 定义自定义数据集类
class ScnDataset(Dataset):
    def __init__(self, dataset, max_len: int):
        super(ScnDataset, self).__init__()
        self.dataset = dataset

        self.max_len = max_len
        self.scn_collate_fn = get_collate_fn(False)
        self.vocab = ProteinVocabulary()

    # 定义数据集的拼接函数
    def collate_fn(self, batch):
        batch = self.scn_collate_fn(batch)
        real_seqs = [
            "".join([self.vocab.int2char(aa) for aa in seq])
            for seq in batch.int_seqs.numpy()
        ]
        seq = real_seqs[0][: self.max_len]
        true_coords = batch.crds[0].view(-1, 14, 3)[: self.max_len].view(-1, 3)
        angles = batch.angs[0, : self.max_len]
        mask = batch.msks[0, : self.max_len]

        # 计算填充序列的长度
        padding_seq = (np.array([*seq]) == "_").sum()
        return {
            "seq": seq,
            "true_coords": true_coords,
            "angles": angles,
            "padding_seq": padding_seq,
            "mask": mask,
        }

    # 获取数据集中指定索引的数据
    def __getitem__(self, index: int):
        return self.dataset[index]

    # 返回数据集的长度
    def __len__(self) -> int:
        return len(self.dataset)

# 定义数据模块类
class ScnDataModule(pl.LightningDataModule):
    # 添加数据特定参数
    @staticmethod
    def add_data_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--casp_version", type=int, default=7)
        parser.add_argument("--scn_dir", type=str, default="./sidechainnet_data")
        parser.add_argument("--train_batch_size", type=int, default=1)
        parser.add_argument("--eval_batch_size", type=int, default=1)
        parser.add_argument("--num_workers", type=int, default=1)
        parser.add_argument("--train_max_len", type=int, default=256)
        parser.add_argument("--eval_max_len", type=int, default=256)

        return parser

    # 初始化数据模块
    def __init__(
        self,
        casp_version: int = 7,
        scn_dir: str = "./sidechainnet_data",
        train_batch_size: int = 1,
        eval_batch_size: int = 1,
        num_workers: int = 1,
        train_max_len: int = 256,
        eval_max_len: int = 256,
        **kwargs,
    ):
        super().__init__()

        assert train_batch_size == eval_batch_size == 1, "batch size must be 1 for now"

        self.casp_version = casp_version
        self.scn_dir = scn_dir
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.num_workers = num_workers
        self.train_max_len = train_max_len
        self.eval_max_len = eval_max_len

    # 设置数据模块
    def setup(self, stage: Optional[str] = None):
        dataloaders = sidechainnet.load(
            casp_version=self.casp_version,
            scn_dir=self.scn_dir,
            with_pytorch="dataloaders",
        )
        print(
            dataloaders.keys()
        )  # ['train', 'train_eval', 'valid-10', ..., 'valid-90', 'test']

        self.train = ScnDataset(dataloaders["train"].dataset, self.train_max_len)
        self.val = ScnDataset(dataloaders["valid-90"].dataset, self.eval_max_len)
        self.test = ScnDataset(dataloaders["test"].dataset, self.eval_max_len)

    # 获取训练数据加载器
    def train_dataloader(self, *args, **kwargs) -> DataLoader:
        return DataLoader(
            self.train,
            batch_size=self.train_batch_size,
            shuffle=True,
            collate_fn=self.train.collate_fn,
            num_workers=self.num_workers,
            pin_memory=True,
        )
    # 定义用于验证数据集的数据加载器函数,返回一个数据加载器对象或对象列表
    def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
        # 创建一个数据加载器对象,用于加载验证数据集
        return DataLoader(
            self.val,  # 使用验证数据集
            batch_size=self.eval_batch_size,  # 指定批量大小
            shuffle=False,  # 不打乱数据集顺序
            collate_fn=self.val.collate_fn,  # 使用验证数据集的数据整理函数
            num_workers=self.num_workers,  # 指定数据加载器的工作进程数
            pin_memory=True,  # 将数据加载到 CUDA 固定内存中
        )

    # 定义用于测试数据集的数据加载器函数,返回一个数据加载器对象或对象列表
    def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
        # 创建一个数据加载器对象,用于加载测试数据集
        return DataLoader(
            self.test,  # 使用测试数据集
            batch_size=self.eval_batch_size,  # 指定批量大小
            shuffle=False,  # 不打乱数据集顺序
            collate_fn=self.test.collate_fn,  # 使用测试数据集的数据整理函数
            num_workers=self.num_workers,  # 指定数据加载器的工作进程数
            pin_memory=True,  # 将数据加载到 CUDA 固定内存中
        )
# 如果当前脚本作为主程序运行
if __name__ == "__main__":
    # 创建一个 ScnDataModule 的实例对象
    dm = ScnDataModule()
    # 设置数据模块
    dm.setup()

    # 获取训练数据加载器
    train = dm.train_dataloader()
    # 打印训练数据加载器的长度
    print("train length", len(train))

    # 获取验证数据加载器
    valid = dm.val_dataloader()
    # 打印验证数据加载器的长度
    print("valid length", len(valid))

    # 获取测试数据加载器
    test = dm.test_dataloader()
    # 打印测试数据加载器的长度
    print("test length", len(test))

    # 遍历测试数据加载器
    for batch in test:
        # 打印当前批次的数据
        print(batch)
        # 跳出循环,只打印第一个批次的数据
        break

.\lucidrains\geometric-vector-perceptron\examples\train_lightning.py

import gc
from argparse import ArgumentParser
from functools import partial
from pathlib import Path
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from einops import rearrange
from loguru import logger
from pytorch_lightning.callbacks import (
    GPUStatsMonitor,
    LearningRateMonitor,
    ModelCheckpoint,
    ProgressBar,
)
from pytorch_lightning.loggers import TensorBoardLogger

from examples.data_handler import kabsch_torch, scn_cloud_mask
from examples.data_utils import (
    encode_whole_bonds,
    encode_whole_protein,
    from_encode_to_pred,
    prot_covalent_bond,
)
from examples.scn_data_module import ScnDataModule
from geometric_vector_perceptron.geometric_vector_perceptron import GVP_Network

# 定义一个继承自 LightningModule 的结构模型类
class StructureModel(pl.LightningModule):
    # 静态方法,用于添加模型特定参数
    @staticmethod
    def add_model_specific_args(parent_parser):
        # 创建参数解析器
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        # 添加模型参数
        parser.add_argument("--depth", type=int, default=4)
        parser.add_argument("--cutoffs", type=float, default=1.0)
        parser.add_argument("--noise", type=float, default=1.0)
        # 添加优化器和调度器参数
        parser.add_argument("--init_lr", type=float, default=1e-3)

        return parser

    # 初始化方法,接受模型参数
    def __init__(
        self,
        depth: int = 1,
        cutoffs: float = 1.0,
        noise: float = 1.0,
        init_lr: float = 1e-3,
        **kwargs,
    ):
        super().__init__()

        # 保存超参数
        self.save_hyperparameters()
        
        # 定义需要的信息字典
        self.needed_info = {
            "cutoffs": [cutoffs], # -1e-3 for just covalent, "30_closest", 5. for under 5A, etc
            "bond_scales": [1, 2, 4],
            "aa_pos_scales": [1, 2, 4, 8, 16, 32, 64, 128],
            "atom_pos_scales": [1, 2, 4, 8, 16, 32],
            "dist2ca_norm_scales": [1, 2, 4],
            "bb_norms_atoms": [0.5],  # will encode 3 vectors with this
        }

        # 创建 GVP_Network 模型
        self.model = GVP_Network(
            n_layers=depth,
            feats_x_in=48,
            vectors_x_in=7,
            feats_x_out=48,
            vectors_x_out=7,
            feats_edge_in=8,
            vectors_edge_in=1,
            feats_edge_out=8,
            vectors_edge_out=1,
            embedding_nums=[36, 20],
            embedding_dims=[16, 16],
            edge_embedding_nums=[2],
            edge_embedding_dims=[2],
            residual=True,
            recalc=1
        )

        self.noise = noise
        self.init_lr = init_lr

        self.baseline_losses = [] # 存储基准损失
        self.epoch_losses = [] # 存储每个 epoch 的损失
    # 定义前向传播函数,接受序列、真实坐标、角度、填充序列、掩码作为输入
    def forward(self, seq, true_coords, angles, padding_seq, mask):
        # 获取需要的信息
        needed_info = self.needed_info
        # 获取设备信息
        device = true_coords.device

        # 将序列截取到填充序列之前的部分
        needed_info["seq"] = seq[: (-padding_seq) or None]
        # 计算蛋白质的共价键
        needed_info["covalent_bond"] = prot_covalent_bond(needed_info["seq"])

        # 对整个蛋白质进行编码
        pre_target = encode_whole_protein(
            seq,
            true_coords,
            angles,
            padding_seq,
            needed_info=needed_info,
            free_mem=True,
        )
        pre_target_x, _, _, embedd_info = pre_target

        # 对蛋白质进行编码并加入噪声
        encoded = encode_whole_protein(
            seq,
            true_coords + self.noise * torch.randn_like(true_coords),
            angles,
            padding_seq,
            needed_info=needed_info,
            free_mem=True,
        )

        x, edge_index, edge_attrs, embedd_info = encoded

        # 创建批次信息
        batch = torch.tensor([0 for i in range(x.shape[0])], device=x.device).long()

        # 添加位置坐标
        cloud_mask = scn_cloud_mask(seq[: (-padding_seq) or None]).to(device)
        chain_mask = mask[: (-padding_seq) or None].unsqueeze(-1) * cloud_mask
        flat_chain_mask = rearrange(chain_mask.bool(), "l c -> (l c)")
        cloud_mask = cloud_mask.bool()
        flat_cloud_mask = rearrange(cloud_mask, "l c -> (l c)")

        # 部分重新计算边
        recalc_edge = partial(
            encode_whole_bonds,
            x_format="encode",
            embedd_info=embedd_info,
            needed_info=needed_info,
            free_mem=True,
        )

        # 预测
        scores = self.model.forward(
            x,
            edge_index,
            batch=batch,
            edge_attr=edge_attrs,
            recalc_edge=recalc_edge,
            verbose=False,
        )

        # 格式化预测、基线和目标
        target = from_encode_to_pred(
            pre_target_x, embedd_info=embedd_info, needed_info=needed_info
        )
        pred = from_encode_to_pred(
            scores, embedd_info=embedd_info, needed_info=needed_info
        )
        base = from_encode_to_pred(x, embedd_info=embedd_info, needed_info=needed_info)

        # 计算误差

        # 选项1:损失是输出令牌的均方误差
        # loss_ = (target-pred)**2
        # loss  = loss_.mean()

        # 选项2:损失是重构坐标的RMSD
        target_coords = target[:, 3:4] * target[:, :3]
        pred_coords = pred[:, 3:4] * pred[:, :3]
        base_coords = base[:, 3:4] * base[:, :3]

        ## 对齐 - 有时svc失败 - 不知道为什么
        try:
            pred_aligned, target_aligned = kabsch_torch(pred_coords.t(), target_coords.t()) # (3, N)
            base_aligned, _ = kabsch_torch(base_coords.t(), target_coords.t())
            loss = ( (pred_aligned.t() - target_aligned.t())[flat_chain_mask[flat_cloud_mask]]**2 ).mean()**0.5 
            loss_base = ( (base_aligned.t() - target_aligned.t())[flat_chain_mask[flat_cloud_mask]]**2 ).mean()**0.5 
        except:
            pred_aligned, target_aligned = None, None
            print("svd failed convergence, ep:", ep)
            loss = ( (pred_coords.t() - target_coords.t())[flat_chain_mask[flat_cloud_mask]]**2 ).mean()**0.5
            loss_base = ( (base_coords - target_coords)[flat_chain_mask[flat_cloud_mask]]**2 ).mean()**0.5 

        # 释放GPU内存
        del true_coords, angles, pre_target_x, edge_index, edge_attrs
        del scores, target_coords, pred_coords, base_coords
        del encoded, pre_target, target_aligned, pred_aligned
        gc.collect()

        # 返回损失
        return {"loss": loss, "loss_base": loss_base}

    # 配置优化器
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.init_lr)
        return optimizer

    # 训练开始时的操作
    def on_train_start(self) -> None:
        self.baseline_losses = []
        self.epoch_losses = []
    # 训练步骤,接收一个批次数据和批次索引
    def training_step(self, batch, batch_idx):
        # 调用前向传播函数得到输出
        output = self.forward(**batch)
        # 获取损失值和基准损失值
        loss = output["loss"]
        loss_base = output["loss_base"]

        # 如果损失值为空或为 NaN,则返回 None
        if loss is None or torch.isnan(loss):
            return None

        # 将损失值和基准损失值添加到对应的列表中
        self.epoch_losses.append(loss.item())
        self.baseline_losses.append(loss_base.item())

        # 记录训练损失值到日志中,显示在进度条中
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        self.log("train_loss_base", output["loss_base"], on_epoch=True, prog_bar=False)

        # 返回损失值
        return loss

    # 训练结束时的操作
    def on_train_end(self) -> None:
        # 创建一个图形窗口
        plt.figure(figsize=(15, 6))
        # 设置图形标题
        plt.title(
            f"Loss Evolution - Denoising of Gaussian-masked Coordinates (mu=0, sigma={self.noise})"
        )
        # 绘制训练损失值随时间的变化曲线

        # 绘制滑动窗口平均值曲线
        for window in [8, 16, 32]:
            # 计算滑动窗口平均值
            plt.plot(
                [
                    np.mean(self.epoch_losses[:window][0 : i + 1])
                    for i in range(min(window, len(self.epoch_losses))
                ]
                + [
                    np.mean(self.epoch_losses[i : i + window + 1])
                    for i in range(len(self.epoch_losses) - window)
                ],
                label="Window mean n={0}".format(window),
            )

        # 绘制基准损失值的水平虚线
        plt.plot(
            np.ones(len(self.epoch_losses)) * np.mean(self.baseline_losses),
            "k--",
            label="Baseline",
        )

        # 设置 x 轴范围
        plt.xlim(-0.01 * len(self.epoch_losses), 1.01 * len(self.epoch_losses))
        # 设置 y 轴标签
        plt.ylabel("RMSD")
        # 设置 x 轴标签
        plt.xlabel("Batch number")
        # 添加图例
        plt.legend()
        # 保存图形为 PDF 文件
        plt.savefig("loss.pdf")

    # 验证步骤,接收一个批次数据和批次索引
    def validation_step(self, batch, batch_idx):
        # 调用前向传播函数得到输出,并记录验证损失值到日志中
        output = self.forward(**batch)
        self.log("val_loss", output["loss"], on_epoch=True, sync_dist=True)
        self.log("val_loss_base", output["loss_base"], on_epoch=True, sync_dist=True)

    # 测试步骤,接收一个批次数据和批次索引
    def test_step(self, batch, batch_idx):
        # 调用前向传播函数得到输出,并记录测试损失值到日志中
        output = self.forward(**batch)
        self.log("test_loss", output["loss"], on_epoch=True, sync_dist=True)
        self.log("test_loss_base", output["loss_base"], on_epoch=True, sync_dist=True)
# 根据参数获取训练器对象
def get_trainer(args):
    # 设置随机种子
    pl.seed_everything(args.seed)

    # 创建日志记录器
    root_dir = Path(args.default_root_dir).expanduser().resolve()
    root_dir.mkdir(parents=True, exist_ok=True)
    tb_save_dir = root_dir / "tb"
    tb_logger = TensorBoardLogger(save_dir=tb_save_dir)
    loggers = [tb_logger]
    logger.info(f"Run tensorboard --logdir {tb_save_dir}")

    # 创建回调函数
    ckpt_cb = ModelCheckpoint(verbose=True)
    lr_cb = LearningRateMonitor(logging_interval="step")
    pb_cb = ProgressBar(refresh_rate=args.progress_bar_refresh_rate)
    callbacks = [lr_cb, pb_cb]

    callbacks.append(ckpt_cb)

    gpu_cb = GPUStatsMonitor()
    callbacks.append(gpu_cb)

    plugins = []
    # 根据参数创建训练器对象
    trainer = pl.Trainer.from_argparse_args(
        args, logger=loggers, callbacks=callbacks, plugins=plugins
    )

    return trainer


def main(args):
    # 创建数据模块对象
    dm = ScnDataModule(**vars(args))
    # 创建模型对象
    model = StructureModel(**vars(args))
    # 获取训练器对象
    trainer = get_trainer(args)
    # 训练模型
    trainer.fit(model, datamodule=dm)
    # 测试模型并获取指标
    metrics = trainer.test(model, datamodule=dm)
    print("test", metrics)


if __name__ == "__main__":
    parser = ArgumentParser()

    parser.add_argument("--seed", type=int, default=23333, help="Seed everything.")

    # 添加模型特定参数
    parser = StructureModel.add_model_specific_args(parser)

    # 添加数据特定参数
    parser = ScnDataModule.add_data_specific_args(parser)

    # 添加训练器参数
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    # 打印参数
    pprint(vars(args))
    # 执行主函数
    main(args)

.\lucidrains\geometric-vector-perceptron\examples\__init__.py

# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
    # 计算矩形的面积
    area = length * width
    # 返回计算得到的面积
    return area

.\lucidrains\geometric-vector-perceptron\geometric_vector_perceptron\geometric_vector_perceptron.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch_geometric.nn 模块中导入 MessagePassing 类
from torch_geometric.nn import MessagePassing

# types

# 导入类型提示相关的模块和类型
from typing import Optional, List, Union
from torch_geometric.typing import OptPairTensor, Adj, Size, OptTensor, Tensor

# helper functions

# 定义一个函数,判断输入值是否存在
def exists(val):
    return val is not None

# classes

# 定义 GVP 类,继承自 nn.Module 类
class GVP(nn.Module):
    def __init__(
        self,
        *,
        dim_vectors_in,
        dim_vectors_out,
        dim_feats_in,
        dim_feats_out,
        feats_activation = nn.Sigmoid(),
        vectors_activation = nn.Sigmoid(),
        vector_gating = False
    ):
        super().__init__()
        self.dim_vectors_in = dim_vectors_in
        self.dim_feats_in = dim_feats_in

        self.dim_vectors_out = dim_vectors_out
        dim_h = max(dim_vectors_in, dim_vectors_out)

        # 初始化权重参数
        self.Wh = nn.Parameter(torch.randn(dim_vectors_in, dim_h))
        self.Wu = nn.Parameter(torch.randn(dim_h, dim_vectors_out))

        self.vectors_activation = vectors_activation

        # 定义输出特征的网络结构
        self.to_feats_out = nn.Sequential(
            nn.Linear(dim_h + dim_feats_in, dim_feats_out),
            feats_activation
        )

        # 根据 vector_gating 参数选择是否使用向量门控
        self.scalar_to_vector_gates = nn.Linear(dim_feats_out, dim_vectors_out) if vector_gating else None

    # 前向传播函数
    def forward(self, data):
        feats, vectors = data
        b, n, _, v, c  = *feats.shape, *vectors.shape

        # 断言向量维度和特征维度是否匹配
        assert c == 3 and v == self.dim_vectors_in, 'vectors have wrong dimensions'
        assert n == self.dim_feats_in, 'scalar features have wrong dimensions'

        # 计算 Vh 和 Vu
        Vh = einsum('b v c, v h -> b h c', vectors, self.Wh)
        Vu = einsum('b h c, h u -> b u c', Vh, self.Wu)

        # 计算向量的模长
        sh = torch.norm(Vh, p = 2, dim = -1)

        # 拼接特征和模长
        s = torch.cat((feats, sh), dim = 1)

        # 计算特征输出
        feats_out = self.to_feats_out(s)

        # 如果存在 scalar_to_vector_gates,则计算门控
        if exists(self.scalar_to_vector_gates):
            gating = self.scalar_to_vector_gates(feats_out)
            gating = gating.unsqueeze(dim = -1)
        else:
            gating = torch.norm(Vu, p = 2, dim = -1, keepdim = True)

        # 计算向量输出
        vectors_out = self.vectors_activation(gating) * Vu
        return (feats_out, vectors_out)

# 定义 GVPDropout 类,继承自 nn.Module 类
class GVPDropout(nn.Module):
    """ Separate dropout for scalars and vectors. """
    def __init__(self, rate):
        super().__init__()
        self.vector_dropout = nn.Dropout2d(rate)
        self.feat_dropout = nn.Dropout(rate)

    # 前向传播函数
    def forward(self, feats, vectors):
        return self.feat_dropout(feats), self.vector_dropout(vectors)

# 定义 GVPLayerNorm 类,继承自 nn.Module 类
class GVPLayerNorm(nn.Module):
    """ Normal layer norm for scalars, nontrainable norm for vectors. """
    def __init__(self, feats_h_size, eps = 1e-8):
        super().__init__()
        self.eps = eps
        self.feat_norm = nn.LayerNorm(feats_h_size)

    # 前向传播函数
    def forward(self, feats, vectors):
        vector_norm = vectors.norm(dim=(-1,-2), keepdim=True)
        normed_feats = self.feat_norm(feats)
        normed_vectors = vectors / (vector_norm + self.eps)
        return normed_feats, normed_vectors

# 定义 GVP_MPNN 类,继承自 MessagePassing 类
class GVP_MPNN(MessagePassing):
    r"""The Geometric Vector Perceptron message passing layer
        introduced in https://openreview.net/forum?id=1YLJDvSx6J4.
        
        Uses a Geometric Vector Perceptron instead of the normal 
        MLP in aggregation phase.

        Inputs will be a concatenation of (vectors, features)

        Args:
        * feats_x_in: int. number of scalar dimensions in the x inputs.
        * vectors_x_in: int. number of vector dimensions in the x inputs.
        * feats_x_out: int. number of scalar dimensions in the x outputs.
        * vectors_x_out: int. number of vector dimensions in the x outputs.
        * feats_edge_in: int. number of scalar dimensions in the edge_attr inputs.
        * vectors_edge_in: int. number of vector dimensions in the edge_attr inputs.
        * feats_edge_out: int. number of scalar dimensions in the edge_attr outputs.
        * vectors_edge_out: int. number of vector dimensions in the edge_attr outputs.
        * dropout: float. dropout rate.
        * vector_dim: int. dimensions of the space containing the vectors.
        * verbose: bool. verbosity level.
    """
    # 初始化函数,接受多个参数
    def __init__(self, feats_x_in, vectors_x_in,
                       feats_x_out, vectors_x_out,
                       feats_edge_in, vectors_edge_in,
                       feats_edge_out, vectors_edge_out,
                       dropout, residual=False, vector_dim=3, 
                       verbose=False, **kwargs):
        # 调用父类的初始化函数,设置聚合方式为"mean"
        super(GVP_MPNN, self).__init__(aggr="mean",**kwargs)
        # 记录是否输出详细信息
        self.verbose = verbose
        # 记录输入特征和向量的维度
        self.feats_x_in    = feats_x_in 
        self.vectors_x_in  = vectors_x_in # 输入中的 N 个向量特征
        self.feats_x_out   = feats_x_out 
        self.vectors_x_out = vectors_x_out # 输出中的 N 个向量特征
        # 记录边属性的维度
        self.feats_edge_in    = feats_edge_in 
        self.vectors_edge_in  = vectors_edge_in # 输入中的 N 个向量特征
        self.feats_edge_out   = feats_edge_out 
        self.vectors_edge_out = vectors_edge_out # 输出中的 N 个向量特征
        # 辅助层
        self.vector_dim = vector_dim
        # 初始化归一化层
        self.norm = nn.ModuleList([GVPLayerNorm(self.feats_x_out), # + self.feats_edge_out
                                   GVPLayerNorm(self.feats_x_out)])
        # 初始化 dropout 层
        self.dropout = GVPDropout(dropout)
        # 是否使用残差连接
        self.residual = residual
        # 接收 vec_in 消息和接收节点
        self.W_EV = nn.Sequential(GVP(
                                      dim_vectors_in = self.vectors_x_in + self.vectors_edge_in, 
                                      dim_vectors_out = self.vectors_x_out + self.feats_edge_out,
                                      dim_feats_in = self.feats_x_in + self.feats_edge_in, 
                                      dim_feats_out = self.feats_x_out + self.feats_edge_out
                                  ), 
                                  GVP(
                                      dim_vectors_in = self.vectors_x_out + self.feats_edge_out, 
                                      dim_vectors_out = self.vectors_x_out + self.feats_edge_out,
                                      dim_feats_in = self.feats_x_out + self.feats_edge_out,
                                      dim_feats_out = self.feats_x_out + self.feats_edge_out
                                  ),
                                  GVP(
                                      dim_vectors_in = self.vectors_x_out + self.feats_edge_out, 
                                      dim_vectors_out = self.vectors_x_out + self.feats_edge_out,
                                      dim_feats_in = self.feats_x_out + self.feats_edge_out,
                                      dim_feats_out = self.feats_x_out + self.feats_edge_out
                                  ))
        
        # 初始化 W_dh 层
        self.W_dh = nn.Sequential(GVP(
                                      dim_vectors_in = self.vectors_x_out,
                                      dim_vectors_out = 2*self.vectors_x_out,
                                      dim_feats_in = self.feats_x_out,
                                      dim_feats_out = 4*self.feats_x_out
                                  ),
                                  GVP(
                                      dim_vectors_in = 2*self.vectors_x_out,
                                      dim_vectors_out = self.vectors_x_out,
                                      dim_feats_in = 4*self.feats_x_out,
                                      dim_feats_out = self.feats_x_out
                                  ))
    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                edge_attr: OptTensor = None, size: Size = None) -> Tensor:
        """"""
        # 获取输入张量 x 的最后一个维度的大小
        x_size = list(x.shape)[-1]
        # 分别聚合特征和向量
        feats, vectors = self.propagate(edge_index, x=x, edge_attr=edge_attr)
        # 聚合
        feats, vectors = self.dropout(feats, vectors.reshape(vectors.shape[0], -1, self.vector_dim))
        # 获取与节点相关的信息 - 不返回边
        feats_nodes  = feats[:, :self.feats_x_in]
        vector_nodes = vectors[:, :self.vectors_x_in]
        # 将向量部分重塑为最后一个 3D
        x_vectors    = x[:, :self.vectors_x_in * self.vector_dim].reshape(x.shape[0], -1, self.vector_dim)
        feats, vectors = self.norm[0]( x[:, self.vectors_x_in * self.vector_dim:]+feats_nodes, x_vectors+vector_nodes )
        # 更新位置感知前馈
        feats_, vectors_ = self.dropout( *self.W_dh( (feats, vectors) ) )
        feats, vectors   = self.norm[1]( feats+feats_, vectors+vectors_ )
        # 使其成为残差
        new_x = torch.cat( [feats, vectors.flatten(start_dim=-2)], dim=-1 )
        if self.residual:
          return new_x + x
        return new_x


    def message(self, x_j, edge_attr) -> Tensor:
        # 拼接特征和边属性
        feats   = torch.cat([ x_j[:, self.vectors_x_in * self.vector_dim:],
                              edge_attr[:, self.vectors_edge_in * self.vector_dim:] ], dim=-1)
        vectors = torch.cat([ x_j[:, :self.vectors_x_in * self.vector_dim], 
                              edge_attr[:, :self.vectors_edge_in * self.vector_dim] ], dim=-1).reshape(x_j.shape[0],-1,self.vector_dim)
        feats, vectors = self.W_EV( (feats, vectors) )
        return feats, vectors.flatten(start_dim=-2)


    def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
        r"""The initial call to start propagating messages.
        Args:
            adj (Tensor or SparseTensor): `edge_index` holds the indices of a general (sparse)
                assignment matrix of shape :obj:`[N, M]`.
            size (tuple, optional): If set to :obj:`None`, the size will be automatically inferred
                and assumed to be quadratic.
                This argument is ignored in case :obj:`edge_index` is a
                :obj:`torch_sparse.SparseTensor`. (default: :obj:`None`)
            **kwargs: Any additional data which is needed to construct and
                aggregate messages, and to update node embeddings.
        """
        size = self.__check_input__(edge_index, size)
        coll_dict = self.__collect__(self.__user_args__,
                                     edge_index, size, kwargs)
        msg_kwargs = self.inspector.distribute('message', coll_dict)
        feats, vectors = self.message(**msg_kwargs)
        # 聚合它们
        aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
        out_feats   = self.aggregate(feats, **aggr_kwargs)
        out_vectors = self.aggregate(vectors, **aggr_kwargs)
        # 返回元组
        update_kwargs = self.inspector.distribute('update', coll_dict)
        return self.update((out_feats, out_vectors), **update_kwargs)

        
    def __repr__(self):
        dict_print = { "feats_x_in": self.feats_x_in,
                       "vectors_x_in": self.vectors_x_in,
                       "feats_x_out": self.feats_x_out,
                       "vectors_x_out": self.vectors_x_out,
                       "feats_edge_in": self.feats_edge_in,
                       "vectors_edge_in": self.vectors_edge_in,
                       "feats_edge_out": self.feats_edge_out,
                       "vectors_edge_out": self.vectors_edge_out,
                       "vector_dim": self.vector_dim }
        return  'GVP_MPNN Layer with the following attributes: ' + str(dict_print)
class GVP_Network(nn.Module):
    r"""Sample GNN model architecture that uses the Geometric Vector Perceptron
        message passing layer to learn over point clouds. 
        Main MPNN layer introduced in https://openreview.net/forum?id=1YLJDvSx6J4.

        Inputs will be standard GNN: x, edge_index, edge_attr, batch, ...

        Args:
        * n_layers: int. number of MPNN layers
        * feats_x_in: int. number of scalar dimensions in the x inputs.
        * vectors_x_in: int. number of vector dimensions in the x inputs.
        * feats_x_out: int. number of scalar dimensions in the x outputs.
        * vectors_x_out: int. number of vector dimensions in the x outputs.
        * feats_edge_in: int. number of scalar dimensions in the edge_attr inputs.
        * vectors_edge_in: int. number of vector dimensions in the edge_attr inputs.
        * feats_edge_out: int. number of scalar dimensions in the edge_attr outputs.
        * embedding_nums: list. number of unique keys to embedd. for points
                          1 entry per embedding needed. 
        * embedding_dims: list. point - number of dimensions of
                          the resulting embedding. 1 entry per embedding needed. 
        * edge_embedding_nums: list. number of unique keys to embedd. for edges.
                               1 entry per embedding needed. 
        * edge_embedding_dims: list. point - number of dimensions of
                               the resulting embedding. 1 entry per embedding needed. 
        * vectors_edge_out: int. number of vector dimensions in the edge_attr outputs.
        * dropout: float. dropout rate.
        * vector_dim: int. dimensions of the space containing the vectors.
        * recalc: bool. Whether to recalculate edge features between MPNN layers.
        * verbose: bool. verbosity level.
    """
    # 初始化函数,接受多个参数,包括层数、输入特征和向量、输出特征和向量、边特征和向量等
    def __init__(self, n_layers, 
                       feats_x_in, vectors_x_in,
                       feats_x_out, vectors_x_out,
                       feats_edge_in, vectors_edge_in,
                       feats_edge_out, vectors_edge_out,
                       embedding_nums=[], embedding_dims=[],
                       edge_embedding_nums=[], edge_embedding_dims=[],
                       dropout=0.0, residual=False, vector_dim=3,
                       recalc=1, verbose=False):
        # 调用父类的初始化函数
        super().__init__()

        # 初始化各种属性
        self.n_layers         = n_layers 
        self.embedding_nums   = embedding_nums
        self.embedding_dims   = embedding_dims
        self.emb_layers       = torch.nn.ModuleList()
        self.edge_embedding_nums = edge_embedding_nums
        self.edge_embedding_dims = edge_embedding_dims
        self.edge_emb_layers     = torch.nn.ModuleList()
        
        # 实例化点和边的嵌入层
        for i in range( len(self.embedding_dims) ):
            self.emb_layers.append(nn.Embedding(num_embeddings = embedding_nums[i],
                                                embedding_dim  = embedding_dims[i]))
            feats_x_in += embedding_dims[i] - 1
            feats_x_out += embedding_dims[i] - 1
        for i in range( len(self.edge_embedding_dims) ):
            self.edge_emb_layers.append(nn.Embedding(num_embeddings = edge_embedding_nums[i],
                                                     embedding_dim  = edge_embedding_dims[i]))
            feats_edge_in += edge_embedding_dims[i] - 1
            feats_edge_out += edge_embedding_dims[i] - 1
        
        # 初始化其他属性
        self.fc_layers        = torch.nn.ModuleList()
        self.gcnn_layers      = torch.nn.ModuleList()
        self.feats_x_in       = feats_x_in
        self.vectors_x_in     = vectors_x_in
        self.feats_x_out      = feats_x_out
        self.vectors_x_out    = vectors_x_out
        self.feats_edge_in    = feats_edge_in
        self.vectors_edge_in  = vectors_edge_in
        self.feats_edge_out   = feats_edge_out
        self.vectors_edge_out = vectors_edge_out
        self.dropout          = dropout
        self.residual         = residual
        self.vector_dim       = vector_dim
        self.recalc           = recalc
        self.verbose          = verbose
        
        # 实例化GCNN层
        for i in range(n_layers):
            layer = GVP_MPNN(feats_x_in, vectors_x_in,
                             feats_x_out, vectors_x_out,
                             feats_edge_in, vectors_edge_in,
                             feats_edge_out, vectors_edge_out,
                             dropout, residual=residual,
                             vector_dim=vector_dim, verbose=verbose)
            self.gcnn_layers.append(layer)
    # 定义一个前向传播函数,接受输入 x、边索引 edge_index、批次 batch、边属性 edge_attr
    # bsize 为批次大小,recalc_edge 为重新计算边特征的函数,verbose 为是否输出详细信息的标志
    def forward(self, x, edge_index, batch, edge_attr,
                bsize=None, recalc_edge=None, verbose=0):
        """ Embedding of inputs when necessary, then pass layers.
            Recalculate edge features every time with the
            `recalc_edge` function.
        """
        # 复制输入数据,用于后续恢复原始数据
        original_x = x.clone()
        original_edge_index = edge_index.clone()
        original_edge_attr = edge_attr.clone()
        
        # 当需要时进行嵌入
        # 选择要嵌入的部分,逐个进行嵌入并添加到输入中
        
        # 提取要嵌入的部分
        to_embedd = x[:, -len(self.embedding_dims):].long()
        for i, emb_layer in enumerate(self.emb_layers):
            # 在第一次迭代时,对应于 `to_embedd` 部分的部分会被丢弃
            stop_concat = -len(self.embedding_dims) if i == 0 else x.shape[-1]
            x = torch.cat([x[:, :stop_concat], 
                           emb_layer(to_embedd[:, i])], dim=-1)
        
        # 传递层
        for i, layer in enumerate(self.gcnn_layers):
            # 嵌入边属性(每次都需要,因为边属性和索引在每次传递时都会重新计算)
            to_embedd = edge_attr[:, -len(self.edge_embedding_dims):].long()
            for j, edge_emb_layer in enumerate(self.edge_emb_layers):
                # 在第一次迭代时,对应于 `to_embedd` 部分的部分会被丢弃
                stop_concat = -len(self.edge_embedding_dims) if j == 0 else x.shape[-1]
                edge_attr = torch.cat([edge_attr[:, :-len(self.edge_embedding_dims) + j], 
                                       edge_emb_layer(to_embedd[:, j])], dim=-1)
            
            # 传递层
            x = layer(x, edge_index, edge_attr, size=bsize)

            # 每 self.recalc 步重新计算边信息
            # 但如果是最后一层的最后一次迭代,则不需要重新计算
            if (1 % self.recalc == 0) and not (i == self.n_layers - 1):
                edge_index, edge_attr, _ = recalc_edge(x)  # 返回属性、索引、嵌入信息
            else:
                edge_attr = original_edge_attr.clone()
                edge_index = original_edge_index.clone()
            
            if verbose:
                print("========")
                print("iter:", j, "layer:", i, "nlinks:", edge_attr.shape)
        
        return x

    # 定义对象的字符串表示形式
    def __repr__(self):
        return 'GVP_Network of: {0} layers'.format(len(self.gcnn_layers))

.\lucidrains\geometric-vector-perceptron\geometric_vector_perceptron\__init__.py

# 从 geometric_vector_perceptron 模块中导入 GVP, GVPDropout, GVPLayerNorm, GVP_MPNN, GVP_Network 类
from geometric_vector_perceptron.geometric_vector_perceptron import GVP, GVPDropout, GVPLayerNorm, GVP_MPNN, GVP_Network

Geometric Vector Perceptron

Implementation of Geometric Vector Perceptron, a simple circuit with 3d rotation equivariance for learning over large biomolecules, in Pytorch. The repository may also contain experimentation to see if this could be easily extended to self-attention.

Install

$ pip install geometric-vector-perceptron

Functionality

  • GVP: Implementing the basic geometric vector perceptron.
  • GVPDropout: Adapted dropout for GVP in MPNN context
  • GVPLayerNorm: Adapted LayerNorm for GVP in MPNN context
  • GVP_MPNN: Adapted instance of Message Passing class from torch-geometric package. Still not tested.
  • GVP_Network: Functional model architecture ready for working with arbitary point clouds.

Usage

import torch
from geometric_vector_perceptron import GVP

model = GVP(
    dim_vectors_in = 1024,
    dim_feats_in = 512,
    dim_vectors_out = 256,
    dim_feats_out = 512,
    vector_gating = True   # use the vector gating as proposed in https://arxiv.org/abs/2106.03843
)

feats, vectors = (torch.randn(1, 512), torch.randn(1, 1024, 3))

feats_out, vectors_out = model( (feats, vectors) ) # (1, 256), (1, 512, 3)

With the specialized dropout and layernorm as described in the paper

import torch
from torch import nn
from geometric_vector_perceptron import GVP, GVPDropout, GVPLayerNorm

model = GVP(
    dim_vectors_in = 1024,
    dim_feats_in = 512,
    dim_vectors_out = 256,
    dim_feats_out = 512,
    vector_gating = True
)

dropout = GVPDropout(0.2)
norm = GVPLayerNorm(512)

feats, vectors = (torch.randn(1, 512), torch.randn(1, 1024, 3))

feats, vectors = model( (feats, vectors) )
feats, vectors = dropout(feats, vectors)
feats, vectors = norm(feats, vectors)  # (1, 256), (1, 512, 3)

TF implementation:

The original implementation in TF by the paper authors can be found here: github.com/drorlab/gvp…

Citations

@inproceedings{anonymous2021learning,
    title   = {Learning from Protein Structure with Geometric Vector Perceptrons},
    author  = {Anonymous},
    booktitle = {Submitted to International Conference on Learning Representations},
    year    = {2021},
    url     = {https://openreview.net/forum?id=1YLJDvSx6J4}
}
@misc{jing2021equivariant,
    title   = {Equivariant Graph Neural Networks for 3D Macromolecular Structure}, 
    author  = {Bowen Jing and Stephan Eismann and Pratham N. Soni and Ron O. Dror},
    year    = {2021},
    eprint  = {2106.03843},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

.\lucidrains\geometric-vector-perceptron\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'geometric-vector-perceptron', # 包的名称
  packages = find_packages(), # 查找所有包
  version = '0.0.14', # 版本号
  license='MIT', # 许可证
  description = 'Geometric Vector Perceptron - Pytorch', # 描述
  author = 'Phil Wang, Eric Alcaide', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/geometric-vector-perceptron', # 项目链接
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'proteins',
    'biomolecules',
    'equivariance'
  ],
  install_requires=[ # 安装依赖
    'torch>=1.6',
    'torch-scatter',
    'torch-sparse',
    'torch-cluster',
    'torch-spline-conv',
    'torch-geometric'
  ],
  setup_requires=[ # 设置需要的依赖
    'pytest-runner',
  ],
  tests_require=[ # 测试需要的依赖
    'pytest'
  ],
  classifiers=[ # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\geometric-vector-perceptron\tests\tests.py

# 导入 torch 库
import torch
# 从 geometric_vector_perceptron 库中导入 GVP, GVPDropout, GVPLayerNorm, GVP_MPNN
from geometric_vector_perceptron import GVP, GVPDropout, GVPLayerNorm, GVP_MPNN

# 定义容差值
TOL = 1e-2

# 生成随机旋转矩阵
def random_rotation():
    q, r = torch.qr(torch.randn(3, 3))
    return q

# 计算向量之间的差值矩阵
def diff_matrix(vectors):
    b, _, d = vectors.shape
    diff = vectors[..., None, :] - vectors[:, None, ...]
    return diff.reshape(b, -1, d)

# 测试等变性
def test_equivariance():
    R = random_rotation()

    # 创建 GVP 模型
    model = GVP(
        dim_vectors_in = 1024,
        dim_feats_in = 512,
        dim_vectors_out = 256,
        dim_feats_out = 512
    )

    feats = torch.randn(1, 512)
    vectors = torch.randn(1, 32, 3)

    feats_out, vectors_out = model( (feats, diff_matrix(vectors)) )
    feats_out_r, vectors_out_r = model( (feats, diff_matrix(vectors @ R)) )

    err = ((vectors_out @ R) - vectors_out_r).max()
    assert err < TOL, 'equivariance must be respected'

# 测试所有层类型
def test_all_layer_types():
    R = random_rotation()

    # 创建 GVP 模型
    model = GVP(
        dim_vectors_in = 1024,
        dim_feats_in = 512,
        dim_vectors_out = 256,
        dim_feats_out = 512
    )
    dropout = GVPDropout(0.2)
    layer_norm = GVPLayerNorm(512)

    feats = torch.randn(1, 512)
    message = torch.randn(1, 512)
    vectors = torch.randn(1, 32, 3)

    # GVP 层
    feats_out, vectors_out = model( (feats, diff_matrix(vectors)) )
    assert list(feats_out.shape) == [1, 512] and list(vectors_out.shape) == [1, 256, 3]

    # GVP Dropout
    feats_out, vectors_out = dropout(feats_out, vectors_out)
    assert list(feats_out.shape) == [1, 512] and list(vectors_out.shape) == [1, 256, 3]

    # GVP Layer Norm
    feats_out, vectors_out = layer_norm(feats_out, vectors_out)
    assert list(feats_out.shape) == [1, 512] and list(vectors_out.shape) == [1, 256, 3]

# 测试 MPNN
def test_mpnn():
    # 输入数据
    x = torch.randn(5, 32)
    edge_idx = torch.tensor([[0,2,3,4,1], [1,1,3,3,4]]).long()
    edge_attr = torch.randn(5, 16)
    # 节点 (8 个标量和 8 个向量) || 边 (4 个标量和 3 个向量)
    dropout = 0.1
    # 定义层
    gvp_mpnn = GVP_MPNN(feats_x_in = 8,
                        vectors_x_in = 8,
                        feats_x_out = 8,
                        vectors_x_out = 8, 
                        feats_edge_in = 4,
                        vectors_edge_in = 4,
                        feats_edge_out = 4,
                        vectors_edge_out = 4,
                        dropout=0.1 )
    x_out = gvp_mpnn(x, edge_idx, edge_attr)

    assert x.shape == x_out.shape, "Input and output shapes don't match"

# 主函数入口
if __name__ == "__main__":
    # 执行等变性测试
    test_equivariance()
    # 执行所有层类型测试
    test_all_layer_types()
    # 执行 MPNN 测试
    test_mpnn()

.\lucidrains\gigagan-pytorch\gigagan_pytorch\attend.py

# 导入必要的库和模块
from functools import wraps
from packaging import version
from collections import namedtuple

import torch
from torch import nn, einsum
import torch.nn.functional as F

# 定义一个命名元组,用于存储注意力机制的配置信息
AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,确保被装饰的函数只执行一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 用装饰器once包装print函数,确保只打印一次
print_once = once(print)

# 主要类定义
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        flash = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定在cuda和cpu上的高效注意力配置
        self.cpu_config = AttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = AttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = AttentionConfig(False, True, True)

    # 实现flash attention的方法
    def flash_attn(self, q, k, v):
        is_cuda = q.is_cuda

        q, k, v = map(lambda t: t.contiguous(), (q, k, v))

        # 检查是否有兼容的设备支持flash attention
        config = self.cuda_config if is_cuda else self.cpu_config

        # 使用torch.backends.cuda.sdp_kernel函数应用flash attention
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p = self.dropout if self.training else 0.
            )

        return out

    # 前向传播函数
    def forward(self, q, k, v):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        if self.flash:
            return self.flash_attn(q, k, v)

        scale = q.shape[-1] ** -0.5

        # 计算相似度
        sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale

        # 注意力计算
        attn = sim.softmax(dim = -1)
        attn = self.attn_dropout(attn)

        # 聚合数值
        out = einsum("b h i j, b h j d -> b h i d", attn, v)

        return out

.\lucidrains\gigagan-pytorch\gigagan_pytorch\data.py

# 导入必要的库
from functools import partial
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from PIL import Image
from torchvision import transforms as T

from beartype.door import is_bearable
from beartype.typing import Tuple

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 将图像转换为指定格式的函数
def convert_image_to_fn(img_type, image):
    if image.mode == img_type:
        return image

    return image.convert(img_type)

# 自定义数据集拼接函数
# 使数据集可以返回字符串并将其拼接成 List[str]
def collate_tensors_or_str(data):
    is_one_data = not isinstance(data[0], tuple)

    if is_one_data:
        data = torch.stack(data)
        return (data,)

    outputs = []
    for datum in zip(*data):
        if is_bearable(datum, Tuple[str, ...]):
            output = list(datum)
        else:
            output = torch.stack(datum)

        outputs.append(output)

    return tuple(outputs)

# 数据集类

# 图像数据集类
class ImageDataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        exts = ['jpg', 'jpeg', 'png', 'tiff'],
        augment_horizontal_flip = False,
        convert_image_to = None
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size

        # 获取文件夹中指定扩展名的所有文件路径
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        # 断言确保文件路径数量大于0
        assert len(self.paths) > 0, 'your folder contains no images'
        # 断言确保文件路径数量大于100
        assert len(self.paths) > 100, 'you need at least 100 images, 10k for research paper, millions for miraculous results (try Laion-5B)'

        # 创建转换函数
        maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else nn.Identity()

        # 图像转换操作序列
        self.transform = T.Compose([
            T.Lambda(maybe_convert_fn),
            T.Resize(image_size),
            T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    # 获取数据加载器
    def get_dataloader(self, *args, **kwargs):
        return DataLoader(self, *args, shuffle = True, drop_last = True, **kwargs)

    # 返回数据集长度
    def __len__(self):
        return len(self.paths)

    # 获取数据���中的数据
    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# 文本图像数据集类
class TextImageDataset(Dataset):
    def __init__(self):
        raise NotImplementedError

    # 获取数据加载器
    def get_dataloader(self, *args, **kwargs):
        return DataLoader(self, *args, collate_fn = collate_tensors_or_str, **kwargs)

# 模拟文本图像数据集类
class MockTextImageDataset(TextImageDataset):
    def __init__(
        self,
        image_size,
        length = int(1e5),
        channels = 3
    ):
        self.image_size = image_size
        self.channels = channels
        self.length = length

    # 获取数据加载器
    def get_dataloader(self, *args, **kwargs):
        return DataLoader(self, *args, collate_fn = collate_tensors_or_str, **kwargs)

    # 返回数据集长度
    def __len__(self):
        return self.length

    # 获取数据集中的数据
    def __getitem__(self, index):
        mock_image = torch.randn(self.channels, self.image_size, self.image_size)
        return mock_image, 'mock text'