Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 28 additions & 31 deletions langfun/core/eval/v2/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import inspect
import os
import re
from typing import Annotated, Any, Callable, Literal, Optional
from typing import Annotated, Any, Callable, ClassVar, Literal, Optional, Type

import langfun.core as lf
from langfun.core.eval.v2 import example as example_lib
Expand Down Expand Up @@ -111,21 +111,22 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):

# Experiment Registration and Lookup

Experiments can be registered by setting a class-level NAME attribute.
Users can then retrieve a registered experiment using Experiment.find(name).
Experiments can be registered by setting a class-level ID attribute
(e.g. a path-like string). Users can then retrieve a registered experiment
using `Experiment.find(id_or_regex)`.

For example:

```
class MyEval(lf.eval.v2.Evaluation):
NAME = 'my_eval'
ID = 'my_eval'

class MyEvalVariation1(MyEval):
NAME = 'my_eval/gemini'
ID = 'my_eval/gemini'
lm = pg.oneof([lf.llms.GeminiPro(), lf.llms.GeminiFlash(), ...])

class MyEvalVariation2(MyEval):
NAME = 'my_eval/openai'
ID = 'my_eval/openai'
lm = pg.oneof([lf.llms.Gpt4o(), lf.llms.Gpt4Turbo(), ...])

# Run all experiments with "gemini" in their name.
Expand Down Expand Up @@ -179,42 +180,42 @@ class MyEvalVariation2(MyEval):
# Class-level functionalities.
#

# An global unique str as a well-known name for an experiment,
# which can be retrieved by `Experiment.find(name)`. If None, the experiment
# does not have a well-known name, thus users need to create the experiment
# by constructing it explicitly.
NAME = None
# An global unique str as a well-known ID for an experiment,
# which can be retrieved by `Experiment.find(id_or_regex)`.
# If None, the experiment does not have a well-known ID, thus users need to
# create the experiment by constructing it explicitly.
ID: ClassVar[str | None] = None

# Global registry for experiment classes with GLOBAL_ID.
_NAME_TO_CLASS = {}
# Global registry for experiment classes with ID.
_ID_TO_CLASS: ClassVar[dict[str, Type['Experiment']]] = {}

def __init_subclass__(cls):
super().__init_subclass__()

if inspect.isabstract(cls):
return

if cls.NAME is not None:
cls._NAME_TO_CLASS[cls.NAME] = cls
if cls.ID is not None:
cls._ID_TO_CLASS[cls.ID] = cls

@classmethod
def find(cls, pattern: str) -> 'Experiment':
def find(cls, id_or_regex: str) -> 'Experiment':
"""Finds an experiment by global name.

Args:
pattern: A regular expression to match the global names of registered
id_or_regex: A regular expression to match the global names of registered
experiments.

Returns:
An experiment object. If multiple experiments are found, a
`Suite` of matched experiments will be returned. If no experiment is
found, an empty `Suite` will be returned.
"""
if pattern in cls._NAME_TO_CLASS:
return cls._NAME_TO_CLASS[pattern]()
regex = re.compile(pattern)
if id_or_regex in cls._ID_TO_CLASS:
return cls._ID_TO_CLASS[id_or_regex]()
regex = re.compile(id_or_regex)
selected = []
for cls_name, exp_cls in cls._NAME_TO_CLASS.items():
for cls_name, exp_cls in cls._ID_TO_CLASS.items():
if regex.match(cls_name):
selected.append(exp_cls())
return selected[0] if len(selected) == 1 else Suite(selected)
Expand Down Expand Up @@ -420,8 +421,8 @@ def run(
create a new run based on the current time if no previous run exists.
If `latest`, it will use the latest run ID under the root directory.
If `new`, it will create a new run ID based on the current time.
runner: The runner to use. If None, it will use the default runner for
the experiment.
runner: The ID of the runner to use. If None, it will use the parallel
runner by default.
warm_start_from: The ID of the previous run to warm start from. If None,
it will continue the experiment identified by `id` from where it left
off. Otherwise, it will create a new experiment run by warming start.
Expand Down Expand Up @@ -941,8 +942,9 @@ def examples_to_load_metadata(self, experiment: Experiment) -> set[int]:
class Runner(pg.Object):
"""Interface for experiment runner."""

# Class-level variable for registering the runner.
NAME = None
# The ID for the runner, which will be referred by the `runner` argument of
# `Experiment.run()` method.
ID: ClassVar[str]

_REGISTRY = {}

Expand All @@ -960,12 +962,7 @@ def __init_subclass__(cls):
super().__init_subclass__()
if inspect.isabstract(cls):
return
if cls.NAME is None:
raise ValueError(
'Runner class must define a NAME constant. '
'Please use the same constant in the runner class.'
)
cls._REGISTRY[cls.NAME] = cls
cls._REGISTRY[cls.ID] = cls

@abc.abstractmethod
def run(self) -> None:
Expand Down
11 changes: 2 additions & 9 deletions langfun/core/eval/v2/experiment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def sample_inputs(num_examples: int = 1):


class MyEvaluation(Evaluation):
NAME = 'my_eval'
ID = 'my_eval'
RUN_ARGS = dict(
runner='test'
)
Expand Down Expand Up @@ -394,7 +394,7 @@ class RunnerTest(unittest.TestCase):
def test_basic(self):

class TestRunner(Runner):
NAME = 'test'
ID = 'test'

def run(self):
pass
Expand All @@ -421,13 +421,6 @@ def run(self):
root_dir=root_dir, id='20241101_1'
)

with self.assertRaisesRegex(
ValueError, 'Runner class must define a NAME constant'
):
class AnotherRunner(Runner): # pylint: disable=unused-variable
def run(self):
pass


if __name__ == '__main__':
unittest.main()
6 changes: 3 additions & 3 deletions langfun/core/eval/v2/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ class SequentialRunner(RunnerBase):
exceptions thrown from the background tasks, making it easier to debug.
"""

NAME = 'sequential'
ID = 'sequential'

def background_run(
self, func: Callable[..., Any], *args: Any, **kwargs: Any
Expand All @@ -402,7 +402,7 @@ def _evaluate_items(
class DebugRunner(SequentialRunner):
"""Debug runner."""

NAME = 'debug'
ID = 'debug'

# Do not use the checkpointer for debug runner.
plugins = []
Expand All @@ -420,7 +420,7 @@ def _save_run_manifest(self) -> None:
class ParallelRunner(RunnerBase):
"""Parallel runner."""

NAME = 'parallel'
ID = 'parallel'

timeout: Annotated[
int | None,
Expand Down
Loading