Lucidrains-系列项目源码解析-五十二-

146 阅读23分钟

Lucidrains 系列项目源码解析(五十二)

Llama - QRLHF (wip)

Implementation of the Llama (or any language model) architecture with RLHF + Q-learning.

This is experimental / independent open research, built off nothing but speculation. But I'll throw some of my brain cycles at the problem in the coming month, just in case the rumors have any basis. Anything you PhD students can get working is up for grabs.

Will start off by adapting the autoregressive discrete Q-learning formulation in the cited paper below and run a few experiments on arithmetic, using a symbolic solver as reward generator.

Yannic Kilcher's educational Q-learning video

Citations

@inproceedings{qtransformer,
    title   = {Q-Transformer: Scalable Offline Reinforcement Learning via Autoregressive Q-Functions},
    authors = {Yevgen Chebotar and Quan Vuong and Alex Irpan and Karol Hausman and Fei Xia and Yao Lu and Aviral Kumar and Tianhe Yu and Alexander Herzog and Karl Pertsch and Keerthana Gopalakrishnan and Julian Ibarz and Ofir Nachum and Sumedh Sontakke and Grecia Salazar and Huong T Tran and Jodilyn Peralta and Clayton Tan and Deeksha Manjunath and Jaspiar Singht and Brianna Zitkovich and Tomas Jackson and Kanishka Rao and Chelsea Finn and Sergey Levine},
    booktitle = {7th Annual Conference on Robot Learning},
    year   = {2023}
}
@inproceedings{Wang2015DuelingNA,
    title   = {Dueling Network Architectures for Deep Reinforcement Learning},
    author  = {Ziyun Wang and Tom Schaul and Matteo Hessel and H. V. Hasselt and Marc Lanctot and Nando de Freitas},
    booktitle = {International Conference on Machine Learning},
    year    = {2015},
    url     = {https://api.semanticscholar.org/CorpusID:5389801}
}

.\lucidrains\llama-qrlhf\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'llama-qrlhf', # 包的名称
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.0.1', # 版本号
  license='MIT', # 许可证
  description = 'Experimental Q-RLHF applied to Language Modeling. Made compatible with Llama of course', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/llama-qrlhf', # 项目链接
  keywords = [
    'artificial intelligence',
    'deep learning',
    'reinforcement learning with human feedback',
    'q learning',
  ], # 关键词
  install_requires = [
    'accelerate',
    'beartype',
    'ema-pytorch',
    'einops>=0.7.0',
    'torch>=2.0'
  ], # 安装依赖
  classifiers=[
    'Development Status :: 4 - Beta', # 开发状态
    'Intended Audience :: Developers', # 目标受众
    'Topic :: Scientific/Engineering :: Artificial Intelligence', # 主题
    'License :: OSI Approved :: MIT License', # 许可证
    'Programming Language :: Python :: 3.6', # 编程语言
  ],
)

Llama2 - Nim (wip)

Basically a transcription of Andrej Karpathy's Llama.c to Nim. Just to gain more experience with Nim.

Data source

The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/

.\lucidrains\local-attention\local_attention\local_attention.py

# 导入数学库
import math

# 导入 torch 库
import torch
from torch import nn, einsum
import torch.nn.functional as F

# 导入 einops 库中的函数
from einops import rearrange, repeat, pack, unpack

# 导入 rotary 模块中的函数
from local_attention.rotary import SinusoidalEmbeddings, apply_rotary_pos_emb

# 常量定义
TOKEN_SELF_ATTN_VALUE = -5e4

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(value, d):
    return d if not exists(value) else value

# 返回张量的设备和数据类型
def to(t):
    return {'device': t.device, 'dtype': t.dtype}

# 返回张量的最大负值
def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

# 对张量进行 L2 归一化
def l2norm(tensor):
    dtype = tensor.dtype
    normed = F.normalize(tensor, dim = -1)
    return normed.type(dtype)

# 将张量填充到指定的倍数
def pad_to_multiple(tensor, multiple, dim=-1, value=0):
    seqlen = tensor.shape[dim]
    m = seqlen / multiple
    if m.is_integer():
        return False, tensor
    remainder = math.ceil(m) * multiple - seqlen
    pad_offset = (0,) * (-1 - dim) * 2
    return True, F.pad(tensor, (*pad_offset, 0, remainder), value = value)

# 在张量周围添加填充
def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2):
    t = x.shape[1]
    dims = (len(x.shape) - dim) * (0, 0)
    padded_x = F.pad(x, (*dims, backward, forward), value = pad_value)
    tensors = [padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)]
    return torch.cat(tensors, dim = dim)

# 主类

