Architecture¶
Visual overview of the pallatom denoising network.
Diagrams follow the forward() call order; tensor shapes use the
named-dimension conventions documented in pallatom/CLAUDE.md
(e.g. N_atom, N_res, c_pair, K).
MainTrunk (Algorithm 2)¶
MainTrunk takes a
FeaturizedBatch and returns denoised atom
positions plus auxiliary sequence and distogram logits.
flowchart TD
classDef inp fill:#dbeafe,stroke:#3b82f6,color:#1e293b,font-size:13px,padding:8px
classDef proc fill:#d1fae5,stroke:#059669,color:#1e293b,font-size:13px,padding:8px
classDef skip fill:#fef3c7,stroke:#d97706,color:#1e293b,font-size:13px,padding:8px
classDef out fill:#fce7f3,stroke:#db2777,color:#1e293b,font-size:13px,padding:8px
I1(["r_input [B, N_atom, 3]"]):::inp
I2(["f_residue_idx [B, N_res, c_res]"]):::inp
I3(["f_distogram [B, N_res, N_res, n_templ_bins] · pseudo_beta_mask [B, N_res]"]):::inp
I4(["ref_pos [B, N_atom, 3] · ref_element [B, N_atom, E] · ref_uid [B, N_atom]"]):::inp
I5(["t_hat · t_normalized"]):::inp
I1 --> P1["r_scaled = r_input / sqrt(sigma_data^2 + t_hat^2) [B, N_atom, 3]"]:::proc
I2 --> P2["s_init = Linear(f_residue_idx) [B, N_res, c_res]"]:::proc
I5 --> P3["t_i = TimeFourierEmbedding(0.25 * log(t_hat / sigma_data)) [B, N_res, c_res]"]:::proc
P2 --> P4["s_init += t_i [B, N_res, c_res]"]:::proc
P3 --> P4
P4 --> P5["z_ij = RelativePositionEncoding [B, N_res, N_res, c_pair]"]:::proc
I3 --> P6["z_ij += TemplateEmbedder(f_distogram, z_ij, t) [B, N_res, N_res, c_pair]"]:::proc
P5 --> P6
P1 --> P7["AtomFeatureEncoder (Algorithm 4)"]:::proc
I4 --> P7
P4 --> P7
P6 --> P7
P7 --> SK1(["s_i [B, N_res, c_res]"]):::skip
P7 --> SK2(["q_skip [B, N_atom, c_atom] · c_skip [B, N_atom, c_atom]"]):::skip
P7 --> SK3(["p_skip [B, N_atom, K, c_atompair]"]):::skip
P7 --> SK4(["c_l [B, N_atom, c_atom]"]):::skip
SK1 --> P8["s_i += proj(LN(s_init)) [B, N_res, c_res]"]:::proc
P4 --> P8
P8 --> DL["Decoder Loop x K_unit"]:::proc
SK2 --> DL
SK4 --> DL
P6 --> DL
DL --> O1(["r_denoised [B, N_atom, 3]"]):::out
DL --> O2(["f_seq_logits [B, N_res, n_amino]"]):::out
DL --> O3(["residue_distogram_logits [B, N_res, N_res, n_bins]"]):::out
DL --> O4(["atom_distogram_logits [B, N_atom, K, n_atom_bins]"]):::out
Decoder Loop (x K_unit)¶
Each iteration of the K_unit loop refines all three representations:
residue single s_i, pair z_ij, and atom context c_l.
The EDM blend formula ensures the denoised coordinates are always on the
correct noise-level manifold.
flowchart TD
classDef proc fill:#d1fae5,stroke:#059669,color:#1e293b,font-size:13px,padding:8px
classDef acc fill:#fef3c7,stroke:#d97706,color:#1e293b,font-size:13px,padding:8px
classDef out fill:#fce7f3,stroke:#db2777,color:#1e293b,font-size:13px,padding:8px
IN(["s_i [B,N_res,c_res] · z_ij [B,N_res,N_res,c_pair] · t_i [B,N_res,c_res] · q_skip · c_skip · c_l [B,N_atom,c_atom] · p_skip [B,N_atom,K,c_atompair] · r_input · r_updates [B,N_atom,3]"]):::acc
IN --> NU["NodeUpdate: s_i = Update(s_i, t_i, z_ij) [B, N_res, c_res]"]:::proc
NU --> AD["AtomAttentionDecoder: q_update [B,N_atom,c_atom], p_update [B,N_atom,K,c_atompair], r_update [B,N_atom,3], c_l [B,N_atom,c_atom] = Decode(q_skip, p_skip, c_skip, c_l, s_i, z_ij)"]:::proc
AD --> RU["r_updates += r_update [B, N_atom, 3]"]:::acc
RU --> RD["r_denoised = (sigma^2 * r_input + sigma*t_hat * r_updates) / (sigma^2 + t_hat^2) [B, N_atom, 3]"]:::acc
RD --> IH(["intermediate stack: r_denoised_k [B,N_atom,3] · aa_logits_k [B,N_res,n_amino]"]):::out
RD --> RC["r_center = r_denoised at center_uid [B, N_res, 3]"]:::acc
RC --> PU["PairUpdate: z_ij = Update(z_ij, r_center) [B, N_res, N_res, c_pair]"]:::proc
PU -.->|"repeat for next k"| NU
PU --> FN(["final: r_denoised [B,N_atom,3] · z_ij [B,N_res,N_res,c_pair] · q_update [B,N_atom,c_atom]"]):::acc
AtomFeatureEncoder (Algorithm 4)¶
Builds per-atom and per-residue embeddings using a sparse
AtomTransformer.
All pair tensors are [B, N_atom, K, *] — the N_atom x N_atom dense
grid is never materialised.
flowchart TD
classDef inp fill:#dbeafe,stroke:#3b82f6,color:#1e293b,font-size:13px,padding:8px
classDef proc fill:#d1fae5,stroke:#059669,color:#1e293b,font-size:13px,padding:8px
classDef skip fill:#fef3c7,stroke:#d97706,color:#1e293b,font-size:13px,padding:8px
classDef out fill:#fce7f3,stroke:#db2777,color:#1e293b,font-size:13px,padding:8px
A1(["ref_pos [B, N_atom, 3] · ref_element [B, N_atom, E]"]):::inp
A2(["s_input [B, N_res, c_res]"]):::inp
A3(["z_input [B, N_res, N_res, c_pair]"]):::inp
A4(["r_scaled [B, N_atom, 3]"]):::inp
A5(["tok_idx [B, N_atom]"]):::inp
A5 --> SP["Build sparse pairs: neighbor_idx [N_atom, K] · valid_mask [B, N_atom, K]"]:::proc
A1 --> FR["f_ref = tile(ref_pos, ref_element) per atom [B, N_atom, f_ref_dim]"]:::proc
FR --> CL["c_l = Linear(f_ref) [B, N_atom, c_atom]"]:::proc
CL --> CS(["c_skip saved [B, N_atom, c_atom]"]):::skip
CL & A4 --> QS["q_skip = c_l + proj(r_scaled) [B, N_atom, c_atom]"]:::proc
QS --> QSS(["q_skip saved [B, N_atom, c_atom]"]):::skip
SP & CL --> PM["p_lm = atom-pair features (distance · chain validity · c_l projections) [B, N_atom, K, c_atompair]"]:::proc
A3 --> PM
A2 & CL --> CL2["c_l += proj(LN(s_input[tok_idx])) [B, N_atom, c_atom]"]:::proc
PM --> PML["p_lm += proj(LN(z_input[tok_l, tok_m])) then p_lm += MLP(p_lm) [B, N_atom, K, c_atompair]"]:::proc
PML --> PS(["p_skip saved [B, N_atom, K, c_atompair]"]):::skip
QSS & CL2 & PML & SP --> AT["AtomTransformer: 3x AtomTransformerBlock (sparse K-neighbor attention) [B, N_atom, c_atom]"]:::proc
AT --> MP["s_i = mean-pool(ReLU(proj(q_skip)), tok_idx) [B, N_res, c_res]"]:::proc
MP --> O1(["s_i [B, N_res, c_res]"]):::out
AT --> O2(["q_skip [B, N_atom, c_atom]"]):::out
CS --> O3(["c_skip [B, N_atom, c_atom]"]):::out
PS --> O4(["p_skip [B, N_atom, K, c_atompair]"]):::out
CL2 --> O5(["c_l [B, N_atom, c_atom]"]):::out