Lucidrains-系列项目源码解析-六十八-

131 阅读12分钟

Lucidrains 系列项目源码解析(六十八)

.\lucidrains\nuwa-pytorch\nuwa_pytorch\__init__.py

# 从 nuwa_pytorch.nuwa_pytorch 模块中导入 NUWA、NUWASketch、NUWAVideoAudio、Sparse3DNA、CrossModalityCrossAttention 类
# 以及从 nuwa_pytorch.vqgan_vae 模块中导入 VQGanVAE 类
from nuwa_pytorch.nuwa_pytorch import NUWA, NUWASketch, NUWAVideoAudio, Sparse3DNA, CrossModalityCrossAttention
from nuwa_pytorch.vqgan_vae import VQGanVAE

# 从 nuwa_pytorch.train_vqgan_vae 模块中导入 VQGanVAETrainer 类
# 以及从 nuwa_pytorch.train_nuwa 模块中导入 NUWATrainer 类
from nuwa_pytorch.train_vqgan_vae import VQGanVAETrainer
from nuwa_pytorch.train_nuwa import NUWATrainer

NÜWA - Pytorch

Join us on Discord

Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch. It also contain an extension into video and audio generation, using a dual decoder approach.

Yannic Kilcher

DeepReader

Status

  • March 2022 - seeing signs of life with a difficult version of moving mnist

  • April 2022 - It seems as though a diffusion based method has taken the new throne for SOTA. However, I will continue on with NUWA, extending it to use multi-headed codes + hierarchical causal transformer. I think that direction is untapped for improving on this line of work.

Install

$ pip install nuwa-pytorch

Usage

First train the VAE

import torch
from nuwa_pytorch import VQGanVAE

vae = VQGanVAE(
    dim = 512,
    channels = 3,               # default is 3, but can be changed to any value for the training of the segmentation masks (sketches)
    image_size = 256,           # image size
    num_layers = 4,             # number of downsampling layers
    num_resnet_blocks = 2,      # number of resnet blocks
    vq_codebook_size = 8192,    # codebook size
    vq_decay = 0.8              # codebook exponential decay
)

imgs = torch.randn(10, 3, 256, 256)

# alternate learning for autoencoder ...

loss = vae(imgs, return_loss = True)
loss.backward()

# and the discriminator ...

discr_loss = vae(imgs, return_discr_loss = True)
discr_loss.backward()

# do above for many steps

# return reconstructed images and make sure they look ok

recon_imgs = vae(imgs)

Then, with your learned VAE

import torch
from nuwa_pytorch import NUWA, VQGanVAE

# autoencoder

vae = VQGanVAE(
    dim = 64,
    num_layers = 4,
    image_size = 256,
    num_conv_blocks = 2,
    vq_codebook_size = 8192
)

# NUWA transformer

nuwa = NUWA(
    vae = vae,
    dim = 512,
    text_num_tokens = 20000,                # number of text tokens
    text_enc_depth = 12,                    # text encoder depth
    text_enc_heads = 8,                     # number of attention heads for encoder
    text_max_seq_len = 256,                 # max sequence length of text conditioning tokens (keep at 256 as in paper, or shorter, if your text is not that long)
    max_video_frames = 10,                  # number of video frames
    image_size = 256,                       # size of each frame of video
    dec_depth = 64,                         # video decoder depth
    dec_heads = 8,                          # number of attention heads in decoder
    dec_reversible = True,                  # reversible networks - from reformer, decoupling memory usage from depth
    enc_reversible = True,                  # reversible encoders, if you need it
    attn_dropout = 0.05,                    # dropout for attention
    ff_dropout = 0.05,                      # dropout for feedforward
    sparse_3dna_kernel_size = (5, 3, 3),    # kernel size of the sparse 3dna attention. can be a single value for frame, height, width, or different values (to simulate axial attention, etc)
    sparse_3dna_dilation = (1, 2, 4),       # cycle dilation of 3d conv attention in decoder, for more range
    shift_video_tokens = True               # cheap relative positions for sparse 3dna transformer, by shifting along spatial dimensions by one
).cuda()

# data

text = torch.randint(0, 20000, (1, 256)).cuda()
video = torch.randn(1, 10, 3, 256, 256).cuda() # (batch, frames, channels, height, width)

loss = nuwa(
    text = text,
    video = video,
    return_loss = True  # set this to True, only for training, to return cross entropy loss
)

loss.backward()

# do above with as much data as possible

# then you can generate a video from text

video = nuwa.generate(text = text, num_frames = 5) # (1, 5, 3, 256, 256)

Conditioning on Sketches

In the paper, they also present a way to condition the video generation based on segmentation mask(s). You can easily do this as well, given you train a VQGanVAE on the sketches before hand.

Then, you will use NUWASketch instead of NUWA, which can accept the sketch VAE as a reference

ex.

import torch
from nuwa_pytorch import NUWASketch, VQGanVAE

# autoencoder, one for main video, the other for the sketch

vae = VQGanVAE(
    dim = 64,
    num_layers = 4,
    image_size = 256,
    num_conv_blocks = 2,
    vq_codebook_size = 8192
)

sketch_vae = VQGanVAE(
    dim = 512,
    channels = 5,                # say the sketch has 5 classes
    num_layers = 4,
    image_size = 256,
    num_conv_blocks = 2,
    vq_codebook_size = 8192
)

# NUWA transformer for conditioning with sketches

nuwa = NUWASketch(
    vae = vae,
    sketch_vae = sketch_vae,
    dim = 512,                              # model dimensions
    sketch_enc_depth = 12,                  # sketch encoder depth
    sketch_enc_heads = 8,                   # number of attention heads for sketch encoder
    sketch_max_video_frames = 3,            # max number of frames for sketches
    sketch_enc_use_sparse_3dna = True,      # whether to use 3d-nearby attention (of full attention if False) for sketch encoding transformer
    max_video_frames = 10,                  # number of video frames
    image_size = 256,                       # size of each frame of video
    dec_depth = 64,                         # video decoder depth
    dec_heads = 8,                          # number of attention heads in decoder
    dec_reversible = True,                  # reversible networks - from reformer, decoupling memory usage from depth
    enc_reversible = True,                  # reversible encoders, if you need it
    attn_dropout = 0.05,                    # dropout for attention
    ff_dropout = 0.05,                      # dropout for feedforward
    sparse_3dna_kernel_size = (5, 3, 3),    # kernel size of the sparse 3dna attention. can be a single value for frame, height, width, or different values (to simulate axial attention, etc)
    sparse_3dna_dilation = (1, 2, 4),       # cycle dilation of 3d conv attention in decoder, for more range
    cross_2dna_kernel_size = 5,             # 2d kernel size of spatial grouping of attention from video frames to sketches
    cross_2dna_dilation = 1,                # 2d dilation of spatial attention from video frames to sketches
    shift_video_tokens = True               # cheap relative positions for sparse 3dna transformer, by shifting along spatial dimensions by one
).cuda()

# data

sketch = torch.randn(2, 2, 5, 256, 256).cuda() # (batch, frames, segmentation classes, height, width)
sketch_mask = torch.ones(2, 2).bool().cuda()   # (batch, frames) [Optional]
video = torch.randn(2, 10, 3, 256, 256).cuda() # (batch, frames, channels, height, width)

loss = nuwa(
    sketch = sketch,
    sketch_mask =sketch_mask,
    video = video,
    return_loss = True  # set this to True, only for training, to return cross entropy loss
)

