Lucidrains-系列项目源码解析-八十七-

171 阅读22分钟

Lucidrains 系列项目源码解析(八十七)

.\lucidrains\RQ-Transformer\rq_transformer\rq_transformer.py

import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops_exts import rearrange_with_anon_dims
from einops import rearrange, reduce, repeat

# helpers

# 检查值是否存在
def exists(val):
    return val is not None

# 返回值或默认值
def default(val, d):
    return val if exists(val) else d

# 计算余数到最接近的倍数
def remainder_to_mult(num, mult):
    return (mult - num % mult) % mult

# 计算对数,避免值过小
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 生成 Gumbel 分布采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

# 保留前 k 个最大值,其余设为负无穷
def top_k(logits, thres = 0.5):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# helper classes

# 前馈神经网络
def FeedForward(*, dim, mult = 4, dropout = 0.):
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, dim * mult),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim * mult, dim)
    )

# 注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        dropout = 0.
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(dim)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        h, device = self.heads, x.device

        x = self.norm(x)
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        i, j = sim.shape[-2:]
        mask_value = -torch.finfo(sim.dtype).max
        mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
        sim = sim.masked_fill(mask, mask_value)

        sim = sim - sim.amax(dim = -1, keepdim = True).detach()
        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# Transformer 模块
class Transformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        layers,
        dim_head = 64,
        heads = 8,
        attn_dropout = 0.,
        ff_dropout = 0.,
        ff_mult = 4
    ):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(layers):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
            ]))

        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

# 主类

