Lucidrains 系列项目源码解析(五十六)
.\lucidrains\med-seg-diff-pytorch\med_seg_diff_pytorch\__init__.py
# 从med_seg_diff_pytorch.med_seg_diff_pytorch模块中导入MedSegDiff和Unet类
from med_seg_diff_pytorch.med_seg_diff_pytorch import MedSegDiff, Unet

MedSegDiff - Pytorch
Implementation of MedSegDiff in Pytorch - SOTA medical segmentation out of Baidu using DDPM and enhanced conditioning on the feature level, with filtering of features in fourier space.
Appreciation
-
StabilityAI for the generous sponsorship, as well as my other sponsors out there
-
Isamu and Daniel for adding a training script for a skin lesion dataset!
Install
$ pip install med-seg-diff-pytorch
Usage
import torch
from med_seg_diff_pytorch import Unet, MedSegDiff
model = Unet(
dim = 64,
image_size = 128,
mask_channels = 1, # segmentation has 1 channel
input_img_channels = 3, # input images have 3 channels
dim_mults = (1, 2, 4, 8)
)
diffusion = MedSegDiff(
model,
timesteps = 1000
).cuda()
segmented_imgs = torch.rand(8, 1, 128, 128) # inputs are normalized from 0 to 1
input_imgs = torch.rand(8, 3, 128, 128)
loss = diffusion(segmented_imgs, input_imgs)
loss.backward()
# after a lot of training
pred = diffusion.sample(input_imgs) # pass in your unsegmented images
pred.shape # predicted segmented images - (8, 3, 128, 128)
Training
Command to run
accelerate launch driver.py --mask_channels=1 --input_img_channels=3 --image_size=64 --data_path='./data' --dim=64 --epochs=100 --batch_size=1 --scale_lr --gradient_accumulation_steps=4
If you want to add in self condition where we condition with the mask we have so far, do --self_condition
Todo
- some basic training code, with Trainer taking in custom dataset tailored for medical image formats - thanks to @isamu-isozaki
- full blown transformer of any depth in the middle, as done in simple diffusion
Citations
@article{Wu2022MedSegDiffMI,
title = {MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model},
author = {Junde Wu and Huihui Fang and Yu Zhang and Yehui Yang and Yanwu Xu},
journal = {ArXiv},
year = {2022},
volume = {abs/2211.00611}
}
@inproceedings{Hoogeboom2023simpleDE,
title = {simple diffusion: End-to-end diffusion for high resolution images},
author = {Emiel Hoogeboom and Jonathan Heek and Tim Salimans},
year = {2023}
}
.\lucidrains\med-seg-diff-pytorch\sample.py
# 导入所需的库
import os
import argparse
from tqdm import tqdm
import torch
import torchvision.transforms as transforms
from med_seg_diff_pytorch import Unet, MedSegDiff
from med_seg_diff_pytorch.dataset import ISICDataset, GenericNpyDataset
from accelerate import Accelerator
import skimage.io as io
## 解析命令行参数 ##
def parse_args():
# 创建参数解析器
parser = argparse.ArgumentParser()
# 添加命令行参数
parser.add_argument('-od', '--output_dir', type=str, default="output", help="Output dir.")
parser.add_argument('-ld', '--logging_dir', type=str, default="logs", help="Logging dir.")
parser.add_argument('-mp', '--mixed_precision', type=str, default="no", choices=["no", "fp16", "bf16"],
help="Whether to do mixed precision")
parser.add_argument('-img', '--img_folder', type=str, default='ISBI2016_ISIC_Part3B_Training_Data',
help='The image file path from data_path')
parser.add_argument('-csv', '--csv_file', type=str, default='ISBI2016_ISIC_Part3B_Training_GroundTruth.csv',
help='The csv file to load in from data_path')
parser.add_argument('-sc', '--self_condition', action='store_true', help='Whether to do self condition')
parser.add_argument('-ic', '--mask_channels', type=int, default=1, help='input channels for training (default: 3)')
parser.add_argument('-c', '--input_img_channels', type=int, default=3,
help='output channels for training (default: 3)')
parser.add_argument('-is', '--image_size', type=int, default=128, help='input image size (default: 128)')
parser.add_argument('-dd', '--data_path', default='./data', help='directory of input image')
parser.add_argument('-d', '--dim', type=int, default=64, help='dim (default: 64)')
parser.add_argument('-e', '--epochs', type=int, default=10000, help='number of epochs (default: 10000)')
parser.add_argument('-bs', '--batch_size', type=int, default=8, help='batch size to train on (default: 8)')
parser.add_argument('--timesteps', type=int, default=1000, help='number of timesteps (default: 1000)')
parser.add_argument('-ds', '--dataset', default='generic', help='Dataset to use')
parser.add_argument('--save_every', type=int, default=100, help='save_every n epochs (default: 100)')
parser.add_argument('--num_ens', type=int, default=5,
help='number of times to sample to make an ensable of predictions like in the paper (default: 5)')
parser.add_argument('--load_model_from', default=None, help='path to pt file to load from')
parser.add_argument('--save_uncertainty', action='store_true',
help='Whether to store the uncertainty in predictions (only works for ensablmes)')
# 解析命令行参数并返回
return parser.parse_args()
def load_data(args):
# 加载数据集
if args.dataset == 'ISIC':
# 定义数据转换
transform_list = [transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), ]
transform_train = transforms.Compose(transform_list)
# 创建 ISIC 数据集对象
dataset = ISICDataset(args.data_path, args.csv_file, args.img_folder, transform=transform_train, training=False,
flip_p=0.5)
elif args.dataset == 'generic':
# 定义数据转换
transform_list = [transforms.ToPILImage(), transforms.Resize(args.image_size), transforms.ToTensor()]
transform_train = transforms.Compose(transform_list)
# 创建通用 Npy 数据集对象
dataset = GenericNpyDataset(args.data_path, transform=transform_train, test_flag=True)
else:
# 抛出未实现的错误
raise NotImplementedError(f"Your dataset {args.dataset} hasn't been implemented yet.")
## 定义 PyTorch 数据生成器
training_generator = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False)
return training_generator
def main():
# 解析命令行参数
args = parse_args()
# 设置日志目录
logging_dir = os.path.join(args.output_dir, args.logging_dir)
inference_dir = os.path.join(args.output_dir, 'inference')
# 创建推断目录
os.makedirs(inference_dir, exist_ok=True)
# 创建加速器对象,用于混合精度训练
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
)
# 定义模型
model = Unet(
dim=args.dim,
image_size=args.image_size,
dim_mults=(1, 2, 4, 8),
mask_channels=args.mask_channels,
input_img_channels=args.input_img_channels,
self_condition=args.self_condition
)
# 加载数据
data_loader = load_data(args)
# 创建 MedSegDiff 对象,用于扩散过程
diffusion = MedSegDiff(
model,
timesteps=args.timesteps
).to(accelerator.device)
# 如果指定了加载模型的路径,则加载模型参数
if args.load_model_from is not None:
save_dict = torch.load(args.load_model_from)
diffusion.model.load_state_dict(save_dict['model_state_dict'])
# 遍历数据加载器中的数据
for (imgs, masks, fnames) in tqdm(data_loader):
# 预先分配预测结果的空间
preds = torch.zeros((imgs.shape[0], args.num_ens, imgs.shape[2], imgs.shape[3]))
# 对每个样本进行多次采样
for i in range(args.num_ens):
preds[:, i:i+1, :, :] = diffusion.sample(imgs).cpu().detach()
# 计算预测结果的均值和标准差
preds_mean = preds.mean(dim=1)
preds_std = preds.std(dim=1)
# 保存预测结果
for idx in range(preds.shape[0]):
io.imsave(os.path.join(inference_dir, fnames[idx].replace('.npy', '.png')), preds_mean[idx, :, :])
# 如果需要保存不确定性信息,则保存预测结果的标准差
if args.save_uncertainty:
io.imsave(os.path.join(inference_dir, fnames[idx].replace('.npy', '_std.png')), preds_std[idx, :, :])
# 如果当前脚本被直接执行,则调用主函数
if __name__ == '__main__':
main()
.\lucidrains\med-seg-diff-pytorch\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的信息
setup(
# 包名
name = 'med-seg-diff-pytorch',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.3.3',
# 许可证
license='MIT',
# 描述
description = 'MedSegDiff - SOTA medical image segmentation - Pytorch',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目链接
url = 'https://github.com/lucidrains/med-seg-diff-pytorch',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'denoising diffusion',
'medical segmentation'
],
# 安装依赖
install_requires = [
'beartype',
'einops',
'lion-pytorch',
'torch',
'torchvision',
'tqdm',
'accelerate>=0.25.0',
'wandb'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
Medical AI Experiments (wip)
A repository to house some personal attempts to beat some state-of-the-art for medical datasets. Will start with basic arrhythmia detection and work my way up to EEG seizure classification / detection.
I will apply everything that I know from the attention field.
.\lucidrains\medical-chatgpt\medical_chatgpt\medical_chatgpt.py
# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch 中导入 nn 和 einsum 模块
from torch import nn, einsum
# 从 einops 中导入 rearrange 函数
from einops import rearrange
# 定义函数,判断变量是否存在
def exists(val):
return val is not None
# 定义函数,返回默认值
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# 定义注意力机制类
class Attention(nn.Module):
def __init__(
self,
dim,
causal = False,
dim_head = 64,
dim_context = None,
heads = 8,
norm_context = False,
num_null_kv = 0,
dropout = 0.1
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.causal = causal
inner_dim = dim_head * heads
dim_context = default(dim_context, dim)
self.norm = nn.LayerNorm(dim)
self.context_norm = nn.LayerNorm(dim_context) if norm_context else nn.Identity()
self.attn_dropout = nn.Dropout(dropout)
self.num_null_kv = num_null_kv
self.null_kv = nn.Parameter(torch.randn(2, num_null_kv, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
nn.Dropout(dropout)
)
def forward(
self,
x,
context = None,
mask = None,
attn_bias = None
):
b = x.shape[0]
if exists(context):
context = self.context_norm(context)
kv_input = default(context, x)
x = self.norm(x)
q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)
if self.num_null_kv > 0:
null_k, null_v = repeat(self.null_kv, 'kv n d -> kv b n d', b = b).unbind(dim = 0)
k = torch.cat((null_k, k), dim = -2)
v = torch.cat((null_v, v), dim = -2)
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
q = q * self.scale
sim = einsum('b h i d, b j d -> b h i j', q, k)
if exists(attn_bias):
attn_bias = F.pad(attn_bias, (self.num_null_kv, 0), value = 0.)
sim = sim + attn_bias
if exists(mask):
mask = F.pad(mask, (self.num_null_kv, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
attn = sim.softmax(dim = -1)
attn = self.attn_dropout(attn)
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
.\lucidrains\medical-chatgpt\medical_chatgpt\__init__.py
# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
# 计算矩形的面积
area = length * width
# 返回计算得到的面积
return area
Explorations into training a ChatGPT, but tailored towards primary care medicine, with the reward being able to collect patient histories in a thorough and efficient manner and come up with a differential diagnosis. May also explore to see if it can be further fine-tuned on pirated copies of Up-To-Date for specialist knowledge
Sadly, I no longer think this is possible in its current state. It will probably see some utility with scribing of basic bread and butter cases; however assess and plan it cannot.
Citations
@inproceedings{Singhal2022LargeLM,
title = {Large Language Models Encode Clinical Knowledge},
author = {Karan Singhal and Shekoofeh Azizi and Tao Tu and Said Mahdavi and Jason Lee Kai Wei and Hyung Won Chung and Nathan Scales and Ajay Kumar Tanwani and Heather J. Cole-Lewis and Stephen J. Pfohl and P A Payne and Martin G. Seneviratne and Paul Gamble and Chris Kelly and Nathaneal Scharli and Aakanksha Chowdhery and P. D. Mansfield and Blaise Ag{\"u}era y Arcas and Dale R. Webster and Greg S. Corrado and Y. Matias and Katherine Hui-Ling Chou and Juraj Gottweis and Nenad Toma{\vs}ev and Yun Liu and Alvin Rajkomar and Jo{\"e}lle K. Barral and Christopher Semturs and Alan Karthikesalingam and Vivek Natarajan},
year = {2022}
}
@article {Kung2022.12.19.22283643,
author = {Kung, Tiffany H. and Cheatham, Morgan and , and Medenilla, Arielle and Sillos, Czarina and De Leon, Lorie and Elepa{\~n}o, Camille and Madriaga, Maria and Aggabao, Rimel and Diaz-Candido, Giezel and Maningo, James and Tseng, Victor},
title = {Performance of ChatGPT on USMLE: Potential for AI-Assisted Medical Education Using Large Language Models},
elocation-id = {2022.12.19.22283643},
year = {2022},
doi = {10.1101/2022.12.19.22283643},
publisher = {Cold Spring Harbor Laboratory Press},
URL = {https://www.medrxiv.org/content/early/2022/12/21/2022.12.19.22283643},
eprint = {https://www.medrxiv.org/content/early/2022/12/21/2022.12.19.22283643.full.pdf},
journal = {medRxiv}
}
@misc{https://doi.org/10.48550/arxiv.2301.10035,
doi = {10.48550/ARXIV.2301.10035},
url = {https://arxiv.org/abs/2301.10035},
author = {Nov, Oded and Singh, Nina and Mann, Devin},
keywords = {Human-Computer Interaction (cs.HC), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Putting ChatGPT's Medical Advice to the (Turing) Test},
publisher = {arXiv},
year = {2023},
copyright = {Creative Commons Attribution Share Alike 4.0 International}
}
@inproceedings{Schick2023ToolformerLM,
title = {Toolformer: Language Models Can Teach Themselves to Use Tools},
author = {Timo Schick and Jane Dwivedi-Yu and Roberto Dessi and Roberta Raileanu and Maria Lomeli and Luke Zettlemoyer and Nicola Cancedda and Thomas Scialom},
year = {2023}
}
@inproceedings{Peng2023CheckYF,
title = {Check Your Facts and Try Again: Improving Large Language Models with External Knowledge and Automated Feedback},
author = {Baolin Peng and Michel Galley and Pengcheng He and Hao Cheng and Yujia Xie and Yu Hu and Qiuyuan Huang and Lars Lid{\'e}n and Zhou Yu and Weizhu Chen and Jianfeng Gao},
year = {2023}
}
@inproceedings{Nori2023CapabilitiesOG,
title = {Capabilities of GPT-4 on Medical Challenge Problems},
author = {Harsha Nori and Nicholas King and Scott Mayer McKinney and Dean Carignan and Eric Horvitz},
year = {2023}
}
.\lucidrains\medical-chatgpt\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'medical-chatgpt', # 包名
packages = find_packages(exclude=[]), # 查找所有包
version = '0.0.1', # 版本号
license='MIT', # 许可证
description = 'Medical ChatGPT', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/medical-chatgpt', # URL
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism',
'reinforcement learning with human feedback'
],
install_requires=[ # 安装依赖
'einops>=0.6',
'django-ninja',
'torch>=1.6',
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
.\lucidrains\Mega-pytorch\mega_pytorch\autoregressive_wrapper.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 定义一个辅助函数 exists,用于检查值是否存在
def exists(val):
return val is not None
# 定义一个装饰器 eval_decorator,用于在模型评估时切换模型状态
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# 定义一个函数 top_k,用于对 logits 进行 top-k 过滤
def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# 定义一个类 AutoregressiveWrapper,用于包装模型
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, pad_value = 0):
super().__init__()
self.pad_value = pad_value
self.net = net
# 生成函数,用于生成序列
@torch.no_grad()
@eval_decorator
def generate(self, start_tokens, seq_len, temperature = 1., filter_thres = 0.9, **kwargs):
b, t, device = *start_tokens.shape, start_tokens.device
out = start_tokens
for _ in range(seq_len):
logits = self.net(out, **kwargs)[:, -1, :]
filtered_logits = top_k(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
out = out[:, t:]
return out
# 前向传播函数,用于计算损失
def forward(self, x, **kwargs):
x_inp, x_labels = x[:, :-1], x[:, 1:]
logits = self.net(x_inp, **kwargs)
return F.cross_entropy(rearrange(logits, 'b c n -> b n c'), x_labels)
.\lucidrains\Mega-pytorch\mega_pytorch\mega_pytorch.py
# 导入数学库
import math
# 从 functools 库中导入 partial 函数
from functools import partial
# 导入 torch 库
import torch
# 从 torch.nn.functional 模块中导入 F
import torch.nn.functional as F
# 从 torch 模块中导入 nn 和 einsum
from torch import nn, einsum
# 从 torch.fft 模块中导入 rfft 和 irfft
from torch.fft import rfft, irfft
# 从 einops 库中导入 rearrange 和 Rearrange
from einops import rearrange
from einops.layers.torch import Rearrange
# 从 scipy.fftpack 模块中导入 next_fast_len 函数
# functions
# 判断变量是否存在的函数
def exists(val):
return val is not None
# 返回输入的函数
def identity(t, *args, **kwargs):
return t
# 如果输入值存在则返回输入值,否则返回默认值的函数
def default(val, d):
return val if exists(val) else d
# 在输入张量的末尾添加指定数量的维度的函数
def append_dims(x, num_dims):
if num_dims <= 0:
return x
return x.view(*x.shape, *((1,) * num_dims))
# 使用傅立叶技巧进行 O(N log(N)) 的 1D 卷积的函数
def conv1d_fft(x, weights, dim = -2, weight_dim = -1):
# O(N log(N)) 1d convolution using some fourier trick
assert weight_dim >= dim
N = x.shape[dim]
M = weights.shape[weight_dim]
fast_len = next_fast_len(N + M - 1)
f_x = rfft(x, n = fast_len, dim = dim)
f_weight = rfft(weights, n = fast_len, dim = weight_dim)
f_v_weight = f_x * append_dims(f_weight.conj(), weight_dim - dim)
out = irfft(f_v_weight, fast_len, dim = dim)
out = out.roll(-1, dims = (dim,))
indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device)
out = out.index_select(dim, indices)
return out
# 用于单头注意力的位置偏置类
class T5RelativePositionBias(nn.Module):
def __init__(
self,
scale,
causal = False,
num_buckets = 32,
max_distance = 128
):
super().__init__()
self.scale = scale
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, 1)
@staticmethod
def _relative_position_bucket(
relative_position,
causal = True,
num_buckets = 32,
max_distance = 128
):
ret = 0
n = -relative_position
if not causal:
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, x):
i, j, device = *x.shape[-2:], x.device
q_pos = torch.arange(i, dtype = torch.long, device = device)
k_pos = torch.arange(j, dtype = torch.long, device = device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, 'i j 1 -> i j')
return bias * self.scale
# classes
# 拉普拉斯注意力函数类
class LaplacianAttnFn(nn.Module):
def forward(self, x):
mu = math.sqrt(0.5)
std = math.sqrt((4 * math.pi) ** -1)
return (1 + torch.special.erf((x - mu) / (std * math.sqrt(2)))) * 0.5
# 偏移和缩放类
class OffsetScale(nn.Module):
def __init__(self, dim, heads = 1):
super().__init__()
self.gamma = nn.Parameter(torch.ones(heads, dim))
self.beta = nn.Parameter(torch.zeros(heads, dim))
nn.init.normal_(self.gamma, std = 0.02)
def forward(self, x):
out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
return out.unbind(dim = -2)
# 单头注意力类
class SingleHeadedAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_qk,
dim_value,
causal = False,
laplacian_attn_fn = False
# 初始化 Transformer 层
def __init__(
self,
causal: bool = False,
laplacian_attn_fn: bool = False
):
# 调用父类的初始化方法
super().__init__()
# 设置是否使用因果关系和 Laplacian 注意力函数
self.causal = causal
self.laplacian_attn_fn = laplacian_attn_fn
# 根据是否使用 Laplacian 注意力函数选择不同的注意力函数
self.attn_fn = partial(F.softmax, dim = -1) if not laplacian_attn_fn else LaplacianAttnFn()
# 初始化相对位置偏置
self.rel_pos_bias = T5RelativePositionBias(causal = causal, scale = dim_qk ** 0.5)
# 将输入转换为查询和键值对
self.to_qk = nn.Sequential(
nn.Linear(dim, dim_qk),
nn.SiLU()
)
# 初始化偏移和缩放层
self.offsetscale = OffsetScale(dim_qk, heads = 2)
# 将输入转换为值
self.to_v = nn.Sequential(
nn.Linear(dim, dim_value),
nn.SiLU()
)
# 前向传播函数
def forward(self, x, v_input = None):
# 获取序列长度、维度、设备和数据类型
seq_len, dim, device, dtype = *x.shape[-2:], x.device, x.dtype
# 如果未提供值输入,则使用 x 作为值输入
v_input = default(v_input, x)
# 将输入转换为查询、键和值
qk, v = self.to_qk(x), self.to_v(v_input)
q, k = self.offsetscale(qk)
# 计算缩放因子
scale = (seq_len ** -1) if self.laplacian_attn_fn else (dim ** -0.5)
# 计算注意力矩阵
sim = einsum('b i d, b j d -> b i j', q, k) * scale
# 添加相对位置偏置
sim = sim + self.rel_pos_bias(sim)
# 如果使用因果关系,则创建因果 mask
if self.causal:
causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).triu(1)
# 如果使用因果关系且不使用 Laplacian 注意力函数,则对注意力矩阵进行 mask 处理
if self.causal and not self.laplacian_attn_fn:
# 如果是 softmax 注意力并且使用大的负值作为 softmax 前的值
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 计算注意力权重
attn = self.attn_fn(sim)
# 如果使用因果关系且使用 Laplacian 注意力函数,则将上三角部分置为 0
if self.causal and self.laplacian_attn_fn:
# 如果使用 Laplacian 注意力函数,则将上三角部分置为 0
attn = attn.masked_fill(causal_mask, 0.)
# 计算输出值
return einsum('b i j, b j d -> b i d', attn, v)
class MultiHeadedEMA(nn.Module):
# 定义多头EMA模块
def __init__(
self,
*,
dim,
heads,
bidirectional = False,
norm_mhesa_heads = False
):
# 初始化函数
super().__init__()
self.bidirectional = bidirectional
# 初始化参数
self.expansion = nn.Parameter(torch.randn(heads * (2 if bidirectional else 1), dim))
self.reduction = nn.Parameter(torch.randn(heads * (2 if bidirectional else 1), dim))
# 学习的alpha和阻尼因子
self.alphas = nn.Parameter(torch.randn(heads))
self.dampen_factors = nn.Parameter(torch.randn(heads))
if bidirectional:
self.reverse_alphas = nn.Parameter(torch.randn(heads))
self.reverse_dampen_factors = nn.Parameter(torch.randn(heads))
self.heads = heads
self.norm_heads = nn.Identity()
if norm_mhesa_heads:
# 使用子层归一化作为组归一化
self.norm_heads = nn.Sequential(
Rearrange('b n h d -> b (h d) n'),
nn.GroupNorm(heads, dim * heads),
Rearrange('b (h d) n -> b n h d', h = heads)
)
def forward(self, x):
# 前向传播函数
device, seq_len = x.device, x.shape[1]
# 投影并分割头部
x = einsum('... d, h d -> ... h d', x, self.expansion)
if self.bidirectional:
x, x_reversed = x.chunk(2, dim = -2)
x_reversed = torch.flip(x_reversed, dims = (1,))
# 从alphas派生的权重(学习的指数平滑衰减率)
def apply_learned_ema_with_damping(x, alphas, dampen_factors):
alphas = alphas.sigmoid()
dampen_factors = dampen_factors.sigmoid()
reversed_powers = torch.arange(seq_len - 1, -1, -1, device = device)
K = alphas * (((1 - alphas) * dampen_factors) ** rearrange(reversed_powers, '... l -> ... l 1'))
# 使用conv1d fft计算
return conv1d_fft(x, K, dim = -3, weight_dim = -2)
x = apply_learned_ema_with_damping(x, self.alphas, self.dampen_factors)
if self.bidirectional:
x_reversed = apply_learned_ema_with_damping(x_reversed, self.reverse_alphas, self.reverse_dampen_factors)
x_reversed = torch.flip(x_reversed, dims = (1,))
x = torch.cat((x, x_reversed), dim = -2)
# 可能归一化头部
x = self.norm_heads(x)
# 合并头部和输出
return einsum('... h d, h d -> ... d', x, self.reduction)
# Mega Layer
# 单头注意力 + 多头EMA,然后是类似GRU的门控
class MegaLayer(nn.Module):
# 定义MegaLayer模块
def __init__(
self,
*,
dim = 128,
ema_heads = 16,
attn_dim_qk = 64,
attn_dim_value = 256,
laplacian_attn_fn = False,
causal = True,
norm_mhesa_heads = False
):
# 初始化函数
super().__init__()
# 单头注意力
self.single_headed_attn = SingleHeadedAttention(
dim = dim,
dim_qk = attn_dim_qk,
dim_value = attn_dim_value,
causal = causal,
laplacian_attn_fn = laplacian_attn_fn
)
# 多头EMA
self.multi_headed_ema = MultiHeadedEMA(
dim = dim,
heads = ema_heads,
bidirectional = not causal,
norm_mhesa_heads = norm_mhesa_heads
)
# 重置门
self.to_reset_gate = nn.Sequential(
nn.Linear(dim, attn_dim_value),
nn.SiLU()
)
# 更新门
self.to_update_gate = nn.Sequential(
nn.Linear(dim, dim),
nn.Sigmoid()
)
# 计算H的方程式14
self.Wh = nn.Parameter(torch.randn(dim, dim))
self.Uh = nn.Parameter(torch.randn(attn_dim_value, dim))
self.bh = nn.Parameter(torch.randn(dim))
# 定义前向传播函数,接受输入 x 和残差 residual,默认为 None
def forward(self, x, residual = None):
# 如果没有传入残差,则使用 x 作为默认值
residual = default(residual, x)
# 使用多头 EMA 模型处理输入 x
ema_output = self.multi_headed_ema(x)
# 使用单头注意力模型处理 EMA 输出和输入 x
attn_output = self.single_headed_attn(ema_output, x)
# 计算重置门和更新门
reset_gate = self.to_reset_gate(ema_output)
update_gate = self.to_update_gate(ema_output)
# 使用重置门对注意力输出进行门控
gated_attn_output = attn_output * reset_gate
# 计算 H,根据方程式 14
H = F.silu(ema_output @ self.Wh + gated_attn_output @ self.Uh + self.bh)
# 更新门
return update_gate * H + (1 - update_gate) * residual
# 定义一个前馈神经网络层,包括线性层、GELU激活函数和另一个线性层
def FeedForward(dim, ff_mult):
# 计算隐藏层维度
dim_hidden = int(dim * ff_mult)
return nn.Sequential(
nn.Linear(dim, dim_hidden), # 输入维度为dim,输出维度为dim_hidden的线性层
nn.GELU(), # GELU激活函数
nn.Linear(dim_hidden, dim) # 输入维度为dim_hidden,输出维度为dim的线性层
)
# 定义一个Mega类,继承自nn.Module
class Mega(nn.Module):
def __init__(
self,
*,
dim,
num_tokens,
depth,
ff_mult = 2,
pre_norm = False,
**kwargs
):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim) # 创建一个嵌入层,用于将token映射为dim维向量
self.pre_norm = pre_norm # 是否使用预层归一化
self.layers = nn.ModuleList([]) # 创建一个空的ModuleList,用于存储多个MegaLayer
# 循环depth次,创建多个MegaLayer及其相关层,并添加到layers中
for _ in range(depth):
self.layers.append(nn.ModuleList([
MegaLayer(dim = dim, **kwargs), # MegaLayer层
nn.LayerNorm(dim), # LayerNorm层
FeedForward(dim = dim, ff_mult = ff_mult), # FeedForward层
nn.LayerNorm(dim) # LayerNorm层
]))
# 创建一个Sequential模块,用于将模型输出映射为num_tokens维度
self.to_logits = nn.Sequential(
nn.LayerNorm(dim) if pre_norm else nn.Identity(), # 如果使用预层归一化,则使用LayerNorm,否则使用Identity
nn.Linear(dim, num_tokens) # 线性层,将dim维度映射为num_tokens维度
)
# 前向传播函数
def forward(self, x):
pre_norm = self.pre_norm
post_norm = not self.pre_norm
x = self.token_emb(x) # 将输入的token映射为dim维度的向量
# 遍历layers中的每个MegaLayer及其相关层
for mega_layer, mega_norm, ff, ff_norm in self.layers:
mega_maybe_prenorm = mega_norm if pre_norm else identity
ff_maybe_prenorm = ff_norm if pre_norm else identity
mega_maybe_postnorm = mega_norm if post_norm else identity
ff_maybe_postnorm = ff_norm if post_norm else identity
x = mega_layer(mega_maybe_prenorm(x), x) # MegaLayer的前向传播
x = mega_maybe_postnorm(x) # 可能的后层归一化
x = ff(ff_maybe_prenorm(x)) + x # FeedForward层的前向传播
x = ff_maybe_postnorm(x) # 可能的后层归一化
return self.to_logits(x) # 将输出映射为num_tokens维度
.\lucidrains\Mega-pytorch\mega_pytorch\__init__.py
# 从 mega_pytorch.mega_pytorch 模块中导入 MegaLayer, Mega, MultiHeadedEMA 类
from mega_pytorch.mega_pytorch import MegaLayer, Mega, MultiHeadedEMA

Mega - Moving Average Equipped Gated Attention - Pytorch
Implementation of the Mega layer, the Single-head Attention with Multi-headed EMA layer that exists in the architecture that currently holds SOTA on Long Range Arena, beating S4 on Pathfinder-X and all the other tasks save for audio.
Install
$ pip install mega-pytorch
Usage
The Mega Layer with combination of attention and learned EMA
import torch
from mega_pytorch import MegaLayer
layer = MegaLayer(
dim = 128, # model dimensions
ema_heads = 16, # number of EMA heads
attn_dim_qk = 64, # dimension of queries / keys in attention
attn_dim_value = 256, # dimension of values in attention
laplacian_attn_fn = False, # whether to use softmax (false) or laplacian attention activation fn (true)
)
x = torch.randn(1, 1024, 128) # (batch, seq, dim)
out = layer(x) # (1, 1024, 128)
Full Mega (with layernorm for now)
import torch
from mega_pytorch import Mega
mega = Mega(
num_tokens = 256, # number of tokens
dim = 128, # model dimensions
depth = 6, # depth
ema_heads = 16, # number of EMA heads
attn_dim_qk = 64, # dimension of queries / keys in attention
attn_dim_value = 256, # dimensino of values in attention
laplacian_attn_fn = True, # whether to use softmax (false) or laplacian attention activation fn (true)
)
x = torch.randint(0, 256, (1, 1024))
logits = mega(x) # (1, 1024, 256)
Todo
- add dynamic positional bias for best length extrapolation arch
Citations
@inproceedings{Ma2022MegaMA,
title = {Mega: Moving Average Equipped Gated Attention},
author = {Xuezhe Ma and Chunting Zhou and Xiang Kong and Junxian He and Liangke Gui and Graham Neubig and Jonathan May and Luke Zettlemoyer},
year = {2022}
}
.\lucidrains\Mega-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的信息
setup(
# 包的名称
name = 'Mega-pytorch',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.1.0',
# 许可证
license='MIT',
# 描述
description = 'Mega - Pytorch',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目链接
url = 'https://github.com/lucidrains/Mega-pytorch',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'attention mechanism',
'exponential moving average',
'long range arena'
],
# 安装依赖
install_requires=[
'einops>=0.4',
'scipy',
'torch>=1.6',
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\Mega-pytorch\train.py
# 导入所需的库
from mega_pytorch.mega_pytorch import Mega
from mega_pytorch.autoregressive_wrapper import AutoregressiveWrapper
import argparse
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512
# 定义辅助函数
# 从 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 从 tokens 解码为字符串
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# 实例化类似 GPT 的解码器模型
model = Mega(
num_tokens = 256,
dim = 512,
depth = 8
)
model = AutoregressiveWrapper(model)
model.cuda()
# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
x = np.array(np.frombuffer(file.read(int(95e6)), dtype = np.uint8))
train_x, valid_x = np.split(x, [int(90e6)])
data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
# 定义文本采样数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
loss.backward()
print(f'training loss: {loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f"\n\n {prime} \n\n {'-' * 80} \n")
sample = model.generate(inp[None, ...], GENERATE_LENGTH)
output_str = decode_tokens(sample[0])
print(output_str + "\n\n")
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
.\lucidrains\MEGABYTE-pytorch\MEGABYTE_pytorch\attend.py
# 导入必要的库
from collections import namedtuple
from functools import wraps
from packaging import version
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
# 定义一个命名元组EfficientAttentionConfig,包含三个布尔类型的参数
EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# 定义一个辅助函数,用于检查值是否存在
def exists(val):
return val is not None
# 定义一个装饰器函数,确保被装饰的函数只执行一次
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 用once装饰print函数,确保只打印一次
print_once = once(print)
# 主要类Attend
class Attend(nn.Module):
def __init__(
self,
causal = False,
dropout = 0.,
flash = False
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.causal = causal
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# 确定用于cuda和cpu的高效注意力配置
self.cpu_config = EfficientAttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = EfficientAttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = EfficientAttentionConfig(False, True, True)
# 生成掩码
def get_mask(self, i, j, device):
return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)
# Flash Attention函数
def flash_attn(self, q, k, v, mask = None, attn_bias = None):
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
# 单头键/值
if k.ndim == 3:
k = rearrange(k, 'b n d -> b 1 n d')
if v.ndim == 3:
v = rearrange(v, 'b n d -> b 1 n d')
# 检查掩码是否存在并扩展到兼容的形状
if exists(mask) and mask.ndim != 4:
mask = rearrange(mask, 'b j -> b 1 1 j')
mask = mask.expand(-1, heads, q_len, -1)
# 检查是否有兼容的设备用于Flash Attention
config = self.cuda_config if is_cuda else self.cpu_config
# 使用torch.backends.cuda.sdp_kernel(**config._asdict())来执行pytorch 2.0的flash attention
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = self.causal
)
return out
# 定义一个前向传播函数,用于计算注意力机制中的查询、键、值以及掩码
def forward(self, q, k, v, mask = None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
# 获取查询和键的序列长度,以及设备信息
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
# 计算缩放因子
scale = q.shape[-1] ** -0.5
# 根据键的维度确定 einsum 的等式
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
# 如果启用了 flash 注意力机制,则调用相应函数
if self.flash:
return self.flash_attn(q, k, v, mask = mask)
# 计算相似度
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
# 如果启用了因果掩码
if self.causal:
# 获取因果掩码
causal_mask = self.get_mask(q_len, k_len, device)
# 将掩码应用到相似度矩阵中
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 计算注意力权重
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# 聚合值
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
return out
.\lucidrains\MEGABYTE-pytorch\MEGABYTE_pytorch\megabyte.py
# 导入数学库
import math
# 导入 functools 库
import functools
# 从 itertools 库中导入 zip_longest 函数
from itertools import zip_longest
# 导入 torch 库
import torch
# 从 torch.nn.functional 中导入 F
import torch.nn.functional as F
# 从 torch 中导入 nn, einsum
from torch import nn, einsum
# 从 einops 库中导入 rearrange, reduce, repeat, pack, unpack
from einops import rearrange, reduce, repeat, pack, unpack
# 从 einops.layers.torch 中导入 Rearrange
from einops.layers.torch import Rearrange
# 从 beartype 库中导入 beartype
from beartype import beartype
# 从 beartype.typing 中导入 Tuple, Union
from beartype.typing import Tuple, Union
# 从 MEGABYTE_pytorch.attend 中导入 Attend
from MEGABYTE_pytorch.attend import Attend
# 从 tqdm 中导入 tqdm
from tqdm import tqdm
# 辅助函数
# 判断值是否存在
def exists(val):
return val is not None
# 如果值存在则返回该值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 将单个张量按照指定模式打包
def pack_one(t, pattern):
return pack([t], pattern)
# 将单个张量按照指定模式解包
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# 计算使 num 变为 mult 的倍数的余数
def remainder_to_mult(num, mult):
return (mult - num % mult) % mult
# 将输入转换为元组,如果输入不是元组则重复 length 次
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)
# 计算多个数的乘积
def reduce_mult(nums):
return functools.reduce(lambda x, y: x * y, nums, 1)
# 张量辅助函数
# 计算张量的自然对数,避免小于 eps 的值
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps)
# 生成 Gumbel 噪声
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
# 生成 Gumbel 分布采样
def gumbel_sample(t, temperature = 1., dim = -1):
return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)
# 保留前 k 个最大值,其余设为负无穷
def top_k(logits, thres = 0.5):
num_logits = logits.shape[-1]
k = max(int((1 - thres) * num_logits), 1)
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# Token Shift,从 Peng et al of RWKV 中借鉴
def token_shift(t):
t, t_shift = t.chunk(2, dim = -1)
t_shift = F.pad(t_shift, (0, 0, 1, -1))
return torch.cat((t, t_shift), dim = -1)
# 旋转位置嵌入
class RotaryEmbedding(nn.Module):
def __init__(self, dim, theta = 10000):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
@property
def device(self):
return next(self.buffers()).device
def forward(self, seq_len):
t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim = -1)
return freqs
# 旋转半个张量
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
# 应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t):
return t * pos.cos() + rotate_half(t) * pos.sin()
# 归一化
class RMSNorm(nn.Module):
def __init__(self, dim, eps = 1e-8):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
return x / norm.clamp(min = self.eps) * self.g
# 辅助类
# 创建 FeedForward 网络
def FeedForward(*, dim, mult = 4, dropout = 0.):
return nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, dim * mult),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
# 注意力机制
class Attention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
dropout = 0.,
flash = False
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.attend = Attend(
causal = True,
flash = flash,
dropout = dropout
)
self.dropout = nn.Dropout(dropout)
self.norm = RMSNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
# 实现 Transformer 模型的前向传播过程
def forward(self, x, rotary_emb = None):
# 获取头数和设备信息
h, device = self.heads, x.device
# 对输入进行归一化处理
x = self.norm(x)
# 将输入 x 分别转换为查询 q、键 k、值 v
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
# 将查询 q 重新排列为形状为 'b h n d' 的张量
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
# 如果存在旋转位置编码,则对查询 q 和键 k 应用旋转位置编码
if exists(rotary_emb):
q, k = map(lambda t: apply_rotary_pos_emb(rotary_emb, t), (q, k))
# 使用注意力机制进行注意力计算
out = self.attend(q, k, v)
# 将输出重新排列为形状为 'b n (h d)' 的张量
out = rearrange(out, 'b h n d -> b n (h d)')
# 将输出转换为最终输出
return self.to_out(out)
# 定义一个名为 Transformer 的类,继承自 nn.Module
class Transformer(nn.Module):
# 初始化函数,接受多个参数
def __init__(
self,
*,
dim, # 维度
layers, # 层数
dim_head = 64, # 头部维度
heads = 8, # 头部数量
attn_dropout = 0., # 注意力机制的 dropout
ff_dropout = 0., # 前馈神经网络的 dropout
ff_mult = 4, # 前馈神经网络的倍数
rel_pos = True, # 是否使用相对位置编码
flash_attn = False # 是否使用 Flash 注意力机制
):
super().__init__() # 调用父类的初始化函数
self.rotary_emb = RotaryEmbedding(dim_head) if rel_pos else None # 如果使用相对位置编码,则创建旋转嵌入对象,否则为 None
self.layers = nn.ModuleList([]) # 创建一个空的 nn.ModuleList 对象
# 循环创建指定层数的注意力机制和前馈神经网络
for _ in range(layers):
self.layers.append(nn.ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn), # 创建注意力机制对象
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) # 创建前馈神经网络对象
]))
self.norm = RMSNorm(dim) # 创建 RMS 归一化对象
# 前向传播函数,接受输入 x
def forward(self, x):
n = x.shape[-2] # 获取输入 x 的倒数第二维度大小
rotary_emb = self.rotary_emb(n) if exists(self.rotary_emb) else None # 如果存在旋转嵌入对象,则根据 n 创建旋转嵌入,否则为 None
# 遍历每一层的注意力机制和前馈神经网络
for attn, ff in self.layers:
x = attn(token_shift(x), rotary_emb = rotary_emb) + x # 执行注意力机制和残差连接
x = ff(token_shift(x)) + x # 执行前馈神经网络和残差连接
return self.norm(x) # 返回经过归一化的结果
# 主类 MEGABYTE
class MEGABYTE(nn.Module):
@beartype
# 初始化函数,接受多个参数
def __init__(
self,
*,
num_tokens, # 标记数量
dim: Union[Tuple, int], # 维度
depth: Tuple, # 深度
max_seq_len: Tuple, # 最大序列长度
dim_head = 64, # 头部维度
heads = 8, # 头部数量
attn_dropout = 0., # 注意力机制的 dropout
ff_mult = 4, # 前馈神经网络的倍数
ff_dropout = 0., # 前馈神经网络的 dropout
pad_id = 0, # 填充��记的 id
rel_pos = False, # 是否使用相对位置编码
pos_emb = False, # 是否使用位置嵌入
flash_attn = False # 是否使用 Flash 注意力机制
):
# 调用父类的构造函数
super().__init__()
# 简化每个层次的配置
# depth = (2, 2, 4) 表示第一阶段深度为2,第二阶段深度为2,第三阶段深度为4
# max_seq_len = (16, 8, 4) 表示第一阶段最大序列长度为16,第二阶段为8,最后一阶段为4
assert isinstance(depth, tuple) and isinstance(max_seq_len, tuple)
assert len(depth) == len(max_seq_len)
self.stages = len(depth)
dim = cast_tuple(dim, self.stages)
assert len(dim) == self.stages
coarsest_dim, *_, fine_dim = dim
self.max_seq_len = max_seq_len
# 初始化起始 token
self.start_tokens = nn.ParameterList([nn.Parameter(torch.randn(h_dim)) for h_dim, seq_len in zip(dim, max_seq_len)])
# 初始化位置嵌入
self.pos_embs = nn.ModuleList([nn.Embedding(seq_len, h_dim) for h_dim, seq_len in zip(dim, max_seq_len)]) if pos_emb else None
self.token_embs = nn.ModuleList([])
patch_size = 1
# 添加 token 嵌入
self.token_embs.append(nn.Embedding(num_tokens, fine_dim))
for dim_out, seq_len in zip(reversed(dim[:-1]), reversed(max_seq_len[1:])):
patch_size *= seq_len
# 构建 token 嵌入的序列
self.token_embs.append(nn.Sequential(
nn.Embedding(num_tokens, fine_dim),
Rearrange('... r d -> ... (r d)'),
nn.LayerNorm(patch_size * fine_dim),
nn.Linear(patch_size * fine_dim, dim_out),
nn.LayerNorm(dim_out)
))
self.transformers = nn.ModuleList([])
self.to_next_transformer_projections = nn.ModuleList([])
for h_dim, next_h_dim, stage_depth, next_seq_len in zip_longest(dim, dim[1:], depth, max_seq_len[1:]):
# 添加 Transformer 模块
self.transformers.append(Transformer(
dim = h_dim,
layers = stage_depth,
dim_head = dim_head,
heads = heads,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
ff_mult = ff_mult,
rel_pos = rel_pos,
flash_attn = flash_attn
))
proj = nn.Identity()
if exists(next_h_dim) and next_h_dim != dim:
proj = nn.Sequential(
Rearrange('b ... d -> b (...) d'),
nn.Linear(h_dim, next_h_dim * next_seq_len),
Rearrange('b m (n d) -> (b m) n d', n = next_seq_len)
)
self.to_next_transformer_projections.append(proj)
# 线性层,用于输出 logits
self.to_logits = nn.Linear(fine_dim, num_tokens)
self.pad_id = pad_id
# 生成文本
def generate(self, prime = None, filter_thres = 0.9, temperature = 1., default_batch_size = 1):
total_seq_len = reduce_mult(self.max_seq_len)
device = next(self.parameters()).device
if not exists(prime):
prime = torch.empty((default_batch_size, 0), dtype = torch.long, device = device)
seq = prime
batch = seq.shape[0]
# 生成文本序列
for _ in tqdm(range(total_seq_len - seq.shape[-1])):
logits = self.forward(seq)[:, -1]
logits = top_k(logits, thres = filter_thres)
sampled = gumbel_sample(logits, dim = -1, temperature = temperature)
seq = torch.cat((seq, rearrange(sampled, 'b -> b 1')), dim = -1)
return seq.reshape(batch, *self.max_seq_len)
# 定义一个方法,用于处理特殊情况,即从输入为0(仅起始标记)中进行采样
def forward_empty(self, batch_size):
# 初始化前一个阶段的标记表示为空
prev_stage_tokens_repr = None
# 遍历起始标记、变换器和投影器,分别对应每个阶段
for stage_start_tokens, transformer, proj in zip(self.start_tokens, self.transformers, self.to_next_transformer_projections):
# 将起始标记重复扩展到指定批次大小
tokens = repeat(stage_start_tokens, 'd -> b 1 d', b = batch_size)
# 如果前一个阶段的标记表示存在,则将其与当前阶段的标记相加
if exists(prev_stage_tokens_repr):
tokens = tokens + prev_stage_tokens_repr[..., :tokens.shape[-2], :]
# 经过变换器处理标记
tokens = transformer(tokens)
# 通过投影器得到当前阶段的标记表示
prev_stage_tokens_repr = proj(tokens)
# 返回标记转换为对数概率的结果
return self.to_logits(tokens)
# 定义前向传播函数,接受输入 ids 和是否返回损失值的标志
def forward(self, ids, return_loss = False):
# 获取批量大小
batch = ids.shape[0]
# 断言输入 ids 的维度为 2 或者 self.stages + 1
assert ids.ndim in {2, self.stages + 1}
# 检查是否为扁平化维度
flattened_dims = ids.ndim == 2
ids_orig_ndim = ids.ndim
# 如果 ids 为空,则调用 forward_empty 函数
if ids.numel() == 0:
return self.forward_empty(ids.shape[0])
# 如果为扁平化维度,则自动填充到最接近深度序列长度的倍数
if flattened_dims:
# 获取序列长度
seq_len = ids.shape[-1]
# 计算填充值
multiple_of = reduce_mult(self.max_seq_len[1:])
padding = remainder_to_mult(seq_len, multiple_of)
# 对 ids 进行填充
ids = F.pad(ids, (0, padding), value = self.pad_id)
ids = ids.reshape(batch, -1, *self.max_seq_len[1:])
# 获取 ids 的形状和设备信息
b, *prec_dims, device = *ids.shape, ids.device
# 检查一些维度
assert prec_dims[0] <= self.max_seq_len[0], 'the first dimension of your axial autoregressive transformer must be less than the first tuple element of max_seq_len (like any autoregressive transformer)'
assert tuple(prec_dims[1:]) == tuple(self.max_seq_len[1:]), 'all subsequent dimensions must match exactly'
# 获取所有层次阶段的 tokens,减少适当的维度并添加绝对位置嵌入
tokens_at_stages = []
pos_embs = default(self.pos_embs, (None,))
for ind, pos_emb, token_emb in zip_longest(range(len(prec_dims)), pos_embs, self.token_embs):
is_first = ind == 0
tokens = token_emb(ids)
if exists(pos_emb):
positions = pos_emb(torch.arange(tokens.shape[-2], device = device))
tokens = tokens + positions
tokens_at_stages.insert(0, tokens)
if is_first:
continue
ids = rearrange(ids, '... m n -> ... (m n)')
# 上一个层次结构的未像素化表示,从 None 开始
prev_stage_tokens_repr = None
# 空间 tokens 是在深度 pos 减少的 tokens + 空间位置
for stage_start_tokens, stage_tokens, transformer, proj in zip(self.start_tokens, tokens_at_stages, self.transformers, self.to_next_transformer_projections):
stage_tokens, ps = pack_one(stage_tokens, '* n d')
stage_start_tokens = repeat(stage_start_tokens, 'f -> b 1 f', b = stage_tokens.shape[0])
# 连接起始 token
stage_tokens = torch.cat((
stage_start_tokens,
stage_tokens,
), dim = -2)
# 对上一个层次结构的表示求和
if exists(prev_stage_tokens_repr):
prev_stage_tokens_repr = F.pad(prev_stage_tokens_repr, (0, 0, 1, 0), value = 0.)
stage_tokens = stage_tokens + prev_stage_tokens_repr
attended = transformer(stage_tokens)
attended = unpack_one(attended, ps, '* n d')
# 为下一个层次结构投影
prev_stage_tokens_repr = proj(attended[..., :-1, :])
# 投影到 logits
logits = self.to_logits(attended)
start_tokens = logits[(slice(None), *((0,) * (logits.ndim - 2)), slice(None)]
start_tokens = rearrange(start_tokens, 'b d -> b 1 d')
logits = logits[..., 1:, :]
if not return_loss:
if flattened_dims:
logits = rearrange(logits, 'b ... c -> b (...) c')
logits = logits[:, :seq_len]
return logits
logits = rearrange(logits, 'b ... c -> b (...) c')
logits = torch.cat((start_tokens, logits), dim = -2)
preds = rearrange(logits, 'b n c -> b c n')
labels = rearrange(ids, 'b ... -> b (...)')
loss = F.cross_entropy(
preds[..., :-1],
labels,
ignore_index = self.pad_id
)
return loss
.\lucidrains\MEGABYTE-pytorch\MEGABYTE_pytorch\__init__.py
# 从MEGABYTE_pytorch包中导入MEGABYTE类
from MEGABYTE_pytorch.megabyte import MEGABYTE

MEGABYTE - Pytorch
Implementation of MEGABYTE, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch. Took the liberty to generalize it even further so one can have multiple local models.
Similar independent research that is a further generalization
Appreciation
- Stability and 🤗 Huggingface for the generous sponsorship to work on and open source cutting edge artificial intelligence research
Install
$ pip install MEGABYTE-pytorch
Usage
import torch
from MEGABYTE_pytorch import MEGABYTE
model = MEGABYTE(
num_tokens = 16000, # number of tokens
dim = (512, 256), # transformer model dimension (512 for coarsest, 256 for fine in this example)
max_seq_len = (1024, 4), # sequence length for global and then local. this can be more than 2
depth = (6, 4), # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's
dim_head = 64, # dimension per head
heads = 8, # number of attention heads
flash_attn = True # use flash attention
)
x = torch.randint(0, 16000, (1, 1024, 4))
loss = model(x, return_loss = True)
loss.backward()
# then after much training
logits = model(x)
# and sample from the logits accordingly
# or you can use the generate function
sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)
Test
Train on character-level enwik8 with patches of size 4 - length 8192
$ python train.py
Citations
@misc{yu2023megabyte,
title = {MEGABYTE: Predicting Million-byte Sequences with Multiscale Transformers},
author = {Lili Yu and Dániel Simig and Colin Flaherty and Armen Aghajanyan and Luke Zettlemoyer and Mike Lewis},
year = {2023},
eprint = {2305.07185},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{https://doi.org/10.48550/arxiv.2302.01327,
doi = {10.48550/ARXIV.2302.01327},
url = {https://arxiv.org/abs/2302.01327},
author = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
title = {Dual PatchNorm},
publisher = {arXiv},
year = {2023},
copyright = {Creative Commons Attribution 4.0 International}
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
@software{peng_bo_2021_5196578,
author = {PENG Bo},
title = {BlinkDL/RWKV-LM: 0.01},
month = {aug},
year = {2021},
publisher = {Zenodo},
version = {0.01},
doi = {10.5281/zenodo.5196578},
url = {https://doi.org/10.5281/zenodo.5196578}
}
@article{Kazemnejad2023TheIO,
title = {The Impact of Positional Encoding on Length Generalization in Transformers},
author = {Amirhossein Kazemnejad and Inkit Padhi and Karthikeyan Natesan Ramamurthy and Payel Das and Siva Reddy},
journal = {ArXiv},
year = {2023},
volume = {abs/2305.19466}
}
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}