class LocalAttention(nn.Module):
    def __init__(
        self,
        window_size,
        causal = False,
        look_backward = 1,
        look_forward = None,
        dropout = 0.,
        shared_qk = False,
        rel_pos_emb_config = None,
        dim = None,
        autopad = False,
        exact_windowsize = False,
        scale = None,
        use_rotary_pos_emb = True,
        use_xpos = False,
        xpos_scale_base = None
    ):
        super().__init__()
        look_forward = default(look_forward, 0 if causal else 1)
        assert not (causal and look_forward > 0), 'you cannot look forward if causal'

        self.scale = scale

        self.window_size = window_size
        self.autopad = autopad
        self.exact_windowsize = exact_windowsize

        self.causal = causal

        self.look_backward = look_backward
        self.look_forward = look_forward

        self.dropout = nn.Dropout(dropout)

        self.shared_qk = shared_qk

        # 相对位置编码

        self.rel_pos = None
        self.use_xpos = use_xpos

        if use_rotary_pos_emb and (exists(rel_pos_emb_config) or exists(dim)):  # 向后兼容旧的 `rel_pos_emb_config` 参数
            if exists(rel_pos_emb_config):
                dim = rel_pos_emb_config[0]

            self.rel_pos = SinusoidalEmbeddings(
                dim,
                use_xpos = use_xpos,
                scale_base = default(xpos_scale_base, window_size // 2)
            )

    def forward(
        self,
        q, k, v,
        mask = None,
        input_mask = None,
        attn_bias = None,
        window_size = None

.\lucidrains\local-attention\local_attention\rotary.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum

# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 定义一个函数,用于检查变量是否存在
def exists(val):
    return val is not None

# 定义一个继承自 nn.Module 的类 SinusoidalEmbeddings
class SinusoidalEmbeddings(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        scale_base = None,
        use_xpos = False
    ):
        super().__init__()
        # 计算频率的倒数
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        # 将频率的倒数作为缓冲区注册到模型中
        self.register_buffer('inv_freq', inv_freq)

        # xpos 相关

        # 是否使用 xpos
        self.use_xpos = use_xpos
        # 缩放基数
        self.scale_base = scale_base

        # 断言,如果使用 xpos,则必须定义缩放基数
        assert not (use_xpos and not exists(scale_base)), 'scale base must be defined if using xpos'

        # 计算缩放值
        scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
        # 将缩放值作为缓冲区注册到模型中,不持久化
        self.register_buffer('scale', scale, persistent = False)

    # 前向传播函数
    def forward(self, x):
        # 获取序列长度和设备信息
        seq_len, device = x.shape[-2], x.device

        # 生成时间步长
        t = torch.arange(seq_len, device = x.device).type_as(self.inv_freq)
        # 计算频率
        freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
        freqs =  torch.cat((freqs, freqs), dim = -1)

        # 如果不使用 xpos,则返回频率和单位矩阵
        if not self.use_xpos:
            return freqs, torch.ones(1, device = device)

        # 计算幂次
        power = (t - (seq_len // 2)) / self.scale_base
        # 计算缩放值
        scale = self.scale ** rearrange(power, 'n -> n 1')
        scale = torch.cat((scale, scale), dim = -1)

        return freqs, scale

# 定义一个函数,用于将输入向量旋转 180 度
def rotate_half(x):
    x = rearrange(x, 'b ... (r d) -> b ... r d', r = 2)
    x1, x2 = x.unbind(dim = -2)
    return torch.cat((-x2, x1), dim = -1)

# 应用旋转位置嵌入
def apply_rotary_pos_emb(q, k, freqs, scale = 1):
    # 获取查询向量的长度
    q_len = q.shape[-2]
    # 获取查询向量的频率
    q_freqs = freqs[..., -q_len:, :]

    # 计算缩放的倒数
    inv_scale = scale ** -1

    # 如果缩放的维度为 2,则截取对应维度
    if scale.ndim == 2:
        scale = scale[-q_len:, :]

    # 对查询向量��用旋转位置嵌入
    q = (q * q_freqs.cos() * scale) + (rotate_half(q) * q_freqs.sin() * scale)
    k = (k * freqs.cos() * inv_scale) + (rotate_half(k) * freqs.sin() * inv_scale)
    return q, k

.\lucidrains\local-attention\local_attention\transformer.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 从 local_attention 包中导入 LocalAttention 类
from local_attention.local_attention import LocalAttention

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# 评估装饰器函数
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 采样函数

# 返回 logits 中大于阈值的前 k 个值
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 多头注意力机制

class LocalMHA(nn.Module):
    def __init__(
        self,
        *,
        dim,
        window_size,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        causal = False,
        prenorm = False,
        qk_rmsnorm = False,
        qk_scale = 8,
        use_xpos = False,
        xpos_scale_base = None,
        exact_windowsize = None,
        gate_values_per_head = False,
        **kwargs
    ):
        super().__init__()        
        inner_dim = dim_head * heads

        # 如果 prenorm 为 True,则使用 LayerNorm 进行归一化
        self.norm = nn.LayerNorm(dim) if prenorm else None

        self.heads = heads
        # 将输入映射到查询、键、值空间
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.qk_rmsnorm = qk_rmsnorm

        if qk_rmsnorm:
            self.q_scale = nn.Parameter(torch.ones(dim_head))
            self.k_scale = nn.Parameter(torch.ones(dim_head))

        # 使用 LocalAttention 进行局部注意力计算
        self.attn_fn = LocalAttention(
            dim = dim_head,
            window_size = window_size,
            causal = causal,
            autopad = True,
            scale = (qk_scale if qk_rmsnorm else None),
            exact_windowsize = default(exact_windowsize, True),
            use_xpos = use_xpos,
            xpos_scale_base = xpos_scale_base,
            **kwargs
        )

        self.to_v_gate = None

        if gate_values_per_head:
            self.to_v_gate = nn.Sequential(
                nn.Linear(dim, heads)
            )

        # 将输出映射回原始维度
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x, mask = None, attn_bias = None):
        if exists(self.norm):
            x = self.norm(x)

        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) 

        if self.qk_rmsnorm:
            q, k = map(l2norm, (q, k))
            q = q * self.q_scale
            k = k * self.k_scale

        out = self.attn_fn(q, k, v, mask = mask, attn_bias = attn_bias)

        if exists(self.to_v_gate):
            gates = self.to_v_gate(x)
            gates = rearrange(gates, 'b n h -> b h n 1')
            out = out * gates.sigmoid()

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 前馈网络

class GEGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return x * F.gelu(gate)

# 创建前馈网络
def FeedForward(dim, mult = 4, dropout = 0.):
    inner_dim = int(dim * mult * 2 / 3)

    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim * 2, bias = False),
        GEGLU(),
        nn.Dropout(dropout),
        nn.Linear(inner_dim, dim, bias = False)
    )

# 动态位置偏置

class DynamicPositionBias(nn.Module):
    def __init__(
        self,
        dim,
        heads
    ):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(1, dim),
            nn.SiLU(),
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Linear(dim, heads)
        )

    @property
    def device(self):
        return next(self.parameters()).device
    # 定义一个前向传播函数,接受输入参数 i 和 j
    def forward(self, i, j):
        # 获取设备信息
        device = self.device
        # 断言 j 大于等于 i
        assert j >= i

        # 创建一个相对距离张量,从 i 到 j,数据类型为浮点数,使用指定设备
        rel_dist = torch.arange(j, dtype=torch.float, device=device)
        # 使用 MLP 模型处理重新排列后的相对距离张量,得到偏置
        bias = self.mlp(rearrange(rel_dist, '... -> ... 1'))

        # 创建从 i 到 j-1 的序列张量,使用指定设备
        i_seq = torch.arange(j - i, j, device=device)
        # 创建从 0 到 j-1 的序列张量,使用指定设备
        j_seq = torch.arange(j, device=device)

        # 计算相对距离的索引,取绝对值
        rel_dist_indices = (rearrange(i_seq, 'i -> i 1') - rearrange(j_seq, 'j -> 1 j')).abs()

        # 重新排列偏置张量,根据相对距离索引,维度顺序为 h i j
        bias = rearrange(bias[rel_dist_indices], 'i j h -> h i j')
        # 返回处理后的偏置张量
        return bias
# 主要的转换器类

class LocalTransformer(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,  # 标记的数量
        max_seq_len,  # 最大序列长度
        dim,  # 维度
        depth,  # 深度
        causal = True,  # 是否使用因果注意力
        local_attn_window_size = 512,  # 本地注意力窗口大小
        dim_head = 64,  # 头部维度
        heads = 8,  # 头部数量
        ff_mult = 4,  # FeedForward 层的倍数
        attn_dropout = 0.,  # 注意力层的丢弃率
        ff_dropout = 0.,  # FeedForward 层的丢弃率
        ignore_index = -1,  # 忽略的索引
        use_xpos = False,  # 是否使用位置编码
        xpos_scale_base = None,  # 位置编码的缩放基数
        use_dynamic_pos_bias = False,  # 是否使用动态位置偏置
        **kwargs
    ):
        super().__init__()
        self.token_emb = nn.Embedding(num_tokens, dim)  # 标记嵌入层
        self.pos_emb = nn.Embedding(max_seq_len, dim)  # 位置嵌入层

        self.max_seq_len = max_seq_len  # 最大序列长度
        self.layers = nn.ModuleList([])  # 层列表

        self.local_attn_window_size = local_attn_window_size  # 本地注意力窗口大小
        self.dynamic_pos_bias = None
        if use_dynamic_pos_bias:
            self.dynamic_pos_bias = DynamicPositionBias(dim = dim // 2, heads = heads)  # 动态位置偏置

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                LocalMHA(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = causal, window_size = local_attn_window_size, use_xpos = use_xpos, xpos_scale_base = xpos_scale_base, use_rotary_pos_emb = not use_dynamic_pos_bias, prenorm = True, **kwargs),
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
            ]))  # 添加多层局部多头注意力和前馈网络

        self.ignore_index = ignore_index  # 忽略的索引
        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),  # 层归一化
            nn.Linear(dim, num_tokens, bias = False)  # 线性层
        )

    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        prime,  # 初始序列
        seq_len,  # 生成序列的长度
        temperature = 1.,  # 温度参数
        filter_thres = 0.9,  # 过滤阈值
        **kwargs
    ):
        n, device = prime.shape[1], prime.device

        out = prime

        for _ in range(seq_len):
            logits = self.forward(out[:, -self.max_seq_len:], **kwargs)  # 前向传播获取 logits
            filtered_logits = top_k(logits[:, -1], thres = filter_thres)  # 获取 top-k logits
            probs = F.softmax(filtered_logits / temperature, dim = -1)  # softmax 计算概率
            sampled = torch.multinomial(probs, 1)  # 多项式采样
            out = torch.cat((out, sampled), dim = -1)  # 将采样结果拼接到输出序列

        return out[:, n:]  # 返回生成的序列

    def forward(self, x, mask = None, return_loss = False):
        if return_loss:
            x, labels = x[:, :-1], x[:, 1:]  # 获取输入和标签序列

        n, device = x.shape[1], x.device
        x = self.token_emb(x)  # 标记嵌入

        assert n <= self.max_seq_len
        x = x + self.pos_emb(torch.arange(n, device = device))  # 添加位置编码

        # 动态位置偏置

        attn_bias = None
        if exists(self.dynamic_pos_bias):
            w = self.local_attn_window_size
            attn_bias = self.dynamic_pos_bias(w, w * 2)  # 计算注意力偏置

        # 通过层

        for attn, ff in self.layers:
            x = attn(x, mask = mask, attn_bias = attn_bias) + x  # 多头注意力层
            x = ff(x) + x  # 前馈网络

        logits = self.to_logits(x)  # 线性层得到 logits

        if not return_loss:
            return logits

        logits = rearrange(logits, 'b n c -> b c n')  # 重新排列 logits
        loss = F.cross_entropy(logits, labels, ignore_index = self.ignore_index)  # 计算交叉熵损失
        return loss  # 返回损失

