Lucidrains-系列项目源码解析-八十六-

67 阅读26分钟

Lucidrains 系列项目源码解析(八十六)

.\lucidrains\routing-transformer\routing_transformer\autopadder.py

# 导入数学库和 PyTorch 库
import math
import torch
# 从 torch 模块中导入 nn 模块
from torch import nn
# 从 routing_transformer 模块中导入 RoutingTransformer 类
from routing_transformer.routing_transformer import RoutingTransformer
# 从 torch.nn.functional 模块中导入 F 别名
import torch.nn.functional as F

# 定义一个函数,用于查找指定类型的模块
def find_module(nn_module, type):
    # 遍历 nn_module 中的所有模块
    for module in nn_module.modules():
        # 如果模块是指定类型的实例,则返回该模块
        if isinstance(module, type):
            return module
    # 如果未找到指定类型的模块,则返回 None
    return None

# 定义一个函数,用于将张量填充到指定的倍数
def pad_to_multiple(tensor, multiple, dim=-1, value=0):
    # 获取张量在指定维度上的长度
    seqlen = tensor.shape[dim]
    # 计算需要填充的长度
    m = seqlen / multiple
    # 如果 m 是整数,则无需填充,直接返回原张量
    if m.is_integer():
        return tensor

    # 计算填充前的偏移量和填充长度
    pre_pad_offset = (0,) * (-1 - dim) * 2
    padding = math.ceil(m) * multiple - seqlen
    # 对张量进行填充操作
    padded_tensor = F.pad(tensor, (*pre_pad_offset, *(0, padding)), value=value)
    return padded_tensor

# 定义一个自动填充器类,继承自 nn.Module
class Autopadder(nn.Module):
    def __init__(self, net):
        super().__init__()
        # 查找 RoutingTransformer 类型的模块
        transformer = find_module(net, RoutingTransformer)
        self.net = net
        # 获取 RoutingTransformer 模块的 pad_to_multiple 属性
        self.pad_multiple = transformer.pad_to_multiple

    def forward(self, x, **kwargs):
        # 如果 pad_multiple 小于等于 0,则直接调用网络的 forward 方法
        if self.pad_multiple <= 0:
            return self.net(x, **kwargs)

        # 获取输入张量 x 的形状和设备信息
        b, t, device = *x.shape, x.device

        # 获取输入参数中的 input_mask,如果不存在则创建全为 True 的 mask 张量
        input_mask = kwargs.get('input_mask')
        if input_mask is None:
            input_mask = torch.full((b, t), True, device=device, dtype=torch.bool)

        # 对输入张量和 mask 张量进行填充操作
        x = pad_to_multiple(x, self.pad_multiple, dim=1)
        new_mask = pad_to_multiple(input_mask, self.pad_multiple, dim=1, value=False)
        kwargs.update(input_mask=new_mask)

        # 调用网络的 forward 方法,���返回结果
        out, loss = self.net(x, **kwargs)
        return out[:, 0:t], loss

.\lucidrains\routing-transformer\routing_transformer\autoregressive_wrapper.py

# 导入所需的库
from functools import partial
import torch
import random
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from routing_transformer.routing_transformer import RoutingTransformerLM
from routing_transformer.autopadder import Autopadder

# 定义一个函数,返回参数值或默认值
def default(value, default):
    return value if value is not None else default

