pallatom.architecture.template_embedder

Classes

Module Contents

class pallatom.architecture.template_embedder.TemplateEmbedder(n_bins: int, c_z: int, c: int = 64, d: int = 128, n_blocks: int = 2, n_heads: int = 4)

Bases: torch.nn.Module

Parameters:
  • n_bins (number of distogram bins (f_distogram last dim))

  • c_z (trunk pair embedding dim (z_ij last dim))

  • c (internal pair dim (default 64))

  • d (output pair dim (default 128))

  • n_blocks (PairformerStack depth (default 2))

  • n_heads (attention heads (default 4))

forward(f_distogram: jaxtyping.Float[torch.Tensor, B N_res N_res n_bins], f_pseudo_beta_mask: jaxtyping.Float[torch.Tensor, B N_res], z_ij: jaxtyping.Float[torch.Tensor, B N_res N_res c_z], t: float) jaxtyping.Float[torch.Tensor, B N_res N_res d]
norm_v
norm_z
pairformer
proj_a
proj_out
proj_z