.\lucidrains\local-attention\local_attention\__init__.py

# 从 local_attention 包中导入 LocalAttention 类
from local_attention.local_attention import LocalAttention
# 从 local_attention 包中导入 LocalTransformer、LocalMHA 和 DynamicPositionBias 类
from local_attention.transformer import LocalTransformer, LocalMHA, DynamicPositionBias

Local attention

An implementation of local windowed attention, which sets an incredibly strong baseline for language modeling. It is becoming apparent that a transformer needs local attention in the bottom layers, with the top layers reserved for global attention to integrate the findings of previous layers. This repository makes it easy to immediately employ local window attention.

This code has been battletested in multiple repositories already, alongside different implementations of sparse long-range attention.

Install

$ pip install local-attention

Usage

import torch
from local_attention import LocalAttention

q = torch.randn(2, 8, 2048, 64)
k = torch.randn(2, 8, 2048, 64)
v = torch.randn(2, 8, 2048, 64)

attn = LocalAttention(
    dim = 64,                # dimension of each head (you need to pass this in for relative positional encoding)
    window_size = 512,       # window size. 512 is optimal, but 256 or 128 yields good enough results
    causal = True,           # auto-regressive or not
    look_backward = 1,       # each window looks at the window before
    look_forward = 0,        # for non-auto-regressive case, will default to 1, so each window looks at the window before and after it
    dropout = 0.1,           # post-attention dropout
    exact_windowsize = False # if this is set to true, in the causal setting, each query will see at maximum the number of keys equal to the window size
)

mask = torch.ones(2, 2048).bool()
out = attn(q, k, v, mask = mask) # (2, 8, 2048, 64)

This library also allows for local attention in the setting of shared query/key space (Reformer architecture). The normalization of the keys, as well as the masking of tokens to itself, will be taken care of.

import torch
from local_attention import LocalAttention

qk = torch.randn(2, 8, 2048, 64)
v  = torch.randn(2, 8, 2048, 64)

attn = LocalAttention(
    dim = 64,
    window_size = 512,
    shared_qk = True,
    causal = True
)

mask = torch.ones(2, 2048).bool()
out = attn(qk, qk, v, mask = mask) # (2, 8, 2048, 64)

If you wish for the module to automagically pad your query / key / values as well as the mask, simply set the autopad keyword to True

import torch
from local_attention import LocalAttention

q = torch.randn(8, 2057, 64)
k = torch.randn(8, 2057, 64)
v = torch.randn(8, 2057, 64)

attn = LocalAttention(
    window_size = 512,
    causal = True,
    autopad = True      # auto pads both inputs and mask, then truncates output appropriately
)

mask = torch.ones(1, 2057).bool()
out = attn(q, k, v, mask = mask) # (8, 2057, 64)

Local Attention Transformer

A full local attention transformer

import torch
from local_attention import LocalTransformer

model = LocalTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 6,
    max_seq_len = 8192,
    causal = True,
    local_attn_window_size = 256
).cuda()

x = torch.randint(0, 256, (1, 8192)).cuda()

logits = model(x) # (1, 8192, 256)

Enwik8 at 4096

window size of 256, lookback of 1, total receptive field of 512

$ python train.py

Citation

@inproceedings{rae-razavi-2020-transformers,
    title   = "Do Transformers Need Deep Long-Range Memory?",
    author  = "Rae, Jack  and Razavi, Ali",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
    month   = jul,
    year    = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url     = "https://www.aclweb.org/anthology/2020.acl-main.672"
}
@misc{roy*2020efficient,
    title   = {Efficient Content-Based Sparse Attention with Routing Transformers},
    author  = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},
    year    = {2020},
    url     = {https://arxiv.org/pdf/2003.05997.pdf}
}
@misc{beltagy2020longformer,
    title   = {Longformer: The Long-Document Transformer},
    author  = {Iz Beltagy and Matthew E. Peters and Arman Cohan},
    year    = {2020},
    eprint  = {2004.05150},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@inproceedings{Sun2022ALT,
    title     = {A Length-Extrapolatable Transformer},
    author    = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
    year      = {2022}
}
@article{Bondarenko2023QuantizableTR,
    title   = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
    author  = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2306.12929},
    url     = {https://api.semanticscholar.org/CorpusID:259224568}
}

.\lucidrains\local-attention\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'local-attention',  # 包的名称
  packages = find_packages(),  # 查找并包含所有包
  version = '1.9.0',  # 版本号
  license='MIT',  # 许可证
  description = 'Local attention, window with lookback, for language modeling',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/local-attention',  # 项目链接
  keywords = [
    'transformers',  # 关键词:transformers
    'attention',  # 关键词:attention
    'artificial intelligence'  # 关键词:artificial intelligence
  ],
  install_requires=[
    'einops>=0.6.0',  # 安装所需的依赖项:einops>=0.6.0
    'torch'  # 安装所需的依赖项:torch
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类器:开发状态为Beta
    'Intended Audience :: Developers',  # 分类器:面向的受众为开发者
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器:主题为科学/工程和人工智能
    'License :: OSI Approved :: MIT License',  # 分类器:许可证为MIT
    'Programming Language :: Python :: 3.6',  # 分类器:编程语言为Python 3.6
  ],
)

.\lucidrains\local-attention\train.py

# 导入所需的库
import random
import tqdm
import gzip
import numpy as np

import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from local_attention import LocalTransformer

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 2048
SEQ_LEN = 2048

# 定义辅助函数

# 将 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 将 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化类似 GPT 的解码器模型
model = LocalTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 6,
    causal = True,
    local_attn_window_size = 256,
    max_seq_len = SEQ_LEN,
    use_dynamic_pos_bias = True
).cuda()

# 准备 enwik8 数据

with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = Adam(model.parameters(), lr=LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader), return_loss = True)
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader), return_loss = True)
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str)