# 根据给定的阈值返回概率最高的logits
def top_p(logits, thres = 0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cum_probs > 1.0 - thres
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = 0

    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# 根据给定的阈值返回概率最高的k个logits
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 对序列进行右侧填充
def pad_sequence_right(seqs, value):
    m = max([len(s) for s in seqs])
    return torch.stack([F.pad(s, (0, m - len(s))) for s in seqs])

# 截断序列
def truncate_sequence(inputs, mask = None, pad_value=0):
    b, t, device, dtype = *inputs.shape, inputs.device, inputs.dtype
    mask = default(mask, torch.ones_like(inputs).bool())
    rand_length = random.randint(2, t)
    return inputs[:, :rand_length], mask[:, :rand_length]

# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, ignore_index = None, pad_value = 0):
        super().__init__()
        assert isinstance(net, RoutingTransformerLM), 'generative trainer wrapper can only accept RoutingTransformerLM class'
        self.pad_value = pad_value
        self.ignore_index = default(ignore_index, pad_value)

        self.net = Autopadder(net)
        self.max_seq_len = net.max_seq_len
        self.base_net = net

    # 更新kmeans
    def update_kmeans(self):
        self.base_net.update_kmeans()

    # 生成序列
    @torch.no_grad()
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
        was_training = self.net.training
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        self.net.eval()
        out = start_tokens
        input_mask = kwargs.pop('input_mask', None)

        if input_mask is None:
            input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)

        for _ in range(seq_len):
            x = out[:, -self.max_seq_len:]
            input_mask = input_mask[:, -self.max_seq_len:]
            logits, _ = self.net(x, input_mask=input_mask, **kwargs)
            logits = logits[:, -1, :]
            filtered_logits = filter_logits_fn(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)
            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)
            input_mask = F.pad(input_mask, (1, 0), value=True)
            if eos_token is not None and (sample == eos_token).all():
                break

        out = out[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        self.net.train(was_training)
        return out
    # 定义一个前向传播函数,接受输入 x,是否返回损失值,是否随机截断序列等参数
    def forward(self, x, return_loss = False, randomly_truncate_sequence = False, **kwargs):
        # 定义一个填充函数,将输入序列填充到相同长度
        pad = partial(pad_sequence, batch_first = True, padding_value = self.pad_value)

        # 如果不需要返回损失值
        if not return_loss:
            # 如果输入不是张量,则进行填充
            if not isinstance(x, torch.Tensor):
                x = pad(x)
            # 返回网络的输出结果
            return self.net(x, **kwargs)

        # 获取输入的掩码
        m = kwargs.get('input_mask', None)

        # 如果需要随机截断序列
        if randomly_truncate_sequence:
            # 对输入序列进行截断
            x, m = truncate_sequence(x, m, pad_value = self.pad_value)

        # 如果输入是张量
        if isinstance(x, torch.Tensor):
            # 将输入序列分为输入和输出部分
            xi, xo = x[:, :-1], x[:, 1:]
        else:
            # 对输入序列进行填充和截断
            xi = pad(list(map(lambda t: t[:-1], x)))
            xo = pad(list(map(lambda t: t[1:], x)))

        # 如果存在输入掩码
        if m is not None:
            # 断言输入掩码的形状与输入序列的形状相同
            assert m.shape == x.shape[0:2], 'input mask must be the same shape as the input of the auto-regressive wrapper to automatically handle'
            # 更新关键字参数中的输入掩码
            kwargs['input_mask'] = m[:, :-1]

        # 获取网络的输出和辅助损失
        out, aux_loss = self.net(xi, **kwargs)

        # 计算交叉熵损失
        loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
        # 将辅助损失加到主要损失上
        loss = loss + aux_loss
        # 返回损失值
        return loss

.\lucidrains\routing-transformer\routing_transformer\encoder_decoder.py

# 导入 re 模块,用于正则表达式操作
# 导入 isfunction 函数,用于检查对象是否为函数
# 导入 torch 模块
# 从 torch 模块中导入 nn 模块
# 从 routing_transformer.routing_transformer 模块中导入 RoutingTransformerLM 类和 update_kmeans_on_backwards 函数
# 从 routing_transformer.autoregressive_wrapper 模块中导入 AutoregressiveWrapper 类

# 定义编码器前缀
ENC_PREFIX = 'enc_'
# 定义解码器前缀
DEC_PREFIX = 'dec_'

# 定义默认函数,如果 x 为 None,则返回 d,如果 d 是函数,则调用函数返回结果
def default(x, d):
    if x is None:
        return d if not isfunction(d) else d()
    return x

# 根据条件 cond 对字典 d 进行分组,返回两个字典
def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 判断字符串是否以指定前缀开头
def string_begins_with(prefix, str):
    return bool(re.match(f'^{prefix}', str))

# 根据前缀对字典 d 进行分组
def group_by_key_prefix(prefix, d):
    return group_dict_by_key(lambda x: string_begins_with(prefix, x), d)

# 根据前缀对字典 d 进行分组,并移除前缀
def group_by_key_prefix_and_remove_prefix(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: string_begins_with(prefix, x), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# 提取编码器和解码器的关键字参数
def extract_enc_dec_kwargs(kwargs):
    enc_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(ENC_PREFIX, kwargs)
    dec_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(DEC_PREFIX, kwargs)
    return enc_kwargs, dec_kwargs, kwargs

# 提取并设置编码器和解码器的关键字参数
def extract_and_set_enc_dec_kwargs(kwargs):
    enc_kwargs, dec_kwargs, kwargs = extract_enc_dec_kwargs(kwargs)
    if 'input_mask' in enc_kwargs:
        dec_kwargs.setdefault('context_mask', enc_kwargs['input_mask'])
    return enc_kwargs, dec_kwargs, kwargs

# 定义 RoutingTransformerEncDec 类,继承自 nn.Module
class RoutingTransformerEncDec(nn.Module):
    # 初始化方法
    def __init__(self, dim, ignore_index = None, pad_value = 0, **kwargs):
        super().__init__()
        ignore_index = default(ignore_index, pad_value)
        enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs)
        
        # 断言编码器关键字参数中不包含 'return_embedding',否则抛出异常
        assert 'return_embedding' not in enc_kwargs, 'you cannot manually set the return embeddings flag for the encoder'
        # 断言解码器和编码器关键字参数中均不包含 'dim',否则抛出异常
        assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder'

        # 设置编码器和解码器的维度
        enc_kwargs['dim'] = dec_kwargs['dim'] = dim
        enc_kwargs['return_embeddings'] = True
        dec_kwargs['causal'] = True
        dec_kwargs['receives_context'] = True
        enc_kwargs['_register_kmeans_update'] = dec_kwargs['_register_kmeans_update'] = False

        # 设置默认的窗口大小
        enc_kwargs.setdefault('window_size', 256)
        dec_kwargs.setdefault('window_size', 256)

        # 创建编码器和解码器对象
        enc = RoutingTransformerLM(**enc_kwargs)
        dec = RoutingTransformerLM(**dec_kwargs)

        self.enc = enc
        self.dec = AutoregressiveWrapper(dec, ignore_index = ignore_index, pad_value = pad_value)

        # 如果解码器可逆,则用户必须手动调用编码器辅助损失的反向传播
        # 应该在此处设置一个 bug 赏金
        self.dec_reversible = dec_kwargs.pop('reversible', False)

        # 显示警告消息
        if self.dec_reversible:
            print('Warning! Due to an issue with reversible nets and encoder auxiliary losses, you must explicitly call backwards on the encoder auxiliary loss, which is supplied as the second element of the returned tuple on forward')

        self._handle = None
        self.register_kmeans_update()

    # 取消 K-means 更新
    def cancel_kmeans_update(self):
        if self._handle is None:
            return
        self._handle.remove()
        self._handle = None

    # 注册 K-means 更新
    def register_kmeans_update(self):
        self.cancel_kmeans_update()
        return update_kmeans_on_backwards(self)

    # 使用 torch.no_grad() 修饰的方法
    # 生成序列,根据输入序列和起始输出序列生成目标序列
    def generate(self, seq_in, seq_out_start, max_seq_len = None, **kwargs):
        # 如果未指定最大序列长度,则使用解码器的最大序列长度
        max_seq_len = default(max_seq_len, self.dec.max_seq_len)
        # 提取并设置编码器和解码器的关键字参数
        enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
        # 使用编码器处理输入序列,获取上下文信息
        context, _ = self.enc(seq_in, **enc_kwargs)
        # 调用解码器生成目标序列
        return self.dec.generate(seq_out_start, max_seq_len, context = context, **{**dec_kwargs, **kwargs})

    # 前向传播,处理输入序列和目标序列,计算损失
    def forward(self, seq_in, seq_out, return_loss = False, randomly_truncate_sequence = False, **kwargs):
        # 提取并设置编码器和解码器的关键字参数
        enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
        # 使用编码器处理输入序列,获取上下文信息和编码器的辅助损失
        context, enc_aux_loss = self.enc(seq_in, **enc_kwargs)
        # 调用解码器计算损失
        loss = self.dec(seq_out, return_loss = return_loss, randomly_truncate_sequence = randomly_truncate_sequence, context = context, aux_loss = enc_aux_loss, **dec_kwargs)

        # 如果解码器可逆性开启,用户必须手动调用编码器辅助损失的反向传播
        if self.dec_reversible:
            return loss, enc_aux_loss

        # 初始化辅助损失为可求导的零张量
        aux_loss = torch.tensor(0., requires_grad = True)
        # 总损失为解码器损失加上编码器辅助损失
        loss = loss + enc_aux_loss
        return loss, aux_loss

.\lucidrains\routing-transformer\routing_transformer\reversible.py

import torch
import torch.nn as nn
from operator import itemgetter
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states

# 用于将参数路由到可逆层函数中的函数

def route_args(router, args, depth):
    # 初始化路由后的参数列表
    routed_args = [(dict(), dict()) for _ in range(depth)]
    # 获取参数中与路由器匹配的键
    matched_keys = [key for key in args.keys() if key in router]

    for key in matched_keys:
        val = args[key]
        for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
            # 根据路由将参数分配到对应的函数中
            new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
            routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
    return routed_args

def layer_drop(layers, prob):
    # 根据概率丢弃层
    to_drop = torch.empty(len(layers)).uniform_(0, 1) < prob
    blocks = [block for block, drop in zip(layers, to_drop) if not drop]
    blocks = layers[:1] if len(blocks) == 0 else blocks
    return blocks

def cast_return(ret, requires_grad = True):
    # 将返回值转换为元组形式,用于梯度计算
    if type(ret) is not tuple:
        loss = torch.tensor(0., device=ret.device, dtype=ret.dtype, requires_grad=requires_grad)
        return (ret, loss)
    return ret

# 参考示例 https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 进行保存和设置随机数生成器
class Deterministic(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.cpu_state = None
        self.cuda_in_fwd = None
        self.gpu_devices = None
        self.gpu_states = None

    def record_rng(self, *args):
        # 记录随机数生成器状态
        self.cpu_state = torch.get_rng_state()
        if torch.cuda._initialized:
            self.cuda_in_fwd = True
            self.gpu_devices, self.gpu_states = get_device_states(*args)

    def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)

# 受 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 启发
# 一旦多GPU工作正常,重构并将PR发送回源代码
class ReversibleBlock(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)

    def forward(self, x, f_args = {}, g_args = {}):
        x1, x2 = torch.chunk(x, 2, dim=2)
        y1, y2 = None, None

        f_args['_reverse'] = g_args['_reverse'] = False

        with torch.no_grad():
            f_out, f_loss = cast_return(self.f(x2, record_rng=self.training, **f_args), requires_grad = False)
            y1 = x1 + f_out

            g_out, g_loss = cast_return(self.g(y1, record_rng=self.training, **g_args), requires_grad = False)
            y2 = x2 + g_out

        return torch.cat([y1, y2], dim=2), f_loss, g_loss
    # 定义反向传播函数,接收输入 y、梯度 dy、损失函数 dl_f 和 dl_g,以及额外参数 f_args 和 g_args
    def backward_pass(self, y, dy, dl_f, dl_g, f_args = {}, g_args = {}):
        # 将 y 沿着第二维度分成两部分 y1 和 y2
        y1, y2 = torch.chunk(y, 2, dim=2)
        # 释放 y 变量的内存
        del y

        # 将 dy 沿着第二维度分成两部分 dy1 和 dy2
        dy1, dy2 = torch.chunk(dy, 2, dim=2)
        # 释放 dy 变量的内存
        del dy

        # 设置 f_args 和 g_args 中的 '_reverse' 参数为 True
        f_args['_reverse'] = g_args['_reverse'] = True

        # 启用梯度计算环境
        with torch.enable_grad():
            # 设置 y1 可以计算梯度
            y1.requires_grad = True
            # 调用 self.g 函数计算 gy1 和 g_loss
            gy1, g_loss = cast_return(self.g(y1, set_rng=True, **g_args))
            # 反向传播计算梯度
            torch.autograd.backward((gy1, g_loss), (dy2, dl_g))

        # 禁用梯度计算环境
        with torch.no_grad():
            # 计算 x2
            x2 = y2 - gy1
            # 释放 y2 和 gy1 变量的内存
            del y2, gy1

            # 计算 dx1
            dx1 = dy1 + y1.grad
            # 释放 dy1 变量的内存
            del dy1
            # 清空 y1 的梯度
            y1.grad = None

        # 再次启用梯度计算环境
        with torch.enable_grad():
            # 设置 x2 可以计算梯度
            x2.requires_grad = True
            # 调用 self.f 函数计算 fx2 和 f_loss
            fx2, f_loss = cast_return(self.f(x2, set_rng=True, **f_args))
            # 反向传播计算梯度,保留计算图
            torch.autograd.backward((fx2, f_loss), (dx1, dl_f), retain_graph=True)

        # 禁用梯度计算环境
        with torch.no_grad():
            # 计算 x1
            x1 = y1 - fx2
            # 释放 y1 和 fx2 变量的内存
            del y1, fx2

            # 计算 dx2
            dx2 = dy2 + x2.grad
            # 释放 dy2 变量的内存
            del dy2
            # 清空 x2 的梯度
            x2.grad = None

            # 拼接 x1 和去除梯度的 x2,沿着第二维度
            x = torch.cat([x1, x2.detach()], dim=2)
            # 拼接 dx1 和 dx2,沿着第二维度
            dx = torch.cat([dx1, dx2], dim=2)

        # 返回拼接后的 x 和 dx
        return x, dx
class _ReversibleFunction(Function):
    # 静态方法,定义前向传播逻辑
    @staticmethod
    def forward(ctx, x, blocks, args):
        # 保存参数
        ctx.args = args

        # 初始化辅助损失列表
        f_aux_loss = []
        g_aux_loss = []

        # 遍历每个块并执行前向传播
        for block, kwarg in zip(blocks, args):
            x, f_loss, g_loss = block(x, **kwarg)
            f_aux_loss.append(f_loss)
            g_aux_loss.append(g_loss)

        # 保存中间结果和块信息
        ctx.y = x.detach()
        ctx.blocks = blocks
        return x, torch.stack(f_aux_loss), torch.stack(g_aux_loss)

    # 静态方法,定义反向传播逻辑
    @staticmethod
    def backward(ctx, dy, dl_f, dl_g):
        # 获取保存的中间结果和参数
        y = ctx.y
        args = ctx.args
        # 反向遍历每个块并执行反向传播
        for block, kwargs, ind in zip(ctx.blocks[::-1], args[::-1], range(len(ctx.blocks))[::-1]):
            y, dy = block.backward_pass(y, dy, dl_f[ind], dl_g[ind], **kwargs)
        return dy, None, None

class SequentialSequence(nn.Module):
    # 初始化顺序序列模块
    def __init__(self, layers, args_route = {}, layer_dropout = 0.):
        super().__init__()
        # 断言每个参数路由映射的深度与顺序层的数量相同
        assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
        self.layers = layers
        self.args_route = args_route
        self.layer_dropout = layer_dropout

    # 前向传播逻辑
    def forward(self, x, **kwargs):
        # 根据参数路由获取参数
        args = route_args(self.args_route, kwargs, len(self.layers))
        layers_and_args = list(zip(self.layers, args))

        # 如果处于训练状态且存在层丢弃率,则执行层丢弃
        if self.training and self.layer_dropout > 0:
            layers_and_args = layer_drop(layers_and_args, self.layer_dropout)

        # 初始化辅助损失
        aux_loss = torch.zeros(1, device=x.device, dtype=x.dtype)

        # 遍历每个层并执行前向传播
        for (f, g), (f_args, g_args) in layers_and_args:
            res, loss = cast_return(f(x, **f_args))
            aux_loss += loss
            x = x + res

            res, loss = cast_return(g(x, **g_args))
            aux_loss += loss
            x = x + res
        return x, aux_loss

class ReversibleSequence(nn.Module):
    # 初始化可逆序列模块
    def __init__(self, blocks, args_route = {}, layer_dropout = 0.):
        super().__init__()
        self.args_route = args_route
        self.layer_dropout = layer_dropout
        # 创建可逆块模块列表
        self.blocks = nn.ModuleList([ReversibleBlock(f, g) for f, g in blocks])

    # 前向传播逻辑
    def forward(self, x, **kwargs):
        # 将输入张量在最后一个维度上进行拼接
        x = torch.cat([x, x], dim=-1)

        blocks = self.blocks
        # 根据参数路由获取参数
        args = route_args(self.args_route, kwargs, len(blocks))
        args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))

        layers_and_args = list(zip(blocks, args))

        # 如果处于训练状态且存在层丢弃率,则执行层丢弃
        if self.training and self.layer_dropout > 0:
            layers_and_args = layer_drop(layers_and_args, self.layer_dropout)
            blocks, args = map(lambda ind: list(map(itemgetter(ind), layers_and_args)), (0, 1))

        # 调用_ReversibleFunction的apply方法执行前向传播
        out, f_loss, g_loss =  _ReversibleFunction.apply(x, blocks, args)
        # 将输出张量在最后一个维度上分割成两部分并取平均
        out = torch.stack(out.chunk(2, dim=-1)).mean(dim=0)
        # 计算辅助损失
        aux_loss = f_loss.sum() + g_loss.sum()
        return out, aux_loss

