pallatom.helpers.alignment¶
kabsch.py — Rigid protein structure alignment via the Kabsch Algorithm Pure PyTorch implementation. Supports batched inputs and optional masking.
References
Kabsch, W. (1976). A solution for the best rotation to relate two sets of vectors. Acta Crystallographica, A32, 922-923. https://doi.org/10.1107/S0567739476001873
Functions¶
|
Apply a previously computed Kabsch rigid transform to a new set of |
|
Rigidly align mobile onto target using the Kabsch algorithm. |
|
Align mobile to target, then return the RMSD. |
|
Compute the optimal rotation matrix R that minimises RMSD between P and Q |
|
Root-mean-square deviation between two (already aligned) coordinate sets. |
Module Contents¶
- pallatom.helpers.alignment.apply_transform(coords: jaxtyping.Float[torch.Tensor, ... M 3], R: jaxtyping.Float[torch.Tensor, ... 3 3], t_from: jaxtyping.Float[torch.Tensor, ... 1 3], t_to: jaxtyping.Float[torch.Tensor, ... 1 3]) jaxtyping.Float[torch.Tensor, ... M 3]¶
Apply a previously computed Kabsch rigid transform to a new set of coordinates (e.g. all atoms after fitting on Cα only).
Transform: coords_aligned = (coords − t_from) @ R^T + t_to
- Parameters:
coords – (…, M, 3) — coordinates to transform (can differ from the N atoms used to compute the transform).
R – (…, 3, 3) — rotation matrix from kabsch_align.
t_from – (…, 1, 3) — centroid of the mobile set (c_mobile).
t_to – (…, 1, 3) — centroid of the target set (c_target).
- Returns:
(…, M, 3) — transformed coordinates.
- pallatom.helpers.alignment.kabsch_align(mobile: jaxtyping.Float[torch.Tensor, ... N 3], target: jaxtyping.Float[torch.Tensor, ... N 3], weights: Optional[jaxtyping.Float[torch.Tensor, ... N]] = None, return_transform: bool = False) Tuple[torch.Tensor, Ellipsis]¶
Rigidly align mobile onto target using the Kabsch algorithm.
- Pipeline:
Compute (weighted) centroids of both structures.
Centre both structures.
Compute optimal rotation R via SVD (Kabsch).
Apply: aligned = (mobile − c_mobile) @ R^T + c_target
- Parameters:
mobile – (…, N, 3) — coordinates to be aligned.
target – (…, N, 3) — reference coordinates.
weights – (…, N) — optional per-residue weights (e.g. 1 for structured, 0 for missing).
return_transform – bool — if True, also return (R, t_mobile, t_target).
- Returns:
(…, N, 3) — mobile after optimal rigid alignment to target. R: (…, 3, 3) — rotation matrix [only if return_transform] t_mobile: (…, 1, 3) — centroid of mobile [only if return_transform] t_target: (…, 1, 3) — centroid of target [only if return_transform]
- Return type:
aligned
- pallatom.helpers.alignment.kabsch_rmsd(mobile: jaxtyping.Float[torch.Tensor, ... N 3], target: jaxtyping.Float[torch.Tensor, ... N 3], weights: Optional[jaxtyping.Float[torch.Tensor, ... N]] = None, mask: Optional[jaxtyping.Bool[torch.Tensor, ... N]] = None) torch.Tensor¶
Align mobile to target, then return the RMSD.
- Parameters:
mobile – (…, N, 3) — coordinates to align.
target – (…, N, 3) — reference coordinates.
weights – (…, N) — per-residue weights (optional).
mask – (…, N) — boolean residue mask (optional).
- Returns:
(…,) RMSD after optimal rigid alignment (Ångströms).
- pallatom.helpers.alignment.kabsch_rotation(P: jaxtyping.Float[torch.Tensor, ... N 3], Q: jaxtyping.Float[torch.Tensor, ... N 3], weights: Optional[jaxtyping.Float[torch.Tensor, ... N]] = None) jaxtyping.Float[torch.Tensor, ... 3 3]¶
Compute the optimal rotation matrix R that minimises RMSD between P and Q after centering, using singular value decomposition (SVD).
- The algorithm finds R such that:
||W^(1/2) (P @ R.T - Q)||_F is minimised
- Parameters:
P – (…, N, 3) — mobile structure (will be rotated).
Q – (…, N, 3) — target/reference structure.
weights – (…, N) — per-residue weights (optional, non-negative). Useful for masking missing residues or down-weighting flexible loops.
- Returns:
(…, 3, 3) — rotation matrix (det = +1, i.e. proper rotation).
- Return type:
R
Notes
Inputs must already be centred (mean-subtracted) when this function is called directly. Use kabsch_align for the full pipeline (centre → rotate → translate).
Handles reflection by flipping the sign of the column corresponding to the smallest singular value when det(V @ U^T) < 0.
- pallatom.helpers.alignment.rmsd(P: jaxtyping.Float[torch.Tensor, ... N 3], Q: jaxtyping.Float[torch.Tensor, ... N 3], weights: Optional[jaxtyping.Float[torch.Tensor, ... N]] = None, mask: Optional[jaxtyping.Bool[torch.Tensor, ... N]] = None) torch.Tensor¶
Root-mean-square deviation between two (already aligned) coordinate sets.
- Parameters:
P – (…, N, 3) — predicted / mobile coordinates.
Q – (…, N, 3) — reference coordinates.
weights – (…, N) — per-residue weights (optional).
mask – (…, N) — boolean mask; True = include residue (optional). Applied on top of weights.
- Returns:
(…,) tensor of RMSD values in the same units as the input coordinates (typically Ångströms for protein structures).