.\lucidrains\local-attention-flax\local_attention_flax\local_attention_flax.py

# 导入必要的库
import flax.linen as nn
from jax import numpy as np
from einops import rearrange

# 定义全局变量,用于掩码操作
ATTN_MASK_VALUE = -1e10

# 定义一个名为LocalAttention的类,继承自nn.Module
class LocalAttention(nn.Module):
    # 初始化函数,接受dim(维度)、window_size(窗口大小)、heads(头数,默认为8)、dim_head(每个头的维度,默认为64)
    dim: int
    window_size: int
    heads: int = 8
    dim_head: int = 64

    # 定义__call__方法,用于实现类的调用
    @nn.compact
    def __call__(self, x):
        # 获取输入张量x的维度信息
        n, h, dim_head, wsz = x.shape[0], self.heads, self.dim_head, self.window_size
        # 断言,确保序列长度必须能被窗口大小整除
        assert (n % wsz) == 0, 'sequence length must be divisible by the window size'
        # 计算缩放因子
        scale = dim_head ** -0.5
        # 计算窗口数量
        window = n // wsz

        # 将输入张量x通过全连接层映射为qkv
        qkv = nn.Dense(features = 3 * h * dim_head, use_bias = False)(x)
        # 将qkv分割为q、k、v
        q, k, v = np.split(qkv, 3, axis = -1)
        # 重排q、k、v的维度
        q, k, v = map(lambda t: rearrange(t, '(w n) (h d) -> h w n d', w = window, h = h), (q, k, v))

        # 对k、v进行填充
        k, v = map(lambda t: np.pad(t, ((0, 0), (1, 0), (0, 0), (0, 0)), constant_values = 0.), (k ,v))
        # 对k、v进行拼接
        k, v = map(lambda t: np.concatenate((t[:, :-1], t[:, 1:]), axis = 2), (k, v))

        # 计算注意力分数
        sim = np.einsum('h w i d, h w j d -> h w i j', q, k) * scale

        # 创建掩码
        mask = np.tril(np.ones((wsz, wsz * 2)), wsz)
        # 将掩码应用到注意力分数上
        sim = np.where(mask, sim, ATTN_MASK_VALUE)

        # 计算注意力权重
        attn = nn.softmax(sim, axis = -1)
        # 计算输出张量
        out = np.einsum('h w i j, h w j d -> h w i d', attn, v)
        # 重排输出张量的维度
        out = rearrange(out, 'h w n d -> (w n) (h d)')
        # 通过全连接层映射输出张量
        out =  nn.Dense(features = self.dim)(out)
        # 返回输出张量
        return out

.\lucidrains\local-attention-flax\local_attention_flax\__init__.py

# 从 local_attention_flax 模块中导入 LocalAttention 类
from local_attention_flax.local_attention_flax import LocalAttention

Local Attention - Flax

Autoregressive Local Attention - Flax module for Jax

Install

$ pip install local-attention-flax

Usage

from jax import random
from local_attention_flax import LocalAttention

attn = LocalAttention(
    dim = 256,
    dim_head = 64,
    heads = 8,
    window_size = 128
)

key = random.PRNGKey(0)
x = random.normal(key, (512, 256))

params = attn.init(key, x)
out = attn.apply(params, x)  # (512, 256)

.\lucidrains\local-attention-flax\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
    name="local-attention-flax",  # 包的名称
    packages=find_packages(),  # 查找并包含所有包
    version="0.0.2",  # 版本号
    license="MIT",  # 许可证
    description="Local Attention - Flax Module in Jax",  # 描述
    author="Phil Wang",  # 作者
    author_email="",  # 作者邮箱
    url="https://github.com/lucidrains/local-attention-flax",  # 项目链接
    keywords=[  # 关键词列表
        "artificial intelligence",
        "deep learning",
        "attention mechanism",
        "jax"
    ],
    install_requires=[  # 安装依赖
        "einops>=0.3",
        "flax",
        "jax",
        "jaxlib"
    ],
    classifiers=[  # 分类器列表
        "Development Status :: 4 - Beta",
        "Intended Audience :: Developers",
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
        "License :: OSI Approved :: MIT License",
        "Programming Language :: Python :: 3.6",
    ],
)

.\lucidrains\logavgexp-torch\logavgexp_pytorch\logavgexp_pytorch.py

import math
from functools import partial

import torch
from torch import nn
import torch.nn.functional as F

from einops import rearrange
from unfoldNd import unfoldNd

# helper functions

# 检查变量是否存在
def exists(t):
    return t is not None

# 对张量取对数
def log(t, eps = 1e-20):
    return torch.log(t + eps)

# 将输入转换为元组
def cast_tuple(t, length = 1):
    return t if isinstance(t, tuple) else ((t,) * length)

