Lucidrains-系列项目源码解析-三十九-

58 阅读10分钟

Lucidrains 系列项目源码解析(三十九)

.\lucidrains\gateloop-transformer\gateloop_transformer\gateloop_transformer_jax.py

# 导入必要的模块和函数
from typing import List, Tuple, Callable
from jax import random, jit, nn, lax, numpy as np
from jax.lax import associative_scan
from equinox import Module, static_field

# linear

# 定义线性层模块
class Linear(Module):
    weight: np.ndarray
    bias: np.ndarray

    def __init__(self, dim_in, dim_out, *, key):
        # 使用随机数生成权重和偏置
        weight_key, bias_key = random.split(key)
        self.weight = random.normal(weight_key, (dim_in, dim_out))
        self.bias = random.normal(bias_key, (dim_out,))

    def __call__(self, x, *, key = None):
        # 计算线性变换
        return x @ self.weight + self.bias

# rmsnorm

# 定义 RMSNorm 模块
class RMSNorm(Module):
    scale: float = static_field()
    eps: float = static_field()
    gamma: np.ndarray

    def __init__(self, dim, eps = 1e-5):
        # 初始化参数
        self.eps = eps
        self.scale = dim ** 0.5
        self.gamma = np.ones((dim,))

    def __call__(self, x):
        # 计算 RMSNorm
        sum_of_squares = np.sum(np.square(x), axis = -1, keepdims = True)
        inv_norm = lax.rsqrt(sum_of_squares + self.eps)
        return inv_norm * x * self.gamma * self.scale

# gate loop layer

# 定义门循环操作符
def gate_loop_operator(k, v, q, a):
    kv = k * v + 0.j

    def binary_operator(e_i, e_j):
        a_i, kv_i = e_i
        a_j, kv_j = e_j
        return a_j * a_i, a_j * kv_i + kv_j

    # 使用关联扫描计算门循环
    _, y = associative_scan(binary_operator, (a, kv), axis = 1)

    return q * np.real(y)

# 定义门循环模块
class GateLoop(Module):
    norm: RMSNorm
    wq: np.ndarray
    wk: np.ndarray
    wv: np.ndarray
    wa: np.ndarray
    wg: np.ndarray
    wo: np.ndarray

    def __init__(
        self,
        dim,
        key
    ):
        """
        q - query
        k - key
        v - value
        a - state transition
        g - gating with silu activation
        o - output
        """

        # 使用随机数生成参数
        q_key, k_key, v_key, a_key, g_key, o_key = random.split(key, 6)

        self.norm = RMSNorm(dim)

        self.wq = random.normal(q_key, (dim, dim))
        self.wk = random.normal(k_key, (dim, dim))
        self.wv = random.normal(v_key, (dim, dim))
        self.wa = random.normal(a_key, (dim, dim * 2))
        self.wg = random.normal(g_key, (dim, dim))
        self.wo = random.normal(o_key, (dim, dim))

    def __call__(self, x):
        x = self.norm(x)

        q = x @ self.wq
        k = x @ self.wk
        v = x @ self.wv
        a = x @ self.wa
        g = x @ self.wg

        # 构成复杂状态转换
        a_real, a_imag = np.split(a, 2, axis = -1)
        a_complex = lax.complex(a_real, a_imag)

        magnitude, phase = np.abs(a_complex), np.angle(a_complex)
        magnitude = nn.sigmoid(magnitude)

        a_complex = magnitude * np.exp(1j * phase)

        # 使用复杂状态进行关联扫描
        y = gate_loop_operator(k, v, q, a_complex)

        # 使用 ReTNet 的 silu gating
        y = y * nn.silu(g)

        o = y @ self.wo

        return o

# basic feedforward with pre-rmsnorm

# 定义带有 RMSNorm 的基本前馈模块
class FeedForward(Module):
    norm: RMSNorm
    proj_in: Linear
    proj_out: Linear

    def __init__(
        self,
        *,
        dim,
        key,
        mult = 4
    ):
        self.norm = RMSNorm(dim)
        self.proj_in = Linear(dim, dim * mult, key = key)
        self.proj_out = Linear(dim * mult, dim, key = key)

    def __call__(self, x):
        x = self.norm(x)
        x = self.proj_in(x)
        x = nn.gelu(x)
        x = self.proj_out(x)
        return x

# main class

# 定义门循环变换器模块
class GateLoopTransformer(Module):
    embedding: np.ndarray
    norm: Module
    layers: List[Tuple[GateLoop, FeedForward]]

    def __init__(
        self,
        *,
        num_tokens,
        dim,
        depth,
        key,
        ff_mult = 4
    # 初始化嵌入矩阵,使用正态分布随机初始化,乘以0.02
    self.embedding = random.normal(key, (num_tokens, dim)) * 0.02

    # 初始化层列表
    layers = []

    # 循环创建深度次数的GateLoop和FeedForward层,并添加到层列表中
    for _ in range(depth):
        gateloop = GateLoop(dim = dim, key = key)
        ff = FeedForward(dim = dim, mult = ff_mult, key = key)
        layers.append((gateloop, ff))

    # 将创建的层列表赋值给self.layers
    self.layers = layers

    # 初始化RMSNorm层
    self.norm = RMSNorm(dim)

@jit
def __call__(self, x):
    # 通过嵌入矩阵获取输入x的嵌入向量
    x = self.embedding[x]

    # 遍历每一层,依次进行GateLoop和FeedForward操作
    for gateloop, ff in self.layers:
        x = gateloop(x) + x
        x = ff(x) + x

    # 对输出进行归一化处理
    x = self.norm(x)

    # 计算logits,即输出结果
    logits = x @ self.embedding.transpose()

    return logits
# 如果当前脚本被直接运行
if __name__ == '__main__':
    # 导入 jax 库
    import jax
    # 使用 PRNGKey 创建一个随机种子
    key = jax.random.PRNGKey(0)

    # 创建一个 GateLoopTransformer 模型实例
    model = GateLoopTransformer(
        num_tokens = 20000,
        dim = 512,
        depth = 12,
        key = key
    )

    # 生成一个长度为 1024 的随机整数序列
    seq = jax.random.randint(key, (1024,), 0, 20000)
    # 使用模型对序列进行推理,得到输出 logits
    logits = model(seq)

    # 打印 logits 的形状
    print(logits.shape) # (1024, 20000)

.\lucidrains\gateloop-transformer\gateloop_transformer\simplified_gate_loop.py

# 导入所需模块
from functools import partial
import torch
from torch import nn, Tensor
from torch.nn import Module
from typing import Tuple
from einops import rearrange, pack, unpack
from einops.layers.torch import Rearrange
from gateloop_transformer.gateloop_transformer import RMSNorm
from gateloop_transformer.associative_scan import associative_scan

# 检查变量是否存在的函数
def exists(v):
    return v is not None

# 绝对值截断函数,用于处理小于给定阈值的值
def abs_clamp_eps(t, eps = 1e-20):
    sign = torch.sign(t)
    return sign * t.abs().clamp(min = eps)

# 使用 Heinsen 序列进行关联扫描
def heinsen_associative_scan(a, kv, eps = 1e-20):
    log_a = a.clamp(min = eps).log()
    log_kv = abs_clamp_eps(kv, eps = eps).to(dtype = torch.complex64).log()
    a_star = torch.cumsum(log_a, dim = 1)
    log_x0_plus_b_star = torch.logcumsumexp(log_kv - a_star, dim = 1)
    log_x = a_star + log_x0_plus_b_star
    return a_star.exp().real, log_x.exp().real

# 使用 TorchScript 实现的二进制运算函数
@torch.jit.script
def binary_operator(
    a: Tuple[Tensor, Tensor],
    b: Tuple[Tensor, Tensor]
):
    a_i, kv_i = a
    a_j, kv_j = b
    return a_j * a_i, torch.addcmul(kv_j, a_j, kv_i)

# 门循环操作符
def gate_loop_operator(q, kv, a, cache = None, heinsen = False):
    if exists(cache):
        cache_a, cache_kv = cache
        a, a_ps = pack([cache_a, a], 'b * d')
        kv, kv_ps = pack([cache_kv, kv], 'b * d')

    if heinsen:
        a, kv = heinsen_associative_scan(a, kv)
    else:
        a, kv = associative_scan(binary_operator, (a, kv))

    if exists(cache):
        _, a = unpack(a, a_ps, 'b * d')
        _, kv = unpack(kv, kv_ps, 'b * d')

    return q * kv, (a[:, -1], kv[:, -1])

# 使用 JAX 实现的门循环操作符
def get_jax_gate_loop_operator():
    try:
        from jax import jit, numpy as jnp
        from jax.lax import associative_scan
        from jax2torch import jax2torch
    except ImportError as e:
        print(f'jax and jax2torch must be installed - `pip install jax2torch`')

    @jit
    def jax_gate_loop_operator(q, kv, a, cache = None):
        def binary_operator(e_i, e_j):
            a_i, kv_i = e_i
            a_j, kv_j = e_j
            return a_j * a_i, a_j * kv_i + kv_j

        if exists(cache):
            cache_a, cache_kv = cache
            a, a_ps = pack([cache_a, a], 'b * d')
            kv, kv_ps = pack([cache_kv, kv], 'b * d')

        _, y = associative_scan(binary_operator, (a, kv), axis = 1)

        if exists(cache):
            _, a = unpack(a, a_ps, 'b * d')
            _, kv = unpack(kv, kv_ps, 'b * d')

        return q * y, (a[:, -1], kv[:, -1])

    return jax2torch(jax_gate_loop_operator)

# 简单的门循环层
class SimpleGateLoopLayer(Module):
    """
    简化的门循环层,用于补充注意力机制
    参考 https://github.com/lucidrains/mega-pytorch
    """

    def __init__(
        self,
        dim,
        prenorm = True,
        use_heinsen = False,
        use_jax_associative_scan = False,
        post_ln = False,
        reverse = False
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言确保 use_heinsen 和 use_jax_associative_scan 中至多只有一个为真
        assert (int(use_heinsen) + int(use_jax_associative_scan)) <= 1

        # 如果 prenorm 为真,则使用 RMSNorm 进行归一化,否则使用 nn.Identity()
        self.norm = RMSNorm(dim) if prenorm else nn.Identity()

        # 初始化维度
        self.dim = dim

        # 将输入映射到 q, k, v,并进行线性变换
        self.to_qkva = nn.Sequential(
            nn.Linear(dim, dim * 3, bias = False),
            Rearrange('b n (qkva d) -> qkva (b d) n 1', qkva = 3)
        )

        # 设置是否使用 Heinsen 或 JAX 的关联扫描
        self.use_heinsen = use_heinsen
        self.use_jax = use_jax_associative_scan

        # 根据使用的扫描方式选择相应的 gate_loop_fn
        if use_jax_associative_scan:
            self.gate_loop_fn = get_jax_gate_loop_operator()
        elif use_heinsen:
            self.gate_loop_fn = partial(gate_loop_operator, heinsen = True)
        else:
            self.gate_loop_fn = gate_loop_operator

        # 如果 post_ln 为真,则使用 nn.LayerNorm(dim) 进行归一化,否则使用 nn.Identity()
        self.maybe_post_ln = nn.LayerNorm(dim) if post_ln else nn.Identity()
        # 将输出进行头部分割
        self.split_heads = Rearrange('(b d) n 1 -> b n d', d = dim)

        # 设置是否反转序列
        self.reverse = reverse

    # 前向传播函数
    def forward(
        self,
        x,
        cache = None,
        return_cache = False
    ):
        # 如果需要反转序列,则对输入进行反转
        if self.reverse:
            x = torch.flip(x, dims = (-2,))

        # 对输入进行归一化
        x = self.norm(x)

        # 将输入映射到 q, k, v
        q, kv, a = self.to_qkva(x)

        # 使用 gate_loop_fn 进行计算
        out, cache = self.gate_loop_fn(q, kv, a.sigmoid(), cache = cache)

        # 将输出进行头部分割
        out = self.split_heads(out)
        # 对输出进行归一化
        out = self.maybe_post_ln(out)

        # 如果需要反转序列,则对输出进行反转
        if self.reverse:
            out = torch.flip(out, dims = (-2,))

        # 如果不需要返回 cache,则直接返回输出
        if not return_cache:
            return out

        # 断言确保只有在非反转序列时才能缓存
        assert not self.reverse, 'caching only works with non-reversed seq'

        # 返回输出和 cache
        return out, cache

.\lucidrains\gateloop-transformer\gateloop_transformer\__init__.py

# 从 gateloop_transformer.gateloop_transformer 模块中导入 CausalFullAttention, GateLoopedAttention, Transformer 类
# 从 gateloop_transformer.simplified_gate_loop 模块中导入 SimpleGateLoopLayer 类
from gateloop_transformer.gateloop_transformer import (
    CausalFullAttention,
    GateLoopedAttention,
    Transformer
)

from gateloop_transformer.simplified_gate_loop import (
    SimpleGateLoopLayer
)

GateLoop Transformer

Implementation of GateLoop Transformer in Pytorch and Jax, to be tested on Enwik8 character level modeling.

Update: A transformer run with regular attention + data dependent xpos relative positions did not converge at all. Also, gate loop's associative scan also is not able to train on even sequence lengths of 128. I'm not sure if it can be done without a specialized CUDA kernel, much like autoregressive linear attention (RWKV and the like)

Update 2: Got a smaller GateLoop transformer (gate loop dimensions of 128) to run on sequence length of 256. It is converging very well with a quick eyeball. Will run some more rigorous experiments tomorrow.

Update 3: Fixed a misunderstanding and definitely seems to be converging better than vanilla linear attention (from my memories of those experiments).

Update 4: Ongoing experiments

Update 5: Author has reviewed the code, and there was another misunderstanding. They use maximum heads (heads == dimension). This is kind of a plot twist, as this is infeasible for normal attention. It also obviates the need a fused CUDA kernel as in autoregressive linear attention.

Update 6: Corrected gateloop transformer run looks amazing. Cautiously optimistic now.

Update 7: Ablating state transition shows expected negative result. Ablating complex valued states though, I see no difference, at least, early in the run.

Update 8: Directly projecting to kv with one projection for the max-heads setting (instead of keys and values separately followed by element-wise multiplication) yields similar results

Update 9: Head to head to 20k, just to make sure Gateloop doesn't get exceeded later on

Update 10: and it got passed by attention, at least, assuming the implementation in the repo is correct.

Update 11: I'm seeing a steady improvement increasing the head dimension, so I no longer believe max-heads is optimal. Increasing the head dimension brings us right back to linear attention and needing the fused CUDA kernel.

Update 12: Nikil spotted a potential error with the kv not being kept in complex (and real component taken at end). Rerunning experiments

Update 13: Still clearly worse

Update 14: See some synergy when mixing gateloop and attention on a small scale, when holding parameters constant. Will be adding a tiny bit of simplified gateloop layers to transformers to address a main weakness in attention for future projects.

Update 15: There may be a way to combine associative scan based works with the findings from the recently proposed taylor series linear attention. will carry out some independent research before end of January 2024 and share the results here.

Appreciation

Install

$ pip install gateloop-transformer

Usage

import torch
from gateloop_transformer import Transformer

model = Transformer(
    num_tokens = 256,
    dim = 624,
    depth = 6,
    use_gate_looped_attn = True
)

ids = torch.randint(0, 256, (1, 1024))
logits = model(ids) # (1, 1024, 256)

A simplified gate loop layer

import torch
from gateloop_transformer import SimpleGateLoopLayer

gateloop = SimpleGateLoopLayer(512)

x = torch.randn(1, 65536, 512)
x = gateloop(x) + x

Character-level Language Modeling

Install requirements

$ pip install -r requirements.txt

Then run the train.py script for autoregressive modeling on enwik8

$ python train.py

Todo

  • jax version with equinox
  • start with naive memory checkpointing of gate loop operation
  • retry the failed full attention experiments (with data dependent xpos), but with complex valued scales (didn't work)
  • separate out a minimal gateloop circuit, to augment attention, rather than to replace it, as done in Mega
  • experiments
    • do all the ablations and figure out how much the data controlled state transitions adds (as well as whether it needs to be complex)
    • do complete runs between transformer + rotary against gateloop with max heads, parameter held constant to 20k steps
  • just use jax's associative scan, wrapped with jax2torch, for now. pytorch team claim they will implement this eventually

Citations

@inproceedings{Katsch2023GateLoopFD,
    title   = {GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling},
    author  = {Tobias Katsch},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:265018962}
}
@inproceedings{Heinsen2023EfficientPO,
    title   = {Efficient Parallelization of a Ubiquitous Sequential Computation},
    author  = {Franz A. Heinsen},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:265213659}
}

.\lucidrains\gateloop-transformer\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'gateloop-transformer',  # 包名
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.2.4',  # 版本号
  license='MIT',  # 许可证
  description = 'GateLoop Transformer',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/gateloop-transformer',  # 项目链接
  keywords = [
    'artificial intelligence',  # 关键词
    'deep learning',  # 关键词
    'gated linear attention'  # 关键词
  ],
  install_requires=[
    'einops>=0.7.0',  # 安装所需的依赖包
    'rotary-embedding-torch',  # 安装所需的依赖包
    'torch>=2.1',  # 安装所需的依赖包
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类器
    'Intended Audience :: Developers',  # 分类器
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器
    'License :: OSI Approved :: MIT License',  # 分类器
    'Programming Language :: Python :: 3.6',  # 分类器
  ],
)