.\lucidrains\routing-transformer\routing_transformer\routing_transformer.py

# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 导入 torch 中的函数操作模块
import torch.nn.functional as F
# 导入 math 库
import math
# 从 inspect 模块中导入 isfunction 函数
from inspect import isfunction
# 从 operator 模块中导入 mul 函数
from operator import mul
# 从 functools 模块中导入 partial, reduce, wraps 函数
from functools import partial, reduce, wraps

# 从 einops 库中导入 rearrange, repeat 函数
from einops import rearrange, repeat
# 从 einops.layers.torch 模块中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 从 local_attention 模块中导入 LocalAttention 类
from local_attention import LocalAttention
# 从 product_key_memory 模块中导入 PKM 类
from product_key_memory import PKM
# 从 mixture_of_experts 模块中导入 MoE 类
from mixture_of_experts import MoE
# 从 routing_transformer.reversible 模块中导入 ReversibleSequence, SequentialSequence 类

# 常量定义

# 定义 TOKEN_SELF_ATTN_VALUE 常量为 -5e4
TOKEN_SELF_ATTN_VALUE = -5e4
# 定义 KMEAN_INIT_ITERS 常量为 10
KMEAN_INIT_ITERS = 10

# 辅助函数

# 判断值是否存在的函数
def exists(val):
    return val is not None

# 返回输入值的函数
def identity(x, *args, **kwargs):
    return x

# 如果输入值不存在,则返回默认值的函数
def default(x, d):
    if not exists(x):
        return d if not isfunction(d) else d()
    return x

# 将输入值转换为元组的函数
def cast_tuple(x):
    return x if isinstance(x, tuple) else (x,)

# 缓存函数的装饰器
def cache_fn(f):
    cache = None
    @wraps(f)
    def cached_fn(*args, **kwargs):
        nonlocal cache
        if exists(cache):
            return cache
        cache = f(*args, **kwargs)
        return cache
    return cached_fn

# 组合多个函数的函数
def compose(*fns):
    def inner(x, *args, **kwargs):
        for fn in reversed(fns):
            x = fn(x, *args, **kwargs)
        return x
    return inner

# 返回输入张量的设备和数据类型的字典的函数
def to(t):
    return {'device': t.device, 'dtype': t.dtype}

# 查找神经网络模块中指定类型的模块的函数
def find_modules(nn_module, type):
    return [module for module in nn_module.modules() if isinstance(module, type)]

# 判断张量是否为空的函数
def is_empty(t):
    return t.nelement() == 0

# 返回指定张量数据类型的最大负值的函数
def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

# 在指定维度上对张量进行批量索引选择的函数
def batched_index_select(values, indices):
    last_dim = values.shape[-1]
    return values.gather(2, expand_dim(indices, -1, last_dim))

# 合并张量的维度的函数
def merge_dims(ind_from, ind_to, tensor):
    shape = list(tensor.shape)
    arr_slice = slice(ind_from, ind_to + 1)
    shape[arr_slice] = [reduce(mul, shape[arr_slice])]
    return tensor.reshape(*shape)

# 在指定维度上扩展张量的函数
def expand_dim(t, dim, k):
    t = t.unsqueeze(dim)
    expand_shape = [-1] * len(t.shape)
    expand_shape[dim] = k
    return t.expand(*expand_shape)

# 在指定维度上对张量进行均值散开的函数
def scatter_mean(src, t, index, dim, eps = 1e-5):
    numer = src.scatter_add(dim, index, t)
    denom = src.scatter_add(dim, index, torch.ones_like(t))
    return numer / (denom + eps)

# 在指定维度上将张量拆分为两部分的函数
def split_at_index(dim, index, t):
    pre_slices = (slice(None),) * dim
    l = (*pre_slices, slice(None, index))
    r = (*pre_slices, slice(index, None))
    return t[l], t[r]

# 重塑张量的维度的函数
def reshape_dim(t, dim, split_dims):
    shape = list(t.shape)
    num_dims = len(shape)
    dim = (dim + num_dims) % num_dims
    shape[dim:dim+1] = split_dims
    return t.reshape(shape)

# 指数移动平均的函数
def ema(old, new, decay):
    if not exists(old):
        return new
    return old * decay + new * (1 - decay)

# 就地指数移动平均的函数
def ema_inplace(moving_avg, new, decay):
    if is_empty(moving_avg):
        moving_avg.data.copy_(new)
        return
    moving_avg.data.mul_(decay).add_(new, alpha= (1 - decay))

# 辅助类

# 对第一个元组或元素应用函数的类
class Chunk(nn.Module):
    def __init__(self, chunks, fn, along_dim = -1):
        super().__init__()
        self.dim = along_dim
        self.chunks = chunks
        self.fn = fn

    def forward(self, x, **kwargs):
        if self.chunks <= 1:
            return self.fn(x, **kwargs)
        chunks = x.chunk(self.chunks, dim = self.dim)
        return torch.cat([self.fn(c, **kwargs) for c in chunks], dim = self.dim)

# 具有预处理的模块列表的类
class PreNorm(nn.ModuleList):
    def __init__(self, norm_class, dim, fn):
        super().__init__()
        self.norm = norm_class(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# ReZero 模块
class ReZero(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.residual_weight = nn.Parameter(torch.zeros(1))
        self.fn = fn

    def forward(self, x, **kwargs):
        x = self.fn(x, **kwargs)
        return map_first_tuple_or_el(x, lambda t: t * self.residual_weight)
# 定义 ScaleNorm 类,用于对输入进行归一化处理
class ScaleNorm(nn.Module):
    # 初始化函数,设置归一化参数和阈值
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1))
        self.eps = eps

    # 前向传播函数,对输入进行归一化处理
    def forward(self, x):
        # 定义内部函数 norm,用于计算归一化后的值
        def norm(t):
            # 计算输入张量 t 在指定维度上的 L2 范数,并进行归一化处理
            n = torch.norm(t, dim=-1, keepdim=True).clamp(min=self.eps)
            return t / n * self.g
        # 调用 map_first_tuple_or_el 函数,对输入进行处理
        return map_first_tuple_or_el(x, norm)

