train_loop

Attributes

Functions

evaluate(→ dict[str, float])

Full-dataset evaluation pass. Returns mean loss per metric.

evaluate_ddp(→ dict[str, float])

Distributed evaluation. Each rank processes its shard; metrics are all-reduced.

train(→ None)

train_ddp(→ None)

DDP training loop. Launched via torchrun — one process per GPU.

Module Contents

train_loop.evaluate(model: architecture.main_trunk.MainTrunk, loader, tcfg: train.train_config.TrainConfig, distogram_res: helpers.featurize.Distogram, distogram_atom: helpers.featurize.Distogram, device: str) dict[str, float]

Full-dataset evaluation pass. Returns mean loss per metric.

train_loop.evaluate_ddp(rank: int, world_size: int, model: torch.nn.Module, loader, tcfg: train.train_config.TrainConfig, distogram_res: helpers.featurize.Distogram, distogram_atom: helpers.featurize.Distogram, device: str) dict[str, float]

Distributed evaluation. Each rank processes its shard; metrics are all-reduced.

train_loop.train(model: architecture.main_trunk.MainTrunk, tcfg: train.train_config.TrainConfig, train_loader: torch.utils.data.DataLoader, test_loader: torch.utils.data.DataLoader, distogram_res: helpers.featurize.Distogram, distogram_atom: helpers.featurize.Distogram, device: str) None
train_loop.train_ddp(rank: int, local_rank: int, world_size: int, model: architecture.main_trunk.MainTrunk, tcfg: train.train_config.TrainConfig, train_loader: torch.utils.data.DataLoader, test_loader: torch.utils.data.DataLoader, distogram_res: helpers.featurize.Distogram, distogram_atom: helpers.featurize.Distogram, device: str | None = None) None

DDP training loop. Launched via torchrun — one process per GPU.

train_loop.log
train_loop.parser