Lucidrains 系列项目源码解析(六十九)
.\lucidrains\PaLM-jax\palm_jax\__init__.py
# 从 palm_jax.palm 模块中导入 PaLM 类
from palm_jax.palm import PaLM

PaLM - Jax
Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax using Equinox
May as well start doing more Jax work, given Facebook (Meta's) uncertain future
Flax version from Enrico!
Install
$ pip install PaLM-jax
Usage
The way the model is built doesn't require vmap at all. It can have any number of leading dimensions
import jax
from palm_jax import PaLM
key = jax.random.PRNGKey(0)
model = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12,
heads = 8,
dim_head = 64,
key = key
)
seq = jax.random.randint(key, (1, 1024), 0, 20000)
logits = model(seq) # (1, 1024, 20000)
The 540B PaLM in the paper would be
model = PaLM(
num_tokens = 256000,
dim = 18432,
depth = 118,
heads = 48,
dim_head = 256,
key = key
)
That's all it is. Attention (and scale) is all we need.
Todos
- bring in optax and setup a basic training on enwik8 (thanks to Enrico)
- ALiBi positional encoding arxiv.org/abs/2108.12… for PaLM-lite
Citations
@inproceedings{Chowdhery2022PaLMSL,
title = {PaLM: Scaling Language Modeling with Pathways},
author = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
year = {2022}
}
@misc{press2021ALiBi,
title = {Train Short, Test Long: Attention with Linear Biases Enable Input Length Extrapolation},
author = {Ofir Press and Noah A. Smith and Mike Lewis},
year = {2021},
url = {https://ofir.io/train_short_test_long.pdf}
}
@article{Rae2021ScalingLM,
title = {Scaling Language Models: Methods, Analysis \& Insights from Training Gopher},
author = {Jack W. Rae and Sebastian Borgeaud and Trevor Cai and Katie Millican and Jordan Hoffmann and Francis Song and John Aslanides and Sarah Henderson and Roman Ring and Susannah Young and Eliza Rutherford and Tom Hennigan and Jacob Menick and Albin Cassirer and Richard Powell and George van den Driessche and Lisa Anne Hendricks and Maribeth Rauh and Po-Sen Huang and Amelia Glaese and Johannes Welbl and Sumanth Dathathri and Saffron Huang and Jonathan Uesato and John F. J. Mellor and Irina Higgins and Antonia Creswell and Nathan McAleese and Amy Wu and Erich Elsen and Siddhant M. Jayakumar and Elena Buchatskaya and David Budden and Esme Sutherland and Karen Simonyan and Michela Paganini and L. Sifre and Lena Martens and Xiang Lorraine Li and Adhiguna Kuncoro and Aida Nematzadeh and Elena Gribovskaya and Domenic Donato and Angeliki Lazaridou and Arthur Mensch and Jean-Baptiste Lespiau and Maria Tsimpoukelli and N. K. Grigorev and Doug Fritz and Thibault Sottiaux and Mantas Pajarskas and Tobias Pohlen and Zhitao Gong and Daniel Toyama and Cyprien de Masson d'Autume and Yujia Li and Tayfun Terzi and Vladimir Mikulik and Igor Babuschkin and Aidan Clark and Diego de Las Casas and Aurelia Guy and Chris Jones and James Bradbury and Matthew G. Johnson and Blake A. Hechtman and Laura Weidinger and Iason Gabriel and William S. Isaac and Edward Lockhart and Simon Osindero and Laura Rimell and Chris Dyer and Oriol Vinyals and Kareem W. Ayoub and Jeff Stanway and L. L. Bennett and Demis Hassabis and Koray Kavukcuoglu and Geoffrey Irving},
journal = {ArXiv},
year = {2021},
volume = {abs/2112.11446}
}
@inproceedings{Zhang2019RootMS,
title = {Root Mean Square Layer Normalization},
author = {Biao Zhang and Rico Sennrich},
booktitle = {NeurIPS},
year = {2019}
}
.\lucidrains\PaLM-jax\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的信息
setup(
# 包名
name = 'PaLM-jax',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.1.2',
# 许可证类型
license='MIT',
# 描述信息
description = 'PaLM: Scaling Language Modeling with Pathways - Jax',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目链接
url = 'https://github.com/lucidrains/PaLM-jax',
# 关键词列表
keywords = [
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism'
],
# 安装依赖
install_requires=[
'einops==0.4',
'equinox>=0.5',
'jax>=0.3.4',
'jaxlib>=0.1',
'optax',
'numpy'
],
# 分类标签
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\PaLM-jax\train.py
# 导入必要的库
import os
from random import randrange
from functools import partial
import tqdm
import gzip
import numpy as np
import jax
import jax.numpy as jnp
from jax import nn
# 导入自定义库
import equinox as eqx
from optax import adam, clip_by_global_norm, chain, apply_every
# 导入自定义模块
from palm_jax.palm_lite import PaLM
from palm_jax.utils import sample
# 设置环境变量
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
MAX_GRAD_NORM = 0.5
VALIDATE_EVERY = 100
SAMPLE_EVERY = 500
SEQ_LEN = 1024
# 定义循环生成器函数
def cycle(loader):
while True:
for data in loader:
yield data
# 解码单个 token 函数
def decode_token(token):
return str(chr(max(32, token)))
# 解码一组 tokens 函数
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# 读取 enwik8 数据集
with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
data_train, data_val = np.split(X, [int(90e6)])
# 从数据集中采样序列函数
def sample_seq_from_data(data, *, seq_len, batch_size):
total_seq_len = data.shape[0]
base_arange = np.arange(seq_len)
start_indices = np.random.randint(0, total_seq_len - seq_len, (batch_size,))
token_indices = start_indices[:, None] + base_arange
return data[token_indices]
# 部分应用采样序列函数
sample_seq_fn = partial(sample_seq_from_data, seq_len = SEQ_LEN, batch_size = BATCH_SIZE)
# 初始化 PRNGKey
key = jax.random.PRNGKey(0)
# 初始化 PaLM 模型
model = PaLM(
num_tokens = 256,
dim = 512,
depth = 8,
heads = 8,
dim_head = 64,
key = key
)
# 交叉熵损失函数
def cross_entropy(logits, targets, axis = -1):
logprobs = nn.log_softmax(logits, axis = axis)
nll = jnp.take_along_axis(logprobs, jnp.expand_dims(targets, axis = axis), axis = axis)
cross_entropy = -jnp.mean(nll)
return cross_entropy
# 定义损失函数
@eqx.filter_value_and_grad
def loss_fn(model, data):
inp, labels = data[:, :-1], data[:, 1:]
logits = model(inp)
return cross_entropy(logits, labels, axis = -1)
# 初始化优化器
optim = chain(
clip_by_global_norm(MAX_GRAD_NORM),
adam(LEARNING_RATE),
apply_every(GRADIENT_ACCUMULATE_EVERY)
)
optim_state = optim.init(model)
# 训练步骤
@eqx.filter_jit(kwargs=dict(data=True))
def train_step(model, data, optim_state):
loss, grads = loss_fn(model, data)
updates, optim_state = optim.update(grads, optim_state)
model = eqx.apply_updates(model, updates)
return model, optim_state, loss
# 训练过程
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
for _ in range(GRADIENT_ACCUMULATE_EVERY):
data = sample_seq_fn(data_train)
model, optim_state, loss = train_step(model, data, optim_state)
print(f'loss: {loss.item()}')
if i % SAMPLE_EVERY == 0:
valid_data = sample_seq_fn(data_val)
prime = valid_data[0][:100]
prime_str = decode_tokens(prime)
print(prime_str, "\n", "*" * 40)
sampled = sample(key, model, prime, SEQ_LEN, top_k = 25)
sampled_str = decode_tokens(sampled[100:])
print(sampled_str)
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
PaLM-pytorch with Deepspeed for Enwik8
Deepspeed is the framework Microsoft used to train the world's largest Attention model (17GB) to date. They have open sourced it, and it works with PaLM Pytorch!
-
First install Deepspeed following instructions from their official repository github.com/microsoft/D…
-
Run the following command in this folder
$ deepspeed train.py --deepspeed --deepspeed_config ds_config.json
.\lucidrains\PaLM-pytorch\examples\enwik8_deepspeed\train.py
import deepspeed
# 导入 deepspeed 库
from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
# 从 palm_pytorch 库中导入 PaLM 类和 AutoregressiveWrapper 类
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from einops import rearrange
from torch import einsum, nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
# 导入所需的库
def add_argument():
parser=argparse.ArgumentParser(description='enwik8')
# 创建参数解析器对象
parser.add_argument('--with_cuda', default=False, action='store_true',
help='use CPU in case there\'s no GPU support')
parser.add_argument('--use_ema', default=False, action='store_true',
help='whether use exponential moving average')
parser.add_argument('-b', '--batch_size', default=32, type=int,
help='mini-batch size (default: 32)')
parser.add_argument('-e', '--epochs', default=30, type=int,
help='number of total epochs (default: 30)')
parser.add_argument('--local_rank', type=int, default=-1,
help='local rank passed from distributed launcher')
# 添加命令行参数
parser = deepspeed.add_config_arguments(parser)
# 添加 deepspeed 配置参数
args=parser.parse_args()
return args
# 定义函数用于添加参数
# constants
EPOCHS = 20
GRADIENT_ACCUMULATE_EVERY = 4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 1024
# 定义常量
# helpers
def decode_token(token):
return str(chr(max(32, token)))
# 定义函数用于解码单个 token
def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))
# 定义函数用于解码多个 tokens
# instantiate GPT-like decoder model
model = PaLM(num_tokens = 256, dim = 512, depth = 8)
# 实例化 PaLM 模型对象,设置参数
model = AutoregressiveWrapper(model, max_seq_len=2048)
# 使用 AutoregressiveWrapper 对象包装模型,设置最大序列长度
model.cuda()
# 将模型移动到 GPU 上
# prepare enwik8 data
with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 读取并准备数据集
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
return full_seq
def __len__(self):
return self.data.size(0) // self.seq_len
# 定义数据集类
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
# 创建训练集和验证集对象
# setup deepspeed
cmd_args = add_argument()
# 调用添加参数函数
model_engine, optimizer, trainloader, _ = deepspeed.initialize(args=cmd_args, model=model, model_parameters=model.parameters(), training_data=train_dataset)
# 使用 deepspeed 初始化模型引擎、优化器、训练数据加载器
# training
for _ in range(EPOCHS):
for i, data in enumerate(trainloader):
model_engine.train()
# 设置模型为训练模式
data = data.to(model_engine.local_rank)
# 将数据移动到指定设备
loss = model_engine(data)
# 计算损失
model_engine.backward(loss)
# 反向传播
torch.nn.utils.clip_grad_norm_(model_engine.parameters(), 0.5)
# 对梯度进行裁剪
model_engine.step()
# 更新模型参数
print(loss.item() * GRADIENT_ACCUMULATE_EVERY)
# 打印损失值
if i % VALIDATE_EVERY == 0:
model.eval()
# 设置模型为评估模式
with torch.no_grad():
inp = random.choice(val_dataset)[:-1]
loss = model(inp[None, :].cuda())
# 计算验证集损失
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
# 设置模型为评估模式
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))
# 打印生成的文本
sample = model.generate(inp[None, ...].cuda(), GENERATE_LENGTH)
output_str = decode_tokens(sample[0])
print(output_str)
# 生成文本并打印
.\lucidrains\PaLM-pytorch\palm_pytorch\autoregressive_wrapper.py
# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 从 torch 模块中导入 nn 模块
from torch import nn
# 定义一个辅助函数,用于检查值是否存在
def exists(val):
return val is not None
# 定义一个装饰器函数,用于在模型评估时切换模型状态
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# 定义一个函数用于进行 top k 过滤
def top_k(logits, thres=0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float("-inf"))
probs.scatter_(1, ind, val)
return probs
# 定义一个自回归封装器类
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, max_seq_len=2048, pad_value=0):
super().__init__()
self.max_seq_len = max_seq_len
self.pad_value = pad_value
self.net = net
# 生成函数,用于生成序列
@torch.no_grad()
@eval_decorator
def generate(
self,
start_tokens,
seq_len,
eos_token=None,
temperature=1.0,
filter_thres=0.9,
**kwargs
):
b, t, device = *start_tokens.shape, start_tokens.device
out = start_tokens
for _ in range(seq_len):
logits = self.net(out, **kwargs)[:, -1, :]
filtered_logits = top_k(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
if exists(eos_token):
is_eos_token = out == eos_token
if is_eos_token.any(dim=-1).all():
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_token, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
out = out.masked_fill(mask, self.pad_value)
break
out = out[:, t:]
return out
# 前向传播函数,用于计算损失
def forward(self, x, **kwargs):
x_inp, x_labels = x[:, :-1], x[:, 1:]
logits = self.net(x_inp, **kwargs)
return F.cross_entropy(rearrange(logits, "b c n -> b n c"), x_labels)
.\lucidrains\PaLM-pytorch\palm_pytorch\palm_lite.py
# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 torch 库中导入 einsum 和 nn 模块
from torch import einsum, nn
# 从 math 库中导入 log2 和 floor 函数
from math import log2, floor
# 定义函数,判断变量是否存在
def exists(val):
return val is not None
# normalization
# 定义 RMSNorm 类,继承自 nn.Module
class RMSNorm(nn.Module):
# 初始化函数
def __init__(self, dim, eps = 1e-8):
super().__init__()
# 初始化缩放因子
self.scale = dim ** -0.5
# 初始化 eps
self.eps = eps
# 创建可学习参数 g
self.g = nn.Parameter(torch.ones(dim))
# 前向传播函数
def forward(self, x):
# 计算输入张量 x 的 L2 范数
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
# 返回归一化后的结果
return x / norm.clamp(min = self.eps) * self.g
# AliBi
# 定义 AlibiPositionalBias 类,继承自 nn.Module
class AlibiPositionalBias(nn.Module):
# 初始化函数
def __init__(self, heads, **kwargs):
super().__init__()
# 初始化头数
self.heads = heads
# 计算斜率
slopes = torch.Tensor(self._get_slopes(heads))
slopes = rearrange(slopes, 'h -> h 1 1')
# 注册缓冲区 slopes 和 bias
self.register_buffer('slopes', slopes, persistent = False)
self.register_buffer('bias', None, persistent = False)
# 获取偏置
def get_bias(self, i, j, device):
i_arange = torch.arange(i, device = device)
j_arange = torch.arange(j, device = device)
bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
return bias
# 静态方法,获取斜率
@staticmethod
def _get_slopes(heads):
def get_slopes_power_of_2(n):
start = (2**(-2**-(log2(n)-3)))
ratio = start
return [start*ratio**i for i in range(n)]
if log2(heads).is_integer():
return get_slopes_power_of_2(heads)
closest_power_of_2 = 2 ** floor(log2(heads))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
# 前向传播函数
def forward(self, qk_sim):
h, i, j, device = *qk_sim.shape[-3:], qk_sim.device
if exists(self.bias) and self.bias.shape[-1] >= j:
return self.bias[..., :i, :j]
bias = self.get_bias(i, j, device)
bias = bias * self.slopes
num_heads_unalibied = h - bias.shape[0]
bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
self.register_buffer('bias', bias, persistent=False)
return bias
# residual
# 定义 Residual 类,继承自 nn.Module
class Residual(nn.Module):
# 初始化函数
def __init__(self, fn):
super().__init__()
self.fn = fn
# 前向传播函数
def forward(self, x):
return self.fn(x) + x
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
# https://arxiv.org/abs/2002.05202
# 定义 SwiGLU 类,继承自 nn.Module
class SwiGLU(nn.Module):
# 前向传播函数
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
# parallel attention and feedforward with residual
# discovered by Wang et al + EleutherAI from GPT-J fame
# 定义 ParallelTransformerBlock 类,继承自 nn.Module
class ParallelTransformerBlock(nn.Module):
# 初始化函数
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
super().__init__()
# 初始化 RMSNorm 层
self.norm = RMSNorm(dim)
attn_inner_dim = dim_head * heads
ff_inner_dim = dim * ff_mult
self.fused_dims = (attn_inner_dim, dim_head, (ff_inner_dim * 2))
self.heads = heads
self.scale = dim_head**-0.5
# 初始化 AlibiPositionalBias 层
self.alibi_pos_biases = AlibiPositionalBias(heads = self.heads)
# 初始化线性变换层
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
self.ff_out = nn.Sequential(
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False)
)
# for caching causal mask
self.register_buffer("mask", None, persistent=False)
# 获取掩码
def get_mask(self, n, device):
if self.mask is not None and self.mask.shape[-1] >= n:
return self.mask[:n, :n]
mask = torch.triu(torch.ones((n, n), device=device, dtype=torch.bool), 1)
self.register_buffer("mask", mask, persistent=False)
return mask
# 定义前向传播函数,接受输入张量 x
def forward(self, x):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
# 获取输入张量 x 的形状信息
n, device, h = x.shape[1], x.device, self.heads
# 对输入张量 x 进行预层归一化处理
x = self.norm(x)
# 获取注意力查询、键或值(共享键/值是我个人的发现)和前馈内部
q, kv, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
# 分割头部
# 他们使用多查询单键值注意力,又一篇 Noam Shazeer 的论文
# 他们发现在一定规模之后没有性能损失,而且解码更有效
# https://arxiv.org/abs/1911.02150
# 重新排列查询张量 q 的形状
q = rearrange(q, "b n (h d) -> b h n d", h = h)
# 缩放
q = q * self.scale
# 相似度计算
sim = einsum("b h i d, b j d -> b h i j", q, kv)
# 添加 alibi 偏置
sim = sim + self.alibi_pos_biases(sim)
# 因果掩码
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 注意力计算
attn = sim.softmax(dim=-1)
out = einsum("b h i j, b j d -> b h i d", attn, kv)
# 合并头部
out = rearrange(out, "b h n d -> b n (h d)")
# 合并头部并通过注意力输出和前馈输出层
merge_heads = self.attn_out(out) + self.ff_out(ff)
return merge_heads
# 定义一个函数PaLM,使用关键字参数,接受模型的维度dim、标记数量num_tokens、层数depth、头部维度dim_head、头部数量heads、前馈网络倍增ff_mult作为参数
def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):
# 创建一个神经网络模型,包括嵌入层、多个平行Transformer块、RMSNorm层和线性层
net = nn.Sequential(
nn.Embedding(num_tokens, dim), # 嵌入层,将标记映射到指定维度的向量
*[Residual(ParallelTransformerBlock(dim, dim_head, heads, ff_mult)) for _ in range(depth)], # 多个平行Transformer块
RMSNorm(dim), # RMSNorm层
nn.Linear(dim, num_tokens, bias=False) # 线性层,将维度映射回标记数量
)
# 将最后一层的权重设置为与第一层嵌入层的权重相同,实现权重共享
net[-1].weight = net[0].weight
# 对第一层嵌入层的权重进行正态分布初始化
nn.init.normal_(net[0].weight, std=0.02)
# 返回神经网络模型
return net
# 主函数,用于测试模型的功能
if __name__ == "__main__":
# 创建一个PaLM模型实例
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 1,
heads = 8,
dim_head = 64,
)
# 生成随机标记序列
tokens = torch.randint(0, 20000, (1, 2048))
# 输入标记序列到模型,得到预测结果logits
logits = palm(tokens) # (1, 2048, 20000)
# 统计模型中可训练参数的数量
n_params_torch = sum(
p.numel() for p in palm.parameters() if p.requires_grad
)
# 打印模型中可训练参数的数量
print(f"Number of parameters in torch model: {n_params_torch}")
.\lucidrains\PaLM-pytorch\palm_pytorch\palm_pytorch.py
# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 从 torch 库中导入 einsum 和 nn 模块
from torch import einsum, nn
# normalization
# they use layernorm without bias, something that pytorch does not offer
# 定义 LayerNorm 类,继承自 nn.Module
class LayerNorm(nn.Module):
# 初始化函数
def __init__(self, dim):
super().__init__()
# 创建可学习参数 gamma
self.gamma = nn.Parameter(torch.ones(dim))
# 创建 buffer beta
self.register_buffer("beta", torch.zeros(dim))
# 前向传播函数
def forward(self, x):
# 使用 F.layer_norm 进行层归一化
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# residual
# 定义 Residual 类,继承自 nn.Module
class Residual(nn.Module):
# 初始化函数
def __init__(self, fn):
super().__init__()
self.fn = fn
# 前向传播函数
def forward(self, x):
# 返回残差连接结果
return self.fn(x) + x
# rotary positional embedding
# https://arxiv.org/abs/2104.09864
# 定义 RotaryEmbedding 类,继承自 nn.Module
class RotaryEmbedding(nn.Module):
# 初始化函数
def __init__(self, dim):
super().__init__()
# 计算频率
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
# 创建 buffer inv_freq
self.register_buffer("inv_freq", inv_freq)
# 前向传播函数
def forward(self, max_seq_len, *, device):
# 生成序列
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
# 计算频率
freqs = einsum("i , j -> i j", seq, self.inv_freq)
# 拼接频率
return torch.cat((freqs, freqs), dim=-1)
# 旋转位置嵌入
def rotate_half(x):
# 重新排列张量维度
x = rearrange(x, "... (j d) -> ... j d", j=2)
# 拆分张量
x1, x2 = x.unbind(dim=-2)
# 拼接张量
return torch.cat((-x2, x1), dim=-1)
# 应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t):
# 计算旋转位置嵌入
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
# https://arxiv.org/abs/2002.05202
# 定义 SwiGLU 类,继承自 nn.Module
class SwiGLU(nn.Module):
# 前向传播函数
def forward(self, x):
# 拆分张量
x, gate = x.chunk(2, dim=-1)
# 使用 SiLU 激活函数
return F.silu(gate) * x
# parallel attention and feedforward with residual
# discovered by Wang et al + EleutherAI from GPT-J fame
# 定义 ParallelTransformerBlock 类,继承自 nn.Module
class ParallelTransformerBlock(nn.Module):
# 初始化函数
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
super().__init__()
# 归一化层
self.norm = LayerNorm(dim)
attn_inner_dim = dim_head * heads
ff_inner_dim = dim * ff_mult
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
self.heads = heads
self.scale = dim_head**-0.5
self.rotary_emb = RotaryEmbedding(dim_head)
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
self.ff_out = nn.Sequential(
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False)
)
# for caching causal mask and rotary embeddings
self.register_buffer("mask", None, persistent=False)
self.register_buffer("pos_emb", None, persistent=False)
# 获取掩码
def get_mask(self, n, device):
if self.mask is not None and self.mask.shape[-1] >= n:
return self.mask[:n, :n]
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask
# 获取旋转嵌入
def get_rotary_embedding(self, n, device):
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
return self.pos_emb[:n]
pos_emb = self.rotary_emb(n, device=device)
self.register_buffer("pos_emb", pos_emb, persistent=False)
return pos_emb
# 定义前向传播函数,接受输入张量 x
def forward(self, x):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
# 获取输入张量 x 的形状信息
n, device, h = x.shape[1], x.device, self.heads
# 对输入张量 x 进行 LayerNorm 处理
x = self.norm(x)
# 使用融合的注意力和前馈神经网络投影层对输入张量 x 进行投影
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
# 将投影后的张量按照指定维度进行分割,用于多头注意力
q = rearrange(q, "b n (h d) -> b h n d", h=h)
# 获取旋转位置嵌入
positions = self.get_rotary_embedding(n, device)
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
# 缩放
q = q * self.scale
# 计算相似度
sim = einsum("b h i d, b j d -> b h i j", q, k)
# 获取因果掩码
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 注意力权重计算
attn = sim.softmax(dim=-1)
# 聚合值
out = einsum("b h i j, b j d -> b h i d", attn, v)
# 合并多头
out = rearrange(out, "b h n d -> b n (h d)")
# 返回注意力输出和前馈网络输出的和
return self.attn_out(out) + self.ff_out(ff)
# 定义一个函数PaLM,用于创建一个Parallel Transformer模型
def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):
# 创建一个神经网络模型,包括嵌入层、多个ParallelTransformerBlock、LayerNorm层和线性层
net = nn.Sequential(
nn.Embedding(num_tokens, dim), # 创建一个嵌入层,将输入的token映射到指定维度的向量
*[
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
for _ in range(depth) # 创建指定数量的ParallelTransformerBlock,并将其作为Residual块添加到模型中
],
LayerNorm(dim), # 添加LayerNorm层,用于归一化模型输出
nn.Linear(dim, num_tokens, bias=False) # 添加线性层,将模型输出映射到指定数量的token
)
# 将嵌入层的权重赋值给线性层的权重,实现权重共享
net[-1].weight = net[0].weight
# 对嵌入层的权重进行正态分布初始化
nn.init.normal_(net[0].weight, std=0.02)
# 返回创建的神经网络模型
return net
.\lucidrains\PaLM-pytorch\palm_pytorch\triton\layernorm.py
# 从 Phil Tillet 的 Triton 的 layernorm 教程中获取的代码
# Triton - https://triton-lang.org
# Layernorm 教程 - https://triton-lang.org/master/getting-started/tutorials/05-layer-norm.html#sphx-glr-getting-started-tutorials-05-layer-norm-py
# 修改为无偏置
# 导入必要的库
import torch
import triton
import triton.language as tl
# 前向传播的 Triton 内核函数
@triton.jit
def _layer_norm_fwd_fused(X, Y, W, M, V, stride, N,
BLOCK_SIZE: tl.constexpr):
# 获取当前行号
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
X += row * stride
Y += row * stride
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
mean = tl.sum(x, axis=0) / N
xmean = tl.where(mask, x - mean, 0.)
var = tl.sum(xmean * xmean, axis=0) / N
rstd = 1 / tl.sqrt(var + 1e-5)
xhat = xmean * rstd
tl.store(M + row, mean)
tl.store(V + row, rstd)
w = tl.load(W + cols, mask=mask)
y = xhat * w
tl.store(Y + cols, y, mask=mask)
# 反向传播的 Triton 内核函数
@triton.jit
def _layer_norm_bwd_dx_fused(DX, DY, DW, X, W, M, V, Lock, stride, N,
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N
X += row * stride
DY += row * stride
DX += row * stride
lock_id = row % GROUP_SIZE_M
Lock += lock_id
Count = Lock + GROUP_SIZE_M
DW = DW + lock_id * N + cols
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
w = tl.load(W + cols, mask=mask).to(tl.float32)
mean = tl.load(M + row)
rstd = tl.load(V + row)
xhat = (x - mean) * rstd
wdy = w * dy
xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy, 0.)
mean1 = tl.sum(xhat * wdy, axis=0) / N
mean2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * mean1 + mean2)) * rstd
tl.store(DX + cols, dx, mask=mask)
partial_dw = (dy * xhat).to(w.dtype)
while tl.atomic_cas(Lock, 0, 1) == 1:
pass
count = tl.load(Count)
if count == 0:
tl.atomic_xchg(Count, 1)
else:
partial_dw += tl.load(DW, mask=mask)
tl.store(DW, partial_dw, mask=mask)
tl.atomic_xchg(Lock, 0)
# 计算权重梯度的 Triton 内核函数
@triton.jit
def _layer_norm_bwd_dw(DW, FINAL_DW, M, N,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
dw += tl.load(DW + offs, mask=mask, other=0.)
sum_dw = tl.sum(dw, axis=0)
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
# LayerNorm 类,继承自 torch.autograd.Function
class LayerNorm(torch.autograd.Function):
# 前向传播函数
@staticmethod
def forward(ctx, x, normalized_shape, weight):
y = torch.empty_like(x)
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
mean = torch.empty((M, ), dtype=torch.float32, device='cuda')
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, mean, rstd,
x_arg.stride(0), N,
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
ctx.save_for_backward(x, weight, mean, rstd)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
return y
@staticmethod
# 反向传播函数,计算梯度
def backward(ctx, dy):
# 从上下文中获取保存的张量 x, w, m, v
x, w, m, v = ctx.saved_tensors
# 获取 w 的形状信息
N = w.shape[0]
GROUP_SIZE_M = 64
# 根据 w 的大小确定 GROUP_SIZE_M 的值
if N <= 8192: GROUP_SIZE_M = 96
if N <= 4096: GROUP_SIZE_M = 128
if N <= 1024: GROUP_SIZE_M = 256
# 创建用于同步的锁
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
# 创建用于存储梯度的 _dw 张量
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
# 创建用于存储 w 梯度的 dw 张量
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
# 创建用于存储输入 x 的梯度的 dx 张量
dx = torch.empty_like(dy)
# 将输入 x 重塑为二维张量
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
# 调用 _layer_norm_bwd_dx_fused 函数计算 dx
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, x, w, m, v, locks,
x_arg.stride(0), N,
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
GROUP_SIZE_M=GROUP_SIZE_M,
num_warps=ctx.num_warps)
# 定义 grid 函数用于计算网格大小
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
# 调用 _layer_norm_bwd_dw 函数计算 dw
_layer_norm_bwd_dw[grid](_dw, dw, GROUP_SIZE_M, N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128)
# 返回计算得到的 dx 和 dw
return dx, None, dw, None
# 将LayerNorm类的apply方法赋值给layernorm_without_bias变量
layernorm_without_bias = LayerNorm.apply
.\lucidrains\PaLM-pytorch\palm_pytorch\triton\palm.py
# 导入所需的库
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import einsum, nn
# 导入自定义的模块
from palm_pytorch.triton.softmax import causal_softmax
from palm_pytorch.triton.layernorm import layernorm_without_bias
# normalization
# 定义 LayerNorm 类,用于实现 Layer Normalization
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
return layernorm_without_bias(x, x.shape[-1:], self.gamma)
# residual
# 定义 Residual 类,用于实现残差连接
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
# rotary positional embedding
# 定义 RotaryEmbedding 类,用于实现旋转位置嵌入
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, max_seq_len, *, device):
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = einsum("i , j -> i j", seq, self.inv_freq)
return torch.cat((freqs, freqs), dim=-1)
# 定义旋转操作函数
def rotate_half(x):
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
# 应用旋转位置嵌入到输入张量
def apply_rotary_pos_emb(pos, t):
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
# feedforward
# 定义 SwiGLU 类,用于实现 Swish-Gated Linear Unit
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
# parallel attention and feedforward with residual
# discovered by Wang et al + EleutherAI from GPT-J fame
# 定义 ParallelTransformerBlock 类,实现并行的 Transformer 模块
class ParallelTransformerBlock(nn.Module):
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
super().__init__()
self.norm = LayerNorm(dim)
attn_inner_dim = dim_head * heads
ff_inner_dim = dim * ff_mult
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
self.heads = heads
self.scale = dim_head**-0.5
self.rotary_emb = RotaryEmbedding(dim_head)
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
self.ff_out = nn.Sequential(SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False)
# for caching of rotary embeddings
self.register_buffer("pos_emb", None, persistent=False)
def get_rotary_embedding(self, n, device):
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
return self.pos_emb[:n]
pos_emb = self.rotary_emb(n, device=device)
self.register_buffer("pos_emb", pos_emb, persistent=False)
return pos_emb
def forward(self, x):
n, device, h = x.shape[1], x.device, self.heads
# pre layernorm
x = self.norm(x)
# attention queries, keys, values, and feedforward inner
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
# split heads
q = rearrange(q, "b n (h d) -> b h n d", h=h)
# rotary embeddings
positions = self.get_rotary_embedding(n, device)
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
# scale
q = q * self.scale
# similarity
sim = einsum("b h i d, b j d -> b h i j", q, k)
# attention
attn = causal_softmax(sim)
# aggregate values
out = einsum("b h i j, b j d -> b h i d", attn, v)
# merge heads
out = rearrange(out, "b h n d -> b n (h d)")
return self.attn_out(out) + self.ff_out(ff)
# transformer
# 定义 PaLM 函数,用于实现 Parallel Transformer
def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):
# 创建一个神经网络模型,包括嵌入层、多个平行Transformer块、LayerNorm层和线性层
net = nn.Sequential(
# 创建一个嵌入层,将输入的标记转换为指定维度的向量
nn.Embedding(num_tokens, dim),
# 使用循环创建指定数量的平行Transformer块,并将它们作为残差连接添加到Sequential中
*[
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
for _ in range(depth)
],
# 添加LayerNorm层,对模型的输出进行归一化处理
LayerNorm(dim),
# 添加线性层,将模型的输出映射为标记的数量
nn.Linear(dim, num_tokens, bias=False)
)
# 将最后一个线性层的权重设置为嵌入层的权重
net[-1].weight = net[0].weight
# 对嵌入层的权重进行正态分布初始化,标准差为0.02
nn.init.normal_(net[0].weight, std=0.02)
# 返回创建的神经网络模型
return net
.\lucidrains\PaLM-pytorch\palm_pytorch\triton\softmax.py
# 导入 torch 库
import torch
# 从 torch 库中导入 autograd 模块
from torch import autograd
# 从 torch.nn.functional 模块中导入 F 函数
import torch.nn.functional as F
# 导入 triton 库
import triton
# 从 triton.language 模块中导入 tl
import triton.language as tl
# 从 triton_transformer.utils 模块中导入 calc_num_warps 函数
from triton_transformer.utils import calc_num_warps
# 定义 softmax_kernel_forward 函数,使用 triton.jit 装饰器
@triton.jit
def softmax_kernel_forward(
output_ptr,
input_ptr,
input_row_stride,
output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr
):
# 获取当前程序的 ID
row_idx = tl.program_id(0)
# 计算当前行的起始指针
row_start_ptr = input_ptr + row_idx * input_row_stride
# 计算列偏移量
col_offsets = tl.arange(0, BLOCK_SIZE)
# 计算输入指针
input_ptrs = row_start_ptr + col_offsets
# 创建一个掩码,用于过滤超出列数的列
mask = col_offsets < n_cols
# 从输入指针加载数据到行
row = tl.load(input_ptrs, mask = mask, other = -float('inf'))
# 创建一个因果掩码
causal_mask = col_offsets > (row_idx % n_cols)
# 对行应用因果掩码
row = row + tl.where(causal_mask, -float('inf'), 0.)
# 计算行减去最大值
row_minus_max = row - tl.max(row, axis=0)
# 计算指数
numerator = tl.exp(row_minus_max)
# 计算分母
denominator = tl.sum(numerator, axis=0)
# 计算 softmax 输出
softmax_output = numerator / denominator
# 计算输出行的起始指针
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# 计算输出指针
output_ptrs = output_row_start_ptr + col_offsets
# 存储 softmax 输出
tl.store(output_ptrs, softmax_output, mask = mask)
# 定义 softmax_kernel_backward 函数,使用 triton.jit 装饰器
@triton.jit
def softmax_kernel_backward(
output_ptr,
input_ptr,
grad_ptr,
grad_row_stride,
input_row_stride,
output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr
):
# 获取当前程序的 ID
row_idx = tl.program_id(0)
# 计算当前行的起始指针
row_start_ptr = input_ptr + row_idx * input_row_stride
grad_row_start_ptr = grad_ptr + row_idx * grad_row_stride
# 计算列偏移量
col_offsets = tl.arange(0, BLOCK_SIZE)
# 计算输入指针和梯度指针
input_ptrs = row_start_ptr + col_offsets
grad_ptrs = grad_row_start_ptr + col_offsets
# 创建一个掩码,用于过滤超出列数的列
mask = col_offsets < n_cols
# 从输入指针加载概率行和梯度行
probs_row = tl.load(input_ptrs, mask = mask, other = 0.)
grad_row = tl.load(grad_ptrs, mask = mask, other = 0.)
# 计算 dxhat
dxhat = probs_row * grad_row
# 计算 softmax 梯度输出
softmax_grad_output = dxhat - probs_row * tl.sum(dxhat, axis = 0)
# 计算输出行的起始指针
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# 计算输出指针
output_ptrs = output_row_start_ptr + col_offsets
# 存储 softmax 梯度输出
tl.store(output_ptrs, softmax_grad_output, mask = mask)
# 定义 _softmax 类,继承自 autograd.Function
class _softmax(autograd.Function):
# 定义前向传播函数
@classmethod
def forward(self, ctx, x):
# 获取输入张量的形状
shape = x.shape
# 将输入张量展平成二维张量
x = x.view(-1, shape[-1])
n_rows, n_cols = x.shape
# 计算 BLOCK_SIZE 和 num_warps
BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = calc_num_warps(BLOCK_SIZE)
# 创建一个与输入张量相同形状的空张量
y = torch.empty_like(x)
# 调用 softmax_kernel_forward 函数
softmax_kernel_forward[(n_rows,)](
y,
x,
x.stride(0),
y.stride(0),
n_cols,
num_warps = num_warps,
BLOCK_SIZE = BLOCK_SIZE,
)
# 如果输入张量需要梯度,则保存中间结果
if x.requires_grad:
ctx.save_for_backward(y)
return y.view(*shape)
# 定义反向传播函数
@classmethod
def backward(self, ctx, grad_probs):
# 获取梯度张量的形状
shape = grad_probs.shape
# 获取前向传播保存的中间结果
probs, = ctx.saved_tensors
# 将梯度张量展平成二维张量
grad_probs = grad_probs.view(-1, grad_probs.shape[-1])
n_rows, n_cols = grad_probs.shape
# 计算 BLOCK_SIZE 和 num_warps
BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = calc_num_warps(BLOCK_SIZE)
# 创建一个与概率张量相同形状的空张量
dx = torch.empty_like(probs)
# 调用 softmax_kernel_backward 函数
softmax_kernel_backward[(n_rows,)](
dx,
probs,
grad_probs,
grad_probs.stride(0),
probs.stride(0),
dx.stride(0),
n_cols,
num_warps = num_warps,
BLOCK_SIZE = BLOCK_SIZE
)
return dx.view(*shape), None
# 定义 causal_softmax 函数,调用 _softmax 类的 apply 方法
causal_softmax = _softmax.apply
.\lucidrains\PaLM-pytorch\palm_pytorch\triton\__init__.py
# 从 palm_pytorch.triton.palm 模块中导入 PaLM 类
from palm_pytorch.triton.palm import PaLM
.\lucidrains\PaLM-pytorch\palm_pytorch\__init__.py
# 从 palm_pytorch 模块中导入 PaLM 类
from palm_pytorch.palm_pytorch import PaLM