loss.backward()

# do above with as much data as possible

# then you can generate a video from sketch(es)

video = nuwa.generate(sketch = sketch, num_frames = 5) # (1, 5, 3, 256, 256)

Text to Video and Audio

This repository will also offer a variant of NUWA that can produce both video and audio. For now, the audio will need to be encoded manually.

import torch
from nuwa_pytorch import NUWAVideoAudio, VQGanVAE

# autoencoder

vae = VQGanVAE(
    dim = 64,
    num_layers = 4,
    image_size = 256,
    num_conv_blocks = 2,
    vq_codebook_size = 100
)

# NUWA transformer

nuwa = NUWAVideoAudio(
    vae = vae,
    dim = 512,
    num_audio_tokens = 2048,                # codebook size for audio tokens
    num_audio_tokens_per_video_frame = 32,  # number of audio tokens per video frame
    cross_modality_attn_every = 3,          # cross modality attention every N layers
    text_num_tokens = 20000,                # number of text tokens
    text_enc_depth = 1,                     # text encoder depth
    text_enc_heads = 8,                     # number of attention heads for encoder
    text_max_seq_len = 256,                 # max sequence length of text conditioning tokens (keep at 256 as in paper, or shorter, if your text is not that long)
    max_video_frames = 10,                  # number of video frames
    image_size = 256,                       # size of each frame of video
    dec_depth = 4,                          # video decoder depth
    dec_heads = 8,                          # number of attention heads in decoder
    enc_reversible = True,                  # reversible encoders, if you need it
    dec_reversible = True,                  # quad-branched reversible network, for making depth of twin video / audio decoder independent of network depth. recommended to be turned on unless you have a ton of memory at your disposal
    attn_dropout = 0.05,                    # dropout for attention
    ff_dropout = 0.05,                      # dropout for feedforward
    sparse_3dna_kernel_size = (5, 3, 3),    # kernel size of the sparse 3dna attention. can be a single value for frame, height, width, or different values (to simulate axial attention, etc)
    sparse_3dna_dilation = (1, 2, 4),       # cycle dilation of 3d conv attention in decoder, for more range
    shift_video_tokens = True               # cheap relative positions for sparse 3dna transformer, by shifting along spatial dimensions by one
).cuda()

# data

text = torch.randint(0, 20000, (1, 256)).cuda()
audio = torch.randint(0, 2048, (1, 32 * 10)).cuda() # (batch, audio tokens per frame * max video frames)
video = torch.randn(1, 10, 3, 256, 256).cuda() # (batch, frames, channels, height, width)

loss = nuwa(
    text = text,
    video = video,
    audio = audio,
    return_loss = True  # set this to True, only for training, to return cross entropy loss
)

loss.backward()

# do above with as much data as possible

# then you can generate a video from text

video, audio = nuwa.generate(text = text, num_frames = 5) # (1, 5, 3, 256, 256), (1, 32 * 5 == 160)

Trainers

This library will offer some utilities to make training easier. For starters, you can use the VQGanVAETrainer class to take care of training the VQGanVAE. Simply wrap the model and also pass in the image folder path as well as the various training hyperparameters.

import torch
from nuwa_pytorch import VQGanVAE, VQGanVAETrainer

vae = VQGanVAE(
    dim = 64,
    image_size = 256,
    num_layers = 5,
    vq_codebook_size = 1024,
    vq_use_cosine_sim = True,
    vq_codebook_dim = 32,
    vq_orthogonal_reg_weight = 10,
    vq_orthogonal_reg_max_codes = 128,
).cuda()

trainer = VQGanVAETrainer(
    vae,                           # VAE defined above
    folder ='/path/to/images',     # path to images
    lr = 3e-4,                     # learning rate
    num_train_steps = 100000,      # number of training steps
    batch_size = 8,                # batch size
    grad_accum_every = 4           # gradient accumulation (effective batch size is (batch_size x grad_accum_every))
)

trainer.train()

# results and model checkpoints will be saved periodically to ./results

To train NUWA, first you need to organize a folder of .gif files with corresponding .txt files containing its caption. It should be organized as such.

ex.

📂video-and-text-data
 ┣ 📜cat.gif
 ┣ 📜cat.txt
 ┣ 📜dog.gif
 ┣ 📜dog.txt
 ┣ 📜turtle.gif
 ┗ 📜turtle.txt
```py

Then you will load your previously trained VQGan-VAE and train NUWA with the `GifVideoDataset` and `NUWATrainer` classes.

```py
import torch
from nuwa_pytorch import NUWA, VQGanVAE
from nuwa_pytorch.train_nuwa import GifVideoDataset, NUWATrainer

# dataset

ds = GifVideoDataset(
    folder = './path/to/videos/',
    channels = 1
)

# autoencoder

vae = VQGanVAE(
    dim = 64,
    image_size = 256,
    num_layers = 5,
    num_resnet_blocks = 2,
    vq_codebook_size = 512,
    attn_dropout = 0.1
)

vae.load_state_dict(torch.load('./path/to/trained/vae.pt'))

# NUWA transformer

nuwa = NUWA(
    vae = vae,
    dim = 512,
    text_enc_depth = 6,
    text_max_seq_len = 256,
    max_video_frames = 10,
    dec_depth = 12,
    dec_reversible = True,
    enc_reversible = True,
    attn_dropout = 0.05,
    ff_dropout = 0.05,
    sparse_3dna_kernel_size = (5, 3, 3),
    sparse_3dna_dilation = (1, 2, 4),
    shift_video_tokens = True
).cuda()

# data

trainer = NUWATrainer(
    nuwa = nuwa,                 # NUWA transformer
    dataset = dataset,           # video dataset class
    num_train_steps = 1000000,   # number of training steps
    lr = 3e-4,                   # learning rate
    wd = 0.01,                   # weight decay
    batch_size = 8,              # batch size
    grad_accum_every = 4,        # gradient accumulation
    max_grad_norm = 0.5,         # gradient clipping
    num_sampled_frames = 10,     # number of frames to sample
    results_folder = './results' # folder to store checkpoints and samples
)

trainer.train()
```py

## VQ improvements

This library depends on this <a href="https://github.com/lucidrains/vector-quantize-pytorch">vector quantization</a> library, which comes with a number of improvements (improved vqgan, orthogonal codebook regularization, etc). To use any of these improvements, you can configure the vector quantizer keyword params by prepending `vq_` on `VQGanVAE` initialization.

ex. cosine sim proposed in <a href="https://arxiv.org/abs/2110.04627">improved vqgan</a>

```py
from nuwa_pytorch import VQGanVAE