# 计算卷积输出形状
def calc_conv_output(shape, kernel_size, padding, stride):
    return tuple(map(lambda x: int((x[0] - x[1] + 2 * x[2]) / x[3] + 1), zip(shape, kernel_size, padding, stride))

# main function

# 对输入张量进行 logavgexp 操作
def logavgexp(
    t,
    mask = None,
    dim = -1,
    eps = 1e-20,
    temp = 0.01,
    keepdim = False
):
    if exists(mask):
        mask_value = -torch.finfo(t.dtype).max
        t = t.masked_fill(~mask, mask_value)
        n = mask.sum(dim = dim)
        norm = torch.log(n)
    else:
        n = t.shape[dim]
        norm = math.log(n)

    t = t / temp
    max_t = t.amax(dim = dim).detach()
    t_exp = (t - max_t.unsqueeze(dim)).exp()
    avg_exp = t_exp.sum(dim = dim).clamp(min = eps) / n
    out = log(avg_exp, eps = eps) + max_t - norm
    out = out * temp

    out = out.unsqueeze(dim) if keepdim else out
    return out

# learned temperature - logavgexp class

# LogAvgExp 类,用于 logavgexp 操作
class LogAvgExp(nn.Module):
    def __init__(
        self,
        dim = -1,
        eps = 1e-20,
        temp = 0.01,
        keepdim = False,
        learned_temp = False
    ):
        super().__init__()
        assert temp >= 0 and temp <= 1., 'temperature must be between 0 and 1'

        self.learned_temp = learned_temp

        if learned_temp:
            self.temp = nn.Parameter(torch.ones((1,)) * math.log(temp))
        else:
            self.temp = temp

        self.dim = dim
        self.keepdim = keepdim

    def forward(self, x, mask = None, eps = 1e-8):
        if not self.learned_temp:
            temp = self.temp
        else:
            temp = self.temp.exp().clamp(min = eps)

        return logavgexp(
            x,
            mask = mask,
            dim = self.dim,
            temp = temp,
            keepdim = self.keepdim
        )

# logavgexp 2d

# LogAvgExp2D 类,用于 2D logavgexp 操作
class LogAvgExp2D(nn.Module):
    def __init__(
        self,
        kernel_size,
        *,
        padding = 0,
        stride = 1,
        temp = 0.01,
        learned_temp = True,
        eps = 1e-20,
        **kwargs
    ):
        super().__init__()
        self.padding = cast_tuple(padding, 2)
        self.stride = cast_tuple(stride, 2)
        self.kernel_size = cast_tuple(kernel_size, 2)

        self.unfold = nn.Unfold(self.kernel_size, padding = self.padding, stride = self.stride)
        self.logavgexp = LogAvgExp(dim = -1, eps = eps, learned_temp = learned_temp, temp = temp)

    def forward(self, x):
        """
        b - batch
        c - channels
        h - height
        w - width
        j - reducing dimension
        """

        b, c, h, w = x.shape
        out_h, out_w = calc_conv_output((h, w), self.kernel_size, self.padding, self.stride)

        # calculate mask for padding, if needed

        mask = None
        if any([i > 0 for i in self.padding]):
            mask = torch.ones((b, 1, h, w), device = x.device)
            mask = self.unfold(mask)
            mask = rearrange(mask, 'b j (h w) -> b 1 h w j', h = out_h, w = out_w)
            mask = mask == 1.

        x = self.unfold(x)
        x = rearrange(x, 'b (c j) (h w) -> b c h w j', h = out_h, w = out_w, c = c)
        return self.logavgexp(x, mask = mask)

# logavgexp 3d

# LogAvgExp3D 类,用于 3D logavgexp 操作
class LogAvgExp3D(nn.Module):
    def __init__(
        self,
        kernel_size,
        *,
        padding = 0,
        stride = 1,
        temp = 0.01,
        learned_temp = True,
        eps = 1e-20,
        **kwargs
    # 初始化函数,设置填充、步幅和卷积核大小
    def __init__(
        super().__init__()
        # 将填充、步幅和卷积核大小转换为元组形式
        self.padding = cast_tuple(padding, 3)
        self.stride = cast_tuple(stride, 3)
        self.kernel_size = cast_tuple(kernel_size, 3)

        # 部分应用 unfoldNd 函数,设置卷积核大小、填充和步幅
        self.unfold = partial(unfoldNd, kernel_size = self.kernel_size, padding = self.padding, stride = self.stride)
        # 初始化 LogAvgExp 函数
        self.logavgexp = LogAvgExp(dim = -1, eps = eps, learned_temp = learned_temp, temp = temp)

    # 前向传播函数
    def forward(self, x):
        """
        b - batch
        c - channels
        f - depth
        h - height
        w - width
        j - reducing dimension
        """

        # 获取输入张量的形状
        b, c, f, h, w = x.shape
        # 计算卷积输出的深度、高度和宽度
        out_f, out_h, out_w = calc_conv_output((f, h, w), self.kernel_size, self.padding, self.stride)

        # 计算是否需要填充的掩码

        mask = None
        if any([i > 0 for i in self.padding]):
            mask = torch.ones((b, 1, f, h, w), device = x.device)
            mask = self.unfold(mask)
            mask = rearrange(mask, 'b j (f h w) -> b 1 f h w j', f = out_f, h = out_h, w = out_w)
            mask = mask == 1.

        # 对输入张量进行展开操作
        x = self.unfold(x)
        x = rearrange(x, 'b (c j) (f h w) -> b c f h w j', f = out_f, h = out_h, w = out_w, c = c)
        # 调用 logavgexp 函数进行计算,传入掩码
        return self.logavgexp(x, mask = mask)

.\lucidrains\logavgexp-torch\logavgexp_pytorch\__init__.py

# 从logavgexp_pytorch.logavgexp_pytorch模块中导入logavgexp、LogAvgExp、LogAvgExp2D、LogAvgExp3D类和函数
from logavgexp_pytorch.logavgexp_pytorch import logavgexp, LogAvgExp, LogAvgExp2D, LogAvgExp3D

LogAvgExp - Pytorch

Implementation of LogAvgExp for Pytorch

Install

$ pip install logavgexp-pytorch

Usage

import torch
from logavgexp_pytorch import logavgexp

# basically it is an improved logsumexp (differentiable max)
# normalized for length

x = torch.arange(1000)
y = logavgexp(x, dim = 0, temp = 0.01) # ~998.8

# more than 1 dimension

x = torch.randn(1, 2048, 5)
y = logavgexp(x, dim = 1, temp = 0.2) # (1, 5)

# keep dimension

x = torch.randn(1, 2048, 5)
y = logavgexp(x, dim = 1, temp = 0.2, keepdim = True) # (1, 1, 5)

# masking (False for mask out with large negative value)

x = torch.randn(1, 2048, 5)
m = torch.randint(0, 2, (1, 2048, 1)).bool()

y = logavgexp(x, mask = m, dim = 1, temp = 0.2, keepdim = True) # (1, 1, 5)

With learned temperature

# learned temperature
import torch
from torch import nn
from logavgexp_pytorch import logavgexp

learned_temp = nn.Parameter(torch.ones(1) * -5).exp().clamp(min = 1e-8) # make sure temperature can't hit 0

x = torch.randn(1, 2048, 5)
y = logavgexp(x, temp = learned_temp, dim = 1) # (1, 5)

Or you can use the LogAvgExp class to handle the learned temperature parameter

import torch
from logavgexp_pytorch import LogAvgExp

logavgexp = LogAvgExp(
    temp = 0.01,
    dim = 1,
    learned_temp = True
)

x = torch.randn(1, 2048, 5)
y = logavgexp(x) # (1, 5)

LogAvgExp2D

import torch
from logavgexp_pytorch import LogAvgExp2D

logavgexp_pool = LogAvgExp2D((2, 2), stride = 2) # (2 x 2) pooling

img = torch.randn(1, 16, 64, 64)
out = logavgexp_pool(img) # (1, 16, 32, 32)

Todo

Citations

@misc{lowe2021logavgexp,
    title   = {LogAvgExp Provides a Principled and Performant Global Pooling Operator}, 
    author  = {Scott C. Lowe and Thomas Trappenberg and Sageev Oore},
    year    = {2021},
    eprint  = {2111.01742},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

.\lucidrains\logavgexp-torch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'logavgexp-pytorch', # 包的名称
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.0.6', # 版本号
  license='MIT', # 许可证
  description = 'LogAvgExp - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/logavgexp-pytorch', # 项目链接
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'pytorch',
    'logsumexp'
  ],
  install_requires=[ # 安装依赖
    'einops>=0.4.1',
    'torch>=1.6',
    'unfoldNd'
  ],
  classifiers=[ # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/

.\lucidrains\long-short-transformer\long_short_transformer\autoregressive_wrapper.py

import torch
from torch import nn
import torch.nn.functional as F

# 定义一个装饰器函数,用于在模型评估时切换模型状态
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 定义一个函数用于对 logits 进行 top-k 过滤
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 定义一个包装类,用于自回归模型
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, ignore_index = -100, pad_value = 0):
        super().__init__()
        self.pad_value = pad_value
        self.ignore_index = ignore_index

        self.net = net
        self.max_seq_len = net.max_seq_len

    # 生成序列的方法,支持自定义起始标记、序列长度、结束标记、温度等参数
    @torch.no_grad()
    @eval_decorator
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
        device = start_tokens.device
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        out = start_tokens
        mask = kwargs.pop('mask', None)

        if mask is None:
            mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)

        for _ in range(seq_len):
            x = out[:, -self.max_seq_len:]
            mask = mask[:, -self.max_seq_len:]

            logits = self.net(x, mask=mask, **kwargs)[:, -1, :]

            filtered_logits = top_k(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)

            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)
            mask = F.pad(mask, (0, 1), value=True)

            if eos_token is not None and (sample == eos_token).all():
                break

        out = out[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        return out

    # 前向传播方法,计算损失函数
    def forward(self, x, **kwargs):
        xi = x[:, :-1]
        xo = x[:, 1:]

        # 解决自回归模型中输入掩码的常见混淆问题
        mask = kwargs.get('mask', None)
        if mask is not None and mask.shape[1] == x.shape[1]:
            mask = mask[:, :-1]
            kwargs.update(mask = mask)

        out = self.net(xi, **kwargs)
        loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
        return loss

.\lucidrains\long-short-transformer\long_short_transformer\long_short_transformer.py

# 从 math 模块中导入 gcd(最大公约数)和 ceil(向上取整)函数
from math import gcd, ceil
# 导入 functools 模块
import functools

# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn(神经网络)和 einsum(张量乘法)模块
from torch import nn, einsum
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F

# 导入 rotary_embedding_torch 模块中的 RotaryEmbedding 和 apply_rotary_emb 函数
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb

# 导入 einops 模块中的 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 定义函数 exists,判断值是否存在
def exists(val):
    return val is not None

# 定义函数 default,如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义函数 lcm,计算多个数的最小公倍数
def lcm(*numbers):
    return int(functools.reduce(lambda x, y: int((x * y) / gcd(x, y)), numbers, 1))

# 定义函数 pad_to_multiple,将张量的长度填充到指定的倍数
def pad_to_multiple(tensor, multiple, dim = -1, value = 0):
    seqlen = tensor.shape[dim]
    m = seqlen / multiple

    if m.is_integer():
        return tensor

    remainder = ceil(m) * multiple - seqlen
    pad_offset = (0,) * (-1 - dim) * 2
    return F.pad(tensor, (*pad_offset, 0, remainder), value=value)

# 定义函数 look_around,根据给定的向前和向后偏移量,在张量周围填充指定值
def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2):
    t = x.shape[1]
    dims = (len(x.shape) - dim) * (0, 0)
    padded_x = F.pad(x, (*dims, backward, forward), value= pad_value)
    tensors = [padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)]
    return torch.cat(tensors, dim=dim)

