Architecture

Visual overview of the pallatom denoising network. Diagrams follow the forward() call order; tensor shapes use the named-dimension conventions documented in pallatom/CLAUDE.md (e.g. N_atom, N_res, c_pair, K).

MainTrunk (Algorithm 2)

MainTrunk takes a FeaturizedBatch and returns denoised atom positions plus auxiliary sequence and distogram logits.

        flowchart TD
    classDef inp  fill:#dbeafe,stroke:#3b82f6,color:#1e293b,font-size:13px,padding:8px
    classDef proc fill:#d1fae5,stroke:#059669,color:#1e293b,font-size:13px,padding:8px
    classDef skip fill:#fef3c7,stroke:#d97706,color:#1e293b,font-size:13px,padding:8px
    classDef out  fill:#fce7f3,stroke:#db2777,color:#1e293b,font-size:13px,padding:8px

    I1(["r_input  [B, N_atom, 3]"]):::inp
    I2(["f_residue_idx  [B, N_res, c_res]"]):::inp
    I3(["f_distogram [B, N_res, N_res, n_templ_bins]  ·  pseudo_beta_mask [B, N_res]"]):::inp
    I4(["ref_pos [B, N_atom, 3]  ·  ref_element [B, N_atom, E]  ·  ref_uid [B, N_atom]"]):::inp
    I5(["t_hat  ·  t_normalized"]):::inp

    I1  --> P1["r_scaled = r_input / sqrt(sigma_data^2 + t_hat^2)  [B, N_atom, 3]"]:::proc
    I2  --> P2["s_init = Linear(f_residue_idx)  [B, N_res, c_res]"]:::proc
    I5  --> P3["t_i = TimeFourierEmbedding(0.25 * log(t_hat / sigma_data))  [B, N_res, c_res]"]:::proc
    P2  --> P4["s_init += t_i  [B, N_res, c_res]"]:::proc
    P3  --> P4

    P4  --> P5["z_ij = RelativePositionEncoding  [B, N_res, N_res, c_pair]"]:::proc
    I3  --> P6["z_ij += TemplateEmbedder(f_distogram, z_ij, t)  [B, N_res, N_res, c_pair]"]:::proc
    P5  --> P6

    P1  --> P7["AtomFeatureEncoder  (Algorithm 4)"]:::proc
    I4  --> P7
    P4  --> P7
    P6  --> P7
    P7  --> SK1(["s_i  [B, N_res, c_res]"]):::skip
    P7  --> SK2(["q_skip  [B, N_atom, c_atom]  ·  c_skip  [B, N_atom, c_atom]"]):::skip
    P7  --> SK3(["p_skip  [B, N_atom, K, c_atompair]"]):::skip
    P7  --> SK4(["c_l  [B, N_atom, c_atom]"]):::skip

    SK1 --> P8["s_i += proj(LN(s_init))  [B, N_res, c_res]"]:::proc
    P4  --> P8

    P8  --> DL["Decoder Loop x K_unit"]:::proc
    SK2 --> DL
    SK4 --> DL
    P6  --> DL

    DL  --> O1(["r_denoised  [B, N_atom, 3]"]):::out
    DL  --> O2(["f_seq_logits  [B, N_res, n_amino]"]):::out
    DL  --> O3(["residue_distogram_logits  [B, N_res, N_res, n_bins]"]):::out
    DL  --> O4(["atom_distogram_logits  [B, N_atom, K, n_atom_bins]"]):::out
    

Decoder Loop (x K_unit)

