Lucidrains 系列项目源码解析(八十八)
.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\basis.py
import os
from math import pi
import torch
from torch import einsum
from einops import rearrange
from itertools import product
from contextlib import contextmanager
from se3_transformer_pytorch.irr_repr import irr_repr, spherical_harmonics
from se3_transformer_pytorch.utils import torch_default_dtype, cache_dir, exists, default, to_order
from se3_transformer_pytorch.spherical_harmonics import clear_spherical_harmonics_cache
CACHE_PATH = default(os.getenv('CACHE_PATH'), os.path.expanduser('~/.cache.equivariant_attention'))
CACHE_PATH = CACHE_PATH if not exists(os.environ.get('CLEAR_CACHE')) else None
RANDOM_ANGLES = [
[4.41301023, 5.56684102, 4.59384642],
[4.93325116, 6.12697327, 4.14574096],
[0.53878964, 4.09050444, 5.36539036],
[2.16017393, 3.48835314, 5.55174441],
[2.52385107, 0.2908958, 3.90040975]
]
@contextmanager
def null_context():
yield
def get_matrix_kernel(A, eps = 1e-10):
'''
计算矩阵A的核的正交基(x_1, x_2, ...)
A x_i = 0
scalar_product(x_i, x_j) = delta_ij
:param A: 矩阵
:return: 每行是A核的基向量的矩阵
'''
_u, s, v = torch.svd(A)
kernel = v.t()[s < eps]
return kernel
def get_matrices_kernel(As, eps = 1e-10):
'''
计算所有矩阵As的公共核
'''
matrix = torch.cat(As, dim=0)
return get_matrix_kernel(matrix, eps)
def get_spherical_from_cartesian(cartesian, divide_radius_by = 1.0):
"""
将笛卡尔坐标转换为球坐标
# ON ANGLE CONVENTION
#
# sh has following convention for angles:
# :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)).
# :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi.
#
# the 3D steerable CNN code therefore (probably) has the following convention for alpha and beta:
# beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1).
# alpha = phi
#
"""
spherical = torch.zeros_like(cartesian)
ind_radius, ind_alpha, ind_beta = 0, 1, 2
cartesian_x, cartesian_y, cartesian_z = 2, 0, 1
r_xy = cartesian[..., cartesian_x] ** 2 + cartesian[..., cartesian_y] ** 2
spherical[..., ind_beta] = torch.atan2(torch.sqrt(r_xy), cartesian[..., cartesian_z])
spherical[...,ind_alpha] = torch.atan2(cartesian[...,cartesian_y], cartesian[...,cartesian_x])
radius = torch.sqrt(r_xy + cartesian[...,cartesian_z]**2)
if divide_radius_by != 1.0:
radius /= divide_radius_by
spherical[..., ind_radius] = radius
return spherical
def kron(a, b):
"""
计算矩阵a和b的Kronecker积
"""
res = einsum('... i j, ... k l -> ... i k j l', a, b)
return rearrange(res, '... i j k l -> ... (i j) (k l)')
def get_R_tensor(order_out, order_in, a, b, c):
return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c)
def sylvester_submatrix(order_out, order_in, J, a, b, c):
''' 生成用于在子空间J中解Sylvester方程的Kronecker积矩阵 '''
R_tensor = get_R_tensor(order_out, order_in, a, b, c)
R_irrep_J = irr_repr(J, a, b, c)
R_tensor_identity = torch.eye(R_tensor.shape[0])
R_irrep_J_identity = torch.eye(R_irrep_J.shape[0]
return kron(R_tensor, R_irrep_J_identity) - kron(R_tensor_identity, R_irrep_J.t())
def basis_transformation_Q_J(J, order_in, order_out, random_angles = RANDOM_ANGLES):
"""
:param J: 球谐函数的阶数
:param order_in: 输入表示的阶数
:param order_out: 输出表示的阶数
:return: 文章中 Q^-1 矩阵的一部分
"""
sylvester_submatrices = [sylvester_submatrix(order_out, order_in, J, a, b, c) for a, b, c in random_angles]
null_space = get_matrices_kernel(sylvester_submatrices)
assert null_space.size(0) == 1, null_space.size()
Q_J = null_space[0]
Q_J = Q_J.view(to_order(order_out) * to_order(order_in), to_order(J))
return Q_J.float()
def precompute_sh(r_ij, max_J):
"""
预计算球谐函数直到最大阶数 max_J
:param r_ij: 相对位置
:param max_J: 整个网络中使用的最大阶数
:return: 字典,每个条目的形状为 [B,N,K,2J+1]
"""
i_alpha, i_beta = 1, 2
Y_Js = {J: spherical_harmonics(J, r_ij[...,i_alpha], r_ij[...,i_beta]) for J in range(max_J + 1)}
clear_spherical_harmonics_cache()
return Y_Js
def get_basis(r_ij, max_degree, differentiable = False):
"""Return equivariant weight basis (basis)
Call this function *once* at the start of each forward pass of the model.
It computes the equivariant weight basis, W_J^lk(x), and internodal
distances, needed to compute varphi_J^lk(x), of eqn 8 of
https://arxiv.org/pdf/2006.10503.pdf. The return values of this function
can be shared as input across all SE(3)-Transformer layers in a model.
Args:
r_ij: relative positional vectors
max_degree: non-negative int for degree of highest feature-type
differentiable: whether r_ij should receive gradients from basis
Returns:
dict of equivariant bases, keys are in form '<d_in><d_out>'
"""
context = null_context if not differentiable else torch.no_grad
device, dtype = r_ij.device, r_ij.dtype
with context():
r_ij = get_spherical_from_cartesian(r_ij)
Y = precompute_sh(r_ij, 2 * max_degree)
basis = {}
for d_in, d_out in product(range(max_degree+1), range(max_degree+1)):
K_Js = []
for J in range(abs(d_in - d_out), d_in + d_out + 1):
Q_J = basis_transformation_Q_J(J, d_in, d_out)
Q_J = Q_J.type(dtype).to(device)
K_J = torch.matmul(Y[J], Q_J.T)
K_Js.append(K_J)
K_Js = torch.stack(K_Js, dim = -1)
size = (*r_ij.shape[:-1], 1, to_order(d_out), 1, to_order(d_in), to_order(min(d_in,d_out)))
basis[f'{d_in},{d_out}'] = K_Js.view(*size)
if not differentiable:
for k, v in basis.items():
basis[k] = v.detach()
return basis
.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\irr_repr.py
import os
import numpy as np
import torch
from torch import sin, cos, atan2, acos
from math import pi
from pathlib import Path
from functools import wraps
from se3_transformer_pytorch.utils import exists, default, cast_torch_tensor, to_order
from se3_transformer_pytorch.spherical_harmonics import get_spherical_harmonics, clear_spherical_harmonics_cache
DATA_PATH = path = Path(os.path.dirname(__file__)) / 'data'
try:
path = DATA_PATH / 'J_dense.pt'
Jd = torch.load(str(path))
except:
path = DATA_PATH / 'J_dense.npy'
Jd_np = np.load(str(path), allow_pickle = True)
Jd = list(map(torch.from_numpy, Jd_np))
def wigner_d_matrix(degree, alpha, beta, gamma, dtype = None, device = None):
"""Create wigner D matrices for batch of ZYZ Euler anglers for degree l."""
J = Jd[degree].type(dtype).to(device)
order = to_order(degree)
x_a = z_rot_mat(alpha, degree)
x_b = z_rot_mat(beta, degree)
x_c = z_rot_mat(gamma, degree)
res = x_a @ J @ x_b @ J @ x_c
return res.view(order, order)
def z_rot_mat(angle, l):
device, dtype = angle.device, angle.dtype
order = to_order(l)
m = angle.new_zeros((order, order))
inds = torch.arange(0, order, 1, dtype=torch.long, device=device)
reversed_inds = torch.arange(2 * l, -1, -1, dtype=torch.long, device=device)
frequencies = torch.arange(l, -l - 1, -1, dtype=dtype, device=device)[None]
m[inds, reversed_inds] = sin(frequencies * angle[None])
m[inds, inds] = cos(frequencies * angle[None])
return m
def irr_repr(order, alpha, beta, gamma, dtype = None):
"""
irreducible representation of SO3
- compatible with compose and spherical_harmonics
"""
cast_ = cast_torch_tensor(lambda t: t)
dtype = default(dtype, torch.get_default_dtype())
alpha, beta, gamma = map(cast_, (alpha, beta, gamma))
return wigner_d_matrix(order, alpha, beta, gamma, dtype = dtype)
@cast_torch_tensor
def rot_z(gamma):
'''
Rotation around Z axis
'''
return torch.tensor([
[cos(gamma), -sin(gamma), 0],
[sin(gamma), cos(gamma), 0],
[0, 0, 1]
], dtype=gamma.dtype)
@cast_torch_tensor
def rot_y(beta):
'''
Rotation around Y axis
'''
return torch.tensor([
[cos(beta), 0, sin(beta)],
[0, 1, 0],
[-sin(beta), 0, cos(beta)]
], dtype=beta.dtype)
@cast_torch_tensor
def x_to_alpha_beta(x):
'''
Convert point (x, y, z) on the sphere into (alpha, beta)
'''
x = x / torch.norm(x)
beta = acos(x[2])
alpha = atan2(x[1], x[0])
return (alpha, beta)
def rot(alpha, beta, gamma):
'''
ZYZ Euler angles rotation
'''
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
def compose(a1, b1, c1, a2, b2, c2):
"""
(a, b, c) = (a1, b1, c1) composed with (a2, b2, c2)
"""
comp = rot(a1, b1, c1) @ rot(a2, b2, c2)
xyz = comp @ torch.tensor([0, 0, 1.])
a, b = x_to_alpha_beta(xyz)
rotz = rot(0, -b, -a) @ comp
c = atan2(rotz[1, 0], rotz[0, 0])
return a, b, c
def spherical_harmonics(order, alpha, beta, dtype = None):
return get_spherical_harmonics(order, theta = (pi - beta), phi = alpha)
.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\reversible.py
import torch
import torch.nn as nn
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states
def map_values(fn, x):
out = {}
for (k, v) in x.items():
out[k] = fn(v)
return out
def dict_chunk(x, chunks, dim):
out1 = {}
out2 = {}
for (k, v) in x.items():
c1, c2 = v.chunk(chunks, dim=dim)
out1[k] = c1
out2[k] = c2
return out1, out2
def dict_sum(x, y):
out = {}
for k in x.keys():
out[k] = x[k] + y[k]
return out
def dict_subtract(x, y):
out = {}
for k in x.keys():
out[k] = x[k] - y[k]
return out
def dict_cat(x, y, dim):
out = {}
for k, v1 in x.items():
v2 = y[k]
out[k] = torch.cat((v1, v2), dim=dim)
return out
def dict_set_(x, key, value):
for k, v in x.items():
setattr(v, key, value)
def dict_backwards_(outputs, grad_tensors):
for k, v in outputs.items():
torch.autograd.backward(v, grad_tensors[k], retain_graph=True)
def dict_del_(x):
for k, v in x.items():
del v
del x
def values(d):
return [v for _, v in d.items()]
class Deterministic(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.cpu_state = None
self.cuda_in_fwd = None
self.gpu_devices = None
self.gpu_states = None
def record_rng(self, *args):
self.cpu_state = torch.get_rng_state()
if torch.cuda._initialized:
self.cuda_in_fwd = True
self.gpu_devices, self.gpu_states = get_device_states(*args)
def forward(self, *args, record_rng=False, set_rng=False, **kwargs):
if record_rng:
self.record_rng(*args)
if not set_rng:
return self.net(*args, **kwargs)
rng_devices = []
if self.cuda_in_fwd:
rng_devices = self.gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=True):
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
return self.net(*args, **kwargs)
class ReversibleBlock(nn.Module):
def __init__(self, f, g):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
def forward(self, x, **kwargs):
training = self.training
x1, x2 = dict_chunk(x, 2, dim=-1)
y1, y2 = None, None
with torch.no_grad():
y1 = dict_sum(x1, self.f(x2, record_rng=training, **kwargs))
y2 = dict_sum(x2, self.g(y1, record_rng=training))
return dict_cat(y1, y2, dim=-1)
def backward_pass(self, y, dy, **kwargs):
y1, y2 = dict_chunk(y, 2, dim = -1)
dict_del_(y)
dy1, dy2 = dict_chunk(dy, 2, dim = -1)
dict_del_(dy)
with torch.enable_grad():
dict_set_(y1, 'requires_grad', True)
gy1 = self.g(y1, set_rng = True)
dict_backwards_(gy1, dy2)
with torch.no_grad():
x2 = dict_subtract(y2, gy1)
dict_del_(y2)
dict_del_(gy1)
dx1 = dict_sum(dy1, map_values(lambda t: t.grad, y1))
dict_del_(dy1)
dict_set_(y1, 'grad', None)
with torch.enable_grad():
dict_set_(x2, 'requires_grad', True)
fx2 = self.f(x2, set_rng = True, **kwargs)
dict_backwards_(fx2, dx1)
with torch.no_grad():
x1 = dict_subtract(y1, fx2)
dict_del_(y1)
dict_del_(fx2)
dx2 = dict_sum(dy2, map_values(lambda t: t.grad, x2))
dict_del_(dy2)
dict_set_(x2, 'grad', None)
x2 = map_values(lambda t: t.detach(), x2)
x = dict_cat(x1, x2, dim = -1)
dx = dict_cat(dx1, dx2, dim = -1)
return x, dx
class _ReversibleFunction(Function):
@staticmethod
def forward(ctx, x, blocks, kwargs):
input_keys = kwargs.pop('input_keys')
split_dims = kwargs.pop('split_dims')
input_values = x.split(split_dims, dim = -1)
x = dict(zip(input_keys, input_values))
ctx.kwargs = kwargs
ctx.split_dims = split_dims
ctx.input_keys = input_keys
for block in blocks:
x = block(x, **kwargs)
ctx.y = map_values(lambda t: t.detach(), x)
ctx.blocks = blocks
x = torch.cat(values(x), dim = -1)
return x
@staticmethod
def backward(ctx, dy):
y = ctx.y
kwargs = ctx.kwargs
input_keys = ctx.input_keys
split_dims = ctx.split_dims
dy = dy.split(split_dims, dim = -1)
dy = dict(zip(input_keys, dy))
for block in ctx.blocks[::-1]:
y, dy = block.backward_pass(y, dy, **kwargs)
dy = torch.cat(values(dy), dim = -1)
return dy, None, None
class SequentialSequence(nn.Module):
def __init__(self, blocks):
super().__init__()
self.blocks = blocks
def forward(self, x, **kwargs):
for (attn, ff) in self.blocks:
x = attn(x, **kwargs)
x = ff(x)
return x
class ReversibleSequence(nn.Module):
def __init__(self, blocks):
super().__init__()
self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks])
def forward(self, x, **kwargs):
blocks = self.blocks
x = map_values(lambda t: torch.cat((t, t), dim = -1), x)
input_keys = x.keys()
split_dims = tuple(map(lambda t: t.shape[-1], x.values()))
block_kwargs = {'input_keys': input_keys, 'split_dims': split_dims, **kwargs}
x = torch.cat(values(x), dim = -1)
x = _ReversibleFunction.apply(x, blocks, block_kwargs)
x = dict(zip(input_keys, x.split(split_dims, dim = -1)))
x = map_values(lambda t: torch.stack(t.chunk(2, dim = -1)).mean(dim = 0), x)
return x
.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\rotary.py
import torch
from torch import nn, einsum
from einops import rearrange, repeat
class SinusoidalEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, t):
freqs = t[..., None].float() * self.inv_freq[None, :]
return repeat(freqs, '... d -> ... (d r)', r = 2)
def rotate_half(x):
x = rearrange(x, '... (d j) m -> ... d j m', j = 2)
x1, x2 = x.unbind(dim = -2)
return torch.cat((-x2, x1), dim = -2)
def apply_rotary_pos_emb(t, freqs):
rot_dim = freqs.shape[-2]
t, t_pass = t[..., :rot_dim, :], t[..., rot_dim:, :]
t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
return torch.cat((t, t_pass), dim = -2)
.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\se3_transformer_pytorch.py
from math import sqrt
from itertools import product
from collections import namedtuple
import torch
import torch.nn.functional as F
from torch import nn, einsum
from se3_transformer_pytorch.basis import get_basis
from se3_transformer_pytorch.utils import exists, default, uniq, map_values, batched_index_select, masked_mean, to_order, fourier_encode, cast_tuple, safe_cat, fast_split, rand_uniform, broadcat
from se3_transformer_pytorch.reversible import ReversibleSequence, SequentialSequence
from se3_transformer_pytorch.rotary import SinusoidalEmbeddings, apply_rotary_pos_emb
from einops import rearrange, repeat
FiberEl = namedtuple('FiberEl', ['degrees', 'dim'])
class Fiber(nn.Module):
def __init__(
self,
structure
):
super().__init__()
if isinstance(structure, dict):
structure = [FiberEl(degree, dim) for degree, dim in structure.items()]
self.structure = structure
@property
def dims(self):
return uniq(map(lambda t: t[1], self.structure))
@property
def degrees(self):
return map(lambda t: t[0], self.structure)
@staticmethod
def create(num_degrees, dim):
dim_tuple = dim if isinstance(dim, tuple) else ((dim,) * num_degrees)
return Fiber([FiberEl(degree, dim) for degree, dim in zip(range(num_degrees), dim_tuple)])
def __getitem__(self, degree):
return dict(self.structure)[degree]
def __iter__(self):
return iter(self.structure)
def __mul__(self, fiber):
return product(self.structure, fiber.structure)
def __and__(self, fiber):
out = []
degrees_out = fiber.degrees
for degree, dim in self:
if degree in fiber.degrees:
dim_out = fiber[degree]
out.append((degree, dim, dim_out))
return out
def get_tensor_device_and_dtype(features):
first_tensor = next(iter(features.items()))[1]
return first_tensor.device, first_tensor.dtype
class ResidualSE3(nn.Module):
""" only support instance where both Fibers are identical """
def forward(self, x, res):
out = {}
for degree, tensor in x.items():
degree = str(degree)
out[degree] = tensor
if degree in res:
out[degree] = out[degree] + res[degree]
return out
class LinearSE3(nn.Module):
def __init__(
self,
fiber_in,
fiber_out
):
super().__init__()
self.weights = nn.ParameterDict()
for (degree, dim_in, dim_out) in (fiber_in & fiber_out):
key = str(degree)
self.weights[key] = nn.Parameter(torch.randn(dim_in, dim_out) / sqrt(dim_in))
def forward(self, x):
out = {}
for degree, weight in self.weights.items():
out[degree] = einsum('b n d m, d e -> b n e m', x[degree], weight)
return out
class NormSE3(nn.Module):
"""Norm-based SE(3)-equivariant nonlinearity.
Nonlinearities are important in SE(3) equivariant GCNs. They are also quite
expensive to compute, so it is convenient for them to share resources with
other layers, such as normalization. The general workflow is as follows:
> for feature type in features:
> norm, phase <- feature
> output = fnc(norm) * phase
where fnc: {R+}^m -> R^m is a learnable map from m norms to m scalars.
"""
def __init__(
self,
fiber,
nonlin = nn.GELU(),
gated_scale = False,
eps = 1e-12,
def __init__(
self,
fiber,
nonlin = nn.ReLU(),
eps = 1e-12,
gated_scale = False
):
super().__init__()
self.fiber = fiber
self.nonlin = nonlin
self.eps = eps
self.transform = nn.ModuleDict()
for degree, chan in fiber:
self.transform[str(degree)] = nn.ParameterDict({
'scale': nn.Parameter(torch.ones(1, 1, chan)) if not gated_scale else None,
'w_gate': nn.Parameter(rand_uniform((chan, chan), -1e-3, 1e-3)) if gated_scale else None
})
def forward(self, features):
output = {}
for degree, t in features.items():
norm = t.norm(dim = -1, keepdim = True).clamp(min = self.eps)
phase = t / norm
parameters = self.transform[degree]
gate_weights, scale = parameters['w_gate'], parameters['scale']
transformed = rearrange(norm, '... () -> ...')
if not exists(scale):
scale = einsum('b n d, d e -> b n e', transformed, gate_weights)
transformed = self.nonlin(transformed * scale)
transformed = rearrange(transformed, '... -> ... ()')
output[degree] = (transformed * phase).view(*t.shape)
return output
class ConvSE3(nn.Module):
"""定义一个张量场网络层
ConvSE3代表一个SE(3)-等变卷积层。它相当于MLP中的线性层,CNN中的卷积层,或者GCN中的图卷积层。
在每个节点上,激活被分成不同的“特征类型”,由SE(3)表示类型索引:非负整数0, 1, 2, ..
"""
def __init__(
self,
fiber_in,
fiber_out,
self_interaction = True,
pool = True,
edge_dim = 0,
fourier_encode_dist = False,
num_fourier_features = 4,
splits = 4
):
super().__init__()
self.fiber_in = fiber_in
self.fiber_out = fiber_out
self.edge_dim = edge_dim
self.self_interaction = self_interaction
self.num_fourier_features = num_fourier_features
self.fourier_encode_dist = fourier_encode_dist
edge_dim += (0 if not fourier_encode_dist else (num_fourier_features * 2))
self.kernel_unary = nn.ModuleDict()
self.splits = splits
for (di, mi), (do, mo) in (self.fiber_in * self.fiber_out):
self.kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim = edge_dim, splits = splits)
self.pool = pool
if self_interaction:
assert self.pool, 'must pool edges if followed with self interaction'
self.self_interact = LinearSE3(fiber_in, fiber_out)
self.self_interact_sum = ResidualSE3()
def forward(
self,
inp,
edge_info,
rel_dist = None,
basis = None
):
splits = self.splits
neighbor_indices, neighbor_masks, edges = edge_info
rel_dist = rearrange(rel_dist, 'b m n -> b m n ()')
kernels = {}
outputs = {}
if self.fourier_encode_dist:
rel_dist = fourier_encode(rel_dist[..., None], num_encodings = self.num_fourier_features)
basis_keys = basis.keys()
split_basis_values = list(zip(*list(map(lambda t: fast_split(t, splits, dim = 1), basis.values())))
split_basis = list(map(lambda v: dict(zip(basis_keys, v)), split_basis_values))
for degree_out in self.fiber_out.degrees:
output = 0
degree_out_key = str(degree_out)
for degree_in, m_in in self.fiber_in:
etype = f'({degree_in},{degree_out})'
x = inp[str(degree_in)]
x = batched_index_select(x, neighbor_indices, dim = 1)
x = x.view(*x.shape[:3], to_order(degree_in) * m_in, 1)
kernel_fn = self.kernel_unary[etype]
edge_features = torch.cat((rel_dist, edges), dim = -1) if exists(edges) else rel_dist
output_chunk = None
split_x = fast_split(x, splits, dim = 1)
split_edge_features = fast_split(edge_features, splits, dim = 1)
for x_chunk, edge_features, basis in zip(split_x, split_edge_features, split_basis):
kernel = kernel_fn(edge_features, basis = basis)
chunk = einsum('... o i, ... i c -> ... o c', kernel, x_chunk)
output_chunk = safe_cat(output_chunk, chunk, dim = 1)
output = output + output_chunk
if self.pool:
output = masked_mean(output, neighbor_masks, dim = 2) if exists(neighbor_masks) else output.mean(dim = 2)
leading_shape = x.shape[:2] if self.pool else x.shape[:3]
output = output.view(*leading_shape, -1, to_order(degree_out))
outputs[degree_out_key] = output
if self.self_interaction:
self_interact_out = self.self_interact(inp)
outputs = self.self_interact_sum(outputs, self_interact_out)
return outputs
class RadialFunc(nn.Module):
"""定义一个神经网络参数化的径向函数。"""
def __init__(
self,
num_freq,
in_dim,
out_dim,
edge_dim = None,
mid_dim = 128
):
super().__init__()
self.num_freq = num_freq
self.in_dim = in_dim
self.mid_dim = mid_dim
self.out_dim = out_dim
self.edge_dim = default(edge_dim, 0)
self.net = nn.Sequential(
nn.Linear(self.edge_dim + 1, mid_dim),
nn.LayerNorm(mid_dim),
nn.GELU(),
nn.Linear(mid_dim, mid_dim),
nn.LayerNorm(mid_dim),
nn.GELU(),
nn.Linear(mid_dim, num_freq * in_dim * out_dim)
)
def forward(self, x):
y = self.net(x)
return rearrange(y, '... (o i f) -> ... o () i () f', i = self.in_dim, o = self.out_dim)
class PairwiseConv(nn.Module):
"""两种单一类型特征之间的SE(3)-等变卷积。"""
def __init__(
self,
degree_in,
nc_in,
degree_out,
nc_out,
edge_dim = 0,
splits = 4
):
super().__init__()
self.degree_in = degree_in
self.degree_out = degree_out
self.nc_in = nc_in
self.nc_out = nc_out
self.num_freq = to_order(min(degree_in, degree_out))
self.d_out = to_order(degree_out)
self.edge_dim = edge_dim
self.rp = RadialFunc(self.num_freq, nc_in, nc_out, edge_dim)
self.splits = splits
def forward(self, feat, basis):
splits = self.splits
R = self.rp(feat)
B = basis[f'{self.degree_in},{self.degree_out}']
out_shape = (*R.shape[:3], self.d_out * self.nc_out, -1)
out = 0
for i in range(R.shape[-1]):
out += R[..., i] * B[..., i]
out = rearrange(out, 'b n h s ... -> (b n h s) ...')
return out.view(*out_shape)
class FeedForwardSE3(nn.Module):
def __init__(
self,
fiber,
mult = 4
):
super().__init__()
self.fiber = fiber
fiber_hidden = Fiber(list(map(lambda t: (t[0], t[1] * mult), fiber)))
self.project_in = LinearSE3(fiber, fiber_hidden)
self.nonlin = NormSE3(fiber_hidden)
self.project_out = LinearSE3(fiber_hidden, fiber)
def forward(self, features):
outputs = self.project_in(features)
outputs = self.nonlin(outputs)
outputs = self.project_out(outputs)
return outputs
class FeedForwardBlockSE3(nn.Module):
def __init__(
self,
fiber,
norm_gated_scale = False
):
super().__init__()
self.fiber = fiber
self.prenorm = NormSE3(fiber, gated_scale = norm_gated_scale)
self.feedforward = FeedForwardSE3(fiber)
self.residual = ResidualSE3()
def forward(self, features):
res = features
out = self.prenorm(features)
out = self.feedforward(out)
return self.residual(out, res)
class AttentionSE3(nn.Module):
def __init__(
self,
fiber,
dim_head = 64,
heads = 8,
attend_self = False,
edge_dim = None,
fourier_encode_dist = False,
rel_dist_num_fourier_features = 4,
use_null_kv = False,
splits = 4,
global_feats_dim = None,
linear_proj_keys = False,
tie_key_values = False
):
super().__init__()
hidden_dim = dim_head * heads
hidden_fiber = Fiber(list(map(lambda t: (t[0], hidden_dim), fiber)))
project_out = not (heads == 1 and len(fiber.dims) == 1 and dim_head == fiber.dims[0])
self.scale = dim_head ** -0.5
self.heads = heads
self.linear_proj_keys = linear_proj_keys
self.to_q = LinearSE3(fiber, hidden_fiber)
self.to_v = ConvSE3(fiber, hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
assert not (linear_proj_keys and tie_key_values), 'you cannot do linear projection of keys and have shared key / values turned on at the same time'
if linear_proj_keys:
self.to_k = LinearSE3(fiber, hidden_fiber)
elif not tie_key_values:
self.to_k = ConvSE3(fiber, hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
else:
self.to_k = None
self.to_out = LinearSE3(hidden_fiber, fiber) if project_out else nn.Identity()
self.use_null_kv = use_null_kv
if use_null_kv:
self.null_keys = nn.ParameterDict()
self.null_values = nn.ParameterDict()
for degree in fiber.degrees:
m = to_order(degree)
degree_key = str(degree)
self.null_keys[degree_key] = nn.Parameter(torch.zeros(heads, dim_head, m))
self.null_values[degree_key] = nn.Parameter(torch.zeros(heads, dim_head, m))
self.attend_self = attend_self
if attend_self:
self.to_self_k = LinearSE3(fiber, hidden_fiber)
self.to_self_v = LinearSE3(fiber, hidden_fiber)
self.accept_global_feats = exists(global_feats_dim)
if self.accept_global_feats:
global_input_fiber = Fiber.create(1, global_feats_dim)
global_output_fiber = Fiber.create(1, hidden_fiber[0])
self.to_global_k = LinearSE3(global_input_fiber, global_output_fiber)
self.to_global_v = LinearSE3(global_input_fiber, global_output_fiber)
def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None):
h, attend_self = self.heads, self.attend_self
device, dtype = get_tensor_device_and_dtype(features)
neighbor_indices, neighbor_mask, edges = edge_info
if exists(neighbor_mask):
neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j')
queries = self.to_q(features)
values = self.to_v(features, edge_info, rel_dist, basis)
if self.linear_proj_keys:
keys = self.to_k(features)
keys = map_values(lambda val: batched_index_select(val, neighbor_indices, dim = 1), keys)
elif not exists(self.to_k):
keys = values
else:
keys = self.to_k(features, edge_info, rel_dist, basis)
if attend_self:
self_keys, self_values = self.to_self_k(features), self.to_self_v(features)
if exists(global_feats):
global_keys, global_values = self.to_global_k(global_feats), self.to_global_v(global_feats)
outputs = {}
for degree in features.keys():
q, k, v = map(lambda t: t[degree], (queries, keys, values))
q = rearrange(q, 'b i (h d) m -> b h i d m', h = h)
k, v = map(lambda t: rearrange(t, 'b i j (h d) m -> b h i j d m', h = h), (k, v))
if attend_self:
self_k, self_v = map(lambda t: t[degree], (self_keys, self_values))
self_k, self_v = map(lambda t: rearrange(t, 'b n (h d) m -> b h n () d m', h = h), (self_k, self_v))
k = torch.cat((self_k, k), dim = 3)
v = torch.cat((self_v, v), dim = 3)
if exists(pos_emb) and degree == '0':
query_pos_emb, key_pos_emb = pos_emb
query_pos_emb = rearrange(query_pos_emb, 'b i d -> b () i d ()')
key_pos_emb = rearrange(key_pos_emb, 'b i j d -> b () i j d ()')
q = apply_rotary_pos_emb(q, query_pos_emb)
k = apply_rotary_pos_emb(k, key_pos_emb)
v = apply_rotary_pos_emb(v, key_pos_emb)
if self.use_null_kv:
null_k, null_v = map(lambda t: t[degree], (self.null_keys, self.null_values))
null_k, null_v = map(lambda t: repeat(t, 'h d m -> b h i () d m', b = q.shape[0], i = q.shape[2]), (null_k, null_v))
k = torch.cat((null_k, k), dim = 3)
v = torch.cat((null_v, v), dim = 3)
if exists(global_feats) and degree == '0':
global_k, global_v = map(lambda t: t[degree], (global_keys, global_values))
global_k, global_v = map(lambda t: repeat(t, 'b j (h d) m -> b h i j d m', h = h, i = k.shape[2]), (global_k, global_v))
k = torch.cat((global_k, k), dim = 3)
v = torch.cat((global_v, v), dim = 3)
sim = einsum('b h i d m, b h i j d m -> b h i j', q, k) * self.scale
if exists(neighbor_mask):
num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1]
mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True)
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h i j d m -> b h i d m', attn, v)
outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m')
return self.to_out(outputs)
class OneHeadedKVAttentionSE3(nn.Module):
def __init__(
self,
fiber,
dim_head = 64,
heads = 8,
attend_self = False,
edge_dim = None,
fourier_encode_dist = False,
rel_dist_num_fourier_features = 4,
use_null_kv = False,
splits = 4,
global_feats_dim = None,
linear_proj_keys = False,
tie_key_values = False
):
super().__init__()
hidden_dim = dim_head * heads
hidden_fiber = Fiber(list(map(lambda t: (t[0], hidden_dim), fiber)))
kv_hidden_fiber = Fiber(list(map(lambda t: (t[0], dim_head), fiber)))
project_out = not (heads == 1 and len(fiber.dims) == 1 and dim_head == fiber.dims[0])
self.scale = dim_head ** -0.5
self.heads = heads
self.linear_proj_keys = linear_proj_keys
self.to_q = LinearSE3(fiber, hidden_fiber)
self.to_v = ConvSE3(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
assert not (linear_proj_keys and tie_key_values), 'you cannot do linear projection of keys and have shared key / values turned on at the same time'
if linear_proj_keys:
self.to_k = LinearSE3(fiber, kv_hidden_fiber)
elif not tie_key_values:
self.to_k = ConvSE3(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
else:
self.to_k = None
self.to_out = LinearSE3(hidden_fiber, fiber) if project_out else nn.Identity()
self.use_null_kv = use_null_kv
if use_null_kv:
self.null_keys = nn.ParameterDict()
self.null_values = nn.ParameterDict()
for degree in fiber.degrees:
m = to_order(degree)
degree_key = str(degree)
self.null_keys[degree_key] = nn.Parameter(torch.zeros(dim_head, m))
self.null_values[degree_key] = nn.Parameter(torch.zeros(dim_head, m))
self.attend_self = attend_self
if attend_self:
self.to_self_k = LinearSE3(fiber, kv_hidden_fiber)
self.to_self_v = LinearSE3(fiber, kv_hidden_fiber)
self.accept_global_feats = exists(global_feats_dim)
if self.accept_global_feats:
global_input_fiber = Fiber.create(1, global_feats_dim)
global_output_fiber = Fiber.create(1, kv_hidden_fiber[0])
self.to_global_k = LinearSE3(global_input_fiber, global_output_fiber)
self.to_global_v = LinearSE3(global_input_fiber, global_output_fiber)
def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None):
h, attend_self = self.heads, self.attend_self
device, dtype = get_tensor_device_and_dtype(features)
neighbor_indices, neighbor_mask, edges = edge_info
if exists(neighbor_mask):
neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j')
queries = self.to_q(features)
values = self.to_v(features, edge_info, rel_dist, basis)
if self.linear_proj_keys:
keys = self.to_k(features)
keys = map_values(lambda val: batched_index_select(val, neighbor_indices, dim = 1), keys)
elif not exists(self.to_k):
keys = values
else:
keys = self.to_k(features, edge_info, rel_dist, basis)
if attend_self:
self_keys, self_values = self.to_self_k(features), self.to_self_v(features)
if exists(global_feats):
global_keys, global_values = self.to_global_k(global_feats), self.to_global_v(global_feats)
outputs = {}
for degree in features.keys():
q, k, v = map(lambda t: t[degree], (queries, keys, values))
q = rearrange(q, 'b i (h d) m -> b h i d m', h = h)
if attend_self:
self_k, self_v = map(lambda t: t[degree], (self_keys, self_values))
self_k, self_v = map(lambda t: rearrange(t, 'b n d m -> b n () d m'), (self_k, self_v))
k = torch.cat((self_k, k), dim = 2)
v = torch.cat((self_v, v), dim = 2)
if exists(pos_emb) and degree == '0':
query_pos_emb, key_pos_emb = pos_emb
query_pos_emb = rearrange(query_pos_emb, 'b i d -> b () i d ()')
key_pos_emb = rearrange(key_pos_emb, 'b i j d -> b i j d ()')
q = apply_rotary_pos_emb(q, query_pos_emb)
k = apply_rotary_pos_emb(k, key_pos_emb)
v = apply_rotary_pos_emb(v, key_pos_emb)
if self.use_null_kv:
null_k, null_v = map(lambda t: t[degree], (self.null_keys, self.null_values))
null_k, null_v = map(lambda t: repeat(t, 'd m -> b i () d m', b = q.shape[0], i = q.shape[2]), (null_k, null_v))
k = torch.cat((null_k, k), dim = 2)
v = torch.cat((null_v, v), dim = 2)
if exists(global_feats) and degree == '0':
global_k, global_v = map(lambda t: t[degree], (global_keys, global_values))
global_k, global_v = map(lambda t: repeat(t, 'b j d m -> b i j d m', i = k.shape[1]), (global_k, global_v))
k = torch.cat((global_k, k), dim = 2)
v = torch.cat((global_v, v), dim = 2)
sim = einsum('b h i d m, b i j d m -> b h i j', q, k) * self.scale
if exists(neighbor_mask):
num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1]
mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True)
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b i j d m -> b h i d m', attn, v)
outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m')
return self.to_out(outputs)
class AttentionBlockSE3(nn.Module):
def __init__(
self,
fiber,
dim_head = 24,
heads = 8,
attend_self = False,
edge_dim = None,
use_null_kv = False,
fourier_encode_dist = False,
rel_dist_num_fourier_features = 4,
splits = 4,
global_feats_dim = False,
linear_proj_keys = False,
tie_key_values = False,
attention_klass = AttentionSE3,
norm_gated_scale = False
):
super().__init__()
self.attn = attention_klass(fiber, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, use_null_kv = use_null_kv, rel_dist_num_fourier_features = rel_dist_num_fourier_features, fourier_encode_dist =fourier_encode_dist, splits = splits, global_feats_dim = global_feats_dim, linear_proj_keys = linear_proj_keys, tie_key_values = tie_key_values)
self.prenorm = NormSE3(fiber, gated_scale = norm_gated_scale)
self.residual = ResidualSE3()
def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None):
res = features
outputs = self.prenorm(features)
outputs = self.attn(outputs, edge_info, rel_dist, basis, global_feats, pos_emb, mask)
return self.residual(outputs, res)
class Swish_(nn.Module):
def forward(self, x):
return x * x.sigmoid()
SiLU = nn.SiLU if hasattr(nn, 'SiLU') else Swish_
class HtypesNorm(nn.Module):
def __init__(self, dim, eps = 1e-8, scale_init = 1e-2, bias_init = 1e-2):
super().__init__()
self.eps = eps
scale = torch.empty(1, 1, 1, dim, 1).fill_(scale_init)
bias = torch.empty(1, 1, 1, dim, 1).fill_(bias_init)
self.scale = nn.Parameter(scale)
self.bias = nn.Parameter(bias)
def forward(self, coors):
norm = coors.norm(dim = -1, keepdim = True)
normed_coors = coors / norm.clamp(min = self.eps)
return normed_coors * (norm * self.scale + self.bias)
class EGNN(nn.Module):
def __init__(
self,
fiber,
hidden_dim = 32,
edge_dim = 0,
init_eps = 1e-3,
coor_weights_clamp_value = None
):
super().__init__()
self.fiber = fiber
node_dim = fiber[0]
htypes = list(filter(lambda t: t.degrees != 0, fiber))
num_htypes = len(htypes)
htype_dims = sum([fiberel.dim for fiberel in htypes])
edge_input_dim = node_dim * 2 + htype_dims + edge_dim + 1
self.node_norm = nn.LayerNorm(node_dim)
self.edge_mlp = nn.Sequential(
nn.Linear(edge_input_dim, edge_input_dim * 2),
SiLU(),
nn.Linear(edge_input_dim * 2, hidden_dim),
SiLU()
)
self.htype_norms = nn.ModuleDict({})
self.htype_gating = nn.ModuleDict({})
for degree, dim in fiber:
if degree == 0:
continue
self.htype_norms[str(degree)] = HtypesNorm(dim)
self.htype_gating[str(degree)] = nn.Linear(node_dim, dim)
self.htypes_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
SiLU(),
nn.Linear(hidden_dim * 4, htype_dims)
)
self.node_mlp = nn.Sequential(
nn.Linear(node_dim + hidden_dim, node_dim * 2),
SiLU(),
nn.Linear(node_dim * 2, node_dim)
)
self.coor_weights_clamp_value = coor_weights_clamp_value
self.init_eps = init_eps
self.apply(self.init_)
def init_(self, module):
if type(module) in {nn.Linear}:
nn.init.normal_(module.weight, std = self.init_eps)
def forward(
self,
features,
edge_info,
rel_dist,
mask = None,
**kwargs
):
neighbor_indices, neighbor_masks, edges = edge_info
mask = neighbor_masks
nodes = features['0']
nodes = rearrange(nodes, '... () -> ...')
htypes = list(filter(lambda t: t[0] != '0', features.items()))
htype_degrees = list(map(lambda t: t[0], htypes))
htype_dims = list(map(lambda t: t[1].shape[-2], htypes))
rel_htypes = []
rel_htypes_dists = []
for degree, htype in htypes:
rel_htype = rearrange(htype, 'b i d m -> b i () d m') - rearrange(htype, 'b j d m -> b () j d m')
rel_htype_dist = rel_htype.norm(dim = -1)
rel_htypes.append(rel_htype)
rel_htypes_dists.append(rel_htype_dist)
nodes_i = rearrange(nodes, 'b i d -> b i () d')
nodes_j = batched_index_select(nodes, neighbor_indices, dim = 1)
neighbor_higher_type_dists = map(lambda t: batched_index_select(t, neighbor_indices, dim = 2), rel_htypes_dists)
coor_rel_dist = rearrange(rel_dist, 'b i j -> b i j ()')
edge_mlp_inputs = broadcat((nodes_i, nodes_j, *neighbor_higher_type_dists, coor_rel_dist), dim = -1)
if exists(edges):
edge_mlp_inputs = torch.cat((edge_mlp_inputs, edges), dim = -1)
m_ij = self.edge_mlp(edge_mlp_inputs)
htype_weights = self.htypes_mlp(m_ij)
if exists(self.coor_weights_clamp_value):
clamp_value = self.coor_weights_clamp_value
htype_weights.clamp_(min = -clamp_value, max = clamp_value)
split_htype_weights = htype_weights.split(htype_dims, dim = -1)
htype_updates = []
if exists(mask):
htype_mask = rearrange(mask, 'b i j -> b i j ()')
htype_weights = htype_weights.masked_fill(~htype_mask, 0.)
for degree, rel_htype, htype_weight in zip(htype_degrees, rel_htypes, split_htype_weights):
normed_rel_htype = self.htype_norms[str(degree)](rel_htype)
normed_rel_htype = batched_index_select(normed_rel_htype, neighbor_indices, dim = 2)
htype_update = einsum('b i j d m, b i j d -> b i d m', normed_rel_htype, htype_weight)
htype_updates.append(htype_update)
if exists(mask):
m_ij_mask = rearrange(mask, '... -> ... ()')
m_ij = m_ij.masked_fill(~m_ij_mask, 0.)
m_i = m_ij.sum(dim = -2)
normed_nodes = self.node_norm(nodes)
node_mlp_input = torch.cat((normed_nodes, m_i), dim = -1)
node_out = self.node_mlp(node_mlp_input) + nodes
features['0'] = rearrange(node_out, '... -> ... ()')
update_htype_dicts = dict(zip(htype_degrees, htype_updates))
for degree, update_htype in update_htype_dicts.items():
features[degree] = features[degree] + update_htype
for degree in htype_degrees:
gating = self.htype_gating[str(degree)](node_out).sigmoid()
features[degree] = features[degree] * rearrange(gating, '... -> ... ()')
return features
class EGnnNetwork(nn.Module):
def __init__(
self,
*,
fiber,
depth,
edge_dim = 0,
hidden_dim = 32,
coor_weights_clamp_value = None,
feedforward = False
):
super().__init__()
self.fiber = fiber
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
EGNN(fiber = fiber, edge_dim = edge_dim, hidden_dim = hidden_dim, coor_weights_clamp_value = coor_weights_clamp_value),
FeedForwardBlockSE3(fiber) if feedforward else None
]))
def forward(
self,
features,
edge_info,
rel_dist,
basis,
global_feats = None,
pos_emb = None,
mask = None,
**kwargs
):
neighbor_indices, neighbor_masks, edges = edge_info
device = neighbor_indices.device
self_indices = torch.arange(neighbor_indices.shape[1], device = device)
self_indices = rearrange(self_indices, 'i -> () i ()')
neighbor_indices = broadcat((self_indices, neighbor_indices), dim = -1)
neighbor_masks = F.pad(neighbor_masks, (1, 0), value = True)
rel_dist = F.pad(rel_dist, (1, 0), value = 0.)
if exists(edges):
edges = F.pad(edges, (0, 0, 1, 0), value = 0.)
edge_info = (neighbor_indices, neighbor_masks, edges)
for egnn, ff in self.layers:
features = egnn(
features,
edge_info = edge_info,
rel_dist = rel_dist,
basis = basis,
global_feats = global_feats,
pos_emb = pos_emb,
mask = mask,
**kwargs
)
if exists(ff):
features = ff(features)
return features
class SE3Transformer(nn.Module):
def __init__(
self,
*,
dim,
heads = 8,
dim_head = 24,
depth = 2,
input_degrees = 1,
num_degrees = None,
output_degrees = 1,
valid_radius = 1e5,
reduce_dim_out = False,
num_tokens = None,
num_positions = None,
num_edge_tokens = None,
edge_dim = None,
reversible = False,
attend_self = True,
use_null_kv = False,
differentiable_coors = False,
fourier_encode_dist = False,
rel_dist_num_fourier_features = 4,
num_neighbors = float('inf'),
attend_sparse_neighbors = False,
num_adj_degrees = None,
adj_dim = 0,
max_sparse_neighbors = float('inf'),
dim_in = None,
dim_out = None,
norm_out = False,
num_conv_layers = 0,
causal = False,
splits = 4,
global_feats_dim = None,
linear_proj_keys = False,
one_headed_key_values = False,
tie_key_values = False,
rotary_position = False,
rotary_rel_dist = False,
norm_gated_scale = False,
use_egnn = False,
egnn_hidden_dim = 32,
egnn_weights_clamp_value = None,
egnn_feedforward = False,
hidden_fiber_dict = None,
out_fiber_dict = None
def forward(
self,
feats,
coors,
mask = None,
adj_mat = None,
edges = None,
return_type = None,
return_pooled = False,
neighbor_mask = None,
global_feats = None
.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\spherical_harmonics.py
from math import pi, sqrt
from functools import reduce
from operator import mul
import torch
from functools import lru_cache
from se3_transformer_pytorch.utils import cache
CACHE = {}
def clear_spherical_harmonics_cache():
CACHE.clear()
def lpmv_cache_key_fn(l, m, x):
return (l, m)
@lru_cache(maxsize = 1000)
def semifactorial(x):
return reduce(mul, range(x, 1, -2), 1.)
@lru_cache(maxsize = 1000)
def pochhammer(x, k):
return reduce(mul, range(x + 1, x + k), float(x))
def negative_lpmv(l, m, y):
if m < 0:
y *= ((-1) ** m / pochhammer(l + m + 1, -2 * m))
return y
@cache(cache=CACHE, key_fn=lpmv_cache_key_fn)
def lpmv(l, m, x):
"""Associated Legendre function including Condon-Shortley phase.
Args:
m: int order
l: int degree
x: float argument tensor
Returns:
tensor of x-shape
"""
m_abs = abs(m)
if m_abs > l:
return None
if l == 0:
return torch.ones_like(x)
if m_abs == l:
y = (-1)**m_abs * semifactorial(2*m_abs-1)
y *= torch.pow(1-x*x, m_abs/2)
return negative_lpmv(l, m, y)
lpmv(l-1, m, x)
y = ((2*l-1) / (l-m_abs)) * x * lpmv(l-1, m_abs, x)
if l - m_abs > 1:
y -= ((l+m_abs-1)/(l-m_abs)) * CACHE[(l-2, m_abs)]
if m < 0:
y = negative_lpmv(l, m, y)
return y
def get_spherical_harmonics_element(l, m, theta, phi):
"""Tesseral spherical harmonic with Condon-Shortley phase.
The Tesseral spherical harmonics are also known as the real spherical
harmonics.
Args:
l: int for degree
m: int for order, where -l <= m < l
theta: collatitude or polar angle
phi: longitude or azimuth
Returns:
tensor of shape theta
"""
m_abs = abs(m)
assert m_abs <= l, "absolute value of order m must be <= degree l"
N = sqrt((2*l + 1) / (4 * pi))
leg = lpmv(l, m_abs, torch.cos(theta))
if m == 0:
return N * leg
if m > 0:
Y = torch.cos(m * phi)
else:
Y = torch.sin(m_abs * phi)
Y *= leg
N *= sqrt(2. / pochhammer(l - m_abs + 1, 2 * m_abs))
Y *= N
return Y
def get_spherical_harmonics(l, theta, phi):
""" Tesseral harmonic with Condon-Shortley phase.
The Tesseral spherical harmonics are also known as the real spherical
harmonics.
Args:
l: int for degree
theta: collatitude or polar angle
phi: longitude or azimuth
Returns:
tensor of shape [*theta.shape, 2*l+1]
"""
return torch.stack([get_spherical_harmonics_element(l, m, theta, phi) \
for m in range(-l, l+1)],
dim=-1)
.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\utils.py
import os
import sys
import time
import pickle
import gzip
import torch
import contextlib
from functools import wraps, lru_cache
from filelock import FileLock
from einops import rearrange
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def uniq(arr):
return list({el: True for el in arr}.keys())
def to_order(degree):
return 2 * degree + 1
def map_values(fn, d):
return {k: fn(v) for k, v in d.items()}
def safe_cat(arr, el, dim):
if not exists(arr):
return el
return torch.cat((arr, el), dim=dim)
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else (val,) * depth
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
def batched_index_select(values, indices, dim=1):
value_dims = values.shape[(dim + 1):]
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
indices = indices[(..., *((None,) * len(value_dims))]
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
value_expand_len = len(indices_shape) - (dim + 1)
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
value_expand_shape = [-1] * len(values.shape)
expand_slice = slice(dim, (dim + value_expand_len))
value_expand_shape[expand_slice] = indices.shape[expand_slice]
values = values.expand(*value_expand_shape)
dim += value_expand_len
return values.gather(dim, indices)
def masked_mean(tensor, mask, dim=-1):
diff_len = len(tensor.shape) - len(mask.shape)
mask = mask[(..., *((None,) * diff_len))]
tensor.masked_fill_(~mask, 0.)
total_el = mask.sum(dim=dim)
mean = tensor.sum(dim=dim) / total_el.clamp(min=1.)
mean.masked_fill_(total_el == 0, 0.)
return mean
def rand_uniform(size, min_val, max_val):
return torch.empty(size).uniform_(min_val, max_val)
def fast_split(arr, splits, dim=0):
axis_len = arr.shape[dim]
splits = min(axis_len, max(splits, 1))
chunk_size = axis_len // splits
remainder = axis_len - chunk_size * splits
s = 0
for i in range(splits):
adjust, remainder = 1 if remainder > 0 else 0, remainder - 1
yield torch.narrow(arr, dim, s, chunk_size + adjust)
s += chunk_size + adjust
def fourier_encode(x, num_encodings=4, include_self=True, flatten=True):
x = x.unsqueeze(-1)
device, dtype, orig_x = x.device, x.dtype, x
scales = 2 ** torch.arange(num_encodings, device=device, dtype=dtype)
x = x / scales
x = torch.cat([x.sin(), x.cos()], dim=-1)
x = torch.cat((x, orig_x), dim=-1) if include_self else x
x = rearrange(x, 'b m n ... -> b m n (...)') if flatten else x
return x
@contextlib.contextmanager
def torch_default_dtype(dtype):
prev_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(prev_dtype)
def cast_torch_tensor(fn):
@wraps(fn)
def inner(t):
if not torch.is_tensor(t):
t = torch.tensor(t, dtype=torch.get_default_dtype())
return fn(t)
return inner
def benchmark(fn):
def inner(*args, **kwargs):
start = time.time()
res = fn(*args, **kwargs)
diff = time.time() - start
return diff, res
return inner
def cache(cache, key_fn):
def cache_inner(fn):
@wraps(fn)
def inner(*args, **kwargs):
key_name = key_fn(*args, **kwargs)
if key_name in cache:
return cache[key_name]
res = fn(*args, **kwargs)
cache[key_name] = res
return res
return inner
return cache_inner
def cache_dir(dirname, maxsize=128):
'''
Cache a function with a directory
:param dirname: the directory path
:param maxsize: maximum size of the RAM cache (there is no limit for the directory cache)
'''
def decorator(func):
@lru_cache(maxsize=maxsize)
@wraps(func)
def wrapper(*args, **kwargs):
if not exists(dirname):
return func(*args, **kwargs)
os.makedirs(dirname, exist_ok=True)
indexfile = os.path.join(dirname, "index.pkl")
lock = FileLock(os.path.join(dirname, "mutex"))
with lock:
index = {}
if os.path.exists(indexfile):
with open(indexfile, "rb") as file:
index = pickle.load(file)
key = (args, frozenset(kwargs), func.__defaults__)
if key in index:
filename = index[key]
else:
index[key] = filename = f"{len(index)}.pkl.gz"
with open(indexfile, "wb") as file:
pickle.dump(index, file)
filepath = os.path.join(dirname, filename)
if os.path.exists(filepath):
with lock:
with gzip.open(filepath, "rb") as file:
result = pickle.load(file)
return result
print(f"compute {filename}... ", end="", flush=True)
result = func(*args, **kwargs)
print(f"save {filename}... ", end="", flush=True)
with lock:
with gzip.open(filepath, "wb") as file:
pickle.dump(result, file)
print("done")
return result
return wrapper
return decorator