pallatom.architecture.pairformer_stack ====================================== .. py:module:: pallatom.architecture.pairformer_stack Classes ------- .. autoapisummary:: pallatom.architecture.pairformer_stack.PairformerBlock pallatom.architecture.pairformer_stack.PairformerStack Module Contents --------------- .. py:class:: PairformerBlock(c: int, n_heads: int = 4) Bases: :py:obj:`torch.nn.Module` One block of the Pairformer operating purely on pair embeddings v_ij. Uses: - Row-wise gated self-attention (triangle attention, rows) - Column-wise gated self-attention (triangle attention, cols) - Pair transition FFN .. py:method:: forward(v: jaxtyping.Float[torch.Tensor, B N_res N_res c]) -> jaxtyping.Float[torch.Tensor, B N_res N_res c] .. py:attribute:: ff1 .. py:attribute:: ff2 .. py:attribute:: g_col .. py:attribute:: g_row .. py:attribute:: head_dim .. py:attribute:: k_col .. py:attribute:: k_row .. py:attribute:: n_heads :value: 4 .. py:attribute:: norm_col .. py:attribute:: norm_ff .. py:attribute:: norm_row .. py:attribute:: out_col .. py:attribute:: out_row .. py:attribute:: q_col .. py:attribute:: q_row .. py:attribute:: v_col .. py:attribute:: v_row .. py:class:: PairformerStack(c: int, n_blocks: int = 2, n_heads: int = 4) Bases: :py:obj:`torch.nn.Module` .. py:method:: forward(v: jaxtyping.Float[torch.Tensor, B N_res N_res c]) -> jaxtyping.Float[torch.Tensor, B N_res N_res c] .. py:attribute:: blocks