train_loop¶
Attributes¶
Functions¶
|
Full-dataset evaluation pass. Returns mean loss per metric. |
|
Distributed evaluation. Each rank processes its shard; metrics are all-reduced. |
|
|
|
DDP training loop. Launched via torchrun — one process per GPU. |
Module Contents¶
- train_loop.evaluate(model: architecture.main_trunk.MainTrunk, loader, tcfg: train.train_config.TrainConfig, distogram_res: helpers.featurize.Distogram, distogram_atom: helpers.featurize.Distogram, device: str) dict[str, float]¶
Full-dataset evaluation pass. Returns mean loss per metric.
- train_loop.evaluate_ddp(rank: int, world_size: int, model: torch.nn.Module, loader, tcfg: train.train_config.TrainConfig, distogram_res: helpers.featurize.Distogram, distogram_atom: helpers.featurize.Distogram, device: str) dict[str, float]¶
Distributed evaluation. Each rank processes its shard; metrics are all-reduced.
- train_loop.train(model: architecture.main_trunk.MainTrunk, tcfg: train.train_config.TrainConfig, train_loader: torch.utils.data.DataLoader, test_loader: torch.utils.data.DataLoader, distogram_res: helpers.featurize.Distogram, distogram_atom: helpers.featurize.Distogram, device: str) None¶
- train_loop.train_ddp(rank: int, local_rank: int, world_size: int, model: architecture.main_trunk.MainTrunk, tcfg: train.train_config.TrainConfig, train_loader: torch.utils.data.DataLoader, test_loader: torch.utils.data.DataLoader, distogram_res: helpers.featurize.Distogram, distogram_atom: helpers.featurize.Distogram, device: str | None = None) None¶
DDP training loop. Launched via torchrun — one process per GPU.
- train_loop.log¶
- train_loop.parser¶