train_loop ========== .. py:module:: train_loop Attributes ---------- .. autoapisummary:: train_loop.log train_loop.parser Functions --------- .. autoapisummary:: train_loop.evaluate train_loop.evaluate_ddp train_loop.train train_loop.train_ddp Module Contents --------------- .. py:function:: evaluate(model: architecture.main_trunk.MainTrunk, loader, tcfg: train.train_config.TrainConfig, distogram_res: helpers.featurize.Distogram, distogram_atom: helpers.featurize.Distogram, device: str) -> dict[str, float] Full-dataset evaluation pass. Returns mean loss per metric. .. py:function:: evaluate_ddp(rank: int, world_size: int, model: torch.nn.Module, loader, tcfg: train.train_config.TrainConfig, distogram_res: helpers.featurize.Distogram, distogram_atom: helpers.featurize.Distogram, device: str) -> dict[str, float] Distributed evaluation. Each rank processes its shard; metrics are all-reduced. .. py:function:: train(model: architecture.main_trunk.MainTrunk, tcfg: train.train_config.TrainConfig, train_loader: torch.utils.data.DataLoader, test_loader: torch.utils.data.DataLoader, distogram_res: helpers.featurize.Distogram, distogram_atom: helpers.featurize.Distogram, device: str) -> None .. py:function:: train_ddp(rank: int, local_rank: int, world_size: int, model: architecture.main_trunk.MainTrunk, tcfg: train.train_config.TrainConfig, train_loader: torch.utils.data.DataLoader, test_loader: torch.utils.data.DataLoader, distogram_res: helpers.featurize.Distogram, distogram_atom: helpers.featurize.Distogram, device: str | None = None) -> None DDP training loop. Launched via torchrun — one process per GPU. .. py:data:: log .. py:data:: parser