Architecture ============ Visual overview of the **pallatom** denoising network. Diagrams follow the ``forward()`` call order; tensor shapes use the named-dimension conventions documented in ``pallatom/CLAUDE.md`` (e.g. ``N_atom``, ``N_res``, ``c_pair``, ``K``). MainTrunk (Algorithm 2) ----------------------- :class:`~pallatom.architecture.main_trunk.MainTrunk` takes a :class:`~pallatom.helpers.featurize.FeaturizedBatch` and returns denoised atom positions plus auxiliary sequence and distogram logits. .. mermaid:: flowchart TD classDef inp fill:#dbeafe,stroke:#3b82f6,color:#1e293b,font-size:13px,padding:8px classDef proc fill:#d1fae5,stroke:#059669,color:#1e293b,font-size:13px,padding:8px classDef skip fill:#fef3c7,stroke:#d97706,color:#1e293b,font-size:13px,padding:8px classDef out fill:#fce7f3,stroke:#db2777,color:#1e293b,font-size:13px,padding:8px I1(["r_input [B, N_atom, 3]"]):::inp I2(["f_residue_idx [B, N_res, c_res]"]):::inp I3(["f_distogram [B, N_res, N_res, n_templ_bins] · pseudo_beta_mask [B, N_res]"]):::inp I4(["ref_pos [B, N_atom, 3] · ref_element [B, N_atom, E] · ref_uid [B, N_atom]"]):::inp I5(["t_hat · t_normalized"]):::inp I1 --> P1["r_scaled = r_input / sqrt(sigma_data^2 + t_hat^2) [B, N_atom, 3]"]:::proc I2 --> P2["s_init = Linear(f_residue_idx) [B, N_res, c_res]"]:::proc I5 --> P3["t_i = TimeFourierEmbedding(0.25 * log(t_hat / sigma_data)) [B, N_res, c_res]"]:::proc P2 --> P4["s_init += t_i [B, N_res, c_res]"]:::proc P3 --> P4 P4 --> P5["z_ij = RelativePositionEncoding [B, N_res, N_res, c_pair]"]:::proc I3 --> P6["z_ij += TemplateEmbedder(f_distogram, z_ij, t) [B, N_res, N_res, c_pair]"]:::proc P5 --> P6 P1 --> P7["AtomFeatureEncoder (Algorithm 4)"]:::proc I4 --> P7 P4 --> P7 P6 --> P7 P7 --> SK1(["s_i [B, N_res, c_res]"]):::skip P7 --> SK2(["q_skip [B, N_atom, c_atom] · c_skip [B, N_atom, c_atom]"]):::skip P7 --> SK3(["p_skip [B, N_atom, K, c_atompair]"]):::skip P7 --> SK4(["c_l [B, N_atom, c_atom]"]):::skip SK1 --> P8["s_i += proj(LN(s_init)) [B, N_res, c_res]"]:::proc P4 --> P8 P8 --> DL["Decoder Loop x K_unit"]:::proc SK2 --> DL SK4 --> DL P6 --> DL DL --> O1(["r_denoised [B, N_atom, 3]"]):::out DL --> O2(["f_seq_logits [B, N_res, n_amino]"]):::out DL --> O3(["residue_distogram_logits [B, N_res, N_res, n_bins]"]):::out DL --> O4(["atom_distogram_logits [B, N_atom, K, n_atom_bins]"]):::out ---- Decoder Loop (x K\_unit) ------------------------ Each iteration of the ``K_unit`` loop refines all three representations: residue single ``s_i``, pair ``z_ij``, and atom context ``c_l``. The EDM blend formula ensures the denoised coordinates are always on the correct noise-level manifold. .. mermaid:: flowchart TD classDef proc fill:#d1fae5,stroke:#059669,color:#1e293b,font-size:13px,padding:8px classDef acc fill:#fef3c7,stroke:#d97706,color:#1e293b,font-size:13px,padding:8px classDef out fill:#fce7f3,stroke:#db2777,color:#1e293b,font-size:13px,padding:8px IN(["s_i [B,N_res,c_res] · z_ij [B,N_res,N_res,c_pair] · t_i [B,N_res,c_res] · q_skip · c_skip · c_l [B,N_atom,c_atom] · p_skip [B,N_atom,K,c_atompair] · r_input · r_updates [B,N_atom,3]"]):::acc IN --> NU["NodeUpdate: s_i = Update(s_i, t_i, z_ij) [B, N_res, c_res]"]:::proc NU --> AD["AtomAttentionDecoder: q_update [B,N_atom,c_atom], p_update [B,N_atom,K,c_atompair], r_update [B,N_atom,3], c_l [B,N_atom,c_atom] = Decode(q_skip, p_skip, c_skip, c_l, s_i, z_ij)"]:::proc AD --> RU["r_updates += r_update [B, N_atom, 3]"]:::acc RU --> RD["r_denoised = (sigma^2 * r_input + sigma*t_hat * r_updates) / (sigma^2 + t_hat^2) [B, N_atom, 3]"]:::acc RD --> IH(["intermediate stack: r_denoised_k [B,N_atom,3] · aa_logits_k [B,N_res,n_amino]"]):::out RD --> RC["r_center = r_denoised at center_uid [B, N_res, 3]"]:::acc RC --> PU["PairUpdate: z_ij = Update(z_ij, r_center) [B, N_res, N_res, c_pair]"]:::proc PU -.->|"repeat for next k"| NU PU --> FN(["final: r_denoised [B,N_atom,3] · z_ij [B,N_res,N_res,c_pair] · q_update [B,N_atom,c_atom]"]):::acc ---- AtomFeatureEncoder (Algorithm 4) --------------------------------- Builds per-atom and per-residue embeddings using a sparse :class:`~pallatom.architecture.atom_transformers.AtomTransformer`. All pair tensors are ``[B, N_atom, K, *]`` — the ``N_atom x N_atom`` dense grid is never materialised. .. mermaid:: flowchart TD classDef inp fill:#dbeafe,stroke:#3b82f6,color:#1e293b,font-size:13px,padding:8px classDef proc fill:#d1fae5,stroke:#059669,color:#1e293b,font-size:13px,padding:8px classDef skip fill:#fef3c7,stroke:#d97706,color:#1e293b,font-size:13px,padding:8px classDef out fill:#fce7f3,stroke:#db2777,color:#1e293b,font-size:13px,padding:8px A1(["ref_pos [B, N_atom, 3] · ref_element [B, N_atom, E]"]):::inp A2(["s_input [B, N_res, c_res]"]):::inp A3(["z_input [B, N_res, N_res, c_pair]"]):::inp A4(["r_scaled [B, N_atom, 3]"]):::inp A5(["tok_idx [B, N_atom]"]):::inp A5 --> SP["Build sparse pairs: neighbor_idx [N_atom, K] · valid_mask [B, N_atom, K]"]:::proc A1 --> FR["f_ref = tile(ref_pos, ref_element) per atom [B, N_atom, f_ref_dim]"]:::proc FR --> CL["c_l = Linear(f_ref) [B, N_atom, c_atom]"]:::proc CL --> CS(["c_skip saved [B, N_atom, c_atom]"]):::skip CL & A4 --> QS["q_skip = c_l + proj(r_scaled) [B, N_atom, c_atom]"]:::proc QS --> QSS(["q_skip saved [B, N_atom, c_atom]"]):::skip SP & CL --> PM["p_lm = atom-pair features (distance · chain validity · c_l projections) [B, N_atom, K, c_atompair]"]:::proc A3 --> PM A2 & CL --> CL2["c_l += proj(LN(s_input[tok_idx])) [B, N_atom, c_atom]"]:::proc PM --> PML["p_lm += proj(LN(z_input[tok_l, tok_m])) then p_lm += MLP(p_lm) [B, N_atom, K, c_atompair]"]:::proc PML --> PS(["p_skip saved [B, N_atom, K, c_atompair]"]):::skip QSS & CL2 & PML & SP --> AT["AtomTransformer: 3x AtomTransformerBlock (sparse K-neighbor attention) [B, N_atom, c_atom]"]:::proc AT --> MP["s_i = mean-pool(ReLU(proj(q_skip)), tok_idx) [B, N_res, c_res]"]:::proc MP --> O1(["s_i [B, N_res, c_res]"]):::out AT --> O2(["q_skip [B, N_atom, c_atom]"]):::out CS --> O3(["c_skip [B, N_atom, c_atom]"]):::out PS --> O4(["p_skip [B, N_atom, K, c_atompair]"]):::out CL2 --> O5(["c_l [B, N_atom, c_atom]"]):::out