diff --git a/crates/cli/src/config.rs b/crates/cli/src/config.rs index 2307065f..9e5b281e 100644 --- a/crates/cli/src/config.rs +++ b/crates/cli/src/config.rs @@ -158,7 +158,7 @@ pub(crate) struct PluginsCommand { /// Plugin configuration subcommands. #[derive(Debug, Clone, Subcommand)] pub(crate) enum PluginsSubcommand { - /// Interactively create or edit the Observability plugin in `plugins.toml`. + /// Interactively create or edit built-in plugin configuration in `plugins.toml`. Edit(PluginsEditCommand), } diff --git a/crates/cli/tests/coverage/plugins_tests.rs b/crates/cli/tests/coverage/plugins_tests.rs index 28bf0fd2..dbde4e07 100644 --- a/crates/cli/tests/coverage/plugins_tests.rs +++ b/crates/cli/tests/coverage/plugins_tests.rs @@ -7,7 +7,7 @@ use nemo_relay::config_editor::{EditorConfig, EditorSchema}; use nemo_relay::observability::plugin_component::{OBSERVABILITY_PLUGIN_KIND, ObservabilityConfig}; use nemo_relay::plugin::{ConfigPolicy, PluginComponentSpec, PluginConfig}; use nemo_relay::plugins::nemo_guardrails::component::{ - NEMO_GUARDRAILS_PLUGIN_KIND, NeMoGuardrailsConfig, RemoteBackendConfig, + LocalBackendConfig, NEMO_GUARDRAILS_PLUGIN_KIND, NeMoGuardrailsConfig, RemoteBackendConfig, }; use nemo_relay_adaptive::AdaptiveConfig; use nemo_relay_adaptive::plugin_component::ADAPTIVE_PLUGIN_KIND; @@ -50,6 +50,40 @@ fn guardrails_component_config(config_id: &str) -> serde_json::Map serde_json::Map { + json!({ + "mode": "local", + "input": false, + "output": false, + "config_path": config_path, + "tool_input": true, + "tool_output": true, + "local": { + "python_module": "custom_guardrails" + } + }) + .as_object() + .unwrap() + .clone() +} + +fn local_llm_guardrails_component_config(config_yaml: &str) -> serde_json::Map { + json!({ + "mode": "local", + "codec": "openai_chat", + "input": true, + "output": true, + "config_yaml": config_yaml, + "colang_content": "define flow noop\n pass", + "local": { + "python_module": "custom_guardrails" + } + }) + .as_object() + .unwrap() + .clone() +} + #[test] fn target_scope_defaults_to_user_and_rejects_conflicts() { assert_eq!( @@ -160,6 +194,28 @@ fn typed_editor_model_contains_nemo_guardrails_options() { EditorFieldKind::StringMap ); + let local = schema.field("local").unwrap().schema().unwrap(); + assert_eq!( + local.field("python_module").unwrap().kind, + EditorFieldKind::String + ); + assert_eq!( + local.field("python_executable").unwrap().kind, + EditorFieldKind::String + ); + assert_eq!( + schema.field("config_path").unwrap().kind, + EditorFieldKind::String + ); + assert_eq!( + schema.field("config_yaml").unwrap().kind, + EditorFieldKind::String + ); + assert_eq!( + schema.field("colang_content").unwrap().kind, + EditorFieldKind::String + ); + let request_defaults = schema.field("request_defaults").unwrap().schema().unwrap(); let rails = request_defaults.field("rails").unwrap().schema().unwrap(); assert_eq!( @@ -1137,6 +1193,98 @@ fn validate_config_accepts_nemo_guardrails_component() { validate_config(&config).unwrap(); } +#[test] +fn validate_config_accepts_local_tool_only_nemo_guardrails_component() { + let config = PluginConfig { + components: vec![PluginComponentSpec { + kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(), + enabled: true, + config: local_guardrails_component_config("./rails"), + }], + ..PluginConfig::default() + }; + + validate_config(&config).unwrap(); +} + +#[test] +fn validate_config_rejects_local_nemo_guardrails_request_defaults() { + let config = PluginConfig { + components: vec![PluginComponentSpec { + kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(), + enabled: true, + config: json!({ + "mode": "local", + "codec": "openai_chat", + "input": true, + "output": true, + "config_yaml": "models: []", + "request_defaults": { + "context": {"tenant": "demo"} + } + }) + .as_object() + .unwrap() + .clone(), + }], + ..PluginConfig::default() + }; + + let error = validate_config(&config).unwrap_err().to_string(); + assert!(error.contains("request_defaults"), "error was: {error}"); + assert!(error.contains("local mode"), "error was: {error}"); +} + +#[test] +fn validate_config_rejects_local_nemo_guardrails_multiple_config_sources() { + let config = PluginConfig { + components: vec![PluginComponentSpec { + kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(), + enabled: true, + config: json!({ + "mode": "local", + "config_path": "./rails", + "config_yaml": "models: []" + }) + .as_object() + .unwrap() + .clone(), + }], + ..PluginConfig::default() + }; + + let error = validate_config(&config).unwrap_err().to_string(); + assert!( + error.contains("exactly one of config_path or config_yaml"), + "error was: {error}" + ); +} + +#[test] +fn validate_config_rejects_local_nemo_guardrails_colang_without_yaml() { + let config = PluginConfig { + components: vec![PluginComponentSpec { + kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(), + enabled: true, + config: json!({ + "mode": "local", + "config_path": "./rails", + "colang_content": "define flow noop\n pass" + }) + .as_object() + .unwrap() + .clone(), + }], + ..PluginConfig::default() + }; + + let error = validate_config(&config).unwrap_err().to_string(); + assert!( + error.contains("colang_content can only be used with config_yaml"), + "error was: {error}" + ); +} + #[test] fn nemo_guardrails_config_map_prunes_default_version() { let map = nemo_guardrails_config_map(&NeMoGuardrailsConfig { @@ -1155,6 +1303,108 @@ fn nemo_guardrails_config_map_prunes_default_version() { assert_eq!(map["remote"]["config_id"], json!("default")); } +#[test] +fn write_plugin_config_round_trips_local_nemo_guardrails_component() { + let temp = tempfile::tempdir().unwrap(); + let path = temp.path().join("plugins.toml"); + let config = PluginConfig { + components: vec![PluginComponentSpec { + kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(), + enabled: true, + config: local_guardrails_component_config("./rails"), + }], + ..PluginConfig::default() + }; + + write_plugin_config(&path, &config).unwrap(); + + let rendered = std::fs::read_to_string(&path).unwrap(); + assert!(rendered.contains("mode = \"local\"")); + assert!(rendered.contains("config_path = \"./rails\"")); + assert!(rendered.contains("tool_input = true")); + assert!(rendered.contains("python_module = \"custom_guardrails\"")); + + let round_tripped = read_plugin_config(&path).unwrap(); + let guardrails = round_tripped + .components + .iter() + .find(|component| component.kind == NEMO_GUARDRAILS_PLUGIN_KIND) + .unwrap(); + assert!(guardrails.enabled); + assert_eq!(guardrails.config["mode"], json!("local")); + assert_eq!(guardrails.config["config_path"], json!("./rails")); + assert_eq!(guardrails.config["tool_input"], json!(true)); + assert_eq!( + guardrails.config["local"]["python_module"], + json!("custom_guardrails") + ); +} + +#[test] +fn nemo_guardrails_config_map_serializes_local_mode_fields() { + let map = nemo_guardrails_config_map(&NeMoGuardrailsConfig { + mode: "local".into(), + config_path: Some("./rails".into()), + tool_input: true, + tool_output: true, + local: Some(LocalBackendConfig { + python_module: Some("custom_guardrails".into()), + python_executable: Some("/opt/python/bin/python3".into()), + }), + ..NeMoGuardrailsConfig::default() + }) + .unwrap(); + + assert!(!map.contains_key("version")); + assert_eq!(map.get("mode"), Some(&json!("local"))); + assert_eq!(map.get("config_path"), Some(&json!("./rails"))); + assert_eq!(map.get("tool_input"), Some(&json!(true))); + assert_eq!(map["local"]["python_module"], json!("custom_guardrails")); + assert_eq!( + map["local"]["python_executable"], + json!("/opt/python/bin/python3") + ); +} + +#[test] +fn write_plugin_config_round_trips_local_llm_nemo_guardrails_component() { + let temp = tempfile::tempdir().unwrap(); + let path = temp.path().join("plugins.toml"); + let config = PluginConfig { + components: vec![PluginComponentSpec { + kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(), + enabled: true, + config: local_llm_guardrails_component_config("models: []"), + }], + ..PluginConfig::default() + }; + + write_plugin_config(&path, &config).unwrap(); + + let rendered = std::fs::read_to_string(&path).unwrap(); + assert!(rendered.contains("mode = \"local\"")); + assert!(rendered.contains("codec = \"openai_chat\"")); + assert!(rendered.contains("input = true")); + assert!(rendered.contains("output = true")); + assert!(rendered.contains("config_yaml = \"models: []\"")); + + let round_tripped = read_plugin_config(&path).unwrap(); + let guardrails = round_tripped + .components + .iter() + .find(|component| component.kind == NEMO_GUARDRAILS_PLUGIN_KIND) + .unwrap(); + assert_eq!(guardrails.config["mode"], json!("local")); + assert_eq!(guardrails.config["codec"], json!("openai_chat")); + assert_eq!(guardrails.config["input"], json!(true)); + assert_eq!(guardrails.config["output"], json!(true)); + assert_eq!(guardrails.config["config_yaml"], json!("models: []")); + assert_eq!( + guardrails.config["colang_content"], + json!("define flow noop\n pass") + ); +} + #[test] fn display_helpers_render_scalars_json_and_defaults() { assert_eq!(display_value(&json!("logs")), "logs"); diff --git a/crates/core/src/plugins/nemo_guardrails/component.rs b/crates/core/src/plugins/nemo_guardrails/component.rs index 13695405..062b1ad6 100644 --- a/crates/core/src/plugins/nemo_guardrails/component.rs +++ b/crates/core/src/plugins/nemo_guardrails/component.rs @@ -17,9 +17,12 @@ use crate::plugin::{ register_plugin, }; +#[path = "local.rs"] +mod local; #[cfg(all(feature = "guardrails-remote", not(target_arch = "wasm32")))] #[path = "remote.rs"] mod remote; +use local::register_local_backend; #[cfg(all(feature = "guardrails-remote", not(target_arch = "wasm32")))] use remote::register_remote_backend; @@ -189,6 +192,9 @@ pub struct LocalBackendConfig { /// Optional import path for the Python runtime module. #[serde(default, skip_serializing_if = "Option::is_none")] pub python_module: Option, + /// Optional Python executable used to run the local Guardrails worker. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub python_executable: Option, } /// Default request semantics applied by the selected Guardrails backend. @@ -323,6 +329,7 @@ crate::editor_config! { crate::editor_config! { impl LocalBackendConfig { python_module => { label: "python_module", kind: String, optional: true }, + python_executable => { label: "python_executable", kind: String, optional: true }, } } @@ -447,9 +454,7 @@ fn register_nemo_guardrails_backend( ) -> PluginResult<()> { match config.mode.as_str() { "remote" => register_remote_backend(config, ctx), - "local" => Err(PluginError::RegistrationFailed( - "built-in NeMo Guardrails local backend is not implemented yet".to_string(), - )), + "local" => register_local_backend(config, ctx), other => Err(PluginError::InvalidConfig(format!( "unsupported NeMo Guardrails mode '{other}'" ))), @@ -525,7 +530,7 @@ fn validate_nemo_guardrails_plugin_config( &config.policy, plugin_config, "local", - &["python_module"], + &["python_module", "python_executable"], ); validate_section_fields( &mut diagnostics, @@ -693,6 +698,20 @@ fn validate_non_empty_strings( "local.python_module must not be empty".to_string(), ); } + + if let Some(local) = &config.local + && let Some(python_executable) = &local.python_executable + && python_executable.trim().is_empty() + { + push_policy_diag( + diagnostics, + policy.unsupported_value, + "nemo_guardrails.unsupported_value", + Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), + Some("local.python_executable".to_string()), + "local.python_executable must not be empty".to_string(), + ); + } } fn validate_config_shape( @@ -955,6 +974,18 @@ fn validate_request_defaults( return; }; + if config.mode == "local" { + push_policy_diag( + diagnostics, + policy.unsupported_value, + "nemo_guardrails.unsupported_value", + Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), + Some("request_defaults".to_string()), + "local mode does not currently support request_defaults".to_string(), + ); + return; + } + validate_json_object_field( diagnostics, policy, diff --git a/crates/core/src/plugins/nemo_guardrails/local.rs b/crates/core/src/plugins/nemo_guardrails/local.rs new file mode 100644 index 00000000..e1618836 --- /dev/null +++ b/crates/core/src/plugins/nemo_guardrails/local.rs @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::plugin::{PluginRegistrationContext, Result as PluginResult}; + +use super::NeMoGuardrailsConfig; + +mod python; + +pub(super) fn register_local_backend( + config: NeMoGuardrailsConfig, + ctx: &mut PluginRegistrationContext, +) -> PluginResult<()> { + python::register_local_backend(config, ctx) +} diff --git a/crates/core/src/plugins/nemo_guardrails/local_worker.py b/crates/core/src/plugins/nemo_guardrails/local_worker.py new file mode 100644 index 00000000..937fef88 --- /dev/null +++ b/crates/core/src/plugins/nemo_guardrails/local_worker.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import importlib +import json +import sys +import traceback + +DEFAULT_MODULE_NAME = "nemoguardrails" +SUPPORTED_NEMOGUARDRAILS_VERSION = "0.22.0" +STREAM_QUEUE_MAXSIZE = 32 + +_PROTOCOL_STDOUT = sys.stdout +sys.stdout = sys.stderr + + +def send(message): + _PROTOCOL_STDOUT.write(json.dumps(message, separators=(",", ":")) + "\n") + _PROTOCOL_STDOUT.flush() + + +def response(request_id, result=None): + payload = {"id": request_id, "ok": True} + if result is not None: + payload["result"] = result + send(payload) + + +def error_response(request_id, error): + send({"id": request_id, "ok": False, "error": str(error)}) + + +def stream_event(request_id, event, **fields): + payload = {"id": request_id, "ok": True, "event": event} + payload.update(fields) + send(payload) + + +def stream_error(request_id, error): + send({"id": request_id, "ok": False, "event": "error", "error": str(error)}) + + +def status_value(status): + value = getattr(status, "value", status) + return str(value).lower() + + +def optional_string_attr(obj, attr): + value = getattr(obj, attr, None) + if value is None: + return None + return str(value) + + +def string_attr_or_empty(obj, attr): + return optional_string_attr(obj, attr) or "" + + +def guardrails_stream_error_message(chunk): + try: + payload = json.loads(chunk) + except Exception: + return None + error = payload.get("error") + if not isinstance(error, dict): + return None + if error.get("type") != "guardrails_violation": + return None + return error.get("message") or "Blocked by output rails." + + +class AsyncTextStream: + def __init__(self, queue): + self._queue = queue + + def __aiter__(self): + return self + + async def __anext__(self): + value = await self._queue.get() + if value is None: + raise StopAsyncIteration + return value + + +class GuardrailsWorker: + def __init__(self, config): + if sys.version_info < (3, 11): + raise RuntimeError("NeMo Guardrails local backend requires python3 >= 3.11") + + local = config.get("local") or {} + root_module = (local.get("python_module") or DEFAULT_MODULE_NAME).strip() + guardrails = self._import_dependency(root_module, root_module) + options = self._import_dependency(f"{root_module}.rails.llm.options", root_module) + + version = getattr(guardrails, "__version__", None) + if version != SUPPORTED_NEMOGUARDRAILS_VERSION: + raise RuntimeError( + "NeMo Guardrails local backend requires " + f"nemoguardrails=={SUPPORTED_NEMOGUARDRAILS_VERSION}, but found {version!r}. " + f"Install it with: pip install nemoguardrails=={SUPPORTED_NEMOGUARDRAILS_VERSION}" + ) + + self._rail_type = options.RailType + self._rail_status = options.RailStatus + guardrails_config = self._build_guardrails_config(guardrails.RailsConfig, config) + self._rails = guardrails.LLMRails(guardrails_config) + + def _import_dependency(self, module_name, root_module): + try: + return importlib.import_module(module_name) + except ImportError as err: + missing = getattr(err, "name", None) + if missing == root_module: + raise RuntimeError( + "NeMo Guardrails is required for the built-in NeMo Guardrails local backend. " + f"Install it with: pip install nemoguardrails=={SUPPORTED_NEMOGUARDRAILS_VERSION}" + ) from err + raise RuntimeError( + "NeMo Guardrails local backend could not import a required dependency: " + f"{missing or err}. Install the full NeMo Guardrails runtime dependencies." + ) from err + + def _build_guardrails_config(self, rails_config_cls, config): + config_path = config.get("config_path") + if config_path: + return rails_config_cls.from_path(config_path) + + config_yaml = config.get("config_yaml") + if config_yaml is None: + raise ValueError("config_yaml is required when config_path is not provided") + return rails_config_cls.from_content( + colang_content=config.get("colang_content"), + yaml_content=config_yaml, + ) + + def _rail_kind(self, rail_type): + if rail_type == "input": + return self._rail_type.INPUT + if rail_type == "output": + return self._rail_type.OUTPUT + raise ValueError(f"unsupported rail_type {rail_type!r}") + + async def check(self, messages, rail_type): + result = await self._rails.check_async( + messages, + rail_types=[self._rail_kind(rail_type)], + ) + return { + "status": status_value(result.status), + "content": string_attr_or_empty(result, "content"), + "rail": optional_string_attr(result, "rail"), + } + + def has_streaming_output_rails(self): + output = self._output_rails_config() + flows = getattr(output, "flows", None) if output is not None else None + return bool(flows) + + def ensure_streaming_output_supported(self): + output = self._output_rails_config() + if output is None: + return + + streaming = getattr(output, "streaming", None) + if streaming is None or not bool(getattr(streaming, "enabled", False)): + raise RuntimeError( + "local NeMo Guardrails streaming output rails require " + "rails.output.streaming.enabled = true in the Guardrails config." + ) + + if not bool(getattr(streaming, "stream_first", True)): + raise RuntimeError( + "local NeMo Guardrails streaming output rails currently require " + "rails.output.streaming.stream_first = true." + ) + + def _output_rails_config(self): + config = getattr(self._rails, "config", None) + rails = getattr(config, "rails", None) + return getattr(rails, "output", None) + + async def monitor_stream(self, request_id, messages, queue, streams): + try: + async for chunk in self._rails.stream_async( + messages=messages, + generator=AsyncTextStream(queue), + include_metadata=False, + ): + if not isinstance(chunk, str): + continue + message = guardrails_stream_error_message(chunk) + if message: + stream_event(request_id, "blocked", message=message) + return + stream_event(request_id, "done") + except Exception as err: + stream_error(request_id, err) + finally: + streams.pop(request_id, None) + + +worker = None +streams = {} + + +def track_task(pending_tasks, task): + pending_tasks.add(task) + task.add_done_callback(pending_tasks.discard) + return task + + +async def handle_message(message, pending_tasks): + global worker + + request_id = str(message.get("id", "")) + command = message.get("command") + try: + if command == "init": + worker = GuardrailsWorker(message.get("config") or {}) + response( + request_id, + { + "python": sys.executable, + "version": ".".join(str(part) for part in sys.version_info[:3]), + }, + ) + elif worker is None: + raise RuntimeError("NeMo Guardrails local Python worker is not initialized") + elif command == "check": + response( + request_id, + await worker.check(message.get("messages") or [], message.get("rail_type")), + ) + elif command == "has_streaming_output_rails": + response(request_id, {"enabled": worker.has_streaming_output_rails()}) + elif command == "ensure_streaming_output_supported": + worker.ensure_streaming_output_supported() + response(request_id) + elif command == "stream_start": + queue = asyncio.Queue(maxsize=STREAM_QUEUE_MAXSIZE) + streams[request_id] = queue + track_task( + pending_tasks, + asyncio.create_task(worker.monitor_stream(request_id, message.get("messages") or [], queue, streams)), + ) + elif command == "stream_text": + queue = streams.get(request_id) + if queue is not None: + await queue.put(message.get("text") or "") + elif command == "stream_end": + queue = streams.get(request_id) + if queue is not None: + await queue.put(None) + else: + raise RuntimeError(f"unknown worker command {command!r}") + except Exception as err: + if command and command.startswith("stream_"): + stream_error(request_id, err) + else: + error_response(request_id, err) + + +async def main(): + pending_tasks = set() + try: + while True: + line = await asyncio.to_thread(sys.stdin.readline) + if not line: + return + try: + message = json.loads(line) + except Exception: + traceback.print_exc(file=sys.stderr) + continue + if str(message.get("command", "")).startswith("stream_"): + await handle_message(message, pending_tasks) + else: + track_task( + pending_tasks, + asyncio.create_task(handle_message(message, pending_tasks)), + ) + finally: + for task in tuple(pending_tasks): + task.cancel() + if pending_tasks: + await asyncio.gather(*pending_tasks, return_exceptions=True) + + +asyncio.run(main()) diff --git a/crates/core/src/plugins/nemo_guardrails/python.rs b/crates/core/src/plugins/nemo_guardrails/python.rs new file mode 100644 index 00000000..86cd2773 --- /dev/null +++ b/crates/core/src/plugins/nemo_guardrails/python.rs @@ -0,0 +1,1156 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashMap; +use std::env; +use std::io::{BufRead, BufReader, Write}; +use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex, mpsc as std_mpsc}; +use std::thread; +use std::time::Duration; + +use serde::Deserialize; +use serde_json::json; +use tokio::sync::mpsc; +use tokio::task::JoinHandle; +use tokio_stream::StreamExt; +use tokio_stream::wrappers::ReceiverStream; + +use crate::api::llm::LlmRequest; +use crate::api::runtime::{LlmExecutionFn, LlmJsonStream, LlmStreamExecutionFn, ToolExecutionFn}; +use crate::codec::anthropic::AnthropicMessagesCodec; +use crate::codec::openai_chat::OpenAIChatCodec; +use crate::codec::openai_responses::OpenAIResponsesCodec; +use crate::codec::request::{AnnotatedLlmRequest, Message, MessageContent}; +use crate::codec::traits::{LlmCodec, LlmResponseCodec}; +use crate::error::{FlowError, Result as FlowResult}; +use crate::json::Json; +use crate::plugin::{PluginError, PluginRegistrationContext, Result as PluginResult}; + +use super::NeMoGuardrailsConfig; + +const DEFAULT_PYTHON_EXECUTABLE: &str = "python3"; +const PYTHON_EXECUTABLE_ENV: &str = "NEMO_RELAY_PYTHON"; +const WORKER_INIT_TIMEOUT: Duration = Duration::from_secs(30); +const WORKER_RPC_TIMEOUT: Duration = Duration::from_secs(30); +const WORKER_SCRIPT: &str = include_str!("local_worker.py"); + +pub(super) fn register_local_backend( + config: NeMoGuardrailsConfig, + ctx: &mut PluginRegistrationContext, +) -> PluginResult<()> { + let runtime = Arc::new(LocalGuardrailsRuntime::new(&config)?); + + if config.input || config.output { + let llm_runtime = Arc::clone(&runtime); + let enable_input = config.input; + let enable_output = config.output; + let llm_execution: LlmExecutionFn = Arc::new(move |_name, request, next| { + let runtime = Arc::clone(&llm_runtime); + Box::pin(async move { + runtime + .execute_llm(request, next, enable_input, enable_output) + .await + }) + }); + ctx.register_llm_execution_intercept( + "nemo_guardrails_local", + config.priority, + llm_execution, + )?; + + let stream_runtime = Arc::clone(&runtime); + let enable_input = config.input; + let enable_output = config.output; + let llm_stream_execution: LlmStreamExecutionFn = Arc::new(move |_name, request, next| { + let runtime = Arc::clone(&stream_runtime); + Box::pin(async move { + runtime + .execute_llm_stream(request, next, enable_input, enable_output) + .await + }) + }); + ctx.register_llm_stream_execution_intercept( + "nemo_guardrails_local_stream", + config.priority, + llm_stream_execution, + )?; + } + + if config.tool_input || config.tool_output { + let tool_runtime = Arc::clone(&runtime); + let enable_tool_input = config.tool_input; + let enable_tool_output = config.tool_output; + let tool_execution: ToolExecutionFn = Arc::new(move |tool_name, args, next| { + let runtime = Arc::clone(&tool_runtime); + let tool_name = tool_name.to_string(); + Box::pin(async move { + let current_args = if enable_tool_input { + runtime.check_tool_input(&tool_name, &args).await? + } else { + args + }; + + let tool_result = next(current_args.clone()).await?; + if !enable_tool_output { + return Ok(tool_result); + } + + runtime + .check_tool_output(&tool_name, ¤t_args, &tool_result) + .await + }) + }); + ctx.register_tool_execution_intercept( + "nemo_guardrails_local", + config.priority, + tool_execution, + )?; + } + + Ok(()) +} + +struct LocalGuardrailsRuntime { + bridge: LocalGuardrailsBridge, + codec: Option, +} + +impl LocalGuardrailsRuntime { + fn new(config: &NeMoGuardrailsConfig) -> PluginResult { + Ok(Self { + bridge: LocalGuardrailsBridge::new(config)?, + codec: resolve_codec(config)?, + }) + } + + async fn execute_llm( + &self, + request: LlmRequest, + next: crate::api::runtime::LlmExecutionNextFn, + enable_input: bool, + enable_output: bool, + ) -> FlowResult { + let (request, messages) = self.prepare_llm_request(request, enable_input).await?; + let response = next(request).await?; + + if enable_output { + let annotated_response = self.codec()?.decode_response(&response)?; + if let Some(response_text) = annotated_response.response_text() { + self.check_output_rails(&messages, response_text).await?; + } + } + + Ok(response) + } + + async fn execute_llm_stream( + &self, + request: LlmRequest, + next: crate::api::runtime::LlmStreamExecutionNextFn, + enable_input: bool, + enable_output: bool, + ) -> FlowResult { + let (request, messages) = self.prepare_llm_request(request, enable_input).await?; + let provider_stream = next(request).await?; + + if !enable_output || !self.bridge.has_streaming_output_rails().await? { + return Ok(provider_stream); + } + + self.bridge.ensure_streaming_output_supported().await?; + self.guard_provider_stream(messages, provider_stream).await + } + + async fn prepare_llm_request( + &self, + request: LlmRequest, + enable_input: bool, + ) -> FlowResult<(LlmRequest, Vec)> { + let codec = self.codec()?; + let mut current_request = request; + let mut annotated = codec.decode(¤t_request)?; + let mut messages = messages_from_annotated(&annotated)?; + + if enable_input { + match self + .bridge + .check(messages.clone(), LocalRailKind::Input) + .await? + { + LocalCheckOutcome::Passed => {} + LocalCheckOutcome::Blocked { rail, .. } => { + return Err(blocked_error("input", rail.as_deref())); + } + LocalCheckOutcome::Modified { content, .. } => { + replace_last_role_content(&mut annotated, "user", content)?; + current_request = codec.encode(&annotated, ¤t_request)?; + messages = messages_from_annotated(&annotated)?; + } + } + } + + Ok((current_request, messages)) + } + + async fn check_output_rails(&self, messages: &[Json], response_text: &str) -> FlowResult<()> { + let mut output_messages = messages.to_vec(); + output_messages.push(json!({ + "role": "assistant", + "content": response_text, + })); + + match self + .bridge + .check(output_messages, LocalRailKind::Output) + .await? + { + LocalCheckOutcome::Passed => Ok(()), + LocalCheckOutcome::Blocked { rail, .. } => { + Err(blocked_error("output", rail.as_deref())) + } + LocalCheckOutcome::Modified { .. } => Err(local_violation( + "NeMo Guardrails output rail returned modified content, but the local backend \ + does not rewrite provider responses yet.", + )), + } + } + + async fn check_tool_input(&self, tool_name: &str, args: &Json) -> FlowResult { + let messages = vec![json!({ + "role": "user", + "content": tool_input_content(tool_name, args)?, + })]; + + match self.bridge.check(messages, LocalRailKind::Input).await? { + LocalCheckOutcome::Passed => Ok(args.clone()), + LocalCheckOutcome::Blocked { rail, .. } => { + Err(blocked_error("tool_input", rail.as_deref())) + } + LocalCheckOutcome::Modified { content, .. } => { + modified_tool_payload(&content, "arguments") + } + } + } + + async fn check_tool_output( + &self, + tool_name: &str, + args: &Json, + result: &Json, + ) -> FlowResult { + let messages = vec![ + json!({ + "role": "user", + "content": tool_input_content(tool_name, args)?, + }), + json!({ + "role": "assistant", + "content": tool_output_content(tool_name, args, result)?, + }), + ]; + + match self.bridge.check(messages, LocalRailKind::Output).await? { + LocalCheckOutcome::Passed => Ok(result.clone()), + LocalCheckOutcome::Blocked { rail, .. } => { + Err(blocked_error("tool_output", rail.as_deref())) + } + LocalCheckOutcome::Modified { content, .. } => { + modified_tool_payload(&content, "result") + } + } + } + + async fn guard_provider_stream( + &self, + messages: Vec, + provider_stream: LlmJsonStream, + ) -> FlowResult { + let (text_tx, text_rx) = mpsc::channel::>(32); + let (chunk_tx, chunk_rx) = mpsc::channel::>(32); + let blocked = Arc::new(Mutex::new(None)); + let monitor = self + .bridge + .spawn_stream_monitor(messages, text_rx, Arc::clone(&blocked))?; + let codec = *self.codec()?; + + tokio::spawn(async move { + forward_guarded_provider_stream( + provider_stream, + codec, + text_tx, + chunk_tx, + monitor, + blocked, + ) + .await; + }); + + Ok(Box::pin(ReceiverStream::new(chunk_rx)) as LlmJsonStream) + } + + fn codec(&self) -> FlowResult<&LocalGuardrailsCodec> { + self.codec.as_ref().ok_or_else(|| { + FlowError::Internal( + "local NeMo Guardrails backend requires a supported codec".to_string(), + ) + }) + } +} + +struct LocalGuardrailsBridge { + worker: Arc, +} + +impl LocalGuardrailsBridge { + fn new(config: &NeMoGuardrailsConfig) -> PluginResult { + Ok(Self { + worker: LocalGuardrailsWorker::start(config)?, + }) + } + + async fn check( + &self, + messages: Vec, + kind: LocalRailKind, + ) -> FlowResult { + let result = self + .worker + .request(json!({ + "command": "check", + "messages": messages, + "rail_type": kind.as_str(), + })) + .await?; + parse_check_result(result) + } + + async fn has_streaming_output_rails(&self) -> FlowResult { + let result = self + .worker + .request(json!({ "command": "has_streaming_output_rails" })) + .await?; + result + .get("enabled") + .and_then(Json::as_bool) + .ok_or_else(|| FlowError::Internal("worker returned invalid streaming probe".into())) + } + + async fn ensure_streaming_output_supported(&self) -> FlowResult<()> { + self.worker + .request(json!({ "command": "ensure_streaming_output_supported" })) + .await + .map(|_| ()) + } + + fn spawn_stream_monitor( + &self, + messages: Vec, + text_rx: mpsc::Receiver>, + blocked: Arc>>, + ) -> FlowResult>> { + let (stream_id, event_rx) = self.worker.start_stream(messages)?; + let worker = Arc::clone(&self.worker); + Ok(tokio::spawn(async move { + monitor_guardrails_stream(worker, stream_id, text_rx, event_rx, blocked).await + })) + } +} + +struct LocalGuardrailsWorker { + writer: Mutex>, + child: Mutex, + waiters: Arc>>>, + stream_events: Arc>>>, + next_id: AtomicU64, +} + +impl LocalGuardrailsWorker { + fn start(config: &NeMoGuardrailsConfig) -> PluginResult> { + let python = python_executable(config); + let mut command = Command::new(&python); + command + .arg("-u") + .arg("-c") + .arg(WORKER_SCRIPT) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::inherit()); + + let mut child = command.spawn().map_err(|err| { + PluginError::RegistrationFailed(format!( + "failed to start NeMo Guardrails local Python worker with {python:?}: {err}" + )) + })?; + let stdin = child.stdin.take().ok_or_else(|| { + PluginError::RegistrationFailed( + "failed to open stdin for NeMo Guardrails local Python worker".to_string(), + ) + })?; + let stdout = child.stdout.take().ok_or_else(|| { + PluginError::RegistrationFailed( + "failed to open stdout for NeMo Guardrails local Python worker".to_string(), + ) + })?; + + let worker = Arc::new(Self { + writer: Mutex::new(Some(WorkerCommandWriter::spawn(stdin))), + child: Mutex::new(child), + waiters: Arc::new(Mutex::new(HashMap::new())), + stream_events: Arc::new(Mutex::new(HashMap::new())), + next_id: AtomicU64::new(1), + }); + worker.spawn_reader(stdout); + worker.initialize(config)?; + Ok(worker) + } + + fn spawn_reader(&self, stdout: ChildStdout) { + let waiters = Arc::clone(&self.waiters); + let stream_events = Arc::clone(&self.stream_events); + thread::spawn(move || { + let reader = BufReader::new(stdout); + for line in reader.lines() { + let line = match line { + Ok(line) => line, + Err(err) => { + notify_worker_closed(&waiters, &stream_events, err.to_string()); + return; + } + }; + if line.trim().is_empty() { + continue; + } + let envelope = match serde_json::from_str::(&line) { + Ok(envelope) => envelope, + Err(err) => { + notify_worker_closed( + &waiters, + &stream_events, + format!("invalid worker response: {err}"), + ); + return; + } + }; + dispatch_worker_envelope(&waiters, &stream_events, envelope); + } + notify_worker_closed(&waiters, &stream_events, "worker exited".to_string()); + }); + } + + fn initialize(&self, config: &NeMoGuardrailsConfig) -> PluginResult<()> { + let response = self + .request_blocking( + json!({ + "command": "init", + "config": config, + }), + WORKER_INIT_TIMEOUT, + ) + .map_err(|err| PluginError::RegistrationFailed(err.to_string()))?; + if response.ok { + Ok(()) + } else { + Err(PluginError::RegistrationFailed( + response + .error + .unwrap_or_else(|| "NeMo Guardrails local Python worker failed".to_string()), + )) + } + } + + async fn request(&self, mut payload: Json) -> FlowResult { + let receiver = self.send_request(&mut payload)?; + let response_task = tokio::task::spawn_blocking(move || receiver.recv()); + let envelope = match tokio::time::timeout(WORKER_RPC_TIMEOUT, response_task).await { + Ok(result) => result + .map_err(|err| FlowError::Internal(format!("worker response task failed: {err}")))? + .map_err(|err| { + FlowError::Internal(format!("worker response channel closed: {err}")) + })?, + Err(_) => { + self.shutdown(); + return Err(FlowError::Internal(format!( + "worker request timed out after {} seconds", + WORKER_RPC_TIMEOUT.as_secs() + ))); + } + }; + worker_result(envelope) + } + + fn request_blocking(&self, mut payload: Json, timeout: Duration) -> FlowResult { + let receiver = self.send_request(&mut payload)?; + receiver + .recv_timeout(timeout) + .map_err(|err| FlowError::Internal(format!("worker did not initialize: {err}"))) + } + + fn send_request(&self, payload: &mut Json) -> FlowResult> { + let id = self.next_request_id(); + set_request_id(payload, &id)?; + let (tx, rx) = std_mpsc::channel(); + self.waiters + .lock() + .map_err(|err| FlowError::Internal(format!("worker waiter lock poisoned: {err}")))? + .insert(id.clone(), tx); + if let Err(err) = self.write_command(payload) { + let _ = self.waiters.lock().map(|mut waiters| waiters.remove(&id)); + return Err(err); + } + Ok(rx) + } + + fn start_stream( + &self, + messages: Vec, + ) -> FlowResult<(String, mpsc::UnboundedReceiver)> { + let id = self.next_request_id(); + let (tx, rx) = mpsc::unbounded_channel(); + self.stream_events + .lock() + .map_err(|err| FlowError::Internal(format!("worker stream lock poisoned: {err}")))? + .insert(id.clone(), tx); + let payload = json!({ + "id": id, + "command": "stream_start", + "messages": messages, + }); + if let Err(err) = self.write_command(&payload) { + self.forget_stream(&id); + return Err(err); + } + Ok((id, rx)) + } + + fn send_stream_text(&self, id: &str, text: String) -> FlowResult<()> { + self.write_command(&json!({ + "id": id, + "command": "stream_text", + "text": text, + })) + } + + fn send_stream_end(&self, id: &str) -> FlowResult<()> { + self.write_command(&json!({ + "id": id, + "command": "stream_end", + })) + } + + fn forget_stream(&self, id: &str) { + let _ = self + .stream_events + .lock() + .map(|mut streams| streams.remove(id)); + } + + fn next_request_id(&self) -> String { + self.next_id.fetch_add(1, Ordering::Relaxed).to_string() + } + + fn write_command(&self, payload: &Json) -> FlowResult<()> { + let line = serde_json::to_string(payload).map_err(|err| { + FlowError::Internal(format!("failed to serialize worker command: {err}")) + })?; + let writer = self + .writer + .lock() + .map_err(|err| FlowError::Internal(format!("worker writer lock poisoned: {err}")))?; + writer + .as_ref() + .ok_or_else(|| FlowError::Internal("worker command writer is closed".to_string()))? + .send(line) + } + + fn shutdown(&self) { + let writer = self.writer.lock().ok().and_then(|mut writer| writer.take()); + if let Ok(mut child) = self.child.lock() { + let _ = child.kill(); + let _ = child.wait(); + } + if let Some(writer) = writer { + writer.join(); + } + } +} + +impl Drop for LocalGuardrailsWorker { + fn drop(&mut self) { + self.shutdown(); + } +} + +struct WorkerCommandWriter { + sender: std_mpsc::Sender, + error: Arc>>, + handle: Option>, +} + +impl WorkerCommandWriter { + fn spawn(mut stdin: ChildStdin) -> Self { + let (sender, receiver) = std_mpsc::channel::(); + let error = Arc::new(Mutex::new(None)); + let writer_error = Arc::clone(&error); + let handle = thread::spawn(move || { + for line in receiver { + if let Err(err) = writeln!(stdin, "{line}").and_then(|_| stdin.flush()) { + if let Ok(mut stored_error) = writer_error.lock() { + *stored_error = Some(err.to_string()); + } + return; + } + } + let _ = stdin.flush(); + }); + Self { + sender, + error, + handle: Some(handle), + } + } + + fn send(&self, line: String) -> FlowResult<()> { + if let Some(error) = self + .error + .lock() + .map_err(|err| { + FlowError::Internal(format!("worker writer error lock poisoned: {err}")) + })? + .clone() + { + return Err(FlowError::Internal(format!( + "failed to write worker command: {error}" + ))); + } + self.sender.send(line).map_err(|err| { + FlowError::Internal(format!("worker command writer channel closed: {err}")) + }) + } + + fn join(mut self) { + drop(self.sender); + if let Some(handle) = self.handle.take() { + let _ = handle.join(); + } + } +} + +#[derive(Debug, Clone, Deserialize)] +struct WorkerEnvelope { + id: String, + ok: bool, + #[serde(default)] + result: Option, + #[serde(default)] + error: Option, + #[serde(default)] + event: Option, + #[serde(default)] + message: Option, +} + +#[derive(Deserialize)] +struct WorkerCheckResult { + status: String, + #[serde(default)] + content: Option, + #[serde(default)] + rail: Option, +} + +fn python_executable(config: &NeMoGuardrailsConfig) -> String { + config + .local + .as_ref() + .and_then(|local| local.python_executable.as_deref()) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(str::to_string) + .or_else(|| { + env::var(PYTHON_EXECUTABLE_ENV) + .ok() + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) + }) + .unwrap_or_else(|| DEFAULT_PYTHON_EXECUTABLE.to_string()) +} + +fn set_request_id(payload: &mut Json, id: &str) -> FlowResult<()> { + let object = payload.as_object_mut().ok_or_else(|| { + FlowError::Internal("worker command payload must be a JSON object".to_string()) + })?; + object.insert("id".to_string(), Json::String(id.to_string())); + Ok(()) +} + +fn dispatch_worker_envelope( + waiters: &Arc>>>, + stream_events: &Arc>>>, + envelope: WorkerEnvelope, +) { + if envelope.event.is_some() { + let sender = stream_events + .lock() + .ok() + .and_then(|streams| streams.get(&envelope.id).cloned()); + if let Some(sender) = sender { + let _ = sender.send(envelope); + } + return; + } + + let sender = waiters + .lock() + .ok() + .and_then(|mut waiters| waiters.remove(&envelope.id)); + if let Some(sender) = sender { + let _ = sender.send(envelope); + } +} + +fn notify_worker_closed( + waiters: &Arc>>>, + stream_events: &Arc>>>, + message: String, +) { + if let Ok(mut waiters) = waiters.lock() { + for (id, sender) in waiters.drain() { + let _ = sender.send(WorkerEnvelope { + id, + ok: false, + result: None, + error: Some(message.clone()), + event: None, + message: None, + }); + } + } + if let Ok(mut streams) = stream_events.lock() { + for (id, sender) in streams.drain() { + let _ = sender.send(WorkerEnvelope { + id, + ok: false, + result: None, + error: Some(message.clone()), + event: Some("error".to_string()), + message: None, + }); + } + } +} + +fn worker_result(envelope: WorkerEnvelope) -> FlowResult { + if envelope.ok { + Ok(envelope.result.unwrap_or(Json::Null)) + } else { + Err(FlowError::Internal(envelope.error.unwrap_or_else(|| { + "NeMo Guardrails local Python worker failed".to_string() + }))) + } +} + +fn parse_check_result(result: Json) -> FlowResult { + let result: WorkerCheckResult = serde_json::from_value(result).map_err(|err| { + FlowError::Internal(format!("worker returned invalid check result: {err}")) + })?; + match result.status.as_str() { + "blocked" => Ok(LocalCheckOutcome::Blocked { rail: result.rail }), + "modified" => Ok(LocalCheckOutcome::Modified { + content: result.content.unwrap_or_default(), + }), + "passed" => Ok(LocalCheckOutcome::Passed), + unexpected => Err(FlowError::Internal(format!( + "unexpected worker check status: {unexpected}" + ))), + } +} + +#[derive(Clone, Copy)] +enum LocalGuardrailsCodec { + OpenAIChat, + OpenAIResponses, + AnthropicMessages, +} + +impl LocalGuardrailsCodec { + fn decode(&self, request: &LlmRequest) -> FlowResult { + match self { + Self::OpenAIChat => OpenAIChatCodec.decode(request), + Self::OpenAIResponses => OpenAIResponsesCodec.decode(request), + Self::AnthropicMessages => AnthropicMessagesCodec.decode(request), + } + } + + fn encode( + &self, + annotated: &AnnotatedLlmRequest, + original: &LlmRequest, + ) -> FlowResult { + match self { + Self::OpenAIChat => OpenAIChatCodec.encode(annotated, original), + Self::OpenAIResponses => OpenAIResponsesCodec.encode(annotated, original), + Self::AnthropicMessages => AnthropicMessagesCodec.encode(annotated, original), + } + } + + fn decode_response( + &self, + response: &Json, + ) -> FlowResult { + match self { + Self::OpenAIChat => OpenAIChatCodec.decode_response(response), + Self::OpenAIResponses => OpenAIResponsesCodec.decode_response(response), + Self::AnthropicMessages => AnthropicMessagesCodec.decode_response(response), + } + } +} + +fn resolve_codec(config: &NeMoGuardrailsConfig) -> PluginResult> { + if !(config.input || config.output) { + return Ok(None); + } + + match config.codec.as_deref() { + Some("openai_chat") => Ok(Some(LocalGuardrailsCodec::OpenAIChat)), + Some("openai_responses") => Ok(Some(LocalGuardrailsCodec::OpenAIResponses)), + Some("anthropic_messages") => Ok(Some(LocalGuardrailsCodec::AnthropicMessages)), + Some(other) => Err(PluginError::InvalidConfig(format!( + "unsupported local NeMo Guardrails codec '{other}'" + ))), + None => Err(PluginError::InvalidConfig( + "local NeMo Guardrails backend requires a supported codec".to_string(), + )), + } +} + +enum LocalCheckOutcome { + Passed, + Blocked { rail: Option }, + Modified { content: String }, +} + +#[derive(Clone, Copy)] +enum LocalRailKind { + Input, + Output, +} + +impl LocalRailKind { + fn as_str(self) -> &'static str { + match self { + Self::Input => "input", + Self::Output => "output", + } + } +} + +fn messages_from_annotated(annotated: &AnnotatedLlmRequest) -> FlowResult> { + match serde_json::to_value(&annotated.messages) + .map_err(|err| FlowError::Internal(format!("failed to serialize messages: {err}")))? + { + Json::Array(messages) => Ok(messages), + _ => Err(FlowError::Internal( + "serialized messages were not a JSON array".to_string(), + )), + } +} + +fn replace_last_role_content( + annotated: &mut AnnotatedLlmRequest, + role: &str, + content: String, +) -> FlowResult<()> { + for message in annotated.messages.iter_mut().rev() { + match (role, message) { + ( + "user", + Message::User { + content: target, .. + }, + ) => { + *target = MessageContent::Text(content); + return Ok(()); + } + ( + "assistant", + Message::Assistant { + content: target, .. + }, + ) => { + *target = Some(MessageContent::Text(content)); + return Ok(()); + } + _ => {} + } + } + + Err(local_violation(format!( + "NeMo Guardrails returned modified {role} content but no {role} message was present." + ))) +} + +fn tool_input_content(name: &str, args: &Json) -> FlowResult { + serde_json::to_string(&json!({ + "tool_name": name, + "arguments": args, + })) + .map_err(|err| FlowError::Internal(format!("failed to serialize tool input: {err}"))) +} + +fn tool_output_content(name: &str, args: &Json, result: &Json) -> FlowResult { + serde_json::to_string(&json!({ + "tool_name": name, + "arguments": args, + "result": result, + })) + .map_err(|err| FlowError::Internal(format!("failed to serialize tool output: {err}"))) +} + +fn modified_tool_payload(content: &str, field: &str) -> FlowResult { + let value: Json = serde_json::from_str(content).map_err(|_| { + local_violation(format!( + "NeMo Guardrails returned modified tool {field} content that is not valid JSON." + )) + })?; + + let Json::Object(object) = value else { + return Err(local_violation(format!( + "NeMo Guardrails returned modified tool {field} content without a '{field}' field." + ))); + }; + object.get(field).cloned().ok_or_else(|| { + local_violation(format!( + "NeMo Guardrails returned modified tool {field} content without a '{field}' field." + )) + }) +} + +fn blocked_error(rail_type: &str, rail: Option<&str>) -> FlowError { + let detail = rail + .filter(|rail| !rail.is_empty()) + .map(|rail| format!(" by rail '{rail}'")) + .unwrap_or_default(); + let subject = if matches!(rail_type, "input" | "output") { + "LLM call" + } else { + "tool call" + }; + local_violation(format!( + "NeMo Guardrails {rail_type} rail blocked the {subject}{detail}." + )) +} + +fn local_violation(message: impl Into) -> FlowError { + FlowError::Internal(message.into()) +} + +async fn forward_guarded_provider_stream( + mut provider_stream: LlmJsonStream, + codec: LocalGuardrailsCodec, + text_tx: mpsc::Sender>, + chunk_tx: mpsc::Sender>, + monitor: JoinHandle>, + blocked: Arc>>, +) { + let mut buffered_chunks = Vec::new(); + while let Some(item) = provider_stream.next().await { + let chunk = match item { + Ok(chunk) => chunk, + Err(err) => { + let _ = chunk_tx.send(Err(err)).await; + let _ = text_tx.send(None).await; + let _ = monitor.await; + return; + } + }; + + if let Some(message) = blocked_message(&blocked) { + let _ = chunk_tx.send(Err(streaming_output_blocked(message))).await; + let _ = text_tx.send(None).await; + let _ = monitor.await; + return; + } + + let text = extract_stream_text(codec, &chunk); + + if let Some(text) = text { + if text_tx.send(Some(text)).await.is_err() { + send_stream_monitor_error(monitor, &chunk_tx, &blocked).await; + return; + } + + if let Some(message) = blocked_message(&blocked) { + let _ = chunk_tx.send(Err(streaming_output_blocked(message))).await; + let _ = text_tx.send(None).await; + let _ = monitor.await; + return; + } + } + + buffered_chunks.push(chunk); + } + + let _ = text_tx.send(None).await; + if send_stream_monitor_error(monitor, &chunk_tx, &blocked).await { + return; + } + + for chunk in buffered_chunks { + if chunk_tx.send(Ok(chunk)).await.is_err() { + return; + } + } +} + +async fn send_stream_monitor_error( + monitor: JoinHandle>, + chunk_tx: &mpsc::Sender>, + blocked: &Arc>>, +) -> bool { + match monitor.await { + Ok(Ok(())) => {} + Ok(Err(err)) => { + let _ = chunk_tx.send(Err(err)).await; + return true; + } + Err(err) => { + let _ = chunk_tx + .send(Err(FlowError::Internal(format!( + "nemo_guardrails stream monitor task failed: {err}" + )))) + .await; + return true; + } + } + + if let Some(message) = blocked_message(blocked) { + let _ = chunk_tx.send(Err(streaming_output_blocked(message))).await; + return true; + } + + false +} + +fn blocked_message(blocked: &Arc>>) -> Option { + blocked.lock().ok().and_then(|guard| guard.clone()) +} + +fn streaming_output_blocked(message: String) -> FlowError { + local_violation(format!( + "NeMo Guardrails output rail blocked the LLM call: {message}" + )) +} + +fn extract_stream_text(codec: LocalGuardrailsCodec, chunk: &Json) -> Option { + let chunk = chunk.as_object()?; + match codec { + LocalGuardrailsCodec::OpenAIChat => { + let choices = chunk.get("choices")?.as_array()?; + let mut parts = vec![]; + for choice in choices { + let content = choice + .get("delta") + .and_then(Json::as_object) + .and_then(|delta| delta.get("content")) + .and_then(Json::as_str); + if let Some(content) = content + && !content.is_empty() + { + parts.push(content); + } + } + (!parts.is_empty()).then(|| parts.join("")) + } + LocalGuardrailsCodec::OpenAIResponses => { + if chunk.get("type").and_then(Json::as_str) == Some("response.output_text.delta") { + chunk + .get("delta") + .and_then(Json::as_str) + .filter(|delta| !delta.is_empty()) + .map(str::to_string) + } else { + None + } + } + LocalGuardrailsCodec::AnthropicMessages => { + if chunk.get("type").and_then(Json::as_str) != Some("content_block_delta") { + return None; + } + let delta = chunk.get("delta")?.as_object()?; + if delta.get("type").and_then(Json::as_str) != Some("text_delta") { + return None; + } + delta + .get("text") + .and_then(Json::as_str) + .filter(|text| !text.is_empty()) + .map(str::to_string) + } + } +} + +async fn monitor_guardrails_stream( + worker: Arc, + stream_id: String, + mut text_rx: mpsc::Receiver>, + mut event_rx: mpsc::UnboundedReceiver, + blocked: Arc>>, +) -> FlowResult<()> { + let mut input_closed = false; + loop { + tokio::select! { + maybe_text = text_rx.recv(), if !input_closed => { + match maybe_text { + Some(Some(text)) => worker.send_stream_text(&stream_id, text)?, + Some(None) | None => { + worker.send_stream_end(&stream_id)?; + input_closed = true; + } + } + } + maybe_event = event_rx.recv() => { + let Some(event) = maybe_event else { + worker.forget_stream(&stream_id); + return Err(FlowError::Internal( + "NeMo Guardrails local Python worker stream closed unexpectedly".to_string(), + )); + }; + if !event.ok { + worker.forget_stream(&stream_id); + return Err(FlowError::Internal(event.error.unwrap_or_else(|| { + "NeMo Guardrails local Python worker stream failed".to_string() + }))); + } + match event.event.as_deref() { + Some("blocked") => { + if let Some(message) = event.message { + let mut guard = blocked.lock().map_err(|err| { + FlowError::Internal(format!("stream block state lock poisoned: {err}")) + })?; + *guard = Some(message); + } + worker.forget_stream(&stream_id); + return Ok(()); + } + Some("done") => { + worker.forget_stream(&stream_id); + return Ok(()); + } + Some(other) => { + worker.forget_stream(&stream_id); + return Err(FlowError::Internal(format!( + "NeMo Guardrails local Python worker returned unknown stream event '{other}'" + ))); + } + None => {} + } + } + } + } +} + +#[cfg(test)] +#[path = "../../../tests/unit/plugins/nemo_guardrails/local_python_tests.rs"] +mod tests; diff --git a/crates/core/tests/unit/observability/atof_tests.rs b/crates/core/tests/unit/observability/atof_tests.rs index 47ba4dd8..e0bd4f96 100644 --- a/crates/core/tests/unit/observability/atof_tests.rs +++ b/crates/core/tests/unit/observability/atof_tests.rs @@ -300,7 +300,7 @@ fn start_http_capture_server(expected_requests: usize) -> (String, Arc { - assert!(message.contains("local backend")); - } - other => panic!("unexpected error: {other}"), - } -} - #[test] fn enabled_unknown_mode_initialization_fails_fast_when_policy_ignores_validation() { let _guard = crate::plugins::nemo_guardrails::test_mutex() diff --git a/crates/core/tests/unit/plugins/nemo_guardrails/local_python_tests.rs b/crates/core/tests/unit/plugins/nemo_guardrails/local_python_tests.rs new file mode 100644 index 00000000..7ef5dbee --- /dev/null +++ b/crates/core/tests/unit/plugins/nemo_guardrails/local_python_tests.rs @@ -0,0 +1,396 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#[cfg(unix)] +use std::fs; +#[cfg(unix)] +use std::os::unix::fs::PermissionsExt; +#[cfg(unix)] +use std::path::{Path, PathBuf}; +#[cfg(unix)] +use std::process::Command; +#[cfg(unix)] +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; + +use serde_json::json; + +use super::*; +#[cfg(unix)] +use crate::plugins::nemo_guardrails::component::LocalBackendConfig; + +#[cfg(unix)] +static NEXT_FIXTURE_ID: AtomicUsize = AtomicUsize::new(1); + +#[cfg(unix)] +struct FakeGuardrails { + root: PathBuf, + module_name: String, + python: PathBuf, +} + +#[cfg(unix)] +impl FakeGuardrails { + fn new(version: &str) -> Self { + let id = NEXT_FIXTURE_ID.fetch_add(1, Ordering::Relaxed); + let module_name = format!("fake_guardrails_{id}"); + let root = std::env::temp_dir().join(format!( + "nemo_relay_fake_guardrails_{}_{}", + std::process::id(), + id + )); + let package = root.join(&module_name); + fs::create_dir_all(package.join("rails/llm")).unwrap(); + fs::write(package.join("rails/__init__.py"), "").unwrap(); + fs::write(package.join("rails/llm/__init__.py"), "").unwrap(); + fs::write(package.join("rails/llm/options.py"), fake_options_module()).unwrap(); + fs::write(package.join("__init__.py"), fake_root_module(version)).unwrap(); + + let python = root.join("python-wrapper"); + fs::write( + &python, + format!( + "#!/bin/sh\nPYTHONPATH='{}' exec python3 \"$@\"\n", + shell_single_quote(&root) + ), + ) + .unwrap(); + let mut permissions = fs::metadata(&python).unwrap().permissions(); + permissions.set_mode(0o755); + fs::set_permissions(&python, permissions).unwrap(); + + Self { + root, + module_name, + python, + } + } + + fn config(&self) -> NeMoGuardrailsConfig { + NeMoGuardrailsConfig { + mode: "local".to_string(), + codec: Some("openai_chat".to_string()), + config_yaml: Some("models: []".to_string()), + colang_content: Some("define flow noop\n pass".to_string()), + local: Some(LocalBackendConfig { + python_module: Some(self.module_name.clone()), + python_executable: Some(self.python.to_string_lossy().into_owned()), + }), + ..NeMoGuardrailsConfig::default() + } + } +} + +#[cfg(unix)] +impl Drop for FakeGuardrails { + fn drop(&mut self) { + let _ = fs::remove_dir_all(&self.root); + } +} + +#[cfg(unix)] +fn python3_available() -> bool { + Command::new("python3") + .arg("--version") + .output() + .map(|output| output.status.success()) + .unwrap_or(false) +} + +#[cfg(unix)] +fn shell_single_quote(path: &Path) -> String { + path.to_string_lossy().replace('\'', "'\\''") +} + +#[cfg(unix)] +fn fake_options_module() -> &'static str { + r#" +class RailType: + INPUT = "input" + OUTPUT = "output" + +class RailStatus: + BLOCKED = "blocked" + MODIFIED = "modified" + PASSED = "passed" +"# +} + +#[cfg(unix)] +fn fake_root_module(version: &str) -> String { + format!( + r#" +import json +import types +from .rails.llm.options import RailStatus + +__version__ = {version:?} + +class Result: + def __init__(self, status, content=None, rail=None): + self.status = status + self.content = content + self.rail = rail + +class RailsConfig: + @staticmethod + def from_content(*, colang_content=None, yaml_content=None): + stream_first = "stream_first_false" not in (yaml_content or "") + flows = [] if "no_stream" in (yaml_content or "") else ["self check output"] + return types.SimpleNamespace( + yaml=yaml_content, + colang=colang_content, + rails=types.SimpleNamespace( + output=types.SimpleNamespace( + flows=flows, + streaming=types.SimpleNamespace(enabled=True, stream_first=stream_first), + ) + ) + ) + + @staticmethod + def from_path(path): + return types.SimpleNamespace( + path=path, + rails=types.SimpleNamespace( + output=types.SimpleNamespace( + flows=["self check output"], + streaming=types.SimpleNamespace(enabled=True, stream_first=True), + ) + ) + ) + +class LLMRails: + def __init__(self, config): + self.config = config + + async def check_async(self, messages, rail_types=None): + content = " ".join(str(message.get("content", "")) for message in messages) + if "block" in content: + return Result(RailStatus.BLOCKED, "", "policy") + if "modify-tool" in content: + return Result(RailStatus.MODIFIED, '{{"arguments":{{"safe":true}},"result":{{"ok":true}}}}') + if "modify" in content: + return Result(RailStatus.MODIFIED, "rewritten") + return Result(RailStatus.PASSED, "") + + async def stream_async(self, *, messages=None, generator=None, include_metadata=False): + async for text in generator: + if "stream-block" in text: + yield json.dumps({{"error": {{"type": "guardrails_violation", "message": "blocked stream"}}}}) + return + yield json.dumps({{"ok": True}}) +"# + ) +} + +#[cfg(unix)] +#[tokio::test(flavor = "current_thread")] +async fn bridge_checks_pass_block_and_modify_outcomes() { + if !python3_available() { + return; + } + + let fixture = FakeGuardrails::new("0.22.0"); + let bridge = LocalGuardrailsBridge::new(&fixture.config()).unwrap(); + + assert!(matches!( + bridge + .check( + vec![json!({"role": "user", "content": "hello"})], + LocalRailKind::Input, + ) + .await + .unwrap(), + LocalCheckOutcome::Passed + )); + + match bridge + .check( + vec![json!({"role": "user", "content": "block this"})], + LocalRailKind::Input, + ) + .await + .unwrap() + { + LocalCheckOutcome::Blocked { rail } => assert_eq!(rail.as_deref(), Some("policy")), + _ => panic!("expected blocked outcome"), + } + + match bridge + .check( + vec![json!({"role": "user", "content": "modify this"})], + LocalRailKind::Input, + ) + .await + .unwrap() + { + LocalCheckOutcome::Modified { content } => assert_eq!(content, "rewritten"), + _ => panic!("expected modified outcome"), + } +} + +#[cfg(unix)] +#[test] +fn bridge_rejects_unsupported_guardrails_version() { + if !python3_available() { + return; + } + + let fixture = FakeGuardrails::new("0.21.0"); + let error = match LocalGuardrailsBridge::new(&fixture.config()) { + Ok(_) => panic!("expected unsupported version error"), + Err(error) => error, + }; + assert!(error.to_string().contains("nemoguardrails==0.22.0")); +} + +#[cfg(unix)] +#[tokio::test(flavor = "current_thread")] +async fn streaming_support_rejects_stream_first_false() { + if !python3_available() { + return; + } + + let fixture = FakeGuardrails::new("0.22.0"); + let mut config = fixture.config(); + config.config_yaml = Some("stream_first_false".to_string()); + let bridge = LocalGuardrailsBridge::new(&config).unwrap(); + + assert!(bridge.has_streaming_output_rails().await.unwrap()); + let error = bridge + .ensure_streaming_output_supported() + .await + .unwrap_err(); + assert!(error.to_string().contains("stream_first = true")); +} + +#[cfg(unix)] +#[tokio::test(flavor = "current_thread")] +async fn stream_monitor_records_blocked_message() { + if !python3_available() { + return; + } + + let fixture = FakeGuardrails::new("0.22.0"); + let bridge = LocalGuardrailsBridge::new(&fixture.config()).unwrap(); + let (text_tx, text_rx) = mpsc::channel(8); + let blocked = Arc::new(Mutex::new(None)); + let monitor = bridge + .spawn_stream_monitor( + vec![json!({"role": "user", "content": "hello"})], + text_rx, + Arc::clone(&blocked), + ) + .unwrap(); + + text_tx + .send(Some("stream-block".to_string())) + .await + .unwrap(); + text_tx.send(None).await.unwrap(); + monitor.await.unwrap().unwrap(); + + assert_eq!(blocked.lock().unwrap().as_deref(), Some("blocked stream")); +} + +#[tokio::test(flavor = "current_thread")] +async fn guarded_provider_stream_reports_block_before_buffered_chunks() { + let provider_stream: LlmJsonStream = Box::pin(tokio_stream::iter(vec![Ok(json!({ + "choices": [{"delta": {"content": "blocked"}}], + }))])); + let (text_tx, mut text_rx) = mpsc::channel::>(8); + let (chunk_tx, mut chunk_rx) = mpsc::channel(8); + let blocked = Arc::new(Mutex::new(None)); + let monitor_blocked = Arc::clone(&blocked); + let monitor = tokio::spawn(async move { + while let Some(item) = text_rx.recv().await { + match item { + Some(text) if text.contains("blocked") => { + *monitor_blocked.lock().unwrap() = Some("blocked stream".to_string()); + } + Some(_) => {} + None => break, + } + } + Ok(()) + }); + + forward_guarded_provider_stream( + provider_stream, + LocalGuardrailsCodec::OpenAIChat, + text_tx, + chunk_tx, + monitor, + blocked, + ) + .await; + + let error = chunk_rx.recv().await.unwrap().unwrap_err(); + assert!( + error.to_string().contains("blocked stream"), + "unexpected error: {error}" + ); + assert!(chunk_rx.recv().await.is_none()); +} + +#[test] +fn parse_check_result_rejects_unknown_status() { + assert!(matches!( + parse_check_result(json!({"status": "passed"})).unwrap(), + LocalCheckOutcome::Passed + )); + + let error = match parse_check_result(json!({"status": "surprising"})) { + Ok(_) => panic!("expected unknown status to fail"), + Err(error) => error, + }; + assert!( + error + .to_string() + .contains("unexpected worker check status: surprising"), + "unexpected error: {error}" + ); +} + +#[test] +fn modified_tool_payload_rejects_malformed_content() { + let error = modified_tool_payload("not-json", "arguments").unwrap_err(); + assert!( + error + .to_string() + .contains("modified tool arguments content that is not valid JSON") + ); + + let error = modified_tool_payload(r#"{"tool_name":"demo"}"#, "result").unwrap_err(); + assert!( + error + .to_string() + .contains("modified tool result content without a 'result' field") + ); +} + +#[test] +fn stream_text_extraction_handles_supported_codecs() { + assert_eq!( + extract_stream_text( + LocalGuardrailsCodec::OpenAIChat, + &json!({"choices": [{"delta": {"content": "hel"}}, {"delta": {"content": "lo"}}]}) + ), + Some("hello".to_string()) + ); + assert_eq!( + extract_stream_text( + LocalGuardrailsCodec::OpenAIResponses, + &json!({"type": "response.output_text.delta", "delta": "hello"}) + ), + Some("hello".to_string()) + ); + assert_eq!( + extract_stream_text( + LocalGuardrailsCodec::AnthropicMessages, + &json!({"type": "content_block_delta", "delta": {"type": "text_delta", "text": "hello"}}) + ), + Some("hello".to_string()) + ); +} diff --git a/crates/python/src/lib.rs b/crates/python/src/lib.rs index d11df353..0c7c1998 100644 --- a/crates/python/src/lib.rs +++ b/crates/python/src/lib.rs @@ -23,6 +23,7 @@ use nemo_relay::shared_runtime::initialize_shared_runtime_binding; use nemo_relay_adaptive::plugin_component::register_adaptive_component; use pyo3::prelude::*; +use pyo3::types::PyModule; mod convert; #[doc(hidden)] @@ -62,3 +63,7 @@ fn _native(m: &Bound<'_, PyModule>) -> PyResult<()> { #[cfg(test)] #[path = "../tests/coverage/coverage_tests.rs"] mod coverage_tests; + +#[cfg(test)] +#[path = "../tests/coverage/nemo_guardrails_coverage_tests.rs"] +mod nemo_guardrails_coverage_tests; diff --git a/crates/python/src/py_plugin.rs b/crates/python/src/py_plugin.rs index d483375b..09209b03 100644 --- a/crates/python/src/py_plugin.rs +++ b/crates/python/src/py_plugin.rs @@ -29,7 +29,8 @@ use nemo_relay::api::subscriber::{deregister_subscriber, register_subscriber}; use nemo_relay::plugin::{ ConfigDiagnostic, DiagnosticLevel, Plugin, PluginConfig, PluginError, PluginRegistration, PluginRegistrationContext, active_plugin_report, clear_plugin_configuration, deregister_plugin, - initialize_plugins, list_plugin_kinds, register_plugin, validate_plugin_config, + initialize_plugins, list_plugin_kinds, register_plugin, rollback_registrations, + validate_plugin_config, }; use crate::convert::{json_to_py, py_to_json}; @@ -160,6 +161,34 @@ fn new_py_plugin_context( ) } +pub(crate) fn invoke_python_plugin_register( + py: Python<'_>, + plugin_kind: &str, + register_fn: &Bound<'_, PyAny>, + plugin_config: &Map, + namespace_prefix: String, +) -> PyResult> { + let py_ctx = new_py_plugin_context( + py, + plugin_kind, + Arc::new(Mutex::new(vec![])), + namespace_prefix, + )?; + let plugin_config_py = plugin_config_to_py(py, plugin_kind, plugin_config)?; + match register_fn.call1((plugin_config_py, py_ctx.clone_ref(py))) { + Ok(_) => { + let py_ctx_ref = py_ctx.bind(py).borrow(); + py_ctx_ref.drain_registrations() + } + Err(err) => { + if let Ok(mut registrations) = py_ctx.bind(py).borrow().drain_registrations() { + rollback_registrations(&mut registrations); + } + Err(err) + } + } +} + #[pyclass(name = "PluginContext")] pub struct PyPluginContext { registrations: Arc>>, @@ -695,22 +724,14 @@ impl Plugin for PyPlugin { let plugin_config = plugin_config.clone(); Box::pin(async move { let registrations = Python::attach(|py| -> PyResult> { - let py_ctx = new_py_plugin_context( + let register_fn = self.plugin.getattr(py, "register")?.into_bound(py); + invoke_python_plugin_register( py, &self.plugin_kind, - Arc::new(Mutex::new(vec![])), + ®ister_fn, + &plugin_config, namespace_prefix, - )?; - let plugin_config_py = json_to_py(py, &Json::Object(plugin_config.clone()))?; - self.plugin.call_method1( - py, - "register", - (plugin_config_py, py_ctx.clone_ref(py)), - )?; - { - let py_ctx_ref = py_ctx.bind(py).borrow(); - py_ctx_ref.drain_registrations() - } + ) }) .map_err(|err| PluginError::RegistrationFailed(err.to_string()))?; diff --git a/crates/python/tests/coverage/nemo_guardrails_coverage_tests.rs b/crates/python/tests/coverage/nemo_guardrails_coverage_tests.rs new file mode 100644 index 00000000..00805085 --- /dev/null +++ b/crates/python/tests/coverage/nemo_guardrails_coverage_tests.rs @@ -0,0 +1,913 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Coverage tests for Python-facing local NeMo Guardrails integration. + +use std::ffi::CString; +use std::fs; +use std::panic::{AssertUnwindSafe, catch_unwind}; +use std::path::{Path, PathBuf}; +use std::process::{self, Command, Stdio}; +use std::sync::{ + Mutex, + atomic::{AtomicUsize, Ordering}, +}; + +use nemo_relay::api::runtime::{NemoRelayContextState, global_context}; +use nemo_relay::plugin::{ + PluginComponentSpec, PluginConfig, clear_plugin_configuration, initialize_plugins, +}; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyModule}; +use serde_json::json; + +static NEXT_FAKE_GUARDRAILS_ID: AtomicUsize = AtomicUsize::new(1); +static SERIAL_TEST_MUTEX: Mutex<()> = Mutex::new(()); + +fn load_module<'py>(py: Python<'py>, code: &str) -> Bound<'py, PyModule> { + let code = CString::new(code).unwrap(); + let file_name = CString::new("nemo_guardrails_coverage_tests.py").unwrap(); + let module_name = CString::new("nemo_guardrails_coverage_tests").unwrap(); + PyModule::from_code(py, &code, &file_name, &module_name).unwrap() +} + +fn python_package_dir() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../python") +} + +struct FakeGuardrailsPackage { + root: PathBuf, + module_name: String, + python_executable: PathBuf, +} + +impl FakeGuardrailsPackage { + fn new(py: Python<'_>, module_name: &str, version: &str, implementation: &str) -> Self { + let id = NEXT_FAKE_GUARDRAILS_ID.fetch_add(1, Ordering::Relaxed); + let root = std::env::temp_dir().join(format!( + "nemo_relay_python_fake_guardrails_{}_{}", + process::id(), + id + )); + let package = root.join(module_name); + fs::create_dir_all(package.join("rails/llm")).unwrap(); + fs::write(package.join("rails/__init__.py"), "").unwrap(); + fs::write(package.join("rails/llm/__init__.py"), "").unwrap(); + fs::write(package.join("rails/llm/options.py"), fake_options_module()).unwrap(); + fs::write( + package.join("__init__.py"), + fake_root_module(version, implementation), + ) + .unwrap(); + + let python_executable = write_python_wrapper(&root, &python_executable_for_worker(py)); + + Self { + root, + module_name: module_name.to_string(), + python_executable, + } + } +} + +fn python_executable_for_worker(py: Python<'_>) -> String { + let executable = py + .import("sys") + .and_then(|sys| sys.getattr("executable")) + .and_then(|executable| executable.extract::()) + .unwrap_or_else(|_| "python3".to_string()); + if Command::new(&executable) + .arg("-c") + .arg("import sys") + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status() + .map(|status| status.success()) + .unwrap_or(false) + { + executable + } else { + "python3".to_string() + } +} + +impl Drop for FakeGuardrailsPackage { + fn drop(&mut self) { + let _ = fs::remove_dir_all(&self.root); + } +} + +fn fake_guardrails_module_prelude( + module_name: &str, + python_dir: &str, + python_executable: &str, +) -> String { + format!( + r#" +import sys + +sys.path.insert(0, {python_dir:?}) + +MODULE_NAME = {module_name:?} +PYTHON_EXECUTABLE = {python_executable:?} +"#, + python_dir = python_dir, + module_name = module_name, + python_executable = python_executable, + ) +} + +fn fake_options_module() -> &'static str { + r#" +class RailType: + INPUT = "input" + OUTPUT = "output" + +class RailStatus: + BLOCKED = "blocked" + MODIFIED = "modified" + PASSED = "passed" +"# +} + +fn fake_root_module(version: &str, implementation: &str) -> String { + format!( + r#" +import types +from .rails.llm.options import RailStatus + +__version__ = {version:?} + +class Result: + def __init__(self, status, content=None, rail=None): + self.status = status + self.content = content + self.rail = rail + +class RailsConfig: + @staticmethod + def from_content(*, colang_content=None, yaml_content=None): + return {{"yaml": yaml_content, "colang": colang_content}} + + @staticmethod + def from_path(path): + return {{"path": path}} + +{implementation} +"# + ) +} + +fn check_sequence_guardrails() -> &'static str { + r#" +class LLMRails: + def __init__(self, config): + self.config = config + self._check_results = [ + Result(RailStatus.MODIFIED, content="sanitized user"), + Result(RailStatus.BLOCKED, rail="output-policy"), + Result(RailStatus.MODIFIED, content='{"arguments": {"city": "Boston"}}'), + Result(RailStatus.MODIFIED, content='{"result": {"ok": true}}'), + ] + + async def check_async(self, messages, rail_types): + return self._check_results.pop(0) +"# +} + +fn tool_sequence_guardrails() -> &'static str { + r#" +class LLMRails: + def __init__(self, config): + self.config = config + self._check_results = [ + Result(RailStatus.MODIFIED, content="sanitized user"), + Result(RailStatus.PASSED), + Result(RailStatus.MODIFIED, content='{"arguments": {"city": "Boston"}}'), + Result(RailStatus.MODIFIED, content='{"result": {"ok": true}}'), + ] + + async def check_async(self, messages, rail_types): + return self._check_results.pop(0) +"# +} + +fn streaming_guardrails() -> &'static str { + r#" +class LLMRails: + def __init__(self, config): + yaml = str(config.get("yaml", "")) + stream_first = "stream_first_false" not in yaml + self.config = types.SimpleNamespace( + rails=types.SimpleNamespace( + output=types.SimpleNamespace( + flows=["self check output"], + streaming=types.SimpleNamespace(enabled=True, stream_first=stream_first), + ) + ) + ) + self._stream_calls = 0 + + async def check_async(self, messages, rail_types): + return Result(RailStatus.PASSED) + + def stream_async(self, *, messages=None, generator=None, include_metadata=False): + async def _run(): + self._stream_calls += 1 + call_index = self._stream_calls + async for chunk in generator: + if call_index == 1: + yield chunk + if call_index > 1: + yield '{"error": {"message": "Blocked by output rails: output-policy", "type": "guardrails_violation"}}' + return _run() +"# +} + +#[cfg(unix)] +fn write_python_wrapper(root: &Path, python_executable: &str) -> PathBuf { + use std::os::unix::fs::PermissionsExt; + + let wrapper = root.join("python-wrapper"); + fs::write( + &wrapper, + format!( + "#!/bin/sh\nPYTHONPATH='{}' exec '{}' \"$@\"\n", + shell_single_quote(root), + shell_single_quote(Path::new(python_executable)) + ), + ) + .unwrap(); + let mut permissions = fs::metadata(&wrapper).unwrap().permissions(); + permissions.set_mode(0o755); + fs::set_permissions(&wrapper, permissions).unwrap(); + wrapper +} + +#[cfg(windows)] +fn write_python_wrapper(root: &Path, python_executable: &str) -> PathBuf { + let wrapper = root.join("python-wrapper.cmd"); + fs::write( + &wrapper, + format!( + "@echo off\r\nset \"PYTHONPATH={};%PYTHONPATH%\"\r\n\"{}\" %*\r\n", + root.display(), + python_executable.replace('"', "\"\"") + ), + ) + .unwrap(); + wrapper +} + +#[cfg(unix)] +fn shell_single_quote(path: &Path) -> String { + path.to_string_lossy().replace('\'', "'\\''") +} + +fn with_isolated_nemo_relay_modules( + py: Python<'_>, + native_module: &Bound<'_, PyModule>, + f: impl FnOnce() -> T, +) -> T { + let _serial_guard = SERIAL_TEST_MUTEX.lock().unwrap(); + let sys = py.import("sys").unwrap(); + let modules = sys + .getattr("modules") + .unwrap() + .cast_into::() + .unwrap(); + let saved_modules = modules + .iter() + .filter_map(|(name, module)| { + let name = name.extract::().ok()?; + if name == "nemo_relay" || name.starts_with("nemo_relay.") { + Some((name, module.unbind())) + } else { + None + } + }) + .collect::>(); + + clear_nemo_relay_modules(&modules); + modules + .set_item("nemo_relay._native", native_module.clone()) + .unwrap(); + + let result = catch_unwind(AssertUnwindSafe(f)); + + clear_nemo_relay_modules(&modules); + for (name, module) in saved_modules { + modules.set_item(name, module).unwrap(); + } + reset_runtime_state(); + + match result { + Ok(value) => value, + Err(payload) => std::panic::resume_unwind(payload), + } +} + +fn clear_nemo_relay_modules(modules: &Bound<'_, PyDict>) { + let module_names = modules + .iter() + .filter_map(|(name, _)| name.extract::().ok()) + .filter(|name| name == "nemo_relay" || name.starts_with("nemo_relay.")) + .collect::>(); + + for name in module_names { + modules.del_item(name).unwrap(); + } +} + +fn with_event_loop(py: Python<'_>, f: impl FnOnce(Bound<'_, PyAny>) -> T) -> T { + let asyncio = py.import("asyncio").unwrap(); + #[cfg(windows)] + { + let policy = asyncio + .getattr("WindowsSelectorEventLoopPolicy") + .unwrap() + .call0() + .unwrap(); + asyncio + .call_method1("set_event_loop_policy", (policy,)) + .unwrap(); + } + let event_loop = asyncio.call_method0("new_event_loop").unwrap(); + asyncio + .call_method1("set_event_loop", (&event_loop,)) + .unwrap(); + let result = catch_unwind(AssertUnwindSafe(|| f(event_loop.clone().into_any()))); + asyncio + .call_method1("set_event_loop", (py.None(),)) + .unwrap(); + event_loop.call_method0("close").unwrap(); + #[cfg(windows)] + asyncio + .call_method1("set_event_loop_policy", (py.None(),)) + .unwrap(); + match result { + Ok(value) => value, + Err(payload) => std::panic::resume_unwind(payload), + } +} + +fn reset_runtime_state() { + let _ = clear_plugin_configuration(); + let context = global_context(); + *context.write().unwrap() = NemoRelayContextState::new(); +} + +#[test] +fn test_native_pymodule_entrypoint_registers_bindings_without_local_provider_install() { + let _python = crate::test_support::init_python_test(); + let _serial_guard = SERIAL_TEST_MUTEX.lock().unwrap(); + reset_runtime_state(); + Python::attach(|py| { + let module = PyModule::new(py, "_native_guardrails_provider").unwrap(); + crate::_native(&module).unwrap(); + }); + + let runtime = tokio::runtime::Runtime::new().unwrap(); + let error = runtime + .block_on(initialize_plugins(PluginConfig { + version: 1, + components: vec![PluginComponentSpec { + kind: "nemo_guardrails".to_string(), + enabled: true, + config: serde_json::from_value(json!({ + "mode": "local", + "codec": "openai_chat", + "config_path": "./rails" + })) + .unwrap(), + }], + policy: Default::default(), + })) + .unwrap_err(); + + reset_runtime_state(); + match error { + nemo_relay::plugin::PluginError::RegistrationFailed(message) => { + assert!( + message.contains( + "NeMo Guardrails is required for the built-in NeMo Guardrails local backend" + ), + "unexpected message: {message}" + ); + } + other => panic!("unexpected error: {other}"), + } +} + +#[test] +fn test_guardrails_local_runtime_enforces_llm_input_and_output_checks() { + let _python = crate::test_support::init_python_test(); + reset_runtime_state(); + + Python::attach(|py| { + let native_module = PyModule::new(py, "_native_guardrails_local_runtime").unwrap(); + crate::_native(&native_module).unwrap(); + + with_isolated_nemo_relay_modules(py, &native_module, || { + let fake = FakeGuardrailsPackage::new( + py, + "fake_guardrails_local_runtime", + "0.22.0", + check_sequence_guardrails(), + ); + let python_dir = python_package_dir(); + let prelude = fake_guardrails_module_prelude( + &fake.module_name, + &python_dir.display().to_string(), + &fake.python_executable.display().to_string(), + ); + let module = load_module( + py, + &format!( + r#" +{prelude} + +import nemo_relay + +async def run_case(): + stack = nemo_relay.create_scope_stack() + nemo_relay.set_thread_scope_stack(stack) + await nemo_relay.plugin.initialize( + {{ + "version": 1, + "components": [ + {{ + "kind": "nemo_guardrails", + "enabled": True, + "config": {{ + "mode": "local", + "codec": "openai_chat", + "config_yaml": "models: []", + "input": True, + "output": True, + "tool_input": True, + "tool_output": True, + "local": {{ + "python_module": MODULE_NAME, + "python_executable": PYTHON_EXECUTABLE, + }}, + }}, + }} + ], + }} + ) + + request = nemo_relay.LLMRequest( + {{}}, + {{ + "model": "gpt-4o-mini", + "messages": [{{"role": "user", "content": "unsafe"}}], + }}, + ) + seen_request_messages = [] + + async def next_call(req): + seen_request_messages.append(req.content["messages"][-1]["content"]) + return {{ + "choices": [{{"message": {{"role": "assistant", "content": "safe reply"}}}}], + "id": "resp_1", + "model": "gpt-4o-mini", + }} + + try: + await nemo_relay.llm.execute( + "demo", + request, + next_call, + response_codec=nemo_relay.codecs.OpenAIChatCodec(), + ) + except RuntimeError as error: + llm_error = str(error) + else: + raise AssertionError("expected output rail block") + + return {{ + "llm_error": llm_error, + "seen_request_messages": seen_request_messages, + }} +"#, + prelude = prelude, + ), + ); + + let result_json = with_event_loop(py, |event_loop| { + let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); + let result = event_loop + .call_method1("run_until_complete", (coroutine,)) + .unwrap(); + crate::convert::py_to_json(&result).unwrap() + }); + + assert_eq!( + result_json["seen_request_messages"][0], + json!("sanitized user") + ); + assert!( + result_json["llm_error"] + .as_str() + .unwrap() + .contains("output rail blocked the LLM call"), + "unexpected error: {}", + result_json["llm_error"] + ); + assert!( + result_json["llm_error"] + .as_str() + .unwrap() + .contains("output-policy"), + "unexpected error: {}", + result_json["llm_error"] + ); + }); + }); + + reset_runtime_state(); +} + +#[test] +fn test_guardrails_local_runtime_rejects_unsupported_nemoguardrails_version() { + let _python = crate::test_support::init_python_test(); + reset_runtime_state(); + + Python::attach(|py| { + let native_module = PyModule::new(py, "_native_guardrails_version").unwrap(); + crate::_native(&native_module).unwrap(); + + with_isolated_nemo_relay_modules(py, &native_module, || { + let fake = FakeGuardrailsPackage::new( + py, + "fake_guardrails_bad_version", + "0.21.0", + check_sequence_guardrails(), + ); + let python_dir = python_package_dir(); + let prelude = fake_guardrails_module_prelude( + &fake.module_name, + &python_dir.display().to_string(), + &fake.python_executable.display().to_string(), + ); + let module = load_module( + py, + &format!( + r#" +{prelude} + +import nemo_relay + +async def run_case(): + await nemo_relay.plugin.initialize( + {{ + "version": 1, + "components": [ + {{ + "kind": "nemo_guardrails", + "enabled": True, + "config": {{ + "mode": "local", + "codec": "openai_chat", + "config_yaml": "models: []", + "input": True, + "local": {{ + "python_module": MODULE_NAME, + "python_executable": PYTHON_EXECUTABLE, + }}, + }}, + }} + ], + }} + ) +"#, + prelude = prelude, + ), + ); + + let error = with_event_loop(py, |event_loop| { + let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); + event_loop + .call_method1("run_until_complete", (coroutine,)) + .unwrap_err() + .to_string() + }); + + assert!( + error.contains("requires nemoguardrails==0.22.0"), + "unexpected error: {error}" + ); + assert!(error.contains("0.21.0"), "unexpected error: {error}"); + }); + }); + + reset_runtime_state(); +} + +#[test] +fn test_guardrails_local_runtime_enforces_streamed_output_rails() { + let _python = crate::test_support::init_python_test(); + reset_runtime_state(); + + Python::attach(|py| { + let native_module = PyModule::new(py, "_native_guardrails_streaming").unwrap(); + crate::_native(&native_module).unwrap(); + + with_isolated_nemo_relay_modules(py, &native_module, || { + let fake = FakeGuardrailsPackage::new( + py, + "fake_guardrails_streaming", + "0.22.0", + streaming_guardrails(), + ); + let python_dir = python_package_dir(); + let prelude = fake_guardrails_module_prelude( + &fake.module_name, + &python_dir.display().to_string(), + &fake.python_executable.display().to_string(), + ); + let module = load_module( + py, + &format!( + r#" +{prelude} + +event_log = [] + +import nemo_relay + +def plugin_config(config_yaml="models: []"): + return {{ + "version": 1, + "components": [ + {{ + "kind": "nemo_guardrails", + "enabled": True, + "config": {{ + "mode": "local", + "codec": "openai_chat", + "config_yaml": config_yaml, + "input": False, + "output": True, + "local": {{ + "python_module": MODULE_NAME, + "python_executable": PYTHON_EXECUTABLE, + }}, + }}, + }} + ], + }} + +async def run_stream(request): + collected = [] + + def next_call(req): + async def _stream(): + event_log.append("source:hello") + yield {{"choices": [{{"delta": {{"content": "hello"}}}}]}} + event_log.append("source:world") + yield {{"choices": [{{"delta": {{"content": "world"}}}}]}} + return _stream() + + stream = await nemo_relay.llm.stream_execute( + "demo", + request, + next_call, + collected.append, + lambda: {{"chunks": collected}}, + response_codec=nemo_relay.codecs.OpenAIChatCodec(), + ) + chunks = [] + async for chunk in stream: + event_log.append(f"yield:{{chunk['choices'][0]['delta']['content']}}") + chunks.append(chunk) + return chunks + +async def run_case(): + stack = nemo_relay.create_scope_stack() + nemo_relay.set_thread_scope_stack(stack) + event_log.clear() + await nemo_relay.plugin.initialize(plugin_config()) + + request = nemo_relay.LLMRequest( + {{}}, + {{ + "model": "gpt-4o-mini", + "messages": [{{"role": "user", "content": "hello"}}], + }}, + ) + + allowed_chunks = await run_stream(request) + + try: + await run_stream(request) + except RuntimeError as error: + blocked = str(error) + else: + raise AssertionError("expected streamed output block") + + nemo_relay.plugin.clear() + await nemo_relay.plugin.initialize(plugin_config("stream_first_false")) + try: + await run_stream(request) + except RuntimeError as error: + modified = str(error) + else: + raise AssertionError("expected stream_first=false error") + + return {{ + "allowed_chunks": allowed_chunks, + "blocked": blocked, + "event_log": event_log, + "modified": modified, + }} +"#, + prelude = prelude, + ), + ); + + let result = with_event_loop(py, |event_loop| { + let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); + let result = event_loop + .call_method1("run_until_complete", (coroutine,)) + .unwrap(); + crate::convert::py_to_json(&result).unwrap() + }); + assert_eq!( + result["allowed_chunks"], + json!([ + {"choices": [{"delta": {"content": "hello"}}]}, + {"choices": [{"delta": {"content": "world"}}]} + ]) + ); + let event_log = result["event_log"].as_array().unwrap(); + for expected in ["source:hello", "source:world", "yield:hello", "yield:world"] { + assert!( + event_log.iter().any(|event| event == expected), + "missing event {expected}: {event_log:?}" + ); + } + let source_hello = event_log + .iter() + .position(|event| event == "source:hello") + .unwrap(); + let source_world = event_log + .iter() + .position(|event| event == "source:world") + .unwrap(); + let yield_hello = event_log + .iter() + .position(|event| event == "yield:hello") + .unwrap(); + let yield_world = event_log + .iter() + .position(|event| event == "yield:world") + .unwrap(); + assert!(source_hello < source_world); + assert!(source_hello < yield_hello); + assert!(source_world < yield_world); + assert!(yield_hello < yield_world); + assert!( + result["blocked"] + .as_str() + .unwrap() + .contains("output rail blocked the LLM call") + ); + assert!( + result["modified"] + .as_str() + .unwrap() + .contains("stream_first = true") + ); + }); + }); + + reset_runtime_state(); +} + +#[test] +fn test_local_guardrails_provider_initializes_and_enforces_managed_core_calls() { + let _python = crate::test_support::init_python_test(); + reset_runtime_state(); + + Python::attach(|py| { + let native_module = PyModule::new(py, "_native_guardrails_e2e").unwrap(); + crate::_native(&native_module).unwrap(); + + with_isolated_nemo_relay_modules(py, &native_module, || { + let fake = FakeGuardrailsPackage::new( + py, + "fake_guardrails_local_e2e", + "0.22.0", + tool_sequence_guardrails(), + ); + let python_dir = python_package_dir(); + let prelude = fake_guardrails_module_prelude( + &fake.module_name, + &python_dir.display().to_string(), + &fake.python_executable.display().to_string(), + ); + let module = load_module( + py, + &format!( + r#" +{prelude} + +import nemo_relay + +async def run_case(): + stack = nemo_relay.create_scope_stack() + nemo_relay.set_thread_scope_stack(stack) + + await nemo_relay.plugin.initialize( + {{ + "version": 1, + "components": [ + {{ + "kind": "nemo_guardrails", + "enabled": True, + "config": {{ + "mode": "local", + "codec": "openai_chat", + "config_yaml": "models: []", + "input": True, + "output": True, + "tool_input": True, + "tool_output": True, + "local": {{ + "python_module": MODULE_NAME, + "python_executable": PYTHON_EXECUTABLE, + }}, + }}, + }} + ], + }} + ) + + request = nemo_relay.LLMRequest( + {{}}, + {{ + "model": "gpt-4o-mini", + "messages": [{{"role": "user", "content": "unsafe"}}], + }}, + ) + + seen_request_messages = [] + async def llm_impl(req): + seen_request_messages.append(req.content["messages"][-1]["content"]) + return {{ + "choices": [{{"message": {{"role": "assistant", "content": "safe reply"}}}}], + "id": "resp_1", + "model": req.content["model"], + }} + + llm_result = await nemo_relay.llm.execute( + "demo", + request, + llm_impl, + response_codec=nemo_relay.codecs.OpenAIChatCodec(), + ) + + seen_tool_args = [] + async def tool_impl(args): + seen_tool_args.append(args) + return {{"raw": True}} + + tool_result = await nemo_relay.tools.execute("weather_lookup", {{"city": "Phoenix"}}, tool_impl) + return {{ + "llm_result": llm_result, + "tool_result": tool_result, + "seen_request_messages": seen_request_messages, + "seen_tool_args": seen_tool_args, + }} +"#, + prelude = prelude, + ), + ); + let result_json = with_event_loop(py, |event_loop| { + let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); + let result = event_loop + .call_method1("run_until_complete", (coroutine,)) + .unwrap(); + crate::convert::py_to_json(&result).unwrap() + }); + + assert_eq!( + result_json["llm_result"]["choices"][0]["message"]["content"], + json!("safe reply") + ); + assert_eq!(result_json["tool_result"], json!({ "ok": true })); + assert_eq!( + result_json["seen_request_messages"][0], + json!("sanitized user") + ); + assert_eq!( + result_json["seen_tool_args"][0], + json!({ "city": "Boston" }) + ); + }); + }); + + reset_runtime_state(); +} diff --git a/crates/python/tests/coverage/py_plugin_coverage_tests.rs b/crates/python/tests/coverage/py_plugin_coverage_tests.rs index f774d5ea..dbb8a21b 100644 --- a/crates/python/tests/coverage/py_plugin_coverage_tests.rs +++ b/crates/python/tests/coverage/py_plugin_coverage_tests.rs @@ -792,6 +792,51 @@ async def initialize_plugins(module, config): }); } +#[test] +fn invoke_python_plugin_register_rolls_back_partial_registrations_on_error() { + let _python = crate::test_support::init_python_test(); + Python::attach(|py| { + let helpers = load_module( + py, + r#" +def subscriber(event): + return None + +class FailingPlugin: + def register(self, plugin_config, context): + context.register_subscriber("sub", subscriber) + raise RuntimeError("boom") +"#, + ); + + let plugin = helpers.getattr("FailingPlugin").unwrap().call0().unwrap(); + let register_fn = plugin.getattr("register").unwrap(); + let namespace_prefix = "rollback.".to_string(); + + for _ in 0..2 { + let err = invoke_python_plugin_register( + py, + "demo.rollback", + ®ister_fn, + &serde_json::Map::new(), + namespace_prefix.clone(), + ) + .unwrap_err(); + assert!(err.to_string().contains("boom"), "{err}"); + + let context = PyPluginContext { + registrations: Arc::new(Mutex::new(vec![])), + namespace_prefix: namespace_prefix.clone(), + }; + context + .register_subscriber("sub", helpers.getattr("subscriber").unwrap().unbind()) + .unwrap(); + let mut registrations = context.drain_registrations().unwrap(); + rollback_registrations(&mut registrations); + } + }); +} + #[test] fn plugin_context_lock_poisoning_covers_error_paths() { let _python = crate::test_support::init_python_test(); diff --git a/docs/about-nemo-relay/concepts/plugins.mdx b/docs/about-nemo-relay/concepts/plugins.mdx index b9c412e9..7034855a 100644 --- a/docs/about-nemo-relay/concepts/plugins.mdx +++ b/docs/about-nemo-relay/concepts/plugins.mdx @@ -171,9 +171,11 @@ The core crate also ships a built-in `nemo_guardrails` plugin component. It is the first-party Guardrails integration point that NeMo Relay owns through the shared plugin system. -The current shipped user-facing lane is the remote backend. It gives NeMo Relay -one canonical plugin kind and config shape for Guardrails-backed managed LLM -and tool checks while broader backend parity work remains separate. +The current shipped user-facing lanes are: + +- the remote backend for Guardrails-service integration +- the Python-backed local backend for `nemoguardrails` integration through a + subprocess worker Detailed Guardrails plugin configuration belongs in [NeMo Guardrails Configuration](/nemo-guardrails-plugin/configuration). diff --git a/docs/build-plugins/nemoguardrails.mdx b/docs/build-plugins/nemoguardrails.mdx index e5517612..a347c3f7 100644 --- a/docs/build-plugins/nemoguardrails.mdx +++ b/docs/build-plugins/nemoguardrails.mdx @@ -15,7 +15,6 @@ first-party `nemo_guardrails` component, see [NeMo Guardrails Plugin](/nemo-guardrails-plugin/about). - The example lives under `examples/nemoguardrails`. The single-file plugin implementation, runnable agent, and Guardrails config artifacts are under `example`. diff --git a/docs/nemo-guardrails-plugin/about.mdx b/docs/nemo-guardrails-plugin/about.mdx index aa1c6925..42c3c739 100644 --- a/docs/nemo-guardrails-plugin/about.mdx +++ b/docs/nemo-guardrails-plugin/about.mdx @@ -17,12 +17,11 @@ first-party NeMo Relay plugin. The plugin is designed around backend modes: - `remote` - - Implemented now. - Calls a Guardrails service over HTTP(S), including streaming over the same remote contract. - `local` - - Planned. - - Reserved for a future in-process Python `nemoguardrails` backend. + - Calls `nemoguardrails` through a local `python3` worker subprocess instead + of a separate Guardrails service. ## Use This Plugin When @@ -30,39 +29,41 @@ Start here when you need to: - Apply Guardrails input and output checks around managed `llm.execute(...)` calls. -- Apply Guardrails policy around managed tool execution, including the current - remote managed `tool_output` lane. +- Apply Guardrails policy around managed tool execution. - Configure Guardrails behavior through the same plugin config surface used by other first-party NeMo Relay components. -- Keep Guardrails behavior in a reusable process-level config document instead - of wiring provider-specific checks into each application call site. +- Keep Guardrails policy authoring in Guardrails-native config while NeMo Relay + owns when those checks run around managed execution. ## Current Scope -The current shipped user-facing lane is the built-in `remote` backend. +The built-in plugin currently exposes two user-facing modes: -That lane supports: +- `remote` for Guardrails-service integration over HTTP(S) +- `local` for `nemoguardrails` integration through a local Python worker -- Managed non-streaming LLM `input` checks. -- Managed non-streaming LLM `output` checks. -- Managed streaming LLM execution over the remote HTTP(S) path. -- Managed tool-result checks through `tool_output`. -- Request-time Guardrails defaults passed through to the remote backend. +Both modes support managed LLM `input` and `output`. The current mode-specific +differences are: -The current built-in remote backend does not support: +- `remote` supports `request_defaults` pass-through but does not support managed + `tool_input` +- `local` supports managed `tool_input` and broader LLM codec coverage, but it + does not support `request_defaults` -- Managed `tool_input` checks against the stock Guardrails remote contract. -- `local` mode. -- Remote managed LLM parity beyond `codec = "openai_chat"`. +The `local` backend requires a `python3 >= 3.11` executable that can import +`nemoguardrails==0.22.0`. It does not embed Python into the NeMo Relay binary. ## Managed Surfaces Versus Request Defaults -The NeMo Guardrails plugin model uses two different concepts: +Both `remote` mode and `local` mode share the same top-level plugin model, but +they do not implement every part of that model in the same way. -- Currently supported managed NeMo Relay execution surfaces in the shipped - remote backend: +At the plugin-model level, NeMo Guardrails uses two different concepts: + +- Top-level managed NeMo Relay execution surfaces: - `input` - `output` + - `tool_input` - `tool_output` - Guardrails backend request defaults: - `request_defaults.context` @@ -78,62 +79,43 @@ This distinction matters: - Managed surfaces wrap real NeMo Relay execution boundaries such as `llm.execute(...)` and `tools.execute(...)`. -- Managed surfaces let NeMo Relay enforce behavior around those boundaries. - Depending on the surface, Relay can block work, allow it, or apply managed - request or result handling before the application sees the final outcome. +- Managed surfaces give NeMo Relay an owned enforcement point around a known + runtime step. Depending on the backend and surface, Relay can block work, + allow it, or apply managed request or result handling before the application + sees the outcome. - Managed surfaces also give NeMo Relay a stable runtime boundary for its own - middleware ordering, lifecycle behavior, and observability marks. Relay knows - exactly which step is being wrapped and can attach policy and telemetry to - that step directly. -- `request_defaults` fields are forwarded to the selected Guardrails backend as - request semantics. They do not create new NeMo Relay-native execution - surfaces. -- `request_defaults` can still influence Guardrails behavior, but they do not - give NeMo Relay a new local runtime step to wrap. Relay is passing backend - options along with a request, not creating a new middleware boundary of its - own. -- `request_defaults` are also backend-contract dependent. A selected Guardrails - backend can use them when evaluating a request, but the exact effect depends - on what that backend supports. Relay is not creating a separate local - retrieval, dialog, or tool boundary just because those fields exist in the - request. - -In practice, the tradeoff is: - -- Managed surfaces give you a Relay-owned enforcement point around a known - runtime step, with Relay-owned enforcement, ordering, and marks around that - step. -- `request_defaults` give you backend-level configuration for a request, but - not a separate Relay-owned interception point, runtime boundary, or - middleware surface. - -Another way to think about it: + middleware ordering, lifecycle behavior, and observability marks. + +In practice: - Managed surfaces are places where NeMo Relay is holding the steering wheel. -- `request_defaults` are notes that NeMo Relay passes to the Guardrails backend - with a request. -Top-level `tool_input` is still part of the built-in plugin contract, but it is -not supported by the current stock-remote backend. +The forwarded request-default side is more mode-specific: + +- In `remote` mode, `request_defaults` fields are forwarded to the selected + Guardrails backend as request semantics. They do not create new NeMo + Relay-native execution surfaces. +- In `local` mode, `request_defaults` is rejected instead of passed through. -The overlap in names is important: +The overlap in names is important in `remote` mode: - Top-level `input` is a managed NeMo Relay execution surface. - `request_defaults.rails.input` is a backend pass-through option. - Top-level `output` is a managed NeMo Relay execution surface. - `request_defaults.rails.output` is a backend pass-through option. -- Top-level `tool_input` is part of the built-in plugin model, but the current - stock-remote backend rejects it. +- Top-level `tool_input` is a managed NeMo Relay execution surface in the + plugin contract. The current stock-remote backend rejects it, while the local + backend supports it. - `request_defaults.rails.tool_input` is a backend pass-through option. - Top-level `tool_output` is a managed NeMo Relay execution surface. - `request_defaults.rails.tool_output` is a backend pass-through option. In particular, `request_defaults.rails.dialog` and -`request_defaults.rails.retrieval` are simple pass-through options. They are -not separate managed middleware surfaces in NeMo Relay. +`request_defaults.rails.retrieval` are pass-through options. They are not +separate managed middleware surfaces in NeMo Relay. ## Pages - [NeMo Guardrails Configuration](/nemo-guardrails-plugin/configuration) - documents the built-in component shape, remote-mode boundaries, and current + documents the built-in component shape, mode boundaries, and the detailed support matrix. diff --git a/docs/nemo-guardrails-plugin/configuration.mdx b/docs/nemo-guardrails-plugin/configuration.mdx index ddaa1fb0..dd604995 100644 --- a/docs/nemo-guardrails-plugin/configuration.mdx +++ b/docs/nemo-guardrails-plugin/configuration.mdx @@ -10,9 +10,6 @@ SPDX-License-Identifier: Apache-2.0 */} Use this page when you want to configure the built-in NeMo Guardrails plugin component. The component kind is `nemo_guardrails`. -The current shipped user-facing backend is `mode = "remote"`. `local` remains -part of the config model, but it is not yet a finished user-facing backend. - For plugin file discovery, precedence, merge behavior, editor controls, and gateway conflict rules, see [Plugin Configuration Files](/build-plugins/plugin-configuration-files). @@ -37,32 +34,38 @@ The top-level NeMo Guardrails object contains: | `codec` | Managed LLM provider codec. | | `input` | Enables managed LLM input checks. | | `output` | Enables managed LLM output checks. | -| `tool_input` | Part of the built-in plugin model for managed tool-argument checks before execution. The current stock-remote backend rejects it. | +| `tool_input` | Enables managed tool-argument checks before execution. | | `tool_output` | Enables managed tool-result checks after execution. | | `priority` | Middleware priority for installed execution intercepts. | | `remote` | Remote backend settings. | -| `local` | Local backend settings for future local mode. | -| `request_defaults` | Default request-time Guardrails semantics passed to the backend. | +| `local` | Local backend settings. | +| `request_defaults` | Default request-time Guardrails semantics passed to the remote backend. | | `policy` | Component-local handling for unknown fields and unsupported values. | At least one managed Guardrails surface must be enabled. -## Current Remote Support +## Backend Support -The current built-in remote backend supports: +| Area | `remote` | `local` | +|---|---|---| +| Built-in component kind and config validation | Supported | Supported | +| Managed LLM `input` | Supported | Supported | +| Managed LLM `output` | Supported | Supported | +| Managed streaming LLM execution | Supported over the remote HTTP(S) contract | Supported; see [Streaming Boundary](#streaming-boundary) | +| Managed `tool_input` | Not supported against the stock Guardrails remote contract | Supported | +| Managed `tool_output` | Supported | Supported | +| `request_defaults` pass-through | Supported | Not supported | +| Codec support | `openai_chat` | `openai_chat`, `openai_responses`, `anthropic_messages` | +| Runtime availability | Any runtime that includes the remote backend | Runtimes that can start `python3 >= 3.11` with `nemoguardrails==0.22.0` installed | -| Area | Support | -|---|---| -| Built-in component kind and config validation | Supported | -| Managed LLM `input` | Supported | -| Managed LLM `output` | Supported | -| Managed streaming LLM execution over the remote HTTP(S) contract | Supported | -| Managed `tool_output` | Supported | -| Managed `tool_input` | Not supported against the stock Guardrails remote contract | -| `request_defaults` pass-through | Supported | -| `local` mode | Not implemented yet | +## Remote Mode + +Use `remote` mode when NeMo Relay should call a Guardrails service over +HTTP(S), especially when Guardrails must be shared across runtimes, used from +non-Python environments, or deployed independently from the application +process. -## Remote Requirements +### Requirements To use `mode = "remote"`, the configured `remote.endpoint` must point at a Guardrails service that NeMo Relay can reach from the running process and that @@ -73,7 +76,12 @@ Guardrails service still owns the actual policy content. In practice, NeMo Relay decides when managed checks run, while the Guardrails config decides what to block, allow, or rewrite. -## `plugins.toml` Example +### `plugins.toml` Example + +You can write this config directly in `plugins.toml`, or create and edit it +through the CLI with `nemo-relay plugins edit`. For plugin file discovery, +precedence, merge behavior, and editor controls, see +[Plugin Configuration Files](/build-plugins/plugin-configuration-files). ```toml version = 1 @@ -108,32 +116,12 @@ unknown_field = "warn" unsupported_value = "error" ``` -This example configures the built-in remote backend for a Guardrails service -that uses `codec = "openai_chat"`, managed LLM `input` and `output`, managed +This example configures the built-in remote mode for a Guardrails service that +uses `codec = "openai_chat"`, managed LLM `input` and `output`, managed `tool_output`, and request-default pass-through for backend context plus backend `input` and `output` rail selection. -In that setup, the NeMo Relay plugin chose the managed surfaces to wrap, while -the Guardrails config defined the actual blocking policy, such as rejecting -secret-seeking prompts, bypass attempts, specific blocked tokens, or -private-key-like output. - -For example, the Guardrails-side policy can look like this: - -```yaml -rails: - input: - flows: - - self check input - output: - flows: - - self check output -``` - -This Guardrails-side config defines the policy logic. The NeMo Relay plugin -config decides when those checks run. - -## Remote Mode Rules +### Rules When `mode = "remote"`: @@ -146,24 +134,24 @@ When `mode = "remote"`: ### Codec Boundary -The current built-in remote backend supports managed LLM execution only with: +The current built-in remote mode supports managed LLM execution only with: - `openai_chat` -## Managed Tool Boundary +### Managed Tool Boundary -The current remote backend supports managed `tool_output`. +The current remote mode supports managed `tool_output`. -The current remote backend rejects managed `tool_input` explicitly because the +The current remote mode rejects managed `tool_input` explicitly because the stock Guardrails remote contract does not activate pre-execution tool-call rails from externally submitted `/v1/chat/completions` history. NeMo Relay rejects `tool_input` in remote mode rather than leaving a silent non-enforcing path. -## Request Defaults +### Request Defaults `request_defaults` lets the built-in plugin pass request-time semantics through -to the selected backend. +to the selected remote backend. Supported request-default fields are: @@ -201,11 +189,12 @@ The `rails` section can include: - `tool_output` - `tool_input` -Those values are forwarded to the backend as request semantics. They do not -mean NeMo Relay owns separate managed retrieval or dialog execution surfaces. -`dialog` and `retrieval` are pass-through request options only. Likewise, -`request_defaults.rails.tool_input` is only a backend pass-through selector. It -does not make managed remote `tool_input` supported in the stock-remote lane. +Those values are forwarded to the remote backend as request semantics. They do +not mean NeMo Relay owns separate managed retrieval or dialog execution +surfaces. `dialog` and `retrieval` are pass-through request options only. +Likewise, `request_defaults.rails.tool_input` is only a backend pass-through +selector. It does not make managed remote `tool_input` supported in the +stock-remote lane. For more targeted request-time pass-through, the remote backend also forwards selectors like these: @@ -219,11 +208,7 @@ dialog = true tool_output = ["validate_tool_output"] ``` -This richer selector shape demonstrates how request-time Guardrails semantics -can be forwarded even when NeMo Relay does not own a separate native managed -surface for that category. - -## Observability +### Observability The current remote backend emits coarse backend-level marks for remote Guardrails activity: @@ -232,4 +217,134 @@ Guardrails activity: - `nemo_guardrails.remote.end` - `nemo_guardrails.remote.error` -These marks cover managed LLM remote execution and managed tool-result checks. +## Local Mode + +Use `local` mode when NeMo Relay should call `nemoguardrails` through a local +Python worker subprocess instead of a separate Guardrails service. + +### Requirements + +To use `mode = "local"`, NeMo Relay must be able to start a `python3 >= 3.11` +executable that can import `nemoguardrails==0.22.0`. + +The built-in local backend starts a Python worker process and sends Guardrails +checks over a JSON-lines protocol. Use it when the runtime has direct access to +the Python Guardrails dependency and configuration files rather than a separate +Guardrails service. Install the tested local-mode Guardrails dependency with +`pip install nemoguardrails==0.22.0`. + +The same ownership boundary still applies: + +- NeMo Relay decides when managed checks run. +- Guardrails-native config still decides what to block, allow, or rewrite. + +### `plugins.toml` Example + +You can write this config directly in `plugins.toml`, or create and edit it +through the CLI with `nemo-relay plugins edit`. For plugin file discovery, +precedence, merge behavior, and editor controls, see +[Plugin Configuration Files](/build-plugins/plugin-configuration-files). + +```toml +version = 1 + +[[components]] +kind = "nemo_guardrails" +enabled = true + +[components.config] +version = 1 +mode = "local" +codec = "openai_chat" +input = true +output = true +tool_input = true +tool_output = true +config_path = "./rails" + +[components.config.local] +python_executable = "python3" + +[components.config.policy] +unknown_component = "warn" +unknown_field = "warn" +unsupported_value = "error" +``` + +This example configures the built-in local mode for a runtime that can start +`python3`, import `nemoguardrails`, and read a native Guardrails config +directory from `./rails`. + +For example, the Guardrails-side policy can look like this: + +```yaml +rails: + input: + flows: + - self check input + output: + flows: + - self check output +``` + +This Guardrails-side config defines the policy logic. The NeMo Relay plugin +config decides when those checks run. + +### Rules + +When `mode = "local"`: + +- Exactly one of `config_path` or `config_yaml` is required. +- `colang_content` can only be used with `config_yaml`. +- `remote` settings cannot be present. +- `request_defaults` is rejected. +- `local.python_module` is optional and only needed when the runtime should + import the Guardrails dependency from a custom Python module path instead of + the default `nemoguardrails` package. +- `local.python_executable` is optional and defaults to the + `NEMO_RELAY_PYTHON` environment variable when set, otherwise `python3`. + +### Codec Boundary + +The current built-in local mode supports managed LLM execution with: + +- `openai_chat` +- `openai_responses` +- `anthropic_messages` + +### Managed Tool Boundary + +The current local mode supports both: + +- managed `tool_input` +- managed `tool_output` + +### Streaming Boundary + +The current local mode supports streaming LLM input checks before the stream +callback runs. + +When output rails are configured, the current local mode uses Guardrails-native +streaming output rails and buffers provider chunks until the local output rail +monitor clears the stream. That requires `rails.output.streaming.enabled = true` +in the Guardrails config. + +Guardrails calls the main streaming-output switch +`rails.output.streaming.stream_first`. + +When `stream_first = true`, the current local mode keeps provider-shaped chunks +buffered while Guardrails evaluates the streamed text: + +- provider chunks are not released to the caller until the monitor finishes +- if Guardrails blocks the stream, the call fails without delivering those chunks + +The current local mode does not support `rails.output.streaming.stream_first = false` +yet. That mode would require Guardrails-first chunk reconstruction: + +- Guardrails would need to evaluate streamed text before chunks are released to + the caller +- the local backend would then need to convert Guardrails-approved text back + into valid provider-shaped stream chunks + +That guarded-text-to-provider-chunk adapter does not exist yet in the current +local backend.