Lucidrains 系列项目源码解析(三)
.\lucidrains\alphafold2\alphafold2_pytorch\__init__.py
# 从 alphafold2_pytorch.alphafold2 模块中导入 Alphafold2 和 Evoformer 类
from alphafold2_pytorch.alphafold2 import Alphafold2, Evoformer

Alphafold2 - Pytorch (wip)
To eventually become an unofficial working Pytorch implementation of Alphafold2, the breathtaking attention network that solved CASP14. Will be gradually implemented as more details of the architecture is released.
Once this is replicated, I intend to fold all available amino acid sequences out there in-silico and release it as an academic torrent, to further science. If you are interested in replication efforts, please drop by #alphafold at this Discord channel
Update: Deepmind has open sourced the official code in Jax, along with the weights 🙏! This repository will now be geared towards a straight pytorch translation with some improvements on positional encoding
Install
$ pip install alphafold2-pytorch
Status
lhatsk has reported training a modified trunk of this repository, using the same setup as trRosetta, with competitive results

blue used the the trRosetta input (MSA -> potts -> axial attention), green used the ESM embedding (only sequence) -> tiling -> axial attention - lhatsk
Usage
Predicting distogram, like Alphafold-1, but with attention
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
reversible = False # set this to True for fully reversible self / cross attention for the trunk
).cuda()
seq = torch.randint(0, 21, (1, 128)).cuda() # AA length of 128
msa = torch.randint(0, 21, (1, 5, 120)).cuda() # MSA doesn't have to be the same length as primary sequence
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
) # (1, 128, 128, 37)
You can also turn on prediction for the angles, by passing a predict_angles = True on init. The below example would be equivalent to trRosetta but with self / cross attention.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
predict_angles = True # set this to True
).cuda()
seq = torch.randint(0, 21, (1, 128)).cuda()
msa = torch.randint(0, 21, (1, 5, 120)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
distogram, theta, phi, omega = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
)
# distogram - (1, 128, 128, 37),
# theta - (1, 128, 128, 25),
# phi - (1, 128, 128, 13),
# omega - (1, 128, 128, 25)
Predicting Coordinates
Fabian's recent paper suggests iteratively feeding the coordinates back into SE3 Transformer, weight shared, may work. I have decided to execute based on this idea, even though it is still up in the air how it actually works.
You can also use E(n)-Transformer or EGNN for structural refinement.
Update: Baker's lab have shown that an end-to-end architecture from sequence and MSA embeddings to SE3 Transformers can best trRosetta and close the gap to Alphafold2. We will be using the Graph Transformer, which acts on the trunk embeddings, to generate the initial set of coordinates to be sent to the equivariant network. (This is further corroborated by Costa et al in their work teasing out 3d coordinates from MSA Transformer embeddings in a paper predating Baker lab's)
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
predict_coords = True,
structure_module_type = 'se3', # use SE3 Transformer - if set to False, will use E(n)-Transformer, Victor and Max Welling's new paper
structure_module_dim = 4, # se3 transformer dimension
structure_module_depth = 1, # depth
structure_module_heads = 1, # heads
structure_module_dim_head = 16, # dimension of heads
structure_module_refinement_iters = 2, # number of equivariant coordinate refinement iterations
structure_num_global_nodes = 1 # number of global nodes for the structure module, only works with SE3 transformer
).cuda()
seq = torch.randint(0, 21, (2, 64)).cuda()
msa = torch.randint(0, 21, (2, 5, 60)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
coords = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
) # (2, 64 * 3, 3) <-- 3 atoms per residue
Atoms
The underlying assumption is that the trunk works on the residue level, and then constitutes to atomic level for the structure module, whether it be SE3 Transformers, E(n)-Transformer, or EGNN doing the refinement. This library defaults to the 3 backbone atoms (C, Ca, N), but you can configure it to include any other atom you like, including Cb and the sidechains.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
predict_coords = True,
atoms = 'backbone-with-cbeta'
).cuda()
seq = torch.randint(0, 21, (2, 64)).cuda()
msa = torch.randint(0, 21, (2, 5, 60)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
coords = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
) # (2, 64 * 4, 3) <-- 4 atoms per residue (C, Ca, N, Cb)
Valid choices for atoms include:
backbone- 3 backbone atoms (C, Ca, N) [default]backbone-with-cbeta- 3 backbone atoms and C betabackbone-with-oxygen- 3 backbone atoms and oxygen from carboxylbackbone-with-cbeta-and-oxygen- 3 backbone atoms with C beta and oxygenall- backbone and all other atoms from sidechain
You can also pass in a tensor of shape (14,) defining which atoms you would like to include
ex.
atoms = torch.tensor([1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
MSA, ESM, or ProtTrans Embeddings
This repository offers you an easy supplement the network with pre-trained embeddings from Facebook AI. It contains wrappers for the pre-trained ESM, MSA Transformers or Protein Transformer.
There are some prerequisites. You will need to make sure that you have Nvidia's apex library installed, as the pretrained transformers make use of some fused operations.
Or you can try running the script below
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
Next, you will simply have to import and wrap your Alphafold2 instance with a ESMEmbedWrapper, MSAEmbedWrapper, or ProtTranEmbedWrapper and it will take care of embedding both the sequence and the multiple-sequence alignments for you (and projecting it to the dimensions as specified on your model). Nothing needs to be changed save for adding the wrapper.
import torch
from alphafold2_pytorch import Alphafold2
from alphafold2_pytorch.embeds import MSAEmbedWrapper
alphafold2 = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64
)
model = MSAEmbedWrapper(
alphafold2 = alphafold2
).cuda()
seq = torch.randint(0, 21, (2, 16)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa = torch.randint(0, 21, (2, 5, 16)).cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
)
By default, even if the wrapper supplies the trunk with the sequence and MSA embeddings, they would be summed with the usual token embeddings. If you want to train Alphafold2 without token embeddings (only rely on pretrained embeddings), you would need to set disable_token_embed to True on Alphafold2 init.
alphafold2 = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
disable_token_embed = True
)
Real-Value Distance Prediction
A paper by Jinbo Xu suggests that one doesn't need to bin the distances, and can instead predict the mean and standard deviation directly. You can use this by turning on one flag predict_real_value_distances, in which case, the distance prediction returned will have a dimension of 2 for the mean and standard deviation respectively.
If predict_coords is also turned on, then the MDS will accept the mean and standard deviation predictions directly without having to calculate that from the distogram bins.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
predict_coords = True,
predict_real_value_distances = True, # set this to True
structure_module_type = 'se3',
structure_module_dim = 4,
structure_module_depth = 1,
structure_module_heads = 1,
structure_module_dim_head = 16,
structure_module_refinement_iters = 2
).cuda()
seq = torch.randint(0, 21, (2, 64)).cuda()
msa = torch.randint(0, 21, (2, 5, 60)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
coords = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
) # (2, 64 * 3, 3) <-- 3 atoms per residue
Convolutions
You can add convolutional blocks, for both the primary sequence as well as the MSA, by simply setting one extra keyword argument use_conv = True
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
use_conv = True # set this to True
).cuda()
seq = torch.randint(0, 21, (1, 128)).cuda()
msa = torch.randint(0, 21, (1, 5, 120)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
) # (1, 128, 128, 37)
The convolutional kernels follow the lead of this paper, combining 1d and 2d kernels in one resnet-like block. You can fully customize the kernels as such.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
use_conv = True, # set this to True
conv_seq_kernels = ((9, 1), (1, 9), (3, 3)), # kernels for N x N primary sequence
conv_msa_kernels = ((1, 9), (3, 3)), # kernels for {num MSAs} x N MSAs
).cuda()
seq = torch.randint(0, 21, (1, 128)).cuda()
msa = torch.randint(0, 21, (1, 5, 120)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
) # (1, 128, 128, 37)
You can also do cycle dilation with one extra keyword argument. Default dilation is 1 for all layers.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
use_conv = True, # set this to True
dilations = (1, 3, 5) # cycle between dilations of 1, 3, 5
).cuda()
seq = torch.randint(0, 21, (1, 128)).cuda()
msa = torch.randint(0, 21, (1, 5, 120)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
) # (1, 128, 128, 37)
Finally, instead of following the pattern of convolutions, self-attention, cross-attention per depth repeating, you can customize any order you wish with the custom_block_types keyword
ex. A network where you do predominately convolutions first, followed by self-attention + cross-attention blocks
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
heads = 8,
dim_head = 64,
custom_block_types = (
*(('conv',) * 6),
*(('self', 'cross') * 6)
)
).cuda()
seq = torch.randint(0, 21, (1, 128)).cuda()
msa = torch.randint(0, 21, (1, 5, 120)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
) # (1, 128, 128, 37)
Sparse Attention
You can train with Microsoft Deepspeed's Sparse Attention, but you will have to endure the installation process. It is two-steps.
First, you need to install Deepspeed with Sparse Attention
$ sh install_deepspeed.sh
Next, you need to install the pip package triton
$ pip install triton
If both of the above succeeded, now you can train with Sparse Attention!
Sadly, the sparse attention is only supported for self attention, and not cross attention. I will bring in a different solution for making cross attention performant.
model = Alphafold2(
dim = 256,
depth = 12,
heads = 8,
dim_head = 64,
max_seq_len = 2048, # the maximum sequence length, this is required for sparse attention. the input cannot exceed what is set here
sparse_self_attn = (True, False) * 6 # interleave sparse and full attention for all 12 layers
).cuda()
Linear Attention
I have also added one of the best linear attention variants, in the hope of lessening the burden of cross attending. I personally have not found Performer to work that well, but since in the paper they reported some ok numbers for protein benchmarks, I thought I'd include it and allow others to experiment.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
cross_attn_linear = True # simply set this to True to use Performer for all cross attention
).cuda()
You can also specify the exact layers you wish to use linear attention by passing in a tuple of the same length as the depth
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 6,
heads = 8,
dim_head = 64,
cross_attn_linear = (True, False) * 3 # interleave linear and full attention
).cuda()
Kronecker Attention for Cross Attention
This paper suggests that if you have queries or contexts that have defined axials (say an image), you can reduce the amount of attention needed by averaging across those axials (height and width) and concatenating the averaged axials into one sequence. You can turn this on as a memory saving technique for the cross attention, specifically for the primary sequence.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 6,
heads = 8,
dim_head = 64,
cross_attn_kron_primary = True # make sure primary sequence undergoes the kronecker operator during cross attention
).cuda()
You can also apply the same operator to the MSAs during cross attention with the cross_attn_kron_msa flag, if your MSAs are aligned and of the same width.
Todo
- offer masked mean reduction method
- rotary embeddings
Memory Compressed Attention
To save on memory for cross attention, you can set a compression ratio for the key / values, following the scheme laid out in this paper. A compression ratio of 2-4 is usually acceptable.
model = Alphafold2(
dim = 256,
depth = 12,
heads = 8,
dim_head = 64,
cross_attn_compress_ratio = 3
).cuda()
MSA processing in Trunk

