pallatom.architecture.atom_transformers

Attributes

Classes

AtomAttentionDecoder

Decodes trunk embeddings back to per-atom position updates.

AtomFeatureEncoder

Encodes per-atom reference features into per-residue embeddings.

AtomTransformer

Stack of n_blocks AtomTransformerBlocks, all operating on sparse [B, N, K, *] pairs.

AtomTransformerBlock

Single block of the AtomTransformer with a 32-residue local attention window.

LinearNoBias

Linear layer with no bias term.

Functions

build_sparse_pairs(→ tuple[jaxtyping.Int[torch.Tensor, ...)

For each atom l, collect indices of all atoms m within the 32-residue window.

Module Contents

class pallatom.architecture.atom_transformers.AtomAttentionDecoder(c_token: int, c_pair: int, c_atom: int, c_atompair: int, n_blocks: int = 3, n_heads: int = 4, window_size: int = WINDOW_SIZE)

Bases: torch.nn.Module

Decodes trunk embeddings back to per-atom position updates.

Like the encoder, all pair tensors are sparse [B, N, K, *]. Builds its own neighbour index from tok_idx; no dense mask required.

Parameters:
  • c_token (trunk single dim (s))

  • c_pair (trunk pair dim (z))

  • c_atom (atom single dim (q, c))

  • c_atompair (atom-pair dim (p))

  • n_blocks (AtomTransformer blocks (default 3))

  • n_heads (attention heads (default 4))

  • window_size (local window in residues (default 32))

  • shapes (Weight) – proj_s_q / proj_s_c : [c_token, c_atom] (no bias) proj_z : [c_pair, c_atompair] (no bias) proj_r : [c_atom, 3] (no bias) mlp_p layers : [c_atompair, c_atompair] each (no bias)

forward(q_skip: jaxtyping.Float[torch.Tensor, B N_atom c_atom], p_skip: jaxtyping.Float[torch.Tensor, B N_atom K c_atompair], c_skip: jaxtyping.Float[torch.Tensor, B N_atom c_atom], c: jaxtyping.Float[torch.Tensor, B N_atom c_atom], s: jaxtyping.Float[torch.Tensor, B N_res c_token], z: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair], tok_idx: jaxtyping.Int[torch.Tensor, B N_atom]) tuple[jaxtyping.Float[torch.Tensor, B N_atom c_atom], jaxtyping.Float[torch.Tensor, B N_atom K c_atompair], jaxtyping.Float[torch.Tensor, B N_atom 3], jaxtyping.Float[torch.Tensor, B N_atom c_atom]]
Parameters:
  • q_skip – atom skip queries [B, N_atom, c_atom]

  • p_skip – sparse pair skip [B, N_atom, K, c_atompair] (from encoder)

  • c_skip – atom skip context [B, N_atom, c_atom]

  • c – atom context [B, N_atom, c_atom]

  • s – trunk single embeds [B, N_res, c_token]

  • z – trunk pair embeds [B, N_res, N_res, c_pair]

  • tok_idx – residue index per atom [B, N_atom]

Returns:

atom query embeddings [B, N_atom, c_atom] p : sparse pair embeddings [B, N_atom, K, c_atompair] r_update : per-atom position update [B, N_atom, 3] c_out : updated atom context [B, N_atom, c_atom]

Return type:

q

mlp_p
norm_q_out
norm_s_c
norm_s_q
norm_z
proj_r
proj_s_c
proj_s_q
proj_z
transformer
window_size = 32
class pallatom.architecture.atom_transformers.AtomFeatureEncoder(f_ref_dim: int, c_token: int, c_pair: int, c: int, d: int, m: int, n_blocks: int, n_heads: int, window_size: int = WINDOW_SIZE)

Bases: torch.nn.Module

Encodes per-atom reference features into per-residue embeddings.

All N×N pair tensors (d_lm, p_lm, z_gathered) are replaced by sparse [B, N, K, *] tensors indexed over the N × K live pairs within the 32-residue window. K is the maximum number of window-neighbours any atom has.

Parameters:
  • f_ref_dim (per-atom f^ref feature size (ref_pos_dim + ref_element_dim))

  • c_token (trunk single dim (s_input))

  • c_pair (trunk pair dim (z_input))

  • c (output per-residue dim)

  • d (atom-pair embedding dim)

  • m (atom single embedding dim)

  • n_blocks (AtomTransformer blocks)

  • n_heads (AtomTransformer heads)

  • window_size (local attention window in residues (default 32))

  • shapes (Weight) – proj_fref_c : [f_ref_dim, m] proj_d_vec : [3, d] proj_inv_sq : [1, d] proj_v : [1, d] proj_cl_pair / cm_pair : [m, d] proj_r_scaled : [3, m] proj_s_init : [c_token, m] proj_z_init : [c_pair, d] mlp_p layers : [d, d] each proj_agg : [m, c]

forward(ref_pos: jaxtyping.Float[torch.Tensor, B N_atom 3], ref_element: jaxtyping.Float[torch.Tensor, B N_atom E], ref_space_uid: jaxtyping.Int[torch.Tensor, B N_atom], s_input: jaxtyping.Float[torch.Tensor, B N_res c_token], z_input: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair], r_scaled: jaxtyping.Float[torch.Tensor, B N_atom 3], tok_idx: jaxtyping.Int[torch.Tensor, B N_atom]) tuple[jaxtyping.Float[torch.Tensor, B N_res c_res], jaxtyping.Float[torch.Tensor, B N_atom c_atom], jaxtyping.Float[torch.Tensor, B N_atom c_atom], jaxtyping.Float[torch.Tensor, B N_atom K c_atompair], jaxtyping.Float[torch.Tensor, B N_atom c_atom]]
Parameters:
  • ref_pos – [B, N_atom, 3] reference atom positions

  • ref_element – [B, N_atom, E] element one-hot features

  • ref_space_uid – [B, N_atom] chain/space identifier per atom

  • s_input – [B, N_res, c_token] trunk single embeddings

  • z_input – [B, N_res, N_res, c_pair] trunk pair embeddings

  • r_scaled – [B, N_atom, 3] scaled noisy positions

  • tok_idx – [B, N_atom] residue index per atom (0-based)