# 定义类 PreNorm,实现预层归一化
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# 定义类 FeedForward,实现前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)

# 定义类 LongShortAttention,实现长短注意力机制
class LongShortAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        causal = True,
        window_size = 128,
        pos_emb = None,
        segment_size = 16,
        r = 1,
        dropout = 0.
    ):
        super().__init__()
        assert not (causal and r >= segment_size), 'r should be less than segment size, if autoregressive'

        inner_dim = heads * dim_head
        self.scale = dim_head ** -0.5

        self.heads = heads
        self.causal = causal

        self.window_size = window_size
        self.segment_size = segment_size
        self.pad_to_multiple = window_size if not causal else lcm(window_size, segment_size)

        self.to_dynamic_proj = nn.Linear(dim_head, r, bias = False)
        self.local_norm = nn.LayerNorm(dim_head)
        self.global_norm = nn.LayerNorm(dim_head)

        self.pos_emb = default(pos_emb, RotaryEmbedding(dim_head))

        self.attn_dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

# 定义主类 LongShortTransformer,实现长短变换器
class LongShortTransformer(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        depth,
        max_seq_len,
        window_size = 128,
        causal = True,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        segment_size = None,
        r = None,
        ff_dropout = 0.,
        attn_dropout = 0.
    ):  
        # 调用父类的构造函数
        super().__init__()
        # 设置最大序列长度
        self.max_seq_len = max_seq_len

        # 创建 token embedding 层
        self.token_emb = nn.Embedding(num_tokens, dim)
        # 创建旋转嵌入层
        pos_emb = RotaryEmbedding(dim_head)

        # 处理自回归默认变量的方式不同
        # 具体来说,segments 仅在自回归情况下使用
        # r 在非自回归情况下是投影的 r << n,在自回归情况下是每个段的投影 r
        # 是的,这很令人困惑,我知道

        # 设置 segment_size 默认值
        segment_size = default(segment_size, 16 if causal else None)
        # 设置 r 默认值
        r = default(r, 1 if causal else 128)

        # 创建多层神经网络
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            # 每层包含一个注意力机制和一个前馈神经网络
            self.layers.append(nn.ModuleList([
                PreNorm(dim, LongShortAttention(dim = dim, heads = heads, dim_head = dim_head, window_size = window_size, causal = causal, pos_emb = pos_emb, segment_size = segment_size, r = r, dropout = attn_dropout)),
                PreNorm(dim, FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout))
            ]))

        # 创建输出层
        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_tokens)
        )

    def forward(self, x, mask = None):
        # 对输入进行 token embedding
        x = self.token_emb(x)

        # 遍历每一层的注意力机制和前馈神经网络
        for attn, ff in self.layers:
            # 注意力机制
            x = attn(x, mask = mask) + x
            # 前馈神经网络
            x = ff(x) + x

        # 输出结果
        return self.to_logits(x)

.\lucidrains\long-short-transformer\long_short_transformer\__init__.py

# 从 long_short_transformer.long_short_transformer 模块中导入 LongShortTransformer 和 LongShortAttention 类
from long_short_transformer.long_short_transformer import LongShortTransformer, LongShortAttention

Long-Short Transformer

Implementation of Long-Short Transformer, combining local and global inductive biases for attention over long sequences, in Pytorch

Install

$ pip install long-short-transformer

Usage

import torch
from long_short_transformer import LongShortTransformer

model = LongShortTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 6,             # how deep
    heads = 8,             # number of heads
    dim_head = 64,         # dimension per head
    max_seq_len = 1024,    # maximum sequence length
    window_size = 128,     # local attention window size
    r = 256                # like linformer, the sequence length is projected down to this value to avoid the quadratic, where r << n (seq len)
)

x = torch.randint(0, 20000, (1, 1024))
mask = torch.ones(1, 1024).bool()

logits = model(x, mask = mask) # (1, 1024, 20000)

For the autoregressive case, you will have to also supply the segment_size and set causal to True

import torch
from long_short_transformer import LongShortTransformer

model = LongShortTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 6,             # how deep
    heads = 8,             # number of heads
    dim_head = 64,         # dimension per head
    causal = True,         # autoregressive or not
    max_seq_len = 1024,    # maximum sequence length
    window_size = 128,     # local attention window size
    segment_size = 16,     # sequence is divided into segments of this size, to be projected down to r
    r = 1                  # paper claimed best results with segment to r of 16:1
)

x = torch.randint(0, 20000, (1, 1024))
mask = torch.ones(1, 1024).bool()

logits = model(x, mask = mask) # (1, 1024, 20000)

You can test the autoregressive on enwik8 with

$ python train.py

Citations

@misc{zhu2021longshort,
    title   = {Long-Short Transformer: Efficient Transformers for Language and Vision}, 
    author  = {Chen Zhu and Wei Ping and Chaowei Xiao and Mohammad Shoeybi and Tom Goldstein and Anima Anandkumar and Bryan Catanzaro},
    year    = {2021},
    eprint  = {2107.02192},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\long-short-transformer\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'long-short-transformer',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.0.5',  # 版本号
  license='MIT',  # 许可证
  description = 'Long Short Transformer - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/long-short-transformer',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'efficient attention'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'rotary-embedding-torch',
    'torch>=1.6'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\long-short-transformer\train.py

# 导入所需的模块和类
from long_short_transformer import LongShortTransformer
from long_short_transformer.autoregressive_wrapper import AutoregressiveWrapper

import random
import tqdm
import gzip
import numpy as np

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# 定义常量
NUM_BATCHES = int(1e6)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 3e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
SEQ_LEN = 1024

# 定义辅助函数

# 将 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 将 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化类 GPT-like decoder model

model = LongShortTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    heads = 8,
    max_seq_len = SEQ_LEN,
    causal = True,
    window_size = 128
)

model = AutoregressiveWrapper(model)
model.cuda()