# 定义 ProjectInOut 类,用于对输入进行线性投影
class ProjectInOut(nn.Module):
    # 初始化函数,设置投影函数和维度参数
    def __init__(self, fn, dim_in, dim_out, project_out = True):
        super().__init__()
        self.fn = fn
        self.project_in = nn.Linear(dim_in, dim_out)
        self.project_out = nn.Linear(dim_out, dim_in) if project_out else identity

    # 前向传播函数,对输入进行线性投影处理
    def forward(self, x, **kwargs):
        # 对输入进行投影处理
        x = self.project_in(x)
        # 调用 fn 函数处理投影后的结果
        x, loss = self.fn(x, **kwargs)
        # 对输出进行反向投影处理
        x = self.project_out(x)
        return x, loss

# 定义 MatrixMultiply 类,用于矩阵乘法操作
class MatrixMultiply(nn.Module):
    # 初始化函数,设置矩阵和是否转置参数
    def __init__(self, tensor, transpose = False):
        super().__init__()
        self.tensor = tensor
        self.transpose = transpose

    # 前向传播函数,进行矩阵乘法操作
    def forward(self, x):
        tensor = self.tensor
        # 如果需要转置,则对矩阵进行转置操作
        if self.transpose:
            tensor = tensor.t()
        return x @ tensor

# 定义 token shift 函数,用于对输入进行位移操作
def shift(t, amount, mask = None):
    # 如果位移量为 0,则直接返回输入
    if amount == 0:
        return t

    # 如果存在掩码,则根据掩码进行填充操作
    if exists(mask):
        t = t.masked_fill(~mask[..., None], 0.)

    return F.pad(t, (0, 0, amount, -amount), value = 0.)

# 定义 PreShiftTokens 类,用于对输入进行预位移操作
class PreShiftTokens(nn.Module):
    # 初始化函数,设置位移量和处理函数
    def __init__(self, shifts, fn):
        super().__init__()
        self.fn = fn
        self.shifts = tuple(shifts)

    # 前向传播函数,对输入进行预位移处理
    def forward(self, x, **kwargs):
        # 获取掩码信息
        mask = kwargs.get('mask', None)
        shifts = self.shifts
        segments = len(shifts)
        feats_per_shift = x.shape[-1] // segments
        splitted = x.split(feats_per_shift, dim = -1)
        segments_to_shift, rest = splitted[:segments], splitted[segments:]
        segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
        x = torch.cat((*segments_to_shift, *rest), dim = -1)
        return self.fn(x, **kwargs)

# 定义 FixedPositionalEmbedding 类,用于固定位置编码
class FixedPositionalEmbedding(nn.Module):
    # 初始化函数,设置维度和最大序列长度
    def __init__(self, dim, max_seq_len):
        super().__init__()
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        position = torch.arange(0, max_seq_len, dtype=torch.float)
        sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq)
        emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
        self.register_buffer('emb', emb)

    # 前向传播函数,返回固定位置编码结果
    def forward(self, x):
        return self.emb[None, :x.shape[1], :].to(x)

# 定义 rotate_every_two 函数,用于对输入进行旋转操作
def rotate_every_two(x):
    x = rearrange(x, '... (d j) -> ... d j', j = 2)
    x1, x2 = x.unbind(dim = -1)
    x = torch.stack((-x2, x1), dim = -1)
    return rearrange(x, '... d j -> ... (d j)')

# 定义 apply_rotary_pos_emb 函数,用于应用旋转位置编码
def apply_rotary_pos_emb(q, k, v, sinu_pos):
    sinu_pos = sinu_pos.type(q.dtype)
    sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j = 2)
    sin, cos = sinu_pos.unbind(dim = -2)
    sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j = 2), (sin, cos))
    q, k, v = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k, v))
    return q, k, v

# 定义 update_kmeans_on_backwards 函数,用于在反向传播时更新 kmeans 模块
def update_kmeans_on_backwards(module):
    module.kmean_modules = find_modules(module, Kmeans)
    def hook(_, grad_in, grad_out):
        for m in module.kmean_modules:
            m.update()

    return module.register_backward_hook(hook)

# 定义 similarity 函数,用于计算输入与均值之间的相似度
def similarity(x, means):
    return torch.einsum('bhld,hcd->bhlc', x, means)

# 定义 dists_and_buckets 函数,用于计算距离和分桶
def dists_and_buckets(x, means):
    dists = similarity(x, means)
    _, buckets = torch.max(dists, dim=-1)
    return dists, buckets

# 定义 batched_bincount 函数,用于批量计算索引的频次
def batched_bincount(index, num_classes, dim=-1):
    shape = list(index.shape)
    shape[dim] = num_classes
    out = index.new_zeros(shape)
    out.scatter_add_(dim, index, torch.ones_like(index, dtype=index.dtype))
    return out

# 定义 kmeans_iter 函数,用于执行 kmeans 迭代
def kmeans_iter(x, means, buckets = None):
    b, h, l, d, dtype, num_clusters = *x.shape, x.dtype, means.shape[1]
    # 如果 buckets 不存在,则通过 dists_and_buckets 函数计算出来
    if not exists(buckets):
        _, buckets = dists_and_buckets(x, means)

    # 对 buckets 进行批量计数,然后对结果进行求和
    bins = batched_bincount(buckets, num_clusters).sum(0, keepdim=True)
    # 创建一个与 bins 形状相同的布尔张量,标记 bins 中为 0 的位置
    zero_mask = bins.long() == 0

    # 创建一个与 buckets 相同形状的全零张量 means_
    means_ = buckets.new_zeros(b, h, num_clusters, d, dtype=dtype)
    # 在指定维度上对 means_ 进行 scatter_add_ 操作,将 x 散射到 means_ 上
    means_.scatter_add_(-2, expand_dim(buckets, -1, d), x)
    # 对 means_ 沿着指定维度求和,并进行归一化,然后转换为指定数据类型
    means_ = F.normalize(means_.sum(0, keepdim=True), dim=-1).type(dtype)

    # 使用 torch.where 函数根据 zero_mask 的值选择更新后的 means_ 或保持原来的 means
    means = torch.where(zero_mask.unsqueeze(-1), means, means_)
    # 去除 means 的第一个维度,返回结果
    means = means.squeeze(0)
    # 返回计算得到的 means
    return means
# 根据距离矩阵和窗口大小,获取最大的 k 个索引
_, topk_indices = dists.topk(k=window_size, dim=-2)
# 转置索引矩阵
indices = topk_indices.transpose(-2, -1)
# 重新整形索引矩阵
return indices.reshape(*indices.size()[:2], -1)

# Kmeans 类定义
class Kmeans(nn.Module):
    def __init__(self, num_heads, head_dim, num_clusters, ema_decay = 0.999, commitment = 1e-4):
        super().__init__()
        self.commitment = commitment
        self.ema_decay = ema_decay

        # 注册缓冲区,存储聚类中心和初始化状态
        self.register_buffer('means', torch.randn(num_heads, num_clusters, head_dim))
        self.register_buffer('initted', torch.tensor(False))
        self.num_new_means = 0
        self.new_means = None

    @torch.no_grad()
    def init(self, x):
        if self.initted:
            return
        _, h, _, d, device, dtype = *x.shape, x.device, x.dtype

        num_clusters = self.means.shape[1]

        # 调整输入数据形状
        means = x.transpose(0, 1).contiguous().view(h, -1, d)
        num_samples = means.shape[1]

        # 初始化聚类中心
        if num_samples >= num_clusters:
            indices = torch.randperm(num_samples, device=device)[:num_clusters]
        else:
            indices = torch.randint(0, num_samples, (num_clusters,), device=device)

        means = means[:, indices]

        # 迭代更新聚类中心
        for _ in range(KMEAN_INIT_ITERS):
            means = kmeans_iter(x, means)

        self.num_new_means = 0
        self.means.data.copy_(means)
        self.initted.data.copy_(torch.tensor(True))

    @torch.no_grad()
    def update(self, new_means = None):
        new_means = default(new_means, self.new_means)
        assert exists(new_means), 'new kmeans has not been supplied'
        # 更新聚类中心
        ema_inplace(self.means, new_means, self.ema_decay)

        del self.new_means
        self.new_means = None
        self.num_new_means = 0

    def forward(self, x, update_means = False):
        self.init(x)

        b, dtype = x.shape[0], x.dtype
        means = self.means.type(dtype)
        x = F.normalize(x, 2, dim=-1).type(dtype)

        with torch.no_grad():
            dists, buckets = dists_and_buckets(x, means)

        routed_means = batched_index_select(expand_dim(means, 0, b), buckets)
        loss = F.mse_loss(x, routed_means) * self.commitment

        if update_means:
            with torch.no_grad():
                means = kmeans_iter(x, means, buckets)
            self.new_means = ema(self.new_means, means, self.num_new_means / (self.num_new_means + 1))
            self.num_new_means += 1

        return dists, loss

