Lucidrains 系列项目源码解析(四十)
.\lucidrains\geometric-vector-perceptron\examples\data_utils.py
# 作者:Eric Alcaide
# 导入必要的库
import os
import sys
# 科学计算库
import torch
import torch_sparse
import numpy as np
from einops import repeat, rearrange
# 导入自定义工具 - 来自 https://github.com/EleutherAI/mp_nerf
from data_handler import *
# 新数据构建函数
def get_atom_ids_dict():
""" 获取将每个原子映射到一个标记的字典 """
# 初始化原子集合
ids = set(["N", "CA", "C", "O"])
# 遍历 SC_BUILD_INFO 中的键值对
for k,v in SC_BUILD_INFO.items():
# 遍历原子名称列表
for name in v["atom-names"]:
ids.add(name)
# 返回原子到标记的映射字典
return {k: i for i,k in enumerate(sorted(ids))}
#################################
##### 原始项目数据 #####
#################################
# 定义氨基酸序列和对应的数字
AAS = "ARNDCQEGHILKMFPSTWYV_"
AAS2NUM = {k: AAS.index(k) for k in AAS}
# 获取原子标记字典
ATOM_IDS = get_atom_ids_dict()
# 定义氨基酸的键值对应的键为氨基酸,值为键值对应的原子键值对
GVP_DATA = {
'A': {
'bonds': [[0,1], [1,2], [2,3], [1,4]]
},
'R': {
'bonds': [[0,1], [1,2], [2,3], [2,4], [4,5], [5,6],
[6,7], [7,8], [8,9], [8,10]]
},
# 其他氨基酸的键值对应的键值对
# ...
'_': {
'bonds': []
}
}
#################################
##### 原始项目数据 #####
#################################
def graph_laplacian_embedds(edges, eigen_k, center_idx=1, norm=False):
""" 返回图拉普拉斯的前 K 个特征向量中点的嵌入。
输入:
* edges: (2, N). 长整型张量或列表。足够表示无向边。
* eigen_k: 整数。要返回嵌入的前 N 个特征向量。
* center_idx: 整数。用作嵌入中心的索引。
* norm: 布尔值。是否使用归一化拉普拉斯。不建议使用。
输出:(n_points, eigen_k)
"""
# 如果 edges 是列表,则转换为长整型张量
if isinstance(edges, list):
edges = torch.tensor(edges).long()
# 纠正维度
if edges.shape[0] != 2:
edges = edges.t()
# 如果为空,则返回零张量
if edges.shape[0] == 0:
return torch.zeros(1, eigen_k)
# 获取参数
# 计算边的最大值并加1,作为邻接矩阵的大小
size = torch.max(edges)+1
# 获取边所在设备信息
device = edges.device
# 创建邻接矩阵
adj_mat = torch.eye(size, device=device)
# 遍历边的起始点和终点,将邻接矩阵对应位置设为1
for i,j in edges.t():
adj_mat[i,j] = adj_mat[j,i] = 1.
# 计算度矩阵
deg_mat = torch.eye(size) * adj_mat.sum(dim=-1, keepdim=True)
# 计算拉普拉斯矩阵
laplace = deg_mat - adj_mat
# 如果传入了norm参数,则使用规范化的拉普拉斯矩阵
if norm:
# 遍历边的起始点和终点,更新拉普拉斯矩阵的值
for i,j in edges.t():
laplace[i,j] = laplace[j,i] = -1 / (deg_mat[i,i] * deg_mat[j,j])**0.5
# 对拉普拉斯矩阵进行特征值分解,获取特征值和特征向量
e, v = torch.symeig(laplace, eigenvectors=True)
# 根据特征值的绝对值降序排列,获取排序后的索引
idxs = torch.sort( e.abs(), descending=True)[1]
# 获取前eigen_k个特征向量作为嵌入向量
embedds = v[:, idxs[:eigen_k]]
# 将嵌入向量减去中心点的嵌入向量
embedds = embedds - embedds[center_idx].unsqueeze(-2)
# 返回处理后的嵌入向量
return embedds
# 返回每个氨基酸中每个原子的标记
def make_atom_id_embedds(k):
# 创建一个长度为14的零张量
mask = torch.zeros(14).long()
# 定义氨基酸中的原子列表
atom_list = ["N", "CA", "C", "O"] + SC_BUILD_INFO[k]["atom-names"]
# 遍历原子列表,将每个原子的标记存储在mask中
for i,atom in enumerate(atom_list):
mask[i] = ATOM_IDS[atom]
return mask
#################################
########## SAVE INFO ############
#################################
# 创建包含各种信息的字典
SUPREME_INFO = {k: {"cloud_mask": make_cloud_mask(k),
"bond_mask": make_bond_mask(k),
"theta_mask": make_theta_mask(k),
"torsion_mask": make_torsion_mask(k),
"idx_mask": make_idx_mask(k),
#
"eigen_embedd": graph_laplacian_embedds(GVP_DATA[k]["bonds"], eigen_k = 3),
"atom_id_embedd": make_atom_id_embedds(k)
}
for k in "ARNDCQEGHILKMFPSTWYV_"}
#################################
######### RANDOM UTILS ##########
#################################
# 使用正弦和余弦编码距离
def encode_dist(x, scales=[1,2,4,8], include_self = True):
""" Encodes a distance with sines and cosines.
Inputs:
* x: (batch, N) or (N,). data to encode.
Infer devic and type (f16, f32, f64) from here.
* scales: (s,) or list. lower or higher depending on distances.
Output: (..., num_scales*2 + 1) if include_self or (..., num_scales*2)
"""
x = x.unsqueeze(-1)
# 推断设备
device, precise = x.device, x.type()
# 转换为张量
if isinstance(scales, list):
scales = torch.tensor([scales], device=device).type(precise)
# 获取正弦编码
sines = torch.sin(x / scales)
cosines = torch.cos(x / scales)
# 连接并返回
enc_x = torch.cat([sines, cosines], dim=-1)
return torch.cat([enc_x, x], dim=-1) if include_self else enc_x
# 解码距离
def decode_dist(x, scales=[1,2,4,8], include_self = False):
""" Encodes a distance with sines and cosines.
Inputs:
* x: (batch, N, 2*fourier_feats (+1) ) or (N,). data to encode.
Infer devic and type (f16, f32, f64) from here.
* scales: (s,) or list. lower or higher depending on distances.
* include_self: whether to average with raw prediction or not.
Output: (batch, N)
"""
device, precise = x.device, x.type()
# 转换为张量
if isinstance(scales, list):
scales = torch.tensor([scales], device=device).type(precise)
# 通过 atan2 解码并校正负角度
half = x.shape[-1]//2
decodes = torch.atan2(x[..., :half], x[..., half:2*half])
decodes += (decodes<0).type(precise) * 2*np.pi
# 调整偏移量
offsets = torch.zeros_like(decodes)
for i in range(decodes.shape[-1]-1, 0, -1):
offsets[:, i-1] = 2 * ( offsets[:, i] + (decodes[:, i]>np.pi).type(precise) * np.pi )
decodes += offsets
avg_dec = (decodes * scales).mean(dim=-1, keepdim=True)
if include_self:
return 0.5*(avg_dec + x[..., -1:])
return avg_dec
# 计算第n次邻接矩阵
def nth_deg_adjacency(adj_mat, n=1, sparse=False):
""" Calculates the n-th degree adjacency matrix.
Performs mm of adj_mat and adds the newly added.
Default is dense. Mods for sparse version are done when needed.
Inputs:
* adj_mat: (N, N) adjacency tensor
* n: int. degree of the output adjacency
* sparse: bool. whether to use torch-sparse module
Outputs:
* edge_idxs: the ij positions of the adjacency matrix
* edge_attrs: the degree of connectivity (1 for neighs, 2 for neighs^2 )
"""
adj_mat = adj_mat.float()
attr_mat = torch.zeros_like(adj_mat)
# 遍历范围为n的循环
for i in range(n):
# 如果i为0,则将属性矩阵与邻接矩阵相加
if i == 0:
attr_mat += adj_mat
continue
# 如果i为1且sparse为True
if i == 1 and sparse:
# 创建稀疏邻接张量
adj_mat = torch.sparse.FloatTensor(adj_mat.nonzero().t(),
adj_mat[adj_mat != 0]).to(adj_mat.device).coalesce()
idxs, vals = adj_mat.indices(), adj_mat.values()
m, k, n = 3 * [adj_mat.shape[0]] # (m, n) * (n, k) , 但adj_mats是方阵:m=n=k
# 如果sparse为True
if sparse:
# 使用torch_sparse库中的spspmm函数进行稀疏矩阵乘法
idxs, vals = torch_sparse.spspmm(idxs, vals, idxs, vals, m=m, k=k, n=n)
adj_mat = torch.zeros_like(attr_mat)
adj_mat[idxs[0], idxs[1]] = vals.bool().float()
else:
# 如果sparse为False,则将邻接矩阵平方,转换为布尔型矩阵
adj_mat = (adj_mat @ adj_mat).bool().float()
# 更新属性矩阵
attr_mat[(adj_mat - attr_mat.bool().float()).bool()] += i + 1
# 返回更新后的邻接矩阵和属性矩阵
return adj_mat, attr_mat
# 返回蛋白质的共价键的索引
def prot_covalent_bond(seq, adj_degree=1, cloud_mask=None):
""" 返回蛋白质的共价键的索引。
输入
* seq: str. 用1字母氨基酸代码表示的蛋白质序列。
* cloud_mask: 选择存在原子的掩码。
输出: edge_idxs
"""
# 创建或推断 cloud_mask
if cloud_mask is None:
cloud_mask = scn_cloud_mask(seq).bool()
device, precise = cloud_mask.device, cloud_mask.type()
# 获取每个氨基酸的起始位置
scaff = torch.zeros_like(cloud_mask)
scaff[:, 0] = 1
idxs = scaff[cloud_mask].nonzero().view(-1)
# 从包含 GVP_DATA 的字典中获取姿势 + 索引 - 返回所有边
adj_mat = torch.zeros(idxs.amax()+14, idxs.amax()+14)
for i,idx in enumerate(idxs):
# 与下一个氨基酸的键
extra = []
if i < idxs.shape[0]-1:
extra = [[2, (idxs[i+1]-idx).item()]]
bonds = idx + torch.tensor( GVP_DATA[seq[i]]['bonds'] + extra ).long().t()
adj_mat[bonds[0], bonds[1]] = 1.
# 转换为无向图
adj_mat = adj_mat + adj_mat.t()
# 进行 N 次邻接
adj_mat, attr_mat = nth_deg_adjacency(adj_mat, n=adj_degree, sparse=True)
edge_idxs = attr_mat.nonzero().t().long()
edge_attrs = attr_mat[edge_idxs[0], edge_idxs[1]]
return edge_idxs, edge_attrs
def dist2ca(x, mask=None, eps=1e-7):
""" 计算每个点到 C-alfa 的距离。
输入:
* x: (L, 14, D)
* mask: (L, 14) 的布尔掩码
返回单位向量和范数。
"""
x = x - x[:, 1].unsqueeze(1)
norm = torch.norm(x, dim=-1, keepdim=True)
x_norm = x / (norm+eps)
if mask:
return x_norm[mask], norm[mask]
return x_norm, norm
def orient_aa(x, mask=None, eps=1e-7):
""" 计算主链特征的单位向量和范数。
输入:
* x: (L, 14, D). Sidechainnet 格式的坐标。
返回单位向量 (5) 和范数 (3)。
"""
# 获取张量信息
device, precise = x.device, x.type()
vec_wrap = torch.zeros(5, x.shape[0], 3, device=device) # (feats, L, dims+1)
norm_wrap = torch.zeros(3, x.shape[0], device=device)
# 第一个特征是 CB-CA
vec_wrap[0] = x[:, 4] - x[:, 1]
norm_wrap[0] = torch.norm(vec_wrap[0], dim=-1)
vec_wrap[0] /= norm_wrap[0].unsqueeze(dim=-1) + eps
# 第二个是 CA+ - CA :
vec_wrap[1, :-1] = x[:-1, 1] - x[1:, 1]
norm_wrap[1, :-1] = torch.norm(vec_wrap[1, :-1], dim=-1)
vec_wrap[1, :-1] /= norm_wrap[1, :-1].unsqueeze(dim=-1) + eps
# 同样但是反向向量
vec_wrap[2] = (-1)*vec_wrap[1]
# 第三个是 CA - CA-
vec_wrap[3, 1:] = x[:-1, 1] - x[1:, 1]
norm_wrap[2, 1:] = torch.norm(vec_wrap[3, 1:], dim=-1)
vec_wrap[3, 1:] /= norm_wrap[2, 1:].unsqueeze(dim=-1) + eps
# 现在反向顺序的向量
vec_wrap[4] = (-1)*vec_wrap[3]
return vec_wrap, norm_wrap
def chain2atoms(x, mask=None):
""" 从 (L, other) 扩展到 (L, C, other)。"""
device, precise = x.device, x.type()
# 获取掩码
wrap = torch.ones(x.shape[0], 14, *x.shape[1:]).type(precise).to(device)
# 分配
wrap = wrap * x.unsqueeze(1)
if mask is not None:
return wrap[mask]
return wrap
def from_encode_to_pred(whole_point_enc, use_fourier=False, embedd_info=None, needed_info=None, vec_dim=3):
""" 将上述函数的编码转换为标签/预测格式。
仅包含位置恢复所需的基本信息 (径向单位向量 + 范数)
输入: 包含以下内容的输入元组:
* whole_point_enc: (atoms, vector_dims+scalar_dims)
与上述函数相同的形状。
径向单位向量必须是第一个向量维度
* embedd_info: 字典。包含标量和向量特征的数量。
"""
vec_dims = vec_dim * embedd_info["point_n_vectors"]
start_pos = 2*len(needed_info["atom_pos_scales"])+vec_dims
# 如果使用傅立叶变换
if use_fourier:
# 解码整个点编码中的部分向量维度,不包括自身
decoded_dist = decode_dist(whole_point_enc[:, vec_dims:start_pos+1],
scales=needed_info["atom_pos_scales"],
include_self=False)
else:
# 如果不使用傅立叶变换,直接取整个点编码中的特定维度
decoded_dist = whole_point_enc[:, start_pos:start_pos+1]
# 返回连接后的张量,包括单位径向向量和向量范数
return torch.cat([whole_point_enc[:, :3], decoded_dist], dim=-1)
def encode_whole_bonds(x, x_format="coords", embedd_info={},
needed_info = {"cutoffs": [2,5,10],
"bond_scales": [.5, 1, 2],
"adj_degree": 1},
free_mem=False, eps=1e-7):
""" Given some coordinates, and the needed info,
encodes the bonds from point information.
* x: (N, 3) or prediction format
* x_format: one of ["coords" or "prediction"]
* embedd_info: dict. contains the needed embedding info
* needed_info: dict. contains additional needed info
{ cutoffs: list. cutoff distances for bonds.
can be a string for the k closest (ex: "30_closest"),
empty list for just covalent.
bond_scales: list. fourier encodings
adj_degree: int. degree of adj (2 means adj of adj is my adj)
0 for no adjacency
}
* free_mem: whether to delete variables
* eps: constant for numerical stability
"""
device, precise = x.device, x.type()
# convert to 3d coords if passed as preds
if x_format == "encode":
pred_x = from_encode_to_pred(x, embedd_info=embedd_info, needed_info=needed_info)
x = pred_x[:, :3] * pred_x[:, 3:4]
# encode bonds
# 1. BONDS: find the covalent bond_indices - allow arg -> DRY
native = None
if "prot_covalent_bond" in needed_info.keys():
native = True
native_bonds = needed_info["covalent_bond"]
elif needed_info["adj_degree"]:
native = True
native_bonds = prot_covalent_bond(needed_info["seq"], needed_info["adj_degree"])
if native:
native_idxs, native_attrs = native_bonds[0].to(device), native_bonds[1].to(device)
# determine kind of cutoff (hard distance threshold or closest points)
closest = None
if len(needed_info["cutoffs"]) > 0:
cutoffs = needed_info["cutoffs"].copy()
if sum( isinstance(ci, str) for ci in cutoffs ) > 0:
cutoffs = [-1e-3] # negative so no bond is taken
closest = int( needed_info["cutoffs"][0].split("_")[0] )
# points under cutoff = d(i - j) < X
cutoffs = torch.tensor(cutoffs, device=device).type(precise)
dist_mat = torch.cdist(x, x, p=2)
# normal buckets
bond_buckets = torch.zeros(*x.shape[:-1], x.shape[-2], device=device).type(precise)
if len(needed_info["cutoffs"]) > 0 and not closest:
# count from latest degree of adjacency given
bond_buckets = torch.bucketize(dist_mat, cutoffs)
bond_buckets[native_idxs[0], native_idxs[1]] = cutoffs.shape[0]
# find the indexes - symmetric and we dont want the diag
bond_buckets += cutoffs.shape[0] * torch.eye(bond_buckets.shape[0], device=device).long()
close_bond_idxs = ( bond_buckets < cutoffs.shape[0] ).nonzero().t()
# move away from poses reserved for native
bond_buckets[close_bond_idxs[0], close_bond_idxs[1]] += needed_info["adj_degree"]+1
# the K closest (covalent bonds excluded) are considered bonds
# 如果存在最近的键,执行以下操作
elif closest:
# 将距离矩阵复制一份,并将共价键屏蔽掉
masked_dist_mat = dist_mat.clone()
masked_dist_mat += torch.eye(masked_dist_mat.shape[0], device=device) * torch.amax(masked_dist_mat)
masked_dist_mat[native_idxs[0], native_idxs[1]] = masked_dist_mat[0,0].clone()
# 根据距离排序,*(-1)使得最小值在前
_, sorted_col_idxs = torch.topk(-masked_dist_mat, k=k, dim=-1)
# 连接索引并重复行索引以匹配列索引的数量
sorted_col_idxs = rearrange(sorted_col_idxs[:, :k], '... n k -> ... (n k)')
sorted_row_idxs = torch.repeat_interleave( torch.arange(dist_mat.shape[0]).long(), repeats=k ).to(device)
close_bond_idxs = torch.stack([ sorted_row_idxs, sorted_col_idxs ], dim=0)
# 将远离保留给原生的姿势
bond_buckets = torch.ones_like(dist_mat) * (needed_info["adj_degree"]+1)
# 合并所有键
if len(needed_info["cutoffs"]) > 0:
if close_bond_idxs.shape[0] > 0:
whole_bond_idxs = torch.cat([native_idxs, close_bond_idxs], dim=-1)
else:
whole_bond_idxs = native_idxs
# 2. ATTRS: 将键编码为属性
bond_vecs = x[ whole_bond_idxs[0] ] - x[ whole_bond_idxs[1] ]
bond_norms = torch.norm(bond_vecs, dim=-1)
bond_vecs /= (bond_norms + eps).unsqueeze(-1)
bond_norms_enc = encode_dist(bond_norms, scales=needed_info["bond_scales"]).squeeze()
if native:
bond_buckets[native_idxs[0], native_idxs[1]] = native_attrs
bond_attrs = bond_buckets[whole_bond_idxs[0] , whole_bond_idxs[1]]
# 打包标量和向量 - 额外的令牌用于共价键
bond_n_vectors = 1
bond_n_scalars = (2 * len(needed_info["bond_scales"]) + 1) + 1 # 最后一个是大小为1+len(cutoffs)的嵌入
whole_bond_enc = torch.cat([bond_vecs, # 1个向量 - 不需要反转 - 我们做2倍的键(对称性)
# 标量
bond_norms_enc, # 2 * len(scales)
(bond_attrs-1).unsqueeze(-1) # 1
], dim=-1)
# 释放 GPU 内存
if free_mem:
del bond_buckets, bond_norms_enc, bond_vecs, dist_mat,\
close_bond_idxs, native_bond_idxs
if closest:
del masked_dist_mat, sorted_col_idxs, sorted_row_idxs
embedd_info = {"bond_n_vectors": bond_n_vectors,
"bond_n_scalars": bond_n_scalars,
"bond_embedding_nums": [ len(needed_info["cutoffs"]) + needed_info["adj_degree"] ]} # 额外一个用于共价键(默认)
return whole_bond_idxs, whole_bond_enc, embedd_info
def encode_whole_protein(seq, true_coords, angles, padding_seq,
needed_info = { "cutoffs": [2, 5, 10],
"bond_scales": [0.5, 1, 2]}, free_mem=False):
""" Encodes a whole protein. In points + vectors. """
# 获取设备和数据类型
device, precise = true_coords.device, true_coords.type()
#################
# encode points #
#################
# 创建云掩码
cloud_mask = torch.tensor(scn_cloud_mask(seq[:-padding_seq or None])).bool().to(device)
flat_mask = rearrange(cloud_mask, 'l c -> (l c)')
# 嵌入所有内容
# 一般位置嵌入
center_coords = true_coords - true_coords.mean(dim=0)
pos_unit_norms = torch.norm(center_coords, dim=-1, keepdim=True)
pos_unit_vecs = center_coords / pos_unit_norms
pos_unit_norms_enc = encode_dist(pos_unit_norms, scales=needed_info["atom_pos_scales"]).squeeze()
# 重新格式化坐标到scn (L, 14, 3) - 待解决如果填充=0
coords_wrap = rearrange(center_coords, '(l c) d -> l c d', c=14)[:-padding_seq or None]
# 蛋白质中的位置嵌入
aa_pos = encode_dist( torch.arange(len(seq[:-padding_seq or None]), device=device).float(), scales=needed_info["aa_pos_scales"])
atom_pos = chain2atoms(aa_pos)[cloud_mask]
# 原子标识嵌入
atom_id_embedds = torch.stack([SUPREME_INFO[k]["atom_id_embedd"] for k in seq[:-padding_seq or None]],
dim=0)[cloud_mask].to(device)
# 氨基酸嵌入
seq_int = torch.tensor([AAS2NUM[aa] for aa in seq[:-padding_seq or None]], device=device).long()
aa_id_embedds = chain2atoms(seq_int, mask=cloud_mask)
# CA - SC 距离
dist2ca_vec, dist2ca_norm = dist2ca(coords_wrap)
dist2ca_norm_enc = encode_dist(dist2ca_norm, scales=needed_info["dist2ca_norm_scales"]).squeeze()
# 主链特征
vecs, norms = orient_aa(coords_wrap)
bb_vecs_atoms = chain2atoms(torch.transpose(vecs, 0, 1), mask=cloud_mask)
bb_norms_atoms = chain2atoms(torch.transpose(norms, 0, 1), mask=cloud_mask)
bb_norms_atoms_enc = encode_dist(bb_norms_atoms, scales=[0.5])
################
# encode bonds #
################
bond_info = encode_whole_bonds(x = coords_wrap[cloud_mask],
x_format = "coords",
embedd_info = {},
needed_info = needed_info )
whole_bond_idxs, whole_bond_enc, bond_embedd_info = bond_info
#########
# merge #
#########
# 连接以使最终为[矢量维度,标量维度]
point_n_vectors = 1 + 1 + 5
point_n_scalars = 2*len(needed_info["atom_pos_scales"]) + 1 +\
2*len(needed_info["aa_pos_scales"]) + 1 +\
2*len(needed_info["dist2ca_norm_scales"]) + 1+\
rearrange(bb_norms_atoms_enc, 'atoms feats encs -> atoms (feats encs)').shape[1] +\
2 # 最后2个尚未嵌入
whole_point_enc = torch.cat([ pos_unit_vecs[ :-padding_seq*14 or None ][ flat_mask ], # 1
dist2ca_vec[cloud_mask], # 1
rearrange(bb_vecs_atoms, 'atoms n d -> atoms (n d)'), # 5
# 标量
pos_unit_norms_enc[ :-padding_seq*14 or None ][ flat_mask ], # 2n+1
atom_pos, # 2n+1
dist2ca_norm_enc[cloud_mask], # 2n+1
rearrange(bb_norms_atoms_enc, 'atoms feats encs -> atoms (feats encs)'), # 2n+1
atom_id_embedds.unsqueeze(-1),
aa_id_embedds.unsqueeze(-1) ], dim=-1) # 最后2个尚未嵌入
if free_mem:
del pos_unit_vecs, dist2ca_vec, bb_vecs_atoms, pos_unit_norms_enc, cloud_mask,\
atom_pos, dist2ca_norm_enc, bb_norms_atoms_enc, atom_id_embedds, aa_id_embedds
# 记录嵌入维度信息,包括点向量数量和标量数量
point_embedd_info = {"point_n_vectors": point_n_vectors,
"point_n_scalars": point_n_scalars,}
# 合并点和键的嵌入信息
embedd_info = {**point_embedd_info, **bond_embedd_info}
# 返回整体点编码、整体键索引、整体键编码和嵌入信息
return whole_point_enc, whole_bond_idxs, whole_bond_enc, embedd_info
def get_prot(dataloader_=None, vocab_=None, min_len=80, max_len=150, verbose=True):
""" Gets a protein from sidechainnet and returns
the right attrs for training.
Inputs:
* dataloader_: sidechainnet iterator over dataset
* vocab_: sidechainnet VOCAB class
* min_len: int. minimum sequence length
* max_len: int. maximum sequence length
* verbose: bool. verbosity level
"""
# 遍历数据加载器中的训练数据批次
for batch in dataloader_['train']:
# 尝试在两个循环中同时中断
try:
# 遍历当前批次中的序列
for i in range(batch.int_seqs.shape[0]):
# 获取变量
seq = ''.join([vocab_.int2char(aa) for aa in batch.int_seqs[i].numpy()])
int_seq = batch.int_seqs[i]
angles = batch.angs[i]
mask = batch.msks[i]
# 获取填充
padding_angles = (torch.abs(angles).sum(dim=-1) == 0).long().sum()
padding_seq = (batch.int_seqs[i] == 20).sum()
# 仅接受具有正确维度且没有缺失坐标的序列
# 大于0以避免后续负索引错误
if batch.crds[i].shape[0]//14 == int_seq.shape[0]:
if ( max_len > len(seq) and len(seq) > min_len ) and padding_seq == padding_angles:
if verbose:
print("stopping at sequence of length", len(seq))
# print(len(seq), angles.shape, "paddings: ", padding_seq, padding_angles)
# 触发 StopIteration 异常
raise StopIteration
else:
# print("found a seq of length:", len(seq),
# "but oustide the threshold:", min_len, max_len)
pass
except StopIteration:
# 中断外部循环
break
# 返回序列、坐标、角度、填充序列、掩码和蛋白质ID
return seq, batch.crds[i], angles, padding_seq, batch.msks[i], batch.pids[i]
GVP - Point Cloud
Geometric Vector Perceptron applied to Point Clouds
To install:
git clone ${repo_url}- install packages:
sidechainnet: github.com/jonathankin…- joblib, tqdm, numpy, einops, ...
- torch (was developed using 1.7.1)
- torch geometric: pytorch-geometric.readthedocs.io/en/latest/n…
cd this_repo_folder+pip install .ORpip install geometric-vector-perceptron(but installing from PyPi is not recommended for now - not updated)- any other just run:
pip install package_name
- Try to run the notebooks (they should run, report errors if encountered)
proto_dev_model.ipynb: shows how to gather the data and train a simple model on it, then reconstruct original struct and calculate improvement.
Descritpion:
- encode a protein (3d) into some features (scalars and position vectors)
- we encode both point features and edge features
- train the model to predict the right point features back
- reconstruct the 3d case to see the improvement
TO DO LIST:
See the issues tab?
Contribute
PRs and ideas are welcome. Describe a list of the changes you've made and provide tests/examples if possible (they're not requiered, but surely helps understanding).
.\lucidrains\geometric-vector-perceptron\examples\scn_data_module.py
# 导入必要的模块
from argparse import ArgumentParser
from typing import List, Optional
from typing import Union
import numpy as np
import pytorch_lightning as pl
import sidechainnet
from sidechainnet.dataloaders.collate import get_collate_fn
from sidechainnet.utils.sequence import ProteinVocabulary
from torch.utils.data import DataLoader, Dataset
# 定义自定义数据集类
class ScnDataset(Dataset):
def __init__(self, dataset, max_len: int):
super(ScnDataset, self).__init__()
self.dataset = dataset
self.max_len = max_len
self.scn_collate_fn = get_collate_fn(False)
self.vocab = ProteinVocabulary()
# 定义数据集的拼接函数
def collate_fn(self, batch):
batch = self.scn_collate_fn(batch)
real_seqs = [
"".join([self.vocab.int2char(aa) for aa in seq])
for seq in batch.int_seqs.numpy()
]
seq = real_seqs[0][: self.max_len]
true_coords = batch.crds[0].view(-1, 14, 3)[: self.max_len].view(-1, 3)
angles = batch.angs[0, : self.max_len]
mask = batch.msks[0, : self.max_len]
# 计算填充序列的长度
padding_seq = (np.array([*seq]) == "_").sum()
return {
"seq": seq,
"true_coords": true_coords,
"angles": angles,
"padding_seq": padding_seq,
"mask": mask,
}
# 获取数据集中指定索引的数据
def __getitem__(self, index: int):
return self.dataset[index]
# 返回数据集的长度
def __len__(self) -> int:
return len(self.dataset)
# 定义数据模块类
class ScnDataModule(pl.LightningDataModule):
# 添加数据特定参数
@staticmethod
def add_data_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--casp_version", type=int, default=7)
parser.add_argument("--scn_dir", type=str, default="./sidechainnet_data")
parser.add_argument("--train_batch_size", type=int, default=1)
parser.add_argument("--eval_batch_size", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--train_max_len", type=int, default=256)
parser.add_argument("--eval_max_len", type=int, default=256)
return parser
# 初始化数据模块
def __init__(
self,
casp_version: int = 7,
scn_dir: str = "./sidechainnet_data",
train_batch_size: int = 1,
eval_batch_size: int = 1,
num_workers: int = 1,
train_max_len: int = 256,
eval_max_len: int = 256,
**kwargs,
):
super().__init__()
assert train_batch_size == eval_batch_size == 1, "batch size must be 1 for now"
self.casp_version = casp_version
self.scn_dir = scn_dir
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.num_workers = num_workers
self.train_max_len = train_max_len
self.eval_max_len = eval_max_len
# 设置数据模块
def setup(self, stage: Optional[str] = None):
dataloaders = sidechainnet.load(
casp_version=self.casp_version,
scn_dir=self.scn_dir,
with_pytorch="dataloaders",
)
print(
dataloaders.keys()
) # ['train', 'train_eval', 'valid-10', ..., 'valid-90', 'test']
self.train = ScnDataset(dataloaders["train"].dataset, self.train_max_len)
self.val = ScnDataset(dataloaders["valid-90"].dataset, self.eval_max_len)
self.test = ScnDataset(dataloaders["test"].dataset, self.eval_max_len)
# 获取训练数据加载器
def train_dataloader(self, *args, **kwargs) -> DataLoader:
return DataLoader(
self.train,
batch_size=self.train_batch_size,
shuffle=True,
collate_fn=self.train.collate_fn,
num_workers=self.num_workers,
pin_memory=True,
)
# 定义用于验证数据集的数据加载器函数,返回一个数据加载器对象或对象列表
def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
# 创建一个数据加载器对象,用于加载验证数据集
return DataLoader(
self.val, # 使用验证数据集
batch_size=self.eval_batch_size, # 指定批量大小
shuffle=False, # 不打乱数据集顺序
collate_fn=self.val.collate_fn, # 使用验证数据集的数据整理函数
num_workers=self.num_workers, # 指定数据加载器的工作进程数
pin_memory=True, # 将数据加载到 CUDA 固定内存中
)
# 定义用于测试数据集的数据加载器函数,返回一个数据加载器对象或对象列表
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
# 创建一个数据加载器对象,用于加载测试数据集
return DataLoader(
self.test, # 使用测试数据集
batch_size=self.eval_batch_size, # 指定批量大小
shuffle=False, # 不打乱数据集顺序
collate_fn=self.test.collate_fn, # 使用测试数据集的数据整理函数
num_workers=self.num_workers, # 指定数据加载器的工作进程数
pin_memory=True, # 将数据加载到 CUDA 固定内存中
)
# 如果当前脚本作为主程序运行
if __name__ == "__main__":
# 创建一个 ScnDataModule 的实例对象
dm = ScnDataModule()
# 设置数据模块
dm.setup()
# 获取训练数据加载器
train = dm.train_dataloader()
# 打印训练数据加载器的长度
print("train length", len(train))
# 获取验证数据加载器
valid = dm.val_dataloader()
# 打印验证数据加载器的长度
print("valid length", len(valid))
# 获取测试数据加载器
test = dm.test_dataloader()
# 打印测试数据加载器的长度
print("test length", len(test))
# 遍历测试数据加载器
for batch in test:
# 打印当前批次的数据
print(batch)
# 跳出循环,只打印第一个批次的数据
break
.\lucidrains\geometric-vector-perceptron\examples\train_lightning.py
import gc
from argparse import ArgumentParser
from functools import partial
from pathlib import Path
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from einops import rearrange
from loguru import logger
from pytorch_lightning.callbacks import (
GPUStatsMonitor,
LearningRateMonitor,
ModelCheckpoint,
ProgressBar,
)
from pytorch_lightning.loggers import TensorBoardLogger
from examples.data_handler import kabsch_torch, scn_cloud_mask
from examples.data_utils import (
encode_whole_bonds,
encode_whole_protein,
from_encode_to_pred,
prot_covalent_bond,
)
from examples.scn_data_module import ScnDataModule
from geometric_vector_perceptron.geometric_vector_perceptron import GVP_Network
# 定义一个继承自 LightningModule 的结构模型类
class StructureModel(pl.LightningModule):
# 静态方法,用于添加模型特定参数
@staticmethod
def add_model_specific_args(parent_parser):
# 创建参数解析器
parser = ArgumentParser(parents=[parent_parser], add_help=False)
# 添加模型参数
parser.add_argument("--depth", type=int, default=4)
parser.add_argument("--cutoffs", type=float, default=1.0)
parser.add_argument("--noise", type=float, default=1.0)
# 添加优化器和调度器参数
parser.add_argument("--init_lr", type=float, default=1e-3)
return parser
# 初始化方法,接受模型参数
def __init__(
self,
depth: int = 1,
cutoffs: float = 1.0,
noise: float = 1.0,
init_lr: float = 1e-3,
**kwargs,
):
super().__init__()
# 保存超参数
self.save_hyperparameters()
# 定义需要的信息字典
self.needed_info = {
"cutoffs": [cutoffs], # -1e-3 for just covalent, "30_closest", 5. for under 5A, etc
"bond_scales": [1, 2, 4],
"aa_pos_scales": [1, 2, 4, 8, 16, 32, 64, 128],
"atom_pos_scales": [1, 2, 4, 8, 16, 32],
"dist2ca_norm_scales": [1, 2, 4],
"bb_norms_atoms": [0.5], # will encode 3 vectors with this
}
# 创建 GVP_Network 模型
self.model = GVP_Network(
n_layers=depth,
feats_x_in=48,
vectors_x_in=7,
feats_x_out=48,
vectors_x_out=7,
feats_edge_in=8,
vectors_edge_in=1,
feats_edge_out=8,
vectors_edge_out=1,
embedding_nums=[36, 20],
embedding_dims=[16, 16],
edge_embedding_nums=[2],
edge_embedding_dims=[2],
residual=True,
recalc=1
)
self.noise = noise
self.init_lr = init_lr
self.baseline_losses = [] # 存储基准损失
self.epoch_losses = [] # 存储每个 epoch 的损失
# 定义前向传播函数,接受序列、真实坐标、角度、填充序列、掩码作为输入
def forward(self, seq, true_coords, angles, padding_seq, mask):
# 获取需要的信息
needed_info = self.needed_info
# 获取设备信息
device = true_coords.device
# 将序列截取到填充序列之前的部分
needed_info["seq"] = seq[: (-padding_seq) or None]
# 计算蛋白质的共价键
needed_info["covalent_bond"] = prot_covalent_bond(needed_info["seq"])
# 对整个蛋白质进行编码
pre_target = encode_whole_protein(
seq,
true_coords,
angles,
padding_seq,
needed_info=needed_info,
free_mem=True,
)
pre_target_x, _, _, embedd_info = pre_target
# 对蛋白质进行编码并加入噪声
encoded = encode_whole_protein(
seq,
true_coords + self.noise * torch.randn_like(true_coords),
angles,
padding_seq,
needed_info=needed_info,
free_mem=True,
)
x, edge_index, edge_attrs, embedd_info = encoded
# 创建批次信息
batch = torch.tensor([0 for i in range(x.shape[0])], device=x.device).long()
# 添加位置坐标
cloud_mask = scn_cloud_mask(seq[: (-padding_seq) or None]).to(device)
chain_mask = mask[: (-padding_seq) or None].unsqueeze(-1) * cloud_mask
flat_chain_mask = rearrange(chain_mask.bool(), "l c -> (l c)")
cloud_mask = cloud_mask.bool()
flat_cloud_mask = rearrange(cloud_mask, "l c -> (l c)")
# 部分重新计算边
recalc_edge = partial(
encode_whole_bonds,
x_format="encode",
embedd_info=embedd_info,
needed_info=needed_info,
free_mem=True,
)
# 预测
scores = self.model.forward(
x,
edge_index,
batch=batch,
edge_attr=edge_attrs,
recalc_edge=recalc_edge,
verbose=False,
)
# 格式化预测、基线和目标
target = from_encode_to_pred(
pre_target_x, embedd_info=embedd_info, needed_info=needed_info
)
pred = from_encode_to_pred(
scores, embedd_info=embedd_info, needed_info=needed_info
)
base = from_encode_to_pred(x, embedd_info=embedd_info, needed_info=needed_info)
# 计算误差
# 选项1:损失是输出令牌的均方误差
# loss_ = (target-pred)**2
# loss = loss_.mean()
# 选项2:损失是重构坐标的RMSD
target_coords = target[:, 3:4] * target[:, :3]
pred_coords = pred[:, 3:4] * pred[:, :3]
base_coords = base[:, 3:4] * base[:, :3]
## 对齐 - 有时svc失败 - 不知道为什么
try:
pred_aligned, target_aligned = kabsch_torch(pred_coords.t(), target_coords.t()) # (3, N)
base_aligned, _ = kabsch_torch(base_coords.t(), target_coords.t())
loss = ( (pred_aligned.t() - target_aligned.t())[flat_chain_mask[flat_cloud_mask]]**2 ).mean()**0.5
loss_base = ( (base_aligned.t() - target_aligned.t())[flat_chain_mask[flat_cloud_mask]]**2 ).mean()**0.5
except:
pred_aligned, target_aligned = None, None
print("svd failed convergence, ep:", ep)
loss = ( (pred_coords.t() - target_coords.t())[flat_chain_mask[flat_cloud_mask]]**2 ).mean()**0.5
loss_base = ( (base_coords - target_coords)[flat_chain_mask[flat_cloud_mask]]**2 ).mean()**0.5
# 释放GPU内存
del true_coords, angles, pre_target_x, edge_index, edge_attrs
del scores, target_coords, pred_coords, base_coords
del encoded, pre_target, target_aligned, pred_aligned
gc.collect()
# 返回损失
return {"loss": loss, "loss_base": loss_base}
# 配置优化器
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.init_lr)
return optimizer
# 训练开始时的操作
def on_train_start(self) -> None:
self.baseline_losses = []
self.epoch_losses = []
# 训练步骤,接收一个批次数据和批次索引
def training_step(self, batch, batch_idx):
# 调用前向传播函数得到输出
output = self.forward(**batch)
# 获取损失值和基准损失值
loss = output["loss"]
loss_base = output["loss_base"]
# 如果损失值为空或为 NaN,则返回 None
if loss is None or torch.isnan(loss):
return None
# 将损失值和基准损失值添加到对应的列表中
self.epoch_losses.append(loss.item())
self.baseline_losses.append(loss_base.item())
# 记录训练损失值到日志中,显示在进度条中
self.log("train_loss", loss, on_epoch=True, prog_bar=True)
self.log("train_loss_base", output["loss_base"], on_epoch=True, prog_bar=False)
# 返回损失值
return loss
# 训练结束时的操作
def on_train_end(self) -> None:
# 创建一个图形窗口
plt.figure(figsize=(15, 6))
# 设置图形标题
plt.title(
f"Loss Evolution - Denoising of Gaussian-masked Coordinates (mu=0, sigma={self.noise})"
)
# 绘制训练损失值随时间的变化曲线
# 绘制滑动窗口平均值曲线
for window in [8, 16, 32]:
# 计算滑动窗口平均值
plt.plot(
[
np.mean(self.epoch_losses[:window][0 : i + 1])
for i in range(min(window, len(self.epoch_losses))
]
+ [
np.mean(self.epoch_losses[i : i + window + 1])
for i in range(len(self.epoch_losses) - window)
],
label="Window mean n={0}".format(window),
)
# 绘制基准损失值的水平虚线
plt.plot(
np.ones(len(self.epoch_losses)) * np.mean(self.baseline_losses),
"k--",
label="Baseline",
)
# 设置 x 轴范围
plt.xlim(-0.01 * len(self.epoch_losses), 1.01 * len(self.epoch_losses))
# 设置 y 轴标签
plt.ylabel("RMSD")
# 设置 x 轴标签
plt.xlabel("Batch number")
# 添加图例
plt.legend()
# 保存图形为 PDF 文件
plt.savefig("loss.pdf")
# 验证步骤,接收一个批次数据和批次索引
def validation_step(self, batch, batch_idx):
# 调用前向传播函数得到输出,并记录验证损失值到日志中
output = self.forward(**batch)
self.log("val_loss", output["loss"], on_epoch=True, sync_dist=True)
self.log("val_loss_base", output["loss_base"], on_epoch=True, sync_dist=True)
# 测试步骤,接收一个批次数据和批次索引
def test_step(self, batch, batch_idx):
# 调用前向传播函数得到输出,并记录测试损失值到日志中
output = self.forward(**batch)
self.log("test_loss", output["loss"], on_epoch=True, sync_dist=True)
self.log("test_loss_base", output["loss_base"], on_epoch=True, sync_dist=True)
# 根据参数获取训练器对象
def get_trainer(args):
# 设置随机种子
pl.seed_everything(args.seed)
# 创建日志记录器
root_dir = Path(args.default_root_dir).expanduser().resolve()
root_dir.mkdir(parents=True, exist_ok=True)
tb_save_dir = root_dir / "tb"
tb_logger = TensorBoardLogger(save_dir=tb_save_dir)
loggers = [tb_logger]
logger.info(f"Run tensorboard --logdir {tb_save_dir}")
# 创建回调函数
ckpt_cb = ModelCheckpoint(verbose=True)
lr_cb = LearningRateMonitor(logging_interval="step")
pb_cb = ProgressBar(refresh_rate=args.progress_bar_refresh_rate)
callbacks = [lr_cb, pb_cb]
callbacks.append(ckpt_cb)
gpu_cb = GPUStatsMonitor()
callbacks.append(gpu_cb)
plugins = []
# 根据参数创建训练器对象
trainer = pl.Trainer.from_argparse_args(
args, logger=loggers, callbacks=callbacks, plugins=plugins
)
return trainer
def main(args):
# 创建数据模块对象
dm = ScnDataModule(**vars(args))
# 创建模型对象
model = StructureModel(**vars(args))
# 获取训练器对象
trainer = get_trainer(args)
# 训练模型
trainer.fit(model, datamodule=dm)
# 测试模型并获取指标
metrics = trainer.test(model, datamodule=dm)
print("test", metrics)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--seed", type=int, default=23333, help="Seed everything.")
# 添加模型特定参数
parser = StructureModel.add_model_specific_args(parser)
# 添加数据特定参数
parser = ScnDataModule.add_data_specific_args(parser)
# 添加训练器参数
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# 打印参数
pprint(vars(args))
# 执行主函数
main(args)
.\lucidrains\geometric-vector-perceptron\examples\__init__.py
# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
# 计算矩形的面积
area = length * width
# 返回计算得到的面积
return area
.\lucidrains\geometric-vector-perceptron\geometric_vector_perceptron\geometric_vector_perceptron.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch_geometric.nn 模块中导入 MessagePassing 类
from torch_geometric.nn import MessagePassing
# types
# 导入类型提示相关的模块和类型
from typing import Optional, List, Union
from torch_geometric.typing import OptPairTensor, Adj, Size, OptTensor, Tensor
# helper functions
# 定义一个函数,判断输入值是否存在
def exists(val):
return val is not None
# classes
# 定义 GVP 类,继承自 nn.Module 类
class GVP(nn.Module):
def __init__(
self,
*,
dim_vectors_in,
dim_vectors_out,
dim_feats_in,
dim_feats_out,
feats_activation = nn.Sigmoid(),
vectors_activation = nn.Sigmoid(),
vector_gating = False
):
super().__init__()
self.dim_vectors_in = dim_vectors_in
self.dim_feats_in = dim_feats_in
self.dim_vectors_out = dim_vectors_out
dim_h = max(dim_vectors_in, dim_vectors_out)
# 初始化权重参数
self.Wh = nn.Parameter(torch.randn(dim_vectors_in, dim_h))
self.Wu = nn.Parameter(torch.randn(dim_h, dim_vectors_out))
self.vectors_activation = vectors_activation
# 定义输出特征的网络结构
self.to_feats_out = nn.Sequential(
nn.Linear(dim_h + dim_feats_in, dim_feats_out),
feats_activation
)
# 根据 vector_gating 参数选择是否使用向量门控
self.scalar_to_vector_gates = nn.Linear(dim_feats_out, dim_vectors_out) if vector_gating else None
# 前向传播函数
def forward(self, data):
feats, vectors = data
b, n, _, v, c = *feats.shape, *vectors.shape
# 断言向量维度和特征维度是否匹配
assert c == 3 and v == self.dim_vectors_in, 'vectors have wrong dimensions'
assert n == self.dim_feats_in, 'scalar features have wrong dimensions'
# 计算 Vh 和 Vu
Vh = einsum('b v c, v h -> b h c', vectors, self.Wh)
Vu = einsum('b h c, h u -> b u c', Vh, self.Wu)
# 计算向量的模长
sh = torch.norm(Vh, p = 2, dim = -1)
# 拼接特征和模长
s = torch.cat((feats, sh), dim = 1)
# 计算特征输出
feats_out = self.to_feats_out(s)
# 如果存在 scalar_to_vector_gates,则计算门控
if exists(self.scalar_to_vector_gates):
gating = self.scalar_to_vector_gates(feats_out)
gating = gating.unsqueeze(dim = -1)
else:
gating = torch.norm(Vu, p = 2, dim = -1, keepdim = True)
# 计算向量输出
vectors_out = self.vectors_activation(gating) * Vu
return (feats_out, vectors_out)
# 定义 GVPDropout 类,继承自 nn.Module 类
class GVPDropout(nn.Module):
""" Separate dropout for scalars and vectors. """
def __init__(self, rate):
super().__init__()
self.vector_dropout = nn.Dropout2d(rate)
self.feat_dropout = nn.Dropout(rate)
# 前向传播函数
def forward(self, feats, vectors):
return self.feat_dropout(feats), self.vector_dropout(vectors)
# 定义 GVPLayerNorm 类,继承自 nn.Module 类
class GVPLayerNorm(nn.Module):
""" Normal layer norm for scalars, nontrainable norm for vectors. """
def __init__(self, feats_h_size, eps = 1e-8):
super().__init__()
self.eps = eps
self.feat_norm = nn.LayerNorm(feats_h_size)
# 前向传播函数
def forward(self, feats, vectors):
vector_norm = vectors.norm(dim=(-1,-2), keepdim=True)
normed_feats = self.feat_norm(feats)
normed_vectors = vectors / (vector_norm + self.eps)
return normed_feats, normed_vectors
# 定义 GVP_MPNN 类,继承自 MessagePassing 类
class GVP_MPNN(MessagePassing):
r"""The Geometric Vector Perceptron message passing layer
introduced in https://openreview.net/forum?id=1YLJDvSx6J4.
Uses a Geometric Vector Perceptron instead of the normal
MLP in aggregation phase.
Inputs will be a concatenation of (vectors, features)
Args:
* feats_x_in: int. number of scalar dimensions in the x inputs.
* vectors_x_in: int. number of vector dimensions in the x inputs.
* feats_x_out: int. number of scalar dimensions in the x outputs.
* vectors_x_out: int. number of vector dimensions in the x outputs.
* feats_edge_in: int. number of scalar dimensions in the edge_attr inputs.
* vectors_edge_in: int. number of vector dimensions in the edge_attr inputs.
* feats_edge_out: int. number of scalar dimensions in the edge_attr outputs.
* vectors_edge_out: int. number of vector dimensions in the edge_attr outputs.
* dropout: float. dropout rate.
* vector_dim: int. dimensions of the space containing the vectors.
* verbose: bool. verbosity level.
"""
# 初始化函数,接受多个参数
def __init__(self, feats_x_in, vectors_x_in,
feats_x_out, vectors_x_out,
feats_edge_in, vectors_edge_in,
feats_edge_out, vectors_edge_out,
dropout, residual=False, vector_dim=3,
verbose=False, **kwargs):
# 调用父类的初始化函数,设置聚合方式为"mean"
super(GVP_MPNN, self).__init__(aggr="mean",**kwargs)
# 记录是否输出详细信息
self.verbose = verbose
# 记录输入特征和向量的维度
self.feats_x_in = feats_x_in
self.vectors_x_in = vectors_x_in # 输入中的 N 个向量特征
self.feats_x_out = feats_x_out
self.vectors_x_out = vectors_x_out # 输出中的 N 个向量特征
# 记录边属性的维度
self.feats_edge_in = feats_edge_in
self.vectors_edge_in = vectors_edge_in # 输入中的 N 个向量特征
self.feats_edge_out = feats_edge_out
self.vectors_edge_out = vectors_edge_out # 输出中的 N 个向量特征
# 辅助层
self.vector_dim = vector_dim
# 初始化归一化层
self.norm = nn.ModuleList([GVPLayerNorm(self.feats_x_out), # + self.feats_edge_out
GVPLayerNorm(self.feats_x_out)])
# 初始化 dropout 层
self.dropout = GVPDropout(dropout)
# 是否使用残差连接
self.residual = residual
# 接收 vec_in 消息和接收节点
self.W_EV = nn.Sequential(GVP(
dim_vectors_in = self.vectors_x_in + self.vectors_edge_in,
dim_vectors_out = self.vectors_x_out + self.feats_edge_out,
dim_feats_in = self.feats_x_in + self.feats_edge_in,
dim_feats_out = self.feats_x_out + self.feats_edge_out
),
GVP(
dim_vectors_in = self.vectors_x_out + self.feats_edge_out,
dim_vectors_out = self.vectors_x_out + self.feats_edge_out,
dim_feats_in = self.feats_x_out + self.feats_edge_out,
dim_feats_out = self.feats_x_out + self.feats_edge_out
),
GVP(
dim_vectors_in = self.vectors_x_out + self.feats_edge_out,
dim_vectors_out = self.vectors_x_out + self.feats_edge_out,
dim_feats_in = self.feats_x_out + self.feats_edge_out,
dim_feats_out = self.feats_x_out + self.feats_edge_out
))
# 初始化 W_dh 层
self.W_dh = nn.Sequential(GVP(
dim_vectors_in = self.vectors_x_out,
dim_vectors_out = 2*self.vectors_x_out,
dim_feats_in = self.feats_x_out,
dim_feats_out = 4*self.feats_x_out
),
GVP(
dim_vectors_in = 2*self.vectors_x_out,
dim_vectors_out = self.vectors_x_out,
dim_feats_in = 4*self.feats_x_out,
dim_feats_out = self.feats_x_out
))
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
edge_attr: OptTensor = None, size: Size = None) -> Tensor:
""""""
# 获取输入张量 x 的最后一个维度的大小
x_size = list(x.shape)[-1]
# 分别聚合特征和向量
feats, vectors = self.propagate(edge_index, x=x, edge_attr=edge_attr)
# 聚合
feats, vectors = self.dropout(feats, vectors.reshape(vectors.shape[0], -1, self.vector_dim))
# 获取与节点相关的信息 - 不返回边
feats_nodes = feats[:, :self.feats_x_in]
vector_nodes = vectors[:, :self.vectors_x_in]
# 将向量部分重塑为最后一个 3D
x_vectors = x[:, :self.vectors_x_in * self.vector_dim].reshape(x.shape[0], -1, self.vector_dim)
feats, vectors = self.norm[0]( x[:, self.vectors_x_in * self.vector_dim:]+feats_nodes, x_vectors+vector_nodes )
# 更新位置感知前馈
feats_, vectors_ = self.dropout( *self.W_dh( (feats, vectors) ) )
feats, vectors = self.norm[1]( feats+feats_, vectors+vectors_ )
# 使其成为残差
new_x = torch.cat( [feats, vectors.flatten(start_dim=-2)], dim=-1 )
if self.residual:
return new_x + x
return new_x
def message(self, x_j, edge_attr) -> Tensor:
# 拼接特征和边属性
feats = torch.cat([ x_j[:, self.vectors_x_in * self.vector_dim:],
edge_attr[:, self.vectors_edge_in * self.vector_dim:] ], dim=-1)
vectors = torch.cat([ x_j[:, :self.vectors_x_in * self.vector_dim],
edge_attr[:, :self.vectors_edge_in * self.vector_dim] ], dim=-1).reshape(x_j.shape[0],-1,self.vector_dim)
feats, vectors = self.W_EV( (feats, vectors) )
return feats, vectors.flatten(start_dim=-2)
def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
r"""The initial call to start propagating messages.
Args:
adj (Tensor or SparseTensor): `edge_index` holds the indices of a general (sparse)
assignment matrix of shape :obj:`[N, M]`.
size (tuple, optional): If set to :obj:`None`, the size will be automatically inferred
and assumed to be quadratic.
This argument is ignored in case :obj:`edge_index` is a
:obj:`torch_sparse.SparseTensor`. (default: :obj:`None`)
**kwargs: Any additional data which is needed to construct and
aggregate messages, and to update node embeddings.
"""
size = self.__check_input__(edge_index, size)
coll_dict = self.__collect__(self.__user_args__,
edge_index, size, kwargs)
msg_kwargs = self.inspector.distribute('message', coll_dict)
feats, vectors = self.message(**msg_kwargs)
# 聚合它们
aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
out_feats = self.aggregate(feats, **aggr_kwargs)
out_vectors = self.aggregate(vectors, **aggr_kwargs)
# 返回元组
update_kwargs = self.inspector.distribute('update', coll_dict)
return self.update((out_feats, out_vectors), **update_kwargs)
def __repr__(self):
dict_print = { "feats_x_in": self.feats_x_in,
"vectors_x_in": self.vectors_x_in,
"feats_x_out": self.feats_x_out,
"vectors_x_out": self.vectors_x_out,
"feats_edge_in": self.feats_edge_in,
"vectors_edge_in": self.vectors_edge_in,
"feats_edge_out": self.feats_edge_out,
"vectors_edge_out": self.vectors_edge_out,
"vector_dim": self.vector_dim }
return 'GVP_MPNN Layer with the following attributes: ' + str(dict_print)
class GVP_Network(nn.Module):
r"""Sample GNN model architecture that uses the Geometric Vector Perceptron
message passing layer to learn over point clouds.
Main MPNN layer introduced in https://openreview.net/forum?id=1YLJDvSx6J4.
Inputs will be standard GNN: x, edge_index, edge_attr, batch, ...
Args:
* n_layers: int. number of MPNN layers
* feats_x_in: int. number of scalar dimensions in the x inputs.
* vectors_x_in: int. number of vector dimensions in the x inputs.
* feats_x_out: int. number of scalar dimensions in the x outputs.
* vectors_x_out: int. number of vector dimensions in the x outputs.
* feats_edge_in: int. number of scalar dimensions in the edge_attr inputs.
* vectors_edge_in: int. number of vector dimensions in the edge_attr inputs.
* feats_edge_out: int. number of scalar dimensions in the edge_attr outputs.
* embedding_nums: list. number of unique keys to embedd. for points
1 entry per embedding needed.
* embedding_dims: list. point - number of dimensions of
the resulting embedding. 1 entry per embedding needed.
* edge_embedding_nums: list. number of unique keys to embedd. for edges.
1 entry per embedding needed.
* edge_embedding_dims: list. point - number of dimensions of
the resulting embedding. 1 entry per embedding needed.
* vectors_edge_out: int. number of vector dimensions in the edge_attr outputs.
* dropout: float. dropout rate.
* vector_dim: int. dimensions of the space containing the vectors.
* recalc: bool. Whether to recalculate edge features between MPNN layers.
* verbose: bool. verbosity level.
"""
# 初始化函数,接受多个参数,包括层数、输入特征和向量、输出特征和向量、边特征和向量等
def __init__(self, n_layers,
feats_x_in, vectors_x_in,
feats_x_out, vectors_x_out,
feats_edge_in, vectors_edge_in,
feats_edge_out, vectors_edge_out,
embedding_nums=[], embedding_dims=[],
edge_embedding_nums=[], edge_embedding_dims=[],
dropout=0.0, residual=False, vector_dim=3,
recalc=1, verbose=False):
# 调用父类的初始化函数
super().__init__()
# 初始化各种属性
self.n_layers = n_layers
self.embedding_nums = embedding_nums
self.embedding_dims = embedding_dims
self.emb_layers = torch.nn.ModuleList()
self.edge_embedding_nums = edge_embedding_nums
self.edge_embedding_dims = edge_embedding_dims
self.edge_emb_layers = torch.nn.ModuleList()
# 实例化点和边的嵌入层
for i in range( len(self.embedding_dims) ):
self.emb_layers.append(nn.Embedding(num_embeddings = embedding_nums[i],
embedding_dim = embedding_dims[i]))
feats_x_in += embedding_dims[i] - 1
feats_x_out += embedding_dims[i] - 1
for i in range( len(self.edge_embedding_dims) ):
self.edge_emb_layers.append(nn.Embedding(num_embeddings = edge_embedding_nums[i],
embedding_dim = edge_embedding_dims[i]))
feats_edge_in += edge_embedding_dims[i] - 1
feats_edge_out += edge_embedding_dims[i] - 1
# 初始化其他属性
self.fc_layers = torch.nn.ModuleList()
self.gcnn_layers = torch.nn.ModuleList()
self.feats_x_in = feats_x_in
self.vectors_x_in = vectors_x_in
self.feats_x_out = feats_x_out
self.vectors_x_out = vectors_x_out
self.feats_edge_in = feats_edge_in
self.vectors_edge_in = vectors_edge_in
self.feats_edge_out = feats_edge_out
self.vectors_edge_out = vectors_edge_out
self.dropout = dropout
self.residual = residual
self.vector_dim = vector_dim
self.recalc = recalc
self.verbose = verbose
# 实例化GCNN层
for i in range(n_layers):
layer = GVP_MPNN(feats_x_in, vectors_x_in,
feats_x_out, vectors_x_out,
feats_edge_in, vectors_edge_in,
feats_edge_out, vectors_edge_out,
dropout, residual=residual,
vector_dim=vector_dim, verbose=verbose)
self.gcnn_layers.append(layer)
# 定义一个前向传播函数,接受输入 x、边索引 edge_index、批次 batch、边属性 edge_attr
# bsize 为批次大小,recalc_edge 为重新计算边特征的函数,verbose 为是否输出详细信息的标志
def forward(self, x, edge_index, batch, edge_attr,
bsize=None, recalc_edge=None, verbose=0):
""" Embedding of inputs when necessary, then pass layers.
Recalculate edge features every time with the
`recalc_edge` function.
"""
# 复制输入数据,用于后续恢复原始数据
original_x = x.clone()
original_edge_index = edge_index.clone()
original_edge_attr = edge_attr.clone()
# 当需要时进行嵌入
# 选择要嵌入的部分,逐个进行嵌入并添加到输入中
# 提取要嵌入的部分
to_embedd = x[:, -len(self.embedding_dims):].long()
for i, emb_layer in enumerate(self.emb_layers):
# 在第一次迭代时,对应于 `to_embedd` 部分的部分会被丢弃
stop_concat = -len(self.embedding_dims) if i == 0 else x.shape[-1]
x = torch.cat([x[:, :stop_concat],
emb_layer(to_embedd[:, i])], dim=-1)
# 传递层
for i, layer in enumerate(self.gcnn_layers):
# 嵌入边属性(每次都需要,因为边属性和索引在每次传递时都会重新计算)
to_embedd = edge_attr[:, -len(self.edge_embedding_dims):].long()
for j, edge_emb_layer in enumerate(self.edge_emb_layers):
# 在第一次迭代时,对应于 `to_embedd` 部分的部分会被丢弃
stop_concat = -len(self.edge_embedding_dims) if j == 0 else x.shape[-1]
edge_attr = torch.cat([edge_attr[:, :-len(self.edge_embedding_dims) + j],
edge_emb_layer(to_embedd[:, j])], dim=-1)
# 传递层
x = layer(x, edge_index, edge_attr, size=bsize)
# 每 self.recalc 步重新计算边信息
# 但如果是最后一层的最后一次迭代,则不需要重新计算
if (1 % self.recalc == 0) and not (i == self.n_layers - 1):
edge_index, edge_attr, _ = recalc_edge(x) # 返回属性、索引、嵌入信息
else:
edge_attr = original_edge_attr.clone()
edge_index = original_edge_index.clone()
if verbose:
print("========")
print("iter:", j, "layer:", i, "nlinks:", edge_attr.shape)
return x
# 定义对象的字符串表示形式
def __repr__(self):
return 'GVP_Network of: {0} layers'.format(len(self.gcnn_layers))
.\lucidrains\geometric-vector-perceptron\geometric_vector_perceptron\__init__.py
# 从 geometric_vector_perceptron 模块中导入 GVP, GVPDropout, GVPLayerNorm, GVP_MPNN, GVP_Network 类
from geometric_vector_perceptron.geometric_vector_perceptron import GVP, GVPDropout, GVPLayerNorm, GVP_MPNN, GVP_Network