Each iteration of the K_unit loop refines all three representations: residue single s_i, pair z_ij, and atom context c_l. The EDM blend formula ensures the denoised coordinates are always on the correct noise-level manifold.

        flowchart TD
    classDef proc fill:#d1fae5,stroke:#059669,color:#1e293b,font-size:13px,padding:8px
    classDef acc  fill:#fef3c7,stroke:#d97706,color:#1e293b,font-size:13px,padding:8px
    classDef out  fill:#fce7f3,stroke:#db2777,color:#1e293b,font-size:13px,padding:8px

    IN(["s_i [B,N_res,c_res] · z_ij [B,N_res,N_res,c_pair] · t_i [B,N_res,c_res] · q_skip · c_skip · c_l [B,N_atom,c_atom] · p_skip [B,N_atom,K,c_atompair] · r_input · r_updates [B,N_atom,3]"]):::acc

    IN  --> NU["NodeUpdate: s_i = Update(s_i, t_i, z_ij)  [B, N_res, c_res]"]:::proc
    NU  --> AD["AtomAttentionDecoder: q_update [B,N_atom,c_atom], p_update [B,N_atom,K,c_atompair], r_update [B,N_atom,3], c_l [B,N_atom,c_atom] = Decode(q_skip, p_skip, c_skip, c_l, s_i, z_ij)"]:::proc
    AD  --> RU["r_updates += r_update  [B, N_atom, 3]"]:::acc
    RU  --> RD["r_denoised = (sigma^2 * r_input + sigma*t_hat * r_updates) / (sigma^2 + t_hat^2)  [B, N_atom, 3]"]:::acc
    RD  --> IH(["intermediate stack: r_denoised_k [B,N_atom,3] · aa_logits_k [B,N_res,n_amino]"]):::out
    RD  --> RC["r_center = r_denoised at center_uid  [B, N_res, 3]"]:::acc
    RC  --> PU["PairUpdate: z_ij = Update(z_ij, r_center)  [B, N_res, N_res, c_pair]"]:::proc
    PU  -.->|"repeat for next k"| NU

    PU  --> FN(["final: r_denoised [B,N_atom,3] · z_ij [B,N_res,N_res,c_pair] · q_update [B,N_atom,c_atom]"]):::acc
    

AtomFeatureEncoder (Algorithm 4)

Builds per-atom and per-residue embeddings using a sparse AtomTransformer. All pair tensors are [B, N_atom, K, *] — the N_atom x N_atom dense grid is never materialised.

        flowchart TD
    classDef inp  fill:#dbeafe,stroke:#3b82f6,color:#1e293b,font-size:13px,padding:8px
    classDef proc fill:#d1fae5,stroke:#059669,color:#1e293b,font-size:13px,padding:8px
    classDef skip fill:#fef3c7,stroke:#d97706,color:#1e293b,font-size:13px,padding:8px
    classDef out  fill:#fce7f3,stroke:#db2777,color:#1e293b,font-size:13px,padding:8px

    A1(["ref_pos [B, N_atom, 3]  ·  ref_element [B, N_atom, E]"]):::inp
    A2(["s_input  [B, N_res, c_res]"]):::inp
    A3(["z_input  [B, N_res, N_res, c_pair]"]):::inp
    A4(["r_scaled  [B, N_atom, 3]"]):::inp
    A5(["tok_idx  [B, N_atom]"]):::inp

    A5  --> SP["Build sparse pairs: neighbor_idx [N_atom, K]  ·  valid_mask [B, N_atom, K]"]:::proc

    A1  --> FR["f_ref = tile(ref_pos, ref_element) per atom  [B, N_atom, f_ref_dim]"]:::proc
    FR  --> CL["c_l = Linear(f_ref)  [B, N_atom, c_atom]"]:::proc
    CL  --> CS(["c_skip saved  [B, N_atom, c_atom]"]):::skip

    CL & A4 --> QS["q_skip = c_l + proj(r_scaled)  [B, N_atom, c_atom]"]:::proc
    QS  --> QSS(["q_skip saved  [B, N_atom, c_atom]"]):::skip

    SP & CL --> PM["p_lm = atom-pair features (distance · chain validity · c_l projections)  [B, N_atom, K, c_atompair]"]:::proc
    A3  --> PM

    A2 & CL --> CL2["c_l += proj(LN(s_input[tok_idx]))  [B, N_atom, c_atom]"]:::proc
    PM  --> PML["p_lm += proj(LN(z_input[tok_l, tok_m]))  then  p_lm += MLP(p_lm)  [B, N_atom, K, c_atompair]"]:::proc
    PML --> PS(["p_skip saved  [B, N_atom, K, c_atompair]"]):::skip

    QSS & CL2 & PML & SP --> AT["AtomTransformer: 3x AtomTransformerBlock (sparse K-neighbor attention)  [B, N_atom, c_atom]"]:::proc

    AT  --> MP["s_i = mean-pool(ReLU(proj(q_skip)), tok_idx)  [B, N_res, c_res]"]:::proc

    MP  --> O1(["s_i  [B, N_res, c_res]"]):::out
    AT  --> O2(["q_skip  [B, N_atom, c_atom]"]):::out
    CS  --> O3(["c_skip  [B, N_atom, c_atom]"]):::out
    PS  --> O4(["p_skip  [B, N_atom, K, c_atompair]"]):::out
    CL2 --> O5(["c_l  [B, N_atom, c_atom]"]):::out