Skip to content

Add native TEST phase to TorchTNT framework#1053

Open
zedsdead01 wants to merge 1 commit intometa-pytorch:masterfrom
zedsdead01:export-D96203908
Open

Add native TEST phase to TorchTNT framework#1053
zedsdead01 wants to merge 1 commit intometa-pytorch:masterfrom
zedsdead01:export-D96203908

Conversation

@zedsdead01
Copy link
Contributor

Summary:
Adds a native TEST phase to TorchTNT as a first-class citizen, following the existing patterns for TRAIN, EVALUATE, and PREDICT phases. This eliminates the need for workarounds that repurpose PREDICT for testing (e.g., Vizard's D88975254).

The TEST phase is semantically distinct from PREDICT (which only runs forward passes) - it runs evaluation on held-out test data with loss/metrics computation, similar to EVALUATE.

Changes made:

State and Enums:

  • Added EntryPoint.TEST and ActivePhase.TEST to state.py
  • Added Phase.TEST to utils/checkpoint.py
  • Added test_state property to State class
  • Updated active_phase_state() and into_phase() to handle TEST

TestUnit Interface (unit.py):

  • Added TestUnit ABC with test_step, lifecycle hooks (on_test_start, on_test_epoch_start, on_test_epoch_end, on_test_end), and get_next_test_batch
  • Added TTestData TypeVar and TTestUnit type alias

Callback Hooks (callback.py, _callback_handler.py):

  • Added 10 on_test_* callback hooks mirroring the eval/predict patterns
  • Added corresponding dispatcher methods in CallbackHandler

Test Loop (test.py - new file):

  • Created test() entry point and _test_impl() loop, modeled after predict.py
  • Uses torch.no_grad() context (like evaluate, not inference_mode like predict)
  • Handles module eval mode, callback dispatch, progress tracking, and exception handling

AutoUnit Integration (auto_unit.py):

  • Added TestUnit[TData] to AutoUnit's inheritance chain
  • Added test_step() (calls compute_loss, like eval_step)
  • Added on_test_step_end() hook
  • Added get_next_test_batch() with prefetch support
  • Updated prefetch dicts to include ActivePhase.TEST

fit() Integration (fit.py):

  • Added optional test_dataloader and max_test_steps_per_epoch parameters
  • Test phase runs after training completes (not interleaved)
  • Added type check for TestUnit when test_dataloader is provided

Exports and Test Utilities:

  • Updated __init__.py to export test, TestUnit, TTestUnit
  • Added DummyTestUnit, DummyFitTestUnit, and get_dummy_test_state to _test_utils.py

Build System:

  • Added test library target to framework/BUCK
  • Updated framework and fit targets to depend on test
  • Added test_test test target to tests/framework/BUCK

Differential Revision: D96203908

Summary:
Adds a native TEST phase to TorchTNT as a first-class citizen, following the existing patterns for TRAIN, EVALUATE, and PREDICT phases. This eliminates the need for workarounds that repurpose PREDICT for testing (e.g., Vizard's D88975254).

The TEST phase is semantically distinct from PREDICT (which only runs forward passes) - it runs evaluation on held-out test data with loss/metrics computation, similar to EVALUATE.

Changes made:

**State and Enums:**
- Added `EntryPoint.TEST` and `ActivePhase.TEST` to `state.py`
- Added `Phase.TEST` to `utils/checkpoint.py`
- Added `test_state` property to `State` class
- Updated `active_phase_state()` and `into_phase()` to handle TEST

**TestUnit Interface (`unit.py`):**
- Added `TestUnit` ABC with `test_step`, lifecycle hooks (`on_test_start`, `on_test_epoch_start`, `on_test_epoch_end`, `on_test_end`), and `get_next_test_batch`
- Added `TTestData` TypeVar and `TTestUnit` type alias

**Callback Hooks (`callback.py`, `_callback_handler.py`):**
- Added 10 `on_test_*` callback hooks mirroring the eval/predict patterns
- Added corresponding dispatcher methods in `CallbackHandler`

**Test Loop (`test.py` - new file):**
- Created `test()` entry point and `_test_impl()` loop, modeled after `predict.py`
- Uses `torch.no_grad()` context (like evaluate, not `inference_mode` like predict)
- Handles module eval mode, callback dispatch, progress tracking, and exception handling

**AutoUnit Integration (`auto_unit.py`):**
- Added `TestUnit[TData]` to `AutoUnit`'s inheritance chain
- Added `test_step()` (calls `compute_loss`, like `eval_step`)
- Added `on_test_step_end()` hook
- Added `get_next_test_batch()` with prefetch support
- Updated prefetch dicts to include `ActivePhase.TEST`

**fit() Integration (`fit.py`):**
- Added optional `test_dataloader` and `max_test_steps_per_epoch` parameters
- Test phase runs after training completes (not interleaved)
- Added type check for `TestUnit` when `test_dataloader` is provided

**Exports and Test Utilities:**
- Updated `__init__.py` to export `test`, `TestUnit`, `TTestUnit`
- Added `DummyTestUnit`, `DummyFitTestUnit`, and `get_dummy_test_state` to `_test_utils.py`

**Build System:**
- Added `test` library target to `framework/BUCK`
- Updated `framework` and `fit` targets to depend on `test`
- Added `test_test` test target to `tests/framework/BUCK`

Differential Revision: D96203908
@meta-cla meta-cla bot added the cla signed label Mar 17, 2026
@meta-codesync
Copy link

meta-codesync bot commented Mar 17, 2026

@zedsdead01 has exported this pull request. If you are a Meta employee, you can view the originating Diff in D96203908.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant