Lucidrains-系列项目源码解析-八十三-

136 阅读22分钟

Lucidrains 系列项目源码解析(八十三)

ResMLP - Pytorch

Implementation of ResMLP, an all MLP solution to image classification out of Facebook AI, in Pytorch

Install

$ pip install res-mlp-pytorch

Usage

import torch
from res_mlp_pytorch import ResMLP

model = ResMLP(
    image_size = 256,
    patch_size = 16,
    dim = 512,
    depth = 12,
    num_classes = 1000
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

Rectangular image

import torch
from res_mlp_pytorch import ResMLP

model = ResMLP(
    image_size = (128, 256), # (128 x 256)
    patch_size = 16,
    dim = 512,
    depth = 12,
    num_classes = 1000
)

img = torch.randn(1, 3, 128, 256)
pred = model(img) # (1, 1000)

Citations

@misc{touvron2021resmlp,
    title   = {ResMLP: Feedforward networks for image classification with data-efficient training}, 
    author  = {Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Hervé Jégou},
    year    = {2021},
    eprint  = {2105.03404},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\res-mlp-pytorch\res_mlp_pytorch\res_mlp_pytorch.py

import torch
from torch import nn, einsum
from einops.layers.torch import Rearrange, Reduce

# 导入必要的库

# 定义一个函数,如果输入不是元组,则返回一个包含相同值的元组
def pair(val):
    return (val, val) if not isinstance(val, tuple) else val

# 定义一个仿射变换类
class Affine(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1, 1, dim))
        self.b = nn.Parameter(torch.zeros(1, 1, dim))

    def forward(self, x):
        return x * self.g + self.b

# 定义一个预仿射后层缩放类
class PreAffinePostLayerScale(nn.Module): # https://arxiv.org/abs/2103.17239
    def __init__(self, dim, depth, fn):
        super().__init__()
        # 根据深度选择初始化值
        if depth <= 18:
            init_eps = 0.1
        elif depth > 18 and depth <= 24:
            init_eps = 1e-5
        else:
            init_eps = 1e-6

        scale = torch.zeros(1, 1, dim).fill_(init_eps)
        self.scale = nn.Parameter(scale)
        self.affine = Affine(dim)
        self.fn = fn

    def forward(self, x):
        return self.fn(self.affine(x)) * self.scale + x

# 定义一个ResMLP模型
def ResMLP(*, image_size, patch_size, dim, depth, num_classes, expansion_factor = 4):
    image_height, image_width = pair(image_size)
    assert (image_height % patch_size) == 0 and (image_width % patch_size) == 0, 'image height and width must be divisible by patch size'
    num_patches = (image_height // patch_size) * (image_width // patch_size)
    wrapper = lambda i, fn: PreAffinePostLayerScale(dim, i + 1, fn)

    return nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
        nn.Linear((patch_size ** 2) * 3, dim),
        *[nn.Sequential(
            wrapper(i, nn.Conv1d(num_patches, num_patches, 1)),
            wrapper(i, nn.Sequential(
                nn.Linear(dim, dim * expansion_factor),
                nn.GELU(),
                nn.Linear(dim * expansion_factor, dim)
            ))
        ) for i in range(depth)],
        Affine(dim),
        Reduce('b n c -> b c', 'mean'),
        nn.Linear(dim, num_classes)
    )

# 返回一个包含ResMLP模型结构的序列

.\lucidrains\res-mlp-pytorch\res_mlp_pytorch\__init__.py

# 从 res_mlp_pytorch.res_mlp_pytorch 模块中导入 ResMLP 类
from res_mlp_pytorch.res_mlp_pytorch import ResMLP

.\lucidrains\res-mlp-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'res-mlp-pytorch', # 包的名称
  packages = find_packages(exclude=[]), # 查找并包含所有包
  version = '0.0.6', # 版本号
  license='MIT', # 许可证
  description = 'ResMLP - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/res-mlp-pytorch', # 项目链接
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'image recognition'
  ],
  install_requires=[ # 安装依赖
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[ # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Retrieval-Augmented Denoising Diffusion Probabilistic Models (wip)

Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in Pytorch

This will make use of the Clip Retrieval library made by @rom1504

Citations

@article{Blattmann2022RetrievalAugmentedDM,
    title   = {Retrieval-Augmented Diffusion Models},
    author  = {A. Blattmann and Robin Rombach and K Oktay and Bj{\"o}rn Ommer},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2204.11824}
}

.\lucidrains\retrieval-augmented-ddpm\retrieval_augmented_ddpm\retrieval_augmented_ddpm.py

# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
    # 计算矩形的面积
    area = length * width
    # 返回计算得到的面积
    return area

.\lucidrains\retrieval-augmented-ddpm\retrieval_augmented_ddpm\__init__.py

# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
    # 计算矩形的面积
    area = length * width
    # 返回计算得到的面积
    return area

.\lucidrains\retrieval-augmented-ddpm\setup.py

# 导入设置安装和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'retrieval-augmented-ddpm',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.1',  # 版本号
  license='MIT',  # 许可证
  description = 'Retrieval-Augmented Denoising Diffusion Probabilistic Models',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/retrieval-augmented-ddpm',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',    
    'denoising diffusion',
    'retrieval'
  ],
  install_requires=[  # 安装依赖的包
    'clip-retrieval',
    'einops>=0.4',
    'torch>=1.6',
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

RETRO - Pytorch

Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch. This will deviate from the paper slightly, using rotary embeddings for relative positional encoding, as well as Faiss library instead of Scann.

This library leverages autofaiss for building the index and calculating the k-nearest neighbors for all chunks.

Jay Alammar explanatory blogpost

The selling point of this retriever approach is reaching GPT-3 performance at 10x less parameters. More research is definitely deserved in this area.

I have also included the features necessary to scale the retrieval transformer to 1000 layers, if the claims of DeepNet paper is to be believed.

Update: Someone on Reddit has gifted me a Gold Award. Not sure what it is, but thank you! 🙏

Update: Deepnorm has been validated at scale in a 130B model out of Tsinghua. It is now recommended that you train with use_deepnet set to True

Install

$ pip install retro-pytorch

Usage

import torch
from retro_pytorch import RETRO

retro = RETRO(
    chunk_size = 64,                         # the chunk size that is indexed and retrieved (needed for proper relative positions as well as causal chunked cross attention)
    max_seq_len = 2048,                      # max sequence length
    enc_dim = 896,                           # encoder model dim
    enc_depth = 2,                           # encoder depth
    dec_dim = 796,                           # decoder model dim
    dec_depth = 12,                          # decoder depth
    dec_cross_attn_layers = (3, 6, 9, 12),   # decoder cross attention layers (with causal chunk cross attention)
    heads = 8,                               # attention heads
    dim_head = 64,                           # dimension per head
    dec_attn_dropout = 0.25,                 # decoder attention dropout
    dec_ff_dropout = 0.25,                   # decoder feedforward dropout
    use_deepnet = True                       # turn on post-normalization with DeepNet residual scaling and initialization, for scaling to 1000 layers
)

seq = torch.randint(0, 20000, (2, 2048 + 1))      # plus one since it is split into input and labels for training
retrieved = torch.randint(0, 20000, (2, 32, 2, 128)) # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation)

loss = retro(seq, retrieved, return_loss = True)
loss.backward()

# do above for many steps

RETRO Training Wrapper

The aim of the TrainingWrapper is to process a folder of text documents into the necessary memmapped numpy arrays to begin training RETRO.

import torch
from retro_pytorch import RETRO, TrainingWrapper

# instantiate RETRO, fit it into the TrainingWrapper with correct settings

retro = RETRO(
    max_seq_len = 2048,                      # max sequence length
    enc_dim = 896,                           # encoder model dimension
    enc_depth = 3,                           # encoder depth
    dec_dim = 768,                           # decoder model dimensions
    dec_depth = 12,                          # decoder depth
    dec_cross_attn_layers = (1, 3, 6, 9),    # decoder cross attention layers (with causal chunk cross attention)
    heads = 8,                               # attention heads
    dim_head = 64,                           # dimension per head
    dec_attn_dropout = 0.25,                 # decoder attention dropout
    dec_ff_dropout = 0.25                    # decoder feedforward dropout
).cuda()

wrapper = TrainingWrapper(
    retro = retro,                                 # path to retro instance
    knn = 2,                                       # knn (2 in paper was sufficient)
    chunk_size = 64,                               # chunk size (64 in paper)
    documents_path = './text_folder',              # path to folder of text
    glob = '**/*.txt',                             # text glob
    chunks_memmap_path = './train.chunks.dat',     # path to chunks
    seqs_memmap_path = './train.seq.dat',          # path to sequence data
    doc_ids_memmap_path = './train.doc_ids.dat',   # path to document ids per chunk (used for filtering neighbors belonging to same document)
    max_chunks = 1_000_000,                        # maximum cap to chunks
    max_seqs = 100_000,                            # maximum seqs
    knn_extra_neighbors = 100,                     # num extra neighbors to fetch
    max_index_memory_usage = '100m',
    current_memory_available = '1G'
)

# get the dataloader and optimizer (AdamW with all the correct settings)

train_dl = iter(wrapper.get_dataloader(batch_size = 2, shuffle = True))
optim = wrapper.get_optimizer(lr = 3e-4, wd = 0.01)

# now do your training
# ex. one gradient step

seq, retrieved = map(lambda t: t.cuda(), next(train_dl))

# seq       - (2, 2049)         - 1 extra token since split by seq[:, :-1], seq[:, 1:]
# retrieved - (2, 32, 2, 128)   - 128 since chunk + continuation, each 64 tokens

loss = retro(
    seq,
    retrieved,
    return_loss = True
)

# one gradient step

loss.backward()
optim.step()
optim.zero_grad()

# do above for many steps, then ...

# topk sampling with retrieval at chunk boundaries

sampled = wrapper.generate(filter_thres = 0.9, temperature = 1.0) # (1, <2049) terminates early if all <eos>

# or you can generate with a prompt, knn retrieval for initial chunks all taken care of

prompt = torch.randint(0, 1000, (1, 128))  # start with two chunks worth of sequence
sampled = wrapper.generate(prompt, filter_thres = 0.9, temperature = 1.0) # (1, <2049) terminates early if all <eos>

If you wish to force a reprocess of the training data, simply run your script with a REPROCESS=1 environment flag as so

$ REPROCESS=1 python train.py

RETRO Datasets

The RETRODataset class accepts paths to a number of memmapped numpy arrays containing the chunks, the index of the first chunk in the sequence to be trained on (in RETRO decoder), and the pre-calculated indices of the k-nearest neighbors per chunk.

You can use this to easily assemble the data for RETRO training, if you do not wish to use the TrainingWrapper from above.

Furthermore, all the functions needed to create the necessary memmapped data is in the sections to follow.

import torch
from torch.utils.data import DataLoader
from retro_pytorch import RETRO, RETRODataset

# mock data constants

import numpy as np

NUM_CHUNKS = 1000
CHUNK_SIZE = 64
NUM_SEQS = 100
NUM_NEIGHBORS = 2

def save_memmap(path, tensor):
    f = np.memmap(path, dtype = tensor.dtype, mode = 'w+', shape = tensor.shape)
    f[:] = tensor
    del f

# generate mock chunk data

save_memmap(
    './train.chunks.dat',
    np.int32(np.random.randint(0, 8192, size = (NUM_CHUNKS, CHUNK_SIZE + 1)))
)

# generate nearest neighbors for each chunk

save_memmap(
    './train.chunks.knn.dat',
    np.int32(np.random.randint(0, 1000, size = (NUM_CHUNKS, NUM_NEIGHBORS)))
)

# generate seq data

save_memmap(
    './train.seq.dat',
    np.int32(np.random.randint(0, 128, size = (NUM_SEQS,)))
)

# instantiate dataset class
# which constructs the sequence and neighbors from memmapped chunk and neighbor information

train_ds = RETRODataset(
    num_sequences = NUM_SEQS,
    num_chunks = NUM_CHUNKS,
    num_neighbors = NUM_NEIGHBORS,
    chunk_size = CHUNK_SIZE,
    seq_len = 2048,
    chunk_memmap_path = './train.chunks.dat',
    chunk_nn_memmap_path = './train.chunks.knn.dat',
    seq_memmap_path = './train.seq.dat'
)

train_dl = iter(DataLoader(train_ds, batch_size = 2))

# one forwards and backwards

retro = RETRO(
    max_seq_len = 2048,                      # max sequence length
    enc_dim = 896,                           # encoder model dimension
    enc_depth = 3,                           # encoder depth
    dec_dim = 768,                           # decoder model dimensions
    dec_depth = 12,                          # decoder depth
    dec_cross_attn_layers = (1, 3, 6, 9),    # decoder cross attention layers (with causal chunk cross attention)
    heads = 8,                               # attention heads
    dim_head = 64,                           # dimension per head
    dec_attn_dropout = 0.25,                 # decoder attention dropout
    dec_ff_dropout = 0.25                    # decoder feedforward dropout
).cuda()

seq, retrieved = map(lambda t: t.cuda(), next(train_dl))

# seq       - (2, 2049)         - 1 extra token since split by seq[:, :-1], seq[:, 1:]
# retrieved - (2, 32, 2, 128)   - 128 since chunk + continuation, each 64 tokens

loss = retro(
    seq,
    retrieved,
    return_loss = True
)

loss.backward()

Retrieval related tools

This repository will use the default tokenizer (sentencepiece) for the cased version of BERT. Embeddings will be fetched from the vanilla BERT, and can either be masked mean pooled representation, or the CLS token.

ex. masked mean pooled representation

from retro_pytorch.retrieval import bert_embed, tokenize

ids = tokenize([
    'hello world',
    'foo bar'
])

embeds = bert_embed(ids) # (2, 768) - 768 is hidden dimension of BERT

ex. CLS token representation

from retro_pytorch.retrieval import bert_embed, tokenize

ids = tokenize([
    'hello world',
    'foo bar'
])

embeds = bert_embed(ids, return_cls_repr = True) # (2, 768)

Create your chunks and chunk start indices (for calculating sequence ranges for autoregressive training) using text_folder_to_chunks_

from retro_pytorch.retrieval import text_folder_to_chunks_

stats = text_folder_to_chunks_(
    folder = './text_folder',
    glob = '**/*.txt',
    chunks_memmap_path = './train.chunks.dat',
    seqs_memmap_path = './train.seq.dat',
    doc_ids_memmap_path = './train.doc_ids.dat',  # document ids are needed for filtering out neighbors belonging to same document appropriately during computation of nearest neighbors
    chunk_size = 64,
    seq_len = 2048,
    max_chunks = 1_000_000,
    max_seqs = 100_000
)

# {'chunks': <number of chunks>, 'docs': <number of documents>, 'seqs': <number of sequences>}

Fetching Nearest Neighbors

You can turn your memmapped chunks numpy array into embeddings and a faiss index with one command

from retro_pytorch.retrieval import chunks_to_index_and_embed

index, embeddings = chunks_to_index_and_embed(
    num_chunks = 1000,
    chunk_size = 64,
    chunk_memmap_path = './train.chunks.dat'
)

query_vector = embeddings[:1]                   # use first embedding as query
_, indices = index.search(query_vector, k = 2)  # fetch 2 neighbors, first indices should be self

neighbor_embeddings = embeddings[indices]       # (1, 2, 768)

You can also directly calculate the nearest neighbor file necessary for training, with chunks_to_precalculated_knn_ command

from retro_pytorch.retrieval import chunks_to_precalculated_knn_

chunks_to_precalculated_knn_(
    num_chunks = 1000,
    chunk_size = 64,
    chunk_memmap_path = './train.chunks.dat',    # path to main chunks dataset
    doc_ids_memmap_path = './train.doc_ids.dat', # path to document ids created by text_folder_to_chunks_, used for filtering out neighbors that belong to the same document
    num_nearest_neighbors = 2,                   # number of nearest neighbors you'd like to use
    num_extra_neighbors = 10                     # fetch 10 extra neighbors, in the case that fetched neighbors are frequently from same document (filtered out)
)

# nearest neighbor info saved to ./train.chunks.knn.dat

Citations

@misc{borgeaud2022improving,
    title   = {Improving language models by retrieving from trillions of tokens}, 
    author  = {Sebastian Borgeaud and Arthur Mensch and Jordan Hoffmann and Trevor Cai and Eliza Rutherford and Katie Millican and George van den Driessche and Jean-Baptiste Lespiau and Bogdan Damoc and Aidan Clark and Diego de Las Casas and Aurelia Guy and Jacob Menick and Roman Ring and Tom Hennigan and Saffron Huang and Loren Maggiore and Chris Jones and Albin Cassirer and Andy Brock and Michela Paganini and Geoffrey Irving and Oriol Vinyals and Simon Osindero and Karen Simonyan and Jack W. Rae and Erich Elsen and Laurent Sifre},
    year  = {2022},
    eprint = {2112.04426},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@article{Wang2022DeepNetST,
    title   = {DeepNet: Scaling Transformers to 1, 000 Layers},
    author  = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2203.00555}
}
@misc{zhang2021sparse,
    title   = {Sparse Attention with Linear Units},
    author  = {Biao Zhang and Ivan Titov and Rico Sennrich},
    year    = {2021},
    eprint  = {2104.07012},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

I consider always the adult life to be the continuous retrieval of childhood. - Umberto Eco

.\lucidrains\RETRO-pytorch\retro_pytorch\data.py

# 导入所需的库
from functools import partial
import numpy as np
import torch
from torch.utils.data import Dataset

# 导入自定义的模块
from retro_pytorch.retrieval import EOS_ID
from retro_pytorch.utils import memmap

# 定义函数 knn_to_retrieved_chunks,用于将 KNN 结果转换为检索到的块
def knn_to_retrieved_chunks(
    knns,
    chunks_memmap,
    *,
    add_continuations,
    num_chunks,
    pad_id = 0,
    eos_id = EOS_ID,
):

    # 推导出没有找到邻居的掩码
    no_neighbor_mask = knns == -1
    knns = np.maximum(knns, 0)

    # 获取邻居和连续块
    knn_chunks = chunks_memmap[knns]
    is_last_document_chunk = np.any(knn_chunks == eos_id, axis = -1, keepdims = True)

    # 使用 [EOS] 在块中的存在作为检测文档边界的方式
    retrieved = knn_chunks[..., :-1]

    if add_continuations:
        continuation_indices = np.clip(knns + 1, 0, num_chunks - 1) # 块是连续存储的
        continuation_chunks = chunks_memmap[continuation_indices][..., :-1]
        continuation_chunks *= ~is_last_document_chunk

        # 将邻居与连续块合并
        retrieved = np.concatenate((retrieved, continuation_chunks), axis = -1)

    # 将任何最近邻块为 -1(在索引时未找到)的掩码为填充 ID
    retrieved = np.where(~no_neighbor_mask[..., None], retrieved, pad_id)
    return retrieved

# 定义类 RETRODataset,继承自 Dataset 类
class RETRODataset(Dataset):
    def __init__(
        self,
        *,
        num_chunks,
        chunk_size,
        seq_len,
        num_sequences,
        num_neighbors,
        chunk_memmap_path,
        chunk_nn_memmap_path,
        seq_memmap_path,
        eos_id = EOS_ID,
        pad_id = 0.,
        add_continuations = True
    ):
        super().__init__()
        self.num_chunks = num_chunks
        self.num_sequences = num_sequences
        self.seq_num_chunks = seq_len // chunk_size
        self.eos_id = eos_id
        self.pad_id = pad_id

        num_chunks_with_padding = num_chunks + self.seq_num_chunks

        chunks_shape = (num_chunks_with_padding, chunk_size + 1)
        knn_shape = (num_chunks_with_padding, num_neighbors)

        self.add_continuations = add_continuations
        self.get_chunks = partial(memmap, chunk_memmap_path, dtype = np.int32, shape = chunks_shape)
        self.get_knns = partial(memmap, chunk_nn_memmap_path, dtype = np.int32, shape = knn_shape)
        self.get_seqs = partial(memmap, seq_memmap_path, dtype = np.int32, shape = (num_sequences,))

    # 返回数据集的长度
    def __len__(self):
        return self.num_sequences

    # 获取数据集中指定索引的数据
    def __getitem__(self, ind):
        with self.get_chunks() as chunks_memmap, self.get_knns() as knns_memmap, self.get_seqs() as seqs_memmap:
            begin_chunk_index = seqs_memmap[ind]
            chunk_range = slice(begin_chunk_index, (begin_chunk_index + self.seq_num_chunks))

            chunks = chunks_memmap[chunk_range]

            # 剪切最后一个标记,除了最后一个块的最后一个标记
            seq_tokens = np.concatenate((chunks[:, :-1].flatten(), chunks[-1, -1:]))

            # 掩码掉(使用填充标记)任何跟在 <eos> 后的标记 | 不允许一个序列中有多个文档,因为这会破坏 RETRO 的 CCA
            seq_mask = np.cumsum(seq_tokens == self.eos_id, axis = 0)
            seq_mask = np.pad(seq_mask, (1, 0))[:-1] == 0.
            seq_tokens = np.where(seq_mask, seq_tokens, 0.)

            # 推导出检索到的标记
            knns = knns_memmap[chunk_range]

            retrieved = knn_to_retrieved_chunks(
                knns,
                chunks_memmap,
                add_continuations = self.add_continuations,
                eos_id = self.eos_id,
                num_chunks = self.num_chunks
            )

        seq_tokens_torch = torch.from_numpy(seq_tokens).long()
        retrieved_torch = torch.from_numpy(retrieved).long()
        return seq_tokens_torch, retrieved_torch

.\lucidrains\RETRO-pytorch\retro_pytorch\optimizer.py

# 从 torch.optim 模块中导入 AdamW 优化器
from torch.optim import AdamW

# 将参数分为可进行权重衰减和不可进行权重衰减的参数
def separate_weight_decayable_params(params):
    # 找出参数中维度小于 2 的参数,即不可进行权重衰减的参数
    no_wd_params = set([param for param in params if param.ndim < 2])
    # 可进行权重衰减的参数为所有参数减去不可进行权重衰减的参数
    wd_params = set(params) - no_wd_params
    return wd_params, no_wd_params

# 根据参数和超参数创建 AdamW 优化器
def get_optimizer(params, lr = 3e-4, wd = 1e-1, filter_by_requires_grad = False):
    # 如果需要根据 requires_grad 过滤参数,则只保留 requires_grad 为 True 的参数
    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    # 将参数转换为集合
    params = set(params)
    # 将参数分为可进行权重衰减和不可进行权重衰减的参数
    wd_params, no_wd_params = separate_weight_decayable_params(params)

    # 构建参数组,其中可进行权重衰减的参数使用默认权重衰减,不可进行权重衰减的参数不使用权重衰减
    param_groups = [
        {'params': list(wd_params)},
        {'params': list(no_wd_params), 'weight_decay': 0},
    ]

    # 返回使用 AdamW 优化器的参数组和超参数 lr 和 wd 的优化器
    return AdamW(param_groups, lr = lr, weight_decay = wd)

.\lucidrains\RETRO-pytorch\retro_pytorch\retrieval.py

# 导入所需的模块
from pathlib import Path
from math import ceil

import torch
import torch.nn.functional as F
import logging
import numpy as np
from einops import rearrange

import faiss
from autofaiss import build_index

from retro_pytorch.utils import memmap, reset_folder_

# 常量定义

SOS_ID = 101
EOS_ID = 102
BERT_MODEL_DIM = 768
BERT_VOCAB_SIZE = 28996

TMP_PATH = Path('./.tmp')
INDEX_FOLDER_PATH = TMP_PATH / '.index'
EMBEDDING_TMP_SUBFOLDER = 'embeddings'

# 辅助函数

def exists(val):
    return val is not None

def range_chunked(max_value, *, batch_size):
    counter = 0
    while counter < max_value:
        curr = counter + batch_size
        curr = min(curr, max_value)
        yield slice(counter, curr)
        counter = curr

# 索引辅助函数

def faiss_read_index(path):
    return faiss.read_index(str(path), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)

# 单例全局变量

MODEL = None
TOKENIZER = None

def get_tokenizer():
    global TOKENIZER
    if not exists(TOKENIZER):
        TOKENIZER = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased')
    return TOKENIZER

def get_bert():
    global MODEL
    if not exists(MODEL):
        MODEL = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-cased')
        if torch.cuda.is_available():
            MODEL = MODEL.cuda()

    return MODEL

# 分词

def tokenize(texts, add_special_tokens = True):
    if not isinstance(texts, (list, tuple)):
        texts = [texts]

    tokenizer = get_tokenizer()

    encoding = tokenizer.batch_encode_plus(
        texts,
        add_special_tokens = add_special_tokens,
        padding = True,
        return_tensors = 'pt'
    )

    token_ids = encoding.input_ids
    return token_ids

# 文本转换为块和序列索引

def doc_text_to_chunks_and_seq_indices(
    *,
    doc_text,
    chunk_size = 64,
    seq_len = 2048,
    pad_id = 0
):
    assert (seq_len % chunk_size) == 0, 'sequence length must be divisible by chunk size'

    ids = tokenize(doc_text)
    ids = rearrange(ids, '1 ... -> ...')

    text_len = ids.shape[-1]

    # 用额外的标记填充到块大小的倍数

    padding = chunk_size - ((text_len - 1) % chunk_size)
    ids = F.pad(ids, (0, padding))

    # 分离最后一个标记

    ids, last_token = ids[:-1], ids[-1:]
    ids = rearrange(ids, '(n c) -> n c', c = chunk_size)

    # 块的第一个标记[2:]及之后的标记将成为块[1:]的最后一个标记

    last_token_per_chunk = ids[1:, 0]
    all_last_tokens = torch.cat((last_token_per_chunk, last_token), dim = 0)
    all_last_tokens = rearrange(all_last_tokens, 'n -> n 1')

    # 将所有最后一个标记附加到块中,形成(num_chunks, chunk_size + 1)

    chunks_with_extra_token = torch.cat((ids, all_last_tokens), dim = -1)

    # 计算从0开始的块索引,间隔为序列长度的块数

    total_chunks = ids.shape[0]
    num_chunks_per_seq = seq_len // chunk_size
    seq = torch.arange(0, total_chunks, num_chunks_per_seq)

    return chunks_with_extra_token, seq

def text_folder_to_chunks_(
    *,
    folder,
    chunks_memmap_path,
    seqs_memmap_path,
    doc_ids_memmap_path,
    chunk_size = 64,
    seq_len = 2048,
    glob = '**/*.txt',
    max_chunks = 1_000_000,
    max_seqs = 100_000
):
    paths = sorted([*Path(folder).glob(glob)])

    total_chunks = 0
    total_docs = 0
    total_seqs = 0

    chunks_shape = (max_chunks, chunk_size + 1)
    seqs_shape = (max_seqs,)
    doc_ids_shape = (max_chunks,)
    # 使用上下文管理器打开三个内存映射文件,分别用于存储chunks、seqs和doc_ids
    with memmap(chunks_memmap_path, shape = chunks_shape, dtype = np.int32, mode = 'w+') as chunks_memmap\
        , memmap(seqs_memmap_path, shape = seqs_shape, dtype = np.int32, mode = 'w+') as seqs_memmap\
        , memmap(doc_ids_memmap_path, shape = doc_ids_shape, dtype = np.int32, mode = 'w+') as doc_ids_memmap:

        # 遍历所有路径
        for path in paths:
            # 打印当前处理的路径
            print(f'processing {path}')

            # 将文档文本转换为chunks和seq的索引
            chunks, seq = doc_text_to_chunks_and_seq_indices(
                doc_text = path.read_text(),
                chunk_size = chunk_size,
                seq_len = seq_len
            )

            # 获取当前文档的chunks和seq的长度
            doc_chunk_len = chunks.shape[0]
            doc_seq_len = seq.shape[0]

            # 将当前文档的chunks写入chunks内存映射文件
            chunks_memmap[total_chunks:(total_chunks + doc_chunk_len)] = chunks.numpy()
            # 将当前文档的seq索引写入seqs内存映射文件,并加上之前文档的总chunks数
            seqs_memmap[total_seqs:(total_seqs + doc_seq_len)] = seq.numpy() + total_chunks
            # 将当前文档的doc_ids写入doc_ids内存映射文件,使用当前文档的总chunks数填充
            doc_ids_memmap[total_chunks:(total_chunks + doc_chunk_len)] = np.full((doc_chunk_len,), total_docs)

            # 更新总chunks、总seqs和总docs数
            total_chunks += doc_chunk_len
            total_seqs += doc_seq_len
            total_docs += 1

    # 返回包含总chunks、总docs和总seqs数的字典
    return dict(
        chunks = total_chunks,
        docs = total_docs,
        seqs = total_seqs
    )
# 嵌入函数

@torch.no_grad()
def bert_embed(
    token_ids,
    return_cls_repr = False,
    eps = 1e-8,
    pad_id = 0.
):
    # 获取 BERT 模型
    model = get_bert()
    # 创建掩码,标记不是填充符的位置
    mask = token_ids != pad_id

    # 如果有可用的 GPU,则将数据移至 GPU
    if torch.cuda.is_available():
        token_ids = token_ids.cuda()
        mask = mask.cuda()

    # 使用 BERT 模型进行前向传播
    outputs = model(
        input_ids = token_ids,
        attention_mask = mask,
        output_hidden_states = True
    )

    # 获取最后一个隐藏状态
    hidden_state = outputs.hidden_states[-1]

    # 如果需要返回 [cls] 作为表示,则返回 [cls] 的隐藏状态
    if return_cls_repr:
        return hidden_state[:, 0]

    # 如果没有掩码存在,则计算所有 token 的平均值
    if not exists(mask):
        return hidden_state.mean(dim = 1)

    # 更新掩码,排除 [cls],考虑长度
    mask = mask[:, 1:]
    mask = rearrange(mask, 'b n -> b n 1')

    # 计算加权平均值
    numer = (hidden_state[:, 1:] * mask).sum(dim = 1)
    denom = mask.sum(dim = 1)
    masked_mean =  numer / (denom + eps)
    return masked_mean

# 将块转换为 KNN

def chunks_to_embeddings_(
    *,
    num_chunks,
    chunks_memmap_path,
    embeddings_memmap_path,
    chunk_size = 64,
    embed_dim = BERT_MODEL_DIM,
    batch_size = 16,
    use_cls_repr = False,
    pad_id = 0.
):
    chunks_shape = (num_chunks, chunk_size + 1)
    embed_shape = (num_chunks, embed_dim)

    # 使用内存映射加载块和嵌入
    with memmap(chunks_memmap_path, shape = chunks_shape, dtype = np.int32) as chunks\
        , memmap(embeddings_memmap_path, shape = embed_shape, dtype = np.float32, mode = 'w+') as embeddings:

        # 对块进行分批处理
        for dim_slice in range_chunked(num_chunks, batch_size = batch_size):
            batch_chunk_npy = chunks[dim_slice]

            batch_chunk = torch.from_numpy(batch_chunk_npy)

            cls_tokens = torch.full((batch_chunk.shape[0], 1), SOS_ID)
            batch_chunk = torch.cat((cls_tokens, batch_chunk), dim = 1)

            batch_chunk = batch_chunk[:, :-1] # 省略最后一个 token,下一个块的第一个 token,用于自回归训练

            # 获取块的嵌入
            batch_embed = bert_embed(
                batch_chunk,
                return_cls_repr = use_cls_repr
            )

            # 将嵌入写入内存映射
            embeddings[dim_slice] = batch_embed.detach().cpu().numpy()
            print(f'embedded {dim_slice.stop} / {num_chunks}')


def memmap_file_to_chunks_(
    memmap_path,
    *,
    folder,
    shape,
    dtype,
    max_rows_per_file = 500
):
    rows, _ = shape

    # 使用内存映射将文件分割为块并保存
    with memmap(memmap_path, shape = shape, dtype = dtype, mode = 'r') as f:
        root_path = TMP_PATH / folder
        reset_folder_(root_path)

        for ind, dim_slice in enumerate(range_chunked(rows, batch_size = max_rows_per_file)):
            filename = root_path / f'{ind:05d}.npy'
            data_slice = f[dim_slice]

            np.save(str(filename), f[dim_slice])
            print(f'saved {str(filename)}')

def index_embeddings(
    embeddings_folder,
    *,
    index_file = 'knn.index',
    index_infos_file = 'index_infos.json',
    max_index_memory_usage = '100m',
    current_memory_available = '1G'
):
    embeddings_path = TMP_PATH / embeddings_folder
    index_path = INDEX_FOLDER_PATH / index_file

    reset_folder_(INDEX_FOLDER_PATH)

    # 构建索引
    build_index(
        embeddings = str(embeddings_path),
        index_path = str(index_path),
        index_infos_path = str(INDEX_FOLDER_PATH / index_infos_file),
        metric_type = "l2",
        max_index_memory_usage = max_index_memory_usage,
        current_memory_available = current_memory_available,
        make_direct_map = True,
        should_be_memory_mappable = False,
        use_gpu = torch.cuda.is_available(),
    )

    # 读取索引
    index = faiss_read_index(index_path)
    return index

def chunks_to_index_and_embed(
    *,
    num_chunks,
    chunk_size,
    chunk_memmap_path,
    use_cls_repr = False,
    max_rows_per_file = 500,
    chunks_to_embeddings_batch_size = 16,
    embed_dim = BERT_MODEL_DIM,
    index_file = 'knn.index',
    **index_kwargs
):
    embedding_path = f'{chunk_memmap_path}.embedded'
    embed_shape = (num_chunks, embed_dim)
    # 将数据分块转换为嵌入向量
    chunks_to_embeddings_(
        num_chunks = num_chunks,  # 数据分块的数量
        chunk_size = chunk_size,  # 每个数据分块的大小
        chunks_memmap_path = chunk_memmap_path,  # 数据分块的内存映射路径
        embeddings_memmap_path = embedding_path,  # 嵌入向量的内存映射路径
        use_cls_repr = use_cls_repr,  # 是否使用分类表示
        batch_size = chunks_to_embeddings_batch_size,  # 转换为嵌入向量的批处理大小
        embed_dim = embed_dim  # 嵌入向量的维度
    )

    # 将内存映射文件转换为数据分块
    memmap_file_to_chunks_(
        embedding_path,  # 嵌入向量的内存映射路径
        shape = embed_shape,  # 嵌入向量的形状
        dtype = np.float32,  # 数据类型为32位浮点数
        folder = EMBEDDING_TMP_SUBFOLDER,  # 数据分块存储的文件夹
        max_rows_per_file = max_rows_per_file  # 每个文件的最大行数
    )

    # 对嵌入向量进行索引
    index = index_embeddings(
        embeddings_folder = EMBEDDING_TMP_SUBFOLDER,  # 嵌入向量存储的文件夹
        index_file = index_file,  # 索引文件
        **index_kwargs  # 其他索引参数
    )

    # 从内存映射文件中读取嵌入向量
    embeddings = np.memmap(embedding_path, shape = embed_shape, dtype = np.float32, mode = 'r')
    # 返回索引和嵌入向量
    return index, embeddings
# 定义一个函数,用于将数据划分为预先计算的 KNN(K-Nearest Neighbors)索引
def chunks_to_precalculated_knn_(
    *,
    num_nearest_neighbors,  # 最近邻居的数量
    num_chunks,  # 数据块的数量
    chunk_size,  # 数据块的大小
    chunk_memmap_path,  # 数据块的内存映射路径
    doc_ids_memmap_path,  # 文档 ID 的内存映射路径
    use_cls_repr = False,  # 是否使用分类表示
    max_rows_per_file = 500,  # 每个文件的最大行数
    chunks_to_embeddings_batch_size = 16,  # 数据块到嵌入的批处理大小
    embed_dim = BERT_MODEL_DIM,  # 嵌入维度
    num_extra_neighbors = 10,  # 额外的邻居数量
    force_reprocess = False,  # 是否强制重新处理
    index_file = 'knn.index',  # 索引文件名
    **index_kwargs  # 其他索引参数
):
    # 获取数据块的路径
    chunk_path = Path(chunk_memmap_path)
    # 获取 KNN 文件的路径
    knn_path = chunk_path.parents[0] / f'{chunk_path.stem}.knn{chunk_path.suffix}'
    # 获取索引文件的路径
    index_path = INDEX_FOLDER_PATH / index_file

    # 如果索引文件和 KNN 文件存在且不需要强制重新处理,则直接返回 KNN 文件路径和 Faiss 索引
    if index_path.exists() and knn_path.exists() and not force_reprocess:
        print(f'preprocessed knn found at {str(knn_path)}, faiss index reconstituted from {str(index_path)}')
        index = faiss_read_index(index_path)
        return knn_path, index

    # 获取 Faiss 索引和数据块的嵌入
    index, embeddings = chunks_to_index_and_embed(
        num_chunks = num_chunks,
        chunk_size = chunk_size,
        chunk_memmap_path = chunk_memmap_path,
        index_file = index_file,
        **index_kwargs
    )

    # 计算需要获取的总邻居数
    total_neighbors_to_fetch = num_extra_neighbors + num_nearest_neighbors + 1

    # 使用内存映射创建 KNN 和文档 ID 的数组
    with memmap(knn_path, shape = (num_chunks, num_nearest_neighbors), dtype = np.int32, mode = 'w+') as knns\
        , memmap(doc_ids_memmap_path, shape = (num_chunks,), dtype = np.int32, mode = 'r') as doc_ids:

        # 对数据块进行分片处理
        for dim_slice in range_chunked(num_chunks, batch_size = max_rows_per_file):
            # 获取查询向量
            query_vector = embeddings[dim_slice]

            # 使用索引查找最近邻居
            distances, indices = index.search(query_vector, k = total_neighbors_to_fetch)

            # 移除自身作为邻居
            distances = distances[:, 1:]
            indices = indices[:, 1:]

            # 将属于同一文档的邻居标记为 -1
            query_doc_ids = doc_ids[dim_slice]
            neighbor_doc_ids = doc_ids[indices]
            neighbor_from_same_doc = query_doc_ids[..., None] == neighbor_doc_ids

            indices = np.where(neighbor_from_same_doc, -1, indices)
            distances = np.where(neighbor_from_same_doc, 1e3, distances)

            # 根据更新后的距离重新排序索引
            indices = np.take_along_axis(indices, np.argsort(distances, axis = 1), axis = 1)

            # 将最近邻居存储到 KNN 内存映射中
            knns[dim_slice] = indices[:, :num_nearest_neighbors]

            print(f'knns calculated for {dim_slice.stop} / {num_chunks}')

    # 打印 KNN 文件保存路径
    print(f'knn saved to {knn_path}')
    return knn_path, index

.\lucidrains\RETRO-pytorch\retro_pytorch\retro_pytorch.py

# 导入必要的库
from functools import partial

import torch
import torch.nn.functional as F
from torch import nn, einsum

# 导入自定义的库
from retro_pytorch.retrieval import BERT_VOCAB_SIZE
from einops import rearrange, repeat

# 常量定义
MIN_DIM_HEAD = 32

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 判断一个数是否可以被另一个数整除
def divisible_by(val, divisor):
    return (val / divisor).is_integer()

# 将变量转换为元组
def cast_tuple(val, num = 1):
    return val if isinstance(val, tuple) else ((val,) * num)

# 初始化深度网络参数
def deepnorm_init(transformer, beta, module_name_match_list = ['.ff.', '.to_v', '.to_out']):
    for name, module in transformer.named_modules():
        if type(module) != nn.Linear:
            continue

        needs_beta_gain = any(map(lambda substr: substr in name, module_name_match_list))
        gain = beta if needs_beta_gain else 1
        nn.init.xavier_normal_(module.weight.data, gain = gain)

        if exists(module.bias):
            nn.init.constant_(module.bias.data, 0)

# 归一化

# RMS归一化类
class RMSNorm(nn.Module):
    def __init__(
        self,
        dim,
        *,
        eps = 1e-8,
        gated = False
    ):
        super().__init__()
        self.eps = eps
        self.scale = dim ** -0.5
        self.gamma = nn.Parameter(torch.ones(dim))
        self.weight = nn.Parameter(torch.ones(dim)) if gated else None

    def forward(self, x):
        norm = x.norm(keepdim = True, dim = -1) * self.scale
        out = (x / norm.clamp(min = self.eps)) * self.gamma

        if not exists(self.weight):
            return out

        return out * (x * self.weight).sigmoid()

# 前向和后向归一化残差包装模块

# 前向归一化类
class PreNorm(nn.Module):
    def __init__(self, dim, fn, norm_klass = RMSNorm):
        super().__init__()
        self.fn = fn
        self.norm = norm_klass(dim)

    def forward(self, x, *args, **kwargs):
        return self.fn(self.norm(x), *args, **kwargs) + x

# 后向归一化类
class PostNorm(nn.Module):
    def __init__(self, dim, fn, scale_residual = 1, norm_klass = RMSNorm):
        super().__init__()
        self.fn = fn
        self.scale_residual = scale_residual
        self.norm = norm_klass(dim)

    def forward(self, x, *args, **kwargs):
        residual = x * self.scale_residual
        out = self.fn(x, *args, **kwargs) + residual
        return self.norm(out)

# 位置嵌入

# 旋转嵌入类
class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, max_seq_len, *, device, offset = 0):
        seq = torch.arange(max_seq_len, device = device) + offset
        freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq)
        emb = torch.cat((freqs, freqs), dim = -1)
        return rearrange(emb, 'n d -> 1 1 n d')

# 旋转半个位置
def rotate_half(x):
    x = rearrange(x, '... (j d) -> ... j d', j = 2)
    x1, x2 = x.unbind(dim = -2)
    return torch.cat((-x2, x1), dim = -1)

# 应用旋转位置嵌入
def apply_rotary_pos_emb(t, freqs):
    seq_len, rot_dim = t.shape[-2], freqs.shape[-1]
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
    t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
    return torch.cat((t, t_pass), dim = -1)

# 前馈网络

# 前馈网络类
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        inner_dim = int(mult * dim)

        self.ff = nn.Sequential(
            nn.Linear(dim, inner_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim)
        )

    def forward(self, x):
        return self.ff(x)

# 注意力机制

# 注意力类
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        context_dim = None,
        dim_head = 64,
        heads = 8,
        causal = False,
        dropout = 0.,
        null_kv = False
    # 初始化函数,设置模型参数
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        causal = False,
        context_dim = None,
        null_kv = False
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 设置上下文维度,默认为输入维度
        context_dim = default(context_dim, dim)

        # 设置头数和缩放因子
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.causal = causal
        inner_dim = dim_head * heads

        # 设置dropout层
        self.dropout = nn.Dropout(dropout)

        # 线性变换层,将输入转换为查询、键、值
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias = False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        # 允许对空值进行注意力计算,以防止注意力破坏
        self.null_k = nn.Parameter(torch.randn(inner_dim)) if null_kv else None
        self.null_v = nn.Parameter(torch.randn(inner_dim)) if null_kv else None

    # 前向传播函数
    def forward(self, x, mask = None, context = None, pos_emb = None):
        # 获取输入张量的形状、设备、头数和缩放因子
        b, device, h, scale = x.shape[0], x.device, self.heads, self.scale

        # 获取键值对输入,默认为输入张量
        kv_input = default(context, x)

        # 分别对输入进行线性变换得到查询、键、值
        q, k, v = self.to_q(x), self.to_k(kv_input), self.to_v(kv_input)

        # 将查询、键、值按头数拆分
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 对查询进行缩放
        q = q * scale

        # 应用相对位置编码(旋转嵌入)
        if exists(pos_emb):
            q_pos_emb, k_pos_emb = cast_tuple(pos_emb, num = 2)
            q = apply_rotary_pos_emb(q, q_pos_emb)
            k = apply_rotary_pos_emb(k, k_pos_emb)

        # 添加空键/值
        if exists(self.null_k):
            nk, nv = self.null_k, self.null_v
            nk, nv = map(lambda t: repeat(t, '(h d) -> b h 1 d', b = b, h = h), (nk, nv))
            k = torch.cat((nk, k), dim = -2)
            v = torch.cat((nv, v), dim = -2)

        # 计算查询键相似度
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # 掩码
        mask_value = -torch.finfo(sim.dtype).max
        if exists(mask):
            if exists(self.null_k):
                mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, mask_value)

        # 如果是因果注意力,进行掩码
        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones(i, j, device = device, dtype = torch.bool).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, mask_value)

        # 注意力计算
        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        # 聚合
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # 合并头部
        out = rearrange(out, 'b h n d -> b n (h d)')

        # 线性变换输出
        return self.to_out(out)
class ChunkedCrossAttention(nn.Module):
    def __init__(
        self,
        chunk_size,
        **kwargs
    ):
        super().__init__()
        self.chunk_size = chunk_size
        self.cross_attn = Attention(null_kv = True, **kwargs)

    def forward(self, x, *, context_mask = None, context, pos_emb = None):
        # derive variables
        chunk_size = self.chunk_size

        b, n, num_chunks, num_retrieved = x.shape[0], x.shape[-2], *context.shape[-4:-2]

        # if sequence length less than chunk size, do an early return
        if n < self.chunk_size:
            return torch.zeros_like(x)

        # causal padding
        causal_padding = chunk_size - 1

        x = F.pad(x, (0, 0, -causal_padding, causal_padding), value = 0.)

        # remove sequence which is ahead of the neighbors retrieved (during inference)
        seq_index = (n // chunk_size) * chunk_size
        x, x_remainder = x[:, :seq_index], x[:, seq_index:]

        seq_remain_len = x_remainder.shape[-2]

        # take care of rotary positional embedding
        # make sure queries positions are properly shifted to the future
        q_pos_emb, k_pos_emb = pos_emb
        q_pos_emb = F.pad(q_pos_emb, (0, 0, -causal_padding, causal_padding), value = 0.)

        k_pos_emb = repeat(k_pos_emb, 'b h n d -> b h (r n) d', r = num_retrieved)
        pos_emb = (q_pos_emb, k_pos_emb)

        # reshape so we have chunk to chunk attention, without breaking causality
        x = rearrange(x, 'b (k n) d -> (b k) n d', k = num_chunks)
        context = rearrange(context, 'b k r n d -> (b k) (r n) d')

        if exists(context_mask):
            context_mask = rearrange(context_mask, 'b k r n -> (b k) (r n)')

        # cross attention
        out = self.cross_attn(x, context = context, mask = context_mask, pos_emb = pos_emb)

        # reshape back to original sequence
        out = rearrange(out, '(b k) n d -> b (k n) d', b = b)

        # pad back to original, with 0s at the beginning (which will be added to the residual and be fine)
        out = F.pad(out, (0, 0, causal_padding, -causal_padding + seq_remain_len), value = 0.)
        return out

# encoder and decoder classes

class Encoder(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth,
        context_dim = None,
        causal = False,
        heads = 8,
        dim_head = 64,
        attn_dropout = 0.,
        ff_mult = 4,
        ff_dropout = 0.,
        final_norm = True,
        cross_attn_layers = None,
        post_norm = False,
        output_dim = None,
        norm_klass = RMSNorm,
        scale_residual = 1.
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        # partial rotary embeddings, which is better than full rotary
        # Wang and Komatsuzaki et al https://github.com/kingoflolz/mesh-transformer-jax/

        rotary_emb_dim = min(dim_head, MIN_DIM_HEAD)
        self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim)

        wrapper = partial(PreNorm, dim, norm_klass = norm_klass) if not post_norm else partial(PostNorm, dim, scale_residual = scale_residual, norm_klass = norm_klass)

        for layer_num in range(1, depth + 1):
            has_cross_attn = not exists(cross_attn_layers) or layer_num in cross_attn_layers

            self.layers.append(nn.ModuleList([
                wrapper(Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = causal)),
                wrapper(Attention(dim = dim, context_dim = context_dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)) if has_cross_attn else None,
                wrapper(FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)),
            ]))

        self.norm_out = norm_klass(dim) if final_norm and not post_norm else nn.Identity()
        self.project_out = nn.Linear(dim, output_dim) if exists(output_dim) else nn.Identity()
    # 定义一个前向传播函数,接受输入 x 和关键字参数 mask 和 chunked_seq
    def forward(self, x, *, mask = None, chunked_seq):
        # 获取输入 x 的设备信息、分块大小和序列长度
        device, chunk_size, seq_len = x.device, x.shape[-2], chunked_seq.shape[-2]

        # 生成查询位置编码
        q_pos_emb = self.rotary_pos_emb(chunk_size, device = device)
        # 生成键值位置编码
        k_pos_emb = self.rotary_pos_emb(seq_len, device = device)

        # 遍历每个层中的注意力、交叉注意力和前馈网络
        for attn, cross_attn, ff in self.layers:
            # 使用注意力机制处理输入 x,传入位置编码 q_pos_emb
            x = attn(x, mask = mask, pos_emb = q_pos_emb)

            # 如果存在交叉注意力层
            if exists(cross_attn):
                # 使用交叉注意力处理输入 x,传入上下文 chunked_seq 和位置编码 q_pos_emb、k_pos_emb
                x = cross_attn(x, context = chunked_seq, pos_emb = (q_pos_emb, k_pos_emb))

            # 使用前馈网络处理输入 x
            x = ff(x)

        # 对处理后的 x 进行输出层的归一化
        x = self.norm_out(x)
        # 对归一化后的 x 进行输出投影
        return self.project_out(x)