vae = VQGanVAE(
    dim = 64,
    image_size = 256,
    num_layers = 4,
    vq_use_cosine_sim = True
    # VectorQuantize will be initialized with use_cosine_sim = True
    # https://github.com/lucidrains/vector-quantize-pytorch#cosine-similarity
).cuda()
```py

## Todo

- [x] complete 3dna causal attention in decoder
- [x] write up easy generation functions
- [x] make sure GAN portion of VQGan is correct, reread paper
- [x] make sure adaptive weight in vqgan is correctly built
- [x] offer new vqvae improvements (orthogonal reg and smaller codebook dimensions)
- [x] batch video tokens -> vae during video generation, to prevent oom
- [x] query chunking in 3dna attention, to put a cap on peak memory
- [x] flesh out VAE resnet blocks, offer some choices
- [x] add all stability tricks from cogview paper by default
- [x] make VQGan able to accept custom VGG for LPAPs loss (audio)
- [x] add feedforward chunking
- [x] add shift token in decoder for cheap powerful RPE
- [x] add reversible networks, to save on memory on depth
- [x] support kernel sizes different along each dimension for sparse 3dna
- [x] add some autotrainer that takes care of the alternating updates of discriminator and VQVAE generator
- [x] segmentation mask encoder, make sure embeddings can undergo 3dna attention with decoder during cross attention
- [x] finish 2d-nearby cross attention for sketches
- [x] able to add convnext blocks to other layers in vqgan vae
- [x] offer vqvae training script
- [x] handle variable lengthed sketches, accept a mask on the sketch frames dimension
- [x] take care of audio transformer and cross modality attention
- [x] add audio transformer, and build audio / video nearby cross attention
- [x] make dual decoder reversible
- [x] rotary embeddings for encoder
- [x] add cycle dilation to audio
- [x] omit vgg from VAE state dict
- [x] add cosine sim attention from swinv2 as an option
- [x] add axial positional embedding to audio
- [ ] Triton kernel for 3dna attention
- [ ] offer a colab with moving mnist example, conditioned on present digits
- [ ] build NUWA controller class that can accept text or sketch
- [ ] key masking for 3dna attention - for variable sketch length masking
- [ ] figure out spec vqgan and fit it into the framework, take care of audio encoding / decoding automatically
- [ ] turn into CLI tool, like stylegan2-pytorch
- [ ] look into integrating https://github.com/lucidrains/RQ-Transformer for both video and audio
- [ ] inference caching

## Citations

```py
@misc{wu2021nuwa,
    title   = {N\"UWA: Visual Synthesis Pre-training for Neural visUal World creAtion}, 
    author  = {Chenfei Wu and Jian Liang and Lei Ji and Fan Yang and Yuejian Fang and Daxin Jiang and Nan Duan},
    year    = {2021},
    eprint  = {2111.12417},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
```py

```py
@misc{esser2021taming,
    title   = {Taming Transformers for High-Resolution Image Synthesis},
    author  = {Patrick Esser and Robin Rombach and Björn Ommer},
    year    = {2021},
    eprint  = {2012.09841},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
```py

```py
@misc{iashin2021taming,
    title   = {Taming Visually Guided Sound Generation},
    author  = {Vladimir Iashin and Esa Rahtu},
    year    = {2021},
    eprint  = {2110.08791},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
```py

```py
@misc{ding2021cogview,
    title   = {CogView: Mastering Text-to-Image Generation via Transformers},
    author  = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
    year    = {2021},
    eprint  = {2105.13290},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
```py

```py
@misc{kitaev2020reformer,
    title   = {Reformer: The Efficient Transformer},
    author  = {Nikita Kitaev and Łukasz Kaiser and Anselm Levskaya},
    year    = {2020},
    eprint  = {2001.04451},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
```py

```py
@misc{shazeer2020talkingheads,
    title   = {Talking-Heads Attention}, 
    author  = {Noam Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou},
    year    = {2020},
    eprint  = {2003.02436},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
```py

```py
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}    
}
```py

```py
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
```py

```py
@inproceedings{ho2021classifierfree,
    title   = {Classifier-Free Diffusion Guidance},
    author  = {Jonathan Ho and Tim Salimans},
    booktitle = {NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications},
    year    = {2021},
    url     = {https://openreview.net/forum?id=qw8AKxfYbI}
}
```py

```py
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
```py

```py
@misc{crowson2022,
    author  = {Katherine Crowson},
    url     = {https://twitter.com/RiversHaveWings/status/1478093658716966912}
}

Attention is the rarest and purest form of generosity. - Simone Weil

.\lucidrains\nuwa-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'nuwa-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 包含所有数据文件
  include_package_data = True,
  # 版本号
  version = '0.7.8',
  # 许可证类型
  license='MIT',
  # 包的描述
  description = 'NÜWA - Pytorch',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/nuwa-pytorch',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'attention mechanism',
    'transformers'
  ],
  # 安装依赖包
  install_requires=[
    'einops>=0.4.1',
    'ftfy',
    'pillow',
    'regex',
    'torch>=1.6',
    'torchvision',
    'tqdm',
    'unfoldNd',
    'vector-quantize-pytorch>=0.4.10'
  ],
  # 分类标签列表
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\NWT-pytorch\nwt_pytorch\nwt_pytorch.py

import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat
from einops.layers.torch import EinMix as Mix

# 定义一个名为Memcodes的神经网络模型,继承自nn.Module类
class Memcodes(nn.Module):
    def __init__(
        self,
        *,
        dim,  # 输入数据的维度
        num_codes,  # 编码的数量
        heads = 8,  # 多头注意力机制中的头数,默认为8
        temperature = 1.,  # 温度参数,默认为1
    ):
        super().__init__()
        assert (dim % heads) == 0, 'dimension must be divisible by number of heads'
        self.heads = heads
        self.dim = dim
        self.scale = (dim // heads) ** -0.5  # 缩放因子
        self.temperature = temperature
        self.num_codes = num_codes

        num_codebooks = heads
        codebook_dim = dim // heads

        # 初始化编码参数
        self.codes = nn.Parameter(torch.randn(num_codebooks, num_codes, codebook_dim))
        # 初始化转换矩阵,用于将编码转换为key
        self.to_k = Mix('h n d -> h n c', weight_shape = 'h d c', h = heads, d = codebook_dim, c = codebook_dim)
        # 初始化转换矩阵,用于将编码转换为value
        self.to_v = Mix('h n d -> h n c', weight_shape = 'h d c', h = heads, d = codebook_dim, c = codebook_dim)

    # 根据编码的索引获取编码
    def get_codes_from_indices(self, codebook_indices, *, merge_output_heads = True):
        batch = codebook_indices.shape[0]

        values = self.to_v(self.codes)
        values = repeat(values, 'h n d -> b h n d', b = batch)

        codebook_indices = repeat(codebook_indices, '... -> ... d', d = values.shape[-1])
        out = values.gather(2, codebook_indices)

        if not merge_output_heads:
            return out

        return rearrange(out, 'b h n d -> b n (h d)')

    # 前向传播函数
    def forward(self, x, *, merge_output_heads = True):
        assert x.shape[-1] == self.dim

        # 将输入数据分成多个头
        q = rearrange(x, 'b n (h d) -> b h n d', h = self.heads)

        q = q * self.scale

        # 获取编码的key和value
        k, v = self.to_k(self.codes), self.to_v(self.codes)

        # 使用直通Gumbel Softmax
        logits = einsum('b h i d, h j d -> b h i j', q, k)

        if self.training:
            attn = F.gumbel_softmax(logits, tau = self.temperature, dim = -1, hard = True)
            codebook_indices = attn.argmax(dim = -1)
        else:
            codebook_indices = logits.argmax(dim = -1)
            attn = F.one_hot(codebook_indices, num_classes = self.num_codes).float()

        out = einsum('b h i j, h j d -> b h i d', attn, v)

        if not merge_output_heads:
            return out, codebook_indices

        # 如果指定了合并头部,则合并头部
        out = rearrange(out, 'b h n d -> b n (h d)')
        return out, codebook_indices

.\lucidrains\NWT-pytorch\nwt_pytorch\__init__.py

# 从 nwt_pytorch.nwt_pytorch 模块中导入 Memcodes 类
from nwt_pytorch.nwt_pytorch import Memcodes

NWT - Pytorch (wip)

Implementation of NWT, audio-to-video generation, in Pytorch.

Generated samples

Install

$ pip install nwt-pytorch

Usage

The paper proposes a new discrete latent representation named Memcodes, which can be succinctly described as a type of multi-head hard-attention to learned memory (codebook) key / values. They claim the need for less codes and smaller codebook dimension in order to achieve better reconstructions.

import torch
from nwt_pytorch import Memcodes

codebook = Memcodes(
    dim = 512,            # dimension of incoming features (codebook dimension will be dim / heads)
    heads = 8,            # head dimension, which is equivalent ot number of codebooks
    num_codes = 1024,     # number of codes per codebook
    temperature = 1.      # gumbel softmax temperature
)

x = torch.randn(1, 1024, 512)
out, codebook_indices = codebook(x) # (1, 1024, 512), (1, 1024, 8)
# (batch, seq, dimension), (batch, seq, heads)

# reconstruct output from codebook indices (codebook indices are autoregressed out from an attention net in paper)

assert torch.allclose(codebook.get_codes_from_indices(codebook_indices), out)

Citations

@misc{mama2021nwt,
    title   = {NWT: Towards natural audio-to-video generation with representation learning}, 
    author  = {Rayhane Mama and Marc S. Tyndel and Hashiam Kadhim and Cole Clifford and Ragavan Thurairatnam},
    year    = {2021},
    eprint  = {2106.04283},
    archivePrefix = {arXiv},
    primaryClass = {cs.SD}
}

.\lucidrains\NWT-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'nwt-pytorch',
  # 查找并包含所有包
  packages = find_packages(),
  # 版本号
  version = '0.0.4',
  # 许可证
  license='MIT',
  # 描述
  description = 'NWT - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/NWT-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'deep learning',
    'pytorch',
    'audio to video synthesis'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.4',
    'torch'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\nystrom-attention\nystrom_attention\nystrom_attention.py

# 从 math 模块中导入 ceil 函数
from math import ceil
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn.functional 中导入 F 模块

import torch.nn.functional as F

# 从 einops 库中导入 rearrange 和 reduce 函数

from einops import rearrange, reduce

# 定义一个辅助函数 exists,用于判断变量是否存在
def exists(val):
    return val is not None

# 定义 Moore-Penrose 伪逆的迭代计算函数
def moore_penrose_iter_pinv(x, iters = 6):
    # 获取输入张量 x 的设备信息
    device = x.device

    # 计算 x 的绝对值
    abs_x = torch.abs(x)
    # 沿着最后一个维度求和,得到列和
    col = abs_x.sum(dim = -1)
    # 沿着倒数第二个维度求和,得到行和
    row = abs_x.sum(dim = -2)
    # 对 x 进行重排,转置操作
    z = rearrange(x, '... i j -> ... j i') / (torch.max(col) * torch.max(row))

    # 创建单位矩阵
    I = torch.eye(x.shape[-1], device = device)
    I = rearrange(I, 'i j -> () i j')

    # 迭代计算 Moore-Penrose 伪逆
    for _ in range(iters):
        xz = x @ z
        z = 0.25 * z @ (13 * I - (xz @ (15 * I - (xz @ (7 * I - xz))))

    return z

# 主要的注意力类 NystromAttention
class NystromAttention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        num_landmarks = 256,
        pinv_iterations = 6,
        residual = True,
        residual_conv_kernel = 33,
        eps = 1e-8,
        dropout = 0.
    ):
        super().__init__()
        self.eps = eps
        inner_dim = heads * dim_head

        self.num_landmarks = num_landmarks
        self.pinv_iterations = pinv_iterations

        self.heads = heads
        self.scale = dim_head ** -0.5
        # 定义一个线性层,用于将输入维度转换为内部维度的三倍
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        # 定义输出层,包含一个线性层和一个 dropout 层
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

        self.residual = residual
        # 如果启用残差连接
        if residual:
            kernel_size = residual_conv_kernel
            padding = residual_conv_kernel // 2
            # 定义一个卷积层,用于残差连接
            self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding = (padding, 0), groups = heads, bias = False)
    # 定义前向传播函数,接受输入 x,mask 和 return_attn 参数
    def forward(self, x, mask = None, return_attn = False):
        # 解包 x 的形状信息,包括 batch size (b), 序列长度 (n), 头数 (h), 地标数 (m), 伪逆迭代次数 (iters), 以及 epsilon (eps)
        b, n, _, h, m, iters, eps = *x.shape, self.heads, self.num_landmarks, self.pinv_iterations, self.eps

        # 将序列填充,使其可以被均匀地分成 m 个地标
        remainder = n % m
        if remainder > 0:
            padding = m - (n % m)
            x = F.pad(x, (0, 0, padding, 0), value = 0)

            if exists(mask):
                mask = F.pad(mask, (padding, 0), value = False)

        # 派生查询、键、值
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 将查询、键、值中的掩码位置设为 0
        if exists(mask):
            mask = rearrange(mask, 'b n -> b () n')
            q, k, v = map(lambda t: t * mask[..., None], (q, k, v))

        q = q * self.scale

        # 通过求和缩减生成地标,然后使用掩码计算均值
        l = ceil(n / m)
        landmark_einops_eq = '... (n l) d -> ... n d'
        q_landmarks = reduce(q, landmark_einops_eq, 'sum', l = l)
        k_landmarks = reduce(k, landmark_einops_eq, 'sum', l = l)

        # 计算地标掩码,并准备计算掩码均值时的非掩码元素总和
        divisor = l
        if exists(mask):
            mask_landmarks_sum = reduce(mask, '... (n l) -> ... n', 'sum', l = l)
            divisor = mask_landmarks_sum[..., None] + eps
            mask_landmarks = mask_landmarks_sum > 0

        # 如果存在掩码,则进行掩码均值计算
        q_landmarks = q_landmarks / divisor
        k_landmarks = k_landmarks / divisor

        # 相似度计算
        einops_eq = '... i d, ... j d -> ... i j'
        sim1 = einsum(einops_eq, q, k_landmarks)
        sim2 = einsum(einops_eq, q_landmarks, k_landmarks)
        sim3 = einsum(einops_eq, q_landmarks, k)

        # 掩码处理
        if exists(mask):
            mask_value = -torch.finfo(q.dtype).max
            sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value)
            sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value)
            sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value)

        # 计算公式 (15) 中的等式,并聚合值
        attn1, attn2, attn3 = map(lambda t: t.softmax(dim = -1), (sim1, sim2, sim3))
        attn2_inv = moore_penrose_iter_pinv(attn2, iters)

        out = (attn1 @ attn2_inv) @ (attn3 @ v)

        # 添加值的深度卷积残差
        if self.residual:
            out = out + self.res_conv(v)

        # 合并和组合头
        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        out = self.to_out(out)
        out = out[:, -n:]

        # 如果需要返回注意力权重,则返回输出和注意力权重
        if return_attn:
            attn = attn1 @ attn2_inv @ attn3
            return out, attn

        return out
# transformer

# 定义一个预标准化层,包含一个 LayerNorm 层和一个传入的函数
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)  # 初始化 LayerNorm 层
        self.fn = fn  # 保存传入的函数

    def forward(self, x, **kwargs):
        x = self.norm(x)  # 对输入数据进行标准化
        return self.fn(x, **kwargs)  # 调用传入的函数处理标准化后的数据

# 定义一个前馈神经网络层,包含线性层、GELU 激活函数、Dropout 和另一个线性层
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),  # 第一个线性层
            nn.GELU(),  # GELU 激活函数
            nn.Dropout(dropout),  # Dropout 层
            nn.Linear(dim * mult, dim)  # 第二个线性层
        )

    def forward(self, x):
        return self.net(x)  # 前馈神经网络的前向传播

