pallatom.architecture.main_trunk¶
MainTrunk — pure PyTorch implementation Based on Algorithm 2 from the AlphaFold 3 paper.
- Imports from previously implemented modules:
atom_attention_decoder.py (LinearNoBias, AtomTransformer)
atom_feature_encoder.py (AtomFeatureEncoder)
template_embedder.py (TemplateEmbedder)
Inputs¶
batch : FeaturizedBatch (all tensors have leading B dim)
Outputs¶
r_denoised : (B, N_atom, 3) — denoised atom positions f_seq_logits : (B, N_token, 20) — amino-acid sequence logits
Classes¶
Module Contents¶
- class pallatom.architecture.main_trunk.MainTrunk(f_ref_dim: int = 35, n_bins: int = 38, n_atom_bins: int = 22, c_atom: int = 128, c_pair: int = 128, c_res: int = 256, c_atompair: int = 16, n_blocks: int = 2, n_heads: int = 4, sigma_data: float = 16.0, K_unit: int = 3, n_amino: int = 20)¶
Bases:
torch.nn.Module- Parameters:
f_ref_dim (per-atom f^ref feature size (3 + element_dim after tile))
n_bins (distogram bins for TemplateEmbedder)
c_atom (atom single dim (default 128))
c_pair (trunk pair dim (default 128))
c_res (trunk single/residue dim (default 256))
c_atompair (atom-pair dim (default 16))
sigma_data (data noise level (default 16))
K_unit (number of decoder units (default 3))
n_amino (amino-acid vocabulary (default 20))
- forward(batch: helpers.featurize.FeaturizedBatch) tuple[jaxtyping.Float[torch.Tensor, B N_atom 3], jaxtyping.Float[torch.Tensor, B N_res n_amino], jaxtyping.Float[torch.Tensor, B N_res N_res n_bins], jaxtyping.Float[torch.Tensor, B N_atom K n_atom_bins], list[jaxtyping.Float[torch.Tensor, B N_atom 3]], list[jaxtyping.Float[torch.Tensor, B N_res n_amino]]]¶
- K_unit = 3¶
- aa_embedding¶
- atom_decoders¶
- atom_distogram_head¶
- atom_encoder¶
- inter_proj_seq¶
- inter_seq_logits¶
- node_updates¶
- norm_s_init¶
- pair_updates¶
- proj_residue_idx¶
- proj_s_init¶
- proj_seq¶
- rel_pos_enc¶
- residue_distogram_head¶
- seq_logits¶
- sigma_data = 16.0¶
- template_embedder¶
- time_fourier¶