From d21b8c0623df7ba522cf33d17eb040067172f6d4 Mon Sep 17 00:00:00 2001 From: Julien Simon Date: Sun, 29 Mar 2026 13:37:59 +0200 Subject: [PATCH] feat: make domain scenarios configurable via YAML files Add YAML-based configuration for domain scenarios, allowing users to define custom domains or override built-ins via --config flag on CLI commands (list-domains, enrich, upload, pipeline). Co-Authored-By: Claude Opus 4.6 (1M context) --- pyproject.toml | 1 + src/ocelgen/cli.py | 61 +++++-- src/ocelgen/scenarios/__init__.py | 9 +- src/ocelgen/scenarios/loader.py | 115 +++++++++++++ src/ocelgen/scenarios/registry.py | 8 +- tests/test_cli_new.py | 25 +++ tests/test_yaml_loader.py | 276 ++++++++++++++++++++++++++++++ 7 files changed, 477 insertions(+), 18 deletions(-) create mode 100644 src/ocelgen/scenarios/loader.py create mode 100644 tests/test_yaml_loader.py diff --git a/pyproject.toml b/pyproject.toml index 37350a6..5bef251 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "openai>=1.0", "huggingface_hub>=0.20", "pyarrow>=15.0", + "pyyaml>=6.0", ] [project.optional-dependencies] diff --git a/src/ocelgen/cli.py b/src/ocelgen/cli.py index 76811cc..4174244 100644 --- a/src/ocelgen/cli.py +++ b/src/ocelgen/cli.py @@ -15,8 +15,16 @@ from ocelgen.export.normative import write_normative_model from ocelgen.export.ocel_json import write_ocel_json from ocelgen.generation.engine import PATTERN_REGISTRY, generate +from ocelgen.scenarios.domain import DomainScenario from ocelgen.validation.schema import validate_ocel_file + +def _get_registry(config: Path | None) -> dict[str, DomainScenario]: + """Build the scenario registry, merging custom YAML config if provided.""" + from ocelgen.scenarios.loader import build_registry + + return build_registry(config) + app = typer.Typer( name="ocelgen", help="Mock OCEL 2.0 event log generator for LangChain multi-agent runs.", @@ -129,9 +137,14 @@ def list_patterns() -> None: @app.command("list-domains") -def list_domains() -> None: +def list_domains( + config: Annotated[ + Path | None, + typer.Option("--config", "-c", help="YAML file or directory with custom domain definitions"), + ] = None, +) -> None: """List available domain scenarios for enriched generation.""" - from ocelgen.scenarios.registry import SCENARIO_REGISTRY + registry = _get_registry(config) table = Table(title="Available Domain Scenarios") table.add_column("Name", style="bold", no_wrap=True) @@ -140,7 +153,7 @@ def list_domains() -> None: table.add_column("Noise", justify="right") table.add_column("Description") - for name, scenario in SCENARIO_REGISTRY.items(): + for name, scenario in registry.items(): table.add_row( name, scenario.pattern, @@ -150,7 +163,7 @@ def list_domains() -> None: ) console.print(table) - console.print(f"\n[bold]{len(SCENARIO_REGISTRY)}[/bold] domains available.") + console.print(f"\n[bold]{len(registry)}[/bold] domains available.") @app.command("enrich") @@ -163,6 +176,10 @@ def enrich_cmd( output: Annotated[ Path | None, typer.Option("-o", "--output", help="Output path") ] = None, + config: Annotated[ + Path | None, + typer.Option("--config", "-c", help="YAML file or directory with custom domain definitions"), + ] = None, ) -> None: """Enrich an OCEL 2.0 trace with LLM-generated content.""" from pydantic import TypeAdapter @@ -172,17 +189,19 @@ def enrich_cmd( from ocelgen.enrichment.enricher import enrich_log from ocelgen.export.ocel_json import ocel_log_to_dict from ocelgen.models.ocel import OcelLog - from ocelgen.scenarios.registry import SCENARIO_REGISTRY, get_scenario + from ocelgen.scenarios.registry import get_scenario + + registry = _get_registry(config) if not path.exists(): console.print(f"[red]File not found: {path}[/red]") raise typer.Exit(1) - if domain not in SCENARIO_REGISTRY: + if domain not in registry: console.print(f"[red]Unknown domain '{domain}'. Use 'list-domains' to see available domains.[/red]") raise typer.Exit(1) - scenario = get_scenario(domain) + scenario = get_scenario(domain, registry=registry) console.print(f"Loading [bold]{path}[/bold]...") with open(path, encoding="utf-8") as f: @@ -213,20 +232,26 @@ def upload_cmd( collection: Annotated[ str, typer.Option("--collection", help="Collection slug") ] = "open-agent-traces", + config: Annotated[ + Path | None, + typer.Option("--config", "-c", help="YAML file or directory with custom domain definitions"), + ] = None, ) -> None: """Upload an enriched trace to Hugging Face Hub.""" from pydantic import TypeAdapter from ocelgen.models.ocel import OcelLog - from ocelgen.scenarios.registry import SCENARIO_REGISTRY, get_scenario + from ocelgen.scenarios.registry import get_scenario from ocelgen.upload.flatten import flatten_log from ocelgen.upload.hf_upload import build_repo_name, prepare_upload_files, upload_to_hub + registry = _get_registry(config) + if not path.exists(): console.print(f"[red]File not found: {path}[/red]") raise typer.Exit(1) - if domain not in SCENARIO_REGISTRY: + if domain not in registry: console.print(f"[red]Unknown domain '{domain}'.[/red]") raise typer.Exit(1) @@ -234,7 +259,7 @@ def upload_cmd( console.print("[red]--namespace is required.[/red]") raise typer.Exit(1) - scenario = get_scenario(domain) + scenario = get_scenario(domain, registry=registry) console.print(f"Loading [bold]{path}[/bold]...") with open(path, encoding="utf-8") as f: @@ -278,7 +303,7 @@ def pipeline_cmd( str | None, typer.Option("--domain", "-d", help="Single domain to process") ] = None, all_domains: Annotated[ - bool, typer.Option("--all", help="Process all 10 domains") + bool, typer.Option("--all", help="Process all domains") ] = False, namespace: Annotated[str, typer.Option("--namespace", "-n", help="HF namespace")] = "", model: Annotated[ @@ -290,13 +315,17 @@ def pipeline_cmd( skip_upload: Annotated[ bool, typer.Option("--skip-upload", help="Generate and enrich but don't upload") ] = False, + config: Annotated[ + Path | None, + typer.Option("--config", "-c", help="YAML file or directory with custom domain definitions"), + ] = None, ) -> None: """End-to-end pipeline: generate, enrich, and upload agent trace datasets.""" from rich.progress import Progress from ocelgen.enrichment.client import LLMClient from ocelgen.enrichment.enricher import enrich_log - from ocelgen.scenarios.registry import SCENARIO_REGISTRY, get_scenario + from ocelgen.scenarios.registry import get_scenario from ocelgen.upload.flatten import flatten_log from ocelgen.upload.hf_upload import ( build_repo_name, @@ -305,6 +334,8 @@ def pipeline_cmd( upload_to_hub, ) + registry = _get_registry(config) + if not namespace: console.print("[red]--namespace is required.[/red]") raise typer.Exit(1) @@ -313,10 +344,10 @@ def pipeline_cmd( console.print("[red]Specify --domain or --all.[/red]") raise typer.Exit(1) - domains: list[str] = list(SCENARIO_REGISTRY.keys()) if all_domains else [domain] # type: ignore[list-item] + domains: list[str] = list(registry.keys()) if all_domains else [domain] # type: ignore[list-item] for d in domains: - if d not in SCENARIO_REGISTRY: + if d not in registry: console.print(f"[red]Unknown domain '{d}'.[/red]") raise typer.Exit(1) @@ -324,7 +355,7 @@ def pipeline_cmd( uploaded_repos: list[str] = [] for d in domains: - scenario = get_scenario(d) + scenario = get_scenario(d, registry=registry) console.rule(f"[bold]{scenario.name}[/bold] ({scenario.pattern})") console.print(f"Generating {scenario.runs} runs...") diff --git a/src/ocelgen/scenarios/__init__.py b/src/ocelgen/scenarios/__init__.py index ef7a92e..ff4a3fd 100644 --- a/src/ocelgen/scenarios/__init__.py +++ b/src/ocelgen/scenarios/__init__.py @@ -1,6 +1,13 @@ """Domain scenario definitions for LLM-enriched trace generation.""" from ocelgen.scenarios.domain import DomainScenario +from ocelgen.scenarios.loader import build_registry, load_domains_from_yaml from ocelgen.scenarios.registry import SCENARIO_REGISTRY, get_scenario -__all__ = ["DomainScenario", "SCENARIO_REGISTRY", "get_scenario"] +__all__ = [ + "DomainScenario", + "SCENARIO_REGISTRY", + "build_registry", + "get_scenario", + "load_domains_from_yaml", +] diff --git a/src/ocelgen/scenarios/loader.py b/src/ocelgen/scenarios/loader.py new file mode 100644 index 0000000..1e4de3c --- /dev/null +++ b/src/ocelgen/scenarios/loader.py @@ -0,0 +1,115 @@ +"""Load domain scenarios from YAML configuration files.""" + +from __future__ import annotations + +from pathlib import Path + +import yaml + +from ocelgen.scenarios.domain import DomainScenario + +_VALID_PATTERNS = {"sequential", "supervisor", "parallel"} +_REQUIRED_FIELDS = {"name", "description", "pattern", "runs", "noise", "seed"} + + +def load_domains_from_yaml(path: Path) -> dict[str, DomainScenario]: + """Read a YAML file and return validated domain scenarios keyed by name. + + The file must have a top-level ``domains`` key containing a list of + domain definitions whose fields match :class:`DomainScenario`. + """ + try: + raw = yaml.safe_load(path.read_text(encoding="utf-8")) + except yaml.YAMLError as exc: + raise ValueError(f"Failed to parse YAML in {path}: {exc}") from exc + + if not raw or "domains" not in raw: + return {} + + domains_list = raw["domains"] + if not isinstance(domains_list, list): + raise ValueError(f"Expected 'domains' to be a list in {path}") + + result: dict[str, DomainScenario] = {} + for idx, entry in enumerate(domains_list): + if not isinstance(entry, dict): + raise ValueError(f"Domain entry {idx} in {path} is not a mapping") + + missing = _REQUIRED_FIELDS - entry.keys() + if missing: + raise ValueError( + f"Domain entry {idx} in {path} missing required fields: {sorted(missing)}" + ) + + name = entry["name"] + pattern = entry["pattern"] + if pattern not in _VALID_PATTERNS: + raise ValueError( + f"Domain '{name}' in {path}: 'pattern' must be one of " + f"{sorted(_VALID_PATTERNS)} (got '{pattern}')" + ) + + runs = entry["runs"] + if not isinstance(runs, int) or runs <= 0: + raise ValueError(f"Domain '{name}' in {path}: 'runs' must be a positive integer") + + noise = entry["noise"] + if not isinstance(noise, (int, float)) or not (0.0 <= noise <= 1.0): + raise ValueError(f"Domain '{name}' in {path}: 'noise' must be between 0.0 and 1.0") + + seed = entry["seed"] + if not isinstance(seed, int): + raise ValueError(f"Domain '{name}' in {path}: 'seed' must be an integer") + + result[name] = DomainScenario( + name=name, + description=entry["description"], + pattern=pattern, + runs=runs, + noise=float(noise), + seed=seed, + user_queries=entry.get("user_queries", []), + agent_personas=entry.get("agent_personas", {}), + tool_descriptions=entry.get("tool_descriptions", {}), + ) + + return result + + +def load_domains_from_dir(dir_path: Path) -> dict[str, DomainScenario]: + """Load all ``*.yaml`` and ``*.yml`` files from a directory. + + Files are processed in alphabetical order; later files override + earlier ones for domains with the same name. + """ + result: dict[str, DomainScenario] = {} + yaml_files = sorted( + p for p in dir_path.iterdir() if p.suffix in {".yaml", ".yml"} + ) + for f in yaml_files: + result.update(load_domains_from_yaml(f)) + return result + + +def build_registry(config: Path | None = None) -> dict[str, DomainScenario]: + """Build a merged scenario registry. + + Starts with the 10 built-in scenarios and merges in any domains from + the given *config* path (a single YAML file or a directory of them). + Custom domains with the same name as a built-in override the built-in. + """ + from ocelgen.scenarios.registry import SCENARIO_REGISTRY + + registry = dict(SCENARIO_REGISTRY) + + if config is None: + return registry + + if config.is_file(): + registry.update(load_domains_from_yaml(config)) + elif config.is_dir(): + registry.update(load_domains_from_dir(config)) + else: + raise ValueError(f"Config path does not exist: {config}") + + return registry diff --git a/src/ocelgen/scenarios/registry.py b/src/ocelgen/scenarios/registry.py index 99c3d28..b60b5f2 100644 --- a/src/ocelgen/scenarios/registry.py +++ b/src/ocelgen/scenarios/registry.py @@ -346,6 +346,10 @@ } -def get_scenario(name: str) -> DomainScenario: +def get_scenario( + name: str, + registry: dict[str, DomainScenario] | None = None, +) -> DomainScenario: """Look up a domain scenario by name. Raises KeyError if not found.""" - return SCENARIO_REGISTRY[name] + reg = registry if registry is not None else SCENARIO_REGISTRY + return reg[name] diff --git a/tests/test_cli_new.py b/tests/test_cli_new.py index 58d8fdf..517bb9b 100644 --- a/tests/test_cli_new.py +++ b/tests/test_cli_new.py @@ -30,6 +30,31 @@ def test_enrich_requires_valid_domain(self, tmp_path) -> None: assert result.exit_code != 0 +class TestConfigFlag: + def test_list_domains_with_config_shows_custom(self, tmp_path) -> None: + config = tmp_path / "custom.yaml" + config.write_text( + """\ +domains: + - name: "my-custom-domain" + description: "Custom" + pattern: "sequential" + runs: 10 + noise: 0.1 + seed: 42 +""" + ) + result = runner.invoke(app, ["list-domains", "--config", str(config)]) + assert result.exit_code == 0 + assert "my-custom-domain" in result.output + assert "customer-support-triage" in result.output + assert "11" in result.output + + def test_list_domains_with_nonexistent_config(self) -> None: + result = runner.invoke(app, ["list-domains", "--config", "/nonexistent/path.yaml"]) + assert result.exit_code != 0 + + class TestPipelineCommand: def test_pipeline_requires_namespace(self) -> None: result = runner.invoke(app, ["pipeline", "--domain", "customer-support-triage"]) diff --git a/tests/test_yaml_loader.py b/tests/test_yaml_loader.py new file mode 100644 index 0000000..399b889 --- /dev/null +++ b/tests/test_yaml_loader.py @@ -0,0 +1,276 @@ +"""Tests for YAML-based domain scenario loading.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from ocelgen.scenarios.domain import DomainScenario +from ocelgen.scenarios.loader import build_registry, load_domains_from_dir, load_domains_from_yaml + +_VALID_DOMAIN_YAML = """\ +domains: + - name: "test-domain" + description: "A test domain" + pattern: "sequential" + runs: 10 + noise: 0.1 + seed: 42 + user_queries: + - "Query one" + - "Query two" + agent_personas: + researcher: "You are a test researcher" + tool_descriptions: + web_search: "Search the web" +""" + + +class TestLoadDomainsFromYaml: + def test_load_single_domain(self, tmp_path: Path) -> None: + f = tmp_path / "domains.yaml" + f.write_text(_VALID_DOMAIN_YAML) + + result = load_domains_from_yaml(f) + + assert len(result) == 1 + d = result["test-domain"] + assert isinstance(d, DomainScenario) + assert d.name == "test-domain" + assert d.pattern == "sequential" + assert d.runs == 10 + assert d.noise == 0.1 + assert d.seed == 42 + assert d.user_queries == ["Query one", "Query two"] + assert d.agent_personas == {"researcher": "You are a test researcher"} + assert d.tool_descriptions == {"web_search": "Search the web"} + + def test_load_multiple_domains(self, tmp_path: Path) -> None: + f = tmp_path / "domains.yaml" + f.write_text( + """\ +domains: + - name: "domain-a" + description: "A" + pattern: "sequential" + runs: 10 + noise: 0.1 + seed: 1 + - name: "domain-b" + description: "B" + pattern: "parallel" + runs: 20 + noise: 0.2 + seed: 2 + - name: "domain-c" + description: "C" + pattern: "supervisor" + runs: 30 + noise: 0.3 + seed: 3 +""" + ) + result = load_domains_from_yaml(f) + assert len(result) == 3 + assert set(result.keys()) == {"domain-a", "domain-b", "domain-c"} + + def test_optional_fields_default_to_empty(self, tmp_path: Path) -> None: + f = tmp_path / "domains.yaml" + f.write_text( + """\ +domains: + - name: "minimal" + description: "Minimal domain" + pattern: "sequential" + runs: 5 + noise: 0.0 + seed: 1 +""" + ) + result = load_domains_from_yaml(f) + d = result["minimal"] + assert d.user_queries == [] + assert d.agent_personas == {} + assert d.tool_descriptions == {} + + def test_invalid_pattern_raises(self, tmp_path: Path) -> None: + f = tmp_path / "domains.yaml" + f.write_text( + """\ +domains: + - name: "bad" + description: "Bad" + pattern: "unknown" + runs: 10 + noise: 0.1 + seed: 1 +""" + ) + with pytest.raises(ValueError, match="pattern"): + load_domains_from_yaml(f) + + def test_missing_required_field_raises(self, tmp_path: Path) -> None: + f = tmp_path / "domains.yaml" + f.write_text( + """\ +domains: + - description: "No name" + pattern: "sequential" + runs: 10 + noise: 0.1 + seed: 1 +""" + ) + with pytest.raises(ValueError, match="missing required fields"): + load_domains_from_yaml(f) + + def test_invalid_noise_range_raises(self, tmp_path: Path) -> None: + f = tmp_path / "domains.yaml" + f.write_text( + """\ +domains: + - name: "bad-noise" + description: "Bad" + pattern: "sequential" + runs: 10 + noise: 1.5 + seed: 1 +""" + ) + with pytest.raises(ValueError, match="noise"): + load_domains_from_yaml(f) + + def test_malformed_yaml_raises(self, tmp_path: Path) -> None: + f = tmp_path / "domains.yaml" + f.write_text("domains:\n - name: [invalid yaml structure\n") + with pytest.raises(ValueError, match="Failed to parse YAML"): + load_domains_from_yaml(f) + + def test_empty_file_returns_empty(self, tmp_path: Path) -> None: + f = tmp_path / "domains.yaml" + f.write_text("") + result = load_domains_from_yaml(f) + assert result == {} + + def test_no_domains_key_returns_empty(self, tmp_path: Path) -> None: + f = tmp_path / "domains.yaml" + f.write_text("other_key: value\n") + result = load_domains_from_yaml(f) + assert result == {} + + def test_negative_runs_raises(self, tmp_path: Path) -> None: + f = tmp_path / "domains.yaml" + f.write_text( + """\ +domains: + - name: "bad-runs" + description: "Bad" + pattern: "sequential" + runs: -5 + noise: 0.1 + seed: 1 +""" + ) + with pytest.raises(ValueError, match="runs"): + load_domains_from_yaml(f) + + +class TestLoadDomainsFromDir: + def test_load_from_directory(self, tmp_path: Path) -> None: + (tmp_path / "a.yaml").write_text( + """\ +domains: + - name: "from-a" + description: "A" + pattern: "sequential" + runs: 10 + noise: 0.1 + seed: 1 +""" + ) + (tmp_path / "b.yml").write_text( + """\ +domains: + - name: "from-b" + description: "B" + pattern: "parallel" + runs: 20 + noise: 0.2 + seed: 2 +""" + ) + result = load_domains_from_dir(tmp_path) + assert len(result) == 2 + assert "from-a" in result + assert "from-b" in result + + def test_later_file_overrides_earlier(self, tmp_path: Path) -> None: + (tmp_path / "01.yaml").write_text( + """\ +domains: + - name: "shared" + description: "First" + pattern: "sequential" + runs: 10 + noise: 0.1 + seed: 1 +""" + ) + (tmp_path / "02.yaml").write_text( + """\ +domains: + - name: "shared" + description: "Second" + pattern: "parallel" + runs: 99 + noise: 0.5 + seed: 2 +""" + ) + result = load_domains_from_dir(tmp_path) + assert len(result) == 1 + assert result["shared"].description == "Second" + assert result["shared"].runs == 99 + + +class TestBuildRegistry: + def test_no_config_returns_builtins(self) -> None: + result = build_registry(None) + assert len(result) == 10 + assert "customer-support-triage" in result + + def test_merge_adds_new_domain(self, tmp_path: Path) -> None: + f = tmp_path / "custom.yaml" + f.write_text(_VALID_DOMAIN_YAML) + result = build_registry(f) + assert len(result) == 11 + assert "test-domain" in result + assert "customer-support-triage" in result + + def test_override_builtin(self, tmp_path: Path) -> None: + f = tmp_path / "override.yaml" + f.write_text( + """\ +domains: + - name: "customer-support-triage" + description: "Custom override" + pattern: "sequential" + runs: 999 + noise: 0.05 + seed: 1001 +""" + ) + result = build_registry(f) + assert len(result) == 10 + assert result["customer-support-triage"].runs == 999 + assert result["customer-support-triage"].description == "Custom override" + + def test_config_directory(self, tmp_path: Path) -> None: + (tmp_path / "extra.yaml").write_text(_VALID_DOMAIN_YAML) + result = build_registry(tmp_path) + assert len(result) == 11 + + def test_nonexistent_path_raises(self, tmp_path: Path) -> None: + with pytest.raises(ValueError, match="does not exist"): + build_registry(tmp_path / "nonexistent")