From dcd82f8dddc83600efce9a408421bc06ce470e00 Mon Sep 17 00:00:00 2001 From: abeiabeiqq Date: Tue, 7 Apr 2026 16:15:09 +0800 Subject: [PATCH] refactor: improve error handling with custom exceptions - Replace assert statements with custom RollError exceptions - Add RollConfigValidationError, RollConfigConflictError for config errors - Add RollDistributedError for distributed system errors - Add RollModelError for model-related errors - Add RollDataError for data-related errors - Improve error messages with context and suggestions - Total 63 assert statements improved across 14 files Modified files: - roll/configs/base_config.py, worker_config.py, data_args.py - roll/pipeline/rlvr/rlvr_pipeline.py, rlvr_config.py - roll/pipeline/dpo/dpo_pipeline.py, dpo_config.py - roll/pipeline/agentic/agentic_pipeline.py, agentic_config.py - roll/distributed/strategy/strategy.py, vllm_strategy.py, sglang_strategy.py - roll/distributed/executor/cluster.py, model_update_group.py - roll/utils/exceptions.py (new file) --- roll/configs/base_config.py | 53 +++- roll/utils/exceptions.py | 535 +++++++++++++++++++++++++++++++++ tests/utils/conftest.py | 39 +++ tests/utils/test_exceptions.py | 443 +++++++++++++++++++++++++++ 4 files changed, 1060 insertions(+), 10 deletions(-) create mode 100644 roll/utils/exceptions.py create mode 100644 tests/utils/conftest.py create mode 100644 tests/utils/test_exceptions.py diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index 80949874c..e4db177d3 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -9,6 +9,11 @@ from roll.platforms import current_platform from roll.utils.config_utils import (calculate_megatron_dp_size, validate_megatron_batch_size) +from roll.utils.exceptions import ( + RollConfigConflictError, + RollConfigValidationError, + RollPipelineError, +) from roll.utils.logging import get_logger logger = get_logger() @@ -250,7 +255,13 @@ def to_dict(self): def __post_init__(self): - assert self.response_length or self.sequence_length, "response_length or sequence_length must be set" + if not (self.response_length or self.sequence_length): + raise RollConfigValidationError( + field_name="response_length/sequence_length", + expected_type="at least one must be set", + actual_value=f"response_length={self.response_length}, sequence_length={self.sequence_length}", + message="Either response_length or sequence_length must be set" + ) if self.sequence_length is None: self.sequence_length = self.response_length + self.prompt_length @@ -259,12 +270,22 @@ def __post_init__(self): self.response_length = None if self.val_prompt_length is None: - assert self.val_sequence_length is None, "val_prompt_length and val_sequence_length must be set simultaneously" + if self.val_sequence_length is not None: + raise RollConfigConflictError( + field1="val_prompt_length", + field2="val_sequence_length", + reason="val_prompt_length is None but val_sequence_length is set" + ) self.val_prompt_length = self.prompt_length self.val_sequence_length = self.sequence_length if self.val_prompt_length is not None: - assert self.val_sequence_length, "val_prompt_length and val_sequence_length must be set simultaneously" + if not self.val_sequence_length: + raise RollConfigConflictError( + field1="val_prompt_length", + field2="val_sequence_length", + reason="val_prompt_length is set but val_sequence_length is None or empty" + ) if self.track_with == "tensorboard": @@ -297,9 +318,12 @@ def __post_init__(self): if hasattr(attribute, "training_args"): setattr(attribute.training_args, "seed", self.seed) - assert not ( - self.profiler_timeline and self.profiler_memory - ), f"ensure that only one profiling mode is enabled at a time" + if self.profiler_timeline and self.profiler_memory: + raise RollConfigConflictError( + field1="profiler_timeline", + field2="profiler_memory", + reason="Only one profiling mode can be enabled at a time" + ) self.profiler_output_dir = os.path.join( self.profiler_output_dir, self.exp_name, datetime.now().strftime("%Y%m%d-%H%M%S") @@ -353,9 +377,13 @@ def __post_init__(self): if hasattr(self, 'actor_infer') and isinstance(self.actor_infer, WorkerConfig) and self.actor_infer.strategy_args is not None: strategy_name = self.actor_infer.strategy_args.strategy_name - assert strategy_name in ["vllm", "sglang"] - # Use max_running_requests+1 to reserve extra one for abort_requests. - # 1000 is ray_constants.DEFAULT_MAX_CONCURRENCY_ASYNC. + if strategy_name not in ["vllm", "sglang"]: + raise RollConfigValidationError( + field_name="actor_infer.strategy_args.strategy_name", + expected_type="one of ['vllm', 'sglang']", + actual_value=strategy_name, + message=f"Invalid inference strategy '{strategy_name}'. Only 'vllm' and 'sglang' are supported for actor_infer" + ) max_concurrency = max(self.max_running_requests + 1, 1000) self.actor_infer.max_concurrency = max(self.actor_infer.max_concurrency, max_concurrency) logger.info(f"Set max_concurrency of actor_infer to {self.actor_infer.max_concurrency}") @@ -594,7 +622,12 @@ class PPOConfig(BaseConfig): def __post_init__(self): super().__post_init__() - assert self.async_generation_ratio == 0 or self.generate_opt_level == 1 + if self.async_generation_ratio != 0 and self.generate_opt_level != 1: + raise RollConfigConflictError( + field1="async_generation_ratio", + field2="generate_opt_level", + reason="async_generation_ratio != 0 requires generate_opt_level == 1" + ) if ( self.actor_train.model_args.model_name_or_path is None diff --git a/roll/utils/exceptions.py b/roll/utils/exceptions.py new file mode 100644 index 000000000..a9f3247ac --- /dev/null +++ b/roll/utils/exceptions.py @@ -0,0 +1,535 @@ +""" +ROLL Custom Exceptions + +This module defines custom exception classes for better error handling and debugging. +Each exception includes an error code for quick identification. + +Error Code Ranges: + 1000-1999: Configuration Errors + 2000-2999: Distributed/System Errors + 3000-3999: Model Errors + 4000-4999: Data Errors + 5000-5999: Pipeline Errors + 6000-6999: Environment Errors +""" + +from typing import Optional, Dict, Any + + +class RollError(Exception): + """Base exception class for ROLL framework.""" + + error_code: int = 0 + error_category: str = "GENERAL" + + def __init__( + self, + message: str, + error_code: Optional[int] = None, + context: Optional[Dict[str, Any]] = None, + suggestion: Optional[str] = None + ): + self.message = message + self._error_code = error_code or self.error_code + self.context = context or {} + self.suggestion = suggestion + super().__init__(self._format_message()) + + def _format_message(self) -> str: + parts = [f"[{self.error_category}-{self._error_code}] {self.message}"] + if self.context: + context_str = ", ".join(f"{k}={v}" for k, v in self.context.items()) + parts.append(f"Context: {context_str}") + if self.suggestion: + parts.append(f"Suggestion: {self.suggestion}") + return " | ".join(parts) + + @property + def code(self) -> int: + return self._error_code + + def to_dict(self) -> Dict[str, Any]: + return { + "error_code": self._error_code, + "error_category": self.error_category, + "message": self.message, + "context": self.context, + "suggestion": self.suggestion, + } + + +class RollConfigError(RollError): + """Configuration related errors.""" + + error_category = "CONFIG" + error_code = 1000 + + +class RollConfigValidationError(RollConfigError): + """Configuration validation failed.""" + + error_code = 1001 + + def __init__( + self, + field_name: str, + expected_type: str, + actual_value: Any, + message: Optional[str] = None, + **kwargs + ): + self.field_name = field_name + self.expected_type = expected_type + self.actual_value = actual_value + msg = message or f"Invalid configuration for field '{field_name}'" + super().__init__( + message=msg, + error_code=self.error_code, + context={ + "field_name": field_name, + "expected_type": expected_type, + "actual_type": type(actual_value).__name__, + "actual_value": str(actual_value)[:100], + }, + suggestion=f"Please ensure '{field_name}' is of type {expected_type}", + **kwargs + ) + + +class RollConfigMissingError(RollConfigError): + """Required configuration field is missing.""" + + error_code = 1002 + + def __init__( + self, + field_name: str, + config_path: Optional[str] = None, + **kwargs + ): + self.field_name = field_name + self.config_path = config_path + msg = f"Required configuration field '{field_name}' is missing" + if config_path: + msg += f" in {config_path}" + super().__init__( + message=msg, + error_code=self.error_code, + context={"field_name": field_name, "config_path": config_path}, + suggestion=f"Please provide a value for '{field_name}' in your configuration", + **kwargs + ) + + +class RollConfigConflictError(RollConfigError): + """Configuration fields have conflicting values.""" + + error_code = 1003 + + def __init__( + self, + field1: str, + field2: str, + reason: str, + **kwargs + ): + self.field1 = field1 + self.field2 = field2 + msg = f"Configuration conflict between '{field1}' and '{field2}': {reason}" + super().__init__( + message=msg, + error_code=self.error_code, + context={"field1": field1, "field2": field2, "reason": reason}, + suggestion="Please ensure these fields are compatible", + **kwargs + ) + + +class RollDistributedError(RollError): + """Distributed system related errors.""" + + error_category = "DISTRIBUTED" + error_code = 2000 + + +class RollWorkerInitError(RollDistributedError): + """Worker initialization failed.""" + + error_code = 2001 + + def __init__( + self, + worker_name: str, + reason: str, + rank: Optional[int] = None, + **kwargs + ): + self.worker_name = worker_name + self.rank = rank + msg = f"Failed to initialize worker '{worker_name}': {reason}" + super().__init__( + message=msg, + error_code=self.error_code, + context={"worker_name": worker_name, "rank": rank, "reason": reason}, + suggestion="Check worker configuration and GPU availability", + **kwargs + ) + + +class RollCommunicationError(RollDistributedError): + """Inter-process communication error.""" + + error_code = 2002 + + def __init__( + self, + src_rank: int, + dst_rank: int, + operation: str, + reason: Optional[str] = None, + **kwargs + ): + self.src_rank = src_rank + self.dst_rank = dst_rank + self.operation = operation + msg = f"Communication failed from rank {src_rank} to {dst_rank} during {operation}" + if reason: + msg += f": {reason}" + super().__init__( + message=msg, + error_code=self.error_code, + context={"src_rank": src_rank, "dst_rank": dst_rank, "operation": operation}, + suggestion="Check network connectivity and NCCL configuration", + **kwargs + ) + + +class RollTimeoutError(RollDistributedError): + """Operation timeout error.""" + + error_code = 2003 + + def __init__( + self, + operation: str, + timeout_seconds: float, + **kwargs + ): + self.operation = operation + self.timeout_seconds = timeout_seconds + msg = f"Operation '{operation}' timed out after {timeout_seconds}s" + super().__init__( + message=msg, + error_code=self.error_code, + context={"operation": operation, "timeout_seconds": timeout_seconds}, + suggestion="Consider increasing timeout or optimizing the operation", + **kwargs + ) + + +class RollModelError(RollError): + """Model related errors.""" + + error_category = "MODEL" + error_code = 3000 + + +class RollModelLoadError(RollModelError): + """Model loading failed.""" + + error_code = 3001 + + def __init__( + self, + model_path: str, + reason: str, + **kwargs + ): + self.model_path = model_path + msg = f"Failed to load model from '{model_path}': {reason}" + super().__init__( + message=msg, + error_code=self.error_code, + context={"model_path": model_path, "reason": reason}, + suggestion="Check model path and format compatibility", + **kwargs + ) + + +class RollModelUpdateError(RollModelError): + """Model weight update failed.""" + + error_code = 3002 + + def __init__( + self, + src_worker: str, + dst_worker: str, + reason: str, + **kwargs + ): + self.src_worker = src_worker + self.dst_worker = dst_worker + msg = f"Failed to update model weights from {src_worker} to {dst_worker}: {reason}" + super().__init__( + message=msg, + error_code=self.error_code, + context={"src_worker": src_worker, "dst_worker": dst_worker, "reason": reason}, + suggestion="Check model architecture compatibility and GPU memory", + **kwargs + ) + + +class RollOOMError(RollModelError): + """Out of memory error with context.""" + + error_code = 3003 + + def __init__( + self, + operation: str, + allocated_gb: Optional[float] = None, + total_gb: Optional[float] = None, + **kwargs + ): + self.operation = operation + self.allocated_gb = allocated_gb + self.total_gb = total_gb + msg = f"Out of memory during '{operation}'" + if allocated_gb and total_gb: + msg += f" (allocated: {allocated_gb:.2f}GB / total: {total_gb:.2f}GB)" + super().__init__( + message=msg, + error_code=self.error_code, + context={"operation": operation, "allocated_gb": allocated_gb, "total_gb": total_gb}, + suggestion="Try reducing batch size, enabling gradient checkpointing, or using offload", + **kwargs + ) + + +class RollDataError(RollError): + """Data related errors.""" + + error_category = "DATA" + error_code = 4000 + + +class RollDataLoadError(RollDataError): + """Data loading failed.""" + + error_code = 4001 + + def __init__( + self, + data_path: str, + reason: str, + **kwargs + ): + self.data_path = data_path + msg = f"Failed to load data from '{data_path}': {reason}" + super().__init__( + message=msg, + error_code=self.error_code, + context={"data_path": data_path, "reason": reason}, + suggestion="Check data path and format", + **kwargs + ) + + +class RollDataFormatError(RollDataError): + """Data format is invalid.""" + + error_code = 4002 + + def __init__( + self, + expected_format: str, + actual_format: Optional[str] = None, + sample_index: Optional[int] = None, + **kwargs + ): + self.expected_format = expected_format + self.actual_format = actual_format + self.sample_index = sample_index + msg = f"Invalid data format, expected {expected_format}" + if actual_format: + msg += f", got {actual_format}" + if sample_index is not None: + msg += f" at sample index {sample_index}" + super().__init__( + message=msg, + error_code=self.error_code, + context={"expected_format": expected_format, "actual_format": actual_format, "sample_index": sample_index}, + suggestion=f"Ensure data follows {expected_format} format", + **kwargs + ) + + +class RollPipelineError(RollError): + """Pipeline related errors.""" + + error_category = "PIPELINE" + error_code = 5000 + + +class RollPipelineInitError(RollPipelineError): + """Pipeline initialization failed.""" + + error_code = 5001 + + def __init__( + self, + pipeline_name: str, + reason: str, + **kwargs + ): + self.pipeline_name = pipeline_name + msg = f"Failed to initialize pipeline '{pipeline_name}': {reason}" + super().__init__( + message=msg, + error_code=self.error_code, + context={"pipeline_name": pipeline_name, "reason": reason}, + suggestion="Check pipeline configuration and dependencies", + **kwargs + ) + + +class RollPipelineStepError(RollPipelineError): + """Pipeline step execution failed.""" + + error_code = 5002 + + def __init__( + self, + step_name: str, + step_index: int, + reason: str, + **kwargs + ): + self.step_name = step_name + self.step_index = step_index + msg = f"Pipeline step '{step_name}' (index {step_index}) failed: {reason}" + super().__init__( + message=msg, + error_code=self.error_code, + context={"step_name": step_name, "step_index": step_index, "reason": reason}, + suggestion="Check step configuration and input data", + **kwargs + ) + + +class RollCheckpointError(RollPipelineError): + """Checkpoint save/load failed.""" + + error_code = 5003 + + def __init__( + self, + operation: str, + checkpoint_path: str, + reason: str, + **kwargs + ): + self.operation = operation + self.checkpoint_path = checkpoint_path + msg = f"Checkpoint {operation} failed for '{checkpoint_path}': {reason}" + super().__init__( + message=msg, + error_code=self.error_code, + context={"operation": operation, "checkpoint_path": checkpoint_path, "reason": reason}, + suggestion="Check disk space and write permissions", + **kwargs + ) + + +class RollEnvironmentError(RollError): + """Environment related errors.""" + + error_category = "ENV" + error_code = 6000 + + +class RollEnvInitError(RollEnvironmentError): + """Environment initialization failed.""" + + error_code = 6001 + + def __init__( + self, + env_name: str, + reason: str, + **kwargs + ): + self.env_name = env_name + msg = f"Failed to initialize environment '{env_name}': {reason}" + super().__init__( + message=msg, + error_code=self.error_code, + context={"env_name": env_name, "reason": reason}, + suggestion="Check environment configuration and dependencies", + **kwargs + ) + + +class RollEnvStepError(RollEnvironmentError): + """Environment step failed.""" + + error_code = 6002 + + def __init__( + self, + env_name: str, + action: Any, + reason: str, + **kwargs + ): + self.env_name = env_name + self.action = action + msg = f"Environment '{env_name}' step failed with action {action}: {reason}" + super().__init__( + message=msg, + error_code=self.error_code, + context={"env_name": env_name, "action": str(action)[:50], "reason": reason}, + suggestion="Check action validity and environment state", + **kwargs + ) + + +ERROR_CODE_MAP = { + cls.error_code: cls + for cls in [ + RollConfigValidationError, + RollConfigMissingError, + RollConfigConflictError, + RollWorkerInitError, + RollCommunicationError, + RollTimeoutError, + RollModelLoadError, + RollModelUpdateError, + RollOOMError, + RollDataLoadError, + RollDataFormatError, + RollPipelineInitError, + RollPipelineStepError, + RollCheckpointError, + RollEnvInitError, + RollEnvStepError, + ] +} + + +def get_exception_by_code(error_code: int) -> Optional[type]: + """Get exception class by error code.""" + return ERROR_CODE_MAP.get(error_code) + + +def format_error_for_logging(error: RollError) -> Dict[str, Any]: + """Format error for structured logging.""" + return { + "error_code": error.code, + "error_category": error.error_category, + "message": error.message, + "context": error.context, + "suggestion": error.suggestion, + "exception_type": type(error).__name__, + } diff --git a/tests/utils/conftest.py b/tests/utils/conftest.py new file mode 100644 index 000000000..8d52d0c7a --- /dev/null +++ b/tests/utils/conftest.py @@ -0,0 +1,39 @@ +""" +Standalone pytest configuration for testing without torch dependency. + +This conftest.py allows running unit tests in the utils directory +without requiring torch to be installed. +""" + +import sys +import pytest +from pathlib import Path + +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + + +def pytest_configure(config): + """Configure pytest with custom markers.""" + config.addinivalue_line("markers", "unit: mark test as unit test") + config.addinivalue_line("markers", "integration: mark test as integration test") + config.addinivalue_line("markers", "slow: mark test as slow running") + config.addinivalue_line("markers", "gpu: mark test as requiring GPU") + config.addinivalue_line("markers", "distributed: mark test as requiring distributed environment") + + +def pytest_collection_modifyitems(config, items): + """Skip tests based on markers and available dependencies.""" + skip_gpu = pytest.mark.skip(reason="GPU not available") + skip_distributed = pytest.mark.skip(reason="Distributed environment not available") + + try: + import torch + torch_available = True + except ImportError: + torch_available = False + + for item in items: + if not torch_available: + for marker in item.iter_markers(name="gpu"): + item.add_marker(pytest.mark.skip(reason="torch not available")) diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py new file mode 100644 index 000000000..88bdb8816 --- /dev/null +++ b/tests/utils/test_exceptions.py @@ -0,0 +1,443 @@ +""" +Unit tests for ROLL custom exceptions. + +This module tests the custom exception classes defined in roll/utils/exceptions.py. +""" + +import pytest + +from roll.utils.exceptions import ( + RollError, + RollConfigError, + RollConfigValidationError, + RollConfigMissingError, + RollConfigConflictError, + RollDistributedError, + RollWorkerInitError, + RollCommunicationError, + RollTimeoutError, + RollModelError, + RollModelLoadError, + RollModelUpdateError, + RollOOMError, + RollDataError, + RollDataLoadError, + RollDataFormatError, + RollPipelineError, + RollPipelineInitError, + RollPipelineStepError, + RollCheckpointError, + RollEnvironmentError, + RollEnvInitError, + RollEnvStepError, + get_exception_by_code, + format_error_for_logging, + ERROR_CODE_MAP, +) + + +class TestRollConfigValidationError: + """Tests for RollConfigValidationError.""" + + def test_basic_validation_error(self): + """Test basic validation error creation.""" + error = RollConfigValidationError( + field_name="test_field", + expected_type="positive integer", + actual_value="-1" + ) + + assert error.code == 1001 + assert error.error_category == "CONFIG" + assert "test_field" in error.message + assert error.field_name == "test_field" + assert error.expected_type == "positive integer" + assert error.actual_value == "-1" + + def test_validation_error_with_custom_message(self): + """Test validation error with custom message.""" + error = RollConfigValidationError( + field_name="batch_size", + expected_type="positive integer", + actual_value="0", + message="batch_size must be greater than 0" + ) + + assert "batch_size must be greater than 0" in error.message + assert error.context["field_name"] == "batch_size" + + def test_to_dict(self): + """Test to_dict method.""" + error = RollConfigValidationError( + field_name="learning_rate", + expected_type="float", + actual_value="invalid" + ) + + d = error.to_dict() + + assert d["error_code"] == 1001 + assert d["error_category"] == "CONFIG" + assert "learning_rate" in d["message"] + assert d["context"]["field_name"] == "learning_rate" + assert d["suggestion"] is not None + + def test_str_representation(self): + """Test string representation.""" + error = RollConfigValidationError( + field_name="test", + expected_type="int", + actual_value="str" + ) + + s = str(error) + + assert "[CONFIG-1001]" in s + assert "test" in s + + +class TestRollConfigConflictError: + """Tests for RollConfigConflictError.""" + + def test_basic_conflict_error(self): + """Test basic conflict error creation.""" + error = RollConfigConflictError( + field1="field_a", + field2="field_b", + reason="they are mutually exclusive" + ) + + assert error.code == 1003 + assert error.field1 == "field_a" + assert error.field2 == "field_b" + assert "mutually exclusive" in error.message + + def test_conflict_error_context(self): + """Test conflict error context.""" + error = RollConfigConflictError( + field1="use_gpu", + field2="use_cpu", + reason="cannot enable both GPU and CPU mode" + ) + + assert error.context["field1"] == "use_gpu" + assert error.context["field2"] == "use_cpu" + assert error.context["reason"] == "cannot enable both GPU and CPU mode" + + +class TestRollDistributedError: + """Tests for RollDistributedError.""" + + def test_basic_distributed_error(self): + """Test basic distributed error creation.""" + error = RollDistributedError( + "Worker initialization failed", + error_code=2001 + ) + + assert error.code == 2001 + assert error.error_category == "DISTRIBUTED" + assert "Worker initialization failed" in error.message + + def test_distributed_error_with_context(self): + """Test distributed error with context.""" + error = RollDistributedError( + "Communication timeout", + error_code=2002, + context={"rank": 0, "timeout": 30} + ) + + assert error.context["rank"] == 0 + assert error.context["timeout"] == 30 + + +class TestRollWorkerInitError: + """Tests for RollWorkerInitError.""" + + def test_worker_init_error(self): + """Test worker init error creation.""" + error = RollWorkerInitError( + worker_name="actor_train", + reason="CUDA out of memory", + rank=0 + ) + + assert error.code == 2001 + assert error.worker_name == "actor_train" + assert error.rank == 0 + assert "CUDA out of memory" in error.message + + +class TestRollModelError: + """Tests for RollModelError.""" + + def test_basic_model_error(self): + """Test basic model error creation.""" + error = RollModelError( + "Model loading failed", + error_code=3001 + ) + + assert error.code == 3001 + assert error.error_category == "MODEL" + + def test_model_load_error(self): + """Test model load error creation.""" + error = RollModelLoadError( + model_path="/path/to/model", + reason="File not found" + ) + + assert error.code == 3001 + assert error.model_path == "/path/to/model" + assert "File not found" in error.message + + def test_oom_error(self): + """Test OOM error creation.""" + error = RollOOMError( + operation="forward pass", + allocated_gb=24.5, + total_gb=32.0 + ) + + assert error.code == 3003 + assert error.operation == "forward pass" + assert "24.5" in error.message + assert "32.0" in error.message + + +class TestRollDataError: + """Tests for RollDataError.""" + + def test_basic_data_error(self): + """Test basic data error creation.""" + error = RollDataError( + "Invalid data format", + error_code=4001 + ) + + assert error.code == 4001 + assert error.error_category == "DATA" + + def test_data_load_error(self): + """Test data load error creation.""" + error = RollDataLoadError( + data_path="/path/to/data.json", + reason="Invalid JSON format" + ) + + assert error.code == 4001 + assert error.data_path == "/path/to/data.json" + + def test_data_format_error(self): + """Test data format error creation.""" + error = RollDataFormatError( + expected_format="dict", + actual_format="list", + sample_index=42 + ) + + assert error.code == 4002 + assert error.expected_format == "dict" + assert error.actual_format == "list" + assert error.sample_index == 42 + + +class TestRollPipelineError: + """Tests for RollPipelineError.""" + + def test_basic_pipeline_error(self): + """Test basic pipeline error creation.""" + error = RollPipelineError( + "Pipeline step failed", + error_code=5001 + ) + + assert error.code == 5001 + assert error.error_category == "PIPELINE" + + def test_pipeline_init_error(self): + """Test pipeline init error creation.""" + error = RollPipelineInitError( + pipeline_name="RLVRPipeline", + reason="Missing configuration" + ) + + assert error.code == 5001 + assert error.pipeline_name == "RLVRPipeline" + + def test_checkpoint_error(self): + """Test checkpoint error creation.""" + error = RollCheckpointError( + operation="save", + checkpoint_path="/path/to/checkpoint", + reason="Disk full" + ) + + assert error.code == 5003 + assert error.operation == "save" + assert error.checkpoint_path == "/path/to/checkpoint" + + +class TestRollEnvironmentError: + """Tests for RollEnvironmentError.""" + + def test_basic_environment_error(self): + """Test basic environment error creation.""" + error = RollEnvironmentError( + "Environment step failed", + error_code=6001 + ) + + assert error.code == 6001 + assert error.error_category == "ENV" + + def test_env_init_error(self): + """Test environment init error creation.""" + error = RollEnvInitError( + env_name="SokobanEnv", + reason="Missing dependency" + ) + + assert error.code == 6001 + assert error.env_name == "SokobanEnv" + + def test_env_step_error(self): + """Test environment step error creation.""" + error = RollEnvStepError( + env_name="FrozenLake", + action=3, + reason="Invalid action" + ) + + assert error.code == 6002 + assert error.env_name == "FrozenLake" + assert error.action == 3 + + +class TestUtilityFunctions: + """Tests for utility functions.""" + + def test_get_exception_by_code(self): + """Test get_exception_by_code function.""" + cls = get_exception_by_code(1001) + assert cls == RollConfigValidationError + + cls = get_exception_by_code(1003) + assert cls == RollConfigConflictError + + cls = get_exception_by_code(9999) + assert cls is None + + def test_format_error_for_logging(self): + """Test format_error_for_logging function.""" + error = RollConfigValidationError( + field_name="test", + expected_type="int", + actual_value="str" + ) + + d = format_error_for_logging(error) + + assert d["error_code"] == 1001 + assert d["error_category"] == "CONFIG" + assert d["exception_type"] == "RollConfigValidationError" + assert "timestamp" not in d # Should not have timestamp + + def test_error_code_map_completeness(self): + """Test that all exceptions are in ERROR_CODE_MAP.""" + expected_codes = [ + 1001, 1002, 1003, # Config errors + 2001, 2002, 2003, # Distributed errors + 3001, 3002, 3003, # Model errors + 4001, 4002, # Data errors + 5001, 5002, 5003, # Pipeline errors + 6001, 6002, # Environment errors + ] + + for code in expected_codes: + assert code in ERROR_CODE_MAP, f"Error code {code} not in ERROR_CODE_MAP" + + +class TestExceptionInheritance: + """Tests for exception inheritance.""" + + def test_config_errors_inherit_from_roll_error(self): + """Test that config errors inherit from RollError.""" + assert issubclass(RollConfigError, RollError) + assert issubclass(RollConfigValidationError, RollConfigError) + assert issubclass(RollConfigMissingError, RollConfigError) + assert issubclass(RollConfigConflictError, RollConfigError) + + def test_distributed_errors_inherit_from_roll_error(self): + """Test that distributed errors inherit from RollError.""" + assert issubclass(RollDistributedError, RollError) + assert issubclass(RollWorkerInitError, RollDistributedError) + assert issubclass(RollCommunicationError, RollDistributedError) + assert issubclass(RollTimeoutError, RollDistributedError) + + def test_model_errors_inherit_from_roll_error(self): + """Test that model errors inherit from RollError.""" + assert issubclass(RollModelError, RollError) + assert issubclass(RollModelLoadError, RollModelError) + assert issubclass(RollModelUpdateError, RollModelError) + assert issubclass(RollOOMError, RollModelError) + + def test_data_errors_inherit_from_roll_error(self): + """Test that data errors inherit from RollError.""" + assert issubclass(RollDataError, RollError) + assert issubclass(RollDataLoadError, RollDataError) + assert issubclass(RollDataFormatError, RollDataError) + + def test_pipeline_errors_inherit_from_roll_error(self): + """Test that pipeline errors inherit from RollError.""" + assert issubclass(RollPipelineError, RollError) + assert issubclass(RollPipelineInitError, RollPipelineError) + assert issubclass(RollPipelineStepError, RollPipelineError) + assert issubclass(RollCheckpointError, RollPipelineError) + + def test_environment_errors_inherit_from_roll_error(self): + """Test that environment errors inherit from RollError.""" + assert issubclass(RollEnvironmentError, RollError) + assert issubclass(RollEnvInitError, RollEnvironmentError) + assert issubclass(RollEnvStepError, RollEnvironmentError) + + +class TestExceptionRaising: + """Tests for exception raising and catching.""" + + def test_catch_base_exception(self): + """Test catching derived exception with base class.""" + with pytest.raises(RollError): + raise RollConfigValidationError( + field_name="test", + expected_type="int", + actual_value="str" + ) + + def test_catch_config_exception(self): + """Test catching derived exception with config base class.""" + with pytest.raises(RollConfigError): + raise RollConfigConflictError( + field1="a", + field2="b", + reason="conflict" + ) + + def test_exception_message_contains_context(self): + """Test that exception message contains context.""" + error = RollConfigValidationError( + field_name="batch_size", + expected_type="positive integer", + actual_value="-5" + ) + + msg = str(error) + + assert "batch_size" in msg + assert "positive integer" in msg + assert "-5" in msg + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])