.\lucidrains\gateloop-transformer\train.py

# 导入所需的库
import math
import gzip
import random
import tqdm
import numpy as np
from functools import wraps, partial

import torch
from torch.optim import Adam, AdamW
from torch import Tensor
from torch.nn import Module, functional as F
from torch.utils.data import DataLoader, Dataset

# 导入加速库
from accelerate import Accelerator

# 导入自定义的 Transformer 模型
from gateloop_transformer import Transformer

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRAD_ACCUM_EVERY = 4
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 256

WANDB = True
PROJECT_NAME = 'gateloop'
RUN_NAME = 'baseline gateloop'

# 初始化加速器
accelerator = Accelerator(log_with='wandb' if WANDB else None)

# 辅助函数
def exists(v):
    return v is not None

def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

# 采样辅助函数
def log(t, eps=1e-20):
    return torch.log(t.clamp(min=eps))

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def gumbel_sample(t, temperature=1., dim=-1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)

def top_k(logits, thres=0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(-1, ind, val)
    return probs

def base_decoding(net: Module, prompt: Tensor, seq_len: int, temperature=1., filter_thres=0.9):
    prompt_seq_len, out = prompt.shape[-1], prompt.clone()
    sample_num_times = max(0, seq_len - prompt_seq_len)

    for _ in range(sample_num_times):
        logits = net(out)
        logits = logits[:, -1]

        logits = top_k(logits, thres=filter_thres)
        sample = gumbel_sample(logits, temperature=temperature, dim=-1)

        out = torch.cat((out, sample[..., None]), dim=-1)

    return out[..., prompt_seq_len:]

# 优化器
def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []

    for param in params:
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)

    return wd_params, no_wd_params

def get_optimizer(params, lr=1e-4, wd=0., betas=(0.9, 0.99), eps=1e-8, group_wd_params=True, **kwargs):
    opt_kwargs = dict(lr=lr, betas=betas, eps=eps)

    if wd == 0:
        return Adam(params, **opt_kwargs)

    opt_kwargs = {'weight_decay': wd, **opt_kwargs}

    if not group_wd_params:
        return AdamW(params, **opt_kwargs)

    wd_params, no_wd_params = separate_weight_decayable_params(params)

    params = [
        {'params': wd_params},
        {'params': no_wd_params, 'weight_decay': 0},
    ]

    return AdamW(params, **opt_kwargs)

# 实例化 Transformer 模型
hparams = dict(
    num_tokens=256,
    dim=512,
    depth=6,
    use_gate_looped_attn=True,
    gate_loop_heads=512,
    data_dependent_rel_pos=False,
    attn_softmax_normalize=True,
    ablate_complex=False,
    ablate_state_transition=False,
    rotary_emb=False,
    post_ln_norm=True
)

model = Transformer(**hparams)

# 初始化实验跟踪
num_parameters = sum(p.numel() for p in model.parameters())
print(f'number of parameters: {num_parameters}')

wandb_config = {**hparams, 'num_parameters': num_parameters}
accelerator.init_trackers(PROJECT_NAME, config=wandb_config)

if WANDB and exists(RUN_NAME) and len(accelerator.trackers) > 0:
    accelerator.trackers[0].run.name = RUN_NAME

# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
    # 从文件中读取指定长度的数据,转换为 numpy 数组,数据类型为无符号整数8位,然后复制一份
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    # 将数据数组分割成训练集和验证集,分割点为第90e6个元素的位置
    np_train, np_valid = np.split(data, [int(90e6)])
    # 将 numpy 数组转换为 PyTorch 张量,分别赋值给训练集和验证集的变量
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)
# 定义一个自定义的数据集类,用于处理文本数据的采样
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data  # 存储数据
        self.seq_len = seq_len  # 存储序列长度

    def __getitem__(self, index):
        # 随机生成起始位置
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        # 获取完整的序列数据
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练数据集和验证数据集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
# 创建训练数据加载器和验证数据加载器
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# 优化器
optim = get_optimizer(
    model.parameters(),
    lr = LEARNING_RATE,
    wd = WEIGHT_DECAY
)

# 准备模型、优化器、训练数据加载器和验证数据加载器
(
    model,
    optim,
    train_loader,
    val_loader
) = accelerator.prepare(
    model,
    optim,
    train_loader,
    val_loader
)

# 将训练数据加载器和验证数据加载器转换为循环迭代器
train_loader = cycle(train_loader)
val_loader = cycle(val_loader)

# 训练过程
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
    model.train()

    for _ in range(GRAD_ACCUM_EVERY):
        data = next(train_loader)

        loss = model(data, return_loss = True)

        accelerator.backward(loss / GRAD_ACCUM_EVERY)

    print(f"training loss: {loss.item():.3f}")
    accelerator.log(dict(loss = loss.item()), step = i)

    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    accelerator.wait_for_everyone()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            valid_data = next(val_loader)

            loss = model(valid_data, return_loss = True)
            print(f"validation loss: {loss.item():.3f}")
            accelerator.log(dict(valid_loss = loss.item()), step = i)

    accelerator.wait_for_everyone()

    if i % GENERATE_EVERY == 0:
        model.eval()

        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        inp = inp.to(accelerator.device)

        prime = decode_tokens(inp)
        print(f"%s \n\n %s", (prime, "*" * 100))

        prompt = inp[None, ...]

        sampled = base_decoding(model, prompt, GENERATE_LENGTH)

        base_decode_output = decode_tokens(sampled[0])

        print("\n\n", base_decode_output, "\n")

.\lucidrains\genetic-algorithm-pytorch\bcga.py

"""
Bee Colonies Genetic Algorithm

Here we simulate different colonies to maintain diversity. At each generation, one allow a small subset of bees from each colony to immigrate to another
"""

import torch
import einx

# constants

GOAL = 'Attention is all you need'

COLONIES = 10
POP_SIZE = 250
MUTATION_PROB = 0.05
STRONG_MUTATION_PROB = 0.15
NUM_TOURNAMENT_PARTICIPANTS = 25
MIGRATE_EVERY = 5
FRAC_BEES_MIGRANTS = 0.1

# encode and decode functions

# 检查变量是否存在
def exists(v):
    return v is not None

# 将字符串编码为张量
def encode(s):
    return torch.tensor([ord(c) for c in s])

