pallatom.architecture.main_trunk ================================ .. py:module:: pallatom.architecture.main_trunk .. autoapi-nested-parse:: MainTrunk — pure PyTorch implementation Based on Algorithm 2 from the AlphaFold 3 paper. Imports from previously implemented modules: - atom_attention_decoder.py (LinearNoBias, AtomTransformer) - atom_feature_encoder.py (AtomFeatureEncoder) - template_embedder.py (TemplateEmbedder) Inputs ------ batch : FeaturizedBatch (all tensors have leading B dim) Outputs ------- r_denoised : (B, N_atom, 3) — denoised atom positions f_seq_logits : (B, N_token, 20) — amino-acid sequence logits Classes ------- .. autoapisummary:: pallatom.architecture.main_trunk.MainTrunk Module Contents --------------- .. py:class:: MainTrunk(f_ref_dim: int = 35, n_bins: int = 38, n_atom_bins: int = 22, c_atom: int = 128, c_pair: int = 128, c_res: int = 256, c_atompair: int = 16, n_blocks: int = 2, n_heads: int = 4, sigma_data: float = 16.0, K_unit: int = 3, n_amino: int = 20) Bases: :py:obj:`torch.nn.Module` :param f_ref_dim: :type f_ref_dim: per-atom f^ref feature size (3 + element_dim after tile) :param n_bins: :type n_bins: distogram bins for TemplateEmbedder :param c_atom: :type c_atom: atom single dim (default 128) :param c_pair: :type c_pair: trunk pair dim (default 128) :param c_res: :type c_res: trunk single/residue dim (default 256) :param c_atompair: :type c_atompair: atom-pair dim (default 16) :param sigma_data: :type sigma_data: data noise level (default 16) :param K_unit: :type K_unit: number of decoder units (default 3) :param n_amino: :type n_amino: amino-acid vocabulary (default 20) .. py:method:: forward(batch: helpers.featurize.FeaturizedBatch) -> tuple[jaxtyping.Float[torch.Tensor, B N_atom 3], jaxtyping.Float[torch.Tensor, B N_res n_amino], jaxtyping.Float[torch.Tensor, B N_res N_res n_bins], jaxtyping.Float[torch.Tensor, B N_atom K n_atom_bins], list[jaxtyping.Float[torch.Tensor, B N_atom 3]], list[jaxtyping.Float[torch.Tensor, B N_res n_amino]]] .. py:attribute:: K_unit :value: 3 .. py:attribute:: aa_embedding .. py:attribute:: atom_decoders .. py:attribute:: atom_distogram_head .. py:attribute:: atom_encoder .. py:attribute:: inter_proj_seq .. py:attribute:: inter_seq_logits .. py:attribute:: node_updates .. py:attribute:: norm_s_init .. py:attribute:: pair_updates .. py:attribute:: proj_residue_idx .. py:attribute:: proj_s_init .. py:attribute:: proj_seq .. py:attribute:: rel_pos_enc .. py:attribute:: residue_distogram_head .. py:attribute:: seq_logits .. py:attribute:: sigma_data :value: 16.0 .. py:attribute:: template_embedder .. py:attribute:: time_fourier