Add native TEST phase to TorchTNT framework#1053
Open
zedsdead01 wants to merge 1 commit intometa-pytorch:masterfrom
Open
Add native TEST phase to TorchTNT framework#1053zedsdead01 wants to merge 1 commit intometa-pytorch:masterfrom
zedsdead01 wants to merge 1 commit intometa-pytorch:masterfrom
Conversation
Summary: Adds a native TEST phase to TorchTNT as a first-class citizen, following the existing patterns for TRAIN, EVALUATE, and PREDICT phases. This eliminates the need for workarounds that repurpose PREDICT for testing (e.g., Vizard's D88975254). The TEST phase is semantically distinct from PREDICT (which only runs forward passes) - it runs evaluation on held-out test data with loss/metrics computation, similar to EVALUATE. Changes made: **State and Enums:** - Added `EntryPoint.TEST` and `ActivePhase.TEST` to `state.py` - Added `Phase.TEST` to `utils/checkpoint.py` - Added `test_state` property to `State` class - Updated `active_phase_state()` and `into_phase()` to handle TEST **TestUnit Interface (`unit.py`):** - Added `TestUnit` ABC with `test_step`, lifecycle hooks (`on_test_start`, `on_test_epoch_start`, `on_test_epoch_end`, `on_test_end`), and `get_next_test_batch` - Added `TTestData` TypeVar and `TTestUnit` type alias **Callback Hooks (`callback.py`, `_callback_handler.py`):** - Added 10 `on_test_*` callback hooks mirroring the eval/predict patterns - Added corresponding dispatcher methods in `CallbackHandler` **Test Loop (`test.py` - new file):** - Created `test()` entry point and `_test_impl()` loop, modeled after `predict.py` - Uses `torch.no_grad()` context (like evaluate, not `inference_mode` like predict) - Handles module eval mode, callback dispatch, progress tracking, and exception handling **AutoUnit Integration (`auto_unit.py`):** - Added `TestUnit[TData]` to `AutoUnit`'s inheritance chain - Added `test_step()` (calls `compute_loss`, like `eval_step`) - Added `on_test_step_end()` hook - Added `get_next_test_batch()` with prefetch support - Updated prefetch dicts to include `ActivePhase.TEST` **fit() Integration (`fit.py`):** - Added optional `test_dataloader` and `max_test_steps_per_epoch` parameters - Test phase runs after training completes (not interleaved) - Added type check for `TestUnit` when `test_dataloader` is provided **Exports and Test Utilities:** - Updated `__init__.py` to export `test`, `TestUnit`, `TTestUnit` - Added `DummyTestUnit`, `DummyFitTestUnit`, and `get_dummy_test_state` to `_test_utils.py` **Build System:** - Added `test` library target to `framework/BUCK` - Updated `framework` and `fit` targets to depend on `test` - Added `test_test` test target to `tests/framework/BUCK` Differential Revision: D96203908
|
@zedsdead01 has exported this pull request. If you are a Meta employee, you can view the originating Diff in D96203908. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
Adds a native TEST phase to TorchTNT as a first-class citizen, following the existing patterns for TRAIN, EVALUATE, and PREDICT phases. This eliminates the need for workarounds that repurpose PREDICT for testing (e.g., Vizard's D88975254).
The TEST phase is semantically distinct from PREDICT (which only runs forward passes) - it runs evaluation on held-out test data with loss/metrics computation, similar to EVALUATE.
Changes made:
State and Enums:
EntryPoint.TESTandActivePhase.TESTtostate.pyPhase.TESTtoutils/checkpoint.pytest_stateproperty toStateclassactive_phase_state()andinto_phase()to handle TESTTestUnit Interface (
unit.py):TestUnitABC withtest_step, lifecycle hooks (on_test_start,on_test_epoch_start,on_test_epoch_end,on_test_end), andget_next_test_batchTTestDataTypeVar andTTestUnittype aliasCallback Hooks (
callback.py,_callback_handler.py):on_test_*callback hooks mirroring the eval/predict patternsCallbackHandlerTest Loop (
test.py- new file):test()entry point and_test_impl()loop, modeled afterpredict.pytorch.no_grad()context (like evaluate, notinference_modelike predict)AutoUnit Integration (
auto_unit.py):TestUnit[TData]toAutoUnit's inheritance chaintest_step()(callscompute_loss, likeeval_step)on_test_step_end()hookget_next_test_batch()with prefetch supportActivePhase.TESTfit() Integration (
fit.py):test_dataloaderandmax_test_steps_per_epochparametersTestUnitwhentest_dataloaderis providedExports and Test Utilities:
__init__.pyto exporttest,TestUnit,TTestUnitDummyTestUnit,DummyFitTestUnit, andget_dummy_test_stateto_test_utils.pyBuild System:
testlibrary target toframework/BUCKframeworkandfittargets to depend ontesttest_testtest target totests/framework/BUCKDifferential Revision: D96203908