# 将张量解码为字符串
def decode(t):
    return ''.join([chr(i) for i in t.tolist()])

# 计算适应度函数
def calc_fitness(genes, target):
    return 1. / (genes - target).square().sum(dim = -1)

# derived constants

# 目标字符串长度
gene_length = len(GOAL)
# 目标基因
target_gene = encode(GOAL)

# 计算基因突变数量
num_code_mutate = MUTATION_PROB * gene_length
strong_num_code_mutate = STRONG_MUTATION_PROB * gene_length

# 计算迁移的蜜蜂数量
num_bees_migrate = int((POP_SIZE - 1) * FRAC_BEES_MIGRANTS)

# queen bee genetic algorithm

generation = 1

# 初始化种群
colonies = torch.randint(0, 255, (COLONIES, POP_SIZE - 1, gene_length))
colonies_arange = torch.arange(COLONIES)[..., None]

# 初始化皇后蜜蜂
queens = torch.randint(0, 255, (COLONIES, gene_length))
queen_fitnesses = calc_fitness(queens, target_gene)

while True:
    print(f"\n\ngeneration {generation}\n")

    # sort population by fitness

    # 计算种群适应度
    colony_fitnesses = calc_fitness(colonies, target_gene)

    # 按适应度降序排序种群
    indices = colony_fitnesses.sort(descending = True).indices
    colonies, colony_fitnesses = colonies[colonies_arange, indices], colony_fitnesses[colonies_arange, indices]

    # display every generation

    for i, (pool, fitnesses) in enumerate(zip(colonies[:, :10], colony_fitnesses[:, :10])):
        print(f'\ncolony {i + 1}:\n')

        if exists(queens):
            queen, queen_fitness = queens[i], queen_fitnesses[i]
            print(f"Q: {decode(queen)} ({queen_fitness.item():.3f})\n")

        for gene, fitness in zip(pool, fitnesses):
            print(f"{decode(gene)} ({fitness.item():.3f})")
    
    # if one of the children has a better fitness than queen, that child becomes the new queen
    # and the queen replaces the worst bee in the population, kept around for at least one generation more

    has_new_queen = colony_fitnesses[:, 0] > queen_fitnesses

    pop_arange = torch.arange(POP_SIZE)
    pop_arange_with_offset = pop_arange + has_new_queen[:, None]

    colonies = torch.cat((
        queens[:, None, :],
        colonies,
        queens[:, None, :]
    ), dim = -2)

    colony_fitnesses = torch.cat((
        queen_fitnesses[:, None],
        colony_fitnesses,
        queen_fitnesses[:, None]
    ), dim = -1)

    colonies = colonies[colonies_arange, pop_arange_with_offset]
    colony_fitnesses = colony_fitnesses[colonies_arange, pop_arange_with_offset]

    queens, colonies = colonies[:, 0], colonies[:, 1:]
    queen_fitnesses, colony_fitnesses = colony_fitnesses[:, 0], colony_fitnesses[:, 1:]

    # solved if any fitness is inf

    if (queen_fitnesses == float('inf')).any():
        print(f'\nsolved at generation {generation}')
        break

    # deterministic tournament selection - let top winner become parent with queen

    colonies_arange_ = colonies_arange[..., None]
    contender_ids = torch.randn((COLONIES, POP_SIZE - 1, POP_SIZE - 1)).argsort(dim = -1)[..., :NUM_TOURNAMENT_PARTICIPANTS]
    participants, tournaments = colonies[colonies_arange_, contender_ids], colony_fitnesses[colonies_arange_, contender_ids]
    top_winner = tournaments.topk(1, dim = -1, largest = True, sorted = False).indices
    parents = einx.get_at('... [t] g, ... 1 -> ... g', participants, top_winner)

    # potential parents with queen is strongly mutated ("Mutant Bee")

    strong_mutate_mask = torch.randn(parents.shape).argsort(dim = -1) < strong_num_code_mutate
    noise = torch.randint(0, 2, parents.shape) * 2 - 1
    mutated_parents = torch.where(strong_mutate_mask, parents + noise, parents)
    mutated_parents.clamp_(0, 255)
    # 随机进行50%的基因代码混合,而不是在中点处进行连续的交叉

    # 生成一个随机的掩码,用于确定哪些基因需要进行混合
    rand_mix_mask = torch.randn(mutated_parents.shape).argsort(dim=-1) < (gene_length // 2)

    # 根据随机混合的掩码,将皇后和变异后的父代进行基因混合
    colonies = einx.where('c p g, c g, c p g', rand_mix_mask, queens, mutated_parents)

    # 对种群中的基因进行突变

    # 生成一个用于确定哪些基因需要突变的掩码
    mutate_mask = torch.randn(colonies.shape).argsort(dim=-1) < num_code_mutate
    # 生成一个随机的噪声,用于基因突变
    noise = torch.randint(0, 2, colonies.shape) * 2 - 1

    # 根据突变掩码,对种群中的基因进行突变
    colonies = torch.where(mutate_mask, colonies + noise, colonies)
    # 将基因值限制在0到255之间
    colonies.clamp_(0, 255)

    # 允许一部分蜜蜂迁移到相邻的群落

    # 如果当前代数是迁移周期的倍数,并且有蜜蜂需要迁移
    if not (generation % MIGRATE_EVERY) and num_bees_migrate > 0:
        # 将一部分蜜蜂迁移到相邻的群落
        colonies, migrant_colonies = colonies[:, :-num_bees_migrate], colonies[:, -num_bees_migrate:]
        # 将迁移的蜜蜂群落向右滚动一个位置
        migrant_colonies = torch.roll(migrant_colonies, 1, dims=0)
        # 将迁移后的蜜蜂群落合并回原始种群
        colonies = torch.cat((colonies, migrant_colonies), dim=1)

    # 增加代数计数

    generation += 1

.\lucidrains\genetic-algorithm-pytorch\bega.py

"""
Queen-bee evolution for genetic algorithms - Jung 2003

Inspired by evolution of bees, the fittest solution is designated the "queen", and the rest of the population contends to mate with it. The strong exploitation is balanced by a higher than normal mutation rate.
For some problems, the paper claims convergence at 2-3 orders of magnitude faster

https://www.researchgate.net/publication/3385719_Queen-bee_evolution_for_genetic_algorithms
"""

import torch
from einops import repeat
from einx import get_at

# constants

GOAL = 'Attention is all you need'  # 目标字符串

POP_SIZE = 100  # 种群大小
MUTATION_PROB = 0.04  # 突变概率
STRONG_MUTATION_RATE = 0.1  # 强突变率
STRONG_MUTATION_PROB = 0.25  # 强突变概率
NUM_TOURNAMENT_PARTICIPANTS = 25  # 锦标赛参与者数量

# encode and decode functions

def encode(s):
    return torch.tensor([ord(c) for c in s])  # 将字符串编码为张量

def decode(t):
    return ''.join([chr(i) for i in t.tolist()])  # 将张量解码为字符串

# derived constants

gene_length = len(GOAL)  # 目标字符串长度
gene_midpoint = gene_length // 2  # 目标字符串中点位置
target_gene = encode(GOAL)  # 目标字符串编码

strong_mutate_pool_size = STRONG_MUTATION_RATE * POP_SIZE  # 强突变池大小
num_code_mutate = MUTATION_PROB * gene_length  # 码位突变数量
strong_num_code_mutate = STRONG_MUTATION_PROB * gene_length  # 强码位突变数量

# queen bee genetic algorithm

generation = 1  # 代数

pool = torch.randint(0, 255, (POP_SIZE, gene_length))  # 随机初始化种群

queen = queen_fitness = None  # 初始化皇后和皇后适应度

while True:
    print(f"\n\ngeneration {generation}\n")  # 打印当前代数

    # sort population by fitness

    fitnesses = 1. / torch.square(pool - target_gene).sum(dim = -1)  # 计算适应度

    indices = fitnesses.sort(descending = True).indices  # 根据适应度排序种群
    pool, fitnesses = pool[indices], fitnesses[indices]

    # display every generation

    if queen is not None:
        print("queen:")
        print(f"{decode(queen)} ({queen_fitness.item():.3f})\n")  # 打印皇后及其适应度

    for gene, fitness in zip(pool, fitnesses):
        print(f"{decode(gene)} ({fitness.item():.3f})")  # 打印每个基因及其适应度
    
    # if one of the children has a better fitness than queen, that child becomes the new queen
    # and the queen replaces the worst bee in the population, kept around for at least one generation more

    if queen is not None and queen_fitness < fitnesses[0]:
        pool = torch.cat((pool, queen[None, :]), dim = 0)  # 将皇后加入种群
        fitnesses = torch.cat((fitnesses, queen_fitness[None]), dim = 0)
        queen = queen_fitness = None

    # separate the queen bee from the rest of the population

    if queen is None:
        queen, pool = pool[0], pool[1:]  # 分离皇后和种群
        queen_fitness, fitnesses = fitnesses[0], fitnesses[1:]

    # solved if any queen fitness is inf

    if (queen_fitness == float('inf')).any():  # 如果皇后适应度为无穷大,则问题已解决
        break

    # deterministic tournament selection - let top winner become parent with queen

    contender_ids = torch.randn((POP_SIZE - 1, POP_SIZE - 1)).argsort(dim = -1)[..., :NUM_TOURNAMENT_PARTICIPANTS]  # 锦标赛选择参与者
    participants, tournaments = pool[contender_ids], fitnesses[contender_ids]
    top_winner = tournaments.topk(1, dim = -1, largest = True, sorted = False).indices  # 选择最优参与者
    parents = get_at('p [t] g, p 1 -> p g', participants, top_winner)  # 获取父母基因

    # cross over all chosen drones with the queen

    queen_parents = repeat(queen, '... -> p ...', p = POP_SIZE - 1)  # 重复皇后基因
    queen_and_parents = torch.stack((queen_parents, parents), dim = 1)  # 合并皇后和父母基因

    rand_crossover_order = torch.randn(queen_and_parents.shape[:2]).argsort(dim = -1)  # 随机交叉排序

    batch_arange = torch.arange(POP_SIZE - 1)[..., None]
    queen_and_parents = queen_and_parents[batch_arange, rand_crossover_order]
    queen_parents, parents = queen_and_parents.unbind(dim = 1)

    pool = torch.cat((queen_parents[:, :gene_midpoint], parents[:, gene_midpoint:]), dim = -1)  # 交叉生成新种群

    # mutate genes in population

    mutate_mask = torch.randn(pool.shape).argsort(dim = -1) < num_code_mutate  # 生成突变掩码
    noise = torch.randint(0, 2, pool.shape) * 2 - 1
    mutated_pool = torch.where(mutate_mask, pool + noise, pool)  # 码位突变

    strong_mutate_mask = torch.randn(pool.shape).argsort(dim = -1) < strong_num_code_mutate  # 生成强突变掩码
    noise = torch.randint(0, 2, pool.shape) * 2 - 1
    strong_mutated_pool = torch.where(strong_mutate_mask, pool + noise, pool)  # 强码位突变
    # 生成一个布尔掩码,用于选择强变异池中的个体
    strong_mutate_pool_mask = torch.randn(POP_SIZE - 1).argsort(dim=-1) < strong_mutate_pool_size

    # 根据强变异池掩码,选择强变异池中的个体或者普通变异池中的个体,组成新的池
    pool = torch.where(strong_mutate_pool_mask[:, None], strong_mutated_pool, mutated_pool)
    # 将池中的值限制在0到255之间
    pool.clamp_(0, 255)

    # 增加一代
    generation += 1

.\lucidrains\genetic-algorithm-pytorch\ga.py

"""
Genetic Algorithm - formalized by John H. Holland in 1992, but has been talked about since 1960-70s

https://www.researchgate.net/figure/Hollands-canonical-genetic-algorithm-Holland-1992_fig4_221174380
"""

import torch
from einx import get_at

# constants

GOAL = 'Attention is all you need'  # 目标字符串

POP_SIZE = 100  # 种群大小
MUTATION_RATE = 0.04  # 变异率
FRAC_FITTEST_SURVIVE = 0.25  # 最适应个体存活比例
FRAC_TOURNAMENT = 0.25  # 锦标赛选择比例
ELITE_FRAC = 0.05  # 精英比例

# encode and decode functions

def encode(s):
    return torch.tensor([ord(c) for c in s])  # 将字符串编码为张量

def decode(t):
    return ''.join([chr(i) for i in t.tolist()])  # 将张量解码为字符串

# derived constants

gene_length = len(GOAL)  # 目标字符串长度
gene_midpoint = gene_length // 2  # 目标字符串中点位置
target_gene = encode(GOAL)  # 目标字符串编码

keep_fittest_len = int(POP_SIZE * FRAC_FITTEST_SURVIVE)  # 保留最适应个体数量
num_elite = int(ELITE_FRAC * POP_SIZE)  # 精英数量
num_repro_and_mutate = keep_fittest_len - num_elite  # 繁殖和变异数量
num_tournament_contenders = int(num_repro_and_mutate * FRAC_TOURNAMENT)  # 锦标赛参与者数量
num_children = POP_SIZE - keep_fittest_len  # 子代数量
num_mutate = MUTATION_RATE * gene_length  # 变异基因数量

assert num_tournament_contenders >= 2  # 断言确保锦标赛参与者数量大于等于2

# genetic algorithm

generation = 1  # 代数计数器

pool = torch.randint(0, 255, (POP_SIZE, gene_length))  # 初始化种群,随机生成基因

while True:
    print(f"\n\ngeneration {generation}\n")  # 打印当前代数

    # sort population by fitness

    fitnesses = 1. / torch.square(pool - target_gene).sum(dim = -1)  # 计算适应度

    indices = fitnesses.sort(descending = True).indices  # 根据适应度对种群排序
    pool, fitnesses = pool[indices], fitnesses[indices]

    # keep the fittest

    pool, fitnesses = pool[:keep_fittest_len], fitnesses[:keep_fittest_len]  # 保留最适应个体

    # display every generation

    for gene, fitness in zip(pool, fitnesses):
        print(f"{decode(gene)} ({fitness.item():.3f})")  # 打印每个个体的基因和适应度

    # solved if any fitness is inf

    if (fitnesses == float('inf')).any():  # 如果有个体的适应度为无穷大,则问题已解决
        break

    # elites can pass directly to next generation

    elites, pool = pool[:num_elite], pool[num_elite:]  # 精英直接传递到下一代
    elites_fitnesses, fitnesses = fitnesses[:num_elite], fitnesses[num_elite:]

    # deterministic tournament selection - let top 2 winners become parents

    contender_ids = torch.randn((num_children, num_repro_and_mutate)).argsort(dim = -1)[..., :num_tournament_contenders]  # 锦标赛选择参与者
    participants, tournaments = pool[contender_ids], fitnesses[contender_ids]
    top2_winners = tournaments.topk(2, dim = -1, largest = True, sorted = False).indices  # 选择前两名作为父母
    parents = get_at('p [t] g, p w -> p w g', participants, top2_winners)  # 获取父母

    # cross over recombination of parents

    parent1, parent2 = parents.unbind(dim = 1)  # 拆分父母
    children = torch.cat((parent1[:, :gene_midpoint], parent2[:, gene_midpoint:]), dim = -1)  # 交叉重组父母基因

    pool = torch.cat((pool, children))  # 将子代加入种群

    # mutate genes in population

    mutate_mask = torch.randn(pool.shape).argsort(dim = -1) < num_mutate  # 生成变异掩码
    noise = torch.randint(0, 2, pool.shape) * 2 - 1  # 生成变异噪声
    pool = torch.where(mutate_mask, pool + noise, pool)  # 变异
    pool.clamp_(0, 255)  # 限制基因值范围在0-255之间

    # add back the elites

    pool = torch.cat((elites, pool))  # 将精英加回种群

    generation += 1  # 代数加一

.\lucidrains\genetic-algorithm-pytorch\inbreed.py

"""
Genetic Algorithm

but without first generation inbreeding
"""

import torch
import einx
from einx import get_at, rearrange

# constants

GOAL = 'Attention is all you need'  # 目标字符串

POP_SIZE = 100  # 种群大小
MUTATION_RATE = 0.04  # 变异率
FRAC_FITTEST_SURVIVE = 0.25  # 存活最适应个体的比例
FRAC_TOURNAMENT = 0.25  # 锦标赛选择的比例
ELITE_FRAC = 0.05  # 精英个体的比例

# encode and decode functions

def encode(s):
    return torch.tensor([ord(c) for c in s])  # 将字符串编码为张量

def decode(t):
    return ''.join([chr(i) for i in t.tolist()])  # 将张量解码为字符串

# derived constants

gene_length = len(GOAL)  # 目标字符串的长度
gene_midpoint = gene_length // 2  # 目标字符串的中点位置
target_gene = encode(GOAL)  # 目标字符串的编码

keep_fittest_len = int(POP_SIZE * FRAC_FITTEST_SURVIVE)  # 保留最适应个体的数量
num_elite = int(ELITE_FRAC * POP_SIZE)  # 精英个体的数量
num_repro_and_mutate = keep_fittest_len - num_elite  # 繁殖和变异的个体数量
num_tournament_contenders = int(num_repro_and_mutate * FRAC_TOURNAMENT)  # 锦标赛的参与者数量
num_children = POP_SIZE - keep_fittest_len  # 子代个体数量
num_mutate = MUTATION_RATE * gene_length  # 变异的基因数量

assert num_tournament_contenders >= 2  # 断言确保锦标赛的参与者数量大于等于2

# genetic algorithm

generation = 1  # 代数

parent_ids = torch.full((POP_SIZE, 2), -1, dtype=torch.long)  # 父母的ID
pool = torch.randint(0, 255, (POP_SIZE, gene_length))  # 种群中的个体

while True:
    print(f"\n\ngeneration {generation}\n")  # 打印当前代数

    # sort population by fitness

    fitnesses = 1. / torch.square(pool - target_gene).sum(dim=-1)  # 计算适应度

    indices = fitnesses.sort(descending=True).indices  # 根据适应度对种群进行排序
    pool, parent_ids, fitnesses = pool[indices], parent_ids[indices], fitnesses[indices]

    # keep the fittest

    pool, parent_ids, fitnesses = pool[:keep_fittest_len], parent_ids[:keep_fittest_len], fitnesses[:keep_fittest_len]  # 保留最适应的个体

    # display every generation

    for gene, fitness in zip(pool, fitnesses):
        print(f"{decode(gene)} ({fitness.item():.3f})")  # 打印每个个体的基因和适应度

    # solved if any fitness is inf

    if (fitnesses == float('inf')).any():  # 如果任何适应度为无穷大,则问题已解决
        break

    # elites can pass directly to next generation

    elites, pool = pool[:num_elite], pool[num_elite:]  # 精英个体直接传递到下一代
    elites_fitnesses, fitnesses = fitnesses[:num_elite], fitnesses[num_elite:]
    elites_parent_ids, parent_ids = parent_ids[:num_elite], parent_ids[num_elite:]

    elites_parent_ids.fill_(-1)  # 将精英个体的父母ID填充为-1

    # deterministic tournament selection
    # 2 tournaments - the second tournament removes all contestants with shared parents with 1st winner

    first_contender_ids = torch.randn((num_children, num_repro_and_mutate)).argsort(dim=-1)[..., :num_tournament_contenders]  # 第一轮锦标赛的参与者ID
    first_participants, participants_parent_ids, tournaments = pool[first_contender_ids], parent_ids[first_contender_ids], fitnesses[first_contender_ids]

    first_winner = tournaments.topk(1, dim=-1, largest=True, sorted=False).indices  # 第一轮锦标赛的获胜者
    first_winner = rearrange('p 1 -> p', first_winner)

    first_parent_ids = get_at('p [t] i, p -> p i', participants_parent_ids, first_winner)  # 第一轮锦标赛的获胜者的父母ID

    # second tournament, masking out any siblings to first winners

    contender_scores = torch.randn((num_children, num_repro_and_mutate))  # 参与者得分
    self_mask = rearrange('i -> i 1', first_winner) == torch.arange(num_repro_and_mutate)  # 自身掩码
    contender_scores = torch.where(self_mask, 1e6, contender_scores)

    sibling_mask = (rearrange('p i -> p 1 i 1', first_parent_ids) == rearrange('c j -> 1 c 1 j', parent_ids))  # 兄弟掩码
    valid_parent_mask = (rearrange('p i -> p 1 i 1', first_parent_ids) != -1) & (rearrange('c j -> 1 c 1 j', parent_ids) != -1)  # 有效父母掩码
    num_shared_parents = (sibling_mask & valid_parent_mask).float().sum(dim=(-1, -2))  # 共享父母的数量
    contender_scores += num_shared_parents * 1e3

    second_contender_ids = contender_scores.argsort(dim=-1)[..., :num_tournament_contenders]  # 第二轮锦标赛的参与者ID
    second_participants, second_tournaments = pool[second_contender_ids], fitnesses[second_contender_ids]
    second_winner = second_tournaments.topk(1, dim=-1, largest=True, sorted=False).indices  # 第二轮锦标赛的获胜者
    second_winner = rearrange('p 1 -> p', second_winner)

    # get parents

    first_ids = get_at('p [t], p -> p', first_contender_ids, first_winner)  # 第一轮锦标赛的获胜者的ID
    second_ids = get_at('p [t], p -> p', second_contender_ids, second_winner)  # 第二轮锦标赛的获胜者的ID

    new_parent_ids = torch.stack((first_ids, second_ids), dim=-1)  # 新的父母ID对
    # 从第一组参与者和第一组获胜者中获取父母1
    parent1 = get_at('p [t] g, p -> p g', first_participants, first_winner)
    # 从第二组参与者和第二组获胜者中获取父母2
    parent2 = get_at('p [t] g, p -> p g', second_participants, second_winner)

    # 交叉重组父母的基因

    # 将父母1的前半部分和父母2的后半部分连接起来形成子代
    children = torch.cat((parent1[:, :gene_midpoint], parent2[:, gene_midpoint:]), dim=-1)

    # 将子代添加到种群中
    pool = torch.cat((pool, children))

    # 重置父母ID数组并将新的父母ID添加到其中
    parent_ids.fill_(-1)
    parent_ids = torch.cat((parent_ids, new_parent_ids))

    # 在种群中突变基因

    # 生成一个用于确定哪些基因需要突变的掩码
    mutate_mask = torch.randn(pool.shape).argsort(dim=-1) < num_mutate
    # 生成一个随机噪声数组,用于基因突变
    noise = torch.randint(0, 2, pool.shape) * 2 - 1
    # 根据掩码决定是否对基因进行突变,并添加随机噪声
    pool = torch.where(mutate_mask, pool + noise, pool)
    # 将基因值限制在0到255之间
    pool.clamp_(0, 255)

    # 将精英个体重新添加到种群中

    # 将精英个体添加回种群中
    pool = torch.cat((elites, pool))
    # 将精英个体的父母ID添加回父母ID数组中
    parent_ids = torch.cat((elites_parent_ids, parent_ids))

    # 递增代数计数器
    generation += 1

.\lucidrains\genetic-algorithm-pytorch\qbmb.py

"""
Queen-bee and Mutant-bee evolution for genetic algorithms - Jung 2007

4 years after proposing the Queen bee evolution genetic algorithm, Jung proposes a simplification to get rid of a few hyperparameters

In the new scheme, all the selected bees to mate with the queen undergo strong mutation prior to crossover
This scheme therefore better preserves the queen's genetic code. He shows through various experiments that this performs just as well as the original algorithm while being simpler

https://www.researchgate.net/publication/290131255_Queen-bee_and_Mutant-bee_Evolution_for_Genetic_Algorithms
"""

import torch
from einops import repeat
from einx import get_at

# constants

GOAL = 'Attention is all you need'

POP_SIZE = 100
MUTATION_PROB = 0.04
STRONG_MUTATION_PROB = 0.25
NUM_TOURNAMENT_PARTICIPANTS = 25

# encode and decode functions

def encode(s):
    return torch.tensor([ord(c) for c in s])

def decode(t):
    return ''.join([chr(i) for i in t.tolist()])

# derived constants

gene_length = len(GOAL)
gene_midpoint = gene_length // 2
target_gene = encode(GOAL)

num_code_mutate = MUTATION_PROB * gene_length
strong_num_code_mutate = STRONG_MUTATION_PROB * gene_length

# queen bee genetic algorithm

generation = 1

pool = torch.randint(0, 255, (POP_SIZE, gene_length))

queen = queen_fitness = None

while True:
    print(f"\n\ngeneration {generation}\n")

    # sort population by fitness

    fitnesses = 1. / torch.square(pool - target_gene).sum(dim = -1)

    indices = fitnesses.sort(descending = True).indices
    pool, fitnesses = pool[indices], fitnesses[indices]

    # display every generation

    if queen is not None:
        print("queen:")
        print(f"{decode(queen)} ({queen_fitness.item():.3f})\n")

    for gene, fitness in zip(pool, fitnesses):
        print(f"{decode(gene)} ({fitness.item():.3f})")

    # if one of the children has a better fitness than queen, that child becomes the new queen
    # and the queen replaces the worst bee in the population, kept around for at least one generation more

    if queen is not None and queen_fitness < fitnesses[0]:
        pool = torch.cat((pool, queen[None, :]), dim = 0)
        fitnesses = torch.cat((fitnesses, queen_fitness[None]), dim = 0)
        queen = queen_fitness = None

    # separate the queen bee from the rest of the population

    if queen is None:
        queen, pool = pool[0], pool[1:]
        queen_fitness, fitnesses = fitnesses[0], fitnesses[1:]

    # solved if any fitness is inf

    if (queen_fitness == float('inf')).any():
        break

    # deterministic tournament selection - let top winner become parent with queen

    contender_ids = torch.randn((POP_SIZE - 1, POP_SIZE - 1)).argsort(dim = -1)[..., :NUM_TOURNAMENT_PARTICIPANTS]
    participants, tournaments = pool[contender_ids], fitnesses[contender_ids]
    top_winner = tournaments.topk(1, dim = -1, largest = True, sorted = False).indices
    parents = get_at('... [t] g, ... 1 -> ... g', participants, top_winner)

    # potential parents with queen is strongly mutated ("Mutant Bee")

    strong_mutate_mask = torch.randn(parents.shape).argsort(dim = -1) < strong_num_code_mutate
    noise = torch.randint(0, 2, parents.shape) * 2 - 1
    mutated_parents = torch.where(strong_mutate_mask, parents + noise, parents)
    mutated_parents.clamp_(0, 255)

    # cross over all chosen drones with the queen

    queen_parents = repeat(queen, '... -> p ...', p = POP_SIZE - 1)
    queen_and_parents = torch.stack((queen_parents, mutated_parents), dim = 1)

    # in my experiments, the crossover point must be random between queen and drones for this to work
    # todo: get caught up with all the different types of crossover operators

    rand_crossover_order = torch.randn(queen_and_parents.shape[:2]).argsort(dim = -1)

    batch_arange = torch.arange(POP_SIZE - 1)[..., None]
    queen_and_parents = queen_and_parents[batch_arange, rand_crossover_order]
    # 从 queen_and_parents 张量中解绑出 queen_parents 和 mutated_parents,沿着第一个维度进行解绑
    queen_parents, mutated_parents = queen_and_parents.unbind(dim = 1)

    # 将 queen_parents 和 mutated_parents 沿着最后一个维度拼接起来,形成新的 pool 张量
    pool = torch.cat((queen_parents[:, :gene_midpoint], mutated_parents[:, gene_midpoint:]), dim = -1)

    # 对种群中的基因进行变异

    # 创建一个与 pool 张量相同形状的张量,其中的元素按照正态分布排序,小于 num_code_mutate 的元素为 True
    mutate_mask = torch.randn(pool.shape).argsort(dim = -1) < num_code_mutate
    # 创建一个与 pool 张量相同形状的张量,元素为 0 或 1
    noise = torch.randint(0, 2, pool.shape) * 2 - 1

    # 根据 mutate_mask,对 pool 张量中的元素进行变异,如果 mutate_mask 中对应位置为 True,则加上 noise 中对应位置的值
    pool = torch.where(mutate_mask, pool + noise, pool)
    # 将 pool 张量中的元素限制在 0 到 255 之间
    pool.clamp_(0, 255)

    # 增加一代

    generation += 1

genetic-algorithm-pytorch

a simple genetic algorithm written in Pytorch

running

$ python ga.py

.\lucidrains\geometric-vector-perceptron\examples\data_handler.py

# 作者:Eric Alcaide

# 从 https://github.com/jonathanking/sidechainnet 借用了大部分代码
# 下面是其许可证:

# 版权所有 2020 Jonathan King
# 允许以源代码和二进制形式重新分发和使用,无论是否进行修改,只要满足以下条件:
#
# 1. 源代码的再分发必须保留上述版权声明、此条件列表和以下免责声明。
#
# 2. 以二进制形式再分发时,必须在提供的文档和/或其他材料中复制上述版权声明、此条件列表和以下免责声明。
#
# 3. 未经特定事先书面许可,不得使用版权持有人或其贡献者的名称来认可或推广从本软件派生的产品。
#
# 版权持有人和贡献者提供的本软件是按原样提供的,不提供任何明示或暗示的担保,包括但不限于对适销性和特定用途的暗示担保。
# 在任何情况下,无论是在合同、严格责任还是侵权(包括疏忽或其他方式)的情况下,版权持有人或贡献者均不对任何直接、间接、附带、特殊、惩罚性或后果性损害(包括但不限于替代商品或服务的采购、使用、数据或利润损失或业务中断)负责,即使已被告知可能发生此类损害。

import warnings
warnings.filterwarnings("ignore")

import torch
import numpy as np 
from einops import repeat, rearrange

######################
## structural utils ##
######################

def get_dihedral(c1, c2, c3, c4):
    """ 返回弯曲角度(弧度)。
        将使用来自以下链接的 atan2 公式:
        https://en.wikipedia.org/wiki/Dihedral_angle#In_polymer_physics
        输入:
        * c1: (batch, 3) 或 (3,)
        * c2: (batch, 3) 或 (3,)
        * c3: (batch, 3) 或 (3,)
        * c4: (batch, 3) 或 (3,)
    """
    u1 = c2 - c1
    u2 = c3 - c2
    u3 = c4 - c3

    return torch.atan2( ( (torch.norm(u2, dim=-1, keepdim=True) * u1) * torch.cross(u2,u3, dim=-1) ).sum(dim=-1) ,  
                        (  torch.cross(u1,u2, dim=-1) * torch.cross(u2, u3, dim=-1) ).sum(dim=-1) )


def get_angle(c1, c2, c3):
    """ 返回角度(弧度)。
        输入:
        * c1: (batch, 3) 或 (3,)
        * c2: (batch, 3) 或 (3,)
        * c3: (batch, 3) 或 (3,)
    """
    u1 = c2 - c1
    u2 = c3 - c2

    # 不使用传统的 arccos,因为它得到的是“最小角度”,不一定是我们想要的
    # return torch.acos( (u1*u2).sum(dim=-1) / (torch.norm(u1, dim=-1)*torch.norm(u2, dim=-1) )+

    # 更好地使用 atan2 公式:atan2(cross, dot) 来自这里:
    # https://johnblackburne.blogspot.com/2012/05/angle-between-two-3d-vectors.html

    # 添加负号,因为我们希望角度是反向的 - sidechainnet 问题
    return torch.atan2( torch.norm(torch.cross(u1,u2, dim=-1), dim=-1), 
                        -(u1*u2).sum(dim=-1) ) 


def kabsch_torch(X, Y):
    """ 将 X 对齐到 Y 的 Kabsch 对齐。
        假设 X、Y 都是 (D, N) 的形式 - 通常是 (3, N)
    """
    # 将 X 和 Y 居中到原点
    X_ = X - X.mean(dim=-1, keepdim=True)
    Y_ = Y - Y.mean(dim=-1, keepdim=True)
    # 计算协方差矩阵(对于每个批次中的蛋白质)
    C = torch.matmul(X_, Y_.t())
    # 通过 SVD 计算最佳旋转矩阵 - 警告!W 必须被转置
    V, S, W = torch.svd(C.detach())
    # 方向校正的行列式符号
    d = (torch.det(V) * torch.det(W)) < 0.0
    if d:
        S[-1]    = S[-1] * (-1)
        V[:, -1] = V[:, -1] * (-1)
    # 创建旋转矩阵 U
    U = torch.matmul(V, W.t())
    # 计算旋转
    X_ = torch.matmul(X_.t(), U).t()
    # 返回居中和对齐后的 X_ 和 Y_
    return X_, Y_
# 计算两个张量之间的均方根偏差,假设 X 和 Y 的形状都是 (batch, d, n),通常是 (batch, 3, N)
def rmsd_torch(X, Y):
    """ Assumes x,y are both (batch, d, n) - usually (batch, 3, N). """
    return torch.sqrt( torch.mean((X - Y)**2, axis=(-1, -2)) )


############
### INFO ###
############

# 包含了不同氨基酸的构建信息的字典
SC_BUILD_INFO = {
    'A': {
        'angles-names': ['N-CA-CB'],
        'angles-types': ['N -CX-CT'],
        'angles-vals': [1.9146261894377796],
        'atom-names': ['CB'],
        'bonds-names': ['CA-CB'],
        'bonds-types': ['CX-CT'],
        'bonds-vals': [1.526],
        'torsion-names': ['C-N-CA-CB'],
        'torsion-types': ['C -N -CX-CT'],
        'torsion-vals': ['p']
    },
    'R': {
        'angles-names': [
            'N-CA-CB', 'CA-CB-CG', 'CB-CG-CD', 'CG-CD-NE', 'CD-NE-CZ', 'NE-CZ-NH1',
            'NE-CZ-NH2'
        ],
        'angles-types': [
            'N -CX-C8', 'CX-C8-C8', 'C8-C8-C8', 'C8-C8-N2', 'C8-N2-CA', 'N2-CA-N2',
            'N2-CA-N2'
        ],
        'angles-vals': [
            1.9146261894377796, 1.911135530933791, 1.911135530933791, 1.9408061282176945,
            2.150245638457014, 2.0943951023931953, 2.0943951023931953
        ],
        'atom-names': ['CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2'],
        'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD', 'CD-NE', 'NE-CZ', 'CZ-NH1', 'CZ-NH2'],
        'bonds-types': ['CX-C8', 'C8-C8', 'C8-C8', 'C8-N2', 'N2-CA', 'CA-N2', 'CA-N2'],
        'bonds-vals': [1.526, 1.526, 1.526, 1.463, 1.34, 1.34, 1.34],
        'torsion-names': [
            'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD', 'CB-CG-CD-NE', 'CG-CD-NE-CZ',
            'CD-NE-CZ-NH1', 'CD-NE-CZ-NH2'
        ],
        'torsion-types': [
            'C -N -CX-C8', 'N -CX-C8-C8', 'CX-C8-C8-C8', 'C8-C8-C8-N2', 'C8-C8-N2-CA',
            'C8-N2-CA-N2', 'C8-N2-CA-N2'
        ],
        'torsion-vals': ['p', 'p', 'p', 'p', 'p', 'p', 'i']
    },
    'N': {
        'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-OD1', 'CB-CG-ND2'],
        'angles-types': ['N -CX-2C', 'CX-2C-C ', '2C-C -O ', '2C-C -N '],
        'angles-vals': [
            1.9146261894377796, 1.9390607989657, 2.101376419401173, 2.035053907825388
        ],
        'atom-names': ['CB', 'CG', 'OD1', 'ND2'],
        'bonds-names': ['CA-CB', 'CB-CG', 'CG-OD1', 'CG-ND2'],
        'bonds-types': ['CX-2C', '2C-C ', 'C -O ', 'C -N '],
        'bonds-vals': [1.526, 1.522, 1.229, 1.335],
        'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-OD1', 'CA-CB-CG-ND2'],
        'torsion-types': ['C -N -CX-2C', 'N -CX-2C-C ', 'CX-2C-C -O ', 'CX-2C-C -N '],
        'torsion-vals': ['p', 'p', 'p', 'i']
    },
    'D': {
        'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-OD1', 'CB-CG-OD2'],
        'angles-types': ['N -CX-2C', 'CX-2C-CO', '2C-CO-O2', '2C-CO-O2'],
        'angles-vals': [
            1.9146261894377796, 1.9390607989657, 2.0420352248333655, 2.0420352248333655
        ],
        'atom-names': ['CB', 'CG', 'OD1', 'OD2'],
        'bonds-names': ['CA-CB', 'CB-CG', 'CG-OD1', 'CG-OD2'],
        'bonds-types': ['CX-2C', '2C-CO', 'CO-O2', 'CO-O2'],
        'bonds-vals': [1.526, 1.522, 1.25, 1.25],
        'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-OD1', 'CA-CB-CG-OD2'],
        'torsion-types': ['C -N -CX-2C', 'N -CX-2C-CO', 'CX-2C-CO-O2', 'CX-2C-CO-O2'],
        'torsion-vals': ['p', 'p', 'p', 'i']
    },
    'C': {
        'angles-names': ['N-CA-CB', 'CA-CB-SG'],
        'angles-types': ['N -CX-2C', 'CX-2C-SH'],
        'angles-vals': [1.9146261894377796, 1.8954275676658419],
        'atom-names': ['CB', 'SG'],
        'bonds-names': ['CA-CB', 'CB-SG'],
        'bonds-types': ['CX-2C', '2C-SH'],
        'bonds-vals': [1.526, 1.81],
        'torsion-names': ['C-N-CA-CB', 'N-CA-CB-SG'],
        'torsion-types': ['C -N -CX-2C', 'N -CX-2C-SH'],
        'torsion-vals': ['p', 'p']
    },
    'Q': {
        'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD', 'CG-CD-OE1', 'CG-CD-NE2'],
        'angles-types': ['N -CX-2C', 'CX-2C-2C', '2C-2C-C ', '2C-C -O ', '2C-C -N '],
        'angles-vals': [
            1.9146261894377796, 1.911135530933791, 1.9390607989657, 2.101376419401173,
            2.035053907825388
        ],
        'atom-names': ['CB', 'CG', 'CD', 'OE1', 'NE2'],
        'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD', 'CD-OE1', 'CD-NE2'],
        'bonds-types': ['CX-2C', '2C-2C', '2C-C ', 'C -O ', 'C -N '],
        'bonds-vals': [1.526, 1.526, 1.522, 1.229, 1.335],
        'torsion-names': [
            'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD', 'CB-CG-CD-OE1', 'CB-CG-CD-NE2'
        ],
        'torsion-types': [
            'C -N -CX-2C', 'N -CX-2C-2C', 'CX-2C-2C-C ', '2C-2C-C -O ', '2C-2C-C -N '
        ],
        'torsion-vals': ['p', 'p', 'p', 'p', 'i']
    },
    'E': {
        'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD', 'CG-CD-OE1', 'CG-CD-OE2'],
        'angles-types': ['N -CX-2C', 'CX-2C-2C', '2C-2C-CO', '2C-CO-O2', '2C-CO-O2'],
        'angles-vals': [
            1.9146261894377796, 1.911135530933791, 1.9390607989657, 2.0420352248333655,
            2.0420352248333655
        ],
        'atom-names': ['CB', 'CG', 'CD', 'OE1', 'OE2'],
        'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD', 'CD-OE1', 'CD-OE2'],
        'bonds-types': ['CX-2C', '2C-2C', '2C-CO', 'CO-O2', 'CO-O2'],
        'bonds-vals': [1.526, 1.526, 1.522, 1.25, 1.25],
        'torsion-names': [
            'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD', 'CB-CG-CD-OE1', 'CB-CG-CD-OE2'
        ],
        'torsion-types': [
            'C -N -CX-2C', 'N -CX-2C-2C', 'CX-2C-2C-CO', '2C-2C-CO-O2', '2C-2C-CO-O2'
        ],
        'torsion-vals': ['p', 'p', 'p', 'p', 'i']
    },
    'G': {
        'angles-names': [],
        'angles-types': [],
        'angles-vals': [],
        'atom-names': [],
        'bonds-names': [],
        'bonds-types': [],
        'bonds-vals': [],
        'torsion-names': [],
        'torsion-types': [],
        'torsion-vals': []
    },
    'H': {
        'angles-names': [
            'N-CA-CB', 'CA-CB-CG', 'CB-CG-ND1', 'CG-ND1-CE1', 'ND1-CE1-NE2', 'CE1-NE2-CD2'
        ],
        'angles-types': [
            'N -CX-CT', 'CX-CT-CC', 'CT-CC-NA', 'CC-NA-CR', 'NA-CR-NB', 'CR-NB-CV'
        ],
        'angles-vals': [
            1.9146261894377796, 1.9739673840055867, 2.0943951023931953,
            1.8849555921538759, 1.8849555921538759, 1.8849555921538759
        ],
        'atom-names': ['CB', 'CG', 'ND1', 'CE1', 'NE2', 'CD2'],
        'bonds-names': ['CA-CB', 'CB-CG', 'CG-ND1', 'ND1-CE1', 'CE1-NE2', 'NE2-CD2'],
        'bonds-types': ['CX-CT', 'CT-CC', 'CC-NA', 'NA-CR', 'CR-NB', 'NB-CV'],
        'bonds-vals': [1.526, 1.504, 1.385, 1.343, 1.335, 1.394],
        'torsion-names': [
            'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-ND1', 'CB-CG-ND1-CE1', 'CG-ND1-CE1-NE2',
            'ND1-CE1-NE2-CD2'
        ],
        'torsion-types': [
            'C -N -CX-CT', 'N -CX-CT-CC', 'CX-CT-CC-NA', 'CT-CC-NA-CR', 'CC-NA-CR-NB',
            'NA-CR-NB-CV'
        ],
        'torsion-vals': ['p', 'p', 'p', 3.141592653589793, 0.0, 0.0]
    },
    'I': {
        'angles-names': ['N-CA-CB', 'CA-CB-CG1', 'CB-CG1-CD1', 'CA-CB-CG2'],
        'angles-types': ['N -CX-3C', 'CX-3C-2C', '3C-2C-CT', 'CX-3C-CT'],
        'angles-vals': [
            1.9146261894377796, 1.911135530933791, 1.911135530933791, 1.911135530933791
        ],
        'atom-names': ['CB', 'CG1', 'CD1', 'CG2'],
        'bonds-names': ['CA-CB', 'CB-CG1', 'CG1-CD1', 'CB-CG2'],
        'bonds-types': ['CX-3C', '3C-2C', '2C-CT', '3C-CT'],
        'bonds-vals': [1.526, 1.526, 1.526, 1.526],
        'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG1', 'CA-CB-CG1-CD1', 'N-CA-CB-CG2'],
        'torsion-types': ['C -N -CX-3C', 'N -CX-3C-2C', 'CX-3C-2C-CT', 'N -CX-3C-CT'],
        'torsion-vals': ['p', 'p', 'p', 'p']
    },
    'L': {
        'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD1', 'CB-CG-CD2'],
        'angles-types': ['N -CX-2C', 'CX-2C-3C', '2C-3C-CT', '2C-3C-CT'],
        'angles-vals': [
            1.9146261894377796, 1.911135530933791, 1.911135530933791, 1.911135530933791
        ],
        'atom-names': ['CB', 'CG', 'CD1', 'CD2'],
        'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD1', 'CG-CD2'],
        'bonds-types': ['CX-2C', '2C-3C', '3C-CT', '3C-CT'],
        'bonds-vals': [1.526, 1.526, 1.526, 1.526],
        'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD1', 'CA-CB-CG-CD2'],
        'torsion-types': ['C -N -CX-2C', 'N -CX-2C-3C', 'CX-2C-3C-CT', 'CX-2C-3C-CT'],
        'torsion-vals': ['p', 'p', 'p', 'p']
    },
    'K': {
        'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD', 'CG-CD-CE', 'CD-CE-NZ'],
        'angles-types': ['N -CX-C8', 'CX-C8-C8', 'C8-C8-C8', 'C8-C8-C8', 'C8-C8-N3'],
        'angles-vals': [
            1.9146261894377796, 1.911135530933791, 1.911135530933791, 1.911135530933791,
            1.9408061282176945
        ],
        'atom-names': ['CB', 'CG', 'CD', 'CE', 'NZ'],
        'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD', 'CD-CE', 'CE-NZ'],
        'bonds-types': ['CX-C8', 'C8-C8', 'C8-C8', 'C8-C8', 'C8-N3'],
        'bonds-vals': [1.526, 1.526, 1.526, 1.526, 1.471],
        'torsion-names': [
            'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD', 'CB-CG-CD-CE', 'CG-CD-CE-NZ'
        ],
        'torsion-types': [
            'C -N -CX-C8', 'N -CX-C8-C8', 'CX-C8-C8-C8', 'C8-C8-C8-C8', 'C8-C8-C8-N3'
        ],
        'torsion-vals': ['p', 'p', 'p', 'p', 'p']
    },
    'M': {
        'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-SD', 'CG-SD-CE'],
        'angles-types': ['N -CX-2C', 'CX-2C-2C', '2C-2C-S ', '2C-S -CT'],
        'angles-vals': [
            1.9146261894377796, 1.911135530933791, 2.0018926520374962, 1.726130630222392
        ],
        'atom-names': ['CB', 'CG', 'SD', 'CE'],
        'bonds-names': ['CA-CB', 'CB-CG', 'CG-SD', 'SD-CE'],
        'bonds-types': ['CX-2C', '2C-2C', '2C-S ', 'S -CT'],
        'bonds-vals': [1.526, 1.526, 1.81, 1.81],
        'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-SD', 'CB-CG-SD-CE'],
        'torsion-types': ['C -N -CX-2C', 'N -CX-2C-2C', 'CX-2C-2C-S ', '2C-2C-S -CT'],
        'torsion-vals': ['p', 'p', 'p', 'p']
    },
    'F': {
        'angles-names': [
            'N-CA-CB', 'CA-CB-CG', 'CB-CG-CD1', 'CG-CD1-CE1', 'CD1-CE1-CZ', 'CE1-CZ-CE2',
            'CZ-CE2-CD2'
        ],
        'angles-types': [
            'N -CX-CT', 'CX-CT-CA', 'CT-CA-CA', 'CA-CA-CA', 'CA-CA-CA', 'CA-CA-CA',
            'CA-CA-CA'
        ],
        'angles-vals': [
            1.9146261894377796, 1.9896753472735358, 2.0943951023931953,
            2.0943951023931953, 2.0943951023931953, 2.0943951023931953, 2.0943951023931953
        ],
        'atom-names': ['CB', 'CG', 'CD1', 'CE1', 'CZ', 'CE2', 'CD2'],
        'bonds-names': [
            'CA-CB', 'CB-CG', 'CG-CD1', 'CD1-CE1', 'CE1-CZ', 'CZ-CE2', 'CE2-CD2'
        ],
        'bonds-types': ['CX-CT', 'CT-CA', 'CA-CA', 'CA-CA', 'CA-CA', 'CA-CA', 'CA-CA'],
        'bonds-vals': [1.526, 1.51, 1.4, 1.4, 1.4, 1.4, 1.4],
        'torsion-names': [
            'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD1', 'CB-CG-CD1-CE1', 'CG-CD1-CE1-CZ',
            'CD1-CE1-CZ-CE2', 'CE1-CZ-CE2-CD2'
        ],
        'torsion-types': [
            'C -N -CX-CT', 'N -CX-CT-CA', 'CX-CT-CA-CA', 'CT-CA-CA-CA', 'CA-CA-CA-CA',
            'CA-CA-CA-CA', 'CA-CA-CA-CA'
        ],
        'torsion-vals': ['p', 'p', 'p', 3.141592653589793, 0.0, 0.0, 0.0]
    },
    'P': {
        'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD'],
        'angles-types': ['N -CX-CT', 'CX-CT-CT', 'CT-CT-CT'],
        'angles-vals': [1.9146261894377796, 1.911135530933791, 1.911135530933791],
        'atom-names': ['CB', 'CG', 'CD'],
        'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD'],
        'bonds-types': ['CX-CT', 'CT-CT', 'CT-CT'],
        'bonds-vals': [1.526, 1.526, 1.526],
        'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD'],
        'torsion-types': ['C -N -CX-CT', 'N -CX-CT-CT', 'CX-CT-CT-CT'],
        'torsion-vals': ['p', 'p', 'p']
    },
    # 定义了氨基酸"S"的键值对,包含了该氨基酸的角度、键和扭转信息
    'S': {
        'angles-names': ['N-CA-CB', 'CA-CB-OG'],
        'angles-types': ['N -CX-2C', 'CX-2C-OH'],
        'angles-vals': [1.9146261894377796, 1.911135530933791],
        'atom-names': ['CB', 'OG'],
        'bonds-names': ['CA-CB', 'CB-OG'],
        'bonds-types': ['CX-2C', '2C-OH'],
        'bonds-vals': [1.526, 1.41],
        'torsion-names': ['C-N-CA-CB', 'N-CA-CB-OG'],
        'torsion-types': ['C -N -CX-2C', 'N -CX-2C-OH'],
        'torsion-vals': ['p', 'p']
    },
    # 定义了氨基酸"T"的键值对,包含了该氨基酸的角度、键和扭转信息
    'T': {
        'angles-names': ['N-CA-CB', 'CA-CB-OG1', 'CA-CB-CG2'],
        'angles-types': ['N -CX-3C', 'CX-3C-OH', 'CX-3C-CT'],
        'angles-vals': [1.9146261894377796, 1.911135530933791, 1.911135530933791],
        'atom-names': ['CB', 'OG1', 'CG2'],
        'bonds-names': ['CA-CB', 'CB-OG1', 'CB-CG2'],
        'bonds-types': ['CX-3C', '3C-OH', '3C-CT'],
        'bonds-vals': [1.526, 1.41, 1.526],
        'torsion-names': ['C-N-CA-CB', 'N-CA-CB-OG1', 'N-CA-CB-CG2'],
        'torsion-types': ['C -N -CX-3C', 'N -CX-3C-OH', 'N -CX-3C-CT'],
        'torsion-vals': ['p', 'p', 'p']
    },
    # 定义了氨基酸"W"的键值对,包含了该氨基酸的角度、键和扭转信息
    'W': {
        'angles-names': [
            'N-CA-CB', 'CA-CB-CG', 'CB-CG-CD1', 'CG-CD1-NE1', 'CD1-NE1-CE2',
            'NE1-CE2-CZ2', 'CE2-CZ2-CH2', 'CZ2-CH2-CZ3', 'CH2-CZ3-CE3', 'CZ3-CE3-CD2'
        ],
        'angles-types': [
            'N -CX-CT', 'CX-CT-C*', 'CT-C*-CW', 'C*-CW-NA', 'CW-NA-CN', 'NA-CN-CA',
            'CN-CA-CA', 'CA-CA-CA', 'CA-CA-CA', 'CA-CA-CB'
        ],
        'angles-vals': [
            1.9146261894377796, 2.0176006153054447, 2.181661564992912, 1.8971728969178363,
            1.9477874452256716, 2.3177972466484698, 2.0943951023931953,
            2.0943951023931953, 2.0943951023931953, 2.0943951023931953
        ],
        'atom-names': [
            'CB', 'CG', 'CD1', 'NE1', 'CE2', 'CZ2', 'CH2', 'CZ3', 'CE3', 'CD2'
        ],
        'bonds-names': [
            'CA-CB', 'CB-CG', 'CG-CD1', 'CD1-NE1', 'NE1-CE2', 'CE2-CZ2', 'CZ2-CH2',
            'CH2-CZ3', 'CZ3-CE3', 'CE3-CD2'
        ],
        'bonds-types': [
            'CX-CT', 'CT-C*', 'C*-CW', 'CW-NA', 'NA-CN', 'CN-CA', 'CA-CA', 'CA-CA',
            'CA-CA', 'CA-CB'
        ],
        'bonds-vals': [1.526, 1.495, 1.352, 1.381, 1.38, 1.4, 1.4, 1.4, 1.4, 1.404],
        'torsion-names': [
            'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD1', 'CB-CG-CD1-NE1', 'CG-CD1-NE1-CE2',
            'CD1-NE1-CE2-CZ2', 'NE1-CE2-CZ2-CH2', 'CE2-CZ2-CH2-CZ3', 'CZ2-CH2-CZ3-CE3',
            'CH2-CZ3-CE3-CD2'
        ],
        'torsion-types': [
            'C -N -CX-CT', 'N -CX-CT-C*', 'CX-CT-C*-CW', 'CT-C*-CW-NA', 'C*-CW-NA-CN',
            'CW-NA-CN-CA', 'NA-CN-CA-CA', 'CN-CA-CA-CA', 'CA-CA-CA-CA', 'CA-CA-CA-CB'
        ],
        'torsion-vals': [
            'p', 'p', 'p', 3.141592653589793, 0.0, 3.141592653589793, 3.141592653589793,
            0.0, 0.0, 0.0
        ]
    },
    'Y': {
        'angles-names': [
            'N-CA-CB', 'CA-CB-CG', 'CB-CG-CD1', 'CG-CD1-CE1', 'CD1-CE1-CZ', 'CE1-CZ-OH',
            'CE1-CZ-CE2', 'CZ-CE2-CD2'
        ],
        'angles-types': [
            'N -CX-CT', 'CX-CT-CA', 'CT-CA-CA', 'CA-CA-CA', 'CA-CA-C ', 'CA-C -OH',
            'CA-C -CA', 'C -CA-CA'
        ],
        'angles-vals': [
            1.9146261894377796, 1.9896753472735358, 2.0943951023931953,
            2.0943951023931953, 2.0943951023931953, 2.0943951023931953,
            2.0943951023931953, 2.0943951023931953
        ],
        'atom-names': ['CB', 'CG', 'CD1', 'CE1', 'CZ', 'OH', 'CE2', 'CD2'],
        'bonds-names': [
            'CA-CB', 'CB-CG', 'CG-CD1', 'CD1-CE1', 'CE1-CZ', 'CZ-OH', 'CZ-CE2', 'CE2-CD2'
        ],
        'bonds-types': [
            'CX-CT', 'CT-CA', 'CA-CA', 'CA-CA', 'CA-C ', 'C -OH', 'C -CA', 'CA-CA'
        ],
        'bonds-vals': [1.526, 1.51, 1.4, 1.4, 1.409, 1.364, 1.409, 1.4],
        'torsion-names': [
            'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD1', 'CB-CG-CD1-CE1', 'CG-CD1-CE1-CZ',
            'CD1-CE1-CZ-OH', 'CD1-CE1-CZ-CE2', 'CE1-CZ-CE2-CD2'
        ],
        'torsion-types': [
            'C -N -CX-CT', 'N -CX-CT-CA', 'CX-CT-CA-CA', 'CT-CA-CA-CA', 'CA-CA-CA-C ',
            'CA-CA-C -OH', 'CA-CA-C -CA', 'CA-C -CA-CA'
        ],
        'torsion-vals': [
            'p', 'p', 'p', 3.141592653589793, 0.0, 3.141592653589793, 0.0, 0.0
        ]
    },
    'V': {
        'angles-names': ['N-CA-CB', 'CA-CB-CG1', 'CA-CB-CG2'],
        'angles-types': ['N -CX-3C', 'CX-3C-CT', 'CX-3C-CT'],
        'angles-vals': [1.9146261894377796, 1.911135530933791, 1.911135530933791],
        'atom-names': ['CB', 'CG1', 'CG2'],
        'bonds-names': ['CA-CB', 'CB-CG1', 'CB-CG2'],
        'bonds-types': ['CX-3C', '3C-CT', '3C-CT'],
        'bonds-vals': [1.526, 1.526, 1.526],
        'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG1', 'N-CA-CB-CG2'],
        'torsion-types': ['C -N -CX-3C', 'N -CX-3C-CT', 'N -CX-3C-CT'],
        'torsion-vals': ['p', 'p', 'p']
    },
    '_': {
        'angles-names': [],
        'angles-types': [],
        'angles-vals': [],
        'atom-names': [],
        'bonds-names': [],
        'bonds-types': [],
        'bonds-vals': [],
        'torsion-names': [],
        'torsion-types': [],
        'torsion-vals': []
    }
