From bfa1fbf1e6cc9ee3ba52156e285e6d2014202b88 Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Fri, 24 Jan 2025 13:21:00 -0800 Subject: [PATCH] `lf.eval.v2.Experiment`: Rename 'NAME' to 'ID' for clarity. PiperOrigin-RevId: 719412721 --- langfun/core/eval/v2/experiment.py | 59 ++++++++++++------------- langfun/core/eval/v2/experiment_test.py | 11 +---- langfun/core/eval/v2/runners.py | 6 +-- 3 files changed, 33 insertions(+), 43 deletions(-) diff --git a/langfun/core/eval/v2/experiment.py b/langfun/core/eval/v2/experiment.py index 39639487..7bf1f5c6 100644 --- a/langfun/core/eval/v2/experiment.py +++ b/langfun/core/eval/v2/experiment.py @@ -20,7 +20,7 @@ import inspect import os import re -from typing import Annotated, Any, Callable, Literal, Optional +from typing import Annotated, Any, Callable, ClassVar, Literal, Optional, Type import langfun.core as lf from langfun.core.eval.v2 import example as example_lib @@ -111,21 +111,22 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension): # Experiment Registration and Lookup - Experiments can be registered by setting a class-level NAME attribute. - Users can then retrieve a registered experiment using Experiment.find(name). + Experiments can be registered by setting a class-level ID attribute + (e.g. a path-like string). Users can then retrieve a registered experiment + using `Experiment.find(id_or_regex)`. For example: ``` class MyEval(lf.eval.v2.Evaluation): - NAME = 'my_eval' + ID = 'my_eval' class MyEvalVariation1(MyEval): - NAME = 'my_eval/gemini' + ID = 'my_eval/gemini' lm = pg.oneof([lf.llms.GeminiPro(), lf.llms.GeminiFlash(), ...]) class MyEvalVariation2(MyEval): - NAME = 'my_eval/openai' + ID = 'my_eval/openai' lm = pg.oneof([lf.llms.Gpt4o(), lf.llms.Gpt4Turbo(), ...]) # Run all experiments with "gemini" in their name. @@ -179,14 +180,14 @@ class MyEvalVariation2(MyEval): # Class-level functionalities. # - # An global unique str as a well-known name for an experiment, - # which can be retrieved by `Experiment.find(name)`. If None, the experiment - # does not have a well-known name, thus users need to create the experiment - # by constructing it explicitly. - NAME = None + # An global unique str as a well-known ID for an experiment, + # which can be retrieved by `Experiment.find(id_or_regex)`. + # If None, the experiment does not have a well-known ID, thus users need to + # create the experiment by constructing it explicitly. + ID: ClassVar[str | None] = None - # Global registry for experiment classes with GLOBAL_ID. - _NAME_TO_CLASS = {} + # Global registry for experiment classes with ID. + _ID_TO_CLASS: ClassVar[dict[str, Type['Experiment']]] = {} def __init_subclass__(cls): super().__init_subclass__() @@ -194,15 +195,15 @@ def __init_subclass__(cls): if inspect.isabstract(cls): return - if cls.NAME is not None: - cls._NAME_TO_CLASS[cls.NAME] = cls + if cls.ID is not None: + cls._ID_TO_CLASS[cls.ID] = cls @classmethod - def find(cls, pattern: str) -> 'Experiment': + def find(cls, id_or_regex: str) -> 'Experiment': """Finds an experiment by global name. Args: - pattern: A regular expression to match the global names of registered + id_or_regex: A regular expression to match the global names of registered experiments. Returns: @@ -210,11 +211,11 @@ def find(cls, pattern: str) -> 'Experiment': `Suite` of matched experiments will be returned. If no experiment is found, an empty `Suite` will be returned. """ - if pattern in cls._NAME_TO_CLASS: - return cls._NAME_TO_CLASS[pattern]() - regex = re.compile(pattern) + if id_or_regex in cls._ID_TO_CLASS: + return cls._ID_TO_CLASS[id_or_regex]() + regex = re.compile(id_or_regex) selected = [] - for cls_name, exp_cls in cls._NAME_TO_CLASS.items(): + for cls_name, exp_cls in cls._ID_TO_CLASS.items(): if regex.match(cls_name): selected.append(exp_cls()) return selected[0] if len(selected) == 1 else Suite(selected) @@ -420,8 +421,8 @@ def run( create a new run based on the current time if no previous run exists. If `latest`, it will use the latest run ID under the root directory. If `new`, it will create a new run ID based on the current time. - runner: The runner to use. If None, it will use the default runner for - the experiment. + runner: The ID of the runner to use. If None, it will use the parallel + runner by default. warm_start_from: The ID of the previous run to warm start from. If None, it will continue the experiment identified by `id` from where it left off. Otherwise, it will create a new experiment run by warming start. @@ -941,8 +942,9 @@ def examples_to_load_metadata(self, experiment: Experiment) -> set[int]: class Runner(pg.Object): """Interface for experiment runner.""" - # Class-level variable for registering the runner. - NAME = None + # The ID for the runner, which will be referred by the `runner` argument of + # `Experiment.run()` method. + ID: ClassVar[str] _REGISTRY = {} @@ -960,12 +962,7 @@ def __init_subclass__(cls): super().__init_subclass__() if inspect.isabstract(cls): return - if cls.NAME is None: - raise ValueError( - 'Runner class must define a NAME constant. ' - 'Please use the same constant in the runner class.' - ) - cls._REGISTRY[cls.NAME] = cls + cls._REGISTRY[cls.ID] = cls @abc.abstractmethod def run(self) -> None: diff --git a/langfun/core/eval/v2/experiment_test.py b/langfun/core/eval/v2/experiment_test.py index 5075f4b8..440df3e1 100644 --- a/langfun/core/eval/v2/experiment_test.py +++ b/langfun/core/eval/v2/experiment_test.py @@ -38,7 +38,7 @@ def sample_inputs(num_examples: int = 1): class MyEvaluation(Evaluation): - NAME = 'my_eval' + ID = 'my_eval' RUN_ARGS = dict( runner='test' ) @@ -394,7 +394,7 @@ class RunnerTest(unittest.TestCase): def test_basic(self): class TestRunner(Runner): - NAME = 'test' + ID = 'test' def run(self): pass @@ -421,13 +421,6 @@ def run(self): root_dir=root_dir, id='20241101_1' ) - with self.assertRaisesRegex( - ValueError, 'Runner class must define a NAME constant' - ): - class AnotherRunner(Runner): # pylint: disable=unused-variable - def run(self): - pass - if __name__ == '__main__': unittest.main() diff --git a/langfun/core/eval/v2/runners.py b/langfun/core/eval/v2/runners.py index df2e034a..0d867636 100644 --- a/langfun/core/eval/v2/runners.py +++ b/langfun/core/eval/v2/runners.py @@ -378,7 +378,7 @@ class SequentialRunner(RunnerBase): exceptions thrown from the background tasks, making it easier to debug. """ - NAME = 'sequential' + ID = 'sequential' def background_run( self, func: Callable[..., Any], *args: Any, **kwargs: Any @@ -402,7 +402,7 @@ def _evaluate_items( class DebugRunner(SequentialRunner): """Debug runner.""" - NAME = 'debug' + ID = 'debug' # Do not use the checkpointer for debug runner. plugins = [] @@ -420,7 +420,7 @@ def _save_run_manifest(self) -> None: class ParallelRunner(RunnerBase): """Parallel runner.""" - NAME = 'parallel' + ID = 'parallel' timeout: Annotated[ int | None,