class RQTransformer(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        max_spatial_seq_len,
        depth_seq_len,
        spatial_layers,
        depth_layers,
        dim_head = 64,
        heads = 8,
        attn_dropout = 0.,
        ff_mult = 4,
        ff_dropout = 0.,
        pad_id = 0
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化模型的维度
        self.dim = dim
        # 初始化空间序列的最大长度
        self.max_spatial_seq_len = max_spatial_seq_len
        # 初始化深度序列的长度
        self.depth_seq_len = depth_seq_len

        # 创建一个词嵌入层,用于将输入的标记转换为向量表示
        self.token_emb = nn.Embedding(num_tokens, dim)
        # 初始化空间序列的起始标记
        self.spatial_start_token = nn.Parameter(torch.randn(dim))

        # 创建一个空间位置编码层
        self.spatial_pos_emb = nn.Embedding(max_spatial_seq_len + 1, dim) # 考虑到一个边界情况
        # 创建一个深度位置编码层
        self.depth_pos_emb = nn.Embedding(depth_seq_len, dim)

        # 创建一个空间变换器,用于处理空间序列的变换
        self.spatial_transformer = Transformer(
            dim = dim,
            layers = spatial_layers,
            dim_head = dim_head,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            ff_mult = ff_mult
        )

        # 创建一个深度变换器,用于处理深度序列的变换
        self.depth_transformer = Transformer(
            dim = dim,
            layers = depth_layers,
            dim_head = dim_head,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            ff_mult = ff_mult
        )

        # 创建一个线性层,用于将模型输出转换为标记的概率分布
        self.to_logits = nn.Linear(dim, num_tokens)
        # 初始化填充标记的ID
        self.pad_id = pad_id

    def generate(self, prime = None, filter_thres = 0.9, temperature = 1., default_batch_size = 1):
        # 计算总的序列长度
        total_seq_len = self.depth_seq_len * self.max_spatial_seq_len
        # 获取模型所在的设备
        device = next(self.parameters()).device

        # 如果没有给定初始输入,则创建一个空的张量作为初始输入
        if not exists(prime):
            prime = torch.empty((default_batch_size, 0), dtype = torch.long, device = device)

        seq = prime

        # 生成序列
        for _ in range(total_seq_len - seq.shape[-1]):
            # 获取模型的预测结果
            logits = self.forward(seq)[:, -1]
            # 通过阈值筛选保留概率较高的标记
            logits = top_k(logits, thres = filter_thres)
            # 通过Gumbel采样获取下一个标记
            sampled = gumbel_sample(logits, dim = -1, temperature = temperature)
            # 将新生成的标记添加到序列中
            seq = torch.cat((seq, rearrange(sampled, 'b -> b 1')), dim = -1)

        # 重新排列生成的序列
        return rearrange(seq, 'b (s d) -> b s d', d = self.depth_seq_len)

    def forward_empty(self, batch_size):
        # 处理特殊情况,当从输入中只采样到0(仅起始标记)时

        # 重复空间起始标记,以匹配指定的批量大小
        spatial_tokens = repeat(self.spatial_start_token, 'd -> b 1 d', b = batch_size)
        # 经过空间变换器处理
        depth_tokens = self.spatial_transformer(spatial_tokens)
        # 经过深度变换器处理
        depth_tokens = self.depth_transformer(depth_tokens)
        # 将处理后的深度标记转换为模型输出
        return self.to_logits(depth_tokens)
    # 定义前向传播函数,接受输入 ids 和是否返回损失值的标志
    def forward(self, ids, return_loss = False):
        # 断言输入 ids 的维度为 2 或 3
        assert ids.ndim in {2, 3}
        # 检查是否为扁平化维度
        flattened_dim = ids.ndim == 2
        # 保存原始 ids 的维度
        ids_orig_ndim = ids.ndim

        # 如果 ids 中元素数量为 0,则调用 forward_empty 函数处理
        if ids.numel() == 0:
            return self.forward_empty(ids.shape[0])

        # 如果是扁平化维度
        if flattened_dim:
            # 允许 ids 的形状为 (batch, seq),自动填充到最接近深度序列长度的倍数
            seq_len = ids.shape[-1]
            padding = remainder_to_mult(seq_len, self.depth_seq_len)
            ids = F.pad(ids, (0, padding), value = self.pad_id)
            ids = rearrange(ids, 'b (s d) -> b s d', d = self.depth_seq_len)
        else:
            seq_len = ids.shape[1] * ids.shape[2]

        # 获取 ids 的形状、空间维度、深度维度、设备信息
        b, space, depth, device = *ids.shape, ids.device
        # 断言空间维度小于等于最大空间序列长度加一
        assert space <= (self.max_spatial_seq_len + 1), 'spatial dimension is greater than the max_spatial_seq_len set'
        # 断言深度维度等于深度序列长度
        assert depth == self.depth_seq_len, 'depth dimension must be equal to depth_seq_len'

        # 获取 token embeddings
        tokens = self.token_emb(ids)

        # 获取空间位置编码和深度位置编码
        spatial_pos = self.spatial_pos_emb(torch.arange(space, device = device))
        depth_pos = self.depth_pos_emb(torch.arange(depth, device = device))

        # 将 token embeddings 和深度位置编码相加
        tokens_with_depth_pos = tokens + depth_pos

        # 计算空间 tokens
        spatial_tokens = reduce(tokens_with_depth_pos, 'b s d f -> b s f', 'sum') + spatial_pos

        # 在空间 tokens 前添加起始 token
        spatial_tokens = torch.cat((
            repeat(self.spatial_start_token, 'f -> b 1 f', b = b),
            spatial_tokens
        ), dim = -2)        

        # 使用空间 transformer 处理空间 tokens
        spatial_tokens = self.spatial_transformer(spatial_tokens)

        # 重新排列空间 tokens 的维度
        spatial_tokens = rearrange(spatial_tokens, 'b s f -> b s 1 f')

        # 将空间 tokens 变为深度维度的起始 tokens
        tokens_with_depth_pos = F.pad(tokens_with_depth_pos, (0, 0, 0, 0, 0, 1), value = 0.)

        # 拼��深度 tokens
        depth_tokens = torch.cat((spatial_tokens, tokens_with_depth_pos), dim = -2)

        # 重新排列深度 tokens 的维度
        depth_tokens = rearrange(depth_tokens, '... n d -> (...) n d')

        # 使用深度 transformer 处理深度 tokens
        depth_tokens = self.depth_transformer(depth_tokens)

        # 重新排列深度 tokens 的维度
        depth_tokens = rearrange(depth_tokens, '(b s) d f -> b s d f', b = b)

        # 获取 logits
        logits = self.to_logits(depth_tokens)
        logits = rearrange(logits, 'b ... f -> b (...) f')
        logits = logits[:, :(seq_len + 1)]

        # 如果不需要返回损失值
        if not return_loss:
            logits = logits[:, 1:]

            # 如果是扁平化维度,则返回重新排列后的 logits
            if flattened_dim:
                return rearrange(logits, 'b ... n -> b (...) n')

            return logits

        # 如果需要返回损失值
        logits = logits[:, :-1]
        
        # 重新排列 logits 和 ids 的维度
        preds = rearrange(logits, 'b ... c -> b c (...)')
        labels = rearrange(ids, 'b s d -> b (s d)')

        # 计算交叉熵损失
        loss = F.cross_entropy(preds, labels, ignore_index = self.pad_id)
        return loss

.\lucidrains\RQ-Transformer\rq_transformer\__init__.py

# 从 rq_transformer 模块中导入 RQTransformer 类
from rq_transformer.rq_transformer import RQTransformer
# 从 rq_transformer 模块中导入 HierarchicalCausalTransformer 类
from rq_transformer.hierarchical_causal_transformer import HierarchicalCausalTransformer

.\lucidrains\RQ-Transformer\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'RQ-transformer',  # 包名
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.1.9',  # 版本号
  license='MIT',  # 许可证
  description = 'RQ Transformer - Autoregressive Transformer for Residual Quantized Codes',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/RQ-transformer',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention-mechanism',
    'autoregressive',
  ],
  install_requires=[  # 安装依赖
    'einops>=0.4',
    'einops-exts',
    'torch>=1.6',
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\RQ-Transformer\train.py

# 导入所需的库
from rq_transformer import HierarchicalCausalTransformer

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 常量定义
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
PRIME_LEN = 100
SEQ_LEN = 1024

# 辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化类似 GPT 的解码器模型
model = HierarchicalCausalTransformer(
    num_tokens = 256,
    dim = 512,
    depth = (4, 3, 3, 3),
    max_seq_len = (4, 4, 8, 8)
).cuda()

# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义文本采样数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader), return_loss = True)
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader), return_loss = True)
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime_inp = inp[:PRIME_LEN]
        prime = decode_tokens(prime_inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(prime_inp[None, :])
        sample = sample.flatten(1)

        output_str = decode_tokens(sample[0][PRIME_LEN:])
        print(output_str)

Data source

The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/

RVQ-VAE-GPT - Residual Vector Quantize VAE - GPT (wip)

My attempts at applying Soundstream design on learned tokenization of text and then applying a hierarchical transformer to text generation.

The Soundstream will be modified to use all local attention. Experiments will compare VQ, RVQ, and also multi-headed VQ

Was told by a researcher friend this will likely fail 😂😂 but I will try it anyways, yolo. In the case it does not work, maybe it can still be useful for genomics. Come to think of it, why shouldn't it be able to at least learn bigrams (for english) and codons (for genomics)? Why don't we have hierarchical predictive coding? We should

Update: Some live experiments

Todo

  • add a diff in the autoencoder training between input and reconstructed, so one can examine the failure cases easily

Citations

@misc{https://doi.org/10.48550/arxiv.2107.03312,
  title  = {SoundStream: An End-to-End Neural Audio Codec},
  author = {Zeghidour, Neil and Luebs, Alejandro and Omran, Ahmed and Skoglund, Jan and Tagliasacchi, Marco},
  publisher = {arXiv},
  url    = {https://arxiv.org/abs/2107.03312},
  year   = {2021}
}
@unknown{unknown,
    author  = {Lee, Doyup and Kim, Chiheon and Kim, Saehoon and Cho, Minsu and Han, Wook-Shin},
    year    = {2022},
    month   = {03},
    title   = {Autoregressive Image Generation using Residual Quantization}
}
@article{Sunkara2022NoMS,
    title   = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
    author  = {Raja Sunkara and Tie Luo},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.03641}
}

.\lucidrains\rvq-vae-gpt\rvq_vae_gpt\rvq_vae_gpt.py

# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块并重命名为 F
import torch.nn.functional as F
# 从 torch 模块中导入 nn、einsum
from torch import nn, einsum

# 从 einops 库中导入 rearrange、repeat、pack、unpack
from einops import rearrange, repeat, pack, unpack
# 从 einops.layers.torch 模块中导入 Rearrange
from einops.layers.torch import Rearrange

# 导入自定义的 local_attention 模块中的 LocalMHA 类
from local_attention import LocalMHA
# 导入自定义的 vector_quantize_pytorch 模块中的 VectorQuantize、ResidualVQ 类
from vector_quantize_pytorch import VectorQuantize, ResidualVQ

# 从 beartype 库中导入 beartype、Tuple、Optional、Union
from beartype import beartype
from beartype.typing import Tuple, Optional, Union

# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 导入 pickle 库
import pickle

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 获取迭代器的第一个元素
def first(it):
    return it[0]

# 返回第一个存在的值
def default(*vals):
    for val in vals:
        if exists(val):
            return val
    return None

# 判断一个数是否可以被另一个数整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 将输入转换为元组
def cast_tuple(t, len = 1):
    return ((t,) * len) if not isinstance(t, tuple) else t

# token shift - RWKV 中使用

# 将输入张量按照最后一个维度分割成两部分,并进行位移
def shift_tokens(t):
    t, t_shift = t.chunk(2, dim = -1)
    t_shift = F.pad(t_shift, (0, 0, 1, -1), value = 0.)
    return torch.cat((t, t_shift), dim = -1)

# 前馈网络

# GEGLU 激活函数
class GEGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return x * F.gelu(gate)

# 创建前馈网络模块
def FeedForward(dim, mult = 4):
    dim_inner = int(dim * mult * 2 / 3)

    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, dim_inner * 2),
        GEGLU(),
        nn.Linear(dim_inner, dim)
    )

# 最佳的上采样和下采样方式

# 上采样模块
class Upsample(nn.Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        factor = 2
    ):
        super().__init__()
        dim_out = default(dim_out, dim)
        linear = nn.Linear(dim, dim_out * factor)

        self.net = nn.Sequential(
            linear,
            nn.SiLU(),
            Rearrange('b n (p d) -> b (n p) d', p = factor)
        )

        self.factor = factor
        self.init_(linear)

    # 初始化线性层的权重和偏置
    def init_(self, linear):
        o, i = linear.weight.shape

        linear_weight = torch.empty(o // self.factor, i)
        nn.init.kaiming_uniform_(linear_weight)

        linear_weight = repeat(linear_weight, 'o ... -> (o r) ...', r = self.factor)

        linear_weight.data.copy_(linear_weight)
        nn.init.zeros_(linear.bias.data)

    def forward(self, x):
        return self.net(x)

# 下采样模块
def Downsample(
    dim,
    dim_out = None,
    factor = 2
):
    dim_out = default(dim_out, dim)
    return nn.Sequential(
        Rearrange('b (n p) d -> b n (p d)', p = factor),
        nn.Linear(dim * factor, dim_out)
    )

# 本地注意力

# 本地 Transformer 模块
class LocalTransformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        heads,
        dim_head,
        window_size
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                LocalMHA(
                    dim = dim,
                    heads = heads,
                    dim_head = dim_head,
                    qk_rmsnorm = True,
                    window_size = window_size,
                    use_rotary_pos_emb = True,
                    use_xpos = True,
                    causal = True
                ),
                FeedForward(dim = dim)
            ]))

    def forward(self, x):

        for attn, ff in self.layers:
            x = attn(shift_tokens(x)) + x
            x = ff(shift_tokens(x)) + x

        return x