# KmeansAttention 类定义
class KmeansAttention(nn.Module):
    def __init__(self, num_clusters, window_size, num_heads, head_dim, causal = False, dropout = 0., ema_decay = 0.999, commitment = 1e-4, context_window_size = None, receives_context = False, num_mem_kv = 0, shared_qk = False):
        super().__init__()
        self.num_heads = num_heads
        self.num_clusters = num_clusters
        self.head_dim = head_dim

        self.window_size = window_size
        self.context_window_size = default(context_window_size, window_size)
        self.causal = causal

        self.shared_qk = shared_qk
        self.receives_context = receives_context
        self.kmeans = Kmeans(num_heads, head_dim, num_clusters, ema_decay, commitment)
        self.dropout = nn.Dropout(dropout)

        self.num_mem_kv = max(num_mem_kv, 1 if causal and not shared_qk else 0)
        self.mem_key = nn.Parameter(torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim))
        self.mem_value = nn.Parameter(torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim))
    # 定义前向传播函数,接受查询 q、键 k、值 v,以及可选的查询和键的掩码
    def forward(self, q, k, v, query_mask = None, key_mask = None, **kwargs):
        # 解包变量 b、h、t、d、kv_t、wsz、c_wsz、nc、device、dtype
        b, h, t, d, kv_t, wsz, c_wsz, nc, device, dtype = *q.shape, k.shape[2], self.window_size, self.context_window_size, self.num_clusters, q.device, q.dtype
        # 从 kwargs 中弹出 '_reverse' 键值对,默认为 False
        is_reverse = kwargs.pop('_reverse', False)

        # 创建与 q 相同形状的零张量 out
        out = torch.zeros_like(q, dtype=dtype)

        # 更新 kmeans 模型的标志,训练中且非反向传播时更新
        update_kmeans = self.training and not is_reverse
        
        # 如果不接收上下文信息,则 key_mask 默认为 query_mask
        key_mask = default(key_mask, query_mask) if not self.receives_context else key_mask
        # 如果不接收上下文信息,则 kv_wsz 为 wsz,否则为 c_wsz
        kv_wsz = wsz if not self.receives_context else c_wsz

        # 更新 wsz 和 kv_wsz 为 t 和 kv_t 的最小值
        wsz = min(wsz, t)
        kv_wsz = min(kv_wsz, kv_t)

        # 如果不共享查询和键或者接收上下文信息
        if not self.shared_qk or self.receives_context:
            # 使用 kmeans 模型计算 q 和 k 的聚类中心距离,返回聚类中心距离和辅助损失
            dists, aux_loss = self.kmeans(torch.cat((q, k), dim=2), update_kmeans)
            # 将 dists 按索引 2 分割为 q_dists 和 k_dists
            q_dists, k_dists = split_at_index(2, t, dists)
            # 根据 q_dists 和 wsz 计算索引
            indices = distribution(q_dists, wsz)
            # 根据 k_dists 和 kv_wsz 计算索引
            kv_indices = distribution(k_dists, kv_wsz)
        else:
            # 使用 kmeans 模型计算 q 的聚类中心距离,返回聚类中心距离和辅助损失
            dists, aux_loss = self.kmeans(q, update_kmeans)
            # 对 k 进行归一化,并转换为与 q 相同的类型
            k = F.normalize(k, dim=-1).to(q)
            # 根据 dists 和 wsz 计算索引
            indices = distribution(dists, wsz)
            # kv_indices 与 indices 相同
            kv_indices = indices

        # 根据索引选择 q、k、v 的子集
        q = batched_index_select(q, indices)
        k = batched_index_select(k, kv_indices)
        v = batched_index_select(v, kv_indices)

        # 定义 reshape_with_window 函数,用于将张量重塑为指定形状
        reshape_with_window = lambda x: x.reshape(b, h, nc, -1, d)
        # 将 q、k、v 分别应用 reshape_with_window 函��
        q, k, v = map(reshape_with_window, (q, k, v))

        # 将 self.mem_key 和 self.mem_value 扩展为与 q 相同的形状
        m_k, m_v = map(lambda x: expand_dim(x, 0, b).to(q), (self.mem_key, self.mem_value))
        # 将 k、v 与 m_k、m_v 连接在最后一个维度上
        k, v = map(lambda x: torch.cat(x, dim=3), ((m_k, k), (m_v, v)))

        # 计算点积,乘以缩放因子
        dots = torch.einsum('bhnid,bhnjd->bhnij', q, k) * (d ** -0.5)

        # 计算掩码值
        mask_value = max_neg_value(dots)

        # 如果存在查询或键的掩码
        if exists(query_mask) or exists(key_mask):
            # 默认创建查询掩码为全 1,键掩码为全 1
            query_mask = default(query_mask, lambda: torch.ones((b, t), device=device).bool())
            key_mask = default(key_mask, lambda: torch.ones((b, kv_t), device=device).bool())

            # 根据 indices 和 kv_indices 从掩码中选择子集
            q_mask = expand_dim(query_mask, 1, h).gather(2, indices)
            kv_mask = expand_dim(key_mask, 1, h).gather(2, kv_indices)
            # 将 q_mask、kv_mask 重塑为指定形状
            q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (q_mask, kv_mask))
            # 创建掩码,填充边界
            mask = q_mask[:, :, :, :, None] * kv_mask[:, :, :, None, :]
            mask = F.pad(mask, (self.num_mem_kv, 0), value=True)
            # 将 dots 中不符合掩码条件的位置填充为 mask_value
            dots.masked_fill_(~mask, mask_value)
            del mask

        # 如果是因果注意力机制
        if self.causal:
            # 将 indices、kv_indices 重塑为指定形状
            q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices))
            # 创建因果掩码
            mask = q_mask[:, :, :, :, None] >= kv_mask[:, :, :, None, :]
            mask = F.pad(mask, (self.num_mem_kv, 0), value=True)
            # 将 dots 中不符合掩码条件的位置填充为 mask_value
            dots.masked_fill_(~mask, mask_value)
            del mask            

        # 如果共享查询和键
        if self.shared_qk:
            # 将 indices、kv_indices 重塑为指定形状
            q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices))
            # 创建自注意力掩码
            mask = q_mask[:, :, :, :, None] == kv_mask[:, :, :, None, :]
            mask = F.pad(mask, (self.num_mem_kv, 0), value=False)
            # 将 dots 中符合掩码条件的位置填充为 TOKEN_SELF_ATTN_VALUE
            dots.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE)
            del mask

        # 对 dots 进行 softmax 操作
        dots = dots.softmax(dim=-1)
        # 对 dots 进行 dropout 操作
        dots = self.dropout(dots)

        # 计算输出张量 bo
        bo = torch.einsum('bhcij,bhcjd->bhcid', dots, v)
        # 将 bo 重塑为指定形状
        so = torch.reshape(bo, (b, h, -1, bo.shape[-1])).type(dtype)
        # 对输出张量 out 进行 scatter_mean 操作
        out = scatter_mean(out, so, indices.unsqueeze(-1).expand_as(so), -2)
        # 返回输出张量 out 和辅助损失
        return out, aux_loss
# 定义 GELU 激活函数类
class GELU_(nn.Module):
    # 前向传播函数
    def forward(self, x):
        # GELU 激活函数的计算公式
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))

# 如果 nn 模块中存在 GELU 函数,则使用 nn.GELU,否则使用自定义的 GELU_ 函数
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_

# 定义前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False):
        super().__init__()
        # 设置激活函数为 GELU
        activation = default(activation, GELU)

        self.glu = glu
        # 第一个全连接层
        self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
        # 激活函数层
        self.act = activation()
        # Dropout 层
        self.dropout = nn.Dropout(dropout)
        # 第二个全连接层
        self.w2 = nn.Linear(dim * mult, dim)

    # 前向传播函数
    def forward(self, x, **kwargs):
        if not self.glu:
            # 非 GLU 模式下的前向传播
            x = self.w1(x)
            x = self.act(x)
        else:
            # GLU 模式下的前向传播
            x, v = self.w1(x).chunk(2, dim=-1)
            x = self.act(x) * v

        x = self.dropout(x)
        x = self.w2(x)
        return x