# 闭合 BB_BUILD_INFO 字典
}

# 定义 BB_BUILD_INFO 字典,包含键值对表示不同键的键值
BB_BUILD_INFO = {
    "BONDLENS": {
        # 更新的值是根据来自 1DPE_1_A 的晶体数据进行的验证
        # 注释的值是 sidechainnet 的值
        'n-ca': 1.4664931, # 1.442, 
        'ca-c': 1.524119,  # 1.498,
        'c-n': 1.3289373,  # 1.379,
        'c-o': 1.229,  # 来自 parm10.dat || 根据结构有很大的变化
        # 我们从 1DPE_1_A 得到 1.3389416,但也从 2F2H_d2f2hf1 得到 1.2289
        'c-oh': 1.364
    },  # 对于 OXT 来自 parm10.dat
    # 用于放置氧原子
    "BONDANGS": {
        'ca-c-o': 2.0944,  # 近似为 2pi / 3; parm10.dat 表示为 2.0350539
        'ca-c-oh': 2.0944
    },  # 等同于 'ca-c-o',对于 OXT
    "BONDTORSIONS": {
        'n-ca-c-n': -0.785398163
    }  # 一个简单的近似,不打算精确
}

# 定义函数 make_cloud_mask,返回一个包含相关点为 1,填充点为 0 的数组
def make_cloud_mask(aa):
    """ relevent points will be 1. paddings will be 0. """
    mask = np.zeros(14)
    if aa != "_":
        n_atoms = 4+len( SC_BUILD_INFO[aa]["atom-names"] )
        mask[:n_atoms] = 1
    return mask