# 模块

# 文本 VQ-VAE 模型
@beartype
class TextVQVAE(nn.Module): # 或者基因组,最终,将 num_tokens 设置为 4
    def __init__(
        self,
        *,
        num_tokens,
        dim: Union[int, Tuple[int, ...]],
        depth: Union[int, Tuple[int, ...]],
        strides: Union[int, Tuple[int, ...]],
        codebook_size = 1024,
        local_attn_window_size = 32,
        local_attn_heads = 8,
        local_attn_dim_head = 64,
        num_codebooks = 4,
        vq_decay = 0.9,
        rvq_quantize_dropout = True
    # 初始化函数,继承父类的初始化方法
    def __init__(
        self,
        vq_decay,
        strides,
        dim,
        depth,
        local_attn_window_size,
        num_tokens,
        local_attn_heads,
        local_attn_dim_head,
        num_codebooks,
        codebook_size,
        rvq_quantize_dropout
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 获取当前函数的局部变量
        config = locals()
        # 移除不需要的变量
        config.pop('self')
        config.pop('__class__')
        # 将配置信息保存到实例变量中
        self._config = config

        # 断言 vq_decay 的取值范围
        assert 0 < vq_decay <= 1.

        # 将 strides 转换为元组
        strides = cast_tuple(strides)
        num_layers = len(strides)

        # 将 dim、depth、local_attn_window_size 转换为元组
        dim = cast_tuple(dim, num_layers)
        depth = cast_tuple(depth, num_layers)
        local_attn_window_size = cast_tuple(local_attn_window_size, num_layers)

        # 断言各参数长度一致
        assert num_layers == len(depth) == len(local_attn_window_size) == len(dim)

        # 获取初始维度和 VQ 维度
        init_dim, vq_dim = dim[0], dim[-1]

        # 构建维度列表和维度对
        dims = [first(dim), *dim]
        dim_pairs = tuple(zip(dims[:-1], dims[1:]))

        # 创建 token embedding 层
        self.token_emb = nn.Embedding(num_tokens, init_dim)

        # 计算总步长
        self.total_strides = torch.tensor(list(strides)).cumprod(dim = -1)[-1].item()

        # 初始化 encoder
        self.encoder = nn.ModuleList([])

        # 构建每一层的参数元组
        layer_params = tuple(zip(
            strides,
            depth,
            local_attn_window_size,
            dim_pairs
        ))

        # 初始化初始 transformer
        self.init_transformer = LocalTransformer(
            dim = init_dim,
            depth = first(depth),
            heads = local_attn_heads,
            dim_head = local_attn_dim_head,
            window_size = first(local_attn_window_size)
        )

        # 初始化最终 transformer
        self.final_transformer = LocalTransformer(
            dim = init_dim,
            depth = first(depth),
            heads = local_attn_heads,
            dim_head = local_attn_dim_head,
            window_size = first(local_attn_window_size)
        )

        # 遍历每一层参数,构建 encoder
        for layer_stride, layer_depth, layer_local_attn_window_size, (dim_in, dim_out) in layer_params:
            self.encoder.append(nn.ModuleList([
                Downsample(dim = dim_in, dim_out = dim_out, factor = layer_stride),
                LocalTransformer(
                    dim = dim_out,
                    depth = layer_depth,
                    heads = local_attn_heads,
                    dim_head = local_attn_dim_head,
                    window_size = layer_local_attn_window_size
                )
            ]))

        # 初始化 encoder_norm
        self.encoder_norm = nn.LayerNorm(vq_dim)

        # 初始化 VQ
        self.vq = ResidualVQ(
            dim = vq_dim,
            num_quantizers = num_codebooks,
            codebook_size = codebook_size,
            decay = vq_decay,
            quantize_dropout = num_codebooks > 1 and rvq_quantize_dropout,
            commitment_weight = 0.,   # the weight on the commitment loss
            kmeans_init = True,
            kmeans_iters = 10
        )

        # 初始化 decoder
        self.decoder = nn.ModuleList([])

        # 遍历每一层参数,构建 decoder
        for layer_stride, layer_depth, layer_local_attn_window_size, (dim_in, dim_out) in reversed(layer_params):
            self.decoder.append(nn.ModuleList([
                Upsample(dim = dim_out, dim_out = dim_in, factor = layer_stride),
                LocalTransformer(
                    dim = dim_out,
                    depth = layer_depth,
                    heads = local_attn_heads,
                    dim_head = local_attn_dim_head,
                    window_size = layer_local_attn_window_size
                )
            ]))

        # 初始化 to_logits
        self.to_logits = nn.Sequential(
            nn.LayerNorm(init_dim),
            nn.Linear(init_dim, num_tokens)
        )

    # 保存模型
    def save(self, path):
        path = Path(path)
        pkg = dict(
            model = self.state_dict(),
            config = pickle.dumps(self._config)
        )
        torch.save(pkg, str(path))

    # 加载模型
    def load(self, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))
        self.load_state_dict(pkg['model'])

    # 初始化并加载模型
    @classmethod
    def init_and_load(cls, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))
        model = cls(**pickle.loads(pkg['config']))
        model.load(path)
        return model

    # 获取设备信息
    @property
    def device(self):
        return next(self.parameters()).device
    # 编码器,将输入的ids转换为tokens
    def encode(self, ids):
        # 使用token_emb方法将ids转换为tokens
        tokens = self.token_emb(ids)

        # 使用init_transformer方法对tokens进行初始化转换
        tokens = self.init_transformer(tokens)

        # 遍历编码器中的每个层,进行下采样和局部注意力操作
        for downsample, local_attn in self.encoder:
            tokens = downsample(tokens)
            tokens = local_attn(tokens)

        # 对编码后的tokens进行编码器归一化
        return self.encoder_norm(tokens)

    # 解码器,将codes解码为logits
    def decode(self, codes):
        # 将codes赋值给tokens
        tokens = codes

        # 遍历解码器中的每个层,进行局部注意力和上采样操作
        for upsample, local_attn in self.decoder:
            tokens = local_attn(tokens)
            tokens = upsample(tokens)

        # 对解码后的tokens进行最终转换
        tokens = self.final_transformer(tokens)

        # 将tokens转换为logits
        logits = self.to_logits(tokens)
        return logits

    # 从codebook_ids解码得到logits
    @torch.no_grad()
    def decode_from_codebook_ids(self, codebook_ids):
        # 使用vq对象的get_codes_from_indices方法将codebook_ids转换为codes
        codes = self.vq.get_codes_from_indices(codebook_ids)
        # 调用decode方法解码codes得到logits
        return self.decode(codes)

    # 整体前向传播过程
    def forward(
        self,
        ids,
        return_codebook_indices = False,
        return_reconstruction = False,
        return_loss_breakdown = False
    ):
        # 获取ids的batch和seq长度
        batch, seq = ids.shape
        # 断言seq能够被total_strides整除
        assert divisible_by(seq, self.total_strides)

        # 将ids移动到设备上
        ids = ids.to(self.device)

        # 对ids进行编码得到tokens
        tokens = self.encode(ids)

        # 对tokens进行向量量化操作,返回更新后的tokens、indices和loss
        tokens, indices, _ = self.vq(tokens)

        # 如果需要返回codebook_indices,则直接返回indices
        if return_codebook_indices:
            return indices

        # 对tokens进行解码得到logits
        logits = self.decode(tokens)

        # 将logits重新排列为 'b c n' 的形式
        logits = rearrange(logits, 'b n c -> b c n')

        # 计算交叉熵损失
        loss = F.cross_entropy(
            logits,
            ids
        )

        # 如果需要返���重构结果,则返回loss和logits的argmax值
        if return_reconstruction:
            return loss, logits.argmax(dim = 1)

        # 返回loss
        return loss