class Decoder(nn.Module):
    # 定义解码器类
    def __init__(
        self,
        dim,
        *,
        depth,
        heads = 8,
        dim_head = 64,
        attn_dropout = 0.,
        ff_mult = 4,
        ff_dropout = 0.,
        final_norm = True,
        cross_attn_layers = None,
        chunk_size = 64,
        post_norm = False,
        norm_klass = RMSNorm,
        scale_residual = 1.
    ):
        # 初始化函数,设置解码器的参数
        super().__init__()
        self.layers = nn.ModuleList([])

        # 部分旋转嵌入,比完整旋转更好
        # 王和小松崎等人 https://github.com/kingoflolz/mesh-transformer-jax/
        rotary_emb_dim = min(dim_head, MIN_DIM_HEAD)
        self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim)

        wrapper = partial(PreNorm, dim, norm_klass = norm_klass) if not post_norm else partial(PostNorm, dim, scale_residual = scale_residual, norm_klass = norm_klass)

        self.chunk_size = chunk_size

        for layer_num in range(1, depth + 1):
            has_cross_attn = not exists(cross_attn_layers) or layer_num in cross_attn_layers

            self.layers.append(nn.ModuleList([
                wrapper(Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = True)),
                wrapper(ChunkedCrossAttention(chunk_size = chunk_size, dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)) if has_cross_attn else None,
                wrapper(FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)),
            ]))

        self.norm_out = norm_klass(dim) if final_norm and not post_norm else nn.Identity()

    def forward(self, x, *, encoder = None, encoder_retrieved_mask = None, context_mask = None, retrieved = None):
        # 前向传播函数,接收输入 x 和其他参数
        device, seq_len = x.device, x.shape[-2]
        self_attn_pos_emb = self.rotary_pos_emb(seq_len, device = device)

        # 计算序列索引
        num_seq_chunks = seq_len // self.chunk_size
        seq_index = num_seq_chunks * self.chunk_size

        # 在检索的块上进行旋转位置
        if exists(retrieved):
            num_chunks, num_neighbors, chunk_size = retrieved.shape[-4:-1]

            cross_attn_q_pos_emb = self.rotary_pos_emb(self.chunk_size, device = device, offset = self.chunk_size - 1)  # 需要添加额外的块大小,因为它将被移位
            cross_attn_k_pos_emb = self.rotary_pos_emb(chunk_size, device = device)

            cross_attn_pos_emb = (cross_attn_q_pos_emb, cross_attn_k_pos_emb)

        # ��踪检索的标记是否已编码
        retrieved_encoded = False

        # 遍历解码器层
        for attn, cross_attn, ff in self.layers:
            x = attn(x, pos_emb = self_attn_pos_emb)

            if exists(cross_attn) and exists(retrieved):
                if not retrieved_encoded:
                    retrieved = rearrange(retrieved, 'b k r n d -> (b k r) n d')
                    seq_as_context = repeat(x[:, :seq_index], 'b (k n) d -> (b k r) n d', n = self.chunk_size, r = num_neighbors)

                    retrieved = encoder(retrieved, mask = encoder_retrieved_mask, chunked_seq = seq_as_context)
                    retrieved = rearrange(retrieved, '(b k r) n d -> b k r n d', k = num_chunks, r = num_neighbors)
                    retrieved_encoded = True

                x = cross_attn(
                    x,
                    context = retrieved,
                    context_mask = context_mask,
                    pos_emb = cross_attn_pos_emb
                )

            x = ff(x)

        return self.norm_out(x)