A new paper by Roshan Rao proposes using axial attention for pretraining on MSA's. Given the strong results, this repository will use the same scheme in the trunk, specifically for the MSA self-attention.
You can also tie the row attentions of the MSA with the msa_tie_row_attn = True setting on initialization of Alphafold2. However, in order to use this, you must make sure that if you have uneven number of MSAs per primary sequence, that the MSA mask is properly set to False for the rows not in use.
model = Alphafold2(
dim = 256,
depth = 2,
heads = 8,
dim_head = 64,
msa_tie_row_attn = True # just set this to true
)
Template processing in Trunk
Template processing is also largely done with axial attention, with cross attention done along the number of templates dimension. This largely follows the same scheme as in the recent all-attention approach to video classification as shown here.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 5,
heads = 8,
dim_head = 64,
reversible = True,
sparse_self_attn = False,
max_seq_len = 256,
cross_attn_compress_ratio = 3
).cuda()
seq = torch.randint(0, 21, (1, 16)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa = torch.randint(0, 21, (1, 10, 16)).cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
templates_seq = torch.randint(0, 21, (1, 2, 16)).cuda()
templates_coors = torch.randint(0, 37, (1, 2, 16, 3)).cuda()
templates_mask = torch.ones_like(templates_seq).bool().cuda()
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask,
templates_seq = templates_seq,
templates_coors = templates_coors,
templates_mask = templates_mask
)
If sidechain information is also present, in the form of the unit vector between the C and C-alpha coordinates of each residue, you can also pass it in as follows.
import torch
from alphafold2_pytorch import Alphafold2
model = Alphafold2(
dim = 256,
depth = 5,
heads = 8,
dim_head = 64,
reversible = True,
sparse_self_attn = False,
max_seq_len = 256,
cross_attn_compress_ratio = 3
).cuda()
seq = torch.randint(0, 21, (1, 16)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa = torch.randint(0, 21, (1, 10, 16)).cuda()
msa_mask = torch.ones_like(msa).bool().cuda()
templates_seq = torch.randint(0, 21, (1, 2, 16)).cuda()
templates_coors = torch.randn(1, 2, 16, 3).cuda()
templates_mask = torch.ones_like(templates_seq).bool().cuda()
templates_sidechains = torch.randn(1, 2, 16, 3).cuda() # unit vectors of difference of C and C-alpha coordinates
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask,
templates_seq = templates_seq,
templates_mask = templates_mask,
templates_coors = templates_coors,
templates_sidechains = templates_sidechains
)
Equivariant Attention
I have prepared a reimplementation of SE3 Transformer, as explained by Fabian Fuchs in a speculatory blogpost.
In addition, a new paper from Victor and Welling uses invariant features for E(n) equivariance, reaching SOTA and outperforming SE3 Transformer at a number of benchmarks, while being much faster. I have taken the main ideas from this paper and modified it to become a transformer (added attention to both features and coordinate updates).
All three of the equivariant networks above have been integrated and are available for use in the repository for atomic coordinate refinement by simply setting one hyperparameter structure_module_type.
-
se3SE3 Transformer -
egnnEGNN
Of interest to readers, each of the three frameworks have also been validated by researchers on related problems.
Testing
$ python setup.py test
Data
This library will use the awesome work by Jonathan King at this repository. Thank you Jonathan 🙏!
We also have the MSA data, all ~3.5 TB worth, downloaded and hosted by Archivist, who owns The-Eye project. (They also host the data and models for Eleuther AI) Please consider a donation if you find them helpful.
$ curl -s https://the-eye.eu/eleuther_staging/globus_stuffs/tree.txt
Speculation
moalquraishi.wordpress.com/2020/12/08/…