# 定义一个名为Transformer的类,表示层次结构的变换器
class Transformer(nn.Module):
    pass

.\lucidrains\rvq-vae-gpt\rvq_vae_gpt\__init__.py

# 从 rvq_vae_gpt.rvq_vae_gpt 模块中导入 TextVQVAE 和 Transformer 类
from rvq_vae_gpt.rvq_vae_gpt import TextVQVAE, Transformer

.\lucidrains\rvq-vae-gpt\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的信息
setup(
  name = 'rvq-vae-gpt',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.4',  # 版本号
  license='MIT',  # 许可证
  description = 'Yet another attempt at GPT in quantized latent space',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/rvq-vae-gpt',  # 项目链接
  keywords = [  # 关键词
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism'
  ],
  install_requires=[  # 安装依赖
    'beartype',
    'einops>=0.4',
    'local-attention>=1.0.0',
    'torch>=1.6',
    'vector-quantize-pytorch>=1.1.2'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\rvq-vae-gpt\train.py

# 导入所需的库
import gzip
import random
import tqdm
import numpy as np

import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 导入自定义模块
from rvq_vae_gpt import TextVQVAE

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
SAVE_EVERY = 1000
SEQ_LEN = 2048

# 定义辅助函数
def cycle(loader):
    # 无限循环生成数据
    while True:
        for data in loader:
            yield data

def first(it):
    # 返回迭代器的第一个元素
    return it[0]

def decode_token(token):
    # 将 token 解码为字符
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    # 将 tokens 解码为字符串
    return "".join(list(map(decode_token, tokens)))

# 实例化 TextVQVAE 模型
model = TextVQVAE(
    num_tokens = 256,    
    dim = (128, 256, 512),
    depth = (2, 2, 4),
    local_attn_window_size = 64,
    num_codebooks = 8,
    strides = (2, 2, 2)
).cuda()

# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 定义 TextSamplerDataset 类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
    model.train()

    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    print(f"training loss: {loss.item():.3f}")
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    if i == 0:
        continue

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            valid_text = next(val_loader)
            loss, recon = model(valid_text, return_reconstruction = True)

            print(f"validation loss: {loss.item():.3f}")

            print(f"\n\n\n[input text]\n\n {decode_tokens(first(valid_text))}")
            print(f"\n\n[reconstructed text]\n\n {decode_tokens(first(recon))}\n\n")

    if i % SAVE_EVERY == 0:
        model.save('./text-vae.pt')

SAC (Soft Actor Critic) - Pytorch (wip)

Implementation of Soft Actor Critic and some of its improvements in Pytorch. Interest comes from watching this lecture

Temporary Discord

Citations

@article{Haarnoja2018SoftAA,
    title   = {Soft Actor-Critic Algorithms and Applications},
    author  = {Tuomas Haarnoja and Aurick Zhou and Kristian Hartikainen and G. Tucker and Sehoon Ha and Jie Tan and Vikash Kumar and Henry Zhu and Abhishek Gupta and P. Abbeel and Sergey Levine},
    journal = {ArXiv},
    year    = {2018},
    volume  = {abs/1812.05905},
    url     = {https://api.semanticscholar.org/CorpusID:55703664}
}
@article{Hiraoka2021DropoutQF,
    title   = {Dropout Q-Functions for Doubly Efficient Reinforcement Learning},
    author  = {Takuya Hiraoka and Takahisa Imagawa and Taisei Hashimoto and Takashi Onishi and Yoshimasa Tsuruoka},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2110.02034},
    url     = {https://api.semanticscholar.org/CorpusID:238353966}
}
@inproceedings{ObandoCeron2024MixturesOE,
    title   = {Mixtures of Experts Unlock Parameter Scaling for Deep RL},
    author  = {Johan S. Obando-Ceron and Ghada Sokar and Timon Willi and Clare Lyle and Jesse Farebrother and Jakob Foerster and Gintare Karolina Dziugaite and Doina Precup and Pablo Samuel Castro},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:267637059}
}

.\lucidrains\SAC-pytorch\SAC_pytorch\SAC.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor
# 从 torch.nn 库中导入 Module, ModuleList
from torch.nn import Module, ModuleList

# 导入 beartype 库
from beartype import beartype
# 从 beartype.typing 中导入 Tuple, List, Optional, Union
from beartype.typing import Tuple, List, Optional, Union

# 导入 einx 库中的 get_at 函数
from einx import get_at
# 导入 einops 库中的 rearrange, repeat, reduce, pack, unpack 函数
from einops import rearrange, repeat, reduce, pack, unpack

# 导入 ema_pytorch 库中的 EMA 类
from ema_pytorch import EMA

# 定义函数 exists,判断变量是否存在
def exists(v):
    return v is not None

# 定义函数 cast_tuple,将输入转换为元组
def cast_tuple(t, length = 1):
    return t if isinstance(t, tuple) else ((t,) * length)

# 定义 MLP 函数,创建简单的多层感知器网络
@beartype
def MLP(
    dim,
    dim_out,
    dim_hiddens: Union[int, Tuple[int, ...]],
    layernorm = False,
    dropout = 0.,
    activation = nn.ReLU
):
    """
    simple mlp for Q and value networks

    following Figure 1 in https://arxiv.org/pdf/2110.02034.pdf for placement of dropouts and layernorm
    however, be aware that Levine in his lecture has ablations that show layernorm alone (without dropout) is sufficient for regularization
    """

    dim_hiddens = cast_tuple(dim_hiddens)

    layers = []

    curr_dim = dim

    for dim_hidden in dim_hiddens:
        layers.append(nn.Linear(curr_dim, dim_hidden))

        layers.append(nn.Dropout(dropout))

        if layernorm:
            layers.append(nn.LayerNorm(dim_hidden))

        layers.append(activation())

        curr_dim = dim_hidden

    # final layer out

    layers.append(nn.Linear(curr_dim, dim_out))

    return nn.Sequential(*layers)

# 定义 Actor 类,用于创建 Actor 神经网络模型
class Actor(Module):
    def __init__(
        self,
        *,
        dim_state,
        num_cont_actions,
        dim_hiddens: Tuple[int, ...] = tuple(),
        eps = 1e-5
    ):
        super().__init__()
        self.eps = eps

        self.to_cont_actions = MLP(
            dim_state,
            dim_hiddens = dim_hiddens,
            dim_out = num_cont_actions * 2
        )

    def forward(
        self,
        state,
        sample = False
    ):
        """
        einops notation
        n - num actions
        ms - mu sigma
        """

        out = self.to_cont_actions(state)
        mu, sigma = rearrange(out, '... (n ms) -> ms ... n', ms = 2)

        sigma = sigma.sigmoid().clamp(min = self.eps)

        if not sample:
            return mu, sigma

        return mu + sigma * torch.randn_like(sigma)

# 定义 Critic 类,用于创建 Critic 神经网络模型
class Critic(Module):
    @beartype
    def __init__(
        self,
        *,
        dim_state,
        num_continuous_actions,
        dim_hiddens: Tuple[int, ...] = tuple(),
        layernorm = False,
        dropout = 0.
    ):
        super().__init__()

        self.to_q = MLP(
            dim_state + num_continuous_actions,
            dim_out = 1,
            dim_hiddens = dim_hiddens,
            layernorm = layernorm,
            dropout = dropout
        )

    def forward(
        self,
        state,
        actions
    ):
        state_actions, _ = pack([state, actions], 'b *')

        q_values = self.to_q(state_actions)
        q_values = rearrange('b 1 -> b')

        return q_values

# 定义 ValueNetwork 类,用于创建值网络模型
class ValueNetwork(Module):
    @beartype
    def __init__(
        self,
        *,
        dim_state,
        dim_hiddens: Tuple[int, ...] = tuple()
    ):
        super().__init__()

        self.to_values = MLP(
            dim_state,
            dim_out= 1,
            dim_hiddens = dim_hiddens
        )

    def forward(
        self,
        states
    ):
        values = self.to_values(states)
        values = rearrange(values, 'b 1 -> b')
        return values

# 定义 SAC 类,用于创建 SAC 神经网络模型
class SAC(Module):
    def __init__(
        self
    ):
        super().__init__()

    def forward(self, x):
        return x

.\lucidrains\SAC-pytorch\SAC_pytorch\__init__.py

# 从SAC_pytorch包中导入SAC类
from SAC_pytorch.SAC import SAC

.\lucidrains\SAC-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'SAC-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.1',
  # 许可证类型
  license='MIT',
  # 描述
  description = 'Soft Actor Critic',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/SAC-pytorch',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'reinforcement learning',
    'soft actor critic'
  ],
  # 安装依赖项
  install_requires=[
    'beartype',
    'einops>=0.7.0',
    'einx[torch]>=0.1.3',
    'ema-pytorch',
    'pytorch-custom-utils>=0.0.18',
    'soft-moe-pytorch>=0.1.6',
    'torch>=2.0'
  ],
  # 分类器列表
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Scattering Compositional Learner

Implementation of Scattering Compositional Learner, which reached superhuman levels on Raven's Progressive Matrices, a type of IQ test for analogical reasoning.

This repository is meant to be exploratory, so it may not follow the exact architecture of the paper down to the T. It is meant to find the underlying inductive bias that could be exported for use in attention networks. The paper suggests this to be the 'Scattering Transform', which is basically grouped convolutions but where each group is tranformed by one shared neural network.

If you would like the exact architecture used in the paper, the official repository is here.

Install

$ pip install scattering-transform

Use

Complete Scattering Compositional Learner network

import torch
import torch.nn.functional as F
from scattering_transform import SCL, SCLTrainingWrapper

# data - (batch, number of choices, channel dimension, image height, image width)

questions = torch.randn(1, 8, 1, 160, 160)
answers   = torch.randn(1, 8, 1, 160, 160)
labels    = torch.tensor([2])

# instantiate model

model = SCL(
    image_size = 160,                           # size of image
    set_size = 9,                               # number of questions + 1 answer
    conv_channels = [1, 16, 16, 32, 32, 32],    # convolutional channel progression, 1 for greyscale, 3 for rgb
    conv_output_dim = 80,                       # model dimension, the output dimension of the vision net
    attr_heads = 10,                            # number of attribute heads
    attr_net_hidden_dims = [128],               # attribute scatter transform MLP hidden dimension(s)
    rel_heads = 80,                             # number of relationship heads
    rel_net_hidden_dims = [64, 23, 5]           # MLP for relationship net
)

model = SCLTrainingWrapper(model)
logits = model(questions, answers) # (1, 8) - the logits of each answer being the correct match

# train

loss = F.cross_entropy(logits, labels)
loss.backward()

Scattering Transform, which is basically one MLP that acts over groups of the dimension

import torch
from scattering_transform import ScatteringTransform

# for potential use in a Transformer

mlp = ScatteringTransform(
    dims = [1024, 4096, 1024],    # MLP - dimension in -> hidden sizes -> dimension out
    heads = 16,                   # number of groups (heads)
    activation = nn.LeakyReLU     # activation to use in the MLP
)

x = torch.randn(1, 512, 1024)
mlp(x) # (1, 512, 1024)

Citation

@misc{wu2020scattering,
    title={The Scattering Compositional Learner: Discovering Objects, Attributes, Relationships in Analogical Reasoning},
    author={Yuhuai Wu and Honghua Dong and Roger Grosse and Jimmy Ba},
    year={2020},
    eprint={2007.04212},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

.\lucidrains\scattering-compositional-learner\scattering_transform\scattering_transform.py

# 导入 PyTorch 库
import torch
from torch import nn
import torch.nn.functional as F

# 辅助函数

# 如果 val 不为 None,则返回 val,否则返回 default_val
def default(val, default_val):
    return val if val is not None else default_val

# 在指定维度上扩展张量 t 的大小为 k
def expand_dim(t, dim, k):
    t = t.unsqueeze(dim)
    expand_shape = [-1] * len(t.shape)
    expand_shape[dim] = k
    return t.expand(*expand_shape)

# 简单的具有 ReLU 激活函数的多层感知机

class MLP(nn.Module):
    def __init__(self, *dims, activation = None):
        super().__init__()
        assert len(dims) > 2, 'must have at least 3 dimensions, for dimension in and dimension out'
        activation = default(activation, nn.ReLU)

        layers = []
        pairs = list(zip(dims[:-1], dims[1:]))

        for ind, (dim_in, dim_out) in enumerate(pairs):
            is_last = ind >= (len(pairs) - 1)
            layers.append(nn.Linear(dim_in, dim_out))
            if not is_last:
                layers.append(activation())

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# 论文中提到的前馈残差块
# 用于在提取视觉特征后以及提取属性信息后使用

class FeedForwardResidual(nn.Module):
    def __init__(self, dim, mult = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.LayerNorm(dim * mult),
            nn.ReLU(inplace = True),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return x + self.net(x)

# 卷积网络
# 待完成,使其可定制化并添加 Evonorm 以进行批次独立归一化

class ConvNet(nn.Module):
    def __init__(self, image_size, chans, output_dim):
        super().__init__()

        num_conv_layers = len(chans) - 1
        conv_output_size = image_size // (2 ** num_conv_layers)

        convolutions = []
        channel_pairs = list(zip(chans[:-1], chans[1:]))

        for ind, (chan_in, chan_out) in enumerate(channel_pairs):
            is_last = ind >= (len(channel_pairs) - 1)
            convolutions.append(nn.Conv2d(chan_in, chan_out, 3, padding=1, stride=2))
            if not is_last:
                convolutions.append(nn.BatchNorm2d(chan_out))

        self.net = nn.Sequential(
            *convolutions,
            nn.Flatten(1),
            nn.Linear(chans[-1] * (conv_output_size ** 2), output_dim),
            nn.ReLU(inplace=True),
            FeedForwardResidual(output_dim)
        )

    def forward(self, x):
        return self.net(x)

# 散射变换

class ScatteringTransform(nn.Module):
    def __init__(self, dims, heads, activation = None):
        super().__init__()
        assert len(dims) > 2, 'must have at least 3 dimensions, for dimension in, the hidden dimension, and dimension out'

        dim_in, *hidden_sizes, dim_out = dims

        dim_in //= heads
        dim_out //= heads

        self.heads = heads
        self.mlp = MLP(dim_in, *hidden_sizes, dim_out, activation = activation)

    def forward(self, x):
        shape, heads = x.shape, self.heads
        dim = shape[-1]

        assert (dim % heads) == 0, f'the dimension {dim} must be divisible by the number of heads {heads}'

        x = x.reshape(-1, heads, dim // heads)
        x = self.mlp(x)

        return x.reshape(shape)

# 主要的散射组合学习器类

class SCL(nn.Module):
    # 初始化函数,设置模型的参数
    def __init__(
        self,
        image_size = 160,  # 图像大小
        set_size = 9,  # 集合大小
        conv_channels = [1, 16, 16, 32, 32, 32],  # 卷积通道数
        conv_output_dim = 80,  # 卷积输出维度
        attr_heads = 10,  # 属性头数
        attr_net_hidden_dims = [128],  # 属性网络隐藏层维度
        rel_heads = 80,  # 关系头数
        rel_net_hidden_dims = [64, 23, 5]):  # 关系网络隐藏层维度

        super().__init__()
        # 创建视觉模型
        self.vision = ConvNet(image_size, conv_channels, conv_output_dim)

        # 设置属性头数和属性网络
        self.attr_heads = attr_heads
        self.attr_net = ScatteringTransform([conv_output_dim, *attr_net_hidden_dims, conv_output_dim], heads = attr_heads)
        self.ff_residual = FeedForwardResidual(conv_output_dim)

        # 设置关系头数和关系网络
        self.rel_heads = rel_heads
        self.rel_net = MLP(set_size * (conv_output_dim // rel_heads), *rel_net_hidden_dims)

        # 线性层,用于输出logits
        self.to_logit = nn.Linear(rel_net_hidden_dims[-1] * rel_heads, 1)

    # 前向传播函数
    def forward(self, sets):
        # 获取输入集合的形状信息
        b, m, n, c, h, w = sets.shape
        # 将集合展平为二维张量
        images = sets.view(-1, c, h, w)
        # 提取图像特征
        features = self.vision(images)

        # 计算属性
        attrs = self.attr_net(features)
        attrs = self.ff_residual(attrs)

        # 重塑属性张量形状
        attrs = attrs.reshape(b, m, n, self.rel_heads, -1).transpose(-2, -3).flatten(3)
        # 计算关系
        rels = self.rel_net(attrs)
        rels = rels.flatten(2)
        
        # 计算logits
        logits = self.to_logit(rels).flatten(1)
        return logits
# 为了更容易进行训练而创建的包装器类
class SCLTrainingWrapper(nn.Module):
    def __init__(self, scl):
        super().__init__()
        self.scl = scl

    # 前向传播函数,接收问题和答案作为输入
    def forward(self, questions, answers):
        # 在答案张量上增加一个维度
        answers = answers.unsqueeze(2)
        # 在问题张量上扩展维度,维度1扩展为8
        questions = expand_dim(questions, dim=1, k=8)

        # 将问题和答案张量连接在一起,沿着第二个维度
        permutations = torch.cat((questions, answers), dim=2)
        # 将连接后的张量传递给self.scl进行处理
        return self.scl(permutations)

.\lucidrains\scattering-compositional-learner\scattering_transform\__init__.py

# 从scattering_transform包中导入SCL, ScatteringTransform, SCLTrainingWrapper类
from scattering_transform.scattering_transform import SCL, ScatteringTransform, SCLTrainingWrapper

.\lucidrains\scattering-compositional-learner\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元信息
setup(
  name = 'scattering-transform',  # 包名
  packages = find_packages(),  # 查找所有包
  version = '0.0.7',  # 版本号
  license='MIT',  # 许可证
  description = 'Scattering Transform module from the paper Scattering Compositional Learner',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/scattering-compositional-learner',  # 项目链接
  keywords = ['artificial intelligence', 'deep learning', 'reasoning'],  # 关键词
  install_requires=[
      'torch'  # 安装依赖
  ],
  classifiers=[
      'Development Status :: 4 - Beta',  # 开发状态
      'Intended Audience :: Developers',  # 目标受众
      'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 主题
      'License :: OSI Approved :: MIT License',  # 许可证类型
      'Programming Language :: Python :: 3.6',  # 编程语言
  ],
)

.\lucidrains\se3-transformer-pytorch\denoise.py

# 导入 PyTorch 库
import torch
# 导入 PyTorch 中的函数库
import torch.nn.functional as F
# 从 torch.optim 中导入 Adam 优化器
from torch.optim import Adam

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 导入 sidechainnet 库,并从 se3_transformer_pytorch 中导入 SE3Transformer 类
import sidechainnet as scn
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer

# 设置默认的数据类型为 float64
torch.set_default_dtype(torch.float64)

# 定义批量大小为 1
BATCH_SIZE = 1
# 定义每隔多少次梯度累积
GRADIENT_ACCUMULATE_EVERY = 16

# 定义一个循环函数,用于处理数据加载器
def cycle(loader, len_thres = 500):
    while True:
        for data in loader:
            # 如果数据序列长度大于指定阈值,则继续循环
            if data.seqs.shape[1] > len_thres:
                continue
            yield data

# 创建 SE3Transformer 模型
transformer = SE3Transformer(
    num_tokens = 24,
    dim = 8,
    dim_head = 8,
    heads = 2,
    depth = 2,
    attend_self = True,
    input_degrees = 1,
    output_degrees = 2,
    reduce_dim_out = True,
    differentiable_coors = True,
    num_neighbors = 0,
    attend_sparse_neighbors = True,
    num_adj_degrees = 2,
    adj_dim = 4,
    num_degrees=2,
)

# 加载数据集
data = scn.load(
    casp_version = 12,
    thinning = 30,
    with_pytorch = 'dataloaders',
    batch_size = BATCH_SIZE,
    dynamic_batching = False
)

# 创建数据加载器
dl = cycle(data['train'])
# 使用 Adam 优化器来优化 SE3Transformer 模型的参数
optim = Adam(transformer.parameters(), lr=1e-4)
# 将模型转移到 GPU 上
transformer = transformer.cuda()

# 进行训练循环
for _ in range(10000):
    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        # 获取一个批次的数据
        batch = next(dl)
        seqs, coords, masks = batch.seqs, batch.crds, batch.msks

        # 将序列转移到 GPU 上,并取最大值索引
        seqs = seqs.cuda().argmax(dim = -1)
        # 将坐标转移到 GPU 上,并设置数据类型为 float64
        coords = coords.cuda().type(torch.float64)
        # 将掩码转移到 GPU 上,并设置数据类型为布尔型
        masks = masks.cuda().bool()

        # 获取序列长度
        l = seqs.shape[1]
        # 重新排列坐标数据
        coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14)

        # 保留骨架坐标
        coords = coords[:, :, 0:3, :]
        coords = rearrange(coords, 'b l s c -> b (l s) c')

        # 重复序列和掩码
        seq = repeat(seqs, 'b n -> b (n c)', c = 3)
        masks = repeat(masks, 'b n -> b (n c)', c = 3)

        # 添加高斯噪声到坐标数据
        noised_coords = coords + torch.randn_like(coords).cuda()

        # 创建邻接矩阵
        i = torch.arange(seq.shape[-1], device = seqs.device)
        adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))

        # 使用 SE3Transformer 进行前向传播
        out = transformer(
            seq,
            noised_coords,
            mask = masks,
            adj_mat = adj_mat,
            return_type = 1
        )

        # 对去噪后的坐标数据计算均方误差损失
        denoised_coords = noised_coords + out
        loss = F.mse_loss(denoised_coords[masks], coords[masks]) 
        # 反向传播
        (loss / GRADIENT_ACCUMULATE_EVERY).backward()

    # 输出损失值
    print('loss:', loss.item())
    # 更新优化器
    optim.step()
    # 梯度清零
    optim.zero_grad()

SE3 Transformer - Pytorch

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 results and other drug discovery applications.

Open In Colab Example of equivariance

If you had been using any version of SE3 Transformers prior to version 0.6.0, please update. A huge bug has been uncovered by @MattMcPartlon, if you were not using the adjacency sparse neighbors settings and relying on nearest neighbors functionality

Update: It is recommended that you use Equiformer instead

Install

$ pip install se3-transformer-pytorch

Usage

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 512,
    heads = 8,
    depth = 6,
    dim_head = 64,
    num_degrees = 4,
    valid_radius = 10
)

feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
mask  = torch.ones(1, 1024).bool()

out = model(feats, coors, mask) # (1, 1024, 512)

Potential example usage in Alphafold2, as outlined here

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True,
    differentiable_coors = True
)

atom_feats = torch.randn(2, 32, 64)
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refined_coors = coors + model(atom_feats, coors, mask, return_type = 1) # (2, 32, 3)

You can also let the base transformer class take care of embedding the type 0 features being passed in. Assuming they are atoms

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,       # 28 unique atoms
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refined_coors = coors + model(atoms, coors, mask, return_type = 1) # (2, 32, 3)

If you think the net could further benefit from positional encoding, you can featurize your positions in space and pass it in as follows.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 2,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True  # reduce out the final dimension
)

