From b96e044fcc9e3a118d826bd7228448410295a0d6 Mon Sep 17 00:00:00 2001 From: m-misiura Date: Wed, 21 Jan 2026 09:31:15 +0000 Subject: [PATCH] feat: add ShieldToModerationMixin to bridge Safety shields to OpenAI Moderation API Implements a reusable mixin that provides run_moderation functionality for all safety providers by delegating to their existing run_shield implementations. Fixes #4605 Signed-off-by: m-misiura --- .../safety/prompt_guard/prompt_guard.py | 12 +- .../remote/safety/bedrock/bedrock.py | 8 +- .../providers/remote/safety/nvidia/nvidia.py | 8 +- .../remote/safety/sambanova/sambanova.py | 12 +- src/llama_stack/providers/utils/safety.py | 114 +++++++++++ .../unit/providers/utils/test_safety_mixin.py | 184 ++++++++++++++++++ 6 files changed, 312 insertions(+), 26 deletions(-) create mode 100644 src/llama_stack/providers/utils/safety.py create mode 100644 tests/unit/providers/utils/test_safety_mixin.py diff --git a/src/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/src/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index fdfa0c7ba3..17aa36b390 100644 --- a/src/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/src/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -9,12 +9,13 @@ from llama_stack.core.utils.model_utils import model_local_dir from llama_stack.log import get_logger -from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str +from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, +) +from llama_stack.providers.utils.safety import ShieldToModerationMixin from llama_stack_api import ( GetShieldRequest, - ModerationObject, OpenAIMessageParam, - RunModerationRequest, RunShieldRequest, RunShieldResponse, Safety, @@ -32,7 +33,7 @@ PROMPT_GUARD_MODEL = "Prompt-Guard-86M" -class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): +class PromptGuardSafetyImpl(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate): shield_store: ShieldStore def __init__(self, config: PromptGuardConfig, _deps) -> None: @@ -59,9 +60,6 @@ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse: return await self.shield.run(request.messages) - async def run_moderation(self, request: RunModerationRequest) -> ModerationObject: - raise NotImplementedError("run_moderation is not implemented for Prompt Guard") - class PromptGuardShield: def __init__( diff --git a/src/llama_stack/providers/remote/safety/bedrock/bedrock.py b/src/llama_stack/providers/remote/safety/bedrock/bedrock.py index 8086c637c8..2afdcb7fb8 100644 --- a/src/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/src/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -8,10 +8,9 @@ from llama_stack.log import get_logger from llama_stack.providers.utils.bedrock.client import create_bedrock_client +from llama_stack.providers.utils.safety import ShieldToModerationMixin from llama_stack_api import ( GetShieldRequest, - ModerationObject, - RunModerationRequest, RunShieldRequest, RunShieldResponse, Safety, @@ -26,7 +25,7 @@ logger = get_logger(name=__name__, category="safety::bedrock") -class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): +class BedrockSafetyAdapter(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate): def __init__(self, config: BedrockSafetyConfig) -> None: self.config = config self.registered_shields = [] @@ -93,6 +92,3 @@ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse: ) return RunShieldResponse() - - async def run_moderation(self, request: RunModerationRequest) -> ModerationObject: - raise NotImplementedError("Bedrock safety provider currently does not implement run_moderation") diff --git a/src/llama_stack/providers/remote/safety/nvidia/nvidia.py b/src/llama_stack/providers/remote/safety/nvidia/nvidia.py index 7a3801915e..7926365f3f 100644 --- a/src/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/src/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -9,11 +9,10 @@ import requests from llama_stack.log import get_logger +from llama_stack.providers.utils.safety import ShieldToModerationMixin from llama_stack_api import ( GetShieldRequest, - ModerationObject, OpenAIMessageParam, - RunModerationRequest, RunShieldRequest, RunShieldResponse, Safety, @@ -28,7 +27,7 @@ logger = get_logger(name=__name__, category="safety::nvidia") -class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): +class NVIDIASafetyAdapter(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate): def __init__(self, config: NVIDIASafetyConfig) -> None: """ Initialize the NVIDIASafetyAdapter with a given safety configuration. @@ -60,9 +59,6 @@ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse: self.shield = NeMoGuardrails(self.config, shield.shield_id) return await self.shield.run(request.messages) - async def run_moderation(self, request: RunModerationRequest) -> ModerationObject: - raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation") - class NeMoGuardrails: """ diff --git a/src/llama_stack/providers/remote/safety/sambanova/sambanova.py b/src/llama_stack/providers/remote/safety/sambanova/sambanova.py index f0c3229372..83252a4752 100644 --- a/src/llama_stack/providers/remote/safety/sambanova/sambanova.py +++ b/src/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -9,10 +9,9 @@ from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger +from llama_stack.providers.utils.safety import ShieldToModerationMixin from llama_stack_api import ( GetShieldRequest, - ModerationObject, - RunModerationRequest, RunShieldRequest, RunShieldResponse, Safety, @@ -29,7 +28,7 @@ CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" -class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData): +class SambaNovaSafetyAdapter(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate, NeedsRequestProviderData): def __init__(self, config: SambaNovaSafetyConfig) -> None: self.config = config self.environment_available_models = [] @@ -79,7 +78,9 @@ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse: logger.debug(f"run_shield::{shield_params}::messages={request.messages}") response = litellm.completion( - model=shield.provider_resource_id, messages=request.messages, api_key=self._get_api_key() + model=shield.provider_resource_id, + messages=request.messages, + api_key=self._get_api_key(), ) shield_message = response.choices[0].message.content @@ -97,6 +98,3 @@ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse: ) return RunShieldResponse() - - async def run_moderation(self, request: RunModerationRequest) -> ModerationObject: - raise NotImplementedError("SambaNova safety provider currently does not implement run_moderation") diff --git a/src/llama_stack/providers/utils/safety.py b/src/llama_stack/providers/utils/safety.py new file mode 100644 index 0000000000..4d6157375a --- /dev/null +++ b/src/llama_stack/providers/utils/safety.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import uuid +from typing import TYPE_CHECKING + +from llama_stack_api import ( + ModerationObject, + ModerationObjectResults, + OpenAIUserMessageParam, + RunModerationRequest, + RunShieldRequest, + RunShieldResponse, +) + +if TYPE_CHECKING: + # Type stub for mypy - actual implementation provided by provider class + class _RunShieldProtocol: + async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse: ... + + +class ShieldToModerationMixin: + """ + Mixin that provides run_moderation implementation by delegating to run_shield. + + Providers must implement run_shield(request: RunShieldRequest) for this mixin to work. + Providers with custom run_moderation implementations will override this automatically. + """ + + if TYPE_CHECKING: + # Type hint for mypy - run_shield is provided by the mixed-in class + async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse: ... + + async def run_moderation(self, request: RunModerationRequest) -> ModerationObject: + """ + Run moderation by converting input to messages and delegating to run_shield. + + Args: + request: RunModerationRequest with input and model + + Returns: + ModerationObject with results for each input + + Raises: + ValueError: If model is None + """ + if request.model is None: + raise ValueError(f"{self.__class__.__name__} moderation requires a model identifier") + + inputs = request.input if isinstance(request.input, list) else [request.input] + results = [] + + for text_input in inputs: + # Convert string to OpenAI message format + message = OpenAIUserMessageParam(content=text_input) + + # Call run_shield (must be implemented by the provider) + shield_request = RunShieldRequest( + shield_id=request.model, + messages=[message], + ) + shield_response = await self.run_shield(shield_request) + + # Convert RunShieldResponse to ModerationObjectResults + results.append(self._shield_response_to_moderation_result(shield_response)) + + return ModerationObject( + id=f"modr-{uuid.uuid4()}", + model=request.model, + results=results, + ) + + def _shield_response_to_moderation_result( + self, + shield_response: RunShieldResponse, + ) -> ModerationObjectResults: + """Convert RunShieldResponse to ModerationObjectResults. + + Args: + shield_response: The response from run_shield + + Returns: + ModerationObjectResults with appropriate fields set + """ + if shield_response.violation is None: + # Safe content + return ModerationObjectResults( + flagged=False, + categories={}, + category_scores={}, + category_applied_input_types={}, + user_message=None, + metadata={}, + ) + + # Unsafe content - extract violation details + v = shield_response.violation + violation_type = v.metadata.get("violation_type", "unsafe") + + # Ensure violation_type is a string (metadata values can be Any) + if not isinstance(violation_type, str): + violation_type = "unsafe" + + return ModerationObjectResults( + flagged=True, + categories={violation_type: True}, + category_scores={violation_type: 1.0}, + category_applied_input_types={violation_type: ["text"]}, + user_message=v.user_message, + metadata=v.metadata, + ) diff --git a/tests/unit/providers/utils/test_safety_mixin.py b/tests/unit/providers/utils/test_safety_mixin.py new file mode 100644 index 0000000000..350a0456b5 --- /dev/null +++ b/tests/unit/providers/utils/test_safety_mixin.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock + +import pytest + +from llama_stack.providers.inline.safety.prompt_guard.prompt_guard import ( + PromptGuardSafetyImpl, +) +from llama_stack.providers.remote.safety.bedrock.bedrock import BedrockSafetyAdapter +from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter +from llama_stack.providers.remote.safety.sambanova.sambanova import ( + SambaNovaSafetyAdapter, +) +from llama_stack.providers.utils.safety import ShieldToModerationMixin +from llama_stack_api import ( + OpenAIUserMessageParam, + RunModerationRequest, + RunShieldRequest, + RunShieldResponse, + Safety, + SafetyViolation, + ShieldsProtocolPrivate, + ViolationLevel, +) + + +@pytest.mark.parametrize( + "provider_class", + [ + NVIDIASafetyAdapter, + BedrockSafetyAdapter, + SambaNovaSafetyAdapter, + PromptGuardSafetyImpl, + ], +) +def test_providers_use_mixin(provider_class): + """Providers should use mixin for run_moderation .""" + for cls in provider_class.__mro__: + if "run_moderation" in cls.__dict__: + assert cls.__name__ == "ShieldToModerationMixin" + return + pytest.fail(f"{provider_class.__name__} has no run_moderation") + + +async def test_safe_content(): + """Safe content returns unflagged result.""" + + class Provider(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate): + def __init__(self): + self.run_shield = AsyncMock(return_value=RunShieldResponse(violation=None)) + + request = RunModerationRequest(input="safe", model="test") + result = await Provider().run_moderation(request) + + assert result is not None + assert result.results[0].flagged is False + + +async def test_unsafe_content(): + """Unsafe content returns flagged with violation_type as category.""" + + class Provider(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate): + def __init__(self): + self.run_shield = AsyncMock( + return_value=RunShieldResponse( + violation=SafetyViolation( + violation_level=ViolationLevel.ERROR, + user_message="Blocked", + metadata={"violation_type": "harmful"}, + ) + ) + ) + + request = RunModerationRequest(input="unsafe", model="test") + result = await Provider().run_moderation(request) + + assert result.results[0].flagged is True + assert result.results[0].categories == {"harmful": True} + assert result.results[0].category_scores == {"harmful": 1.0} + + +async def test_missing_violation_type_defaults_to_unsafe(): + """When violation_type missing in metadata, defaults to 'unsafe'.""" + + class Provider(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate): + def __init__(self): + self.run_shield = AsyncMock( + return_value=RunShieldResponse( + violation=SafetyViolation( + violation_level=ViolationLevel.ERROR, + user_message="Bad", + metadata={}, # No violation_type + ) + ) + ) + + request = RunModerationRequest(input="test", model="test") + result = await Provider().run_moderation(request) + + assert result.results[0].categories == {"unsafe": True} + + +async def test_non_string_violation_type_defaults_to_unsafe(): + """Non-string violation_type defaults to 'unsafe'""" + + class Provider(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate): + def __init__(self): + self.run_shield = AsyncMock( + return_value=RunShieldResponse( + violation=SafetyViolation( + violation_level=ViolationLevel.ERROR, + user_message="Bad", + metadata={"violation_type": 12345}, # int, not string + ) + ) + ) + + request = RunModerationRequest(input="test", model="test") + result = await Provider().run_moderation(request) + + assert result.results[0].categories == {"unsafe": True} + + +async def test_multiple_inputs(): + """List input produces multiple results.""" + + class Provider(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate): + def __init__(self): + self.run_shield = AsyncMock() + + provider = Provider() + provider.run_shield.side_effect = [ + RunShieldResponse(violation=None), + RunShieldResponse( + violation=SafetyViolation( + violation_level=ViolationLevel.ERROR, + user_message="Bad", + metadata={"violation_type": "bad"}, + ) + ), + ] + + request = RunModerationRequest(input=["safe", "unsafe"], model="test") + result = await provider.run_moderation(request) + + assert len(result.results) == 2 + assert result.results[0].flagged is False + assert result.results[1].flagged is True + + +async def test_run_shield_receives_correct_params(): + """Verify run_shield called with RunShieldRequest containing shield_id and messages.""" + + class Provider(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate): + def __init__(self): + self.run_shield = AsyncMock(return_value=RunShieldResponse(violation=None)) + + provider = Provider() + request = RunModerationRequest(input="test input", model="my-shield") + await provider.run_moderation(request) + + call_args = provider.run_shield.call_args.args + assert len(call_args) == 1 + shield_request = call_args[0] + assert isinstance(shield_request, RunShieldRequest) + assert shield_request.shield_id == "my-shield" + assert isinstance(shield_request.messages[0], OpenAIUserMessageParam) + assert shield_request.messages[0].content == "test input" + + +async def test_model_none_raises_error(): + """Model parameter is required (cannot be None).""" + + class Provider(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate): + pass + + request = RunModerationRequest(input="test", model=None) + with pytest.raises(ValueError, match="moderation requires a model identifier"): + await Provider().run_moderation(request)