pallatom.architecture.main_trunk

MainTrunk — pure PyTorch implementation Based on Algorithm 2 from the AlphaFold 3 paper.

Imports from previously implemented modules:
  • atom_attention_decoder.py (LinearNoBias, AtomTransformer)

  • atom_feature_encoder.py (AtomFeatureEncoder)

  • template_embedder.py (TemplateEmbedder)

Inputs

batch : FeaturizedBatch (all tensors have leading B dim)

Outputs

r_denoised : (B, N_atom, 3) — denoised atom positions f_seq_logits : (B, N_token, 20) — amino-acid sequence logits

Classes

Module Contents

class pallatom.architecture.main_trunk.MainTrunk(f_ref_dim: int = 35, n_bins: int = 38, n_atom_bins: int = 22, c_atom: int = 128, c_pair: int = 128, c_res: int = 256, c_atompair: int = 16, n_blocks: int = 2, n_heads: int = 4, sigma_data: float = 16.0, K_unit: int = 3, n_amino: int = 20)

Bases: torch.nn.Module

Parameters:
  • f_ref_dim (per-atom f^ref feature size (3 + element_dim after tile))

  • n_bins (distogram bins for TemplateEmbedder)

  • c_atom (atom single dim (default 128))

  • c_pair (trunk pair dim (default 128))

  • c_res (trunk single/residue dim (default 256))

  • c_atompair (atom-pair dim (default 16))

  • sigma_data (data noise level (default 16))

  • K_unit (number of decoder units (default 3))

  • n_amino (amino-acid vocabulary (default 20))

forward(batch: helpers.featurize.FeaturizedBatch) tuple[jaxtyping.Float[torch.Tensor, B N_atom 3], jaxtyping.Float[torch.Tensor, B N_res n_amino], jaxtyping.Float[torch.Tensor, B N_res N_res n_bins], jaxtyping.Float[torch.Tensor, B N_atom K n_atom_bins], list[jaxtyping.Float[torch.Tensor, B N_atom 3]], list[jaxtyping.Float[torch.Tensor, B N_res n_amino]]]
K_unit = 3
aa_embedding
atom_decoders
atom_distogram_head
atom_encoder
inter_proj_seq
inter_seq_logits
node_updates
norm_s_init
pair_updates
proj_residue_idx
proj_s_init
proj_seq
rel_pos_enc
residue_distogram_head
seq_logits
sigma_data = 16.0
template_embedder
time_fourier