pallatom.helpers.featurize¶
Classes¶
Pairwise distogram module. |
|
Functions¶
|
|
|
|
|
Sinusoidal positional encoding. positions: (batch, N_res) → (batch, N_res, dim) |
Module Contents¶
- class pallatom.helpers.featurize.Distogram(n_bins: int, min_dist: float = 2.0, max_dist: float = 22.0, overflow_bin: bool = False)¶
Bases:
torch.nn.ModulePairwise distogram module.
Precomputes bin edges once at init; forward() maps per-residue coordinates → one-hot distogram + validity mask. Accepts either:
(..., N, 3)— single atom per residue (e.g. pseudo-Cβ), auto-expanded to(..., N, 1, 3)(..., N, A, 3)— A atoms per residue (e.g. atom5, atom37)
- Parameters:
n_bins – Number of distance bins.
min_dist – Lower edge of first bin in Ångströms (default 2.0).
max_dist – Upper edge of last bin in Ångströms (default 22.0).
overflow_bin – If True, adds one extra bin capturing distances > max_dist, making the output shape (…, n_bins + 1) instead of (…, n_bins).
- forward(coords: jaxtyping.Float[torch.Tensor, ... total_atom_count 3], coords_mask: jaxtyping.Bool[torch.Tensor, ... total_atom_count] | None = None) tuple[jaxtyping.Float[torch.Tensor, ... total_atom_count total_atom_count n_bins], jaxtyping.Bool[torch.Tensor, ... total_atom_count total_atom_count]]¶
- Parameters:
coords – (…, total_atom_count, 3)
coords_mask – (…, total_atom_count) — 1 where valid; all-ones if None.
- Returns:
- (…, total_atom_count, total_atom_count, n_bins [+1])
— one-hot bin assignment; last bin is the overflow bin when overflow_bin=True.
- f_pair_mask: (…, total_atom_count, total_atom_count) bool — True where pair is valid.
overflow_bin=True: valid atom pairs only. overflow_bin=False: valid atom pairs AND dist <= max_dist.
- Return type:
f_distogram
- max_dist = 22.0¶
- min_dist = 2.0¶
- n_bins¶
- overflow_bin = False¶
- class pallatom.helpers.featurize.FeaturizedBatch¶
- aa_indices: jaxtyping.Int[torch.Tensor, B N_res]¶
- atom5_mask: jaxtyping.Bool[torch.Tensor, B N_atom]¶
- center_uid: jaxtyping.Int[torch.Tensor, B N_res]¶
- f_pseudo_beta_mask: jaxtyping.Int[torch.Tensor, B N_res]¶
- f_residue_idx: jaxtyping.Float[torch.Tensor, B N_res c_res]¶
- gt_atom_distogram_mask_sparse: jaxtyping.Bool[torch.Tensor, B N_atom K]¶
- gt_atom_distogram_sparse: jaxtyping.Float[torch.Tensor, B N_atom K n_atom_bins]¶
- gt_res_distogram: jaxtyping.Int[torch.Tensor, B N_res N_res n_templ_bins]¶
- r_gt: jaxtyping.Float[torch.Tensor, B N_atom 3]¶
- r_input: jaxtyping.Float[torch.Tensor, B N_atom 3]¶
- ref_element: jaxtyping.Float[torch.Tensor, B N_atom 4]¶
- ref_pos: jaxtyping.Float[torch.Tensor, B N_atom 3]¶
- ref_space_uid: jaxtyping.Int[torch.Tensor, B N_atom]¶
- residue_mask: jaxtyping.Bool[torch.Tensor, B N_res]¶
- tok_idx: jaxtyping.Int[torch.Tensor, B N_atom]¶
- class pallatom.helpers.featurize.ProteinBatch¶
- atom_mask: jaxtyping.Float[torch.Tensor, B N_res 37]¶
- atom_positions: jaxtyping.Float[torch.Tensor, B N_res 37 3]¶
- residue_index: jaxtyping.Float[torch.Tensor, B N_res]¶
- pallatom.helpers.featurize.apply_conditioning_dropout(batch: FeaturizedBatch, p_distogram: float, p_atom: float, p_seq: float, device: str) FeaturizedBatch¶
- pallatom.helpers.featurize.featurize_batch(batch: ProteinBatch, tcfg: train.train_config.TrainConfig, c_beta_distogram_fn: Distogram, atom_distogram_fn: Distogram, device: str = 'cuda' if torch.cuda.is_available() else 'cpu') FeaturizedBatch¶
- pallatom.helpers.featurize.sinusoidal_encoding(positions: jaxtyping.Float[torch.Tensor, batch N_res], dim: int = 32) jaxtyping.Float[torch.Tensor, batch N_res dim]¶
Sinusoidal positional encoding. positions: (batch, N_res) → (batch, N_res, dim)