pallatom.architecture.node_update¶
Classes¶
Self-attention on node embeddings s_i biased by pair embeddings z_ij. |
|
Module Contents¶
- class pallatom.architecture.node_update.AttentionPairBias(c: int, c_pair: int, n_heads: int = 8)¶
Bases:
torch.nn.ModuleSelf-attention on node embeddings s_i biased by pair embeddings z_ij.
- Parameters:
c (single embedding dim (default 256))
c_pair (pair embedding dim)
n_heads (number of attention heads (default 8 per Algorithm 6))
- forward(s: jaxtyping.Float[torch.Tensor, B N_res c_res], t: jaxtyping.Float[torch.Tensor, B N_res c_res], z: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair]) jaxtyping.Float[torch.Tensor, B N_res c_res]¶
- head_dim¶
- n_heads = 8¶
- norm_s¶
- norm_t¶
- norm_z¶
- proj_t¶
- to_bias¶
- to_g¶
- to_k¶
- to_out¶
- to_q¶
- to_v¶
- class pallatom.architecture.node_update.NodeUpdate(c: int = 256, c_pair: int = 128, n_heads: int = 8, dropout: float = 0.25)¶
Bases:
torch.nn.Module- Parameters:
c (single embedding dim (default 256))
c_pair (pair embedding dim)
n_heads (attention heads (default 8, per Algorithm 6))
dropout (rowwise dropout prob (default 0.25, per Algorithm 6))
- forward(s: jaxtyping.Float[torch.Tensor, B N_res c_res], t: jaxtyping.Float[torch.Tensor, B N_res c_res], z: jaxtyping.Float[torch.Tensor, B N_res N_res c_pair]) jaxtyping.Float[torch.Tensor, B N_res c_res]¶
- attn_pair_bias¶
- dropout_row¶
- transition¶