sampling¶
EDM sampling from a trained MainTrunk denoising network.
Attributes¶
Classes¶
Wraps MainTrunk as an EDM-compatible denoiser D_θ(r_noisy, σ) → r_denoised. |
|
Karras et al. 2022 deterministic (Heun) sampler. |
|
Functions¶
|
Map atom5 coordinates back into the full atom37 layout. |
|
|
|
Build the static context FeaturizedBatch for conditional or unconditional sampling. |
|
Module Contents¶
- class sampling.AllAtomContext¶
- aa_indices: jaxtyping.Int[torch.Tensor, B N_res]¶
- atom5_mask: jaxtyping.Bool[torch.Tensor, B N_atom]¶
- f_residue_idx: jaxtyping.Float[torch.Tensor, B N_res c_res]¶
- gt_atom_distogram_mask_sparse: jaxtyping.Bool[torch.Tensor, B N_atom K]¶
- gt_atom_distogram_sparse: jaxtyping.Float[torch.Tensor, B N_atom K n_atom_bins]¶
- r_gt: jaxtyping.Float[torch.Tensor, B N_atom 3]¶
- residue_mask: jaxtyping.Bool[torch.Tensor, B N_res]¶
- class sampling.EDMPrecond(model: architecture.main_trunk.MainTrunk, context: helpers.featurize.FeaturizedBatch, sigma_min: float = 0.002, sigma_max: float = 80.0)¶
Bases:
torch.nn.ModuleWraps MainTrunk as an EDM-compatible denoiser D_θ(r_noisy, σ) → r_denoised.
- Parameters:
model (trained MainTrunk)
context (FeaturizedBatch with static fields filled in; r_input and) – t_hat are replaced at every forward call
sigma_min (lower σ bound, used only to compute t_normalized)
sigma_max (upper σ bound, used only to compute t_normalized)
- forward(r_input: jaxtyping.Float[torch.Tensor, B N_atom 3], t_hat: float) jaxtyping.Float[torch.Tensor, B N_atom 3]¶
- context¶
- model¶
- sigma_max = 80.0¶
- sigma_min = 0.002¶
- class sampling.EDMSampler(denoiser: EDMPrecond, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: float = 7.0, S_churn: float = 0.0, S_tmin: float = 0.0, S_tmax: float = float('inf'), S_noise: float = 1.003)¶
Karras et al. 2022 deterministic (Heun) sampler.
- Parameters:
denoiser (EDMPrecond wrapping a trained MainTrunk)
sigma_min (float smallest noise level (paper: 0.002))
sigma_max (float largest noise level (paper: 80.0))
rho (float schedule exponent (paper: 7.0))
S_churn (float stochastic noise injected per step (0 = deterministic))
S_tmin (float only inject noise in [S_tmin, S_tmax])
S_tmax (float)
S_noise (float scaling of injected noise)
- sample(shape: tuple[int, int, int], steps: int = 40, device: torch.device | str = 'cpu') jaxtyping.Float[torch.Tensor, B N_atom 3]¶
shape : (B, N_atom, 3) — batch of atom coordinate tensor shapes steps : number of ODE steps (40 is usually plenty)
- S_churn = 0.0¶
- S_noise = 1.003¶
- S_tmax¶
- S_tmin = 0.0¶
- denoiser¶
- rho = 7.0¶
- sigma_max = 80.0¶
- sigma_min = 0.002¶
- class sampling.TemplateContext¶
- f_pseudo_beta_mask: jaxtyping.Int[torch.Tensor, B N_res]¶
- f_template_distogram: jaxtyping.Int[torch.Tensor, B N_res N_res n_templ_bins]¶
- sampling.atom5_to_atom37(coords_5: jaxtyping.Float[numpy.ndarray, N_res 5 3], mask_5: jaxtyping.Float[numpy.ndarray, N_res 5] | None = None) tuple[jaxtyping.Float[numpy.ndarray, N_res 37 3], jaxtyping.Float[numpy.ndarray, N_res 37]]¶
Map atom5 coordinates back into the full atom37 layout.
- Returns:
x_37 ((N_res, 37, 3))
mask_37 ((N_res, 37))
- sampling.build_AA_context(atom_37_coordinate_tensor: jaxtyping.Float[torch.Tensor, N_res 37 3], atom_37_mask: jaxtyping.Float[torch.Tensor, N_res 37], residue_index: jaxtyping.Float[torch.Tensor, build_AA_context.N_res], aa_sequence: str, atom_distogram_fn: helpers.featurize.Distogram, batch_size: int, device: str, c_res: int) AllAtomContext¶
- sampling.build_sampling_context(atom_positions: jaxtyping.Float[torch.Tensor, N_res 37 3], atom_mask: jaxtyping.Float[torch.Tensor, N_res 37], residue_index: jaxtyping.Float[torch.Tensor, build_sampling_context.N_res], seq: str, pdb_files: list[str], atom_distogram_fn: helpers.featurize.Distogram, templ_distogram_fn: helpers.featurize.Distogram, c_res: int, batch_size: int = 1, device: str = 'cpu') helpers.featurize.FeaturizedBatch¶
Build the static context FeaturizedBatch for conditional or unconditional sampling.
- Parameters:
atom_positions ((N_res, 37, 3) reference atom coordinates in atom37 layout)
atom_mask ((N_res, 37) float mask; 1 where atom is present)
residue_index ((N_res,) per-residue position index (sinusoidal encoding applied internally))
seq (amino-acid sequence string of length N_res)
pdb_files (PDB paths used as templates; empty list → unconditioned templates)
atom_distogram_fn (Distogram for atom-level pairwise distances)
templ_distogram_fn (Distogram for template Cβ pairwise distances)
c_res (residue embedding dimension)
batch_size (B — number of parallel samples; all share the same context)
device (torch device string)
- sampling.build_template_context(ls_of_proteins: list[helpers.atom_utils.Protein], distogram_fn: helpers.featurize.Distogram, device: str = 'cpu') TemplateContext¶
- sampling.ATOM5_TO_ATOM37¶
- sampling.NATOM = 5¶
- sampling.log¶
- sampling.parser¶