# 定义一个 Nystromformer 模型,包含多个层
class Nystromformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_landmarks = 256,
        pinv_iterations = 6,
        attn_values_residual = True,
        attn_values_residual_conv_kernel = 33,
        attn_dropout = 0.,
        ff_dropout = 0.   
    ):
        super().__init__()

        self.layers = nn.ModuleList([])  # 初始化一个空的 ModuleList
        for _ in range(depth):
            # 每一层包含一个 NystromAttention 层和一个 FeedForward 层,都经过预标准化
            self.layers.append(nn.ModuleList([
                PreNorm(dim, NystromAttention(dim = dim, dim_head = dim_head, heads = heads, num_landmarks = num_landmarks, pinv_iterations = pinv_iterations, residual = attn_values_residual, residual_conv_kernel = attn_values_residual_conv_kernel, dropout = attn_dropout)),
                PreNorm(dim, FeedForward(dim = dim, dropout = ff_dropout))
            ]))

    def forward(self, x, mask = None):
        # 遍历每一层,依次进行注意力计算和前馈神经网络处理
        for attn, ff in self.layers:
            x = attn(x, mask = mask) + x  # 注意力计算后加上残差连接
            x = ff(x) + x  # 前馈神经网络处理后加上残差连接
        return x  # 返回处理后的数据