atom_feats  = torch.randn(2, 32, 64, 1) # b x n x d x type0
coors_feats = torch.randn(2, 32, 64, 3) # b x n x d x type1

# atom features are type 0, predicted coordinates are type 1
features = {'0': atom_feats, '1': coors_feats}
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refined_coors = coors + model(features, coors, mask, return_type = 1) # (2, 32, 3) - equivariant to input type 1 features and coordinates

Edges

To offer edge information to SE3 Transformers (say bond types between atoms), you just have to pass in two more keyword arguments on initialization.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,
    dim = 64,
    num_edge_tokens = 4,       # number of edge type, say 4 bond types
    edge_dim = 16,             # dimension of edge embedding
    depth = 2,
    input_degrees = 1,
    num_degrees = 3,
    output_degrees = 1,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
bonds = torch.randint(0, 4, (2, 32, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

pred = model(atoms, coors, mask, edges = bonds, return_type = 0) # (2, 32, 1)

If you would like to pass in continuous values for your edges, you can choose to not set the num_edge_tokens, encode your discrete bond types, and then concat it to the fourier features of these continuous values

import torch
from se3_transformer_pytorch import SE3Transformer
from se3_transformer_pytorch.utils import fourier_encode

model = SE3Transformer(
    dim = 64,
    depth = 1,
    attend_self = True,
    num_degrees = 2,
    output_degrees = 2,
    edge_dim = 34           # edge dimension must match the final dimension of the edges being passed in
)

feats = torch.randn(1, 32, 64)
coors = torch.randn(1, 32, 3)
mask  = torch.ones(1, 32).bool()

pairwise_continuous_values = torch.randint(0, 4, (1, 32, 32, 2))  # say there are 2

edges = fourier_encode(
    pairwise_continuous_values,
    num_encodings = 8,
    include_self = True
) # (1, 32, 32, 34) - {2 * (2 * 8 + 1)}

out = model(feats, coors, mask, edges = edges, return_type = 1)

Sparse Neighbors

If you know the connectivity of your points (say you are working with molecules), you can pass in an adjacency matrix, in the form of a boolean mask (where True indicates connectivity).

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 32,
    heads = 8,
    depth = 1,
    dim_head = 64,
    num_degrees = 2,
    valid_radius = 10,
    attend_sparse_neighbors = True,  # this must be set to true, in which case it will assert that you pass in the adjacency matrix
    num_neighbors = 0,               # if you set this to 0, it will only consider the connected neighbors as defined by the adjacency matrix. but if you set a value greater than 0, it will continue to fetch the closest points up to this many, excluding the ones already specified by the adjacency matrix
    max_sparse_neighbors = 8         # you can cap the number of neighbors, sampled from within your sparse set of neighbors as defined by the adjacency matrix, if specified
)

feats = torch.randn(1, 128, 32)
coors = torch.randn(1, 128, 3)
mask  = torch.ones(1, 128).bool()

# placeholder adjacency matrix
# naively assuming the sequence is one long chain (128, 128)

i = torch.arange(128)
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))