# 自注意力机制类
class SelfAttention(nn.Module):
    def __init__(self,  dim, depth, max_seq_len, heads, local_attn_heads, window_size, dim_head = None, local_attn_window_size = None, local_attn_radius_blocks = 1, causal = False, attn_dropout = 0., dropout = 0., kmeans_ema_decay = 0.999, commitment_factor = 1e-4, receives_context = False, context_window_size = None, rel_pos_emb = True, num_mem_kv = 0, shared_qk = False, conv_query_kernel = 9):
        super().__init__()
        # 断言确保隐藏维度可以被头数整除
        assert dim_head or (dim % heads) == 0, 'hidden dimension must be divisible by number of heads'
        # 断言确保最大序列长度可以被窗口大小整除
        assert (max_seq_len % window_size) == 0, 'maximum sequence length must be divisible by the target window size'
        # 断言确保本地注意力头数小于总头数
        assert local_attn_heads <= heads, 'number of local attention heads must be less than total heads'
        # 断言确保本地注意力和上下文注意力不能同时使用
        assert not (receives_context and local_attn_heads > 0), 'local attention cannot be used for self attention with context'
        # 断言确保上下文注意力和因果��不能同时使用
        assert not (receives_context and causal), 'contextual attention layer cannot be causal'

        local_attn_window_size = default(local_attn_window_size, window_size)
        context_window_size = default(context_window_size, window_size)

        self.shared_qk = shared_qk
        self.receives_context = receives_context
        self.heads = heads
        self.local_attn_heads = local_attn_heads
        self.global_attn_heads = heads - local_attn_heads

        self.causal = causal
        self.window_size = window_size

        dim_head = default(dim_head, dim // heads)
        dim_heads = dim_head * heads
        self.dim_head = dim_head

        num_clusters = max_seq_len // window_size

        # 本地注意力
        local_dim_heads = dim_head * self.local_attn_heads
        if self.local_attn_heads > 0:
            rel_pos_emb_config = (dim_head, local_attn_heads) if rel_pos_emb else None
            self.local_attn = LocalAttention(dim = dim_head, window_size = local_attn_window_size, causal = causal, dropout = attn_dropout, rel_pos_emb_config = rel_pos_emb_config, look_backward = local_attn_radius_blocks, look_forward = 0 if causal else local_attn_radius_blocks)
            self.local_to_qkv = nn.Linear(dim, 3 * local_dim_heads)

        # 全局注意力
        global_dim_heads = dim_head * self.global_attn_heads
        if self.global_attn_heads > 0:
            self.global_attn = KmeansAttention(num_clusters, window_size, self.global_attn_heads, dim_head, causal = causal, dropout = attn_dropout, ema_decay = kmeans_ema_decay, commitment = commitment_factor, receives_context = receives_context, num_mem_kv = num_mem_kv, shared_qk = shared_qk)

        self.to_q = nn.Linear(dim, global_dim_heads, bias = False)
        self.to_v = nn.Linear(dim, global_dim_heads, bias = False)

        if not self.shared_qk:
            self.to_k = nn.Linear(dim, global_dim_heads, bias = False)

        # 输出
        self.to_out = nn.Linear(dim_heads, dim, bias = False)
        self.dropout = nn.Dropout(dropout)
    # 定义前向传播函数,接受输入 x 和其他参数
    def forward(self, x, context = None, input_mask = None, context_mask = None, pos_emb = None, **kwargs):
        # 断言如果需要上下文信息但未传入,则抛出异常
        assert not (self.receives_context and not exists(context)), 'context must be passed if self attention is set to receive context'
        # 获取输入 x 的形状信息
        b, t, e, h, dh = *x.shape, self.heads, self.dim_head
        # 判断是否存在局部和全局注意力头
        has_local, has_global = map(lambda x: x > 0, (self.local_attn_heads, self.global_attn_heads))

        # 定义函数用于将输入张量按照头数进行分割
        split_heads = lambda v: reshape_dim(v, -1, (-1, dh)).transpose(1, 2).contiguous()

        # 如果存在局部注意力头
        if has_local:
            # 将局部注意力头的查询、键、值分别提取出来并按头数分割
            local_qkv = self.local_to_qkv(x).chunk(3, dim=-1)
            lq, lk, lv = map(split_heads, local_qkv)

        # 如果存在全局注意力头
        if has_global:
            # 根据是否接收上下文信息选择输入作为查询和值
            kv_input = x if not self.receives_context else context

            # 将查询和值分别转换为 Q 和 V,并按头数分割
            q, v = self.to_q(x), self.to_v(kv_input)

            # 如果不共享 Q 和 K,则将键也转换为 K,否则根据是否接收上下文信息选择使用 Q 或者 K
            if not self.shared_qk:
                k = self.to_k(kv_input)
            else:
                k = self.to_q(kv_input) if self.receives_context else q

            q, k, v = map(split_heads, (q, k, v))

        # 初始化输出列表和总损失
        out = []
        total_loss = torch.tensor(0., requires_grad=True, **to(x))

        # 如果存在局部注意力头
        if has_local:
            # 使用局部注意力计算输出
            local_out = self.local_attn(lq, lk, lv, input_mask = input_mask)
            out.append(local_out)

        # 如果存在全局注意力头
        if has_global:
            # 如果不接收上下文信息且存在位置编码,则应用位置编码
            if not self.receives_context and exists(pos_emb):
                q, k, v = apply_rotary_pos_emb(q, k, v, pos_emb)

            # 使用全局注意力计算输出和损失
            global_out, loss = self.global_attn(q, k, v, query_mask = input_mask, key_mask = context_mask)
            total_loss = total_loss + loss

            out.append(global_out)

        # 将所有输出拼接在一起
        out = torch.cat(out, dim=1)
        # 重塑输出张量的形状
        out = out.reshape(b, h, t, -1).transpose(1, 2).reshape(b, t, -1)
        # 将输出传递给输出层,并应用 dropout
        out = self.to_out(out)
        return self.dropout(out), total_loss
class RoutingTransformer(nn.Module):
    # 定义一个路由变换器类,继承自 nn.Module
    def __init__(
        self,
        dim,
        depth,
        max_seq_len,
        heads = 8,
        dim_head = None,
        window_size = 64,
        local_attn_window_size = 256,
        local_attn_radius_blocks = 1,
        causal = False,
        weight_tie = False,
        attn_dropout = 0.,
        ff_dropout = 0.,
        attn_layer_dropout = 0.,
        layer_dropout = 0.,
        n_local_attn_heads = 0,
        ff_glu = False,
        reversible = False,
        ff_chunks = 1,
        kmeans_ema_decay = 0.999,
        commitment_factor = 1e-4,
        receives_context = False,
        context_window_size = None,
        _register_kmeans_update = False,
        rel_pos_emb = True,
        pkm_layers = tuple(),
        pkm_num_keys = 128,
        moe_layers = tuple(),
        moe_num_experts = 4,
        moe_loss_coef = 1e-2,
        num_mem_kv = 0,
        shared_qk = None,
        context_shared_qk = False,
        use_rezero = False,
        use_scale_norm = False,
        ff_activation = None,
        shift_tokens = False
    # 初始化函数,设置路由变换器的各种参数
    def cancel_kmeans_update(self):
        # 取消 K-means 更新
        if not exists(self._handle):
            return
        self._handle.remove()
        self._handle = None

    def register_kmeans_update(self):
        # 注册 K-means 更新
        self._handle = update_kmeans_on_backwards(self)

    def forward(self, x, **kwargs):
        # 前向传播函数
        x, loss = self.layers(x, **kwargs)
        return x, loss

class RoutingTransformerLM(nn.Module):
    # 定义一个路由变换器语言模型类,继承自 nn.Module
    def __init__(
        self,
        num_tokens,
        dim,
        depth,
        max_seq_len,
        heads = 8,
        dim_head = 64,
        window_size = 64,
        local_attn_window_size = None,
        local_attn_radius_blocks = 1,
        causal = False,
        emb_dim = None,
        weight_tie = False,
        attn_dropout = 0.,
        ff_dropout = 0.,
        attn_layer_dropout = 0.,
        layer_dropout = 0.,
        ff_mult = 4,
        ff_activation = None,
        ff_glu = False,
        return_embeddings = False,
        n_local_attn_heads = 0,
        reversible = False,
        ff_chunks = 1,
        kmeans_ema_decay = 0.999,
        commitment_factor = 1e-4,
        receives_context = False,
        context_window_size = None,
        rel_pos_emb = True,
        _register_kmeans_update = True,
        pkm_layers = tuple(),
        pkm_num_keys = 128,
        moe_layers = tuple(),
        moe_num_experts = 4,
        moe_loss_coef = 1e-2,
        num_mem_kv = 0,
        shared_qk = None,
        context_shared_qk = False,
        use_rezero = False,
        use_scale_norm = False,
        tie_embedding = False,
        use_absolute_pos_emb = False,
        shift_tokens = False
    # 初始化函数,设置路由变换器语言模型的各种参数
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言最大序列长度必须能被窗口大小整除,以计算 kmeans 簇的数量
        assert (max_seq_len % window_size) == 0, 'max sequence length must be divisible by the window size, to calculate number of kmeans cluster'
        # 如果未指定嵌入维度,则使用默认维度
        emb_dim = default(emb_dim, dim)

        # 初始化最大序列长度和正弦位置编码
        self.max_seq_len = max_seq_len
        self.sinu_pos_emb = FixedPositionalEmbedding(dim_head, max_seq_len)

        # 初始化标记嵌入层
        self.token_emb = nn.Embedding(num_tokens, emb_dim)
        # 使用正态分布初始化权重
        nn.init.normal_(self.token_emb.weight, std = 0.02)

        # 初始化路由变换器
        self.routing_transformer = RoutingTransformer(dim, depth, max_seq_len, heads = heads, dim_head = dim_head, window_size = window_size, local_attn_window_size = local_attn_window_size, local_attn_radius_blocks = local_attn_radius_blocks, causal = causal, weight_tie = weight_tie, ff_dropout = ff_dropout, attn_dropout = attn_dropout, attn_layer_dropout = attn_layer_dropout, layer_dropout = layer_dropout, n_local_attn_heads = n_local_attn_heads, ff_glu = ff_glu, reversible = reversible, ff_chunks = ff_chunks, kmeans_ema_decay = kmeans_ema_decay, receives_context = receives_context, context_window_size = context_window_size, rel_pos_emb = rel_pos_emb, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys,  moe_layers = moe_layers, moe_num_experts = moe_num_experts, moe_loss_coef = moe_loss_coef, num_mem_kv = num_mem_kv, shared_qk = shared_qk, context_shared_qk = context_shared_qk, _register_kmeans_update = _register_kmeans_update, use_rezero = use_rezero, use_scale_norm = use_scale_norm, ff_activation = ff_activation, shift_tokens = shift_tokens)

        # 如果嵌入维度不等于维度,则使用 ProjectInOut 进行维度转换
        if emb_dim != dim:
            self.routing_transformer = ProjectInOut(self.routing_transformer, emb_dim, dim, project_out = not return_embeddings)

        # 初始化 LayerNorm 层
        self.norm = nn.LayerNorm(emb_dim)

        # 根据返回嵌入标志选择输出层
        if return_embeddings:
            self.out = nn.Identity()
        elif tie_embedding:
            self.out = MatrixMultiply(self.token_emb.weight, transpose = True)
        else:
            self.out = nn.Linear(emb_dim, num_tokens)

    # 取消 kmeans 更新
    def cancel_kmeans_update(self):
        # 找到 RoutingTransformer 模块并取消 kmeans 更新
        transformer = find_modules(self, RoutingTransformer)[0]
        transformer.cancel_kmeans_update()

    # ���新 kmeans
    def update_kmeans(self):
        # 对于所有的 Kmeans 模块,执行更新
        for m in find_modules(self, Kmeans):
            m.update()

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 对输入进行标记嵌入
        x = self.token_emb(x)

        # 计算旋转位置编码
        rotary_pos_emb = self.sinu_pos_emb(x)
        # 使用路由变换器进行前向传播
        x, loss = self.routing_transformer(x, pos_emb = rotary_pos_emb, **kwargs)

        # 对输出进行 LayerNorm
        x = self.norm(x)
        # 返回输出和损失
        return self.out(x), loss

.\lucidrains\routing-transformer\routing_transformer\__init__.py

# 从 routing_transformer 包中导入 RoutingTransformer、RoutingTransformerLM、KmeansAttention、update_kmeans_on_backwards 类
from routing_transformer.routing_transformer import RoutingTransformer, RoutingTransformerLM, KmeansAttention, update_kmeans_on_backwards
# 从 routing_transformer 包中导入 RoutingTransformerEncDec 类
from routing_transformer.encoder_decoder import RoutingTransformerEncDec
# 从 routing_transformer 包中导入 AutoregressiveWrapper 类
from routing_transformer.autoregressive_wrapper import AutoregressiveWrapper
# 从 routing_transformer 包中导入 Autopadder 类
from routing_transformer.autopadder import Autopadder

.\lucidrains\routing-transformer\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'routing_transformer',  # 包的名称
  packages = find_packages(exclude=['examples']),  # 查找并包含除了 examples 之外的所有包
  version = '1.6.1',  # 版本号
  license='MIT',  # 许可证
  description = 'Routing Transformer (Pytorch)',  # 描述
  author = 'Phil Wang, Aran Komatsuzaki',  # 作者
  author_email = 'lucidrains@gmail.com, aran1234321@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/routing-transformer',  # 项目链接
  keywords = ['transformers', 'attention', 'artificial intelligence'],  # 关键词
  install_requires=[
      'einops',  # 安装所需的依赖包
      'local-attention>=1.4.0',
      'mixture-of-experts>=0.2.0',
      'product-key-memory',
      'torch'
  ],
  classifiers=[
      'Development Status :: 4 - Beta',  # 分类器
      'Intended Audience :: Developers',
      'Topic :: Scientific/Engineering :: Artificial Intelligence',
      'License :: OSI Approved :: MIT License',
      'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/

RQ-Transformer

Implementation of RQ Transformer, which proposes a more efficient way of training multi-dimensional sequences autoregressively. This repository will only contain the transformer for now. You can use this vector quantization library for the residual VQ.

This type of axial autoregressive transformer should be compatible with memcodes, proposed in NWT. It would likely also work well with multi-headed VQ

Install

$ pip install RQ-transformer

Usage

import torch
from rq_transformer import RQTransformer

model = RQTransformer(
    num_tokens = 16000,             # number of tokens, in the paper they had a codebook size of 16k
    dim = 512,                      # transformer model dimension
    max_spatial_seq_len = 1024,     # maximum positions along space
    depth_seq_len = 4,              # number of positions along depth (residual quantizations in paper)
    spatial_layers = 8,             # number of layers for space
    depth_layers = 4,               # number of layers for depth
    dim_head = 64,                  # dimension per head
    heads = 8,                      # number of attention heads
)

x = torch.randint(0, 16000, (1, 1024, 4))

loss = model(x, return_loss = True)
loss.backward()

# then after much training

logits = model(x)

# and sample from the logits accordingly
# or you can use the generate function

sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)

I also think there is something deeper going on, and have generalized this to any number of dimensions. You can use it by importing the HierarchicalCausalTransformer

import torch
from rq_transformer import HierarchicalCausalTransformer

model = HierarchicalCausalTransformer(
    num_tokens = 16000,                   # number of tokens
    dim = 512,                            # feature dimension
    dim_head = 64,                        # dimension of attention heads
    heads = 8,                            # number of attention heads
    depth = (4, 4, 2),                    # 3 stages (but can be any number) - transformer of depths 4, 4, 2
    max_seq_len = (16, 4, 5)              # the maximum sequence length of first, stage, then the fixed sequence length of all subsequent stages
).cuda()

x = torch.randint(0, 16000, (1, 10, 4, 5)).cuda()

loss = model(x, return_loss = True)
loss.backward()

# after a lot training

sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 16, 4, 5)

Todo

  • move hierarchical causal transformer to separate repository, seems to be working

Citations

@unknown{unknown,
    author  = {Lee, Doyup and Kim, Chiheon and Kim, Saehoon and Cho, Minsu and Han, Wook-Shin},
    year    = {2022},
    month   = {03},
    title   = {Autoregressive Image Generation using Residual Quantization}
}
@misc{press2021ALiBi,
    title   = {Train Short, Test Long: Attention with Linear Biases Enable Input Length Extrapolation},
    author  = {Ofir Press and Noah A. Smith and Mike Lewis},
    year    = {2021},
    url     = {https://ofir.io/train_short_test_long.pdf}
}

.\lucidrains\RQ-Transformer\rq_transformer\hierarchical_causal_transformer.py

# 导入数学库
import math
# 导入 functools 库
import functools
# 导入 torch 库
import torch
# 导入 torch.nn.functional 库
import torch.nn.functional as F
# 从 torch 中导入 nn 和 einsum
from torch import nn, einsum
# 从 einops_exts 中导入 rearrange_with_anon_dims
from einops_exts import rearrange_with_anon_dims
# 从 einops 中导入 rearrange, reduce, repeat

# helpers

# 判断值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 计算 num 与 mult 的余数
def remainder_to_mult(num, mult):
    return (mult - num % mult) % mult

# 将输入转换为元组
def cast_tuple(t, length = 1):
    return t if isinstance(t, tuple) else ((t,) * length)

# 对多个数进行乘法运算
def reduce_mult(nums):
    return functools.reduce(lambda x, y: x * y, nums, 1)

# tensor helpers

# 计算张量的对数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 生成 Gumbel 分布采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

# 获取前 k 个最大值的概率
def top_k(logits, thres = 0.5):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# positional bias

# 定义 Alibi 类
class Alibi(nn.Module):
    def __init__(self, heads, **kwargs):
        super().__init__()
        self.heads = heads
        slopes = torch.Tensor(self._get_slopes(heads))
        slopes = rearrange(slopes, 'h -> h 1 1')
        self.register_buffer('slopes', slopes, persistent = False)
        self.register_buffer('bias', None, persistent = False)

    @staticmethod
    def _get_slopes(heads):
        def get_slopes_power_of_2(n):
            start = (2**(-2**-(math.log2(n)-3)))
            ratio = start
            return [start*ratio**i for i in range(n)]

        if math.log2(heads).is_integer():
            return get_slopes_power_of_2(heads)

        closest_power_of_2 = 2 ** math.floor(math.log2(heads))
        return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]

    def forward(self, i, j, device):
        if exists(self.bias) and self.bias.shape[-1] >= j:
            return self.bias[..., :j]

        bias = torch.arange(j, device = device)
        bias = rearrange(bias, 'j -> 1 1 j')
        bias = bias * self.slopes

        self.register_buffer('bias', bias, persistent = False)
        return self.bias

# norm

# 定义 RMSNorm 类
class RMSNorm(nn.Module):
    def __init__(self, dim, eps = 1e-8):
        super().__init__()
        self.scale = dim ** -0.5
        self.eps = eps
        self.g = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
        return x / norm.clamp(min = self.eps) * self.g

# helper classes

# 定义 FeedForward 函数
def FeedForward(*, dim, mult = 4, dropout = 0.):
    return nn.Sequential(
        RMSNorm(dim),
        nn.Linear(dim, dim * mult),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim * mult, dim)
    )

# 定义 Attention 类
class Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        dropout = 0.
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.dropout = nn.Dropout(dropout)
        self.norm = RMSNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)
    # 实现自注意力机制的前向传播
    def forward(self, x, attn_bias = None):
        # 获取头数和设备信息
        h, device = self.heads, x.device

        # 对输入进行归一化处理
        x = self.norm(x)
        # 将输入转换为查询、键、值
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
        # 将查询向量重新排列为多头形式
        q = rearrange(q, 'b n (h d) -> b h n d', h = h)

        # 缩放查询向量
        q = q * self.scale
        # 计算注意力分数
        sim = einsum('b h i d, b j d -> b h i j', q, k)

        # 如果存在注意力偏置,则加上
        if exists(attn_bias):
            sim = sim + attn_bias

        # 创建掩码
        i, j = sim.shape[-2:]
        mask_value = -torch.finfo(sim.dtype).max
        mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
        sim = sim.masked_fill(mask, mask_value)

        # 对注意力分数进行归一化处理
        sim = sim - sim.amax(dim = -1, keepdim = True).detach()
        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        # 计算输出
        out = einsum('b h i j, b j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
# 定义一个名为 Transformer 的类,继承自 nn.Module
class Transformer(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        *,
        dim,  # 维度
        layers,  # 层数
        dim_head = 64,  # 头部维度
        heads = 8,  # 头部数量
        attn_dropout = 0.,  # 注意力机制的 dropout
        ff_dropout = 0.,  # 前馈神经网络的 dropout
        ff_mult = 4,  # 前馈神经网络的倍数
        rel_pos_bias = True  # 是否使用相对位置偏置
    ):
        super().__init__()
        # 如果使用相对位置偏置,则创建 Alibi 对象,否则为 None
        self.alibi = Alibi(heads = heads) if rel_pos_bias else None
        # 创建空的 nn.ModuleList 对象
        self.layers = nn.ModuleList([])

        # 循环创建 layers 个层
        for _ in range(layers):
            # 每个层包含一个注意力机制和一个前馈神经网络
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
            ]))

        # 创建 RMSNorm 对象
        self.norm = RMSNorm(dim)

    # 前向传播函数
    def forward(self, x):
        # 获取输入张量 x 的倒数第二个维度的大小
        n = x.shape[-2]
        # 如果存在相对位置偏置,则根据输入张量 x 的设备创建注意力偏置
        attn_bias = self.alibi(n, n, device = x.device) if exists(self.alibi) else None

        # 遍历每个层中的注意力机制和前馈神经网络
        for attn, ff in self.layers:
            # 使用注意力机制处理输入张量 x,并加上原始输入
            x = attn(x, attn_bias = attn_bias) + x
            # 使用前馈神经网络处理输入张量 x,并加上原始输入
            x = ff(x) + x

        # 返回经过归一化处理后的结果
        return self.norm(x)

