Lucidrains 系列项目源码解析(九十七)
.\lucidrains\tf-bind-transformer\scripts\download_experiments.py
import json
import tqdm
import requests
NCBI_TAX_ID = dict(
human = 9606,
mouse = 10090
)
SPECIES = 'human'
API_URL = 'https://remap.univ-amu.fr/api/v1/'
def get_json(url, params = dict()):
headers = dict(Accept = 'application/json')
resp = requests.get(url, params = params, headers = headers)
return resp.json()
def get_experiments(species):
assert species in NCBI_TAX_ID
taxid = NCBI_TAX_ID[species]
experiments = get_json(f'{API_URL}list/experiments/taxid={taxid}')
return experiments
def get_experiment(experiment_id, species):
assert species in NCBI_TAX_ID
taxid = NCBI_TAX_ID[species]
experiment = get_json(f'http://remap.univ-amu.fr/api/v1/info/byExperiment/experiment={experiment_id}&taxid={taxid}')
return experiment
experiments = get_experiments(SPECIES)
for experiment in tqdm.tqdm(experiments['experiments']):
experiment_details = get_experiment(experiment['accession'], SPECIES)
experiment['details'] = experiment_details
with open('data/experiments.json', 'w+') as f:
contents = json.dumps(experiments, indent = 4, sort_keys = True)
f.write(contents)
print('success')
.\lucidrains\tf-bind-transformer\scripts\fetch_factor_fastas.py
import requests
from pathlib import Path
import click
import polars as pl
from tqdm import tqdm
from tf_bind_transformer.gene_utils import parse_gene_name
from tf_bind_transformer.data import read_bed
UNIPROT_URL = 'http://www.uniprot.org'
DEFAULT_REMAP_PATH = dict(
HUMAN = './remap2022_crm_macs2_hg38_v1_0.bed',
MOUSE = './remap2022_crm_macs2_mm10_v1_0.bed',
)
GENE_NAME_TO_ID_OVERRIDE = {
'SS18-SSX': ['Q8IZH1'],
'TFIIIC': ['A6ZV34']
}
def uniprot_mapping(fromtype, totype, identifier):
params = {
'from': fromtype,
'to': totype,
'format': 'tab',
'query': identifier,
}
response = requests.get(f'{UNIPROT_URL}/mapping', params = params)
return response.text
@click.command()
@click.option('--species', help = 'Species', default = 'human', type = click.Choice(['human', 'mouse']))
@click.option('--remap-bed-path', help = 'Path to species specific remap file')
@click.option('--fasta-folder', help = 'Path to factor fastas', default = './tfactor.fastas')
def fetch_factors(
species,
remap_bed_path,
fasta_folder
):
species = species.upper()
if remap_bed_path is None:
remap_bed_path = DEFAULT_REMAP_PATH[species]
remap_bed_path = Path(remap_bed_path)
assert remap_bed_path.exists(), f'remap file does not exist at {str(remap_bed_path)}'
df = read_bed(remap_bed_path)
genes = set([target for targets in df[:, 3] for target in targets.split(',')])
print(f'{len(genes)} factors found')
fasta_files = [str(path) for path in Path('./').glob('*.fasta')]
processed_genes = set([*map(lambda t: str(t).split('.')[0], fasta_files)])
results_folder = Path(fasta_folder)
results_folder.mkdir(exist_ok = True, parents = True)
for unparsed_gene_name in tqdm(genes):
for gene_name in parse_gene_name(unparsed_gene_name):
if gene_name in processed_genes:
continue
if gene_name not in GENE_NAME_TO_ID_OVERRIDE:
uniprot_resp = uniprot_mapping('GENENAME', 'ID', gene_name)
entries = list(filter(lambda t: f'_{species}' in t, uniprot_resp.split('\n')))
entries = list(map(lambda t: t.split('\t')[1], entries))
else:
entries = GENE_NAME_TO_ID_OVERRIDE[gene_name]
if len(entries) == 0:
print(f'no entries found for {gene_name}')
continue
for entry in entries:
response = requests.get(f'{UNIPROT_URL}/uniprot/{entry}.fasta')
if response.status_code != 200:
print(f'<{response.status_code}> error fetching fasta file from gene {gene_name} {entry}')
continue
fasta_path = str(results_folder / f'{gene_name}.{entry}.fasta')
with open(fasta_path, 'w') as f:
f.write(response.text)
print(f'gene {gene_name} written')
if __name__ == '__main__':
fetch_factors()
.\lucidrains\tf-bind-transformer\scripts\negative_peak_to_bool_npy.py
import polars as pl
import numpy as np
from pathlib import Path
import sys
NEGATIVE_PEAK_PATH = sys.argv[1]
NUMROWS = int(sys.argv[2])
ID_COLUMN = 'column_6'
df = pl.read_csv(NEGATIVE_PEAK_PATH, sep = '\t', has_headers = False)
np_array = df.get_column(ID_COLUMN).to_numpy()
to_save = np.full((NUMROWS,), False)
to_save[np_array - 1] = True
p = Path(NEGATIVE_PEAK_PATH)
filename = f'{p.stem}.bool'
np.save(filename, to_save)
print(f'{filename} saved')
.\lucidrains\tf-bind-transformer\scripts\remap_to_separate_exp_target_cell_beds.py
import polars as pl
from pathlib import Path
from tf_bind_transformer.data import read_bed, save_bed
def generate_separate_exp_target_cell_beds(
remap_file,
*,
output_folder = './negative-peaks-per-target',
exp_target_cell_type_col = 'column_4'
):
output_folder = Path(output_folder)
output_folder.mkdir(exist_ok = True, parents = True)
df = read_bed(remap_file)
target_experiments = df.get_column(exp_target_cell_type_col).unique().to_list()
for target_experiment in target_experiments:
filtered_df = df.filter(pl.col(exp_target_cell_type_col) == target_experiment)
target_bed_path = str(output_folder / f'{target_experiment}.bed')
save_bed(filtered_df, target_bed_path)
print('success')
.\lucidrains\tf-bind-transformer\setup.py
from setuptools import setup, find_packages
setup(
name = 'tf-bind-transformer',
packages = find_packages(exclude=[]),
version = '0.0.118',
license='MIT',
description = 'Transformer for Transcription Factor Binding',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/tf-bind-transformer',
long_description_content_type = 'text/markdown',
keywords = [
'artificial intelligence',
'deep learning',
'attention mechanism',
'transformers',
'transcription factors',
'gene expression'
],
install_requires=[
'bidirectional-cross-attention',
'biopython',
'click',
'einops>=0.3',
'enformer-pytorch>=0.5',
'fair-esm',
'logavgexp-pytorch',
'polars',
'python-dotenv',
'sentencepiece',
'torch>=1.6',
'transformers>=4.0',
'tqdm'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\tf-bind-transformer\tf_bind_transformer\attention.py
import torch
from torch import nn
from einops import rearrange
from torch import einsum
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def FeedForward(dim, mult = 4, dropout = 0.):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * mult),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
class SelfAttention(nn.Module):
def __init__(
self,
*,
dim,
heads = 8,
dim_head = 64,
dropout = 0.
):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x,
mask = None,
):
h = self.heads
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q = q * self.scale
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
sim = einsum('b h i d, b h j d -> b h i j', q, k)
if exists(mask):
mask_value = -torch.finfo(sim.dtype).max
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, mask_value)
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class SelfAttentionBlock(nn.Module):
def __init__(
self,
*,
dim,
dropout = 0.,
ff_mult = 4,
**kwargs
):
super().__init__()
self.attn = SelfAttention(dim = dim, dropout = dropout, **kwargs)
self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout)
def forward(self, x, mask = None):
x = self.attn(x, mask = mask) + x
x = self.ff(x) + x
return x
class CrossAttention(nn.Module):
def __init__(
self,
*,
dim,
heads = 8,
dim_head = 64,
context_dim = None,
dropout = 0.
):
super().__init__()
context_dim = default(context_dim, dim)
self.norm = nn.LayerNorm(dim)
self.context_norm = nn.LayerNorm(context_dim)
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x,
context,
mask = None,
context_mask = None
):
h = self.heads
x = self.norm(x)
context = self.context_norm(context)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q = q * self.scale
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
sim = einsum('b h i d, b h j d -> b h i j', q, k)
if exists(context_mask):
mask_value = -torch.finfo(sim.dtype).max
context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~context_mask, mask_value)
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class JointCrossAttentionBlock(nn.Module):
def __init__(
self,
*,
dim,
context_dim = None,
ff_mult = 4,
dropout = 0.,
**kwargs
):
super().__init__()
context_dim = default(context_dim, dim)
self.attn = BidirectionalCrossAttention(dim = dim, context_dim = context_dim, dropout = dropout, prenorm = True, **kwargs)
self.ff = FeedForward(dim, mult = ff_mult, dropout = dropout)
self.context_ff = FeedForward(context_dim, mult = ff_mult, dropout = dropout)
def forward(
self,
x,
context,
mask = None,
context_mask = None
):
attn_out, context_attn_out = self.attn(x, context, mask = mask, context_mask = context_mask)
x = x + attn_out
context = context + context_attn_out
x = self.ff(x) + x
context = self.context_ff(context) + context
return x, context
.\lucidrains\tf-bind-transformer\tf_bind_transformer\cache_utils.py
import os
from shutil import rmtree
import torch
import hashlib
from functools import wraps
from pathlib import Path
def exists(val):
return val is not None
CACHE_PATH = Path(os.getenv('TF_BIND_CACHE_PATH', os.path.expanduser('~/.cache.tf.bind.transformer')))
CACHE_PATH.mkdir(exist_ok=True, parents=True)
CLEAR_CACHE = exists(os.getenv('CLEAR_CACHE', None))
VERBOSE = exists(os.getenv('VERBOSE', None))
def log(s):
if not VERBOSE:
return
print(s)
def md5_hash_fn(s):
encoded = s.encode('utf-8')
return hashlib.md5(encoded).hexdigest()
GLOBAL_RUN_RECORDS = dict()
def run_once(global_id=None):
def outer(fn):
has_ran_local = False
output = None
@wraps(fn)
def inner(*args, **kwargs):
nonlocal has_ran_local
nonlocal output
has_ran = GLOBAL_RUN_RECORDS.get(global_id, False) if exists(global_id) else has_ran_local
if has_ran:
return output
output = fn(*args, **kwargs)
if exists(global_id):
GLOBAL_RUN_RECORDS[global_id] = True
has_ran = True
return output
return inner
return outer
def cache_fn(
fn,
path='',
hash_fn=md5_hash_fn,
clear=False or CLEAR_CACHE,
should_cache=True
):
if not should_cache:
return fn
(CACHE_PATH / path).mkdir(parents=True, exist_ok=True)
@run_once(path)
def clear_cache_folder_():
cache_path = rmtree(str(CACHE_PATH / path))
(CACHE_PATH / path).mkdir(parents=True, exist_ok=True)
@wraps(fn)
def inner(t, *args, __cache_key=None, **kwargs):
if clear:
clear_cache_folder_()
cache_str = __cache_key if exists(__cache_key) else t
key = hash_fn(cache_str)
entry_path = CACHE_PATH / path / f'{key}.pt'
if entry_path.exists():
log(f'cache hit: fetching {t} from {str(entry_path)}')
return torch.load(str(entry_path))
out = fn(t, *args, **kwargs)
log(f'saving: {t} to {str(entry_path)}')
torch.save(out, str(entry_path))
return out
return inner
.\lucidrains\tf-bind-transformer\tf_bind_transformer\context_utils.py
import torch
import os
import logging
from transformers import AutoTokenizer, AutoModelForMaskedLM, logging
from tf_bind_transformer.cache_utils import cache_fn, run_once
logging.set_verbosity_error()
def exists(val):
return val is not None
def map_values(fn, dictionary):
return {k: fn(v) for k, v in dictionary.items()}
CONTEXT_EMBED_USE_CPU = os.getenv('CONTEXT_EMBED_USE_CPU', None) is not None
if CONTEXT_EMBED_USE_CPU:
print('calculating context embed only on cpu')
MODELS = dict(
pubmed = dict(
dim = 768,
path = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract',
)
)
GLOBAL_VARIABLES = dict(model = None, tokenizer = None)
def get_contextual_dim(model_name):
assert model_name in MODELS
return MODELS[model_name]['dim']
@run_once('init_transformer')
def init_transformer(model_name):
path = MODELS[model_name]['path']
GLOBAL_VARIABLES['tokenizer'] = AutoTokenizer.from_pretrained(path)
model = AutoModelForMaskedLM.from_pretrained(path)
if not CONTEXT_EMBED_USE_CPU:
model = model.cuda()
GLOBAL_VARIABLES['model'] = model
@torch.no_grad()
def tokenize_text(
text,
max_length = 256,
model_name = 'pubmed',
hidden_state_index = -1,
return_cls_token = True
):
init_transformer(model_name)
model = GLOBAL_VARIABLES['model']
tokenizer = GLOBAL_VARIABLES['tokenizer']
encoding = tokenizer.batch_encode_plus(
[text],
add_special_tokens = True,
padding = True,
truncation = True,
max_length = max_length,
return_attention_mask = True,
return_tensors = 'pt'
)
if not CONTEXT_EMBED_USE_CPU:
encoding = map_values(lambda t: t.cuda(), encoding)
model.eval()
with torch.no_grad():
outputs = model(**encoding, output_hidden_states = True)
hidden_state = outputs.hidden_states[hidden_state_index][0]
if return_cls_token:
return hidden_state[0]
return hidden_state.mean(dim = 0)
def get_text_repr(
texts,
*,
device,
max_length = 256,
model_name = 'pubmed',
hidden_state_index = -1,
return_cls_token = True,
):
assert model_name in MODELS, f'{model_name} not found in available text transformers to use'
if isinstance(texts, str):
texts = [texts]
get_context_repr_fn = cache_fn(tokenize_text, path = f'contexts/{model_name}')
representations = [get_context_repr_fn(text, max_length = max_length, model_name = model_name, hidden_state_index = hidden_state_index, return_cls_token = return_cls_token) for text in texts]
return torch.stack(representations).to(device)
.\lucidrains\tf-bind-transformer\tf_bind_transformer\data.py
from Bio import SeqIO
from random import choice, randrange
from pathlib import Path
import functools
import polars as pl
from collections import defaultdict
import os
import json
import shutil
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tf_bind_transformer.gene_utils import parse_gene_name
from enformer_pytorch import FastaInterval
from pyfaidx import Fasta
import pybedtools
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def find_first_index(cond, arr):
for ind, el in enumerate(arr):
if cond(el):
return ind
return -1
def cast_list(val = None):
if not exists(val):
return []
return [val] if not isinstance(val, (tuple, list)) else val
def read_bed(path):
return pl.read_csv(path, sep = '\t', has_headers = False)
def save_bed(df, path):
df.to_csv(path, sep = '\t', has_header = False)
def parse_exp_target_cell(exp_target_cell):
experiment, target, *cell_type = exp_target_cell.split('.')
cell_type = '.'.join(cell_type)
return experiment, target, cell_type
def fetch_experiments_index(path):
if not exists(path):
return dict()
exp_path = Path(path)
assert exp_path.exists(), 'path to experiments json must exist'
root_json = json.loads(exp_path.read_text())
experiments = root_json['experiments']
index = {}
for experiment in experiments:
exp_id = experiment['accession']
if 'details' not in experiment:
continue
details = experiment['details']
if 'datasets' not in details:
continue
datasets = details['datasets']
for dataset in datasets:
dataset_name = dataset['dataset_name']
index[dataset_name] = dataset['peaks_NR']
return index
class FactorProteinDatasetByUniprotID(Dataset):
def __init__(
self,
folder,
species_priority = ['human', 'mouse']
):
super().__init__()
fasta_paths = [*Path(folder).glob('*.fasta')]
assert len(fasta_paths) > 0, f'no fasta files found at {folder}'
self.paths = fasta_paths
self.index_by_id = dict()
for path in fasta_paths:
gene, uniprotid, *_ = path.stem.split('.')
self.index_by_id[uniprotid] = path
def __len__(self):
return len(self.paths)
def __getitem__(self, uid):
index = self.index_by_id
if uid not in index:
return None
entry = index[uid]
fasta = SeqIO.read(entry, 'fasta')
return str(fasta.seq)
class FactorProteinDataset(Dataset):
def __init__(
self,
folder,
species_priority = ['human', 'mouse', 'unknown'],
return_tuple_only = False
def __init__(
super().__init__()
fasta_paths = [*Path(folder).glob('*.fasta')]
assert len(fasta_paths) > 0, f'no fasta files found at {folder}'
self.paths = fasta_paths
index_by_gene = defaultdict(list)
self.return_tuple_only = return_tuple_only
for path in fasta_paths:
gene, uniprotid, *_ = path.stem.split('.')
index_by_gene[gene].append(path)
get_species_from_path = lambda p: p.stem.split('_')[-1].lower() if '_' in p.stem else 'unknown'
filtered_index_by_gene = defaultdict(list)
for gene, gene_paths in index_by_gene.items():
species_count = list(map(lambda specie: len(list(filter(lambda p: get_species_from_path(p) == specie, gene_paths))), species_priority))
species_ind_non_zero = find_first_index(lambda t: t > 0, species_count)
if species_ind_non_zero == -1:
continue
species = species_priority[species_ind_non_zero]
filtered_index_by_gene[gene] = list(filter(lambda p: get_species_from_path(p) == species, gene_paths))
self.index_by_gene = filtered_index_by_gene
def __len__(self):
return len(self.paths)
def __getitem__(self, unparsed_gene_name):
index = self.index_by_gene
genes = parse_gene_name(unparsed_gene_name)
seqs = []
for gene in genes:
entry = index[gene]
if len(entry) == 0:
print(f'no entries for {gene}')
continue
path = choice(entry) if isinstance(entry, list) else entry
fasta = SeqIO.read(path, 'fasta')
seqs.append(str(fasta.seq))
seqs = tuple(seqs)
if len(seqs) == 1 and not self.return_tuple_only:
return seqs[0]
return seqs
def get_chr_names(ids):
return set(map(lambda t: f'chr{t}', ids))
CHR_IDS = set([*range(1, 23), 'X'])
CHR_NAMES = get_chr_names(CHR_IDS)
def remap_df_add_experiment_target_cell(df, col = 'column_4'):
df = df.clone()
exp_id = df.select([pl.col(col).str.extract(r"^([\w\-]+)\.*")])
exp_id = exp_id.rename({col: 'experiment'}).to_series(0)
df.insert_at_idx(3, exp_id)
targets = df.select([pl.col(col).str.extract(r"[\w\-]+\.([\w\-]+)\.[\w\-]+")])
targets = targets.rename({col: 'target'}).to_series(0)
df.insert_at_idx(3, targets)
cell_type = df.select([pl.col(col).str.extract(r"^.*\.([\w\-]+)$")])
cell_type = cell_type.rename({col: 'cell_type'}).to_series(0)
df.insert_at_idx(3, cell_type)
return df
def pl_isin(col, arr):
equalities = list(map(lambda t: pl.col(col) == t, arr))
return functools.reduce(lambda a, b: a | b, equalities)
def pl_notin(col, arr):
equalities = list(map(lambda t: pl.col(col) != t, arr))
return functools.reduce(lambda a, b: a & b, equalities)
def filter_by_col_isin(df, col, arr, chunk_size = 25):
"""
polars 似乎存在一个 bug
当 OR 条件超过 25 个时会冻结(对于 pl_isin)
拆分成 25 个一组进行处理,然后合并
"""
dataframes = []
for i in range(0, len(arr), chunk_size):
sub_arr = arr[i:(i + chunk_size)]
filtered_df = df.filter(pl_isin(col, sub_arr))
dataframes.append(filtered_df)
return pl.concat(dataframes)
def filter_bed_file_by_(bed_file_1, bed_file_2, output_file):
bed_file_1_bedtool = pybedtools.BedTool(bed_file_1)
bed_file_2_bedtool = pybedtools.BedTool(bed_file_2)
bed_file_1_bedtool_intersect_bed_file_2_bedtool = bed_file_1_bedtool.intersect(bed_file_2_bedtool, v = True)
bed_file_1_bedtool_intersect_bed_file_2_bedtool.saveas(output_file)
def filter_df_by_tfactor_fastas(df, folder):
files = [*Path(folder).glob('**/*.fasta')]
present_target_names = set([f.stem.split('.')[0] for f in files])
all_df_targets = df.get_column('target').unique().to_list()
all_df_targets_with_parsed_name = [(target, parse_gene_name(target)) for target in all_df_targets]
unknown_targets = [target for target, parsed_target_name in all_df_targets_with_parsed_name for parsed_target_name_sub_el in parsed_target_name if parsed_target_name_sub_el not in present_target_names]
if len(unknown_targets) > 0:
df = df.filter(pl_notin('target', unknown_targets))
return df
def generate_random_ranges_from_fasta(
fasta_file,
*,
output_filename = 'random-ranges.bed',
context_length,
filter_bed_files = [],
num_entries_per_key = 10,
keys = None,
):
fasta = Fasta(fasta_file)
tmp_file = f'/tmp/{output_filename}'
with open(tmp_file, 'w') as f:
for chr_name in sorted(CHR_NAMES):
print(f'generating ranges for {chr_name}')
if chr_name not in fasta:
print(f'{chr_name} not found in fasta file')
continue
chromosome = fasta[chr_name]
chromosome_length = len(chromosome)
start = np.random.randint(0, chromosome_length - context_length, (num_entries_per_key,))
end = start + context_length
start_and_end = np.stack((start, end), axis = -1)
for row in start_and_end.tolist():
start, end = row
f.write('\t'.join((chr_name, str(start), str(end))) + '\n')
for file in filter_bed_files:
filter_bed_file_by_(tmp_file, file, tmp_file)
shutil.move(tmp_file, f'./{output_filename}')
print('success')
class ContextDataset(Dataset):
def __init__(
self,
biotypes_metadata_path = None,
include_biotypes_metadata_in_context = False,
include_biotypes_metadata_columns = [],
biotypes_metadata_delimiter = ' | ',
def __init__(
self, include_biotypes_metadata_in_context, include_biotypes_metadata_columns, biotypes_metadata_delimiter
):
self.include_biotypes_metadata_in_context = include_biotypes_metadata_in_context
self.include_biotypes_metadata_columns = include_biotypes_metadata_columns
self.biotypes_metadata_delimiter = biotypes_metadata_delimiter
if include_biotypes_metadata_in_context:
assert len(self.include_biotypes_metadata_columns) > 0, 'must have more than one biotype metadata column to include'
assert exists(biotypes_metadata_path), 'biotypes metadata path must be supplied if to be included in context string'
p = Path(biotypes_metadata_path)
if p.suffix == '.csv':
sep = ','
elif p.suffix == '.tsv':
sep = '\t'
else:
raise ValueError(f'invalid suffix {p.suffix} for biotypes')
self.df = pl.read_csv(str(p), sep = sep)
def __len__():
return len(self.df) if self.include_biotypes_metadata_in_context else -1
def __getitem__(self, biotype):
if not self.include_biotypes_metadata_in_context:
return biotype
col_indices = list(map(self.df.columns.index, self.include_biotypes_metadata_columns))
filtered = self.df.filter(pl.col('biotype') == biotype)
if len(filtered) == 0:
print(f'no rows found for {biotype} in biotype metadata file')
return biotype
row = filtered.row(0)
columns = list(map(lambda t: row[t], col_indices))
context_string = self.biotypes_metadata_delimiter.join([biotype, *columns])
return context_string
class RemapAllPeakDataset(Dataset):
def __init__(
self,
*,
factor_fasta_folder,
bed_file = None,
remap_df = None,
filter_chromosome_ids = None,
exclude_targets = None,
include_targets = None,
exclude_cell_types = None,
include_cell_types = None,
remap_df_frac = 1.,
experiments_json_path = None,
include_biotypes_metadata_in_context = False,
biotypes_metadata_path = None,
include_biotypes_metadata_columns = [],
biotypes_metadata_delimiter = ' | ',
balance_sampling_by_target = False,
**kwargs
):
super().__init__()
assert exists(remap_df) ^ exists(bed_file), 'either remap bed file or remap dataframe must be passed in'
if not exists(remap_df):
remap_df = read_bed(bed_file)
if remap_df_frac < 1:
remap_df = remap_df.sample(frac = remap_df_frac)
dataset_chr_ids = CHR_IDS
if exists(filter_chromosome_ids):
dataset_chr_ids = dataset_chr_ids.intersection(set(filter_chromosome_ids))
remap_df = remap_df.filter(pl_isin('column_1', get_chr_names(dataset_chr_ids)))
remap_df = filter_df_by_tfactor_fastas(remap_df, factor_fasta_folder)
self.factor_ds = FactorProteinDataset(factor_fasta_folder)
include_targets = cast_list(include_targets)
exclude_targets = cast_list(exclude_targets)
if include_targets:
remap_df = remap_df.filter(pl_isin('target', include_targets))
if exclude_targets:
remap_df = remap_df.filter(pl_notin('target', exclude_targets))
include_cell_types = cast_list(include_cell_types)
exclude_cell_types = cast_list(exclude_cell_types)
if include_cell_types:
remap_df = remap_df.filter(pl_isin('cell_type', include_cell_types))
if exclude_cell_types:
remap_df = remap_df.filter(pl_notin('cell_type', exclude_cell_types))
assert len(remap_df) > 0, 'dataset is empty by filter criteria'
self.df = remap_df
self.fasta = FastaInterval(**kwargs)
self.experiments_index = fetch_experiments_index(experiments_json_path)
self.balance_sampling_by_target = balance_sampling_by_target
if self.balance_sampling_by_target:
self.df_indexed_by_target = []
for target in self.df.get_column('target').unique().to_list():
df_by_target = self.df.filter(pl.col('target') == target)
self.df_indexed_by_target.append(df_by_target)
self.context_ds = ContextDataset(
include_biotypes_metadata_in_context = include_biotypes_metadata_in_context,
biotypes_metadata_path = biotypes_metadata_path,
include_biotypes_metadata_columns = include_biotypes_metadata_columns,
biotypes_metadata_delimiter = biotypes_metadata_delimiter
)
def __len__(self):
if self.balance_sampling_by_target:
return len(self.df_indexed_by_target)
else:
return len(self.df)
def __getitem__(self, ind):
if self.balance_sampling_by_target:
filtered_df = self.df_indexed_by_target[ind]
rand_ind = randrange(0, len(filtered_df))
sample = filtered_df.row(rand_ind)
else:
sample = self.df.row(ind)
chr_name, begin, end, _, _, _, experiment_target_cell_type, reading, *_ = sample
experiment, target, cell_type = parse_exp_target_cell(experiment_target_cell_type)
seq = self.fasta(chr_name, begin, end)
aa_seq = self.factor_ds[target]
context_str = self.context_ds[cell_type]
read_value = torch.Tensor([reading])
peaks_nr = self.experiments_index.get(experiment_target_cell_type, 0.)
peaks_nr = torch.Tensor([peaks_nr])
label = torch.Tensor([1.])
return seq, aa_seq, context_str, peaks_nr, read_value, label
def filter_exp_target_cell(
arr,
*,
exclude_targets = None,
include_targets = None,
exclude_cell_types = None,
include_cell_types = None,
):
out = []
for el in arr:
experiment, target, cell_type = parse_exp_target_cell(el)
if exists(include_targets) and len(include_targets) > 0 and target not in include_targets:
continue
if exists(exclude_targets) and target in exclude_targets:
continue
if exists(include_cell_types) and len(include_cell_types) > 0 and cell_type not in include_cell_types:
continue
if exists(exclude_cell_types) and cell_type in exclude_cell_types:
continue
out.append(el)
return out
class ScopedNegativePeakDataset(Dataset):
def __init__(
self,
*,
fasta_file,
factor_fasta_folder,
numpy_folder_with_scoped_negatives,
exts = '.bed.bool.npy',
remap_bed_file = None,
remap_df = None,
filter_chromosome_ids = None,
experiments_json_path = None,
exclude_targets = None,
include_targets = None,
exclude_cell_types = None,
include_cell_types = None,
include_biotypes_metadata_in_context = False,
biotypes_metadata_path = None,
include_biotypes_metadata_columns = [],
biotypes_metadata_delimiter = ' | ',
balance_sampling_by_target = False,
**kwargs
):
super().__init__()
assert exists(remap_df) ^ exists(remap_bed_file), 'either remap bed file or remap dataframe must be passed in'
if not exists(remap_df):
remap_df = read_bed(remap_bed_file)
dataset_chr_ids = CHR_IDS
if exists(filter_chromosome_ids):
dataset_chr_ids = dataset_chr_ids.intersection(set(filter_chromosome_ids))
filter_map_df = remap_df.with_column(pl.when(pl_isin('column_1', get_chr_names(dataset_chr_ids))).then(True).otherwise(False).alias('mask'))
mask = filter_map_df.get_column('mask').to_numpy()
num_scoped_negs = mask.sum()
print(f'{num_scoped_negs} scoped negative rows found for training')
assert num_scoped_negs > 0, 'all remap rows filtered out for scoped negative peak dataset'
self.df = remap_df
self.chromosome_mask = mask
npys_paths = [*Path(numpy_folder_with_scoped_negatives).glob('**/*.npy')]
exp_target_cell_negatives = [(path.name.rstrip(exts), path) for path in npys_paths]
exp_target_cells = [el[0] for el in exp_target_cell_negatives]
exp_target_cells = filter_exp_target_cell(
exp_target_cells,
include_targets = include_targets,
exclude_targets = exclude_targets,
include_cell_types = include_cell_types,
exclude_cell_types = exclude_cell_types
)
filtered_exp_target_cell_negatives = list(filter(lambda el: el[0] in exp_target_cells, exp_target_cell_negatives))
self.exp_target_cell_negatives = filtered_exp_target_cell_negatives
assert len(self.exp_target_cell_negatives) > 0, 'no experiment-target-cell scoped negatives to select from after filtering'
self.balance_sampling_by_target = balance_sampling_by_target
if balance_sampling_by_target:
self.exp_target_cell_by_target = defaultdict(list)
for exp_target_cell, filepath in self.exp_target_cell_negatives:
_, target, *_ = parse_exp_target_cell(exp_target_cell)
self.exp_target_cell_by_target[target].append((exp_target_cell, filepath))
self.factor_ds = FactorProteinDataset(factor_fasta_folder)
self.fasta = FastaInterval(fasta_file = fasta_file, **kwargs)
self.experiments_index = fetch_experiments_index(experiments_json_path)
self.context_ds = ContextDataset(
include_biotypes_metadata_in_context = include_biotypes_metadata_in_context,
biotypes_metadata_path = biotypes_metadata_path,
include_biotypes_metadata_columns = include_biotypes_metadata_columns,
biotypes_metadata_delimiter = biotypes_metadata_delimiter
)
def __len__(self):
if self.balance_sampling_by_target:
return len(self.exp_target_cell_by_target)
else:
return len(self.exp_target_cell_negatives)
def __getitem__(self, idx):
if self.balance_sampling_by_target:
negatives = list(self.exp_target_cell_by_target.values())[idx]
sample = choice(negatives)
else:
sample = self.exp_target_cell_negatives[idx]
exp_target_cell, bool_numpy_path = sample
experiment, target, cell_type = parse_exp_target_cell(exp_target_cell)
np_arr = np.load(str(bool_numpy_path))
np_arr_noised = np_arr.astype(np.float32) + np.random.uniform(low=-1e-1, high=1e-1, size=np_arr.shape[0])
np_arr_noised *= self.chromosome_mask.astype(np.float32)
random_neg_peak_index = np_arr_noised.argmax()
chr_name, begin, end, *_ = self.df.row(random_neg_peak_index)
seq = self.fasta(chr_name, begin, end)
aa_seq = self.factor_ds[target]
context_str = self.context_ds[cell_type]
peaks_nr = self.experiments_index.get(exp_target_cell, 0.)
peaks_nr = torch.Tensor([peaks_nr])
read_value = torch.Tensor([0.])
label = torch.Tensor([0.])
return seq, aa_seq, context_str, peaks_nr, read_value, label
class NegativePeakDataset(Dataset):
def __init__(
self,
*,
factor_fasta_folder,
negative_bed_file = None,
remap_bed_file = None,
remap_df = None,
negative_df = None,
filter_chromosome_ids = None,
exclude_targets = None,
include_targets = None,
exclude_cell_types = None,
include_cell_types = None,
exp_target_cell_column = 'column_4',
experiments_json_path = None,
include_biotypes_metadata_in_context = False,
biotypes_metadata_path = None,
include_biotypes_metadata_columns = [],
biotypes_metadata_delimiter = ' | ',
balance_sampling_by_target = False,
**kwargs
):
super().__init__()
assert exists(remap_df) ^ exists(remap_bed_file), 'either remap bed file or remap dataframe must be passed in'
assert exists(negative_df) ^ exists(negative_bed_file), 'either negative bed file or negative dataframe must be passed in'
if not exists(remap_df):
remap_df = read_bed(remap_bed_file)
neg_df = negative_df
if not exists(negative_df):
neg_df = read_bed(negative_bed_file)
remap_df = filter_df_by_tfactor_fastas(remap_df, factor_fasta_folder)
dataset_chr_ids = CHR_IDS
if exists(filter_chromosome_ids):
dataset_chr_ids = dataset_chr_ids.intersection(set(filter_chromosome_ids))
neg_df = neg_df.filter(pl_isin('column_1', get_chr_names(dataset_chr_ids)))
assert len(neg_df) > 0, 'dataset is empty by filter criteria'
self.neg_df = neg_df
exp_target_cells = remap_df.get_column(exp_target_cell_column).unique().to_list()
self.filtered_exp_target_cells = filter_exp_target_cell(
exp_target_cells,
include_targets = include_targets,
exclude_targets = exclude_targets,
include_cell_types = include_cell_types,
exclude_cell_types = exclude_cell_types
)
assert len(self.filtered_exp_target_cells), 'no experiment-target-cell left for hard negative set'
self.balance_sampling_by_target = balance_sampling_by_target
if balance_sampling_by_target:
self.exp_target_cell_by_target = defaultdict(list)
for exp_target_cell in self.filtered_exp_target_cells:
_, target, *_ = parse_exp_target_cell(exp_target_cell)
self.exp_target_cell_by_target[target].append(exp_target_cell)
self.factor_ds = FactorProteinDataset(factor_fasta_folder)
self.fasta = FastaInterval(**kwargs)
self.experiments_index = fetch_experiments_index(experiments_json_path)
self.context_ds = ContextDataset(
include_biotypes_metadata_in_context = include_biotypes_metadata_in_context,
biotypes_metadata_path = biotypes_metadata_path,
include_biotypes_metadata_columns = include_biotypes_metadata_columns,
biotypes_metadata_delimiter = biotypes_metadata_delimiter
)
def __len__(self):
return len(self.neg_df)
def __getitem__(self, ind):
chr_name, begin, end = self.neg_df.row(ind)
if self.balance_sampling_by_target:
rand_ind = randrange(0, len(self.exp_target_cell_by_target))
exp_target_cell_by_target_list = list(self.exp_target_cell_by_target.values())
random_exp_target_cell_type = choice(exp_target_cell_by_target_list[rand_ind])
else:
random_exp_target_cell_type = choice(self.filtered_exp_target_cells)
experiment, target, cell_type = parse_exp_target_cell(random_exp_target_cell_type)
seq = self.fasta(chr_name, begin, end)
aa_seq = self.factor_ds[target]
context_str = self.context_ds[cell_type]
read_value = torch.Tensor([0.])
peaks_nr = self.experiments_index.get(random_exp_target_cell_type, 0.)
peaks_nr = torch.Tensor([peaks_nr])
label = torch.Tensor([0.])
return seq, aa_seq, context_str, peaks_nr, read_value, label
def collate_fn(data):
seq, aa_seq, context_str, peaks_nr, read_values, labels = list(zip(*data))
return torch.stack(seq), tuple(aa_seq), tuple(context_str), torch.stack(peaks_nr, dim=0), torch.stack(read_values, dim=0), torch.cat(labels, dim=0)
def collate_dl_outputs(*dl_outputs):
outputs = list(zip(*dl_outputs))
ret = []
for entry in outputs:
if isinstance(entry[0], torch.Tensor):
entry = torch.cat(entry, dim=0)
else:
entry = (sub_el for el in entry for sub_el in el)
ret.append(entry)
return tuple(ret)
def cycle(loader):
while True:
for data in loader:
yield data
def get_dataloader(ds, cycle_iter=False, **kwargs):
dataset_len = len(ds)
batch_size = kwargs.get('batch_size')
drop_last = dataset_len > batch_size
dl = DataLoader(ds, collate_fn=collate_fn, drop_last=drop_last, **kwargs)
wrapper = cycle if cycle_iter else iter
return wrapper(dl)
.\lucidrains\tf-bind-transformer\tf_bind_transformer\data_bigwig.py
from pathlib import Path
import polars as pl
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tf_bind_transformer.data import FactorProteinDataset, ContextDataset, cast_list, filter_df_by_tfactor_fastas
from tf_bind_transformer.data import pl_isin, pl_notin, fetch_experiments_index, parse_exp_target_cell, read_bed, cycle, filter_by_col_isin
from tf_bind_transformer.data import CHR_IDS, CHR_NAMES, get_chr_names
from enformer_pytorch import FastaInterval
try:
import pyBigWig
except ImportError:
print('pyBigWig needs to be installed - conda install pyBigWig')
exit()
def exists(val):
return val is not None
def chip_atlas_add_experiment_target_cell(
df,
col_target = 'column_4',
col_cell_type = 'column_5'
):
df = df.clone()
targets = df.select(col_target)
targets = targets.to_series(0).str.to_uppercase().rename('target')
df.insert_at_idx(2, targets)
cell_type = df.select(col_cell_type)
cell_type = cell_type.rename({col_cell_type: 'cell_type'}).to_series(0)
df.insert_at_idx(2, cell_type)
return df
class BigWigDataset(Dataset):
def __init__(
self,
*,
factor_fasta_folder,
bigwig_folder,
enformer_loci_path,
fasta_file,
annot_file = None,
filter_chromosome_ids = None,
exclude_targets = None,
include_targets = None,
exclude_cell_types = None,
include_cell_types = None,
df_frac = 1.,
experiments_json_path = None,
include_biotypes_metadata_in_context = False,
biotypes_metadata_path = None,
filter_sequences_by = None,
include_biotypes_metadata_columns = [],
biotypes_metadata_delimiter = ' | ',
only_ref = ['mm10', 'hg38'],
factor_species_priority = ['human', 'mouse'],
downsample_factor = 128,
target_length = 896,
bigwig_reduction_type = 'sum',
**kwargs
def __init__(
super().__init__()
assert exists(annot_file)
if not exists(bigwig_folder):
self.invalid = True
self.ntargets = 0
return
bigwig_folder = Path(bigwig_folder)
assert bigwig_folder.exists(), 'bigwig folder does not exist'
bw_experiments = [p.stem for p in bigwig_folder.glob('*.bw')]
assert len(bw_experiments) > 0, 'no bigwig files found in bigwig folder'
loci = read_bed(enformer_loci_path)
annot_df = pl.read_csv(annot_file, sep = "\t", has_headers = False, columns = list(map(lambda i: f'column_{i + 1}', range(17))))
annot_df = annot_df.filter(pl_isin('column_2', only_ref))
annot_df = filter_by_col_isin(annot_df, 'column_1', bw_experiments)
if df_frac < 1:
annot_df = annot_df.sample(frac = df_frac)
dataset_chr_ids = CHR_IDS
if exists(filter_chromosome_ids):
dataset_chr_ids = dataset_chr_ids.intersection(set(filter_chromosome_ids))
loci = loci.filter(pl_isin('column_1', get_chr_names(dataset_chr_ids)))
if exists(filter_sequences_by):
col_name, col_val = filter_sequences_by
loci = loci.filter(pl.col(col_name) == col_val)
self.factor_ds = FactorProteinDataset(factor_fasta_folder, species_priority = factor_species_priority)
exp_ids = set(annot_df.get_column('column_1').to_list())
annot_df = chip_atlas_add_experiment_target_cell(annot_df)
annot_df = filter_df_by_tfactor_fastas(annot_df, factor_fasta_folder)
filtered_exp_ids = set(annot_df.get_column('column_1').to_list())
filtered_out_exp_ids = exp_ids - filtered_exp_ids
print(f'{", ".join(only_ref)} - {len(filtered_out_exp_ids)} experiments filtered out by lack of transcription factor fastas', filtered_out_exp_ids)
include_targets = cast_list(include_targets)
exclude_targets = cast_list(exclude_targets)
if include_targets:
annot_df = annot_df.filter(pl_isin('target', include_targets))
if exclude_targets:
annot_df = annot_df.filter(pl_notin('target', exclude_targets))
include_cell_types = cast_list(include_cell_types)
exclude_cell_types = cast_list(exclude_cell_types)
if include_cell_types:
annot_df = annot_df.filter(pl_isin('cell_type', include_cell_types))
if exclude_cell_types:
annot_df = annot_df.filter(pl_notin('cell_type', exclude_cell_types))
self.fasta = FastaInterval(fasta_file = fasta_file, **kwargs)
self.df = loci
self.annot = annot_df
self.ntargets = self.annot.shape[0]
self.bigwigs = [pyBigWig.open(str(bigwig_folder / f'{str(i)}.bw')) for i in self.annot.get_column("column_1")]
self.downsample_factor = downsample_factor
self.target_length = target_length
self.bigwig_reduction_type = bigwig_reduction_type
self.invalid = False
def __len__(self):
if self.invalid:
return 0
return len(self.df) * self.ntargets
def __getitem__(self, ind):
chr_name, begin, end, _ = self.df.row(ind % self.df.shape[0])
targets = self.annot.select('target').to_series(0)
cell_types = self.annot.select('cell_type').to_series(0)
ix_target = ind // self.df.shape[0]
target = targets[ix_target]
context_str = cell_types[ix_target]
exp_bw = self.bigwigs[ix_target]
aa_seq = self.factor_ds[target]
seq = self.fasta(chr_name, begin, end)
output = np.array(exp_bw.values(chr_name, begin, end))
output = output.reshape((-1, self.downsample_factor))
if self.bigwig_reduction_type == 'mean':
om = np.nanmean(output, axis = 1)
elif self.bigwig_reduction_type == 'sum':
om = np.nansum(output, axis = 1)
else:
raise ValueError(f'unknown reduction type {self.bigwig_reduction_type}')
output_length = output.shape[0]
if output_length < self.target_length:
assert f'target length {self.target_length} cannot be less than the {output_length}'
trim = (output.shape[0] - self.target_length) // 2
om = om[trim:-trim]
np.nan_to_num(om, copy = False)
label = torch.Tensor(om)
return seq, aa_seq, context_str, label
class BigWigTracksOnlyDataset(Dataset):
def __init__(
self,
*,
bigwig_folder,
enformer_loci_path,
fasta_file,
ref,
annot_file = None,
filter_chromosome_ids = None,
downsample_factor = 128,
target_length = 896,
bigwig_reduction_type = 'sum',
filter_sequences_by = None,
**kwargs
):
super().__init__()
assert exists(annot_file)
if not exists(bigwig_folder):
self.invalid = True
self.ntargets = 0
return
bigwig_folder = Path(bigwig_folder)
assert bigwig_folder.exists(), 'bigwig folder does not exist'
bw_experiments = [p.stem for p in bigwig_folder.glob('*.bw')]
assert len(bw_experiments) > 0, 'no bigwig files found in bigwig folder'
loci = read_bed(enformer_loci_path)
annot_df = pl.read_csv(annot_file, sep = "\t", has_headers = False, columns = list(map(lambda i: f'column_{i + 1}', range(17))))
annot_df = annot_df.filter(pl.col('column_2') == ref)
annot_df = filter_by_col_isin(annot_df, 'column_1', bw_experiments)
dataset_chr_ids = CHR_IDS
if exists(filter_chromosome_ids):
dataset_chr_ids = dataset_chr_ids.intersection(set(filter_chromosome_ids))
loci = loci.filter(pl_isin('column_1', get_chr_names(dataset_chr_ids)))
if exists(filter_sequences_by):
col_name, col_val = filter_sequences_by
loci = loci.filter(pl.col(col_name) == col_val)
self.fasta = FastaInterval(fasta_file = fasta_file, **kwargs)
self.df = loci
self.annot = annot_df
self.ntargets = self.annot.shape[0]
self.bigwigs = [(str(i), pyBigWig.open(str(bigwig_folder / f'{str(i)}.bw'))) for i in self.annot.get_column("column_1")]
self.downsample_factor = downsample_factor
self.target_length = target_length
self.bigwig_reduction_type = bigwig_reduction_type
self.invalid = False
def __len__(self):
if self.invalid:
return 0
return len(self.df) * int(self.ntargets > 0)
def __getitem__(self, ind):
chr_name, begin, end, _ = self.df.row(ind)
seq = self.fasta(chr_name, begin, end)
all_bw_values = []
for bw_path, bw in self.bigwigs:
try:
bw_values = bw.values(chr_name, begin, end)
all_bw_values.append(bw_values)
except:
print(f'hitting invalid range for {bw_path} - ({chr_name}, {begin}, {end})')
exit()
output = np.stack(all_bw_values, axis = -1)
output = output.reshape((-1, self.downsample_factor, self.ntargets))
if self.bigwig_reduction_type == 'mean':
om = np.nanmean(output, axis = 1)
elif self.bigwig_reduction_type == 'sum':
om = np.nansum(output, axis = 1)
else:
raise ValueError(f'unknown reduction type {self.bigwig_reduction_type}')
output_length = output.shape[0]
if output_length < self.target_length:
assert f'target length {self.target_length} cannot be less than the {output_length}'
trim = (output.shape[0] - self.target_length) // 2
om = om[trim:-trim]
np.nan_to_num(om, copy = False)
label = torch.Tensor(om)
return seq, label
def bigwig_collate_fn(data):
seq, aa_seq, context_str, labels = list(zip(*data))
return torch.stack(seq), tuple(aa_seq), tuple(context_str), torch.stack(labels)
def get_bigwig_dataloader(ds, cycle_iter = False, **kwargs):
dataset_len = len(ds)
batch_size = kwargs.get('batch_size')
drop_last = dataset_len > batch_size
dl = DataLoader(ds, collate_fn = bigwig_collate_fn, drop_last = drop_last, **kwargs)
wrapper = cycle if cycle_iter else iter
return wrapper(dl)
def get_bigwig_tracks_dataloader(ds, cycle_iter = False, **kwargs):
dataset_len = len(ds)
batch_size = kwargs.get('batch_size')
drop_last = dataset_len > batch_size
dl = DataLoader(ds, drop_last = drop_last, **kwargs)
wrapper = cycle if cycle_iter else iter
return wrapper(dl)