.\lucidrains\nystrom-attention\nystrom_attention\__init__.py

# 从 nystrom_attention 模块中导入 NystromAttention 和 Nystromformer 类
from nystrom_attention.nystrom_attention import NystromAttention, Nystromformer
# 将 Nystromformer 类赋值给 Nystromer 变量
Nystromer = Nystromformer

Nyström Attention

Implementation of Nyström Self-attention, from the paper Nyströmformer.

Yannic Kilcher video

Install

$ pip install nystrom-attention

Usage

import torch
from nystrom_attention import NystromAttention

attn = NystromAttention(
    dim = 512,
    dim_head = 64,
    heads = 8,
    num_landmarks = 256,    # number of landmarks
    pinv_iterations = 6,    # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper
    residual = True         # whether to do an extra residual with the value or not. supposedly faster convergence if turned on
)

x = torch.randn(1, 16384, 512)
mask = torch.ones(1, 16384).bool()

attn(x, mask = mask) # (1, 16384, 512)

Nyströmformer, layers of Nyström attention

import torch
from nystrom_attention import Nystromformer

model = Nystromformer(
    dim = 512,
    dim_head = 64,
    heads = 8,
    depth = 6,
    num_landmarks = 256,
    pinv_iterations = 6
)

x = torch.randn(1, 16384, 512)
mask = torch.ones(1, 16384).bool()

model(x, mask = mask) # (1, 16384, 512)

You can also import it as Nyströmer if you wish

from nystrom_attention import Nystromer

Citations

@misc{xiong2021nystromformer,
    title   = {Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention},
    author  = {Yunyang Xiong and Zhanpeng Zeng and Rudrasis Chakraborty and Mingxing Tan and Glenn Fung and Yin Li and Vikas Singh},
    year    = {2021},
    eprint  = {2102.03902},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

.\lucidrains\nystrom-attention\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'nystrom-attention',
  # 查找所有包
  packages = find_packages(),
  # 版本号
  version = '0.0.12',
  # 许可证
  license='MIT',
  # 描述
  description = 'Nystrom Attention - Pytorch',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/nystrom-attention',
  # 关键词
  keywords = [
    'artificial intelligence',
    'attention mechanism'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.7.0',
    'torch>=2.0'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\omninet-pytorch\omninet_pytorch\omninet_pytorch.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 einops.layers.torch 库中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 使用 PerformerAttention 作为自注意力机制,因为它有最好的报告数字
from performer_pytorch import SelfAttention as PerformerAttention

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 获取模块所在设备的函数
def get_module_device(module):
    return next(module.parameters()).device

# 查找指定类型模块的函数
def find_modules(nn_module, type):
    return [module for module in nn_module.modules() if isinstance(module, type)]

# 类定义

# 预层归一化类
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        # 初始化 LayerNorm 归一化层
        self.norm = nn.LayerNorm(dim)
        # 初始化传入的函数
        self.fn = fn

    def forward(self, x, **kwargs):
        # 对输入进行归一化后,再传入函数进行处理
        return self.fn(self.norm(x), **kwargs)

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        # 定义前馈神经网络结构
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        # 前馈神经网络前向传播
        return self.net(x)

# 自注意力机制类
class Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        causal = False
    ):
        super().__init__()
        inner_dim = heads * dim_head
        self.heads =  heads
        self.scale = dim_head ** -0.5
        self.causal = causal

        # 定义 Q、K、V 的线性变换层
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        # 定义输出层
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask = None):
        # 获取输入 x 的形状信息
        b, n, d, h, device = *x.shape, self.heads, x.device
        # 将输入 x 进行 Q、K、V 的线性���换,并分割为 Q、K、V
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

        # 计算注意力分数
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        # 定义最大负值
        max_neg_value = -torch.finfo(sim.dtype).max

        # 如果存在 mask,则进行 mask 操作
        if exists(mask):
            mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j')
            sim.masked_fill_(~mask, max_neg_value)

        # 如果是因果注意力机制,则进行 mask 操作
        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
            causal_mask = rearrange(causal_mask, 'i j -> () i j')
            sim.masked_fill_(causal_mask, max_neg_value)

        # 计算注意力权重
        attn = sim.softmax(dim = -1)

        # 计算输出
        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        return self.to_out(out)

# 主类

