Skip to content

feat: Add training loop skeleton with logging hooks and checkpoint save/load#24

Open
KrishanYadav333 wants to merge 1 commit intoML4SCI:mainfrom
KrishanYadav333:feat/training-loop
Open

feat: Add training loop skeleton with logging hooks and checkpoint save/load#24
KrishanYadav333 wants to merge 1 commit intoML4SCI:mainfrom
KrishanYadav333:feat/training-loop

Conversation

@KrishanYadav333
Copy link
Copy Markdown

Part of pre-GSoC groundwork for the EXXA DDPM denoising pipeline.

Adds a minimal, model-agnostic Trainer class that drives the training loop.

Changes

  • src/training/trainer.pyTrainer class with:
    • train_one_epoch() — dataloader iteration, forward pass, loss, optimizer step
    • log_fn hook — called with (epoch, step, loss) after each step
    • save_checkpoint() / load_checkpoint() — full state dict round-trip
  • src/training/__init__.py — exports Trainer

Design

Trainer is model-agnostic — it expects any nn.Module with a training_loss(batch) -> Tensor method. This means it works today with the toy model in tests and will plug directly into DDPM once implemented.

Tests

18 tests in tests/test_trainer.py covering:

  • Instantiation and device placement
  • train_one_epoch return type, loss positivity, epoch counter
  • Loss descent over multiple epochs on a toy synthetic batch
  • Logging hook call count and argument types
  • Checkpoint save/load restoring epoch, model weights, and inference consistency

All 18 pass.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant