diff --git a/cosmos_framework/callbacks/hf_export_test.py b/cosmos_framework/callbacks/hf_export_test.py new file mode 100644 index 0000000..3754482 --- /dev/null +++ b/cosmos_framework/callbacks/hf_export_test.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Hermetic tests for HF export worker error propagation.""" + +from __future__ import annotations + +import importlib +import sys +import types + +import pytest + +pytestmark = [pytest.mark.level(0), pytest.mark.gpus(0)] + + +@pytest.fixture +def hf_export_module(monkeypatch: pytest.MonkeyPatch): + module_name = "cosmos_framework.callbacks.hf_export" + + fake_torch = types.ModuleType("torch") + fake_torch.float32 = "float32" + fake_torch.float16 = "float16" + fake_torch.bfloat16 = "bfloat16" + fake_torch.float64 = "float64" + fake_torch.dtype = object + fake_torch.Tensor = object + fake_torch.distributed = types.SimpleNamespace(tensor=types.SimpleNamespace(DTensor=type("DTensor", (), {}))) + + fake_log = types.ModuleType("cosmos_framework.utils.log") + fake_log.error = lambda *args, **kwargs: None + fake_log.info = lambda *args, **kwargs: None + fake_log.warning = lambda *args, **kwargs: None + + fake_callback = types.ModuleType("cosmos_framework.utils.callback") + fake_callback.Callback = type("Callback", (), {}) + + fake_distributed = types.ModuleType("cosmos_framework.utils.distributed") + fake_distributed.is_rank0 = lambda: True + + monkeypatch.setitem(sys.modules, "torch", fake_torch) + monkeypatch.setitem(sys.modules, "cosmos_framework.utils.log", fake_log) + monkeypatch.setitem(sys.modules, "cosmos_framework.utils.callback", fake_callback) + monkeypatch.setitem(sys.modules, "cosmos_framework.utils.distributed", fake_distributed) + sys.modules.pop(module_name, None) + + module = importlib.import_module(module_name) + yield module + + sys.modules.pop(module_name, None) + + +def test_save_and_upload_stores_worker_exception_when_export_fails(hf_export_module) -> None: + callback = hf_export_module.HFExportCallback() + + def _raise(*args, **kwargs) -> None: + raise RuntimeError("worker failed") + + callback._do_save_and_upload = _raise + + callback._save_and_upload([], {}, 0, None, "model", "/tmp/export", 12) + + assert isinstance(callback._worker_exception, RuntimeError) + assert str(callback._worker_exception) == "worker failed" diff --git a/cosmos_framework/utils/log.py b/cosmos_framework/utils/log.py index 7736ef2..d82f70d 100644 --- a/cosmos_framework/utils/log.py +++ b/cosmos_framework/utils/log.py @@ -108,36 +108,65 @@ def _rank0_only_filter(record: Any) -> bool: return not is_rank0 -def trace(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).trace(message) +def _prepare_log_message( + message: str, args: tuple[Any, ...], kwargs: dict[str, Any] +) -> tuple[str, tuple[Any, ...], dict[str, Any]]: + if args: + try: + return message % args, (), kwargs + except (TypeError, ValueError): + pass + if kwargs: + try: + return message % kwargs, (), {} + except (KeyError, TypeError, ValueError): + pass + return message, args, kwargs -def debug(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).debug(message) +def _log_with_rank( + level: str, + message: str, + *args: Any, + rank0_only: bool = True, + exc_info: Any = None, + **kwargs: Any, +) -> None: + message, args, kwargs = _prepare_log_message(message, args, kwargs) + bound_logger = logger.opt(depth=1, exception=exc_info).bind(rank0_only=rank0_only) + getattr(bound_logger, level)(message, *args, **kwargs) -def info(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).info(message) +def trace(message: str, *args: Any, rank0_only: bool = True, exc_info: Any = None, **kwargs: Any) -> None: + _log_with_rank("trace", message, *args, rank0_only=rank0_only, exc_info=exc_info, **kwargs) -def success(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).success(message) +def debug(message: str, *args: Any, rank0_only: bool = True, exc_info: Any = None, **kwargs: Any) -> None: + _log_with_rank("debug", message, *args, rank0_only=rank0_only, exc_info=exc_info, **kwargs) -def warning(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).warning(message) +def info(message: str, *args: Any, rank0_only: bool = True, exc_info: Any = None, **kwargs: Any) -> None: + _log_with_rank("info", message, *args, rank0_only=rank0_only, exc_info=exc_info, **kwargs) -def error(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).error(message) +def success(message: str, *args: Any, rank0_only: bool = True, exc_info: Any = None, **kwargs: Any) -> None: + _log_with_rank("success", message, *args, rank0_only=rank0_only, exc_info=exc_info, **kwargs) -def critical(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).critical(message) +def warning(message: str, *args: Any, rank0_only: bool = True, exc_info: Any = None, **kwargs: Any) -> None: + _log_with_rank("warning", message, *args, rank0_only=rank0_only, exc_info=exc_info, **kwargs) -def exception(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).exception(message) +def error(message: str, *args: Any, rank0_only: bool = True, exc_info: Any = None, **kwargs: Any) -> None: + _log_with_rank("error", message, *args, rank0_only=rank0_only, exc_info=exc_info, **kwargs) + + +def critical(message: str, *args: Any, rank0_only: bool = True, exc_info: Any = None, **kwargs: Any) -> None: + _log_with_rank("critical", message, *args, rank0_only=rank0_only, exc_info=exc_info, **kwargs) + + +def exception(message: str, *args: Any, rank0_only: bool = True, exc_info: Any = True, **kwargs: Any) -> None: + _log_with_rank("exception", message, *args, rank0_only=rank0_only, exc_info=exc_info, **kwargs) # Execute at import time. diff --git a/cosmos_framework/utils/log_test.py b/cosmos_framework/utils/log_test.py new file mode 100644 index 0000000..04fa5c2 --- /dev/null +++ b/cosmos_framework/utils/log_test.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Hermetic tests for the project logging wrappers.""" + +from __future__ import annotations + +import importlib +import sys +import types + +import pytest + +pytestmark = [pytest.mark.level(0), pytest.mark.gpus(0)] + + +class _CaptureLogger: + def __init__(self, *args, events: list[dict] | None = None, **kwargs) -> None: + self.events = [] if events is None else events + self._options = (None, None, [], {}) + self._exception = None + self._rank0_only = True + + def remove(self, *args, **kwargs) -> None: + return None + + def add(self, *args, **kwargs) -> int: + return 0 + + def opt(self, *, depth=None, exception=None): + clone = _CaptureLogger(events=self.events) + clone._exception = exception + clone._rank0_only = self._rank0_only + return clone + + def bind(self, **kwargs): + clone = _CaptureLogger(events=self.events) + clone._exception = self._exception + clone._rank0_only = kwargs.get("rank0_only", self._rank0_only) + return clone + + def _record(self, level: str, message: str, *args, **kwargs) -> None: + self.events.append( + { + "level": level, + "message": message, + "args": args, + "kwargs": kwargs, + "exception": self._exception, + "rank0_only": self._rank0_only, + } + ) + + def trace(self, message: str, *args, **kwargs) -> None: + self._record("trace", message, *args, **kwargs) + + def debug(self, message: str, *args, **kwargs) -> None: + self._record("debug", message, *args, **kwargs) + + def info(self, message: str, *args, **kwargs) -> None: + self._record("info", message, *args, **kwargs) + + def success(self, message: str, *args, **kwargs) -> None: + self._record("success", message, *args, **kwargs) + + def warning(self, message: str, *args, **kwargs) -> None: + self._record("warning", message, *args, **kwargs) + + def error(self, message: str, *args, **kwargs) -> None: + self._record("error", message, *args, **kwargs) + + def critical(self, message: str, *args, **kwargs) -> None: + self._record("critical", message, *args, **kwargs) + + def exception(self, message: str, *args, **kwargs) -> None: + self._record("exception", message, *args, **kwargs) + + +@pytest.fixture +def log_module(monkeypatch: pytest.MonkeyPatch): + module_name = "cosmos_framework.utils.log" + + fake_torch = types.ModuleType("torch") + fake_dist = types.ModuleType("torch.distributed") + fake_dist.is_available = lambda: False + fake_dist.is_initialized = lambda: False + fake_torch.distributed = fake_dist + + fake_loguru = types.ModuleType("loguru") + fake_loguru_logger = types.ModuleType("loguru._logger") + fake_loguru_logger.Core = type("Core", (), {}) + fake_loguru_logger.Logger = _CaptureLogger + + monkeypatch.setitem(sys.modules, "torch", fake_torch) + monkeypatch.setitem(sys.modules, "torch.distributed", fake_dist) + monkeypatch.setitem(sys.modules, "loguru", fake_loguru) + monkeypatch.setitem(sys.modules, "loguru._logger", fake_loguru_logger) + sys.modules.pop(module_name, None) + + module = importlib.import_module(module_name) + module.logger = _CaptureLogger() + yield module + + sys.modules.pop(module_name, None) + + +def test_info_supports_percent_style_formatting(log_module) -> None: + log_module.info("Wrote %d upsampled prompts to %s", 3, "output.json") + + assert log_module.logger.events == [ + { + "level": "info", + "message": "Wrote 3 upsampled prompts to output.json", + "args": (), + "kwargs": {}, + "exception": None, + "rank0_only": True, + } + ] + + +def test_error_supports_exc_info_and_rank0_override(log_module) -> None: + try: + raise ValueError("boom") + except ValueError: + log_module.error( + "[HFExportCallback] Export worker for iter %d raised an exception: %s", + 7, + "boom", + exc_info=True, + rank0_only=False, + ) + + assert log_module.logger.events == [ + { + "level": "error", + "message": "[HFExportCallback] Export worker for iter 7 raised an exception: boom", + "args": (), + "kwargs": {}, + "exception": True, + "rank0_only": False, + } + ]