pallatom.helpers.alignment

kabsch.py — Rigid protein structure alignment via the Kabsch Algorithm Pure PyTorch implementation. Supports batched inputs and optional masking.

References

Kabsch, W. (1976). A solution for the best rotation to relate two sets of vectors. Acta Crystallographica, A32, 922-923. https://doi.org/10.1107/S0567739476001873

Functions

apply_transform(→ jaxtyping.Float[torch.Tensor, ... M 3])

Apply a previously computed Kabsch rigid transform to a new set of

kabsch_align(→ Tuple[torch.Tensor, Ellipsis])

Rigidly align mobile onto target using the Kabsch algorithm.

kabsch_rmsd(→ torch.Tensor)

Align mobile to target, then return the RMSD.

kabsch_rotation(→ jaxtyping.Float[torch.Tensor, ... 3 3])

Compute the optimal rotation matrix R that minimises RMSD between P and Q

rmsd(→ torch.Tensor)

Root-mean-square deviation between two (already aligned) coordinate sets.

Module Contents

pallatom.helpers.alignment.apply_transform(coords: jaxtyping.Float[torch.Tensor, ... M 3], R: jaxtyping.Float[torch.Tensor, ... 3 3], t_from: jaxtyping.Float[torch.Tensor, ... 1 3], t_to: jaxtyping.Float[torch.Tensor, ... 1 3]) jaxtyping.Float[torch.Tensor, ... M 3]

Apply a previously computed Kabsch rigid transform to a new set of coordinates (e.g. all atoms after fitting on Cα only).

Transform: coords_aligned = (coords − t_from) @ R^T + t_to

Parameters:
  • coords – (…, M, 3) — coordinates to transform (can differ from the N atoms used to compute the transform).

  • R – (…, 3, 3) — rotation matrix from kabsch_align.

  • t_from – (…, 1, 3) — centroid of the mobile set (c_mobile).

  • t_to – (…, 1, 3) — centroid of the target set (c_target).

Returns:

(…, M, 3) — transformed coordinates.

pallatom.helpers.alignment.kabsch_align(mobile: jaxtyping.Float[torch.Tensor, ... N 3], target: jaxtyping.Float[torch.Tensor, ... N 3], weights: Optional[jaxtyping.Float[torch.Tensor, ... N]] = None, return_transform: bool = False) Tuple[torch.Tensor, Ellipsis]

Rigidly align mobile onto target using the Kabsch algorithm.

Pipeline:
  1. Compute (weighted) centroids of both structures.

  2. Centre both structures.

  3. Compute optimal rotation R via SVD (Kabsch).

  4. Apply: aligned = (mobile − c_mobile) @ R^T + c_target

Parameters:
  • mobile – (…, N, 3) — coordinates to be aligned.

  • target – (…, N, 3) — reference coordinates.

  • weights – (…, N) — optional per-residue weights (e.g. 1 for structured, 0 for missing).

  • return_transform – bool — if True, also return (R, t_mobile, t_target).

Returns:

(…, N, 3) — mobile after optimal rigid alignment to target. R: (…, 3, 3) — rotation matrix [only if return_transform] t_mobile: (…, 1, 3) — centroid of mobile [only if return_transform] t_target: (…, 1, 3) — centroid of target [only if return_transform]

Return type:

aligned

pallatom.helpers.alignment.kabsch_rmsd(mobile: jaxtyping.Float[torch.Tensor, ... N 3], target: jaxtyping.Float[torch.Tensor, ... N 3], weights: Optional[jaxtyping.Float[torch.Tensor, ... N]] = None, mask: Optional[jaxtyping.Bool[torch.Tensor, ... N]] = None) torch.Tensor

Align mobile to target, then return the RMSD.

Parameters:
  • mobile – (…, N, 3) — coordinates to align.

  • target – (…, N, 3) — reference coordinates.

  • weights – (…, N) — per-residue weights (optional).

  • mask – (…, N) — boolean residue mask (optional).

Returns:

(…,) RMSD after optimal rigid alignment (Ångströms).

pallatom.helpers.alignment.kabsch_rotation(P: jaxtyping.Float[torch.Tensor, ... N 3], Q: jaxtyping.Float[torch.Tensor, ... N 3], weights: Optional[jaxtyping.Float[torch.Tensor, ... N]] = None) jaxtyping.Float[torch.Tensor, ... 3 3]

Compute the optimal rotation matrix R that minimises RMSD between P and Q after centering, using singular value decomposition (SVD).

The algorithm finds R such that:

||W^(1/2) (P @ R.T - Q)||_F is minimised

Parameters:
  • P – (…, N, 3) — mobile structure (will be rotated).

  • Q – (…, N, 3) — target/reference structure.

  • weights – (…, N) — per-residue weights (optional, non-negative). Useful for masking missing residues or down-weighting flexible loops.

Returns:

(…, 3, 3) — rotation matrix (det = +1, i.e. proper rotation).

Return type:

R

Notes

  • Inputs must already be centred (mean-subtracted) when this function is called directly. Use kabsch_align for the full pipeline (centre → rotate → translate).

  • Handles reflection by flipping the sign of the column corresponding to the smallest singular value when det(V @ U^T) < 0.

pallatom.helpers.alignment.rmsd(P: jaxtyping.Float[torch.Tensor, ... N 3], Q: jaxtyping.Float[torch.Tensor, ... N 3], weights: Optional[jaxtyping.Float[torch.Tensor, ... N]] = None, mask: Optional[jaxtyping.Bool[torch.Tensor, ... N]] = None) torch.Tensor

Root-mean-square deviation between two (already aligned) coordinate sets.

Parameters:
  • P – (…, N, 3) — predicted / mobile coordinates.

  • Q – (…, N, 3) — reference coordinates.

  • weights – (…, N) — per-residue weights (optional).

  • mask – (…, N) — boolean mask; True = include residue (optional). Applied on top of weights.

Returns:

(…,) tensor of RMSD values in the same units as the input coordinates (typically Ångströms for protein structures).