Transformers 源码解析(四十七)
.\models\esm\openfold_utils\chunk_utils.py
import logging
import math
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import torch
from .tensor_utils import tensor_tree_map, tree_map
def _fetch_dims(tree: Union[dict, list, tuple, torch.Tensor]) -> List[Tuple[int, ...]]:
shapes = []
if isinstance(tree, dict):
for v in tree.values():
shapes.extend(_fetch_dims(v))
elif isinstance(tree, (list, tuple)):
for t in tree:
shapes.extend(_fetch_dims(t))
elif isinstance(tree, torch.Tensor):
shapes.append(tree.shape)
else:
raise ValueError("Not supported")
return shapes
@torch.jit.ignore
def _flat_idx_to_idx(flat_idx: int, dims: Tuple[int, ...]) -> Tuple[int, ...]:
idx = []
for d in reversed(dims):
idx.append(flat_idx % d)
flat_idx = flat_idx // d
return tuple(reversed(idx))
@torch.jit.ignore
def _get_minimal_slice_set(
start: Sequence[int],
end: Sequence[int],
dims: Sequence[int],
start_edges: Optional[Sequence[bool]] = None,
end_edges: Optional[Sequence[bool]] = None,
) -> List[Tuple[slice, ...]]:
"""
Produces an ordered sequence of tensor slices that, when used in sequence on a tensor with shape dims, yields
tensors that contain every leaf in the contiguous range [start, end]. Care is taken to yield a short sequence of
slices, and perhaps even the shortest possible (I'm pretty sure it's the latter).
end is INCLUSIVE.
"""
if start_edges is None:
start_edges = [s == 0 for s in start]
reduce_edge_list(start_edges)
if end_edges is None:
end_edges = [e == (d - 1) for e, d in zip(end, dims)]
reduce_edge_list(end_edges)
if len(start) == 0:
return [()]
elif len(start) == 1:
return [(slice(start[0], end[0] + 1),)]
slices: List[Tuple[slice, ...]] = []
path_list: List[slice] = []
for s, e in zip(start, end):
if s == e:
path_list.append(slice(s, s + 1))
else:
break
path: Tuple[slice, ...] = tuple(path_list)
divergence_idx = len(path)
if divergence_idx == len(dims):
return [path]
def upper() -> Tuple[Tuple[slice, ...], ...]:
assert start_edges is not None
assert end_edges is not None
sdi = start[divergence_idx]
return tuple(
path + (slice(sdi, sdi + 1),) + s
for s in _get_minimal_slice_set(
start[divergence_idx + 1 :],
[d - 1 for d in dims[divergence_idx + 1 :]],
dims[divergence_idx + 1 :],
start_edges=start_edges[divergence_idx + 1 :],
end_edges=[True for _ in end_edges[divergence_idx + 1 :]],
)
)
def lower() -> Tuple[Tuple[slice, ...], ...]:
assert start_edges is not None
assert end_edges is not None
edi = end[divergence_idx]
return tuple(
path + (slice(edi, edi + 1),) + s
for s in _get_minimal_slice_set(
[0 for _ in start[divergence_idx + 1 :]],
end[divergence_idx + 1 :],
dims[divergence_idx + 1 :],
start_edges=[True for _ in start_edges[divergence_idx + 1 :]],
end_edges=end_edges[divergence_idx + 1 :],
)
)
if start_edges[divergence_idx] and end_edges[divergence_idx]:
slices.append(path + (slice(start[divergence_idx], end[divergence_idx] + 1),))
elif start_edges[divergence_idx]:
slices.append(path + (slice(start[divergence_idx], end[divergence_idx]),))
slices.extend(lower())
elif end_edges[divergence_idx]:
slices.extend(upper())
slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),))
else:
slices.extend(upper())
middle_ground = end[divergence_idx] - start[divergence_idx]
if middle_ground > 1:
slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx]),))
slices.extend(lower())
return slices
@torch.jit.ignore
def _chunk_slice(t: torch.Tensor, flat_start: int, flat_end: int, no_batch_dims: int) -> torch.Tensor:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the initial reshape call, which can be memory-intensive in certain situations. The only
reshape operations in this function are performed on sub-tensors that scale with (flat_end - flat_start), the chunk
size.
"""
batch_dims = t.shape[:no_batch_dims]
start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
slices = _get_minimal_slice_set(
start_idx,
end_idx,
batch_dims,
)
sliced_tensors = [t[s] for s in slices]
return torch.cat([s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors])
def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
no_batch_dims: int,
low_mem: bool = False,
_out: Any = None,
_add_into_out: bool = False,
) -> Any:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees," consisting only of (arbitrarily nested) lists, tuples,
and dicts with torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must be tensors and must share the same batch
dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch dimensions are specified, a "sub-batch" is defined
as a single indexing of all batch dimensions simultaneously (s.t. the number of sub-batches is the product
of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary in most cases, and is ever so slightly
slower than the default setting.
Returns:
The reassembled output of the layer on the inputs.
"""
if not (len(inputs) > 0):
raise ValueError("Must provide at least one input")
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def _prep_inputs(t: torch.Tensor) -> torch.Tensor:
if not low_mem:
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
t = t.reshape(-1, *t.shape[no_batch_dims:])
else:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
return t
prepped_inputs: Dict[str, Any] = tensor_tree_map(_prep_inputs, inputs)
prepped_outputs = None
if _out is not None:
prepped_outputs = tensor_tree_map(lambda t: t.view([-1] + list(t.shape[no_batch_dims:])), _out)
flat_batch_dim = 1
for d in orig_batch_dims:
flat_batch_dim *= d
no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0)
def _select_chunk(t: torch.Tensor) -> torch.Tensor:
return t[i : i + chunk_size] if t.shape[0] != 1 else t
i = 0
out = prepped_outputs
for _ in range(no_chunks):
if not low_mem:
select_chunk = _select_chunk
else:
select_chunk = partial(
_chunk_slice,
flat_start=i,
flat_end=min(flat_batch_dim, i + chunk_size),
no_batch_dims=len(orig_batch_dims),
)
chunks: Dict[str, Any] = tensor_tree_map(select_chunk, prepped_inputs)
output_chunk = layer(**chunks)
if out is None:
out = tensor_tree_map(lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]), output_chunk)
if isinstance(output_chunk, dict):
def assign(d1: dict, d2: dict) -> None:
for k, v in d1.items():
if isinstance(v, dict):
assign(v, d2[k])
else:
if _add_into_out:
v[i : i + chunk_size] += d2[k]
else:
v[i : i + chunk_size] = d2[k]
assign(out, output_chunk)
elif isinstance(output_chunk, tuple):
for x1, x2 in zip(out, output_chunk):
if _add_into_out:
x1[i : i + chunk_size] += x2
else:
x1[i : i + chunk_size] = x2
elif isinstance(output_chunk, torch.Tensor):
if _add_into_out:
out[i : i + chunk_size] += output_chunk
else:
out[i : i + chunk_size] = output_chunk
else:
raise ValueError("Not supported")
i += chunk_size
out = tensor_tree_map(lambda t: t.view(orig_batch_dims + t.shape[1:]), out)
return out
class ChunkSizeTuner:
def __init__(
self,
max_chunk_size: int = 512,
):
self.max_chunk_size = max_chunk_size
self.cached_chunk_size: Optional[int] = None
self.cached_arg_data: Optional[tuple] = None
def _determine_favorable_chunk_size(self, fn: Callable, args: tuple, min_chunk_size: int) -> int:
logging.info("Tuning chunk size...")
if min_chunk_size >= self.max_chunk_size:
return min_chunk_size
candidates: List[int] = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
candidates = [c for c in candidates if c > min_chunk_size]
candidates = [min_chunk_size] + candidates
candidates[-1] += 4
def test_chunk_size(chunk_size: int) -> bool:
try:
with torch.no_grad():
fn(*args, chunk_size=chunk_size)
return True
except RuntimeError:
return False
min_viable_chunk_size_index = 0
i = len(candidates) - 1
while i > min_viable_chunk_size_index:
viable = test_chunk_size(candidates[i])
if not viable:
i = (min_viable_chunk_size_index + i) // 2
else:
min_viable_chunk_size_index = i
i = (i + len(candidates) - 1) // 2
return candidates[min_viable_chunk_size_index]
def _compare_arg_caches(self, ac1: Iterable, ac2: Iterable) -> bool:
consistent = True
for a1, a2 in zip(ac1, ac2):
assert type(ac1) == type(ac2)
if isinstance(ac1, (list, tuple)):
consistent &= self._compare_arg_caches(a1, a2)
elif isinstance(ac1, dict):
a1_items = [v for _, v in sorted(a1.items(), key=lambda x: x[0])]
a2_items = [v for _, v in sorted(a2.items(), key=lambda x: x[0])]
consistent &= self._compare_arg_caches(a1_items, a2_items)
else:
consistent &= a1 == a2
return consistent
def tune_chunk_size(
self,
representative_fn: Callable,
args: tuple,
min_chunk_size: int,
) -> int:
consistent = True
arg_data: tuple = tree_map(lambda a: a.shape if isinstance(a, torch.Tensor) else a, args, object)
if self.cached_arg_data is not None:
assert len(self.cached_arg_data) == len(arg_data)
consistent = self._compare_arg_caches(self.cached_arg_data, arg_data)
else:
consistent = False
if not consistent:
self.cached_chunk_size = self._determine_favorable_chunk_size(
representative_fn,
args,
min_chunk_size,
)
self.cached_arg_data = arg_data
assert self.cached_chunk_size is not None
return self.cached_chunk_size
.\models\esm\openfold_utils\data_transforms.py
from typing import Dict
import numpy as np
import torch
from . import residue_constants as rc
from .tensor_utils import tensor_tree_map, tree_map
def make_atom14_masks(protein: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""构建更密集的原子位置掩码(14维而非37维)。"""
restype_atom14_to_atom37_list = []
restype_atom37_to_atom14_list = []
restype_atom14_mask_list = []
for rt in rc.restypes:
atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
restype_atom14_to_atom37_list.append([(rc.atom_order[name] if name else 0) for name in atom_names])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14_list.append(
[(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in rc.atom_types]
)
restype_atom14_mask_list.append([(1.0 if name else 0.0) for name in atom_names])
restype_atom14_to_atom37_list.append([0] * 14)
restype_atom37_to_atom14_list.append([0] * 37)
restype_atom14_mask_list.append([0.0] * 14)
restype_atom14_to_atom37 = torch.tensor(
restype_atom14_to_atom37_list,
dtype=torch.int32,
device=protein["aatype"].device,
)
restype_atom37_to_atom14 = torch.tensor(
restype_atom37_to_atom14_list,
dtype=torch.int32,
device=protein["aatype"].device,
)
restype_atom14_mask = torch.tensor(
restype_atom14_mask_list,
dtype=torch.float32,
device=protein["aatype"].device,
)
protein_aatype = protein["aatype"].to(torch.long)
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
residx_atom14_mask = restype_atom14_mask[protein_aatype]
protein["atom14_atom_exists"] = residx_atom14_mask
protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
restype_atom37_mask = torch.zeros([21, 37], dtype=torch.float32, device=protein["aatype"].device)
for restype, restype_letter in enumerate(rc.restypes):
restype_name = rc.restype_1to3[restype_letter]
atom_names = rc.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = rc.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[protein_aatype]
protein["atom37_atom_exists"] = residx_atom37_mask
return protein
def make_atom14_masks_np(batch: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
batch = tree_map(lambda n: torch.tensor(n, device=batch["aatype"].device), batch, np.ndarray)
out = tensor_tree_map(lambda t: np.array(t), make_atom14_masks(batch))
return out
.\models\esm\openfold_utils\feats.py
from typing import Dict, Tuple, overload
import torch
import torch.types
from torch import nn
from . import residue_constants as rc
from .rigid_utils import Rigid, Rotation
from .tensor_utils import batched_gather
@overload
def pseudo_beta_fn(aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: None) -> torch.Tensor:
...
@overload
def pseudo_beta_fn(
aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
...
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
is_gly = aatype == rc.restype_order["G"]
ca_idx = rc.atom_order["CA"]
cb_idx = rc.atom_order["CB"]
pseudo_beta = torch.where(
is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :],
)
if all_atom_masks is not None:
pseudo_beta_mask = torch.where(
is_gly,
all_atom_masks[..., ca_idx],
all_atom_masks[..., cb_idx],
)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
def atom14_to_atom37(atom14: torch.Tensor, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
atom37_data = batched_gather(
atom14,
batch["residx_atom37_to_atom14"],
dim=-2,
no_batch_dims=len(atom14.shape[:-2]),
)
atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]
return atom37_data
def build_template_angle_feat(template_feats: Dict[str, torch.Tensor]) -> torch.Tensor:
template_aatype = template_feats["template_aatype"]
torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"]
torsion_angles_mask = template_feats["template_torsion_angles_mask"]
template_angle_feat = torch.cat(
[
nn.functional.one_hot(template_aatype, 22),
torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14),
alt_torsion_angles_sin_cos.reshape(*alt_torsion_angles_sin_cos.shape[:-2], 14),
torsion_angles_mask,
],
dim=-1,
)
return template_angle_feat
def build_template_pair_feat(
batch: Dict[str, torch.Tensor],
min_bin: torch.types.Number,
max_bin: torch.types.Number,
no_bins: int,
use_unit_vector: bool = False,
eps: float = 1e-20,
inf: float = 1e8,
def torsion_angles_to_frames(
r: Rigid,
alpha: torch.Tensor,
aatype: torch.Tensor,
rrgdf: torch.Tensor,
) -> Rigid:
default_4x4 = rrgdf[aatype, ...]
default_r = r.from_tensor_4x4(default_4x4)
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
bb_rot[..., 1] = 1
alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2)
all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha
all_frames = default_r.compose(Rigid(Rotation(rot_mats=all_rots), None))
chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6]
chi4_frame_to_frame = all_frames[..., 7]
chi1_frame_to_bb = all_frames[..., 4]
chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = Rigid.cat(
[
all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1),
chi3_frame_to_bb.unsqueeze(-1),
chi4_frame_to_bb.unsqueeze(-1),
],
dim=-1,
)
all_frames_to_global = r[..., None].compose(all_frames_to_bb)
return all_frames_to_global
group_mask = group_idx[aatype, ...]
group_mask_one_hot: torch.LongTensor = nn.functional.one_hot(
group_mask,
num_classes=default_frames.shape[-3],
)
t_atoms_to_global = r[..., None, :] * group_mask_one_hot
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
lit_positions = lit_positions[aatype, ...]
pred_positions = t_atoms_to_global.apply(lit_positions)
pred_positions = pred_positions * atom_mask
return pred_positions
.\models\esm\openfold_utils\loss.py
from typing import Dict, Optional, Tuple
import torch
def _calculate_bin_centers(boundaries: torch.Tensor) -> torch.Tensor:
step = boundaries[1] - boundaries[0]
bin_centers = boundaries + step / 2
bin_centers = torch.cat([bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0)
return bin_centers
def _calculate_expected_aligned_error(
alignment_confidence_breaks: torch.Tensor,
aligned_distance_error_probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
return (
torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
bin_centers[-1],
)
def compute_predicted_aligned_error(
logits: torch.Tensor,
max_bin: int = 31,
no_bins: int = 64,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""从对数输出计算对齐信心度度量。
Args:
logits: [*, num_res, num_res, num_bins] PredictedAlignedErrorHead 输出的对数。
max_bin: 最大 bin 值
no_bins: bin 的数量
Returns:
aligned_confidence_probs: [*, num_res, num_res, num_bins] 每个残基对的预测对齐误差概率。
predicted_aligned_error: [*, num_res, num_res] 每对残基的预期对齐距离误差。
max_predicted_aligned_error: [*] 可能的最大预测误差。
"""
boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device)
aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)
predicted_aligned_error, max_predicted_aligned_error = _calculate_expected_aligned_error(
alignment_confidence_breaks=boundaries,
aligned_distance_error_probs=aligned_confidence_probs,
)
return {
"aligned_confidence_probs": aligned_confidence_probs,
"predicted_aligned_error": predicted_aligned_error,
"max_predicted_aligned_error": max_predicted_aligned_error,
}
def compute_tm(
logits: torch.Tensor,
residue_weights: Optional[torch.Tensor] = None,
max_bin: int = 31,
no_bins: int = 64,
eps: float = 1e-8,
**kwargs,
) -> torch.Tensor:
if residue_weights is None:
residue_weights = logits.new_ones(logits.shape[-2])
boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device)
bin_centers = _calculate_bin_centers(boundaries)
torch.sum(residue_weights)
n = logits.shape[-2]
clipped_n = max(n, 19)
d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
probs = torch.nn.functional.softmax(logits, dim=-1)
tm_per_bin = 1.0 / (1 + (bin_centers**2) / (d0**2))
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
normed_residue_mask = residue_weights / (eps + residue_weights.sum())
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
weighted = per_alignment * residue_weights
argmax = (weighted == torch.max(weighted)).nonzero()[0]
return per_alignment[tuple(argmax)]
.\models\esm\openfold_utils\protein.py
import dataclasses
import re
import string
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple
import numpy as np
from . import residue_constants
FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any]
PICO_TO_ANGSTROM = 0.01
@dataclasses.dataclass(frozen=True)
class Protein:
"""蛋白质结构的表示类。"""
atom_positions: np.ndarray
aatype: np.ndarray
atom_mask: np.ndarray
residue_index: np.ndarray
b_factors: np.ndarray
chain_index: Optional[np.ndarray] = None
remark: Optional[str] = None
parents: Optional[Sequence[str]] = None
parents_chain_index: Optional[Sequence[int]] = None
def from_proteinnet_string(proteinnet_str: str) -> Protein:
tag_re = r"(\[[A-Z]+\]\n)"
tags: List[str] = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0]
groups: Iterator[Tuple[str, List[str]]] = zip(tags[0::2], [l.split("\n") for l in tags[1::2]])
atoms: List[str] = ["N", "CA", "C"]
aatype = None
atom_positions = None
atom_mask = None
for g in groups:
if "[PRIMARY]" == g[0]:
seq = g[1][0].strip()
for i in range(len(seq)):
if seq[i] not in residue_constants.restypes:
seq[i] = "X"
aatype = np.array(
[residue_constants.restype_order.get(res_symbol, residue_constants.restype_num) for res_symbol in seq]
)
elif "[TERTIARY]" == g[0]:
tertiary: List[List[float]] = []
for axis in range(3):
tertiary.append(list(map(float, g[1][axis].split())))
tertiary_np = np.array(tertiary)
atom_positions = np.zeros((len(tertiary[0]) // 3, residue_constants.atom_type_num, 3)).astype(np.float32)
for i, atom in enumerate(atoms):
atom_positions[:, residue_constants.atom_order[atom], :] = np.transpose(tertiary_np[:, i::3])
atom_positions *= PICO_TO_ANGSTROM
elif "[MASK]" == g[0]:
mask = np.array(list(map({"-": 0, "+": 1}.get, g[1][0].strip())))
atom_mask = np.zeros(
(
len(mask),
residue_constants.atom_type_num,
)
).astype(np.float32)
for i, atom in enumerate(atoms):
atom_mask[:, residue_constants.atom_order[atom]] = 1
atom_mask *= mask[..., None]
assert aatype is not None
return Protein(
atom_positions=atom_positions,
atom_mask=atom_mask,
aatype=aatype,
residue_index=np.arange(len(aatype)),
b_factors=None,
)
def get_pdb_headers(prot: Protein, chain_id: int = 0) -> List[str]:
pdb_headers: List[str] = []
remark = prot.remark
if remark is not None:
pdb_headers.append(f"REMARK {remark}")
parents = prot.parents
parents_chain_index = prot.parents_chain_index
if parents is not None and parents_chain_index is not None:
parents = [p for i, p in zip(parents_chain_index, parents) if i == chain_id]
if parents is None or len(parents) == 0:
parents = ["N/A"]
pdb_headers.append(f"PARENT {' '.join(parents)}")
return pdb_headers
def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
"""Add pdb headers to an existing PDB string. Useful during multi-chain
recycling
"""
out_pdb_lines: List[str] = []
lines = pdb_str.split("\n")
remark = prot.remark
if remark is not None:
out_pdb_lines.append(f"REMARK {remark}")
parents_per_chain: List[List[str]]
if prot.parents is not None and len(prot.parents) > 0:
parents_per_chain = []
if prot.parents_chain_index is not None:
parent_dict: Dict[str, List[str]] = {}
for p, i in zip(prot.parents, prot.parents_chain_index):
parent_dict.setdefault(str(i), [])
parent_dict[str(i)].append(p)
max_idx = max([int(chain_idx) for chain_idx in parent_dict])
for i in range(max_idx + 1):
chain_parents = parent_dict.get(str(i), ["N/A"])
parents_per_chain.append(chain_parents)
else:
parents_per_chain.append(list(prot.parents))
else:
parents_per_chain = [["N/A"]]
def make_parent_line(p: Sequence[str]) -> str:
return f"PARENT {' '.join(p)}"
out_pdb_lines.append(make_parent_line(parents_per_chain[0]))
chain_counter = 0
for i, l in enumerate(lines):
if "PARENT" not in l and "REMARK" not in l:
out_pdb_lines.append(l)
if "TER" in l and "END" not in lines[i + 1]:
chain_counter += 1
if not chain_counter >= len(parents_per_chain):
chain_parents = parents_per_chain[chain_counter]
else:
chain_parents = ["N/A"]
out_pdb_lines.append(make_parent_line(chain_parents))
return "\n".join(out_pdb_lines)
def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
Args:
prot: The protein to convert to PDB.
Returns:
PDB string.
"""
restypes = residue_constants.restypes + ["X"]
def res_1to3(r: int) -> str:
return residue_constants.restype_1to3.get(restypes[r], "UNK")
atom_types = residue_constants.atom_types
pdb_lines: List[str] = []
atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
b_factors = prot.b_factors
chain_index = prot.chain_index
if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.")
headers = get_pdb_headers(prot)
if len(headers) > 0:
pdb_lines.extend(headers)
n = aatype.shape[0]
atom_index = 1
prev_chain_index = 0
chain_tags = string.ascii_uppercase
chain_tag = None
for i in range(n):
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
if mask < 0.5:
continue
record_type = "ATOM"
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
alt_loc = ""
insertion_code = ""
occupancy = 1.00
element = atom_name[0]
charge = ""
chain_tag = "A"
if chain_index is not None:
chain_tag = chain_tags[chain_index[i]]
atom_line = (
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
f"{res_name_3:>3} {chain_tag:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} "
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
f"{occupancy:>6.2f}{b_factor:>6.2f} "
f"{element:>2}{charge:>2}"
)
pdb_lines.append(atom_line)
atom_index += 1
should_terminate = i == n - 1
if chain_index is not None:
if i != n - 1 and chain_index[i + 1] != prev_chain_index:
should_terminate = True
prev_chain_index = chain_index[i + 1]
if should_terminate:
chain_end = "TER"
chain_termination_line = (
f"{chain_end:<6}{atom_index:>5} {res_1to3(aatype[i]):>3} {chain_tag:>1}{residue_index[i]:>4}"
)
pdb_lines.append(chain_termination_line)
atom_index += 1
if i != n - 1:
pdb_lines.extend(get_pdb_headers(prot, prev_chain_index))
pdb_lines.append("END")
pdb_lines.append("")
return "\n".join(pdb_lines)
def ideal_atom_mask(prot: Protein) -> np.ndarray:
"""Computes an ideal atom mask.
`Protein.atom_mask` typically is defined according to the atoms that are reported in the PDB. This function
computes a mask according to heavy atoms that should be present in the given sequence of amino acids.
Args:
prot: `Protein` whose fields are `numpy.ndarray` objects.
Returns:
An ideal atom mask.
"""
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
def from_prediction(
features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
chain_index: Optional[np.ndarray] = None,
remark: Optional[str] = None,
parents: Optional[Sequence[str]] = None,
parents_chain_index: Optional[Sequence[int]] = None,
) -> Protein:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
chain_index: (Optional) Chain indices for multi-chain predictions
remark: (Optional) Remark about the prediction
parents: (Optional) List of template names
Returns:
A protein instance.
"""
return Protein(
aatype=features["aatype"],
atom_positions=result["final_atom_positions"],
atom_mask=result["final_atom_mask"],
residue_index=features["residue_index"] + 1,
b_factors=b_factors if b_factors is not None else np.zeros_like(result["final_atom_mask"]),
chain_index=chain_index,
remark=remark,
parents=parents,
parents_chain_index=parents_chain_index,
)
.\models\esm\openfold_utils\residue_constants.py
"""Constants used in AlphaFold."""
import collections
import copy
import functools
from importlib import resources
from typing import Dict, List, Mapping, Sequence, Tuple
import numpy as np
ca_ca = 3.80209737096
chi_angles_atoms: Dict[str, List[List[str]]] = {
"ALA": [],
"ARG": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "NE"], ["CG", "CD", "NE", "CZ"]],
"ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
"ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
"CYS": [["N", "CA", "CB", "SG"]],
"GLN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "OE1"]],
"GLU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "OE1"]],
"GLY": [],
"HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
"ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
"LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"LYS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "CE"], ["CG", "CD", "CE", "NZ"]],
"MET": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "SD"], ["CB", "CG", "SD", "CE"]],
"PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
"SER": [["N", "CA", "CB", "OG"]],
"THR": [["N", "CA", "CB", "OG1"]],
"TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"VAL": [["N", "CA", "CB", "CG1"]],
}
chi_angles_mask: List[List[float]] = [
[0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 0.0],
[1.0, 1.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[
[1.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 0.0],
[1.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
]
chi_pi_periodic: List[List[float]] = [
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
]
rigid_group_atom_positions: Dict[str, List[Tuple[str, int, Tuple[float, float, float]]]] = {
"ALA": [
("N", 0, (-0.525, 1.363, 0.000)),
("CA", 0, (0.000, 0.000, 0.000)),
("C", 0, (1.526, -0.000, -0.000)),
("CB", 0, (-0.529, -0.774, -1.205)),
("O", 3, (0.627, 1.062, 0.000)),
],
"ARG": [
("N", 0, (-0.524, 1.362, -0.000)),
("CA", 0, (0.000, 0.000, 0.000)),
("C", 0, (1.525, -0.000, -0.000)),
("CB", 0, (-0.524, -0.778, -1.209)),
("O", 3, (0.626, 1.062, 0.000)),
("CG", 4, (0.616, 1.390, -0.000)),
("CD", 5, (0.564, 1.414, 0.000)),
("NE", 6, (0.539, 1.357, -0.000)),
("NH1", 7, (0.206, 2.301, 0.000)),
("NH2", 7, (2.078, 0.978, -0.000)),
("CZ", 7, (0.758, 1.093, -0.000)),
],
"ASN": [
("N", 0, (-0.536, 1.357, 0.000)),
("CA", 0, (0.000, 0.000, 0.000)),
("C", 0, (1.526, -0.000, -0.000)),
("CB", 0, (-0.531, -0.787, -1.200)),
("O", 3, (0.625, 1.062, 0.000)),
("CG", 4, (0.584, 1.399, 0.000)),
("ND2", 5, (0.593, -1.188, 0.001)),
("OD1", 5, (0.633, 1.059, 0.000)),
],
"ASP": [
("N", 0, (-0.525, 1.362, -0.000)),
("CA", 0, (0.000, 0.000, 0.000)),
("C", 0, (1.527, 0.000, -0.000)),
("CB", 0, (-0.526, -0.778, -1.208)),
("O", 3, (0.626, 1.062, -0.000)),
("CG", 4, (0.593, 1.398, -0.000)),
("OD1", 5, (0.610, 1.091, 0.000)),
("OD2", 5, (0.592, -1.101, -0.003)),
],
"CYS": [
("N", 0, (-0.522, 1.362, -0.000)),
("CA", 0, (0.000, 0.000, 0.000)),
("C", 0, (1.524, 0.000, 0.000)),
("CB", 0, (-0.519, -0.773, -1.212)),
("O", 3, (0.625, 1.062, -0.000)),
("SG", 4, (0.728, 1.653, 0.000)),
],
"GLN": [
("N", 0, (-0.526, 1.361, -0.000)),
("CA", 0, (0.000, 0.000, 0.000)),
("C", 0, (1.526, 0.000, 0.000)),
("CB", 0, (-0.525, -0.779, -1.207)),
("O", 3, (0.626, 1.062, -0.000)),
("CG", 4, (0.615, 1.393, 0.000)),
("CD", 5, (0.587, 1.399, -0.000)),
("NE2", 6, (0.593, -1.189, -0.001)),
("OE1", 6, (0.634, 1.060, 0.000)),
],
"GLU": [
("N", 0, (-0.528, 1.361, 0.000)),
("CA", 0, (0.000, 0.000, 0.000)),
("C", 0, (1.526, -0.000, -0.000)),
("CB", 0, (-0.526, -0.781, -1.207)),
("O", 3, (0.626, 1.062, 0.000)),
("CG", 4, (0.615, 1.392, 0.000)),
("CD", 5, (0.600, 1.397, 0.000)),
("OE1", 6, (0.607, 1.095, -0.000)),
("OE2", 6, (0.589, -1.104, -0.001)),
],
"GLY": [
("N", 0, (-0.572, 1.337, 0.000)),
("CA", 0, (0.000, 0.000, 0.000)),
("C", 0, (1.517, -0.000, -0.000)),
("O", 3, (0.626, 1.062, -0.000)),
],
"HIS": [
("N", 0, (-0.527, 1.360, 0.000)),
("CA", 0, (0.000, 0.000, 0.000)),
("C", 0, (1.525, 0.000, 0.000)),
("CB", 0, (-0.525, -0.778, -1.208)),
("O", 3, (0.625, 1.063, 0.000)),
("CG", 4, (0.600, 1.370, -0.000)),
("CD2", 5, (0.889, -1.021, 0.003)),
("ND1", 5, (0.744, 1.160, -0.000)),
("CE1", 5, (2.030, 0.851, 0.002)),
("NE2", 5, (2.145, -0.466, 0.004)),
],
"ILE": [
("N", 0, (-0.493, 1.373, -0.000)),
("CA", 0, (0.000, 0.000, 0.000)),
("C", 0, (1.527, -0.000, -0.000)),
("CB", 0, (-0.536, -0.793, -1.213)),
("O", 3, (0.627, 1.062, -0.000)),
("CG1", 4, (0.534, 1.437, -0.000)),
("CG2", 4, (0.540, -0.785, -1.199)),
("CD1", 5, (0.619, 1.391, 0.000)),
],
"LEU": [
("N", 0, (-0.520, 1.363, 0.000)),
("CA", 0, (0.000, 0.000, 0.000)),
("C", 0, (1.525, -0.000, -0.000)),
("CB", 0, (-0.522, -0.773, -1.214)),
("O", 3, (0.625, 1.063, -0.000)),
("CG", 4, (0.678, 1.371, 0.000)),
("CD1", 5, (0.530, 1.430, -0.000)),
("CD2", 5, (0.535, -0.774, 1.200)),
],
"LYS": [
("N", 0, (-0.526, 1.362, -0.000)),
("CA", 0, (0.000, 0.000, 0.000)),
("C", 0, (1.526, 0.000, 0.000)),
("CB", 0, (-0.524, -0.778, -1.208)),
("O", 3, (0.626, 1.062, -0.000)),
("CG", 4, (0.619, 1.390, 0.000)),
("CD", 5, (0.559, 1.417, 0.000)),
("CE", 6, (0.560, 1.416, 0.000)),
("NZ", 7, (0.554, 1.387, 0.000)),
],
"MET": [
("N", 0, (-0.521, 1.364, -0.000)),
("CA", 0, (0.000, 0.000, 0.000)),
("C", 0, (1.525, 0.000, 0.000)),
("CB", 0, (-0.523, -0.776, -1.210)),
("O", 3, (0.625, 1.062, -0.000)),
("CG", 4, (0.613, 1.391, -0.000)),
("SD", 5, (0.703, 1.695, 0.000)),
("CE", 6, (0.320, 1.786, -0.000)),
],
"PHE": [
("N", 0, (-0.518, 1.363, 0.000)),
("CA", 0, (0.000, 0.000, 0.000)),
("C", 0, (1.524, 0.000, -0.000)),
("CB", 0, (-0.525, -0.776, -1.212)),
("O", 3, (0.626, 1.062, -0.000)),
("CG", 4, (0.607, 1.377, 0.000)),
("CD1", 5, (0.709, 1.195, -0.000)),
("CD2", 5, (0.706, -1.196, 0.000)),
("CE1", 5, (2.102, 1.198, -0.000)),
("CE2", 5, (2.098, -1.201, -0.000)),
("CZ", 5, (2.794, -0.003, -0.001)),
],
"PRO": [
("N", 0, (-0.566, 1.351, -0.000)),
("CA", 0, (0.000, 0.000, 0.000)),
("C", 0, (1.527, -0.000, 0.000)),
("CB", 0, (-0.546, -0.611, -1.293)),
("O", 3, (0.621, 1.066, 0.000)),
("CG", 4, (0.382, 1.445, 0.0)),
("CD", 5, (0.477, 1.424, 0.0)),
"TYR": [
("N", 0, (-0.522, 1.362, 0.000)),
("CA", 0, (0.000, 0.000, 0.000)),
("C", 0, (1.524, -0.000, -0.000)),
("CB", 0, (-0.522, -0.776, -1.213)),
("O", 3, (0.627, 1.062, -0.000)),
("CG", 4, (0.607, 1.382, -0.000)),
("CD1", 5, (0.716, 1.195, -0.000)),
("CD2", 5, (0.713, -1.194, -0.001)),
("CE1", 5, (2.107, 1.200, -0.002)),
("CE2", 5, (2.104, -1.201, -0.003)),
("OH", 5, (4.168, -0.002, -0.005)),
("CZ", 5, (2.791, -0.001, -0.003)),
],
"VAL": [
("N", 0, (-0.494, 1.373, -0.000)),
("CA", 0, (0.000, 0.000, 0.000)),
("C", 0, (1.527, -0.000, -0.000)),
("CB", 0, (-0.533, -0.795, -1.213)),
("O", 3, (0.627, 1.062, -0.000)),
("CG1", 4, (0.540, 1.429, -0.000)),
("CG2", 4, (0.533, -0.776, 1.203)),
],
residue_atoms: Dict[str, List[str]] = {
"ALA": ["C", "CA", "CB", "N", "O"],
"ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
"ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
"ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
"CYS": ["C", "CA", "CB", "N", "O", "SG"],
"GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
"GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
"GLY": ["C", "CA", "N", "O"],
"HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
"ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
"LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
"LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
"MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
"PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
"PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
"SER": ["C", "CA", "CB", "N", "O", "OG"],
"THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
"TRP": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE2", "CE3", "CZ2", "CZ3", "CH2", "N", "NE1", "O"],
"TYR": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O", "OH"],
"VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
}
residue_atom_renaming_swaps: Dict[str, Dict[str, str]] = {
"ASP": {"OD1": "OD2"},
"GLU": {"OE1": "OE2"},
"PHE": {"CD1": "CD2", "CE1": "CE2"},
"TYR": {"CD1": "CD2", "CE1": "CE2"},
}
van_der_waals_radius: Dict[str, float] = {
"C": 1.7,
"N": 1.55,
"O": 1.52,
"S": 1.8,
}
Bond = collections.namedtuple("Bond", ["atom1_name", "atom2_name", "length", "stddev"])
BondAngle = collections.namedtuple(
"BondAngle",
["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"],
)
def map_structure_with_atom_order(in_list: list, first_call: bool = True) -> list:
if first_call:
in_list = copy.deepcopy(in_list)
for i in range(len(in_list)):
if isinstance(in_list[i], list):
in_list[i] = map_structure_with_atom_order(in_list[i], first_call=False)
elif isinstance(in_list[i], str):
in_list[i] = atom_order[in_list[i]]
else:
raise ValueError("Unexpected type when mapping nested lists!")
return in_list
@functools.lru_cache(maxsize=None)
def load_stereo_chemical_props() -> None:
pass
Tuple[
Mapping[str, List[Bond]],
Mapping[str, List[Bond]],
Mapping[str, List[BondAngle]],
]
stereo_chemical_props = resources.read_text("openfold.resources", "stereo_chemical_props.txt")
lines_iter = iter(stereo_chemical_props.splitlines())
residue_bonds: Dict[str, List[Bond]] = {}
next(lines_iter)
for line in lines_iter:
if line.strip() == "-":
break
bond, resname, bond_length, stddev = line.split()
atom1, atom2 = bond.split("-")
if resname not in residue_bonds:
residue_bonds[resname] = []
residue_bonds[resname].append(Bond(atom1, atom2, float(bond_length), float(stddev)))
residue_bonds["UNK"] = []
residue_bond_angles: Dict[str, List[BondAngle]] = {}
next(lines_iter)
next(lines_iter)
for line in lines_iter:
if line.strip() == "-":
break
bond, resname, angle_degree, stddev_degree = line.split()
atom1, atom2, atom3 = bond.split("-")
if resname not in residue_bond_angles:
residue_bond_angles[resname] = []
residue_bond_angles[resname].append(
BondAngle(
atom1,
atom2,
atom3,
float(angle_degree) / 180.0 * np.pi,
float(stddev_degree) / 180.0 * np.pi,
)
)
residue_bond_angles["UNK"] = []
def make_bond_key(atom1_name: str, atom2_name: str) -> str:
"""创建用于查找键长的唯一键值。"""
return "-".join(sorted([atom1_name, atom2_name]))
residue_virtual_bonds: Dict[str, List[Bond]] = {}
for resname, bond_angles in residue_bond_angles.items():
bond_cache: Dict[str, Bond] = {}
for b in residue_bonds[resname]:
bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
residue_virtual_bonds[resname] = []
for ba in bond_angles:
bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
gamma = ba.angle_rad
length = np.sqrt(bond1.length**2 + bond2.length**2 - 2 * bond1.length * bond2.length * np.cos(gamma))
dl_outer = 0.5 / length
dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer
dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer
dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer
stddev = np.sqrt(
(dl_dgamma * ba.stddev) ** 2 + (dl_db1 * bond1.stddev) ** 2 + (dl_db2 * bond2.stddev) ** 2
)
residue_virtual_bonds[resname].append(Bond(ba.atom1_name, ba.atom3name, length, stddev))
return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
between_res_bond_length_c_n: Tuple[float, float] = (1.329, 1.341)
between_res_bond_length_stddev_c_n: Tuple[float, float] = (0.014, 0.016)
between_res_cos_angles_c_n_ca: Tuple[float, float] = (-0.5203, 0.0353)
between_res_cos_angles_ca_c_n: Tuple[float, float] = (-0.4473, 0.0311)
atom_types: List[str] = [
"N", "CA", "C", "CB", "O", "CG", "CG1", "CG2", "OG", "OG1", "SG", "CD",
"CD1", "CD2", "ND1", "ND2", "OD1", "OD2", "SD", "CE", "CE1", "CE2", "CE3",
"NE", "NE1", "NE2", "OE1", "OE2", "CH2", "NH1", "NH2", "OH", "CZ", "CZ2",
"CZ3", "NZ", "OXT",
]
atom_order: Dict[str, int] = {atom_type: i for i, atom_type in enumerate(atom_types)}
atom_type_num = len(atom_types)
restype_name_to_atom14_names: Dict[str, List[str]] = {
"ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
"ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2", "", "", ""],
"ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""],
"ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""],
"CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
"GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2", "", "", "", "", ""],
"GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2", "", "", "", "", ""],
"GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
"HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2", "", "", "", ""],
"ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""],
"LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""],
"LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""],
"MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""],
"PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "", "", ""],
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
"SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
"THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""],
"TRP": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "NE1", "CE2", "CE3", "CZ2", "CZ3", "CH2"],
"TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH", "", ""],
"VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""],
"UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
restypes: List[str] = [
"A",
"R",
"N",
"D",
"C",
"Q",
"E",
"G",
"H",
"I",
"L",
"K",
"M",
"F",
"P",
"S",
"T",
"W",
"Y",
"V",
]
restype_order: Dict[str, int] = {restype: i for i, restype in enumerate(restypes)}
restype_num = len(restypes)
unk_restype_index = restype_num
restypes_with_x: List[str] = restypes + ["X"]
restype_order_with_x: Dict[str, int] = {restype: i for i, restype in enumerate(restypes_with_x)}
def sequence_to_onehot(sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False) -> np.ndarray:
"""Maps the given sequence into a one-hot encoded matrix.
Args:
sequence: An amino acid sequence.
mapping: A dictionary mapping amino acids to integers.
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
mapped to the unknown amino acid 'X'. If the mapping doesn't contain amino acid 'X', an error will be thrown.
If False, any amino acid not in the mapping will throw an error.
Returns:
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of the sequence.
Raises:
ValueError: If the mapping doesn't contain values from 0 to
num_unique_aas - 1 without any gaps.
"""
num_entries = max(mapping.values()) + 1
if sorted(set(mapping.values())) != list(range(num_entries)):
raise ValueError(
"The mapping must have values from 0 to num_unique_aas-1 without any gaps. Got: %s"
% sorted(mapping.values())
)
one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
for aa_index, aa_type in enumerate(sequence):
if map_unknown_to_x:
if aa_type.isalpha() and aa_type.isupper():
aa_id = mapping.get(aa_type, mapping["X"])
else:
raise ValueError(f"Invalid character in the sequence: {aa_type}")
else:
aa_id = mapping[aa_type]
one_hot_arr[aa_index, aa_id] = 1
return one_hot_arr
restype_1to3: Dict[str, str] = {
"A": "ALA",
"R": "ARG",
"N": "ASN",
"D": "ASP",
"C": "CYS",
"Q": "GLN",
"E": "GLU",
"G": "GLY",
"H": "HIS",
"I": "ILE",
"L": "LEU",
"K": "LYS",
"M": "MET",
"F": "PHE",
"P": "PRO",
"S": "SER",
"T": "THR",
"W": "TRP",
"Y": "TYR",
"V": "VAL",
}
restype_3to1: Dict[str, str] = {v: k for k, v in restype_1to3.items()}
unk_restype = "UNK"
resnames: List[str] = [restype_1to3[r] for r in restypes] + [unk_restype]
resname_to_idx: Dict[str, int] = {resname: i for i, resname in enumerate(resnames)}
HHBLITS_AA_TO_ID: Dict[str, int] = {
"A": 0,
"B": 2,
"C": 1,
"D": 2,
"E": 3,
"F": 4,
"G": 5,
"H": 6,
"I": 7,
"J": 20,
"K": 8,
"L": 9,
"M": 10,
"N": 11,
"O": 20,
"P": 12,
"Q": 13,
"R": 14,
"S": 15,
"T": 16,
"U": 1,
"V": 17,
"W": 18,
"X": 20,
"Y": 19,
"Z": 3,
"-": 21,
}
ID_TO_HHBLITS_AA: Dict[int, str] = {
0: "A",
1: "C",
2: "D",
3: "E",
4: "F",
5: "G",
6: "H",
7: "I",
8: "K",
9: "L",
10: "M",
11: "N",
12: "P",
13: "Q",
14: "R",
15: "S",
16: "T",
17: "V",
18: "W",
19: "Y",
20: "X",
21: "-",
}
restypes_with_x_and_gap: List[str] = restypes + ["X", "-"]
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE: Tuple[int, ...] = tuple(
restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) for i in range(len(restypes_with_x_and_gap))
)
def _make_standard_atom_mask() -> np.ndarray:
"""Returns [num_res_types, num_atom_types] mask array."""
mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
for restype, restype_letter in enumerate(restypes):
restype_name = restype_1to3[restype_letter]
atom_names = residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = atom_order[atom_name]
mask[restype, atom_type] = 1
return mask
STANDARD_ATOM_MASK = _make_standard_atom_mask()
def chi_angle_atom(atom_index: int) -> np.ndarray:
"""Define chi-angle rigid groups via one-hot representations."""
chi_angles_index = {}
one_hots = []
for k, v in chi_angles_atoms.items():
indices = [atom_types.index(s[atom_index]) for s in v]
indices.extend([-1] * (4 - len(indices)))
chi_angles_index[k] = indices
for r in restypes:
res3 = restype_1to3[r]
one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
one_hots.append(one_hot)
one_hots.append(np.zeros([4, atom_type_num]))
one_hot = np.stack(one_hots, axis=0)
one_hot = np.transpose(one_hot, [0, 2, 1])
return one_hot
chi_atom_1_one_hot = chi_angle_atom(1)
chi_atom_2_one_hot = chi_angle_atom(2)
chi_angles_atom_indices_list: List[List[List[str]]] = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
chi_angles_atom_indices_ours: list = map_structure_with_atom_order(chi_angles_atom_indices_list)
chi_angles_atom_indices = np.array(
[chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) for chi_atoms in chi_angles_atom_indices_list]
)
chi_groups_for_atom: Dict[Tuple[str, str], List[Tuple[int, int]]] = collections.defaultdict(list)
for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
for atom_i, atom in enumerate(chi_group):
chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
chi_groups_for_atom = dict(chi_groups_for_atom)
def _make_rigid_transformation_4x4(ex: np.ndarray, ey: np.ndarray, translation: np.ndarray) -> np.ndarray:
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
ex_normalized = ex / np.linalg.norm(ex)
ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
ey_normalized /= np.linalg.norm(ey_normalized)
eznorm = np.cross(ex_normalized, ey_normalized)
m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose()
m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
return m
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
def _make_rigid_group_constants() -> None:
"""Fill the arrays above."""
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
for atomname, group_idx, atom_position in rigid_group_atom_positions[resname]:
atomtype = atom_order[atomname]
restype_atom37_to_rigid_group[restype, atomtype] = group_idx
restype_atom37_mask[restype, atomtype] = 1
restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position
atom14idx = restype_name_to_atom14_names[resname].index(atomname)
restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
restype_atom14_mask[restype, atom14idx] = 1
restype_atom14_rigid_group_positions[restype, atom14idx, :] = atom_position
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_positions: Dict[str, np.ndarray] = {
name: np.array(pos) for name, _, pos in rigid_group_atom_positions[resname]
}
restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
mat = _make_rigid_transformation_4x4(
ex=atom_positions["N"] - atom_positions["CA"],
ey=np.array([1.0, 0.0, 0.0]),
translation=atom_positions["N"],
)
restype_rigid_group_default_frame[restype, 2, :, :] = mat
mat = _make_rigid_transformation_4x4(
ex=atom_positions["C"] - atom_positions["CA"],
ey=atom_positions["CA"] - atom_positions["N"],
translation=atom_positions["C"],
)
restype_rigid_group_default_frame[restype, 3, :, :] = mat
if chi_angles_mask[restype][0]:
base_atom_names = chi_angles_atoms[resname][0]
base_atom_positions = [atom_positions[name] for name in base_atom_names]
mat = _make_rigid_transformation_4x4(
ex=base_atom_positions[2] - base_atom_positions[1],
ey=base_atom_positions[0] - base_atom_positions[1],
translation=base_atom_positions[2],
)
restype_rigid_group_default_frame[restype, 4, :, :] = mat
for chi_idx in range(1, 4):
if chi_angles_mask[restype][chi_idx]:
axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
axis_end_atom_position = atom_positions[axis_end_atom_name]
mat = _make_rigid_transformation_4x4(
ex=axis_end_atom_position,
ey=np.array([-1.0, 0.0, 0.0]),
translation=axis_end_atom_position,
)
restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat
_make_rigid_group_constants()
def make_atom14_dists_bounds(
overlap_tolerance: float = 1.5,
bond_length_tolerance_factor: int = 15,
) -> Dict[str, np.ndarray]:
"""compute upper and lower bounds for bonds to assess violations."""
restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_list = restype_name_to_atom14_names[resname]
for atom1_idx, atom1_name in enumerate(atom_list):
if not atom1_name:
continue
atom1_radius = van_der_waals_radius[atom1_name[0]]
for atom2_idx, atom2_name in enumerate(atom_list):
if (not atom2_name) or atom1_idx == atom2_idx:
continue
atom2_radius = van_der_waals_radius[atom2_name[0]]
lower = atom1_radius + atom2_radius - overlap_tolerance
upper = 1e10
restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
atom1_idx = atom_list.index(b.atom1_name)
atom2_idx = atom_list.index(b.atom2_name)
lower = b.length - bond_length_tolerance_factor * b.stddev
upper = b.length + bond_length_tolerance_factor * b.stddev
restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
return {
"lower_bound": restype_atom14_bond_lower_bound,
"upper_bound": restype_atom14_bond_upper_bound,
"stddev": restype_atom14_bond_stddev,
}
restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
restype_atom14_ambiguous_atoms_swap_idx: np.ndarray = np.tile(np.arange(14, dtype=int), (21, 1))
def _make_atom14_ambiguity_feats() -> None:
for res, pairs in residue_atom_renaming_swaps.items():
res_idx = restype_order[restype_3to1[res]]
for atom1, atom2 in pairs.items():
atom1_idx = restype_name_to_atom14_names[res].index(atom1)
atom2_idx = restype_name_to_atom14_names[res].index(atom2)
restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1
restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1
restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom1_idx] = atom2_idx
restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom2_idx] = atom1_idx
_make_atom14_ambiguity_feats()
def aatype_to_str_sequence(aatype: Sequence[int]) -> str:
return "".join([restypes_with_x[aatype[i]] for i in range(len(aatype))])
.\models\esm\openfold_utils\rigid_utils.py
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
def rot_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
执行两个旋转矩阵张量的矩阵乘法。手动编写以避免 AMP 下转换。
Args:
a: [*, 3, 3] 左乘数
b: [*, 3, 3] 右乘数
Returns:
乘积 ab
"""
def row_mul(i: int) -> torch.Tensor:
return torch.stack(
[
a[..., i, 0] * b[..., 0, 0] + a[..., i, 1] * b[..., 1, 0] + a[..., i, 2] * b[..., 2, 0],
a[..., i, 0] * b[..., 0, 1] + a[..., i, 1] * b[..., 1, 1] + a[..., i, 2] * b[..., 2, 1],
a[..., i, 0] * b[..., 0, 2] + a[..., i, 1] * b[..., 1, 2] + a[..., i, 2] * b[..., 2, 2],
],
dim=-1,
)
return torch.stack(
[
row_mul(0),
row_mul(1),
row_mul(2),
],
dim=-2,
)
def rot_vec_mul(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
对向量施加旋转。手动编写以避免 AMP 下转换。
Args:
r: [*, 3, 3] 旋转矩阵
t: [*, 3] 坐标张量
Returns:
[*, 3] 旋转后的坐标
"""
x, y, z = torch.unbind(t, dim=-1)
return torch.stack(
[
r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
],
dim=-1,
)
@lru_cache(maxsize=None)
def identity_rot_mats(
batch_dims: Tuple[int, ...],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
"""
返回指定批次维度下的单位旋转矩阵张量。
Args:
batch_dims: 批次维度的元组
dtype: 张量数据类型,默认为 None
device: 张量的设备,默认为 None
requires_grad: 是否需要梯度,默认为 True
Returns:
torch.Tensor: 单位旋转矩阵张量
"""
rots = torch.eye(3, dtype=dtype, device=device, requires_grad=requires_grad)
rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
rots = rots.expand(*batch_dims, -1, -1)
rots = rots.contiguous()
return rots
@lru_cache(maxsize=None)
def identity_trans(
batch_dims: Tuple[int, ...],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
"""
返回指定批次维度下的单位平移张量。
Args:
batch_dims: 批次维度的元组
dtype: 张量数据类型,默认为 None
device: 张量的设备,默认为 None
requires_grad: 是否需要梯度,默认为 True
Returns:
torch.Tensor: 单位平移张量
"""
trans = torch.zeros(3, dtype=dtype, device=device, requires_grad=requires_grad)
trans = trans.view(*((1,) * len(batch_dims)), 3)
trans = trans.expand(*batch_dims, -1)
trans = trans.contiguous()
return trans
def identity_quats(
batch_dims: Tuple[int, ...],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
trans = torch.zeros((*batch_dims, 3), dtype=dtype, device=device, requires_grad=requires_grad)
return trans
@lru_cache(maxsize=None)
def identity_quats(
batch_dims: Tuple[int, ...],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
quat = torch.zeros((*batch_dims, 4), dtype=dtype, device=device, requires_grad=requires_grad)
with torch.no_grad():
quat[..., 0] = 1
return quat
_quat_elements: List[str] = ["a", "b", "c", "d"]
_qtr_keys: List[str] = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
_qtr_ind_dict: Dict[str, int] = {key: ind for ind, key in enumerate(_qtr_keys)}
def _to_mat(pairs: List[Tuple[str, int]]) -> np.ndarray:
mat = np.zeros((4, 4))
for key, value in pairs:
ind = _qtr_ind_dict[key]
mat[ind // 4][ind % 4] = value
return mat
_QTR_MAT = np.zeros((4, 4, 3, 3))
_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
"""
Converts a quaternion to a rotation matrix.
Args:
quat: [*, 4] quaternions
Returns:
[*, 3, 3] rotation matrices
"""
quat = quat[..., None] * quat[..., None, :]
mat = _get_quat("_QTR_MAT", dtype=quat.dtype, device=quat.device)
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
quat = quat[..., None, None] * shaped_qtr_mat
return torch.sum(quat, dim=(-3, -4))
def rot_to_quat(rot: torch.Tensor) -> torch.Tensor:
if rot.shape[-2:] != (3, 3):
raise ValueError("Input rotation is incorrectly shaped")
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = [[rot[..., i, j] for j in range(3)] for i in range(3)]
k = [
[
xx + yy + zz,
zy - yz,
xz - zx,
yx - xy,
],
[
zy - yz,
xx - yy - zz,
xy + yx,
xz + zx,
],
[
xz - zx,
xy + yx,
yy - xx - zz,
yz + zy,
],
[
yx - xy,
xz + zx,
yz + zy,
zz - xx - yy,
],
]
_, vectors = torch.linalg.eigh((1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2))
return vectors[..., -1]
_QUAT_MULTIPLY = np.zeros((4, 4, 4))
_QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, -1]]
_QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0]]
_QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], [0, 1, 0, 0]]
_QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], [1, 0, 0, 0]]
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
_CACHED_QUATS: Dict[str, np.ndarray] = {
"_QTR_MAT": _QTR_MAT,
"_QUAT_MULTIPLY": _QUAT_MULTIPLY,
"_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC,
}
@lru_cache(maxsize=None)
def _get_quat(quat_key: str, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device)
def quat_multiply(quat1: torch.Tensor, quat2: torch.Tensor) -> torch.Tensor:
"""Multiply a quaternion by another quaternion."""
mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device)
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
return torch.sum(reshaped_mat * quat1[..., :, None, None] * quat2[..., None, :, None], dim=(-3, -2))
def quat_multiply_by_vec(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
"""Multiply a quaternion by a pure-vector quaternion."""
mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device)
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
return torch.sum(reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], dim=(-3, -2))
def invert_rot_mat(rot_mat: torch.Tensor) -> torch.Tensor:
return rot_mat.transpose(-1, -2)
def invert_quat(quat: torch.Tensor) -> torch.Tensor:
quat_prime = quat.clone()
quat_prime[..., 1:] *= -1
inv = quat_prime / torch.sum(quat**2, dim=-1, keepdim=True)
return inv
class Rotation:
"""
A 3D rotation. Depending on how the object is initialized, the rotation is represented by either a rotation matrix
or a quaternion, though both formats are made available by helper functions. To simplify gradient computation, the
underlying format of the rotation cannot be changed in-place. Like Rigid, the class is designed to mimic the
behavior of a torch Tensor, almost as if each Rotation object were a tensor of rotations, in one format or another.
"""
def __init__(
self,
rot_mats: Optional[torch.Tensor] = None,
quats: Optional[torch.Tensor] = None,
normalize_quats: bool = True,
"""
Args:
rot_mats:
A [*, 3, 3] rotation matrix tensor. Mutually exclusive with quats
quats:
A [*, 4] quaternion. Mutually exclusive with rot_mats. If normalize_quats is not True, must be a unit
quaternion
normalize_quats:
If quats is specified, whether to normalize quats
"""
if (rot_mats is None and quats is None) or (rot_mats is not None and quats is not None):
raise ValueError("Exactly one input argument must be specified")
if (rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or (quats is not None and quats.shape[-1] != 4):
raise ValueError("Incorrectly shaped rotation matrix or quaternion")
if quats is not None:
quats = quats.to(dtype=torch.float32)
if rot_mats is not None:
rot_mats = rot_mats.to(dtype=torch.float32)
if quats is not None and normalize_quats:
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
self._rot_mats = rot_mats
self._quats = quats
@staticmethod
def identity(
shape,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
fmt: str = "quat",
) -> Rotation:
"""
Returns an identity Rotation.
Args:
shape:
The "shape" of the resulting Rotation object. See documentation for the shape property
dtype:
The torch dtype for the rotation
device:
The torch device for the new rotation
requires_grad:
Whether the underlying tensors in the new rotation object should require gradient computation
fmt:
One of "quat" or "rot_mat". Determines the underlying format of the new object's rotation
Returns:
A new identity rotation
"""
if fmt == "rot_mat":
rot_mats = identity_rot_mats(
shape,
dtype,
device,
requires_grad,
)
return Rotation(rot_mats=rot_mats, quats=None)
elif fmt == "quat":
quats = identity_quats(shape, dtype, device, requires_grad)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError(f"Invalid format: f{fmt}")
def __getitem__(self, index: Any) -> Rotation:
"""
Allows torch-style indexing over the virtual shape of the rotation object. See documentation for the shape
property.
Args:
index:
A torch index. E.g. (1, 3, 2), or (slice(None,))
Returns:
The indexed rotation
"""
if type(index) != tuple:
index = (index,)
if self._rot_mats is not None:
rot_mats = self._rot_mats[index + (slice(None), slice(None))]
return Rotation(rot_mats=rot_mats)
elif self._quats is not None:
quats = self._quats[index + (slice(None),)]
return Rotation(quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def __mul__(self, right: torch.Tensor) -> Rotation:
"""
Pointwise left multiplication of the rotation with a tensor. Can be used to e.g. mask the Rotation.
Args:
right:
The tensor multiplicand
Returns:
The product
"""
if not (isinstance(right, torch.Tensor)):
raise TypeError("The other multiplicand must be a Tensor")
if self._rot_mats is not None:
rot_mats = self._rot_mats * right[..., None, None]
return Rotation(rot_mats=rot_mats, quats=None)
elif self._quats is not None:
quats = self._quats * right[..., None]
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def __rmul__(self, left: torch.Tensor) -> Rotation:
"""
Reverse pointwise multiplication of the rotation with a tensor.
Args:
left:
The left multiplicand
Returns:
The product
"""
return self.__mul__(left)
@property
def shape(self) -> torch.Size:
"""
Returns the virtual shape of the rotation object. This shape is defined as the batch dimensions of the
underlying rotation matrix or quaternion. If the Rotation was initialized with a [10, 3, 3] rotation matrix
tensor, for example, the resulting shape would be [10].
Returns:
The virtual shape of the rotation object
"""
if self._rot_mats is not None:
return self._rot_mats.shape[:-2]
elif self._quats is not None:
return self._quats.shape[:-1]
else:
raise ValueError("Both rotations are None")
@property
def dtype(self) -> torch.dtype:
"""
Returns the dtype of the underlying rotation.
Returns:
The dtype of the underlying rotation
"""
if self._rot_mats is not None:
return self._rot_mats.dtype
elif self._quats is not None:
return self._quats.dtype
else:
raise ValueError("Both rotations are None")
@property
def device(self) -> torch.device:
"""
The device of the underlying rotation
Returns:
The device of the underlying rotation
"""
if self._rot_mats is not None:
return self._rot_mats.device
elif self._quats is not None:
return self._quats.device
else:
raise ValueError("Both rotations are None")
@property
def requires_grad(self) -> bool:
"""
Returns the requires_grad property of the underlying rotation
Returns:
The requires_grad property of the underlying tensor
"""
if self._rot_mats is not None:
return self._rot_mats.requires_grad
elif self._quats is not None:
return self._quats.requires_grad
else:
raise ValueError("Both rotations are None")
def get_rot_mats(self) -> torch.Tensor:
"""
Returns the underlying rotation as a rotation matrix tensor.
Returns:
The rotation as a rotation matrix tensor
"""
if self._rot_mats is not None:
return self._rot_mats
elif self._quats is not None:
return quat_to_rot(self._quats)
else:
raise ValueError("Both rotations are None")
def get_quats(self) -> torch.Tensor:
"""
Returns the underlying rotation as a quaternion tensor.
Depending on whether the Rotation was initialized with a quaternion, this function may call torch.linalg.eigh.
Returns:
The rotation as a quaternion tensor.
"""
if self._rot_mats is not None:
return rot_to_quat(self._rot_mats)
elif self._quats is not None:
return self._quats
else:
raise ValueError("Both rotations are None")
def get_cur_rot(self) -> torch.Tensor:
"""
Return the underlying rotation in its current form
Returns:
The stored rotation
"""
if self._rot_mats is not None:
return self._rot_mats
elif self._quats is not None:
return self._quats
else:
raise ValueError("Both rotations are None")
def compose_q_update_vec(self, q_update_vec: torch.Tensor, normalize_quats: bool = True) -> Rotation:
"""
Returns a new quaternion Rotation after updating the current object's underlying rotation with a quaternion
update, formatted as a [*, 3] tensor whose final three columns represent x, y, z such that (1, x, y, z) is the
desired (not necessarily unit) quaternion update.
Args:
q_update_vec:
A [*, 3] quaternion update tensor
normalize_quats:
Whether to normalize the output quaternion
Returns:
An updated Rotation
"""
quats = self.get_quats()
new_quats = quats + quat_multiply_by_vec(quats, q_update_vec)
return Rotation(
rot_mats=None,
quats=new_quats,
normalize_quats=normalize_quats,
)
def compose_r(self, r: Rotation) -> Rotation:
"""
Compose the rotation matrices of the current Rotation object with those of another.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
r1 = self.get_rot_mats()
r2 = r.get_rot_mats()
new_rot_mats = rot_matmul(r1, r2)
return Rotation(rot_mats=new_rot_mats, quats=None)
def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation:
"""
Compose the quaternions of the current Rotation object with those of another.
Depending on whether either Rotation was initialized with quaternions, this function may call
torch.linalg.eigh.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
q1 = self.get_quats()
q2 = r.get_quats()
new_quats = quat_multiply(q1, q2)
return Rotation(rot_mats=None, quats=new_quats, normalize_quats=normalize_quats)
def apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
Apply the current Rotation as a rotation matrix to a set of 3D coordinates.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] rotated points
"""
rot_mats = self.get_rot_mats()
return rot_vec_mul(rot_mats, pts)
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
The inverse of the apply() method.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] inverse-rotated points
"""
rot_mats = self.get_rot_mats()
inv_rot_mats = invert_rot_mat(rot_mats)
return rot_vec_mul(inv_rot_mats, pts)
def invert(self) -> Rotation:
"""
Returns the inverse of the current Rotation.
Returns:
The inverse of the current Rotation
"""
if self._rot_mats is not None:
return Rotation(rot_mats=invert_rot_mat(self._rot_mats), quats=None)
elif self._quats is not None:
return Rotation(
rot_mats=None,
quats=invert_quat(self._quats),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
def unsqueeze(self, dim: int) -> Rotation:
"""
Analogous to torch.unsqueeze. The dimension is relative to the shape of the Rotation object.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed Rotation.
"""
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
if self._rot_mats is not None:
rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
return Rotation(rot_mats=rot_mats, quats=None)
elif self._quats is not None:
quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
@staticmethod
def cat(rs: Sequence[Rotation], dim: int) -> Rotation:
"""
Concatenates rotations along one of the batch dimensions. Analogous to torch.cat().
Note that the output of this operation is always a rotation matrix, regardless of the format of input
rotations.
Args:
rs:
A list of rotation objects
dim:
The dimension along which the rotations should be concatenated
Returns:
A concatenated Rotation object in rotation matrix format
"""
rot_mats = torch.cat(
[r.get_rot_mats() for r in rs],
dim=dim if dim >= 0 else dim - 2,
)
return Rotation(rot_mats=rot_mats, quats=None)
def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rotation:
"""
Apply a Tensor -> Tensor function to underlying rotation tensors, mapping over the rotation dimension(s). Can
be used e.g. to sum out a one-hot batch dimension.
Args:
fn:
A Tensor -> Tensor function to be mapped over the Rotation
Returns:
The transformed Rotation object
"""
if self._rot_mats is not None:
rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
rot_mats = torch.stack(list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1)
rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
return Rotation(rot_mats=rot_mats, quats=None)
elif self._quats is not None:
quats = torch.stack(list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def cuda(self) -> Rotation:
"""
Analogous to the cuda() method of torch Tensors
Returns:
A copy of the Rotation in CUDA memory
"""
if self._rot_mats is not None:
return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
elif self._quats is not None:
return Rotation(rot_mats=None, quats=self._quats.cuda(), normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def to(self, device: Optional[torch.device], dtype: Optional[torch.dtype]) -> Rotation:
"""
Analogous to the to() method of torch Tensors
Args:
device:
A torch device
dtype:
A torch dtype
Returns:
A copy of the Rotation using the new device and dtype
"""
if self._rot_mats is not None:
return Rotation(
rot_mats=self._rot_mats.to(device=device, dtype=dtype),
quats=None,
)
elif self._quats is not None:
return Rotation(
rot_mats=None,
quats=self._quats.to(device=device, dtype=dtype),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
def detach(self) -> Rotation:
"""
Returns a copy of the Rotation whose underlying Tensor has been detached from its torch graph.
Returns:
A copy of the Rotation whose underlying Tensor has been detached from its torch graph
"""
if self._rot_mats is not None:
return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
elif self._quats is not None:
return Rotation(
rot_mats=None,
quats=self._quats.detach(),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
class Rigid:
"""
A class representing a rigid transformation. Little more than a wrapper around two objects: a Rotation object and a
[*, 3] translation Designed to behave approximately like a single torch tensor with the shape of the shared batch
dimensions of its component parts.
"""
def __init__(self, rots: Optional[Rotation], trans: Optional[torch.Tensor]):
"""
Args:
rots: A [*, 3, 3] rotation tensor
trans: A corresponding [*, 3] translation tensor
"""
batch_dims, dtype, device, requires_grad = None, None, None, None
if trans is not None:
batch_dims = trans.shape[:-1]
dtype = trans.dtype
device = trans.device
requires_grad = trans.requires_grad
elif rots is not None:
batch_dims = rots.shape
dtype = rots.dtype
device = rots.device
requires_grad = rots.requires_grad
else:
raise ValueError("At least one input argument must be specified")
if rots is None:
rots = Rotation.identity(
batch_dims,
dtype,
device,
requires_grad,
)
elif trans is None:
trans = identity_trans(
batch_dims,
dtype,
device,
requires_grad,
)
assert rots is not None
assert trans is not None
if (rots.shape != trans.shape[:-1]) or (rots.device != trans.device):
raise ValueError("Rots and trans incompatible")
trans = trans.to(dtype=torch.float32)
self._rots = rots
self._trans = trans
@staticmethod
def identity(
shape: Tuple[int, ...],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
fmt: str = "quat",
) -> Rigid:
"""
Constructs an identity transformation.
Args:
shape:
The desired shape
dtype:
The dtype of both internal tensors
device:
The device of both internal tensors
requires_grad:
Whether grad should be enabled for the internal tensors
Returns:
The identity transformation
"""
return Rigid(
Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
identity_trans(shape, dtype, device, requires_grad),
)
def __getitem__(self, index: Any) -> Rigid:
"""
Indexes the affine transformation with PyTorch-style indices. The index is applied to the shared dimensions of
both the rotation and the translation.
E.g.::
r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
t = Rigid(r, torch.rand(10, 10, 3))
indexed = t[3, 4:6]
assert(indexed.shape == (2,))
assert(indexed.get_rots().shape == (2,))
assert(indexed.get_trans().shape == (2, 3))
Args:
index: A standard torch tensor index. E.g. 8, (10, None, 3),
or (3, slice(0, 1, None))
Returns:
The indexed tensor
"""
if type(index) != tuple:
index = (index,)
return Rigid(
self._rots[index],
self._trans[index + (slice(None),)],
)
def __mul__(self, right: torch.Tensor) -> Rigid:
"""
Pointwise left multiplication of the transformation with a tensor. Can be used to e.g. mask the Rigid.
Args:
right:
The tensor multiplicand
Returns:
The product
"""
if not (isinstance(right, torch.Tensor)):
raise TypeError("The other multiplicand must be a Tensor")
new_rots = self._rots * right
new_trans = self._trans * right[..., None]
return Rigid(new_rots, new_trans)
def __rmul__(self, left: torch.Tensor) -> Rigid:
"""
Reverse pointwise multiplication of the transformation with a tensor.
Args:
left:
The left multiplicand
Returns:
The product
"""
return self.__mul__(left)
@property
def shape(self) -> torch.Size:
"""
Returns the shape of the shared dimensions of the rotation and the translation.
Returns:
The shape of the transformation
"""
return self._trans.shape[:-1]
@property
def device(self) -> torch.device:
"""
Returns the device on which the Rigid's tensors are located.
Returns:
The device on which the Rigid's tensors are located
"""
return self._trans.device
def get_rots(self) -> Rotation:
"""
Getter for the rotation.
Returns:
The rotation object
"""
return self._rots
def get_trans(self) -> torch.Tensor:
"""
Getter for the translation.
Returns:
The stored translation
"""
return self._trans
def compose_q_update_vec(self, q_update_vec: torch.Tensor) -> Rigid:
"""
Composes the transformation with a quaternion update vector of shape [*, 6], where the final 6 columns
represent the x, y, and z values of a quaternion of form (1, x, y, z) followed by a 3D translation.
Args:
q_vec: The quaternion update vector.
Returns:
The composed transformation.
"""
q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
new_rots = self._rots.compose_q_update_vec(q_vec)
trans_update = self._rots.apply(t_vec)
new_translation = self._trans + trans_update
return Rigid(new_rots, new_translation)
def compose(self, r: Rigid) -> Rigid:
"""
Composes the current rigid object with another.
Args:
r:
Another Rigid object
Returns:
The composition of the two transformations
"""
new_rot = self._rots.compose_r(r._rots)
new_trans = self._rots.apply(r._trans) + self._trans
return Rigid(new_rot, new_trans)
def apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
Applies the transformation to a coordinate tensor.
Args:
pts: A [*, 3] coordinate tensor.
Returns:
The transformed points.
"""
rotated = self._rots.apply(pts)
return rotated + self._trans
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
Applies the inverse of the transformation to a coordinate tensor.
Args:
pts: A [*, 3] coordinate tensor
Returns:
The transformed points.
"""
pts = pts - self._trans
return self._rots.invert_apply(pts)
def invert(self) -> Rigid:
"""
Inverts the transformation.
Returns:
The inverse transformation.
"""
rot_inv = self._rots.invert()
trn_inv = rot_inv.apply(self._trans)
return Rigid(rot_inv, -1 * trn_inv)
def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rigid:
"""
Apply a Tensor -> Tensor function to underlying translation and rotation tensors, mapping over the
translation/rotation dimensions respectively.
Args:
fn:
A Tensor -> Tensor function to be mapped over the Rigid
Returns:
The transformed Rigid object
"""
new_rots = self._rots.map_tensor_fn(fn)
new_trans = torch.stack(list(map(fn, torch.unbind(self._trans, dim=-1))), dim=-1)
return Rigid(new_rots, new_trans)
def to_tensor_4x4(self) -> torch.Tensor:
"""
Converts a transformation to a homogenous transformation tensor.
Returns:
A [*, 4, 4] homogenous transformation tensor
"""
tensor = self._trans.new_zeros((*self.shape, 4, 4))
tensor[..., :3, :3] = self._rots.get_rot_mats()
tensor[..., :3, 3] = self._trans
tensor[..., 3, 3] = 1
return tensor
@staticmethod
def from_tensor_4x4(t: torch.Tensor) -> Rigid:
"""
Constructs a transformation from a homogenous transformation tensor.
Args:
t: [*, 4, 4] homogenous transformation tensor
Returns:
T object with shape [*]
"""
if t.shape[-2:] != (4, 4):
raise ValueError("Incorrectly shaped input tensor")
rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
trans = t[..., :3, 3]
return Rigid(rots, trans)
def to_tensor_7(self) -> torch.Tensor:
"""
Converts a transformation to a tensor with 7 final columns, four for the quaternion followed by three for the
translation.
Returns:
A [*, 7] tensor representation of the transformation
"""
tensor = self._trans.new_zeros((*self.shape, 7))
tensor[..., :4] = self._rots.get_quats()
tensor[..., 4:] = self._trans
return tensor
@staticmethod
def from_tensor_7(t: torch.Tensor, normalize_quats: bool = False) -> Rigid:
if t.shape[-1] != 7:
raise ValueError("Incorrectly shaped input tensor")
quats, trans = t[..., :4], t[..., 4:]
rots = Rotation(rot_mats=None, quats=quats, normalize_quats=normalize_quats)
return Rigid(rots, trans)
@staticmethod
def from_3_points(
p_neg_x_axis: torch.Tensor, origin: torch.Tensor, p_xy_plane: torch.Tensor, eps: float = 1e-8
):
pass
) -> Rigid:
"""
Implements algorithm 21. Constructs transformations from sets of 3 points using the Gram-Schmidt algorithm.
Args:
p_neg_x_axis: [*, 3] coordinates
Coordinates of points defining the negative x-axis direction
origin: [*, 3] coordinates used as frame origins
Coordinates of points defining the origin of the frame
p_xy_plane: [*, 3] coordinates
Coordinates of points defining the xy-plane orientation
eps: Small epsilon value
Small value added to avoid division by zero
Returns:
A transformation object of shape [*]
"""
p_neg_x_axis_unbound = torch.unbind(p_neg_x_axis, dim=-1)
origin_unbound = torch.unbind(origin, dim=-1)
p_xy_plane_unbound = torch.unbind(p_xy_plane, dim=-1)
e0 = [c1 - c2 for c1, c2 in zip(origin_unbound, p_neg_x_axis_unbound)]
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane_unbound, origin_unbound)]
denom = torch.sqrt(sum(c * c for c in e0) + eps * torch.ones_like(e0[0]))
e0 = [c / denom for c in e0]
dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
denom = torch.sqrt(sum((c * c for c in e1)) + eps * torch.ones_like(e1[0]))
e1 = [c / denom for c in e1]
e2 = [
e0[1] * e1[2] - e0[2] * e1[1],
e0[2] * e1[0] - e0[0] * e1[2],
e0[0] * e1[1] - e0[1] * e1[0],
]
rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
rots = rots.reshape(rots.shape[:-1] + (3, 3))
rot_obj = Rotation(rot_mats=rots, quats=None)
return Rigid(rot_obj, torch.stack(origin_unbound, dim=-1))
def unsqueeze(self, dim: int) -> Rigid:
"""
Analogous to torch.unsqueeze. The dimension is relative to the shared dimensions of the rotation/translation.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed transformation.
"""
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
rots = self._rots.unsqueeze(dim)
trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
return Rigid(rots, trans)
@staticmethod
def cat(ts: Sequence[Rigid], dim: int) -> Rigid:
"""
Concatenates transformations along a new dimension.
Args:
ts:
A list of T objects
dim:
The dimension along which the transformations should be concatenated
Returns:
A concatenated transformation object
"""
rots = Rotation.cat([t._rots for t in ts], dim)
trans = torch.cat([t._trans for t in ts], dim=dim if dim >= 0 else dim - 1)
return Rigid(rots, trans)
def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Rigid:
"""
Applies a Rotation -> Rotation function to the stored rotation object.
Args:
fn: A function of type Rotation -> Rotation
Returns:
A transformation object with a transformed rotation.
"""
return Rigid(fn(self._rots), self._trans)
def apply_trans_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rigid:
"""
Applies a Tensor -> Tensor function to the stored translation.
Args:
fn:
A function of type Tensor -> Tensor to be applied to the translation
Returns:
A transformation object with a transformed translation.
"""
return Rigid(self._rots, fn(self._trans))
def scale_translation(self, trans_scale_factor: float) -> Rigid:
"""
Scales the translation by a constant factor.
Args:
trans_scale_factor:
The constant factor
Returns:
A transformation object with a scaled translation.
"""
return self.apply_trans_fn(lambda t: t * trans_scale_factor)
def stop_rot_gradient(self) -> Rigid:
"""
Detaches the underlying rotation object
Returns:
A transformation object with detached rotations
"""
return self.apply_rot_fn(lambda r: r.detach())
@staticmethod
def make_transform_from_reference(
n_xyz: torch.Tensor, ca_xyz: torch.Tensor, c_xyz: torch.Tensor, eps: float = 1e-20
) -> Rigid:
"""
Constructs a transformation object based on reference points.
Args:
n_xyz:
Tensor representing N atom coordinates
ca_xyz:
Tensor representing C-alpha atom coordinates
c_xyz:
Tensor representing C atom coordinates
eps:
Small value to avoid division by zero (default: 1e-20)
Returns:
A transformation object initialized with the given reference points.
"""
) -> Rigid:
"""
Returns a transformation object from reference coordinates.
Note that this method does not take care of symmetries. If you provide the atom positions in the non-standard
way, the N atom will end up not at [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
need to take care of such cases in your code.
Args:
n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
c_xyz: A [*, 3] tensor of carbon xyz coordinates.
Returns:
A transformation object. After applying the translation and rotation to the reference backbone, the
coordinates will approximately equal to the input coordinates.
"""
translation = -1 * ca_xyz
n_xyz = n_xyz + translation
c_xyz = c_xyz + translation
c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
norm = torch.sqrt(eps + c_x**2 + c_y**2)
sin_c1 = -c_y / norm
cos_c1 = c_x / norm
c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
c1_rots[..., 0, 0] = cos_c1
c1_rots[..., 0, 1] = -1 * sin_c1
c1_rots[..., 1, 0] = sin_c1
c1_rots[..., 1, 1] = cos_c1
c1_rots[..., 2, 2] = 1
norm = torch.sqrt(eps + c_x**2 + c_y**2 + c_z**2)
sin_c2 = c_z / norm
cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm
c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
c2_rots[..., 0, 0] = cos_c2
c2_rots[..., 0, 2] = sin_c2
c2_rots[..., 1, 1] = 1
c2_rots[..., 2, 0] = -1 * sin_c2
c2_rots[..., 2, 2] = cos_c2
c_rots = rot_matmul(c2_rots, c1_rots)
n_xyz = rot_vec_mul(c_rots, n_xyz)
_, n_y, n_z = [n_xyz[..., i] for i in range(3)]
norm = torch.sqrt(eps + n_y**2 + n_z**2)
sin_n = -n_z / norm
cos_n = n_y / norm
n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
n_rots[..., 0, 0] = 1
n_rots[..., 1, 1] = cos_n
n_rots[..., 1, 2] = -1 * sin_n
n_rots[..., 2, 1] = sin_n
n_rots[..., 2, 2] = cos_n
rots = rot_matmul(n_rots, c_rots)
rots = rots.transpose(-1, -2)
translation = -1 * translation
rot_obj = Rotation(rot_mats=rots, quats=None)
return Rigid(rot_obj, translation)
def cuda(self) -> Rigid:
"""
Moves the transformation object to GPU memory
Returns:
A version of the transformation on GPU
"""
return Rigid(self._rots.cuda(), self._trans.cuda())
.\models\esm\openfold_utils\tensor_utils.py
from functools import partial
from typing import Any, Callable, Dict, List, Type, TypeVar, Union, overload
import torch
import torch.nn as nn
import torch.types
def add(m1: torch.Tensor, m2: torch.Tensor, inplace: bool) -> torch.Tensor:
if not inplace:
m1 = m1 + m2
else:
m1 += m2
return m1
def permute_final_dims(tensor: torch.Tensor, inds: List[int]) -> torch.Tensor:
zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
def flatten_final_dims(t: torch.Tensor, no_dims: int) -> torch.Tensor:
return t.reshape(t.shape[:-no_dims] + (-1,))
def masked_mean(mask: torch.Tensor, value: torch.Tensor, dim: int, eps: float = 1e-4) -> torch.Tensor:
mask = mask.expand(*value.shape)
return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
def pts_to_distogram(
pts: torch.Tensor, min_bin: torch.types.Number = 2.3125, max_bin: torch.types.Number = 21.6875, no_bins: int = 64
) -> torch.Tensor:
boundaries = torch.linspace(min_bin, max_bin, no_bins - 1, device=pts.device)
dists = torch.sqrt(torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1))
return torch.bucketize(dists, boundaries)
def dict_multimap(fn: Callable[[list], Any], dicts: List[dict]) -> dict:
first = dicts[0]
new_dict = {}
for k, v in first.items():
all_v = [d[k] for d in dicts]
if isinstance(v, dict):
new_dict[k] = dict_multimap(fn, all_v)
else:
new_dict[k] = fn(all_v)
return new_dict
def one_hot(x: torch.Tensor, v_bins: torch.Tensor) -> torch.Tensor:
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
diffs = x[..., None] - reshaped_bins
am = torch.argmin(torch.abs(diffs), dim=-1)
return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
def batched_gather(data: torch.Tensor, inds: torch.Tensor, dim: int = 0, no_batch_dims: int = 0) -> torch.Tensor:
ranges: List[Union[slice, torch.Tensor]] = []
for i, s in enumerate(data.shape[:no_batch_dims]):
r = torch.arange(s)
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
ranges.append(r)
remaining_dims: List[Union[slice, torch.Tensor]] = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
ranges.extend(remaining_dims)
return data[tuple(ranges)]
T = TypeVar("T")
def dict_map(
fn: Callable[[T], Any], dic: Dict[Any, Union[dict, list, tuple, T]], leaf_type: Type[T]
) -> Dict[Any, Union[dict, list, tuple, Any]]:
new_dict: Dict[Any, Union[dict, list, tuple, Any]] = {}
for k, v in dic.items():
if isinstance(v, dict):
new_dict[k] = dict_map(fn, v, leaf_type)
else:
new_dict[k] = tree_map(fn, v, leaf_type)
return new_dict
@overload
def tree_map(fn: Callable[[T], Any], tree: T, leaf_type: Type[T]) -> Any:
...
@overload
def tree_map(fn: Callable[[T], Any], tree: dict, leaf_type: Type[T]) -> dict:
...
@overload
def tree_map(fn: Callable[[T], Any], tree: list, leaf_type: Type[T]) -> list:
...
@overload
def tree_map(fn: Callable[[T], Any], tree: tuple, leaf_type: Type[T]) -> tuple:
...
def tree_map(fn, tree, leaf_type):
if isinstance(tree, dict):
return dict_map(fn, tree, leaf_type)
elif isinstance(tree, list):
return [tree_map(fn, x, leaf_type) for x in tree]
elif isinstance(tree, tuple):
return tuple(tree_map(fn, x, leaf_type) for x in tree)
elif isinstance(tree, leaf_type):
return fn(tree)
else:
print(type(tree))
raise ValueError("Not supported")
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
.\models\esm\openfold_utils\__init__.py
from .chunk_utils import chunk_layer
from .data_transforms import make_atom14_masks
from .feats import atom14_to_atom37, frames_and_literature_positions_to_atom14_pos, torsion_angles_to_frames
from .loss import compute_predicted_aligned_error, compute_tm
from .protein import Protein as OFProtein
from .protein import to_pdb
from .rigid_utils import Rigid, Rotation
from .tensor_utils import dict_multimap, flatten_final_dims, permute_final_dims
.\models\esm\tokenization_esm.py
"""Tokenization classes for ESM."""
import os
from typing import List, Optional
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"facebook/esm2_t6_8M_UR50D": "https://huggingface.co/facebook/esm2_t6_8M_UR50D/resolve/main/vocab.txt",
"facebook/esm2_t12_35M_UR50D": "https://huggingface.co/facebook/esm2_t12_35M_UR50D/resolve/main/vocab.txt",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/esm2_t6_8M_UR50D": 1024,
"facebook/esm2_t12_35M_UR50D": 1024,
}
def load_vocab_file(vocab_file):
with open(vocab_file, "r") as f:
lines = f.read().splitlines()
return [l.strip() for l in lines]
class EsmTokenizer(PreTrainedTokenizer):
"""
Constructs an ESM tokenizer.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
unk_token="<unk>",
cls_token="<cls>",
pad_token="<pad>",
mask_token="<mask>",
eos_token="<eos>",
**kwargs,
):
self.all_tokens = load_vocab_file(vocab_file)
self._id_to_token = dict(enumerate(self.all_tokens))
self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)}
super().__init__(
unk_token=unk_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
eos_token=eos_token,
**kwargs,
)
self.unique_no_split_tokens = self.all_tokens
self._update_trie(self.unique_no_split_tokens)
def _convert_id_to_token(self, index: int) -> str:
return self._id_to_token.get(index, self.unk_token)
def _convert_token_to_id(self, token: str) -> int:
return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
def _tokenize(self, text, **kwargs):
return text.split()
def get_vocab(self):
base_vocab = self._token_to_id.copy()
base_vocab.update(self.added_tokens_encoder)
return base_vocab
def token_to_id(self, token: str) -> int:
return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
def id_to_token(self, index: int) -> str:
return self._id_to_token.get(index, self.unk_token)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
cls = [self.cls_token_id]
sep = [self.eos_token_id]
if token_ids_1 is None:
if self.eos_token_id is None:
return cls + token_ids_0
else:
return cls + token_ids_0 + sep
elif self.eos_token_id is None:
raise ValueError("Cannot tokenize multiple sequences when EOS token is not set!")
return cls + token_ids_0 + sep + token_ids_1 + sep
def get_special_tokens_mask(
self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
检索没有添加特殊token的token列表的序列id。当使用tokenizer的`prepare_for_model`或`encode_plus`方法添加特殊token时调用此方法。
Args:
token_ids_0 (`List[int]`):
第一个序列的id列表。
token_ids_1 (`List[int]`, *可选*):
第二个序列的id列表。
already_has_special_tokens (`bool`, *可选*, 默认为 `False`):
token列表是否已经格式化包含了模型的特殊token。
Returns:
一个整数列表,范围为[0, 1]:1表示特殊token,0表示序列token。
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return [1 if token in self.all_special_ids else 0 for token in token_ids_0]
mask = [1] + ([0] * len(token_ids_0)) + [1]
if token_ids_1 is not None:
mask += [0] * len(token_ids_1) + [1]
return mask
def save_vocabulary(self, save_directory, filename_prefix):
vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt")
with open(vocab_file, "w") as f:
f.write("\n".join(self.all_tokens))
return (vocab_file,)
@property
def vocab_size(self) -> int:
return len(self.all_tokens)
.\models\esm\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {
"configuration_esm": ["ESM_PRETRAINED_CONFIG_ARCHIVE_MAP", "EsmConfig"],
"tokenization_esm": ["EsmTokenizer"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_esm"] = [
"ESM_PRETRAINED_MODEL_ARCHIVE_LIST",
"EsmForMaskedLM",
"EsmForSequenceClassification",
"EsmForTokenClassification",
"EsmModel",
"EsmPreTrainedModel",
]
_import_structure["modeling_esmfold"] = ["EsmForProteinFolding", "EsmFoldPreTrainedModel"]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_esm"] = [
"TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFEsmForMaskedLM",
"TFEsmForSequenceClassification",
"TFEsmForTokenClassification",
"TFEsmModel",
"TFEsmPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_esm import ESM_PRETRAINED_CONFIG_ARCHIVE_MAP, EsmConfig
from .tokenization_esm import EsmTokenizer
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_esm import (
ESM_PRETRAINED_MODEL_ARCHIVE_LIST,
EsmForMaskedLM,
EsmForSequenceClassification,
EsmForTokenClassification,
EsmModel,
EsmPreTrainedModel,
)
from .modeling_esmfold import EsmFoldPreTrainedModel, EsmForProteinFolding
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_esm import (
TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFEsmForMaskedLM,
TFEsmForSequenceClassification,
TFEsmForTokenClassification,
TFEsmModel,
TFEsmPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
.\models\falcon\configuration_falcon.py
"""
Falcon configuration
"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
FALCON_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"tiiuae/falcon-40b": "https://huggingface.co/tiiuae/falcon-40b/resolve/main/config.json",
"tiiuae/falcon-7b": "https://huggingface.co/tiiuae/falcon-7b/resolve/main/config.json",
}
class FalconConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`FalconModel`]. It is used to instantiate a Falcon
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the
[tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```
>>> from transformers import FalconModel, FalconConfig
>>> # Initializing a small (2-layer) Falcon configuration
>>> configuration = FalconConfig(num_hidden_layers=2)
>>> # Initializing a model from the small configuration
>>> model = FalconModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "falcon"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=65024,
hidden_size=4544,
num_hidden_layers=32,
num_attention_heads=71,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
use_cache=True,
hidden_dropout=0.0,
attention_dropout=0.0,
num_kv_heads=None,
alibi=False,
new_decoder_architecture=False,
multi_query=True,
parallel_attn=True,
bias=False,
max_position_embeddings=2048,
rope_theta=10000.0,
rope_scaling=None,
bos_token_id=11,
eos_token_id=11,
**kwargs,
):
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
layer_norm_epsilon=layer_norm_epsilon,
initializer_range=initializer_range,
use_cache=use_cache,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
num_kv_heads=num_kv_heads,
alibi=alibi,
new_decoder_architecture=new_decoder_architecture,
multi_query=multi_query,
parallel_attn=parallel_attn,
bias=bias,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
):
self.vocab_size = vocab_size
n_embed = kwargs.pop("n_embed", None)
self.hidden_size = hidden_size if n_embed is None else n_embed
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.use_cache = use_cache
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.num_kv_heads = num_attention_heads if num_kv_heads is None else num_kv_heads
self.alibi = alibi
self.new_decoder_architecture = new_decoder_architecture
self.multi_query = multi_query
self.parallel_attn = parallel_attn
self.bias = bias
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@property
def head_dim(self):
return self.hidden_size // self.num_attention_heads
@property
def rotary(self):
return not self.alibi
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if self.alibi:
raise ValueError("`rope_scaling` is not supported when `alibi` is `True`.")
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
.\models\falcon\convert_custom_code_checkpoint.py
import json
from argparse import ArgumentParser
from pathlib import Path
"""
This script converts Falcon custom code checkpoints to modern Falcon checkpoints that use code in the Transformers
library. After conversion, performance (especially for generation) should improve and the checkpoint can be loaded
without needing trust_remote_code=True.
"""
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--checkpoint_dir",
type=Path,
required=True,
help="Directory containing a custom code checkpoint to convert to a modern Falcon checkpoint.",
)
args = parser.parse_args()
if not args.checkpoint_dir.is_dir():
raise ValueError("--checkpoint_dir argument should be a directory!")
if (
not (args.checkpoint_dir / "configuration_RW.py").is_file()
or not (args.checkpoint_dir / "modelling_RW.py").is_file()
):
raise ValueError(
"The model directory should contain configuration_RW.py and modelling_RW.py files! Are you sure this is a custom code checkpoint?"
)
(args.checkpoint_dir / "configuration_RW.py").unlink()
(args.checkpoint_dir / "modelling_RW.py").unlink()
config = args.checkpoint_dir / "config.json"
text = config.read_text()
text = text.replace("RWForCausalLM", "FalconForCausalLM")
text = text.replace("RefinedWebModel", "falcon")
text = text.replace("RefinedWeb", "falcon")
json_config = json.loads(text)
del json_config["auto_map"]
if "n_head" in json_config:
json_config["num_attention_heads"] = json_config.pop("n_head")
if "n_layer" in json_config:
json_config["num_hidden_layers"] = json_config.pop("n_layer")
if "n_head_kv" in json_config:
json_config["num_kv_heads"] = json_config.pop("n_head_kv")
json_config["new_decoder_architecture"] = True
else:
json_config["new_decoder_architecture"] = False
bos_token_id = json_config.get("bos_token_id", 1)
eos_token_id = json_config.get("eos_token_id", 2)
config.unlink()
config.write_text(json.dumps(json_config, indent=2, sort_keys=True))
tokenizer_config = args.checkpoint_dir / "tokenizer_config.json"
if tokenizer_config.is_file():
text = tokenizer_config.read_text()
json_config = json.loads(text)
if json_config["tokenizer_class"] == "PreTrainedTokenizerFast":
json_config["model_input_names"] = ["input_ids", "attention_mask"]
tokenizer_config.unlink()
tokenizer_config.write_text(json.dumps(json_config, indent=2, sort_keys=True))
generation_config_path = args.checkpoint_dir / "generation_config.json"
generation_dict = {
"_from_model_config": True,
"bos_token_id": bos_token_id,
"eos_token_id": eos_token_id,
"transformers_version": "4.33.0.dev0",
}
generation_config_path.write_text(json.dumps(generation_dict, indent=2, sort_keys=True))
print("Done! Please double-check that the new checkpoint works as expected.")
.\models\falcon\modeling_falcon.py
"""PyTorch Falcon model."""
import math
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from torch.nn import functional as F
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_0
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
from .configuration_falcon import FalconConfig
if TYPE_CHECKING:
from ...configuration_utils import PretrainedConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
logger = logging.get_logger(__name__)
FALCON_PRETRAINED_MODEL_ARCHIVE_LIST = [
"tiiuae/falcon-40b",
"tiiuae/falcon-40b-instruct",
"tiiuae/falcon-7b",
"tiiuae/falcon-7b-instruct",
"tiiuae/falcon-rw-7b",
"tiiuae/falcon-rw-1b",
]
_CHECKPOINT_FOR_DOC = "Rocketknight1/falcon-rw-1b"
_CONFIG_FOR_DOC = "FalconConfig"
class FalconLinear(nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
hidden_states = input @ self.weight.T
if self.bias is None:
return hidden_states
return hidden_states + self.bias
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""对查询张量和键张量应用旋转位置嵌入。
Args:
q (`torch.Tensor`): 查询张量。
k (`torch.Tensor`): 键张量。
cos (`torch.Tensor`): 旋转嵌入的余弦部分。
sin (`torch.Tensor`): 旋转嵌入的正弦部分。
position_ids (`torch.Tensor`):
对应于查询和键张量的标记位置索引。例如,当使用KV缓存时,可以传递偏移的位置ID。
unsqueeze_dim (`int`, *optional*, defaults to 1):
'unsqueeze_dim' 参数指定沿其进行展开的维度,以便将 cos[position_ids] 和 sin[position_ids] 正确广播到 q 和 k 的维度。
例如,注意 cos[position_ids] 和 sin[position_ids] 的形状为 [batch_size, seq_len, head_dim]。然后,
如果 q 和 k 的形状为 [batch_size, heads, seq_len, head_dim],设置 unsqueeze_dim=1 使得 cos[position_ids] 和 sin[position_ids]
可以广播到 q 和 k 的形状。类似地,如果 q 和 k 的形状为 [batch_size, seq_len, heads, head_dim],则设置 unsqueeze_dim=2。
Returns:
`tuple(torch.Tensor)`: 包含使用旋转位置嵌入旋转后的查询和键张量。
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def _get_unpad_data(attention_mask):
"""获取未填充数据。
Args:
attention_mask (`torch.Tensor`): 注意力掩码张量。
Returns:
`tuple`: 包含以下三个元素的元组:
- `torch.Tensor`: 指示非填充位置索引的张量。
- `torch.Tensor`: 指示累积序列长度的张量,用于填充。
- `int`: 批次中最大序列长度。
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
class FalconRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
"""使用线性缩放扩展的FalconRotaryEmbedding。由Reddit用户/u/kaiokendev贡献"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
"""使用动态NTK缩放扩展的FalconRotaryEmbedding。由Reddit用户/u/bloc97和/u/emozilla贡献"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
)
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None].bfloat16() * arange_tensor
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
"""
Dropout add function
Args:
x (`torch.tensor`, *required*):
input tensor 输入张量
residual (`torch.tensor`, *required*):
residual tensor 剩余张量
prob (`float`, *required*):
dropout probability dropout概率
training (`bool`, *required*):
training mode 训练模式
"""
out = F.dropout(x, p=prob, training=training)
out = residual + out
return out
class FalconAttention(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.split_size = self.hidden_size
self.hidden_dropout = config.hidden_dropout
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self._use_sdpa = config._attn_implementation == "sdpa"
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
f" {self.num_heads})."
)
if config.rotary:
self._init_rope()
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.beta = self.inv_norm_factor
if config.new_decoder_architecture:
qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
elif config.multi_query:
qkv_out_dim = self.hidden_size + 2 * self.head_dim
else:
qkv_out_dim = 3 * self.hidden_size
self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
self.new_decoder_architecture = config.new_decoder_architecture
self.multi_query = config.multi_query
self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
self.attention_dropout = nn.Dropout(config.attention_dropout)
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = FalconRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = FalconLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = FalconDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
Args:
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
Returns:
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
value: [batch_size, seq_length, num_heads, head_dim]
"""
if self.new_decoder_architecture:
batch, seq_len, _ = fused_qkv.shape
qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
query = qkv[:, :, :, :-2]
key = qkv[:, :, :, [-2]]
value = qkv[:, :, :, [-1]]
key = torch.broadcast_to(key, query.shape)
value = torch.broadcast_to(value, query.shape)
query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
return query, key, value
elif not self.multi_query:
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
else:
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
"""
Merge heads together over the last dimension
Args:
x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
Returns:
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
"""
batch_size_and_num_heads, seq_length, _ = x.shape
batch_size = batch_size_and_num_heads // self.num_heads
x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
x = x.permute(0, 2, 1, 3)
return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
causal = self.is_causal and query_length != 1
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
)
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class FalconMLP(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = FalconLinear(hidden_size, 4 * hidden_size, bias=config.bias)
self.act = nn.GELU()
self.dense_4h_to_h = FalconLinear(4 * hidden_size, hidden_size, bias=config.bias)
self.hidden_dropout = config.hidden_dropout
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.act(self.dense_h_to_4h(x))
x = self.dense_4h_to_h(x)
return x
FALCON_ATTENTION_CLASSES = {
"eager": FalconAttention,
"sdpa": FalconAttention,
"flash_attention_2": FalconFlashAttention2,
}
class FalconDecoderLayer(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = FALCON_ATTENTION_CLASSES[config._attn_implementation](config)
self.mlp = FalconMLP(config)
self.hidden_dropout = config.hidden_dropout
self.config = config
if config.new_decoder_architecture:
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if not config.parallel_attn:
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs,
):
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states
if self.config.new_decoder_architecture:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
attention_layernorm_out = self.input_layernorm(hidden_states)
attn_outputs = self.self_attention(
attention_layernorm_out,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
**kwargs,
)
attention_output = attn_outputs[0]
if not self.config.new_decoder_architecture:
if self.config.parallel_attn:
mlp_layernorm_out = attention_layernorm_out
else:
residual = dropout_add(
attention_output, residual, self.config.attention_dropout, training=self.training
)
mlp_layernorm_out = self.post_attention_layernorm(residual)
outputs = attn_outputs[1:]
mlp_output = self.mlp(mlp_layernorm_out)
if self.config.new_decoder_architecture or self.config.parallel_attn:
mlp_output += attention_output
output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
if use_cache:
outputs = (output,) + outputs
else:
outputs = (output,) + outputs[1:]
return outputs
FALCON_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`FalconConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
FALCON_INPUTS_DOCSTRING = r"""
"""
class FalconPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = FalconConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["FalconDecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module: nn.Module):
"""Initialize the weights."""
if isinstance(module, nn.Linear) or isinstance(module, FalconLinear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig":
if hard_check_only:
if not is_torch_greater_or_equal_than_2_0:
raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.0.")
if not is_torch_greater_or_equal_than_2_0:
return config
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
if _is_bettertransformer:
return config
if not hard_check_only:
config._attn_implementation = "sdpa"
return config
@add_start_docstrings(
"The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
FALCON_START_DOCSTRING,
)
class FalconModel(FalconPreTrainedModel):
def __init__(self, config: FalconConfig):
super().__init__(config)
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.use_alibi = config.alibi
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.word_embeddings
def set_input_embeddings(self, new_embeddings: torch.Tensor):
self.word_embeddings = new_embeddings
@add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@add_start_docstrings(
"The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).",
FALCON_START_DOCSTRING,
)
class FalconForCausalLM(FalconPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: FalconConfig):
super().__init__(config)
self.transformer = FalconModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings: torch.Tensor):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if not self.transformer.use_alibi and attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
return {
"input_ids": input_ids,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
@add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
Output shares the same memory storage as `past`.
"""
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
}
reordered_past = tuple(
(
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
)
for layer_past in past
)
return reordered_past
@add_start_docstrings(
"""
The Falcon Model transformer with a sequence classification head on top (linear layer).
[`FalconForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-1) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
FALCON_START_DOCSTRING,
)
class FalconForSequenceClassification(FalconPreTrainedModel):
def __init__(self, config: FalconConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = FalconModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
self.post_init()
@add_start_docstrings(
"""
Falcon Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
FALCON_START_DOCSTRING,
)
class FalconForTokenClassification(FalconPreTrainedModel):
def __init__(self, config: FalconConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = FalconModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
@add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
batch_size, seq_length = labels.shape
loss_fct = CrossEntropyLoss()
loss = loss_fct(
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
)
if not return_dict:
output = (logits,) + transformer_outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
"""
The Falcon Model transformer with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
FALCON_START_DOCSTRING,
)
class FalconForQuestionAnswering(FalconPreTrainedModel):
"""
Falcon model for question answering tasks, extending FalconPreTrainedModel.
Inherits from FalconPreTrainedModel and implements a transformer with a span classification head
for tasks such as SQuAD. It includes linear layers to compute `span start logits` and `span end logits`.
"""
def __init__(self, config):
"""
Initializes the FalconForQuestionAnswering model.
Args:
config (FalconConfig): Configuration object specifying the model architecture and parameters.
"""
super().__init__(config)
self.transformer = FalconModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.post_init()
@add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Defines the forward pass for FalconForQuestionAnswering.
Args:
input_ids (torch.LongTensor, optional): Input token IDs.
attention_mask (torch.FloatTensor, optional): Mask to avoid performing attention on padding tokens.
head_mask (torch.FloatTensor, optional): Mask to nullify selected heads of the self-attention modules.
inputs_embeds (torch.FloatTensor, optional): Embedded input tokens.
start_positions (torch.LongTensor, optional): Index of the start position for the answer span.
end_positions (torch.LongTensor, optional): Index of the end position for the answer span.
output_attentions (bool, optional): Whether to output attentions weights.
output_hidden_states (bool, optional): Whether to output hidden states.
return_dict (bool, optional): Whether to return a dictionary as the output.
Returns:
FalconForQuestionAnsweringOutput: Output object containing the logits for start and end positions of the span.
"""
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.qa_outputs(outputs[0])
return FalconForQuestionAnsweringOutput(
loss=None if start_positions is None or end_positions is None else self.loss(logits, start_positions, end_positions),
start_logits=logits[:, :, 0] if logits.shape[:2] == 3 else None,
end_logits=logits[:, :, 1] if logits.shape[:2] == 3 else None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)