class Omninet(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        pool_layer_tokens_every = 2,
        attn_dropout = 0.,
        ff_dropout = 0.,
        feature_redraw_interval = 1000
    ):
        super().__init__()

        layers = nn.ModuleList([])
        for ind in range(depth):
            num_layers = ind + 1
            should_pool = num_layers % pool_layer_tokens_every

            layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)),
                PreNorm(dim, FeedForward(dim = dim, dropout = ff_dropout)),
                PerformerAttention(dim = dim, heads= heads, dim_head = dim_head) if should_pool else None
            ]))

        self.layers = layers
        self.pool_num_layers = pool_layer_tokens_every

        # 跟踪重新绘制 Performer 投影矩阵的次数
        self.feature_redraw_interval = feature_redraw_interval
        self.register_buffer('calls_since_last_redraw', torch.tensor(0))

    # 修复投影矩阵的函数
    def fix_projection_matrices_(self):
        self.feature_redraw_interval = None
    # 检查是否需要重新绘制投影矩阵
    def check_redraw_projections(self):
        # 如果不处于训练状态,则直接返回
        if not self.training:
            return

        # 如果存在特征重新绘制间隔,并且自上次重新绘制以来的调用次数超过间隔
        if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval:
            # 获取模块所在设备
            device = get_module_device(self)

            # 查找所有 FastAttention 模块
            fast_attentions = find_modules(self, FastAttention)
            # 对每个 FastAttention 模块重新绘制投影矩阵
            for fast_attention in fast_attentions:
                fast_attention.redraw_projection_matrix(device)

            # 重置自上次重新绘制以来的调用次数
            self.calls_since_last_redraw.zero_()
            return

        # 自上次重新绘制以来的调用次数加一
        self.calls_since_last_redraw += 1

    # 前向传播函数
    def forward(self, x, mask = None):
        # 检查是否需要重新绘制投影矩阵
        self.check_redraw_projections()
        # 获取池化层数
        pool_num_layers = self.pool_num_layers

        # 初始化隐藏层列表
        hiddens = [x]

        # 遍历每个注意力层、前馈层和高效注意力层
        for attn, ff, efficient_attn in self.layers:
            # 注意力层的输出加上输入,得到新的输出
            x = attn(x, mask = mask) + x
            # 前馈层的输出加上输入,得到新的输出
            x = ff(x) + x

            # 将新的输出添加到隐藏层列表中
            hiddens.append(x)
            # 如果存在高效注意力层
            if exists(efficient_attn):
                # 选择最近的池化层数量的隐藏层
                layers_to_pool = hiddens[-pool_num_layers:]
                num_layers = len(layers_to_pool)

                # 将所有隐藏层的 token 合并成一个张量
                all_tokens = torch.stack(layers_to_pool)
                all_tokens = rearrange(all_tokens, 'l b n d -> b (n l) d')

                # 初始化池化注意力层的掩码
                pool_attn_mask = None
                if exists(mask):
                    pool_attn_mask = repeat(mask, 'b n -> b (n l)', l = num_layers)

                # 对合并的 token 应用高效注意力层
                attended_tokens = efficient_attn(all_tokens, mask = pool_attn_mask)

                # 重新排列输出张量的维度
                attended_tokens = rearrange(attended_tokens, 'b n c -> b c n')
                # 对注意力输出进行最大池化
                pooled_tokens = F.max_pool1d(attended_tokens, kernel_size = num_layers, stride = num_layers)
                # 将池化后的 token 添加到输出中
                x += rearrange(pooled_tokens, 'b c n -> b n c')

        # 返回最终输出
        return x
# 定义一个名为 OmninetCausal 的类,用于处理因果关系的情况,采用轴向注意力层,直到重写线性注意力的 CUDA 内核
class OmninetCausal(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        pool_layer_tokens_every = 2,
        attn_dropout = 0.,
        ff_dropout = 0.
    ):
        super().__init__()

        # 初始化层位置嵌入参数
        self.layer_pos_emb = nn.Parameter(torch.randn(depth + 1, dim))

        # 初始化层列表
        layers = nn.ModuleList([])
        for ind in range(depth):
            num_layers = ind + 1
            should_pool = num_layers % pool_layer_tokens_every

            # 添加每一层的注意力、前馈和轴向注意力(如果需要池化)到层列表中
            layers.append(nn.ModuleList([
                PreNorm(dim, Attention(causal = True, dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)),
                PreNorm(dim, FeedForward(dim = dim, dropout = ff_dropout)),
                Attention(dim = dim, heads= heads, dim_head = dim_head) if should_pool else None
            ]))

        self.layers = layers
        self.pool_num_layers = pool_layer_tokens_every

    # 前向传播函数
    def forward(self, x, mask = None):
        pool_num_layers = self.pool_num_layers

        b = x.shape[0]
        pos_embs = rearrange(self.layer_pos_emb, 'n d -> () n d')

        x += pos_embs[:, 0]
        hiddens = [x]

        for ind, (attn, ff, layer_axial_attn) in enumerate(self.layers):

            # 执行注意力层操作
            x = attn(x, mask = mask) + x
            # 执行前馈层操作
            x = ff(x) + x

            x += pos_embs[:, ind + 1]
            hiddens.append(x)

            if exists(layer_axial_attn):
                layers_to_pool = hiddens[-pool_num_layers:]
                num_layers = len(layers_to_pool)

                # 重排层的 tokens,并进行轴向注意力操作
                layer_tokens = rearrange(torch.stack(layers_to_pool), 'l b n d -> (b n) l d')

                attended_tokens = layer_axial_attn(layer_tokens)
                attended_tokens = rearrange(attended_tokens, '(b n) l d -> b n l d', b = b)
                pooled_attended_tokens = attended_tokens.max(dim = -2).values
                x += pooled_attended_tokens

        return x

.\lucidrains\omninet-pytorch\omninet_pytorch\__init__.py

# 从 omninet_pytorch 模块中导入 Omninet 和 OmninetCausal 类
from omninet_pytorch.omninet_pytorch import Omninet, OmninetCausal

Omninet - Pytorch

Implementation of OmniNet, Omnidirectional Representations from Transformers, in Pytorch. The authors propose that we should be attending to all the tokens of the previous layers, leveraging recent efficient attention advances to achieve this goal.

Install

$ pip install omninet-pytorch

Usage

import torch
from omninet_pytorch import Omninet

omninet = Omninet(
    dim = 512,                     # model dimension
    depth = 6,                     # depth
    dim_head = 64,                 # dimension per head
    heads = 8,                     # number of heads
    pool_layer_tokens_every = 3,   # key to this paper - every N layers, omni attend to all tokens of all layers
    attn_dropout = 0.1,            # attention dropout
    ff_dropout = 0.1,              # feedforward dropout
    feature_redraw_interval = 1000 # how often to redraw the projection matrix for omni attention net - Performer
)

x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()

omninet(x, mask = mask) # (1, 1024, 512)

Causal case, just use the class OmninetCausal. At the moment, it isn't faithful to the paper (I am using layer axial attention with layer positional embeddings to draw up information), but will fix this once I rework the linear attention CUDA kernel.

import torch
from omninet_pytorch import OmninetCausal

omninet = OmninetCausal(
    dim = 512,                     # model dimension
    depth = 6,                     # depth
    dim_head = 64,                 # dimension per head
    heads = 8,                     # number of heads
    pool_layer_tokens_every = 3,   # key to this paper - every N layers, omni attend to all tokens of all layers
    attn_dropout = 0.1,            # attention dropout
    ff_dropout = 0.1               # feedforward dropout
)

x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()

omninet(x, mask = mask) # (1, 1024, 512)

Citations

