pallatom.architecture.pairformer_stack¶
Classes¶
One block of the Pairformer operating purely on pair embeddings v_ij. |
|
Module Contents¶
- class pallatom.architecture.pairformer_stack.PairformerBlock(c: int, n_heads: int = 4)¶
Bases:
torch.nn.ModuleOne block of the Pairformer operating purely on pair embeddings v_ij.
Uses:
Row-wise gated self-attention (triangle attention, rows)
Column-wise gated self-attention (triangle attention, cols)
Pair transition FFN
- forward(v: jaxtyping.Float[torch.Tensor, B N_res N_res c]) jaxtyping.Float[torch.Tensor, B N_res N_res c]¶
- ff1¶
- ff2¶
- g_col¶
- g_row¶
- head_dim¶
- k_col¶
- k_row¶
- n_heads = 4¶
- norm_col¶
- norm_ff¶
- norm_row¶
- out_col¶
- out_row¶
- q_col¶
- q_row¶
- v_col¶
- v_row¶