Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions cosmos_framework/callbacks/hf_export_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: OpenMDW-1.1

"""Hermetic tests for HF export worker error propagation."""

from __future__ import annotations

import importlib
import sys
import types

import pytest

pytestmark = [pytest.mark.level(0), pytest.mark.gpus(0)]


@pytest.fixture
def hf_export_module(monkeypatch: pytest.MonkeyPatch):
module_name = "cosmos_framework.callbacks.hf_export"

fake_torch = types.ModuleType("torch")
fake_torch.float32 = "float32"
fake_torch.float16 = "float16"
fake_torch.bfloat16 = "bfloat16"
fake_torch.float64 = "float64"
fake_torch.dtype = object
fake_torch.Tensor = object
fake_torch.distributed = types.SimpleNamespace(tensor=types.SimpleNamespace(DTensor=type("DTensor", (), {})))

fake_log = types.ModuleType("cosmos_framework.utils.log")
fake_log.error = lambda *args, **kwargs: None
fake_log.info = lambda *args, **kwargs: None
fake_log.warning = lambda *args, **kwargs: None

fake_callback = types.ModuleType("cosmos_framework.utils.callback")
fake_callback.Callback = type("Callback", (), {})

fake_distributed = types.ModuleType("cosmos_framework.utils.distributed")
fake_distributed.is_rank0 = lambda: True

monkeypatch.setitem(sys.modules, "torch", fake_torch)
monkeypatch.setitem(sys.modules, "cosmos_framework.utils.log", fake_log)
monkeypatch.setitem(sys.modules, "cosmos_framework.utils.callback", fake_callback)
monkeypatch.setitem(sys.modules, "cosmos_framework.utils.distributed", fake_distributed)
sys.modules.pop(module_name, None)

module = importlib.import_module(module_name)
yield module

sys.modules.pop(module_name, None)


def test_save_and_upload_stores_worker_exception_when_export_fails(hf_export_module) -> None:
callback = hf_export_module.HFExportCallback()

def _raise(*args, **kwargs) -> None:
raise RuntimeError("worker failed")

callback._do_save_and_upload = _raise

callback._save_and_upload([], {}, 0, None, "model", "/tmp/export", 12)

assert isinstance(callback._worker_exception, RuntimeError)
assert str(callback._worker_exception) == "worker failed"
61 changes: 45 additions & 16 deletions cosmos_framework/utils/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,36 +108,65 @@ def _rank0_only_filter(record: Any) -> bool:
return not is_rank0


