Lucidrains 系列项目源码解析(一百零二)
.\lucidrains\triangle-multiplicative-module\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'triangle-multiplicative-module', # 包名
packages = find_packages(), # 查找所有包
version = '0.0.3', # 版本号
license='MIT', # 许可证
description = 'Triangle Multiplicative Module', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/triangle-multiplicative-module', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'protein folding'
],
install_requires=[ # 安装依赖
'einops>=0.3',
'torch>=1.7'
],
setup_requires=[ # 设置依赖
'pytest-runner',
],
tests_require=[ # 测试依赖
'pytest'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\triangle-multiplicative-module\triangle_multiplicative_module\triangle_multiplicative_module.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 定义辅助函数
# 判断值是否存在
def exists(val):
return val is not None
# 如果值存在则返回该值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 定义类
# 三角形乘法模块类
class TriangleMultiplicativeModule(nn.Module):
def __init__(
self,
*,
dim,
hidden_dim = None,
mix = 'ingoing'
):
super().__init__()
# 断言 mix 参数只能为 'ingoing' 或 'outgoing'
assert mix in {'ingoing', 'outgoing'}, 'mix must be either ingoing or outgoing'
# 如果 hidden_dim 不存在,则设为 dim
hidden_dim = default(hidden_dim, dim)
# 对输入进行 LayerNorm 归一化
self.norm = nn.LayerNorm(dim)
# 左投影层
self.left_proj = nn.Linear(dim, hidden_dim)
# 右投影层
self.right_proj = nn.Linear(dim, hidden_dim)
# 左门控层
self.left_gate = nn.Linear(dim, hidden_dim)
# 右门控层
self.right_gate = nn.Linear(dim, hidden_dim)
# 输出门控层
self.out_gate = nn.Linear(dim, hidden_dim)
# 初始化所有门控层的权重为 0,偏置为 1
for gate in (self.left_gate, self.right_gate, self.out_gate):
nn.init.constant_(gate.weight, 0.)
nn.init.constant_(gate.bias, 1.)
# 根据 mix 参数确定 einsum 公式
if mix == 'outgoing':
self.mix_einsum_eq = '... i k d, ... j k d -> ... i j d'
elif mix == 'ingoing':
self.mix_einsum_eq = '... k j d, ... k i d -> ... i j d'
# 输出层归一化
self.to_out_norm = nn.LayerNorm(hidden_dim)
# 输出层线性变换
self.to_out = nn.Linear(hidden_dim, dim)
def forward(self, x, mask = None):
# 断言输入特征图必须是对称的
assert x.shape[1] == x.shape[2], 'feature map must be symmetrical'
# 如果 mask 存在,则重组 mask 的维度
if exists(mask):
mask = rearrange(mask, 'b i j -> b i j ()')
# 对输入进行归一化
x = self.norm(x)
# 左投影
left = self.left_proj(x)
# 右投影
right = self.right_proj(x)
# 如果 mask 存在,则对左右投影进行 mask 处理
if exists(mask):
left = left * mask
right = right * mask
# 计算左门控值
left_gate = self.left_gate(x).sigmoid()
# 计算右门控值
right_gate = self.right_gate(x).sigmoid()
# 计算输出门控值
out_gate = self.out_gate(x).sigmoid()
# 左投影乘以左门控值
left = left * left_gate
# 右投影乘以右门控值
right = right * right_gate
# 执行 einsum 运算,根据 mix_einsum_eq 公式计算输出
out = einsum(self.mix_einsum_eq, left, right)
# 对输出进行归一化
out = self.to_out_norm(out)
# 输出乘以输出门控值
out = out * out_gate
# 返回输出结果
return self.to_out(out)
.\lucidrains\triangle-multiplicative-module\triangle_multiplicative_module\__init__.py
# 从triangle_multiplicative_module.triangle_multiplicative_module模块中导入TriangleMultiplicativeModule类
from triangle_multiplicative_module.triangle_multiplicative_module import TriangleMultiplicativeModule
.\lucidrains\triton-transformer\assert.py
# 导入 PyTorch 库
import torch
# 从 triton_transformer 模块中导入 Transformer 类
from triton_transformer import Transformer
# 检查是否有可用的 CUDA 设备
assert torch.cuda.is_available()
# 实例化模型和数据
# 创建 Transformer 模型对象,设置参数:标记数量为 256,最大序列长度为 1024,维度为 512,深度为 6,头数为 8,头维度为 64,使用因果性,不使用 Triton
model = Transformer(
num_tokens = 256,
max_seq_len = 1024,
dim = 512,
depth = 6,
heads = 8,
dim_head = 64,
causal = True,
use_triton = False
).cuda()
# 生成一个大小为 (1, 1024) 的张量,填充随机整数,放在 CUDA 设备上
x = torch.randint(0, 256, (1, 1024)).cuda()
# 生成一个大小为 (1, 1024) 的张量,填充随机整数,放在 CUDA 设备上
labels = torch.randint(0, 256, (1, 1024)).cuda()
# 无 Triton 的前向传播和反向传播
# 计算模型输出和损失
loss = model(x, labels = labels)
# 反向传播计算梯度
loss.backward()
# 复制损失值
loss = loss.clone()
# 复制 token embeddings 的梯度
emb_grad = model.token_emb.weight.grad.clone()
# 复制 LayerNorm 层的权重梯度
ln_weight_grad = model.norm.weight.grad.clone()
# 复制 LayerNorm 层的偏置梯度
ln_bias_grad = model.norm.bias.grad.clone()
# 清零所有梯度
model.zero_grad()
# Triton 的前向传播和反向传播
# 使用 Triton 进行前向传播和反向传播
triton_loss = model(x, labels = labels, use_triton = True)
# Triton 反向传播计算梯度
triton_loss.backward()
# 复制 Triton 下的 token embeddings 的梯度
triton_emb_grad = model.token_emb.weight.grad.clone()
# 复制 Triton 下的 LayerNorm 层的权重梯度
triton_ln_weight_grad = model.norm.weight.grad.clone()
# 复制 Triton 下的 LayerNorm 层的偏置梯度
triton_ln_bias_grad = model.norm.bias.grad.clone()
# 应该相等,对输出和 token embeddings 的梯度进行检查
# 检查输出是否相等
assert torch.allclose(loss.cpu(), triton_loss.cpu(), atol=1e-6), 'output is the same'
# 检查 token embeddings 的梯度是否相等
assert torch.allclose(emb_grad.cpu(), triton_emb_grad.cpu(), atol=2e-6), 'grad is the same'
# 检查 LayerNorm 层的权重梯度是否相等
assert torch.allclose(ln_weight_grad.cpu(), triton_ln_weight_grad.cpu(), atol=2e-6), 'layernorm weight grad is the same'
# 检查 LayerNorm 层的偏置梯度是否相等
assert torch.allclose(ln_bias_grad.cpu(), triton_ln_bias_grad.cpu(), atol=2e-6), 'layernorm bias grad is the same'
# 打印成功信息
print('succeeded')
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
Transformer in Triton (wip)
Implementation of a Transformer, but completely in Triton. I'm completely new to lower-level neural net code, so this repository will mostly be a learning experience, with the end-goal being a vanilla transformer that is faster and more efficient to train.
Results
Layernorm forward

Layernorm forwards and backwards

Softmax forwards and backwards

Install
$ pip install triton-transformer
Usage
import torch
from triton_transformer import Transformer
model = Transformer(
num_tokens = 256, # vocab size
max_seq_len = 1024, # maximum sequence length
dim = 512, # dimension
depth = 6, # depth
heads = 8, # number of heads
dim_head = 64, # dimension per head
causal = True, # autoregressive or not
attn_dropout = 0.1, # attention dropout
ff_dropout = 0.1, # feedforward dropout
use_triton = True # use this to turn on / off triton
).cuda()
x = torch.randint(0, 256, (1, 1024)).cuda()
logits = model(x) # (1, 1024, 256)
To train, just pass in the labels with the keyword labels on forward, and the cross entropy loss will be returned for backprop.
ex. BERT
import torch
from triton_transformer import Transformer
model = Transformer(
num_tokens = 20000,
max_seq_len = 512,
dim = 512,
depth = 12,
heads = 8,
dim_head = 64,
use_triton = True
).cuda()
x = torch.randint(0, 20000, (1, 512)).cuda()
labels = torch.randint(0, 20000, (1, 512)).cuda()
mask = torch.ones(1, 512).bool().cuda()
loss = model(x, mask = mask, labels = labels)
loss.backward()
Test - GPT training
$ python train.py
Todo
- softmax
- cross-entropy (using triton ops)
- layernorm forward
- layernorm backwards
- batch matrix multiply + fused act forwards
- optimize layernorm backwards (figure out how much to store vs recompute)
- use memory efficient dropout from Triton tutorials
- batch matrix multiply + fused act backwards
- fused attention (expand on softmax)
- use triton matmul for other projections
- benchmark and optimize
- kernels conditional on inference vs training
- efficient triangular matmul kernel for causal attention
Citations
@article{Tillet2019TritonAI,
title = {Triton: an intermediate language and compiler for tiled neural network computations},
author = {Philippe Tillet and H. Kung and D. Cox},
journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
year = {2019}
}
@misc{vaswani2017attention,
title = {Attention Is All You Need},
author = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
year = {2017},
eprint = {1706.03762},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
@misc{so2021primer,
title = {Primer: Searching for Efficient Transformers for Language Modeling},
author = {David R. So and Wojciech Mańke and Hanxiao Liu and Zihang Dai and Noam Shazeer and Quoc V. Le},
year = {2021},
eprint = {2109.08668},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@article{chowdhery2022PaLM,
title = {PaLM: Scaling Language Modeling with Pathways},
author = {Chowdhery, Aakanksha et al},
year = {2022}
}
.\lucidrains\triton-transformer\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'triton-transformer', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.1.1', # 版本号
license='MIT', # 许可证
description = 'Transformer in Triton', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/triton-transformer', # 项目链接
keywords = [
'artificial intelligence', # 关键词
'attention mechanism', # 关键词
'transformers' # 关键词
],
install_requires=[
'einops', # 安装所需的依赖包
'torch>=1.6', # 安装所需的依赖包
'triton==1.0.1.dev20210924' # 安装所需的依赖包
],
classifiers=[
'Development Status :: 4 - Beta', # 分类器
'Intended Audience :: Developers', # 分类器
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类器
'License :: OSI Approved :: MIT License', # 分类器
'Programming Language :: Python :: 3.6', # 分类器
],
)
.\lucidrains\triton-transformer\train.py
# 导入所需的库
from triton_transformer import Transformer
from triton_transformer.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512
# 定义辅助函数
# 从 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 从 tokens 解码为字符串
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# 实例化类似 GPT 的解码器模型
model = Transformer(
num_tokens = 256,
dim = 512,
max_seq_len = SEQ_LEN,
depth = 8,
heads = 8,
causal = True,
use_triton = True,
attn_dropout = 0.1,
ff_dropout = 0.1,
)
model = AutoregressiveWrapper(model)
model.cuda()
# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义文本采样数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
loss.backward()
print(f'training loss: {loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))
sample = model.generate(inp[None, ...], GENERATE_LENGTH)
output_str = decode_tokens(sample[0])
print(output_str)
.\lucidrains\triton-transformer\triton_transformer\autoregressive_wrapper.py
import torch
from torch import nn
import torch.nn.functional as F
# helper function
# 检查值是否存在的辅助函数
def exists(val):
return val is not None
# 评估装饰器函数
def eval_decorator(fn):
def inner(model, *args, **kwargs):
# 保存模型原始训练状态
was_training = model.training
# 将模型设置为评估模式
model.eval()
# 调用传入的函数
out = fn(model, *args, **kwargs)
# 恢复模型原始训练状态
model.train(was_training)
return out
return inner
# top k filtering
# 根据阈值过滤 logits 中的 top k 值
def top_k(logits, thres = 0.9):
# 计算 top k 的数量
k = int((1 - thres) * logits.shape[-1])
# 获取 top k 的值和索引
val, ind = torch.topk(logits, k)
# 创建与 logits 相同形状的全为负无穷的张量
probs = torch.full_like(logits, float('-inf'))
# 根据索引将 top k 的值填充到 probs 中
probs.scatter_(1, ind, val)
return probs
# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, pad_value = 0):
super().__init__()
self.pad_value = pad_value
self.net = net
self.max_seq_len = net.max_seq_len
# 生成序列的方法
@torch.no_grad()
@eval_decorator
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_thres = 0.9, **kwargs):
# 获取起始 tokens 的形状和设备信息
b, t, device = *start_tokens.shape, start_tokens.device
out = start_tokens
for _ in range(seq_len):
# 获取最后 self.max_seq_len 个 token
x = out[:, -self.max_seq_len:]
# 获取模型预测的 logits
logits = self.net(x, **kwargs)[:, -1, :]
# 过滤 logits 中的 top k 值
filtered_logits = top_k(logits, thres = filter_thres)
# 计算 softmax 温度调节后的概率
probs = F.softmax(filtered_logits / temperature, dim=-1)
# 从概率分布中采样一个 token
sample = torch.multinomial(probs, 1)
# 将采样的 token 添加到输出序列中
out = torch.cat((out, sample), dim=-1)
if exists(eos_token):
# 检查是否存在 eos_token
is_eos_token = (out == eos_token)
if is_eos_token.any(dim = -1).all():
# 如果所有序列中都存在 eos_token,则停止生成
# 创建一个向右移动一位�� eos_token mask
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
# 创建一个 mask,标记 eos_token 后的所有位置
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
# 将 mask 标记的位置填充为 pad_value
out = out.masked_fill(mask, self.pad_value)
break
# 去除起始 tokens,返回生成的序列
out = out[:, t:]
return out
# 前向传播方法
def forward(self, x, **kwargs):
# 将输入拆分为输入和标签
x_inp, x_labels = x[:, :-1], x[:, 1:]
return self.net(x_inp, labels = x_labels, **kwargs)
.\lucidrains\triton-transformer\triton_transformer\bmm.py
# 导入 torch 库
import torch
# 从 torch 库中导入 autograd 模块
from torch import autograd
# 从 torch.nn.functional 模块中导入 F 函数
import torch.nn.functional as F
# 从 triton_transformer.utils 模块中导入 calc_num_warps 和 exists 函数
from triton_transformer.utils import calc_num_warps, exists
# 导入 triton 库
import triton
# 从 triton.language 模块中导入 tl
import triton.language as tl
# 使用 triton.autotune 装饰器,配置自动调优参数
@triton.autotune(
configs=[
# 配置不同的参数组合
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32 , 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
],
key=['M', 'N', 'K'],
)
# 使用 triton.jit 装饰器,编译 bmm_kernel 函数
@triton.jit
def bmm_kernel(
x_ptr, y_ptr, o_ptr,
M, N, K,
stride_al, stride_am, stride_ak,
stride_bl, stride_bk, stride_bn,
stride_ol, stride_om, stride_on,
**meta,
):
# 定义常量
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
GROUP_SIZE_M = 8
# 计算程序 ID
pid_batch = tl.program_id(0)
pid = tl.program_id(1)
# 计算分组数量
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# 计算偏移量
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak + pid_batch*stride_al)
y_ptrs = y_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn + pid_batch*stride_bl)
# 初始化输出矩阵 o
o = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# 循环计算矩阵乘法
for k in range(0, K, BLOCK_SIZE_K):
x = tl.load(x_ptrs)
y = tl.load(y_ptrs)
o += tl.dot(x, y)
x_ptrs += BLOCK_SIZE_K * stride_ak
y_ptrs += BLOCK_SIZE_K * stride_bk
# 如果存在激活函数,则应用激活函数
if exists(meta['ACTIVATION']):
o = meta['ACTIVATION'](o)
# 计算偏移量
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
# 创建掩码
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
# 计算输出指针
o_ptrs = o_ptr + stride_om * offs_m[:, None] + stride_on * offs_n[None, :] + stride_ol * pid_batch
# 存储结果到输出指针
tl.store(o_ptrs, o, mask=mask)
# 定义 triton_bmm 函数
def triton_bmm(x, y, activation = None):
# 获取 x 的形状信息
B, M, K = x.shape
# 如果 y 的维度为 2,则扩展维度
if y.ndim == 2:
y = y.unsqueeze(0).expand(B, -1, -1)
# 获取 y 的形状信息
_, K, N = y.shape
# 断言 K 必须能被 32 整除
assert (K % 32 == 0), "K must be divisible by 32"
# 创建输出张量 o
o = torch.empty((B, M, N), device = x.device, dtype = x.dtype)
# 定义 grid 函数
grid = lambda META: (
B, triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
# 调用 bmm_kernel 函数
bmm_kernel[grid](
x, y, o,
M, N, K,
x.stride(0), x.stride(1), x.stride(2),
y.stride(0), y.stride(1), y.stride(2),
o.stride(0), o.stride(1), o.stride(2),
ACTIVATION = activation
)
# 返回结果张量 o
return o
# 使用 triton.jit 装饰器,编译 relu_squared_activation 函数
@triton.jit
def relu_squared_activation(x):
return tl.where(x > 0, x * x, 0.)
# 定义 _relu_squared 类
class _relu_squared(autograd.Function):
# 前向传播函数
@classmethod
def forward(self, ctx, x, w):
# 调用 triton_bmm 函数,应用 relu_squared_activation 激活函数
o = triton_bmm(x, w, activation = relu_squared_activation)
# 如果 x 需要梯度,则保存相关信息
if x.requires_grad:
ctx.save_for_backward(x, w, o)
return o
@classmethod
# 反向传播函数,接收上下文和梯度作为输入
def backward(self, ctx, dy):
# 从上下文中获取保存的张量 x, w, o
x, w, o = ctx.saved_tensors
# 计算 dy 乘以 o 的平方根乘以 2,得到新的梯度 dy
dy = torch.sqrt(o) * 2 * dy
# 计算 dy 与权重 w 的转置的矩阵乘积,得到输入 x 的梯度 dx
dx = triton_bmm(dy, w.t())
# 计算输入 x 的转置与梯度 dy 的矩阵乘积,得到权重 w 的梯度 dw
dw = triton_bmm(x.transpose(-1, -2), dy)
# 返回输入 x 和权重 w 的梯度
return dx, dw
# 将 _relu_squared.apply 赋值给 triton_relu_squared,用于后续调用
triton_relu_squared = _relu_squared.apply
# 定义一个融合了 ReLU 和平方操作的函数
def fused_relu_squared(x, w, use_triton = False):
# 如果 use_triton 为 True,则调用 triton_relu_squared 函数
if use_triton:
return triton_relu_squared(x, w)
# 如果 use_triton 为 False,则计算 x @ w 的矩阵乘法结果,然后对结果进行 ReLU 和平方操作
return F.relu(x @ w) ** 2
.\lucidrains\triton-transformer\triton_transformer\cross_entropy.py
# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 导入 triton 库
import triton
# 从 triton 库中导入 language 模块并重命名为 tl
import triton.language as tl
# 定义交叉熵损失函数,接受 logits(预测值)、labels(真实标签)、ignore_index(忽略的索引,默认为0)、use_triton(是否使用 triton 加速,默认为 False)
def cross_entropy_fn(logits, labels, ignore_index = 0., use_triton = False):
# 重新排列 logits 张量的维度,将 'b n c' 转换为 '(b n) c'
logits = rearrange(logits, 'b n c -> (b n) c')
# 重新排列 labels 张量的维度,将 'b n' 转换为 '(b n)'
labels = rearrange(labels, 'b n -> (b n)')
# 如果 use_triton 为 True,则使用 triton 库中的 cross_entropy 函数计算损失
if use_triton:
loss = triton.ops.cross_entropy(logits, labels)
# 否则使用 torch.nn.functional 库中的 cross_entropy 函数计算损失
else:
loss = F.cross_entropy(logits, labels, reduction = 'none')
# 创建一个掩码,标记 labels 中不等于 ignore_index 的位置
mask = (labels != ignore_index)
# 返回经过掩码处理后的损失的均值
return loss[mask].mean()
.\lucidrains\triton-transformer\triton_transformer\dropout.py
# 导入所需的库
import torch
from torch import autograd
import torch.nn.functional as F
import triton
import triton.language as tl
from random import randrange
# 定义常量 BLOCK_SIZE
BLOCK_SIZE = 1024
# Triton JIT 编译的函数,实现带有随机种子的 dropout 操作
@triton.jit
def _seeded_dropout(x_ptr, output_ptr, n_elements, p, seed, **meta):
BLOCK_SIZE = meta['BLOCK_SIZE']
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE * 4
off0 = block_start + BLOCK_SIZE * 0 + tl.arange(0, BLOCK_SIZE)
off1 = block_start + BLOCK_SIZE * 1 + tl.arange(0, BLOCK_SIZE)
off2 = block_start + BLOCK_SIZE * 2 + tl.arange(0, BLOCK_SIZE)
off3 = block_start + BLOCK_SIZE * 3 + tl.arange(0, BLOCK_SIZE)
mask0 = off0 < n_elements
mask1 = off1 < n_elements
mask2 = off2 < n_elements
mask3 = off3 < n_elements
x0 = tl.load(x_ptr + off0, mask = mask0)
x1 = tl.load(x_ptr + off1, mask = mask1)
x2 = tl.load(x_ptr + off2, mask = mask2)
x3 = tl.load(x_ptr + off3, mask = mask3)
r0, r1, r2, r3 = tl.random.rand4x(seed, off0)
keep0, keep1, keep2, keep3 = r0 > p, r1 > p, r2 > p, r3 > p
o0 = tl.where(keep0, x0 / (1 - p), 0.0)
o1 = tl.where(keep1, x1 / (1 - p), 0.0)
o2 = tl.where(keep2, x2 / (1 - p), 0.0)
o3 = tl.where(keep3, x3 / (1 - p), 0.0)
tl.store(output_ptr + off0, o0, mask = mask0)
tl.store(output_ptr + off1, o1, mask = mask1)
tl.store(output_ptr + off2, o2, mask = mask2)
tl.store(output_ptr + off3, o3, mask = mask3)
# 带有随机种子的 dropout 操作的包装函数
def seeded_dropout(x, p, seed):
output = torch.empty_like(x)
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE'] * 4),)
_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE = BLOCK_SIZE)
return output
# 自定义 autograd.Function 类,实现 dropout 操作
class dropout_(autograd.Function):
@classmethod
def forward(cls, ctx, x, p):
seed = randrange(int(1e6))
ctx.p = p
ctx.seed = seed
return seeded_dropout(x, p, seed)
@classmethod
def backward(cls, ctx, dy):
p = ctx.p
seed = ctx.seed
return seeded_dropout(dy, p, seed), None
# dropout 操作的函数,根据 use_triton 参数选择使用 Triton 实现的 dropout 还是 PyTorch 自带的 dropout
def dropout_fn(x, p, use_triton = False):
if p == 0. or not x.requires_grad:
return x
if not use_triton:
return F.dropout(x, p, training = True)
return dropout_.apply(x, p)
.\lucidrains\triton-transformer\triton_transformer\layernorm.py
# 导入 torch 库
import torch
# 从 torch 库中导入 autograd 模块
from torch import autograd
# 从 torch 库中导入 functional 模块
import torch.nn.functional as F
# 导入 triton 库
import triton
# 从 triton 库中导入 language 模块并重命名为 tl
import triton.language as tl
# 从 triton_transformer.utils 模块中导入 calc_num_warps 和 exists 函数
from triton_transformer.utils import calc_num_warps, exists
# 定义 GAMMA_BLOCK_SIZE 常量为 64
GAMMA_BLOCK_SIZE = 64
# 定义 GAMMA_ROW_BLOCK_SIZE 常量为 64
GAMMA_ROW_BLOCK_SIZE = 64
# 定义 layernorm_kernel_forward_training 函数
@triton.jit
def layernorm_kernel_forward_training(
output_ptr,
mean_centered_ptr,
normed_ptr,
input_ptr,
gamma_ptr,
input_row_stride,
gamma_row_stride,
output_row_stride,
mean_centered_row_stride,
normed_row_stride,
n_cols,
stable,
eps,
**meta
):
# 获取当前程序的 ID
row_idx = tl.program_id(0)
# 从 meta 中获取 BLOCK_SIZE 常量
BLOCK_SIZE = meta['BLOCK_SIZE']
# 计算当前行的起始指针
row_start_ptr = input_ptr + row_idx * input_row_stride
# 计算当前行 gamma 的起始指针
gamma_row_start_ptr = gamma_ptr + row_idx * gamma_row_stride
# 生成列偏移量
col_offsets = tl.arange(0, BLOCK_SIZE)
# 计算当前行的输入指针
input_ptrs = row_start_ptr + col_offsets
# 计算当前行的 gamma 指针
gamma_ptrs = gamma_row_start_ptr + col_offsets
# 创建一个掩码,用于处理列偏移量小于 n_cols 的情况
mask = col_offsets < n_cols
# 从输入指针处加载数据到 row,如果掩码为 False,则加载 0.0
row = tl.load(input_ptrs, mask=mask, other=0.)
# 从 gamma 指针处加载数据到 gammas,如果掩码为 False,则加载 0.0
gammas = tl.load(gamma_ptrs, mask=mask, other=0.)
# 如果 stable 为 True
if stable:
# 计算当前行的最大值
row_max = tl.max(tl.where(mask, row, float('-inf')), axis=0)
# 对当前行进行归一化
row /= row_max
# 计算当前行的均值
row_mean = tl.sum(row, axis=0) / n_cols
# 计算当前行的中心化值
row_mean_centered = tl.where(mask, row - row_mean, 0.)
# 计算当前行的方差
row_var = tl.sum(row_mean_centered * row_mean_centered, axis=0) / n_cols
# 计算当前行的标准差的倒数
inv_var = 1. / tl.sqrt(row_var + eps)
# 计算当前行的归一化值
normed = row_mean_centered * inv_var
# 计算输出值
output = normed * gammas
# 计算输出行的起始指针
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# 计算输出指针
output_ptrs = output_row_start_ptr + col_offsets
# 将输出值存储到输出指针处
tl.store(output_ptrs, output, mask=mask)
# 计算中心化行的起始指针
mean_centered_row_start_ptr = mean_centered_ptr + row_idx * mean_centered_row_stride
# 计算中心化指针
mean_centered_ptrs = mean_centered_row_start_ptr + col_offsets
# 将中心化值存储到中心化指针处
tl.store(mean_centered_ptrs, row_mean_centered, mask=mask)
# 计算归一化行的起始指针
normed_row_start_ptr = normed_ptr + row_idx * normed_row_stride
# 计算归一化指针
normed_ptrs = normed_row_start_ptr + col_offsets
# 将归一化值存储到归一化指针处
tl.store(normed_ptrs, normed, mask=mask)
# 定义 layernorm_kernel_forward_inference 函数
@triton.jit
def layernorm_kernel_forward_inference(
output_ptr,
input_ptr,
gamma_ptr,
input_row_stride,
gamma_row_stride,
output_row_stride,
n_cols,
stable,
eps,
**meta
):
# 获取当前程序的 ID
row_idx = tl.program_id(0)
# 从 meta 中获取 BLOCK_SIZE 常量
BLOCK_SIZE = meta['BLOCK_SIZE']
# 计算当前行的起始指针
row_start_ptr = input_ptr + row_idx * input_row_stride
# 计算当前行 gamma 的起始指针
gamma_row_start_ptr = gamma_ptr + row_idx * gamma_row_stride
# 生成列偏移量
col_offsets = tl.arange(0, BLOCK_SIZE)
# 计算当前行的输入指针
input_ptrs = row_start_ptr + col_offsets
# 计算当前行的 gamma 指针
gamma_ptrs = gamma_row_start_ptr + col_offsets
# 创建一个掩码,用于处理列偏移量小于 n_cols 的情况
mask = col_offsets < n_cols
# 从输入指针处加载数据到 row,如果掩码为 False,则加载 0.0
row = tl.load(input_ptrs, mask=mask, other=0.)
# 从 gamma 指针处加载数据到 gammas,如果掩码为 False,则加载 0.0
gammas = tl.load(gamma_ptrs, mask=mask, other=0.)
# 如果 stable 为 True
if stable:
# 计算当前行的最大值
row_max = tl.max(tl.where(mask, row, float('-inf')), axis=0)
# 对当前行进行归一化
row /= row_max
# 计算当前行的均值
row_mean = tl.sum(row, axis=0) / n_cols
# 计算当前行的中心化值
row_mean_centered = tl.where(mask, row - row_mean, 0.)
# 计算当前行的方差
row_var = tl.sum(row_mean_centered * row_mean_centered, axis=0) / n_cols
# 计算当前行的标准差的倒数
inv_var = 1. / tl.sqrt(row_var + eps)
# 计算当前行的归一化值
normed = row_mean_centered * inv_var
# 计算输出值
output = normed * gammas
# 计算输出行的起始指针
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# 计算输出指针
output_ptrs = output_row_start_ptr + col_offsets
# 将输出值存储到输出指针处
tl.store(output_ptrs, output, mask=mask)
# 定义 layernorm_kernel_backward 函数
@triton.jit
def layernorm_kernel_backward(
output_ptr,
dy_ptr,
mean_centered_ptr,
output_row_stride,
dy_row_stride,
mean_centered_row_stride,
n_cols,
eps,
**meta
):
# 获取当前程序的 ID
row_idx = tl.program_id(0)
# 从 meta 中获取 BLOCK_SIZE 常量
BLOCK_SIZE = meta['BLOCK_SIZE']
# 计算当前行的 dy 起始指针
dy_row_start_ptr = dy_ptr + row_idx * dy_row_stride
# 计算当前行的中心化值起始指针
mean_centered_row_start_ptr = mean_centered_ptr + row_idx * mean_centered_row_stride
# 生成列偏移量
col_offsets = tl.arange(0, BLOCK_SIZE)
# 计算当前行的 dy 指针
dy_ptrs = dy_row_start_ptr + col_offsets
# 计算当前行的中心化值指针
mean_centered_ptrs = mean_centered_row_start_ptr + col_offsets
# 创建一个掩码,用于处理列偏移量小于 n_cols 的情况
mask = col_offsets < n_cols
# 从 dy 指针处加载数据到 dy,如果掩码为 False,则加载 0.0
dy = tl.load(dy_ptrs, mask=mask, other=0.)
# 从中心化值指针处加载数据到 mean_centered,如果掩码为 False,则加载 0.0
mean_centered = tl.load(mean_centered_ptrs, mask=mask, other=0.)
# 计算每行的方差
row_var = tl.sum(mean_centered * mean_centered, axis=0) / n_cols
# 计算每行的标准差的倒数
inv_var = 1. / tl.sqrt(row_var + eps)
# 对数据进行标准化处理
normed = mean_centered * inv_var
# 计算输出值
output = 1. / n_cols * inv_var * (n_cols * dy - tl.sum(dy, axis=0) - normed * tl.sum(dy * normed, axis=0))
# 计算输出行的起始指针
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# 计算输出指针数组
output_ptrs = output_row_start_ptr + col_offsets
# 存储输出数据到指定的指针位置,使用掩码进行过滤
tl.store(output_ptrs, output, mask=mask)
# 定义一个使用 Triton JIT 编译的函数,用于计算 LayerNorm 操作的 gamma 反向传播
def layernorm_gamma_kernel_backward(
dgamma_ptr, # 存储计算得到的 dgamma 结果的指针
norm_ptr, # 存储 norm 数据的指针
dy_ptr, # 存储 dy 数据的指针
norm_stride, # norm 数据的步长
dy_stride, # dy 数据的步长
dgamma_row_stride, # dgamma 行步长
n_rows, # 数据行数
n_cols, # 数据列数
**meta # 其他元数据
):
# 获取当前程序的列索引和行索引
col_idx = tl.program_id(0)
row_idx = tl.program_id(1)
# 从元数据中获取 BLOCK_SIZE 和 ROW_BLOCK_SIZE
BLOCK_SIZE = meta['BLOCK_SIZE']
ROW_BLOCK_SIZE = meta['BLOCK_SIZE_ROW']
# 创建列偏移量和行偏移量
col_offsets = tl.arange(0, BLOCK_SIZE)
row_offsets = tl.arange(0, ROW_BLOCK_SIZE)
# 计算列范围和行范围
col_range = col_idx * BLOCK_SIZE + col_offsets
row_range = row_idx * ROW_BLOCK_SIZE + row_offsets
# 创建列掩码
col_mask = col_range < n_cols
# 创建掩码,用于过滤超出数据范围的行列
mask = (row_range < n_rows)[:, None] & col_mask[None, :]
# 更新 dy_ptr 和 norm_ptr 指针位置
dy_ptr += row_range[:, None] * dy_stride + col_range[None, :]
norm_ptr += row_range[:, None] * norm_stride + col_range[None, :]
# 从指定位置加载 dy 和 norm 数据
dy = tl.load(dy_ptr, mask=mask, other=0.)
norm = tl.load(norm_ptr, mask=mask, other=0.)
# 计算 dgamma
dgamma = tl.sum(dy * norm, axis=0)
# 更新 dgamma_ptr 指针位置
dgamma_ptr += row_idx * dgamma_row_stride + col_range
# 存储计算得到的 dgamma 结果
tl.store(dgamma_ptr, dgamma, mask=col_mask)
# 定义一个 autograd 函数 _layernorm
class _layernorm(autograd.Function):
@classmethod
def forward(cls, ctx, x, gamma, eps, stable):
# 获取输入 x 的形状和维度
shape = x.shape
dim = shape[-1]
x = x.view(-1, dim)
n_rows, n_cols = x.shape
# 扩展 gamma 到与 x 相同的形状
expanded_gamma = gamma[None, :].expand(n_rows, -1)
# 计算 BLOCK_SIZE 和 num_warps
BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = calc_num_warps(BLOCK_SIZE)
# 创建一个与 x 相同形状的输出张量
out = torch.empty_like(x)
# 保存 eps 到上下文中
ctx.eps = eps
if x.requires_grad:
# 创建 scaled_x 和 normed 张量
scaled_x = torch.empty_like(x)
normed = torch.empty_like(x)
# 调用 layernorm_kernel_forward_training 函数进行前向传播计算
layernorm_kernel_forward_training[(n_rows,)](
out,
scaled_x,
normed,
x,
expanded_gamma,
x.stride(0),
expanded_gamma.stride(0),
out.stride(0),
scaled_x.stride(0),
normed.stride(0),
n_cols,
stable,
eps,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
# 保存 scaled_x, gamma, out 到上下文中
ctx.save_for_backward(scaled_x, gamma, out)
else:
# 调用 layernorm_kernel_forward_inference 函数进行前向传播计算(无梯度)
layernorm_kernel_forward_inference[(n_rows,)](
out,
x,
expanded_gamma,
x.stride(0),
expanded_gamma.stride(0),
out.stride(0),
n_cols,
stable,
eps,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
# 返回输出张量,并恢复原始形状
return out.view(*shape)
@classmethod
def backward(cls, ctx, dy):
# 获取 dy 的形状和设备信息
shape, device = dy.shape, dy.device
dim = shape[-1]
dy = dy.view(-1, dim)
# 从上下文中获取保存的 scaled_x, gamma, normed 张量
scaled_x, gamma, normed = ctx.saved_tensors
n_rows, n_cols = dy.shape
# 计算 num_col_programs 和 num_row_programs
num_col_programs = triton.cdiv(n_cols, GAMMA_BLOCK_SIZE)
num_row_programs = triton.cdiv(n_rows, GAMMA_ROW_BLOCK_SIZE)
# 创建一个用于存储 dgamma 的张量
dgamma = torch.empty((num_row_programs, n_cols), device=device)
# 调用 layernorm_gamma_kernel_backward 函数进行 gamma 反向传播计算
layernorm_gamma_kernel_backward[(num_col_programs, num_row_programs)](
dgamma,
normed,
dy,
normed.stride(0),
dy.stride(0),
dgamma.stride(0),
n_rows,
n_cols,
num_warps=4,
BLOCK_SIZE=GAMMA_BLOCK_SIZE,
BLOCK_SIZE_ROW=GAMMA_ROW_BLOCK_SIZE
)
# 对 dgamma 沿指定维度求和
dgamma = dgamma.sum(dim=0)
# 计算 dxhat 和 dx
dxhat = dy * gamma
dx = torch.empty_like(dy)
# 计算 BLOCK_SIZE 和 num_warps
BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = calc_num_warps(BLOCK_SIZE)
# 调用 layernorm_kernel_backward 函数进行反向传播计算
layernorm_kernel_backward[(n_rows,)](
dx,
dxhat,
scaled_x,
dx.stride(0),
dxhat.stride(0),
scaled_x.stride(0),
n_cols,
ctx.eps,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
# 恢复原始形状并返回 dx, dgamma
dx = dx.view(*shape)
return dx, dgamma, None, None
# 对输入数据进行 Layer Normalization 处理
def layernorm(x, gamma, eps = 1e-5, use_triton = False, stable = False):
# 如果使用 Triton 加速库
if use_triton:
# 调用 Triton 提供的 Layer Normalization 函数
out = _layernorm.apply(x, gamma, eps, stable)
else:
# 如果不使用 Triton 加速库
if stable:
# 对输入数据进行稳定处理,将每个元素除以最大值
x = x / torch.amax(x, dim = -1, keepdim = True)
# 使用 PyTorch 提供的 Layer Normalization 函数
out = F.layer_norm(x, (x.shape[-1],), gamma, torch.zeros_like(gamma), eps = eps)
# 返回处理后的数据
return out
.\lucidrains\triton-transformer\triton_transformer\softmax.py
# 导入 torch 库
import torch
# 从 torch 库中导入 autograd 模块
from torch import autograd
# 从 torch.nn.functional 模块中导入 F 函数
import torch.nn.functional as F
# 导入 triton 库
import triton
# 从 triton.language 模块中导入 tl
import triton.language as tl
# 从 triton_transformer.utils 模块中导入 calc_num_warps 函数
from triton_transformer.utils import calc_num_warps
# 定义 softmax_kernel_forward 函数,使用 triton.jit 装饰器
@triton.jit
def softmax_kernel_forward(
output_ptr,
input_ptr,
input_row_stride,
output_row_stride,
n_cols,
causal,
**meta
):
# 获取当前程序的行索引
row_idx = tl.program_id(0)
# 获取 meta 字典中的 BLOCK_SIZE 值
BLOCK_SIZE = meta['BLOCK_SIZE']
# 计算当前行的起始指针
row_start_ptr = input_ptr + row_idx * input_row_stride
# 生成列偏移量
col_offsets = tl.arange(0, BLOCK_SIZE)
# 计算输入指针
input_ptrs = row_start_ptr + col_offsets
# 创建掩码,用于处理超出列数的情况
mask = col_offsets < n_cols
# 从输入指针加载数据到 row 变量,处理超出列数的情况
row = tl.load(input_ptrs, mask = mask, other = -float('inf'))
# 如果是因果的情况,进行处理
if causal:
causal_mask = col_offsets > (row_idx % n_cols)
row = row + tl.where(causal_mask, -float('inf'), 0.)
# 计算 row 减去最大值
row_minus_max = row - tl.max(row, axis=0)
# 计算 softmax 的分子
numerator = tl.exp(row_minus_max)
# 计算 softmax 的分母
denominator = tl.sum(numerator, axis=0)
# 计算 softmax 输出
softmax_output = numerator / denominator
# 计算输出行的起始指针
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# 计算输出指针
output_ptrs = output_row_start_ptr + col_offsets
# 存储 softmax 输出到输出指针,处理超出列数的情况
tl.store(output_ptrs, softmax_output, mask = mask)
# 定义 softmax_kernel_backward 函数,使用 triton.jit 装饰器
@triton.jit
def softmax_kernel_backward(
output_ptr,
input_ptr,
grad_ptr,
grad_row_stride,
input_row_stride,
output_row_stride,
n_cols,
**meta
):
# 获取当前程序的行索引
row_idx = tl.program_id(0)
# 获取 meta 字典中的 BLOCK_SIZE 值
BLOCK_SIZE = meta['BLOCK_SIZE']
# 计算当前行的起始指针
row_start_ptr = input_ptr + row_idx * input_row_stride
grad_row_start_ptr = grad_ptr + row_idx * grad_row_stride
# 生成列偏移量
col_offsets = tl.arange(0, BLOCK_SIZE)
# 计算输入指针和梯度指���
input_ptrs = row_start_ptr + col_offsets
grad_ptrs = grad_row_start_ptr + col_offsets
# 创建掩码,用于处理超出列数的情况
mask = col_offsets < n_cols
# 从输入指针加载数据到 probs_row 变量,处理超出列数的情况
probs_row = tl.load(input_ptrs, mask = mask, other = 0.)
# 从梯度指针加载数据到 grad_row 变量,处理超出列数的情况
grad_row = tl.load(grad_ptrs, mask = mask, other = 0.)
# 计算 dxhat
dxhat = probs_row * grad_row
# 计算 softmax 梯度输出
softmax_grad_output = dxhat - probs_row * tl.sum(dxhat, axis = 0)
# 计算输出行的起始指针
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# 计算输出指针
output_ptrs = output_row_start_ptr + col_offsets
# 存储 softmax 梯度输出到输出指针,处理超出列数的情况
tl.store(output_ptrs, softmax_grad_output, mask = mask)
# 定义 _softmax 类,继承自 autograd.Function
class _softmax(autograd.Function):
# 前向传播函数
@classmethod
def forward(self, ctx, x, causal):
# 获取输入张量的形状
shape = x.shape
# 将输入张量展平为二维张量
x = x.view(-1, shape[-1])
n_rows, n_cols = x.shape
# 计算 BLOCK_SIZE 和 num_warps
BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = calc_num_warps(BLOCK_SIZE)
# 创建与输入张量相同形状的空张量 y
y = torch.empty_like(x)
# 调用 softmax_kernel_forward 函数进行前向传播计算
softmax_kernel_forward[(n_rows,)](
y,
x,
x.stride(0),
y.stride(0),
n_cols,
causal,
num_warps = num_warps,
BLOCK_SIZE = BLOCK_SIZE,
)
# 如果输入张量需要梯度,保存 y 用于反向传播
if x.requires_grad:
ctx.save_for_backward(y)
return y.view(*shape)
# 反向传播函数
@classmethod
def backward(self, ctx, grad_probs):
# 获取梯度张量的形状
shape = grad_probs.shape
# 从上下文中获取保存的张量 probs
probs, = ctx.saved_tensors
# 将梯度张量展平为二维张量
grad_probs = grad_probs.view(-1, grad_probs.shape[-1])
n_rows, n_cols = grad_probs.shape
# 计算 BLOCK_SIZE 和 num_warps
BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = calc_num_warps(BLOCK_SIZE)
# 创建与 probs 张量相同形状的空张量 dx
dx = torch.empty_like(probs)
# 调用 softmax_kernel_backward 函数进行反向传播计算
softmax_kernel_backward[(n_rows,)](
dx,
probs,
grad_probs,
grad_probs.stride(0),
probs.stride(0),
dx.stride(0),
n_cols,
num_warps = num_warps,
BLOCK_SIZE = BLOCK_SIZE
)
# 返回 dx 和 None,None 表示不需要额外的梯度信息
return dx.view(*shape), None
# 定义 triton_softmax 函数,调用 _softmax 类的 apply 方法
triton_softmax = _softmax.apply
# 定义 softmax 函数,实现 softmax 操作
def softmax(x, causal = False, use_triton = False):
# 如果使用 triton 进行计算
if use_triton:
# 调用 triton_softmax 函数
return triton_softmax(x, causal)
else:
# 使用 PyTorch 的 F.softmax 函数
return F.softmax(x, dim = -1)
.\lucidrains\triton-transformer\triton_transformer\transformer.py
# 导入必要的库
from functools import partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
# 导入自定义的模块
from triton_transformer.layernorm import layernorm
from triton_transformer.softmax import softmax
from triton_transformer.cross_entropy import cross_entropy_fn
from triton_transformer.bmm import fused_relu_squared
from triton_transformer.dropout import dropout_fn
from triton_transformer.utils import exists, default
# 定义类
class PreNormResidual(nn.Module):
def __init__(self, dim, fn, use_triton = False):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
self.use_triton = use_triton
def forward(self, x, **kwargs):
use_triton = kwargs.get('use_triton', self.use_triton)
normed = layernorm(x, self.norm.weight, use_triton = use_triton)
return self.fn(normed, **kwargs) + x
# 辅助类
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
causal = False,
dropout = 0.,
use_triton = False
):
super().__init__()
self.use_triton = use_triton
self.heads = heads
self.scale = dim_head ** -0.5
self.causal = causal
inner_dim = dim_head * heads
self.dropout = dropout
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, mask = None, use_triton = None):
use_triton = default(use_triton, self.use_triton)
h = self.heads
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
q = q * self.scale
sim = einsum('b i d, b j d -> b i j', q, k)
if exists(mask):
mask_value = -torch.finfo(sim.dtype).max
sim = sim.masked_fill(mask, mask_value)
attn = softmax(sim, causal = self.causal, use_triton = use_triton)
attn = dropout_fn(attn, self.dropout, use_triton = use_triton)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
out = self.to_out(out)
return dropout_fn(out, self.dropout, use_triton = use_triton)
class FeedForward(nn.Module):
def __init__(
self,
dim,
mult = 4,
dropout = 0.,
use_triton = False
):
super().__init__()
self.use_triton = use_triton
inner_dim = dim * mult
self.dropout = dropout
self.proj_in_weight = nn.Parameter(torch.randn(dim, inner_dim))
self.proj_out = nn.Linear(inner_dim, dim)
def forward(self, x, use_triton = None):
use_triton = default(use_triton, self.use_triton)
x = fused_relu_squared(x, self.proj_in_weight, use_triton = use_triton)
x = dropout_fn(x, self.dropout, use_triton = use_triton)
x = self.proj_out(x)
return x
# 主类
class Transformer(nn.Module):
def __init__(
self,
*,
dim,
num_tokens,
max_seq_len,
depth,
causal = False,
heads = 8,
dim_head = 64,
ff_dropout = 0.,
ff_mult = 4,
attn_dropout = 0.,
use_triton = False
):
# 调用父类的构造函数
super().__init__()
# 初始化最大序列长度
self.max_seq_len = max_seq_len
# 创建 token embedding 层
self.token_emb = nn.Embedding(num_tokens, dim)
# 创建位置 embedding 层
self.pos_emb = nn.Embedding(max_seq_len, dim)
# 初始化层列表
self.layers = nn.ModuleList([])
# 创建部分预归一化残差块
wrapper = partial(PreNormResidual, dim)
# 循环创建指定深度的注意力和前馈网络层
for _ in range(depth):
self.layers.append(nn.ModuleList([
wrapper(Attention(dim, heads = heads, dim_head = dim_head, causal = causal, dropout = attn_dropout, use_triton = use_triton)),
wrapper(FeedForward(dim, dropout = ff_dropout, mult = ff_mult, use_triton = use_triton))
]))
# 创建层归一化层
self.norm = nn.LayerNorm(dim)
# 创建输出层
self.to_logits = nn.Linear(dim, num_tokens)
# 创建掩码
self.use_triton = use_triton
self.causal = causal
# 根据是否自回归创建掩码
mask = torch.ones(max_seq_len, max_seq_len, dtype = torch.bool).triu(1) if causal else None
self.register_buffer('mask', mask, persistent = False)
def forward(
self,
x,
mask = None,
*,
labels = None,
use_triton = None
):
# 设置使用 Triton 加速的标志
use_triton = default(use_triton, self.use_triton)
# 获取序列长度和设备信息
n, device = x.shape[1], x.device
# 嵌入 token 并添加位置嵌入
x = self.token_emb(x)
pos_emb = self.pos_emb(torch.arange(n, device = device))
x = x + rearrange(pos_emb, 'n d -> () n d')
# 生成掩码,取决于是否自回归
assert not (self.causal and exists(mask)), 'mask is not needed during autoregressive mode'
if self.causal and not use_triton:
mask = self.mask[:n, :n]
mask = rearrange(mask, 'i j -> () i j')
elif not self.causal and exists(mask):
mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j')
mask = ~mask
# 通过层
for attn, ff in self.layers:
x = attn(x, mask = mask, use_triton = use_triton)
x = ff(x, use_triton = use_triton)
# 进行层归一化
x = layernorm(x, self.norm.weight, use_triton = use_triton, stable = True)
# 计算 logits
logits = self.to_logits(x)
if not exists(labels):
return logits
# 计算损失
loss = cross_entropy_fn(logits, labels, ignore_index = 0, use_triton = use_triton)
return loss
.\lucidrains\triton-transformer\triton_transformer\utils.py
# 检查值是否不为 None
def exists(val):
return val is not None
# 如果值存在,则返回该值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 根据块大小计算 warp 数量
def calc_num_warps(block_size):
# 默认 warp 数量为 4
num_warps = 4
# 如果块大小大于等于 2048,则 warp 数量为 8
if block_size >= 2048:
num_warps = 8
# 如果块大小大于等于 4096,则 warp 数量为 16
if block_size >= 4096:
num_warps = 16
# 返回 warp 数量
return num_warps
.\lucidrains\triton-transformer\triton_transformer\__init__.py
# 从 triton_transformer.transformer 模块中导入 Transformer 类
from triton_transformer.transformer import Transformer
Uformer - Pytorch
Implementation of Uformer, Attention-based Unet, in Pytorch. It will only offer the concat-cross-skip connection.
This repository will be geared towards use in a project for learning protein structures. Specifically, it will include the ability to condition on time steps (needed for DDPM), as well as 2d relative positional encoding using rotary embeddings (instead of the bias on the attention matrix in the paper).
Install
$ pip install uformer-pytorch
Usage
import torch
from uformer_pytorch import Uformer
model = Uformer(
dim = 64, # initial dimensions after input projection, which increases by 2x each stage
stages = 4, # number of stages
num_blocks = 2, # number of transformer blocks per stage
window_size = 16, # set window size (along one side) for which to do the attention within
dim_head = 64,
heads = 8,
ff_mult = 4
)
x = torch.randn(1, 3, 256, 256)
pred = model(x) # (1, 3, 256, 256)
To condition on time for DDPM training
import torch
from uformer_pytorch import Uformer
model = Uformer(
dim = 64,
stages = 4,
num_blocks = 2,
window_size = 16,
dim_head = 64,
heads = 8,
ff_mult = 4,
time_emb = True # set this to true
)
x = torch.randn(1, 3, 256, 256)
time = torch.arange(1)
pred = model(x, time = time) # (1, 3, 256, 256)
Citations
@misc{wang2021uformer,
title = {Uformer: A General U-Shaped Transformer for Image Restoration},
author = {Zhendong Wang and Xiaodong Cun and Jianmin Bao and Jianzhuang Liu},
year = {2021},
eprint = {2106.03106},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
.\lucidrains\uformer-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'uformer-pytorch', # 包的名称
packages = find_packages(), # 查找并包含所有包
version = '0.0.8', # 版本号
license='MIT', # 许可证信息
description = 'Uformer - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/uformer-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'transformers',
'image segmentation',
'unet'
],
install_requires=[ # 安装依赖列表
'einops>=0.3',
'torch>=1.6'
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\uformer-pytorch\uformer_pytorch\uformer_pytorch.py
# 导入 math 模块
import math
# 从 math 模块导入 log, pi, sqrt 函数
from math import log, pi, sqrt
# 从 functools 模块导入 partial 函数
from functools import partial
# 导入 torch 模块
import torch
# 从 torch 模块导入 nn, einsum 函数
from torch import nn, einsum
# 从 torch.nn 模块导入 functional 模块
import torch.nn.functional as F
# 导入 einops 模块中的 rearrange, repeat 函数
from einops import rearrange, repeat
# 定义常量 List 为 nn.ModuleList 类
List = nn.ModuleList
# 辅助函数
# 判断变量是否存在的函数
def exists(val):
return val is not None
# 返回默认值的函数
def default(val, d):
return val if exists(val) else d
# 将变量转换为元组的函数
def cast_tuple(val, depth = 1):
return val if isinstance(val, tuple) else (val,) * depth
# 位置嵌入
# 应用旋转位置嵌入的函数
def apply_rotary_emb(q, k, pos_emb):
sin, cos = pos_emb
dim_rotary = sin.shape[-1]
(q, q_pass), (k, k_pass) = map(lambda t: (t[..., :dim_rotary], t[..., dim_rotary:]), (q, k))
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass)))
return q, k
# 每两个元素旋转的函数
def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j = 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d j -> ... (d j)')
# 轴向旋转嵌入类
class AxialRotaryEmbedding(nn.Module):
def __init__(self, dim, max_freq = 10):
super().__init__()
self.dim = dim
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
self.register_buffer('scales', scales)
def forward(self, x):
device, dtype, h, w = x.device, x.dtype, *x.shape[-2:]
seq_x = torch.linspace(-1., 1., steps = h, device = device)
seq_x = seq_x.unsqueeze(-1)
seq_y = torch.linspace(-1., 1., steps = w, device = device)
seq_y = seq_y.unsqueeze(-1)
scales = self.scales[(*((None,) * (len(seq_x.shape) - 1)), Ellipsis)]
scales = scales.to(x)
scales = self.scales[(*((None,) * (len(seq_y.shape) - 1)), Ellipsis)]
scales = scales.to(x)
seq_x = seq_x * scales * pi
seq_y = seq_y * scales * pi
x_sinu = repeat(seq_x, 'i d -> i j d', j = w)
y_sinu = repeat(seq_y, 'j d -> i j d', i = h)
sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1)
cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1)
sin, cos = map(lambda t: rearrange(t, 'i j d -> i j d'), (sin, cos))
sin, cos = map(lambda t: repeat(t, 'i j d -> () i j (d r)', r = 2), (sin, cos))
return sin, cos
# 时间正弦位置嵌入类
class TimeSinuPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = device) * -emb)
emb = einsum('i, j -> i j', x, emb)
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
return emb
# 辅助类
# 层归一化类
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
# 预归一化��
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
# 注意力类
class Attention(nn.Module):
def __init__(self, dim, dim_head = 64, heads = 8, window_size = 16):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
self.window_size = window_size
inner_dim = dim_head * heads
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1)
# 定义前向传播函数,接受输入 x,跳跃连接 skip,默认时间嵌入 time_emb 和位置嵌入 pos_emb
def forward(self, x, skip = None, time_emb = None, pos_emb = None):
# 获取头数 h,窗口大小 w,输入张量的批量大小 b
h, w, b = self.heads, self.window_size, x.shape[0]
# 如果时间嵌入存在,则将其重排维度并与输入相加
if exists(time_emb):
time_emb = rearrange(time_emb, 'b c -> b c () ()')
x = x + time_emb
# 将输入 x 转换为查询向量 q
q = self.to_q(x)
# 将键值对输入设置为 x
kv_input = x
# 如果跳跃连接存在,则将其与键值对输入连接在一起
if exists(skip):
kv_input = torch.cat((kv_input, skip), dim = 0)
# 将键值对输入转换为键 k 和值 v,并按维度进行分块
k, v = self.to_kv(kv_input).chunk(2, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) x y c', h = h), (q, k, v))
# 如果位置嵌入存在,则应用旋转位置嵌入到查询 q 和键 k 上
if exists(pos_emb):
q, k = apply_rotary_emb(q, k, pos_emb)
# 重排查询 q、键 k 和值 v 的维度
q, k, v = map(lambda t: rearrange(t, 'b (x w1) (y w2) c -> (b x y) (w1 w2) c', w1 = w, w2 = w), (q, k, v))
# 如果跳跃连接存在,则对键 k 和值 v 进行维度重排
if exists(skip):
k, v = map(lambda t: rearrange(t, '(r b) n d -> b (r n) d', r = 2), (k, v))
# 计算注意力相似度矩阵
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
# 对相似度矩阵进行 softmax 操作得到注意力权重
attn = sim.softmax(dim = -1)
# 根据注意力权重计算输出
out = einsum('b i j, b j d -> b i d', attn, v)
# 重排输出的维度
out = rearrange(out, '(b h x y) (w1 w2) c -> b (h c) (x w1) (y w2)', b = b, h = h, y = x.shape[-1] // w, w1 = w, w2 = w)
# 将输出传递给输出层并返回结果
return self.to_out(out)
# 定义一个前馈神经网络模块
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4):
super().__init__()
hidden_dim = dim * mult
# 输入投影层,将输入维度转换为隐藏维度
self.project_in = nn.Conv2d(dim, hidden_dim, 1)
# 输出投影层,包含卷积、GELU激活函数和再次卷积
self.project_out = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding = 1),
nn.GELU(),
nn.Conv2d(hidden_dim, dim, 1)
)
def forward(self, x, time_emb = None):
# 对输入进行投影
x = self.project_in(x)
# 如果存在时间嵌入,则将其重排并加到输入上
if exists(time_emb):
time_emb = rearrange(time_emb, 'b c -> b c () ()')
x = x + time_emb
# 返回经过输出投影层的结果
return self.project_out(x)
# 定义一个块模块
class Block(nn.Module):
def __init__(
self,
dim,
depth,
dim_head = 64,
heads = 8,
ff_mult = 4,
window_size = 16,
time_emb_dim = None,
rotary_emb = True
):
super().__init__()
self.attn_time_emb = None
self.ff_time_emb = None
# 如果存在时间嵌入维度,则创建注意力和前馈的时间嵌入
if exists(time_emb_dim):
self.attn_time_emb = nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
self.ff_time_emb = nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim * ff_mult))
# 如果使用轴向旋转嵌入,则创建位置嵌入
self.pos_emb = AxialRotaryEmbedding(dim_head) if rotary_emb else None
# 创建多个块层
self.layers = List([])
for _ in range(depth):
self.layers.append(List([
PreNorm(dim, Attention(dim, dim_head = dim_head, heads = heads, window_size = window_size)),
PreNorm(dim, FeedForward(dim, mult = ff_mult))
]))
def forward(self, x, skip = None, time = None):
attn_time_emb = None
ff_time_emb = None
# 如果存在时间信息,则计算注意力和前馈的时间嵌入
if exists(time):
assert exists(self.attn_time_emb) and exists(self.ff_time_emb), 'time_emb_dim must be given on init if you are conditioning based on time'
attn_time_emb = self.attn_time_emb(time)
ff_time_emb = self.ff_time_emb(time)
pos_emb = None
# 如果存在位置嵌入,则计算位置嵌入
if exists(self.pos_emb):
pos_emb = self.pos_emb(x)
# 遍历每个块层,进行注意力和前馈操作
for attn, ff in self.layers:
x = attn(x, skip = skip, time_emb = attn_time_emb, pos_emb = pos_emb) + x
x = ff(x, time_emb = ff_time_emb) + x
# 返回处理后的结果
return x
# 定义一个 Uformer 模块
class Uformer(nn.Module):
def __init__(
self,
dim = 64,
channels = 3,
stages = 4,
num_blocks = 2,
dim_head = 64,
window_size = 16,
heads = 8,
ff_mult = 4,
time_emb = False,
input_channels = None,
output_channels = None
):
# 调用父类的构造函数
super().__init__()
# 设置输入通道数为默认值或者与输出通道数相同
input_channels = default(input_channels, channels)
output_channels = default(output_channels, channels)
self.to_time_emb = None
time_emb_dim = None
# 如果需要时间嵌入
if time_emb:
time_emb_dim = dim
# 创建时间嵌入层
self.to_time_emb = nn.Sequential(
TimeSinuPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
# 输入通道到维度转换
self.project_in = nn.Sequential(
nn.Conv2d(input_channels, dim, 3, padding = 1),
nn.GELU()
)
# 维度到输出通道转换
self.project_out = nn.Sequential(
nn.Conv2d(dim, output_channels, 3, padding = 1),
)
# 下采样和上采样列表
self.downs = List([])
self.ups = List([])
# 将参数转换为指定深度的元组
heads, window_size, dim_head, num_blocks = map(partial(cast_tuple, depth = stages), (heads, window_size, dim_head, num_blocks))
# 遍历各个阶段
for ind, heads, window_size, dim_head, num_blocks in zip(range(stages), heads, window_size, dim_head, num_blocks):
is_last = ind == (stages - 1)
# 添加下采样模块
self.downs.append(List([
Block(dim, depth = num_blocks, dim_head = dim_head, heads = heads, ff_mult = ff_mult, window_size = window_size, time_emb_dim = time_emb_dim),
nn.Conv2d(dim, dim * 2, 4, stride = 2, padding = 1)
]))
# 添加上采样模块
self.ups.append(List([
nn.ConvTranspose2d(dim * 2, dim, 2, stride = 2),
Block(dim, depth = num_blocks, dim_head = dim_head, heads = heads, ff_mult = ff_mult, window_size = window_size, time_emb_dim = time_emb_dim)
]))
dim *= 2
# 如果是最后一个阶段,设置中间模块
if is_last:
self.mid = Block(dim = dim, depth = num_blocks, dim_head = dim_head, heads = heads, ff_mult = ff_mult, window_size = window_size, time_emb_dim = time_emb_dim)
# 前向传播函数
def forward(
self,
x,
time = None
):
# 如果存在时间信息
if exists(time):
assert exists(self.to_time_emb), 'time_emb must be set to true to condition on time'
time = time.to(x)
time = self.to_time_emb(time)
# 输入数据通过输入通道转换
x = self.project_in(x)
skips = []
# 对下采样模块进行迭代
for block, downsample in self.downs:
x = block(x, time = time)
skips.append(x)
x = downsample(x)
# 中间模块
x = self.mid(x, time = time)
# 对上采样模块进行迭代
for (upsample, block), skip in zip(reversed(self.ups), reversed(skips)):
x = upsample(x)
x = block(x, skip = skip, time = time)
# 输出数据通过输出通道转换
x = self.project_out(x)
return x