out = model(feats, coors, mask, adj_mat = adj_mat) # (1, 128, 512)

You can also have the network automatically derive for you the Nth-degree neighbors with one extra keyword num_adj_degrees. If you would like the system to differentiate between the degree of the neighbors as edge information, further pass in a non-zero adj_dim.

import torch
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 1,
    attend_self = True,
    num_degrees = 2,
    output_degrees = 2,
    num_neighbors = 0,
    attend_sparse_neighbors = True,
    num_adj_degrees = 2,    # automatically derive 2nd degree neighbors
    adj_dim = 4             # embed 1st and 2nd degree neighbors (as well as null neighbors) with edge embeddings of this dimension
)

feats = torch.randn(1, 32, 64)
coors = torch.randn(1, 32, 3)
mask  = torch.ones(1, 32).bool()

# placeholder adjacency matrix
# naively assuming the sequence is one long chain (128, 128)

i = torch.arange(128)
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))

out = model(feats, coors, mask, adj_mat = adj_mat, return_type = 1)

To have fine control over the dimensionality of each type, you can use the hidden_fiber_dict and out_fiber_dict keywords to pass in a dictionary with the degree to dimension values as the key / values.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,
    dim = 64,
    num_edge_tokens = 4,
    edge_dim = 16,
    depth = 2,
    input_degrees = 1,
    num_degrees = 3,
    output_degrees = 1,
    hidden_fiber_dict = {0: 16, 1: 8, 2: 4},
    out_fiber_dict = {0: 16, 1: 1},
    reduce_dim_out = False
)