# 主类
class HierarchicalCausalTransformer(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        *,
        num_tokens,  # 标记数量
        dim,  # 维度
        depth,  # 深度
        max_seq_len,  # 最大序列长度
        dim_head = 64,  # 头部维度
        heads = 8,  # 头部数量
        attn_dropout = 0.,  # 注意力机制的 dropout
        ff_mult = 4,  # 前馈神经网络的倍数
        ff_dropout = 0.,  # 前馈神经网络的 dropout
        pad_id = 0,  # 填充标记的 id
        rel_pos_bias = True  # 是否使用相对位置偏置
    ):
        super().__init__()

        # 简化每个层次的配置
        # depth = (2, 2, 4) ���示第一阶段深度为 2,第二阶段深度为 2,第三阶段深度为 4
        # max_seq_len = (16, 8, 4) 表示第一阶段最大序列长度为 16,第二阶段为 8,第三阶段为 4

        assert isinstance(depth, tuple) and isinstance(max_seq_len, tuple)
        assert len(depth) == len(max_seq_len)

        # 阶段数量为深度元组的长度
        self.stages = len(depth)

        # 创建标记嵌入层
        self.token_emb = nn.Embedding(num_tokens, dim)
        # 创建起始标记参数
        self.start_tokens = nn.Parameter(torch.randn(dim))

        # 最大序列长度和位置嵌入层列表
        self.max_seq_len = max_seq_len
        self.pos_embs = nn.ModuleList([nn.Embedding(seq_len, dim) for seq_len in max_seq_len])

        # 创建 Transformer 模块列表
        self.transformers = nn.ModuleList([])

        # 遍历每个阶段的深度
        for stage_depth in depth:
            # 创建 Transformer 模块并添加到列表中
            self.transformers.append(Transformer(
                dim = dim,
                layers = stage_depth,
                dim_head = dim_head,
                heads = heads,
                attn_dropout = attn_dropout,
                ff_dropout = ff_dropout,
                ff_mult = ff_mult,
                rel_pos_bias = rel_pos_bias
            ))

        # 创建线性层用于输出标记
        self.to_logits = nn.Linear(dim, num_tokens)
        # 填充标记的 id
        self.pad_id = pad_id

    # 生成函数
    def generate(self, prime = None, filter_thres = 0.9, temperature = 1., default_batch_size = 1):
        # 计算总序列长度
        total_seq_len = reduce_mult(self.max_seq_len)
        # 获取设备
        device = next(self.parameters()).device

        # 如果 prime 为空,则创建一个空的张量
        if not exists(prime):
            prime = torch.empty((default_batch_size, 0), dtype = torch.long, device = device)

        # 初始化序列为 prime
        seq = prime

        # 循环生成序列
        for _ in range(total_seq_len - seq.shape[-1]):
            # 获取 logits
            logits = self.forward(seq)[:, -1]
            # 根据 filter_thres 过滤 top-k logits
            logits = top_k(logits, thres = filter_thres)
            # 使用 Gumbel 分布采样
            sampled = gumbel_sample(logits, dim = -1, temperature = temperature)
            # 将采样结果拼接到序列中
            seq = torch.cat((seq, rearrange(sampled, 'b -> b 1')), dim = -1)

        # 重新排列序列并返回
        return rearrange_with_anon_dims(seq, 'b (...d) -> b ...d', d = self.max_seq_len)

    # 空输入前向传播函数
    def forward_empty(self, batch_size):
        # 处理特殊情况,从输入为 0(仅起始标记)的样本中采样

        # 重复起始标记,创建 tokens 张量
        tokens = repeat(self.start_tokens, 'd -> b 1 d', b = batch_size)

        # 遍历每个 Transformer 模块
        for transformer in self.transformers:
            tokens = transformer(tokens)

        # 返回 logits
        return self.to_logits(tokens)
    # 定义前向传播函数,接受输入 ids 和是否返回损失值的标志
    def forward(self, ids, return_loss = False):
        # 断言输入 ids 的维度为 2 或者 self.stages + 1
        assert ids.ndim in {2, self.stages + 1}
        # 检查是否为扁平化维度
        flattened_dims = ids.ndim == 2
        # 保存原始 ids 的维度
        ids_orig_ndim = ids.ndim

        # 如果 ids 为空,则调用 forward_empty 函数
        if ids.numel() == 0:
            return self.forward_empty(ids.shape[0])

        # 如果是扁平化维度,则进行自动填充
        if flattened_dims:
            # 获取序列长度
            seq_len = ids.shape[-1]
            # 计算填充值
            multiple_of = reduce_mult(self.max_seq_len[1:])
            padding = remainder_to_mult(seq_len, multiple_of)
            # 对 ids 进行填充和重新排列
            ids = F.pad(ids, (0, padding), value = self.pad_id)
            ids = rearrange_with_anon_dims(ids, 'b (l ...d) -> b l ...d', d = self.max_seq_len[1:])

        # 获取 ids 的形状和设备信息
        b, *prec_dims, device = *ids.shape, ids.device

        # 检查一些维度

        assert prec_dims[0] <= self.max_seq_len[0], 'the first dimension of your axial autoregressive transformer must be less than the first tuple element of max_seq_len (like any autoregressive transformer)'
        assert tuple(prec_dims[1:]) == tuple(self.max_seq_len[1:]), 'all subsequent dimensions must match exactly'

        # 获取 token embeddings

        tokens = self.token_emb(ids)

        # 获取所有层次阶段的 tokens,减少适当的维度并添加绝对位置嵌入

        tokens_at_stages = []
        reduced_tokens = tokens

        for ind, pos_emb in zip(range(len(prec_dims)), reversed(self.pos_embs)):
            is_first = ind == 0

            if not is_first:
                reduced_tokens = reduce(reduced_tokens, 'b ... r d -> b ... d', 'sum')

            positions = pos_emb(torch.arange(reduced_tokens.shape[-2], device = device))
            tokens_with_position = reduced_tokens + positions
            tokens_at_stages.insert(0, tokens_with_position)

        # 获取起始 tokens 并附加到最粗糙的阶段

        start_tokens = repeat(self.start_tokens, 'f -> b 1 f', b = b)

        # 空间 tokens 是在深度 pos 减少的 tokens + 空间位置

        for ind, (stage_tokens, transformer) in enumerate(zip(tokens_at_stages, self.transformers)):
            is_last = ind == (self.stages - 1)

            stage_tokens = torch.cat((
                start_tokens,
                stage_tokens,
            ), dim = -2)

            *prec_dims, _, _ = stage_tokens.shape

            stage_tokens = rearrange(stage_tokens, '... n d -> (...) n d')
            attended = transformer(stage_tokens)
            attended = rearrange_with_anon_dims(attended, '(...b) n d -> ...b n d', b = prec_dims)

            start_tokens = rearrange(attended[..., :-1, :], '... n d -> ... n 1 d')

        logits = self.to_logits(attended)

        logits = logits[..., 1:, :]

        # 如果不需要返回损失值

        if not return_loss:

            if flattened_dims:
                logits = rearrange(logits, 'b ... n -> b (...) n')
                logits = logits[:, :seq_len]

            return logits

        preds = rearrange(logits, 'b ... c -> b c (...)')
        labels = rearrange(ids, 'b ... -> b (...)')

        # 计算交叉熵损失
        loss = F.cross_entropy(
            preds[..., :-1],
            labels[..., 1:],
            ignore_index = self.pad_id
        )
        return loss