Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"openai>=1.0",
"huggingface_hub>=0.20",
"pyarrow>=15.0",
"pyyaml>=6.0",
]

[project.optional-dependencies]
Expand Down
61 changes: 46 additions & 15 deletions src/ocelgen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@
from ocelgen.export.normative import write_normative_model
from ocelgen.export.ocel_json import write_ocel_json
from ocelgen.generation.engine import PATTERN_REGISTRY, generate
from ocelgen.scenarios.domain import DomainScenario
from ocelgen.validation.schema import validate_ocel_file


def _get_registry(config: Path | None) -> dict[str, DomainScenario]:
"""Build the scenario registry, merging custom YAML config if provided."""
from ocelgen.scenarios.loader import build_registry

return build_registry(config)

app = typer.Typer(
name="ocelgen",
help="Mock OCEL 2.0 event log generator for LangChain multi-agent runs.",
Expand Down Expand Up @@ -129,9 +137,14 @@ def list_patterns() -> None:


@app.command("list-domains")
def list_domains() -> None:
def list_domains(
config: Annotated[
Path | None,
typer.Option("--config", "-c", help="YAML file or directory with custom domain definitions"),
] = None,
) -> None:
"""List available domain scenarios for enriched generation."""
from ocelgen.scenarios.registry import SCENARIO_REGISTRY
registry = _get_registry(config)

table = Table(title="Available Domain Scenarios")
table.add_column("Name", style="bold", no_wrap=True)
Expand All @@ -140,7 +153,7 @@ def list_domains() -> None:
table.add_column("Noise", justify="right")
table.add_column("Description")

for name, scenario in SCENARIO_REGISTRY.items():
for name, scenario in registry.items():
table.add_row(
name,
scenario.pattern,
Expand All @@ -150,7 +163,7 @@ def list_domains() -> None:
)

console.print(table)
console.print(f"\n[bold]{len(SCENARIO_REGISTRY)}[/bold] domains available.")
console.print(f"\n[bold]{len(registry)}[/bold] domains available.")


@app.command("enrich")
Expand All @@ -163,6 +176,10 @@ def enrich_cmd(
output: Annotated[
Path | None, typer.Option("-o", "--output", help="Output path")
] = None,
config: Annotated[
Path | None,
typer.Option("--config", "-c", help="YAML file or directory with custom domain definitions"),
] = None,
) -> None:
"""Enrich an OCEL 2.0 trace with LLM-generated content."""
from pydantic import TypeAdapter
Expand All @@ -172,17 +189,19 @@ def enrich_cmd(
from ocelgen.enrichment.enricher import enrich_log
from ocelgen.export.ocel_json import ocel_log_to_dict
from ocelgen.models.ocel import OcelLog
from ocelgen.scenarios.registry import SCENARIO_REGISTRY, get_scenario
from ocelgen.scenarios.registry import get_scenario

registry = _get_registry(config)

if not path.exists():
console.print(f"[red]File not found: {path}[/red]")
raise typer.Exit(1)

if domain not in SCENARIO_REGISTRY:
if domain not in registry:
console.print(f"[red]Unknown domain '{domain}'. Use 'list-domains' to see available domains.[/red]")
raise typer.Exit(1)

scenario = get_scenario(domain)
scenario = get_scenario(domain, registry=registry)

console.print(f"Loading [bold]{path}[/bold]...")
with open(path, encoding="utf-8") as f:
Expand Down Expand Up @@ -213,28 +232,34 @@ def upload_cmd(
collection: Annotated[
str, typer.Option("--collection", help="Collection slug")
] = "open-agent-traces",
config: Annotated[
Path | None,
typer.Option("--config", "-c", help="YAML file or directory with custom domain definitions"),
] = None,
) -> None:
"""Upload an enriched trace to Hugging Face Hub."""
from pydantic import TypeAdapter

from ocelgen.models.ocel import OcelLog
from ocelgen.scenarios.registry import SCENARIO_REGISTRY, get_scenario
from ocelgen.scenarios.registry import get_scenario
from ocelgen.upload.flatten import flatten_log
from ocelgen.upload.hf_upload import build_repo_name, prepare_upload_files, upload_to_hub

registry = _get_registry(config)

if not path.exists():
console.print(f"[red]File not found: {path}[/red]")
raise typer.Exit(1)

if domain not in SCENARIO_REGISTRY:
if domain not in registry:
console.print(f"[red]Unknown domain '{domain}'.[/red]")
raise typer.Exit(1)

if not namespace:
console.print("[red]--namespace is required.[/red]")
raise typer.Exit(1)

scenario = get_scenario(domain)
scenario = get_scenario(domain, registry=registry)

console.print(f"Loading [bold]{path}[/bold]...")
with open(path, encoding="utf-8") as f:
Expand Down Expand Up @@ -278,7 +303,7 @@ def pipeline_cmd(
str | None, typer.Option("--domain", "-d", help="Single domain to process")
] = None,
all_domains: Annotated[
bool, typer.Option("--all", help="Process all 10 domains")
bool, typer.Option("--all", help="Process all domains")
] = False,
namespace: Annotated[str, typer.Option("--namespace", "-n", help="HF namespace")] = "",
model: Annotated[
Expand All @@ -290,13 +315,17 @@ def pipeline_cmd(
skip_upload: Annotated[
bool, typer.Option("--skip-upload", help="Generate and enrich but don't upload")
] = False,
config: Annotated[
Path | None,
typer.Option("--config", "-c", help="YAML file or directory with custom domain definitions"),
] = None,
) -> None:
"""End-to-end pipeline: generate, enrich, and upload agent trace datasets."""
from rich.progress import Progress

from ocelgen.enrichment.client import LLMClient
from ocelgen.enrichment.enricher import enrich_log
from ocelgen.scenarios.registry import SCENARIO_REGISTRY, get_scenario
from ocelgen.scenarios.registry import get_scenario
from ocelgen.upload.flatten import flatten_log
from ocelgen.upload.hf_upload import (
build_repo_name,
Expand All @@ -305,6 +334,8 @@ def pipeline_cmd(
upload_to_hub,
)

registry = _get_registry(config)

if not namespace:
console.print("[red]--namespace is required.[/red]")
raise typer.Exit(1)
Expand All @@ -313,18 +344,18 @@ def pipeline_cmd(
console.print("[red]Specify --domain <name> or --all.[/red]")
raise typer.Exit(1)

domains: list[str] = list(SCENARIO_REGISTRY.keys()) if all_domains else [domain] # type: ignore[list-item]
domains: list[str] = list(registry.keys()) if all_domains else [domain] # type: ignore[list-item]

for d in domains:
if d not in SCENARIO_REGISTRY:
if d not in registry:
console.print(f"[red]Unknown domain '{d}'.[/red]")
raise typer.Exit(1)

client = LLMClient(model=model)
uploaded_repos: list[str] = []

for d in domains:
scenario = get_scenario(d)
scenario = get_scenario(d, registry=registry)
console.rule(f"[bold]{scenario.name}[/bold] ({scenario.pattern})")

console.print(f"Generating {scenario.runs} runs...")
Expand Down
9 changes: 8 additions & 1 deletion src/ocelgen/scenarios/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
"""Domain scenario definitions for LLM-enriched trace generation."""

from ocelgen.scenarios.domain import DomainScenario
from ocelgen.scenarios.loader import build_registry, load_domains_from_yaml
from ocelgen.scenarios.registry import SCENARIO_REGISTRY, get_scenario

__all__ = ["DomainScenario", "SCENARIO_REGISTRY", "get_scenario"]
__all__ = [
"DomainScenario",
"SCENARIO_REGISTRY",
"build_registry",
"get_scenario",
"load_domains_from_yaml",
]
115 changes: 115 additions & 0 deletions src/ocelgen/scenarios/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Load domain scenarios from YAML configuration files."""

from __future__ import annotations

from pathlib import Path

import yaml

from ocelgen.scenarios.domain import DomainScenario

_VALID_PATTERNS = {"sequential", "supervisor", "parallel"}
_REQUIRED_FIELDS = {"name", "description", "pattern", "runs", "noise", "seed"}


def load_domains_from_yaml(path: Path) -> dict[str, DomainScenario]:
"""Read a YAML file and return validated domain scenarios keyed by name.

The file must have a top-level ``domains`` key containing a list of
domain definitions whose fields match :class:`DomainScenario`.
"""
try:
raw = yaml.safe_load(path.read_text(encoding="utf-8"))
except yaml.YAMLError as exc:
raise ValueError(f"Failed to parse YAML in {path}: {exc}") from exc

if not raw or "domains" not in raw:
return {}

domains_list = raw["domains"]
if not isinstance(domains_list, list):
raise ValueError(f"Expected 'domains' to be a list in {path}")

result: dict[str, DomainScenario] = {}
for idx, entry in enumerate(domains_list):
if not isinstance(entry, dict):
raise ValueError(f"Domain entry {idx} in {path} is not a mapping")

missing = _REQUIRED_FIELDS - entry.keys()
if missing:
raise ValueError(
f"Domain entry {idx} in {path} missing required fields: {sorted(missing)}"
)

name = entry["name"]
pattern = entry["pattern"]
if pattern not in _VALID_PATTERNS:
raise ValueError(
f"Domain '{name}' in {path}: 'pattern' must be one of "
f"{sorted(_VALID_PATTERNS)} (got '{pattern}')"
)

runs = entry["runs"]
if not isinstance(runs, int) or runs <= 0:
raise ValueError(f"Domain '{name}' in {path}: 'runs' must be a positive integer")

noise = entry["noise"]
if not isinstance(noise, (int, float)) or not (0.0 <= noise <= 1.0):
raise ValueError(f"Domain '{name}' in {path}: 'noise' must be between 0.0 and 1.0")

seed = entry["seed"]
if not isinstance(seed, int):
raise ValueError(f"Domain '{name}' in {path}: 'seed' must be an integer")

result[name] = DomainScenario(
name=name,
description=entry["description"],
pattern=pattern,
runs=runs,
noise=float(noise),
seed=seed,
user_queries=entry.get("user_queries", []),
agent_personas=entry.get("agent_personas", {}),
tool_descriptions=entry.get("tool_descriptions", {}),
)

return result


def load_domains_from_dir(dir_path: Path) -> dict[str, DomainScenario]:
"""Load all ``*.yaml`` and ``*.yml`` files from a directory.

Files are processed in alphabetical order; later files override
earlier ones for domains with the same name.
"""
result: dict[str, DomainScenario] = {}
yaml_files = sorted(
p for p in dir_path.iterdir() if p.suffix in {".yaml", ".yml"}
)
for f in yaml_files:
result.update(load_domains_from_yaml(f))
return result


def build_registry(config: Path | None = None) -> dict[str, DomainScenario]:
"""Build a merged scenario registry.

Starts with the 10 built-in scenarios and merges in any domains from
the given *config* path (a single YAML file or a directory of them).
Custom domains with the same name as a built-in override the built-in.
"""
from ocelgen.scenarios.registry import SCENARIO_REGISTRY

registry = dict(SCENARIO_REGISTRY)

if config is None:
return registry

if config.is_file():
registry.update(load_domains_from_yaml(config))
elif config.is_dir():
registry.update(load_domains_from_dir(config))
else:
raise ValueError(f"Config path does not exist: {config}")

return registry
8 changes: 6 additions & 2 deletions src/ocelgen/scenarios/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,10 @@
}


def get_scenario(name: str) -> DomainScenario:
def get_scenario(
name: str,
registry: dict[str, DomainScenario] | None = None,
) -> DomainScenario:
"""Look up a domain scenario by name. Raises KeyError if not found."""
return SCENARIO_REGISTRY[name]
reg = registry if registry is not None else SCENARIO_REGISTRY
return reg[name]
25 changes: 25 additions & 0 deletions tests/test_cli_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,31 @@ def test_enrich_requires_valid_domain(self, tmp_path) -> None:
assert result.exit_code != 0


class TestConfigFlag:
def test_list_domains_with_config_shows_custom(self, tmp_path) -> None:
config = tmp_path / "custom.yaml"
config.write_text(
"""\
domains:
- name: "my-custom-domain"
description: "Custom"
pattern: "sequential"
runs: 10
noise: 0.1
seed: 42
"""
)
result = runner.invoke(app, ["list-domains", "--config", str(config)])
assert result.exit_code == 0
assert "my-custom-domain" in result.output
assert "customer-support-triage" in result.output
assert "11" in result.output

def test_list_domains_with_nonexistent_config(self) -> None:
result = runner.invoke(app, ["list-domains", "--config", "/nonexistent/path.yaml"])
assert result.exit_code != 0


class TestPipelineCommand:
def test_pipeline_requires_namespace(self) -> None:
result = runner.invoke(app, ["pipeline", "--domain", "customer-support-triage"])
Expand Down
Loading
Loading