# 准备 enwik8 数据

with gzip.open('./data/enwik8.gz') as file:
    data = np.fromstring(file.read(int(95e6)), dtype = np.uint8)
    data_train, data_val = map(torch.from_numpy, np.split(data, [int(90e6)]))

# 定义 Dataset 类用于采样文本数据
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    print(f'training loss: {loss.item()}')

    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

.\lucidrains\lumiere-pytorch\lumiere_pytorch\lumiere.py

"""
einstein notation
b - batch
t - time
c - channels
h - height
w - width
"""

from copy import deepcopy
from functools import wraps

import torch
from torch import nn, einsum, Tensor, is_tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F

from beartype import beartype
from beartype.typing import List, Tuple, Optional, Type

from einops import rearrange, pack, unpack, repeat

from optree import tree_flatten, tree_unflatten

from x_transformers.x_transformers import (
    Attention,
    RMSNorm
)

# helpers

# 检查变量是否存在
def exists(v):
    return v is not None

# 如果变量存在则返回变量,否则返回默认值
def default(v, d):
    return v if exists(v) else d

# 将单个张量按照指定模式打包
def pack_one(t, pattern):
    return pack([t], pattern)

# 将单个张量按照指定模式解包
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 判断一个数是否可以被另一个数整除
def divisible_by(num, den):
    return (num % den) == 0

# 判断一个数是否为奇数
def is_odd(n):
    return not divisible_by(n, 2)

# 压缩字典中存在值的键值对
def compact_values(d: dict):
    return {k: v for k, v in d.items() if exists(v)}

# extract dimensions using hooks

# 使用钩子函数提取模块的输出形状
@beartype
def extract_output_shapes(
    modules: List[Module],
    model: Module,
    model_input,
    model_kwargs: dict = dict()
):
    shapes = []
    hooks = []

    def hook_fn(_, input, output):
        return shapes.append(output.shape)

    for module in modules:
        hook = module.register_forward_hook(hook_fn)
        hooks.append(hook)

    with torch.no_grad():
        model(model_input, **model_kwargs)

    for hook in hooks:
        hook.remove()

    return shapes

# freezing text-to-image, and only learning temporal parameters

# 冻结所有层,只学习时间参数
@beartype
def set_module_requires_grad_(
    module: Module,
    requires_grad: bool
):
    for param in module.parameters():
        param.requires_grad = requires_grad

def freeze_all_layers_(module):
    set_module_requires_grad_(module, False)

# function that takes in the entire text-to-video network, and sets the time dimension

# 设置时间维度
def set_time_dim_(
    klasses: Tuple[Type[Module]],
    model: Module,
    time_dim: int
):
    for model in model.modules():
        if isinstance(model, klasses):
            model.time_dim = time_dim

# decorator for residual

# 用于添加残差的装饰器
def residualize(fn):
    @wraps(fn)
    def inner(
        self,
        x,
        *args,
        **kwargs
    ):
        residual = x
        out = fn(self, x, *args, **kwargs)
        return out + residual

    return inner

# decorator for converting an input tensor from either image or video format to 1d time

# 将输入张量从图像或视频格式转换为1维时间的装饰器
def image_or_video_to_time(fn):

    @wraps(fn)
    def inner(
        self,
        x,
        batch_size = None,
        **kwargs
    ):

        is_video = x.ndim == 5

        if is_video:
            batch_size = x.shape[0]
            x = rearrange(x, 'b c t h w -> b h w c t')
        else:
            assert exists(batch_size) or exists(self.time_dim)
            rearrange_kwargs = dict(b = batch_size, t = self.time_dim)
            x = rearrange(x, '(b t) c h w -> b h w c t', **compact_values(rearrange_kwargs))

        x, ps = pack_one(x, '* c t')

        x = fn(self, x, **kwargs)

        x = unpack_one(x, ps, '* c t')

        if is_video:
            x = rearrange(x, 'b h w c t -> b c t h w')
        else:
            x = rearrange(x, 'b h w c t -> (b t) c h w')

        return x

    return inner

# handle channel last

# 处理通道在最后的情况
def handle_maybe_channel_last(fn):

    @wraps(fn)
    def inner(
        self,
        x,
        *args,
        **kwargs
    ):

        if self.channel_last:
            x = rearrange(x, 'b c ... -> b ... c')

        out = fn(self, x, *args, **kwargs)

        if self.channel_last:
            out = rearrange(out, 'b c ... -> b ... c')

        return out

    return inner

# helpers

# 创建一个序列模块,过滤掉不存在的模块
def Sequential(*modules):
    modules = list(filter(exists, modules))
    return nn.Sequential(*modules)

# 定义一个带有残差连接的模块
class Residual(Module):
    def __init__(self, fn: Module):
        super().__init__()
        self.fn = fn

    def forward(self, t, *args, **kwargs):
        return self.fn(t, *args, **kwargs) + t

# temporal down and upsample
# 初始化一维双线性插值卷积核
def init_bilinear_kernel_1d_(conv: Module):
    # 初始化卷积核权重为零
    nn.init.zeros_(conv.weight)
    # 如果存在偏置项,初始化为零
    if exists(conv.bias):
        nn.init.zeros_(conv.bias)

    # 获取卷积核的通道数
    channels = conv.weight.shape[0]
    # 创建双线性插值核
    bilinear_kernel = Tensor([0.5, 1., 0.5])
    # 创建对角线掩码
    diag_mask = torch.eye(channels).bool()
    # 将双线性插值核应用到卷积核的对角线位置
    conv.weight.data[diag_mask] = bilinear_kernel

# 时间下采样模块
class TemporalDownsample(Module):
    def __init__(
        self,
        dim,
        channel_last = False,
        time_dim = None
    ):
        super().__init__()
        self.time_dim = time_dim
        self.channel_last = channel_last

        # 创建一维卷积层,用于时间下采样
        self.conv = nn.Conv1d(dim, dim, kernel_size = 3, stride = 2, padding = 1)
        # 初始化卷积核为双线性插值核
        init_bilinear_kernel_1d_(self.conv)

    # 前向传播函数
    @handle_maybe_channel_last
    @image_or_video_to_time
    def forward(
        self,
        x
    ):
        # 断言时间维度大于1,以便进行压缩
        assert x.shape[-1] > 1, 'time dimension must be greater than 1 to be compressed'

        return self.conv(x)

# 时间上采样模块
class TemporalUpsample(Module):
    def __init__(
        self,
        dim,
        channel_last = False,
        time_dim = None
    ):
        super().__init__()
        self.time_dim = time_dim
        self.channel_last = channel_last

        # 创建一维转置卷积层,用于时间上采样
        self.conv = nn.ConvTranspose1d(dim, dim, kernel_size = 3, stride = 2, padding = 1, output_padding = 1)
        # 初始化卷积核为双线性插值核
        init_bilinear_kernel_1d_(self.conv)

    # 前向传播函数
    @handle_maybe_channel_last
    @image_or_video_to_time
    def forward(
        self,
        x
    ):
        return self.conv(x)

