pallatom.architecture.pair_update

PairUpdate — pure PyTorch implementation Based on Algorithm 7 from the AlphaFold 3 paper.

Replaces the simplified PairUpdate stub in main_trunk.py.

Steps

  1. d_ij = ||r_i^center − r_j^center|| scalar pairwise distances

  2. b_ij = LinearNoBias(Transform_RBF(d_ij)) RBF-discretized distance bias ∈ R^c

  3. z_ij += DropoutRowwise_0.25(TriangleAttentionStartingNodeWithBias(z_ij, b_ij))

  4. z_ij += DropoutColumnwise_0.25(TriangleAttentionEndingNodeWithBias(z_ij, b_ij))

  5. z_ij += Transition(z_ij)

Classes

DropoutColumnwise

Drops entire columns (dim 2) with probability p during training.

DropoutRowwise

Drops entire rows (dim 1) with probability p during training.

PairUpdate

TransformRBF

Converts a scalar distance d into a fixed-size RBF feature vector,

Transition

TriangleAttentionEndingNodeWithBias

For each column j, attend over all i using queries/keys/values from z_ij,

TriangleAttentionStartingNodeWithBias

For each row i, attend over all j using queries/keys/values from z_ij,

Module Contents

class pallatom.architecture.pair_update.DropoutColumnwise(p: float = 0.25)

Bases: torch.nn.Module

Drops entire columns (dim 2) with probability p during training.

forward(x: torch.Tensor) torch.Tensor
p = 0.25
class pallatom.architecture.pair_update.DropoutRowwise(p: float = 0.25)

Bases: torch.nn.Module

Drops entire rows (dim 1) with probability p during training.

forward(x: torch.Tensor) torch.Tensor
p = 0.25
class pallatom.architecture.pair_update.PairUpdate(c: int = 128, n_rbf: int = 16, n_heads: int = 4, dropout: float = 0.25)

Bases: torch.nn.Module

Parameters:
  • c (pair embedding dim (default 128))

  • n_rbf (number of RBF centres)

  • n_heads (attention heads)

  • dropout (rowwise/columnwise dropout probability (0.25 per paper))

forward(z: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair], r_center: jaxtyping.Float[torch.Tensor, B N_res 3]) jaxtyping.Float[torch.Tensor, B N_res N_res c_pair]
drop_col
drop_row
rbf
transition
tri_end
tri_start
class pallatom.architecture.pair_update.TransformRBF(c: int, n_rbf: int = 16, d_min: float = 0.0, d_max: float = 22.0)

Bases: torch.nn.Module

Converts a scalar distance d into a fixed-size RBF feature vector, then projects to R^c via LinearNoBias.

Centers are evenly spaced in [d_min, d_max]; width = spacing.

forward(d: jaxtyping.Float[torch.Tensor, B N_res N_res]) jaxtyping.Float[torch.Tensor, B N_res N_res c_pair]
centers: torch.Tensor
proj
sigma = 1.375
class pallatom.architecture.pair_update.Transition(c: int, expansion: int = 4)

Bases: torch.nn.Module

forward(z: torch.Tensor) torch.Tensor
ff1
ff2
norm
class pallatom.architecture.pair_update.TriangleAttentionEndingNodeWithBias(c: int, n_heads: int = 4)

Bases: torch.nn.Module

For each column j, attend over all i using queries/keys/values from z_ij, biased by b_ij.

forward(z: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair], b: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair]) jaxtyping.Float[torch.Tensor, B N_res N_res c_pair]
head_dim
n_heads = 4
norm
norm_b
to_b
to_g
to_k
to_out
to_q
to_v
class pallatom.architecture.pair_update.TriangleAttentionStartingNodeWithBias(c: int, n_heads: int = 4)

Bases: torch.nn.Module

For each row i, attend over all j using queries/keys/values from z_ij, with an additive pair bias b_ij projected to per-head scalars. Gate with a sigmoid on z_ij.

forward(z: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair], b: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair]) jaxtyping.Float[torch.Tensor, B N_res N_res c_pair]
head_dim
n_heads = 4
norm
norm_b
to_b
to_g
to_k
to_out
to_q
to_v