# 定义函数 make_bond_mask,返回一个包含每个原子起始键长的数组
def make_bond_mask(aa):
    """ Gives the length of the bond originating each atom. """
    mask = np.zeros(14)
    # backbone
    mask[0] = BB_BUILD_INFO["BONDLENS"]['c-n']
    mask[1] = BB_BUILD_INFO["BONDLENS"]['n-ca']
    mask[2] = BB_BUILD_INFO["BONDLENS"]['ca-c']
    mask[3] = BB_BUILD_INFO["BONDLENS"]['c-o']
    # sidechain - except padding token 
    if aa in SC_BUILD_INFO.keys():
        for i,bond in enumerate(SC_BUILD_INFO[aa]['bonds-vals']):
            mask[4+i] = bond
    return mask

# 定义函数 make_theta_mask,返回一个包含每个原子起始键角度的数组
def make_theta_mask(aa):
    """ Gives the theta of the bond originating each atom. """
    mask = np.zeros(14)
    # backbone
    #
    # sidechain
    for i,theta in enumerate(SC_BUILD_INFO[aa]['angles-vals']):
        mask[4+i] = theta
    return mask

# 定义函数 make_torsion_mask,返回一个包含每个原子起始键二面角的数组
def make_torsion_mask(aa):
    """ Gives the dihedral of the bond originating each atom. """
    mask = np.zeros(14)
    # backbone
    #
    # sidechain
    for i, torsion in enumerate(SC_BUILD_INFO[aa]['torsion-vals']):
        if torsion == 'p':
            mask[4+i] = np.nan 
        elif torsion == "i":
            # https://github.com/jonathanking/sidechainnet/blob/master/sidechainnet/structure/StructureBuilder.py#L372
            mask[4+i] =  999  # anotate to change later # mask[4+i-1] - np.pi
        else:
            mask[4+i] = torsion
    return mask