Geometric Vector Perceptron
Implementation of Geometric Vector Perceptron, a simple circuit with 3d rotation equivariance for learning over large biomolecules, in Pytorch. The repository may also contain experimentation to see if this could be easily extended to self-attention.
Install
$ pip install geometric-vector-perceptron
Functionality
GVP: Implementing the basic geometric vector perceptron.GVPDropout: Adapted dropout for GVP in MPNN contextGVPLayerNorm: Adapted LayerNorm for GVP in MPNN contextGVP_MPNN: Adapted instance of Message Passing class fromtorch-geometricpackage. Still not tested.GVP_Network: Functional model architecture ready for working with arbitary point clouds.
Usage
import torch
from geometric_vector_perceptron import GVP
model = GVP(
dim_vectors_in = 1024,
dim_feats_in = 512,
dim_vectors_out = 256,
dim_feats_out = 512,
vector_gating = True # use the vector gating as proposed in https://arxiv.org/abs/2106.03843
)
feats, vectors = (torch.randn(1, 512), torch.randn(1, 1024, 3))
feats_out, vectors_out = model( (feats, vectors) ) # (1, 256), (1, 512, 3)
With the specialized dropout and layernorm as described in the paper
import torch
from torch import nn
from geometric_vector_perceptron import GVP, GVPDropout, GVPLayerNorm
model = GVP(
dim_vectors_in = 1024,
dim_feats_in = 512,
dim_vectors_out = 256,
dim_feats_out = 512,
vector_gating = True
)
dropout = GVPDropout(0.2)
norm = GVPLayerNorm(512)
feats, vectors = (torch.randn(1, 512), torch.randn(1, 1024, 3))
feats, vectors = model( (feats, vectors) )
feats, vectors = dropout(feats, vectors)
feats, vectors = norm(feats, vectors) # (1, 256), (1, 512, 3)
TF implementation:
The original implementation in TF by the paper authors can be found here: github.com/drorlab/gvp…
Citations
@inproceedings{anonymous2021learning,
title = {Learning from Protein Structure with Geometric Vector Perceptrons},
author = {Anonymous},
booktitle = {Submitted to International Conference on Learning Representations},
year = {2021},
url = {https://openreview.net/forum?id=1YLJDvSx6J4}
}
@misc{jing2021equivariant,
title = {Equivariant Graph Neural Networks for 3D Macromolecular Structure},
author = {Bowen Jing and Stephan Eismann and Pratham N. Soni and Ron O. Dror},
year = {2021},
eprint = {2106.03843},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
.\lucidrains\geometric-vector-perceptron\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'geometric-vector-perceptron', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.0.14', # 版本号
license='MIT', # 许可证
description = 'Geometric Vector Perceptron - Pytorch', # 描述
author = 'Phil Wang, Eric Alcaide', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/geometric-vector-perceptron', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'proteins',
'biomolecules',
'equivariance'
],
install_requires=[ # 安装依赖
'torch>=1.6',
'torch-scatter',
'torch-sparse',
'torch-cluster',
'torch-spline-conv',
'torch-geometric'
],
setup_requires=[ # 设置需要的依赖
'pytest-runner',
],
tests_require=[ # 测试需要的依赖
'pytest'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\geometric-vector-perceptron\tests\tests.py
# 导入 torch 库
import torch
# 从 geometric_vector_perceptron 库中导入 GVP, GVPDropout, GVPLayerNorm, GVP_MPNN
from geometric_vector_perceptron import GVP, GVPDropout, GVPLayerNorm, GVP_MPNN
# 定义容差值
TOL = 1e-2
# 生成随机旋转矩阵
def random_rotation():
q, r = torch.qr(torch.randn(3, 3))
return q
# 计算向量之间的差值矩阵
def diff_matrix(vectors):
b, _, d = vectors.shape
diff = vectors[..., None, :] - vectors[:, None, ...]
return diff.reshape(b, -1, d)
# 测试等变性
def test_equivariance():
R = random_rotation()
# 创建 GVP 模型
model = GVP(
dim_vectors_in = 1024,
dim_feats_in = 512,
dim_vectors_out = 256,
dim_feats_out = 512
)
feats = torch.randn(1, 512)
vectors = torch.randn(1, 32, 3)
feats_out, vectors_out = model( (feats, diff_matrix(vectors)) )
feats_out_r, vectors_out_r = model( (feats, diff_matrix(vectors @ R)) )
err = ((vectors_out @ R) - vectors_out_r).max()
assert err < TOL, 'equivariance must be respected'
# 测试所有层类型
def test_all_layer_types():
R = random_rotation()
# 创建 GVP 模型
model = GVP(
dim_vectors_in = 1024,
dim_feats_in = 512,
dim_vectors_out = 256,
dim_feats_out = 512
)
dropout = GVPDropout(0.2)
layer_norm = GVPLayerNorm(512)
feats = torch.randn(1, 512)
message = torch.randn(1, 512)
vectors = torch.randn(1, 32, 3)
# GVP 层
feats_out, vectors_out = model( (feats, diff_matrix(vectors)) )
assert list(feats_out.shape) == [1, 512] and list(vectors_out.shape) == [1, 256, 3]
# GVP Dropout
feats_out, vectors_out = dropout(feats_out, vectors_out)
assert list(feats_out.shape) == [1, 512] and list(vectors_out.shape) == [1, 256, 3]
# GVP Layer Norm
feats_out, vectors_out = layer_norm(feats_out, vectors_out)
assert list(feats_out.shape) == [1, 512] and list(vectors_out.shape) == [1, 256, 3]
# 测试 MPNN
def test_mpnn():
# 输入数据
x = torch.randn(5, 32)
edge_idx = torch.tensor([[0,2,3,4,1], [1,1,3,3,4]]).long()
edge_attr = torch.randn(5, 16)
# 节点 (8 个标量和 8 个向量) || 边 (4 个标量和 3 个向量)
dropout = 0.1
# 定义层
gvp_mpnn = GVP_MPNN(feats_x_in = 8,
vectors_x_in = 8,
feats_x_out = 8,
vectors_x_out = 8,
feats_edge_in = 4,
vectors_edge_in = 4,
feats_edge_out = 4,
vectors_edge_out = 4,
dropout=0.1 )
x_out = gvp_mpnn(x, edge_idx, edge_attr)
assert x.shape == x_out.shape, "Input and output shapes don't match"
# 主函数入口
if __name__ == "__main__":
# 执行等变性测试
test_equivariance()
# 执行所有层类型测试
test_all_layer_types()
# 执行 MPNN 测试
test_mpnn()
.\lucidrains\gigagan-pytorch\gigagan_pytorch\attend.py
# 导入必要的库和模块
from functools import wraps
from packaging import version
from collections import namedtuple
import torch
from torch import nn, einsum
import torch.nn.functional as F
# 定义一个命名元组,用于存储注意力机制的配置信息
AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# 定义一个辅助函数,用于检查值是否存在
def exists(val):
return val is not None
# 定义一个装饰器函数,确保被装饰的函数只执行一次
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 用装饰器once包装print函数,确保只打印一次
print_once = once(print)
# 主要类定义
class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
flash = False
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# 确定在cuda和cpu上的高效注意力配置
self.cpu_config = AttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = AttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = AttentionConfig(False, True, True)
# 实现flash attention的方法
def flash_attn(self, q, k, v):
is_cuda = q.is_cuda
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
# 检查是否有兼容的设备支持flash attention
config = self.cuda_config if is_cuda else self.cpu_config
# 使用torch.backends.cuda.sdp_kernel函数应用flash attention
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p = self.dropout if self.training else 0.
)
return out
# 前向传播函数
def forward(self, q, k, v):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
if self.flash:
return self.flash_attn(q, k, v)
scale = q.shape[-1] ** -0.5
# 计算相似度
sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
# 注意力计算
attn = sim.softmax(dim = -1)
attn = self.attn_dropout(attn)
# 聚合数值
out = einsum("b h i j, b h j d -> b h i d", attn, v)
return out
.\lucidrains\gigagan-pytorch\gigagan_pytorch\data.py
# 导入必要的库
from functools import partial
from pathlib import Path
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms as T
from beartype.door import is_bearable
from beartype.typing import Tuple
# 辅助函数
# 检查值是否存在
def exists(val):
return val is not None
# 将图像转换为指定格式的函数
def convert_image_to_fn(img_type, image):
if image.mode == img_type:
return image
return image.convert(img_type)
# 自定义数据集拼接函数
# 使数据集可以返回字符串并将其拼接成 List[str]
def collate_tensors_or_str(data):
is_one_data = not isinstance(data[0], tuple)
if is_one_data:
data = torch.stack(data)
return (data,)
outputs = []
for datum in zip(*data):
if is_bearable(datum, Tuple[str, ...]):
output = list(datum)
else:
output = torch.stack(datum)
outputs.append(output)
return tuple(outputs)
# 数据集类
# 图像数据集类
class ImageDataset(Dataset):
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png', 'tiff'],
augment_horizontal_flip = False,
convert_image_to = None
):
super().__init__()
self.folder = folder
self.image_size = image_size
# 获取文件夹中指定扩展名的所有文件路径
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
# 断言确保文件路径数量大于0
assert len(self.paths) > 0, 'your folder contains no images'
# 断言确保文件路径数量大于100
assert len(self.paths) > 100, 'you need at least 100 images, 10k for research paper, millions for miraculous results (try Laion-5B)'
# 创建转换函数
maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else nn.Identity()
# 图像转换操作序列
self.transform = T.Compose([
T.Lambda(maybe_convert_fn),
T.Resize(image_size),
T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
T.CenterCrop(image_size),
T.ToTensor()
])
# 获取数据加载器
def get_dataloader(self, *args, **kwargs):
return DataLoader(self, *args, shuffle = True, drop_last = True, **kwargs)
# 返回数据集长度
def __len__(self):
return len(self.paths)
# 获取数据���中的数据
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
# 文本图像数据集类
class TextImageDataset(Dataset):
def __init__(self):
raise NotImplementedError
# 获取数据加载器
def get_dataloader(self, *args, **kwargs):
return DataLoader(self, *args, collate_fn = collate_tensors_or_str, **kwargs)
# 模拟文本图像数据集类
class MockTextImageDataset(TextImageDataset):
def __init__(
self,
image_size,
length = int(1e5),
channels = 3
):
self.image_size = image_size
self.channels = channels
self.length = length
# 获取数据加载器
def get_dataloader(self, *args, **kwargs):
return DataLoader(self, *args, collate_fn = collate_tensors_or_str, **kwargs)
# 返回数据集长度
def __len__(self):
return self.length
# 获取数据集中的数据
def __getitem__(self, index):
mock_image = torch.randn(self.channels, self.image_size, self.image_size)
return mock_image, 'mock text'