pallatom.architecture.node_update

Classes

AttentionPairBias

Self-attention on node embeddings s_i biased by pair embeddings z_ij.

NodeUpdate

Module Contents

class pallatom.architecture.node_update.AttentionPairBias(c: int, c_pair: int, n_heads: int = 8)

Bases: torch.nn.Module

Self-attention on node embeddings s_i biased by pair embeddings z_ij.

Parameters:
  • c (single embedding dim (default 256))

  • c_pair (pair embedding dim)

  • n_heads (number of attention heads (default 8 per Algorithm 6))

forward(s: jaxtyping.Float[torch.Tensor, B N_res c_res], t: jaxtyping.Float[torch.Tensor, B N_res c_res], z: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair]) jaxtyping.Float[torch.Tensor, B N_res c_res]
head_dim
n_heads = 8
norm_s
norm_t
norm_z
proj_t
to_bias
to_g
to_k
to_out
to_q
to_v
class pallatom.architecture.node_update.NodeUpdate(c: int = 256, c_pair: int = 128, n_heads: int = 8, dropout: float = 0.25)

Bases: torch.nn.Module

Parameters:
  • c (single embedding dim (default 256))

  • c_pair (pair embedding dim)

  • n_heads (attention heads (default 8, per Algorithm 6))

  • dropout (rowwise dropout prob (default 0.25, per Algorithm 6))

forward(s: jaxtyping.Float[torch.Tensor, B N_res c_res], t: jaxtyping.Float[torch.Tensor, B N_res c_res], z: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair]) jaxtyping.Float[torch.Tensor, B N_res c_res]
attn_pair_bias
dropout_row
transition