pallatom.architecture.losses

Functions

atom_loss(→ jaxtyping.Float[torch.Tensor, ...])

Kabsch-aligned atom-coordinate MSE loss.

distogram_loss_atom(→ jaxtyping.Float[torch.Tensor, ...])

Atomic-level local distogram cross-entropy loss (L_dist_atom).

distogram_loss_residue(→ jaxtyping.Float[torch.Tensor, ...)

Residue-level distogram cross-entropy loss (L_dist_res).

med_loss(→ jaxtyping.Float[torch.Tensor, ])

Full intermediate supervision loss averaged over K decoder blocks with

med_loss_per_block(→ jaxtyping.Float[torch.Tensor, ...])

Per-decoder-block intermediate loss L^k_med.

smooth_lddt_loss(→ jaxtyping.Float[torch.Tensor, ])

Smooth lDDT loss — Algorithm 8 (simplified AF3 version for all-atom design).

Module Contents

pallatom.architecture.losses.atom_loss(r_denoised: jaxtyping.Float[torch.Tensor, ... N_res 3], r_gt: jaxtyping.Float[torch.Tensor, ... N_res 3], mask: Optional[jaxtyping.Bool[torch.Tensor, ... N_res]] = None) jaxtyping.Float[torch.Tensor, ...]

Kabsch-aligned atom-coordinate MSE loss.

Rigidly aligns the ground-truth structure onto the denoised structure (i.e. the prediction is held fixed; the GT is rotated/translated to match it), then computes the mean squared deviation per coordinate:

L_atom = ||r_denoised − r_aligned||² / (3L)

where L is the number of (unmasked) residues and the factor 3 accounts for the x, y, z dimensions, matching the formulation in the screenshot.

Parameters:
  • r_denoised – (…, N_res, 3) — model output / denoised coordinates.

  • r_gt – (…, N_res, 3) — ground-truth coordinates r̄⁰.

  • mask – (…, N_res) — boolean or float mask; 1 = valid residue. When None, all N_res residues are used.

Returns:

(…,) scalar loss per batch element.

Notes

  • Alignment is done with float weights derived from mask, so missing / padded residues do not pollute the rotation estimate.

  • Gradients flow through r_denoised (r_aligned is treated as a constant frame — the GT is being moved, not the prediction).

pallatom.architecture.losses.distogram_loss_atom(q: jaxtyping.Float[torch.Tensor, ... N_atom K n_bins], y: torch.Tensor, local_mask: Optional[jaxtyping.Bool[torch.Tensor, ... N_atom K]] = None) jaxtyping.Float[torch.Tensor, ...]

Atomic-level local distogram cross-entropy loss (L_dist_atom).

Supervises predicted local inter-atom distance bin probabilities against one-hot encoded ground-truth bins within a local attention window:

L_dist_atom = -1/(N_atom·K) · Σ_{n,k ∈ local} Σ_{b=1}^{n_bins} y^b_{nk} · log q^b_{nk}

Parameters:
  • q – (…, N_atom, K, n_bins) — predicted bin logits for K sparse neighbours.

  • y – (…, N_atom, K, n_bins) — one-hot targets, OR (…, N_atom, K) integer bin indices.

  • local_mask – (…, N_atom, K) — boolean validity mask (optional).

Returns:

(…,) scalar loss per batch element.

pallatom.architecture.losses.distogram_loss_residue(p: jaxtyping.Float[torch.Tensor, ... N_res N_res n_bins], y: torch.Tensor, mask: Optional[jaxtyping.Bool[torch.Tensor, ... N_res]] = None) jaxtyping.Float[torch.Tensor, ...]

Residue-level distogram cross-entropy loss (L_dist_res).

Supervises predicted inter-residue distance bin probabilities against one-hot encoded ground-truth distance bins:

L_dist_res = -1/N_res² · Σ_{i,j} Σ_{b=1}^{n_bins} y^b_{ij} · log p^b_{ij}

Parameters:
  • p – (…, N_res, N_res, n_bins) — predicted bin logits.

  • y – (…, N_res, N_res, n_bins) — one-hot target bin assignments, OR (…, N_res, N_res) integer bin indices.

  • mask – (…, N_res) — boolean/float residue mask (optional).

Returns:

(…,) scalar loss per batch element.

pallatom.architecture.losses.med_loss(r_denoised_blocks: list, r_gt: jaxtyping.Float[torch.Tensor, ... N_res 3], logits_aa_blocks: list, aa_gt: jaxtyping.Int[torch.Tensor, ... N_res], lam: float, alpha_0: float, gamma: float = 0.99, mask: Optional[jaxtyping.Bool[torch.Tensor, ... N_res]] = None) jaxtyping.Float[torch.Tensor]

Full intermediate supervision loss averaged over K decoder blocks with exponential weight decay that up-weights later blocks:

L_med = (1/K) · Σ_{k=1}^{K} γ^(K−k) · L^k_med

γ < 1 means early blocks get lower weight (γ^(K-1), γ^(K-2), …, γ^0 = 1), so the final block always receives weight 1 and earlier blocks are progressively discounted.

Parameters:
  • r_denoised_blocks – list of K tensors, each (…, N_res, 3).

  • r_gt – (…, N_res, 3) — ground-truth coordinates r̄⁰.

  • logits_aa_blocks – list of K tensors, each (…, N_res, n_amino).

  • aa_gt – (…, N_res) — ground-truth amino-acid indices a⁰.

  • lam – float — λ(t), structure loss weight.

  • alpha_0 – float — α₀, sequence loss weight.

  • gamma – float — decay factor γ (default 0.99).

  • mask – (…, N_res) — optional residue mask.

Returns:

Scalar (mean over batch) total intermediate loss L_med.

Raises:

ValueError – if the two block lists have different lengths.

pallatom.architecture.losses.med_loss_per_block(r_denoised_k: jaxtyping.Float[torch.Tensor, ... N_res 3], r_gt: jaxtyping.Float[torch.Tensor, ... N_res 3], logits_aa_k: jaxtyping.Float[torch.Tensor, ... N_res n_amino], aa_gt: jaxtyping.Int[torch.Tensor, ... N_res], lam: float, alpha_0: float, mask: Optional[jaxtyping.Bool[torch.Tensor, ... N_res]] = None) jaxtyping.Float[torch.Tensor, ...]

Per-decoder-block intermediate loss L^k_med.

L^k_med = λ(t) · ||r̄^denoised_(k) − r̄^aligned||² / 3L
  • α₀ · CE(â_(k), a⁰)

The structure term reuses atom_loss (Kabsch-aligned MSE). The sequence term is standard cross-entropy over amino-acid logits.

Parameters:
  • r_denoised_k – (…, N_res, 3) — denoised coordinates from block k.

  • r_gt – (…, N_res, 3) — ground-truth coordinates r̄⁰.

  • logits_aa_k – (…, N_res, n_amino) — amino-acid logits â_(k).

  • aa_gt – (…, N_res) — ground-truth amino-acid class indices a⁰.

  • lam – float — λ(t), noise-level weight for the structure term.

  • alpha_0 – float — α₀, fixed weight for the sequence CE term.

  • mask – (…, N_res) — boolean/float residue mask (optional).

Returns:

(…,) per-batch loss for block k.

pallatom.architecture.losses.smooth_lddt_loss(r_pred: jaxtyping.Float[torch.Tensor, ... N_atom 3], r_true: jaxtyping.Float[torch.Tensor, ... N_atom 3], mask: Optional[jaxtyping.Bool[torch.Tensor, ... N_atom]] = None, cutoff: float = 15.0) jaxtyping.Float[torch.Tensor]

Smooth lDDT loss — Algorithm 8 (simplified AF3 version for all-atom design).

Exact algorithm:

1. δr_lm     = ||r_l  − r_m||          (predicted pairwise distances)
2. δr_lm_GT  = ||r_l_GT − r_m_GT||     (GT pairwise distances)
3. δ_lm      = |δr_lm_GT − δr_lm|      (absolute distance difference)
4. ε_lm      = ¼ [ σ(½ − δ_lm)
                  + σ(1  − δ_lm)
                  + σ(2  − δ_lm)
                  + σ(4  − δ_lm) ]     (smooth score ∈ (0,1))
5. c_lm      = 1( δr_lm_GT < 15 Å )    (local-neighbourhood mask, l≠m)
   lddt      = mean_{l≠m}(c_lm · ε_lm)
             / mean_{l≠m}(c_lm)
return 1 − lddt
Parameters:
  • r_pred – (…, N_atom, 3) — predicted atom coordinates r̄_l.

  • r_true – (…, N_atom, 3) — ground-truth coordinates r̄_l^GT.

  • mask – (…, N_atom) — float/bool validity mask (optional). Pairs (l, m) are included only when both atoms are valid.

  • cutoff – float — local neighbourhood radius in Å (default 15.0).

Returns:

Scalar loss = 1 − lddt_smooth (averaged over batch).