# 定义函数 make_idx_mask,返回一个包含前三个点的索引的数组
def make_idx_mask(aa):
    """ Gives the idxs of the 3 previous points. """
    mask = np.zeros((11, 3))
    # backbone
    mask[0, :] = np.arange(3) 
    # sidechain
    mapper = {"N": 0, "CA": 1, "C":2,  "CB": 4}
    for i, torsion in enumerate(SC_BUILD_INFO[aa]['torsion-names']):
        # 获取形成二面角的所有原子
        torsions = [x.rstrip(" ") for x in torsion.split("-")]
        # 对于每个原子
        for n, torsion in enumerate(torsions[:-1]):
            # 获取坐标数组中原子的索引
            loc = mapper[torsion] if torsion in mapper.keys() else 4 + SC_BUILD_INFO[aa]['atom-names'].index(torsion)
            # 设置位置为索引
            mask[i+1][n] = loc
    return mask

# 定义 SUPREME_INFO 字典,包含各种信息的字典
SUPREME_INFO = {k: {"cloud_mask": make_cloud_mask(k),
                    "bond_mask": make_bond_mask(k),
                    "theta_mask": make_theta_mask(k),
                    "torsion_mask": make_torsion_mask(k),
                    "idx_mask": make_idx_mask(k),
                    } 
                for k in "ARNDCQEGHILKMFPSTWYV_"}

