pallatom.architecture.atom_transformers¶
Attributes¶
Classes¶
Decodes trunk embeddings back to per-atom position updates. |
|
Encodes per-atom reference features into per-residue embeddings. |
|
Stack of n_blocks AtomTransformerBlocks, all operating on sparse [B, N, K, *] pairs. |
|
Single block of the AtomTransformer with a 32-residue local attention window. |
|
Linear layer with no bias term. |
Functions¶
|
For each atom l, collect indices of all atoms m within the 32-residue window. |
Module Contents¶
- class pallatom.architecture.atom_transformers.AtomAttentionDecoder(c_token: int, c_pair: int, c_atom: int, c_atompair: int, n_blocks: int = 3, n_heads: int = 4, window_size: int = WINDOW_SIZE)¶
Bases:
torch.nn.ModuleDecodes trunk embeddings back to per-atom position updates.
Like the encoder, all pair tensors are sparse [B, N, K, *]. Builds its own neighbour index from tok_idx; no dense mask required.
- Parameters:
c_token (trunk single dim (s))
c_pair (trunk pair dim (z))
c_atom (atom single dim (q, c))
c_atompair (atom-pair dim (p))
n_blocks (AtomTransformer blocks (default 3))
n_heads (attention heads (default 4))
window_size (local window in residues (default 32))
shapes (Weight) – proj_s_q / proj_s_c : [c_token, c_atom] (no bias) proj_z : [c_pair, c_atompair] (no bias) proj_r : [c_atom, 3] (no bias) mlp_p layers : [c_atompair, c_atompair] each (no bias)
- forward(q_skip: jaxtyping.Float[torch.Tensor, B N_atom c_atom], p_skip: jaxtyping.Float[torch.Tensor, B N_atom K c_atompair], c_skip: jaxtyping.Float[torch.Tensor, B N_atom c_atom], c: jaxtyping.Float[torch.Tensor, B N_atom c_atom], s: jaxtyping.Float[torch.Tensor, B N_res c_token], z: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair], tok_idx: jaxtyping.Int[torch.Tensor, B N_atom]) tuple[jaxtyping.Float[torch.Tensor, B N_atom c_atom], jaxtyping.Float[torch.Tensor, B N_atom K c_atompair], jaxtyping.Float[torch.Tensor, B N_atom 3], jaxtyping.Float[torch.Tensor, B N_atom c_atom]]¶
- Parameters:
q_skip – atom skip queries [B, N_atom, c_atom]
p_skip – sparse pair skip [B, N_atom, K, c_atompair] (from encoder)
c_skip – atom skip context [B, N_atom, c_atom]
c – atom context [B, N_atom, c_atom]
s – trunk single embeds [B, N_res, c_token]
z – trunk pair embeds [B, N_res, N_res, c_pair]
tok_idx – residue index per atom [B, N_atom]
- Returns:
atom query embeddings [B, N_atom, c_atom] p : sparse pair embeddings [B, N_atom, K, c_atompair] r_update : per-atom position update [B, N_atom, 3] c_out : updated atom context [B, N_atom, c_atom]
- Return type:
q
- mlp_p¶
- norm_q_out¶
- norm_s_c¶
- norm_s_q¶
- norm_z¶
- proj_r¶
- proj_s_c¶
- proj_s_q¶
- proj_z¶
- transformer¶
- window_size = 32¶
- class pallatom.architecture.atom_transformers.AtomFeatureEncoder(f_ref_dim: int, c_token: int, c_pair: int, c: int, d: int, m: int, n_blocks: int, n_heads: int, window_size: int = WINDOW_SIZE)¶
Bases:
torch.nn.ModuleEncodes per-atom reference features into per-residue embeddings.
All N×N pair tensors (d_lm, p_lm, z_gathered) are replaced by sparse [B, N, K, *] tensors indexed over the N × K live pairs within the 32-residue window. K is the maximum number of window-neighbours any atom has.
- Parameters:
f_ref_dim (per-atom f^ref feature size (ref_pos_dim + ref_element_dim))
c_token (trunk single dim (s_input))
c_pair (trunk pair dim (z_input))
c (output per-residue dim)
d (atom-pair embedding dim)
m (atom single embedding dim)
n_blocks (AtomTransformer blocks)
n_heads (AtomTransformer heads)
window_size (local attention window in residues (default 32))
shapes (Weight) – proj_fref_c : [f_ref_dim, m] proj_d_vec : [3, d] proj_inv_sq : [1, d] proj_v : [1, d] proj_cl_pair / cm_pair : [m, d] proj_r_scaled : [3, m] proj_s_init : [c_token, m] proj_z_init : [c_pair, d] mlp_p layers : [d, d] each proj_agg : [m, c]
- forward(ref_pos: jaxtyping.Float[torch.Tensor, B N_atom 3], ref_element: jaxtyping.Float[torch.Tensor, B N_atom E], ref_space_uid: jaxtyping.Int[torch.Tensor, B N_atom], s_input: jaxtyping.Float[torch.Tensor, B N_res c_token], z_input: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair], r_scaled: jaxtyping.Float[torch.Tensor, B N_atom 3], tok_idx: jaxtyping.Int[torch.Tensor, B N_atom]) tuple[jaxtyping.Float[torch.Tensor, B N_res c_res], jaxtyping.Float[torch.Tensor, B N_atom c_atom], jaxtyping.Float[torch.Tensor, B N_atom c_atom], jaxtyping.Float[torch.Tensor, B N_atom K c_atompair], jaxtyping.Float[torch.Tensor, B N_atom c_atom]]¶
- Parameters:
ref_pos – [B, N_atom, 3] reference atom positions
ref_element – [B, N_atom, E] element one-hot features
ref_space_uid – [B, N_atom] chain/space identifier per atom
s_input – [B, N_res, c_token] trunk single embeddings
z_input – [B, N_res, N_res, c_pair] trunk pair embeddings
r_scaled – [B, N_atom, 3] scaled noisy positions
tok_idx – [B, N_atom] residue index per atom (0-based)
- Returns (all sparse where applicable):
s_i : [B, N_res, c] aggregated residue embeddings q_skip : [B, N_atom, m] atom skip queries (post-transformer) c_skip : [B, N_atom, m] atom skip context p_skip : [B, N_atom, K, d] sparse atom-pair skip embeddings c_l : [B, N_atom, m] updated atom context
- m¶
- mlp_p¶
- norm_s_init¶
- norm_z_init¶
- proj_agg¶
- proj_cl_pair¶
- proj_cm_pair¶
- proj_d_vec¶
- proj_fref_c¶
- proj_inv_sq¶
- proj_r_scaled¶
- proj_s_init¶
- proj_v¶
- proj_z_init¶
- transformer¶
- window_size = 32¶
- class pallatom.architecture.atom_transformers.AtomTransformer(c_atom: int, c_atompair: int, n_blocks: int = 3, n_heads: int = 4, window_size: int = WINDOW_SIZE)¶
Bases:
torch.nn.ModuleStack of n_blocks AtomTransformerBlocks, all operating on sparse [B, N, K, *] pairs.
- forward(q: jaxtyping.Float[torch.Tensor, B N c_atom], c: jaxtyping.Float[torch.Tensor, B N c_atom], p: jaxtyping.Float[torch.Tensor, B N K c_atompair], neighbor_idx: jaxtyping.Int[torch.Tensor, N K], valid_mask: jaxtyping.Bool[torch.Tensor, B N K]) jaxtyping.Float[torch.Tensor, B N c_atom]¶
- Parameters:
q – atom query embeddings [B, N, c_atom]
c – atom context embeddings [B, N, c_atom]
p – sparse pair embeddings [B, N, K, c_atompair]
neighbor_idx – neighbour atom indices [N, K]
valid_mask – live-pair mask [B, N, K]
- Returns:
updated atom embeddings [B, N, c_atom]
- Return type:
q
- blocks¶
- window_size = 32¶
- class pallatom.architecture.atom_transformers.AtomTransformerBlock(c_atom: int, c_atompair: int, n_heads: int, window_size: int = WINDOW_SIZE)¶
Bases:
torch.nn.ModuleSingle block of the AtomTransformer with a 32-residue local attention window.
All pair tensors are sparse [B, N, K, *] — no N² allocation.
- Weight shapes:
to_q / to_k / to_v / to_out : [c_atom, c_atom] (no bias) pair_bias : [c_atompair, n_heads] (no bias) ff1 : [c_atom, c_atom * 4] (no bias) ff2 : [c_atom * 4, c_atom] (no bias)
- Sparse attention shapes (window_size=32, K = live neighbours per atom):
Q, K_proj, V : [B, N, n_heads, head_dim] K_nbr, V_nbr : [B, N, K, n_heads, head_dim] — gathered over K neighbours scores : [B, N, K, n_heads] — O(N × K), not O(N²) pair bias : [B, N, K, n_heads]
- forward(q: jaxtyping.Float[torch.Tensor, B N c_atom], c: jaxtyping.Float[torch.Tensor, B N c_atom], p: jaxtyping.Float[torch.Tensor, B N K c_atompair], neighbor_idx: jaxtyping.Int[torch.Tensor, N K], valid_mask: jaxtyping.Bool[torch.Tensor, B N K]) jaxtyping.Float[torch.Tensor, B N c_atom]¶
- Parameters:
q – atom query embeddings [B, N, c_atom]
c – atom context embeddings [B, N, c_atom]
p – sparse pair embeddings [B, N, K, c_atompair] p[b, l, k] = features for pair (l, neighbor_idx[l, k])
neighbor_idx – window neighbour indices [N, K]
valid_mask – True = live pair [B, N, K] encodes both window and chain constraints
- Returns:
updated atom embeddings [B, N, c_atom]
- Return type:
q
- ff1¶
- ff2¶
- head_dim¶
- n_heads¶
- norm_c¶
- norm_ff¶
- norm_pair¶
- norm_q¶
- pair_bias¶
- to_k¶
- to_out¶
- to_q¶
- to_v¶
- window_size = 32¶
- class pallatom.architecture.atom_transformers.LinearNoBias(in_features: int, out_features: int)¶
Bases:
torch.nn.LinearLinear layer with no bias term.
- pallatom.architecture.atom_transformers.build_sparse_pairs(tok_idx: jaxtyping.Int[torch.Tensor, build_sparse_pairs.N], window_size: int = WINDOW_SIZE) tuple[jaxtyping.Int[torch.Tensor, N K], jaxtyping.Bool[torch.Tensor, N K]]¶
For each atom l, collect indices of all atoms m within the 32-residue window.
Uses an O(N²) boolean intermediate (one byte per pair) to build the index, then returns O(N × K) long + bool tensors where K << N for large proteins.
- Parameters:
tok_idx – [N] residue index (0-based) for each atom
window_size – total window span in residues (default 32) atom l attends to m iff
|tok_idx[l] - tok_idx[m]|< window_size // 2
- Returns:
[N, K] atom indices of each neighbour; padding slots → 0 valid_mask : [N, K] True where the slot is a real neighbour (not padding)
- Return type:
neighbor_idx
K = maximum neighbours any single atom has; varies with sequence length and atoms-per-residue. For a 32-residue window with ~14 atoms/residue, K ≈ 448.