pallatom.architecture.pair_update¶
PairUpdate — pure PyTorch implementation Based on Algorithm 7 from the AlphaFold 3 paper.
Replaces the simplified PairUpdate stub in main_trunk.py.
Steps¶
d_ij = ||r_i^center − r_j^center|| scalar pairwise distances
b_ij = LinearNoBias(Transform_RBF(d_ij)) RBF-discretized distance bias ∈ R^c
z_ij += DropoutRowwise_0.25(TriangleAttentionStartingNodeWithBias(z_ij, b_ij))
z_ij += DropoutColumnwise_0.25(TriangleAttentionEndingNodeWithBias(z_ij, b_ij))
z_ij += Transition(z_ij)
Classes¶
Drops entire columns (dim 2) with probability p during training. |
|
Drops entire rows (dim 1) with probability p during training. |
|
Converts a scalar distance d into a fixed-size RBF feature vector, |
|
For each column j, attend over all i using queries/keys/values from z_ij, |
|
For each row i, attend over all j using queries/keys/values from z_ij, |
Module Contents¶
- class pallatom.architecture.pair_update.DropoutColumnwise(p: float = 0.25)¶
Bases:
torch.nn.ModuleDrops entire columns (dim 2) with probability p during training.
- forward(x: torch.Tensor) torch.Tensor¶
- p = 0.25¶
- class pallatom.architecture.pair_update.DropoutRowwise(p: float = 0.25)¶
Bases:
torch.nn.ModuleDrops entire rows (dim 1) with probability p during training.
- forward(x: torch.Tensor) torch.Tensor¶
- p = 0.25¶
- class pallatom.architecture.pair_update.PairUpdate(c: int = 128, n_rbf: int = 16, n_heads: int = 4, dropout: float = 0.25)¶
Bases:
torch.nn.Module- Parameters:
c (pair embedding dim (default 128))
n_rbf (number of RBF centres)
n_heads (attention heads)
dropout (rowwise/columnwise dropout probability (0.25 per paper))
- forward(z: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair], r_center: jaxtyping.Float[torch.Tensor, B N_res 3]) jaxtyping.Float[torch.Tensor, B N_res N_res c_pair]¶
- drop_col¶
- drop_row¶
- rbf¶
- transition¶
- tri_end¶
- tri_start¶
- class pallatom.architecture.pair_update.TransformRBF(c: int, n_rbf: int = 16, d_min: float = 0.0, d_max: float = 22.0)¶
Bases:
torch.nn.ModuleConverts a scalar distance d into a fixed-size RBF feature vector, then projects to R^c via LinearNoBias.
Centers are evenly spaced in [d_min, d_max]; width = spacing.
- forward(d: jaxtyping.Float[torch.Tensor, B N_res N_res]) jaxtyping.Float[torch.Tensor, B N_res N_res c_pair]¶
- centers: torch.Tensor¶
- proj¶
- sigma = 1.375¶
- class pallatom.architecture.pair_update.Transition(c: int, expansion: int = 4)¶
Bases:
torch.nn.Module- forward(z: torch.Tensor) torch.Tensor¶
- ff1¶
- ff2¶
- norm¶
- class pallatom.architecture.pair_update.TriangleAttentionEndingNodeWithBias(c: int, n_heads: int = 4)¶
Bases:
torch.nn.ModuleFor each column j, attend over all i using queries/keys/values from z_ij, biased by b_ij.
- forward(z: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair], b: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair]) jaxtyping.Float[torch.Tensor, B N_res N_res c_pair]¶
- head_dim¶
- n_heads = 4¶
- norm¶
- norm_b¶
- to_b¶
- to_g¶
- to_k¶
- to_out¶
- to_q¶
- to_v¶
- class pallatom.architecture.pair_update.TriangleAttentionStartingNodeWithBias(c: int, n_heads: int = 4)¶
Bases:
torch.nn.ModuleFor each row i, attend over all j using queries/keys/values from z_ij, with an additive pair bias b_ij projected to per-head scalars. Gate with a sigmoid on z_ij.
- forward(z: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair], b: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair]) jaxtyping.Float[torch.Tensor, B N_res N_res c_pair]¶
- head_dim¶
- n_heads = 4¶
- norm¶
- norm_b¶
- to_b¶
- to_g¶
- to_k¶
- to_out¶
- to_q¶
- to_v¶