# 主类
class RETRO(nn.Module):
    # 定义主类
    # 初始化模型参数
    def __init__(
        self,
        *,
        num_tokens = BERT_VOCAB_SIZE,  # 设置词汇表大小,默认为BERT词汇表大小
        max_seq_len = 2048,  # 设置最大序列长度,默认为2048
        enc_dim = 896,  # 设置编码器维度,默认为896
        enc_depth = 2,  # 设置编码器深度,默认为2
        enc_cross_attn_layers = None,  # 设置编码器交叉注意力层,默认为None
        dec_depth = 12,  # 设置解码器深度,默认为12
        dec_cross_attn_layers = (1, 3, 6, 9),  # 设置解码器交叉注意力层,默认为(1, 3, 6, 9)
        heads = 8,  # 设置头数,默认为8
        dec_dim = 768,  # 设置解码器维度,默认为768
        dim_head = 64,  # 设置每个头的维度,默认为64
        enc_attn_dropout = 0.,  # 设置编码器注意力机制的dropout,默认为0
        enc_ff_dropout = 0.,  # 设置编码器前馈网络的dropout,默认为0
        dec_attn_dropout = 0.,  # 设置解码器注意力机制的dropout,默认为0
        dec_ff_dropout = 0.,  # 设置解码器前馈网络的dropout,默认为0
        chunk_size = 64,  # 设置块大小,默认为64
        pad_id = 0,  # 设置填充ID,默认为0
        enc_scale_residual = None,  # 设置编码器残差缩放,默认为None
        dec_scale_residual = None,  # 设置解码器残差缩放,默认为None
        norm_klass = None,  # 设置规范化类,默认为None
        gated_rmsnorm = False,  # 设置是否使用门控RMSNorm,默认为False
        use_deepnet = False  # 设置是否使用深度网络,默认为False
    ):
        super().__init__()
        assert dim_head >= MIN_DIM_HEAD, f'dimension per head must be greater than {MIN_DIM_HEAD}'  # 断言每个头的维度必须大于等于最小维度
        self.seq_len = max_seq_len  # 设置序列长度为最大序列长度
        self.pad_id = pad_id  # 设置填充ID

        self.token_emb = nn.Embedding(num_tokens, enc_dim)  # 创建词嵌入层
        self.pos_emb = nn.Embedding(max_seq_len, enc_dim)  # 创建位置嵌入层

        self.chunk_size = chunk_size  # 设置块大小

        self.to_decoder_model_dim = nn.Linear(enc_dim, dec_dim) if enc_dim != dec_dim else nn.Identity()  # 创建线性层,用于编码器到解码器维度转换

        # for deepnet, residual scales
        # follow equation in Figure 2. in https://arxiv.org/abs/2203.00555

        norm_klass = default(norm_klass, RMSNorm)  # 设置规范化类为默认值或RMSNorm

        if use_deepnet:
            enc_scale_residual = default(enc_scale_residual, 0.81 * ((enc_depth ** 4) * dec_depth) ** .0625)  # 如果使用深度网络,则设置编码器残��缩放
            dec_scale_residual = default(dec_scale_residual, (3 * dec_depth) ** 0.25)  # 如果使用深度网络,则设置解码器残差缩放
            norm_klass = nn.LayerNorm  # 如果使用深度网络,则设置规范化类为LayerNorm

        # allow for gated rmsnorm

        if gated_rmsnorm:
            norm_klass = partial(RMSNorm, gated = True)  # 如果使用门控RMSNorm,则设置规范化类为带有门控的RMSNorm

        # define encoder and decoders

        self.encoder = Encoder(
            dim = enc_dim,
            context_dim = dec_dim,
            dim_head = dim_head,
            depth = enc_depth,
            attn_dropout = enc_attn_dropout,
            ff_dropout = enc_ff_dropout,
            cross_attn_layers = enc_cross_attn_layers,
            post_norm = use_deepnet,
            norm_klass = norm_klass,
            scale_residual = enc_scale_residual,
            output_dim = dec_dim
        )  # 定义编码器

        self.decoder = Decoder(
            dim = dec_dim,
            depth = dec_depth,
            dim_head = dim_head,
            attn_dropout = dec_attn_dropout,
            ff_dropout = dec_ff_dropout,
            cross_attn_layers = dec_cross_attn_layers,
            chunk_size = chunk_size,
            post_norm = use_deepnet,
            norm_klass = norm_klass,
            scale_residual = dec_scale_residual
        )  # 定义解码器

        self.to_logits = nn.Linear(dec_dim, num_tokens)  # 创建线性层,用于将解码器输出映射到词汇表大小

        # deepnet has special init of weight matrices

        if use_deepnet:
            deepnorm_init(self.encoder, 0.87 * ((enc_depth ** 4) * dec_depth) ** -0.0625)  # 如果使用深度网络,则初始化编码器
            deepnorm_init(self.decoder, (12 * dec_depth) ** -0.25)  # 如果使用深度网络,则初始化解码器

    def forward_without_retrieval(
        self,
        seq
    ):
        # embed sequence

        embed = self.token_emb(seq)  # 对序列进行词嵌入
        embed = embed[:, :self.seq_len]  # 截取指定长度的嵌入序列

        # get absolute positional embedding

        pos_emb = self.pos_emb(torch.arange(embed.shape[1], device = embed.device))  # 获取绝对位置嵌入
        pos_emb = rearrange(pos_emb, 'n d -> 1 n d')  # 重新排列位置嵌入
        embed = embed + pos_emb  # 将位置嵌入加到词嵌入上

        embed = self.to_decoder_model_dim(embed)  # 将嵌入转换到解码器模型维度
        embed = self.decoder(embed)  # 解码器处理嵌入序列

        # project to logits

        return self.to_logits(embed)  # 将解码器输出映射到词汇表大小

    def forward(
        self,
        seq,
        retrieved = None,
        return_loss = False
        """
        b - batch
        n - sequence length / chunk length
        k - number of chunks
        d - feature dimension
        r - num retrieved neighbors
        """

        # 如果没有提供retrieved参数,则直接调用forward_without_retrieval方法
        if not exists(retrieved):
            return self.forward_without_retrieval(seq)

        # 断言只有在训练时才能返回损失
        assert not (return_loss and not self.training), 'must be training if returning loss'

        # 假设填充标记ID(通常为0)需要被屏蔽掉
        mask = retrieved != self.pad_id

        # 处理一些用户输入
        if retrieved.ndim == 3:
            # 重新排列retrieved的维度,将'n'维度变为1
            retrieved = rearrange(retrieved, 'b k n -> b k 1 n') # 1 neighbor retrieved

        # 如果需要返回损失,则推导标签
        if return_loss:
            seq, labels = seq[:, :-1], seq[:, 1:]

        # 定义变量
        n, num_chunks, num_neighbors, chunk_size, retrieved_shape, device = seq.shape[-1], *retrieved.shape[-3:], retrieved.shape, seq.device

        # 断言检查retrieved输入的chunk_size必须大于等于RETRO初始化时指定的chunk_size
        assert chunk_size >= self.chunk_size, 'chunk size of retrieval input must be greater or equal to the designated chunk_size on RETRO initialization'

        # 计算序列需要的chunk数量,并检查传入的num_chunks是否符合要求
        num_seq_chunks = n // self.chunk_size
        assert num_chunks == num_seq_chunks, f'sequence requires {num_seq_chunks} retrieved chunks, but only {num_chunks} passed in'

        # 计算还未获取k个最近邻的序列索引
        seq_index = num_seq_chunks * self.chunk_size

        # 对序列和retrieved chunks进行嵌入
        embed = self.token_emb(seq)
        retrieved = self.token_emb(retrieved)

        # 获取绝对位置嵌入
        pos_emb = self.pos_emb(torch.arange(n, device=device))
        pos_emb = rearrange(pos_emb, 'n d -> 1 n d')
        embed = embed + pos_emb

        # 如果需要,处理编码器和解码器的掩码
        encoder_retrieved_mask = decoder_retrieved_mask = None
        if exists(mask):
            assert mask.shape == retrieved_shape, 'retrieval mask must be of the same shape as the retrieval tokens'
            encoder_retrieved_mask = rearrange(mask, 'b k r n -> (b k r) n')
            decoder_retrieved_mask = mask

        # 如果需要,将序列嵌入和retrieved嵌入投影到解码器维度
        embed = self.to_decoder_model_dim(embed)

        # 解码
        embed = self.decoder(
            embed,
            encoder=self.encoder,
            context_mask=decoder_retrieved_mask,
            encoder_retrieved_mask=encoder_retrieved_mask,
            retrieved=retrieved
        )

        # 投影到logits
        logits = self.to_logits(embed)

        # 如果不需要返回损失,则返回logits
        if not return_loss:
            return logits

        # 计算交叉熵损失
        loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index=self.pad_id)
        return loss

