Lucidrains 系列项目源码解析(二)
.\lucidrains\alphafold2\alphafold2_pytorch\constants.py
import torch
MAX_NUM_MSA = 20
MAX_NUM_TEMPLATES = 10
NUM_AMINO_ACIDS = 21
NUM_EMBEDDS_TR = 1280
NUM_EMBEDDS_T5 = 1024
NUM_COORDS_PER_RES = 14
DISTOGRAM_BUCKETS = 37
THETA_BUCKETS = 25
PHI_BUCKETS = 13
OMEGA_BUCKETS = 25
MSA_EMBED_DIM = 768
MSA_MODEL_PATH = ["facebookresearch/esm", "esm_msa1_t12_100M_UR50S"]
ESM_EMBED_DIM = 1280
ESM_MODEL_PATH = ["facebookresearch/esm", "esm1b_t33_650M_UR50S"]
PROTTRAN_EMBED_DIM = 1024
DEVICE_NAME = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE = torch.device(DEVICE_NAME)
AA_DATA = {
'A': {
'bonds': [[0,1], [1,2], [2,3], [1,4]]
},
'R': {
'bonds': [[0,1], [1,2], [2,3], [2,4], [4,5], [5,6],
[6,7], [7,8], [8,9], [8,10]]
},
'N': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[5,7]]
},
'D': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[5,7]]
},
'C': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5]]
},
'Q': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[6,7], [6,8]]
},
'E': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[6,7], [7,8]]
},
'G': {
'bonds': [[0,1], [1,2], [2,3]]
},
'H': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[6,7], [7,8], [8,9], [5,9]]
},
'I': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[4,7]]
},
'L': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[5,7]]
},
'K': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[6,7], [7,8]]
},
'M': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[6,7]]
},
'F': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[6,7], [7,8], [8,9], [9,10], [5,10]]
},
'P': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[0,6]]
},
'S': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5]]
},
'T': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [4,6]]
},
'W': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[6,7], [7,8], [8,9], [9,10], [10,11], [11,12],
[12, 13], [5,13], [8,13]]
},
'Y': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
[6,7], [7,8], [8,9], [8,10], [10,11], [5,11]]
},
'V': {
'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [4,6]]
},
'_': {
'bonds': []
}
}
.\lucidrains\alphafold2\alphafold2_pytorch\embeds.py
import torch
import torch.nn.functional as F
from torch import nn
from alphafold2_pytorch.utils import get_msa_embedd, get_esm_embedd, get_prottran_embedd, exists
from alphafold2_pytorch.constants import MSA_MODEL_PATH, MSA_EMBED_DIM, ESM_MODEL_PATH, ESM_EMBED_DIM, PROTTRAN_EMBED_DIM
from einops import rearrange
class ProtTranEmbedWrapper(nn.Module):
def __init__(self, *, alphafold2):
super().__init__()
from transformers import AutoTokenizer, AutoModel
self.alphafold2 = alphafold2
self.project_embed = nn.Linear(PROTTRAN_EMBED_DIM, alphafold2.dim)
self.tokenizer = AutoTokenizer.from_pretrained('Rostlab/prot_bert', do_lower_case=False)
self.model = AutoModel.from_pretrained('Rostlab/prot_bert')
def forward(self, seq, msa, msa_mask = None, **kwargs):
device = seq.device
num_msa = msa.shape[1]
msa_flat = rearrange(msa, 'b m n -> (b m) n')
seq_embed = get_prottran_embedd(seq, self.model, self.tokenizer, device = device)
msa_embed = get_prottran_embedd(msa_flat, self.model, self.tokenizer, device = device)
seq_embed, msa_embed = map(self.project_embed, (seq_embed, msa_embed))
msa_embed = rearrange(msa_embed, '(b m) n d -> b m n d', m = num_msa)
return self.alphafold2(seq, msa, seq_embed = seq_embed, msa_embed = msa_embed, msa_mask = msa_mask, **kwargs)
class MSAEmbedWrapper(nn.Module):
def __init__(self, *, alphafold2):
super().__init__()
self.alphafold2 = alphafold2
model, alphabet = torch.hub.load(*MSA_MODEL_PATH)
batch_converter = alphabet.get_batch_converter()
self.model = model
self.batch_converter = batch_converter
self.project_embed = nn.Linear(MSA_EMBED_DIM, alphafold2.dim) if MSA_EMBED_DIM != alphafold2.dim else nn.Identity()
def forward(self, seq, msa, msa_mask = None, **kwargs):
assert seq.shape[-1] == msa.shape[-1], 'sequence and msa must have the same length if you wish to use MSA transformer embeddings'
model, batch_converter, device = self.model, self.batch_converter, seq.device
seq_and_msa = torch.cat((seq.unsqueeze(1), msa), dim = 1)
if exists(msa_mask):
num_msa = msa_mask.any(dim = -1).sum(dim = -1).tolist()
seq_and_msa_list = seq_and_msa.unbind(dim = 0)
num_rows = seq_and_msa.shape[1]
embeds = []
for num, batch_el in zip(num_msa, seq_and_msa_list):
batch_el = rearrange(batch_el, '... -> () ...')
batch_el = batch_el[:, :num]
embed = get_msa_embedd(batch_el, model, batch_converter, device = device)
embed = F.pad(embed, (0, 0, 0, 0, 0, num_rows - num), value = 0.)
embeds.append(embed)
embeds = torch.cat(embeds, dim = 0)
else:
embeds = get_msa_embedd(seq_and_msa, model, batch_converter, device = device)
embeds = self.project_embed(embeds)
seq_embed, msa_embed = embeds[:, 0], embeds[:, 1:]
return self.alphafold2(seq, msa, seq_embed = seq_embed, msa_embed = msa_embed, msa_mask = msa_mask, **kwargs)
class ESMEmbedWrapper(nn.Module):
def __init__(self, *, alphafold2):
super().__init__()
self.alphafold2 = alphafold2
model, alphabet = torch.hub.load(*ESM_MODEL_PATH)
batch_converter = alphabet.get_batch_converter()
self.model = model
self.batch_converter = batch_converter
self.project_embed = nn.Linear(ESM_EMBED_DIM, alphafold2.dim) if ESM_EMBED_DIM != alphafold2.dim else nn.Identity()
def forward(self, seq, msa=None, **kwargs):
model, batch_converter, device = self.model, self.batch_converter, seq.device
seq_embeds = get_esm_embedd(seq, model, batch_converter, device = device)
seq_embeds = self.project_embed(seq_embeds)
if msa is not None:
flat_msa = rearrange(msa, 'b m n -> (b m) n')
msa_embeds = get_esm_embedd(flat_msa, model, batch_converter, device = device)
msa_embeds = rearrange(msa_embeds, '(b m) n d -> b m n d')
msa_embeds = self.project_embed(msa_embeds)
else:
msa_embeds = None
return self.alphafold2(seq, msa, seq_embed = seq_embeds, msa_embed = msa_embeds, **kwargs)
.\lucidrains\alphafold2\alphafold2_pytorch\mlm.py
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from alphafold2_pytorch import constants
from einops import rearrange
def get_mask_subset_with_prob(mask, prob):
batch, seq_len, device = *mask.shape, mask.device
max_masked = math.ceil(prob * seq_len)
num_tokens = mask.sum(dim=-1, keepdim=True)
mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil())
mask_excess = mask_excess[:, :max_masked]
rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
_, sampled_indices = rand.topk(max_masked, dim=-1)
sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)
new_mask = torch.zeros((batch, seq_len + 1), device=device)
new_mask.scatter_(-1, sampled_indices, 1)
return new_mask[:, 1:].bool()
class MLM(nn.Module):
def __init__(
self,
dim,
num_tokens,
mask_id,
mask_prob = 0.15,
random_replace_token_prob = 0.1,
keep_token_same_prob = 0.1,
exclude_token_ids = (0,)
):
super().__init__()
self.to_logits = nn.Linear(dim, num_tokens)
self.mask_id = mask_id
self.mask_prob = mask_prob
self.exclude_token_ids = exclude_token_ids
self.keep_token_same_prob = keep_token_same_prob
self.random_replace_token_prob = random_replace_token_prob
def noise(self, seq, mask):
num_msa = seq.shape[1]
seq = rearrange(seq, 'b n ... -> (b n) ...')
mask = rearrange(mask, 'b n ... -> (b n) ...')
excluded_tokens_mask = mask
for token_id in self.exclude_token_ids:
excluded_tokens_mask = excluded_tokens_mask & (seq != token_id)
mlm_mask = get_mask_subset_with_prob(excluded_tokens_mask, self.mask_prob)
replace_token_with_mask = get_mask_subset_with_prob(mlm_mask, 1. - self.keep_token_same_prob)
seq = seq.masked_fill(mlm_mask, self.mask_id)
random_replace_token_prob_mask = get_mask_subset_with_prob(mlm_mask, (1 - self.keep_token_same_prob) * self.random_replace_token_prob)
random_tokens = torch.randint(1, constants.NUM_AMINO_ACIDS, seq.shape).to(seq.device)
for token_id in self.exclude_token_ids:
random_replace_token_prob_mask = random_replace_token_prob_mask & (random_tokens != token_id)
noised_seq = torch.where(random_replace_token_prob_mask, random_tokens, seq)
noised_seq = rearrange(noised_seq, '(b n) ... -> b n ...', n = num_msa)
mlm_mask = rearrange(mlm_mask, '(b n) ... -> b n ...', n = num_msa)
return noised_seq, mlm_mask
def forward(self, seq_embed, original_seq, mask):
logits = self.to_logits(seq_embed)
seq_logits = logits[mask]
seq_labels = original_seq[mask]
loss = F.cross_entropy(seq_logits, seq_labels, reduction = 'mean')
return loss
.\lucidrains\alphafold2\alphafold2_pytorch\reversible.py
import torch
import torch.nn as nn
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states
from contextlib import contextmanager
from einops import reduce
def exists(val):
return val is not None
@contextmanager
def null_context():
yield
def split_at_index(dim, index, t):
pre_slices = (slice(None),) * dim
l = (*pre_slices, slice(None, index))
r = (*pre_slices, slice(index, None))
return t[l], t[r]
class Deterministic(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.cpu_state = None
self.cuda_in_fwd = None
self.gpu_devices = None
self.gpu_states = None
def record_rng(self, *args):
self.cpu_state = torch.get_rng_state()
if torch.cuda._initialized:
self.cuda_in_fwd = True
self.gpu_devices, self.gpu_states = get_device_states(*args)
def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
if record_rng:
self.record_rng(*args)
if not set_rng:
return self.net(*args, **kwargs)
rng_devices = []
if self.cuda_in_fwd:
rng_devices = self.gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=True):
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
return self.net(*args, **kwargs)
class ReversibleSelfAttnBlock(nn.Module):
def __init__(self, f, g, j, k):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
self.j = Deterministic(j)
self.k = Deterministic(k)
def forward(self, x, m, mask = None, msa_mask = None, seq_shape = None, msa_shape = None, seq_pos_emb = None, msa_pos_emb = None, _reverse = True, **kwargs):
x1, x2 = torch.chunk(x, 2, dim = 2)
m1, m2 = torch.chunk(m, 2, dim = 2)
y1, y2, n1, n2 = None, None, None, None
context = torch.no_grad if _reverse else null_context
record_rng = self.training and _reverse
with context():
y1 = x1 + self.f(x2, shape = seq_shape, record_rng = record_rng, mask = mask, rotary_emb = seq_pos_emb)
y2 = x2 + self.g(y1, shape = seq_shape, record_rng = record_rng)
n1 = m1 + self.j(m2, shape = msa_shape, record_rng = record_rng, mask = msa_mask, rotary_emb = msa_pos_emb)
n2 = m2 + self.k(n1, record_rng = record_rng)
return torch.cat((y1, y2), dim = 2), torch.cat((n1, n2), dim = 2)
def backward_pass(self, y, n, dy, dn, mask = None, msa_mask = None, seq_shape = None, msa_shape = None, seq_pos_emb = None, msa_pos_emb = None, **kwargs):
y1, y2 = torch.chunk(y, 2, dim = 2)
del y
dy1, dy2 = torch.chunk(dy, 2, dim = 2)
del dy
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.g(y1, shape = seq_shape, set_rng = True)
torch.autograd.backward(gy1, dy2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, shape = seq_shape, set_rng = True, mask = mask, rotary_emb = seq_pos_emb)
torch.autograd.backward(fx2, dx1, retain_graph = True)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
x = torch.cat([x1, x2.detach()], dim = 2)
dx = torch.cat([dx1, dx2], dim = 2)
n1, n2 = torch.chunk(n, 2, dim = 2)
del n
dn1, dn2 = torch.chunk(dn, 2, dim = 2)
del dn
with torch.enable_grad():
n1.requires_grad = True
gn1 = self.k(n1, set_rng = True)
torch.autograd.backward(gn1, dn2)
with torch.no_grad():
m2 = n2 - gn1
del n2, gn1
dm1 = dn1 + n1.grad
del dn1
n1.grad = None
with torch.enable_grad():
m2.requires_grad = True
fm2 = self.j(m2, shape = msa_shape, set_rng = True, mask = msa_mask, rotary_emb = msa_pos_emb)
torch.autograd.backward(fm2, dm1, retain_graph=True)
with torch.no_grad():
m1 = n1 - fm2
del n1, fm2
dm2 = dn2 + m2.grad
del dn2
m2.grad = None
m = torch.cat([m1, m2.detach()], dim = 2)
dm = torch.cat([dm1, dm2], dim = 2)
return x, m, dx, dm
class ReversibleCrossAttnBlock(nn.Module):
def __init__(self, f, g, j, k):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
self.j = Deterministic(j)
self.k = Deterministic(k)
def forward(self, x, m, mask = None, msa_mask = None, seq_shape = None, msa_shape = None, seq_to_msa_pos_emb = None, msa_to_seq_pos_emb = None, _reverse = True, **kwargs):
x1, x2 = torch.chunk(x, 2, dim = 2)
m1, m2 = torch.chunk(m, 2, dim = 2)
y1, y2, n1, n2 = None, None, None, None
context = torch.no_grad if _reverse else null_context
record_rng = self.training and _reverse
with context():
y1 = x1 + self.f(x2, m2, record_rng = record_rng, mask = mask, context_mask = msa_mask, shape = seq_shape, context_shape = msa_shape, rotary_emb = seq_to_msa_pos_emb)
y2 = x2 + self.k(y1, shape = seq_shape, record_rng = record_rng)
n1 = m1 + self.j(m2, y2, record_rng = record_rng, mask = msa_mask, context_mask = mask, shape = msa_shape, context_shape = seq_shape, rotary_emb = msa_to_seq_pos_emb)
n2 = m2 + self.g(n1, record_rng = record_rng)
return torch.cat((y1, y2), dim = 2), torch.cat((n1, n2), dim = 2)
def backward_pass(self, y, n, dy, dn, mask = None, msa_mask = None, seq_shape = None, msa_shape = None, seq_to_msa_pos_emb = None, msa_to_seq_pos_emb = None, **kwargs):
n1, n2 = torch.chunk(n, 2, dim = 2)
del n
dn1, dn2 = torch.chunk(dn, 2, dim = 2)
del dn
y1, y2 = torch.chunk(y, 2, dim = 2)
del y
dy1, dy2 = torch.chunk(dy, 2, dim = 2)
del dy
with torch.enable_grad():
n1.requires_grad = True
gn1 = self.g(n1, set_rng = True)
torch.autograd.backward(gn1, dn2)
with torch.no_grad():
m2 = n2 - gn1
del n2, gn1
dm1 = dn1 + n1.grad
del dn1
n1.grad = None
with torch.enable_grad():
m2.requires_grad = True
y2.requires_grad = True
fm2 = self.j(m2, y2, set_rng=True, mask = msa_mask, context_mask = mask, shape = msa_shape, context_shape = seq_shape, rotary_emb = msa_to_seq_pos_emb)
torch.autograd.backward(fm2, dm1)
with torch.no_grad():
m1 = n1 - fm2
del n1, fm2
dm2 = dn2 + m2.grad
dx2 = dy2 + y2.grad
del dn2
del dy2
m2.grad = None
y2.grad = None
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.k(y1, shape = seq_shape, set_rng = True)
torch.autograd.backward(gy1, dx2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
m2.requires_grad = True
fx2 = self.f(x2, m2, set_rng = True, mask = mask, context_mask = msa_mask, shape = seq_shape, context_shape = msa_shape, rotary_emb = seq_to_msa_pos_emb)
torch.autograd.backward(fx2, dx1)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dx2 + x2.grad
dm2 = dm2 + m2.grad
x2.grad = None
m2.grad = None
with torch.no_grad():
m = torch.cat([m1, m2.detach()], dim = 2)
dm = torch.cat([dm1, dm2], dim = 2)
x = torch.cat([x1, x2.detach()], dim = 2)
dx = torch.cat([dx1, dx2], dim = 2)
return x, m, dx, dm
class ReversibleFunction(Function):
@staticmethod
def forward(ctx, inp, ind, blocks, kwargs):
x, m = split_at_index(1, ind, inp)
for block in blocks:
x, m = block(x, m, _reverse = True, **kwargs)
ctx.blocks = blocks
ctx.kwargs = kwargs
ctx.ind = ind
ctx.save_for_backward(x.detach(), m.detach())
return torch.cat((x, m), dim = 1)
@staticmethod
def backward(ctx, d):
ind = ctx.ind
blocks = ctx.blocks
kwargs = ctx.kwargs
dy, dn = split_at_index(1, ind, d)
y, n = ctx.saved_tensors
for block in blocks[::-1]:
y, n, dy, dn = block.backward_pass(y, n, dy, dn, **kwargs)
d = torch.cat((dy, dn), dim = 1)
return d, None, None, None
reversible_apply = ReversibleFunction.apply
def irreversible_apply(inputs, ind, blocks, kwargs):
x, m = split_at_index(1, ind, inputs)
for block in blocks:
x, m = block(x, m, _reverse = False, **kwargs)
return torch.cat((x, m), dim = 1)
class ReversibleSequence(nn.Module):
def __init__(self, input_blocks, block_types):
super().__init__()
self.block_types = block_types
blocks = nn.ModuleList([])
for block, block_type in zip(input_blocks, block_types):
if block_type == 'self':
reversible_klass = ReversibleSelfAttnBlock
elif block_type == 'cross':
reversible_klass = ReversibleCrossAttnBlock
elif block_type == 'conv':
reversible_klass = ReversibleSelfAttnBlock
blocks.append(reversible_klass(*block))
self.blocks = blocks
def forward(
self,
seq,
msa,
seq_shape = None,
msa_shape = None,
mask = None,
msa_mask = None,
seq_pos_emb = None,
msa_pos_emb = None,
seq_to_msa_pos_emb = None,
msa_to_seq_pos_emb = None,
reverse = True
):
assert exists(msa), 'reversibility does not work with no MSA sequences yet'
blocks = self.blocks
seq, msa = list(map(lambda t: torch.cat((t, t), dim = -1), (seq, msa)))
kwargs = {'mask': mask, 'msa_mask': msa_mask, 'seq_shape': seq_shape, 'msa_shape': msa_shape, 'seq_pos_emb': seq_pos_emb, 'msa_pos_emb': msa_pos_emb, 'seq_to_msa_pos_emb': seq_to_msa_pos_emb, 'msa_to_seq_pos_emb': msa_to_seq_pos_emb}
fn = reversible_apply if reverse else irreversible_apply
ind = seq.shape[1]
inp = torch.cat((seq, msa), dim = 1)
out = fn(inp, ind, blocks, kwargs)
seq, msa = split_at_index(1, ind, out)
return list(map(lambda t: reduce(t, 'b n (c d) -> b n d', 'mean', c = 2), (seq, msa)))
.\lucidrains\alphafold2\alphafold2_pytorch\rotary.py
def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j = 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d j -> ... (d j)')
def apply_rotary_pos_emb(x, sinu_pos):
sin, cos = map(lambda t: rearrange(t, 'b ... -> b () ...'), sinu_pos)
rot_dim = sin.shape[-1]
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
x = x * cos + rotate_every_two(x) * sin
return torch.cat((x, x_pass), dim = -1)
class DepthWiseConv1d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding = 0, stride = 1, bias = True, groups = None):
super().__init__()
groups = default(groups, dim_in)
self.net = nn.Sequential(
nn.Conv1d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = groups, stride = stride, bias = bias),
nn.Conv1d(dim_in, dim_out, 1, bias = bias)
)
def forward(self, x):
return self.net(x)
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, n, device):
seq = torch.arange(n, device = device).type_as(self.inv_freq)
freqs = einsum('i , j -> i j', seq, self.inv_freq)
freqs = repeat(freqs, 'i j -> () i (j r)', r = 2)
return [freqs.sin(), freqs.cos()]
class AxialRotaryEmbedding(nn.Module):
def __init__(self, dim, max_freq = 10):
super().__init__()
self.dim = dim // 2
inv_freq = 1. / (10000 ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, n, device):
seq = torch.arange(n, device = device).type_as(self.inv_freq)
x = einsum('n, d -> n d', seq, self.inv_freq)
y = einsum('n, d -> n d', seq, self.inv_freq)
x_sinu = repeat(x, 'i d -> i j d', j = n)
y_sinu = repeat(y, 'j d -> i j d', i = n)
sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1)
cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1)
sin, cos = map(lambda t: repeat(t, 'i j d -> () (i j) (d r)', r = 2), (sin, cos))
return [sin, cos]
.\lucidrains\alphafold2\alphafold2_pytorch\utils.py
import os
import re
import numpy as np
import torch
import contextlib
from functools import wraps
from einops import rearrange, repeat
from Bio import SeqIO
import itertools
import string
from sidechainnet.utils.sequence import ProteinVocabulary, ONE_TO_THREE_LETTER_MAP
from sidechainnet.utils.measure import GLOBAL_PAD_CHAR
from sidechainnet.structure.build_info import NUM_COORDS_PER_RES, BB_BUILD_INFO, SC_BUILD_INFO
from sidechainnet.structure.StructureBuilder import _get_residue_build_iter
import mp_nerf
VOCAB = ProteinVocabulary()
import alphafold2_pytorch.constants as constants
def exists(val):
return val is not None
DISTANCE_THRESHOLDS = torch.linspace(2, 20, steps = constants.DISTOGRAM_BUCKETS)
def get_bucketed_distance_matrix(coords, mask, num_buckets = constants.DISTOGRAM_BUCKETS, ignore_index = -100):
distances = torch.cdist(coords, coords, p=2)
boundaries = torch.linspace(2, 20, steps = num_buckets, device = coords.device)
discretized_distances = torch.bucketize(distances, boundaries[:-1])
discretized_distances.masked_fill_(~(mask[..., None] & mask[..., None, :]), ignore_index)
return discretized_distances
def set_backend_kwarg(fn):
@wraps(fn)
def inner(*args, backend = 'auto', **kwargs):
if backend == 'auto':
backend = 'torch' if isinstance(args[0], torch.Tensor) else 'numpy'
kwargs.update(backend = backend)
return fn(*args, **kwargs)
return inner
def expand_dims_to(t, length = 3):
if length == 0:
return t
return t.reshape(*((1,) * length), *t.shape)
def expand_arg_dims(dim_len = 3):
""" pack here for reuse.
turns input into (B x D x N)
"""
def outer(fn):
@wraps(fn)
def inner(x, y, **kwargs):
assert len(x.shape) == len(y.shape), "Shapes of A and B must match."
remaining_len = dim_len - len(x.shape)
x = expand_dims_to(x, length = remaining_len)
y = expand_dims_to(y, length = remaining_len)
return fn(x, y, **kwargs)
return inner
return outer
def invoke_torch_or_numpy(torch_fn, numpy_fn):
def outer(fn):
@wraps(fn)
def inner(*args, **kwargs):
backend = kwargs.pop('backend')
passed_args = fn(*args, **kwargs)
passed_args = list(passed_args)
if isinstance(passed_args[-1], dict):
passed_kwargs = passed_args.pop()
else:
passed_kwargs = {}
backend_fn = torch_fn if backend == 'torch' else numpy_fn
return backend_fn(*passed_args, **passed_kwargs)
return inner
return outer
@contextlib.contextmanager
def torch_default_dtype(dtype):
prev_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(prev_dtype)
def get_atom_ids_dict():
""" 获取将每个原子映射到令牌的字典。 """
ids = set(["", "N", "CA", "C", "O"])
for k,v in SC_BUILD_INFO.items():
for name in v["atom-names"]:
ids.add(name)
return {k: i for i,k in enumerate(sorted(ids))}
def make_cloud_mask(aa):
""" 重要点为1,填充点为0。 """
mask = np.zeros(constants.NUM_COORDS_PER_RES)
if aa == "_":
return mask
n_atoms = 4+len( SC_BUILD_INFO[ ONE_TO_THREE_LETTER_MAP[aa] ]["atom-names"] )
mask[:n_atoms] = 1
return mask
def make_atom_id_embedds(aa, atom_ids):
""" 返回aa中每个原子的令牌。 """
mask = np.zeros(constants.NUM_COORDS_PER_RES)
if aa == "_":
return mask
atom_list = ["N", "CA", "C", "O"] + SC_BUILD_INFO[ ONE_TO_THREE_LETTER_MAP[aa] ]["atom-names"]
for i,atom in enumerate(atom_list):
mask[i] = ATOM_IDS[atom]
return mask
ATOM_IDS = get_atom_ids_dict()
CUSTOM_INFO = {k: {"cloud_mask": make_cloud_mask(k),
"atom_id_embedd": make_atom_id_embedds(k, atom_ids=ATOM_IDS),
} for k in "ARNDCQEGHILKMFPSTWYV_"}
def download_pdb(name, route):
""" Downloads a PDB entry from the RCSB PDB.
Inputs:
* name: str. the PDB entry id. 4 characters, capitalized.
* route: str. route of the destin file. usually ".pdb" extension
Output: route of destin file
"""
os.system(f"curl https://files.rcsb.org/download/{name}.pdb > {route}")
return route
def clean_pdb(name, route=None, chain_num=None):
""" Cleans the structure to only leave the important part.
Inputs:
* name: str. route of the input .pdb file
* route: str. route of the output. will overwrite input if not provided
* chain_num: int. index of chain to select (1-indexed as pdb files)
Output: route of destin file.
"""
import mdtraj
destin = route if route is not None else name
raw_prot = mdtraj.load_pdb(name)
idxs = []
for chain in raw_prot.topology.chains:
if chain_num is not None:
if chain_num != chain.index:
continue
chain_idxs = raw_prot.topology.select(f"chainid == {str(chain.index)}")
idxs.extend( chain_idxs.tolist() )
idxs = sorted(idxs)
prot = mdtraj.Trajectory(xyz=raw_prot.xyz[:, idxs],
topology=raw_prot.topology.subset(idxs))
prot.save(destin)
return destin
def custom2pdb(coords, proteinnet_id, route):
""" Takes a custom representation and turns into a .pdb file.
Inputs:
* coords: array/tensor of shape (3 x N) or (N x 3). in Angstroms.
same order as in the proteinnnet is assumed (same as raw pdb file)
* proteinnet_id: str. proteinnet id format (<class>#<pdb_id>_<chain_number>_<chain_id>)
see: https://github.com/aqlaboratory/proteinnet/
* route: str. destin route.
Output: tuple of routes: (original, generated) for the structures.
"""
import mdtraj
if isinstance(coords, torch.Tensor):
coords = coords.detach().cpu().numpy()
if coords.shape[1] == 3:
coords = coords.T
coords = np.newaxis(coords, axis=0)
pdb_name, chain_num = proteinnet_id.split("#")[-1].split("_")[:-1]
pdb_destin = "/".join(route.split("/")[:-1])+"/"+pdb_name+".pdb"
download_pdb(pdb_name, pdb_destin)
clean_pdb(pdb_destin, chain_num=chain_num)
scaffold = mdtraj.load_pdb(pdb_destin)
scaffold.xyz = coords
scaffold.save(route)
return pdb_destin, route
def coords2pdb(seq, coords, cloud_mask, prefix="", name="af2_struct.pdb"):
""" Turns coordinates into PDB files ready to be visualized.
Inputs:
* seq: (L,) tensor of ints (sidechainnet aa-key pairs)
* coords: (3, N) coords of atoms
* cloud_mask: (L, C) boolean mask of occupied spaces in scn format
* prefix: str. directory to save files.
* name: str. name of destin file (ex: pred1.pdb)
"""
scaffold = torch.zeros( cloud_mask.shape, 3 )
scaffold[cloud_mask] = coords.cpu().float()
pred = scn.StructureBuilder( seq, crd=scaffold )
pred.to_pdb(prefix+name)
def remove_insertions(sequence: str) -> str:
""" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
deletekeys = dict.fromkeys(string.ascii_lowercase)
deletekeys["."] = None
deletekeys["*"] = None
translation = str.maketrans(deletekeys)
return sequence.translate(translation)
def read_msa(filename: str, nseq: int):
""" Reads the first nseq sequences from an MSA file, automatically removes insertions."""
return [(record.description, remove_insertions(str(record.seq)))
for record in itertools.islice(SeqIO.parse(filename, "fasta"), nseq)]
def ids_to_embed_input(x):
""" Returns the amino acid string input for calculating the ESM and MSA transformer embeddings
Inputs:
* x: any deeply nested list of integers that correspond with amino acid id
"""
assert isinstance(x, list), 'input must be a list'
id2aa = VOCAB._int2char
out = []
for el in x:
if isinstance(el, list):
out.append(ids_to_embed_input(el))
elif isinstance(el, int):
out.append(id2aa[el])
else:
raise TypeError('type must be either list or character')
if all(map(lambda c: isinstance(c, str), out)):
return (None, ''.join(out))
return out
def ids_to_prottran_input(x):
""" Returns the amino acid string input for calculating the ESM and MSA transformer embeddings
Inputs:
* x: any deeply nested list of integers that correspond with amino acid id
"""
assert isinstance(x, list), 'input must be a list'
id2aa = VOCAB._int2char
out = []
for ids in x:
chars = ' '.join([id2aa[i] for i in ids])
chars = re.sub(r"[UZOB]", "X", chars)
out.append(chars)
return out
def get_prottran_embedd(seq, model, tokenizer, device = None):
from transformers import pipeline
fe = pipeline('feature-extraction', model = model, tokenizer = tokenizer, device = (-1 if not exists(device) else device.index))
max_seq_len = seq.shape[1]
embedd_inputs = ids_to_prottran_input(seq.cpu().tolist())
embedding = fe(embedd_inputs)
embedding = torch.tensor(embedding, device = device)
return embedding[:, 1:(max_seq_len + 1)]
def get_msa_embedd(msa, embedd_model, batch_converter, device = None):
""" Returns the MSA_tr embeddings for a protein.
Inputs:
* seq: ( (b,) L,) tensor of ints (in sidechainnet int-char convention)
* embedd_model: MSA_tr model (see train_end2end.py for an example)
* batch_converter: MSA_tr batch converter (see train_end2end.py for an example)
Outputs: tensor of (batch, n_seqs, L, embedd_dim)
* n_seqs: number of sequences in the MSA
* embedd_dim: number of embedding dimensions. 768 for MSA_Transformer
"""
REPR_LAYER_NUM = 12
device = seq.device
max_seq_len = msa.shape[-1]
embedd_inputs = ids_to_embed_input(msa.cpu().tolist())
msa_batch_labels, msa_batch_strs, msa_batch_tokens = batch_converter(embedd_inputs)
with torch.no_grad():
results = embedd_model(msa_batch_tokens.to(device), repr_layers=[REPR_LAYER_NUM], return_contacts=False)
token_reps = results["representations"][REPR_LAYER_NUM][..., 1:max_seq_len+1, :]
return token_reps
def get_esm_embedd(seq, embedd_model, batch_converter, msa_data=None):
""" Returns the ESM embeddings for a protein.
Inputs:
* seq: ( (b,) L,) tensor of ints (in sidechainnet int-char convention)
* embedd_model: ESM model (see train_end2end.py for an example)
* batch_converter: ESM batch converter (see train_end2end.py for an example)
Outputs: tensor of (batch, n_seqs, L, embedd_dim)
* n_seqs: number of sequences in the MSA. 1 for ESM-1b
* embedd_dim: number of embedding dimensions. 1280 for ESM-1b
"""
device = seq.device
REPR_LAYER_NUM = 33
max_seq_len = seq.shape[-1]
embedd_inputs = ids_to_embed_input(seq.cpu().tolist())
batch_labels, batch_strs, batch_tokens = batch_converter(embedd_inputs)
with torch.no_grad():
results = embedd_model(batch_tokens.to(device), repr_layers=[REPR_LAYER_NUM], return_contacts=False)
token_reps = results["representations"][REPR_LAYER_NUM][..., 1:max_seq_len+1, :].unsqueeze(dim=1)
return token_reps
def get_t5_embedd(seq, tokenizer, encoder, msa_data=None, device=None):
""" Returns the ProtT5-XL-U50 embeddings for a protein.
Inputs:
* seq: ( (b,) L,) tensor of ints (in sidechainnet int-char convention)
* tokenizer: tokenizer model: T5Tokenizer
* encoder: encoder model: T5EncoderModel
ex: from transformers import T5EncoderModel, T5Tokenizer
model_name = "Rostlab/prot_t5_xl_uniref50"
tokenizer = T5Tokenizer.from_pretrained(model_name, do_lower_case=False )
model = T5EncoderModel.from_pretrained(model_name)
# prepare model
model = model.to(device)
model = model.eval()
if torch.cuda.is_available():
model = model.half()
Outputs: tensor of (batch, n_seqs, L, embedd_dim)
* n_seqs: number of sequences in the MSA. 1 for T5 models
* embedd_dim: number of embedding dimensions. 1024 for T5 models
"""
device = seq.device if device is None else device
embedd_inputs = ids_to_prottran_input(seq.cpu().tolist())
inputs_embedding = []
shift_left, shift_right = 0, -1
ids = tokenizer.batch_encode_plus(embedd_inputs, add_special_tokens=True,
padding=True,
return_tensors="pt")
with torch.no_grad():
embedding = encoder(input_ids=torch.tensor(ids['input_ids']).to(device),
attention_mask=torch.tensor(ids["attention_mask"]).to(device))
token_reps = embedding.last_hidden_state[:, shift_left:shift_right].to(device)
token_reps = expand_dims_to(token_reps, 4-len(token_reps.shape))
return token_reps.float()
def get_all_protein_ids(dataloader, verbose=False):
""" Given a sidechainnet dataloader for a CASP version,
Returns all the ids belonging to proteins.
Inputs:
* dataloader: a sidechainnet dataloader for a CASP version
Outputs: a set containing the ids for all protein entries.
"""
ids = set([])
for i,batch in tqdm(enumerate(dataloaders['train'])):
try:
for i in range(batch.int_seqs.shape[0]):
max_len_10 = len(batch.pids[i]) < 10
fragments = [len(x) <= 4 for x in batch.pids[i].split("_")]
fragments_under_4 = sum(fragments) == len(fragments)
if max_len_10 and fragments_under_4:
ids.add(batch.pids[i])
else:
if verbose:
print("skip:", batch.pids[i], "under 4", fragments)
except StopIteration:
break
return ids
def scn_cloud_mask(scn_seq, boolean=True, coords=None):
""" Gets the boolean mask atom positions (not all aas have same atoms).
Inputs:
* scn_seq: (batch, length) sequence as provided by Sidechainnet package
* boolean: whether to return as array of idxs or boolean values
* coords: optional .(batch, lc, 3). sidechainnet coords.
returns the true mask (solves potential atoms that might not be provided)
Outputs: (batch, length, NUM_COORDS_PER_RES) boolean mask
"""
scn_seq = expand_dims_to(scn_seq, 2 - len(scn_seq.shape))
if coords is not None:
batch_mask = ( rearrange(coords, '... (l c) d -> ... l c d', c=constants.NUM_COORDS_PER_RES) == 0 ).sum(dim=-1) < coords.shape[-1]
if boolean:
return batch_mask.bool()
else:
return batch_mask.nonzero()
device = scn_seq.device
batch_mask = []
scn_seq = scn_seq.cpu().tolist()
for i, seq in enumerate(scn_seq):
batch_mask.append( torch.tensor([CUSTOM_INFO[VOCAB._int2char[aa]]['cloud_mask'] \
for aa in seq]).bool().to(device) )
batch_mask = torch.stack(batch_mask, dim=0)
if boolean:
return batch_mask.bool()
else:
return batch_mask.nonzero()
def scn_backbone_mask(scn_seq, boolean=True, n_aa=3):
""" Gets the boolean mask for N and CA positions.
Inputs:
* scn_seq: sequence(s) as provided by Sidechainnet package (int tensor/s)
* n_aa: number of atoms in a backbone. (may include cbeta as 4th pos)
* bool: whether to return as array of idxs or boolean values
Outputs: (N_mask, CA_mask, C_mask)
"""
wrapper = torch.zeros(*scn_seq.shape, n_aa).to(scn_seq.device)
wrapper[..., 0] = 1
wrapper[..., 1] = 2
wrapper[..., 2] = 3
wrapper = rearrange(wrapper, '... l c -> ... (l c)')
N_mask = wrapper == 1
CA_mask = wrapper == 2
C_mask = wrapper == 3
if boolean:
return N_mask, CA_mask, C_mask
return torch.nonzero(N_mask), torch.nonzero(CA_mask), torch.nonzero(C_mask)
def scn_atom_embedd(scn_seq):
""" Returns the token for each atom in the aa.
Inputs:
* scn_seq: sequence(s) as provided by Sidechainnet package (int tensor/s)
"""
device = scn_seq.device
batch_tokens = []
scn_seq = scn_seq.cpu().tolist()
for i,seq in enumerate(scn_seq):
batch_tokens.append( torch.tensor([CUSTOM_INFO[VOCAB.int2char(aa)]["atom_id_embedd"] \
for aa in seq]) )
batch_tokens = torch.stack(batch_tokens, dim=0).long().to(device)
return batch_tokens
def mat_input_to_masked(x, x_mask=None, edges_mat=None, edges=None,
edge_mask=None, edge_attr_mat=None,
edge_attr=None):
""" Turns the padded input and edges + mask into the
non-padded inputs and edges.
At least one of (edges_mat, edges) must be provided.
The same format for edges and edge_attr must be provided
(either adj matrix form or flattened form).
Inputs:
* x: ((batch), N, D) a tensor of N nodes and D dims for each one
* x_mask: ((batch), N,) boolean mask for x
* edges: (2, E) optional. indices of the corresponding adjancecy matrix.
* edges_mat: ((batch), N, N) optional. adjacency matrix for x
* edge_mask: optional. boolean mask of the same shape of either "edge_mat" or "edges".
* edge_attr: (E, D_edge) optional. edge attributes of D_edge dims.
* edge_attr_mat: ((batch), N, N) optional. adjacency matrix with features
Outputs:
* x: (N_, D) the masked node features
* edge_index: (2, E_) the masked x-indices for the edges
* edge_attr: (E_, D_edge) the masked edge attributes
* batch: (N_,) the corresponding index in the batch for each node
"""
if len(x.shape) == 3:
batch_dim = x.shape[1]
x = rearrange(x, 'b n d ... -> (b n) d ...')
if x_mask is not None:
x_mask = rearrange(x_mask, 'b n ... -> (b n) ...')
else:
x_mask = torch.ones_like(x[..., 0]).bool()
if edges_mat is not None and edges is None:
edges = torch.nonzero(edges_mat, as_tuple=False).t()
edges = edges[1:] + edges[:1]*batch_dim
batch = (torch.arange(x.shape[0], device=x.device) // batch_dim)[x_mask]
else:
if edges_mat is not None and edges is None:
edges = torch.nonzero(edges_mat, as_tuple=False).t()
batch = torch.zeros(x.shape[0], device=x.device).to(x.device)
if edge_attr_mat is not None and edge_attr is None:
edge_attr = edge_attr[edges_mat.bool()]
if edge_mask is None:
edge_mask = torch.ones_like(edges[-1]).bool()
x = x[x_mask]
max_num = edges.max().item()+1
wrapper = torch.zeros(max_num, max_num).to(x.device)
wrapper[edges[0][edge_mask], edges[1][edge_mask]] = 1
wrapper = wrapper[x_mask, :][:, x_mask]
edge_index = torch.nonzero(wrapper, as_tuple=False).t()
edge_attr = edge_attr[edge_mask] if edge_attr is not None else None
return x, edge_index, edge_attr, batch
def nth_deg_adjacency(adj_mat, n=1, sparse=False):
""" Calculates the n-th degree adjacency matrix.
计算第 n 次邻接矩阵。
Performs mm of adj_mat and adds the newly added.
执行 adj_mat 的矩阵乘法并添加新添加的部分。
Default is dense. Mods for sparse version are done when needed.
默认为密集矩阵。在需要时进行稀疏版本的修改。
Inputs:
* adj_mat: (N, N) adjacency tensor
* n: int. degree of the output adjacency
* sparse: bool. whether to use torch-sparse module
输入:
* adj_mat: (N, N) 邻接张量
* n: int。输出邻接的度
* sparse: bool。是否使用 torch-sparse 模块
Outputs:
* edge_idxs: ij positions of the adjacency matrix
* edge_attrs: degree of connectivity (1 for neighs, 2 for neighs^2, ... )
输出:
* edge_idxs: 邻接矩阵的 ij 位置
* edge_attrs: 连通度的度数(1 表示邻居,2 表示邻居的平方,...)
"""
adj_mat = adj_mat.float()
attr_mat = torch.zeros_like(adj_mat)
new_adj_mat = adj_mat.clone()
for i in range(n):
if i == 0:
attr_mat += adj_mat
continue
if i == 1 and sparse:
idxs = adj_mat.nonzero().t()
vals = adj_mat[idxs[0], idxs[1]]
new_idxs = idxs.clone()
new_vals = vals.clone()
m, k, n = 3 * [adj_mat.shape[0]]
if sparse:
new_idxs, new_vals = torch_sparse.spspmm(new_idxs, new_vals, idxs, vals, m=m, k=k, n=n)
new_vals = new_vals.bool().float()
previous = attr_mat[new_idxs[0], new_idxs[1]].bool().float()
attr_mat[new_idxs[0], new_idxs[1]] = (1 - previous)*(i+1)
else:
new_adj_mat = (new_adj_mat @ adj_mat).bool().float()
attr_mat.masked_fill( (new_adj_mat - attr_mat.bool().float()).bool(), i+1 )
return new_adj_mat, attr_mat
def prot_covalent_bond(seqs, adj_degree=1, cloud_mask=None, mat=True, sparse=False):
""" Returns the idxs of covalent bonds for a protein.
返回蛋白质的共价键的索引。
Inputs
* seq: (b, n) torch long.
* adj_degree: int. adjacency degree
* cloud_mask: mask selecting the present atoms.
* mat: whether to return as indexes of only atoms (PyG version)
or matrices of masked atoms (for batched training).
for indexes, only 1 seq is supported.
* sparse: bool. whether to use torch_sparse for adj_mat calc
输入
* seq: (b, n) torch long.
* adj_degree: int. 邻接度
* cloud_mask: 选择当前原子的掩码。
* mat: 是否返回仅原子的索引(PyG 版本)或掩码原子的矩阵(用于批量训练)。
对于索引,仅支持 1 个 seq。
* sparse: bool。是否使用 torch_sparse 计算 adj_mat
Outputs: edge_idxs, edge_types (degree of adjacency).
输出:edge_idxs, edge_types(邻接度)。
"""
device = seqs.device
next_aa = NUM_COORDS_PER_RES
adj_mat = torch.zeros(seqs.shape[0], *[seqs.shape[1]*NUM_COORDS_PER_RES]*2)
seq_list = seqs.cpu().tolist()
for s,seq in enumerate(seq_list):
next_idx = 0
for i,idx in enumerate(seq):
aa_bonds = constants.AA_DATA[VOCAB._int2char[idx]]['bonds']
if len(aa_bonds) == 0:
break
next_aa = max(aa_bonds, key=lambda x: max(x))[-1]
bonds = next_idx + torch.tensor( aa_bonds + [[2, next_aa]] ).t()
next_idx += next_aa
if i == seqs.shape[1] - 1:
bonds = bonds[:, :-1]
adj_mat[s, bonds[0], bonds[1]] = 1
adj_mat[s] = adj_mat[s] + adj_mat[s].t()
adj_mat, attr_mat = nth_deg_adjacency(adj_mat, n=adj_degree, sparse=sparse)
if mat:
return attr_mat.bool().to(seqs.device), attr_mat.to(device)
else:
edge_idxs = attr_mat[0].nonzero().t().long()
edge_types = attr_mat[0, edge_idxs[0], edge_idxs[1]]
return edge_idxs.to(seqs.device), edge_types.to(seqs.device)
def sidechain_container(seqs, backbones, atom_mask, cloud_mask=None, padding_tok=20):
""" Gets a backbone of the protein, returns the whole coordinates
with sidechains (same format as sidechainnet). Keeps differentiability.
Inputs:
* seqs: (batch, L) either tensor or list
* backbones: (batch, L*n_aa, 3): assume batch=1 (could be extended (?not tested)).
Coords for (N-term, C-alpha, C-term, (c_beta)) of every aa.
* atom_mask: (14,). int or bool tensor specifying which atoms are passed.
* cloud_mask: (batch, l, c). optional. cloud mask from scn_cloud_mask`.
sets point outside of mask to 0. if passed, else c_alpha
* padding: int. padding token. same as in sidechainnet: 20
Outputs: whole coordinates of shape (batch, L, 14, 3)
"""
atom_mask = atom_mask.bool().cpu().detach()
cum_atom_mask = atom_mask.cumsum(dim=-1).tolist()
device = backbones.device
batch, length = backbones.shape[0], backbones.shape[1] // cum_atom_mask[-1]
predicted = rearrange(backbones, 'b (l back) d -> b l back d', l=length)
if cum_atom_mask[-1] == 14:
return predicted
new_coords = torch.zeros(batch, length, constants.NUM_COORDS_PER_RES, 3)
predicted = predicted.cpu() if predicted.is_cuda else predicted
for i,atom in enumerate(atom_mask.tolist()):
if atom:
new_coords[:, :, i] = predicted[:, :, cum_atom_mask[i]-1]
for s,seq in enumerate(seqs):
if isinstance(seq, torch.Tensor):
padding = (seq == padding_tok).sum().item()
seq_str = ''.join([VOCAB._int2char[aa] for aa in seq.cpu().numpy()[:-padding or None]])
elif isinstance(seq, str):
padding = 0
seq_str = seq
scaffolds = mp_nerf.proteins.build_scaffolds_from_scn_angles(seq_str, angles=None, device="cpu")
coords, _ = mp_nerf.proteins.sidechain_fold(wrapper = new_coords[s, :-padding or None].detach(),
**scaffolds, c_beta = cum_atom_mask[4]==5)
for i,atom in enumerate(atom_mask.tolist()):
if not atom:
new_coords[:, :-padding or None, i] = coords[:, i]
new_coords = new_coords.to(device)
if cloud_mask is not None:
new_coords[torch.logical_not(cloud_mask)] = 0.
nan_mask = list(torch.nonzero(new_coords!=new_coords, as_tuple=True))
new_coords[nan_mask[0], nan_mask[1], nan_mask[2]] = new_coords[nan_mask[0],
nan_mask[1],
(nan_mask[-2]+1) % new_coords.shape[-1]
return new_coords.to(device)
def center_distogram_torch(distogram, bins=DISTANCE_THRESHOLDS, min_t=1., center="mean", wide="std"):
""" Returns the central estimate of a distogram. Median for now.
Inputs:
* distogram: (batch, N, N, B) where B is the number of buckets.
* bins: (B,) containing the cutoffs for the different buckets
* min_t: float. lower bound for distances.
Outputs:
* central: (batch, N, N)
* dispersion: (batch, N, N)
* weights: (batch, N, N)
"""
shape, device = distogram.shape, distogram.device
n_bins = ( bins - 0.5 * (bins[2] - bins[1]) ).to(device)
n_bins[0] = 1.5
n_bins[-1] = 1.33*bins[-1]
magnitudes = distogram.sum(dim=-1)
if center == "median":
cum_dist = torch.cumsum(distogram, dim=-1)
medium = 0.5 * cum_dist[..., -1:]
central = torch.searchsorted(cum_dist, medium).squeeze()
central = n_bins[torch.min(central, max_bin_allowed)]
elif center == "mean":
central = (distogram * n_bins).sum(dim=-1) / magnitudes
mask = (central <= bins[-2].item()).float()
diag_idxs = np.arange(shape[-2])
central = expand_dims_to(central, 3 - len(central.shape))
central[:, diag_idxs, diag_idxs] *= 0.
if wide == "var":
dispersion = (distogram * (n_bins - central.unsqueeze(-1))**2).sum(dim=-1) / magnitudes
elif wide == "std":
dispersion = ((distogram * (n_bins - central.unsqueeze(-1))**2).sum(dim=-1) / magnitudes).sqrt()
else:
dispersion = torch.zeros_like(central, device=device)
weights = mask / (1 + dispersion)
weights[weights != weights] *= 0.
weights[:, diag_idxs, diag_idxs] *= 0.
return central, weights
def mds_torch(pre_dist_mat, weights=None, iters=10, tol=1e-5, eigen=False, verbose=2):
""" 获取距离矩阵,输出三维坐标。参见下面的包装器。
假设(目前)距离图是(N x N)且对称的
输出:
* best_3d_coords: (batch x 3 x N)
* historic_stresses: (batch x steps)
"""
device, dtype = pre_dist_mat.device, pre_dist_mat.type()
pre_dist_mat = expand_dims_to(pre_dist_mat, length=(3 - len(pre_dist_mat.shape)))
batch, N, _ = pre_dist_mat.shape
diag_idxs = np.arange(N)
his = [torch.tensor([np.inf]*batch, device=device)]
D = pre_dist_mat**2
M = 0.5 * (D[:, :1, :] + D[:, :, :1] - D)
svds = [torch.svd_lowrank(mi) for mi in M]
u = torch.stack([svd[0] for svd in svds], dim=0)
s = torch.stack([svd[1] for svd in svds], dim=0)
v = torch.stack([svd[2] for svd in svds], dim=0)
best_3d_coords = torch.bmm(u, torch.diag_embed(s).abs().sqrt())[..., :3]
if weights is None and eigen==True:
return torch.transpose(best_3d_coords, -1, -2), torch.zeros_like(torch.stack(his, dim=0))
elif eigen==True:
if verbose:
print("如果激活权重,则无法使用特征分解标志。回退到迭代方式")
if weights is None:
weights = torch.ones_like(pre_dist_mat)
for i in range(iters):
best_3d_coords = best_3d_coords.contiguous()
dist_mat = torch.cdist(best_3d_coords, best_3d_coords, p=2).clone()
stress = (weights * (dist_mat - pre_dist_mat)**2).sum(dim=(-1,-2)) * 0.5
dist_mat[dist_mat <= 0] += 1e-7
ratio = weights * (pre_dist_mat / dist_mat)
B = -ratio
B[:, diag_idxs, diag_idxs] += ratio.sum(dim=-1)
coords = (1. / N * torch.matmul(B, best_3d_coords))
dis = torch.norm(coords, dim=(-1, -2))
if verbose >= 2:
print('迭代次数:%d,��力 %s' % (i, stress))
if (his[-1] - stress / dis).mean() <= tol:
if verbose:
print('在迭代 %d 中以应力 %s 结束' % (i, stress / dis))
break
best_3d_coords = coords
his.append(stress / dis)
return torch.transpose(best_3d_coords, -1, -2), torch.stack(his, dim=0)
def mds_numpy(pre_dist_mat, weights=None, iters=10, tol=1e-5, eigen=False, verbose=2):
""" 获取距离矩阵。输出三维坐标。参见下面的包装器。
假设(目前)距离图是(N x N)且对称的
输出:
* best_3d_coords: (3 x N)
* historic_stress
"""
if weights is None:
weights = np.ones_like(pre_dist_mat)
pre_dist_mat = expand_dims_to(pre_dist_mat, length=(3 - len(pre_dist_mat.shape)))
batch, N, _ = pre_dist_mat.shape
his = [np.inf]
best_stress = np.inf * np.ones(batch)
best_3d_coords = 2*np.random.rand(batch, 3, N) - 1
for i in range(iters):
dist_mat = np.linalg.norm(best_3d_coords[:, :, :, None] - best_3d_coords[:, :, None, :], axis=-3)
stress = (( weights * (dist_mat - pre_dist_mat) )**2).sum(axis=(-1, -2)) * 0.5
dist_mat[dist_mat == 0] = 1e-7
ratio = weights * (pre_dist_mat / dist_mat)
B = -ratio
B[:, np.arange(N), np.arange(N)] += ratio.sum(axis=-1)
coords = (1. / N * np.matmul(best_3d_coords, B))
dis = np.linalg.norm(coords, axis=(-1, -2))
if verbose >= 2:
print('it: %d, stress %s' % (i, stress))
if (best_stress - stress / dis).mean() <= tol:
if verbose:
print('breaking at iteration %d with stress %s' % (i,
stress / dis))
break
best_3d_coords = coords
best_stress = stress / dis
his.append(best_stress)
return best_3d_coords, np.array(his)
def get_dihedral_torch(c1, c2, c3, c4):
""" Returns the dihedral angle in radians.
Will use atan2 formula from:
https://en.wikipedia.org/wiki/Dihedral_angle#In_polymer_physics
Can't use torch.dot bc it does not broadcast
Inputs:
* c1: (batch, 3) or (3,)
* c1: (batch, 3) or (3,)
* c1: (batch, 3) or (3,)
* c1: (batch, 3) or (3,)
"""
u1 = c2 - c1
u2 = c3 - c2
u3 = c4 - c3
return torch.atan2( ( (torch.norm(u2, dim=-1, keepdim=True) * u1) * torch.cross(u2,u3, dim=-1) ).sum(dim=-1) ,
( torch.cross(u1,u2, dim=-1) * torch.cross(u2, u3, dim=-1) ).sum(dim=-1) )
def get_dihedral_numpy(c1, c2, c3, c4):
""" Returns the dihedral angle in radians.
Will use atan2 formula from:
https://en.wikipedia.org/wiki/Dihedral_angle#In_polymer_physics
Inputs:
* c1: (batch, 3) or (3,)
* c1: (batch, 3) or (3,)
* c1: (batch, 3) or (3,)
* c1: (batch, 3) or (3,)
"""
u1 = c2 - c1
u2 = c3 - c2
u3 = c4 - c3
return np.arctan2( ( (np.linalg.norm(u2, axis=-1, keepdims=True) * u1) * np.cross(u2,u3, axis=-1)).sum(axis=-1),
( np.cross(u1,u2, axis=-1) * np.cross(u2, u3, axis=-1) ).sum(axis=-1) )
def calc_phis_torch(pred_coords, N_mask, CA_mask, C_mask=None,
prop=True, verbose=0):
""" Filters mirrors selecting the 1 with most N of negative phis.
Used as part of the MDScaling wrapper if arg is passed. See below.
Angle Phi between planes: (Cterm{-1}, N, Ca{0}) and (N{0}, Ca{+1}, Cterm{+1})
Inputs:
* pred_coords: (batch, 3, N) predicted coordinates
* N_mask: (batch, N) boolean mask for N-term positions
* CA_mask: (batch, N) boolean mask for C-alpha positions
* C_mask: (batch, N) or None. boolean mask for C-alpha positions or
automatically calculate from N_mask and CA_mask if None.
* prop: bool. whether to return as a proportion of negative phis.
* verbose: bool. verbosity level
Output: (batch, N) containing the phi angles or (batch,) containing
the proportions.
Note: use [0] since all prots in batch have same backbone
"""
pred_coords_ = torch.transpose(pred_coords.detach(), -1 , -2).cpu()
N_mask = expand_dims_to( N_mask, 2-len(N_mask.shape) )
CA_mask = expand_dims_to( CA_mask, 2-len(CA_mask.shape) )
if C_mask is not None:
C_mask = expand_dims_to( C_mask, 2-len(C_mask.shape) )
else:
C_mask = torch.logical_not(torch.logical_or(N_mask,CA_mask))
n_terms = pred_coords_[:, N_mask[0].squeeze()]
c_alphas = pred_coords_[:, CA_mask[0].squeeze()]
c_terms = pred_coords_[:, C_mask[0].squeeze()]
phis = [get_dihedral_torch(c_terms[i, :-1],
n_terms[i, 1:],
c_alphas[i, 1:],
c_terms[i, 1:]) for i in range(pred_coords.shape[0])]
if prop:
return torch.stack([(x<0).float().mean() for x in phis], dim=0 )
return phis
def calc_phis_numpy(pred_coords, N_mask, CA_mask, C_mask=None,
prop=True, verbose=0):
""" Filters mirrors selecting the 1 with most N of negative phis.
Used as part of the MDScaling wrapper if arg is passed. See below.
Angle Phi between planes: (Cterm{-1}, N, Ca{0}) and (N{0}, Ca{+1}, Cterm{+1})
Inputs:
* pred_coords: (batch, 3, N) predicted coordinates
* N_mask: (N, ) boolean mask for N-term positions
* CA_mask: (N, ) boolean mask for C-alpha positions
* C_mask: (N, ) or None. boolean mask for C-alpha positions or
automatically calculate from N_mask and CA_mask if None.
* prop: bool. whether to return as a proportion of negative phis.
* verbose: bool. verbosity level
Output: (batch, N) containing the phi angles or (batch,) containing
the proportions.
"""
pred_coords_ = np.transpose(pred_coords, (0, 2, 1))
n_terms = pred_coords_[:, N_mask.squeeze()]
c_alphas = pred_coords_[:, CA_mask.squeeze()]
if C_mask is not None:
c_terms = pred_coords_[:, C_mask]
else:
c_terms = pred_coords_[:, (np.ones_like(N_mask)-N_mask-CA_mask).squeeze().astype(bool) ]
phis = [get_dihedral_numpy(c_terms[i, :-1],
n_terms[i, 1:],
c_alphas[i, 1:],
c_terms[i, 1:]) for i in range(pred_coords.shape[0])]
if prop:
return np.array( [(x<0).mean() for x in phis] )
return phis
def kabsch_torch(X, Y, cpu=True):
""" Kabsch alignment of X into Y.
Assumes X,Y are both (Dims x N_points). See below for wrapper.
"""
device = X.device
X_ = X - X.mean(dim=-1, keepdim=True)
Y_ = Y - Y.mean(dim=-1, keepdim=True)
C = torch.matmul(X_, Y_.t()).detach()
if cpu:
C = C.cpu()
if int(torch.__version__.split(".")[1]) < 8:
V, S, W = torch.svd(C)
W = W.t()
else:
V, S, W = torch.linalg.svd(C)
d = (torch.det(V) * torch.det(W)) < 0.0
if d:
S[-1] = S[-1] * (-1)
V[:, -1] = V[:, -1] * (-1)
U = torch.matmul(V, W).to(device)
X_ = torch.matmul(X_.t(), U).t()
return X_, Y_
def kabsch_numpy(X, Y):
""" Kabsch alignment of X into Y.
Assumes X,Y are both (Dims x N_points). See below for wrapper.
"""
X_ = X - X.mean(axis=-1, keepdims=True)
Y_ = Y - Y.mean(axis=-1, keepdims=True)
C = np.dot(X_, Y_.transpose())
V, S, W = np.linalg.svd(C)
d = (np.linalg.det(V) * np.linalg.det(W)) < 0.0
if d:
S[-1] = S[-1] * (-1)
V[:, -1] = V[:, -1] * (-1)
U = np.dot(V, W)
X_ = np.dot(X_.T, U).T
return X_, Y_
def distmat_loss_torch(X=None, Y=None, X_mat=None, Y_mat=None, p=2, q=2,
custom=None, distmat_mask=None, clamp=None):
""" Calculates a loss on the distance matrix - no need to align structs.
Inputs:
* X: (N, d) tensor. the predicted structure. One of (X, X_mat) is needed.
* X_mat: (N, N) tensor. the predicted distance matrix. Optional ()
* Y: (N, d) tensor. the true structure. One of (Y, Y_mat) is needed.
* Y_mat: (N, N) tensor. the predicted distance matrix. Optional ()
* p: int. power for the distance calculation (2 for euclidean)
* q: float. power for the scaling of the loss (2 for MSE, 1 for MAE, etc)
* custom: func or None. custom loss over distance matrices.
ex: lambda x,y: 1 - 1/ (1 + ((x-y))**2) (1 is very bad. 0 is good)
* distmat_mask: (N, N) mask (boolean or weights for each ij pos). optional.
* clamp: tuple of (min,max) values for clipping distance matrices. ex: (0,150)
"""
assert (X is not None or X_mat is not None) and \
(Y is not None or Y_mat is not None), "The true and predicted coords or dist mats must be provided"
if X_mat is None:
X = X.squeeze()
if clamp is not None:
X = torch.clamp(X, *clamp)
X_mat = torch.cdist(X, X, p=p)
if Y_mat is None:
Y = Y.squeeze()
if clamp is not None:
Y = torch.clamp(Y, *clamp)
Y_mat = torch.cdist(Y, Y, p=p)
if distmat_mask is None:
distmat_mask = torch.ones_like(Y_mat).bool()
if custom is not None:
return custom(X_mat.squeeze(), Y_mat.squeeze()).mean()
else:
loss = ( X_mat - Y_mat )**2
if q != 2:
loss = loss**(q/2)
return loss[distmat_mask].mean()
def rmsd_torch(X, Y):
return torch.sqrt( torch.mean((X - Y)**2, axis=(-1, -2)) )
def rmsd_numpy(X, Y):
""" Assumes x,y are both (B x D x N). See below for wrapper. """
return np.sqrt( np.mean((X - Y)**2, axis=(-1, -2)) )
def gdt_torch(X, Y, cutoffs, weights=None):
""" Assumes x,y are both (B x D x N). see below for wrapper.
* cutoffs is a list of `K` thresholds
* weights is a list of `K` weights (1 x each threshold)
"""
device = X.device
if weights is None:
weights = torch.ones(1,len(cutoffs))
else:
weights = torch.tensor([weights]).to(device)
GDT = torch.zeros(X.shape[0], len(cutoffs), device=device)
dist = ((X - Y)**2).sum(dim=1).sqrt()
for i,cutoff in enumerate(cutoffs):
GDT[:, i] = (dist <= cutoff).float().mean(dim=-1)
return (GDT*weights).mean(-1)
def gdt_numpy(X, Y, cutoffs, weights=None):
""" Assumes x,y are both (B x D x N). see below for wrapper.
* cutoffs is a list of `K` thresholds
* weights is a list of `K` weights (1 x each threshold)
"""
if weights is None:
weights = np.ones( (1,len(cutoffs)) )
else:
weights = np.array([weights])
GDT = np.zeros( (X.shape[0], len(cutoffs)) )
dist = np.sqrt( ((X - Y)**2).sum(axis=1) )
for i,cutoff in enumerate(cutoffs):
GDT[:, i] = (dist <= cutoff).mean(axis=-1)
return (GDT*weights).mean(-1)
def tmscore_torch(X, Y):
""" Assumes x,y are both (B x D x N). see below for wrapper. """
L = max(15, X.shape[-1])
d0 = 1.24 * (L - 15)**(1/3) - 1.8
dist = ((X - Y)**2).sum(dim=1).sqrt()
return (1 / (1 + (dist/d0)**2)).mean(dim=-1)
def tmscore_numpy(X, Y):
""" Assumes x,y are both (B x D x N). see below for wrapper. """
L = max(15, X.shape[-1])
d0 = 1.24 * np.cbrt(L - 15) - 1.8
dist = np.sqrt( ((X - Y)**2).sum(axis=1) )
return (1 / (1 + (dist/d0)**2)).mean(axis=-1)
def mdscaling_torch(pre_dist_mat, weights=None, iters=10, tol=1e-5,
fix_mirror=True, N_mask=None, CA_mask=None, C_mask=None,
eigen=False, verbose=2):
""" Handles the specifics of MDS for proteins (mirrors, ...) """
preds, stresses = mds_torch(pre_dist_mat, weights=weights,iters=iters,
tol=tol, eigen=eigen, verbose=verbose)
if not fix_mirror:
return preds, stresses
phi_ratios = calc_phis_torch(preds, N_mask, CA_mask, C_mask, prop=True)
to_correct = torch.nonzero( (phi_ratios < 0.5)).view(-1)
preds[to_correct, -1] = (-1)*preds[to_correct, -1]
if verbose == 2:
print("Corrected mirror idxs:", to_correct)
return preds, stresses
def mdscaling_numpy(pre_dist_mat, weights=None, iters=10, tol=1e-5,
fix_mirror=True, N_mask=None, CA_mask=None, C_mask=None, verbose=2):
""" Handles the specifics of MDS for proteins (mirrors, ...) """
preds, stresses = mds_numpy(pre_dist_mat, weights=weights,iters=iters,
tol=tol, verbose=verbose)
if not fix_mirror:
return preds, stresses
phi_ratios = calc_phis_numpy(preds, N_mask, CA_mask, C_mask, prop=True)
for i,pred in enumerate(preds):
if phi_ratios < 0.5:
preds[i, -1] = (-1)*preds[i, -1]
if verbose == 2:
print("Corrected mirror in struct no.", i)
return preds, stresses
def lddt_ca_torch(true_coords, pred_coords, cloud_mask, r_0=15.):
""" Computes the lddt score for each C_alpha.
https://academic.oup.com/bioinformatics/article/29/21/2722/195896
Inputs:
* true_coords: (b, l, c, d) in sidechainnet format.
* pred_coords: (b, l, c, d) in sidechainnet format.
* cloud_mask : (b, l, c) adapted for scn format.
* r_0: float. maximum inclusion radius in reference struct.
Outputs:
* (b, l) lddt for c_alpha scores (ranging between 0 and 1)
See wrapper below.
"""
device, dtype = true_coords.device, true_coords.type()
thresholds = torch.tensor([0.5, 1, 2, 4], device=device).type(dtype)
cloud_mask = cloud_mask.bool().cpu()
c_alpha_mask = torch.zeros(cloud_mask.shape[1:], device=device).bool()
c_alpha_mask[..., 1] = True
wrapper = torch.zeros(true_coords.shape[:2], device=device).type(dtype)
for bi, seq in enumerate(true_coords):
c_alphas = cloud_mask[bi]*c_alpha_mask
selected_pred = pred_coords[bi, c_alphas, :]
selected_target = true_coords[bi, c_alphas, :]
dist_mat_pred = torch.cdist(selected_pred, selected_pred, p=2)
dist_mat_target = torch.cdist(selected_target, selected_target, p=2)
under_r0_target = dist_mat_target < r_0
compare_dists = torch.abs(dist_mat_pred - dist_mat_target)[under_r0_target]
score = torch.zeros_like(under_r0_target).float()
max_score = torch.zeros_like(under_r0_target).float()
max_score[under_r0_target] = 4.
score[under_r0_target] = thresholds.shape[0] - \
torch.bucketize( compare_dists, boundaries=thresholds ).float()
l_mask = c_alphas.float().sum(dim=-1).bool()
wrapper[bi, l_mask] = ( score.sum(dim=-1) - thresholds.shape[0] ) / \
( max_score.sum(dim=-1) - thresholds.shape[0] )
return wrapper
@set_backend_kwarg
@invoke_torch_or_numpy(mdscaling_torch, mdscaling_numpy)
def MDScaling(pre_dist_mat, **kwargs):
""" Gets distance matrix (-ces). Outputs 3d.
Assumes (for now) distrogram is (N x N) and symmetric.
For support of ditograms: see `center_distogram_torch()`
Inputs:
* pre_dist_mat: (1, N, N) distance matrix.
* weights: optional. (N x N) pairwise relative weights .
* iters: number of iterations to run the algorithm on
* tol: relative tolerance at which to stop the algorithm if no better
improvement is achieved
* backend: one of ["numpy", "torch", "auto"] for backend choice
* fix_mirror: int. number of iterations to run the 3d generation and
pick the best mirror (highest number of negative phis)
* N_mask: indexing array/tensor for indices of backbone N.
Only used if fix_mirror > 0.
* CA_mask: indexing array/tensor for indices of backbone C_alpha.
Only used if fix_mirror > 0.
* verbose: whether to print logs
Outputs:
* best_3d_coords: (3 x N)
* historic_stress: (timesteps, )
"""
pre_dist_mat = expand_dims_to(pre_dist_mat, 3 - len(pre_dist_mat.shape))
return pre_dist_mat, kwargs
@expand_arg_dims(dim_len = 2)
@set_backend_kwarg
@invoke_torch_or_numpy(kabsch_torch, kabsch_numpy)
def Kabsch(A, B):
"""
返回通过将 A 对齐到 B 而产生的 Kabsch 旋转矩阵。
从 https://github.com/charnley/rmsd/ 改编而来。
* 输入:
* A,B 是 (3 x N) 的矩阵
* backend: 选择 ["numpy", "torch", "auto"] 之一作为后端
* 输出:形状为 (3 x N) 的张量/数组
"""
return A, B
def RMSD(A, B):
""" Returns RMSD score as defined here (lower is better):
https://en.wikipedia.org/wiki/
Root-mean-square_deviation_of_atomic_positions
* Inputs:
* A,B are (B x 3 x N) or (3 x N)
* backend: one of ["numpy", "torch", "auto"] for backend choice
* Outputs: tensor/array of size (B,)
"""
return A, B
def GDT(A, B, *, mode="TS", cutoffs=[1,2,4,8], weights=None):
""" Returns GDT score as defined here (highre is better):
Supports both TS and HA
http://predictioncenter.org/casp12/doc/help.html
* Inputs:
* A,B are (B x 3 x N) (np.array or torch.tensor)
* cutoffs: defines thresholds for gdt
* weights: list containing the weights
* mode: one of ["numpy", "torch", "auto"] for backend
* Outputs: tensor/array of size (B,)
"""
cutoffs = [0.5,1,2,4] if mode in ["HA", "ha"] else [1,2,4,8]
return A, B, cutoffs, {'weights': weights}
def TMscore(A, B):
""" Returns TMscore as defined here (higher is better):
>0.5 (likely) >0.6 (highly likely) same folding.
= 0.2. https://en.wikipedia.org/wiki/Template_modeling_score
Warning! It's not exactly the code in:
https://zhanglab.ccmb.med.umich.edu/TM-score/TMscore.cpp
but will suffice for now.
Inputs:
* A,B are (B x 3 x N) (np.array or torch.tensor)
* mode: one of ["numpy", "torch", "auto"] for backend
Outputs: tensor/array of size (B,)
"""
return A, B