@misc{tay2021omninet,
    title   = {OmniNet: Omnidirectional Representations from Transformers}, 
    author  = {Yi Tay and Mostafa Dehghani and Vamsi Aribandi and Jai Gupta and Philip Pham and Zhen Qin and Dara Bahri and Da-Cheng Juan and Donald Metzler},
    year    = {2021},
    eprint  = {2103.01075},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\omninet-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'omninet-pytorch', # 包的名称
  packages = find_packages(), # 查找并包含所有包
  version = '0.0.6', # 版本号
  license='MIT', # 许可证
  description = 'Omninet - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/omninet-pytorch', # 项目链接
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformer',
    'attention mechanism'
  ],
  install_requires=[ # 安装依赖
    'einops>=0.3',
    'torch>=1.6',
    'performer-pytorch'
  ],
  classifiers=[ # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/

.\lucidrains\PaLM-jax\palm_jax\palm.py

# 导入所需的模块和库
from typing import List, Tuple

import numpy as onp
from jax import random, nn, lax, jit, numpy as np
from jax.numpy import einsum

from equinox import Module, static_field
from einops import rearrange, repeat

# bias-less layernorm

class LayerNorm(Module):
    gamma: np.ndarray
    eps: float = static_field()

    def __init__(self, dim, eps = 1e-5):
        # 初始化 LayerNorm 类,设置 gamma 和 eps 属性
        self.gamma = np.ones((dim,))
        self.eps = eps

    def __call__(self, x):
        # 计算均值和均方差
        mean = np.mean(x, axis = -1, keepdims = True)
        mean_of_squares = np.mean(np.square(x), axis = -1, keepdims = True)
        variance = mean_of_squares - np.square(mean)
        inv = lax.rsqrt(variance + self.eps)
        # 返回 LayerNorm 结果
        return inv * (x - mean) * self.gamma

# Rotary embedding

def fixed_pos_embedding(inv_freq, seq):
    # 生成固定位置嵌入的正弦和余弦值
    sinusoid_inp = einsum('i , j -> i j', np.arange(seq), inv_freq)
    sinusoid_inp = repeat(sinusoid_inp, '... d -> ... (d r)', r = 2)
    return np.sin(sinusoid_inp), np.cos(sinusoid_inp)

def rotate_every_two(x):
    # 将输入张量中的每两个元素进行旋转
    x = rearrange(x, '... (d r) -> ... d r', r = 2)
    x1, x2 = x[..., 0], x[..., 1]
    x = np.stack((-x2, x1), axis = -1)
    return rearrange(x, '... d r -> ... (d r)')

def apply_rotary_pos_emb(x, sincos):
    sin, cos = sincos
    # 应用旋转位置嵌入
    return (x * cos) + (rotate_every_two(x) * sin)

# attention - multi-query, one-headed key / values variant
# feedforward - Shazeer's SwiGLU variant

class ParallelTransformerBlock(Module):
    norm: Module
    wi: np.ndarray
    attn_wo: np.ndarray
    ff_wo: np.ndarray

    heads: int = static_field()
    fused_dims: Tuple[int] = static_field()
    scale: float = static_field()
    mask_value: float = static_field()

    def __init__(
        self,
        dim,
        dim_head,
        heads,
        key,
        ff_mult = 4,
        mask_value = -1e10
    ):
        attn_inner_dim = dim_head * heads
        ff_inner_dim = dim * ff_mult
        self.norm = LayerNorm(dim)
        self.fused_dims = (attn_inner_dim, dim_head, dim_head, ff_inner_dim, ff_inner_dim)

        self.wi = random.normal(key, (dim, sum(self.fused_dims)))
        self.attn_wo = random.normal(key, (attn_inner_dim, dim))
        self.ff_wo = random.normal(key, (ff_inner_dim, dim))

        self.heads = heads
        self.scale = dim_head ** -0.5
        self.mask_value = mask_value

    def __call__(self, x, *, pos_emb, causal_mask):
        n, split_indices = x.shape[-2], onp.cumsum(self.fused_dims[:-1])

        x = self.norm(x)

        # fused attention and feedforward projections

        q, k, v, ff, ff_gate = np.split(x @ self.wi, split_indices, axis = -1)

        # split out heads

        q = rearrange(q, '... n (h d) -> ... h n d', h = self.heads)

        # scale

        q *= self.scale

        # apply rotary embeddings

        q, k = map(lambda t: apply_rotary_pos_emb(t, pos_emb), (q, k))

        # sim

        sim = einsum('... h i d, ... j d -> ... h i j', q, k)

        # causal mask

        sim = np.where(causal_mask, sim, self.mask_value)

        # attention

        attn = nn.softmax(sim, axis = -1)

        # aggregate values

        out = einsum('... h i j, ... j d -> ... h i d', attn, v)

        # merge heads

        out = rearrange(out, '... h n d -> ... n (h d)')

        # feedforward out

        attn_out = out @ self.attn_wo

        ff_out = (ff * nn.swish(ff_gate)) @ self.ff_wo

        # combine heads out

        return attn_out + ff_out

# main class

class PaLM(Module):
    embedding: np.ndarray
    norm: Module
    layers: List[List[Module]]
    inv_freq: onp.ndarray = static_field()

    def __init__(
        self,
        *,
        num_tokens,
        dim,
        dim_head,
        depth,
        heads,
        key,
        ff_mult = 4
    # 初始化 Transformer 模型的参数
    ):
        # 使用正态分布随机初始化嵌入矩阵,乘以0.02缩放
        self.embedding = random.normal(key, (num_tokens, dim)) * 0.02
        # 计算位置编码的倒数频率
        self.inv_freq = 1.0 / (10000 ** (np.arange(0, dim_head, 2) / dim_head))

        # 创建 Transformer 模型的多个层
        self.layers = [ParallelTransformerBlock(dim = dim, dim_head = dim_head, heads = heads, ff_mult = ff_mult, key = key) for _ in range(depth)]
        # 初始化 LayerNorm 层
        self.norm = LayerNorm(dim)

    # 定义 JIT 编译的调用函数
    @jit
    def __call__(self, x):
        # 获取输入张量 x 的最后一个维度大小
        n = x.shape[-1]
        # 使用嵌入矩阵将输入 x 映射到嵌入空间
        x = self.embedding[x]

        # 生成固定的位置编码
        rotary_emb = fixed_pos_embedding(self.inv_freq, n)
        # 生成因果掩码,下三角矩阵
        causal_mask = np.tril(np.ones((n, n)))

        # 遍历 Transformer 模型的每个层进行前向传播
        for block in self.layers:
            # 调用每个层的前向传播函数,更新输入 x
            x = block(x, pos_emb = rotary_emb, causal_mask = causal_mask) + x

        # 对输出 x 进行 LayerNorm 处理
        x = self.norm(x)
        # 返回最终输出,执行嵌入矩阵的转置乘积
        return x @ self.embedding.transpose()

.\lucidrains\PaLM-jax\palm_jax\palm_lite.py

# 从 math 模块中导入 log2 和 floor 函数
# 从 typing 模块中导入 List 和 Tuple 类型
import numpy as onp
# 从 jax 模块中导入 random, jit, nn, lax, numpy 模块,并将 numpy 模块重命名为 np
from jax import random, jit, nn, lax, numpy as np
# 从 jax.numpy 模块中导入 einsum 函数
from jax.numpy import einsum
# 从 equinox 模块中导入 Module, static_field 类
from equinox import Module, static_field
# 从 einops 模块中导入 rearrange, repeat 函数

# 定义 RMSNorm 类,继承自 Module 类
class RMSNorm(Module):
    # 定义类属性 gamma, scale, eps
    gamma: np.ndarray
    scale: float = static_field()
    eps: float = static_field()

    # 初始化方法,接受 dim 和 eps 两个参数
    def __init__(self, dim, eps = 1e-5):
        # 初始化 gamma 为全为 1 的数组
        self.gamma = np.ones((dim,))
        self.eps = eps
        self.scale = dim ** 0.5

    # 定义 __call__ 方法,接受参数 x
    def __call__(self, x):
        # 计算 x 的平方和,并在最后一个维度上保持维度
        sum_of_squares = np.sum(np.square(x), axis = -1, keepdims = True)
        # 计算 sum_of_squares 加上 eps 的平方根的倒数
        inv_norm = lax.rsqrt(sum_of_squares + self.eps)
        # 返回 inv_norm 乘以 x 乘以 gamma 乘以 scale 的结果
        return inv_norm * x * self.gamma * self.scale

# 定义 get_alibi_slopes 函数,接受 heads 参数
def get_alibi_slopes(heads):
    # 定义内部函数 get_slopes_power_of_2,接受 n 参数
    def get_slopes_power_of_2(n):
        # 计算起始值 start
        start = (2 ** (-2 ** -(log2(n) - 3)))
        ratio = start
        # 返回等比数列
        return [start*ratio**i for i in range(n)]

    # 如果 heads 的对数是整数
    if log2(heads).is_integer():
        # 返回 get_slopes_power_of_2(heads) 的结果
        return get_slopes_power_of_2(heads)

    # 计算最接近 heads 的 2 的幂次方
    closest_power_of_2 = 2 ** floor(log2(heads))
    # 返回 get_slopes_power_of_2(closest_power_of_2) 和 get_slopes_power_of_2(2 * closest_power_of_2) 的结果
    return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]