PaLM - Pytorch
Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways, in less than 200 lines of code.
This model is pretty much SOTA on everything language. Yannic Kilcher explanation
It obviously will not scale, but it is just for educational purposes. To elucidate the public how simple it all really is.
Install
$ pip install PaLM-pytorch
Usage
import torch
from palm_pytorch import PaLM
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12,
heads = 8,
dim_head = 64,
)
tokens = torch.randint(0, 20000, (1, 2048))
logits = palm(tokens) # (1, 2048, 20000)
The PaLM 540B in the paper would be
palm = PaLM(
num_tokens = 256000,
dim = 18432,
depth = 118,
heads = 48,
dim_head = 256
)
Test on Enwik8
$ python train.py
Citations
@inproceedings{Chowdhery2022PaLMSL,
title = {PaLM: Scaling Language Modeling with Pathways},
author = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
year = {2022}
}
@article{Tillet2019TritonAI,
title = {Triton: an intermediate language and compiler for tiled neural network computations},
author = {Philippe Tillet and H. T. Kung and David D. Cox},
journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
year = {2019}
}
.\lucidrains\PaLM-pytorch\setup.py
# 导入必要的模块
from setuptools import find_packages, setup
# 设置包的信息
setup(
# 包的名称
name="PaLM-pytorch",
# 查找所有包,不排除任何包
packages=find_packages(exclude=[]),
# 版本号
version="0.2.2",
# 许可证
license="MIT",
# 描述
description="PaLM: Scaling Language Modeling with Pathways - Pytorch",
# 作者
author="Phil Wang",
# 作者邮箱
author_email="lucidrains@gmail.com",
# 长描述内容类型为 markdown
long_description_content_type = 'text/markdown',
# 项目链接
url="https://github.com/lucidrains/PaLM-pytorch",
# 关键词
keywords=[
"artificial general intelligence",
"deep learning",
"transformers",
"attention mechanism",
],
# 安装依赖
install_requires=[
"einops>=0.4",
"torch>=1.6",
"triton>=2.0dev"
],
# 分类
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.6",
],
)
.\lucidrains\PaLM-pytorch\train.py
# 导入所需的库
import gzip
import random
import numpy as np
import torch
import torch.optim as optim
import tqdm
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 导入自定义的类和函数
from palm_pytorch.triton import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 1024
# 定义辅助函数
# 生成数据加载器的无限循环
def cycle(loader):
while True:
for data in loader:
yield data
# 将 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 将 tokens 解码为字符串
def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))
# 实例化类似 GPT 的解码器模型
model = PaLM(num_tokens=256, dim=512, depth=8)
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
model.cuda()
# 准备 enwik8 数据
with gzip.open("./examples/enwik8_deepspeed/data/enwik8.gz") as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
loss.backward()
print(f"training loss: {loss.item()}")
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f"validation loss: {loss.item()}")
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f"%s \n\n %s", (prime, "*" * 100))
sample = model.generate(inp[None, ...], GENERATE_LENGTH)
output_str = decode_tokens(sample[0])
print(output_str)
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
.\lucidrains\PaLM-rlhf-pytorch\palm_rlhf_pytorch\attention.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 collections 模块中导入 namedtuple 类
from collections import namedtuple
# 从 functools 模块中导入 wraps 函数
from functools import wraps
# 从 packaging 模块中导入 version 类
from packaging import version
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 定义一个命名元组 Config,包含三个布尔类型的参数
Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# 定义一个辅助函数,用于检查值是否存在
def exists(val):
return val is not None
# 定义一个装饰器函数,用于确保被装饰的函数只执行一次
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 定义一个打印函数,使用 once 装饰器确保只打印一次
print_once = once(print)
# 主要类定义
class Attention(nn.Module):
def __init__(
self,
dropout = 0.,
causal = False,
use_flash_attn = False
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.causal = causal
# 注册一个缓冲区变量 mask,初始值为 None,不会被持久化
self.register_buffer("mask", None, persistent=False)
self.use_flash_attn = use_flash_attn
# 断言条件,如果不满足则抛出异常
assert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# 确定 CUDA 和 CPU 的高效注意力配置
self.cpu_config = Config(True, True, True)
self.cuda_config = None
# 如果没有可用的 CUDA 或不使用 flash attention,则直接返回
if not torch.cuda.is_available() or not use_flash_attn:
return
# 获取当前 CUDA 设备的属性
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
# 根据 CUDA 设备的主要和次要版本号选择配置
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = Config(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = Config(False, True, True)
# 获取掩码 mask
def get_mask(self, n, device):
if exists(self.mask) and self.mask.shape[-1] >= n:
return self.mask[:n, :n]
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask
# Flash Attention 函数
def flash_attn(self, q, k, v, mask = None):
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
# 推荐的多查询单键值注意力重排操作
k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
# 检查是否存在 mask 并扩展到兼容的形状
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
mask = mask.expand(-1, heads, q_len, -1)
# 检查是否有兼容的设备用于 flash attention
config = self.cuda_config if is_cuda else self.cpu_config
# 使用 torch.backends.cuda.sdp_kernel 函数应用配置,执行 Flash Attention
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = self.causal
)
return out
# 定义一个前向传播函数,接受查询(q)、键(k)、值(v)和掩码(mask)作为输入参数
def forward(self, q, k, v, mask = None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
# 获取序列长度(n)和设备信息(device)
n, device = q.shape[-2], q.device
# 计算缩放因子
scale = q.shape[-1] ** -0.5
# 如果使用闪回注意力机制,则调用flash_attn函数
if self.use_flash_attn:
return self.flash_attn(q, k, v, mask = mask)
# 计算相似度
sim = einsum("b h i d, b j d -> b h i j", q, k) * scale
# 键填充掩码
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# 因果掩码
if self.causal:
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 注意力计算
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# 聚合值
out = einsum("b h i j, b j d -> b h i d", attn, v)
return out
.\lucidrains\PaLM-rlhf-pytorch\palm_rlhf_pytorch\lora.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 辅助函数
# 判断值是否存在
def exists(val):
return val is not None
# 如果值存在则返回该值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# LoRA - https://arxiv.org/abs/2106.09685
# 定义 LoRA 类,继承自 nn.Module 类
class LoRA(nn.Module):
# 初始化函数
def __init__(
self,
dim,
dim_out,
r = 8,
alpha = None
):
super().__init__()
# 如果 alpha 不存在,则使用 r 作为默认值
alpha = default(alpha, r)
# 计算缩放因子
self.scale = alpha / r
# 定义 A 和 B 为可学习参数
self.A = nn.Parameter(torch.randn(dim, r))
self.B = nn.Parameter(torch.zeros(r, dim_out))
# 定义 weight 属性,返回 A 和 B 的乘积再乘以缩放因子
@property
def weight(self):
return (self.A @ self.B) * self.scale
# 前向传播函数,返回输入 x 与权重 weight 的乘积
def forward(self, x):
return x @ self.weight