sampling

EDM sampling from a trained MainTrunk denoising network.

Attributes

Classes

AllAtomContext

EDMPrecond

Wraps MainTrunk as an EDM-compatible denoiser D_θ(r_noisy, σ) → r_denoised.

EDMSampler

Karras et al. 2022 deterministic (Heun) sampler.

TemplateContext

Functions

atom5_to_atom37(→ tuple[jaxtyping.Float[numpy.ndarray, ...)

Map atom5 coordinates back into the full atom37 layout.

build_AA_context(→ AllAtomContext)

build_sampling_context(→ helpers.featurize.FeaturizedBatch)

Build the static context FeaturizedBatch for conditional or unconditional sampling.

build_template_context(→ TemplateContext)

Module Contents

class sampling.AllAtomContext
aa_indices: jaxtyping.Int[torch.Tensor, B N_res]
atom5_mask: jaxtyping.Bool[torch.Tensor, B N_atom]
f_residue_idx: jaxtyping.Float[torch.Tensor, B N_res c_res]
gt_atom_distogram_mask_sparse: jaxtyping.Bool[torch.Tensor, B N_atom K]
gt_atom_distogram_sparse: jaxtyping.Float[torch.Tensor, B N_atom K n_atom_bins]
r_gt: jaxtyping.Float[torch.Tensor, B N_atom 3]
residue_mask: jaxtyping.Bool[torch.Tensor, B N_res]
class sampling.EDMPrecond(model: architecture.main_trunk.MainTrunk, context: helpers.featurize.FeaturizedBatch, sigma_min: float = 0.002, sigma_max: float = 80.0)

Bases: torch.nn.Module

Wraps MainTrunk as an EDM-compatible denoiser D_θ(r_noisy, σ) → r_denoised.

Parameters:
  • model (trained MainTrunk)

  • context (FeaturizedBatch with static fields filled in; r_input and) – t_hat are replaced at every forward call

  • sigma_min (lower σ bound, used only to compute t_normalized)

  • sigma_max (upper σ bound, used only to compute t_normalized)

forward(r_input: jaxtyping.Float[torch.Tensor, B N_atom 3], t_hat: float) jaxtyping.Float[torch.Tensor, B N_atom 3]
context
model
sigma_max = 80.0
sigma_min = 0.002
class sampling.EDMSampler(denoiser: EDMPrecond, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: float = 7.0, S_churn: float = 0.0, S_tmin: float = 0.0, S_tmax: float = float('inf'), S_noise: float = 1.003)

Karras et al. 2022 deterministic (Heun) sampler.

Parameters:
  • denoiser (EDMPrecond wrapping a trained MainTrunk)

  • sigma_min (float smallest noise level (paper: 0.002))

  • sigma_max (float largest noise level (paper: 80.0))

  • rho (float schedule exponent (paper: 7.0))

  • S_churn (float stochastic noise injected per step (0 = deterministic))

  • S_tmin (float only inject noise in [S_tmin, S_tmax])

  • S_tmax (float)

  • S_noise (float scaling of injected noise)

sample(shape: tuple[int, int, int], steps: int = 40, device: torch.device | str = 'cpu') jaxtyping.Float[torch.Tensor, B N_atom 3]

shape : (B, N_atom, 3) — batch of atom coordinate tensor shapes steps : number of ODE steps (40 is usually plenty)

S_churn = 0.0
S_noise = 1.003
S_tmax
S_tmin = 0.0
denoiser
rho = 7.0
sigma_max = 80.0
sigma_min = 0.002
class sampling.TemplateContext
f_pseudo_beta_mask: jaxtyping.Int[torch.Tensor, B N_res]
f_template_distogram: jaxtyping.Int[torch.Tensor, B N_res N_res n_templ_bins]
sampling.atom5_to_atom37(coords_5: jaxtyping.Float[numpy.ndarray, N_res 5 3], mask_5: jaxtyping.Float[numpy.ndarray, N_res 5] | None = None) tuple[jaxtyping.Float[numpy.ndarray, N_res 37 3], jaxtyping.Float[numpy.ndarray, N_res 37]]

Map atom5 coordinates back into the full atom37 layout.

Returns:

  • x_37 ((N_res, 37, 3))

  • mask_37 ((N_res, 37))

sampling.build_AA_context(atom_37_coordinate_tensor: jaxtyping.Float[torch.Tensor, N_res 37 3], atom_37_mask: jaxtyping.Float[torch.Tensor, N_res 37], residue_index: jaxtyping.Float[torch.Tensor, build_AA_context.N_res], aa_sequence: str, atom_distogram_fn: helpers.featurize.Distogram, batch_size: int, device: str, c_res: int) AllAtomContext
sampling.build_sampling_context(atom_positions: jaxtyping.Float[torch.Tensor, N_res 37 3], atom_mask: jaxtyping.Float[torch.Tensor, N_res 37], residue_index: jaxtyping.Float[torch.Tensor, build_sampling_context.N_res], seq: str, pdb_files: list[str], atom_distogram_fn: helpers.featurize.Distogram, templ_distogram_fn: helpers.featurize.Distogram, c_res: int, batch_size: int = 1, device: str = 'cpu') helpers.featurize.FeaturizedBatch

Build the static context FeaturizedBatch for conditional or unconditional sampling.

Parameters:
  • atom_positions ((N_res, 37, 3) reference atom coordinates in atom37 layout)

  • atom_mask ((N_res, 37) float mask; 1 where atom is present)

  • residue_index ((N_res,) per-residue position index (sinusoidal encoding applied internally))

  • seq (amino-acid sequence string of length N_res)

  • pdb_files (PDB paths used as templates; empty list → unconditioned templates)

  • atom_distogram_fn (Distogram for atom-level pairwise distances)

  • templ_distogram_fn (Distogram for template Cβ pairwise distances)

  • c_res (residue embedding dimension)

  • batch_size (B — number of parallel samples; all share the same context)

  • device (torch device string)

sampling.build_template_context(ls_of_proteins: list[helpers.atom_utils.Protein], distogram_fn: helpers.featurize.Distogram, device: str = 'cpu') TemplateContext
sampling.ATOM5_TO_ATOM37
sampling.NATOM = 5
sampling.log
sampling.parser