pallatom.architecture.pairformer_stack

Classes

PairformerBlock

One block of the Pairformer operating purely on pair embeddings v_ij.

PairformerStack

Module Contents

class pallatom.architecture.pairformer_stack.PairformerBlock(c: int, n_heads: int = 4)

Bases: torch.nn.Module

One block of the Pairformer operating purely on pair embeddings v_ij.

Uses:

  • Row-wise gated self-attention (triangle attention, rows)

  • Column-wise gated self-attention (triangle attention, cols)

  • Pair transition FFN

forward(v: jaxtyping.Float[torch.Tensor, B N_res N_res c]) jaxtyping.Float[torch.Tensor, B N_res N_res c]
ff1
ff2
g_col
g_row
head_dim
k_col
k_row
n_heads = 4
norm_col
norm_ff
norm_row
out_col
out_row
q_col
q_row
v_col
v_row
class pallatom.architecture.pairformer_stack.PairformerStack(c: int, n_blocks: int = 2, n_heads: int = 4)

Bases: torch.nn.Module

forward(v: jaxtyping.Float[torch.Tensor, B N_res N_res c]) jaxtyping.Float[torch.Tensor, B N_res N_res c]
blocks