.\lucidrains\RETRO-pytorch\retro_pytorch\training.py

import numpy as np
from functools import partial
import json
from pathlib import Path

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from retro_pytorch import RETRO, RETRODataset
from retro_pytorch.data import knn_to_retrieved_chunks
from retro_pytorch.optimizer import get_optimizer
from retro_pytorch.retrieval import text_folder_to_chunks_, chunks_to_precalculated_knn_, bert_embed, SOS_ID, EOS_ID
from retro_pytorch.utils import memmap, is_true_env_flag

from einops import rearrange

# helpers

# 检查值是否存在
def exists(val):
    return val is not None

# 评估装饰器,用于在评估时切换模型状态
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 安全拼接张量
def safe_cat(accum, t, dim = -1):
    if not exists(accum):
        return t
    return torch.cat((accum, t), dim = dim)

# sampling helpers

# 对数函数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps)

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 从 Gumbel 噪声中采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

# Top-K 采样
def top_k(logits, thres = 0.9):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# Top-P 采样
def top_p(logits, thres = 0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cum_probs > (1 - thres)
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = 0

    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# 从序列块获取 KNN 块的函数
def knn_chunks_from_seq_chunks(
    seq_chunks,
    *,
    knn,
    faiss_index,
    num_chunks,
    chunk_size,
    chunks_memmap_path,
):
    b, device = seq_chunks.shape[0], seq_chunks.device

    # 为 BERT 嵌入准备带有 SOS 和 EOS 标记的最后一个块

    ones = torch.ones((b, 1), dtype = torch.bool, device = device)
    sos = ones * SOS_ID
    eos = ones * EOS_ID

    seq_chunks = torch.cat((sos, seq_chunks, eos), dim = 1)

    # 使用冻结的 BERT 进行嵌入

    embeds = bert_embed(seq_chunks.cpu()) # 暂时在 CPU 上获取嵌入

    # 使用 faiss 检索 KNN

    _, knn_indices = faiss_index.search(embeds.cpu().numpy(), k = knn)

    # numpy 转换为 torch

    with memmap(chunks_memmap_path, dtype = np.int32, shape = (num_chunks + 1, chunk_size + 1)) as chunk_memmap:
        knn_chunks = knn_to_retrieved_chunks(
            knn_indices,
            chunk_memmap,
            add_continuations = True,
            num_chunks = num_chunks
        )

        knn_chunks_torch = torch.from_numpy(knn_chunks).to(device)

    return knn_chunks_torch

# 训练包装类
class TrainingWrapper(nn.Module):
    def __init__(
        self,
        *,
        retro,
        chunk_size,
        documents_path,
        knn,
        glob = '**/*.txt',
        chunks_memmap_path = './train.chunks.dat',
        seqs_memmap_path = './train.seq.dat',
        doc_ids_memmap_path = './train.doc_ids.dat',
        max_chunks = 1_000_000,
        max_seqs = 100_000,
        knn_extra_neighbors = 100,
        processed_stats_json_path = './processed-stats.json',
        faiss_index_filename = 'knn.index',
        **index_kwargs
    # 初始化 RETROGenerator 类
    def __init__(
        self,
        retro: RETRO,
        processed_stats_json_path: str,
        documents_path: str,
        chunks_memmap_path: str,
        seqs_memmap_path: str,
        doc_ids_memmap_path: str,
        chunk_size: int,
        max_chunks: int,
        max_seqs: int,
        knn: int,
        knn_extra_neighbors: int,
        faiss_index_filename: str,
        **index_kwargs
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 断言 retro 必须是 RETRO 类的实例
        assert isinstance(retro, RETRO), 'retro must be instance of RETRO'
        # 将 retro 赋值给 self.retro
        self.retro = retro

        # 检查是否需要强制重新处理数据
        force_reprocess = is_true_env_flag('REPROCESS')

        # 存储处理后的训练数据统计信息,如块数和序列数
        stats_path = Path(processed_stats_json_path)

        # 如果统计文件不存在或需要强制重新处理,则处理文本文件夹
        if not stats_path.exists() or force_reprocess:
            # 调用函数处理文本文件夹,返回统计信息
            self.stats = text_folder_to_chunks_(
                folder = documents_path,
                glob = glob,
                chunks_memmap_path = chunks_memmap_path,
                seqs_memmap_path = seqs_memmap_path,
                doc_ids_memmap_path = doc_ids_memmap_path,
                chunk_size = chunk_size,
                seq_len = retro.seq_len,
                max_chunks = max_chunks,
                max_seqs = max_seqs
            )
            # 将统计信息写入 JSON 文件
            with open(processed_stats_json_path, 'w') as f:
                json.dump(self.stats, f)
        else:
            # 如果统计文件已经存在,则加载已处理的统计信息
            print(f'found to be previously processed at {str(stats_path)}')
            self.stats = json.loads(stats_path.read_text())

        # 获取块数和序列数
        num_chunks = self.stats['chunks']
        num_seqs = self.stats['seqs']

        # 计算 knn 的内存映射路径并获取 faiss 索引
        knn_memmap_path, faiss_index = chunks_to_precalculated_knn_(
            num_chunks = num_chunks,
            chunk_size = chunk_size,
            chunk_memmap_path = chunks_memmap_path,
            doc_ids_memmap_path = doc_ids_memmap_path,
            num_nearest_neighbors = knn,
            num_extra_neighbors = knn_extra_neighbors,
            index_file = faiss_index_filename,
            force_reprocess = force_reprocess,
            **index_kwargs
        )

        # 初始化 RETRODataset 类
        self.ds = RETRODataset(
            num_sequences = num_seqs,
            num_chunks = num_chunks,
            num_neighbors = knn,
            chunk_size = chunk_size,
            seq_len = retro.seq_len,
            chunk_memmap_path = chunks_memmap_path,
            chunk_nn_memmap_path = knn_memmap_path,
            seq_memmap_path = seqs_memmap_path
        )

        # 生成所需的参数
        self.chunk_size = chunk_size
        self.max_seq_len = self.retro.seq_len

        # 部分函数,用于从序列块中获取 knn 块
        self.fetch_knn_chunks_fn = partial(
            knn_chunks_from_seq_chunks,
            knn = knn,
            chunk_size = chunk_size,
            num_chunks = num_chunks,
            chunks_memmap_path = chunks_memmap_path,
            faiss_index = faiss_index
        )

    # 生成文本的方法,使用装饰器进行评估
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        start = None,
        retrieved = None,
        filter_fn = top_k,
        filter_thres = 0.9,
        temperature = 1.0,
    ):
        # 断言过滤函数必须是top-k或nucleus
        assert filter_fn in {top_k, top_p}, 'filter function must be either top-k or nucleus'

        # 获取设备信息
        device = next(self.retro.parameters()).device

        # 如果没有给定起始标记,则假设从SOS标记开始,批量大小为1
        if not exists(start):
            start = torch.full((1, 1), SOS_ID, device=device).long()

        b, start_seq_len = start.shape

        # 将起始标记移动到与RETRO相同的设备上
        start = start.to(device)

        # 准备检索相关变量
        if start_seq_len >= self.chunk_size:
            seq_index = (start_seq_len // self.chunk_size) * self.chunk_size
            past_seq_chunks = rearrange(start[:, :seq_index], 'b (n c) -> (b n) c', c=self.chunk_size)

            # 获取KNN块
            retrieved = self.fetch_knn_chunks_fn(past_seq_chunks)
            retrieved = rearrange(retrieved, '(b n) k c -> b n k c', b=b)

        # 获取起始序列索引
        out = start

        # 采样循环
        for i in range(start_seq_len - 1, self.max_seq_len):

            logits = self.retro(out, retrieved=retrieved)
            logits = logits[:, i]

            logits = filter_fn(logits, thres=filter_thres)
            sampled = gumbel_sample(logits, temperature=temperature, dim=-1)
            sampled = rearrange(sampled, 'b -> b 1')

            out = torch.cat((out, sampled), dim=1)

            # 如果全部是EOS标记,则提前终止
            is_eos_tokens = (out == EOS_ID)

            if is_eos_tokens.any(dim=-1).all():

                # 在EOS标记后屏蔽所有内容
                shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
                mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
                out = out.masked_fill(mask, self.retro.pad_id)
                break

            # 当序列长度是块大小的倍数时,检索下一组KNN
            curr_seq_len = out.shape[-1]

            if (curr_seq_len % self.chunk_size) == 0:
                last_chunk = rearrange(out, 'b (c n) -> b c n', n=self.chunk_size)[:, -1]

                knn_chunks = self.fetch_knn_chunks_fn(last_chunk)

                # 将检索到的KNN块连接到所有检索到的内容中
                # 以便在下一次迭代中发送到Retro进行块交叉注意力
                knn_chunks = rearrange(knn_chunks, 'b k r -> b 1 k r')
                retrieved = safe_cat(retrieved, knn_chunks, dim=1)

                print(f'retrieved at {curr_seq_len} / {self.max_seq_len}')

        return out

    # 获取数据加载器
    def get_dataloader(self, **kwargs):
        return DataLoader(self.ds, **kwargs)

    # 获取优化器
    def get_optimizer(self, **kwargs):
        return get_optimizer(self.retro.parameters(), **kwargs)

    # 前向传播函数
    def forward(self):
        raise NotImplemented