pallatom.architecture.template_embedder¶
Classes¶
Module Contents¶
- class pallatom.architecture.template_embedder.TemplateEmbedder(n_bins: int, c_z: int, c: int = 64, d: int = 128, n_blocks: int = 2, n_heads: int = 4)¶
Bases:
torch.nn.Module- Parameters:
n_bins (number of distogram bins (f_distogram last dim))
c_z (trunk pair embedding dim (z_ij last dim))
c (internal pair dim (default 64))
d (output pair dim (default 128))
n_blocks (PairformerStack depth (default 2))
n_heads (attention heads (default 4))
- forward(f_distogram: jaxtyping.Float[torch.Tensor, B N_res N_res n_bins], f_pseudo_beta_mask: jaxtyping.Float[torch.Tensor, B N_res], z_ij: jaxtyping.Float[torch.Tensor, B N_res N_res c_z], t: float) jaxtyping.Float[torch.Tensor, B N_res N_res d]¶
- norm_v¶
- norm_z¶
- pairformer¶
- proj_a¶
- proj_out¶
- proj_z¶