def trace(message: str, rank0_only: bool = True) -> None:
logger.opt(depth=1).bind(rank0_only=rank0_only).trace(message)
def _prepare_log_message(
message: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
if args:
try:
return message % args, (), kwargs
except (TypeError, ValueError):
pass
if kwargs:
try:
return message % kwargs, (), {}
except (KeyError, TypeError, ValueError):
pass
return message, args, kwargs


def debug(message: str, rank0_only: bool = True) -> None:
logger.opt(depth=1).bind(rank0_only=rank0_only).debug(message)
def _log_with_rank(
level: str,
message: str,
*args: Any,
rank0_only: bool = True,
exc_info: Any = None,
**kwargs: Any,
) -> None:
message, args, kwargs = _prepare_log_message(message, args, kwargs)
bound_logger = logger.opt(depth=1, exception=exc_info).bind(rank0_only=rank0_only)
getattr(bound_logger, level)(message, *args, **kwargs)


def info(message: str, rank0_only: bool = True) -> None:
logger.opt(depth=1).bind(rank0_only=rank0_only).info(message)
def trace(message: str, *args: Any, rank0_only: bool = True, exc_info: Any = None, **kwargs: Any) -> None:
_log_with_rank("trace", message, *args, rank0_only=rank0_only, exc_info=exc_info, **kwargs)


def success(message: str, rank0_only: bool = True) -> None:
logger.opt(depth=1).bind(rank0_only=rank0_only).success(message)
def debug(message: str, *args: Any, rank0_only: bool = True, exc_info: Any = None, **kwargs: Any) -> None:
_log_with_rank("debug", message, *args, rank0_only=rank0_only, exc_info=exc_info, **kwargs)


def warning(message: str, rank0_only: bool = True) -> None:
logger.opt(depth=1).bind(rank0_only=rank0_only).warning(message)
def info(message: str, *args: Any, rank0_only: bool = True, exc_info: Any = None, **kwargs: Any) -> None:
_log_with_rank("info", message, *args, rank0_only=rank0_only, exc_info=exc_info, **kwargs)


def error(message: str, rank0_only: bool = True) -> None:
logger.opt(depth=1).bind(rank0_only=rank0_only).error(message)
def success(message: str, *args: Any, rank0_only: bool = True, exc_info: Any = None, **kwargs: Any) -> None:
_log_with_rank("success", message, *args, rank0_only=rank0_only, exc_info=exc_info, **kwargs)


def critical(message: str, rank0_only: bool = True) -> None:
logger.opt(depth=1).bind(rank0_only=rank0_only).critical(message)
def warning(message: str, *args: Any, rank0_only: bool = True, exc_info: Any = None, **kwargs: Any) -> None:
_log_with_rank("warning", message, *args, rank0_only=rank0_only, exc_info=exc_info, **kwargs)


def exception(message: str, rank0_only: bool = True) -> None:
logger.opt(depth=1).bind(rank0_only=rank0_only).exception(message)
def error(message: str, *args: Any, rank0_only: bool = True, exc_info: Any = None, **kwargs: Any) -> None:
_log_with_rank("error", message, *args, rank0_only=rank0_only, exc_info=exc_info, **kwargs)


def critical(message: str, *args: Any, rank0_only: bool = True, exc_info: Any = None, **kwargs: Any) -> None:
_log_with_rank("critical", message, *args, rank0_only=rank0_only, exc_info=exc_info, **kwargs)


def exception(message: str, *args: Any, rank0_only: bool = True, exc_info: Any = True, **kwargs: Any) -> None:
_log_with_rank("exception", message, *args, rank0_only=rank0_only, exc_info=exc_info, **kwargs)


# Execute at import time.
Expand Down
143 changes: 143 additions & 0 deletions cosmos_framework/utils/log_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: OpenMDW-1.1

"""Hermetic tests for the project logging wrappers."""

from __future__ import annotations

import importlib
import sys
import types

import pytest

pytestmark = [pytest.mark.level(0), pytest.mark.gpus(0)]


class _CaptureLogger:
def __init__(self, *args, events: list[dict] | None = None, **kwargs) -> None:
self.events = [] if events is None else events
self._options = (None, None, [], {})
self._exception = None
self._rank0_only = True

def remove(self, *args, **kwargs) -> None:
return None

def add(self, *args, **kwargs) -> int:
return 0

def opt(self, *, depth=None, exception=None):
clone = _CaptureLogger(events=self.events)
clone._exception = exception
clone._rank0_only = self._rank0_only
return clone

def bind(self, **kwargs):
clone = _CaptureLogger(events=self.events)
clone._exception = self._exception
clone._rank0_only = kwargs.get("rank0_only", self._rank0_only)
return clone

def _record(self, level: str, message: str, *args, **kwargs) -> None:
self.events.append(
{
"level": level,
"message": message,
"args": args,
"kwargs": kwargs,
"exception": self._exception,
"rank0_only": self._rank0_only,
}
)

def trace(self, message: str, *args, **kwargs) -> None:
self._record("trace", message, *args, **kwargs)

def debug(self, message: str, *args, **kwargs) -> None:
self._record("debug", message, *args, **kwargs)

def info(self, message: str, *args, **kwargs) -> None:
self._record("info", message, *args, **kwargs)

def success(self, message: str, *args, **kwargs) -> None:
self._record("success", message, *args, **kwargs)

def warning(self, message: str, *args, **kwargs) -> None:
self._record("warning", message, *args, **kwargs)

def error(self, message: str, *args, **kwargs) -> None:
self._record("error", message, *args, **kwargs)

def critical(self, message: str, *args, **kwargs) -> None:
self._record("critical", message, *args, **kwargs)

def exception(self, message: str, *args, **kwargs) -> None:
self._record("exception", message, *args, **kwargs)


@pytest.fixture
def log_module(monkeypatch: pytest.MonkeyPatch):
module_name = "cosmos_framework.utils.log"

fake_torch = types.ModuleType("torch")
fake_dist = types.ModuleType("torch.distributed")
fake_dist.is_available = lambda: False
fake_dist.is_initialized = lambda: False
fake_torch.distributed = fake_dist

fake_loguru = types.ModuleType("loguru")
fake_loguru_logger = types.ModuleType("loguru._logger")
fake_loguru_logger.Core = type("Core", (), {})
fake_loguru_logger.Logger = _CaptureLogger

monkeypatch.setitem(sys.modules, "torch", fake_torch)
monkeypatch.setitem(sys.modules, "torch.distributed", fake_dist)
monkeypatch.setitem(sys.modules, "loguru", fake_loguru)
monkeypatch.setitem(sys.modules, "loguru._logger", fake_loguru_logger)
sys.modules.pop(module_name, None)

module = importlib.import_module(module_name)
module.logger = _CaptureLogger()
yield module

sys.modules.pop(module_name, None)


def test_info_supports_percent_style_formatting(log_module) -> None:
log_module.info("Wrote %d upsampled prompts to %s", 3, "output.json")

assert log_module.logger.events == [
{
"level": "info",
"message": "Wrote 3 upsampled prompts to output.json",
"args": (),
"kwargs": {},
"exception": None,
"rank0_only": True,
}
]


def test_error_supports_exc_info_and_rank0_override(log_module) -> None:
try:
raise ValueError("boom")
except ValueError:
log_module.error(
"[HFExportCallback] Export worker for iter %d raised an exception: %s",
7,
"boom",
exc_info=True,
rank0_only=False,
)

assert log_module.logger.events == [
{
"level": "error",
"message": "[HFExportCallback] Export worker for iter 7 raised an exception: boom",
"args": (),
"kwargs": {},
"exception": True,
"rank0_only": False,
}
]