Lucidrains 系列项目源码解析(九十六)
.\lucidrains\stylegan2-pytorch\stylegan2_pytorch\version.py
# 定义变量 __version__,赋值为字符串 '1.8.9'
__version__ = '1.8.9'
.\lucidrains\stylegan2-pytorch\stylegan2_pytorch\__init__.py
# 从 stylegan2_pytorch.stylegan2_pytorch 模块中导入 Trainer, StyleGAN2, NanException, ModelLoader 类
from stylegan2_pytorch.stylegan2_pytorch import Trainer, StyleGAN2, NanException, ModelLoader

Tab Transformer
Implementation of Tab Transformer, attention network for tabular data, in Pytorch. This simple architecture came within a hair's breadth of GBDT's performance.
Update: Amazon AI claims to have beaten GBDT with Attention on a real-world tabular dataset (predicting shipping cost).
Install
$ pip install tab-transformer-pytorch
Usage
import torch
import torch.nn as nn
from tab_transformer_pytorch import TabTransformer
cont_mean_std = torch.randn(10, 2)
model = TabTransformer(
categories = (10, 5, 6, 5, 8), # tuple containing the number of unique values within each category
num_continuous = 10, # number of continuous values
dim = 32, # dimension, paper set at 32
dim_out = 1, # binary prediction, but could be anything
depth = 6, # depth, paper recommended 6
heads = 8, # heads, paper recommends 8
attn_dropout = 0.1, # post-attention dropout
ff_dropout = 0.1, # feed forward dropout
mlp_hidden_mults = (4, 2), # relative multiples of each hidden dimension of the last mlp to logits
mlp_act = nn.ReLU(), # activation for final mlp, defaults to relu, but could be anything else (selu etc)
continuous_mean_std = cont_mean_std # (optional) - normalize the continuous values before layer norm
)
x_categ = torch.randint(0, 5, (1, 5)) # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_cont = torch.randn(1, 10) # assume continuous values are already normalized individually
pred = model(x_categ, x_cont) # (1, 1)
FT Transformer

