pallatom.helpers.featurize

Classes

Distogram

Pairwise distogram module.

FeaturizedBatch

ProteinBatch

Functions

apply_conditioning_dropout(→ FeaturizedBatch)

featurize_batch(else ) → FeaturizedBatch)

sinusoidal_encoding(→ jaxtyping.Float[torch.Tensor, ...)

Sinusoidal positional encoding. positions: (batch, N_res) → (batch, N_res, dim)

Module Contents

class pallatom.helpers.featurize.Distogram(n_bins: int, min_dist: float = 2.0, max_dist: float = 22.0, overflow_bin: bool = False)

Bases: torch.nn.Module

Pairwise distogram module.

Precomputes bin edges once at init; forward() maps per-residue coordinates → one-hot distogram + validity mask. Accepts either:

  • (..., N, 3) — single atom per residue (e.g. pseudo-Cβ), auto-expanded to (..., N, 1, 3)

  • (..., N, A, 3) — A atoms per residue (e.g. atom5, atom37)

Parameters:
  • n_bins – Number of distance bins.

  • min_dist – Lower edge of first bin in Ångströms (default 2.0).

  • max_dist – Upper edge of last bin in Ångströms (default 22.0).

  • overflow_bin – If True, adds one extra bin capturing distances > max_dist, making the output shape (…, n_bins + 1) instead of (…, n_bins).

extra_repr() str
forward(coords: jaxtyping.Float[torch.Tensor, ... total_atom_count 3], coords_mask: jaxtyping.Bool[torch.Tensor, ... total_atom_count] | None = None) tuple[jaxtyping.Float[torch.Tensor, ... total_atom_count total_atom_count n_bins], jaxtyping.Bool[torch.Tensor, ... total_atom_count total_atom_count]]
Parameters:
  • coords – (…, total_atom_count, 3)

  • coords_mask – (…, total_atom_count) — 1 where valid; all-ones if None.

Returns:

(…, total_atom_count, total_atom_count, n_bins [+1])

— one-hot bin assignment; last bin is the overflow bin when overflow_bin=True.

f_pair_mask: (…, total_atom_count, total_atom_count) bool — True where pair is valid.

overflow_bin=True: valid atom pairs only. overflow_bin=False: valid atom pairs AND dist <= max_dist.

Return type:

f_distogram

max_dist = 22.0
min_dist = 2.0
n_bins
overflow_bin = False
class pallatom.helpers.featurize.FeaturizedBatch
aa_indices: jaxtyping.Int[torch.Tensor, B N_res]
atom5_mask: jaxtyping.Bool[torch.Tensor, B N_atom]
center_uid: jaxtyping.Int[torch.Tensor, B N_res]
f_pseudo_beta_mask: jaxtyping.Int[torch.Tensor, B N_res]
f_residue_idx: jaxtyping.Float[torch.Tensor, B N_res c_res]
gt_atom_distogram_mask_sparse: jaxtyping.Bool[torch.Tensor, B N_atom K]
gt_atom_distogram_sparse: jaxtyping.Float[torch.Tensor, B N_atom K n_atom_bins]
gt_res_distogram: jaxtyping.Int[torch.Tensor, B N_res N_res n_templ_bins]
r_gt: jaxtyping.Float[torch.Tensor, B N_atom 3]
r_input: jaxtyping.Float[torch.Tensor, B N_atom 3]
ref_element: jaxtyping.Float[torch.Tensor, B N_atom 4]
ref_pos: jaxtyping.Float[torch.Tensor, B N_atom 3]
ref_space_uid: jaxtyping.Int[torch.Tensor, B N_atom]
residue_mask: jaxtyping.Bool[torch.Tensor, B N_res]
t_hat: float
t_normalized: float
tok_idx: jaxtyping.Int[torch.Tensor, B N_atom]
class pallatom.helpers.featurize.ProteinBatch
atom_mask: jaxtyping.Float[torch.Tensor, B N_res 37]
atom_positions: jaxtyping.Float[torch.Tensor, B N_res 37 3]
residue_index: jaxtyping.Float[torch.Tensor, B N_res]
seq: list[str]
pallatom.helpers.featurize.apply_conditioning_dropout(batch: FeaturizedBatch, p_distogram: float, p_atom: float, p_seq: float, device: str) FeaturizedBatch
pallatom.helpers.featurize.featurize_batch(batch: ProteinBatch, tcfg: train.train_config.TrainConfig, c_beta_distogram_fn: Distogram, atom_distogram_fn: Distogram, device: str = 'cuda' if torch.cuda.is_available() else 'cpu') FeaturizedBatch
pallatom.helpers.featurize.sinusoidal_encoding(positions: jaxtyping.Float[torch.Tensor, batch N_res], dim: int = 32) jaxtyping.Float[torch.Tensor, batch N_res dim]

Sinusoidal positional encoding. positions: (batch, N_res) → (batch, N_res, dim)