atoms = torch.randint(0, 28, (2, 32))
bonds = torch.randint(0, 4, (2, 32, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

pred = model(atoms, coors, mask, edges = bonds)

pred['0'] # (2, 32, 16)
pred['1'] # (2, 32, 1, 3)

Neighbors

You can further control which nodes can be considered by passing in a neighbor mask. All False values will be masked out of consideration.

import torch
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 16,
    dim_head = 16,
    attend_self = True,
    num_degrees = 4,
    output_degrees = 2,
    num_edge_tokens = 4,
    num_neighbors = 8,      # make sure you set this value as the maximum number of neighbors set by your neighbor_mask, or it will throw a warning
    edge_dim = 2,
    depth = 3
)

feats = torch.randn(1, 32, 16)
coors = torch.randn(1, 32, 3)
mask  = torch.ones(1, 32).bool()
bonds = torch.randint(0, 4, (1, 32, 32))

neighbor_mask = torch.ones(1, 32, 32).bool() # set the nodes you wish to be masked out as False

out = model(
    feats,
    coors,
    mask,
    edges = bonds,
    neighbor_mask = neighbor_mask,
    return_type = 1
)

Global Nodes

This feature allows you to pass in vectors that can be viewed as global nodes that are seen by all other nodes. The idea would be to pool your graph into a few feature vectors, which will be projected to key / values across all the attention layers in the network. All nodes will have full access to global node information, regardless of nearest neighbors or adjacency calculation.

import torch
from torch import nn
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 1,
    num_degrees = 2,
    num_neighbors = 4,
    valid_radius = 10,
    global_feats_dim = 32 # this must be set to the dimension of the global features, in this example, 32
)

