sampling ======== .. py:module:: sampling .. autoapi-nested-parse:: EDM sampling from a trained MainTrunk denoising network. Attributes ---------- .. autoapisummary:: sampling.ATOM5_TO_ATOM37 sampling.NATOM sampling.log sampling.parser Classes ------- .. autoapisummary:: sampling.AllAtomContext sampling.EDMPrecond sampling.EDMSampler sampling.TemplateContext Functions --------- .. autoapisummary:: sampling.atom5_to_atom37 sampling.build_AA_context sampling.build_sampling_context sampling.build_template_context Module Contents --------------- .. py:class:: AllAtomContext .. py:attribute:: aa_indices :type: jaxtyping.Int[torch.Tensor, B N_res] .. py:attribute:: atom5_mask :type: jaxtyping.Bool[torch.Tensor, B N_atom] .. py:attribute:: f_residue_idx :type: jaxtyping.Float[torch.Tensor, B N_res c_res] .. py:attribute:: gt_atom_distogram_mask_sparse :type: jaxtyping.Bool[torch.Tensor, B N_atom K] .. py:attribute:: gt_atom_distogram_sparse :type: jaxtyping.Float[torch.Tensor, B N_atom K n_atom_bins] .. py:attribute:: r_gt :type: jaxtyping.Float[torch.Tensor, B N_atom 3] .. py:attribute:: residue_mask :type: jaxtyping.Bool[torch.Tensor, B N_res] .. py:class:: EDMPrecond(model: architecture.main_trunk.MainTrunk, context: helpers.featurize.FeaturizedBatch, sigma_min: float = 0.002, sigma_max: float = 80.0) Bases: :py:obj:`torch.nn.Module` Wraps MainTrunk as an EDM-compatible denoiser D_θ(r_noisy, σ) → r_denoised. :param model: :type model: trained MainTrunk :param context: t_hat are replaced at every forward call :type context: FeaturizedBatch with static fields filled in; r_input and :param sigma_min: :type sigma_min: lower σ bound, used only to compute t_normalized :param sigma_max: :type sigma_max: upper σ bound, used only to compute t_normalized .. py:method:: forward(r_input: jaxtyping.Float[torch.Tensor, B N_atom 3], t_hat: float) -> jaxtyping.Float[torch.Tensor, B N_atom 3] .. py:attribute:: context .. py:attribute:: model .. py:attribute:: sigma_max :value: 80.0 .. py:attribute:: sigma_min :value: 0.002 .. py:class:: EDMSampler(denoiser: EDMPrecond, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: float = 7.0, S_churn: float = 0.0, S_tmin: float = 0.0, S_tmax: float = float('inf'), S_noise: float = 1.003) Karras et al. 2022 deterministic (Heun) sampler. :param denoiser: :type denoiser: EDMPrecond wrapping a trained MainTrunk :param sigma_min: :type sigma_min: float smallest noise level (paper: 0.002) :param sigma_max: :type sigma_max: float largest noise level (paper: 80.0) :param rho: :type rho: float schedule exponent (paper: 7.0) :param S_churn: :type S_churn: float stochastic noise injected per step (0 = deterministic) :param S_tmin: :type S_tmin: float only inject noise in [S_tmin, S_tmax] :param S_tmax: :type S_tmax: float :param S_noise: :type S_noise: float scaling of injected noise .. py:method:: sample(shape: tuple[int, int, int], steps: int = 40, device: torch.device | str = 'cpu') -> jaxtyping.Float[torch.Tensor, B N_atom 3] shape : (B, N_atom, 3) — batch of atom coordinate tensor shapes steps : number of ODE steps (40 is usually plenty) .. py:attribute:: S_churn :value: 0.0 .. py:attribute:: S_noise :value: 1.003 .. py:attribute:: S_tmax .. py:attribute:: S_tmin :value: 0.0 .. py:attribute:: denoiser .. py:attribute:: rho :value: 7.0 .. py:attribute:: sigma_max :value: 80.0 .. py:attribute:: sigma_min :value: 0.002 .. py:class:: TemplateContext .. py:attribute:: f_pseudo_beta_mask :type: jaxtyping.Int[torch.Tensor, B N_res] .. py:attribute:: f_template_distogram :type: jaxtyping.Int[torch.Tensor, B N_res N_res n_templ_bins] .. py:function:: atom5_to_atom37(coords_5: jaxtyping.Float[numpy.ndarray, N_res 5 3], mask_5: jaxtyping.Float[numpy.ndarray, N_res 5] | None = None) -> tuple[jaxtyping.Float[numpy.ndarray, N_res 37 3], jaxtyping.Float[numpy.ndarray, N_res 37]] Map atom5 coordinates back into the full atom37 layout. :returns: * **x_37** (*(N_res, 37, 3)*) * **mask_37** (*(N_res, 37)*) .. py:function:: build_AA_context(atom_37_coordinate_tensor: jaxtyping.Float[torch.Tensor, N_res 37 3], atom_37_mask: jaxtyping.Float[torch.Tensor, N_res 37], residue_index: jaxtyping.Float[torch.Tensor, build_AA_context.N_res], aa_sequence: str, atom_distogram_fn: helpers.featurize.Distogram, batch_size: int, device: str, c_res: int) -> AllAtomContext .. py:function:: build_sampling_context(atom_positions: jaxtyping.Float[torch.Tensor, N_res 37 3], atom_mask: jaxtyping.Float[torch.Tensor, N_res 37], residue_index: jaxtyping.Float[torch.Tensor, build_sampling_context.N_res], seq: str, pdb_files: list[str], atom_distogram_fn: helpers.featurize.Distogram, templ_distogram_fn: helpers.featurize.Distogram, c_res: int, batch_size: int = 1, device: str = 'cpu') -> helpers.featurize.FeaturizedBatch Build the static context FeaturizedBatch for conditional or unconditional sampling. :param atom_positions: :type atom_positions: (N_res, 37, 3) reference atom coordinates in atom37 layout :param atom_mask: :type atom_mask: (N_res, 37) float mask; 1 where atom is present :param residue_index: :type residue_index: (N_res,) per-residue position index (sinusoidal encoding applied internally) :param seq: :type seq: amino-acid sequence string of length N_res :param pdb_files: :type pdb_files: PDB paths used as templates; empty list → unconditioned templates :param atom_distogram_fn: :type atom_distogram_fn: Distogram for atom-level pairwise distances :param templ_distogram_fn: :type templ_distogram_fn: Distogram for template Cβ pairwise distances :param c_res: :type c_res: residue embedding dimension :param batch_size: :type batch_size: B — number of parallel samples; all share the same context :param device: :type device: torch device string .. py:function:: build_template_context(ls_of_proteins: list[helpers.atom_utils.Protein], distogram_fn: helpers.featurize.Distogram, device: str = 'cpu') -> TemplateContext .. py:data:: ATOM5_TO_ATOM37 .. py:data:: NATOM :value: 5 .. py:data:: log .. py:data:: parser