# 定义函数 scn_cloud_mask,获取原子位置的布尔掩码
def scn_cloud_mask(seq, coords=None):
    """ Gets the boolean mask atom positions (not all aas have same atoms). 
        Inputs: 
        * seqs: (length) iterable of 1-letter aa codes of a protein
        * coords: optional .(batch, lc, 3). sidechainnet coords.
                  returns the true mask (solves potential atoms that might not be provided)
        Outputs: (length, 14) boolean mask 
    """ 
    # 如果坐标不为空
    if coords is not None:
        # 重新排列坐标张量的维度,将最后一维拆分为两个维度
        # 检查是否等于0,然后按最后一个维度求和
        # 检查是否小于坐标张量的最后一个维度的长度
        # 转换为浮点数并移动到 CPU 上
        return ((rearrange(coords, '... (l c) d -> ... l c d', c=14) == 0).sum(dim=-1) < coords.shape[-1]).float().cpu()
    # 如果坐标为空
    # 返回一个张量,其中包含序列中每个氨基酸的云掩码信息
    return torch.tensor([SUPREME_INFO[aa]['cloud_mask'] for aa in seq])
# 定义函数,根据氨基酸序列生成键长掩码
def scn_bond_mask(seq):
    """ Inputs: 
        * seqs: (length). iterable of 1-letter aa codes of a protein
        Outputs: (L, 14) maps point to bond length
    """ 
    # 返回键长掩码的张量
    return torch.tensor([SUPREME_INFO[aa]['bond_mask'] for aa in seq])