# 卷积膨胀块
class ConvolutionInflationBlock(Module):
    def __init__(
        self,
        *,
        dim,
        conv2d_kernel_size = 3,
        conv1d_kernel_size = 3,
        groups = 8,
        channel_last = False,
        time_dim = None
    ):
        super().__init__()
        assert is_odd(conv2d_kernel_size)
        assert is_odd(conv1d_kernel_size)

        self.time_dim = time_dim
        self.channel_last = channel_last

        # 空间卷积层
        self.spatial_conv = nn.Sequential(
            nn.Conv2d(dim, dim, conv2d_kernel_size, padding = conv2d_kernel_size // 2),
            nn.GroupNorm(groups, num_channels = dim),
            nn.SiLU()
        )

        # 时间卷积层
        self.temporal_conv = nn.Sequential(
            nn.Conv1d(dim, dim, conv1d_kernel_size, padding = conv1d_kernel_size // 2),
            nn.GroupNorm(groups, num_channels = dim),
            nn.SiLU()
        )

        # 投影输出层
        self.proj_out = nn.Conv1d(dim, dim, 1)

        # 初始化投影输出层的权重和偏置为零
        nn.init.zeros_(self.proj_out.weight)
        nn.init.zeros_(self.proj_out.bias)

    # 前向传播函数
    @residualize
    @handle_maybe_channel_last
    def forward(
        self,
        x,
        batch_size = None
    ):
        is_video = x.ndim == 5

        if is_video:
            batch_size = x.shape[0]
            x = rearrange(x, 'b c t h w -> (b t) c h w')

        x = self.spatial_conv(x)

        rearrange_kwargs = compact_values(dict(b = batch_size, t = self.time_dim))

        assert len(rearrange_kwargs) > 0, 'either batch_size is passed in on forward, or time_dim is set on init'
        x = rearrange(x, '(b t) c h w -> b h w c t', **rearrange_kwargs)

        x, ps = pack_one(x, '* c t')

        x = self.temporal_conv(x)
        x = self.proj_out(x)

        x = unpack_one(x, ps, '* c t')

        if is_video:
            x = rearrange(x, 'b h w c t -> b c t h w')
        else:
            x = rearrange(x, 'b h w c t -> (b t) c h w')

        return x

# 注意力膨胀块
class AttentionInflationBlock(Module):
    def __init__(
        self,
        *,
        dim,
        depth = 1,
        prenorm = True,
        residual_attn = True,
        time_dim = None,
        channel_last = False,
        **attn_kwargs
    # 初始化函数,继承父类的初始化方法
    def __init__(
        self,
        time_dim,
        channel_last,
        depth,
        dim,
        attn_kwargs = {},
        prenorm = False,
        residual_attn = False
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 设置时间维度和是否通道在最后的标志
        self.time_dim = time_dim
        self.channel_last = channel_last

        # 初始化时间注意力模块列表
        self.temporal_attns = ModuleList([])

        # 根据深度循环创建注意力模块
        for _ in range(depth):
            # 创建注意力模块序列
            attn = Sequential(
                RMSNorm(dim) if prenorm else None,
                Attention(
                    dim = dim,
                    **attn_kwargs
                )
            )

            # 如果开启残差连接,则将注意力模块包装成残差模块
            if residual_attn:
                attn = Residual(attn)

            # 将创建的注意力模块添加到时间注意力模块列表中
            self.temporal_attns.append(attn)

        # 创建输出投影层
        self.proj_out = nn.Linear(dim, dim)

        # 初始化输出投影层的权重和偏置为零
        nn.init.zeros_(self.proj_out.weight)
        nn.init.zeros_(self.proj_out.bias)

    # 前向传播函数,添加了装饰器
    @residualize
    @handle_maybe_channel_last
    def forward(
        self,
        x,
        batch_size = None
    ):
        # 判断输入是否为视频数据
        is_video = x.ndim == 5
        # 断言判断输入数据维度是否符合要求
        assert is_video ^ (exists(batch_size) or exists(self.time_dim)), 'either a tensor of shape (batch, channels, time, height, width) is passed in, or (batch * time, channels, height, width) along with `batch_size`'

        # 如果通道在最后,则重新排列输入数据
        if self.channel_last:
            x = rearrange(x, 'b ... c -> b c ...')

        # 如果是视频数据,则重新排列输入数据
        if is_video:
            batch_size = x.shape[0]
            x = rearrange(x, 'b c t h w -> b h w t c')
        else:
            assert exists(batch_size) or exists(self.time_dim)

            rearrange_kwargs = dict(b = batch_size, t = self.time_dim)
            x = rearrange(x, '(b t) c h w -> b h w t c', **compact_values(rearrange_kwargs))

        # 打包输入数据
        x, ps = pack_one(x, '* t c')

        # 遍历时间注意力模块列表,对输入数据进行注意力���作
        for attn in self.temporal_attns:
            x = attn(x)

        # 输出投影层处理数据
        x = self.proj_out(x)

        # 解包数据
        x = unpack_one(x, ps, '* t c')

        # 根据是否为视频数据重新排列输出数据
        if is_video:
            x = rearrange(x, 'b h w t c -> b c t h w')
        else:
            x = rearrange(x, 'b h w t c -> (b t) c h w')

        # 如果通道在最后,则重新排列输出数据
        if self.channel_last:
            x = rearrange(x, 'b c ... -> b ... c')

        # 返回处理后的输出数据
        return x
# 定义一个包装器类,用于在模块后添加钩子
class PostModuleHookWrapper(Module):
    def __init__(self, temporal_module: Module):
        super().__init__()
        self.temporal_module = temporal_module

    # 在前向传播过程中,对输出进行处理并返回
    def forward(self, _, input, output):
        output = self.temporal_module(output)
        return output

# 将临时模块插入到模块列表中
def insert_temporal_modules_(modules: List[Module], temporal_modules: ModuleList):
    assert len(modules) == len(temporal_modules)

    # 遍历模块列表和临时模块列表,为每个模块注册一个后向钩子
    for module, temporal_module in zip(modules, temporal_modules):
        module.register_forward_hook(PostModuleHookWrapper(temporal_module))

# 主要的文本到图像模型包装器
class Lumiere(Module):

    # 初始化函数
    @beartype
    def __init__(
        self,
        model: Module,
        *,
        image_size: int,
        unet_time_kwarg: str,
        conv_module_names: List[str],
        attn_module_names: List[str] = [],
        downsample_module_names: List[str] = [],
        upsample_module_names: List[str] = [],
        channels: int = 3,
        conv_inflation_kwargs: dict = dict(),
        attn_inflation_kwargs: dict = dict(),
        downsample_kwargs: dict = dict(),
        upsample_kwargs: dict = dict(),
        conv_klass = ConvolutionInflationBlock,
        attn_klass = AttentionInflationBlock,
        downsample_klass = TemporalDownsample,
        upsample_klass = TemporalUpsample
    @property
    def downsample_factor(self):
        return 2 ** len(self.downsamples)

    # 返回模型的参数
    def parameters(self):
        return [
            *self.convs.parameters(),
            *self.attns.parameters(),
            *self.downsamples.parameters(),
            *self.upsamples.parameters(),
        ]

    # 前向传播函数
    @beartype
    def forward(
        self,
        video: Tensor,
        *args,
        **kwargs
    ) -> Tensor:

        assert video.ndim == 5
        batch, channels, time, height, width = video.shape

        assert channels == self.channels
        assert (height, width) == (self.image_size, self.image_size)

        assert divisible_by(time, self.downsample_factor)

        # 将视频转换为一堆图像
        images = rearrange(video, 'b c t h w -> (b t) c h w')

        # 为所有时间层设置正确的时间维度
        set_time_dim_(self.temporal_klasses, self, time)

        # 将所有图像传入文本到图像模型
        images = self.model(images, *args, **kwargs)

        # 将结果重塑回去成去噪视频
        return rearrange(images, '(b t) c h w -> b c t h w', b = batch)