Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@

from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.utils.safety import ShieldToModerationMixin
from llama_stack_api import (
GetShieldRequest,
ModerationObject,
OpenAIMessageParam,
RunModerationRequest,
RunShieldRequest,
RunShieldResponse,
Safety,
Expand All @@ -32,7 +33,7 @@
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"


class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
class PromptGuardSafetyImpl(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate):
shield_store: ShieldStore

def __init__(self, config: PromptGuardConfig, _deps) -> None:
Expand All @@ -59,9 +60,6 @@ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:

return await self.shield.run(request.messages)

async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
raise NotImplementedError("run_moderation is not implemented for Prompt Guard")


class PromptGuardShield:
def __init__(
Expand Down
8 changes: 2 additions & 6 deletions src/llama_stack/providers/remote/safety/bedrock/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@

from llama_stack.log import get_logger
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.safety import ShieldToModerationMixin
from llama_stack_api import (
GetShieldRequest,
ModerationObject,
RunModerationRequest,
RunShieldRequest,
RunShieldResponse,
Safety,
Expand All @@ -26,7 +25,7 @@
logger = get_logger(name=__name__, category="safety::bedrock")


class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
class BedrockSafetyAdapter(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate):
def __init__(self, config: BedrockSafetyConfig) -> None:
self.config = config
self.registered_shields = []
Expand Down Expand Up @@ -93,6 +92,3 @@ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
)

return RunShieldResponse()

async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
raise NotImplementedError("Bedrock safety provider currently does not implement run_moderation")
8 changes: 2 additions & 6 deletions src/llama_stack/providers/remote/safety/nvidia/nvidia.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
import requests

from llama_stack.log import get_logger
from llama_stack.providers.utils.safety import ShieldToModerationMixin
from llama_stack_api import (
GetShieldRequest,
ModerationObject,
OpenAIMessageParam,
RunModerationRequest,
RunShieldRequest,
RunShieldResponse,
Safety,
Expand All @@ -28,7 +27,7 @@
logger = get_logger(name=__name__, category="safety::nvidia")


class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
class NVIDIASafetyAdapter(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate):
def __init__(self, config: NVIDIASafetyConfig) -> None:
"""
Initialize the NVIDIASafetyAdapter with a given safety configuration.
Expand Down Expand Up @@ -60,9 +59,6 @@ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(request.messages)

async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")


class NeMoGuardrails:
"""
Expand Down
12 changes: 5 additions & 7 deletions src/llama_stack/providers/remote/safety/sambanova/sambanova.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@

from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.safety import ShieldToModerationMixin
from llama_stack_api import (
GetShieldRequest,
ModerationObject,
RunModerationRequest,
RunShieldRequest,
RunShieldResponse,
Safety,
Expand All @@ -29,7 +28,7 @@
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"


class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
class SambaNovaSafetyAdapter(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
def __init__(self, config: SambaNovaSafetyConfig) -> None:
self.config = config
self.environment_available_models = []
Expand Down Expand Up @@ -79,7 +78,9 @@ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
logger.debug(f"run_shield::{shield_params}::messages={request.messages}")

response = litellm.completion(
model=shield.provider_resource_id, messages=request.messages, api_key=self._get_api_key()
model=shield.provider_resource_id,
messages=request.messages,
api_key=self._get_api_key(),
)
shield_message = response.choices[0].message.content

Expand All @@ -97,6 +98,3 @@ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
)

return RunShieldResponse()

async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
raise NotImplementedError("SambaNova safety provider currently does not implement run_moderation")
114 changes: 114 additions & 0 deletions src/llama_stack/providers/utils/safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import uuid
from typing import TYPE_CHECKING

from llama_stack_api import (
ModerationObject,
ModerationObjectResults,
OpenAIUserMessageParam,
RunModerationRequest,
RunShieldRequest,
RunShieldResponse,
)

if TYPE_CHECKING:
# Type stub for mypy - actual implementation provided by provider class
class _RunShieldProtocol:
async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse: ...


class ShieldToModerationMixin:
"""
Mixin that provides run_moderation implementation by delegating to run_shield.

Providers must implement run_shield(request: RunShieldRequest) for this mixin to work.
Providers with custom run_moderation implementations will override this automatically.
"""

if TYPE_CHECKING:
# Type hint for mypy - run_shield is provided by the mixed-in class
async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse: ...

async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
"""
Run moderation by converting input to messages and delegating to run_shield.

Args:
request: RunModerationRequest with input and model

Returns:
ModerationObject with results for each input

Raises:
ValueError: If model is None
"""
if request.model is None:
raise ValueError(f"{self.__class__.__name__} moderation requires a model identifier")

inputs = request.input if isinstance(request.input, list) else [request.input]
results = []

for text_input in inputs:
# Convert string to OpenAI message format
message = OpenAIUserMessageParam(content=text_input)

# Call run_shield (must be implemented by the provider)
shield_request = RunShieldRequest(
shield_id=request.model,
messages=[message],
)
shield_response = await self.run_shield(shield_request)

# Convert RunShieldResponse to ModerationObjectResults
results.append(self._shield_response_to_moderation_result(shield_response))

return ModerationObject(
id=f"modr-{uuid.uuid4()}",
model=request.model,
results=results,
)

def _shield_response_to_moderation_result(
self,
shield_response: RunShieldResponse,
) -> ModerationObjectResults:
"""Convert RunShieldResponse to ModerationObjectResults.

Args:
shield_response: The response from run_shield

Returns:
ModerationObjectResults with appropriate fields set
"""
if shield_response.violation is None:
# Safe content
return ModerationObjectResults(
flagged=False,
categories={},
category_scores={},
category_applied_input_types={},
user_message=None,
metadata={},
)

# Unsafe content - extract violation details
v = shield_response.violation
violation_type = v.metadata.get("violation_type", "unsafe")

# Ensure violation_type is a string (metadata values can be Any)
if not isinstance(violation_type, str):
violation_type = "unsafe"

return ModerationObjectResults(
flagged=True,
categories={violation_type: True},
category_scores={violation_type: 1.0},
category_applied_input_types={violation_type: ["text"]},
user_message=v.user_message,
metadata=v.metadata,
)
Loading
Loading