# 定义函数,根据氨基酸序列和角度生成角度掩码
def scn_angle_mask(seq, angles):
    """ Inputs: 
        * seq: (length). iterable of 1-letter aa codes of a protein
        * angles: (length, 12). [phi, psi, omega, b_angle(n_ca_c), b_angle(ca_c_n), b_angle(c_n_ca), 6_scn_torsions]
        Outputs: (L, 14) maps point to theta and dihedral.
                 first angle is theta, second is dihedral
    """ 
    # 获取设备和精度
    device, precise = angles.device, angles.type()
    angles = angles
    # 获取角度掩码
    theta_mask   = torch.tensor([SUPREME_INFO[aa]['theta_mask'] for aa in seq]).type(precise)
    torsion_mask = torch.tensor([SUPREME_INFO[aa]['torsion_mask'] for aa in seq]).type(precise)
    # 填充掩码与角度值
    theta_mask[:, 0] = angles[:, 4] # ca_c_n
    theta_mask[1:, 1] = angles[:-1, 5] # c_n_ca
    theta_mask[:, 2] = angles[:, 3] # n_ca_c
    theta_mask[:, 3] = BB_BUILD_INFO["BONDANGS"]["ca-c-o"]
    torsion_mask[:, 0] = angles[:, 1] # n determined by psi of previous
    torsion_mask[1:, 1] = angles[:-1, 2] # ca determined by omega of previous
    torsion_mask[:, 2] = angles[:, 0] # c determined by phi
    torsion_mask[:, 3] = angles[:, 1] - np.pi 
    torsion_mask[-1, 3] += np.pi              
    to_fill = torsion_mask != torsion_mask
    to_pick = torsion_mask == 999
    for i in range(len(seq)):
        number = to_fill[i].long().sum()
        torsion_mask[i, to_fill[i]] = angles[i, 6:6+number]
        for j, val in enumerate(to_pick[i]):
            if val:
                torsion_mask[i, j] = torsion_mask[i, j-1] - np.pi

    return torch.stack([theta_mask, torsion_mask], dim=0).to(device)


# 定义函数,根据氨基酸序列生成索引掩码
def scn_index_mask(seq):
    """ Inputs: 
        * seq: (length). iterable of 1-letter aa codes of a protein
        Outputs: (L, 11, 3) maps point to theta and dihedral.
                 first angle is theta, second is dihedral
    """ 
    # 获取索引掩码
    idxs = torch.tensor([SUPREME_INFO[aa]['idx_mask'] for aa in seq])
    return rearrange(idxs, 'l s d -> d l s')


# 定义函数,根据氨基酸序列和角度生成蛋白质骨架
def build_scaffolds_from_scn_angles(seq, angles, coords=None, device="auto"):
    """ Builds scaffolds for fast access to data
        Inputs: 
        * seq: string of aas (1 letter code)
        * angles: (L, 12) tensor containing the internal angles.
                  Distributed as follows (following sidechainnet convention):
                  * (L, 3) for torsion angles
                  * (L, 3) bond angles
                  * (L, 6) sidechain angles
        * coords: (L, 3) sidechainnet coords. builds the mask with those instead
                  (better accuracy if modified residues present).
        Outputs:
        * cloud_mask: (L, 14 ) mask of points that should be converted to coords 
        * point_ref_mask: (3, L, 11) maps point (except n-ca-c) to idxs of
                                     previous 3 points in the coords array
        * angles_mask: (2, L, 14) maps point to theta and dihedral
        * bond_mask: (L, 14) gives the length of the bond originating that atom
    """
    precise = angles.type()
    if device == "auto":
        device = angles.device

    if coords is not None: 
        cloud_mask = scn_cloud_mask(seq, coords=coords)
    else: 
        cloud_mask = scn_cloud_mask(seq)

    cloud_mask = torch.tensor(cloud_mask).bool().to(device)
    # 生成点云索引掩码,将其转换为长整型张量,并移动到指定设备上
    point_ref_mask = torch.tensor(scn_index_mask(seq)).long().to(device)
     
    # 生成角度掩码,将其转换为指定精度类型的张量,并移动到指定设备上
    angles_mask = torch.tensor(scn_angle_mask(seq, angles)).type(precise).to(device)
     
    # 生成键合掩码,将其转换为指定精度类型的张量,并移动到指定设备上
    bond_mask = torch.tensor(scn_bond_mask(seq)).type(precise).to(device)
    # 将所有结果以字典形式返回
    return {"cloud_mask":     cloud_mask, 
            "point_ref_mask": point_ref_mask,
            "angles_mask":    angles_mask,
            "bond_mask":      bond_mask }
#############################
####### ENCODERS ############
#############################


# 修改蛋白质支架的坐标信息
def modify_scaffolds_with_coords(scaffolds, coords):
    """ Gets scaffolds and fills in the right data.
        Inputs: 
        * scaffolds: dict. as returned by `build_scaffolds_from_scn_angles`
        * coords: (L, 14, 3). sidechainnet tensor. same device as scaffolds
        Outputs: corrected scaffolds
    """

    # 计算距离并更新:
    # N, CA, C
    scaffolds["bond_mask"][1:, 0] = torch.norm(coords[1:, 0] - coords[:-1, 2], dim=-1) # N
    scaffolds["bond_mask"][:, 1] = torch.norm(coords[:, 1] - coords[:, 0], dim=-1) # CA
    scaffolds["bond_mask"][:, 2] = torch.norm(coords[:, 2] - coords[:, 1], dim=-1) # C
    # O, CB, 侧链
    selector = np.arange(len(coords))
    for i in range(3, 14):
        # 获取索引
        idx_a, idx_b, idx_c = scaffolds["point_ref_mask"][:, :, i-3] # (3, L, 11) -> 3 * (L, 11)
        # 修正距离
        scaffolds["bond_mask"][:, i] = torch.norm(coords[:, i] - coords[selector, idx_c], dim=-1)
        # 获取角度
        scaffolds["angles_mask"][0, :, i] = get_angle(coords[selector, idx_b], 
                                                      coords[selector, idx_c], 
                                                      coords[:, i])
        # 处理 C-beta,其中请求的 C 来自前一个氨基酸
        if i == 4:
            # 对于第一个氨基酸,使用第二个氨基酸的 N 位置
            first_next_n = coords[1, :1] # 1, 3
            # 请求的 C 来自前一个氨基酸
            main_c_prev_idxs = coords[selector[:-1], idx_a[1:]] # (L-1), 3
            # 连接
            coords_a = torch.cat([first_next_n, main_c_prev_idxs])
        else:
            coords_a = coords[selector, idx_a]
        # 获取二面角
        scaffolds["angles_mask"][1, :, i] = get_dihedral(coords_a,
                                                         coords[selector, idx_b], 
                                                         coords[selector, idx_c], 
                                                         coords[:, i])
    # 为主链修正角度和二面角
    scaffolds["angles_mask"][0, :-1, 0] = get_angle(coords[:-1, 1], coords[:-1, 2], coords[1:, 0]) # ca_c_n
    scaffolds["angles_mask"][0, 1:, 1] = get_angle(coords[:-1, 2], coords[1:, 0], coords[1:, 1]) # c_n_ca
    scaffolds["angles_mask"][0, :, 2] = get_angle(coords[:, 0], coords[:, 1], coords[:, 2]) # n_ca_c
    
    # N 由前一个 psi 决定 = f(n, ca, c, n+1)
    scaffolds["angles_mask"][1, :-1, 0] = get_dihedral(coords[:-1, 0], coords[:-1, 1], coords[:-1, 2], coords[1:, 0])
    # CA 由 omega 决定 = f(ca, c, n+1, ca+1)
    scaffolds["angles_mask"][1, 1:, 1] = get_dihedral(coords[:-1, 1], coords[:-1, 2], coords[1:, 0], coords[1:, 1])
    # C 由 phi 决定 = f(c-1, n, ca, c)
    scaffolds["angles_mask"][1, 1:, 2] = get_dihedral(coords[:-1, 2], coords[1:, 0], coords[1:, 1], coords[1:, 2])

    return scaffolds



if __name__ == "__main__":
    print(scn_cloud_mask("AAAA"))