Returns (all sparse where applicable):

s_i : [B, N_res, c] aggregated residue embeddings q_skip : [B, N_atom, m] atom skip queries (post-transformer) c_skip : [B, N_atom, m] atom skip context p_skip : [B, N_atom, K, d] sparse atom-pair skip embeddings c_l : [B, N_atom, m] updated atom context

m
mlp_p
norm_s_init
norm_z_init
proj_agg
proj_cl_pair
proj_cm_pair
proj_d_vec
proj_fref_c
proj_inv_sq
proj_r_scaled
proj_s_init
proj_v
proj_z_init
transformer
window_size = 32
class pallatom.architecture.atom_transformers.AtomTransformer(c_atom: int, c_atompair: int, n_blocks: int = 3, n_heads: int = 4, window_size: int = WINDOW_SIZE)

Bases: torch.nn.Module

Stack of n_blocks AtomTransformerBlocks, all operating on sparse [B, N, K, *] pairs.

forward(q: jaxtyping.Float[torch.Tensor, B N c_atom], c: jaxtyping.Float[torch.Tensor, B N c_atom], p: jaxtyping.Float[torch.Tensor, B N K c_atompair], neighbor_idx: jaxtyping.Int[torch.Tensor, N K], valid_mask: jaxtyping.Bool[torch.Tensor, B N K]) jaxtyping.Float[torch.Tensor, B N c_atom]
Parameters:
  • q – atom query embeddings [B, N, c_atom]

  • c – atom context embeddings [B, N, c_atom]

  • p – sparse pair embeddings [B, N, K, c_atompair]

  • neighbor_idx – neighbour atom indices [N, K]

  • valid_mask – live-pair mask [B, N, K]

Returns:

updated atom embeddings [B, N, c_atom]

Return type:

q

blocks
window_size = 32
class pallatom.architecture.atom_transformers.AtomTransformerBlock(c_atom: int, c_atompair: int, n_heads: int, window_size: int = WINDOW_SIZE)

Bases: torch.nn.Module

Single block of the AtomTransformer with a 32-residue local attention window.

All pair tensors are sparse [B, N, K, *] — no N² allocation.

Weight shapes:

to_q / to_k / to_v / to_out : [c_atom, c_atom] (no bias) pair_bias : [c_atompair, n_heads] (no bias) ff1 : [c_atom, c_atom * 4] (no bias) ff2 : [c_atom * 4, c_atom] (no bias)

Sparse attention shapes (window_size=32, K = live neighbours per atom):

Q, K_proj, V : [B, N, n_heads, head_dim] K_nbr, V_nbr : [B, N, K, n_heads, head_dim] — gathered over K neighbours scores : [B, N, K, n_heads] — O(N × K), not O(N²) pair bias : [B, N, K, n_heads]

forward(q: jaxtyping.Float[torch.Tensor, B N c_atom], c: jaxtyping.Float[torch.Tensor, B N c_atom], p: jaxtyping.Float[torch.Tensor, B N K c_atompair], neighbor_idx: jaxtyping.Int[torch.Tensor, N K], valid_mask: jaxtyping.Bool[torch.Tensor, B N K]) jaxtyping.Float[torch.Tensor, B N c_atom]
Parameters:
  • q – atom query embeddings [B, N, c_atom]

  • c – atom context embeddings [B, N, c_atom]

  • p – sparse pair embeddings [B, N, K, c_atompair] p[b, l, k] = features for pair (l, neighbor_idx[l, k])

  • neighbor_idx – window neighbour indices [N, K]

  • valid_mask – True = live pair [B, N, K] encodes both window and chain constraints

Returns:

updated atom embeddings [B, N, c_atom]

Return type:

q

ff1
ff2
head_dim
n_heads
norm_c
norm_ff
norm_pair
norm_q
pair_bias
to_k
to_out
to_q
to_v
window_size = 32
class pallatom.architecture.atom_transformers.LinearNoBias(in_features: int, out_features: int)

Bases: torch.nn.Linear

Linear layer with no bias term.

pallatom.architecture.atom_transformers.build_sparse_pairs(tok_idx: jaxtyping.Int[torch.Tensor, build_sparse_pairs.N], window_size: int = WINDOW_SIZE) tuple[jaxtyping.Int[torch.Tensor, N K], jaxtyping.Bool[torch.Tensor, N K]]

For each atom l, collect indices of all atoms m within the 32-residue window.

Uses an O(N²) boolean intermediate (one byte per pair) to build the index, then returns O(N × K) long + bool tensors where K << N for large proteins.

Parameters:
  • tok_idx – [N] residue index (0-based) for each atom

  • window_size – total window span in residues (default 32) atom l attends to m iff |tok_idx[l] - tok_idx[m]| < window_size // 2

Returns:

[N, K] atom indices of each neighbour; padding slots → 0 valid_mask : [N, K] True where the slot is a real neighbour (not padding)

Return type:

neighbor_idx

K = maximum neighbours any single atom has; varies with sequence length and atoms-per-residue. For a 32-residue window with ~14 atoms/residue, K ≈ 448.

pallatom.architecture.atom_transformers.WINDOW_SIZE: int = 32