feats = torch.randn(1, 32, 64)
coors = torch.randn(1, 32, 3)
mask  = torch.ones(1, 32).bool()

# naively derive global features
# by pooling features and projecting
global_feats = nn.Linear(64, 32)(feats.mean(dim = 1, keepdim = True)) # (1, 1, 32)

out = model(feats, coors, mask, return_type = 0, global_feats = global_feats)

Todo:

  • allow global nodes to attend to all other nodes, to give the network a global conduit for information. (Similar to BigBird, ETC, Longformer etc)

Autoregressive

You can use SE3 Transformers autoregressively with just one extra flag

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 512,
    heads = 8,
    depth = 6,
    dim_head = 64,
    num_degrees = 4,
    valid_radius = 10,
    causal = True          # set this to True
)

feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
mask  = torch.ones(1, 1024).bool()

out = model(feats, coors, mask) # (1, 1024, 512)

Experimental Features

Non-pairwise convolved keys

I've discovered that using linearly projected keys (rather than the pairwise convolution) seems to do ok in a toy denoising task. This leads to 25% memory savings. You can try this feature by setting linear_proj_keys = True

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 1,
    num_degrees = 4,
    num_neighbors = 8,
    valid_radius = 10,
    splits = 4,
    linear_proj_keys = True # set this to True
).cuda()

feats = torch.randn(1, 32, 64).cuda()
coors = torch.randn(1, 32, 3).cuda()
mask  = torch.ones(1, 32).bool().cuda()

out = model(feats, coors, mask, return_type = 0)

Shared key / values across all heads

There is a relatively unknown technique for transformers where one can share one key / value head across all the heads of the queries. In my experience in NLP, this usually leads to worse performance, but if you are really in need to tradeoff memory for more depth or higher number of degrees, this may be a good option.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 8,
    num_degrees = 4,
    num_neighbors = 8,
    valid_radius = 10,
    splits = 4,
    one_headed_key_values = True  # one head of key / values shared across all heads of the queries
).cuda()

feats = torch.randn(1, 32, 64).cuda()
coors = torch.randn(1, 32, 3).cuda()
mask  = torch.ones(1, 32).bool().cuda()

out = model(feats, coors, mask, return_type = 0)

Tied key / values

You can also tie the key / values (have them be the same), for half memory savings

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 8,
    num_degrees = 4,
    num_neighbors = 8,
    valid_radius = 10,
    splits = 4,
    tie_key_values = True # set this to True
).cuda()

feats = torch.randn(1, 32, 64).cuda()
coors = torch.randn(1, 32, 3).cuda()
mask  = torch.ones(1, 32).bool().cuda()

out = model(feats, coors, mask, return_type = 0)

Using EGNN

This is an experimental version of EGNN that works for higher types, and greater dimensionality than just 1 (for the coordinates). The class name is still SE3Transformer since it reuses some preexisting logic, so just ignore that for now until I clean it up later.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 32,
    num_neighbors = 8,
    num_edge_tokens = 4,
    edge_dim = 4,
    num_degrees = 4,       # number of higher order types - will use basis on a TCN to project to these dimensions
    use_egnn = True,       # set this to true to use EGNN instead of equivariant attention layers
    egnn_hidden_dim = 64,  # egnn hidden dimension
    depth = 4,             # depth of EGNN
    reduce_dim_out = True  # will project the dimension of the higher types to 1
).cuda()

feats = torch.randn(2, 32, 32).cuda()
coors = torch.randn(2, 32, 3).cuda()
bonds = torch.randint(0, 4, (2, 32, 32)).cuda()
mask  = torch.ones(2, 32).bool().cuda()

refinement = model(feats, coors, mask, edges = bonds, return_type = 1) # (2, 32, 3)

coors = coors + refinement  # update coors with refinement

If you would like to specify individual dimensions for each of the higher types, just pass in hidden_fiber_dict where the dictionary is in the format {<degree>:<dim>} instead of num_degrees

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 32,
    num_neighbors = 8,
    hidden_fiber_dict = {0: 32, 1: 16, 2: 8, 3: 4},
    use_egnn = True,
    depth = 4,
    egnn_hidden_dim = 64,
    egnn_weights_clamp_value = 2, 
    reduce_dim_out = True
).cuda()

feats = torch.randn(2, 32, 32).cuda()
coors = torch.randn(2, 32, 3).cuda()
mask  = torch.ones(2, 32).bool().cuda()

refinement = model(feats, coors, mask, return_type = 1) # (2, 32, 3)

coors = coors + refinement  # update coors with refinement

Scaling (wip)

This section will list ongoing efforts to make SE3 Transformer scale a little better.

Firstly, I have added reversible networks. This allows me to add a little more depth before hitting the usual memory roadblocks. Equivariance preservation is demonstrated in the tests.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 20,
    dim = 32,
    dim_head = 32,
    heads = 4,
    depth = 12,             # 12 layers
    input_degrees = 1,
    num_degrees = 3,
    output_degrees = 1,
    reduce_dim_out = True,
    reversible = True       # set reversible to True
).cuda()

atoms = torch.randint(0, 4, (2, 32)).cuda()
coors = torch.randn(2, 32, 3).cuda()
mask  = torch.ones(2, 32).bool().cuda()

pred = model(atoms, coors, mask = mask, return_type = 0)

loss = pred.sum()
loss.backward()

Examples

First install sidechainnet

$ pip install sidechainnet

Then run the protein backbone denoising task

$ python denoise.py

Caching

By default, the basis vectors are cached. However, if there is ever the need to clear the cache, you simply have to set the environmental flag CLEAR_CACHE to some value on initiating the script

$ CLEAR_CACHE=1 python train.py

Or you can try deleting the cache directory, which should exist at

$ rm -rf ~/.cache.equivariant_attention

You can also designate your own directory where you want the caches to be stored, in the case that the default directory may have permission issues

CACHE_PATH=./path/to/my/cache python train.py

Testing

$ python setup.py pytest

Credit

This library is largely a port of Fabian's official repository, but without the DGL library.

Citations

@misc{fuchs2020se3transformers,
    title   = {SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks}, 
    author  = {Fabian B. Fuchs and Daniel E. Worrall and Volker Fischer and Max Welling},
    year    = {2020},
    eprint  = {2006.10503},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{satorras2021en,
    title   = {E(n) Equivariant Graph Neural Networks},
    author  = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
    year    = {2021},
    eprint  = {2102.09844},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{gomez2017reversible,
    title     = {The Reversible Residual Network: Backpropagation Without Storing Activations},
    author    = {Aidan N. Gomez and Mengye Ren and Raquel Urtasun and Roger B. Grosse},
    year      = {2017},
    eprint    = {1707.04585},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{shazeer2019fast,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam Shazeer},
    year    = {2019},
    eprint  = {1911.02150},
    archivePrefix = {arXiv},
    primaryClass = {cs.NE}
}