From e62a1eb167e75942fca4edc4b36439393e52025d Mon Sep 17 00:00:00 2001 From: Minh Vu Date: Sun, 28 Jun 2026 18:03:59 +0200 Subject: [PATCH] Fail closed in Qwen3Guard --- .../guardrail/qwen3guard/qwen3guard.py | 7 +- .../guardrail/qwen3guard/qwen3guard_test.py | 123 ++++++++++++++++++ 2 files changed, 128 insertions(+), 2 deletions(-) create mode 100644 cosmos_framework/auxiliary/guardrail/qwen3guard/qwen3guard_test.py diff --git a/cosmos_framework/auxiliary/guardrail/qwen3guard/qwen3guard.py b/cosmos_framework/auxiliary/guardrail/qwen3guard/qwen3guard.py index 89d63df..396a6e3 100644 --- a/cosmos_framework/auxiliary/guardrail/qwen3guard/qwen3guard.py +++ b/cosmos_framework/auxiliary/guardrail/qwen3guard/qwen3guard.py @@ -53,9 +53,12 @@ def extract_label_and_categories(self, prompt): output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist() content = self.tokenizer.decode(output_ids, skip_special_tokens=True) - safe_label_match = re.search(safe_pattern, content) + safe_label_match = re.search(safe_pattern, content, flags=re.IGNORECASE) label = safe_label_match.group(1) if safe_label_match else None categories = re.findall(category_pattern, content) + if label is None: + log.warning("Unable to parse safety label from Qwen3Guard output; treating prompt as unsafe") + return False, "Prompt blocked by Qwen3Guard. Unable to determine safety label from moderation output." if label.lower() == "unsafe": return False, f"Prompt blocked by Qwen3Guard. Safety: {label}, Categories: {categories}" else: @@ -67,7 +70,7 @@ def is_safe(self, prompt: str) -> tuple[bool, str]: return self.extract_label_and_categories(prompt) except Exception as e: log.error(f"Unexpected error occurred when running Qwen3Guard guardrail: {e}") - return True, "Unexpected error occurred when running Qwen3Guard guardrail." + return False, "Prompt blocked by Qwen3Guard because the guardrail could not complete safety evaluation." def parse_args(): diff --git a/cosmos_framework/auxiliary/guardrail/qwen3guard/qwen3guard_test.py b/cosmos_framework/auxiliary/guardrail/qwen3guard/qwen3guard_test.py new file mode 100644 index 0000000..4b0d8b2 --- /dev/null +++ b/cosmos_framework/auxiliary/guardrail/qwen3guard/qwen3guard_test.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Unit tests for Qwen3Guard safety-fallback behavior.""" + +from __future__ import annotations + +import importlib +import sys +import types + +import pytest + +pytestmark = [pytest.mark.level(0), pytest.mark.gpus(0)] + + +@pytest.fixture +def qwen3guard_module(monkeypatch: pytest.MonkeyPatch): + module_name = "cosmos_framework.auxiliary.guardrail.qwen3guard.qwen3guard" + fake_torch = types.ModuleType("torch") + fake_torch.bfloat16 = "bfloat16" + fake_transformers = types.ModuleType("transformers") + fake_transformers.AutoModelForCausalLM = object + fake_transformers.AutoTokenizer = object + fake_core = types.ModuleType("cosmos_framework.auxiliary.guardrail.common.core") + fake_core.ContentSafetyGuardrail = type("ContentSafetyGuardrail", (), {}) + fake_core.GuardrailRunner = type("GuardrailRunner", (), {}) + fake_log = types.ModuleType("cosmos_framework.utils.log") + fake_log.debug = lambda *args, **kwargs: None + fake_log.warning = lambda *args, **kwargs: None + fake_log.error = lambda *args, **kwargs: None + fake_misc = types.ModuleType("cosmos_framework.utils.misc") + + class _Color: + @staticmethod + def green(value: str) -> str: + return value + + @staticmethod + def red(value: str) -> str: + return value + + fake_misc.Color = _Color + + monkeypatch.setitem(sys.modules, "torch", fake_torch) + monkeypatch.setitem(sys.modules, "transformers", fake_transformers) + monkeypatch.setitem(sys.modules, "cosmos_framework.auxiliary.guardrail.common.core", fake_core) + monkeypatch.setitem(sys.modules, "cosmos_framework.utils.log", fake_log) + monkeypatch.setitem(sys.modules, "cosmos_framework.utils.misc", fake_misc) + sys.modules.pop(module_name, None) + + module = importlib.import_module(module_name) + yield module + + sys.modules.pop(module_name, None) + + +class _DummySequence: + def __init__(self, values: list[int]) -> None: + self.values = values + + def __getitem__(self, item): + if isinstance(item, slice): + return _DummySequence(self.values[item]) + return self.values[item] + + def tolist(self) -> list[int]: + return list(self.values) + + +class _DummyInputs(dict): + def __init__(self) -> None: + super().__init__(input_ids=[[1, 2, 3]]) + self.input_ids = [[1, 2, 3]] + + def to(self, device: str): + return self + + +class _DummyTokenizer: + def __init__(self, moderation_output: str) -> None: + self.moderation_output = moderation_output + + def apply_chat_template(self, messages, tokenize: bool = False) -> str: + return "prompt" + + def __call__(self, texts, return_tensors: str = "pt") -> _DummyInputs: + return _DummyInputs() + + def decode(self, output_ids, skip_special_tokens: bool = True) -> str: + return self.moderation_output + + +class _DummyModel: + device = "cpu" + + def generate(self, **kwargs): + return [_DummySequence([1, 2, 3, 4, 5])] + + +def test_extract_label_and_categories_blocks_unparseable_output(qwen3guard_module) -> None: + guard = object.__new__(qwen3guard_module.Qwen3Guard) + guard.tokenizer = _DummyTokenizer("moderation output without a recognized safety label") + guard.model = _DummyModel() + + safe, message = guard.extract_label_and_categories("hello") + + assert safe is False + assert message == "Prompt blocked by Qwen3Guard. Unable to determine safety label from moderation output." + + +def test_is_safe_blocks_when_guardrail_raises(qwen3guard_module) -> None: + guard = object.__new__(qwen3guard_module.Qwen3Guard) + + def _raise(prompt: str) -> tuple[bool, str]: + raise RuntimeError("boom") + + guard.extract_label_and_categories = _raise + + safe, message = guard.is_safe("hello") + + assert safe is False + assert message == "Prompt blocked by Qwen3Guard because the guardrail could not complete safety evaluation."