Recent works by competing labs
pubmed.ncbi.nlm.nih.gov/33637700/
tFold presentation, from Tencent AI labs
External packages
- Final step - Fast Relax - Installation Instructions:
- Download the pyrosetta wheel from: www.pyrosetta.org/dow (select appropiate version) - beware the file is heavy (approx 1.2 Gb)
- The download should be free for anyone with an academic email
- Bash >
cd downloads_folder>pip install pyrosetta_wheel_filename.whl
- Download the pyrosetta wheel from: www.pyrosetta.org/dow (select appropiate version) - beware the file is heavy (approx 1.2 Gb)
Citations
@misc{unpublished2021alphafold2,
title = {Alphafold2},
author = {John Jumper},
year = {2020},
archivePrefix = {arXiv},
primaryClass = {q-bio.BM}
}
@article{Rao2021.02.12.430858,
author = {Rao, Roshan and Liu, Jason and Verkuil, Robert and Meier, Joshua and Canny, John F. and Abbeel, Pieter and Sercu, Tom and Rives, Alexander},
title = {MSA Transformer},
year = {2021},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2021/02/13/2021.02.12.430858},
journal = {bioRxiv}
}
@article {Rives622803,
author = {Rives, Alexander and Goyal, Siddharth and Meier, Joshua and Guo, Demi and Ott, Myle and Zitnick, C. Lawrence and Ma, Jerry and Fergus, Rob},
title = {Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences},
year = {2019},
doi = {10.1101/622803},
publisher = {Cold Spring Harbor Laboratory},
journal = {bioRxiv}
}
@article {Elnaggar2020.07.12.199554,
author = {Elnaggar, Ahmed and Heinzinger, Michael and Dallago, Christian and Rehawi, Ghalia and Wang, Yu and Jones, Llion and Gibbs, Tom and Feher, Tamas and Angerer, Christoph and Steinegger, Martin and BHOWMIK, DEBSINDHU and Rost, Burkhard},
title = {ProtTrans: Towards Cracking the Language of Life{\textquoteright}s Code Through Self-Supervised Deep Learning and High Performance Computing},
elocation-id = {2020.07.12.199554},
year = {2021},
doi = {10.1101/2020.07.12.199554},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2021/05/04/2020.07.12.199554},
eprint = {https://www.biorxiv.org/content/early/2021/05/04/2020.07.12.199554.full.pdf},
journal = {bioRxiv}
}
@misc{king2020sidechainnet,
title = {SidechainNet: An All-Atom Protein Structure Dataset for Machine Learning},
author = {Jonathan E. King and David Ryan Koes},
year = {2020},
eprint = {2010.08162},
archivePrefix = {arXiv},
primaryClass = {q-bio.BM}
}
@misc{alquraishi2019proteinnet,
title = {ProteinNet: a standardized data set for machine learning of protein structure},
author = {Mohammed AlQuraishi},
year = {2019},
eprint = {1902.00249},
archivePrefix = {arXiv},
primaryClass = {q-bio.BM}
}
@misc{gomez2017reversible,
title = {The Reversible Residual Network: Backpropagation Without Storing Activations},
author = {Aidan N. Gomez and Mengye Ren and Raquel Urtasun and Roger B. Grosse},
year = {2017},
eprint = {1707.04585},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@misc{fuchs2021iterative,
title = {Iterative SE(3)-Transformers},
author = {Fabian B. Fuchs and Edward Wagstaff and Justas Dauparas and Ingmar Posner},
year = {2021},
eprint = {2102.13419},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{satorras2021en,
title = {E(n) Equivariant Graph Neural Networks},
author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
year = {2021},
eprint = {2102.09844},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
@article{Gao_2020,
title = {Kronecker Attention Networks},
ISBN = {9781450379984},
url = {http://dx.doi.org/10.1145/3394486.3403065},
DOI = {10.1145/3394486.3403065},
journal = {Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining},
publisher = {ACM},
author = {Gao, Hongyang and Wang, Zhengyang and Ji, Shuiwang},
year = {2020},
month = {Jul}
}
@article {Si2021.05.10.443415,
author = {Si, Yunda and Yan, Chengfei},
title = {Improved protein contact prediction using dimensional hybrid residual networks and singularity enhanced loss function},
elocation-id = {2021.05.10.443415},
year = {2021},
doi = {10.1101/2021.05.10.443415},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2021/05/11/2021.05.10.443415},
eprint = {https://www.biorxiv.org/content/early/2021/05/11/2021.05.10.443415.full.pdf},
journal = {bioRxiv}
}
@article {Costa2021.06.02.446809,
author = {Costa, Allan and Ponnapati, Manvitha and Jacobson, Joseph M. and Chatterjee, Pranam},
title = {Distillation of MSA Embeddings to Folded Protein Structures with Graph Transformers},
year = {2021},
doi = {10.1101/2021.06.02.446809},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2021/06/02/2021.06.02.446809},
eprint = {https://www.biorxiv.org/content/early/2021/06/02/2021.06.02.446809.full.pdf},
journal = {bioRxiv}
}
@article {Baek2021.06.14.448402,
author = {Baek, Minkyung and DiMaio, Frank and Anishchenko, Ivan and Dauparas, Justas and Ovchinnikov, Sergey and Lee, Gyu Rie and Wang, Jue and Cong, Qian and Kinch, Lisa N. and Schaeffer, R. Dustin and Mill{\'a}n, Claudia and Park, Hahnbeom and Adams, Carson and Glassman, Caleb R. and DeGiovanni, Andy and Pereira, Jose H. and Rodrigues, Andria V. and van Dijk, Alberdina A. and Ebrecht, Ana C. and Opperman, Diederik J. and Sagmeister, Theo and Buhlheller, Christoph and Pavkov-Keller, Tea and Rathinaswamy, Manoj K and Dalwadi, Udit and Yip, Calvin K and Burke, John E and Garcia, K. Christopher and Grishin, Nick V. and Adams, Paul D. and Read, Randy J. and Baker, David},
title = {Accurate prediction of protein structures and interactions using a 3-track network},
year = {2021},
doi = {10.1101/2021.06.14.448402},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2021/06/15/2021.06.14.448402},
eprint = {https://www.biorxiv.org/content/early/2021/06/15/2021.06.14.448402.full.pdf},
journal = {bioRxiv}
}
.\lucidrains\alphafold2\scripts\refinement.py
# 导入所需的库和模块
import os
import json
import warnings
# 科学计算库
import numpy as np
# 尝试导入 pyrosetta 模块,如果导入失败则发出警告
try:
import pyrosetta
except ModuleNotFoundError:
msg = "Unable to find an existing installation of the PyRosetta module. " +\
"Functions involving this module such as the FastRelax pipeline " +\
"will not work."
warnings.warn(msg) # no pyRosetta was found
#####################
### ROSETTA STUFF ###
#####################
def pdb2rosetta(route):
""" Takes pdb file route(s) as input and returns rosetta pose(s).
Input:
* route: list or string.
Output: list of 1 or many according to input
"""
# 如果输入是字符串,则返回包含单个 rosetta pose 的列表
if isinstance(route, str):
return [pyrosetta.io.pose_from_pdb(route)]
else:
return list(pyrosetta.io.poses_from_files(route))
def rosetta2pdb(pose, route, verbose=True):
""" Takes pose(s) as input and saves pdb(s) to disk.
Input:
* pose: list or string. rosetta poses object(s).
* route: list or string. destin filenames to be written.
* verbose: bool. warns if lengths dont match and @ every write.
Inspo:
* https://www.rosettacommons.org/demos/latest/tutorials/input_and_output/input_and_output#controlling-output_common-structure-output-files_pdb-file
* https://graylab.jhu.edu/PyRosetta.documentation/pyrosetta.rosetta.core.io.pdb.html#pyrosetta.rosetta.core.io.pdb.dump_pdb
"""
# 将输入转换为列表
pose = [pose] if isinstance(pose, str) else pose
route = [route] if isinstance(route, str) else route
# 检查长度是否匹配,如果不匹配则发出警告
if verbose and ( len(pose) != len(route) ):
print("Length of pose and route are not the same. Will stop at the minimum.")
# 转换并保存
for i,pos in enumerate(pose):
pyrosetta.rosetta.core.io.pdb.dump_pdb(pos, route[i])
if verbose:
print("Saved structure @ "+route)
return
def run_fast_relax(config_route, pdb_route=None, pose=None):
""" Runs the Fast-Relax pipeline.
* config_route: route to json file with config
* pose: rosetta pose to run the pipeline on
Output: rosetta pose
"""
# 加载 rosetta pose - 如果传入字符串或列表,则转换为 pose + 重新调用
if isinstance(pdb_route, str):
pose = pdb2rosetta(pdb_route)
return run_fast_relax(config, pose=pose)
elif isinstance(pdb_route, list):
return [run_fast_relax(config, pdb_route=pdb) for pdb in pdb_route]
# 加载配置文件
config = json.load(config_route)
# 运行 Fast-Relax pipeline - 示例:
# https://colab.research.google.com/github/RosettaCommons/PyRosetta.notebooks/blob/master/notebooks/06.02-Packing-design-and-regional-relax.ipynb#scrollTo=PYr025Rn1Q8i
# https://nbviewer.jupyter.org/github/RosettaCommons/PyRosetta.notebooks/blob/master/notebooks/06.03-Design-with-a-resfile-and-relax.ipynb
# https://faculty.washington.edu/dimaio/files/demo2.py
raise NotImplementedError("Last step. Not implemented yet.")
.\lucidrains\alphafold2\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
# 包的名称
name = 'alphafold2-pytorch',
# 查找并包含所有包
packages = find_packages(),
# 版本号
version = '0.4.32',
# 许可证
license='MIT',
# 描述
description = 'AlphaFold2 - Pytorch',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 作者
author = 'Phil Wang, Eric Alcaide',
# 作者邮箱
author_email = 'lucidrains@gmail.com, ericalcaide1@gmail.com',
# 项目链接
url = 'https://github.com/lucidrains/alphafold2',
# 关键词
keywords = [
'artificial intelligence',
'attention mechanism',
'protein folding'
],
# 安装依赖
install_requires=[
'einops>=0.3',
'En-transformer>=0.2.3',
'invariant-point-attention',
'mdtraj>=1.8',
'numpy',
'proDy',
'pytorch3d',
'requests',
'sidechainnet',
'torch>=1.6',
'transformers',
'tqdm',
'biopython',
'mp-nerf>=0.1.5'
],
# 设置需要的依赖
setup_requires=[
'pytest-runner',
],
# 测试需要的依赖
tests_require=[
'pytest'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.7',
],
)
.\lucidrains\alphafold2\tests\test_attention.py
import torch
from torch import nn
from einops import repeat
from alphafold2_pytorch.alphafold2 import Alphafold2
from alphafold2_pytorch.utils import *
# 定义测试函数 test_main
def test_main():
# 创建 Alphafold2 模型对象
model = Alphafold2(
dim = 32,
depth = 2,
heads = 2,
dim_head = 32
)
# 生成随机序列数据和多序列比对数据
seq = torch.randint(0, 21, (2, 128))
msa = torch.randint(0, 21, (2, 5, 128))
mask = torch.ones_like(seq).bool()
msa_mask = torch.ones_like(msa).bool()
# 使用模型进行预测
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
)
# 断言测试结果为真
assert True
# 定义测试函数 test_no_msa
def test_no_msa():
# 创建 Alphafold2 模型对象
model = Alphafold2(
dim = 32,
depth = 2,
heads = 2,
dim_head = 32
)
# 生成随机序列数据和掩码
seq = torch.randint(0, 21, (2, 128))
mask = torch.ones_like(seq).bool()
# 使用模型进行预测
distogram = model(
seq,
mask = mask
)
# 断言测试结果为真
assert True
# 定义测试函数 test_anglegrams
def test_anglegrams():
# 创建 Alphafold2 模型对象
model = Alphafold2(
dim = 32,
depth = 2,
heads = 2,
dim_head = 32,
predict_angles = True
)
# 生成随机序列数据、多序列比对数据和掩码
seq = torch.randint(0, 21, (2, 128))
msa = torch.randint(0, 21, (2, 5, 128))
mask = torch.ones_like(seq).bool()
msa_mask = torch.ones_like(msa).bool()
# 使用模型进行预测
ret = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
)
# 断言测试结果为真
assert True
# 定义测试函数 test_templates
def test_templates():
# 创建 Alphafold2 模型对象
model = Alphafold2(
dim = 32,
depth = 2,
heads = 2,
dim_head = 32,
templates_dim = 32,
templates_angles_feats_dim = 32
)
# 生成随机序列数据、多序列比对数据和掩码
seq = torch.randint(0, 21, (2, 16))
mask = torch.ones_like(seq).bool()
msa = torch.randint(0, 21, (2, 5, 16))
msa_mask = torch.ones_like(msa).bool()
# 生成随机模板特征数据、模板角度数据和模板掩码
templates_feats = torch.randn(2, 3, 16, 16, 32)
templates_angles = torch.randn(2, 3, 16, 32)
templates_mask = torch.ones(2, 3, 16).bool()
# 使用模型进行预测
distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask,
templates_feats = templates_feats,
templates_angles = templates_angles,
templates_mask = templates_mask
)
# 断言测试结果为真
assert True
# 定义测试函数 test_extra_msa
def test_extra_msa():
# 创建 Alphafold2 模型对象
model = Alphafold2(
dim = 128,
depth = 2,
heads = 2,
dim_head = 32,
predict_coords = True
)
# 生成随机序列数据、多序列比对数据和掩码
seq = torch.randint(0, 21, (2, 4))
mask = torch.ones_like(seq).bool()
msa = torch.randint(0, 21, (2, 5, 4))
msa_mask = torch.ones_like(msa).bool()
# 生成额外的多序列比对数据和掩码
extra_msa = torch.randint(0, 21, (2, 5, 4))
extra_msa_mask = torch.ones_like(extra_msa).bool()
# 使用模型进行预测
coords = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask,
extra_msa = extra_msa,
extra_msa_mask = extra_msa_mask
)
# 断言测试结果为真
assert True
# 定义测试函数 test_embeddings
def test_embeddings():
# 创建 Alphafold2 模型对象
model = Alphafold2(
dim = 32,
depth = 2,
heads = 2,
dim_head = 32
)
# 生成随机序列数据、掩码和嵌入数据
seq = torch.randint(0, 21, (2, 16))
mask = torch.ones_like(seq).bool()
embedds = torch.randn(2, 1, 16, 1280)
# 使用模型进行预测(不带掩码)
distogram = model(
seq,
mask = mask,
embedds = embedds,
msa_mask = None
)
# 生成嵌入数据的掩码
embedds_mask = torch.ones_like(embedds[..., -1]).bool()
# 使用模型进行预测(带掩码)
distogram = model(
seq,
mask = mask,
embedds = embedds,
msa_mask = embedds_mask
)
# 断言测试结果为真
assert True
# 定义测试函数 test_coords
def test_coords():
# 创建 Alphafold2 模型对象
model = Alphafold2(
dim = 32,
depth = 2,
heads = 2,
dim_head = 32,
predict_coords = True,
structure_module_depth = 1,
structure_module_heads = 1,
structure_module_dim_head = 1,
)
# 生成随机序列数据、多序列比对数据和掩码
seq = torch.randint(0, 21, (2, 16))
mask = torch.ones_like(seq).bool()
msa = torch.randint(0, 21, (2, 5, 16))
msa_mask = torch.ones_like(msa).bool()
# 使用模型进行预测
coords = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
)
# 断言输出坐标的形状为 (2, 16, 3)
assert coords.shape == (2, 16, 3), 'must output coordinates'
# 定义测试函数 test_coords_backbone_with_cbeta
def test_coords_backbone_with_cbeta():
# 创建 Alphafold2 模型对象
model = Alphafold2(
dim = 32,
depth = 2,
heads = 2,
dim_head = 32,
predict_coords = True,
structure_module_depth = 1,
structure_module_heads = 1,
structure_module_dim_head = 1,
)
# 生成随机序列数据、多序列比对数据和掩码
seq = torch.randint(0, 21, (2, 16))
mask = torch.ones_like(seq).bool()
msa = torch.randint(0, 21, (2, 5, 16))
msa_mask = torch.ones_like(msa).bool()
# 使用模型进行预测
coords = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
)
# 断言输出坐标的形状为 (2, 16, 3)
assert coords.shape == (2, 16, 3), 'must output coordinates'
# 定义测试函数 test_coords_all_atoms
def test_coords_all_atoms():
# 创建 Alphafold2 模型对象
model = Alphafold2(
dim = 32,
depth = 2,
heads = 2,
dim_head = 32,
predict_coords = True,
structure_module_depth = 1,
structure_module_heads = 1,
structure_module_dim_head = 1,
)
# 生成随机序列数据、多序列比对数据和掩码
seq = torch.randint(0, 21, (2, 16))
mask = torch.ones_like(seq).bool()
msa = torch.randint(0, 21, (2, 5, 16))
msa_mask = torch.ones_like(msa).bool()
# 使用模型进行预测
coords = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
)
# 断言输出坐标的形状为 (2, 16, 3)
assert coords.shape == (2, 16, 3), 'must output coordinates'
# 定义测试函数 test_mds
def test_mds():
# 创建 Alphafold2 模型对象
model = Alphafold2(
dim = 32,
depth = 2,
heads = 2,
dim_head = 32,
predict_coords = True,
structure_module_depth = 1,
structure_module_heads = 1,
structure_module_dim_head = 1,
)
# 生成随机序列数据、多序列比对数据和掩码
seq = torch.randint(0, 21, (2, 16))
mask = torch.ones_like(seq).bool()
msa = torch.randint(0, 21, (2, 5, 16))
msa_mask = torch.ones_like(msa).bool()
# 使用模型进行预测
coords = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
)
# 断言输出坐标的形状为 (2, 16, 3)
assert coords.shape == (2, 16, 3), 'must output coordinates'
# 定义测试函数 test_edges_to_equivariant_network
def test_edges_to_equivariant_network():
# 创建 Alphafold2 模型对象
model = Alphafold2(
dim = 32,
depth = 1,
heads = 2,
dim_head = 32,
predict_coords = True,
predict_angles = True
)
# 生成随机序列数据、多序列比对数据和掩码
seq = torch.randint(0, 21, (2, 32))
mask = torch.ones_like(seq).bool()
msa = torch.randint(0, 21, (2, 5, 32))
msa_mask = torch.ones_like(msa).bool()
# 使用模型进行预测
coords, confidences = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask,
return_confidence = True
)
# 断言测试结果为真
assert True, 'should run without errors'
# 定义测试函数 test_coords_backwards
def test_coords_backwards():
# 创建 Alphafold2 模型对象
model = Alphafold2(
dim = 256,
depth = 2,
heads = 2,
dim_head = 32,
predict_coords = True,
structure_module_depth = 1,
structure_module_heads = 1,
structure_module_dim_head = 1,
)
# 生成随机序列数据、多序列比对数据和掩码
seq = torch.randint(0, 21, (2, 16))
mask = torch.ones_like(seq).bool()
msa = torch.randint(0, 21, (2, 5, 16))
msa_mask = torch.ones_like(msa).bool()
# 使用模型进行预测
coords = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
)
# 反向传播
coords.sum().backward()
assert True, 'must be able to go backwards through MDS and center distogram'
# 定义测试函数 test_confidence
def test_confidence():
# 创建 Alphafold2 模型对象
model = Alphafold2(
dim = 256,
depth = 1,
heads = 2,
dim_head = 32,
predict_coords = True
)
# 生成随机序列数据、多序列比对数据和掩码
seq = torch.randint(0, 21, (2, 16))
mask = torch.ones_like(seq).bool()
msa = torch.randint(0, 21, (2, 5, 16))
msa_mask = torch.ones_like(msa).bool()
# 使用模型进行预测
coords, confidences = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask,
return_confidence = True
)
# 断言坐标和置信度的形状相同
assert coords.shape[:-1] == confidences.shape[:-1]
# 定义测试函数 test_recycling
def test_recycling():
# 创建 Alphafold2 模型对象
model = Alphafold2(
dim = 128,
depth = 2,
heads = 2,
dim_head = 32,
predict_coords = True,
)
# 生成随机序列数据、多序列比对数据和掩码
seq = torch.randint(0, 21, (2, 4))
mask = torch.ones_like(seq).bool()
msa = torch.randint(0, 21, (2, 5, 4))
msa_mask = torch.ones_like(msa).bool()
# 生成额外的多序列比对数据和掩码
extra_msa = torch.randint(0, 21, (2, 5, 4))
extra_msa_mask = torch.ones_like(extra_msa).bool()
# 调用模型,传入序列、多序列比对、掩码、多序列比对掩码、额外多序列比对、额外多序列比对掩码等参数,并返回坐标和结果
coords, ret = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask,
extra_msa = extra_msa,
extra_msa_mask = extra_msa_mask,
return_aux_logits = True, # 返回辅助日志
return_recyclables = True # 返回可回收的数据
)
# 调用模型,传入序列、多序列比对、掩码、多序列比对掩码、额外多序列比对、额外多序列比对掩码、可回收的数据等参数,并返回坐标和结果
coords, ret = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask,
extra_msa = extra_msa,
extra_msa_mask = extra_msa_mask,
recyclables = ret.recyclables, # 使用上一个调用返回的可回收数据
return_aux_logits = True, # 返回辅助日志
return_recyclables = True # 返回可回收的数据
)
# 断言,确保条件为真,否则会引发异常
assert True
.\lucidrains\alphafold2\tests\test_utils.py
import torch
import numpy as np
from alphafold2_pytorch.utils import *
# 测试 mat_input_to_masked 函数
def test_mat_to_masked():
# nodes
x = torch.ones(19, 3)
x_mask = torch.randn(19) > -0.3
# edges
edges_mat = torch.randn(19, 19) < 1
edges = torch.nonzero(edges_mat, as_tuple=False).t()
# 测试正常的边缘/节点
cleaned = mat_input_to_masked(x, x_mask, edges=edges)
cleaned_2 = mat_input_to_masked(x, x_mask, edges_mat=edges_mat)
# 测试批处理维度
x_ = torch.stack([x]*2, dim=0)
x_mask_ = torch.stack([x_mask]*2, dim=0)
edges_mat_ = torch.stack([edges_mat]*2, dim=0)
cleaned_3 = mat_input_to_masked(x_, x_mask_, edges_mat=edges_mat_)
assert True
# 测试 center_distogram_torch 函数
def test_center_distogram_median():
distogram = torch.randn(1, 128, 128, 37)
distances, weights = center_distogram_torch(distogram, center='median')
assert True
# 测试 scn_backbone_mask 函数
def test_masks():
seqs = torch.randint(20, size=(2, 50))
N_mask, CA_mask, C_mask = scn_backbone_mask(seqs, boolean=True)
assert True
# 测试 MDScaling 函数
def test_mds_and_mirrors():
distogram = torch.randn(2, 32*3, 32*3, 37)
distances, weights = center_distogram_torch(distogram)
paddings = [7, 0]
for i, pad in enumerate(paddings):
if pad > 0:
weights[i, -pad:, -pad:] = 0.
masker = torch.arange(distogram.shape[1]) % 3
N_mask = (masker == 0).bool()
CA_mask = (masker == 1).bool()
coords_3d, _ = MDScaling(distances, weights=weights, iters=5, fix_mirror=2, N_mask=N_mask, CA_mask=CA_mask, C_mask=None)
assert list(coords_3d.shape) == [2, 3, 32*3], 'coordinates must be of the right shape after MDS'
# 测试 sidechain_container 函数
def test_sidechain_container():
seqs = torch.tensor([[0]*137, [3]*137]).long()
bb = torch.randn(2, 137*4, 3)
atom_mask = torch.tensor([1]*4 + [0]*(14-4))
proto_3d = sidechain_container(seqs, bb, atom_mask=atom_mask)
assert list(proto_3d.shape) == [2, 137, 14, 3]
# 测试 distmat_loss_torch 函数
def test_distmat_loss():
a = torch.randn(2, 137, 14, 3)
b = torch.randn(2, 137, 14, 3)
loss = distmat_loss_torch(a, b, p=2, q=2) # mse on distmat
assert True
# 测试 lddt_ca_torch 函数
def test_lddt():
a = torch.randn(2, 137, 14, 3)
b = torch.randn(2, 137, 14, 3)
cloud_mask = torch.ones(a.shape[:-1]).bool()
lddt_result = lddt_ca_torch(a, b, cloud_mask)
assert list(lddt_result.shape) == [2, 137]
# 测试 Kabsch 函数
def test_kabsch():
a = torch.randn(3, 8)
b = torch.randn(3, 8)
a_, b_ = Kabsch(a, b)
assert a.shape == a_.shape
# 测试 TMscore 函数
def test_tmscore():
a = torch.randn(2, 3, 8)
b = torch.randn(2, 3, 8)
out = TMscore(a, b)
assert True
# 测试 GDT 函数
def test_gdt():
a = torch.randn(1, 3, 8)
b = torch.randn(1, 3, 8)
GDT(a, b, weights=1)
assert True
.\lucidrains\alphafold2\training_scripts\datasets\trrosetta.py
import pickle
import string
from argparse import ArgumentParser
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import numpy.linalg as LA
import prody
import torch
from Bio import SeqIO
from einops import repeat
from sidechainnet.utils.measure import get_seq_coords_and_angles
from sidechainnet.utils.sequence import ProteinVocabulary
from torch.utils.data import DataLoader, Dataset
from alphafold2_pytorch.constants import DISTOGRAM_BUCKETS
from tqdm import tqdm
try:
import pytorch_lightning as pl
LightningDataModule = pl.LightningDataModule
except ImportError:
LightningDataModule = object
CACHE_PATH = Path("~/.cache/alphafold2_pytorch").expanduser()
DATA_DIR = CACHE_PATH / "trrosetta" / "trrosetta"
URL = "http://s3.amazonaws.com/proteindata/data_pytorch/trrosetta.tar.gz"
REMOVE_KEYS = dict.fromkeys(string.ascii_lowercase)
REMOVE_KEYS["."] = None
REMOVE_KEYS["*"] = None
translation = str.maketrans(REMOVE_KEYS)
DEFAULT_VOCAB = ProteinVocabulary()
def default_tokenize(seq: str) -> List[int]:
return [DEFAULT_VOCAB[ch] for ch in seq]
def read_fasta(filename: str) -> List[Tuple[str, str]]:
def remove_insertions(sequence: str) -> str:
return sequence.translate(translation)
return [
(record.description, remove_insertions(str(record.seq)))
for record in SeqIO.parse(filename, "fasta")
]
def read_pdb(pdb: str):
ag = prody.parsePDB(pdb)
for chain in ag.iterChains():
angles, coords, seq = get_seq_coords_and_angles(chain)
return angles, coords, seq
def download_file(url, filename=None, root=CACHE_PATH):
import os
import urllib
root.mkdir(exist_ok=True, parents=True)
filename = filename or os.path.basename(url)
download_target = root / filename
download_target_tmp = root / f"tmp.{filename}"
if download_target.exists() and not download_target.is_file():
raise RuntimeError(f"{download_target} exists and is not a regular file")
if download_target.is_file():
return download_target
with urllib.request.urlopen(url) as source, open(
download_target_tmp, "wb"
) as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
download_target_tmp.rename(download_target)
return download_target
def get_or_download(url: str = URL):
"""
download and extract trrosetta data
"""
import tarfile
file = CACHE_PATH / "trrosetta.tar.gz"
dir = CACHE_PATH / "trrosetta"
dir_temp = CACHE_PATH / "trrosetta_tmp"
if dir.is_dir():
print(f"Load cached data from {dir}")
return dir
if not file.is_file():
print(f"Cache not found, download from {url} to {file}")
download_file(url)
print(f"Extract data from {file} to {dir}")
with tarfile.open(file, "r:gz") as tar:
tar.extractall(dir_temp)
dir_temp.rename(dir)
return dir
def pad_sequences(sequences, constant_value=0, dtype=None) -> np.ndarray:
batch_size = len(sequences)
shape = [batch_size] + np.max([seq.shape for seq in sequences], 0).tolist()
if dtype is None:
dtype = sequences[0].dtype
if isinstance(sequences[0], np.ndarray):
array = np.full(shape, constant_value, dtype=dtype)
elif isinstance(sequences[0], torch.Tensor):
array = torch.full(shape, constant_value, dtype=dtype)
for arr, seq in zip(array, sequences):
arrslice = tuple(slice(dim) for dim in seq.shape)
arr[arrslice] = seq
return array
class TrRosettaDataset(Dataset):
def __init__(
self,
data_dir: Path,
list_path: Path,
tokenize: Callable[[str], List[int]],
seq_pad_value: int = 20,
random_sample_msa: bool = False,
max_seq_len: int = 300,
max_msa_num: int = 300,
overwrite: bool = False,
):
self.data_dir = data_dir
self.file_list: List[Path] = self.read_file_list(data_dir, list_path)
self.tokenize = tokenize
self.seq_pad_value = seq_pad_value
self.random_sample_msa = random_sample_msa
self.max_seq_len = max_seq_len
self.max_msa_num = max_msa_num
self.overwrite = overwrite
def __len__(self) -> int:
return len(self.file_list)
def read_file_list(self, data_dir: Path, list_path: Path):
file_glob = (data_dir / "npz").glob("*.npz")
files = set(list_path.read_text().split())
if len(files) == 0:
raise ValueError("Passed an empty split file set")
file_list = [f for f in file_glob if f.name in files]
if len(file_list) != len(files):
num_missing = len(files) - len(file_list)
raise FileNotFoundError(
f"{num_missing} specified split files not found in directory"
)
return file_list
def has_cache(self, index):
if self.overwrite:
return False
path = (self.data_dir / "cache" / self.file_list[index].stem).with_suffix(
".pkl"
)
return path.is_file()
def write_cache(self, index, data):
path = (self.data_dir / "cache" / self.file_list[index].stem).with_suffix(
".pkl"
)
path.parent.mkdir(exist_ok=True, parents=True)
with open(path, "wb") as file:
pickle.dump(data, file)
def read_cache(self, index):
path = (self.data_dir / "cache" / self.file_list[index].stem).with_suffix(
".pkl"
)
with open(path, "rb") as file:
return pickle.load(file)
def __getitem__(self, index):
if self.has_cache(index):
item = self.read_cache(index)
else:
id = self.file_list[index].stem
pdb_path = self.data_dir / "pdb" / f"{id}.pdb"
msa_path = self.data_dir / "a3m" / f"{id}.a3m"
_, msa = zip(*read_fasta(str(msa_path)))
msa = np.array([np.array(list(seq)) for seq in msa])
angles, coords, seq = read_pdb(str(pdb_path))
seq = np.array(list(seq))
coords = coords.reshape((coords.shape[0] // 14, 14, 3))
dist = self.get_bucketed_distance(seq, coords, subset="ca")
item = {
"id": id,
"seq": seq,
"msa": msa,
"coords": coords,
"angles": angles,
"dist": dist
}
self.write_cache(index, item)
item["msa"] = self.sample(item["msa"], self.max_msa_num, self.random_sample_msa)
item = self.crop(item, self.max_seq_len)
return item
def calc_cb(self, coord):
N = coord[0]
CA = coord[1]
C = coord[2]
b = CA - N
c = C - CA
a = np.cross(b, c)
CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
return CB
def get_bucketed_distance(
self, seq, coords, subset="ca", start=2, bins=DISTOGRAM_BUCKETS-1, step=0.5
assert subset in ("ca", "cb")
# 检查 subset 是否为 "ca" 或 "cb"
if subset == "ca":
coords = coords[:, 1, :]
# 如果 subset 为 "ca",则只保留坐标的第二列数据
elif subset == "cb":
cb_coords = []
# 创建空列表用于存储 cb 坐标数据
for res, coord in zip(seq, coords):
# 遍历序列和坐标数据
if res == "G":
# 如果氨基酸为 "G"
cb = self.calc_cb(coord)
# 计算 cb 坐标
cb_coords.append(cb)
# 将计算得到的 cb 坐标添加到列表中
else:
cb_coords.append(coord[4, :])
# 如果氨基酸不是 "G",则将坐标的第五行数据添加到列表中
coords = np.array(cb_coords)
# 将列表转换为 NumPy 数组,更新坐标数据
vcs = coords + np.zeros([coords.shape[0]] + list(coords.shape))
# 创建与 coords 形状相同的全零数组,并与 coords 相加,得到 vcs
vcs = vcs - np.swapaxes(vcs, 0, 1)
# 将 vcs 与其转置矩阵相减,更新 vcs
distance_map = LA.norm(vcs, axis=2)
# 计算 vcs 的二范数,得到距离矩阵
mask = np.ones(distance_map.shape) - np.eye(distance_map.shape[0])
# 创建与距离矩阵形状相同的全一数组,减去单位矩阵,得到 mask
low_pos = np.where(distance_map < start)
# 找出距离矩阵中小于 start 的位置
high_pos = np.where(distance_map >= start + step * bins)
# 找出距离矩阵中大于等于 start + step * bins 的位置
mask[low_pos] = 0
# 将低于 start 的位置在 mask 中置为 0
distance_map = (distance_map - start) // step
# 对距离矩阵进行归一化处理
distance_map[high_pos] = bins
# 将高于 start + step * bins 的位置在距离矩阵中置为 bins
dist = (distance_map * mask).astype(int)
# 将归一化后的距离矩阵乘以 mask,并转换为整数类型,得到最终距离矩阵
return dist
# 返回距离矩阵
def crop(self, item, max_seq_len: int):
# 截取序列数据,使其长度不超过 max_seq_len
seq_len = len(item["seq"])
if seq_len <= max_seq_len or max_seq_len <= 0:
return item
# 如果序列长度小于等于 max_seq_len 或 max_seq_len 小于等于 0,则直接返回原始数据
start = 0
end = start + max_seq_len
# 计算截取的起始位置和结束位置
item["seq"] = item["seq"][start:end]
item["msa"] = item["msa"][:, start:end]
item["coords"] = item["coords"][start:end]
item["angles"] = item["angles"][start:end]
item["dist"] = item["dist"][start:end, start:end]
# 对 item 中的各项数据进行截取操作
return item
# 返回截取后的数据
def sample(self, msa, max_msa_num: int, random: bool):
# 对多序列进行采样,使其数量不超过 max_msa_num
num_msa, seq_len = len(msa), len(msa[0])
if num_msa <= max_msa_num or max_msa_num <= 0:
return msa
# 如果多序列数量小于等于 max_msa_num 或 max_msa_num 小于等于 0,则直接返回原始数据
if random:
# 如果需要随机采样
num_sample = max_msa_num - 1
# 计算需要采样的数量
indices = np.random.choice(num_msa - 1, size=num_sample, replace=False) + 1
# 随机选择索引进行采样
indices = np.pad(indices, [1, 0], "constant")
# 在索引数组前面添加一个元素
return msa[indices]
# 返回采样后的多序列数据
else:
return msa[:max_msa_num]
# 如果不需要随机采样,则直接返回前 max_msa_num 个多序列数据
def collate_fn(self, batch):
# 对批量数据进行整理
b = len(batch)
# 获取批量数据的长度
batch = {k: [item[k] for item in batch] for k in batch[0]}
# 将批量数据转换为字典形式,按照键值进行整理
id = batch["id"]
seq = batch["seq"]
msa = batch["msa"]
coords = batch["coords"]
angles = batch["angles"]
dist = batch["dist"]
# 获取批量数据中的各项内容
lengths = torch.LongTensor([len(x[0]) for x in msa])
depths = torch.LongTensor([len(x) for x in msa])
max_len = lengths.max()
max_depth = depths.max()
# 计算多序列数据的长度和深度信息
seq = pad_sequences(
[torch.LongTensor(self.tokenize(seq_)) for seq_ in seq], self.seq_pad_value,
)
# 对序列数据进行填充处理
msa = pad_sequences(
[torch.LongTensor([self.tokenize(seq_) for seq_ in msa_]) for msa_ in msa],
self.seq_pad_value,
)
# 对多序列数据进行填充处理
coords = pad_sequences([torch.FloatTensor(x) for x in coords], 0.0)
# 对坐标数据进行填充处理
angles = pad_sequences([torch.FloatTensor(x) for x in angles], 0.0)
# 对角度数据进行填充处理
dist = pad_sequences([torch.LongTensor(x) for x in dist], -100)
# 对距离数据进行填充处理
mask = repeat(torch.arange(max_len), "l -> b l", b=b) < repeat(
lengths, "b -> b l", l=max_len
)
# 生成序列数据的掩码
msa_seq_mask = repeat(
torch.arange(max_len), "l -> b s l", b=b, s=max_depth
) < repeat(lengths, "b -> b s l", s=max_depth, l=max_len)
# 生成多序列数据的序列掩码
msa_depth_mask = repeat(
torch.arange(max_depth), "s -> b s l", b=b, l=max_len
) < repeat(depths, "b -> b s l", s=max_depth, l=max_len)
# 生成多序列数据的深度掩码
msa_mask = msa_seq_mask & msa_depth_mask
# 组合多序列数据的掩码
return {
"id": id,
"seq": seq,
"msa": msa,
"coords": coords,
"angles": angles,
"mask": mask,
"msa_mask": msa_mask,
"dist": dist,
}
# 返回整理后的批量��据
class TrRosettaDataModule(LightningDataModule):
@staticmethod
def add_data_specific_args(parent_parser):
# 创建参数解析器
parser = ArgumentParser(parents=[parent_parser], add_help=False)
# 添加数据目录参数
parser.add_argument("--data_dir", type=str, default=str(DATA_DIR))
# 添加训练批量大小参数
parser.add_argument("--train_batch_size", type=int, default=1)
# 添加评估批量大小参数
parser.add_argument("--eval_batch_size", type=int, default=1)
# 添加测试批量大小参数
parser.add_argument("--test_batch_size", type=int, default=1)
# 添加工作进程数参数
parser.add_argument("--num_workers", type=int, default=0)
# 添加训练最大序列长度参数
parser.add_argument("--train_max_seq_len", type=int, default=256)
# 添加评估最大序列长度参数
parser.add_argument("--eval_max_seq_len", type=int, default=256)
# 添加测试最大序列长度参数
parser.add_argument("--test_max_seq_len", type=int, default=-1)
# 添加训练最大 MSA 数量参数
parser.add_argument("--train_max_msa_num", type=int, default=256)
# 添加评估最大 MSA 数量参数
parser.add_argument("--eval_max_msa_num", type=int, default=256)
# 添加测试最大 MSA 数量参数
parser.add_argument("--test_max_msa_num", type=int, default=1000)
# 添加覆盖参数
parser.add_argument("--overwrite", dest="overwrite", action="store_true")
# 返回参数解析器
return parser
def __init__(
self,
data_dir: str = DATA_DIR,
train_batch_size: int = 1,
eval_batch_size: int = 1,
test_batch_size: int = 1,
num_workers: int = 0,
train_max_seq_len: int = 256,
eval_max_seq_len: int = 256,
test_max_seq_len: int = -1,
train_max_msa_num: int = 32,
eval_max_msa_num: int = 32,
test_max_msa_num: int = 64,
tokenize: Callable[[str], List[int]] = default_tokenize,
seq_pad_value: int = 20,
overwrite: bool = False,
**kwargs,
):
# 调用父类构造函数
super(TrRosettaDataModule, self).__init__()
# 解析数据目录
self.data_dir = Path(data_dir).expanduser().resolve()
# 初始化各参数
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.test_batch_size = test_batch_size
self.num_workers = num_workers
self.train_max_seq_len = train_max_seq_len
self.eval_max_seq_len = eval_max_seq_len
self.test_max_seq_len = test_max_seq_len
self.train_max_msa_num = train_max_msa_num
self.eval_max_msa_num = eval_max_msa_num
self.test_max_msa_num = test_max_msa_num
self.tokenize = tokenize
self.seq_pad_value = seq_pad_value
self.overwrite = overwrite
# 获取或下载数据
get_or_download()
def setup(self, stage: Optional[str] = None):
# 设置训练数据集
self.train = TrRosettaDataset(
self.data_dir,
self.data_dir / "train_files.txt",
self.tokenize,
self.seq_pad_value,
random_sample_msa=True,
max_seq_len=self.train_max_seq_len,
max_msa_num=self.train_max_msa_num,
overwrite=self.overwrite,
)
# 设置验证数据集
self.val = TrRosettaDataset(
self.data_dir,
self.data_dir / "valid_files.txt",
self.tokenize,
self.seq_pad_value,
random_sample_msa=False,
max_seq_len=self.eval_max_seq_len,
max_msa_num=self.eval_max_msa_num,
overwrite=self.overwrite,
)
# 设置测试数据集
self.test = TrRosettaDataset(
self.data_dir,
self.data_dir / "valid_files.txt",
self.tokenize,
self.seq_pad_value,
random_sample_msa=False,
max_seq_len=self.test_max_seq_len,
max_msa_num=self.test_max_msa_num,
overwrite=self.overwrite,
)
def train_dataloader(self, *args, **kwargs) -> DataLoader:
# 返回训练数据加载器
return DataLoader(
self.train,
batch_size=self.train_batch_size,
shuffle=True,
collate_fn=self.train.collate_fn,
num_workers=self.num_workers,
)
def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
# 返回验证数据加载器
return DataLoader(
self.val,
batch_size=self.eval_batch_size,
shuffle=False,
collate_fn=self.val.collate_fn,
num_workers=self.num_workers,
)
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
# 返回测试数据加载器
return DataLoader(
self.test,
batch_size=self.test_batch_size,
shuffle=False,
collate_fn=self.test.collate_fn,
num_workers=self.num_workers,
)
def test():
# 创建数据模块实例
dm = TrRosettaDataModule(train_batch_size=1, num_workers=4)
# 设置数据
dm.setup()
# 遍历训练数据加载器
for batch in dm.train_dataloader():
print("id", batch["id"])
print("seq", batch["seq"].shape, batch["seq"])
print("msa", batch["msa"].shape, batch["msa"][..., :20])
print("msa", batch["msa"].shape, batch["msa"][..., -20:])
print("coords", batch["coords"].shape)
print("angles", batch["angles"].shape)
print("mask", batch["mask"].shape)
print("msa_mask", batch["msa_mask"].shape)
print("dist", batch["dist"].shape, batch["dist"])
break
if __name__ == "__main__":
test()