# 定义 calc_alibi_bias 函数,接受 seq_len 和 heads 两个参数
def calc_alibi_bias(seq_len, heads):
    # 获取斜率
    slopes = get_alibi_slopes(heads)
    # 重排 slopes 数组的维度
    slopes = rearrange(onp.array(slopes), 'h -> h 1 1')
    # 生成偏置
    bias = rearrange(onp.arange(seq_len), 'j -> 1 1 j')
    return slopes * bias

# 定义 ParallelTransformerBlock 类,继承自 Module 类
class ParallelTransformerBlock(Module):
    # 定义类属性 norm, wi, attn_wo, ff_wo, heads, fused_dims, scale, mask_value
    norm: Module
    wi: np.ndarray
    attn_wo: np.ndarray
    ff_wo: np.ndarray
    heads: int = static_field()
    fused_dims: Tuple[int] = static_field()
    scale: float = static_field()
    mask_value: float = static_field()

    # 初始化方法,接受 dim, dim_head, heads, key, ff_mult, mask_value 参数
    def __init__(
        self,
        dim,
        dim_head,
        heads,
        key,
        ff_mult = 4,
        mask_value = -1e10
    ):
        # 计算注意力内部维度和前馈内部维度
        attn_inner_dim = dim_head * heads
        ff_inner_dim = dim * ff_mult
        # 初始化 norm 为 RMSNorm 类的实例
        self.norm = RMSNorm(dim)
        self.fused_dims = (attn_inner_dim, dim_head, ff_inner_dim, ff_inner_dim)

        # 初始化 wi, attn_wo, ff_wo 为随机正态分布的数组
        self.wi = random.normal(key, (dim, sum(self.fused_dims)))
        self.attn_wo = random.normal(key, (attn_inner_dim, dim))
        self.ff_wo = random.normal(key, (ff_inner_dim, dim))

        self.heads = heads
        self.scale = dim_head ** -0.5
        self.mask_value = mask_value

    # 定义 __call__ 方法,接受 x 和 attn_bias 两个参数
    def __call__(self, x, *, attn_bias):
        # 获取 x 的倒数第二个维度的大小和分割索引
        n, split_indices = x.shape[-2], onp.cumsum(self.fused_dims[:-1])

        # 对 x 进行归一化
        x = self.norm(x)

        # 融合注意力和前馈的投影

        q, kv, ff, ff_gate = np.split(x @ self.wi, split_indices, axis = -1)

        # 分割出头部

        q = rearrange(q, '... n (h d) -> ... h n d', h = self.heads)

        # 缩放

        q *= self.scale

        # 相似度

        sim = einsum('... h i d, ... j d -> ... h i j', q, kv)

        # 因果掩码

        sim = sim + attn_bias

        # 注意力

        attn = nn.softmax(sim, axis = -1)

        # 聚合值

        out = einsum('... h i j, ... j d -> ... h i d', attn, kv)

        # 合并头部

        out = rearrange(out, '... h n d -> ... n (h d)')

        # 前馈输出

        attn_out = out @ self.attn_wo

        ff_out = (ff * nn.swish(ff_gate)) @ self.ff_wo

        # 合并头部输出

        return attn_out + ff_out

# 主类

class PaLM(Module):
    # 定义类属性 embedding, norm, layers, attn_bias
    embedding: np.ndarray
    norm: Module
    layers: List[List[Module]]
    attn_bias: onp.ndarray = static_field()

    # 初始化方法,接受 num_tokens, dim, dim_head, depth, heads, key, ff_mult, max_seq_len, mask_value 参数
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        dim_head,
        depth,
        heads,
        key,
        ff_mult = 4,
        max_seq_len = 2048,
        mask_value = -1e10
        self.embedding = random.normal(key, (num_tokens, dim)) * 0.02
        # 初始化嵌入矩阵,使用正态分布生成随机值,并乘以0.02

        causal_mask = onp.tril(onp.ones((max_seq_len, max_seq_len)))
        # 创建一个下三角矩阵作为因果掩码
        alibi_bias = calc_alibi_bias(max_seq_len, heads = heads)
        # 计算alibi偏置
        self.attn_bias = np.where(causal_mask, repeat(alibi_bias, 'h 1 j -> h i j', i = max_seq_len), mask_value)
        # 根据因果掩码和alibi偏置生成注意力偏置矩阵

        self.layers = [ParallelTransformerBlock(dim = dim, dim_head = dim_head, heads = heads, key = key, ff_mult = ff_mult) for _ in range(depth)]
        # 创建多个并行Transformer块
        self.norm = RMSNorm(dim)
        # 初始化RMS归一化层

    @jit
    def __call__(self, x):
        # 定义类的调用方法,输入x
        n = x.shape[-1]
        # 获取输入x的最后一个维度大小
        x = self.embedding[x]
        # 使用嵌入矩阵将输入x转换为嵌入向量

        attn_bias = self.attn_bias[..., :n, :n]
        # 获取与输入长度相关的注意力偏置

        for block in self.layers:
            # 遍历每个Transformer块
            x = block(x, attn_bias = attn_bias) + x
            # 对输入x进行Transformer块的处理,并将结果与原始输入相加

        x = self.norm(x)
        # 对处理后的结果进行RMS归一化
        return x @ self.embedding.transpose()
        # 返回结果与嵌入矩阵的转置矩阵的乘积

.\lucidrains\PaLM-jax\palm_jax\utils.py

# 导入所需的库
from jax import random
from jax.lax import top_k
import jax.numpy as np

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 计算对数,加上一个很小的值以避免出现对数零的情况
def log(t, eps = 1e-20):
    return np.log(t + eps)

# 采样函数

# 选择概率最高的前 k 个元素
def select_top_k(tensor, k):
    values, _ = top_k(tensor, k)
    mask = tensor > values.min()
    return mask, np.where(mask, tensor, 0.)

# 生成 Gumbel 噪声
def gumbel_noise(key, shape):
    noise = random.uniform(key, shape = shape, minval = 0., maxval = 1.)
    return -log(-log(noise))

# 生成样本序列
def sample(key, model, prime, length, top_k = None):
    start_pos = prime.shape[-1]
    seq = np.pad(prime, (0, length - prime.shape[-1]))
    one_hots = np.eye(length, dtype = int)

    for curr_pos in range(start_pos, length):
        logits = model(seq)
        logits = logits[curr_pos - 1]

        _, key = random.split(key)
        noise = gumbel_noise(key, logits.shape)

        if exists(top_k):
            mask, logits = select_top_k(logits, top_k)
            noise *= mask

        logits += noise
        sampled_ind = np.argmax(logits, axis = -1)

        one_hot = one_hots[curr_pos]
        seq += one_hot * sampled_ind

    return seq