This paper from Yandex improves on Tab Transformer by using a simpler scheme for embedding the continuous numerical values as shown in the diagram above, courtesy of this reddit post.
Included in this repository just for convenient comparison to Tab Transformer
import torch
from tab_transformer_pytorch import FTTransformer
model = FTTransformer(
categories = (10, 5, 6, 5, 8), # tuple containing the number of unique values within each category
num_continuous = 10, # number of continuous values
dim = 32, # dimension, paper set at 32
dim_out = 1, # binary prediction, but could be anything
depth = 6, # depth, paper recommended 6
heads = 8, # heads, paper recommends 8
attn_dropout = 0.1, # post-attention dropout
ff_dropout = 0.1 # feed forward dropout
)
x_categ = torch.randint(0, 5, (1, 5)) # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_numer = torch.randn(1, 10) # numerical value
pred = model(x_categ, x_numer) # (1, 1)
Unsupervised Training
To undergo the type of unsupervised training described in the paper, you can first convert your categories tokens to the appropriate unique ids, and then use Electra on model.transformer.
Todo
- consider arxiv.org/abs/2203.05…
Citations
@misc{huang2020tabtransformer,
title = {TabTransformer: Tabular Data Modeling Using Contextual Embeddings},
author = {Xin Huang and Ashish Khetan and Milan Cvitkovic and Zohar Karnin},
year = {2020},
eprint = {2012.06678},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@article{Gorishniy2021RevisitingDL,
title = {Revisiting Deep Learning Models for Tabular Data},
author = {Yu. V. Gorishniy and Ivan Rubachev and Valentin Khrulkov and Artem Babenko},
journal = {ArXiv},
year = {2021},
volume = {abs/2106.11959}
}
.\lucidrains\tab-transformer-pytorch\setup.py
# 导入设置安装和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
# 包的名称
name = 'tab-transformer-pytorch',
# 查找并包含所有包
packages = find_packages(),
# 版本号
version = '0.3.0',
# 许可证
license='MIT',
# 描述
description = 'Tab Transformer - Pytorch',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 项目链接
url = 'https://github.com/lucidrains/tab-transformer-pytorch',
# 关键词
keywords = [
'artificial intelligence',
'transformers',
'attention mechanism',
'tabular data'
],
# 安装依赖
install_requires=[
'einops>=0.3',
'torch>=1.6'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\tab-transformer-pytorch\tab_transformer_pytorch\ft_transformer.py
# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch 中导入 nn 和 einsum 模块
from torch import nn, einsum
# 从 einops 中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# feedforward and attention
# 定义 GEGLU 类,继承自 nn.Module
class GEGLU(nn.Module):
# 前向传播函数
def forward(self, x):
# 将输入 x 按照最后一个维度分成两部分
x, gates = x.chunk(2, dim = -1)
# 返回 x 乘以 gates 经过 gelu 激活函数的结果
return x * F.gelu(gates)
# 定义 FeedForward 函数,接受维度 dim、倍数 mult 和 dropout 参数
def FeedForward(dim, mult = 4, dropout = 0.):
# 返回一个序列模块
return nn.Sequential(
# LayerNorm 层
nn.LayerNorm(dim),
# 线性变换层,输入维度为 dim,输出维度为 dim * mult * 2
nn.Linear(dim, dim * mult * 2),
# GEGLU 层
GEGLU(),
# Dropout 层
nn.Dropout(dropout),
# 线性变换层,输入维度为 dim * mult,输出维度为 dim
nn.Linear(dim * mult, dim)
)
# 定义 Attention 类,继承自 nn.Module
class Attention(nn.Module):
# 初始化函数,接受维度 dim、头数 heads、头维度 dim_head 和 dropout 参数
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
dropout = 0.
):
super().__init__()
# 计算内部维度
inner_dim = dim_head * heads
# 头数和头维度的缩放系数
self.heads = heads
self.scale = dim_head ** -0.5
# LayerNorm 层
self.norm = nn.LayerNorm(dim)
# 线性变换层,输入维度为 dim,输出维度为 inner_dim * 3
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
# 线性变换层,输入维度为 inner_dim,输出维度为 dim
self.to_out = nn.Linear(inner_dim, dim, bias = False)
# Dropout 层
self.dropout = nn.Dropout(dropout)
# 前向传播函数
def forward(self, x):
# 头数
h = self.heads
# 对输入 x 进行 LayerNorm
x = self.norm(x)
# 将输入 x 经过线性变换得到 q、k、v
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
# 对 q、k、v 进行维度重排
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# 对 q 进行缩放
q = q * self.scale
# 计算注意力矩阵
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# 对注意力矩阵进行 softmax
attn = sim.softmax(dim = -1)
# 对 softmax 结果进行 dropout
dropped_attn = self.dropout(attn)
# 计算输出
out = einsum('b h i j, b h j d -> b h i d', dropped_attn, v)
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
out = self.to_out(out)
return out, attn
# transformer
# 定义 Transformer 类,继承自 nn.Module
class Transformer(nn.Module):
# 初始化函数,接受维度 dim、深度 depth、头数 heads、头维度 dim_head、注意力 dropout 和前馈 dropout 参数
def __init__(
self,
dim,
depth,
heads,
dim_head,
attn_dropout,
ff_dropout
):
super().__init__()
# 初始化层列表
self.layers = nn.ModuleList([])
# 循环创建 depth 个层
for _ in range(depth):
self.layers.append(nn.ModuleList([
# 注意力层
Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout),
# 前馈层
FeedForward(dim, dropout = ff_dropout),
]))
# 前向传播函数
def forward(self, x, return_attn = False):
# 存储后 softmax 的注意力矩阵
post_softmax_attns = []
# 遍历每个层
for attn, ff in self.layers:
# 获取注意力层的输出和后 softmax 的注意力矩阵
attn_out, post_softmax_attn = attn(x)
post_softmax_attns.append(post_softmax_attn)
# 更新 x
x = attn_out + x
x = ff(x) + x
# 如果不返回注意力矩阵,则返回 x
if not return_attn:
return x
# 返回 x 和后 softmax 的注意力矩阵
return x, torch.stack(post_softmax_attns)
# numerical embedder
# 定义 NumericalEmbedder 类,继承自 nn.Module
class NumericalEmbedder(nn.Module):
# 初始化函数,接受维度 dim 和数值类型数量 num_numerical_types
def __init__(self, dim, num_numerical_types):
super().__init__()
# 定义权重参数和偏置参数
self.weights = nn.Parameter(torch.randn(num_numerical_types, dim))
self.biases = nn.Parameter(torch.randn(num_numerical_types, dim))
# 前向传播函数
def forward(self, x):
# 将输入 x 维度重排
x = rearrange(x, 'b n -> b n 1')
# 返回加权和偏置后的结果
return x * self.weights + self.biases
# main class
# 定义 FTTransformer 类,继承自 nn.Module
class FTTransformer(nn.Module):
# 初始化函数,接受关键字参数 categories、num_continuous、dim、depth、heads、头维度 dim_head、输出维度 dim_out、特殊标记数量 num_special_tokens、注意力 dropout 和前馈 dropout
def __init__(
self,
*,
categories,
num_continuous,
dim,
depth,
heads,
dim_head = 16,
dim_out = 1,
num_special_tokens = 2,
attn_dropout = 0.,
ff_dropout = 0.
):
# 调用父类的构造函数
super().__init__()
# 断言所有类别的数量必须大于0
assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive'
# 断言类别数量加上连续值的数量不能为0
assert len(categories) + num_continuous > 0, 'input shape must not be null'
# categories related calculations
# 计算类别相关的参数
self.num_categories = len(categories)
self.num_unique_categories = sum(categories)
# create category embeddings table
# 创建类别嵌入表
self.num_special_tokens = num_special_tokens
total_tokens = self.num_unique_categories + num_special_tokens
# for automatically offsetting unique category ids to the correct position in the categories embedding table
# 用于自动将唯一类别ID偏移至类别嵌入表中的正确位置
if self.num_unique_categories > 0:
categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
categories_offset = categories_offset.cumsum(dim = -1)[:-1]
self.register_buffer('categories_offset', categories_offset)
# categorical embedding
# 类别嵌入
self.categorical_embeds = nn.Embedding(total_tokens, dim)
# continuous
# 连续值
self.num_continuous = num_continuous
if self.num_continuous > 0:
self.numerical_embedder = NumericalEmbedder(dim, self.num_continuous)
# cls token
# 类别标记
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# transformer
# 变换器
self.transformer = Transformer(
dim = dim,
depth = depth,
heads = heads,
dim_head = dim_head,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout
)
# to logits
# 转换为logits
self.to_logits = nn.Sequential(
nn.LayerNorm(dim),
nn.ReLU(),
nn.Linear(dim, dim_out)
)
def forward(self, x_categ, x_numer, return_attn = False):
assert x_categ.shape[-1] == self.num_categories, f'you must pass in {self.num_categories} values for your categories input'
xs = []
if self.num_unique_categories > 0:
x_categ = x_categ + self.categories_offset
x_categ = self.categorical_embeds(x_categ)
xs.append(x_categ)
# add numerically embedded tokens
if self.num_continuous > 0:
x_numer = self.numerical_embedder(x_numer)
xs.append(x_numer)
# concat categorical and numerical
# 连接类别和连续值
x = torch.cat(xs, dim = 1)
# append cls tokens
b = x.shape[0]
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
# attend
# 注意力机制
x, attns = self.transformer(x, return_attn = True)
# get cls token
# 获取类别标记
x = x[:, 0]
# out in the paper is linear(relu(ln(cls)))
# 论文中的输出是线性(ReLU(LN(cls)))
logits = self.to_logits(x)
if not return_attn:
return logits
return logits, attns
.\lucidrains\tab-transformer-pytorch\tab_transformer_pytorch\tab_transformer_pytorch.py
# 导入 PyTorch 库
import torch
import torch.nn.functional as F
from torch import nn, einsum
# 导入 einops 库中的 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 辅助函数
# 判断值是否存在
def exists(val):
return val is not None
# 返回默认值
def default(val, d):
return val if exists(val) else d
# 类定义
# 残差连接模块
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
# 预层归一化模块
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# 注意力机制
# GEGLU 激活函数
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
# 前馈神经网络模块
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
def forward(self, x, **kwargs):
return self.net(x)
# 注意力机制模块
class Attention(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 16,
dropout = 0.
):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
h = self.heads
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = sim.softmax(dim = -1)
dropped_attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', dropped_attn, v)
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
return self.to_out(out), attn
# Transformer 模块
class Transformer(nn.Module):
def __init__(
self,
dim,
depth,
heads,
dim_head,
attn_dropout,
ff_dropout
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
PreNorm(dim, FeedForward(dim, dropout = ff_dropout)),
]))
def forward(self, x, return_attn = False):
post_softmax_attns = []
for attn, ff in self.layers:
attn_out, post_softmax_attn = attn(x)
post_softmax_attns.append(post_softmax_attn)
x = x + attn_out
x = ff(x) + x
if not return_attn:
return x
return x, torch.stack(post_softmax_attns)
# 多层感知机模块
class MLP(nn.Module):
def __init__(self, dims, act = None):
super().__init__()
dims_pairs = list(zip(dims[:-1], dims[1:]))
layers = []
for ind, (dim_in, dim_out) in enumerate(dims_pairs):
is_last = ind >= (len(dims_pairs) - 1)
linear = nn.Linear(dim_in, dim_out)
layers.append(linear)
if is_last:
continue
act = default(act, nn.ReLU())
layers.append(act)
self.mlp = nn.Sequential(*layers)
def forward(self, x):
return self.mlp(x)
# 主类 TabTransformer
class TabTransformer(nn.Module):
# 初始化函数,设置模型的各种参数
def __init__(
self,
*,
categories, # 类别特征的数量列表
num_continuous, # 连续特征的数量
dim, # 模型的维度
depth, # Transformer 模型的深度
heads, # Transformer 模型的头数
dim_head = 16, # 每个头的维度
dim_out = 1, # 输出的维度
mlp_hidden_mults = (4, 2), # MLP 隐藏层的倍数
mlp_act = None, # MLP 的激活函数
num_special_tokens = 2, # 特殊标记的数量
continuous_mean_std = None, # 连续特征的均值和标准差
attn_dropout = 0., # 注意力机制的 dropout
ff_dropout = 0., # FeedForward 层的 dropout
use_shared_categ_embed = True, # 是否使用共享的类别嵌入
shared_categ_dim_divisor = 8 # 在论文中,他们将维度的 1/8 保留给共享的类别嵌入
):
super().__init__()
# 断言确保每个类别的数量大于 0
assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive'
# 断言确保类别数量和连续特征数量之和大于 0
assert len(categories) + num_continuous > 0, 'input shape must not be null'
# 与类别相关的计算
self.num_categories = len(categories) # 类别的数量
self.num_unique_categories = sum(categories) # 所有类别的总数
# 创建类别嵌入表
self.num_special_tokens = num_special_tokens
total_tokens = self.num_unique_categories + num_special_tokens
shared_embed_dim = 0 if not use_shared_categ_embed else int(dim // shared_categ_dim_divisor)
self.category_embed = nn.Embedding(total_tokens, dim - shared_embed_dim)
# 处理共享的类别嵌入
self.use_shared_categ_embed = use_shared_categ_embed
if use_shared_categ_embed:
self.shared_category_embed = nn.Parameter(torch.zeros(self.num_categories, shared_embed_dim))
nn.init.normal_(self.shared_category_embed, std = 0.02)
# 用于自动偏移唯一类别 id 到类别嵌入表中的正确位置
if self.num_unique_categories > 0:
categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
categories_offset = categories_offset.cumsum(dim = -1)[:-1]
self.register_buffer('categories_offset', categories_offset)
# 连续特征
self.num_continuous = num_continuous
if self.num_continuous > 0:
if exists(continuous_mean_std):
assert continuous_mean_std.shape == (num_continuous, 2), f'continuous_mean_std must have a shape of ({num_continuous}, 2) where the last dimension contains the mean and variance respectively'
self.register_buffer('continuous_mean_std', continuous_mean_std)
self.norm = nn.LayerNorm(num_continuous)
# Transformer 模型
self.transformer = Transformer(
dim = dim,
depth = depth,
heads = heads,
dim_head = dim_head,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout
)
# MLP 转换为 logits
input_size = (dim * self.num_categories) + num_continuous
hidden_dimensions = [input_size * t for t in mlp_hidden_mults]
all_dimensions = [input_size, *hidden_dimensions, dim_out]
self.mlp = MLP(all_dimensions, act = mlp_act)
# 定义一个前向传播函数,接受类别特征和连续特征作为输入,可选择返回注意力权重
def forward(self, x_categ, x_cont, return_attn = False):
# 初始化一个空列表用于存储不同类型特征的输出
xs = []
# 检查类别特征的最后一个维度是否与预期的类别数量相同
assert x_categ.shape[-1] == self.num_categories, f'you must pass in {self.num_categories} values for your categories input'
# 如果存在唯一的类别数量大于0
if self.num_unique_categories > 0:
# 对类别特征进行偏移处理
x_categ = x_categ + self.categories_offset
# 对类别特征进行嵌入处理
categ_embed = self.category_embed(x_categ)
# 如果使用共享的类别嵌入
if self.use_shared_categ_embed:
# 复制共享的类别嵌入并与类别嵌入拼接
shared_categ_embed = repeat(self.shared_category_embed, 'n d -> b n d', b = categ_embed.shape[0])
categ_embed = torch.cat((categ_embed, shared_categ_embed), dim = -1)
# 使用 Transformer 处理类别嵌入特征,可选择返回注意力权重
x, attns = self.transformer(categ_embed, return_attn = True)
# 将处理后的类别特征展平
flat_categ = rearrange(x, 'b ... -> b (...)')
xs.append(flat_categ)
# 检查连续特征的第二个维度是否与预期的连续特征数量相同
assert x_cont.shape[1] == self.num_continuous, f'you must pass in {self.num_continuous} values for your continuous input'
# 如果连续特征数量大于0
if self.num_continuous > 0:
# 如果存在连续特征的均值和标准差
if exists(self.continuous_mean_std):
# 分离连续特征的均值和标准差
mean, std = self.continuous_mean_std.unbind(dim = -1)
# 对连续特征进行标准化处理
x_cont = (x_cont - mean) / std
# 对标准化后的连续特征进行归一化处理
normed_cont = self.norm(x_cont)
xs.append(normed_cont)
# 将处理后的类别特征和连续特征拼接在一起
x = torch.cat(xs, dim = -1)
# 使用 MLP 处理拼接后的特征,得到输出 logits
logits = self.mlp(x)
# 如果不需要返回注意力权重,则直接返回 logits
if not return_attn:
return logits
# 如果需要返回注意力权重,则同时返回 logits 和注意力权重
return logits, attns
.\lucidrains\tab-transformer-pytorch\tab_transformer_pytorch\__init__.py
# 从 tab_transformer_pytorch 库中导入 TabTransformer 类
from tab_transformer_pytorch.tab_transformer_pytorch import TabTransformer
# 从 tab_transformer_pytorch 库中导入 FTTransformer 类
from tab_transformer_pytorch.ft_transformer import FTTransformer

Tableformer - Pytorch (wip)
Implementation of TableFormer, Robust Transformer Modeling for Table-Text Encoding, in Pytorch. The claim of this paper is that through attentional biases, they can make transformers more robust to perturbations to the table in question. They show improved results compared to TAPAS
Citations
@article{Yang2022TableFormerRT,
title = {TableFormer: Robust Transformer Modeling for Table-Text Encoding},
author = {Jingfeng Yang and Aditya Gupta and Shyam Upadhyay and Luheng He and Rahul Goel and Shachi Paul},
journal = {ArXiv},
year = {2022},
volume = {abs/2203.00274}
}

Taylor Series Linear Attention
Explorations into the Taylor Series Linear Attention proposed in the paper Zoology: Measuring and Improving Recall in Efficient Language Models
This repository will offer full self attention, cross attention, and autoregressive via CUDA kernel from pytorch-fast-transformers.
Be aware that in linear attention, the quadratic is pushed to the attention head dimension. With the second taylor expansion, this becomes O(D^3), so more research needed.
Update: It works! Strongest formulation of linear attention I've come across in the literature
Appreciation
- A16Z Open Source AI Grant Program and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
Install
$ pip install taylor-series-linear-attention
Usage
import torch
from taylor_series_linear_attention import TaylorSeriesLinearAttn
attn = TaylorSeriesLinearAttn(
dim = 512,
dim_head = 16,
heads = 16
)
x = torch.randn(1, 4096, 512)
mask = torch.ones((1, 4096)).bool()
out = attn(x, mask = mask)
assert x.shape == out.shape
Cross attention
import torch
from taylor_series_linear_attention import TaylorSeriesLinearAttn
attn = TaylorSeriesLinearAttn(
dim = 512,
dim_head = 16,
heads = 16
)
x = torch.randn(1, 1024, 512)
context = torch.randn(1, 65536, 512)
context_mask = torch.ones((1, 65536)).bool()
out = attn(x, context = context, mask = context_mask)
assert x.shape == out.shape
For autoregressive, first pip install pytorch-fast-transformers. Then set causal = True
import torch
from taylor_series_linear_attention import TaylorSeriesLinearAttn
attn = TaylorSeriesLinearAttn(
dim = 512,
dim_head = 16,
heads = 16,
causal = True, # set this to True
rotary_emb = True # rotary embeddings
)
x = torch.randn(1, 8192, 512)
out = attn(x)
assert x.shape == out.shape
Todo
- take care of caching for causal variant
Citations
@inproceedings{Arora2023ZoologyMA,
title = {Zoology: Measuring and Improving Recall in Efficient Language Models},
author = {Simran Arora and Sabri Eyuboglu and Aman Timalsina and Isys Johnson and Michael Poli and James Zou and Atri Rudra and Christopher R'e},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:266149332}
}
@inproceedings{Keles2022OnTC,
title = {On The Computational Complexity of Self-Attention},
author = {Feyza Duman Keles and Pruthuvi Maheshakya Wijewardena and Chinmay Hegde},
booktitle = {International Conference on Algorithmic Learning Theory},
year = {2022},
url = {https://api.semanticscholar.org/CorpusID:252198880}
}
@article{Shazeer2019FastTD,
title = {Fast Transformer Decoding: One Write-Head is All You Need},
author = {Noam M. Shazeer},
journal = {ArXiv},
year = {2019},
volume = {abs/1911.02150}
}
@inproceedings{Peng2023RWKVRR,
title = {RWKV: Reinventing RNNs for the Transformer Era},
author = {Bo Peng and Eric Alcaide and Quentin G. Anthony and Alon Albalak and Samuel Arcadinho and Stella Biderman and Huanqi Cao and Xin Cheng and Michael Chung and Matteo Grella and G Kranthikiran and Xuming He and Haowen Hou and Przemyslaw Kazienko and Jan Kocoń and Jiaming Kong and Bartlomiej Koptyra and Hayden Lau and Krishna Sri Ipsit Mantri and Ferdinand Mom and Atsushi Saito and Xiangru Tang and Bolun Wang and Johan Sokrates Wind and Stansilaw Wozniak and Ruichong Zhang and Zhenyuan Zhang and Qihang Zhao and Peng Zhou and Jian Zhu and Rui Zhu},
booktitle = {Conference on Empirical Methods in Natural Language Processing},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:258832459}
}
@inproceedings{Katharopoulos2020TransformersAR,
title = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
author = {Angelos Katharopoulos and Apoorv Vyas and Nikolaos Pappas and Franccois Fleuret},
booktitle = {International Conference on Machine Learning},
year = {2020},
url = {https://api.semanticscholar.org/CorpusID:220250819}
}
The greatest shortcoming of the human race is man’s inability to understand the exponential function. - Albert A. Bartlett
.\lucidrains\taylor-series-linear-attention\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'taylor-series-linear-attention', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.1.9', # 版本号
license='MIT', # 许可证
description = 'Taylor Series Linear Attention', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/taylor-series-linear-attention', # 项目链接
keywords = [
'artificial intelligence', # 关键词
'deep learning', # 关键词
'attention mechanism' # 关键词
],
install_requires=[
'einops>=0.7.0', # 安装所需的依赖包
'einx', # 安装所需的依赖包
'rotary-embedding-torch>=0.5.3', # 安装所需的依赖包
'torch>=2.0', # 安装所需的依赖包
'torchtyping' # 安装所需的依赖包
],
classifiers=[
'Development Status :: 4 - Beta', # 分类器
'Intended Audience :: Developers', # 分类器
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类器
'License :: OSI Approved :: MIT License', # 分类器
'Programming Language :: Python :: 3.6', # 分类器
],
)
.\lucidrains\taylor-series-linear-attention\taylor_series_linear_attention\attention.py
# 导入必要的库
import importlib
from functools import partial
from collections import namedtuple
import torch
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from torch import nn, einsum, Tensor
from einops import rearrange, pack, unpack
from einops.layers.torch import Rearrange
from typing import Optional
from torchtyping import TensorType
from rotary_embedding_torch import RotaryEmbedding
# 定义常量
# 命名元组,用于存储缓存信息
Cache = namedtuple('Cache', [
'seq_len',
'last_token',
'kv_cumsum',
'k_cumsum'
])
# 定义函数
# 判断变量是否存在
def exists(v):
return v is not None
# 如果变量存在则返回该变量,否则返回默认值
def default(v, d):
return v if exists(v) else d
# 对张量进行循环移位操作
def shift(t):
t, t_shift = t.chunk(2, dim = -1)
t_shift = F.pad(t_shift, (0, 0, 1, -1), value = 0.)
return torch.cat((t, t_shift), dim = -1)
# 预标准化
# RMS 标准化模块
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
return self.gamma * F.normalize(x, dim = -1) * self.scale
# 使用二阶泰勒展开计算指数函数
def second_taylor_expansion(x: Tensor):
dtype, device, dim = x.dtype, x.device, x.shape[-1]
x, ps = pack([x], '* d')
lead_dims = x.shape[0]
# exp(qk) = 1 + qk + (qk)^2 / 2
x0 = x.new_ones((lead_dims,))
x1 = x
x2 = einsum('... i, ... j -> ... i j', x, x) * (0.5 ** 0.5)
# 连接 - 维度 D 现在变成 (1 + D + D ^2)
# 在论文中,他们必须大幅减少注意力头维度才能使其工作
out, _ = pack([x0, x1, x2], 'b *')
out, = unpack(out, ps, '* d')
return out
# 主类
# 泰勒级数线性注意力模块
class TaylorSeriesLinearAttn(Module):
def __init__(
self,
dim,
*,
dim_head = 16,
heads = 8,
causal = False,
one_headed_kv = False,
rotary_emb = False,
combine_heads = True,
gate_value_heads = False,
prenorm = False,
shift_tokens = False,
dropout = 0.
):
super().__init__()
self.scale = dim_head ** -0.5
dim_inner = dim_head * heads
self.shift_tokens = shift_tokens
self.norm = RMSNorm(dim) if prenorm else nn.Identity()
self.heads = heads
self.dim_hidden = dim_inner
self.causal = causal
self.causal_linear_attn_fn = None
if causal:
if not exists(importlib.util.find_spec('fast_transformers')):
print('pytorch-fast-transformers must be installed. `pip install pytorch-fast-transformers` first')
exit()
from fast_transformers.causal_product import CausalDotProduct
self.causal_linear_attn_fn = CausalDotProduct.apply
kv_heads = heads if not one_headed_kv else 1
dim_kv_inner = dim_head * (heads if not one_headed_kv else 1)
self.rotary_emb = RotaryEmbedding(dim_head) if rotary_emb else None
self.one_headed_kv = one_headed_kv
# 查询投影层
self.to_q = nn.Sequential(
nn.Linear(dim, dim_inner, bias = False),
Rearrange('b n (h d) -> b h n d', h = heads)
)
# 键值投影层
self.to_kv = nn.Sequential(
nn.Linear(dim, dim_kv_inner * 2, bias = False),
Rearrange('b n (kv h d) -> kv b h n d', kv = 2, h = kv_heads)
)
# 值门控层
self.to_v_gates = nn.Sequential(
nn.Linear(dim, heads, bias = False),
nn.Sigmoid(),
Rearrange('b n h -> b h n 1')
) if gate_value_heads else None
# 合并注意力头
self.merge_heads = Rearrange('b h n d -> b n (h d)')
self.to_out = nn.Identity()
if combine_heads:
# 输出层
self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias = False),
nn.Dropout(dropout)
)
# 定义一个方法用于前向传播
def forward(
# 输入参数 x,类型为张量,形状为 ['batch', 'seq', 'dim'],数据类型为 float
x: TensorType['batch', 'seq', 'dim', float],
# 可选参数 mask,类型为张量,形状为 ['batch', 'seq'],数据类型为 bool,默认为 None
mask: Optional[TensorType['batch', 'seq', bool]] = None,
# 可选参数 context,类型为张量,形状为 ['batch', 'target_seq', 'dim'],数据类型为 float,默认为 None
context: Optional[TensorType['batch', 'target_seq', 'dim', float]] = None,
# 参数 eps,数据类型为 float,默认值为 1e-5
eps: float = 1e-5,
# 可选参数 cache,类型为 Cache 对象,默认为 None
cache: Optional[Cache] = None,
# 参数 return_cache,数据类型为 bool,默认值为 False
return_cache = False
# 适用于图像和视频的通道优先的Taylor Series线性注意力机制模块
class ChannelFirstTaylorSeriesLinearAttn(Module):
def __init__(
self,
*args,
**kwargs
):
super().__init__()
# 初始化Taylor Series线性注意力机制
self.attn = TaylorSeriesLinearAttn(*args, **kwargs)
def forward(
self,
x: Tensor
):
# 将输入张量重新排列为'通道优先'的形式
x = rearrange(x, 'b c ... -> b ... c')
# 打包输入张量,将通道维度视为单个维度
x, ps = pack([x], 'b * c')
# 使用Taylor Series线性注意力机制处理输入张量
out = self.attn(x)
# 解包处理后的张量,恢复原始形状
out, = unpack(out, ps, 'b * c')
# 将输出张量重新排列为原始形状
return rearrange(out, 'b ... c -> b c ...')
.\lucidrains\taylor-series-linear-attention\taylor_series_linear_attention\vit.py
# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 导入 torch 库
import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
import torch.nn.functional as F
# 导入 einops 库中的 rearrange 和 repeat 函数
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
# 导入自定义的注意力模块
from taylor_series_linear_attention.attention import (
TaylorSeriesLinearAttn,
ChannelFirstTaylorSeriesLinearAttn
)
# 定义函数 posemb_sincos_2d,用于生成二维的正弦余弦位置编码
def posemb_sincos_2d(
h, w,
dim,
temperature: int = 10000,
dtype = torch.float32
):
# 生成网格坐标
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing = "ij")
# 确保特征维度是4的倍数
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
dim //= 4
omega = torch.arange(dim) / (dim - 1)
omega = temperature ** -omega
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pe.type(dtype)
# 定义深度可分离卷积函数 DepthWiseConv2d
def DepthWiseConv2d(
dim_in,
dim_out,
kernel_size,
padding,
stride = 1,
bias = True
):
return nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
)
# 定义前馈神经网络类 FeedForward
class FeedForward(Module):
def __init__(
self,
dim,
mult = 4,
dropout = 0.
):
super().__init__()
dim_hidden = int(dim * mult)
self.net = nn.Sequential(
nn.Conv2d(dim, dim_hidden, 1),
nn.Hardswish(),
DepthWiseConv2d(dim_hidden, dim_hidden, 3, padding = 1),
nn.Hardswish(),
nn.Dropout(dropout),
nn.Conv2d(dim_hidden, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
h = w = int(sqrt(x.shape[-2]))
x = rearrange(x, 'b (h w) c -> b c h w', h = h, w = w)
x = self.net(x)
x = rearrange(x, 'b c h w -> b (h w) c')
return x
# 定义 Transformer 类
class Transformer(Module):
def __init__(
self,
dim,
depth,
heads,
dim_head,
ff_mult,
dropout = 0.
):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
nn.LayerNorm(dim),
TaylorSeriesLinearAttn(dim, heads = heads, dim_head = dim_head, dropout = dropout),
nn.LayerNorm(dim),
FeedForward(dim, ff_mult, dropout = dropout)
]))
def forward(self, x):
for attn_norm, attn, ff_norm, ff in self.layers:
x = attn(attn_norm(x)) + x
x = ff(ff_norm(x)) + x
return x
# 定义主类 ViT
class ViT(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
ff_mult = 4,
heads = 16,
channels = 3,
dim_head = 8,
dropout = 0.,
emb_dropout = 0.
): # 定义一个类,继承自 nn.Module
super().__init__() # 调用父类的构造函数
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size.' # 断言图片尺寸必须能够被分块尺寸整除
num_patches = (image_size // patch_size) ** 2 # 计算总的分块数量
patch_dim = channels * patch_size ** 2 # 计算每个分块的维度
self.to_patch_embedding = nn.Sequential( # 定义一个序列模块,用于将图像转换为分块嵌入
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), # 重新排列输入张量的维度
nn.LayerNorm(patch_dim), # 对每个分块进行 LayerNorm
nn.Linear(patch_dim, dim), # 线性变换将每个分块的维度映射到指定维度
nn.LayerNorm(dim), # 对映射后的维度进行 LayerNorm
)
self.register_buffer('pos_embedding', posemb_sincos_2d( # 注册一个缓冲区,存储位置编码
h = image_size // patch_size, # 图像高度上的分块数量
w = image_size // patch_size, # 图像宽度上的分块数量
dim = dim, # 位置编码的维度
), persistent = False) # 设置缓冲区为非持久性的
self.dropout = nn.Dropout(emb_dropout) # 定义一个 Dropout 层,用于在嵌入层上进行随机失活
self.transformer = Transformer(dim, depth, heads, dim_head, ff_mult, dropout) # 定义一个 Transformer 模型
self.mlp_head = nn.Sequential( # 定义一个序列模块,用于最终的 MLP 头部分类
Reduce('b n d -> b d', 'mean'), # 对输入张量进行维度缩减,计算均值
nn.LayerNorm(dim), # 对均值后的张量进行 LayerNorm
nn.Linear(dim, num_classes) # 线性变换将维度映射到类别数量
)
def forward(self, img): # 定义前向传播函数,接收输入图像
x = self.to_patch_embedding(img) # 将输入图像转换为分块嵌入
x = x + self.pos_embedding # 添加位置编码到嵌入中
x = self.dropout(x) # 对嵌入进行随机失活
x = self.transformer(x) # 使用 Transformer 模型进行特征提取和交互
return self.mlp_head(x) # 使用 MLP 头部对特征进行分类
.\lucidrains\taylor-series-linear-attention\taylor_series_linear_attention\__init__.py
# 从taylor_series_linear_attention.attention模块中导入TaylorSeriesLinearAttn和ChannelFirstTaylorSeriesLinearAttn类
from taylor_series_linear_attention.attention import (
TaylorSeriesLinearAttn,
ChannelFirstTaylorSeriesLinearAttn
)
# 从taylor_series_linear_attention.vit模块中导入ViT类
from taylor_series_linear_attention.vit import ViT
.\lucidrains\tf-bind-transformer\finetune_binary_pred.py
# 导入 load_dotenv 函数,用于加载 .env 文件中的环境变量
from dotenv import load_dotenv
# 设置缓存路径在 .env 文件中,并取消下一行的注释
# load_dotenv()
# 导入 Enformer 类
from enformer_pytorch import Enformer
# 导入 AdapterModel、Trainer 类
from tf_bind_transformer import AdapterModel, Trainer
# 实例化 Enformer 对象或加载预训练模型
enformer = Enformer.from_hparams(
dim = 768,
depth = 4,
heads = 8,
target_length = -1,
use_convnext = True,
num_downsamples = 6 # 分辨率为 2 ^ 6 == 64bp
)
# 实例化模型包装器,接受 Enformer 对象作为输入
model = AdapterModel(
enformer = enformer,
use_aa_embeds = True,
use_free_text_context = True,
free_text_embed_method = 'mean_pool',
binary_target = True,
target_mse_loss = False,
use_squeeze_excite = True,
aa_embed_encoder = 'protalbert'
).cuda()
# 训练常量
BATCH_SIZE = 2
GRAD_ACCUM_STEPS = 8
# 有效批量大小为 BATCH_SIZE * GRAD_ACCUM_STEPS = 16
VALIDATE_EVERY = 250
GRAD_CLIP_MAX_NORM = 1.5
REMAP_FILE_PATH = './remap2022_all.bed'
TFACTOR_FOLDER = './tfactor.fastas'
FASTA_FILE_PATH = './hg38.ml.fa'
NON_PEAK_PATH = './generated-non-peaks.bed'
CONTEXT_LENGTH = 4096
SCOPED_NEGS_REMAP_PATH = './neg-npy/remap2022.bed'
SCOPED_NEGS_PATH = './neg-npy'
TRAIN_CHROMOSOMES = [*range(1, 24, 2), 'X'] # 在奇数染色体上训练
VALID_CHROMOSOMES = [*range(2, 24, 2)] # 在偶数染色体上验证
HELD_OUT_TARGET = ['AFF4']
# 实例化 Trainer 类用于微调
trainer = Trainer(
model,
context_length = CONTEXT_LENGTH,
batch_size = BATCH_SIZE,
validate_every = VALIDATE_EVERY,
grad_clip_norm = GRAD_CLIP_MAX_NORM,
grad_accum_every = GRAD_ACCUM_STEPS,
remap_bed_file = REMAP_FILE_PATH,
negative_bed_file = NON_PEAK_PATH,
factor_fasta_folder = TFACTOR_FOLDER,
fasta_file = FASTA_FILE_PATH,
train_chromosome_ids = TRAIN_CHROMOSOMES,
valid_chromosome_ids = VALID_CHROMOSOMES,
held_out_targets = HELD_OUT_TARGET,
include_scoped_negs = True,
scoped_negs_remap_bed_path = SCOPED_NEGS_REMAP_PATH,
scoped_negs_path = SCOPED_NEGS_PATH,
)
# 在 while 循环中执行梯度步骤
while True:
_ = trainer(finetune_enformer_ln_only = False)
.\lucidrains\tf-bind-transformer\finetune_track.py
# 导入 load_dotenv 函数,用于加载 .env 文件中的环境变量
from dotenv import load_dotenv
# 设置缓存路径在 .env 文件中,并取消下一行的注释
# load_dotenv()
# 导入 Enformer 类和 AdapterModel、BigWigTrainer 类
from enformer_pytorch import Enformer
from tf_bind_transformer import AdapterModel, BigWigTrainer
# 训练常量
# 批量大小
BATCH_SIZE = 1
# 梯度累积步数
GRAD_ACCUM_STEPS = 8
# 学习率
LEARNING_RATE = 1e-4 # Deepmind 在 Enformer 微调中使用了 1e-4
# 有效批量大小为 BATCH_SIZE * GRAD_ACCUM_STEPS = 16
# 每隔多少步进行验证
VALIDATE_EVERY = 250
# 梯度裁剪最大范数
GRAD_CLIP_MAX_NORM = 1.5
# TFactor 文件夹路径
TFACTOR_FOLDER = './tfactor.fastas'
# 人类基因组 FASTA 文件路径
HUMAN_FASTA_FILE_PATH = './hg38.ml.fa'
# 小鼠基因组 FASTA 文件路径
MOUSE_FASTA_FILE_PATH = './mm10.ml.fa'
# 人类基因组区域路径
HUMAN_LOCI_PATH = './chip_atlas/human_sequences.bed'
# 小鼠基因组区域路径
MOUSE_LOCI_PATH = './chip_atlas/mouse_sequences.bed'
# BigWig 文件夹路径
BIGWIG_PATH = './chip_atlas/bigwig'
# 仅包含 BigWig 轨道的文件夹路径
BIGWIG_TRACKS_ONLY_PATH = './chip_atlas/bigwig_tracks_only'
# 注释文件路径
ANNOT_FILE_PATH = './chip_atlas/annot.tab'
# 目标长度
TARGET_LENGTH = 896
# 保留的目标
HELD_OUT_TARGET = ['GATA2']
# 实例化 Enformer 或加载预训练模型
enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough', target_length = TARGET_LENGTH)
# 实例化模型包装器,接受 Enformer 模型
model = AdapterModel(
enformer = enformer,
use_aa_embeds = True,
use_free_text_context = True,
free_text_embed_method = 'mean_pool',
aa_embed_encoder = 'esm',
finetune_output_heads = dict(
human = 12,
mouse = 24
)
).cuda()
# 用于微调的训练器类
trainer = BigWigTrainer(
model,
human_loci_path = HUMAN_LOCI_PATH,
mouse_loci_path = MOUSE_LOCI_PATH,
human_fasta_file = HUMAN_FASTA_FILE_PATH,
mouse_fasta_file = MOUSE_FASTA_FILE_PATH,
bigwig_folder_path = BIGWIG_PATH,
bigwig_tracks_only_folder_path = BIGWIG_TRACKS_ONLY_PATH,
annot_file_path = ANNOT_FILE_PATH,
target_length = TARGET_LENGTH,
lr = LEARNING_RATE,
batch_size = BATCH_SIZE,
shuffle = True,
validate_every = VALIDATE_EVERY,
grad_clip_norm = GRAD_CLIP_MAX_NORM,
grad_accum_every = GRAD_ACCUM_STEPS,
human_factor_fasta_folder = TFACTOR_FOLDER,
mouse_factor_fasta_folder = TFACTOR_FOLDER,
held_out_targets = HELD_OUT_TARGET
)
# 在 while 循环中执行梯度步骤
while True:
_ = trainer()
.\lucidrains\tf-bind-transformer\precache_proteins.py
# 导入需要的库
import click # 用于创建命令行接口
from tqdm import tqdm # 用于显示进度条
from pathlib import Path # 用于处理文件路径
from Bio import SeqIO # 用于处理生物信息学数据
from tf_bind_transformer.protein_utils import get_protein_embedder # 从自定义模块中导入函数
# 创建命令行接口
@click.command()
@click.option('--model-name', default = 'protalbert', help = 'Protein model name') # 添加命令行参数,指定蛋白质模型名称
@click.option('--fasta-folder', help = 'Path to factor fastas', required = True) # 添加命令行参数,指定FASTA文件夹路径
def cache_embeddings(
model_name, # 指定蛋白质模型名称
fasta_folder # 指定FASTA文件夹路径
):
# 获取指定蛋白质模型的函数
fn = get_protein_embedder(model_name)['fn']
# 获取FASTA文件夹下所有的FASTA文件路径
fastas = [*Path(fasta_folder).glob('**/*.fasta')]
# 断言确保至少找到一个FASTA文件
assert len(fastas) > 0, f'no fasta files found at {fasta_folder}'
# 遍历所有FASTA文件并处理
for fasta in tqdm(fastas):
# 读取FASTA文件中的序列数据
seq = SeqIO.read(fasta, 'fasta')
# 将序列数据转换为字符串
seq_str = str(seq.seq)
# 使用指定的函数处理序列数据
fn([seq_str], device = 'cpu')
# 如果作为脚本直接运行,则调用cache_embeddings函数
if __name__ == '__main__':
cache_embeddings()
Transcription Factor binding predictions with Attention and Transformers
A repository with exploration into using transformers to predict DNA ↔ transcription factor binding.
Install
Run the following at the project root to download dependencies
$ python setup.py install --user
Then you must install pybedtools as well as pyBigWig
$ conda install --channel conda-forge --channel bioconda pybedtools pyBigWig
Usage
import torch
from tf_bind_transformer import AdapterModel
# instantiate enformer or load pretrained
from enformer_pytorch import Enformer
enformer = Enformer.from_hparams(
dim = 1536,
depth = 2,
target_length = 256
)
# instantiate model wrapper that takes in enformer
model = AdapterModel(
enformer = enformer,
aa_embed_dim = 512,
contextual_embed_dim = 256
).cuda()
# mock data
seq = torch.randint(0, 4, (1, 196_608 // 2)).cuda() # for ACGT
aa_embed = torch.randn(1, 1024, 512).cuda()
aa_mask = torch.ones(1, 1024).bool().cuda()
contextual_embed = torch.randn(1, 256).cuda() # contextual embeddings, including cell type, species, experimental parameter embeddings
target = torch.randn(1, 256).cuda()
# train
loss = model(
seq,
aa_embed = aa_embed,
aa_mask = aa_mask,
contextual_embed = contextual_embed,
target = target
)
loss.backward()
# after a lot of training
corr_coef = model(
seq,
aa_embed = aa_embed,
aa_mask = aa_mask,
contextual_embed = contextual_embed,
target = target,
return_corr_coef = True
)
Using ESM or ProtAlbert for fetching of transcription factor protein embeddings
import torch
from enformer_pytorch import Enformer
from tf_bind_transformer import AdapterModel
enformer = Enformer.from_hparams(
dim = 1536,
depth = 2,
target_length = 256
)
model = AdapterModel(
enformer = enformer,
use_aa_embeds = True, # set this to True
aa_embed_encoder = 'esm', # by default, will use esm, but can be set to 'protalbert', which has a longer context length of 2048 (vs esm's 1024)
contextual_embed_dim = 256
).cuda()
# mock data
seq = torch.randint(0, 4, (1, 196_608 // 2)).cuda()
tf_aa = torch.randint(0, 21, (1, 4)).cuda() # transcription factor amino acid sequence, from 0 to 20
contextual_embed = torch.randn(1, 256).cuda()
target = torch.randn(1, 256).cuda()
# train
loss = model(
seq,
aa = tf_aa,
contextual_embed = contextual_embed,
target = target
)
loss.backward()
- add alphafold2
Context passed in as free text
One can also pass the context (cell type, experimental parameters) directly as free text, which will be encoded by a text transformer trained on pubmed abstracts.
import torch
from tf_bind_transformer import AdapterModel
# instantiate enformer or load pretrained
from enformer_pytorch import Enformer
enformer = Enformer.from_hparams(
dim = 1536,
depth = 2,
target_length = 256
)
# instantiate model wrapper that takes in enformer
model = AdapterModel(
enformer = enformer,
use_aa_embeds = True,
use_free_text_context = True, # this must be set to True
free_text_embed_method = 'mean_pool' # allow for mean pooling of embeddings, instead of using CLS token
).cuda()
# mock data
seq = torch.randint(0, 4, (2, 196_608 // 2)).cuda() # for ACGT
target = torch.randn(2, 256).cuda()
tf_aa = [
'KVFGRCELAA', # single protein
('AMKRHGLDNY', 'YNDLGHRKMA') # complex, representations will be concatted together
]
contextual_texts = [
'cell type: GM12878 | dual cross-linked',
'cell type: H1-hESC'
]
# train
loss = model(
seq,
target = target,
aa = tf_aa,
contextual_free_text = contextual_texts,
)
loss.backward()
Binary prediction
For predicting binary outcome (bind or not bind), just set the binary_targets = True when initializing either adapters
ex.
import torch
from tf_bind_transformer import AdapterModel
from enformer_pytorch import Enformer
# instantiate enformer or load pretrained
enformer = Enformer.from_hparams(
dim = 1536,
depth = 2,
target_length = 256
)
# instantiate model wrapper that takes in enformer
model = AdapterModel(
enformer = enformer,
use_aa_embeds = True,
use_free_text_context = True,
free_text_embed_method = 'mean_pool',
use_squeeze_excite = True,
binary_target = True, # set this to True
target_mse_loss = False # whether to use MSE loss with target value
).cuda()
# mock data
seq = torch.randint(0, 4, (1, 196_608 // 2)).cuda() # for ACGT
binary_target = torch.randint(0, 2, (2,)).cuda() # bind or not bind
tf_aa = [
'KVFGRCELAA',
('AMKRHGLDNY', 'YNDLGHRKMA')
]
contextual_texts = [
'cell type: GM12878 | chip-seq dual cross-linked',
'cell type: H1-hESC | chip-seq single cross-linked'
]
# train
loss = model(
seq,
target = binary_target,
aa = tf_aa,
contextual_free_text = contextual_texts,
)
loss.backward()
Predicting Tracks from BigWig
from pathlib import Path
import torch
from enformer_pytorch import Enformer
from tf_bind_transformer import AdapterModel
from tf_bind_transformer.data_bigwig import BigWigDataset, get_bigwig_dataloader
# constants
ROOT = Path('.')
TFACTOR_TF = str(ROOT / 'tfactor.fastas')
ENFORMER_DATA = str(ROOT / 'chip_atlas' / 'sequences.bed')
FASTA_FILE_PATH = str(ROOT / 'hg38.ml.fa')
BIGWIG_PATH = str(ROOT / 'chip_atlas')
ANNOT_FILE_PATH = str(ROOT / 'chip_atlas' / 'annot.tab')
# bigwig dataset and dataloader
ds = BigWigDataset(
factor_fasta_folder = TFACTOR_TF,
bigwig_folder = BIGWIG_PATH,
enformer_loci_path = ENFORMER_DATA,
annot_file = ANNOT_FILE_PATH,
fasta_file = FASTA_FILE_PATH
)
dl = get_bigwig_dataloader(ds, batch_size = 2)
# enformer
enformer = Enformer.from_hparams(
dim = 384,
depth = 1,
target_length = 896
)
model = AdapterModel(
enformer = enformer,
use_aa_embeds = True,
use_free_text_context = True
).cuda()
# mock data
seq, tf_aa, context_str, target = next(dl)
seq, target = seq.cuda(), target.cuda()
# train
loss = model(
seq,
aa = tf_aa,
contextual_free_text = context_str,
target = target
)
loss.backward()
Data
The data needed for training is at this download page.
Transcription factors for Human and Mouse
To download the protein sequences for both species, you need to download the remap CRMs bed files, from which all the targets will be extracted, and fastas to be downloaded from Uniprot.
Download human remap CRMS
$ wget https://remap.univ-amu.fr/storage/remap2022/hg38/MACS2/remap2022_crm_macs2_hg38_v1_0.bed.gz
$ gzip -d remap2022_crm_macs2_hg38_v1_0.bed.gz
Download mouse remap CRMs
$ wget https://remap.univ-amu.fr/storage/remap2022/mm10/MACS2/remap2022_crm_macs2_mm10_v1_0.bed.gz
$ gzip -d remap2022_crm_macs2_mm10_v1_0.bed.gz
Downloading all human transcription factors
$ python script/fetch_factor_fastas.py --species human
For mouse transcription factors
$ python script/fetch_factor_fastas.py --species mouse
Generating Negatives
Generating Hard Negatives
For starters, the RemapAllPeakDataset will allow you to load data easily from the full remap peaks bed file for training.
Firstly you'll need to generate the non-peaks dataset by running the following function
from tf_bind_transformer.data import generate_random_ranges_from_fasta
generate_random_ranges_from_fasta(
'./hg38.ml.fa',
output_filename = './path/to/generated-non-peaks.bed', # path to output file
context_length = 4096,
num_entries_per_key = 1_000_000, # number of negative samples
filter_bed_files = [
'./remap_all.bed', # filter out by all peak ranges (todo, allow filtering namespaced to experiment and target)
'./hg38.blacklist.rep.bed' # further filtering by blacklisted regions (gs://basenji_barnyard/hg38.blacklist.rep.bed)
]
)
Generating Scoped Negatives - Negatives per Dataset (experiment + target + cell type)
Todo
Simple Trainer class for fine-tuning
working fine-tuning training loop for bind / no-bind prediction
import torch
from enformer_pytorch import Enformer
from tf_bind_transformer import AdapterModel, Trainer
# instantiate enformer or load pretrained
enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough', target_length = -1)
# instantiate model wrapper that takes in enformer
model = AdapterModel(
enformer = enformer,
use_aa_embeds = True,
use_free_text_context = True,
free_text_embed_method = 'mean_pool',
binary_target = True,
target_mse_loss = True,
use_squeeze_excite = True,
aux_read_value_loss = True # use auxiliary read value loss, can be turned off
).cuda()
# pass the model (adapter + enformer) to the Trainer
trainer = Trainer(
model,
batch_size = 2, # batch size
context_length = 4096, # genetic sequence length
grad_accum_every = 8, # gradient accumulation steps
grad_clip_norm = 2.0, # gradient clipping
validate_every = 250,
remap_bed_file = './remap2022_all.bed', # path to remap bed peaks
negative_bed_file = './generated-non-peaks.bed', # path to generated non-peaks
factor_fasta_folder = './tfactor.fastas', # path to factor fasta files
fasta_file = './hg38.ml.fa', # human genome sequences
train_chromosome_ids = [*range(1, 24, 2), 'X'], # chromosomes to train on
valid_chromosome_ids = [*range(2, 24, 2)], # chromosomes to validate on
held_out_targets = ['AFF4'], # targets to hold out for validation
experiments_json_path = './data/experiments.json' # path to all experiments data, at this path relative to the project root, if repository is git cloned
)
while True:
_ = trainer()
working fine-tuning script for training on new enformer tracks, with cross-attending transcription factor protein embeddings and cell type conditioning
from dotenv import load_dotenv
# set path to cache in .env and unset the next comment
# load_dotenv()
from enformer_pytorch import Enformer
from tf_bind_transformer import AdapterModel, BigWigTrainer
# training constants
BATCH_SIZE = 1
GRAD_ACCUM_STEPS = 8
# effective batch size of BATCH_SIZE * GRAD_ACCUM_STEPS = 16
VALIDATE_EVERY = 250
GRAD_CLIP_MAX_NORM = 1.5
TFACTOR_FOLDER = './tfactor.fastas'
FASTA_FILE_PATH = './hg38.ml.fa'
LOCI_PATH = './sequences.bed'
BIGWIG_PATH = './bigwig_folder'
ANNOT_FILE_PATH = './experiments.tab'
TARGET_LENGTH = 896
TRAIN_CHROMOSOMES = [*range(1, 24, 2), 'X'] # train on odd chromosomes
VALID_CHROMOSOMES = [*range(2, 24, 2)] # validate on even
HELD_OUT_TARGET = ['SOX2']
# instantiate enformer or load pretrained
enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough', target_length = TARGET_LENGTH)
# instantiate model wrapper that takes in enformer
model = AdapterModel(
enformer = enformer,
use_aa_embeds = True,
use_free_text_context = True,
free_text_embed_method = 'mean_pool',
aa_embed_encoder = 'protalbert'
).cuda()
# trainer class for fine-tuning
trainer = BigWigTrainer(
model,
loci_path = LOCI_PATH,
bigwig_folder_path = BIGWIG_PATH,
annot_file_path = ANNOT_FILE_PATH,
target_length = TARGET_LENGTH,
batch_size = BATCH_SIZE,
validate_every = VALIDATE_EVERY,
grad_clip_norm = GRAD_CLIP_MAX_NORM,
grad_accum_every = GRAD_ACCUM_STEPS,
factor_fasta_folder = TFACTOR_FOLDER,
fasta_file = FASTA_FILE_PATH,
train_chromosome_ids = TRAIN_CHROMOSOMES,
valid_chromosome_ids = VALID_CHROMOSOMES,
held_out_targets = HELD_OUT_TARGET
)
# do gradient steps in a while loop
while True:
_ = trainer()
Resources
If you are low on GPU memory, you can save by making sure the protein and contextual embeddings are executed on CPU
CONTEXT_EMBED_USE_CPU=1 PROTEIN_EMBED_USE_CPU=1 python train.py
Data
Transcription factor dataset
from tf_bind_transformer.data import FactorProteinDataset
ds = FactorProteinDataset(
folder = 'path/to/tfactor/fastas'
)
# single factor
ds['ETV1'] # <seq>
# multi-complexes
ds['PAX3-FOXO1'] # (<seq1>, <seq2>)
Preprocessing (wip)
get a copy of hg38 blacklist bed file from calico
$ gsutil cp gs://basenji_barnyard/hg38.blacklist.rep.bed ./
using bedtools to filter out repetitive regions of the genome
$ bedtools intersect -v -a ./remap2022_all_macs2_hg38_v1_0.bed -b ./hg38.blacklist.rep.bed > remap2022_all_filtered.bed
Caching
During training, protein sequences and contextual strings are cached to ~/.cache.tf.bind.transformer directory. If you would like to make sure the caching is working, you just need to run your training script with VERBOSE=1
ex.
$ VERBOSE=1 python train.py
You can also force a cache clearance
$ CLEAR_CACHE=1 python train.py
Todo
- ESM and AF2 embedding fetching integrations
- HF transformers integration for conditioning on free text
- allow for fine-tuning layernorms of Enformer easily
- add caching for external embeddings
- figure out a way for external models (ESM, transformers) to be omitted from state dictionary on saving (use singletons)
- take care of caching genetic sequences when enformer is frozen
- offer a fully transformer variant with cross-attention with shared attention matrix and FiLM conditioning with contextual embed
- also offer using pooled genetic / protein sequence concatted with context -> project -> squeeze excitation type conditioning
- use checkpointing when fine-tuning enformer
- take care of prepping dataframe with proper chromosome and training / validation split
- use basenji blacklist bed file for filtering out rows in remap
- filter remap dataframe based on tfactor fasta folder
- filter remap dataframe with hg38 blacklist
- handle targets with modifications from remap with all peaks (underscore in name)
- grad clipping
- add a safe initialization whereby rows of dataframe with targets not found in the tfactor fasta folder will be filtered out
- add accuracy metric to fine tune script
- master trainer class that handles both training / validation splitting, efficient instantiation of dataframe, filtering etc
- write a simple trainer class that takes care of the training loop
- create faster protein and context embedding derivation by optionally moving model to gpu and back to cpu when done
- use ProtTrans for longer context proteins, look into AF2
- make protalbert usable with one flag
- log auxiliary losses appropriately (read value)
- write fine-tuning script for finetuning on merged genomic track(s) from remap
- support for custom transformers other than enformer
- warmup in training loop
- mixed precision
- use wandb for experiment tracking
- cleanup tech debt in data and protein_utils
- explore protein model fine-tuning of layernorm
- auto-auroc calc
- k-fold cross validation
- output attention intermediates (or convolution output for hypertransformer), for interpreting binding site
- use prefect.io to manage downloading of tfactors fastas, remap scoped negative peaks, blacklist filtering etc
Appreciation
This work was generously sponsored by Jeff Hsu to be done completely open sourced.
Citations
@article {Avsec2021.04.07.438649,
author = {Avsec, {\v Z}iga and Agarwal, Vikram and Visentin, Daniel and Ledsam, Joseph R. and Grabska-Barwinska, Agnieszka and Taylor, Kyle R. and Assael, Yannis and Jumper, John and Kohli, Pushmeet and Kelley, David R.},
title = {Effective gene expression prediction from sequence by integrating long-range interactions},
elocation-id = {2021.04.07.438649},
year = {2021},
doi = {10.1101/2021.04.07.438649},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649},
eprint = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649.full.pdf},
journal = {bioRxiv}
}
@misc{yao2021filip,
title = {FILIP: Fine-grained Interactive Language-Image Pre-Training},
author = {Lewei Yao and Runhui Huang and Lu Hou and Guansong Lu and Minzhe Niu and Hang Xu and Xiaodan Liang and Zhenguo Li and Xin Jiang and Chunjing Xu},
year = {2021},
eprint = {2111.07783},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@misc{tay2020hypergrid,
title = {HyperGrid: Efficient Multi-Task Transformers with Grid-wise Decomposable Hyper Projections},
author = {Yi Tay and Zhe Zhao and Dara Bahri and Donald Metzler and Da-Cheng Juan},
year = {2020},
eprint = {2007.05891},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
@misc{lowe2021logavgexp,
title = {LogAvgExp Provides a Principled and Performant Global Pooling Operator},
author = {Scott C. Lowe and Thomas Trappenberg and Sageev Oore},
year = {2021},
eprint = {2111.01742},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@article{10.1093/nar/gkab996,
author = {Hammal, Fayrouz and de Langen, Pierre and Bergon, Aurélie and Lopez, Fabrice and Ballester, Benoit},
title = "{ReMap 2022: a database of Human, Mouse, Drosophila and Arabidopsis regulatory regions from an integrative analysis of DNA-binding sequencing experiments}",
journal = {Nucleic Acids Research},
issn = {0305-1048},
doi = {10.1093/nar/gkab996},
url = {https://doi.org/10.1093/nar/gkab996},
eprint = {https://academic.oup.com/nar/article-pdf/50/D1/D316/42058627/gkab996.pdf},
}