From 253c845fbefc8c3c62c04285f94df29ab22ecf49 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Sun, 31 May 2026 09:21:20 -0700 Subject: [PATCH 01/20] feat: add local backend for built-in nemo guardrails Signed-off-by: Alex Fournier --- .../src/plugins/nemo_guardrails/component.rs | 20 +- .../core/src/plugins/nemo_guardrails/local.rs | 51 ++ .../nemo_guardrails/component_tests.rs | 48 +- crates/python/src/lib.rs | 83 +++ crates/python/src/py_plugin.rs | 39 +- .../python/tests/coverage/coverage_tests.rs | 648 +++++++++++++++++- docs/about-nemo-relay/concepts/plugins.mdx | 7 +- docs/build-plugins/nemoguardrails.mdx | 1 - docs/nemo-guardrails-plugin/about.mdx | 108 ++- docs/nemo-guardrails-plugin/configuration.mdx | 205 ++++-- python/nemo_relay/_guardrails_local.py | 589 ++++++++++++++++ 11 files changed, 1654 insertions(+), 145 deletions(-) create mode 100644 crates/core/src/plugins/nemo_guardrails/local.rs create mode 100644 python/nemo_relay/_guardrails_local.py diff --git a/crates/core/src/plugins/nemo_guardrails/component.rs b/crates/core/src/plugins/nemo_guardrails/component.rs index 13695405..28decfbe 100644 --- a/crates/core/src/plugins/nemo_guardrails/component.rs +++ b/crates/core/src/plugins/nemo_guardrails/component.rs @@ -17,9 +17,13 @@ use crate::plugin::{ register_plugin, }; +#[path = "local.rs"] +mod local; #[cfg(all(feature = "guardrails-remote", not(target_arch = "wasm32")))] #[path = "remote.rs"] mod remote; +use local::register_local_backend; +pub use local::{clear_local_backend_provider, register_local_backend_provider}; #[cfg(all(feature = "guardrails-remote", not(target_arch = "wasm32")))] use remote::register_remote_backend; @@ -447,9 +451,7 @@ fn register_nemo_guardrails_backend( ) -> PluginResult<()> { match config.mode.as_str() { "remote" => register_remote_backend(config, ctx), - "local" => Err(PluginError::RegistrationFailed( - "built-in NeMo Guardrails local backend is not implemented yet".to_string(), - )), + "local" => register_local_backend(config, ctx), other => Err(PluginError::InvalidConfig(format!( "unsupported NeMo Guardrails mode '{other}'" ))), @@ -955,6 +957,18 @@ fn validate_request_defaults( return; }; + if config.mode == "local" { + push_policy_diag( + diagnostics, + policy.unsupported_value, + "nemo_guardrails.unsupported_value", + Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), + Some("request_defaults".to_string()), + "local mode does not currently support request_defaults".to_string(), + ); + return; + } + validate_json_object_field( diagnostics, policy, diff --git a/crates/core/src/plugins/nemo_guardrails/local.rs b/crates/core/src/plugins/nemo_guardrails/local.rs new file mode 100644 index 00000000..31f4e1c8 --- /dev/null +++ b/crates/core/src/plugins/nemo_guardrails/local.rs @@ -0,0 +1,51 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::sync::{Arc, LazyLock, Mutex, MutexGuard}; + +use crate::plugin::{PluginError, PluginRegistrationContext, Result as PluginResult}; + +use super::NeMoGuardrailsConfig; + +type LocalBackendProvider = Arc< + dyn Fn(NeMoGuardrailsConfig, &mut PluginRegistrationContext) -> PluginResult<()> + Send + Sync, +>; + +static LOCAL_BACKEND_PROVIDER: LazyLock>> = + LazyLock::new(|| Mutex::new(None)); + +fn local_backend_provider_guard() -> PluginResult>> { + LOCAL_BACKEND_PROVIDER.lock().map_err(|e| { + PluginError::Internal(format!( + "NeMo Guardrails local backend provider lock poisoned: {e}" + )) + }) +} + +#[doc(hidden)] +pub fn register_local_backend_provider(provider: LocalBackendProvider) -> PluginResult<()> { + let mut guard = local_backend_provider_guard()?; + *guard = Some(provider); + Ok(()) +} + +#[doc(hidden)] +pub fn clear_local_backend_provider() -> PluginResult<()> { + let mut guard = local_backend_provider_guard()?; + *guard = None; + Ok(()) +} + +pub(super) fn register_local_backend( + config: NeMoGuardrailsConfig, + ctx: &mut PluginRegistrationContext, +) -> PluginResult<()> { + let provider = local_backend_provider_guard()?.clone(); + + match provider { + Some(provider) => provider(config, ctx), + None => Err(PluginError::RegistrationFailed( + "built-in NeMo Guardrails local backend is unavailable in this runtime".to_string(), + )), + } +} diff --git a/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs b/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs index 852b8928..0823bbac 100644 --- a/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs +++ b/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs @@ -42,6 +42,7 @@ const TEST_TIMEOUT: Duration = Duration::from_secs(5); fn reset_runtime() { let _ = clear_plugin_configuration(); + crate::plugins::nemo_guardrails::component::clear_local_backend_provider().unwrap(); crate::shared_runtime::reset_runtime_owner_for_tests(); let context = global_context(); *context.write().unwrap() = NemoRelayContextState::new(); @@ -789,6 +790,22 @@ fn invalid_shapes_and_values_are_reported() { .any(|diag| diag.field.as_deref() == Some("local.python_module")) ); + let local_request_defaults = validate_plugin_config(&plugin_config(json!({ + "mode": "local", + "codec": "openai_chat", + "config_path": "./rails", + "request_defaults": { + "context": {"tenant": "demo"} + } + }))); + assert!(local_request_defaults.has_errors()); + assert!(local_request_defaults.diagnostics.iter().any(|diag| { + diag.field.as_deref() == Some("request_defaults") + && diag + .message + .contains("local mode does not currently support request_defaults") + })); + let invalid_request_defaults = validate_plugin_config(&plugin_config(json!({ "mode": "remote", "codec": "openai_chat", @@ -975,7 +992,7 @@ fn enabled_local_initialization_fails_fast_until_backend_exists() { match error { crate::plugin::PluginError::RegistrationFailed(message) => { - assert!(message.contains("local backend")); + assert!(message.contains("unavailable in this runtime")); } other => panic!("unexpected error: {other}"), } @@ -1007,5 +1024,34 @@ fn enabled_unknown_mode_initialization_fails_fast_when_policy_ignores_validation } } +#[test] +fn enabled_local_initialization_dispatches_through_installed_provider() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + + let provider_called = Arc::new(AtomicBool::new(false)); + let provider_called_clone = Arc::clone(&provider_called); + crate::plugins::nemo_guardrails::component::register_local_backend_provider(Arc::new( + move |config, _ctx| { + provider_called_clone.store(true, Ordering::SeqCst); + assert_eq!(config.mode, "local"); + assert_eq!(config.config_path.as_deref(), Some("./rails")); + Ok(()) + }, + )) + .unwrap(); + + futures::executor::block_on(initialize_plugins(plugin_config(json!({ + "mode": "local", + "codec": "openai_chat", + "config_path": "./rails" + })))) + .unwrap(); + + assert!(provider_called.load(Ordering::SeqCst)); +} + #[path = "remote_tests.rs"] mod remote_tests; diff --git a/crates/python/src/lib.rs b/crates/python/src/lib.rs index d11df353..13d0c29f 100644 --- a/crates/python/src/lib.rs +++ b/crates/python/src/lib.rs @@ -20,9 +20,16 @@ //! - `py_adaptive` — Python-facing adaptive helpers (`set_latency_sensitivity`) //! - `py_plugin` — Python-facing generic plugin config/registration helpers //! - `convert` — JSON ↔ Python conversion utilities +use nemo_relay::plugin::{PluginRegistrationContext, Result as PluginResult}; +use nemo_relay::plugins::nemo_guardrails::component::{ + NeMoGuardrailsConfig, register_local_backend_provider, +}; use nemo_relay::shared_runtime::initialize_shared_runtime_binding; use nemo_relay_adaptive::plugin_component::register_adaptive_component; use pyo3::prelude::*; +use serde_json::Value as Json; +use std::path::{Path, PathBuf}; +use std::sync::Arc; mod convert; #[doc(hidden)] @@ -52,6 +59,13 @@ fn _native(m: &Bound<'_, PyModule>) -> PyResult<()> { "failed to register adaptive plugin component: {e}" )) })?; + register_local_backend_provider(Arc::new(register_python_local_guardrails_backend)).map_err( + |e| { + pyo3::exceptions::PyRuntimeError::new_err(format!( + "failed to register NeMo Guardrails local backend provider: {e}" + )) + }, + )?; py_types::register(m)?; py_api::register(m)?; py_plugin::register(m)?; @@ -59,6 +73,75 @@ fn _native(m: &Bound<'_, PyModule>) -> PyResult<()> { Ok(()) } +fn register_python_local_guardrails_backend( + config: NeMoGuardrailsConfig, + ctx: &mut PluginRegistrationContext, +) -> PluginResult<()> { + let plugin_config = match serde_json::to_value(config) { + Ok(Json::Object(config)) => config, + Ok(_) => { + return Err(nemo_relay::plugin::PluginError::Internal( + "NeMo Guardrails local config did not serialize to a JSON object".to_string(), + )); + } + Err(err) => { + return Err(nemo_relay::plugin::PluginError::Internal(format!( + "failed to serialize NeMo Guardrails local config: {err}" + ))); + } + }; + + let registrations = Python::attach(|py| { + let register_fn = load_guardrails_local_register_fn(py)?; + let namespace_prefix = ctx.qualify_name(""); + crate::py_plugin::invoke_python_plugin_register( + py, + "nemo_guardrails", + ®ister_fn, + &plugin_config, + namespace_prefix, + ) + }) + .map_err(|err| nemo_relay::plugin::PluginError::RegistrationFailed(err.to_string()))?; + + ctx.extend_registrations(registrations); + Ok(()) +} + +fn load_guardrails_local_register_fn(py: Python<'_>) -> PyResult> { + let module = match py.import("nemo_relay._guardrails_local") { + Ok(module) => module, + Err(err) => { + let source_python_dir = guardrails_local_source_python_dir(); + if !source_python_dir.exists() { + return Err(err); + } + + prepend_python_path_if_missing(py, &source_python_dir)?; + py.import("nemo_relay._guardrails_local")? + } + }; + module.getattr("register_local_backend") +} + +fn guardrails_local_source_python_dir() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../python") +} + +fn prepend_python_path_if_missing(py: Python<'_>, path: &Path) -> PyResult<()> { + let sys = py.import("sys")?; + let sys_path = sys.getattr("path")?; + let path_str = path.to_string_lossy(); + + if !sys_path.contains(path_str.as_ref())? { + // Source-tree fallback for local development and in-repo tests where the + // Python package has not been installed into the active environment yet. + sys_path.call_method1("insert", (0, path_str.as_ref()))?; + } + + Ok(()) +} + #[cfg(test)] #[path = "../tests/coverage/coverage_tests.rs"] mod coverage_tests; diff --git a/crates/python/src/py_plugin.rs b/crates/python/src/py_plugin.rs index d483375b..ee668ea1 100644 --- a/crates/python/src/py_plugin.rs +++ b/crates/python/src/py_plugin.rs @@ -160,6 +160,27 @@ fn new_py_plugin_context( ) } +pub(crate) fn invoke_python_plugin_register( + py: Python<'_>, + plugin_kind: &str, + register_fn: &Bound<'_, PyAny>, + plugin_config: &Map, + namespace_prefix: String, +) -> PyResult> { + let py_ctx = new_py_plugin_context( + py, + plugin_kind, + Arc::new(Mutex::new(vec![])), + namespace_prefix, + )?; + let plugin_config_py = plugin_config_to_py(py, plugin_kind, plugin_config)?; + register_fn.call1((plugin_config_py, py_ctx.clone_ref(py)))?; + { + let py_ctx_ref = py_ctx.bind(py).borrow(); + py_ctx_ref.drain_registrations() + } +} + #[pyclass(name = "PluginContext")] pub struct PyPluginContext { registrations: Arc>>, @@ -695,22 +716,14 @@ impl Plugin for PyPlugin { let plugin_config = plugin_config.clone(); Box::pin(async move { let registrations = Python::attach(|py| -> PyResult> { - let py_ctx = new_py_plugin_context( + let register_fn = self.plugin.getattr(py, "register")?.into_bound(py); + invoke_python_plugin_register( py, &self.plugin_kind, - Arc::new(Mutex::new(vec![])), + ®ister_fn, + &plugin_config, namespace_prefix, - )?; - let plugin_config_py = json_to_py(py, &Json::Object(plugin_config.clone()))?; - self.plugin.call_method1( - py, - "register", - (plugin_config_py, py_ctx.clone_ref(py)), - )?; - { - let py_ctx_ref = py_ctx.bind(py).borrow(); - py_ctx_ref.drain_registrations() - } + ) }) .map_err(|err| PluginError::RegistrationFailed(err.to_string()))?; diff --git a/crates/python/tests/coverage/coverage_tests.rs b/crates/python/tests/coverage/coverage_tests.rs index 6c3205e0..3e553341 100644 --- a/crates/python/tests/coverage/coverage_tests.rs +++ b/crates/python/tests/coverage/coverage_tests.rs @@ -4,11 +4,13 @@ //! Coverage tests for coverage in the NeMo Relay Python crate. use std::ffi::CString; +use std::path::PathBuf; use std::pin::Pin; use std::sync::Arc; +use pyo3::ffi::c_str; use pyo3::prelude::*; -use pyo3::types::PyModule; +use pyo3::types::{IntoPyDict, PyModule}; use serde_json::{Value as Json, json}; use tokio_stream::Stream; use tokio_stream::StreamExt; @@ -24,7 +26,13 @@ use crate::py_callable::{ }; use nemo_relay::api::event::{BaseEvent, Event, EventCategory, ScopeCategory, ScopeEvent}; use nemo_relay::api::llm::LlmRequest; -use nemo_relay::api::runtime::{LlmExecutionNextFn, LlmStreamExecutionNextFn, ToolExecutionNextFn}; +use nemo_relay::api::runtime::{ + LlmExecutionNextFn, LlmStreamExecutionNextFn, NemoRelayContextState, ToolExecutionNextFn, + global_context, +}; +use nemo_relay::plugin::{ + PluginComponentSpec, PluginConfig, clear_plugin_configuration, initialize_plugins, +}; fn load_module<'py>(py: Python<'py>, code: &str) -> Bound<'py, PyModule> { let code = CString::new(code).unwrap(); @@ -65,6 +73,13 @@ fn with_event_loop(py: Python<'_>, f: impl FnOnce(Bound<'_, PyAny>) -> T) -> result } +fn reset_runtime_state() { + let _ = clear_plugin_configuration(); + nemo_relay::plugins::nemo_guardrails::component::clear_local_backend_provider().unwrap(); + let context = global_context(); + *context.write().unwrap() = NemoRelayContextState::new(); +} + #[test] fn test_native_module_registers_types_and_api_functions() { let _python = crate::test_support::init_python_test(); @@ -94,6 +109,635 @@ fn test_native_pymodule_entrypoint_registers_bindings() { }); } +#[test] +fn test_native_pymodule_entrypoint_installs_nemo_guardrails_local_provider() { + let _python = crate::test_support::init_python_test(); + Python::attach(|py| { + let module = PyModule::new(py, "_native_guardrails_provider").unwrap(); + crate::_native(&module).unwrap(); + }); + + let runtime = tokio::runtime::Runtime::new().unwrap(); + let error = runtime + .block_on(initialize_plugins(PluginConfig { + version: 1, + components: vec![PluginComponentSpec { + kind: "nemo_guardrails".to_string(), + enabled: true, + config: serde_json::from_value(json!({ + "mode": "local", + "codec": "openai_chat", + "config_path": "./rails" + })) + .unwrap(), + }], + policy: Default::default(), + })) + .unwrap_err(); + + let _ = clear_plugin_configuration(); + match error { + nemo_relay::plugin::PluginError::RegistrationFailed(message) => { + assert!( + message.contains( + "NeMo Guardrails is required for the built-in NeMo Guardrails local backend" + ), + "unexpected message: {message}" + ); + } + other => panic!("unexpected error: {other}"), + } +} + +#[test] +fn test_guardrails_local_helper_registers_and_enforces_llm_and_tool_checks() { + let _python = crate::test_support::init_python_test(); + Python::attach(|py| { + let python_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../python"); + let module = load_module( + py, + &format!( + r#" +import pathlib +import sys +import types + +sys.path.insert(0, {python_dir:?}) + +MODULE_NAME = "fake_guardrails_local_helper" + +fake_root = types.ModuleType(MODULE_NAME) +fake_options = types.ModuleType(MODULE_NAME + ".rails.llm.options") + +class Result: + def __init__(self, status, content=None, rail=None): + self.status = status + self.content = content + self.rail = rail + +class RailType: + INPUT = "input" + OUTPUT = "output" + +class RailStatus: + BLOCKED = "blocked" + MODIFIED = "modified" + PASSED = "passed" + +class RailsConfig: + @staticmethod + def from_content(*, colang_content=None, yaml_content=None): + return {{"yaml": yaml_content, "colang": colang_content}} + + @staticmethod + def from_path(path): + return {{"path": path}} + +check_results = [] +check_calls = [] + +class LLMRails: + def __init__(self, config): + self.config = config + + async def check_async(self, messages, rail_types): + check_calls.append((messages, rail_types)) + return check_results.pop(0) + +fake_root.RailsConfig = RailsConfig +fake_root.LLMRails = LLMRails +fake_options.RailType = RailType +fake_options.RailStatus = RailStatus + +sys.modules[MODULE_NAME] = fake_root +sys.modules[MODULE_NAME + ".rails"] = types.ModuleType(MODULE_NAME + ".rails") +sys.modules[MODULE_NAME + ".rails.llm"] = types.ModuleType(MODULE_NAME + ".rails.llm") +sys.modules[MODULE_NAME + ".rails.llm.options"] = fake_options + +from nemo_relay._native import LLMRequest +from nemo_relay._guardrails_local import register_local_backend + +class Context: + def register_llm_execution_intercept(self, name, priority, callback): + self.llm = callback + + def register_llm_stream_execution_intercept(self, name, priority, callback): + self.stream = callback + + def register_tool_execution_intercept(self, name, priority, callback): + self.tool = callback + +async def run_case(): + ctx = Context() + event_log = [] + register_local_backend( + {{ + "mode": "local", + "codec": "openai_chat", + "config_yaml": "models: []", + "input": True, + "output": True, + "tool_input": True, + "tool_output": True, + "local": {{"python_module": MODULE_NAME}}, + }}, + ctx, + ) + + request = LLMRequest( + {{}}, + {{ + "model": "gpt-4o-mini", + "messages": [{{"role": "user", "content": "unsafe"}}], + }}, + ) + seen_request_messages = [] + + async def next_call(req): + seen_request_messages.append(req.content["messages"][-1]["content"]) + return {{ + "choices": [{{"message": {{"role": "assistant", "content": "safe reply"}}}}], + "id": "resp_1", + "model": "gpt-4o-mini", + }} + + check_results.extend( + [ + Result(RailStatus.MODIFIED, content="sanitized user"), + Result(RailStatus.PASSED), + ] + ) + llm_result = await ctx.llm("demo", request, next_call) + + seen_tool_args = [] + + async def next_tool(args): + seen_tool_args.append(args) + return {{"raw": True}} + + check_results.extend( + [ + Result(RailStatus.MODIFIED, content='{{"arguments": {{"city": "Boston"}}}}'), + Result(RailStatus.MODIFIED, content='{{"result": {{"ok": true}}}}'), + ] + ) + tool_result = await ctx.tool("weather_lookup", {{"city": "Phoenix"}}, next_tool) + + return {{ + "llm_result": llm_result, + "tool_result": tool_result, + "seen_request_messages": seen_request_messages, + "seen_tool_args": seen_tool_args, + "check_calls": check_calls, + }} +"#, + python_dir = python_dir.display(), + ), + ); + + let result_json = with_event_loop(py, |event_loop| { + let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); + let result = event_loop + .call_method1("run_until_complete", (coroutine,)) + .unwrap(); + crate::convert::py_to_json(&result).unwrap() + }); + + assert_eq!( + result_json["seen_request_messages"][0], + json!("sanitized user") + ); + assert_eq!(result_json["tool_result"], json!({ "ok": true })); + assert_eq!( + result_json["seen_tool_args"][0], + json!({ "city": "Boston" }) + ); + assert_eq!( + result_json["llm_result"]["choices"][0]["message"]["content"], + json!("safe reply") + ); + assert_eq!(result_json["check_calls"].as_array().unwrap().len(), 4); + }); +} + +#[test] +fn test_guardrails_local_helper_enforces_streamed_output_rails() { + let _python = crate::test_support::init_python_test(); + Python::attach(|py| { + let native_module = PyModule::new(py, "_native_guardrails_streaming").unwrap(); + crate::_native(&native_module).unwrap(); + let sys = py.import("sys").unwrap(); + let modules = sys.getattr("modules").unwrap(); + modules + .set_item("nemo_relay._native", native_module.clone()) + .unwrap(); + + let python_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../python"); + let module = load_module( + py, + &format!( + r#" +import sys +import types + +sys.path.insert(0, {python_dir:?}) + +MODULE_NAME = "fake_guardrails_streaming" + +fake_root = types.ModuleType(MODULE_NAME) +fake_options = types.ModuleType(MODULE_NAME + ".rails.llm.options") + +class Result: + def __init__(self, status, content=None, rail=None): + self.status = status + self.content = content + self.rail = rail + +class RailType: + INPUT = "input" + OUTPUT = "output" + +class RailStatus: + BLOCKED = "blocked" + MODIFIED = "modified" + PASSED = "passed" + +class RailsConfig: + @staticmethod + def from_content(*, colang_content=None, yaml_content=None): + return {{"yaml": yaml_content}} + +stream_results = [] +event_log = [] + +class LLMRails: + def __init__(self, config): + self.config = types.SimpleNamespace( + rails=types.SimpleNamespace( + output=types.SimpleNamespace( + flows=["self check output"], + streaming=types.SimpleNamespace(enabled=True, stream_first=True), + ) + ) + ) + + async def check_async(self, messages, rail_types): + return Result(RailStatus.PASSED) + + def stream_async(self, *, messages=None, generator=None, include_metadata=False): + async def _run(): + outcome = stream_results.pop(0) + async for chunk in generator: + event_log.append(f"guardrails-sees:{{chunk}}") + if outcome == "pass": + yield chunk + if outcome == "block": + yield '{{"error": {{"message": "Blocked by output rails: output-policy", "type": "guardrails_violation"}}}}' + return _run() + +fake_root.RailsConfig = RailsConfig +fake_root.LLMRails = LLMRails +fake_options.RailType = RailType +fake_options.RailStatus = RailStatus + +sys.modules[MODULE_NAME] = fake_root +sys.modules[MODULE_NAME + ".rails"] = types.ModuleType(MODULE_NAME + ".rails") +sys.modules[MODULE_NAME + ".rails.llm"] = types.ModuleType(MODULE_NAME + ".rails.llm") +sys.modules[MODULE_NAME + ".rails.llm.options"] = fake_options + +from nemo_relay._native import LLMRequest +from nemo_relay._guardrails_local import register_local_backend + +class Context: + def register_llm_execution_intercept(self, name, priority, callback): + self.llm = callback + + def register_llm_stream_execution_intercept(self, name, priority, callback): + self.stream = callback + + def register_tool_execution_intercept(self, name, priority, callback): + self.tool = callback + +async def run_case(): + ctx = Context() + event_log.clear() + register_local_backend( + {{ + "mode": "local", + "codec": "openai_chat", + "config_yaml": "models: []", + "input": False, + "output": True, + "local": {{"python_module": MODULE_NAME}}, + }}, + ctx, + ) + + request = LLMRequest( + {{}}, + {{ + "model": "gpt-4o-mini", + "messages": [{{"role": "user", "content": "hello"}}], + }}, + ) + + async def next_call(req): + async def _stream(): + event_log.append("source:hello") + yield {{"choices": [{{"delta": {{"content": "hello"}}}}]}} + event_log.append("source:world") + yield {{"choices": [{{"delta": {{"content": "world"}}}}]}} + return _stream() + + stream_results.append("pass") + allowed_stream = await ctx.stream(request, next_call) + allowed_chunks = [] + async for chunk in allowed_stream: + event_log.append(f"yield:{{chunk['choices'][0]['delta']['content']}}") + allowed_chunks.append(chunk) + + stream_results.append("block") + try: + blocked_stream = await ctx.stream(request, next_call) + async for _chunk in blocked_stream: + pass + except RuntimeError as error: + blocked = str(error) + else: + raise AssertionError("expected streamed output block") + + ctx_stream_first_false = Context() + fake_root.LLMRails = lambda config: types.SimpleNamespace( + config=types.SimpleNamespace( + rails=types.SimpleNamespace( + output=types.SimpleNamespace( + flows=["self check output"], + streaming=types.SimpleNamespace(enabled=True, stream_first=False), + ) + ) + ), + check_async=LLMRails(config).check_async, + stream_async=LLMRails(config).stream_async, + ) + register_local_backend( + {{ + "mode": "local", + "codec": "openai_chat", + "config_yaml": "models: []", + "input": False, + "output": True, + "local": {{"python_module": MODULE_NAME}}, + }}, + ctx_stream_first_false, + ) + try: + failing_stream = await ctx_stream_first_false.stream(request, next_call) + async for _chunk in failing_stream: + pass + except RuntimeError as error: + modified = str(error) + else: + raise AssertionError("expected stream_first=false error") + + return {{ + "allowed_chunks": allowed_chunks, + "blocked": blocked, + "event_log": event_log, + "modified": modified, + }} +"#, + python_dir = python_dir.display(), + ), + ); + + let result = with_event_loop(py, |event_loop| { + let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); + let result = event_loop + .call_method1("run_until_complete", (coroutine,)) + .unwrap(); + crate::convert::py_to_json(&result).unwrap() + }); + assert_eq!( + result["allowed_chunks"], + json!([ + {"choices": [{"delta": {"content": "hello"}}]}, + {"choices": [{"delta": {"content": "world"}}]} + ]) + ); + let event_log = result["event_log"].as_array().unwrap(); + assert_eq!( + &event_log[..6], + json!([ + "source:hello", + "yield:hello", + "source:world", + "yield:world", + "guardrails-sees:hello", + "guardrails-sees:world", + ]) + .as_array() + .unwrap() + ); + assert!( + result["blocked"] + .as_str() + .unwrap() + .contains("output rail blocked the LLM call") + ); + assert!( + result["modified"] + .as_str() + .unwrap() + .contains("stream_first = true") + ); + }); +} + +#[test] +fn test_local_guardrails_provider_initializes_and_enforces_managed_core_calls() { + let _python = crate::test_support::init_python_test(); + reset_runtime_state(); + + Python::attach(|py| { + let native_module = PyModule::new(py, "_native_guardrails_e2e").unwrap(); + crate::_native(&native_module).unwrap(); + let sys = py.import("sys").unwrap(); + let modules = sys.getattr("modules").unwrap(); + let module_names = py + .eval( + c_str!("list(sys.modules.keys())"), + None, + Some(&[(c_str!("sys"), sys)].into_py_dict(py).unwrap()), + ) + .unwrap() + .extract::>() + .unwrap(); + for name in module_names { + if name == "nemo_relay" || name.starts_with("nemo_relay.") { + modules.del_item(name).unwrap(); + } + } + modules + .set_item("nemo_relay._native", native_module.clone()) + .unwrap(); + + let python_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../python"); + let module = load_module( + py, + &format!( + r#" +import sys +import types + +sys.path.insert(0, {python_dir:?}) + +MODULE_NAME = "fake_guardrails_local_e2e" + +fake_root = types.ModuleType(MODULE_NAME) +fake_options = types.ModuleType(MODULE_NAME + ".rails.llm.options") + +class Result: + def __init__(self, status, content=None, rail=None): + self.status = status + self.content = content + self.rail = rail + +class RailType: + INPUT = "input" + OUTPUT = "output" + +class RailStatus: + BLOCKED = "blocked" + MODIFIED = "modified" + PASSED = "passed" + +class RailsConfig: + @staticmethod + def from_content(*, colang_content=None, yaml_content=None): + return {{"yaml": yaml_content}} + +check_results = [] + +class LLMRails: + def __init__(self, config): + self.config = config + + async def check_async(self, messages, rail_types): + return check_results.pop(0) + +fake_root.RailsConfig = RailsConfig +fake_root.LLMRails = LLMRails +fake_options.RailType = RailType +fake_options.RailStatus = RailStatus + +sys.modules[MODULE_NAME] = fake_root +sys.modules[MODULE_NAME + ".rails"] = types.ModuleType(MODULE_NAME + ".rails") +sys.modules[MODULE_NAME + ".rails.llm"] = types.ModuleType(MODULE_NAME + ".rails.llm") +sys.modules[MODULE_NAME + ".rails.llm.options"] = fake_options + +import nemo_relay + +async def run_case(): + stack = nemo_relay.create_scope_stack() + nemo_relay.set_thread_scope_stack(stack) + + await nemo_relay.plugin.initialize( + {{ + "version": 1, + "components": [ + {{ + "kind": "nemo_guardrails", + "enabled": True, + "config": {{ + "mode": "local", + "codec": "openai_chat", + "config_yaml": "models: []", + "input": True, + "output": True, + "tool_input": True, + "tool_output": True, + "local": {{"python_module": MODULE_NAME}}, + }}, + }} + ], + }} + ) + + check_results.extend( + [ + Result(RailStatus.MODIFIED, content="sanitized user"), + Result(RailStatus.PASSED), + Result(RailStatus.MODIFIED, content='{{"arguments": {{"city": "Boston"}}}}'), + Result(RailStatus.MODIFIED, content='{{"result": {{"ok": true}}}}'), + ] + ) + + request = nemo_relay.LLMRequest( + {{}}, + {{ + "model": "gpt-4o-mini", + "messages": [{{"role": "user", "content": "unsafe"}}], + }}, + ) + + seen_request_messages = [] + async def llm_impl(req): + seen_request_messages.append(req.content["messages"][-1]["content"]) + return {{ + "choices": [{{"message": {{"role": "assistant", "content": "safe reply"}}}}], + "id": "resp_1", + "model": req.content["model"], + }} + + llm_result = await nemo_relay.llm.execute( + "demo", + request, + llm_impl, + response_codec=nemo_relay.codecs.OpenAIChatCodec(), + ) + + seen_tool_args = [] + async def tool_impl(args): + seen_tool_args.append(args) + return {{"raw": True}} + + tool_result = await nemo_relay.tools.execute("weather_lookup", {{"city": "Phoenix"}}, tool_impl) + return {{ + "llm_result": llm_result, + "tool_result": tool_result, + "seen_request_messages": seen_request_messages, + "seen_tool_args": seen_tool_args, + }} +"#, + python_dir = python_dir.display(), + ), + ); + let result_json = with_event_loop(py, |event_loop| { + let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); + let result = event_loop + .call_method1("run_until_complete", (coroutine,)) + .unwrap(); + crate::convert::py_to_json(&result).unwrap() + }); + + assert_eq!( + result_json["llm_result"]["choices"][0]["message"]["content"], + json!("safe reply") + ); + assert_eq!(result_json["tool_result"], json!({ "ok": true })); + assert_eq!( + result_json["seen_request_messages"][0], + json!("sanitized user") + ); + assert_eq!( + result_json["seen_tool_args"][0], + json!({ "city": "Boston" }) + ); + }); + + reset_runtime_state(); +} + #[test] fn test_python_test_guard_restores_existing_runtime_env() { let lock = crate::test_support::lock_python_test(); diff --git a/docs/about-nemo-relay/concepts/plugins.mdx b/docs/about-nemo-relay/concepts/plugins.mdx index b9c412e9..065b4b96 100644 --- a/docs/about-nemo-relay/concepts/plugins.mdx +++ b/docs/about-nemo-relay/concepts/plugins.mdx @@ -171,9 +171,10 @@ The core crate also ships a built-in `nemo_guardrails` plugin component. It is the first-party Guardrails integration point that NeMo Relay owns through the shared plugin system. -The current shipped user-facing lane is the remote backend. It gives NeMo Relay -one canonical plugin kind and config shape for Guardrails-backed managed LLM -and tool checks while broader backend parity work remains separate. +The current shipped user-facing lanes are: + +- the remote backend for Guardrails-service integration +- the Python-backed local backend for in-process `nemoguardrails` integration Detailed Guardrails plugin configuration belongs in [NeMo Guardrails Configuration](/nemo-guardrails-plugin/configuration). diff --git a/docs/build-plugins/nemoguardrails.mdx b/docs/build-plugins/nemoguardrails.mdx index e5517612..a347c3f7 100644 --- a/docs/build-plugins/nemoguardrails.mdx +++ b/docs/build-plugins/nemoguardrails.mdx @@ -15,7 +15,6 @@ first-party `nemo_guardrails` component, see [NeMo Guardrails Plugin](/nemo-guardrails-plugin/about). - The example lives under `examples/nemoguardrails`. The single-file plugin implementation, runnable agent, and Guardrails config artifacts are under `example`. diff --git a/docs/nemo-guardrails-plugin/about.mdx b/docs/nemo-guardrails-plugin/about.mdx index aa1c6925..5c0cd2f0 100644 --- a/docs/nemo-guardrails-plugin/about.mdx +++ b/docs/nemo-guardrails-plugin/about.mdx @@ -17,12 +17,11 @@ first-party NeMo Relay plugin. The plugin is designed around backend modes: - `remote` - - Implemented now. - Calls a Guardrails service over HTTP(S), including streaming over the same remote contract. - `local` - - Planned. - - Reserved for a future in-process Python `nemoguardrails` backend. + - Calls `nemoguardrails` in process through the Python runtime instead of a + separate Guardrails service. ## Use This Plugin When @@ -30,39 +29,43 @@ Start here when you need to: - Apply Guardrails input and output checks around managed `llm.execute(...)` calls. -- Apply Guardrails policy around managed tool execution, including the current - remote managed `tool_output` lane. +- Apply Guardrails policy around managed tool execution. - Configure Guardrails behavior through the same plugin config surface used by other first-party NeMo Relay components. -- Keep Guardrails behavior in a reusable process-level config document instead - of wiring provider-specific checks into each application call site. +- Keep Guardrails policy authoring in Guardrails-native config while NeMo Relay + owns when those checks run around managed execution. ## Current Scope -The current shipped user-facing lane is the built-in `remote` backend. +The built-in plugin currently exposes two user-facing modes with +different boundaries. -That lane supports: +| Area | `remote` | `local` | +|---|---|---| +| Managed non-streaming LLM `input` | Supported | Supported | +| Managed non-streaming LLM `output` | Supported | Supported | +| Managed streaming LLM execution | Supported over the remote HTTP(S) contract | Supported for managed input checks and Guardrails-native output streaming when `rails.output.streaming.enabled = true`; with `stream_first = true`, output rails can stop the stream after some chunks have already been delivered; `stream_first = false` is not supported yet | +| Managed `tool_input` | Not supported against the stock Guardrails remote contract | Supported | +| Managed `tool_output` | Supported | Supported | +| `request_defaults` | Supported as backend pass-through request semantics | Not supported | +| Codec support | `openai_chat` | `openai_chat`, `openai_responses`, `anthropic_messages` | +| Runtime availability | Any runtime that includes the remote backend | Python-enabled runtimes that can import `nemoguardrails` | -- Managed non-streaming LLM `input` checks. -- Managed non-streaming LLM `output` checks. -- Managed streaming LLM execution over the remote HTTP(S) path. -- Managed tool-result checks through `tool_output`. -- Request-time Guardrails defaults passed through to the remote backend. - -The current built-in remote backend does not support: - -- Managed `tool_input` checks against the stock Guardrails remote contract. -- `local` mode. -- Remote managed LLM parity beyond `codec = "openai_chat"`. +The `local` backend is a Python-backed runtime feature, not a universal +cross-binding backend. Runtimes that do not install the local backend provider +report `local` mode as unavailable during plugin initialization. ## Managed Surfaces Versus Request Defaults -The NeMo Guardrails plugin model uses two different concepts: +Both `remote` mode and `local` mode share the same top-level plugin model, but +they do not implement every part of that model in the same way. + +At the plugin-model level, NeMo Guardrails uses two different concepts: -- Currently supported managed NeMo Relay execution surfaces in the shipped - remote backend: +- Top-level managed NeMo Relay execution surfaces: - `input` - `output` + - `tool_input` - `tool_output` - Guardrails backend request defaults: - `request_defaults.context` @@ -78,62 +81,43 @@ This distinction matters: - Managed surfaces wrap real NeMo Relay execution boundaries such as `llm.execute(...)` and `tools.execute(...)`. -- Managed surfaces let NeMo Relay enforce behavior around those boundaries. - Depending on the surface, Relay can block work, allow it, or apply managed - request or result handling before the application sees the final outcome. +- Managed surfaces give NeMo Relay an owned enforcement point around a known + runtime step. Depending on the backend and surface, Relay can block work, + allow it, or apply managed request or result handling before the application + sees the outcome. - Managed surfaces also give NeMo Relay a stable runtime boundary for its own - middleware ordering, lifecycle behavior, and observability marks. Relay knows - exactly which step is being wrapped and can attach policy and telemetry to - that step directly. -- `request_defaults` fields are forwarded to the selected Guardrails backend as - request semantics. They do not create new NeMo Relay-native execution - surfaces. -- `request_defaults` can still influence Guardrails behavior, but they do not - give NeMo Relay a new local runtime step to wrap. Relay is passing backend - options along with a request, not creating a new middleware boundary of its - own. -- `request_defaults` are also backend-contract dependent. A selected Guardrails - backend can use them when evaluating a request, but the exact effect depends - on what that backend supports. Relay is not creating a separate local - retrieval, dialog, or tool boundary just because those fields exist in the - request. - -In practice, the tradeoff is: - -- Managed surfaces give you a Relay-owned enforcement point around a known - runtime step, with Relay-owned enforcement, ordering, and marks around that - step. -- `request_defaults` give you backend-level configuration for a request, but - not a separate Relay-owned interception point, runtime boundary, or - middleware surface. - -Another way to think about it: + middleware ordering, lifecycle behavior, and observability marks. + +In practice: - Managed surfaces are places where NeMo Relay is holding the steering wheel. -- `request_defaults` are notes that NeMo Relay passes to the Guardrails backend - with a request. -Top-level `tool_input` is still part of the built-in plugin contract, but it is -not supported by the current stock-remote backend. +The forwarded request-default side is more mode-specific: + +- In `remote` mode, `request_defaults` fields are forwarded to the selected + Guardrails backend as request semantics. They do not create new NeMo + Relay-native execution surfaces. +- In `local` mode, `request_defaults` is rejected instead of passed through. -The overlap in names is important: +The overlap in names is important in `remote` mode: - Top-level `input` is a managed NeMo Relay execution surface. - `request_defaults.rails.input` is a backend pass-through option. - Top-level `output` is a managed NeMo Relay execution surface. - `request_defaults.rails.output` is a backend pass-through option. -- Top-level `tool_input` is part of the built-in plugin model, but the current - stock-remote backend rejects it. +- Top-level `tool_input` is a managed NeMo Relay execution surface in the + plugin contract. The current stock-remote backend rejects it, while the local + backend supports it. - `request_defaults.rails.tool_input` is a backend pass-through option. - Top-level `tool_output` is a managed NeMo Relay execution surface. - `request_defaults.rails.tool_output` is a backend pass-through option. In particular, `request_defaults.rails.dialog` and -`request_defaults.rails.retrieval` are simple pass-through options. They are -not separate managed middleware surfaces in NeMo Relay. +`request_defaults.rails.retrieval` are pass-through options. They are not +separate managed middleware surfaces in NeMo Relay. ## Pages - [NeMo Guardrails Configuration](/nemo-guardrails-plugin/configuration) - documents the built-in component shape, remote-mode boundaries, and current + documents the built-in component shape, mode boundaries, and current support matrix. diff --git a/docs/nemo-guardrails-plugin/configuration.mdx b/docs/nemo-guardrails-plugin/configuration.mdx index ddaa1fb0..b1554e1c 100644 --- a/docs/nemo-guardrails-plugin/configuration.mdx +++ b/docs/nemo-guardrails-plugin/configuration.mdx @@ -10,9 +10,6 @@ SPDX-License-Identifier: Apache-2.0 */} Use this page when you want to configure the built-in NeMo Guardrails plugin component. The component kind is `nemo_guardrails`. -The current shipped user-facing backend is `mode = "remote"`. `local` remains -part of the config model, but it is not yet a finished user-facing backend. - For plugin file discovery, precedence, merge behavior, editor controls, and gateway conflict rules, see [Plugin Configuration Files](/build-plugins/plugin-configuration-files). @@ -37,32 +34,36 @@ The top-level NeMo Guardrails object contains: | `codec` | Managed LLM provider codec. | | `input` | Enables managed LLM input checks. | | `output` | Enables managed LLM output checks. | -| `tool_input` | Part of the built-in plugin model for managed tool-argument checks before execution. The current stock-remote backend rejects it. | +| `tool_input` | Enables managed tool-argument checks before execution. | | `tool_output` | Enables managed tool-result checks after execution. | | `priority` | Middleware priority for installed execution intercepts. | | `remote` | Remote backend settings. | -| `local` | Local backend settings for future local mode. | -| `request_defaults` | Default request-time Guardrails semantics passed to the backend. | +| `local` | Local backend settings. | +| `request_defaults` | Default request-time Guardrails semantics passed to the remote backend. | | `policy` | Component-local handling for unknown fields and unsupported values. | At least one managed Guardrails surface must be enabled. -## Current Remote Support +## Backend Support -The current built-in remote backend supports: +| Area | `remote` | `local` | +|---|---|---| +| Built-in component kind and config validation | Supported | Supported | +| Managed LLM `input` | Supported | Supported | +| Managed LLM `output` | Supported | Supported | +| Managed streaming LLM execution | Supported over the remote HTTP(S) contract | Supported for managed input checks and Guardrails-native output streaming when `rails.output.streaming.enabled = true`; with `stream_first = true`, output rails can stop the stream after some chunks have already been delivered; `stream_first = false` is not supported yet | +| Managed `tool_input` | Not supported against the stock Guardrails remote contract | Supported | +| Managed `tool_output` | Supported | Supported | +| `request_defaults` pass-through | Supported | Not supported | +| Codec support | `openai_chat` | `openai_chat`, `openai_responses`, `anthropic_messages` | +| Runtime availability | Any runtime that includes the remote backend | Python-enabled runtimes that can import `nemoguardrails` | -| Area | Support | -|---|---| -| Built-in component kind and config validation | Supported | -| Managed LLM `input` | Supported | -| Managed LLM `output` | Supported | -| Managed streaming LLM execution over the remote HTTP(S) contract | Supported | -| Managed `tool_output` | Supported | -| Managed `tool_input` | Not supported against the stock Guardrails remote contract | -| `request_defaults` pass-through | Supported | -| `local` mode | Not implemented yet | +## Remote Mode + +Use `remote` mode when NeMo Relay should call a Guardrails service over +HTTP(S). -## Remote Requirements +### Requirements To use `mode = "remote"`, the configured `remote.endpoint` must point at a Guardrails service that NeMo Relay can reach from the running process and that @@ -73,7 +74,7 @@ Guardrails service still owns the actual policy content. In practice, NeMo Relay decides when managed checks run, while the Guardrails config decides what to block, allow, or rewrite. -## `plugins.toml` Example +### `plugins.toml` Example ```toml version = 1 @@ -108,32 +109,12 @@ unknown_field = "warn" unsupported_value = "error" ``` -This example configures the built-in remote backend for a Guardrails service -that uses `codec = "openai_chat"`, managed LLM `input` and `output`, managed +This example configures the built-in remote mode for a Guardrails service that +uses `codec = "openai_chat"`, managed LLM `input` and `output`, managed `tool_output`, and request-default pass-through for backend context plus backend `input` and `output` rail selection. -In that setup, the NeMo Relay plugin chose the managed surfaces to wrap, while -the Guardrails config defined the actual blocking policy, such as rejecting -secret-seeking prompts, bypass attempts, specific blocked tokens, or -private-key-like output. - -For example, the Guardrails-side policy can look like this: - -```yaml -rails: - input: - flows: - - self check input - output: - flows: - - self check output -``` - -This Guardrails-side config defines the policy logic. The NeMo Relay plugin -config decides when those checks run. - -## Remote Mode Rules +### Rules When `mode = "remote"`: @@ -146,24 +127,24 @@ When `mode = "remote"`: ### Codec Boundary -The current built-in remote backend supports managed LLM execution only with: +The current built-in remote mode supports managed LLM execution only with: - `openai_chat` -## Managed Tool Boundary +### Managed Tool Boundary -The current remote backend supports managed `tool_output`. +The current remote mode supports managed `tool_output`. -The current remote backend rejects managed `tool_input` explicitly because the +The current remote mode rejects managed `tool_input` explicitly because the stock Guardrails remote contract does not activate pre-execution tool-call rails from externally submitted `/v1/chat/completions` history. NeMo Relay rejects `tool_input` in remote mode rather than leaving a silent non-enforcing path. -## Request Defaults +### Request Defaults `request_defaults` lets the built-in plugin pass request-time semantics through -to the selected backend. +to the selected remote backend. Supported request-default fields are: @@ -201,11 +182,12 @@ The `rails` section can include: - `tool_output` - `tool_input` -Those values are forwarded to the backend as request semantics. They do not -mean NeMo Relay owns separate managed retrieval or dialog execution surfaces. -`dialog` and `retrieval` are pass-through request options only. Likewise, -`request_defaults.rails.tool_input` is only a backend pass-through selector. It -does not make managed remote `tool_input` supported in the stock-remote lane. +Those values are forwarded to the remote backend as request semantics. They do +not mean NeMo Relay owns separate managed retrieval or dialog execution +surfaces. `dialog` and `retrieval` are pass-through request options only. +Likewise, `request_defaults.rails.tool_input` is only a backend pass-through +selector. It does not make managed remote `tool_input` supported in the +stock-remote lane. For more targeted request-time pass-through, the remote backend also forwards selectors like these: @@ -219,11 +201,7 @@ dialog = true tool_output = ["validate_tool_output"] ``` -This richer selector shape demonstrates how request-time Guardrails semantics -can be forwarded even when NeMo Relay does not own a separate native managed -surface for that category. - -## Observability +### Observability The current remote backend emits coarse backend-level marks for remote Guardrails activity: @@ -232,4 +210,111 @@ Guardrails activity: - `nemo_guardrails.remote.end` - `nemo_guardrails.remote.error` -These marks cover managed LLM remote execution and managed tool-result checks. +## Local Mode + +Use `local` mode when NeMo Relay should call `nemoguardrails` in process +through the Python runtime instead of a separate Guardrails service. + +### Requirements + +To use `mode = "local"`, the running Python environment must be able to import +`nemoguardrails`. + +The built-in local backend is installed by the Python binding and runs +Guardrails in process. Use it when the runtime has direct access to the Python +Guardrails dependency and configuration files rather than a separate Guardrails +service. + +The same ownership boundary still applies: + +- NeMo Relay decides when managed checks run. +- Guardrails-native config still decides what to block, allow, or rewrite. + +### `plugins.toml` Example + +```toml +version = 1 + +[[components]] +kind = "nemo_guardrails" +enabled = true + +[components.config] +version = 1 +mode = "local" +codec = "openai_chat" +input = true +output = true +tool_input = true +tool_output = true +config_path = "./rails" + +[components.config.policy] +unknown_component = "warn" +unknown_field = "warn" +unsupported_value = "error" +``` + +This example configures the built-in local mode for a Python-enabled runtime +that can import `nemoguardrails` and read a native Guardrails config directory +from `./rails`. + +For example, the Guardrails-side policy can look like this: + +```yaml +rails: + input: + flows: + - self check input + output: + flows: + - self check output +``` + +This Guardrails-side config defines the policy logic. The NeMo Relay plugin +config decides when those checks run. + +### Rules + +When `mode = "local"`: + +- Exactly one of `config_path` or `config_yaml` is required. +- `colang_content` can only be used with `config_yaml`. +- `remote` settings cannot be present. +- `request_defaults` is rejected. +- `local.python_module` is optional and only needed when the runtime should + import the Guardrails dependency from a custom Python module path instead of + the default `nemoguardrails` package. + +### Codec Boundary + +The current built-in local mode supports managed LLM execution with: + +- `openai_chat` +- `openai_responses` +- `anthropic_messages` + +### Managed Tool Boundary + +The current local mode supports both: + +- managed `tool_input` +- managed `tool_output` + +### Streaming Boundary + +The current local mode supports streaming LLM input checks before the stream +callback runs. + +When output rails are configured, the current local mode uses Guardrails-native +streaming output rails instead of buffering the full provider stream. That +requires `rails.output.streaming.enabled = true` in the Guardrails config. + +The current local mode supports the `stream_first = true` streaming semantics: +provider chunks can still flow to the caller while Guardrails evaluates the +stream in parallel. If Guardrails later blocks the stream, the call fails at +that point even though some chunks may already have been delivered. + +The current local mode does not support `rails.output.streaming.stream_first = false` +yet, because that would require converting guarded text chunks back into valid +provider-shaped stream chunks. diff --git a/python/nemo_relay/_guardrails_local.py b/python/nemo_relay/_guardrails_local.py new file mode 100644 index 00000000..5f30eb49 --- /dev/null +++ b/python/nemo_relay/_guardrails_local.py @@ -0,0 +1,589 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Internal helpers for the built-in NeMo Guardrails local backend.""" + +from __future__ import annotations + +import asyncio +import importlib +import json +from collections.abc import Callable +from typing import Any, Protocol, cast + +from nemo_relay import Json, LLMRequest +from nemo_relay.codecs import ( + AnthropicMessagesCodec, + LlmCodec, + LlmResponseCodec, + OpenAIChatCodec, + OpenAIResponsesCodec, +) +from nemo_relay.plugin import PluginContext + +_DEFAULT_PRIORITY = 100 + + +class NeMoGuardrailsDependencyError(RuntimeError): + """Raised when the optional ``nemoguardrails`` dependency is unavailable.""" + + +class NeMoGuardrailsViolation(RuntimeError): + """Raised when NeMo Guardrails blocks or cannot safely apply a rail result.""" + + def __init__( + self, + message: str, + *, + rail_type: str, + rail: str | None = None, + content: str | None = None, + ) -> None: + super().__init__(message) + self.rail_type = rail_type + self.rail = rail + self.content = content + + +class _GuardrailsCodec(LlmCodec, LlmResponseCodec, Protocol): + """Codec shape required by the local backend.""" + + +_CODECS: dict[str, Callable[[], _GuardrailsCodec]] = { + "openai_chat": OpenAIChatCodec, + "openai_responses": OpenAIResponsesCodec, + "anthropic_messages": AnthropicMessagesCodec, +} + + +def _load_nemoguardrails(module_name: str | None): + root_module = module_name or "nemoguardrails" + try: + guardrails = cast(Any, importlib.import_module(root_module)) + options = cast(Any, importlib.import_module(f"{root_module}.rails.llm.options")) + except ImportError as error: + if error.name == root_module: + raise NeMoGuardrailsDependencyError( + "NeMo Guardrails is required for the built-in NeMo Guardrails local backend. " + "Install it with: pip install nemoguardrails" + ) from error + raise NeMoGuardrailsDependencyError( + "NeMo Guardrails local backend could not import a required dependency: " + f"{error.name or error}. Install the full NeMo Guardrails runtime dependencies." + ) from error + + return ( + guardrails.RailsConfig, + guardrails.LLMRails, + options.RailType, + options.RailStatus, + ) + + +def _status_value(status: Any) -> str: + return str(getattr(status, "value", status)).lower() + + +def _messages_from_annotated(annotated: Any) -> list[dict[str, Any]]: + return [dict(message) for message in annotated.messages] + + +async def _apply_input_rails( + rails: Any, + rail_type: Any, + rail_status: Any, + codec: _GuardrailsCodec, + request: LLMRequest, +) -> tuple[LLMRequest, list[dict[str, Any]]]: + annotated_request = codec.decode(request) + messages = _messages_from_annotated(annotated_request) + input_result = await rails.check_async(messages, rail_types=[rail_type.INPUT]) + input_status = _status_value(input_result.status) + if input_status == _status_value(rail_status.BLOCKED): + _raise_blocked(input_result, "input") + if input_status == _status_value(rail_status.MODIFIED): + input_content = getattr(input_result, "content", "") + annotated_request.messages = _replace_last_role_content( + messages, + "user", + "" if input_content is None else str(input_content), + ) + request = codec.encode(annotated_request, request) + messages = _messages_from_annotated(annotated_request) + return request, messages + + +def _replace_last_role_content(messages: list[dict[str, Any]], role: str, content: str) -> list[dict[str, Any]]: + updated = [dict(message) for message in messages] + for index in range(len(updated) - 1, -1, -1): + if updated[index].get("role") == role: + updated[index]["content"] = content + return updated + raise NeMoGuardrailsViolation( + f"NeMo Guardrails returned modified {role} content but no {role} message was present.", + rail_type="input" if role == "user" else "output", + content=content, + ) + + +def _tool_input_content(name: str, args: Json) -> str: + return json.dumps( + { + "tool_name": name, + "arguments": args, + }, + sort_keys=True, + separators=(",", ":"), + ) + + +def _tool_output_content(name: str, args: Json, result: Json) -> str: + return json.dumps( + { + "tool_name": name, + "arguments": args, + "result": result, + }, + sort_keys=True, + separators=(",", ":"), + ) + + +def _modified_tool_payload(content: str, field: str) -> Json: + try: + value = json.loads(content) + except json.JSONDecodeError as error: + raise NeMoGuardrailsViolation( + f"NeMo Guardrails returned modified tool {field} content that is not valid JSON.", + rail_type=f"tool_{field}", + content=content, + ) from error + + if not isinstance(value, dict) or field not in value: + raise NeMoGuardrailsViolation( + f"NeMo Guardrails returned modified tool {field} content without a '{field}' field.", + rail_type=f"tool_{field}", + content=content, + ) + return cast(Json, value[field]) + + +def _raise_modified_output_not_supported(result: Any) -> None: + output_content = getattr(result, "content", "") + output_rail = getattr(result, "rail", None) + raise NeMoGuardrailsViolation( + "NeMo Guardrails output rail returned modified content, but the local backend " + "does not rewrite provider responses yet.", + rail_type="output", + rail=None if output_rail is None else str(output_rail), + content="" if output_content is None else str(output_content), + ) + + +async def _check_output_rails( + rails: Any, + rail_type: Any, + rail_status: Any, + messages: list[dict[str, Any]], + response_text: str | None, +) -> None: + if response_text is None: + return + + output_messages = [*messages, {"role": "assistant", "content": response_text}] + output_result = await rails.check_async(output_messages, rail_types=[rail_type.OUTPUT]) + output_status = _status_value(output_result.status) + if output_status == _status_value(rail_status.BLOCKED): + _raise_blocked(output_result, "output") + if output_status == _status_value(rail_status.MODIFIED): + _raise_modified_output_not_supported(output_result) + + +def _has_streaming_output_rails(rails: Any) -> bool: + return bool(getattr(rails.config.rails.output, "flows", [])) + + +def _output_streaming_config(rails: Any) -> Any | None: + return getattr(rails.config.rails.output, "streaming", None) + + +def _guardrails_streaming_enabled(rails: Any) -> bool: + streaming = _output_streaming_config(rails) + return bool(streaming is not None and getattr(streaming, "enabled", False)) + + +def _extract_stream_text(codec_name: str, chunk: Json) -> str | None: + if not isinstance(chunk, dict): + return None + + if codec_name == "openai_chat": + choices = chunk.get("choices") + if not isinstance(choices, list): + return None + parts: list[str] = [] + for choice in choices: + if not isinstance(choice, dict): + continue + delta = choice.get("delta") + if not isinstance(delta, dict): + continue + content = delta.get("content") + if isinstance(content, str) and content: + parts.append(content) + return "".join(parts) if parts else None + + if codec_name == "openai_responses": + if chunk.get("type") == "response.output_text.delta": + delta = chunk.get("delta") + return delta if isinstance(delta, str) and delta else None + return None + + if codec_name == "anthropic_messages": + if chunk.get("type") != "content_block_delta": + return None + delta = chunk.get("delta") + if not isinstance(delta, dict): + return None + if delta.get("type") != "text_delta": + return None + text = delta.get("text") + return text if isinstance(text, str) and text else None + + return None + + +def _guardrails_stream_error_message(chunk: str) -> str | None: + try: + payload = json.loads(chunk) + except json.JSONDecodeError: + return None + if not isinstance(payload, dict): + return None + error = payload.get("error") + if not isinstance(error, dict): + return None + if error.get("type") != "guardrails_violation": + return None + message = error.get("message") + return message if isinstance(message, str) and message else "Blocked by output rails." + + +async def _queue_string_stream(queue: "asyncio.Queue[str | None]"): + while True: + item = await queue.get() + if item is None: + return + yield item + + +async def _monitor_streaming_output_rails( + *, + rails: Any, + messages: list[dict[str, Any]], + text_queue: "asyncio.Queue[str | None]", + blocked: dict[str, str | None], +) -> None: + guarded_stream = rails.stream_async( + messages=messages, + generator=_queue_string_stream(text_queue), + include_metadata=False, + ) + async for chunk in guarded_stream: + if isinstance(chunk, str): + message = _guardrails_stream_error_message(chunk) + if message is not None: + blocked["message"] = message + return + + +def _raise_streaming_output_blocked(blocked_message: str) -> None: + raise NeMoGuardrailsViolation( + f"NeMo Guardrails output rail blocked the LLM call: {blocked_message}", + rail_type="output", + content=blocked_message, + ) + + +def _build_guardrails_config(config: dict[str, Any], rails_config_cls: Any) -> Any: + if config.get("config_path") is not None: + return rails_config_cls.from_path(cast(str, config["config_path"])) + return rails_config_cls.from_content( + colang_content=cast(str | None, config.get("colang_content")), + yaml_content=cast(str, config["config_yaml"]), + ) + + +def _resolve_codec(config: dict[str, Any]) -> tuple[str, _GuardrailsCodec]: + codec_name = cast(str | None, config.get("codec")) + if codec_name is None or codec_name not in _CODECS: + raise RuntimeError("local NeMo Guardrails backend requires a supported codec") + return codec_name, _CODECS[codec_name]() + + +async def _check_tool_input( + rails: Any, + rail_type: Any, + rail_status: Any, + tool_name: str, + args: Json, +) -> Json: + input_result = await rails.check_async( + [{"role": "user", "content": _tool_input_content(tool_name, args)}], + rail_types=[rail_type.INPUT], + ) + input_status = _status_value(input_result.status) + if input_status == _status_value(rail_status.BLOCKED): + _raise_blocked(input_result, "tool_input") + if input_status == _status_value(rail_status.MODIFIED): + input_content = getattr(input_result, "content", "") + return _modified_tool_payload( + "" if input_content is None else str(input_content), + "arguments", + ) + return args + + +async def _check_tool_output( + rails: Any, + rail_type: Any, + rail_status: Any, + tool_name: str, + args: Json, + result: Json, +) -> Json: + output_result = await rails.check_async( + [ + {"role": "user", "content": _tool_input_content(tool_name, args)}, + { + "role": "assistant", + "content": _tool_output_content(tool_name, args, result), + }, + ], + rail_types=[rail_type.OUTPUT], + ) + output_status = _status_value(output_result.status) + if output_status == _status_value(rail_status.BLOCKED): + _raise_blocked(output_result, "tool_output") + if output_status == _status_value(rail_status.MODIFIED): + output_content = getattr(output_result, "content", "") + return _modified_tool_payload( + "" if output_content is None else str(output_content), + "result", + ) + return result + + +def _make_llm_intercept( + *, + rails: Any, + rail_type: Any, + rail_status: Any, + codec: _GuardrailsCodec, + enable_input: bool, + enable_output: bool, +): + async def intercept(_name: str, request: LLMRequest, next_call): + current_request = request + messages = _messages_from_annotated(codec.decode(current_request)) + + if enable_input: + current_request, messages = await _apply_input_rails( + rails, + rail_type, + rail_status, + codec, + current_request, + ) + + response = await next_call(current_request) + if not enable_output: + return response + + annotated_response = codec.decode_response(response) + await _check_output_rails( + rails, + rail_type, + rail_status, + messages, + annotated_response.response_text(), + ) + return response + + return intercept + + +def _make_llm_stream_intercept( + *, + rails: Any, + rail_type: Any, + rail_status: Any, + codec_name: str, + codec: _GuardrailsCodec, + enable_input: bool, + enable_output: bool, +): + async def stream_intercept(request: LLMRequest, next_call): + current_request = request + messages = _messages_from_annotated(codec.decode(current_request)) + if enable_input: + current_request, messages = await _apply_input_rails( + rails, + rail_type, + rail_status, + codec, + current_request, + ) + + stream = await next_call(current_request) + if not enable_output: + return stream + if not _has_streaming_output_rails(rails): + return stream + if not _guardrails_streaming_enabled(rails): + raise RuntimeError( + "local NeMo Guardrails streaming output rails require " + "rails.output.streaming.enabled = true in the Guardrails config." + ) + + streaming_config = _output_streaming_config(rails) + if streaming_config is None or not getattr(streaming_config, "stream_first", True): + raise RuntimeError( + "local NeMo Guardrails streaming output rails currently require " + "rails.output.streaming.stream_first = true." + ) + + text_queue: asyncio.Queue[str | None] = asyncio.Queue() + blocked: dict[str, str | None] = {"message": None} + monitor = asyncio.create_task( + _monitor_streaming_output_rails( + rails=rails, + messages=messages, + text_queue=text_queue, + blocked=blocked, + ) + ) + + async def guarded_provider_stream(): + try: + async for chunk in stream: + if blocked["message"] is not None: + _raise_streaming_output_blocked(blocked["message"]) + + text = _extract_stream_text(codec_name, chunk) + if text is not None: + await text_queue.put(text) + + yield chunk + + if blocked["message"] is not None: + _raise_streaming_output_blocked(blocked["message"]) + finally: + await text_queue.put(None) + await monitor + if blocked["message"] is not None: + _raise_streaming_output_blocked(blocked["message"]) + + return guarded_provider_stream() + + return stream_intercept + + +def _make_tool_intercept( + *, + rails: Any, + rail_type: Any, + rail_status: Any, + enable_tool_input: bool, + enable_tool_output: bool, +): + async def tool_intercept(tool_name: str, args: Json, next_call): + current_args = args + + if enable_tool_input: + current_args = await _check_tool_input( + rails, + rail_type, + rail_status, + tool_name, + current_args, + ) + + tool_result = await next_call(current_args) + if not enable_tool_output: + return tool_result + + return await _check_tool_output( + rails, + rail_type, + rail_status, + tool_name, + current_args, + tool_result, + ) + + return tool_intercept + + +def _raise_blocked(result: Any, rail_type: str) -> None: + rail_value = getattr(result, "rail", None) + rail = None if rail_value is None else str(rail_value) + content = getattr(result, "content", "") + detail = f" by rail '{rail}'" if rail else "" + subject = "LLM call" if rail_type in {"input", "output"} else "tool call" + raise NeMoGuardrailsViolation( + f"NeMo Guardrails {rail_type} rail blocked the {subject}{detail}.", + rail_type=rail_type, + rail=rail, + content="" if content is None else str(content), + ) + + +def register_local_backend(config: dict[str, Any], context: PluginContext) -> None: + """Install the built-in NeMo Guardrails local backend.""" + + local = cast(dict[str, Any], config.get("local") or {}) + module_name = cast(str | None, local.get("python_module")) + RailsConfig, LLMRails, RailType, RailStatus = _load_nemoguardrails(module_name) + guardrails_config = _build_guardrails_config(config, RailsConfig) + rails = LLMRails(guardrails_config) + enable_input = bool(config.get("input", True)) + enable_output = bool(config.get("output", True)) + enable_tool_input = bool(config.get("tool_input", False)) + enable_tool_output = bool(config.get("tool_output", False)) + priority = int(config.get("priority", _DEFAULT_PRIORITY)) + + if enable_input or enable_output: + codec_name, codec = _resolve_codec(config) + intercept = _make_llm_intercept( + rails=rails, + rail_type=RailType, + rail_status=RailStatus, + codec=codec, + enable_input=enable_input, + enable_output=enable_output, + ) + stream_intercept = _make_llm_stream_intercept( + rails=rails, + rail_type=RailType, + rail_status=RailStatus, + codec_name=codec_name, + codec=codec, + enable_input=enable_input, + enable_output=enable_output, + ) + context.register_llm_execution_intercept("nemo_guardrails_local", priority, intercept) + context.register_llm_stream_execution_intercept( + "nemo_guardrails_local_stream", + priority, + stream_intercept, + ) + + if enable_tool_input or enable_tool_output: + tool_intercept = _make_tool_intercept( + rails=rails, + rail_type=RailType, + rail_status=RailStatus, + enable_tool_input=enable_tool_input, + enable_tool_output=enable_tool_output, + ) + context.register_tool_execution_intercept("nemo_guardrails_local", priority, tool_intercept) From ec49259df6bc42e2c70c344c4741c4c322693171 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Mon, 1 Jun 2026 06:48:08 -0700 Subject: [PATCH 02/20] docs: refine local guardrails mode docs Signed-off-by: Alex Fournier --- docs/nemo-guardrails-plugin/about.mdx | 27 +++++++------ docs/nemo-guardrails-plugin/configuration.mdx | 39 +++++++++++++++---- 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/docs/nemo-guardrails-plugin/about.mdx b/docs/nemo-guardrails-plugin/about.mdx index 5c0cd2f0..d346fcb0 100644 --- a/docs/nemo-guardrails-plugin/about.mdx +++ b/docs/nemo-guardrails-plugin/about.mdx @@ -37,19 +37,18 @@ Start here when you need to: ## Current Scope -The built-in plugin currently exposes two user-facing modes with -different boundaries. - -| Area | `remote` | `local` | -|---|---|---| -| Managed non-streaming LLM `input` | Supported | Supported | -| Managed non-streaming LLM `output` | Supported | Supported | -| Managed streaming LLM execution | Supported over the remote HTTP(S) contract | Supported for managed input checks and Guardrails-native output streaming when `rails.output.streaming.enabled = true`; with `stream_first = true`, output rails can stop the stream after some chunks have already been delivered; `stream_first = false` is not supported yet | -| Managed `tool_input` | Not supported against the stock Guardrails remote contract | Supported | -| Managed `tool_output` | Supported | Supported | -| `request_defaults` | Supported as backend pass-through request semantics | Not supported | -| Codec support | `openai_chat` | `openai_chat`, `openai_responses`, `anthropic_messages` | -| Runtime availability | Any runtime that includes the remote backend | Python-enabled runtimes that can import `nemoguardrails` | +The built-in plugin currently exposes two user-facing modes: + +- `remote` for Guardrails-service integration over HTTP(S) +- `local` for in-process `nemoguardrails` integration through the Python runtime + +Both modes support managed LLM `input` and `output`. The current mode-specific +differences are: + +- `remote` supports `request_defaults` pass-through but does not support managed + `tool_input` +- `local` supports managed `tool_input` and broader LLM codec coverage, but it + does not support `request_defaults` The `local` backend is a Python-backed runtime feature, not a universal cross-binding backend. Runtimes that do not install the local backend provider @@ -119,5 +118,5 @@ separate managed middleware surfaces in NeMo Relay. ## Pages - [NeMo Guardrails Configuration](/nemo-guardrails-plugin/configuration) - documents the built-in component shape, mode boundaries, and current + documents the built-in component shape, mode boundaries, and the detailed support matrix. diff --git a/docs/nemo-guardrails-plugin/configuration.mdx b/docs/nemo-guardrails-plugin/configuration.mdx index b1554e1c..24cc12c6 100644 --- a/docs/nemo-guardrails-plugin/configuration.mdx +++ b/docs/nemo-guardrails-plugin/configuration.mdx @@ -61,7 +61,9 @@ At least one managed Guardrails surface must be enabled. ## Remote Mode Use `remote` mode when NeMo Relay should call a Guardrails service over -HTTP(S). +HTTP(S), especially when Guardrails must be shared across runtimes, used from +non-Python environments, or deployed independently from the application +process. ### Requirements @@ -76,6 +78,11 @@ to block, allow, or rewrite. ### `plugins.toml` Example +You can write this config directly in `plugins.toml`, or create and edit it +through the CLI with `nemo-relay plugins edit`. For plugin file discovery, +precedence, merge behavior, and editor controls, see +[Plugin Configuration Files](/build-plugins/plugin-configuration-files). + ```toml version = 1 @@ -232,6 +239,11 @@ The same ownership boundary still applies: ### `plugins.toml` Example +You can write this config directly in `plugins.toml`, or create and edit it +through the CLI with `nemo-relay plugins edit`. For plugin file discovery, +precedence, merge behavior, and editor controls, see +[Plugin Configuration Files](/build-plugins/plugin-configuration-files). + ```toml version = 1 @@ -310,11 +322,24 @@ When output rails are configured, the current local mode uses Guardrails-native streaming output rails instead of buffering the full provider stream. That requires `rails.output.streaming.enabled = true` in the Guardrails config. -The current local mode supports the `stream_first = true` streaming semantics: -provider chunks can still flow to the caller while Guardrails evaluates the -stream in parallel. If Guardrails later blocks the stream, the call fails at -that point even though some chunks may already have been delivered. +Guardrails calls the main streaming-output switch +`rails.output.streaming.stream_first`. + +When `stream_first = true`, the current local mode uses pass-through-first +streaming semantics: + +- provider chunks can flow to the caller immediately +- Guardrails evaluates the streamed text in parallel +- if Guardrails later blocks the stream, the call fails at that point even + though some chunks may already have been delivered The current local mode does not support `rails.output.streaming.stream_first = false` -yet, because that would require converting guarded text chunks back into valid -provider-shaped stream chunks. +yet. That mode would be Guardrails-first streaming semantics: + +- Guardrails would need to evaluate streamed text before chunks are released to + the caller +- the local backend would then need to convert Guardrails-approved text back + into valid provider-shaped stream chunks + +That guarded-text-to-provider-chunk adapter does not exist yet in the current +local backend. From 98d49155906b6ccab70710c854faa22e719660b0 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Mon, 1 Jun 2026 07:03:26 -0700 Subject: [PATCH 03/20] test: factor local guardrails coverage fixtures Signed-off-by: Alex Fournier --- .../python/tests/coverage/coverage_tests.rs | 256 ++++++++---------- 1 file changed, 112 insertions(+), 144 deletions(-) diff --git a/crates/python/tests/coverage/coverage_tests.rs b/crates/python/tests/coverage/coverage_tests.rs index 3e553341..90b792f8 100644 --- a/crates/python/tests/coverage/coverage_tests.rs +++ b/crates/python/tests/coverage/coverage_tests.rs @@ -41,6 +41,80 @@ fn load_module<'py>(py: Python<'py>, code: &str) -> Bound<'py, PyModule> { PyModule::from_code(py, &code, &file_name, &module_name).unwrap() } +fn python_package_dir() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../python") +} + +fn fake_guardrails_module_prelude(module_name: &str, python_dir: &str) -> String { + format!( + r#" +import sys +import types + +sys.path.insert(0, {python_dir:?}) + +MODULE_NAME = {module_name:?} + +fake_root = types.ModuleType(MODULE_NAME) +fake_options = types.ModuleType(MODULE_NAME + ".rails.llm.options") + +class Result: + def __init__(self, status, content=None, rail=None): + self.status = status + self.content = content + self.rail = rail + +class RailType: + INPUT = "input" + OUTPUT = "output" + +class RailStatus: + BLOCKED = "blocked" + MODIFIED = "modified" + PASSED = "passed" + +class RailsConfig: + @staticmethod + def from_content(*, colang_content=None, yaml_content=None): + return {{"yaml": yaml_content, "colang": colang_content}} + + @staticmethod + def from_path(path): + return {{"path": path}} +"#, + python_dir = python_dir, + module_name = module_name, + ) +} + +fn register_fake_guardrails_module_epilogue() -> &'static str { + r#" +fake_root.RailsConfig = RailsConfig +fake_root.LLMRails = LLMRails +fake_options.RailType = RailType +fake_options.RailStatus = RailStatus + +sys.modules[MODULE_NAME] = fake_root +sys.modules[MODULE_NAME + ".rails"] = types.ModuleType(MODULE_NAME + ".rails") +sys.modules[MODULE_NAME + ".rails.llm"] = types.ModuleType(MODULE_NAME + ".rails.llm") +sys.modules[MODULE_NAME + ".rails.llm.options"] = fake_options +"# +} + +fn local_plugin_context_python() -> &'static str { + r#" +class Context: + def register_llm_execution_intercept(self, name, priority, callback): + self.llm = callback + + def register_llm_stream_execution_intercept(self, name, priority, callback): + self.stream = callback + + def register_tool_execution_intercept(self, name, priority, callback): + self.tool = callback +"# +} + fn make_request() -> LlmRequest { LlmRequest { headers: serde_json::Map::from_iter([("x-trace".into(), json!("1"))]), @@ -153,45 +227,24 @@ fn test_native_pymodule_entrypoint_installs_nemo_guardrails_local_provider() { fn test_guardrails_local_helper_registers_and_enforces_llm_and_tool_checks() { let _python = crate::test_support::init_python_test(); Python::attach(|py| { - let python_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../python"); + let native_module = PyModule::new(py, "_native_guardrails_helper").unwrap(); + crate::_native(&native_module).unwrap(); + let sys = py.import("sys").unwrap(); + let modules = sys.getattr("modules").unwrap(); + modules + .set_item("nemo_relay._native", native_module.clone()) + .unwrap(); + + let python_dir = python_package_dir(); + let prelude = + fake_guardrails_module_prelude("fake_guardrails_local_helper", &python_dir.display().to_string()); + let epilogue = register_fake_guardrails_module_epilogue(); + let context_class = local_plugin_context_python(); let module = load_module( py, &format!( r#" -import pathlib -import sys -import types - -sys.path.insert(0, {python_dir:?}) - -MODULE_NAME = "fake_guardrails_local_helper" - -fake_root = types.ModuleType(MODULE_NAME) -fake_options = types.ModuleType(MODULE_NAME + ".rails.llm.options") - -class Result: - def __init__(self, status, content=None, rail=None): - self.status = status - self.content = content - self.rail = rail - -class RailType: - INPUT = "input" - OUTPUT = "output" - -class RailStatus: - BLOCKED = "blocked" - MODIFIED = "modified" - PASSED = "passed" - -class RailsConfig: - @staticmethod - def from_content(*, colang_content=None, yaml_content=None): - return {{"yaml": yaml_content, "colang": colang_content}} - - @staticmethod - def from_path(path): - return {{"path": path}} +{prelude} check_results = [] check_calls = [] @@ -204,32 +257,15 @@ class LLMRails: check_calls.append((messages, rail_types)) return check_results.pop(0) -fake_root.RailsConfig = RailsConfig -fake_root.LLMRails = LLMRails -fake_options.RailType = RailType -fake_options.RailStatus = RailStatus - -sys.modules[MODULE_NAME] = fake_root -sys.modules[MODULE_NAME + ".rails"] = types.ModuleType(MODULE_NAME + ".rails") -sys.modules[MODULE_NAME + ".rails.llm"] = types.ModuleType(MODULE_NAME + ".rails.llm") -sys.modules[MODULE_NAME + ".rails.llm.options"] = fake_options +{epilogue} from nemo_relay._native import LLMRequest from nemo_relay._guardrails_local import register_local_backend -class Context: - def register_llm_execution_intercept(self, name, priority, callback): - self.llm = callback - - def register_llm_stream_execution_intercept(self, name, priority, callback): - self.stream = callback - - def register_tool_execution_intercept(self, name, priority, callback): - self.tool = callback +{context_class} async def run_case(): ctx = Context() - event_log = [] register_local_backend( {{ "mode": "local", @@ -291,7 +327,9 @@ async def run_case(): "check_calls": check_calls, }} "#, - python_dir = python_dir.display(), + prelude = prelude, + epilogue = epilogue, + context_class = context_class, ), ); @@ -332,40 +370,16 @@ fn test_guardrails_local_helper_enforces_streamed_output_rails() { .set_item("nemo_relay._native", native_module.clone()) .unwrap(); - let python_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../python"); + let python_dir = python_package_dir(); + let prelude = + fake_guardrails_module_prelude("fake_guardrails_streaming", &python_dir.display().to_string()); + let epilogue = register_fake_guardrails_module_epilogue(); + let context_class = local_plugin_context_python(); let module = load_module( py, &format!( r#" -import sys -import types - -sys.path.insert(0, {python_dir:?}) - -MODULE_NAME = "fake_guardrails_streaming" - -fake_root = types.ModuleType(MODULE_NAME) -fake_options = types.ModuleType(MODULE_NAME + ".rails.llm.options") - -class Result: - def __init__(self, status, content=None, rail=None): - self.status = status - self.content = content - self.rail = rail - -class RailType: - INPUT = "input" - OUTPUT = "output" - -class RailStatus: - BLOCKED = "blocked" - MODIFIED = "modified" - PASSED = "passed" - -class RailsConfig: - @staticmethod - def from_content(*, colang_content=None, yaml_content=None): - return {{"yaml": yaml_content}} +{prelude} stream_results = [] event_log = [] @@ -395,28 +409,12 @@ class LLMRails: yield '{{"error": {{"message": "Blocked by output rails: output-policy", "type": "guardrails_violation"}}}}' return _run() -fake_root.RailsConfig = RailsConfig -fake_root.LLMRails = LLMRails -fake_options.RailType = RailType -fake_options.RailStatus = RailStatus - -sys.modules[MODULE_NAME] = fake_root -sys.modules[MODULE_NAME + ".rails"] = types.ModuleType(MODULE_NAME + ".rails") -sys.modules[MODULE_NAME + ".rails.llm"] = types.ModuleType(MODULE_NAME + ".rails.llm") -sys.modules[MODULE_NAME + ".rails.llm.options"] = fake_options +{epilogue} from nemo_relay._native import LLMRequest from nemo_relay._guardrails_local import register_local_backend -class Context: - def register_llm_execution_intercept(self, name, priority, callback): - self.llm = callback - - def register_llm_stream_execution_intercept(self, name, priority, callback): - self.stream = callback - - def register_tool_execution_intercept(self, name, priority, callback): - self.tool = callback +{context_class} async def run_case(): ctx = Context() @@ -506,7 +504,9 @@ async def run_case(): "modified": modified, }} "#, - python_dir = python_dir.display(), + prelude = prelude, + epilogue = epilogue, + context_class = context_class, ), ); @@ -581,40 +581,15 @@ fn test_local_guardrails_provider_initializes_and_enforces_managed_core_calls() .set_item("nemo_relay._native", native_module.clone()) .unwrap(); - let python_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../python"); + let python_dir = python_package_dir(); + let prelude = + fake_guardrails_module_prelude("fake_guardrails_local_e2e", &python_dir.display().to_string()); + let epilogue = register_fake_guardrails_module_epilogue(); let module = load_module( py, &format!( r#" -import sys -import types - -sys.path.insert(0, {python_dir:?}) - -MODULE_NAME = "fake_guardrails_local_e2e" - -fake_root = types.ModuleType(MODULE_NAME) -fake_options = types.ModuleType(MODULE_NAME + ".rails.llm.options") - -class Result: - def __init__(self, status, content=None, rail=None): - self.status = status - self.content = content - self.rail = rail - -class RailType: - INPUT = "input" - OUTPUT = "output" - -class RailStatus: - BLOCKED = "blocked" - MODIFIED = "modified" - PASSED = "passed" - -class RailsConfig: - @staticmethod - def from_content(*, colang_content=None, yaml_content=None): - return {{"yaml": yaml_content}} +{prelude} check_results = [] @@ -625,15 +600,7 @@ class LLMRails: async def check_async(self, messages, rail_types): return check_results.pop(0) -fake_root.RailsConfig = RailsConfig -fake_root.LLMRails = LLMRails -fake_options.RailType = RailType -fake_options.RailStatus = RailStatus - -sys.modules[MODULE_NAME] = fake_root -sys.modules[MODULE_NAME + ".rails"] = types.ModuleType(MODULE_NAME + ".rails") -sys.modules[MODULE_NAME + ".rails.llm"] = types.ModuleType(MODULE_NAME + ".rails.llm") -sys.modules[MODULE_NAME + ".rails.llm.options"] = fake_options +{epilogue} import nemo_relay @@ -709,7 +676,8 @@ async def run_case(): "seen_tool_args": seen_tool_args, }} "#, - python_dir = python_dir.display(), + prelude = prelude, + epilogue = epilogue, ), ); let result_json = with_event_loop(py, |event_loop| { From 244f29f022929b395fe64ace2e4b0dc428ccea5d Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Mon, 1 Jun 2026 09:16:28 -0700 Subject: [PATCH 04/20] style: apply rustfmt for local guardrails tests Signed-off-by: Alex Fournier --- .../core/src/plugins/nemo_guardrails/local.rs | 3 ++- crates/python/tests/coverage/coverage_tests.rs | 18 ++++++++++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/crates/core/src/plugins/nemo_guardrails/local.rs b/crates/core/src/plugins/nemo_guardrails/local.rs index 31f4e1c8..240ed186 100644 --- a/crates/core/src/plugins/nemo_guardrails/local.rs +++ b/crates/core/src/plugins/nemo_guardrails/local.rs @@ -14,7 +14,8 @@ type LocalBackendProvider = Arc< static LOCAL_BACKEND_PROVIDER: LazyLock>> = LazyLock::new(|| Mutex::new(None)); -fn local_backend_provider_guard() -> PluginResult>> { +fn local_backend_provider_guard() -> PluginResult>> +{ LOCAL_BACKEND_PROVIDER.lock().map_err(|e| { PluginError::Internal(format!( "NeMo Guardrails local backend provider lock poisoned: {e}" diff --git a/crates/python/tests/coverage/coverage_tests.rs b/crates/python/tests/coverage/coverage_tests.rs index 90b792f8..029eee58 100644 --- a/crates/python/tests/coverage/coverage_tests.rs +++ b/crates/python/tests/coverage/coverage_tests.rs @@ -236,8 +236,10 @@ fn test_guardrails_local_helper_registers_and_enforces_llm_and_tool_checks() { .unwrap(); let python_dir = python_package_dir(); - let prelude = - fake_guardrails_module_prelude("fake_guardrails_local_helper", &python_dir.display().to_string()); + let prelude = fake_guardrails_module_prelude( + "fake_guardrails_local_helper", + &python_dir.display().to_string(), + ); let epilogue = register_fake_guardrails_module_epilogue(); let context_class = local_plugin_context_python(); let module = load_module( @@ -371,8 +373,10 @@ fn test_guardrails_local_helper_enforces_streamed_output_rails() { .unwrap(); let python_dir = python_package_dir(); - let prelude = - fake_guardrails_module_prelude("fake_guardrails_streaming", &python_dir.display().to_string()); + let prelude = fake_guardrails_module_prelude( + "fake_guardrails_streaming", + &python_dir.display().to_string(), + ); let epilogue = register_fake_guardrails_module_epilogue(); let context_class = local_plugin_context_python(); let module = load_module( @@ -582,8 +586,10 @@ fn test_local_guardrails_provider_initializes_and_enforces_managed_core_calls() .unwrap(); let python_dir = python_package_dir(); - let prelude = - fake_guardrails_module_prelude("fake_guardrails_local_e2e", &python_dir.display().to_string()); + let prelude = fake_guardrails_module_prelude( + "fake_guardrails_local_e2e", + &python_dir.display().to_string(), + ); let epilogue = register_fake_guardrails_module_epilogue(); let module = load_module( py, From ffa88dcee41143e3bfb4131162155a57ef922876 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Mon, 1 Jun 2026 09:24:48 -0700 Subject: [PATCH 05/20] refactor: name local guardrails imports Signed-off-by: Alex Fournier --- python/nemo_relay/_guardrails_local.py | 41 ++++++++++++++++---------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/python/nemo_relay/_guardrails_local.py b/python/nemo_relay/_guardrails_local.py index 5f30eb49..86f5946c 100644 --- a/python/nemo_relay/_guardrails_local.py +++ b/python/nemo_relay/_guardrails_local.py @@ -9,7 +9,7 @@ import importlib import json from collections.abc import Callable -from typing import Any, Protocol, cast +from typing import Any, NamedTuple, Protocol, cast from nemo_relay import Json, LLMRequest from nemo_relay.codecs import ( @@ -49,6 +49,15 @@ class _GuardrailsCodec(LlmCodec, LlmResponseCodec, Protocol): """Codec shape required by the local backend.""" +class _GuardrailsRuntimeImports(NamedTuple): + """Resolved Python symbols required by the local Guardrails backend.""" + + rails_config_cls: Any + llm_rails_cls: Any + rail_type: Any + rail_status: Any + + _CODECS: dict[str, Callable[[], _GuardrailsCodec]] = { "openai_chat": OpenAIChatCodec, "openai_responses": OpenAIResponsesCodec, @@ -56,7 +65,7 @@ class _GuardrailsCodec(LlmCodec, LlmResponseCodec, Protocol): } -def _load_nemoguardrails(module_name: str | None): +def _load_nemoguardrails(module_name: str | None) -> _GuardrailsRuntimeImports: root_module = module_name or "nemoguardrails" try: guardrails = cast(Any, importlib.import_module(root_module)) @@ -72,11 +81,11 @@ def _load_nemoguardrails(module_name: str | None): f"{error.name or error}. Install the full NeMo Guardrails runtime dependencies." ) from error - return ( - guardrails.RailsConfig, - guardrails.LLMRails, - options.RailType, - options.RailStatus, + return _GuardrailsRuntimeImports( + rails_config_cls=guardrails.RailsConfig, + llm_rails_cls=guardrails.LLMRails, + rail_type=options.RailType, + rail_status=options.RailStatus, ) @@ -543,9 +552,9 @@ def register_local_backend(config: dict[str, Any], context: PluginContext) -> No local = cast(dict[str, Any], config.get("local") or {}) module_name = cast(str | None, local.get("python_module")) - RailsConfig, LLMRails, RailType, RailStatus = _load_nemoguardrails(module_name) - guardrails_config = _build_guardrails_config(config, RailsConfig) - rails = LLMRails(guardrails_config) + runtime_imports = _load_nemoguardrails(module_name) + guardrails_config = _build_guardrails_config(config, runtime_imports.rails_config_cls) + rails = runtime_imports.llm_rails_cls(guardrails_config) enable_input = bool(config.get("input", True)) enable_output = bool(config.get("output", True)) enable_tool_input = bool(config.get("tool_input", False)) @@ -556,16 +565,16 @@ def register_local_backend(config: dict[str, Any], context: PluginContext) -> No codec_name, codec = _resolve_codec(config) intercept = _make_llm_intercept( rails=rails, - rail_type=RailType, - rail_status=RailStatus, + rail_type=runtime_imports.rail_type, + rail_status=runtime_imports.rail_status, codec=codec, enable_input=enable_input, enable_output=enable_output, ) stream_intercept = _make_llm_stream_intercept( rails=rails, - rail_type=RailType, - rail_status=RailStatus, + rail_type=runtime_imports.rail_type, + rail_status=runtime_imports.rail_status, codec_name=codec_name, codec=codec, enable_input=enable_input, @@ -581,8 +590,8 @@ def register_local_backend(config: dict[str, Any], context: PluginContext) -> No if enable_tool_input or enable_tool_output: tool_intercept = _make_tool_intercept( rails=rails, - rail_type=RailType, - rail_status=RailStatus, + rail_type=runtime_imports.rail_type, + rail_status=runtime_imports.rail_status, enable_tool_input=enable_tool_input, enable_tool_output=enable_tool_output, ) From f8dead5c41a81bfc58b6c895e396181674b2833a Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Mon, 1 Jun 2026 09:36:47 -0700 Subject: [PATCH 06/20] fix: address local guardrails review nits Signed-off-by: Alex Fournier --- crates/python/src/lib.rs | 21 + crates/python/src/py_plugin.rs | 18 +- .../python/tests/coverage/coverage_tests.rs | 370 ++++++++++-------- .../coverage/py_plugin_coverage_tests.rs | 45 +++ python/nemo_relay/_guardrails_local.py | 30 +- 5 files changed, 307 insertions(+), 177 deletions(-) diff --git a/crates/python/src/lib.rs b/crates/python/src/lib.rs index 13d0c29f..4a40eaf8 100644 --- a/crates/python/src/lib.rs +++ b/crates/python/src/lib.rs @@ -112,6 +112,10 @@ fn load_guardrails_local_register_fn(py: Python<'_>) -> PyResult module, Err(err) => { + if !is_missing_guardrails_local_module(py, &err)? { + return Err(err); + } + let source_python_dir = guardrails_local_source_python_dir(); if !source_python_dir.exists() { return Err(err); @@ -124,6 +128,23 @@ fn load_guardrails_local_register_fn(py: Python<'_>) -> PyResult, err: &PyErr) -> PyResult { + if !err.is_instance_of::(py) { + return Ok(false); + } + + let err_value = err.value(py); + let module_name = err_value + .getattr("name") + .ok() + .and_then(|name| name.extract::().ok()); + + Ok(matches!( + module_name.as_deref(), + Some("nemo_relay") | Some("nemo_relay._guardrails_local") + )) +} + fn guardrails_local_source_python_dir() -> PathBuf { PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../python") } diff --git a/crates/python/src/py_plugin.rs b/crates/python/src/py_plugin.rs index ee668ea1..09209b03 100644 --- a/crates/python/src/py_plugin.rs +++ b/crates/python/src/py_plugin.rs @@ -29,7 +29,8 @@ use nemo_relay::api::subscriber::{deregister_subscriber, register_subscriber}; use nemo_relay::plugin::{ ConfigDiagnostic, DiagnosticLevel, Plugin, PluginConfig, PluginError, PluginRegistration, PluginRegistrationContext, active_plugin_report, clear_plugin_configuration, deregister_plugin, - initialize_plugins, list_plugin_kinds, register_plugin, validate_plugin_config, + initialize_plugins, list_plugin_kinds, register_plugin, rollback_registrations, + validate_plugin_config, }; use crate::convert::{json_to_py, py_to_json}; @@ -174,10 +175,17 @@ pub(crate) fn invoke_python_plugin_register( namespace_prefix, )?; let plugin_config_py = plugin_config_to_py(py, plugin_kind, plugin_config)?; - register_fn.call1((plugin_config_py, py_ctx.clone_ref(py)))?; - { - let py_ctx_ref = py_ctx.bind(py).borrow(); - py_ctx_ref.drain_registrations() + match register_fn.call1((plugin_config_py, py_ctx.clone_ref(py))) { + Ok(_) => { + let py_ctx_ref = py_ctx.bind(py).borrow(); + py_ctx_ref.drain_registrations() + } + Err(err) => { + if let Ok(mut registrations) = py_ctx.bind(py).borrow().drain_registrations() { + rollback_registrations(&mut registrations); + } + Err(err) + } } } diff --git a/crates/python/tests/coverage/coverage_tests.rs b/crates/python/tests/coverage/coverage_tests.rs index 029eee58..a104d68c 100644 --- a/crates/python/tests/coverage/coverage_tests.rs +++ b/crates/python/tests/coverage/coverage_tests.rs @@ -4,13 +4,13 @@ //! Coverage tests for coverage in the NeMo Relay Python crate. use std::ffi::CString; +use std::panic::{AssertUnwindSafe, catch_unwind}; use std::path::PathBuf; use std::pin::Pin; use std::sync::Arc; -use pyo3::ffi::c_str; use pyo3::prelude::*; -use pyo3::types::{IntoPyDict, PyModule}; +use pyo3::types::{PyDict, PyModule}; use serde_json::{Value as Json, json}; use tokio_stream::Stream; use tokio_stream::StreamExt; @@ -115,6 +115,59 @@ class Context: "# } +fn with_isolated_nemo_relay_modules( + py: Python<'_>, + native_module: &Bound<'_, PyModule>, + f: impl FnOnce() -> T, +) -> T { + let sys = py.import("sys").unwrap(); + let modules = sys + .getattr("modules") + .unwrap() + .cast_into::() + .unwrap(); + let saved_modules = modules + .iter() + .filter_map(|(name, module)| { + let name = name.extract::().ok()?; + if name == "nemo_relay" || name.starts_with("nemo_relay.") { + Some((name, module.unbind())) + } else { + None + } + }) + .collect::>(); + + clear_nemo_relay_modules(&modules); + modules + .set_item("nemo_relay._native", native_module.clone()) + .unwrap(); + + let result = catch_unwind(AssertUnwindSafe(f)); + + clear_nemo_relay_modules(&modules); + for (name, module) in saved_modules { + modules.set_item(name, module).unwrap(); + } + + match result { + Ok(value) => value, + Err(payload) => std::panic::resume_unwind(payload), + } +} + +fn clear_nemo_relay_modules(modules: &Bound<'_, PyDict>) { + let module_names = modules + .iter() + .filter_map(|(name, _)| name.extract::().ok()) + .filter(|name| name == "nemo_relay" || name.starts_with("nemo_relay.")) + .collect::>(); + + for name in module_names { + modules.del_item(name).unwrap(); + } +} + fn make_request() -> LlmRequest { LlmRequest { headers: serde_json::Map::from_iter([("x-trace".into(), json!("1"))]), @@ -229,23 +282,19 @@ fn test_guardrails_local_helper_registers_and_enforces_llm_and_tool_checks() { Python::attach(|py| { let native_module = PyModule::new(py, "_native_guardrails_helper").unwrap(); crate::_native(&native_module).unwrap(); - let sys = py.import("sys").unwrap(); - let modules = sys.getattr("modules").unwrap(); - modules - .set_item("nemo_relay._native", native_module.clone()) - .unwrap(); - let python_dir = python_package_dir(); - let prelude = fake_guardrails_module_prelude( - "fake_guardrails_local_helper", - &python_dir.display().to_string(), - ); - let epilogue = register_fake_guardrails_module_epilogue(); - let context_class = local_plugin_context_python(); - let module = load_module( - py, - &format!( - r#" + with_isolated_nemo_relay_modules(py, &native_module, || { + let python_dir = python_package_dir(); + let prelude = fake_guardrails_module_prelude( + "fake_guardrails_local_helper", + &python_dir.display().to_string(), + ); + let epilogue = register_fake_guardrails_module_epilogue(); + let context_class = local_plugin_context_python(); + let module = load_module( + py, + &format!( + r#" {prelude} check_results = [] @@ -329,34 +378,61 @@ async def run_case(): "check_calls": check_calls, }} "#, - prelude = prelude, - epilogue = epilogue, - context_class = context_class, - ), - ); + prelude = prelude, + epilogue = epilogue, + context_class = context_class, + ), + ); - let result_json = with_event_loop(py, |event_loop| { - let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); - let result = event_loop - .call_method1("run_until_complete", (coroutine,)) - .unwrap(); - crate::convert::py_to_json(&result).unwrap() - }); + let result_json = with_event_loop(py, |event_loop| { + let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); + let result = event_loop + .call_method1("run_until_complete", (coroutine,)) + .unwrap(); + crate::convert::py_to_json(&result).unwrap() + }); - assert_eq!( - result_json["seen_request_messages"][0], - json!("sanitized user") - ); - assert_eq!(result_json["tool_result"], json!({ "ok": true })); - assert_eq!( - result_json["seen_tool_args"][0], - json!({ "city": "Boston" }) - ); - assert_eq!( - result_json["llm_result"]["choices"][0]["message"]["content"], - json!("safe reply") - ); - assert_eq!(result_json["check_calls"].as_array().unwrap().len(), 4); + assert_eq!( + result_json["seen_request_messages"][0], + json!("sanitized user") + ); + assert_eq!(result_json["tool_result"], json!({ "ok": true })); + assert_eq!( + result_json["seen_tool_args"][0], + json!({ "city": "Boston" }) + ); + assert_eq!( + result_json["llm_result"]["choices"][0]["message"]["content"], + json!("safe reply") + ); + assert_eq!( + result_json["check_calls"], + json!([ + [ + [{"role": "user", "content": "unsafe"}], + ["input"] + ], + [ + [ + {"role": "user", "content": "sanitized user"}, + {"role": "assistant", "content": "safe reply"} + ], + ["output"] + ], + [ + [{"role": "user", "content": "{\"arguments\":{\"city\":\"Phoenix\"},\"tool_name\":\"weather_lookup\"}"}], + ["input"] + ], + [ + [ + {"role": "user", "content": "{\"arguments\":{\"city\":\"Boston\"},\"tool_name\":\"weather_lookup\"}"}, + {"role": "assistant", "content": "{\"arguments\":{\"city\":\"Boston\"},\"result\":{\"raw\":true},\"tool_name\":\"weather_lookup\"}"} + ], + ["output"] + ] + ]) + ); + }); }); } @@ -366,23 +442,19 @@ fn test_guardrails_local_helper_enforces_streamed_output_rails() { Python::attach(|py| { let native_module = PyModule::new(py, "_native_guardrails_streaming").unwrap(); crate::_native(&native_module).unwrap(); - let sys = py.import("sys").unwrap(); - let modules = sys.getattr("modules").unwrap(); - modules - .set_item("nemo_relay._native", native_module.clone()) - .unwrap(); - let python_dir = python_package_dir(); - let prelude = fake_guardrails_module_prelude( - "fake_guardrails_streaming", - &python_dir.display().to_string(), - ); - let epilogue = register_fake_guardrails_module_epilogue(); - let context_class = local_plugin_context_python(); - let module = load_module( - py, - &format!( - r#" + with_isolated_nemo_relay_modules(py, &native_module, || { + let python_dir = python_package_dir(); + let prelude = fake_guardrails_module_prelude( + "fake_guardrails_streaming", + &python_dir.display().to_string(), + ); + let epilogue = register_fake_guardrails_module_epilogue(); + let context_class = local_plugin_context_python(); + let module = load_module( + py, + &format!( + r#" {prelude} stream_results = [] @@ -508,52 +580,53 @@ async def run_case(): "modified": modified, }} "#, - prelude = prelude, - epilogue = epilogue, - context_class = context_class, - ), - ); + prelude = prelude, + epilogue = epilogue, + context_class = context_class, + ), + ); - let result = with_event_loop(py, |event_loop| { - let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); - let result = event_loop - .call_method1("run_until_complete", (coroutine,)) - .unwrap(); - crate::convert::py_to_json(&result).unwrap() - }); - assert_eq!( - result["allowed_chunks"], - json!([ - {"choices": [{"delta": {"content": "hello"}}]}, - {"choices": [{"delta": {"content": "world"}}]} - ]) - ); - let event_log = result["event_log"].as_array().unwrap(); - assert_eq!( - &event_log[..6], - json!([ - "source:hello", - "yield:hello", - "source:world", - "yield:world", - "guardrails-sees:hello", - "guardrails-sees:world", - ]) - .as_array() - .unwrap() - ); - assert!( - result["blocked"] - .as_str() - .unwrap() - .contains("output rail blocked the LLM call") - ); - assert!( - result["modified"] - .as_str() + let result = with_event_loop(py, |event_loop| { + let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); + let result = event_loop + .call_method1("run_until_complete", (coroutine,)) + .unwrap(); + crate::convert::py_to_json(&result).unwrap() + }); + assert_eq!( + result["allowed_chunks"], + json!([ + {"choices": [{"delta": {"content": "hello"}}]}, + {"choices": [{"delta": {"content": "world"}}]} + ]) + ); + let event_log = result["event_log"].as_array().unwrap(); + assert_eq!( + &event_log[..6], + json!([ + "source:hello", + "yield:hello", + "source:world", + "yield:world", + "guardrails-sees:hello", + "guardrails-sees:world", + ]) + .as_array() .unwrap() - .contains("stream_first = true") - ); + ); + assert!( + result["blocked"] + .as_str() + .unwrap() + .contains("output rail blocked the LLM call") + ); + assert!( + result["modified"] + .as_str() + .unwrap() + .contains("stream_first = true") + ); + }); }); } @@ -565,36 +638,18 @@ fn test_local_guardrails_provider_initializes_and_enforces_managed_core_calls() Python::attach(|py| { let native_module = PyModule::new(py, "_native_guardrails_e2e").unwrap(); crate::_native(&native_module).unwrap(); - let sys = py.import("sys").unwrap(); - let modules = sys.getattr("modules").unwrap(); - let module_names = py - .eval( - c_str!("list(sys.modules.keys())"), - None, - Some(&[(c_str!("sys"), sys)].into_py_dict(py).unwrap()), - ) - .unwrap() - .extract::>() - .unwrap(); - for name in module_names { - if name == "nemo_relay" || name.starts_with("nemo_relay.") { - modules.del_item(name).unwrap(); - } - } - modules - .set_item("nemo_relay._native", native_module.clone()) - .unwrap(); - let python_dir = python_package_dir(); - let prelude = fake_guardrails_module_prelude( - "fake_guardrails_local_e2e", - &python_dir.display().to_string(), - ); - let epilogue = register_fake_guardrails_module_epilogue(); - let module = load_module( - py, - &format!( - r#" + with_isolated_nemo_relay_modules(py, &native_module, || { + let python_dir = python_package_dir(); + let prelude = fake_guardrails_module_prelude( + "fake_guardrails_local_e2e", + &python_dir.display().to_string(), + ); + let epilogue = register_fake_guardrails_module_epilogue(); + let module = load_module( + py, + &format!( + r#" {prelude} check_results = [] @@ -682,31 +737,32 @@ async def run_case(): "seen_tool_args": seen_tool_args, }} "#, - prelude = prelude, - epilogue = epilogue, - ), - ); - let result_json = with_event_loop(py, |event_loop| { - let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); - let result = event_loop - .call_method1("run_until_complete", (coroutine,)) - .unwrap(); - crate::convert::py_to_json(&result).unwrap() - }); + prelude = prelude, + epilogue = epilogue, + ), + ); + let result_json = with_event_loop(py, |event_loop| { + let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); + let result = event_loop + .call_method1("run_until_complete", (coroutine,)) + .unwrap(); + crate::convert::py_to_json(&result).unwrap() + }); - assert_eq!( - result_json["llm_result"]["choices"][0]["message"]["content"], - json!("safe reply") - ); - assert_eq!(result_json["tool_result"], json!({ "ok": true })); - assert_eq!( - result_json["seen_request_messages"][0], - json!("sanitized user") - ); - assert_eq!( - result_json["seen_tool_args"][0], - json!({ "city": "Boston" }) - ); + assert_eq!( + result_json["llm_result"]["choices"][0]["message"]["content"], + json!("safe reply") + ); + assert_eq!(result_json["tool_result"], json!({ "ok": true })); + assert_eq!( + result_json["seen_request_messages"][0], + json!("sanitized user") + ); + assert_eq!( + result_json["seen_tool_args"][0], + json!({ "city": "Boston" }) + ); + }); }); reset_runtime_state(); diff --git a/crates/python/tests/coverage/py_plugin_coverage_tests.rs b/crates/python/tests/coverage/py_plugin_coverage_tests.rs index f774d5ea..dbb8a21b 100644 --- a/crates/python/tests/coverage/py_plugin_coverage_tests.rs +++ b/crates/python/tests/coverage/py_plugin_coverage_tests.rs @@ -792,6 +792,51 @@ async def initialize_plugins(module, config): }); } +#[test] +fn invoke_python_plugin_register_rolls_back_partial_registrations_on_error() { + let _python = crate::test_support::init_python_test(); + Python::attach(|py| { + let helpers = load_module( + py, + r#" +def subscriber(event): + return None + +class FailingPlugin: + def register(self, plugin_config, context): + context.register_subscriber("sub", subscriber) + raise RuntimeError("boom") +"#, + ); + + let plugin = helpers.getattr("FailingPlugin").unwrap().call0().unwrap(); + let register_fn = plugin.getattr("register").unwrap(); + let namespace_prefix = "rollback.".to_string(); + + for _ in 0..2 { + let err = invoke_python_plugin_register( + py, + "demo.rollback", + ®ister_fn, + &serde_json::Map::new(), + namespace_prefix.clone(), + ) + .unwrap_err(); + assert!(err.to_string().contains("boom"), "{err}"); + + let context = PyPluginContext { + registrations: Arc::new(Mutex::new(vec![])), + namespace_prefix: namespace_prefix.clone(), + }; + context + .register_subscriber("sub", helpers.getattr("subscriber").unwrap().unbind()) + .unwrap(); + let mut registrations = context.drain_registrations().unwrap(); + rollback_registrations(&mut registrations); + } + }); +} + #[test] fn plugin_context_lock_poisoning_covers_error_paths() { let _python = crate::test_support::init_python_test(); diff --git a/python/nemo_relay/_guardrails_local.py b/python/nemo_relay/_guardrails_local.py index 86f5946c..f16f839a 100644 --- a/python/nemo_relay/_guardrails_local.py +++ b/python/nemo_relay/_guardrails_local.py @@ -462,21 +462,21 @@ async def stream_intercept(request: LLMRequest, next_call): ) text_queue: asyncio.Queue[str | None] = asyncio.Queue() - blocked: dict[str, str | None] = {"message": None} - monitor = asyncio.create_task( - _monitor_streaming_output_rails( - rails=rails, - messages=messages, - text_queue=text_queue, - blocked=blocked, - ) - ) + block_state: dict[str, str | None] = {"message": None} async def guarded_provider_stream(): + monitor = asyncio.create_task( + _monitor_streaming_output_rails( + rails=rails, + messages=messages, + text_queue=text_queue, + blocked=block_state, + ) + ) try: async for chunk in stream: - if blocked["message"] is not None: - _raise_streaming_output_blocked(blocked["message"]) + if block_state["message"] is not None: + _raise_streaming_output_blocked(block_state["message"]) text = _extract_stream_text(codec_name, chunk) if text is not None: @@ -484,13 +484,13 @@ async def guarded_provider_stream(): yield chunk - if blocked["message"] is not None: - _raise_streaming_output_blocked(blocked["message"]) + if block_state["message"] is not None: + _raise_streaming_output_blocked(block_state["message"]) finally: await text_queue.put(None) await monitor - if blocked["message"] is not None: - _raise_streaming_output_blocked(blocked["message"]) + if block_state["message"] is not None: + _raise_streaming_output_blocked(block_state["message"]) return guarded_provider_stream() From 67fd1b912e4da86c8362f6d93e0437933591dd20 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Mon, 1 Jun 2026 12:11:06 -0700 Subject: [PATCH 07/20] test: extend local guardrails cli coverage Signed-off-by: Alex Fournier --- crates/cli/tests/coverage/plugins_tests.rs | 243 ++++++++++++++++++++- 1 file changed, 242 insertions(+), 1 deletion(-) diff --git a/crates/cli/tests/coverage/plugins_tests.rs b/crates/cli/tests/coverage/plugins_tests.rs index 28bf0fd2..502e7b07 100644 --- a/crates/cli/tests/coverage/plugins_tests.rs +++ b/crates/cli/tests/coverage/plugins_tests.rs @@ -7,7 +7,7 @@ use nemo_relay::config_editor::{EditorConfig, EditorSchema}; use nemo_relay::observability::plugin_component::{OBSERVABILITY_PLUGIN_KIND, ObservabilityConfig}; use nemo_relay::plugin::{ConfigPolicy, PluginComponentSpec, PluginConfig}; use nemo_relay::plugins::nemo_guardrails::component::{ - NEMO_GUARDRAILS_PLUGIN_KIND, NeMoGuardrailsConfig, RemoteBackendConfig, + LocalBackendConfig, NEMO_GUARDRAILS_PLUGIN_KIND, NeMoGuardrailsConfig, RemoteBackendConfig, }; use nemo_relay_adaptive::AdaptiveConfig; use nemo_relay_adaptive::plugin_component::ADAPTIVE_PLUGIN_KIND; @@ -50,6 +50,40 @@ fn guardrails_component_config(config_id: &str) -> serde_json::Map serde_json::Map { + json!({ + "mode": "local", + "input": false, + "output": false, + "config_path": config_path, + "tool_input": true, + "tool_output": true, + "local": { + "python_module": "custom_guardrails" + } + }) + .as_object() + .unwrap() + .clone() +} + +fn local_llm_guardrails_component_config(config_yaml: &str) -> serde_json::Map { + json!({ + "mode": "local", + "codec": "openai_chat", + "input": true, + "output": true, + "config_yaml": config_yaml, + "colang_content": "define flow noop\n pass", + "local": { + "python_module": "custom_guardrails" + } + }) + .as_object() + .unwrap() + .clone() +} + #[test] fn target_scope_defaults_to_user_and_rejects_conflicts() { assert_eq!( @@ -160,6 +194,24 @@ fn typed_editor_model_contains_nemo_guardrails_options() { EditorFieldKind::StringMap ); + let local = schema.field("local").unwrap().schema().unwrap(); + assert_eq!( + local.field("python_module").unwrap().kind, + EditorFieldKind::String + ); + assert_eq!( + schema.field("config_path").unwrap().kind, + EditorFieldKind::String + ); + assert_eq!( + schema.field("config_yaml").unwrap().kind, + EditorFieldKind::String + ); + assert_eq!( + schema.field("colang_content").unwrap().kind, + EditorFieldKind::String + ); + let request_defaults = schema.field("request_defaults").unwrap().schema().unwrap(); let rails = request_defaults.field("rails").unwrap().schema().unwrap(); assert_eq!( @@ -1137,6 +1189,98 @@ fn validate_config_accepts_nemo_guardrails_component() { validate_config(&config).unwrap(); } +#[test] +fn validate_config_accepts_local_tool_only_nemo_guardrails_component() { + let config = PluginConfig { + components: vec![PluginComponentSpec { + kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(), + enabled: true, + config: local_guardrails_component_config("./rails"), + }], + ..PluginConfig::default() + }; + + validate_config(&config).unwrap(); +} + +#[test] +fn validate_config_rejects_local_nemo_guardrails_request_defaults() { + let config = PluginConfig { + components: vec![PluginComponentSpec { + kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(), + enabled: true, + config: json!({ + "mode": "local", + "codec": "openai_chat", + "input": true, + "output": true, + "config_yaml": "models: []", + "request_defaults": { + "context": {"tenant": "demo"} + } + }) + .as_object() + .unwrap() + .clone(), + }], + ..PluginConfig::default() + }; + + let error = validate_config(&config).unwrap_err().to_string(); + assert!(error.contains("request_defaults"), "error was: {error}"); + assert!(error.contains("local mode"), "error was: {error}"); +} + +#[test] +fn validate_config_rejects_local_nemo_guardrails_multiple_config_sources() { + let config = PluginConfig { + components: vec![PluginComponentSpec { + kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(), + enabled: true, + config: json!({ + "mode": "local", + "config_path": "./rails", + "config_yaml": "models: []" + }) + .as_object() + .unwrap() + .clone(), + }], + ..PluginConfig::default() + }; + + let error = validate_config(&config).unwrap_err().to_string(); + assert!( + error.contains("exactly one of config_path or config_yaml"), + "error was: {error}" + ); +} + +#[test] +fn validate_config_rejects_local_nemo_guardrails_colang_without_yaml() { + let config = PluginConfig { + components: vec![PluginComponentSpec { + kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(), + enabled: true, + config: json!({ + "mode": "local", + "config_path": "./rails", + "colang_content": "define flow noop\n pass" + }) + .as_object() + .unwrap() + .clone(), + }], + ..PluginConfig::default() + }; + + let error = validate_config(&config).unwrap_err().to_string(); + assert!( + error.contains("colang_content can only be used with config_yaml"), + "error was: {error}" + ); +} + #[test] fn nemo_guardrails_config_map_prunes_default_version() { let map = nemo_guardrails_config_map(&NeMoGuardrailsConfig { @@ -1155,6 +1299,103 @@ fn nemo_guardrails_config_map_prunes_default_version() { assert_eq!(map["remote"]["config_id"], json!("default")); } +#[test] +fn write_plugin_config_round_trips_local_nemo_guardrails_component() { + let temp = tempfile::tempdir().unwrap(); + let path = temp.path().join("plugins.toml"); + let config = PluginConfig { + components: vec![PluginComponentSpec { + kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(), + enabled: true, + config: local_guardrails_component_config("./rails"), + }], + ..PluginConfig::default() + }; + + write_plugin_config(&path, &config).unwrap(); + + let rendered = std::fs::read_to_string(&path).unwrap(); + assert!(rendered.contains("mode = \"local\"")); + assert!(rendered.contains("config_path = \"./rails\"")); + assert!(rendered.contains("tool_input = true")); + assert!(rendered.contains("python_module = \"custom_guardrails\"")); + + let round_tripped = read_plugin_config(&path).unwrap(); + let guardrails = round_tripped + .components + .iter() + .find(|component| component.kind == NEMO_GUARDRAILS_PLUGIN_KIND) + .unwrap(); + assert!(guardrails.enabled); + assert_eq!(guardrails.config["mode"], json!("local")); + assert_eq!(guardrails.config["config_path"], json!("./rails")); + assert_eq!(guardrails.config["tool_input"], json!(true)); + assert_eq!( + guardrails.config["local"]["python_module"], + json!("custom_guardrails") + ); +} + +#[test] +fn nemo_guardrails_config_map_serializes_local_mode_fields() { + let map = nemo_guardrails_config_map(&NeMoGuardrailsConfig { + mode: "local".into(), + config_path: Some("./rails".into()), + tool_input: true, + tool_output: true, + local: Some(LocalBackendConfig { + python_module: Some("custom_guardrails".into()), + }), + ..NeMoGuardrailsConfig::default() + }) + .unwrap(); + + assert!(!map.contains_key("version")); + assert_eq!(map.get("mode"), Some(&json!("local"))); + assert_eq!(map.get("config_path"), Some(&json!("./rails"))); + assert_eq!(map.get("tool_input"), Some(&json!(true))); + assert_eq!(map["local"]["python_module"], json!("custom_guardrails")); +} + +#[test] +fn write_plugin_config_round_trips_local_llm_nemo_guardrails_component() { + let temp = tempfile::tempdir().unwrap(); + let path = temp.path().join("plugins.toml"); + let config = PluginConfig { + components: vec![PluginComponentSpec { + kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(), + enabled: true, + config: local_llm_guardrails_component_config("models: []"), + }], + ..PluginConfig::default() + }; + + write_plugin_config(&path, &config).unwrap(); + + let rendered = std::fs::read_to_string(&path).unwrap(); + assert!(rendered.contains("mode = \"local\"")); + assert!(rendered.contains("codec = \"openai_chat\"")); + assert!(rendered.contains("input = true")); + assert!(rendered.contains("output = true")); + assert!(rendered.contains("config_yaml = \"models: []\"")); + + let round_tripped = read_plugin_config(&path).unwrap(); + let guardrails = round_tripped + .components + .iter() + .find(|component| component.kind == NEMO_GUARDRAILS_PLUGIN_KIND) + .unwrap(); + assert_eq!(guardrails.config["mode"], json!("local")); + assert_eq!(guardrails.config["codec"], json!("openai_chat")); + assert_eq!(guardrails.config["input"], json!(true)); + assert_eq!(guardrails.config["output"], json!(true)); + assert_eq!(guardrails.config["config_yaml"], json!("models: []")); + assert_eq!( + guardrails.config["colang_content"], + json!("define flow noop\n pass") + ); +} + #[test] fn display_helpers_render_scalars_json_and_defaults() { assert_eq!(display_value(&json!("logs")), "logs"); From e86ae58ef5e8b4d5dffab6b710db1b08d42614fc Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Wed, 3 Jun 2026 17:11:46 -0700 Subject: [PATCH 08/20] refactor: embed local guardrails helper snapshot Signed-off-by: Alex Fournier --- .../embedded_python}/_guardrails_local.py | 11 +- crates/python/src/lib.rs | 80 +++++++---- .../python/tests/coverage/coverage_tests.rs | 128 ++++++++++++++++++ docs/nemo-guardrails-plugin/configuration.mdx | 5 +- 4 files changed, 191 insertions(+), 33 deletions(-) rename {python/nemo_relay => crates/python/embedded_python}/_guardrails_local.py (97%) diff --git a/python/nemo_relay/_guardrails_local.py b/crates/python/embedded_python/_guardrails_local.py similarity index 97% rename from python/nemo_relay/_guardrails_local.py rename to crates/python/embedded_python/_guardrails_local.py index f16f839a..9f93367c 100644 --- a/python/nemo_relay/_guardrails_local.py +++ b/crates/python/embedded_python/_guardrails_local.py @@ -22,6 +22,7 @@ from nemo_relay.plugin import PluginContext _DEFAULT_PRIORITY = 100 +_SUPPORTED_NEMOGUARDRAILS_VERSION = "0.22.0" class NeMoGuardrailsDependencyError(RuntimeError): @@ -74,13 +75,21 @@ def _load_nemoguardrails(module_name: str | None) -> _GuardrailsRuntimeImports: if error.name == root_module: raise NeMoGuardrailsDependencyError( "NeMo Guardrails is required for the built-in NeMo Guardrails local backend. " - "Install it with: pip install nemoguardrails" + "Install it with: pip install nemoguardrails==0.22.0" ) from error raise NeMoGuardrailsDependencyError( "NeMo Guardrails local backend could not import a required dependency: " f"{error.name or error}. Install the full NeMo Guardrails runtime dependencies." ) from error + version = getattr(guardrails, "__version__", None) + if version != _SUPPORTED_NEMOGUARDRAILS_VERSION: + raise NeMoGuardrailsDependencyError( + "NeMo Guardrails local backend requires nemoguardrails==" + f"{_SUPPORTED_NEMOGUARDRAILS_VERSION}, but found {version!r}. " + "Install it with: pip install nemoguardrails==0.22.0" + ) + return _GuardrailsRuntimeImports( rails_config_cls=guardrails.RailsConfig, llm_rails_cls=guardrails.LLMRails, diff --git a/crates/python/src/lib.rs b/crates/python/src/lib.rs index 4a40eaf8..8328c265 100644 --- a/crates/python/src/lib.rs +++ b/crates/python/src/lib.rs @@ -27,7 +27,9 @@ use nemo_relay::plugins::nemo_guardrails::component::{ use nemo_relay::shared_runtime::initialize_shared_runtime_binding; use nemo_relay_adaptive::plugin_component::register_adaptive_component; use pyo3::prelude::*; +use pyo3::types::{PyDict, PyModule}; use serde_json::Value as Json; +use std::ffi::CString; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -46,6 +48,11 @@ pub mod py_types; #[cfg(test)] mod test_support; +const EMBEDDED_GUARDRAILS_LOCAL_MODULE_NAME: &str = "nemo_relay._guardrails_local"; +const EMBEDDED_GUARDRAILS_LOCAL_FILENAME: &str = "nemo_relay/_guardrails_local.py"; +const EMBEDDED_GUARDRAILS_LOCAL_SOURCE: &str = + include_str!("../embedded_python/_guardrails_local.py"); + /// The `_native` PyO3 module entry point. Registers all types and functions. #[pymodule] fn _native(m: &Bound<'_, PyModule>) -> PyResult<()> { @@ -70,6 +77,18 @@ fn _native(m: &Bound<'_, PyModule>) -> PyResult<()> { py_api::register(m)?; py_plugin::register(m)?; py_adaptive::register(m)?; + install_native_module_alias(m)?; + Ok(()) +} + +fn install_native_module_alias(m: &Bound<'_, PyModule>) -> PyResult<()> { + let py = m.py(); + let sys = py.import("sys")?; + let modules = sys.getattr("modules")?.cast_into::()?; + modules.set_item("nemo_relay._native", m)?; + if let Ok(package) = py.import("nemo_relay") { + let _ = package.setattr("_native", m); + } Ok(()) } @@ -109,43 +128,46 @@ fn register_python_local_guardrails_backend( } fn load_guardrails_local_register_fn(py: Python<'_>) -> PyResult> { - let module = match py.import("nemo_relay._guardrails_local") { - Ok(module) => module, - Err(err) => { - if !is_missing_guardrails_local_module(py, &err)? { - return Err(err); - } + let module = load_embedded_guardrails_local_module(py)?; + module.getattr("register_local_backend") +} - let source_python_dir = guardrails_local_source_python_dir(); - if !source_python_dir.exists() { - return Err(err); - } +fn load_embedded_guardrails_local_module(py: Python<'_>) -> PyResult> { + ensure_nemo_relay_package_importable(py)?; - prepend_python_path_if_missing(py, &source_python_dir)?; - py.import("nemo_relay._guardrails_local")? - } - }; - module.getattr("register_local_backend") + let sys = py.import("sys")?; + let modules = sys.getattr("modules")?.cast_into::()?; + if let Some(existing) = modules.get_item(EMBEDDED_GUARDRAILS_LOCAL_MODULE_NAME)? { + return Ok(existing.cast_into::()?); + } + + let source = CString::new(EMBEDDED_GUARDRAILS_LOCAL_SOURCE).unwrap(); + let filename = CString::new(EMBEDDED_GUARDRAILS_LOCAL_FILENAME).unwrap(); + let module_name = CString::new(EMBEDDED_GUARDRAILS_LOCAL_MODULE_NAME).unwrap(); + let module = PyModule::from_code(py, &source, &filename, &module_name)?; + modules.set_item(EMBEDDED_GUARDRAILS_LOCAL_MODULE_NAME, &module)?; + if let Ok(package) = py.import("nemo_relay") { + let _ = package.setattr("_guardrails_local", &module); + } + Ok(module) } -fn is_missing_guardrails_local_module(py: Python<'_>, err: &PyErr) -> PyResult { - if !err.is_instance_of::(py) { - return Ok(false); +fn ensure_nemo_relay_package_importable(py: Python<'_>) -> PyResult<()> { + if py.import("nemo_relay").is_ok() { + return Ok(()); } - let err_value = err.value(py); - let module_name = err_value - .getattr("name") - .ok() - .and_then(|name| name.extract::().ok()); + let source_python_dir = embedded_guardrails_source_python_dir(); + if !source_python_dir.exists() { + return Ok(()); + } - Ok(matches!( - module_name.as_deref(), - Some("nemo_relay") | Some("nemo_relay._guardrails_local") - )) + prepend_python_path_if_missing(py, &source_python_dir)?; + let _ = py.import("nemo_relay")?; + Ok(()) } -fn guardrails_local_source_python_dir() -> PathBuf { +fn embedded_guardrails_source_python_dir() -> PathBuf { PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../python") } @@ -155,8 +177,6 @@ fn prepend_python_path_if_missing(py: Python<'_>, path: &Path) -> PyResult<()> { let path_str = path.to_string_lossy(); if !sys_path.contains(path_str.as_ref())? { - // Source-tree fallback for local development and in-repo tests where the - // Python package has not been installed into the active environment yet. sys_path.call_method1("insert", (0, path_str.as_ref()))?; } diff --git a/crates/python/tests/coverage/coverage_tests.rs b/crates/python/tests/coverage/coverage_tests.rs index a104d68c..75a38ea1 100644 --- a/crates/python/tests/coverage/coverage_tests.rs +++ b/crates/python/tests/coverage/coverage_tests.rs @@ -45,6 +45,10 @@ fn python_package_dir() -> PathBuf { PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../python") } +fn embedded_guardrails_local_source_path() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("embedded_python/_guardrails_local.py") +} + fn fake_guardrails_module_prelude(module_name: &str, python_dir: &str) -> String { format!( r#" @@ -115,6 +119,34 @@ class Context: "# } +fn embedded_guardrails_local_loader_python(source_path: &str) -> String { + format!( + r#" +import pathlib +import sys +import types + +import nemo_relay + +GUARDRAILS_LOCAL_SOURCE_PATH = pathlib.Path({source_path:?}) +guardrails_local_module = types.ModuleType("nemo_relay._guardrails_local") +guardrails_local_module.__file__ = str(GUARDRAILS_LOCAL_SOURCE_PATH) +guardrails_local_module.__package__ = "nemo_relay" +sys.modules["nemo_relay._guardrails_local"] = guardrails_local_module +setattr(nemo_relay, "_guardrails_local", guardrails_local_module) +exec( + compile( + GUARDRAILS_LOCAL_SOURCE_PATH.read_text(), + str(GUARDRAILS_LOCAL_SOURCE_PATH), + "exec", + ), + guardrails_local_module.__dict__, +) +"#, + source_path = source_path, + ) +} + fn with_isolated_nemo_relay_modules( py: Python<'_>, native_module: &Bound<'_, PyModule>, @@ -291,6 +323,11 @@ fn test_guardrails_local_helper_registers_and_enforces_llm_and_tool_checks() { ); let epilogue = register_fake_guardrails_module_epilogue(); let context_class = local_plugin_context_python(); + let embedded_loader = embedded_guardrails_local_loader_python( + &embedded_guardrails_local_source_path() + .display() + .to_string(), + ); let module = load_module( py, &format!( @@ -310,6 +347,8 @@ class LLMRails: {epilogue} +{embedded_loader} + from nemo_relay._native import LLMRequest from nemo_relay._guardrails_local import register_local_backend @@ -381,6 +420,7 @@ async def run_case(): prelude = prelude, epilogue = epilogue, context_class = context_class, + embedded_loader = embedded_loader, ), ); @@ -436,6 +476,86 @@ async def run_case(): }); } +#[test] +fn test_guardrails_local_helper_rejects_unsupported_nemoguardrails_version() { + let _python = crate::test_support::init_python_test(); + Python::attach(|py| { + let native_module = PyModule::new(py, "_native_guardrails_version").unwrap(); + crate::_native(&native_module).unwrap(); + + with_isolated_nemo_relay_modules(py, &native_module, || { + let python_dir = python_package_dir(); + let prelude = fake_guardrails_module_prelude( + "fake_guardrails_bad_version", + &python_dir.display().to_string(), + ); + let epilogue = register_fake_guardrails_module_epilogue(); + let context_class = local_plugin_context_python(); + let embedded_loader = embedded_guardrails_local_loader_python( + &embedded_guardrails_local_source_path() + .display() + .to_string(), + ); + let module = load_module( + py, + &format!( + r#" +{prelude} + +fake_root.__version__ = "0.21.0" + +class LLMRails: + def __init__(self, config): + self.config = config + + async def check_async(self, messages, rail_types): + return Result(RailStatus.PASSED) + +{epilogue} + +{embedded_loader} + +from nemo_relay._guardrails_local import register_local_backend + +{context_class} + +async def run_case(): + ctx = Context() + register_local_backend( + {{ + "mode": "local", + "codec": "openai_chat", + "config_yaml": "models: []", + "input": True, + "local": {{"python_module": MODULE_NAME}}, + }}, + ctx, + ) +"#, + prelude = prelude, + epilogue = epilogue, + embedded_loader = embedded_loader, + context_class = context_class, + ), + ); + + let error = with_event_loop(py, |event_loop| { + let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); + event_loop + .call_method1("run_until_complete", (coroutine,)) + .unwrap_err() + .to_string() + }); + + assert!( + error.contains("requires nemoguardrails==0.22.0"), + "unexpected error: {error}" + ); + assert!(error.contains("0.21.0"), "unexpected error: {error}"); + }); + }); +} + #[test] fn test_guardrails_local_helper_enforces_streamed_output_rails() { let _python = crate::test_support::init_python_test(); @@ -451,6 +571,11 @@ fn test_guardrails_local_helper_enforces_streamed_output_rails() { ); let epilogue = register_fake_guardrails_module_epilogue(); let context_class = local_plugin_context_python(); + let embedded_loader = embedded_guardrails_local_loader_python( + &embedded_guardrails_local_source_path() + .display() + .to_string(), + ); let module = load_module( py, &format!( @@ -487,6 +612,8 @@ class LLMRails: {epilogue} +{embedded_loader} + from nemo_relay._native import LLMRequest from nemo_relay._guardrails_local import register_local_backend @@ -583,6 +710,7 @@ async def run_case(): prelude = prelude, epilogue = epilogue, context_class = context_class, + embedded_loader = embedded_loader, ), ); diff --git a/docs/nemo-guardrails-plugin/configuration.mdx b/docs/nemo-guardrails-plugin/configuration.mdx index 24cc12c6..16245f24 100644 --- a/docs/nemo-guardrails-plugin/configuration.mdx +++ b/docs/nemo-guardrails-plugin/configuration.mdx @@ -225,12 +225,13 @@ through the Python runtime instead of a separate Guardrails service. ### Requirements To use `mode = "local"`, the running Python environment must be able to import -`nemoguardrails`. +`nemoguardrails==0.22.0`. The built-in local backend is installed by the Python binding and runs Guardrails in process. Use it when the runtime has direct access to the Python Guardrails dependency and configuration files rather than a separate Guardrails -service. +service. Install the tested local-mode Guardrails dependency with +`pip install nemoguardrails==0.22.0`. The same ownership boundary still applies: From 3c7eecdb7a5549f08b96ddd181d1da25ffbdacb0 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Thu, 4 Jun 2026 08:42:02 -0700 Subject: [PATCH 09/20] refactor: move local guardrails backend into core Signed-off-by: Alex Fournier --- Cargo.lock | 3 + crates/cli/Cargo.toml | 2 +- crates/cli/src/config.rs | 2 +- crates/core/Cargo.toml | 10 +- .../src/plugins/nemo_guardrails/component.rs | 10 +- .../embedded_python/_guardrails_local.py | 12 +- .../core/src/plugins/nemo_guardrails/local.rs | 48 +- .../src/plugins/nemo_guardrails/python.rs | 842 ++++++++++++++++++ .../nemo_guardrails/component_tests.rs | 54 +- crates/python/Cargo.toml | 2 +- crates/python/src/lib.rs | 111 --- .../python/tests/coverage/coverage_tests.rs | 68 +- 12 files changed, 948 insertions(+), 216 deletions(-) rename crates/{python => core/src/plugins/nemo_guardrails}/embedded_python/_guardrails_local.py (98%) create mode 100644 crates/core/src/plugins/nemo_guardrails/python.rs diff --git a/Cargo.lock b/Cargo.lock index 0246c835..488fab08 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1327,6 +1327,9 @@ dependencies = [ "opentelemetry-http", "opentelemetry-otlp", "opentelemetry_sdk", + "pyo3", + "pyo3-async-runtimes", + "pythonize", "reqwest", "rustls", "schemars", diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index ca1d49d2..16bc94a2 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -21,7 +21,7 @@ pkg-fmt = "bin" workspace = true [dependencies] -nemo-relay = { workspace = true, features = ["guardrails-remote", "object-store", "openinference"] } +nemo-relay = { workspace = true, features = ["guardrails-remote", "object-store", "openinference", "python"] } nemo-relay-adaptive = { workspace = true, features = ["redis-backend"] } async-stream = "0.3" axum = "0.8" diff --git a/crates/cli/src/config.rs b/crates/cli/src/config.rs index 2307065f..9e5b281e 100644 --- a/crates/cli/src/config.rs +++ b/crates/cli/src/config.rs @@ -158,7 +158,7 @@ pub(crate) struct PluginsCommand { /// Plugin configuration subcommands. #[derive(Debug, Clone, Subcommand)] pub(crate) enum PluginsSubcommand { - /// Interactively create or edit the Observability plugin in `plugins.toml`. + /// Interactively create or edit built-in plugin configuration in `plugins.toml`. Edit(PluginsEditCommand), } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index e205b83c..c69c8226 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -14,12 +14,17 @@ readme = "README.md" workspace = true [features] -default = ["otel", "openinference", "guardrails-remote", "object-store"] +default = ["otel", "openinference", "guardrails-remote", "object-store", "python"] schema = ["dep:schemars"] guardrails-remote = [ "dep:reqwest", "dep:rustls", ] +python = [ + "dep:pyo3", + "dep:pyo3-async-runtimes", + "dep:pythonize", +] object-store = [ "dep:object_store", "tokio/net", @@ -78,6 +83,9 @@ opentelemetry-http = { version = "0.31", default-features = false, optional = tr wasm-bindgen = { version = "0.2", optional = true } wasm-bindgen-futures = { version = "0.4", optional = true } web-sys = { version = "0.3", features = ["Headers", "Request", "RequestInit", "Response", "Window", "console"], optional = true } +pyo3 = { version = "0.28.2", features = ["auto-initialize"], optional = true } +pyo3-async-runtimes = { version = "0.28.0", features = ["tokio-runtime"], optional = true } +pythonize = { version = "0.28.0", optional = true } [dev-dependencies] tokio = { version = "1", features = ["rt", "macros", "sync", "test-util", "rt-multi-thread", "time"] } diff --git a/crates/core/src/plugins/nemo_guardrails/component.rs b/crates/core/src/plugins/nemo_guardrails/component.rs index 28decfbe..37255738 100644 --- a/crates/core/src/plugins/nemo_guardrails/component.rs +++ b/crates/core/src/plugins/nemo_guardrails/component.rs @@ -23,7 +23,6 @@ mod local; #[path = "remote.rs"] mod remote; use local::register_local_backend; -pub use local::{clear_local_backend_provider, register_local_backend_provider}; #[cfg(all(feature = "guardrails-remote", not(target_arch = "wasm32")))] use remote::register_remote_backend; @@ -745,6 +744,15 @@ fn validate_local_config_shape( config: &NeMoGuardrailsConfig, flags: &ConfigShapeFlags, ) { + #[cfg(not(feature = "python"))] + push_config_shape_diag( + diagnostics, + policy.unsupported_value, + "nemo_guardrails.unavailable_backend", + Some("mode"), + "local mode requires a build with the 'python' feature enabled", + ); + if flags.has_config_path == flags.has_config_yaml { push_config_shape_diag( diagnostics, diff --git a/crates/python/embedded_python/_guardrails_local.py b/crates/core/src/plugins/nemo_guardrails/embedded_python/_guardrails_local.py similarity index 98% rename from crates/python/embedded_python/_guardrails_local.py rename to crates/core/src/plugins/nemo_guardrails/embedded_python/_guardrails_local.py index 9f93367c..fd863640 100644 --- a/crates/python/embedded_python/_guardrails_local.py +++ b/crates/core/src/plugins/nemo_guardrails/embedded_python/_guardrails_local.py @@ -11,15 +11,15 @@ from collections.abc import Callable from typing import Any, NamedTuple, Protocol, cast -from nemo_relay import Json, LLMRequest -from nemo_relay.codecs import ( +from _nemo_guardrails_local_runtime import ( AnthropicMessagesCodec, - LlmCodec, - LlmResponseCodec, + LLMRequest, OpenAIChatCodec, OpenAIResponsesCodec, + PluginContext, ) -from nemo_relay.plugin import PluginContext + +Json = Any _DEFAULT_PRIORITY = 100 _SUPPORTED_NEMOGUARDRAILS_VERSION = "0.22.0" @@ -46,7 +46,7 @@ def __init__( self.content = content -class _GuardrailsCodec(LlmCodec, LlmResponseCodec, Protocol): +class _GuardrailsCodec(Protocol): """Codec shape required by the local backend.""" diff --git a/crates/core/src/plugins/nemo_guardrails/local.rs b/crates/core/src/plugins/nemo_guardrails/local.rs index 240ed186..f46dea2f 100644 --- a/crates/core/src/plugins/nemo_guardrails/local.rs +++ b/crates/core/src/plugins/nemo_guardrails/local.rs @@ -1,52 +1,24 @@ // SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use std::sync::{Arc, LazyLock, Mutex, MutexGuard}; - use crate::plugin::{PluginError, PluginRegistrationContext, Result as PluginResult}; use super::NeMoGuardrailsConfig; -type LocalBackendProvider = Arc< - dyn Fn(NeMoGuardrailsConfig, &mut PluginRegistrationContext) -> PluginResult<()> + Send + Sync, ->; - -static LOCAL_BACKEND_PROVIDER: LazyLock>> = - LazyLock::new(|| Mutex::new(None)); - -fn local_backend_provider_guard() -> PluginResult>> -{ - LOCAL_BACKEND_PROVIDER.lock().map_err(|e| { - PluginError::Internal(format!( - "NeMo Guardrails local backend provider lock poisoned: {e}" - )) - }) -} - -#[doc(hidden)] -pub fn register_local_backend_provider(provider: LocalBackendProvider) -> PluginResult<()> { - let mut guard = local_backend_provider_guard()?; - *guard = Some(provider); - Ok(()) -} - -#[doc(hidden)] -pub fn clear_local_backend_provider() -> PluginResult<()> { - let mut guard = local_backend_provider_guard()?; - *guard = None; - Ok(()) -} +#[cfg(feature = "python")] +mod python; pub(super) fn register_local_backend( config: NeMoGuardrailsConfig, ctx: &mut PluginRegistrationContext, ) -> PluginResult<()> { - let provider = local_backend_provider_guard()?.clone(); - - match provider { - Some(provider) => provider(config, ctx), - None => Err(PluginError::RegistrationFailed( - "built-in NeMo Guardrails local backend is unavailable in this runtime".to_string(), - )), + #[cfg(feature = "python")] + { + return python::register_local_backend(config, ctx); } + + #[allow(unreachable_code)] + Err(PluginError::RegistrationFailed( + "built-in NeMo Guardrails local backend is unavailable in this build".to_string(), + )) } diff --git a/crates/core/src/plugins/nemo_guardrails/python.rs b/crates/core/src/plugins/nemo_guardrails/python.rs new file mode 100644 index 00000000..c782fa38 --- /dev/null +++ b/crates/core/src/plugins/nemo_guardrails/python.rs @@ -0,0 +1,842 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::ffi::CString; +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; + +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyModule}; + +use crate::api::llm::LlmRequest; +use crate::api::registry::{ + deregister_llm_execution_intercept, deregister_llm_stream_execution_intercept, + deregister_tool_execution_intercept, register_llm_execution_intercept, + register_llm_stream_execution_intercept, register_tool_execution_intercept, +}; +use crate::api::runtime::{LlmExecutionNextFn, LlmStreamExecutionNextFn, ToolExecutionNextFn}; +use crate::codec::anthropic::AnthropicMessagesCodec; +use crate::codec::openai_chat::OpenAIChatCodec; +use crate::codec::openai_responses::OpenAIResponsesCodec; +use crate::codec::request::{AnnotatedLlmRequest, Message}; +use crate::codec::response::AnnotatedLlmResponse; +use crate::codec::traits::{LlmCodec, LlmResponseCodec}; +use crate::error::{FlowError, Result as FlowResult}; +use crate::json::Json; +use crate::plugin::{ + PluginError, PluginRegistration, PluginRegistrationContext, Result as PluginResult, + rollback_registrations, +}; + +use super::NeMoGuardrailsConfig; + +const SUPPORT_MODULE_NAME: &str = "_nemo_guardrails_local_runtime"; +const HELPER_MODULE_NAME: &str = "_nemo_guardrails_local"; +const HELPER_FILENAME: &str = "_nemo_guardrails_local.py"; +const HELPER_SOURCE: &str = include_str!("embedded_python/_guardrails_local.py"); + +pub(super) fn register_local_backend( + config: NeMoGuardrailsConfig, + ctx: &mut PluginRegistrationContext, +) -> PluginResult<()> { + Python::initialize(); + + let plugin_config = match serde_json::to_value(config) { + Ok(Json::Object(config)) => config, + Ok(_) => { + return Err(PluginError::Internal( + "NeMo Guardrails local config did not serialize to a JSON object".to_string(), + )); + } + Err(err) => { + return Err(PluginError::Internal(format!( + "failed to serialize NeMo Guardrails local config: {err}" + ))); + } + }; + + let registrations = Python::attach(|py| { + let register_fn = load_guardrails_local_register_fn(py)?; + invoke_embedded_plugin_register(py, ®ister_fn, &plugin_config, ctx.qualify_name("")) + }) + .map_err(|err| PluginError::RegistrationFailed(err.to_string()))?; + + ctx.extend_registrations(registrations); + Ok(()) +} + +fn invoke_embedded_plugin_register( + py: Python<'_>, + register_fn: &Bound<'_, PyAny>, + plugin_config: &serde_json::Map, + namespace_prefix: String, +) -> PyResult> { + let context = Py::new( + py, + PyLocalPluginContext { + registrations: Arc::new(Mutex::new(vec![])), + namespace_prefix, + }, + )?; + let plugin_config_py = json_to_py(py, &Json::Object(plugin_config.clone()))?; + + match register_fn.call1((plugin_config_py, context.clone_ref(py))) { + Ok(_) => context.bind(py).borrow().drain_registrations(), + Err(err) => { + if let Ok(mut registrations) = context.bind(py).borrow().drain_registrations() { + rollback_registrations(&mut registrations); + } + Err(err) + } + } +} + +fn load_guardrails_local_register_fn(py: Python<'_>) -> PyResult> { + install_support_module(py)?; + let module = load_guardrails_local_module(py)?; + module.getattr("register_local_backend") +} + +fn install_support_module(py: Python<'_>) -> PyResult> { + let sys = py.import("sys")?; + let modules = sys.getattr("modules")?.cast_into::()?; + if let Some(existing) = modules.get_item(SUPPORT_MODULE_NAME)? { + return Ok(existing.cast_into::()?); + } + + let module = PyModule::new(py, SUPPORT_MODULE_NAME)?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + modules.set_item(SUPPORT_MODULE_NAME, &module)?; + Ok(module) +} + +fn load_guardrails_local_module(py: Python<'_>) -> PyResult> { + let sys = py.import("sys")?; + let modules = sys.getattr("modules")?.cast_into::()?; + if let Some(existing) = modules.get_item(HELPER_MODULE_NAME)? { + return Ok(existing.cast_into::()?); + } + + let source = CString::new(HELPER_SOURCE).unwrap(); + let filename = CString::new(HELPER_FILENAME).unwrap(); + let module_name = CString::new(HELPER_MODULE_NAME).unwrap(); + let module = PyModule::from_code(py, &source, &filename, &module_name)?; + modules.set_item(HELPER_MODULE_NAME, &module)?; + Ok(module) +} + +fn py_to_json(obj: &Bound<'_, PyAny>) -> PyResult { + pythonize::depythonize(obj).map_err(|e| { + PyErr::new::(format!("Failed to convert to JSON: {e}")) + }) +} + +fn json_to_py(py: Python<'_>, value: &Json) -> PyResult> { + let obj: Bound<'_, PyAny> = pythonize::pythonize(py, value).map_err(|e| { + PyErr::new::(format!("Failed to convert from JSON: {e}")) + })?; + Ok(obj.unbind()) +} + +fn messages_to_json(messages: &[Message]) -> PyResult { + serde_json::to_value(messages).map_err(|e| { + PyErr::new::(format!( + "Failed to serialize messages: {e}" + )) + }) +} + +#[pyclass(name = "LLMRequest", from_py_object)] +#[derive(Clone)] +struct PyLLMRequest { + inner: LlmRequest, +} + +#[pymethods] +impl PyLLMRequest { + #[new] + #[pyo3(signature = (headers, content), text_signature = "(headers: dict[str, str], content: object)")] + fn new(headers: &Bound<'_, PyAny>, content: &Bound<'_, PyAny>) -> PyResult { + let headers_json = py_to_json(headers)?; + let Json::Object(headers_map) = headers_json else { + return Err(PyErr::new::( + "headers must be a dict", + )); + }; + let content_json = py_to_json(content)?; + Ok(Self { + inner: LlmRequest { + headers: headers_map, + content: content_json, + }, + }) + } + + #[getter] + fn headers(&self, py: Python<'_>) -> PyResult> { + json_to_py(py, &Json::Object(self.inner.headers.clone())) + } + + #[getter] + fn content(&self, py: Python<'_>) -> PyResult> { + json_to_py(py, &self.inner.content) + } + + fn __repr__(&self) -> String { + "LLMRequest(...)".to_string() + } +} + +#[pyclass(name = "AnnotatedLLMRequest", from_py_object)] +#[derive(Clone)] +struct PyAnnotatedLLMRequest { + inner: AnnotatedLlmRequest, +} + +#[pymethods] +impl PyAnnotatedLLMRequest { + #[getter] + fn messages(&self, py: Python<'_>) -> PyResult> { + json_to_py(py, &messages_to_json(&self.inner.messages)?) + } + + #[setter] + fn set_messages(&mut self, value: &Bound<'_, PyAny>) -> PyResult<()> { + self.inner.messages = pythonize::depythonize(value).map_err(|e| { + PyErr::new::(format!("invalid messages: {e}")) + })?; + Ok(()) + } +} + +#[pyclass(name = "AnnotatedLLMResponse", skip_from_py_object)] +#[derive(Clone)] +struct PyAnnotatedLLMResponse { + inner: AnnotatedLlmResponse, +} + +#[pymethods] +impl PyAnnotatedLLMResponse { + fn response_text(&self) -> Option { + self.inner.response_text().map(str::to_string) + } +} + +#[pyclass(name = "OpenAIChatCodec")] +struct PyOpenAIChatCodec; + +#[pymethods] +impl PyOpenAIChatCodec { + #[new] + fn new() -> Self { + Self + } + + fn decode(&self, request: &PyLLMRequest) -> PyResult { + OpenAIChatCodec + .decode(&request.inner) + .map(|inner| PyAnnotatedLLMRequest { inner }) + .map_err(flow_to_py_err) + } + + fn encode( + &self, + annotated: &PyAnnotatedLLMRequest, + original: &PyLLMRequest, + ) -> PyResult { + OpenAIChatCodec + .encode(&annotated.inner, &original.inner) + .map(|inner| PyLLMRequest { inner }) + .map_err(flow_to_py_err) + } + + fn decode_response(&self, response: &Bound<'_, PyAny>) -> PyResult { + let response = py_to_json(response)?; + OpenAIChatCodec + .decode_response(&response) + .map(|inner| PyAnnotatedLLMResponse { inner }) + .map_err(flow_to_py_err) + } +} + +#[pyclass(name = "OpenAIResponsesCodec")] +struct PyOpenAIResponsesCodec; + +#[pymethods] +impl PyOpenAIResponsesCodec { + #[new] + fn new() -> Self { + Self + } + + fn decode(&self, request: &PyLLMRequest) -> PyResult { + OpenAIResponsesCodec + .decode(&request.inner) + .map(|inner| PyAnnotatedLLMRequest { inner }) + .map_err(flow_to_py_err) + } + + fn encode( + &self, + annotated: &PyAnnotatedLLMRequest, + original: &PyLLMRequest, + ) -> PyResult { + OpenAIResponsesCodec + .encode(&annotated.inner, &original.inner) + .map(|inner| PyLLMRequest { inner }) + .map_err(flow_to_py_err) + } + + fn decode_response(&self, response: &Bound<'_, PyAny>) -> PyResult { + let response = py_to_json(response)?; + OpenAIResponsesCodec + .decode_response(&response) + .map(|inner| PyAnnotatedLLMResponse { inner }) + .map_err(flow_to_py_err) + } +} + +#[pyclass(name = "AnthropicMessagesCodec")] +struct PyAnthropicMessagesCodec; + +#[pymethods] +impl PyAnthropicMessagesCodec { + #[new] + fn new() -> Self { + Self + } + + fn decode(&self, request: &PyLLMRequest) -> PyResult { + AnthropicMessagesCodec + .decode(&request.inner) + .map(|inner| PyAnnotatedLLMRequest { inner }) + .map_err(flow_to_py_err) + } + + fn encode( + &self, + annotated: &PyAnnotatedLLMRequest, + original: &PyLLMRequest, + ) -> PyResult { + AnthropicMessagesCodec + .encode(&annotated.inner, &original.inner) + .map(|inner| PyLLMRequest { inner }) + .map_err(flow_to_py_err) + } + + fn decode_response(&self, response: &Bound<'_, PyAny>) -> PyResult { + let response = py_to_json(response)?; + AnthropicMessagesCodec + .decode_response(&response) + .map(|inner| PyAnnotatedLLMResponse { inner }) + .map_err(flow_to_py_err) + } +} + +#[pyclass(name = "PluginContext")] +struct PyLocalPluginContext { + registrations: Arc>>, + namespace_prefix: String, +} + +impl PyLocalPluginContext { + fn qualify_name(&self, name: &str) -> String { + format!("{}{}", self.namespace_prefix, name) + } + + fn drain_registrations(&self) -> PyResult> { + let mut guard = self.registrations.lock().map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("plugin context lock poisoned: {e}")) + })?; + Ok(std::mem::take(&mut *guard)) + } +} + +#[pymethods] +impl PyLocalPluginContext { + #[pyo3(signature = (name, priority, callback), text_signature = "(name: str, priority: int, callback: object) -> None")] + fn register_llm_execution_intercept( + &self, + name: &str, + priority: i32, + callback: Py, + ) -> PyResult<()> { + let qualified_name = self.qualify_name(name); + register_llm_execution_intercept( + &qualified_name, + priority, + wrap_py_llm_exec_intercept_fn(callback), + ) + .map_err(plugin_to_py_err)?; + self.push_registration(qualified_name, RegistrationKind::Llm) + } + + #[pyo3(signature = (name, priority, callback), text_signature = "(name: str, priority: int, callback: object) -> None")] + fn register_llm_stream_execution_intercept( + &self, + name: &str, + priority: i32, + callback: Py, + ) -> PyResult<()> { + let qualified_name = self.qualify_name(name); + register_llm_stream_execution_intercept( + &qualified_name, + priority, + wrap_py_llm_stream_exec_intercept_fn(callback), + ) + .map_err(plugin_to_py_err)?; + self.push_registration(qualified_name, RegistrationKind::LlmStream) + } + + #[pyo3(signature = (name, priority, callback), text_signature = "(name: str, priority: int, callback: object) -> None")] + fn register_tool_execution_intercept( + &self, + name: &str, + priority: i32, + callback: Py, + ) -> PyResult<()> { + let qualified_name = self.qualify_name(name); + register_tool_execution_intercept( + &qualified_name, + priority, + wrap_py_tool_exec_intercept_fn(callback), + ) + .map_err(plugin_to_py_err)?; + self.push_registration(qualified_name, RegistrationKind::Tool) + } + + fn __repr__(&self) -> String { + "".to_string() + } +} + +impl PyLocalPluginContext { + fn push_registration(&self, name: String, kind: RegistrationKind) -> PyResult<()> { + let mut guard = self.registrations.lock().map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("plugin context lock poisoned: {e}")) + })?; + guard.push(PluginRegistration::new( + "plugin", + name.clone(), + Box::new(move || match kind { + RegistrationKind::Llm => deregister_llm_execution_intercept(&name) + .map(|_| ()) + .map_err(registration_failure), + RegistrationKind::LlmStream => deregister_llm_stream_execution_intercept(&name) + .map(|_| ()) + .map_err(registration_failure), + RegistrationKind::Tool => deregister_tool_execution_intercept(&name) + .map(|_| ()) + .map_err(registration_failure), + }), + )); + Ok(()) + } +} + +enum RegistrationKind { + Llm, + LlmStream, + Tool, +} + +fn registration_failure(err: FlowError) -> PluginError { + PluginError::RegistrationFailed(err.to_string()) +} + +fn plugin_to_py_err(err: FlowError) -> PyErr { + pyo3::exceptions::PyRuntimeError::new_err(err.to_string()) +} + +fn flow_to_py_err(err: FlowError) -> PyErr { + pyo3::exceptions::PyRuntimeError::new_err(err.to_string()) +} + +type PyValueFuture = Pin>> + Send>>; +type ToolExecIntercept = Arc< + dyn Fn( + &str, + Json, + ToolExecutionNextFn, + ) -> Pin> + Send>> + + Send + + Sync, +>; +type LlmExecIntercept = Arc< + dyn Fn( + &str, + LlmRequest, + LlmExecutionNextFn, + ) -> Pin> + Send>> + + Send + + Sync, +>; +type LlmStreamIntercept = Arc< + dyn Fn( + &str, + LlmRequest, + LlmStreamExecutionNextFn, + ) -> Pin< + Box< + dyn Future< + Output = FlowResult< + Pin> + Send>>, + >, + > + Send, + >, + > + Send + + Sync, +>; + +fn split_py_object_or_future( + py: Python<'_>, + result: Py, +) -> FlowResult, PyValueFuture>> { + let bound = result.bind(py); + if bound.getattr("__await__").is_ok() { + let future = pyo3_async_runtimes::tokio::into_future(result.into_bound(py)) + .map_err(|e| FlowError::Internal(e.to_string()))?; + Ok(Err(Box::pin(future) as PyValueFuture)) + } else { + Ok(Ok(result)) + } +} + +async fn resolve_py_object_or_future( + outcome: FlowResult, PyValueFuture>>, +) -> FlowResult> { + match outcome? { + Ok(value) => Ok(value), + Err(future) => future.await.map_err(|e| FlowError::Internal(e.to_string())), + } +} + +fn next_async_iter_coro(async_iter: &Arc>) -> FlowResult>> { + Python::attach(|py| { + let iter = async_iter.bind(py); + match iter.call_method0("__anext__") { + Ok(coro) => Ok(Some(coro.unbind())), + Err(error) => { + if error.is_instance_of::(py) { + Ok(None) + } else { + Err(FlowError::Internal(error.to_string())) + } + } + } + }) +} + +async fn await_async_iter_value(coro: Py) -> FlowResult> { + let future = Python::attach(|py| { + pyo3_async_runtimes::tokio::into_future(coro.into_bound(py)) + .map_err(|e| FlowError::Internal(e.to_string())) + })?; + + match future.await { + Ok(result) => Python::attach(|py| { + py_to_json(result.bind(py)) + .map(Some) + .map_err(|e| FlowError::Internal(e.to_string())) + }), + Err(error) => Python::attach(|py| { + if error.is_instance_of::(py) { + Ok(None) + } else { + Err(FlowError::Internal(error.to_string())) + } + }), + } +} + +async fn forward_async_iter( + async_iter: Arc>, + tx: tokio::sync::mpsc::Sender>, +) { + loop { + let next_value = match next_async_iter_coro(&async_iter) { + Ok(None) => break, + Ok(Some(coro)) => await_async_iter_value(coro).await, + Err(error) => Err(error), + }; + + match next_value { + Ok(Some(value)) => { + if tx.send(Ok(value)).await.is_err() { + break; + } + } + Ok(None) => break, + Err(error) => { + let _ = tx.send(Err(error)).await; + break; + } + } + } +} + +fn stream_from_async_iter( + async_iter: Py, +) -> FlowResult> + Send>>> { + let (tx, rx) = tokio::sync::mpsc::channel::>(32); + let task_locals = Python::attach(|py| { + pyo3_async_runtimes::tokio::get_current_locals(py) + .map_err(|e: pyo3::PyErr| FlowError::Internal(e.to_string())) + })?; + + let async_iter = Arc::new(async_iter); + tokio::spawn(pyo3_async_runtimes::tokio::scope(task_locals, async move { + forward_async_iter(async_iter, tx).await; + })); + + Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx))) +} + +#[pyclass] +struct PyToolNextFn { + inner: ToolExecutionNextFn, +} + +#[pymethods] +impl PyToolNextFn { + fn __call__<'py>( + &self, + py: Python<'py>, + args: Bound<'_, PyAny>, + ) -> PyResult> { + let args = py_to_json(&args)?; + let next = self.inner.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let result = next(args) + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + Python::attach(|py| json_to_py(py, &result)) + }) + } +} + +#[pyclass] +struct PyLlmNextFn { + inner: LlmExecutionNextFn, +} + +#[pymethods] +impl PyLlmNextFn { + fn __call__<'py>(&self, py: Python<'py>, request: PyLLMRequest) -> PyResult> { + let next = self.inner.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let result = next(request.inner) + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + Python::attach(|py| json_to_py(py, &result)) + }) + } +} + +#[pyclass(name = "LlmStream")] +struct PyLlmStream { + receiver: tokio::sync::Mutex>>, +} + +#[pymethods] +impl PyLlmStream { + fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __anext__<'py>(&self, py: Python<'py>) -> PyResult> { + let receiver_ptr = &self.receiver + as *const tokio::sync::Mutex>>; + let receiver_ref = unsafe { &*receiver_ptr }; + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let mut guard = receiver_ref.lock().await; + match guard.recv().await { + None => Err(PyErr::new::( + "stream exhausted", + )), + Some(Ok(value)) => Python::attach(|py| json_to_py(py, &value)), + Some(Err(err)) => Err(PyErr::new::( + err.to_string(), + )), + } + }) + } +} + +#[pyclass] +struct PyLlmStreamNextFn { + inner: LlmStreamExecutionNextFn, +} + +#[pymethods] +impl PyLlmStreamNextFn { + fn __call__<'py>(&self, py: Python<'py>, request: PyLLMRequest) -> PyResult> { + let next = self.inner.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let rust_stream = next(request.inner) + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + let (tx, rx) = tokio::sync::mpsc::channel::>(32); + tokio::spawn(async move { + use tokio_stream::StreamExt; + let mut stream = rust_stream; + while let Some(item) = stream.next().await { + if tx.send(item).await.is_err() { + break; + } + } + }); + Ok(PyLlmStream { + receiver: tokio::sync::Mutex::new(rx), + }) + }) + } +} + +fn wrap_py_tool_exec_intercept_fn(py_fn: Py) -> ToolExecIntercept { + let py_fn = Arc::new(py_fn); + Arc::new(move |name: &str, args: Json, next: ToolExecutionNextFn| { + let py_fn = py_fn.clone(); + let name = name.to_string(); + Box::pin(async move { + let outcome: FlowResult> = Python::attach(|py| { + let py_args = + json_to_py(py, &args).map_err(|e: PyErr| FlowError::Internal(e.to_string()))?; + let py_next = PyToolNextFn { inner: next }; + let result = py_fn + .call1( + py, + ( + &name, + py_args, + py_next + .into_pyobject(py) + .map_err(|e| FlowError::Internal(e.to_string()))? + .into_any(), + ), + ) + .map_err(|e: PyErr| FlowError::Internal(e.to_string()))?; + let bound = result.bind(py); + if bound.getattr("__await__").is_ok() { + let future = pyo3_async_runtimes::tokio::into_future(result.into_bound(py)) + .map_err(|e| FlowError::Internal(e.to_string()))?; + Ok(Err(Box::pin(future) as PyValueFuture)) + } else { + let json = + py_to_json(bound).map_err(|e: PyErr| FlowError::Internal(e.to_string()))?; + Ok(Ok(json)) + } + }); + + match outcome? { + Ok(json) => Ok(json), + Err(future) => { + let py_result = future + .await + .map_err(|e| FlowError::Internal(e.to_string()))?; + Python::attach(|py| { + py_to_json(py_result.bind(py)) + .map_err(|e: PyErr| FlowError::Internal(e.to_string())) + }) + } + } + }) + }) +} + +fn wrap_py_llm_exec_intercept_fn(py_fn: Py) -> LlmExecIntercept { + let py_fn = Arc::new(py_fn); + Arc::new( + move |name: &str, request: LlmRequest, next: LlmExecutionNextFn| { + let py_fn = py_fn.clone(); + let name = name.to_string(); + Box::pin(async move { + let outcome: FlowResult> = Python::attach(|py| { + let py_req = PyLLMRequest { inner: request }; + let py_next = PyLlmNextFn { inner: next }; + let result = py_fn + .call1( + py, + ( + &name, + py_req + .into_pyobject(py) + .map_err(|e| FlowError::Internal(e.to_string()))? + .into_any(), + py_next + .into_pyobject(py) + .map_err(|e| FlowError::Internal(e.to_string()))? + .into_any(), + ), + ) + .map_err(|e: PyErr| FlowError::Internal(e.to_string()))?; + let bound = result.bind(py); + if bound.getattr("__await__").is_ok() { + let future = pyo3_async_runtimes::tokio::into_future(result.into_bound(py)) + .map_err(|e| FlowError::Internal(e.to_string()))?; + Ok(Err(Box::pin(future) as PyValueFuture)) + } else { + let json = py_to_json(bound) + .map_err(|e: PyErr| FlowError::Internal(e.to_string()))?; + Ok(Ok(json)) + } + }); + + match outcome? { + Ok(json) => Ok(json), + Err(future) => { + let py_result = future + .await + .map_err(|e| FlowError::Internal(e.to_string()))?; + Python::attach(|py| { + py_to_json(py_result.bind(py)) + .map_err(|e: PyErr| FlowError::Internal(e.to_string())) + }) + } + } + }) + }, + ) +} + +fn wrap_py_llm_stream_exec_intercept_fn(py_fn: Py) -> LlmStreamIntercept { + let py_fn = Arc::new(py_fn); + Arc::new( + move |_name: &str, request: LlmRequest, next: LlmStreamExecutionNextFn| { + let py_fn = py_fn.clone(); + Box::pin(async move { + let async_iter = resolve_py_object_or_future(Python::attach(|py| { + let py_req = PyLLMRequest { inner: request }; + let py_next = PyLlmStreamNextFn { inner: next }; + let result = py_fn + .call1( + py, + ( + py_req + .into_pyobject(py) + .map_err(|e: PyErr| FlowError::Internal(e.to_string()))? + .into_any(), + py_next + .into_pyobject(py) + .map_err(|e: PyErr| FlowError::Internal(e.to_string()))? + .into_any(), + ), + ) + .map_err(|e: PyErr| FlowError::Internal(e.to_string()))?; + split_py_object_or_future(py, result) + })) + .await?; + + stream_from_async_iter(async_iter) + }) + }, + ) +} diff --git a/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs b/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs index 0823bbac..80ae0c2b 100644 --- a/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs +++ b/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs @@ -42,7 +42,6 @@ const TEST_TIMEOUT: Duration = Duration::from_secs(5); fn reset_runtime() { let _ = clear_plugin_configuration(); - crate::plugins::nemo_guardrails::component::clear_local_backend_provider().unwrap(); crate::shared_runtime::reset_runtime_owner_for_tests(); let context = global_context(); *context.write().unwrap() = NemoRelayContextState::new(); @@ -976,26 +975,20 @@ fn unknown_fields_follow_policy() { assert!(ignored.diagnostics.is_empty()); } +#[cfg(not(feature = "python"))] #[test] -fn enabled_local_initialization_fails_fast_until_backend_exists() { - let _guard = crate::plugins::nemo_guardrails::test_mutex() - .lock() - .unwrap_or_else(|err| err.into_inner()); - reset_runtime(); - - let error = futures::executor::block_on(initialize_plugins(plugin_config(json!({ +fn local_mode_validation_reports_missing_python_feature() { + let diagnostics = validate_plugin_config(&plugin_config(json!({ "mode": "local", "codec": "openai_chat", "config_path": "./rails" - })))) - .unwrap_err(); + }))) + .unwrap(); - match error { - crate::plugin::PluginError::RegistrationFailed(message) => { - assert!(message.contains("unavailable in this runtime")); - } - other => panic!("unexpected error: {other}"), - } + assert!(diagnostics.diagnostics.iter().any(|diag| { + diag.message + .contains("local mode requires a build with the 'python' feature enabled") + })); } #[test] @@ -1024,34 +1017,5 @@ fn enabled_unknown_mode_initialization_fails_fast_when_policy_ignores_validation } } -#[test] -fn enabled_local_initialization_dispatches_through_installed_provider() { - let _guard = crate::plugins::nemo_guardrails::test_mutex() - .lock() - .unwrap_or_else(|err| err.into_inner()); - reset_runtime(); - - let provider_called = Arc::new(AtomicBool::new(false)); - let provider_called_clone = Arc::clone(&provider_called); - crate::plugins::nemo_guardrails::component::register_local_backend_provider(Arc::new( - move |config, _ctx| { - provider_called_clone.store(true, Ordering::SeqCst); - assert_eq!(config.mode, "local"); - assert_eq!(config.config_path.as_deref(), Some("./rails")); - Ok(()) - }, - )) - .unwrap(); - - futures::executor::block_on(initialize_plugins(plugin_config(json!({ - "mode": "local", - "codec": "openai_chat", - "config_path": "./rails" - })))) - .unwrap(); - - assert!(provider_called.load(Ordering::SeqCst)); -} - #[path = "remote_tests.rs"] mod remote_tests; diff --git a/crates/python/Cargo.toml b/crates/python/Cargo.toml index 6f20de35..34655038 100644 --- a/crates/python/Cargo.toml +++ b/crates/python/Cargo.toml @@ -18,7 +18,7 @@ name = "_native" crate-type = ["cdylib", "rlib"] [dependencies] -nemo-relay = { workspace = true, features = ["otel", "openinference"] } +nemo-relay = { workspace = true, features = ["otel", "openinference", "python"] } nemo-relay-adaptive = { workspace = true, features = ["redis-backend"] } pyo3 = { version = "0.28.2", features = ["abi3", "abi3-py311", "experimental-inspect", "macros"] } pyo3-async-runtimes = { version = "0.28.0", features = ["tokio-runtime"] } diff --git a/crates/python/src/lib.rs b/crates/python/src/lib.rs index 8328c265..c6878f41 100644 --- a/crates/python/src/lib.rs +++ b/crates/python/src/lib.rs @@ -20,18 +20,10 @@ //! - `py_adaptive` — Python-facing adaptive helpers (`set_latency_sensitivity`) //! - `py_plugin` — Python-facing generic plugin config/registration helpers //! - `convert` — JSON ↔ Python conversion utilities -use nemo_relay::plugin::{PluginRegistrationContext, Result as PluginResult}; -use nemo_relay::plugins::nemo_guardrails::component::{ - NeMoGuardrailsConfig, register_local_backend_provider, -}; use nemo_relay::shared_runtime::initialize_shared_runtime_binding; use nemo_relay_adaptive::plugin_component::register_adaptive_component; use pyo3::prelude::*; use pyo3::types::{PyDict, PyModule}; -use serde_json::Value as Json; -use std::ffi::CString; -use std::path::{Path, PathBuf}; -use std::sync::Arc; mod convert; #[doc(hidden)] @@ -48,11 +40,6 @@ pub mod py_types; #[cfg(test)] mod test_support; -const EMBEDDED_GUARDRAILS_LOCAL_MODULE_NAME: &str = "nemo_relay._guardrails_local"; -const EMBEDDED_GUARDRAILS_LOCAL_FILENAME: &str = "nemo_relay/_guardrails_local.py"; -const EMBEDDED_GUARDRAILS_LOCAL_SOURCE: &str = - include_str!("../embedded_python/_guardrails_local.py"); - /// The `_native` PyO3 module entry point. Registers all types and functions. #[pymodule] fn _native(m: &Bound<'_, PyModule>) -> PyResult<()> { @@ -66,13 +53,6 @@ fn _native(m: &Bound<'_, PyModule>) -> PyResult<()> { "failed to register adaptive plugin component: {e}" )) })?; - register_local_backend_provider(Arc::new(register_python_local_guardrails_backend)).map_err( - |e| { - pyo3::exceptions::PyRuntimeError::new_err(format!( - "failed to register NeMo Guardrails local backend provider: {e}" - )) - }, - )?; py_types::register(m)?; py_api::register(m)?; py_plugin::register(m)?; @@ -92,97 +72,6 @@ fn install_native_module_alias(m: &Bound<'_, PyModule>) -> PyResult<()> { Ok(()) } -fn register_python_local_guardrails_backend( - config: NeMoGuardrailsConfig, - ctx: &mut PluginRegistrationContext, -) -> PluginResult<()> { - let plugin_config = match serde_json::to_value(config) { - Ok(Json::Object(config)) => config, - Ok(_) => { - return Err(nemo_relay::plugin::PluginError::Internal( - "NeMo Guardrails local config did not serialize to a JSON object".to_string(), - )); - } - Err(err) => { - return Err(nemo_relay::plugin::PluginError::Internal(format!( - "failed to serialize NeMo Guardrails local config: {err}" - ))); - } - }; - - let registrations = Python::attach(|py| { - let register_fn = load_guardrails_local_register_fn(py)?; - let namespace_prefix = ctx.qualify_name(""); - crate::py_plugin::invoke_python_plugin_register( - py, - "nemo_guardrails", - ®ister_fn, - &plugin_config, - namespace_prefix, - ) - }) - .map_err(|err| nemo_relay::plugin::PluginError::RegistrationFailed(err.to_string()))?; - - ctx.extend_registrations(registrations); - Ok(()) -} - -fn load_guardrails_local_register_fn(py: Python<'_>) -> PyResult> { - let module = load_embedded_guardrails_local_module(py)?; - module.getattr("register_local_backend") -} - -fn load_embedded_guardrails_local_module(py: Python<'_>) -> PyResult> { - ensure_nemo_relay_package_importable(py)?; - - let sys = py.import("sys")?; - let modules = sys.getattr("modules")?.cast_into::()?; - if let Some(existing) = modules.get_item(EMBEDDED_GUARDRAILS_LOCAL_MODULE_NAME)? { - return Ok(existing.cast_into::()?); - } - - let source = CString::new(EMBEDDED_GUARDRAILS_LOCAL_SOURCE).unwrap(); - let filename = CString::new(EMBEDDED_GUARDRAILS_LOCAL_FILENAME).unwrap(); - let module_name = CString::new(EMBEDDED_GUARDRAILS_LOCAL_MODULE_NAME).unwrap(); - let module = PyModule::from_code(py, &source, &filename, &module_name)?; - modules.set_item(EMBEDDED_GUARDRAILS_LOCAL_MODULE_NAME, &module)?; - if let Ok(package) = py.import("nemo_relay") { - let _ = package.setattr("_guardrails_local", &module); - } - Ok(module) -} - -fn ensure_nemo_relay_package_importable(py: Python<'_>) -> PyResult<()> { - if py.import("nemo_relay").is_ok() { - return Ok(()); - } - - let source_python_dir = embedded_guardrails_source_python_dir(); - if !source_python_dir.exists() { - return Ok(()); - } - - prepend_python_path_if_missing(py, &source_python_dir)?; - let _ = py.import("nemo_relay")?; - Ok(()) -} - -fn embedded_guardrails_source_python_dir() -> PathBuf { - PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../python") -} - -fn prepend_python_path_if_missing(py: Python<'_>, path: &Path) -> PyResult<()> { - let sys = py.import("sys")?; - let sys_path = sys.getattr("path")?; - let path_str = path.to_string_lossy(); - - if !sys_path.contains(path_str.as_ref())? { - sys_path.call_method1("insert", (0, path_str.as_ref()))?; - } - - Ok(()) -} - #[cfg(test)] #[path = "../tests/coverage/coverage_tests.rs"] mod coverage_tests; diff --git a/crates/python/tests/coverage/coverage_tests.rs b/crates/python/tests/coverage/coverage_tests.rs index 75a38ea1..98d00c92 100644 --- a/crates/python/tests/coverage/coverage_tests.rs +++ b/crates/python/tests/coverage/coverage_tests.rs @@ -46,7 +46,8 @@ fn python_package_dir() -> PathBuf { } fn embedded_guardrails_local_source_path() -> PathBuf { - PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("embedded_python/_guardrails_local.py") + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("../core/src/plugins/nemo_guardrails/embedded_python/_guardrails_local.py") } fn fake_guardrails_module_prelude(module_name: &str, python_dir: &str) -> String { @@ -60,6 +61,7 @@ sys.path.insert(0, {python_dir:?}) MODULE_NAME = {module_name:?} fake_root = types.ModuleType(MODULE_NAME) +fake_root.__version__ = "0.22.0" fake_options = types.ModuleType(MODULE_NAME + ".rails.llm.options") class Result: @@ -126,14 +128,59 @@ import pathlib import sys import types -import nemo_relay +from nemo_relay._native import LLMRequest + +class _AnnotatedRequest: + def __init__(self, messages): + self.messages = [dict(message) for message in messages] + +class _AnnotatedResponse: + def __init__(self, response): + self._response = response + + def response_text(self): + try: + return self._response["choices"][0]["message"]["content"] + except Exception: + return None + +class _BaseCodec: + def decode(self, request): + return _AnnotatedRequest(request.content.get("messages", [])) + + def encode(self, annotated, original): + content = dict(original.content) + content["messages"] = annotated.messages + return LLMRequest(original.headers, content) + + def decode_response(self, response): + return _AnnotatedResponse(response) + +class OpenAIChatCodec(_BaseCodec): + pass + +class OpenAIResponsesCodec(_BaseCodec): + pass + +class AnthropicMessagesCodec(_BaseCodec): + pass + +class PluginContext: + pass + +runtime_module = types.ModuleType("_nemo_guardrails_local_runtime") +runtime_module.LLMRequest = LLMRequest +runtime_module.OpenAIChatCodec = OpenAIChatCodec +runtime_module.OpenAIResponsesCodec = OpenAIResponsesCodec +runtime_module.AnthropicMessagesCodec = AnthropicMessagesCodec +runtime_module.PluginContext = PluginContext +sys.modules["_nemo_guardrails_local_runtime"] = runtime_module GUARDRAILS_LOCAL_SOURCE_PATH = pathlib.Path({source_path:?}) -guardrails_local_module = types.ModuleType("nemo_relay._guardrails_local") +guardrails_local_module = types.ModuleType("_nemo_guardrails_local") guardrails_local_module.__file__ = str(GUARDRAILS_LOCAL_SOURCE_PATH) -guardrails_local_module.__package__ = "nemo_relay" -sys.modules["nemo_relay._guardrails_local"] = guardrails_local_module -setattr(nemo_relay, "_guardrails_local", guardrails_local_module) +guardrails_local_module.__package__ = "" +sys.modules["_nemo_guardrails_local"] = guardrails_local_module exec( compile( GUARDRAILS_LOCAL_SOURCE_PATH.read_text(), @@ -234,7 +281,6 @@ fn with_event_loop(py: Python<'_>, f: impl FnOnce(Bound<'_, PyAny>) -> T) -> fn reset_runtime_state() { let _ = clear_plugin_configuration(); - nemo_relay::plugins::nemo_guardrails::component::clear_local_backend_provider().unwrap(); let context = global_context(); *context.write().unwrap() = NemoRelayContextState::new(); } @@ -269,7 +315,7 @@ fn test_native_pymodule_entrypoint_registers_bindings() { } #[test] -fn test_native_pymodule_entrypoint_installs_nemo_guardrails_local_provider() { +fn test_native_pymodule_entrypoint_registers_bindings_without_local_provider_install() { let _python = crate::test_support::init_python_test(); Python::attach(|py| { let module = PyModule::new(py, "_native_guardrails_provider").unwrap(); @@ -350,7 +396,7 @@ class LLMRails: {embedded_loader} from nemo_relay._native import LLMRequest -from nemo_relay._guardrails_local import register_local_backend +from _nemo_guardrails_local import register_local_backend {context_class} @@ -515,7 +561,7 @@ class LLMRails: {embedded_loader} -from nemo_relay._guardrails_local import register_local_backend +from _nemo_guardrails_local import register_local_backend {context_class} @@ -615,7 +661,7 @@ class LLMRails: {embedded_loader} from nemo_relay._native import LLMRequest -from nemo_relay._guardrails_local import register_local_backend +from _nemo_guardrails_local import register_local_backend {context_class} From e2492389d33f0e581f73b2f233557b345b7c0d8e Mon Sep 17 00:00:00 2001 From: Will Killian Date: Thu, 4 Jun 2026 16:49:25 -0400 Subject: [PATCH 10/20] refactor: own local NeMo Guardrails runtime in Rust Signed-off-by: Will Killian --- .../embedded_python/_guardrails_local.py | 607 ------- .../src/plugins/nemo_guardrails/python.rs | 1543 +++++++++-------- .../nemo_guardrails/local_python_tests.rs | 217 +++ crates/python/src/lib.rs | 4 + .../python/tests/coverage/coverage_tests.rs | 852 +-------- .../nemo_guardrails_coverage_tests.rs | 819 +++++++++ 6 files changed, 1904 insertions(+), 2138 deletions(-) delete mode 100644 crates/core/src/plugins/nemo_guardrails/embedded_python/_guardrails_local.py create mode 100644 crates/core/tests/unit/plugins/nemo_guardrails/local_python_tests.rs create mode 100644 crates/python/tests/coverage/nemo_guardrails_coverage_tests.rs diff --git a/crates/core/src/plugins/nemo_guardrails/embedded_python/_guardrails_local.py b/crates/core/src/plugins/nemo_guardrails/embedded_python/_guardrails_local.py deleted file mode 100644 index fd863640..00000000 --- a/crates/core/src/plugins/nemo_guardrails/embedded_python/_guardrails_local.py +++ /dev/null @@ -1,607 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Internal helpers for the built-in NeMo Guardrails local backend.""" - -from __future__ import annotations - -import asyncio -import importlib -import json -from collections.abc import Callable -from typing import Any, NamedTuple, Protocol, cast - -from _nemo_guardrails_local_runtime import ( - AnthropicMessagesCodec, - LLMRequest, - OpenAIChatCodec, - OpenAIResponsesCodec, - PluginContext, -) - -Json = Any - -_DEFAULT_PRIORITY = 100 -_SUPPORTED_NEMOGUARDRAILS_VERSION = "0.22.0" - - -class NeMoGuardrailsDependencyError(RuntimeError): - """Raised when the optional ``nemoguardrails`` dependency is unavailable.""" - - -class NeMoGuardrailsViolation(RuntimeError): - """Raised when NeMo Guardrails blocks or cannot safely apply a rail result.""" - - def __init__( - self, - message: str, - *, - rail_type: str, - rail: str | None = None, - content: str | None = None, - ) -> None: - super().__init__(message) - self.rail_type = rail_type - self.rail = rail - self.content = content - - -class _GuardrailsCodec(Protocol): - """Codec shape required by the local backend.""" - - -class _GuardrailsRuntimeImports(NamedTuple): - """Resolved Python symbols required by the local Guardrails backend.""" - - rails_config_cls: Any - llm_rails_cls: Any - rail_type: Any - rail_status: Any - - -_CODECS: dict[str, Callable[[], _GuardrailsCodec]] = { - "openai_chat": OpenAIChatCodec, - "openai_responses": OpenAIResponsesCodec, - "anthropic_messages": AnthropicMessagesCodec, -} - - -def _load_nemoguardrails(module_name: str | None) -> _GuardrailsRuntimeImports: - root_module = module_name or "nemoguardrails" - try: - guardrails = cast(Any, importlib.import_module(root_module)) - options = cast(Any, importlib.import_module(f"{root_module}.rails.llm.options")) - except ImportError as error: - if error.name == root_module: - raise NeMoGuardrailsDependencyError( - "NeMo Guardrails is required for the built-in NeMo Guardrails local backend. " - "Install it with: pip install nemoguardrails==0.22.0" - ) from error - raise NeMoGuardrailsDependencyError( - "NeMo Guardrails local backend could not import a required dependency: " - f"{error.name or error}. Install the full NeMo Guardrails runtime dependencies." - ) from error - - version = getattr(guardrails, "__version__", None) - if version != _SUPPORTED_NEMOGUARDRAILS_VERSION: - raise NeMoGuardrailsDependencyError( - "NeMo Guardrails local backend requires nemoguardrails==" - f"{_SUPPORTED_NEMOGUARDRAILS_VERSION}, but found {version!r}. " - "Install it with: pip install nemoguardrails==0.22.0" - ) - - return _GuardrailsRuntimeImports( - rails_config_cls=guardrails.RailsConfig, - llm_rails_cls=guardrails.LLMRails, - rail_type=options.RailType, - rail_status=options.RailStatus, - ) - - -def _status_value(status: Any) -> str: - return str(getattr(status, "value", status)).lower() - - -def _messages_from_annotated(annotated: Any) -> list[dict[str, Any]]: - return [dict(message) for message in annotated.messages] - - -async def _apply_input_rails( - rails: Any, - rail_type: Any, - rail_status: Any, - codec: _GuardrailsCodec, - request: LLMRequest, -) -> tuple[LLMRequest, list[dict[str, Any]]]: - annotated_request = codec.decode(request) - messages = _messages_from_annotated(annotated_request) - input_result = await rails.check_async(messages, rail_types=[rail_type.INPUT]) - input_status = _status_value(input_result.status) - if input_status == _status_value(rail_status.BLOCKED): - _raise_blocked(input_result, "input") - if input_status == _status_value(rail_status.MODIFIED): - input_content = getattr(input_result, "content", "") - annotated_request.messages = _replace_last_role_content( - messages, - "user", - "" if input_content is None else str(input_content), - ) - request = codec.encode(annotated_request, request) - messages = _messages_from_annotated(annotated_request) - return request, messages - - -def _replace_last_role_content(messages: list[dict[str, Any]], role: str, content: str) -> list[dict[str, Any]]: - updated = [dict(message) for message in messages] - for index in range(len(updated) - 1, -1, -1): - if updated[index].get("role") == role: - updated[index]["content"] = content - return updated - raise NeMoGuardrailsViolation( - f"NeMo Guardrails returned modified {role} content but no {role} message was present.", - rail_type="input" if role == "user" else "output", - content=content, - ) - - -def _tool_input_content(name: str, args: Json) -> str: - return json.dumps( - { - "tool_name": name, - "arguments": args, - }, - sort_keys=True, - separators=(",", ":"), - ) - - -def _tool_output_content(name: str, args: Json, result: Json) -> str: - return json.dumps( - { - "tool_name": name, - "arguments": args, - "result": result, - }, - sort_keys=True, - separators=(",", ":"), - ) - - -def _modified_tool_payload(content: str, field: str) -> Json: - try: - value = json.loads(content) - except json.JSONDecodeError as error: - raise NeMoGuardrailsViolation( - f"NeMo Guardrails returned modified tool {field} content that is not valid JSON.", - rail_type=f"tool_{field}", - content=content, - ) from error - - if not isinstance(value, dict) or field not in value: - raise NeMoGuardrailsViolation( - f"NeMo Guardrails returned modified tool {field} content without a '{field}' field.", - rail_type=f"tool_{field}", - content=content, - ) - return cast(Json, value[field]) - - -def _raise_modified_output_not_supported(result: Any) -> None: - output_content = getattr(result, "content", "") - output_rail = getattr(result, "rail", None) - raise NeMoGuardrailsViolation( - "NeMo Guardrails output rail returned modified content, but the local backend " - "does not rewrite provider responses yet.", - rail_type="output", - rail=None if output_rail is None else str(output_rail), - content="" if output_content is None else str(output_content), - ) - - -async def _check_output_rails( - rails: Any, - rail_type: Any, - rail_status: Any, - messages: list[dict[str, Any]], - response_text: str | None, -) -> None: - if response_text is None: - return - - output_messages = [*messages, {"role": "assistant", "content": response_text}] - output_result = await rails.check_async(output_messages, rail_types=[rail_type.OUTPUT]) - output_status = _status_value(output_result.status) - if output_status == _status_value(rail_status.BLOCKED): - _raise_blocked(output_result, "output") - if output_status == _status_value(rail_status.MODIFIED): - _raise_modified_output_not_supported(output_result) - - -def _has_streaming_output_rails(rails: Any) -> bool: - return bool(getattr(rails.config.rails.output, "flows", [])) - - -def _output_streaming_config(rails: Any) -> Any | None: - return getattr(rails.config.rails.output, "streaming", None) - - -def _guardrails_streaming_enabled(rails: Any) -> bool: - streaming = _output_streaming_config(rails) - return bool(streaming is not None and getattr(streaming, "enabled", False)) - - -def _extract_stream_text(codec_name: str, chunk: Json) -> str | None: - if not isinstance(chunk, dict): - return None - - if codec_name == "openai_chat": - choices = chunk.get("choices") - if not isinstance(choices, list): - return None - parts: list[str] = [] - for choice in choices: - if not isinstance(choice, dict): - continue - delta = choice.get("delta") - if not isinstance(delta, dict): - continue - content = delta.get("content") - if isinstance(content, str) and content: - parts.append(content) - return "".join(parts) if parts else None - - if codec_name == "openai_responses": - if chunk.get("type") == "response.output_text.delta": - delta = chunk.get("delta") - return delta if isinstance(delta, str) and delta else None - return None - - if codec_name == "anthropic_messages": - if chunk.get("type") != "content_block_delta": - return None - delta = chunk.get("delta") - if not isinstance(delta, dict): - return None - if delta.get("type") != "text_delta": - return None - text = delta.get("text") - return text if isinstance(text, str) and text else None - - return None - - -def _guardrails_stream_error_message(chunk: str) -> str | None: - try: - payload = json.loads(chunk) - except json.JSONDecodeError: - return None - if not isinstance(payload, dict): - return None - error = payload.get("error") - if not isinstance(error, dict): - return None - if error.get("type") != "guardrails_violation": - return None - message = error.get("message") - return message if isinstance(message, str) and message else "Blocked by output rails." - - -async def _queue_string_stream(queue: "asyncio.Queue[str | None]"): - while True: - item = await queue.get() - if item is None: - return - yield item - - -async def _monitor_streaming_output_rails( - *, - rails: Any, - messages: list[dict[str, Any]], - text_queue: "asyncio.Queue[str | None]", - blocked: dict[str, str | None], -) -> None: - guarded_stream = rails.stream_async( - messages=messages, - generator=_queue_string_stream(text_queue), - include_metadata=False, - ) - async for chunk in guarded_stream: - if isinstance(chunk, str): - message = _guardrails_stream_error_message(chunk) - if message is not None: - blocked["message"] = message - return - - -def _raise_streaming_output_blocked(blocked_message: str) -> None: - raise NeMoGuardrailsViolation( - f"NeMo Guardrails output rail blocked the LLM call: {blocked_message}", - rail_type="output", - content=blocked_message, - ) - - -def _build_guardrails_config(config: dict[str, Any], rails_config_cls: Any) -> Any: - if config.get("config_path") is not None: - return rails_config_cls.from_path(cast(str, config["config_path"])) - return rails_config_cls.from_content( - colang_content=cast(str | None, config.get("colang_content")), - yaml_content=cast(str, config["config_yaml"]), - ) - - -def _resolve_codec(config: dict[str, Any]) -> tuple[str, _GuardrailsCodec]: - codec_name = cast(str | None, config.get("codec")) - if codec_name is None or codec_name not in _CODECS: - raise RuntimeError("local NeMo Guardrails backend requires a supported codec") - return codec_name, _CODECS[codec_name]() - - -async def _check_tool_input( - rails: Any, - rail_type: Any, - rail_status: Any, - tool_name: str, - args: Json, -) -> Json: - input_result = await rails.check_async( - [{"role": "user", "content": _tool_input_content(tool_name, args)}], - rail_types=[rail_type.INPUT], - ) - input_status = _status_value(input_result.status) - if input_status == _status_value(rail_status.BLOCKED): - _raise_blocked(input_result, "tool_input") - if input_status == _status_value(rail_status.MODIFIED): - input_content = getattr(input_result, "content", "") - return _modified_tool_payload( - "" if input_content is None else str(input_content), - "arguments", - ) - return args - - -async def _check_tool_output( - rails: Any, - rail_type: Any, - rail_status: Any, - tool_name: str, - args: Json, - result: Json, -) -> Json: - output_result = await rails.check_async( - [ - {"role": "user", "content": _tool_input_content(tool_name, args)}, - { - "role": "assistant", - "content": _tool_output_content(tool_name, args, result), - }, - ], - rail_types=[rail_type.OUTPUT], - ) - output_status = _status_value(output_result.status) - if output_status == _status_value(rail_status.BLOCKED): - _raise_blocked(output_result, "tool_output") - if output_status == _status_value(rail_status.MODIFIED): - output_content = getattr(output_result, "content", "") - return _modified_tool_payload( - "" if output_content is None else str(output_content), - "result", - ) - return result - - -def _make_llm_intercept( - *, - rails: Any, - rail_type: Any, - rail_status: Any, - codec: _GuardrailsCodec, - enable_input: bool, - enable_output: bool, -): - async def intercept(_name: str, request: LLMRequest, next_call): - current_request = request - messages = _messages_from_annotated(codec.decode(current_request)) - - if enable_input: - current_request, messages = await _apply_input_rails( - rails, - rail_type, - rail_status, - codec, - current_request, - ) - - response = await next_call(current_request) - if not enable_output: - return response - - annotated_response = codec.decode_response(response) - await _check_output_rails( - rails, - rail_type, - rail_status, - messages, - annotated_response.response_text(), - ) - return response - - return intercept - - -def _make_llm_stream_intercept( - *, - rails: Any, - rail_type: Any, - rail_status: Any, - codec_name: str, - codec: _GuardrailsCodec, - enable_input: bool, - enable_output: bool, -): - async def stream_intercept(request: LLMRequest, next_call): - current_request = request - messages = _messages_from_annotated(codec.decode(current_request)) - if enable_input: - current_request, messages = await _apply_input_rails( - rails, - rail_type, - rail_status, - codec, - current_request, - ) - - stream = await next_call(current_request) - if not enable_output: - return stream - if not _has_streaming_output_rails(rails): - return stream - if not _guardrails_streaming_enabled(rails): - raise RuntimeError( - "local NeMo Guardrails streaming output rails require " - "rails.output.streaming.enabled = true in the Guardrails config." - ) - - streaming_config = _output_streaming_config(rails) - if streaming_config is None or not getattr(streaming_config, "stream_first", True): - raise RuntimeError( - "local NeMo Guardrails streaming output rails currently require " - "rails.output.streaming.stream_first = true." - ) - - text_queue: asyncio.Queue[str | None] = asyncio.Queue() - block_state: dict[str, str | None] = {"message": None} - - async def guarded_provider_stream(): - monitor = asyncio.create_task( - _monitor_streaming_output_rails( - rails=rails, - messages=messages, - text_queue=text_queue, - blocked=block_state, - ) - ) - try: - async for chunk in stream: - if block_state["message"] is not None: - _raise_streaming_output_blocked(block_state["message"]) - - text = _extract_stream_text(codec_name, chunk) - if text is not None: - await text_queue.put(text) - - yield chunk - - if block_state["message"] is not None: - _raise_streaming_output_blocked(block_state["message"]) - finally: - await text_queue.put(None) - await monitor - if block_state["message"] is not None: - _raise_streaming_output_blocked(block_state["message"]) - - return guarded_provider_stream() - - return stream_intercept - - -def _make_tool_intercept( - *, - rails: Any, - rail_type: Any, - rail_status: Any, - enable_tool_input: bool, - enable_tool_output: bool, -): - async def tool_intercept(tool_name: str, args: Json, next_call): - current_args = args - - if enable_tool_input: - current_args = await _check_tool_input( - rails, - rail_type, - rail_status, - tool_name, - current_args, - ) - - tool_result = await next_call(current_args) - if not enable_tool_output: - return tool_result - - return await _check_tool_output( - rails, - rail_type, - rail_status, - tool_name, - current_args, - tool_result, - ) - - return tool_intercept - - -def _raise_blocked(result: Any, rail_type: str) -> None: - rail_value = getattr(result, "rail", None) - rail = None if rail_value is None else str(rail_value) - content = getattr(result, "content", "") - detail = f" by rail '{rail}'" if rail else "" - subject = "LLM call" if rail_type in {"input", "output"} else "tool call" - raise NeMoGuardrailsViolation( - f"NeMo Guardrails {rail_type} rail blocked the {subject}{detail}.", - rail_type=rail_type, - rail=rail, - content="" if content is None else str(content), - ) - - -def register_local_backend(config: dict[str, Any], context: PluginContext) -> None: - """Install the built-in NeMo Guardrails local backend.""" - - local = cast(dict[str, Any], config.get("local") or {}) - module_name = cast(str | None, local.get("python_module")) - runtime_imports = _load_nemoguardrails(module_name) - guardrails_config = _build_guardrails_config(config, runtime_imports.rails_config_cls) - rails = runtime_imports.llm_rails_cls(guardrails_config) - enable_input = bool(config.get("input", True)) - enable_output = bool(config.get("output", True)) - enable_tool_input = bool(config.get("tool_input", False)) - enable_tool_output = bool(config.get("tool_output", False)) - priority = int(config.get("priority", _DEFAULT_PRIORITY)) - - if enable_input or enable_output: - codec_name, codec = _resolve_codec(config) - intercept = _make_llm_intercept( - rails=rails, - rail_type=runtime_imports.rail_type, - rail_status=runtime_imports.rail_status, - codec=codec, - enable_input=enable_input, - enable_output=enable_output, - ) - stream_intercept = _make_llm_stream_intercept( - rails=rails, - rail_type=runtime_imports.rail_type, - rail_status=runtime_imports.rail_status, - codec_name=codec_name, - codec=codec, - enable_input=enable_input, - enable_output=enable_output, - ) - context.register_llm_execution_intercept("nemo_guardrails_local", priority, intercept) - context.register_llm_stream_execution_intercept( - "nemo_guardrails_local_stream", - priority, - stream_intercept, - ) - - if enable_tool_input or enable_tool_output: - tool_intercept = _make_tool_intercept( - rails=rails, - rail_type=runtime_imports.rail_type, - rail_status=runtime_imports.rail_status, - enable_tool_input=enable_tool_input, - enable_tool_output=enable_tool_output, - ) - context.register_tool_execution_intercept("nemo_guardrails_local", priority, tool_intercept) diff --git a/crates/core/src/plugins/nemo_guardrails/python.rs b/crates/core/src/plugins/nemo_guardrails/python.rs index c782fa38..2f28d706 100644 --- a/crates/core/src/plugins/nemo_guardrails/python.rs +++ b/crates/core/src/plugins/nemo_guardrails/python.rs @@ -1,521 +1,948 @@ // SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use std::ffi::CString; -use std::future::Future; -use std::pin::Pin; use std::sync::{Arc, Mutex}; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyModule}; +use pyo3::types::{PyDict, PyList}; +use serde_json::json; +use tokio::sync::mpsc; +use tokio::task::JoinHandle; +use tokio_stream::StreamExt; +use tokio_stream::wrappers::ReceiverStream; use crate::api::llm::LlmRequest; -use crate::api::registry::{ - deregister_llm_execution_intercept, deregister_llm_stream_execution_intercept, - deregister_tool_execution_intercept, register_llm_execution_intercept, - register_llm_stream_execution_intercept, register_tool_execution_intercept, -}; -use crate::api::runtime::{LlmExecutionNextFn, LlmStreamExecutionNextFn, ToolExecutionNextFn}; +use crate::api::runtime::{LlmExecutionFn, LlmJsonStream, LlmStreamExecutionFn, ToolExecutionFn}; use crate::codec::anthropic::AnthropicMessagesCodec; use crate::codec::openai_chat::OpenAIChatCodec; use crate::codec::openai_responses::OpenAIResponsesCodec; -use crate::codec::request::{AnnotatedLlmRequest, Message}; -use crate::codec::response::AnnotatedLlmResponse; +use crate::codec::request::{AnnotatedLlmRequest, Message, MessageContent}; use crate::codec::traits::{LlmCodec, LlmResponseCodec}; use crate::error::{FlowError, Result as FlowResult}; use crate::json::Json; -use crate::plugin::{ - PluginError, PluginRegistration, PluginRegistrationContext, Result as PluginResult, - rollback_registrations, -}; +use crate::plugin::{PluginError, PluginRegistrationContext, Result as PluginResult}; use super::NeMoGuardrailsConfig; -const SUPPORT_MODULE_NAME: &str = "_nemo_guardrails_local_runtime"; -const HELPER_MODULE_NAME: &str = "_nemo_guardrails_local"; -const HELPER_FILENAME: &str = "_nemo_guardrails_local.py"; -const HELPER_SOURCE: &str = include_str!("embedded_python/_guardrails_local.py"); +const DEFAULT_MODULE_NAME: &str = "nemoguardrails"; +const SUPPORTED_NEMOGUARDRAILS_VERSION: &str = "0.22.0"; pub(super) fn register_local_backend( config: NeMoGuardrailsConfig, ctx: &mut PluginRegistrationContext, ) -> PluginResult<()> { - Python::initialize(); - - let plugin_config = match serde_json::to_value(config) { - Ok(Json::Object(config)) => config, - Ok(_) => { - return Err(PluginError::Internal( - "NeMo Guardrails local config did not serialize to a JSON object".to_string(), - )); - } - Err(err) => { - return Err(PluginError::Internal(format!( - "failed to serialize NeMo Guardrails local config: {err}" - ))); - } - }; + let runtime = Arc::new(LocalGuardrailsRuntime::new(&config)?); + + if config.input || config.output { + let llm_runtime = Arc::clone(&runtime); + let enable_input = config.input; + let enable_output = config.output; + let llm_execution: LlmExecutionFn = Arc::new(move |_name, request, next| { + let runtime = Arc::clone(&llm_runtime); + Box::pin(async move { + runtime + .execute_llm(request, next, enable_input, enable_output) + .await + }) + }); + ctx.register_llm_execution_intercept( + "nemo_guardrails_local", + config.priority, + llm_execution, + )?; + + let stream_runtime = Arc::clone(&runtime); + let enable_input = config.input; + let enable_output = config.output; + let llm_stream_execution: LlmStreamExecutionFn = Arc::new(move |_name, request, next| { + let runtime = Arc::clone(&stream_runtime); + Box::pin(async move { + runtime + .execute_llm_stream(request, next, enable_input, enable_output) + .await + }) + }); + ctx.register_llm_stream_execution_intercept( + "nemo_guardrails_local_stream", + config.priority, + llm_stream_execution, + )?; + } + + if config.tool_input || config.tool_output { + let tool_runtime = Arc::clone(&runtime); + let enable_tool_input = config.tool_input; + let enable_tool_output = config.tool_output; + let tool_execution: ToolExecutionFn = Arc::new(move |tool_name, args, next| { + let runtime = Arc::clone(&tool_runtime); + let tool_name = tool_name.to_string(); + Box::pin(async move { + let current_args = if enable_tool_input { + runtime.check_tool_input(&tool_name, &args).await? + } else { + args + }; - let registrations = Python::attach(|py| { - let register_fn = load_guardrails_local_register_fn(py)?; - invoke_embedded_plugin_register(py, ®ister_fn, &plugin_config, ctx.qualify_name("")) - }) - .map_err(|err| PluginError::RegistrationFailed(err.to_string()))?; + let tool_result = next(current_args.clone()).await?; + if !enable_tool_output { + return Ok(tool_result); + } + + runtime + .check_tool_output(&tool_name, ¤t_args, &tool_result) + .await + }) + }); + ctx.register_tool_execution_intercept( + "nemo_guardrails_local", + config.priority, + tool_execution, + )?; + } - ctx.extend_registrations(registrations); Ok(()) } -fn invoke_embedded_plugin_register( - py: Python<'_>, - register_fn: &Bound<'_, PyAny>, - plugin_config: &serde_json::Map, - namespace_prefix: String, -) -> PyResult> { - let context = Py::new( - py, - PyLocalPluginContext { - registrations: Arc::new(Mutex::new(vec![])), - namespace_prefix, - }, - )?; - let plugin_config_py = json_to_py(py, &Json::Object(plugin_config.clone()))?; - - match register_fn.call1((plugin_config_py, context.clone_ref(py))) { - Ok(_) => context.bind(py).borrow().drain_registrations(), - Err(err) => { - if let Ok(mut registrations) = context.bind(py).borrow().drain_registrations() { - rollback_registrations(&mut registrations); +struct LocalGuardrailsRuntime { + bridge: LocalGuardrailsBridge, + codec: Option, +} + +impl LocalGuardrailsRuntime { + fn new(config: &NeMoGuardrailsConfig) -> PluginResult { + Python::initialize(); + Ok(Self { + bridge: LocalGuardrailsBridge::new(config)?, + codec: resolve_codec(config)?, + }) + } + + async fn execute_llm( + &self, + request: LlmRequest, + next: crate::api::runtime::LlmExecutionNextFn, + enable_input: bool, + enable_output: bool, + ) -> FlowResult { + let (request, messages) = self.prepare_llm_request(request, enable_input).await?; + let response = next(request).await?; + + if enable_output { + let annotated_response = self.codec()?.decode_response(&response)?; + if let Some(response_text) = annotated_response.response_text() { + self.check_output_rails(&messages, response_text).await?; } - Err(err) } + + Ok(response) } -} -fn load_guardrails_local_register_fn(py: Python<'_>) -> PyResult> { - install_support_module(py)?; - let module = load_guardrails_local_module(py)?; - module.getattr("register_local_backend") -} + async fn execute_llm_stream( + &self, + request: LlmRequest, + next: crate::api::runtime::LlmStreamExecutionNextFn, + enable_input: bool, + enable_output: bool, + ) -> FlowResult { + let (request, messages) = self.prepare_llm_request(request, enable_input).await?; + let provider_stream = next(request).await?; + + if !enable_output || !self.bridge.has_streaming_output_rails()? { + return Ok(provider_stream); + } -fn install_support_module(py: Python<'_>) -> PyResult> { - let sys = py.import("sys")?; - let modules = sys.getattr("modules")?.cast_into::()?; - if let Some(existing) = modules.get_item(SUPPORT_MODULE_NAME)? { - return Ok(existing.cast_into::()?); + self.bridge.ensure_streaming_output_supported()?; + self.guard_provider_stream(messages, provider_stream) } - let module = PyModule::new(py, SUPPORT_MODULE_NAME)?; - module.add_class::()?; - module.add_class::()?; - module.add_class::()?; - module.add_class::()?; - module.add_class::()?; - module.add_class::()?; - module.add_class::()?; - modules.set_item(SUPPORT_MODULE_NAME, &module)?; - Ok(module) -} + async fn prepare_llm_request( + &self, + request: LlmRequest, + enable_input: bool, + ) -> FlowResult<(LlmRequest, Vec)> { + let codec = self.codec()?; + let mut current_request = request; + let mut annotated = codec.decode(¤t_request)?; + let mut messages = messages_from_annotated(&annotated)?; + + if enable_input { + match self + .bridge + .check(messages.clone(), LocalRailKind::Input) + .await? + { + LocalCheckOutcome::Passed => {} + LocalCheckOutcome::Blocked { rail, .. } => { + return Err(blocked_error("input", rail.as_deref())); + } + LocalCheckOutcome::Modified { content, .. } => { + replace_last_role_content(&mut annotated, "user", content)?; + current_request = codec.encode(&annotated, ¤t_request)?; + messages = messages_from_annotated(&annotated)?; + } + } + } -fn load_guardrails_local_module(py: Python<'_>) -> PyResult> { - let sys = py.import("sys")?; - let modules = sys.getattr("modules")?.cast_into::()?; - if let Some(existing) = modules.get_item(HELPER_MODULE_NAME)? { - return Ok(existing.cast_into::()?); + Ok((current_request, messages)) } - let source = CString::new(HELPER_SOURCE).unwrap(); - let filename = CString::new(HELPER_FILENAME).unwrap(); - let module_name = CString::new(HELPER_MODULE_NAME).unwrap(); - let module = PyModule::from_code(py, &source, &filename, &module_name)?; - modules.set_item(HELPER_MODULE_NAME, &module)?; - Ok(module) -} + async fn check_output_rails(&self, messages: &[Json], response_text: &str) -> FlowResult<()> { + let mut output_messages = messages.to_vec(); + output_messages.push(json!({ + "role": "assistant", + "content": response_text, + })); -fn py_to_json(obj: &Bound<'_, PyAny>) -> PyResult { - pythonize::depythonize(obj).map_err(|e| { - PyErr::new::(format!("Failed to convert to JSON: {e}")) - }) -} + match self + .bridge + .check(output_messages, LocalRailKind::Output) + .await? + { + LocalCheckOutcome::Passed => Ok(()), + LocalCheckOutcome::Blocked { rail, .. } => { + Err(blocked_error("output", rail.as_deref())) + } + LocalCheckOutcome::Modified { .. } => Err(local_violation( + "NeMo Guardrails output rail returned modified content, but the local backend \ + does not rewrite provider responses yet.", + )), + } + } -fn json_to_py(py: Python<'_>, value: &Json) -> PyResult> { - let obj: Bound<'_, PyAny> = pythonize::pythonize(py, value).map_err(|e| { - PyErr::new::(format!("Failed to convert from JSON: {e}")) - })?; - Ok(obj.unbind()) -} + async fn check_tool_input(&self, tool_name: &str, args: &Json) -> FlowResult { + let messages = vec![json!({ + "role": "user", + "content": tool_input_content(tool_name, args)?, + })]; -fn messages_to_json(messages: &[Message]) -> PyResult { - serde_json::to_value(messages).map_err(|e| { - PyErr::new::(format!( - "Failed to serialize messages: {e}" - )) - }) -} + match self.bridge.check(messages, LocalRailKind::Input).await? { + LocalCheckOutcome::Passed => Ok(args.clone()), + LocalCheckOutcome::Blocked { rail, .. } => { + Err(blocked_error("tool_input", rail.as_deref())) + } + LocalCheckOutcome::Modified { content, .. } => { + modified_tool_payload(&content, "arguments") + } + } + } -#[pyclass(name = "LLMRequest", from_py_object)] -#[derive(Clone)] -struct PyLLMRequest { - inner: LlmRequest, -} + async fn check_tool_output( + &self, + tool_name: &str, + args: &Json, + result: &Json, + ) -> FlowResult { + let messages = vec![ + json!({ + "role": "user", + "content": tool_input_content(tool_name, args)?, + }), + json!({ + "role": "assistant", + "content": tool_output_content(tool_name, args, result)?, + }), + ]; -#[pymethods] -impl PyLLMRequest { - #[new] - #[pyo3(signature = (headers, content), text_signature = "(headers: dict[str, str], content: object)")] - fn new(headers: &Bound<'_, PyAny>, content: &Bound<'_, PyAny>) -> PyResult { - let headers_json = py_to_json(headers)?; - let Json::Object(headers_map) = headers_json else { - return Err(PyErr::new::( - "headers must be a dict", - )); - }; - let content_json = py_to_json(content)?; - Ok(Self { - inner: LlmRequest { - headers: headers_map, - content: content_json, - }, + match self.bridge.check(messages, LocalRailKind::Output).await? { + LocalCheckOutcome::Passed => Ok(result.clone()), + LocalCheckOutcome::Blocked { rail, .. } => { + Err(blocked_error("tool_output", rail.as_deref())) + } + LocalCheckOutcome::Modified { content, .. } => { + modified_tool_payload(&content, "result") + } + } + } + + fn guard_provider_stream( + &self, + messages: Vec, + provider_stream: LlmJsonStream, + ) -> FlowResult { + let (text_tx, text_rx) = mpsc::channel::>(32); + let (chunk_tx, chunk_rx) = mpsc::channel::>(32); + let blocked = Arc::new(Mutex::new(None)); + let monitor = self + .bridge + .spawn_stream_monitor(messages, text_rx, Arc::clone(&blocked))?; + let codec = *self.codec()?; + + tokio::spawn(async move { + forward_guarded_provider_stream( + provider_stream, + codec, + text_tx, + chunk_tx, + monitor, + blocked, + ) + .await; + }); + + Ok(Box::pin(ReceiverStream::new(chunk_rx)) as LlmJsonStream) + } + + fn codec(&self) -> FlowResult<&LocalGuardrailsCodec> { + self.codec.as_ref().ok_or_else(|| { + FlowError::Internal( + "local NeMo Guardrails backend requires a supported codec".to_string(), + ) }) } +} - #[getter] - fn headers(&self, py: Python<'_>) -> PyResult> { - json_to_py(py, &Json::Object(self.inner.headers.clone())) +struct LocalGuardrailsBridge { + rails: Py, + input_rail: Py, + output_rail: Py, + blocked_status: String, + modified_status: String, +} + +impl LocalGuardrailsBridge { + fn new(config: &NeMoGuardrailsConfig) -> PluginResult { + Python::attach(|py| { + let imports = load_nemoguardrails( + py, + config.local.as_ref().and_then(|l| { + l.python_module + .as_deref() + .filter(|module| !module.trim().is_empty()) + }), + )?; + let guardrails_config = build_guardrails_config(py, config, &imports.rails_config_cls)?; + let rails = imports.llm_rails_cls.call1(py, (guardrails_config,))?; + let input_rail = imports.rail_type.getattr(py, "INPUT")?; + let output_rail = imports.rail_type.getattr(py, "OUTPUT")?; + let blocked = imports.rail_status.getattr(py, "BLOCKED")?; + let modified = imports.rail_status.getattr(py, "MODIFIED")?; + let blocked_status = py_status_value(blocked.bind(py))?; + let modified_status = py_status_value(modified.bind(py))?; + + Ok::(Self { + rails, + input_rail, + output_rail, + blocked_status, + modified_status, + }) + }) + .map_err(|err| PluginError::RegistrationFailed(err.to_string())) } - #[getter] - fn content(&self, py: Python<'_>) -> PyResult> { - json_to_py(py, &self.inner.content) + async fn check( + &self, + messages: Vec, + kind: LocalRailKind, + ) -> FlowResult { + let future = Python::attach(|py| { + let messages = json_to_py(py, &Json::Array(messages)) + .map_err(|err| FlowError::Internal(err.to_string()))?; + let rail_type = match kind { + LocalRailKind::Input => self.input_rail.clone_ref(py), + LocalRailKind::Output => self.output_rail.clone_ref(py), + }; + let rail_types = + PyList::new(py, [rail_type]).map_err(|err| FlowError::Internal(err.to_string()))?; + let kwargs = PyDict::new(py); + kwargs + .set_item("rail_types", rail_types) + .map_err(|err| FlowError::Internal(err.to_string()))?; + let result = self + .rails + .bind(py) + .call_method("check_async", (messages,), Some(&kwargs)) + .map_err(|err| FlowError::Internal(err.to_string()))?; + pyo3_async_runtimes::tokio::into_future(result.unbind().into_bound(py)) + .map_err(|err| FlowError::Internal(err.to_string())) + })?; + + let result = future + .await + .map_err(|err| FlowError::Internal(err.to_string()))?; + + Python::attach(|py| { + self.parse_check_result(result.bind(py)) + .map_err(|err| FlowError::Internal(err.to_string())) + }) } - fn __repr__(&self) -> String { - "LLMRequest(...)".to_string() + fn has_streaming_output_rails(&self) -> FlowResult { + Python::attach(|py| { + let Some(output) = self.output_rails_config(py)? else { + return Ok(false); + }; + match output.getattr("flows") { + Ok(flows) => flows + .is_truthy() + .map_err(|err| FlowError::Internal(err.to_string())), + Err(_) => Ok(false), + } + }) } -} -#[pyclass(name = "AnnotatedLLMRequest", from_py_object)] -#[derive(Clone)] -struct PyAnnotatedLLMRequest { - inner: AnnotatedLlmRequest, -} + fn ensure_streaming_output_supported(&self) -> FlowResult<()> { + Python::attach(|py| { + let Some(output) = self.output_rails_config(py)? else { + return Ok(()); + }; + let streaming = output.getattr("streaming").map_err(|_| { + FlowError::Internal( + "local NeMo Guardrails streaming output rails require \ + rails.output.streaming.enabled = true in the Guardrails config." + .to_string(), + ) + })?; + let enabled = streaming + .getattr("enabled") + .and_then(|value| value.is_truthy()) + .unwrap_or(false); + if !enabled { + return Err(FlowError::Internal( + "local NeMo Guardrails streaming output rails require \ + rails.output.streaming.enabled = true in the Guardrails config." + .to_string(), + )); + } -#[pymethods] -impl PyAnnotatedLLMRequest { - #[getter] - fn messages(&self, py: Python<'_>) -> PyResult> { - json_to_py(py, &messages_to_json(&self.inner.messages)?) + let stream_first = streaming + .getattr("stream_first") + .and_then(|value| value.is_truthy()) + .unwrap_or(true); + if !stream_first { + return Err(FlowError::Internal( + "local NeMo Guardrails streaming output rails currently require \ + rails.output.streaming.stream_first = true." + .to_string(), + )); + } + + Ok(()) + }) } - #[setter] - fn set_messages(&mut self, value: &Bound<'_, PyAny>) -> PyResult<()> { - self.inner.messages = pythonize::depythonize(value).map_err(|e| { - PyErr::new::(format!("invalid messages: {e}")) + fn spawn_stream_monitor( + &self, + messages: Vec, + text_rx: mpsc::Receiver>, + blocked: Arc>>, + ) -> FlowResult>> { + let (async_iter, task_locals) = Python::attach(|py| { + let generator = Py::new( + py, + PyStringStream { + receiver: Arc::new(tokio::sync::Mutex::new(text_rx)), + }, + ) + .map_err(|err| FlowError::Internal(err.to_string()))?; + let messages = json_to_py(py, &Json::Array(messages)) + .map_err(|err| FlowError::Internal(err.to_string()))?; + let kwargs = PyDict::new(py); + kwargs + .set_item("messages", messages) + .map_err(|err| FlowError::Internal(err.to_string()))?; + kwargs + .set_item("generator", generator) + .map_err(|err| FlowError::Internal(err.to_string()))?; + kwargs + .set_item("include_metadata", false) + .map_err(|err| FlowError::Internal(err.to_string()))?; + let async_iter = self + .rails + .bind(py) + .call_method("stream_async", (), Some(&kwargs)) + .map_err(|err| FlowError::Internal(err.to_string()))? + .unbind(); + let task_locals = pyo3_async_runtimes::tokio::get_current_locals(py) + .map_err(|err| FlowError::Internal(err.to_string()))?; + Ok((async_iter, task_locals)) })?; - Ok(()) + + let async_iter = Arc::new(async_iter); + Ok(tokio::spawn(pyo3_async_runtimes::tokio::scope( + task_locals, + async move { monitor_guardrails_stream(async_iter, blocked).await }, + ))) } -} -#[pyclass(name = "AnnotatedLLMResponse", skip_from_py_object)] -#[derive(Clone)] -struct PyAnnotatedLLMResponse { - inner: AnnotatedLlmResponse, -} + fn output_rails_config<'py>(&self, py: Python<'py>) -> FlowResult>> { + let rails = self.rails.bind(py); + let config = match rails.getattr("config") { + Ok(config) => config, + Err(_) => return Ok(None), + }; + let rails_config = match config.getattr("rails") { + Ok(rails_config) => rails_config, + Err(_) => return Ok(None), + }; + match rails_config.getattr("output") { + Ok(output) => Ok(Some(output)), + Err(_) => Ok(None), + } + } -#[pymethods] -impl PyAnnotatedLLMResponse { - fn response_text(&self) -> Option { - self.inner.response_text().map(str::to_string) + fn parse_check_result(&self, result: &Bound<'_, PyAny>) -> PyResult { + let status = py_status_value(&result.getattr("status")?)?; + let rail = optional_string_attr(result, "rail")?; + let content = string_attr_or_empty(result, "content")?; + + if status == self.blocked_status { + return Ok(LocalCheckOutcome::Blocked { rail }); + } + if status == self.modified_status { + return Ok(LocalCheckOutcome::Modified { content }); + } + Ok(LocalCheckOutcome::Passed) } } -#[pyclass(name = "OpenAIChatCodec")] -struct PyOpenAIChatCodec; +struct GuardrailsRuntimeImports { + rails_config_cls: Py, + llm_rails_cls: Py, + rail_type: Py, + rail_status: Py, +} -#[pymethods] -impl PyOpenAIChatCodec { - #[new] - fn new() -> Self { - Self +fn load_nemoguardrails( + py: Python<'_>, + module_name: Option<&str>, +) -> PyResult { + let root_module = module_name.unwrap_or(DEFAULT_MODULE_NAME); + let importlib = py.import("importlib")?; + let import_module = importlib.getattr("import_module")?; + let guardrails = import_module + .call1((root_module,)) + .map_err(|err| import_dependency_error(py, err, root_module))?; + let options_module_name = format!("{root_module}.rails.llm.options"); + let options = import_module + .call1((options_module_name.as_str(),)) + .map_err(|err| import_dependency_error(py, err, root_module))?; + + let version = guardrails + .getattr("__version__") + .ok() + .and_then(|value| value.extract::().ok()); + if version.as_deref() != Some(SUPPORTED_NEMOGUARDRAILS_VERSION) { + let found = version + .map(|version| format!("{version:?}")) + .unwrap_or_else(|| "None".to_string()); + return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "NeMo Guardrails local backend requires nemoguardrails==\ + {SUPPORTED_NEMOGUARDRAILS_VERSION}, but found {found}. \ + Install it with: pip install nemoguardrails=={SUPPORTED_NEMOGUARDRAILS_VERSION}" + ))); + } + + Ok(GuardrailsRuntimeImports { + rails_config_cls: guardrails.getattr("RailsConfig")?.unbind(), + llm_rails_cls: guardrails.getattr("LLMRails")?.unbind(), + rail_type: options.getattr("RailType")?.unbind(), + rail_status: options.getattr("RailStatus")?.unbind(), + }) +} + +fn import_dependency_error(py: Python<'_>, err: PyErr, root_module: &str) -> PyErr { + if !err.is_instance_of::(py) { + return err; } - fn decode(&self, request: &PyLLMRequest) -> PyResult { - OpenAIChatCodec - .decode(&request.inner) - .map(|inner| PyAnnotatedLLMRequest { inner }) - .map_err(flow_to_py_err) + let name = err.value(py).getattr("name").ok().and_then(|name| { + if name.is_none() { + None + } else { + name.extract::().ok() + } + }); + + if name.as_deref() == Some(root_module) { + return pyo3::exceptions::PyRuntimeError::new_err(format!( + "NeMo Guardrails is required for the built-in NeMo Guardrails local backend. \ + Install it with: pip install nemoguardrails=={SUPPORTED_NEMOGUARDRAILS_VERSION}" + )); } - fn encode( - &self, - annotated: &PyAnnotatedLLMRequest, - original: &PyLLMRequest, - ) -> PyResult { - OpenAIChatCodec - .encode(&annotated.inner, &original.inner) - .map(|inner| PyLLMRequest { inner }) - .map_err(flow_to_py_err) + pyo3::exceptions::PyRuntimeError::new_err(format!( + "NeMo Guardrails local backend could not import a required dependency: {}. \ + Install the full NeMo Guardrails runtime dependencies.", + name.unwrap_or_else(|| err.to_string()) + )) +} + +fn build_guardrails_config( + py: Python<'_>, + config: &NeMoGuardrailsConfig, + rails_config_cls: &Py, +) -> PyResult> { + let rails_config_cls = rails_config_cls.bind(py); + if let Some(config_path) = config.config_path.as_deref() { + return rails_config_cls + .call_method1("from_path", (config_path,)) + .map(Bound::unbind); + } + + let config_yaml = config.config_yaml.as_deref().ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err( + "config_yaml is required when config_path is not provided", + ) + })?; + let kwargs = PyDict::new(py); + kwargs.set_item("colang_content", config.colang_content.as_deref())?; + kwargs.set_item("yaml_content", config_yaml)?; + rails_config_cls + .call_method("from_content", (), Some(&kwargs)) + .map(Bound::unbind) +} + +fn py_status_value(status: &Bound<'_, PyAny>) -> PyResult { + let value = status.getattr("value").unwrap_or_else(|_| status.clone()); + Ok(value.str()?.extract::()?.to_lowercase()) +} + +fn optional_string_attr(obj: &Bound<'_, PyAny>, attr: &str) -> PyResult> { + match obj.getattr(attr) { + Ok(value) if !value.is_none() => Ok(Some(value.str()?.extract::()?)), + Ok(_) | Err(_) => Ok(None), } +} - fn decode_response(&self, response: &Bound<'_, PyAny>) -> PyResult { - let response = py_to_json(response)?; - OpenAIChatCodec - .decode_response(&response) - .map(|inner| PyAnnotatedLLMResponse { inner }) - .map_err(flow_to_py_err) +fn string_attr_or_empty(obj: &Bound<'_, PyAny>, attr: &str) -> PyResult { + match optional_string_attr(obj, attr)? { + Some(value) => Ok(value), + None => Ok(String::new()), } } -#[pyclass(name = "OpenAIResponsesCodec")] -struct PyOpenAIResponsesCodec; +fn json_to_py(py: Python<'_>, value: &Json) -> PyResult> { + let obj: Bound<'_, PyAny> = pythonize::pythonize(py, value).map_err(|e| { + PyErr::new::(format!("Failed to convert from JSON: {e}")) + })?; + Ok(obj.unbind()) +} -#[pymethods] -impl PyOpenAIResponsesCodec { - #[new] - fn new() -> Self { - Self - } +#[derive(Clone, Copy)] +enum LocalGuardrailsCodec { + OpenAIChat, + OpenAIResponses, + AnthropicMessages, +} - fn decode(&self, request: &PyLLMRequest) -> PyResult { - OpenAIResponsesCodec - .decode(&request.inner) - .map(|inner| PyAnnotatedLLMRequest { inner }) - .map_err(flow_to_py_err) +impl LocalGuardrailsCodec { + fn decode(&self, request: &LlmRequest) -> FlowResult { + match self { + Self::OpenAIChat => OpenAIChatCodec.decode(request), + Self::OpenAIResponses => OpenAIResponsesCodec.decode(request), + Self::AnthropicMessages => AnthropicMessagesCodec.decode(request), + } } fn encode( &self, - annotated: &PyAnnotatedLLMRequest, - original: &PyLLMRequest, - ) -> PyResult { - OpenAIResponsesCodec - .encode(&annotated.inner, &original.inner) - .map(|inner| PyLLMRequest { inner }) - .map_err(flow_to_py_err) + annotated: &AnnotatedLlmRequest, + original: &LlmRequest, + ) -> FlowResult { + match self { + Self::OpenAIChat => OpenAIChatCodec.encode(annotated, original), + Self::OpenAIResponses => OpenAIResponsesCodec.encode(annotated, original), + Self::AnthropicMessages => AnthropicMessagesCodec.encode(annotated, original), + } } - fn decode_response(&self, response: &Bound<'_, PyAny>) -> PyResult { - let response = py_to_json(response)?; - OpenAIResponsesCodec - .decode_response(&response) - .map(|inner| PyAnnotatedLLMResponse { inner }) - .map_err(flow_to_py_err) + fn decode_response( + &self, + response: &Json, + ) -> FlowResult { + match self { + Self::OpenAIChat => OpenAIChatCodec.decode_response(response), + Self::OpenAIResponses => OpenAIResponsesCodec.decode_response(response), + Self::AnthropicMessages => AnthropicMessagesCodec.decode_response(response), + } } } -#[pyclass(name = "AnthropicMessagesCodec")] -struct PyAnthropicMessagesCodec; - -#[pymethods] -impl PyAnthropicMessagesCodec { - #[new] - fn new() -> Self { - Self +fn resolve_codec(config: &NeMoGuardrailsConfig) -> PluginResult> { + if !(config.input || config.output) { + return Ok(None); } - fn decode(&self, request: &PyLLMRequest) -> PyResult { - AnthropicMessagesCodec - .decode(&request.inner) - .map(|inner| PyAnnotatedLLMRequest { inner }) - .map_err(flow_to_py_err) + match config.codec.as_deref() { + Some("openai_chat") => Ok(Some(LocalGuardrailsCodec::OpenAIChat)), + Some("openai_responses") => Ok(Some(LocalGuardrailsCodec::OpenAIResponses)), + Some("anthropic_messages") => Ok(Some(LocalGuardrailsCodec::AnthropicMessages)), + Some(other) => Err(PluginError::InvalidConfig(format!( + "unsupported local NeMo Guardrails codec '{other}'" + ))), + None => Err(PluginError::InvalidConfig( + "local NeMo Guardrails backend requires a supported codec".to_string(), + )), } +} - fn encode( - &self, - annotated: &PyAnnotatedLLMRequest, - original: &PyLLMRequest, - ) -> PyResult { - AnthropicMessagesCodec - .encode(&annotated.inner, &original.inner) - .map(|inner| PyLLMRequest { inner }) - .map_err(flow_to_py_err) +enum LocalCheckOutcome { + Passed, + Blocked { rail: Option }, + Modified { content: String }, +} + +#[derive(Clone, Copy)] +enum LocalRailKind { + Input, + Output, +} + +fn messages_from_annotated(annotated: &AnnotatedLlmRequest) -> FlowResult> { + match serde_json::to_value(&annotated.messages) + .map_err(|err| FlowError::Internal(format!("failed to serialize messages: {err}")))? + { + Json::Array(messages) => Ok(messages), + _ => Err(FlowError::Internal( + "serialized messages were not a JSON array".to_string(), + )), } +} - fn decode_response(&self, response: &Bound<'_, PyAny>) -> PyResult { - let response = py_to_json(response)?; - AnthropicMessagesCodec - .decode_response(&response) - .map(|inner| PyAnnotatedLLMResponse { inner }) - .map_err(flow_to_py_err) +fn replace_last_role_content( + annotated: &mut AnnotatedLlmRequest, + role: &str, + content: String, +) -> FlowResult<()> { + for message in annotated.messages.iter_mut().rev() { + match (role, message) { + ( + "user", + Message::User { + content: target, .. + }, + ) => { + *target = MessageContent::Text(content); + return Ok(()); + } + ( + "assistant", + Message::Assistant { + content: target, .. + }, + ) => { + *target = Some(MessageContent::Text(content)); + return Ok(()); + } + _ => {} + } } + + Err(local_violation(format!( + "NeMo Guardrails returned modified {role} content but no {role} message was present." + ))) } -#[pyclass(name = "PluginContext")] -struct PyLocalPluginContext { - registrations: Arc>>, - namespace_prefix: String, +fn tool_input_content(name: &str, args: &Json) -> FlowResult { + serde_json::to_string(&json!({ + "tool_name": name, + "arguments": args, + })) + .map_err(|err| FlowError::Internal(format!("failed to serialize tool input: {err}"))) } -impl PyLocalPluginContext { - fn qualify_name(&self, name: &str) -> String { - format!("{}{}", self.namespace_prefix, name) - } +fn tool_output_content(name: &str, args: &Json, result: &Json) -> FlowResult { + serde_json::to_string(&json!({ + "tool_name": name, + "arguments": args, + "result": result, + })) + .map_err(|err| FlowError::Internal(format!("failed to serialize tool output: {err}"))) +} - fn drain_registrations(&self) -> PyResult> { - let mut guard = self.registrations.lock().map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("plugin context lock poisoned: {e}")) - })?; - Ok(std::mem::take(&mut *guard)) - } +fn modified_tool_payload(content: &str, field: &str) -> FlowResult { + let value: Json = serde_json::from_str(content).map_err(|_| { + local_violation(format!( + "NeMo Guardrails returned modified tool {field} content that is not valid JSON." + )) + })?; + + let Json::Object(object) = value else { + return Err(local_violation(format!( + "NeMo Guardrails returned modified tool {field} content without a '{field}' field." + ))); + }; + object.get(field).cloned().ok_or_else(|| { + local_violation(format!( + "NeMo Guardrails returned modified tool {field} content without a '{field}' field." + )) + }) } -#[pymethods] -impl PyLocalPluginContext { - #[pyo3(signature = (name, priority, callback), text_signature = "(name: str, priority: int, callback: object) -> None")] - fn register_llm_execution_intercept( - &self, - name: &str, - priority: i32, - callback: Py, - ) -> PyResult<()> { - let qualified_name = self.qualify_name(name); - register_llm_execution_intercept( - &qualified_name, - priority, - wrap_py_llm_exec_intercept_fn(callback), - ) - .map_err(plugin_to_py_err)?; - self.push_registration(qualified_name, RegistrationKind::Llm) - } +fn blocked_error(rail_type: &str, rail: Option<&str>) -> FlowError { + let detail = rail + .filter(|rail| !rail.is_empty()) + .map(|rail| format!(" by rail '{rail}'")) + .unwrap_or_default(); + let subject = if matches!(rail_type, "input" | "output") { + "LLM call" + } else { + "tool call" + }; + local_violation(format!( + "NeMo Guardrails {rail_type} rail blocked the {subject}{detail}." + )) +} - #[pyo3(signature = (name, priority, callback), text_signature = "(name: str, priority: int, callback: object) -> None")] - fn register_llm_stream_execution_intercept( - &self, - name: &str, - priority: i32, - callback: Py, - ) -> PyResult<()> { - let qualified_name = self.qualify_name(name); - register_llm_stream_execution_intercept( - &qualified_name, - priority, - wrap_py_llm_stream_exec_intercept_fn(callback), - ) - .map_err(plugin_to_py_err)?; - self.push_registration(qualified_name, RegistrationKind::LlmStream) +fn local_violation(message: impl Into) -> FlowError { + FlowError::Internal(message.into()) +} + +async fn forward_guarded_provider_stream( + mut provider_stream: LlmJsonStream, + codec: LocalGuardrailsCodec, + text_tx: mpsc::Sender>, + chunk_tx: mpsc::Sender>, + monitor: JoinHandle>, + blocked: Arc>>, +) { + while let Some(item) = provider_stream.next().await { + let chunk = match item { + Ok(chunk) => chunk, + Err(err) => { + let _ = chunk_tx.send(Err(err)).await; + let _ = text_tx.send(None).await; + let _ = monitor.await; + return; + } + }; + + if let Some(message) = blocked_message(&blocked) { + let _ = chunk_tx.send(Err(streaming_output_blocked(message))).await; + let _ = text_tx.send(None).await; + let _ = monitor.await; + return; + } + + let text = extract_stream_text(codec, &chunk); + + if chunk_tx.send(Ok(chunk)).await.is_err() { + let _ = text_tx.send(None).await; + let _ = monitor.await; + return; + } + tokio::task::yield_now().await; + + if let Some(text) = text { + let _ = text_tx.send(Some(text)).await; + } + + if let Some(message) = blocked_message(&blocked) { + let _ = chunk_tx.send(Err(streaming_output_blocked(message))).await; + let _ = text_tx.send(None).await; + let _ = monitor.await; + return; + } } - #[pyo3(signature = (name, priority, callback), text_signature = "(name: str, priority: int, callback: object) -> None")] - fn register_tool_execution_intercept( - &self, - name: &str, - priority: i32, - callback: Py, - ) -> PyResult<()> { - let qualified_name = self.qualify_name(name); - register_tool_execution_intercept( - &qualified_name, - priority, - wrap_py_tool_exec_intercept_fn(callback), - ) - .map_err(plugin_to_py_err)?; - self.push_registration(qualified_name, RegistrationKind::Tool) + let _ = text_tx.send(None).await; + match monitor.await { + Ok(Ok(())) => {} + Ok(Err(err)) => { + let _ = chunk_tx.send(Err(err)).await; + return; + } + Err(err) => { + let _ = chunk_tx + .send(Err(FlowError::Internal(format!( + "nemo_guardrails stream monitor task failed: {err}" + )))) + .await; + return; + } } - fn __repr__(&self) -> String { - "".to_string() + if let Some(message) = blocked_message(&blocked) { + let _ = chunk_tx.send(Err(streaming_output_blocked(message))).await; } } -impl PyLocalPluginContext { - fn push_registration(&self, name: String, kind: RegistrationKind) -> PyResult<()> { - let mut guard = self.registrations.lock().map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("plugin context lock poisoned: {e}")) - })?; - guard.push(PluginRegistration::new( - "plugin", - name.clone(), - Box::new(move || match kind { - RegistrationKind::Llm => deregister_llm_execution_intercept(&name) - .map(|_| ()) - .map_err(registration_failure), - RegistrationKind::LlmStream => deregister_llm_stream_execution_intercept(&name) - .map(|_| ()) - .map_err(registration_failure), - RegistrationKind::Tool => deregister_tool_execution_intercept(&name) - .map(|_| ()) - .map_err(registration_failure), - }), - )); - Ok(()) - } -} - -enum RegistrationKind { - Llm, - LlmStream, - Tool, -} - -fn registration_failure(err: FlowError) -> PluginError { - PluginError::RegistrationFailed(err.to_string()) -} - -fn plugin_to_py_err(err: FlowError) -> PyErr { - pyo3::exceptions::PyRuntimeError::new_err(err.to_string()) -} - -fn flow_to_py_err(err: FlowError) -> PyErr { - pyo3::exceptions::PyRuntimeError::new_err(err.to_string()) -} - -type PyValueFuture = Pin>> + Send>>; -type ToolExecIntercept = Arc< - dyn Fn( - &str, - Json, - ToolExecutionNextFn, - ) -> Pin> + Send>> - + Send - + Sync, ->; -type LlmExecIntercept = Arc< - dyn Fn( - &str, - LlmRequest, - LlmExecutionNextFn, - ) -> Pin> + Send>> - + Send - + Sync, ->; -type LlmStreamIntercept = Arc< - dyn Fn( - &str, - LlmRequest, - LlmStreamExecutionNextFn, - ) -> Pin< - Box< - dyn Future< - Output = FlowResult< - Pin> + Send>>, - >, - > + Send, - >, - > + Send - + Sync, ->; - -fn split_py_object_or_future( - py: Python<'_>, - result: Py, -) -> FlowResult, PyValueFuture>> { - let bound = result.bind(py); - if bound.getattr("__await__").is_ok() { - let future = pyo3_async_runtimes::tokio::into_future(result.into_bound(py)) - .map_err(|e| FlowError::Internal(e.to_string()))?; - Ok(Err(Box::pin(future) as PyValueFuture)) - } else { - Ok(Ok(result)) +fn blocked_message(blocked: &Arc>>) -> Option { + blocked.lock().ok().and_then(|guard| guard.clone()) +} + +fn streaming_output_blocked(message: String) -> FlowError { + local_violation(format!( + "NeMo Guardrails output rail blocked the LLM call: {message}" + )) +} + +fn extract_stream_text(codec: LocalGuardrailsCodec, chunk: &Json) -> Option { + let chunk = chunk.as_object()?; + match codec { + LocalGuardrailsCodec::OpenAIChat => { + let choices = chunk.get("choices")?.as_array()?; + let mut parts = vec![]; + for choice in choices { + let content = choice + .get("delta") + .and_then(Json::as_object) + .and_then(|delta| delta.get("content")) + .and_then(Json::as_str); + if let Some(content) = content + && !content.is_empty() + { + parts.push(content); + } + } + (!parts.is_empty()).then(|| parts.join("")) + } + LocalGuardrailsCodec::OpenAIResponses => { + if chunk.get("type").and_then(Json::as_str) == Some("response.output_text.delta") { + chunk + .get("delta") + .and_then(Json::as_str) + .filter(|delta| !delta.is_empty()) + .map(str::to_string) + } else { + None + } + } + LocalGuardrailsCodec::AnthropicMessages => { + if chunk.get("type").and_then(Json::as_str) != Some("content_block_delta") { + return None; + } + let delta = chunk.get("delta")?.as_object()?; + if delta.get("type").and_then(Json::as_str) != Some("text_delta") { + return None; + } + delta + .get("text") + .and_then(Json::as_str) + .filter(|text| !text.is_empty()) + .map(str::to_string) + } } } -async fn resolve_py_object_or_future( - outcome: FlowResult, PyValueFuture>>, -) -> FlowResult> { - match outcome? { - Ok(value) => Ok(value), - Err(future) => future.await.map_err(|e| FlowError::Internal(e.to_string())), +async fn monitor_guardrails_stream( + async_iter: Arc>, + blocked: Arc>>, +) -> FlowResult<()> { + loop { + let Some(coro) = next_async_iter_coro(&async_iter)? else { + break; + }; + let Some(value) = await_async_iter_value(coro).await? else { + break; + }; + Python::attach(|py| { + if let Ok(chunk) = value.bind(py).extract::() + && let Some(message) = guardrails_stream_error_message(&chunk) + { + let mut guard = blocked.lock().map_err(|err| { + FlowError::Internal(format!("stream block state lock poisoned: {err}")) + })?; + *guard = Some(message); + } + Ok::<(), FlowError>(()) + })?; + if blocked_message(&blocked).is_some() { + break; + } } + Ok(()) } fn next_async_iter_coro(async_iter: &Arc>) -> FlowResult>> { @@ -534,18 +961,14 @@ fn next_async_iter_coro(async_iter: &Arc>) -> FlowResult) -> FlowResult> { +async fn await_async_iter_value(coro: Py) -> FlowResult>> { let future = Python::attach(|py| { pyo3_async_runtimes::tokio::into_future(coro.into_bound(py)) - .map_err(|e| FlowError::Internal(e.to_string())) + .map_err(|err| FlowError::Internal(err.to_string())) })?; match future.await { - Ok(result) => Python::attach(|py| { - py_to_json(result.bind(py)) - .map(Some) - .map_err(|e| FlowError::Internal(e.to_string())) - }), + Ok(result) => Ok(Some(result)), Err(error) => Python::attach(|py| { if error.is_instance_of::(py) { Ok(None) @@ -556,287 +979,45 @@ async fn await_async_iter_value(coro: Py) -> FlowResult> { } } -async fn forward_async_iter( - async_iter: Arc>, - tx: tokio::sync::mpsc::Sender>, -) { - loop { - let next_value = match next_async_iter_coro(&async_iter) { - Ok(None) => break, - Ok(Some(coro)) => await_async_iter_value(coro).await, - Err(error) => Err(error), - }; - - match next_value { - Ok(Some(value)) => { - if tx.send(Ok(value)).await.is_err() { - break; - } - } - Ok(None) => break, - Err(error) => { - let _ = tx.send(Err(error)).await; - break; - } - } - } -} - -fn stream_from_async_iter( - async_iter: Py, -) -> FlowResult> + Send>>> { - let (tx, rx) = tokio::sync::mpsc::channel::>(32); - let task_locals = Python::attach(|py| { - pyo3_async_runtimes::tokio::get_current_locals(py) - .map_err(|e: pyo3::PyErr| FlowError::Internal(e.to_string())) - })?; - - let async_iter = Arc::new(async_iter); - tokio::spawn(pyo3_async_runtimes::tokio::scope(task_locals, async move { - forward_async_iter(async_iter, tx).await; - })); - - Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx))) -} - -#[pyclass] -struct PyToolNextFn { - inner: ToolExecutionNextFn, -} - -#[pymethods] -impl PyToolNextFn { - fn __call__<'py>( - &self, - py: Python<'py>, - args: Bound<'_, PyAny>, - ) -> PyResult> { - let args = py_to_json(&args)?; - let next = self.inner.clone(); - pyo3_async_runtimes::tokio::future_into_py(py, async move { - let result = next(args) - .await - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Python::attach(|py| json_to_py(py, &result)) - }) +fn guardrails_stream_error_message(chunk: &str) -> Option { + let payload: Json = serde_json::from_str(chunk).ok()?; + let error = payload.get("error")?.as_object()?; + if error.get("type").and_then(Json::as_str) != Some("guardrails_violation") { + return None; } + error + .get("message") + .and_then(Json::as_str) + .filter(|message| !message.is_empty()) + .map(str::to_string) + .or_else(|| Some("Blocked by output rails.".to_string())) } -#[pyclass] -struct PyLlmNextFn { - inner: LlmExecutionNextFn, +#[pyclass(name = "StringStream")] +struct PyStringStream { + receiver: Arc>>>, } #[pymethods] -impl PyLlmNextFn { - fn __call__<'py>(&self, py: Python<'py>, request: PyLLMRequest) -> PyResult> { - let next = self.inner.clone(); - pyo3_async_runtimes::tokio::future_into_py(py, async move { - let result = next(request.inner) - .await - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Python::attach(|py| json_to_py(py, &result)) - }) - } -} - -#[pyclass(name = "LlmStream")] -struct PyLlmStream { - receiver: tokio::sync::Mutex>>, -} - -#[pymethods] -impl PyLlmStream { +impl PyStringStream { fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { slf } fn __anext__<'py>(&self, py: Python<'py>) -> PyResult> { - let receiver_ptr = &self.receiver - as *const tokio::sync::Mutex>>; - let receiver_ref = unsafe { &*receiver_ptr }; - + let receiver = Arc::clone(&self.receiver); pyo3_async_runtimes::tokio::future_into_py(py, async move { - let mut guard = receiver_ref.lock().await; + let mut guard = receiver.lock().await; match guard.recv().await { - None => Err(PyErr::new::( + Some(Some(value)) => Ok(value), + Some(None) | None => Err(PyErr::new::( "stream exhausted", )), - Some(Ok(value)) => Python::attach(|py| json_to_py(py, &value)), - Some(Err(err)) => Err(PyErr::new::( - err.to_string(), - )), } }) } } -#[pyclass] -struct PyLlmStreamNextFn { - inner: LlmStreamExecutionNextFn, -} - -#[pymethods] -impl PyLlmStreamNextFn { - fn __call__<'py>(&self, py: Python<'py>, request: PyLLMRequest) -> PyResult> { - let next = self.inner.clone(); - pyo3_async_runtimes::tokio::future_into_py(py, async move { - let rust_stream = next(request.inner) - .await - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - let (tx, rx) = tokio::sync::mpsc::channel::>(32); - tokio::spawn(async move { - use tokio_stream::StreamExt; - let mut stream = rust_stream; - while let Some(item) = stream.next().await { - if tx.send(item).await.is_err() { - break; - } - } - }); - Ok(PyLlmStream { - receiver: tokio::sync::Mutex::new(rx), - }) - }) - } -} - -fn wrap_py_tool_exec_intercept_fn(py_fn: Py) -> ToolExecIntercept { - let py_fn = Arc::new(py_fn); - Arc::new(move |name: &str, args: Json, next: ToolExecutionNextFn| { - let py_fn = py_fn.clone(); - let name = name.to_string(); - Box::pin(async move { - let outcome: FlowResult> = Python::attach(|py| { - let py_args = - json_to_py(py, &args).map_err(|e: PyErr| FlowError::Internal(e.to_string()))?; - let py_next = PyToolNextFn { inner: next }; - let result = py_fn - .call1( - py, - ( - &name, - py_args, - py_next - .into_pyobject(py) - .map_err(|e| FlowError::Internal(e.to_string()))? - .into_any(), - ), - ) - .map_err(|e: PyErr| FlowError::Internal(e.to_string()))?; - let bound = result.bind(py); - if bound.getattr("__await__").is_ok() { - let future = pyo3_async_runtimes::tokio::into_future(result.into_bound(py)) - .map_err(|e| FlowError::Internal(e.to_string()))?; - Ok(Err(Box::pin(future) as PyValueFuture)) - } else { - let json = - py_to_json(bound).map_err(|e: PyErr| FlowError::Internal(e.to_string()))?; - Ok(Ok(json)) - } - }); - - match outcome? { - Ok(json) => Ok(json), - Err(future) => { - let py_result = future - .await - .map_err(|e| FlowError::Internal(e.to_string()))?; - Python::attach(|py| { - py_to_json(py_result.bind(py)) - .map_err(|e: PyErr| FlowError::Internal(e.to_string())) - }) - } - } - }) - }) -} - -fn wrap_py_llm_exec_intercept_fn(py_fn: Py) -> LlmExecIntercept { - let py_fn = Arc::new(py_fn); - Arc::new( - move |name: &str, request: LlmRequest, next: LlmExecutionNextFn| { - let py_fn = py_fn.clone(); - let name = name.to_string(); - Box::pin(async move { - let outcome: FlowResult> = Python::attach(|py| { - let py_req = PyLLMRequest { inner: request }; - let py_next = PyLlmNextFn { inner: next }; - let result = py_fn - .call1( - py, - ( - &name, - py_req - .into_pyobject(py) - .map_err(|e| FlowError::Internal(e.to_string()))? - .into_any(), - py_next - .into_pyobject(py) - .map_err(|e| FlowError::Internal(e.to_string()))? - .into_any(), - ), - ) - .map_err(|e: PyErr| FlowError::Internal(e.to_string()))?; - let bound = result.bind(py); - if bound.getattr("__await__").is_ok() { - let future = pyo3_async_runtimes::tokio::into_future(result.into_bound(py)) - .map_err(|e| FlowError::Internal(e.to_string()))?; - Ok(Err(Box::pin(future) as PyValueFuture)) - } else { - let json = py_to_json(bound) - .map_err(|e: PyErr| FlowError::Internal(e.to_string()))?; - Ok(Ok(json)) - } - }); - - match outcome? { - Ok(json) => Ok(json), - Err(future) => { - let py_result = future - .await - .map_err(|e| FlowError::Internal(e.to_string()))?; - Python::attach(|py| { - py_to_json(py_result.bind(py)) - .map_err(|e: PyErr| FlowError::Internal(e.to_string())) - }) - } - } - }) - }, - ) -} - -fn wrap_py_llm_stream_exec_intercept_fn(py_fn: Py) -> LlmStreamIntercept { - let py_fn = Arc::new(py_fn); - Arc::new( - move |_name: &str, request: LlmRequest, next: LlmStreamExecutionNextFn| { - let py_fn = py_fn.clone(); - Box::pin(async move { - let async_iter = resolve_py_object_or_future(Python::attach(|py| { - let py_req = PyLLMRequest { inner: request }; - let py_next = PyLlmStreamNextFn { inner: next }; - let result = py_fn - .call1( - py, - ( - py_req - .into_pyobject(py) - .map_err(|e: PyErr| FlowError::Internal(e.to_string()))? - .into_any(), - py_next - .into_pyobject(py) - .map_err(|e: PyErr| FlowError::Internal(e.to_string()))? - .into_any(), - ), - ) - .map_err(|e: PyErr| FlowError::Internal(e.to_string()))?; - split_py_object_or_future(py, result) - })) - .await?; - - stream_from_async_iter(async_iter) - }) - }, - ) -} +#[cfg(test)] +#[path = "../../../tests/unit/plugins/nemo_guardrails/local_python_tests.rs"] +mod tests; diff --git a/crates/core/tests/unit/plugins/nemo_guardrails/local_python_tests.rs b/crates/core/tests/unit/plugins/nemo_guardrails/local_python_tests.rs new file mode 100644 index 00000000..cf48592c --- /dev/null +++ b/crates/core/tests/unit/plugins/nemo_guardrails/local_python_tests.rs @@ -0,0 +1,217 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::ffi::CString; + +use pyo3::prelude::*; +use pyo3::types::PyModule; +use serde_json::json; + +use super::*; +use crate::plugins::nemo_guardrails::component::LocalBackendConfig; + +fn local_config(module_name: &str) -> NeMoGuardrailsConfig { + NeMoGuardrailsConfig { + mode: "local".to_string(), + codec: Some("openai_chat".to_string()), + config_yaml: Some("models: []".to_string()), + colang_content: Some("define flow noop\n pass".to_string()), + local: Some(LocalBackendConfig { + python_module: Some(module_name.to_string()), + }), + ..NeMoGuardrailsConfig::default() + } +} + +fn install_fake_guardrails(py: Python<'_>, module_name: &str, version: &str, llm_rails_init: &str) { + let code = format!( + r#" +import sys +import types + +MODULE_NAME = {module_name:?} + +fake_root = types.ModuleType(MODULE_NAME) +fake_root.__version__ = {version:?} +fake_options = types.ModuleType(MODULE_NAME + ".rails.llm.options") + +class Result: + def __init__(self, status, content=None, rail=None): + self.status = status + self.content = content + self.rail = rail + +class RailType: + INPUT = "input" + OUTPUT = "output" + +class RailStatus: + BLOCKED = "blocked" + MODIFIED = "modified" + PASSED = "passed" + +class RailsConfig: + @staticmethod + def from_content(*, colang_content=None, yaml_content=None): + return {{"yaml": yaml_content, "colang": colang_content}} + + @staticmethod + def from_path(path): + return {{"path": path}} + +class LLMRails: + instances = [] + + def __init__(self, config): + LLMRails.instances.append(self) +{llm_rails_init} + +fake_root.Result = Result +fake_root.RailStatus = RailStatus +fake_root.RailsConfig = RailsConfig +fake_root.LLMRails = LLMRails +fake_options.RailType = RailType +fake_options.RailStatus = RailStatus + +sys.modules[MODULE_NAME] = fake_root +sys.modules[MODULE_NAME + ".rails"] = types.ModuleType(MODULE_NAME + ".rails") +sys.modules[MODULE_NAME + ".rails.llm"] = types.ModuleType(MODULE_NAME + ".rails.llm") +sys.modules[MODULE_NAME + ".rails.llm.options"] = fake_options +"# + ); + let code = CString::new(code).unwrap(); + let file_name = CString::new("fake_guardrails.py").unwrap(); + let module_name = CString::new(format!("{module_name}_installer")).unwrap(); + PyModule::from_code(py, &code, &file_name, &module_name).unwrap(); +} + +fn py_to_json(obj: &Bound<'_, PyAny>) -> Json { + pythonize::depythonize(obj).unwrap() +} + +#[test] +fn bridge_loads_inline_guardrails_config() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + Python::attach(|py| { + let module_name = "fake_guardrails_bridge_config"; + install_fake_guardrails(py, module_name, "0.22.0", " self.config = config"); + + let bridge = LocalGuardrailsBridge::new(&local_config(module_name)).unwrap(); + let config = bridge.rails.bind(py).getattr("config").unwrap(); + assert_eq!( + py_to_json(&config), + json!({"yaml": "models: []", "colang": "define flow noop\n pass"}) + ); + }); +} + +#[test] +fn bridge_parses_pass_block_and_modify_outcomes() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + Python::attach(|py| { + let module_name = "fake_guardrails_bridge_outcomes"; + install_fake_guardrails(py, module_name, "0.22.0", " self.config = config"); + let bridge = LocalGuardrailsBridge::new(&local_config(module_name)).unwrap(); + let root = py.import(module_name).unwrap(); + let result_cls = root.getattr("Result").unwrap(); + let status = root.getattr("RailStatus").unwrap(); + + let passed = result_cls + .call1((status.getattr("PASSED").unwrap(),)) + .unwrap(); + assert!(matches!( + bridge.parse_check_result(&passed).unwrap(), + LocalCheckOutcome::Passed + )); + + let blocked = result_cls + .call1((status.getattr("BLOCKED").unwrap(), "stop", "policy")) + .unwrap(); + match bridge.parse_check_result(&blocked).unwrap() { + LocalCheckOutcome::Blocked { rail } => assert_eq!(rail.as_deref(), Some("policy")), + _ => panic!("expected blocked outcome"), + } + + let modified = result_cls + .call1((status.getattr("MODIFIED").unwrap(), "rewritten")) + .unwrap(); + match bridge.parse_check_result(&modified).unwrap() { + LocalCheckOutcome::Modified { content } => assert_eq!(content, "rewritten"), + _ => panic!("expected modified outcome"), + } + }); +} + +#[test] +fn modified_tool_payload_rejects_malformed_content() { + let error = modified_tool_payload("not-json", "arguments").unwrap_err(); + assert!( + error + .to_string() + .contains("modified tool arguments content that is not valid JSON") + ); + + let error = modified_tool_payload(r#"{"tool_name":"demo"}"#, "result").unwrap_err(); + assert!( + error + .to_string() + .contains("modified tool result content without a 'result' field") + ); +} + +#[test] +fn streaming_support_rejects_stream_first_false() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + Python::attach(|py| { + let module_name = "fake_guardrails_bridge_streaming"; + install_fake_guardrails( + py, + module_name, + "0.22.0", + r#" self.config = types.SimpleNamespace( + rails=types.SimpleNamespace( + output=types.SimpleNamespace( + flows=["self check output"], + streaming=types.SimpleNamespace(enabled=True, stream_first=False), + ) + ) + )"#, + ); + + let bridge = LocalGuardrailsBridge::new(&local_config(module_name)).unwrap(); + assert!(bridge.has_streaming_output_rails().unwrap()); + let error = bridge.ensure_streaming_output_supported().unwrap_err(); + assert!(error.to_string().contains("stream_first = true")); + }); +} + +#[test] +fn stream_text_extraction_handles_supported_codecs() { + assert_eq!( + extract_stream_text( + LocalGuardrailsCodec::OpenAIChat, + &json!({"choices": [{"delta": {"content": "hel"}}, {"delta": {"content": "lo"}}]}) + ), + Some("hello".to_string()) + ); + assert_eq!( + extract_stream_text( + LocalGuardrailsCodec::OpenAIResponses, + &json!({"type": "response.output_text.delta", "delta": "hello"}) + ), + Some("hello".to_string()) + ); + assert_eq!( + extract_stream_text( + LocalGuardrailsCodec::AnthropicMessages, + &json!({"type": "content_block_delta", "delta": {"type": "text_delta", "text": "hello"}}) + ), + Some("hello".to_string()) + ); +} diff --git a/crates/python/src/lib.rs b/crates/python/src/lib.rs index c6878f41..b3377d28 100644 --- a/crates/python/src/lib.rs +++ b/crates/python/src/lib.rs @@ -75,3 +75,7 @@ fn install_native_module_alias(m: &Bound<'_, PyModule>) -> PyResult<()> { #[cfg(test)] #[path = "../tests/coverage/coverage_tests.rs"] mod coverage_tests; + +#[cfg(test)] +#[path = "../tests/coverage/nemo_guardrails_coverage_tests.rs"] +mod nemo_guardrails_coverage_tests; diff --git a/crates/python/tests/coverage/coverage_tests.rs b/crates/python/tests/coverage/coverage_tests.rs index 98d00c92..6c3205e0 100644 --- a/crates/python/tests/coverage/coverage_tests.rs +++ b/crates/python/tests/coverage/coverage_tests.rs @@ -4,13 +4,11 @@ //! Coverage tests for coverage in the NeMo Relay Python crate. use std::ffi::CString; -use std::panic::{AssertUnwindSafe, catch_unwind}; -use std::path::PathBuf; use std::pin::Pin; use std::sync::Arc; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyModule}; +use pyo3::types::PyModule; use serde_json::{Value as Json, json}; use tokio_stream::Stream; use tokio_stream::StreamExt; @@ -26,13 +24,7 @@ use crate::py_callable::{ }; use nemo_relay::api::event::{BaseEvent, Event, EventCategory, ScopeCategory, ScopeEvent}; use nemo_relay::api::llm::LlmRequest; -use nemo_relay::api::runtime::{ - LlmExecutionNextFn, LlmStreamExecutionNextFn, NemoRelayContextState, ToolExecutionNextFn, - global_context, -}; -use nemo_relay::plugin::{ - PluginComponentSpec, PluginConfig, clear_plugin_configuration, initialize_plugins, -}; +use nemo_relay::api::runtime::{LlmExecutionNextFn, LlmStreamExecutionNextFn, ToolExecutionNextFn}; fn load_module<'py>(py: Python<'py>, code: &str) -> Bound<'py, PyModule> { let code = CString::new(code).unwrap(); @@ -41,212 +33,6 @@ fn load_module<'py>(py: Python<'py>, code: &str) -> Bound<'py, PyModule> { PyModule::from_code(py, &code, &file_name, &module_name).unwrap() } -fn python_package_dir() -> PathBuf { - PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../python") -} - -fn embedded_guardrails_local_source_path() -> PathBuf { - PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .join("../core/src/plugins/nemo_guardrails/embedded_python/_guardrails_local.py") -} - -fn fake_guardrails_module_prelude(module_name: &str, python_dir: &str) -> String { - format!( - r#" -import sys -import types - -sys.path.insert(0, {python_dir:?}) - -MODULE_NAME = {module_name:?} - -fake_root = types.ModuleType(MODULE_NAME) -fake_root.__version__ = "0.22.0" -fake_options = types.ModuleType(MODULE_NAME + ".rails.llm.options") - -class Result: - def __init__(self, status, content=None, rail=None): - self.status = status - self.content = content - self.rail = rail - -class RailType: - INPUT = "input" - OUTPUT = "output" - -class RailStatus: - BLOCKED = "blocked" - MODIFIED = "modified" - PASSED = "passed" - -class RailsConfig: - @staticmethod - def from_content(*, colang_content=None, yaml_content=None): - return {{"yaml": yaml_content, "colang": colang_content}} - - @staticmethod - def from_path(path): - return {{"path": path}} -"#, - python_dir = python_dir, - module_name = module_name, - ) -} - -fn register_fake_guardrails_module_epilogue() -> &'static str { - r#" -fake_root.RailsConfig = RailsConfig -fake_root.LLMRails = LLMRails -fake_options.RailType = RailType -fake_options.RailStatus = RailStatus - -sys.modules[MODULE_NAME] = fake_root -sys.modules[MODULE_NAME + ".rails"] = types.ModuleType(MODULE_NAME + ".rails") -sys.modules[MODULE_NAME + ".rails.llm"] = types.ModuleType(MODULE_NAME + ".rails.llm") -sys.modules[MODULE_NAME + ".rails.llm.options"] = fake_options -"# -} - -fn local_plugin_context_python() -> &'static str { - r#" -class Context: - def register_llm_execution_intercept(self, name, priority, callback): - self.llm = callback - - def register_llm_stream_execution_intercept(self, name, priority, callback): - self.stream = callback - - def register_tool_execution_intercept(self, name, priority, callback): - self.tool = callback -"# -} - -fn embedded_guardrails_local_loader_python(source_path: &str) -> String { - format!( - r#" -import pathlib -import sys -import types - -from nemo_relay._native import LLMRequest - -class _AnnotatedRequest: - def __init__(self, messages): - self.messages = [dict(message) for message in messages] - -class _AnnotatedResponse: - def __init__(self, response): - self._response = response - - def response_text(self): - try: - return self._response["choices"][0]["message"]["content"] - except Exception: - return None - -class _BaseCodec: - def decode(self, request): - return _AnnotatedRequest(request.content.get("messages", [])) - - def encode(self, annotated, original): - content = dict(original.content) - content["messages"] = annotated.messages - return LLMRequest(original.headers, content) - - def decode_response(self, response): - return _AnnotatedResponse(response) - -class OpenAIChatCodec(_BaseCodec): - pass - -class OpenAIResponsesCodec(_BaseCodec): - pass - -class AnthropicMessagesCodec(_BaseCodec): - pass - -class PluginContext: - pass - -runtime_module = types.ModuleType("_nemo_guardrails_local_runtime") -runtime_module.LLMRequest = LLMRequest -runtime_module.OpenAIChatCodec = OpenAIChatCodec -runtime_module.OpenAIResponsesCodec = OpenAIResponsesCodec -runtime_module.AnthropicMessagesCodec = AnthropicMessagesCodec -runtime_module.PluginContext = PluginContext -sys.modules["_nemo_guardrails_local_runtime"] = runtime_module - -GUARDRAILS_LOCAL_SOURCE_PATH = pathlib.Path({source_path:?}) -guardrails_local_module = types.ModuleType("_nemo_guardrails_local") -guardrails_local_module.__file__ = str(GUARDRAILS_LOCAL_SOURCE_PATH) -guardrails_local_module.__package__ = "" -sys.modules["_nemo_guardrails_local"] = guardrails_local_module -exec( - compile( - GUARDRAILS_LOCAL_SOURCE_PATH.read_text(), - str(GUARDRAILS_LOCAL_SOURCE_PATH), - "exec", - ), - guardrails_local_module.__dict__, -) -"#, - source_path = source_path, - ) -} - -fn with_isolated_nemo_relay_modules( - py: Python<'_>, - native_module: &Bound<'_, PyModule>, - f: impl FnOnce() -> T, -) -> T { - let sys = py.import("sys").unwrap(); - let modules = sys - .getattr("modules") - .unwrap() - .cast_into::() - .unwrap(); - let saved_modules = modules - .iter() - .filter_map(|(name, module)| { - let name = name.extract::().ok()?; - if name == "nemo_relay" || name.starts_with("nemo_relay.") { - Some((name, module.unbind())) - } else { - None - } - }) - .collect::>(); - - clear_nemo_relay_modules(&modules); - modules - .set_item("nemo_relay._native", native_module.clone()) - .unwrap(); - - let result = catch_unwind(AssertUnwindSafe(f)); - - clear_nemo_relay_modules(&modules); - for (name, module) in saved_modules { - modules.set_item(name, module).unwrap(); - } - - match result { - Ok(value) => value, - Err(payload) => std::panic::resume_unwind(payload), - } -} - -fn clear_nemo_relay_modules(modules: &Bound<'_, PyDict>) { - let module_names = modules - .iter() - .filter_map(|(name, _)| name.extract::().ok()) - .filter(|name| name == "nemo_relay" || name.starts_with("nemo_relay.")) - .collect::>(); - - for name in module_names { - modules.del_item(name).unwrap(); - } -} - fn make_request() -> LlmRequest { LlmRequest { headers: serde_json::Map::from_iter([("x-trace".into(), json!("1"))]), @@ -279,12 +65,6 @@ fn with_event_loop(py: Python<'_>, f: impl FnOnce(Bound<'_, PyAny>) -> T) -> result } -fn reset_runtime_state() { - let _ = clear_plugin_configuration(); - let context = global_context(); - *context.write().unwrap() = NemoRelayContextState::new(); -} - #[test] fn test_native_module_registers_types_and_api_functions() { let _python = crate::test_support::init_python_test(); @@ -314,634 +94,6 @@ fn test_native_pymodule_entrypoint_registers_bindings() { }); } -#[test] -fn test_native_pymodule_entrypoint_registers_bindings_without_local_provider_install() { - let _python = crate::test_support::init_python_test(); - Python::attach(|py| { - let module = PyModule::new(py, "_native_guardrails_provider").unwrap(); - crate::_native(&module).unwrap(); - }); - - let runtime = tokio::runtime::Runtime::new().unwrap(); - let error = runtime - .block_on(initialize_plugins(PluginConfig { - version: 1, - components: vec![PluginComponentSpec { - kind: "nemo_guardrails".to_string(), - enabled: true, - config: serde_json::from_value(json!({ - "mode": "local", - "codec": "openai_chat", - "config_path": "./rails" - })) - .unwrap(), - }], - policy: Default::default(), - })) - .unwrap_err(); - - let _ = clear_plugin_configuration(); - match error { - nemo_relay::plugin::PluginError::RegistrationFailed(message) => { - assert!( - message.contains( - "NeMo Guardrails is required for the built-in NeMo Guardrails local backend" - ), - "unexpected message: {message}" - ); - } - other => panic!("unexpected error: {other}"), - } -} - -#[test] -fn test_guardrails_local_helper_registers_and_enforces_llm_and_tool_checks() { - let _python = crate::test_support::init_python_test(); - Python::attach(|py| { - let native_module = PyModule::new(py, "_native_guardrails_helper").unwrap(); - crate::_native(&native_module).unwrap(); - - with_isolated_nemo_relay_modules(py, &native_module, || { - let python_dir = python_package_dir(); - let prelude = fake_guardrails_module_prelude( - "fake_guardrails_local_helper", - &python_dir.display().to_string(), - ); - let epilogue = register_fake_guardrails_module_epilogue(); - let context_class = local_plugin_context_python(); - let embedded_loader = embedded_guardrails_local_loader_python( - &embedded_guardrails_local_source_path() - .display() - .to_string(), - ); - let module = load_module( - py, - &format!( - r#" -{prelude} - -check_results = [] -check_calls = [] - -class LLMRails: - def __init__(self, config): - self.config = config - - async def check_async(self, messages, rail_types): - check_calls.append((messages, rail_types)) - return check_results.pop(0) - -{epilogue} - -{embedded_loader} - -from nemo_relay._native import LLMRequest -from _nemo_guardrails_local import register_local_backend - -{context_class} - -async def run_case(): - ctx = Context() - register_local_backend( - {{ - "mode": "local", - "codec": "openai_chat", - "config_yaml": "models: []", - "input": True, - "output": True, - "tool_input": True, - "tool_output": True, - "local": {{"python_module": MODULE_NAME}}, - }}, - ctx, - ) - - request = LLMRequest( - {{}}, - {{ - "model": "gpt-4o-mini", - "messages": [{{"role": "user", "content": "unsafe"}}], - }}, - ) - seen_request_messages = [] - - async def next_call(req): - seen_request_messages.append(req.content["messages"][-1]["content"]) - return {{ - "choices": [{{"message": {{"role": "assistant", "content": "safe reply"}}}}], - "id": "resp_1", - "model": "gpt-4o-mini", - }} - - check_results.extend( - [ - Result(RailStatus.MODIFIED, content="sanitized user"), - Result(RailStatus.PASSED), - ] - ) - llm_result = await ctx.llm("demo", request, next_call) - - seen_tool_args = [] - - async def next_tool(args): - seen_tool_args.append(args) - return {{"raw": True}} - - check_results.extend( - [ - Result(RailStatus.MODIFIED, content='{{"arguments": {{"city": "Boston"}}}}'), - Result(RailStatus.MODIFIED, content='{{"result": {{"ok": true}}}}'), - ] - ) - tool_result = await ctx.tool("weather_lookup", {{"city": "Phoenix"}}, next_tool) - - return {{ - "llm_result": llm_result, - "tool_result": tool_result, - "seen_request_messages": seen_request_messages, - "seen_tool_args": seen_tool_args, - "check_calls": check_calls, - }} -"#, - prelude = prelude, - epilogue = epilogue, - context_class = context_class, - embedded_loader = embedded_loader, - ), - ); - - let result_json = with_event_loop(py, |event_loop| { - let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); - let result = event_loop - .call_method1("run_until_complete", (coroutine,)) - .unwrap(); - crate::convert::py_to_json(&result).unwrap() - }); - - assert_eq!( - result_json["seen_request_messages"][0], - json!("sanitized user") - ); - assert_eq!(result_json["tool_result"], json!({ "ok": true })); - assert_eq!( - result_json["seen_tool_args"][0], - json!({ "city": "Boston" }) - ); - assert_eq!( - result_json["llm_result"]["choices"][0]["message"]["content"], - json!("safe reply") - ); - assert_eq!( - result_json["check_calls"], - json!([ - [ - [{"role": "user", "content": "unsafe"}], - ["input"] - ], - [ - [ - {"role": "user", "content": "sanitized user"}, - {"role": "assistant", "content": "safe reply"} - ], - ["output"] - ], - [ - [{"role": "user", "content": "{\"arguments\":{\"city\":\"Phoenix\"},\"tool_name\":\"weather_lookup\"}"}], - ["input"] - ], - [ - [ - {"role": "user", "content": "{\"arguments\":{\"city\":\"Boston\"},\"tool_name\":\"weather_lookup\"}"}, - {"role": "assistant", "content": "{\"arguments\":{\"city\":\"Boston\"},\"result\":{\"raw\":true},\"tool_name\":\"weather_lookup\"}"} - ], - ["output"] - ] - ]) - ); - }); - }); -} - -#[test] -fn test_guardrails_local_helper_rejects_unsupported_nemoguardrails_version() { - let _python = crate::test_support::init_python_test(); - Python::attach(|py| { - let native_module = PyModule::new(py, "_native_guardrails_version").unwrap(); - crate::_native(&native_module).unwrap(); - - with_isolated_nemo_relay_modules(py, &native_module, || { - let python_dir = python_package_dir(); - let prelude = fake_guardrails_module_prelude( - "fake_guardrails_bad_version", - &python_dir.display().to_string(), - ); - let epilogue = register_fake_guardrails_module_epilogue(); - let context_class = local_plugin_context_python(); - let embedded_loader = embedded_guardrails_local_loader_python( - &embedded_guardrails_local_source_path() - .display() - .to_string(), - ); - let module = load_module( - py, - &format!( - r#" -{prelude} - -fake_root.__version__ = "0.21.0" - -class LLMRails: - def __init__(self, config): - self.config = config - - async def check_async(self, messages, rail_types): - return Result(RailStatus.PASSED) - -{epilogue} - -{embedded_loader} - -from _nemo_guardrails_local import register_local_backend - -{context_class} - -async def run_case(): - ctx = Context() - register_local_backend( - {{ - "mode": "local", - "codec": "openai_chat", - "config_yaml": "models: []", - "input": True, - "local": {{"python_module": MODULE_NAME}}, - }}, - ctx, - ) -"#, - prelude = prelude, - epilogue = epilogue, - embedded_loader = embedded_loader, - context_class = context_class, - ), - ); - - let error = with_event_loop(py, |event_loop| { - let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); - event_loop - .call_method1("run_until_complete", (coroutine,)) - .unwrap_err() - .to_string() - }); - - assert!( - error.contains("requires nemoguardrails==0.22.0"), - "unexpected error: {error}" - ); - assert!(error.contains("0.21.0"), "unexpected error: {error}"); - }); - }); -} - -#[test] -fn test_guardrails_local_helper_enforces_streamed_output_rails() { - let _python = crate::test_support::init_python_test(); - Python::attach(|py| { - let native_module = PyModule::new(py, "_native_guardrails_streaming").unwrap(); - crate::_native(&native_module).unwrap(); - - with_isolated_nemo_relay_modules(py, &native_module, || { - let python_dir = python_package_dir(); - let prelude = fake_guardrails_module_prelude( - "fake_guardrails_streaming", - &python_dir.display().to_string(), - ); - let epilogue = register_fake_guardrails_module_epilogue(); - let context_class = local_plugin_context_python(); - let embedded_loader = embedded_guardrails_local_loader_python( - &embedded_guardrails_local_source_path() - .display() - .to_string(), - ); - let module = load_module( - py, - &format!( - r#" -{prelude} - -stream_results = [] -event_log = [] - -class LLMRails: - def __init__(self, config): - self.config = types.SimpleNamespace( - rails=types.SimpleNamespace( - output=types.SimpleNamespace( - flows=["self check output"], - streaming=types.SimpleNamespace(enabled=True, stream_first=True), - ) - ) - ) - - async def check_async(self, messages, rail_types): - return Result(RailStatus.PASSED) - - def stream_async(self, *, messages=None, generator=None, include_metadata=False): - async def _run(): - outcome = stream_results.pop(0) - async for chunk in generator: - event_log.append(f"guardrails-sees:{{chunk}}") - if outcome == "pass": - yield chunk - if outcome == "block": - yield '{{"error": {{"message": "Blocked by output rails: output-policy", "type": "guardrails_violation"}}}}' - return _run() - -{epilogue} - -{embedded_loader} - -from nemo_relay._native import LLMRequest -from _nemo_guardrails_local import register_local_backend - -{context_class} - -async def run_case(): - ctx = Context() - event_log.clear() - register_local_backend( - {{ - "mode": "local", - "codec": "openai_chat", - "config_yaml": "models: []", - "input": False, - "output": True, - "local": {{"python_module": MODULE_NAME}}, - }}, - ctx, - ) - - request = LLMRequest( - {{}}, - {{ - "model": "gpt-4o-mini", - "messages": [{{"role": "user", "content": "hello"}}], - }}, - ) - - async def next_call(req): - async def _stream(): - event_log.append("source:hello") - yield {{"choices": [{{"delta": {{"content": "hello"}}}}]}} - event_log.append("source:world") - yield {{"choices": [{{"delta": {{"content": "world"}}}}]}} - return _stream() - - stream_results.append("pass") - allowed_stream = await ctx.stream(request, next_call) - allowed_chunks = [] - async for chunk in allowed_stream: - event_log.append(f"yield:{{chunk['choices'][0]['delta']['content']}}") - allowed_chunks.append(chunk) - - stream_results.append("block") - try: - blocked_stream = await ctx.stream(request, next_call) - async for _chunk in blocked_stream: - pass - except RuntimeError as error: - blocked = str(error) - else: - raise AssertionError("expected streamed output block") - - ctx_stream_first_false = Context() - fake_root.LLMRails = lambda config: types.SimpleNamespace( - config=types.SimpleNamespace( - rails=types.SimpleNamespace( - output=types.SimpleNamespace( - flows=["self check output"], - streaming=types.SimpleNamespace(enabled=True, stream_first=False), - ) - ) - ), - check_async=LLMRails(config).check_async, - stream_async=LLMRails(config).stream_async, - ) - register_local_backend( - {{ - "mode": "local", - "codec": "openai_chat", - "config_yaml": "models: []", - "input": False, - "output": True, - "local": {{"python_module": MODULE_NAME}}, - }}, - ctx_stream_first_false, - ) - try: - failing_stream = await ctx_stream_first_false.stream(request, next_call) - async for _chunk in failing_stream: - pass - except RuntimeError as error: - modified = str(error) - else: - raise AssertionError("expected stream_first=false error") - - return {{ - "allowed_chunks": allowed_chunks, - "blocked": blocked, - "event_log": event_log, - "modified": modified, - }} -"#, - prelude = prelude, - epilogue = epilogue, - context_class = context_class, - embedded_loader = embedded_loader, - ), - ); - - let result = with_event_loop(py, |event_loop| { - let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); - let result = event_loop - .call_method1("run_until_complete", (coroutine,)) - .unwrap(); - crate::convert::py_to_json(&result).unwrap() - }); - assert_eq!( - result["allowed_chunks"], - json!([ - {"choices": [{"delta": {"content": "hello"}}]}, - {"choices": [{"delta": {"content": "world"}}]} - ]) - ); - let event_log = result["event_log"].as_array().unwrap(); - assert_eq!( - &event_log[..6], - json!([ - "source:hello", - "yield:hello", - "source:world", - "yield:world", - "guardrails-sees:hello", - "guardrails-sees:world", - ]) - .as_array() - .unwrap() - ); - assert!( - result["blocked"] - .as_str() - .unwrap() - .contains("output rail blocked the LLM call") - ); - assert!( - result["modified"] - .as_str() - .unwrap() - .contains("stream_first = true") - ); - }); - }); -} - -#[test] -fn test_local_guardrails_provider_initializes_and_enforces_managed_core_calls() { - let _python = crate::test_support::init_python_test(); - reset_runtime_state(); - - Python::attach(|py| { - let native_module = PyModule::new(py, "_native_guardrails_e2e").unwrap(); - crate::_native(&native_module).unwrap(); - - with_isolated_nemo_relay_modules(py, &native_module, || { - let python_dir = python_package_dir(); - let prelude = fake_guardrails_module_prelude( - "fake_guardrails_local_e2e", - &python_dir.display().to_string(), - ); - let epilogue = register_fake_guardrails_module_epilogue(); - let module = load_module( - py, - &format!( - r#" -{prelude} - -check_results = [] - -class LLMRails: - def __init__(self, config): - self.config = config - - async def check_async(self, messages, rail_types): - return check_results.pop(0) - -{epilogue} - -import nemo_relay - -async def run_case(): - stack = nemo_relay.create_scope_stack() - nemo_relay.set_thread_scope_stack(stack) - - await nemo_relay.plugin.initialize( - {{ - "version": 1, - "components": [ - {{ - "kind": "nemo_guardrails", - "enabled": True, - "config": {{ - "mode": "local", - "codec": "openai_chat", - "config_yaml": "models: []", - "input": True, - "output": True, - "tool_input": True, - "tool_output": True, - "local": {{"python_module": MODULE_NAME}}, - }}, - }} - ], - }} - ) - - check_results.extend( - [ - Result(RailStatus.MODIFIED, content="sanitized user"), - Result(RailStatus.PASSED), - Result(RailStatus.MODIFIED, content='{{"arguments": {{"city": "Boston"}}}}'), - Result(RailStatus.MODIFIED, content='{{"result": {{"ok": true}}}}'), - ] - ) - - request = nemo_relay.LLMRequest( - {{}}, - {{ - "model": "gpt-4o-mini", - "messages": [{{"role": "user", "content": "unsafe"}}], - }}, - ) - - seen_request_messages = [] - async def llm_impl(req): - seen_request_messages.append(req.content["messages"][-1]["content"]) - return {{ - "choices": [{{"message": {{"role": "assistant", "content": "safe reply"}}}}], - "id": "resp_1", - "model": req.content["model"], - }} - - llm_result = await nemo_relay.llm.execute( - "demo", - request, - llm_impl, - response_codec=nemo_relay.codecs.OpenAIChatCodec(), - ) - - seen_tool_args = [] - async def tool_impl(args): - seen_tool_args.append(args) - return {{"raw": True}} - - tool_result = await nemo_relay.tools.execute("weather_lookup", {{"city": "Phoenix"}}, tool_impl) - return {{ - "llm_result": llm_result, - "tool_result": tool_result, - "seen_request_messages": seen_request_messages, - "seen_tool_args": seen_tool_args, - }} -"#, - prelude = prelude, - epilogue = epilogue, - ), - ); - let result_json = with_event_loop(py, |event_loop| { - let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); - let result = event_loop - .call_method1("run_until_complete", (coroutine,)) - .unwrap(); - crate::convert::py_to_json(&result).unwrap() - }); - - assert_eq!( - result_json["llm_result"]["choices"][0]["message"]["content"], - json!("safe reply") - ); - assert_eq!(result_json["tool_result"], json!({ "ok": true })); - assert_eq!( - result_json["seen_request_messages"][0], - json!("sanitized user") - ); - assert_eq!( - result_json["seen_tool_args"][0], - json!({ "city": "Boston" }) - ); - }); - }); - - reset_runtime_state(); -} - #[test] fn test_python_test_guard_restores_existing_runtime_env() { let lock = crate::test_support::lock_python_test(); diff --git a/crates/python/tests/coverage/nemo_guardrails_coverage_tests.rs b/crates/python/tests/coverage/nemo_guardrails_coverage_tests.rs new file mode 100644 index 00000000..7c1c3893 --- /dev/null +++ b/crates/python/tests/coverage/nemo_guardrails_coverage_tests.rs @@ -0,0 +1,819 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Coverage tests for Python-facing local NeMo Guardrails integration. + +use std::ffi::CString; +use std::panic::{AssertUnwindSafe, catch_unwind}; +use std::path::PathBuf; + +use nemo_relay::api::runtime::{NemoRelayContextState, global_context}; +use nemo_relay::plugin::{ + PluginComponentSpec, PluginConfig, clear_plugin_configuration, initialize_plugins, +}; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyModule}; +use serde_json::json; + +fn load_module<'py>(py: Python<'py>, code: &str) -> Bound<'py, PyModule> { + let code = CString::new(code).unwrap(); + let file_name = CString::new("nemo_guardrails_coverage_tests.py").unwrap(); + let module_name = CString::new("nemo_guardrails_coverage_tests").unwrap(); + PyModule::from_code(py, &code, &file_name, &module_name).unwrap() +} + +fn python_package_dir() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../python") +} + +fn fake_guardrails_module_prelude(module_name: &str, python_dir: &str) -> String { + format!( + r#" +import sys +import types + +sys.path.insert(0, {python_dir:?}) + +MODULE_NAME = {module_name:?} + +fake_root = types.ModuleType(MODULE_NAME) +fake_root.__version__ = "0.22.0" +fake_options = types.ModuleType(MODULE_NAME + ".rails.llm.options") + +class Result: + def __init__(self, status, content=None, rail=None): + self.status = status + self.content = content + self.rail = rail + +class RailType: + INPUT = "input" + OUTPUT = "output" + +class RailStatus: + BLOCKED = "blocked" + MODIFIED = "modified" + PASSED = "passed" + +class RailsConfig: + @staticmethod + def from_content(*, colang_content=None, yaml_content=None): + return {{"yaml": yaml_content, "colang": colang_content}} + + @staticmethod + def from_path(path): + return {{"path": path}} +"#, + python_dir = python_dir, + module_name = module_name, + ) +} + +fn register_fake_guardrails_module_epilogue() -> &'static str { + r#" +fake_root.RailsConfig = RailsConfig +fake_root.LLMRails = LLMRails +fake_options.RailType = RailType +fake_options.RailStatus = RailStatus + +sys.modules[MODULE_NAME] = fake_root +sys.modules[MODULE_NAME + ".rails"] = types.ModuleType(MODULE_NAME + ".rails") +sys.modules[MODULE_NAME + ".rails.llm"] = types.ModuleType(MODULE_NAME + ".rails.llm") +sys.modules[MODULE_NAME + ".rails.llm.options"] = fake_options +"# +} + +fn with_isolated_nemo_relay_modules( + py: Python<'_>, + native_module: &Bound<'_, PyModule>, + f: impl FnOnce() -> T, +) -> T { + let sys = py.import("sys").unwrap(); + let modules = sys + .getattr("modules") + .unwrap() + .cast_into::() + .unwrap(); + let saved_modules = modules + .iter() + .filter_map(|(name, module)| { + let name = name.extract::().ok()?; + if name == "nemo_relay" || name.starts_with("nemo_relay.") { + Some((name, module.unbind())) + } else { + None + } + }) + .collect::>(); + + clear_nemo_relay_modules(&modules); + modules + .set_item("nemo_relay._native", native_module.clone()) + .unwrap(); + + let result = catch_unwind(AssertUnwindSafe(f)); + + clear_nemo_relay_modules(&modules); + for (name, module) in saved_modules { + modules.set_item(name, module).unwrap(); + } + + match result { + Ok(value) => value, + Err(payload) => std::panic::resume_unwind(payload), + } +} + +fn clear_nemo_relay_modules(modules: &Bound<'_, PyDict>) { + let module_names = modules + .iter() + .filter_map(|(name, _)| name.extract::().ok()) + .filter(|name| name == "nemo_relay" || name.starts_with("nemo_relay.")) + .collect::>(); + + for name in module_names { + modules.del_item(name).unwrap(); + } +} + +fn with_event_loop(py: Python<'_>, f: impl FnOnce(Bound<'_, PyAny>) -> T) -> T { + let asyncio = py.import("asyncio").unwrap(); + #[cfg(windows)] + { + let policy = asyncio + .getattr("WindowsSelectorEventLoopPolicy") + .unwrap() + .call0() + .unwrap(); + asyncio + .call_method1("set_event_loop_policy", (policy,)) + .unwrap(); + } + let event_loop = asyncio.call_method0("new_event_loop").unwrap(); + asyncio + .call_method1("set_event_loop", (&event_loop,)) + .unwrap(); + let result = f(event_loop.clone().into_any()); + asyncio + .call_method1("set_event_loop", (py.None(),)) + .unwrap(); + event_loop.call_method0("close").unwrap(); + result +} + +fn reset_runtime_state() { + let _ = clear_plugin_configuration(); + let context = global_context(); + *context.write().unwrap() = NemoRelayContextState::new(); +} + +#[test] +fn test_native_pymodule_entrypoint_registers_bindings_without_local_provider_install() { + let _python = crate::test_support::init_python_test(); + reset_runtime_state(); + Python::attach(|py| { + let module = PyModule::new(py, "_native_guardrails_provider").unwrap(); + crate::_native(&module).unwrap(); + }); + + let runtime = tokio::runtime::Runtime::new().unwrap(); + let error = runtime + .block_on(initialize_plugins(PluginConfig { + version: 1, + components: vec![PluginComponentSpec { + kind: "nemo_guardrails".to_string(), + enabled: true, + config: serde_json::from_value(json!({ + "mode": "local", + "codec": "openai_chat", + "config_path": "./rails" + })) + .unwrap(), + }], + policy: Default::default(), + })) + .unwrap_err(); + + reset_runtime_state(); + match error { + nemo_relay::plugin::PluginError::RegistrationFailed(message) => { + assert!( + message.contains( + "NeMo Guardrails is required for the built-in NeMo Guardrails local backend" + ), + "unexpected message: {message}" + ); + } + other => panic!("unexpected error: {other}"), + } +} + +#[test] +fn test_guardrails_local_runtime_registers_and_enforces_llm_and_tool_checks() { + let _python = crate::test_support::init_python_test(); + reset_runtime_state(); + + Python::attach(|py| { + let native_module = PyModule::new(py, "_native_guardrails_local_runtime").unwrap(); + crate::_native(&native_module).unwrap(); + + with_isolated_nemo_relay_modules(py, &native_module, || { + let python_dir = python_package_dir(); + let prelude = fake_guardrails_module_prelude( + "fake_guardrails_local_runtime", + &python_dir.display().to_string(), + ); + let epilogue = register_fake_guardrails_module_epilogue(); + let module = load_module( + py, + &format!( + r#" +{prelude} + +check_results = [] +check_calls = [] + +class LLMRails: + def __init__(self, config): + self.config = config + + async def check_async(self, messages, rail_types): + check_calls.append((messages, rail_types)) + return check_results.pop(0) + +{epilogue} + +import nemo_relay + +async def run_case(): + stack = nemo_relay.create_scope_stack() + nemo_relay.set_thread_scope_stack(stack) + await nemo_relay.plugin.initialize( + {{ + "version": 1, + "components": [ + {{ + "kind": "nemo_guardrails", + "enabled": True, + "config": {{ + "mode": "local", + "codec": "openai_chat", + "config_yaml": "models: []", + "input": True, + "output": True, + "tool_input": True, + "tool_output": True, + "local": {{"python_module": MODULE_NAME}}, + }}, + }} + ], + }} + ) + + request = nemo_relay.LLMRequest( + {{}}, + {{ + "model": "gpt-4o-mini", + "messages": [{{"role": "user", "content": "unsafe"}}], + }}, + ) + seen_request_messages = [] + + async def next_call(req): + seen_request_messages.append(req.content["messages"][-1]["content"]) + return {{ + "choices": [{{"message": {{"role": "assistant", "content": "safe reply"}}}}], + "id": "resp_1", + "model": "gpt-4o-mini", + }} + + check_results.extend( + [ + Result(RailStatus.MODIFIED, content="sanitized user"), + Result(RailStatus.PASSED), + ] + ) + llm_result = await nemo_relay.llm.execute( + "demo", + request, + next_call, + response_codec=nemo_relay.codecs.OpenAIChatCodec(), + ) + + seen_tool_args = [] + + async def next_tool(args): + seen_tool_args.append(args) + return {{"raw": True}} + + check_results.extend( + [ + Result(RailStatus.MODIFIED, content='{{"arguments": {{"city": "Boston"}}}}'), + Result(RailStatus.MODIFIED, content='{{"result": {{"ok": true}}}}'), + ] + ) + tool_result = await nemo_relay.tools.execute("weather_lookup", {{"city": "Phoenix"}}, next_tool) + + return {{ + "llm_result": llm_result, + "tool_result": tool_result, + "seen_request_messages": seen_request_messages, + "seen_tool_args": seen_tool_args, + "check_calls": check_calls, + }} +"#, + prelude = prelude, + epilogue = epilogue, + ), + ); + + let result_json = with_event_loop(py, |event_loop| { + let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); + let result = event_loop + .call_method1("run_until_complete", (coroutine,)) + .unwrap(); + crate::convert::py_to_json(&result).unwrap() + }); + + assert_eq!( + result_json["seen_request_messages"][0], + json!("sanitized user") + ); + assert_eq!(result_json["tool_result"], json!({ "ok": true })); + assert_eq!( + result_json["seen_tool_args"][0], + json!({ "city": "Boston" }) + ); + assert_eq!( + result_json["llm_result"]["choices"][0]["message"]["content"], + json!("safe reply") + ); + assert_eq!( + result_json["check_calls"], + json!([ + [ + [{"role": "user", "content": "unsafe"}], + ["input"] + ], + [ + [ + {"role": "user", "content": "sanitized user"}, + {"role": "assistant", "content": "safe reply"} + ], + ["output"] + ], + [ + [{"role": "user", "content": "{\"arguments\":{\"city\":\"Phoenix\"},\"tool_name\":\"weather_lookup\"}"}], + ["input"] + ], + [ + [ + {"role": "user", "content": "{\"arguments\":{\"city\":\"Boston\"},\"tool_name\":\"weather_lookup\"}"}, + {"role": "assistant", "content": "{\"arguments\":{\"city\":\"Boston\"},\"result\":{\"raw\":true},\"tool_name\":\"weather_lookup\"}"} + ], + ["output"] + ] + ]) + ); + }); + }); + + reset_runtime_state(); +} + +#[test] +fn test_guardrails_local_runtime_rejects_unsupported_nemoguardrails_version() { + let _python = crate::test_support::init_python_test(); + reset_runtime_state(); + + Python::attach(|py| { + let native_module = PyModule::new(py, "_native_guardrails_version").unwrap(); + crate::_native(&native_module).unwrap(); + + with_isolated_nemo_relay_modules(py, &native_module, || { + let python_dir = python_package_dir(); + let prelude = fake_guardrails_module_prelude( + "fake_guardrails_bad_version", + &python_dir.display().to_string(), + ); + let epilogue = register_fake_guardrails_module_epilogue(); + let module = load_module( + py, + &format!( + r#" +{prelude} + +fake_root.__version__ = "0.21.0" + +class LLMRails: + def __init__(self, config): + self.config = config + + async def check_async(self, messages, rail_types): + return Result(RailStatus.PASSED) + +{epilogue} + +import nemo_relay + +async def run_case(): + await nemo_relay.plugin.initialize( + {{ + "version": 1, + "components": [ + {{ + "kind": "nemo_guardrails", + "enabled": True, + "config": {{ + "mode": "local", + "codec": "openai_chat", + "config_yaml": "models: []", + "input": True, + "local": {{"python_module": MODULE_NAME}}, + }}, + }} + ], + }} + ) +"#, + prelude = prelude, + epilogue = epilogue, + ), + ); + + let error = with_event_loop(py, |event_loop| { + let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); + event_loop + .call_method1("run_until_complete", (coroutine,)) + .unwrap_err() + .to_string() + }); + + assert!( + error.contains("requires nemoguardrails==0.22.0"), + "unexpected error: {error}" + ); + assert!(error.contains("0.21.0"), "unexpected error: {error}"); + }); + }); + + reset_runtime_state(); +} + +#[test] +fn test_guardrails_local_runtime_enforces_streamed_output_rails() { + let _python = crate::test_support::init_python_test(); + reset_runtime_state(); + + Python::attach(|py| { + let native_module = PyModule::new(py, "_native_guardrails_streaming").unwrap(); + crate::_native(&native_module).unwrap(); + + with_isolated_nemo_relay_modules(py, &native_module, || { + let python_dir = python_package_dir(); + let prelude = fake_guardrails_module_prelude( + "fake_guardrails_streaming", + &python_dir.display().to_string(), + ); + let epilogue = register_fake_guardrails_module_epilogue(); + let module = load_module( + py, + &format!( + r#" +{prelude} + +stream_results = [] +event_log = [] + +class LLMRails: + def __init__(self, config): + self.config = types.SimpleNamespace( + rails=types.SimpleNamespace( + output=types.SimpleNamespace( + flows=["self check output"], + streaming=types.SimpleNamespace(enabled=True, stream_first=True), + ) + ) + ) + + async def check_async(self, messages, rail_types): + return Result(RailStatus.PASSED) + + def stream_async(self, *, messages=None, generator=None, include_metadata=False): + async def _run(): + outcome = stream_results.pop(0) + async for chunk in generator: + event_log.append(f"guardrails-sees:{{chunk}}") + if outcome == "pass": + yield chunk + if outcome == "block": + yield '{{"error": {{"message": "Blocked by output rails: output-policy", "type": "guardrails_violation"}}}}' + return _run() + +{epilogue} + +import nemo_relay + +def plugin_config(): + return {{ + "version": 1, + "components": [ + {{ + "kind": "nemo_guardrails", + "enabled": True, + "config": {{ + "mode": "local", + "codec": "openai_chat", + "config_yaml": "models: []", + "input": False, + "output": True, + "local": {{"python_module": MODULE_NAME}}, + }}, + }} + ], + }} + +async def run_stream(request): + collected = [] + + def next_call(req): + async def _stream(): + event_log.append("source:hello") + yield {{"choices": [{{"delta": {{"content": "hello"}}}}]}} + event_log.append("source:world") + yield {{"choices": [{{"delta": {{"content": "world"}}}}]}} + return _stream() + + stream = await nemo_relay.llm.stream_execute( + "demo", + request, + next_call, + collected.append, + lambda: {{"chunks": collected}}, + response_codec=nemo_relay.codecs.OpenAIChatCodec(), + ) + chunks = [] + async for chunk in stream: + event_log.append(f"yield:{{chunk['choices'][0]['delta']['content']}}") + chunks.append(chunk) + return chunks + +async def run_case(): + stack = nemo_relay.create_scope_stack() + nemo_relay.set_thread_scope_stack(stack) + event_log.clear() + await nemo_relay.plugin.initialize(plugin_config()) + + request = nemo_relay.LLMRequest( + {{}}, + {{ + "model": "gpt-4o-mini", + "messages": [{{"role": "user", "content": "hello"}}], + }}, + ) + + stream_results.append("pass") + allowed_chunks = await run_stream(request) + + stream_results.append("block") + try: + await run_stream(request) + except RuntimeError as error: + blocked = str(error) + else: + raise AssertionError("expected streamed output block") + + nemo_relay.plugin.clear() + fake_root.LLMRails = lambda config: types.SimpleNamespace( + config=types.SimpleNamespace( + rails=types.SimpleNamespace( + output=types.SimpleNamespace( + flows=["self check output"], + streaming=types.SimpleNamespace(enabled=True, stream_first=False), + ) + ) + ), + check_async=LLMRails(config).check_async, + stream_async=LLMRails(config).stream_async, + ) + await nemo_relay.plugin.initialize(plugin_config()) + try: + await run_stream(request) + except RuntimeError as error: + modified = str(error) + else: + raise AssertionError("expected stream_first=false error") + + return {{ + "allowed_chunks": allowed_chunks, + "blocked": blocked, + "event_log": event_log, + "modified": modified, + }} +"#, + prelude = prelude, + epilogue = epilogue, + ), + ); + + let result = with_event_loop(py, |event_loop| { + let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); + let result = event_loop + .call_method1("run_until_complete", (coroutine,)) + .unwrap(); + crate::convert::py_to_json(&result).unwrap() + }); + assert_eq!( + result["allowed_chunks"], + json!([ + {"choices": [{"delta": {"content": "hello"}}]}, + {"choices": [{"delta": {"content": "world"}}]} + ]) + ); + let event_log = result["event_log"].as_array().unwrap(); + for expected in [ + "source:hello", + "source:world", + "yield:hello", + "yield:world", + "guardrails-sees:hello", + "guardrails-sees:world", + ] { + assert!( + event_log.iter().any(|event| event == expected), + "missing event {expected}: {event_log:?}" + ); + } + let yield_hello = event_log + .iter() + .position(|event| event == "yield:hello") + .unwrap(); + let guardrails_hello = event_log + .iter() + .position(|event| event == "guardrails-sees:hello") + .unwrap(); + let yield_world = event_log + .iter() + .position(|event| event == "yield:world") + .unwrap(); + let guardrails_world = event_log + .iter() + .position(|event| event == "guardrails-sees:world") + .unwrap(); + assert!(yield_hello < guardrails_hello); + assert!(yield_world < guardrails_world); + assert!( + result["blocked"] + .as_str() + .unwrap() + .contains("output rail blocked the LLM call") + ); + assert!( + result["modified"] + .as_str() + .unwrap() + .contains("stream_first = true") + ); + }); + }); + + reset_runtime_state(); +} + +#[test] +fn test_local_guardrails_provider_initializes_and_enforces_managed_core_calls() { + let _python = crate::test_support::init_python_test(); + reset_runtime_state(); + + Python::attach(|py| { + let native_module = PyModule::new(py, "_native_guardrails_e2e").unwrap(); + crate::_native(&native_module).unwrap(); + + with_isolated_nemo_relay_modules(py, &native_module, || { + let python_dir = python_package_dir(); + let prelude = fake_guardrails_module_prelude( + "fake_guardrails_local_e2e", + &python_dir.display().to_string(), + ); + let epilogue = register_fake_guardrails_module_epilogue(); + let module = load_module( + py, + &format!( + r#" +{prelude} + +check_results = [] + +class LLMRails: + def __init__(self, config): + self.config = config + + async def check_async(self, messages, rail_types): + return check_results.pop(0) + +{epilogue} + +import nemo_relay + +async def run_case(): + stack = nemo_relay.create_scope_stack() + nemo_relay.set_thread_scope_stack(stack) + + await nemo_relay.plugin.initialize( + {{ + "version": 1, + "components": [ + {{ + "kind": "nemo_guardrails", + "enabled": True, + "config": {{ + "mode": "local", + "codec": "openai_chat", + "config_yaml": "models: []", + "input": True, + "output": True, + "tool_input": True, + "tool_output": True, + "local": {{"python_module": MODULE_NAME}}, + }}, + }} + ], + }} + ) + + check_results.extend( + [ + Result(RailStatus.MODIFIED, content="sanitized user"), + Result(RailStatus.PASSED), + Result(RailStatus.MODIFIED, content='{{"arguments": {{"city": "Boston"}}}}'), + Result(RailStatus.MODIFIED, content='{{"result": {{"ok": true}}}}'), + ] + ) + + request = nemo_relay.LLMRequest( + {{}}, + {{ + "model": "gpt-4o-mini", + "messages": [{{"role": "user", "content": "unsafe"}}], + }}, + ) + + seen_request_messages = [] + async def llm_impl(req): + seen_request_messages.append(req.content["messages"][-1]["content"]) + return {{ + "choices": [{{"message": {{"role": "assistant", "content": "safe reply"}}}}], + "id": "resp_1", + "model": req.content["model"], + }} + + llm_result = await nemo_relay.llm.execute( + "demo", + request, + llm_impl, + response_codec=nemo_relay.codecs.OpenAIChatCodec(), + ) + + seen_tool_args = [] + async def tool_impl(args): + seen_tool_args.append(args) + return {{"raw": True}} + + tool_result = await nemo_relay.tools.execute("weather_lookup", {{"city": "Phoenix"}}, tool_impl) + return {{ + "llm_result": llm_result, + "tool_result": tool_result, + "seen_request_messages": seen_request_messages, + "seen_tool_args": seen_tool_args, + }} +"#, + prelude = prelude, + epilogue = epilogue, + ), + ); + let result_json = with_event_loop(py, |event_loop| { + let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); + let result = event_loop + .call_method1("run_until_complete", (coroutine,)) + .unwrap(); + crate::convert::py_to_json(&result).unwrap() + }); + + assert_eq!( + result_json["llm_result"]["choices"][0]["message"]["content"], + json!("safe reply") + ); + assert_eq!(result_json["tool_result"], json!({ "ok": true })); + assert_eq!( + result_json["seen_request_messages"][0], + json!("sanitized user") + ); + assert_eq!( + result_json["seen_tool_args"][0], + json!({ "city": "Boston" }) + ); + }); + }); + + reset_runtime_state(); +} From 7e0b3af454a868b1f08beda666b6f7ae4a1d3644 Mon Sep 17 00:00:00 2001 From: Will Killian Date: Fri, 5 Jun 2026 09:00:11 -0400 Subject: [PATCH 11/20] refactor: gate fn impl on cfg rather than branch Signed-off-by: Will Killian --- .../core/src/plugins/nemo_guardrails/local.rs | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/crates/core/src/plugins/nemo_guardrails/local.rs b/crates/core/src/plugins/nemo_guardrails/local.rs index f46dea2f..0bc18480 100644 --- a/crates/core/src/plugins/nemo_guardrails/local.rs +++ b/crates/core/src/plugins/nemo_guardrails/local.rs @@ -1,23 +1,29 @@ // SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::plugin::{PluginError, PluginRegistrationContext, Result as PluginResult}; +use crate::plugin::{PluginRegistrationContext, Result as PluginResult}; + +#[cfg(not(feature = "python"))] +use crate::plugin::PluginError; use super::NeMoGuardrailsConfig; #[cfg(feature = "python")] mod python; +#[cfg(feature = "python")] pub(super) fn register_local_backend( config: NeMoGuardrailsConfig, ctx: &mut PluginRegistrationContext, ) -> PluginResult<()> { - #[cfg(feature = "python")] - { - return python::register_local_backend(config, ctx); - } + python::register_local_backend(config, ctx) +} - #[allow(unreachable_code)] +#[cfg(not(feature = "python"))] +pub(super) fn register_local_backend( + config: NeMoGuardrailsConfig, + ctx: &mut PluginRegistrationContext, +) -> PluginResult<()> { Err(PluginError::RegistrationFailed( "built-in NeMo Guardrails local backend is unavailable in this build".to_string(), )) From 9ad985ee371f379b8a2242c26df5df853089134a Mon Sep 17 00:00:00 2001 From: Will Killian Date: Fri, 5 Jun 2026 11:28:25 -0400 Subject: [PATCH 12/20] Centralize Python Rust dependency versions Signed-off-by: Will Killian --- Cargo.toml | 3 +++ crates/core/Cargo.toml | 6 +++--- crates/python/Cargo.toml | 6 +++--- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bf54af16..e45741b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,9 @@ nemo-relay = { version = "0.4.0", path = "crates/core", default-features = false nemo-relay-adaptive = { version = "0.4.0", path = "crates/adaptive" } nemo-relay-ffi = { version = "0.4.0", path = "crates/ffi" } nemo-relay-cli = { version = "0.4.0", path = "crates/cli" } +pyo3 = "0.28.2" +pyo3-async-runtimes = "0.28.0" +pythonize = "0.28.0" uuid = "=1.18.1" [workspace.lints.rust] diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index c69c8226..e5ca5b93 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -83,9 +83,9 @@ opentelemetry-http = { version = "0.31", default-features = false, optional = tr wasm-bindgen = { version = "0.2", optional = true } wasm-bindgen-futures = { version = "0.4", optional = true } web-sys = { version = "0.3", features = ["Headers", "Request", "RequestInit", "Response", "Window", "console"], optional = true } -pyo3 = { version = "0.28.2", features = ["auto-initialize"], optional = true } -pyo3-async-runtimes = { version = "0.28.0", features = ["tokio-runtime"], optional = true } -pythonize = { version = "0.28.0", optional = true } +pyo3 = { workspace = true, features = ["auto-initialize"], optional = true } +pyo3-async-runtimes = { workspace = true, features = ["tokio-runtime"], optional = true } +pythonize = { workspace = true, optional = true } [dev-dependencies] tokio = { version = "1", features = ["rt", "macros", "sync", "test-util", "rt-multi-thread", "time"] } diff --git a/crates/python/Cargo.toml b/crates/python/Cargo.toml index 34655038..1978031c 100644 --- a/crates/python/Cargo.toml +++ b/crates/python/Cargo.toml @@ -20,9 +20,9 @@ crate-type = ["cdylib", "rlib"] [dependencies] nemo-relay = { workspace = true, features = ["otel", "openinference", "python"] } nemo-relay-adaptive = { workspace = true, features = ["redis-backend"] } -pyo3 = { version = "0.28.2", features = ["abi3", "abi3-py311", "experimental-inspect", "macros"] } -pyo3-async-runtimes = { version = "0.28.0", features = ["tokio-runtime"] } -pythonize = "0.28.0" +pyo3 = { workspace = true, features = ["abi3", "abi3-py311", "experimental-inspect", "macros"] } +pyo3-async-runtimes = { workspace = true, features = ["tokio-runtime"] } +pythonize = { workspace = true } serde_json = "1" serde = "1" uuid = { workspace = true, features = ["v7"] } From 5dba153993302f27d5d6e537dd76fdc7eab73df8 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Fri, 5 Jun 2026 09:01:15 -0700 Subject: [PATCH 13/20] fix: stabilize local guardrails branch checks Signed-off-by: Alex Fournier --- .../core/src/plugins/nemo_guardrails/local.rs | 17 +++++++---- .../python/tests/coverage/coverage_tests.rs | 30 +++++++++++-------- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/crates/core/src/plugins/nemo_guardrails/local.rs b/crates/core/src/plugins/nemo_guardrails/local.rs index f46dea2f..6f1aa5a8 100644 --- a/crates/core/src/plugins/nemo_guardrails/local.rs +++ b/crates/core/src/plugins/nemo_guardrails/local.rs @@ -1,23 +1,28 @@ // SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::plugin::{PluginError, PluginRegistrationContext, Result as PluginResult}; +#[cfg(not(feature = "python"))] +use crate::plugin::PluginError; +use crate::plugin::{PluginRegistrationContext, Result as PluginResult}; use super::NeMoGuardrailsConfig; #[cfg(feature = "python")] mod python; +#[cfg(feature = "python")] pub(super) fn register_local_backend( config: NeMoGuardrailsConfig, ctx: &mut PluginRegistrationContext, ) -> PluginResult<()> { - #[cfg(feature = "python")] - { - return python::register_local_backend(config, ctx); - } + python::register_local_backend(config, ctx) +} - #[allow(unreachable_code)] +#[cfg(not(feature = "python"))] +pub(super) fn register_local_backend( + _config: NeMoGuardrailsConfig, + _ctx: &mut PluginRegistrationContext, +) -> PluginResult<()> { Err(PluginError::RegistrationFailed( "built-in NeMo Guardrails local backend is unavailable in this build".to_string(), )) diff --git a/crates/python/tests/coverage/coverage_tests.rs b/crates/python/tests/coverage/coverage_tests.rs index 98d00c92..eb92f510 100644 --- a/crates/python/tests/coverage/coverage_tests.rs +++ b/crates/python/tests/coverage/coverage_tests.rs @@ -775,19 +775,23 @@ async def run_case(): ]) ); let event_log = result["event_log"].as_array().unwrap(); - assert_eq!( - &event_log[..6], - json!([ - "source:hello", - "yield:hello", - "source:world", - "yield:world", - "guardrails-sees:hello", - "guardrails-sees:world", - ]) - .as_array() - .unwrap() - ); + let source_hello = event_log.iter().position(|value| value == "source:hello").unwrap(); + let source_world = event_log.iter().position(|value| value == "source:world").unwrap(); + let yield_hello = event_log.iter().position(|value| value == "yield:hello").unwrap(); + let yield_world = event_log.iter().position(|value| value == "yield:world").unwrap(); + let guardrails_hello = event_log + .iter() + .position(|value| value == "guardrails-sees:hello") + .unwrap(); + let guardrails_world = event_log + .iter() + .position(|value| value == "guardrails-sees:world") + .unwrap(); + assert!(source_hello < source_world); + assert!(source_hello < yield_hello); + assert!(source_world < yield_world); + assert!(yield_hello < yield_world); + assert!(guardrails_hello < guardrails_world); assert!( result["blocked"] .as_str() From 4aee5540e148d1c08f224bb5fb94b2584cdfead5 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Fri, 5 Jun 2026 10:15:23 -0700 Subject: [PATCH 14/20] fix: restore python guardrails coverage split Signed-off-by: Alex Fournier --- .../python/tests/coverage/coverage_tests.rs | 632 ------------------ .../nemo_guardrails_coverage_tests.rs | 23 +- 2 files changed, 17 insertions(+), 638 deletions(-) diff --git a/crates/python/tests/coverage/coverage_tests.rs b/crates/python/tests/coverage/coverage_tests.rs index 173a38af..6c3205e0 100644 --- a/crates/python/tests/coverage/coverage_tests.rs +++ b/crates/python/tests/coverage/coverage_tests.rs @@ -94,638 +94,6 @@ fn test_native_pymodule_entrypoint_registers_bindings() { }); } -#[test] -fn test_native_pymodule_entrypoint_registers_bindings_without_local_provider_install() { - let _python = crate::test_support::init_python_test(); - Python::attach(|py| { - let module = PyModule::new(py, "_native_guardrails_provider").unwrap(); - crate::_native(&module).unwrap(); - }); - - let runtime = tokio::runtime::Runtime::new().unwrap(); - let error = runtime - .block_on(initialize_plugins(PluginConfig { - version: 1, - components: vec![PluginComponentSpec { - kind: "nemo_guardrails".to_string(), - enabled: true, - config: serde_json::from_value(json!({ - "mode": "local", - "codec": "openai_chat", - "config_path": "./rails" - })) - .unwrap(), - }], - policy: Default::default(), - })) - .unwrap_err(); - - let _ = clear_plugin_configuration(); - match error { - nemo_relay::plugin::PluginError::RegistrationFailed(message) => { - assert!( - message.contains( - "NeMo Guardrails is required for the built-in NeMo Guardrails local backend" - ), - "unexpected message: {message}" - ); - } - other => panic!("unexpected error: {other}"), - } -} - -#[test] -fn test_guardrails_local_helper_registers_and_enforces_llm_and_tool_checks() { - let _python = crate::test_support::init_python_test(); - Python::attach(|py| { - let native_module = PyModule::new(py, "_native_guardrails_helper").unwrap(); - crate::_native(&native_module).unwrap(); - - with_isolated_nemo_relay_modules(py, &native_module, || { - let python_dir = python_package_dir(); - let prelude = fake_guardrails_module_prelude( - "fake_guardrails_local_helper", - &python_dir.display().to_string(), - ); - let epilogue = register_fake_guardrails_module_epilogue(); - let context_class = local_plugin_context_python(); - let embedded_loader = embedded_guardrails_local_loader_python( - &embedded_guardrails_local_source_path() - .display() - .to_string(), - ); - let module = load_module( - py, - &format!( - r#" -{prelude} - -check_results = [] -check_calls = [] - -class LLMRails: - def __init__(self, config): - self.config = config - - async def check_async(self, messages, rail_types): - check_calls.append((messages, rail_types)) - return check_results.pop(0) - -{epilogue} - -{embedded_loader} - -from nemo_relay._native import LLMRequest -from _nemo_guardrails_local import register_local_backend - -{context_class} - -async def run_case(): - ctx = Context() - register_local_backend( - {{ - "mode": "local", - "codec": "openai_chat", - "config_yaml": "models: []", - "input": True, - "output": True, - "tool_input": True, - "tool_output": True, - "local": {{"python_module": MODULE_NAME}}, - }}, - ctx, - ) - - request = LLMRequest( - {{}}, - {{ - "model": "gpt-4o-mini", - "messages": [{{"role": "user", "content": "unsafe"}}], - }}, - ) - seen_request_messages = [] - - async def next_call(req): - seen_request_messages.append(req.content["messages"][-1]["content"]) - return {{ - "choices": [{{"message": {{"role": "assistant", "content": "safe reply"}}}}], - "id": "resp_1", - "model": "gpt-4o-mini", - }} - - check_results.extend( - [ - Result(RailStatus.MODIFIED, content="sanitized user"), - Result(RailStatus.PASSED), - ] - ) - llm_result = await ctx.llm("demo", request, next_call) - - seen_tool_args = [] - - async def next_tool(args): - seen_tool_args.append(args) - return {{"raw": True}} - - check_results.extend( - [ - Result(RailStatus.MODIFIED, content='{{"arguments": {{"city": "Boston"}}}}'), - Result(RailStatus.MODIFIED, content='{{"result": {{"ok": true}}}}'), - ] - ) - tool_result = await ctx.tool("weather_lookup", {{"city": "Phoenix"}}, next_tool) - - return {{ - "llm_result": llm_result, - "tool_result": tool_result, - "seen_request_messages": seen_request_messages, - "seen_tool_args": seen_tool_args, - "check_calls": check_calls, - }} -"#, - prelude = prelude, - epilogue = epilogue, - context_class = context_class, - embedded_loader = embedded_loader, - ), - ); - - let result_json = with_event_loop(py, |event_loop| { - let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); - let result = event_loop - .call_method1("run_until_complete", (coroutine,)) - .unwrap(); - crate::convert::py_to_json(&result).unwrap() - }); - - assert_eq!( - result_json["seen_request_messages"][0], - json!("sanitized user") - ); - assert_eq!(result_json["tool_result"], json!({ "ok": true })); - assert_eq!( - result_json["seen_tool_args"][0], - json!({ "city": "Boston" }) - ); - assert_eq!( - result_json["llm_result"]["choices"][0]["message"]["content"], - json!("safe reply") - ); - assert_eq!( - result_json["check_calls"], - json!([ - [ - [{"role": "user", "content": "unsafe"}], - ["input"] - ], - [ - [ - {"role": "user", "content": "sanitized user"}, - {"role": "assistant", "content": "safe reply"} - ], - ["output"] - ], - [ - [{"role": "user", "content": "{\"arguments\":{\"city\":\"Phoenix\"},\"tool_name\":\"weather_lookup\"}"}], - ["input"] - ], - [ - [ - {"role": "user", "content": "{\"arguments\":{\"city\":\"Boston\"},\"tool_name\":\"weather_lookup\"}"}, - {"role": "assistant", "content": "{\"arguments\":{\"city\":\"Boston\"},\"result\":{\"raw\":true},\"tool_name\":\"weather_lookup\"}"} - ], - ["output"] - ] - ]) - ); - }); - }); -} - -#[test] -fn test_guardrails_local_helper_rejects_unsupported_nemoguardrails_version() { - let _python = crate::test_support::init_python_test(); - Python::attach(|py| { - let native_module = PyModule::new(py, "_native_guardrails_version").unwrap(); - crate::_native(&native_module).unwrap(); - - with_isolated_nemo_relay_modules(py, &native_module, || { - let python_dir = python_package_dir(); - let prelude = fake_guardrails_module_prelude( - "fake_guardrails_bad_version", - &python_dir.display().to_string(), - ); - let epilogue = register_fake_guardrails_module_epilogue(); - let context_class = local_plugin_context_python(); - let embedded_loader = embedded_guardrails_local_loader_python( - &embedded_guardrails_local_source_path() - .display() - .to_string(), - ); - let module = load_module( - py, - &format!( - r#" -{prelude} - -fake_root.__version__ = "0.21.0" - -class LLMRails: - def __init__(self, config): - self.config = config - - async def check_async(self, messages, rail_types): - return Result(RailStatus.PASSED) - -{epilogue} - -{embedded_loader} - -from _nemo_guardrails_local import register_local_backend - -{context_class} - -async def run_case(): - ctx = Context() - register_local_backend( - {{ - "mode": "local", - "codec": "openai_chat", - "config_yaml": "models: []", - "input": True, - "local": {{"python_module": MODULE_NAME}}, - }}, - ctx, - ) -"#, - prelude = prelude, - epilogue = epilogue, - embedded_loader = embedded_loader, - context_class = context_class, - ), - ); - - let error = with_event_loop(py, |event_loop| { - let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); - event_loop - .call_method1("run_until_complete", (coroutine,)) - .unwrap_err() - .to_string() - }); - - assert!( - error.contains("requires nemoguardrails==0.22.0"), - "unexpected error: {error}" - ); - assert!(error.contains("0.21.0"), "unexpected error: {error}"); - }); - }); -} - -#[test] -fn test_guardrails_local_helper_enforces_streamed_output_rails() { - let _python = crate::test_support::init_python_test(); - Python::attach(|py| { - let native_module = PyModule::new(py, "_native_guardrails_streaming").unwrap(); - crate::_native(&native_module).unwrap(); - - with_isolated_nemo_relay_modules(py, &native_module, || { - let python_dir = python_package_dir(); - let prelude = fake_guardrails_module_prelude( - "fake_guardrails_streaming", - &python_dir.display().to_string(), - ); - let epilogue = register_fake_guardrails_module_epilogue(); - let context_class = local_plugin_context_python(); - let embedded_loader = embedded_guardrails_local_loader_python( - &embedded_guardrails_local_source_path() - .display() - .to_string(), - ); - let module = load_module( - py, - &format!( - r#" -{prelude} - -stream_results = [] -event_log = [] - -class LLMRails: - def __init__(self, config): - self.config = types.SimpleNamespace( - rails=types.SimpleNamespace( - output=types.SimpleNamespace( - flows=["self check output"], - streaming=types.SimpleNamespace(enabled=True, stream_first=True), - ) - ) - ) - - async def check_async(self, messages, rail_types): - return Result(RailStatus.PASSED) - - def stream_async(self, *, messages=None, generator=None, include_metadata=False): - async def _run(): - outcome = stream_results.pop(0) - async for chunk in generator: - event_log.append(f"guardrails-sees:{{chunk}}") - if outcome == "pass": - yield chunk - if outcome == "block": - yield '{{"error": {{"message": "Blocked by output rails: output-policy", "type": "guardrails_violation"}}}}' - return _run() - -{epilogue} - -{embedded_loader} - -from nemo_relay._native import LLMRequest -from _nemo_guardrails_local import register_local_backend - -{context_class} - -async def run_case(): - ctx = Context() - event_log.clear() - register_local_backend( - {{ - "mode": "local", - "codec": "openai_chat", - "config_yaml": "models: []", - "input": False, - "output": True, - "local": {{"python_module": MODULE_NAME}}, - }}, - ctx, - ) - - request = LLMRequest( - {{}}, - {{ - "model": "gpt-4o-mini", - "messages": [{{"role": "user", "content": "hello"}}], - }}, - ) - - async def next_call(req): - async def _stream(): - event_log.append("source:hello") - yield {{"choices": [{{"delta": {{"content": "hello"}}}}]}} - event_log.append("source:world") - yield {{"choices": [{{"delta": {{"content": "world"}}}}]}} - return _stream() - - stream_results.append("pass") - allowed_stream = await ctx.stream(request, next_call) - allowed_chunks = [] - async for chunk in allowed_stream: - event_log.append(f"yield:{{chunk['choices'][0]['delta']['content']}}") - allowed_chunks.append(chunk) - - stream_results.append("block") - try: - blocked_stream = await ctx.stream(request, next_call) - async for _chunk in blocked_stream: - pass - except RuntimeError as error: - blocked = str(error) - else: - raise AssertionError("expected streamed output block") - - ctx_stream_first_false = Context() - fake_root.LLMRails = lambda config: types.SimpleNamespace( - config=types.SimpleNamespace( - rails=types.SimpleNamespace( - output=types.SimpleNamespace( - flows=["self check output"], - streaming=types.SimpleNamespace(enabled=True, stream_first=False), - ) - ) - ), - check_async=LLMRails(config).check_async, - stream_async=LLMRails(config).stream_async, - ) - register_local_backend( - {{ - "mode": "local", - "codec": "openai_chat", - "config_yaml": "models: []", - "input": False, - "output": True, - "local": {{"python_module": MODULE_NAME}}, - }}, - ctx_stream_first_false, - ) - try: - failing_stream = await ctx_stream_first_false.stream(request, next_call) - async for _chunk in failing_stream: - pass - except RuntimeError as error: - modified = str(error) - else: - raise AssertionError("expected stream_first=false error") - - return {{ - "allowed_chunks": allowed_chunks, - "blocked": blocked, - "event_log": event_log, - "modified": modified, - }} -"#, - prelude = prelude, - epilogue = epilogue, - context_class = context_class, - embedded_loader = embedded_loader, - ), - ); - - let result = with_event_loop(py, |event_loop| { - let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); - let result = event_loop - .call_method1("run_until_complete", (coroutine,)) - .unwrap(); - crate::convert::py_to_json(&result).unwrap() - }); - assert_eq!( - result["allowed_chunks"], - json!([ - {"choices": [{"delta": {"content": "hello"}}]}, - {"choices": [{"delta": {"content": "world"}}]} - ]) - ); - let event_log = result["event_log"].as_array().unwrap(); - let source_hello = event_log.iter().position(|value| value == "source:hello").unwrap(); - let source_world = event_log.iter().position(|value| value == "source:world").unwrap(); - let yield_hello = event_log.iter().position(|value| value == "yield:hello").unwrap(); - let yield_world = event_log.iter().position(|value| value == "yield:world").unwrap(); - let guardrails_hello = event_log - .iter() - .position(|value| value == "guardrails-sees:hello") - .unwrap(); - let guardrails_world = event_log - .iter() - .position(|value| value == "guardrails-sees:world") - .unwrap(); - assert!(source_hello < source_world); - assert!(source_hello < yield_hello); - assert!(source_world < yield_world); - assert!(yield_hello < yield_world); - assert!(guardrails_hello < guardrails_world); - assert!( - result["blocked"] - .as_str() - .unwrap() - .contains("output rail blocked the LLM call") - ); - assert!( - result["modified"] - .as_str() - .unwrap() - .contains("stream_first = true") - ); - }); - }); -} - -#[test] -fn test_local_guardrails_provider_initializes_and_enforces_managed_core_calls() { - let _python = crate::test_support::init_python_test(); - reset_runtime_state(); - - Python::attach(|py| { - let native_module = PyModule::new(py, "_native_guardrails_e2e").unwrap(); - crate::_native(&native_module).unwrap(); - - with_isolated_nemo_relay_modules(py, &native_module, || { - let python_dir = python_package_dir(); - let prelude = fake_guardrails_module_prelude( - "fake_guardrails_local_e2e", - &python_dir.display().to_string(), - ); - let epilogue = register_fake_guardrails_module_epilogue(); - let module = load_module( - py, - &format!( - r#" -{prelude} - -check_results = [] - -class LLMRails: - def __init__(self, config): - self.config = config - - async def check_async(self, messages, rail_types): - return check_results.pop(0) - -{epilogue} - -import nemo_relay - -async def run_case(): - stack = nemo_relay.create_scope_stack() - nemo_relay.set_thread_scope_stack(stack) - - await nemo_relay.plugin.initialize( - {{ - "version": 1, - "components": [ - {{ - "kind": "nemo_guardrails", - "enabled": True, - "config": {{ - "mode": "local", - "codec": "openai_chat", - "config_yaml": "models: []", - "input": True, - "output": True, - "tool_input": True, - "tool_output": True, - "local": {{"python_module": MODULE_NAME}}, - }}, - }} - ], - }} - ) - - check_results.extend( - [ - Result(RailStatus.MODIFIED, content="sanitized user"), - Result(RailStatus.PASSED), - Result(RailStatus.MODIFIED, content='{{"arguments": {{"city": "Boston"}}}}'), - Result(RailStatus.MODIFIED, content='{{"result": {{"ok": true}}}}'), - ] - ) - - request = nemo_relay.LLMRequest( - {{}}, - {{ - "model": "gpt-4o-mini", - "messages": [{{"role": "user", "content": "unsafe"}}], - }}, - ) - - seen_request_messages = [] - async def llm_impl(req): - seen_request_messages.append(req.content["messages"][-1]["content"]) - return {{ - "choices": [{{"message": {{"role": "assistant", "content": "safe reply"}}}}], - "id": "resp_1", - "model": req.content["model"], - }} - - llm_result = await nemo_relay.llm.execute( - "demo", - request, - llm_impl, - response_codec=nemo_relay.codecs.OpenAIChatCodec(), - ) - - seen_tool_args = [] - async def tool_impl(args): - seen_tool_args.append(args) - return {{"raw": True}} - - tool_result = await nemo_relay.tools.execute("weather_lookup", {{"city": "Phoenix"}}, tool_impl) - return {{ - "llm_result": llm_result, - "tool_result": tool_result, - "seen_request_messages": seen_request_messages, - "seen_tool_args": seen_tool_args, - }} -"#, - prelude = prelude, - epilogue = epilogue, - ), - ); - let result_json = with_event_loop(py, |event_loop| { - let coroutine = module.getattr("run_case").unwrap().call0().unwrap(); - let result = event_loop - .call_method1("run_until_complete", (coroutine,)) - .unwrap(); - crate::convert::py_to_json(&result).unwrap() - }); - - assert_eq!( - result_json["llm_result"]["choices"][0]["message"]["content"], - json!("safe reply") - ); - assert_eq!(result_json["tool_result"], json!({ "ok": true })); - assert_eq!( - result_json["seen_request_messages"][0], - json!("sanitized user") - ); - assert_eq!( - result_json["seen_tool_args"][0], - json!({ "city": "Boston" }) - ); - }); - }); - - reset_runtime_state(); -} - #[test] fn test_python_test_guard_restores_existing_runtime_env() { let lock = crate::test_support::lock_python_test(); diff --git a/crates/python/tests/coverage/nemo_guardrails_coverage_tests.rs b/crates/python/tests/coverage/nemo_guardrails_coverage_tests.rs index 7c1c3893..89becc2d 100644 --- a/crates/python/tests/coverage/nemo_guardrails_coverage_tests.rs +++ b/crates/python/tests/coverage/nemo_guardrails_coverage_tests.rs @@ -644,24 +644,35 @@ async def run_case(): "missing event {expected}: {event_log:?}" ); } - let yield_hello = event_log + let source_hello = event_log .iter() - .position(|event| event == "yield:hello") + .position(|event| event == "source:hello") .unwrap(); - let guardrails_hello = event_log + let source_world = event_log .iter() - .position(|event| event == "guardrails-sees:hello") + .position(|event| event == "source:world") + .unwrap(); + let yield_hello = event_log + .iter() + .position(|event| event == "yield:hello") .unwrap(); let yield_world = event_log .iter() .position(|event| event == "yield:world") .unwrap(); + let guardrails_hello = event_log + .iter() + .position(|event| event == "guardrails-sees:hello") + .unwrap(); let guardrails_world = event_log .iter() .position(|event| event == "guardrails-sees:world") .unwrap(); - assert!(yield_hello < guardrails_hello); - assert!(yield_world < guardrails_world); + assert!(source_hello < source_world); + assert!(source_hello < yield_hello); + assert!(source_world < yield_world); + assert!(yield_hello < yield_world); + assert!(guardrails_hello < guardrails_world); assert!( result["blocked"] .as_str() From aba629c34303fefc502fcec957ac54f45a4f706e Mon Sep 17 00:00:00 2001 From: Will Killian Date: Fri, 5 Jun 2026 17:30:52 -0400 Subject: [PATCH 15/20] fix: run local guardrails through python subprocess Signed-off-by: Will Killian --- Cargo.lock | 3 - crates/cli/Cargo.toml | 2 +- crates/cli/tests/coverage/plugins_tests.rs | 9 + crates/core/Cargo.toml | 9 +- .../src/plugins/nemo_guardrails/component.rs | 29 +- .../core/src/plugins/nemo_guardrails/local.rs | 15 - .../plugins/nemo_guardrails/local_worker.py | 268 ++++++ .../src/plugins/nemo_guardrails/python.rs | 772 +++++++++--------- .../nemo_guardrails/component_tests.rs | 25 +- .../nemo_guardrails/local_python_tests.rs | 383 ++++++--- docs/about-nemo-relay/concepts/plugins.mdx | 3 +- docs/nemo-guardrails-plugin/about.mdx | 12 +- docs/nemo-guardrails-plugin/configuration.mdx | 29 +- 13 files changed, 987 insertions(+), 572 deletions(-) create mode 100644 crates/core/src/plugins/nemo_guardrails/local_worker.py diff --git a/Cargo.lock b/Cargo.lock index c2b3de48..14a9f718 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1343,9 +1343,6 @@ dependencies = [ "opentelemetry-http", "opentelemetry-otlp", "opentelemetry_sdk", - "pyo3", - "pyo3-async-runtimes", - "pythonize", "reqwest", "rustls", "schemars", diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index d1130a8e..f500c03a 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -25,7 +25,7 @@ default = ["atof-streaming"] atof-streaming = ["nemo-relay/atof-streaming"] [dependencies] -nemo-relay = { workspace = true, features = ["guardrails-remote", "object-store", "openinference", "python"] } +nemo-relay = { workspace = true, features = ["guardrails-remote", "object-store", "openinference"] } nemo-relay-adaptive = { workspace = true, features = ["redis-backend"] } async-stream = "0.3" axum = "0.8" diff --git a/crates/cli/tests/coverage/plugins_tests.rs b/crates/cli/tests/coverage/plugins_tests.rs index 502e7b07..dbde4e07 100644 --- a/crates/cli/tests/coverage/plugins_tests.rs +++ b/crates/cli/tests/coverage/plugins_tests.rs @@ -199,6 +199,10 @@ fn typed_editor_model_contains_nemo_guardrails_options() { local.field("python_module").unwrap().kind, EditorFieldKind::String ); + assert_eq!( + local.field("python_executable").unwrap().kind, + EditorFieldKind::String + ); assert_eq!( schema.field("config_path").unwrap().kind, EditorFieldKind::String @@ -1345,6 +1349,7 @@ fn nemo_guardrails_config_map_serializes_local_mode_fields() { tool_output: true, local: Some(LocalBackendConfig { python_module: Some("custom_guardrails".into()), + python_executable: Some("/opt/python/bin/python3".into()), }), ..NeMoGuardrailsConfig::default() }) @@ -1355,6 +1360,10 @@ fn nemo_guardrails_config_map_serializes_local_mode_fields() { assert_eq!(map.get("config_path"), Some(&json!("./rails"))); assert_eq!(map.get("tool_input"), Some(&json!(true))); assert_eq!(map["local"]["python_module"], json!("custom_guardrails")); + assert_eq!( + map["local"]["python_executable"], + json!("/opt/python/bin/python3") + ); } #[test] diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index d1bcf414..7da6f8b3 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -36,11 +36,7 @@ guardrails-remote = [ "dep:reqwest", "dep:rustls", ] -python = [ - "dep:pyo3", - "dep:pyo3-async-runtimes", - "dep:pythonize", -] +python = [] object-store = [ "dep:object_store", "dep:reqwest", @@ -102,9 +98,6 @@ opentelemetry-http = { version = "0.31", default-features = false, optional = tr wasm-bindgen = { version = "0.2", optional = true } wasm-bindgen-futures = { version = "0.4", optional = true } web-sys = { version = "0.3", features = ["Headers", "Request", "RequestInit", "Response", "Window", "console"], optional = true } -pyo3 = { workspace = true, features = ["auto-initialize"], optional = true } -pyo3-async-runtimes = { workspace = true, features = ["tokio-runtime"], optional = true } -pythonize = { workspace = true, optional = true } [dev-dependencies] tokio = { version = "1", features = ["rt", "macros", "sync", "test-util", "rt-multi-thread", "time"] } diff --git a/crates/core/src/plugins/nemo_guardrails/component.rs b/crates/core/src/plugins/nemo_guardrails/component.rs index 37255738..062b1ad6 100644 --- a/crates/core/src/plugins/nemo_guardrails/component.rs +++ b/crates/core/src/plugins/nemo_guardrails/component.rs @@ -192,6 +192,9 @@ pub struct LocalBackendConfig { /// Optional import path for the Python runtime module. #[serde(default, skip_serializing_if = "Option::is_none")] pub python_module: Option, + /// Optional Python executable used to run the local Guardrails worker. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub python_executable: Option, } /// Default request semantics applied by the selected Guardrails backend. @@ -326,6 +329,7 @@ crate::editor_config! { crate::editor_config! { impl LocalBackendConfig { python_module => { label: "python_module", kind: String, optional: true }, + python_executable => { label: "python_executable", kind: String, optional: true }, } } @@ -526,7 +530,7 @@ fn validate_nemo_guardrails_plugin_config( &config.policy, plugin_config, "local", - &["python_module"], + &["python_module", "python_executable"], ); validate_section_fields( &mut diagnostics, @@ -694,6 +698,20 @@ fn validate_non_empty_strings( "local.python_module must not be empty".to_string(), ); } + + if let Some(local) = &config.local + && let Some(python_executable) = &local.python_executable + && python_executable.trim().is_empty() + { + push_policy_diag( + diagnostics, + policy.unsupported_value, + "nemo_guardrails.unsupported_value", + Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), + Some("local.python_executable".to_string()), + "local.python_executable must not be empty".to_string(), + ); + } } fn validate_config_shape( @@ -744,15 +762,6 @@ fn validate_local_config_shape( config: &NeMoGuardrailsConfig, flags: &ConfigShapeFlags, ) { - #[cfg(not(feature = "python"))] - push_config_shape_diag( - diagnostics, - policy.unsupported_value, - "nemo_guardrails.unavailable_backend", - Some("mode"), - "local mode requires a build with the 'python' feature enabled", - ); - if flags.has_config_path == flags.has_config_yaml { push_config_shape_diag( diagnostics, diff --git a/crates/core/src/plugins/nemo_guardrails/local.rs b/crates/core/src/plugins/nemo_guardrails/local.rs index d5ac28b9..e1618836 100644 --- a/crates/core/src/plugins/nemo_guardrails/local.rs +++ b/crates/core/src/plugins/nemo_guardrails/local.rs @@ -3,28 +3,13 @@ use crate::plugin::{PluginRegistrationContext, Result as PluginResult}; -#[cfg(not(feature = "python"))] -use crate::plugin::PluginError; - use super::NeMoGuardrailsConfig; -#[cfg(feature = "python")] mod python; -#[cfg(feature = "python")] pub(super) fn register_local_backend( config: NeMoGuardrailsConfig, ctx: &mut PluginRegistrationContext, ) -> PluginResult<()> { python::register_local_backend(config, ctx) } - -#[cfg(not(feature = "python"))] -pub(super) fn register_local_backend( - _config: NeMoGuardrailsConfig, - _ctx: &mut PluginRegistrationContext, -) -> PluginResult<()> { - Err(PluginError::RegistrationFailed( - "built-in NeMo Guardrails local backend is unavailable in this build".to_string(), - )) -} diff --git a/crates/core/src/plugins/nemo_guardrails/local_worker.py b/crates/core/src/plugins/nemo_guardrails/local_worker.py new file mode 100644 index 00000000..40fac320 --- /dev/null +++ b/crates/core/src/plugins/nemo_guardrails/local_worker.py @@ -0,0 +1,268 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import importlib +import json +import sys +import traceback + +DEFAULT_MODULE_NAME = "nemoguardrails" +SUPPORTED_NEMOGUARDRAILS_VERSION = "0.22.0" + +_PROTOCOL_STDOUT = sys.stdout +sys.stdout = sys.stderr + + +def send(message): + _PROTOCOL_STDOUT.write(json.dumps(message, separators=(",", ":")) + "\n") + _PROTOCOL_STDOUT.flush() + + +def response(request_id, result=None): + payload = {"id": request_id, "ok": True} + if result is not None: + payload["result"] = result + send(payload) + + +def error_response(request_id, error): + send({"id": request_id, "ok": False, "error": str(error)}) + + +def stream_event(request_id, event, **fields): + payload = {"id": request_id, "ok": True, "event": event} + payload.update(fields) + send(payload) + + +def stream_error(request_id, error): + send({"id": request_id, "ok": False, "event": "error", "error": str(error)}) + + +def status_value(status): + value = getattr(status, "value", status) + return str(value).lower() + + +def optional_string_attr(obj, attr): + value = getattr(obj, attr, None) + if value is None: + return None + return str(value) + + +def string_attr_or_empty(obj, attr): + return optional_string_attr(obj, attr) or "" + + +def guardrails_stream_error_message(chunk): + try: + payload = json.loads(chunk) + except Exception: + return None + error = payload.get("error") + if not isinstance(error, dict): + return None + if error.get("type") != "guardrails_violation": + return None + return error.get("message") or "Blocked by output rails." + + +class AsyncTextStream: + def __init__(self, queue): + self._queue = queue + + def __aiter__(self): + return self + + async def __anext__(self): + value = await self._queue.get() + if value is None: + raise StopAsyncIteration + return value + + +class GuardrailsWorker: + def __init__(self, config): + if sys.version_info < (3, 11): + raise RuntimeError("NeMo Guardrails local backend requires python3 >= 3.11") + + local = config.get("local") or {} + root_module = (local.get("python_module") or DEFAULT_MODULE_NAME).strip() + guardrails = self._import_dependency(root_module, root_module) + options = self._import_dependency(f"{root_module}.rails.llm.options", root_module) + + version = getattr(guardrails, "__version__", None) + if version != SUPPORTED_NEMOGUARDRAILS_VERSION: + raise RuntimeError( + "NeMo Guardrails local backend requires " + f"nemoguardrails=={SUPPORTED_NEMOGUARDRAILS_VERSION}, but found {version!r}. " + f"Install it with: pip install nemoguardrails=={SUPPORTED_NEMOGUARDRAILS_VERSION}" + ) + + self._rail_type = options.RailType + self._rail_status = options.RailStatus + guardrails_config = self._build_guardrails_config(guardrails.RailsConfig, config) + self._rails = guardrails.LLMRails(guardrails_config) + + def _import_dependency(self, module_name, root_module): + try: + return importlib.import_module(module_name) + except ImportError as err: + missing = getattr(err, "name", None) + if missing == root_module: + raise RuntimeError( + "NeMo Guardrails is required for the built-in NeMo Guardrails local backend. " + f"Install it with: pip install nemoguardrails=={SUPPORTED_NEMOGUARDRAILS_VERSION}" + ) from err + raise RuntimeError( + "NeMo Guardrails local backend could not import a required dependency: " + f"{missing or err}. Install the full NeMo Guardrails runtime dependencies." + ) from err + + def _build_guardrails_config(self, rails_config_cls, config): + config_path = config.get("config_path") + if config_path: + return rails_config_cls.from_path(config_path) + + config_yaml = config.get("config_yaml") + if config_yaml is None: + raise ValueError("config_yaml is required when config_path is not provided") + return rails_config_cls.from_content( + colang_content=config.get("colang_content"), + yaml_content=config_yaml, + ) + + def _rail_kind(self, rail_type): + if rail_type == "input": + return self._rail_type.INPUT + if rail_type == "output": + return self._rail_type.OUTPUT + raise ValueError(f"unsupported rail_type {rail_type!r}") + + async def check(self, messages, rail_type): + result = await self._rails.check_async( + messages, + rail_types=[self._rail_kind(rail_type)], + ) + return { + "status": status_value(result.status), + "content": string_attr_or_empty(result, "content"), + "rail": optional_string_attr(result, "rail"), + } + + def has_streaming_output_rails(self): + output = self._output_rails_config() + flows = getattr(output, "flows", None) if output is not None else None + return bool(flows) + + def ensure_streaming_output_supported(self): + output = self._output_rails_config() + if output is None: + return + + streaming = getattr(output, "streaming", None) + if streaming is None or not bool(getattr(streaming, "enabled", False)): + raise RuntimeError( + "local NeMo Guardrails streaming output rails require " + "rails.output.streaming.enabled = true in the Guardrails config." + ) + + if not bool(getattr(streaming, "stream_first", True)): + raise RuntimeError( + "local NeMo Guardrails streaming output rails currently require " + "rails.output.streaming.stream_first = true." + ) + + def _output_rails_config(self): + config = getattr(self._rails, "config", None) + rails = getattr(config, "rails", None) + return getattr(rails, "output", None) + + async def monitor_stream(self, request_id, messages, queue, streams): + try: + async for chunk in self._rails.stream_async( + messages=messages, + generator=AsyncTextStream(queue), + include_metadata=False, + ): + if not isinstance(chunk, str): + continue + message = guardrails_stream_error_message(chunk) + if message: + stream_event(request_id, "blocked", message=message) + return + stream_event(request_id, "done") + except Exception as err: + stream_error(request_id, err) + finally: + streams.pop(request_id, None) + + +worker = None +streams = {} + + +async def handle_message(message): + global worker + + request_id = str(message.get("id", "")) + command = message.get("command") + try: + if command == "init": + worker = GuardrailsWorker(message.get("config") or {}) + response( + request_id, + { + "python": sys.executable, + "version": ".".join(str(part) for part in sys.version_info[:3]), + }, + ) + elif worker is None: + raise RuntimeError("NeMo Guardrails local Python worker is not initialized") + elif command == "check": + response( + request_id, + await worker.check(message.get("messages") or [], message.get("rail_type")), + ) + elif command == "has_streaming_output_rails": + response(request_id, {"enabled": worker.has_streaming_output_rails()}) + elif command == "ensure_streaming_output_supported": + worker.ensure_streaming_output_supported() + response(request_id) + elif command == "stream_start": + queue = asyncio.Queue() + streams[request_id] = queue + asyncio.create_task(worker.monitor_stream(request_id, message.get("messages") or [], queue, streams)) + elif command == "stream_text": + queue = streams.get(request_id) + if queue is not None: + await queue.put(message.get("text") or "") + elif command == "stream_end": + queue = streams.get(request_id) + if queue is not None: + await queue.put(None) + else: + raise RuntimeError(f"unknown worker command {command!r}") + except Exception as err: + if command and command.startswith("stream_"): + stream_error(request_id, err) + else: + error_response(request_id, err) + + +async def main(): + while True: + line = await asyncio.to_thread(sys.stdin.readline) + if not line: + return + try: + message = json.loads(line) + except Exception: + traceback.print_exc(file=sys.stderr) + continue + asyncio.create_task(handle_message(message)) + + +asyncio.run(main()) diff --git a/crates/core/src/plugins/nemo_guardrails/python.rs b/crates/core/src/plugins/nemo_guardrails/python.rs index 2f28d706..7d20c7eb 100644 --- a/crates/core/src/plugins/nemo_guardrails/python.rs +++ b/crates/core/src/plugins/nemo_guardrails/python.rs @@ -1,10 +1,16 @@ // SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use std::sync::{Arc, Mutex}; - -use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList}; +use std::collections::HashMap; +use std::env; +use std::io::{BufRead, BufReader, Write}; +use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex, mpsc as std_mpsc}; +use std::thread; +use std::time::Duration; + +use serde::Deserialize; use serde_json::json; use tokio::sync::mpsc; use tokio::task::JoinHandle; @@ -24,8 +30,10 @@ use crate::plugin::{PluginError, PluginRegistrationContext, Result as PluginResu use super::NeMoGuardrailsConfig; -const DEFAULT_MODULE_NAME: &str = "nemoguardrails"; -const SUPPORTED_NEMOGUARDRAILS_VERSION: &str = "0.22.0"; +const DEFAULT_PYTHON_EXECUTABLE: &str = "python3"; +const PYTHON_EXECUTABLE_ENV: &str = "NEMO_RELAY_PYTHON"; +const WORKER_INIT_TIMEOUT: Duration = Duration::from_secs(30); +const WORKER_SCRIPT: &str = include_str!("local_worker.py"); pub(super) fn register_local_backend( config: NeMoGuardrailsConfig, @@ -110,7 +118,6 @@ struct LocalGuardrailsRuntime { impl LocalGuardrailsRuntime { fn new(config: &NeMoGuardrailsConfig) -> PluginResult { - Python::initialize(); Ok(Self { bridge: LocalGuardrailsBridge::new(config)?, codec: resolve_codec(config)?, @@ -147,12 +154,12 @@ impl LocalGuardrailsRuntime { let (request, messages) = self.prepare_llm_request(request, enable_input).await?; let provider_stream = next(request).await?; - if !enable_output || !self.bridge.has_streaming_output_rails()? { + if !enable_output || !self.bridge.has_streaming_output_rails().await? { return Ok(provider_stream); } - self.bridge.ensure_streaming_output_supported()?; - self.guard_provider_stream(messages, provider_stream) + self.bridge.ensure_streaming_output_supported().await?; + self.guard_provider_stream(messages, provider_stream).await } async fn prepare_llm_request( @@ -254,7 +261,7 @@ impl LocalGuardrailsRuntime { } } - fn guard_provider_stream( + async fn guard_provider_stream( &self, messages: Vec, provider_stream: LlmJsonStream, @@ -292,42 +299,14 @@ impl LocalGuardrailsRuntime { } struct LocalGuardrailsBridge { - rails: Py, - input_rail: Py, - output_rail: Py, - blocked_status: String, - modified_status: String, + worker: Arc, } impl LocalGuardrailsBridge { fn new(config: &NeMoGuardrailsConfig) -> PluginResult { - Python::attach(|py| { - let imports = load_nemoguardrails( - py, - config.local.as_ref().and_then(|l| { - l.python_module - .as_deref() - .filter(|module| !module.trim().is_empty()) - }), - )?; - let guardrails_config = build_guardrails_config(py, config, &imports.rails_config_cls)?; - let rails = imports.llm_rails_cls.call1(py, (guardrails_config,))?; - let input_rail = imports.rail_type.getattr(py, "INPUT")?; - let output_rail = imports.rail_type.getattr(py, "OUTPUT")?; - let blocked = imports.rail_status.getattr(py, "BLOCKED")?; - let modified = imports.rail_status.getattr(py, "MODIFIED")?; - let blocked_status = py_status_value(blocked.bind(py))?; - let modified_status = py_status_value(modified.bind(py))?; - - Ok::(Self { - rails, - input_rail, - output_rail, - blocked_status, - modified_status, - }) + Ok(Self { + worker: LocalGuardrailsWorker::start(config)?, }) - .map_err(|err| PluginError::RegistrationFailed(err.to_string())) } async fn check( @@ -335,90 +314,33 @@ impl LocalGuardrailsBridge { messages: Vec, kind: LocalRailKind, ) -> FlowResult { - let future = Python::attach(|py| { - let messages = json_to_py(py, &Json::Array(messages)) - .map_err(|err| FlowError::Internal(err.to_string()))?; - let rail_type = match kind { - LocalRailKind::Input => self.input_rail.clone_ref(py), - LocalRailKind::Output => self.output_rail.clone_ref(py), - }; - let rail_types = - PyList::new(py, [rail_type]).map_err(|err| FlowError::Internal(err.to_string()))?; - let kwargs = PyDict::new(py); - kwargs - .set_item("rail_types", rail_types) - .map_err(|err| FlowError::Internal(err.to_string()))?; - let result = self - .rails - .bind(py) - .call_method("check_async", (messages,), Some(&kwargs)) - .map_err(|err| FlowError::Internal(err.to_string()))?; - pyo3_async_runtimes::tokio::into_future(result.unbind().into_bound(py)) - .map_err(|err| FlowError::Internal(err.to_string())) - })?; - - let result = future - .await - .map_err(|err| FlowError::Internal(err.to_string()))?; - - Python::attach(|py| { - self.parse_check_result(result.bind(py)) - .map_err(|err| FlowError::Internal(err.to_string())) - }) + let result = self + .worker + .request(json!({ + "command": "check", + "messages": messages, + "rail_type": kind.as_str(), + })) + .await?; + parse_check_result(result) } - fn has_streaming_output_rails(&self) -> FlowResult { - Python::attach(|py| { - let Some(output) = self.output_rails_config(py)? else { - return Ok(false); - }; - match output.getattr("flows") { - Ok(flows) => flows - .is_truthy() - .map_err(|err| FlowError::Internal(err.to_string())), - Err(_) => Ok(false), - } - }) + async fn has_streaming_output_rails(&self) -> FlowResult { + let result = self + .worker + .request(json!({ "command": "has_streaming_output_rails" })) + .await?; + result + .get("enabled") + .and_then(Json::as_bool) + .ok_or_else(|| FlowError::Internal("worker returned invalid streaming probe".into())) } - fn ensure_streaming_output_supported(&self) -> FlowResult<()> { - Python::attach(|py| { - let Some(output) = self.output_rails_config(py)? else { - return Ok(()); - }; - let streaming = output.getattr("streaming").map_err(|_| { - FlowError::Internal( - "local NeMo Guardrails streaming output rails require \ - rails.output.streaming.enabled = true in the Guardrails config." - .to_string(), - ) - })?; - let enabled = streaming - .getattr("enabled") - .and_then(|value| value.is_truthy()) - .unwrap_or(false); - if !enabled { - return Err(FlowError::Internal( - "local NeMo Guardrails streaming output rails require \ - rails.output.streaming.enabled = true in the Guardrails config." - .to_string(), - )); - } - - let stream_first = streaming - .getattr("stream_first") - .and_then(|value| value.is_truthy()) - .unwrap_or(true); - if !stream_first { - return Err(FlowError::Internal( - "local NeMo Guardrails streaming output rails currently require \ - rails.output.streaming.stream_first = true." - .to_string(), - )); - } - - Ok(()) - }) + async fn ensure_streaming_output_supported(&self) -> FlowResult<()> { + self.worker + .request(json!({ "command": "ensure_streaming_output_supported" })) + .await + .map(|_| ()) } fn spawn_stream_monitor( @@ -427,196 +349,343 @@ impl LocalGuardrailsBridge { text_rx: mpsc::Receiver>, blocked: Arc>>, ) -> FlowResult>> { - let (async_iter, task_locals) = Python::attach(|py| { - let generator = Py::new( - py, - PyStringStream { - receiver: Arc::new(tokio::sync::Mutex::new(text_rx)), - }, + let (stream_id, event_rx) = self.worker.start_stream(messages)?; + let worker = Arc::clone(&self.worker); + Ok(tokio::spawn(async move { + monitor_guardrails_stream(worker, stream_id, text_rx, event_rx, blocked).await + })) + } +} + +struct LocalGuardrailsWorker { + stdin: Mutex, + child: Mutex, + waiters: Arc>>>, + stream_events: Arc>>>, + next_id: AtomicU64, +} + +impl LocalGuardrailsWorker { + fn start(config: &NeMoGuardrailsConfig) -> PluginResult> { + let python = python_executable(config); + let mut command = Command::new(&python); + command + .arg("-u") + .arg("-c") + .arg(WORKER_SCRIPT) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::inherit()); + + let mut child = command.spawn().map_err(|err| { + PluginError::RegistrationFailed(format!( + "failed to start NeMo Guardrails local Python worker with {python:?}: {err}" + )) + })?; + let stdin = child.stdin.take().ok_or_else(|| { + PluginError::RegistrationFailed( + "failed to open stdin for NeMo Guardrails local Python worker".to_string(), + ) + })?; + let stdout = child.stdout.take().ok_or_else(|| { + PluginError::RegistrationFailed( + "failed to open stdout for NeMo Guardrails local Python worker".to_string(), ) - .map_err(|err| FlowError::Internal(err.to_string()))?; - let messages = json_to_py(py, &Json::Array(messages)) - .map_err(|err| FlowError::Internal(err.to_string()))?; - let kwargs = PyDict::new(py); - kwargs - .set_item("messages", messages) - .map_err(|err| FlowError::Internal(err.to_string()))?; - kwargs - .set_item("generator", generator) - .map_err(|err| FlowError::Internal(err.to_string()))?; - kwargs - .set_item("include_metadata", false) - .map_err(|err| FlowError::Internal(err.to_string()))?; - let async_iter = self - .rails - .bind(py) - .call_method("stream_async", (), Some(&kwargs)) - .map_err(|err| FlowError::Internal(err.to_string()))? - .unbind(); - let task_locals = pyo3_async_runtimes::tokio::get_current_locals(py) - .map_err(|err| FlowError::Internal(err.to_string()))?; - Ok((async_iter, task_locals)) })?; - let async_iter = Arc::new(async_iter); - Ok(tokio::spawn(pyo3_async_runtimes::tokio::scope( - task_locals, - async move { monitor_guardrails_stream(async_iter, blocked).await }, - ))) + let worker = Arc::new(Self { + stdin: Mutex::new(stdin), + child: Mutex::new(child), + waiters: Arc::new(Mutex::new(HashMap::new())), + stream_events: Arc::new(Mutex::new(HashMap::new())), + next_id: AtomicU64::new(1), + }); + worker.spawn_reader(stdout); + worker.initialize(config)?; + Ok(worker) } - fn output_rails_config<'py>(&self, py: Python<'py>) -> FlowResult>> { - let rails = self.rails.bind(py); - let config = match rails.getattr("config") { - Ok(config) => config, - Err(_) => return Ok(None), - }; - let rails_config = match config.getattr("rails") { - Ok(rails_config) => rails_config, - Err(_) => return Ok(None), - }; - match rails_config.getattr("output") { - Ok(output) => Ok(Some(output)), - Err(_) => Ok(None), + fn spawn_reader(&self, stdout: ChildStdout) { + let waiters = Arc::clone(&self.waiters); + let stream_events = Arc::clone(&self.stream_events); + thread::spawn(move || { + let reader = BufReader::new(stdout); + for line in reader.lines() { + let line = match line { + Ok(line) => line, + Err(err) => { + notify_worker_closed(&waiters, &stream_events, err.to_string()); + return; + } + }; + if line.trim().is_empty() { + continue; + } + let envelope = match serde_json::from_str::(&line) { + Ok(envelope) => envelope, + Err(err) => { + notify_worker_closed( + &waiters, + &stream_events, + format!("invalid worker response: {err}"), + ); + return; + } + }; + dispatch_worker_envelope(&waiters, &stream_events, envelope); + } + notify_worker_closed(&waiters, &stream_events, "worker exited".to_string()); + }); + } + + fn initialize(&self, config: &NeMoGuardrailsConfig) -> PluginResult<()> { + let response = self + .request_blocking( + json!({ + "command": "init", + "config": config, + }), + WORKER_INIT_TIMEOUT, + ) + .map_err(|err| PluginError::RegistrationFailed(err.to_string()))?; + if response.ok { + Ok(()) + } else { + Err(PluginError::RegistrationFailed( + response + .error + .unwrap_or_else(|| "NeMo Guardrails local Python worker failed".to_string()), + )) } } - fn parse_check_result(&self, result: &Bound<'_, PyAny>) -> PyResult { - let status = py_status_value(&result.getattr("status")?)?; - let rail = optional_string_attr(result, "rail")?; - let content = string_attr_or_empty(result, "content")?; + async fn request(&self, mut payload: Json) -> FlowResult { + let receiver = self.send_request(&mut payload)?; + let envelope = tokio::task::spawn_blocking(move || receiver.recv()) + .await + .map_err(|err| FlowError::Internal(format!("worker response task failed: {err}")))? + .map_err(|err| FlowError::Internal(format!("worker response channel closed: {err}")))?; + worker_result(envelope) + } + + fn request_blocking(&self, mut payload: Json, timeout: Duration) -> FlowResult { + let receiver = self.send_request(&mut payload)?; + receiver + .recv_timeout(timeout) + .map_err(|err| FlowError::Internal(format!("worker did not initialize: {err}"))) + } - if status == self.blocked_status { - return Ok(LocalCheckOutcome::Blocked { rail }); + fn send_request(&self, payload: &mut Json) -> FlowResult> { + let id = self.next_request_id(); + set_request_id(payload, &id)?; + let (tx, rx) = std_mpsc::channel(); + self.waiters + .lock() + .map_err(|err| FlowError::Internal(format!("worker waiter lock poisoned: {err}")))? + .insert(id.clone(), tx); + if let Err(err) = self.write_command(payload) { + let _ = self.waiters.lock().map(|mut waiters| waiters.remove(&id)); + return Err(err); } - if status == self.modified_status { - return Ok(LocalCheckOutcome::Modified { content }); + Ok(rx) + } + + fn start_stream( + &self, + messages: Vec, + ) -> FlowResult<(String, mpsc::UnboundedReceiver)> { + let id = self.next_request_id(); + let (tx, rx) = mpsc::unbounded_channel(); + self.stream_events + .lock() + .map_err(|err| FlowError::Internal(format!("worker stream lock poisoned: {err}")))? + .insert(id.clone(), tx); + let payload = json!({ + "id": id, + "command": "stream_start", + "messages": messages, + }); + if let Err(err) = self.write_command(&payload) { + self.forget_stream(&id); + return Err(err); } - Ok(LocalCheckOutcome::Passed) - } -} - -struct GuardrailsRuntimeImports { - rails_config_cls: Py, - llm_rails_cls: Py, - rail_type: Py, - rail_status: Py, -} - -fn load_nemoguardrails( - py: Python<'_>, - module_name: Option<&str>, -) -> PyResult { - let root_module = module_name.unwrap_or(DEFAULT_MODULE_NAME); - let importlib = py.import("importlib")?; - let import_module = importlib.getattr("import_module")?; - let guardrails = import_module - .call1((root_module,)) - .map_err(|err| import_dependency_error(py, err, root_module))?; - let options_module_name = format!("{root_module}.rails.llm.options"); - let options = import_module - .call1((options_module_name.as_str(),)) - .map_err(|err| import_dependency_error(py, err, root_module))?; - - let version = guardrails - .getattr("__version__") - .ok() - .and_then(|value| value.extract::().ok()); - if version.as_deref() != Some(SUPPORTED_NEMOGUARDRAILS_VERSION) { - let found = version - .map(|version| format!("{version:?}")) - .unwrap_or_else(|| "None".to_string()); - return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( - "NeMo Guardrails local backend requires nemoguardrails==\ - {SUPPORTED_NEMOGUARDRAILS_VERSION}, but found {found}. \ - Install it with: pip install nemoguardrails=={SUPPORTED_NEMOGUARDRAILS_VERSION}" - ))); + Ok((id, rx)) } - Ok(GuardrailsRuntimeImports { - rails_config_cls: guardrails.getattr("RailsConfig")?.unbind(), - llm_rails_cls: guardrails.getattr("LLMRails")?.unbind(), - rail_type: options.getattr("RailType")?.unbind(), - rail_status: options.getattr("RailStatus")?.unbind(), - }) -} + fn send_stream_text(&self, id: &str, text: String) -> FlowResult<()> { + self.write_command(&json!({ + "id": id, + "command": "stream_text", + "text": text, + })) + } -fn import_dependency_error(py: Python<'_>, err: PyErr, root_module: &str) -> PyErr { - if !err.is_instance_of::(py) { - return err; + fn send_stream_end(&self, id: &str) -> FlowResult<()> { + self.write_command(&json!({ + "id": id, + "command": "stream_end", + })) } - let name = err.value(py).getattr("name").ok().and_then(|name| { - if name.is_none() { - None - } else { - name.extract::().ok() - } - }); + fn forget_stream(&self, id: &str) { + let _ = self + .stream_events + .lock() + .map(|mut streams| streams.remove(id)); + } - if name.as_deref() == Some(root_module) { - return pyo3::exceptions::PyRuntimeError::new_err(format!( - "NeMo Guardrails is required for the built-in NeMo Guardrails local backend. \ - Install it with: pip install nemoguardrails=={SUPPORTED_NEMOGUARDRAILS_VERSION}" - )); + fn next_request_id(&self) -> String { + self.next_id.fetch_add(1, Ordering::Relaxed).to_string() } - pyo3::exceptions::PyRuntimeError::new_err(format!( - "NeMo Guardrails local backend could not import a required dependency: {}. \ - Install the full NeMo Guardrails runtime dependencies.", - name.unwrap_or_else(|| err.to_string()) - )) + fn write_command(&self, payload: &Json) -> FlowResult<()> { + let line = serde_json::to_string(payload).map_err(|err| { + FlowError::Internal(format!("failed to serialize worker command: {err}")) + })?; + let mut stdin = self + .stdin + .lock() + .map_err(|err| FlowError::Internal(format!("worker stdin lock poisoned: {err}")))?; + writeln!(stdin, "{line}") + .and_then(|_| stdin.flush()) + .map_err(|err| FlowError::Internal(format!("failed to write worker command: {err}"))) + } } -fn build_guardrails_config( - py: Python<'_>, - config: &NeMoGuardrailsConfig, - rails_config_cls: &Py, -) -> PyResult> { - let rails_config_cls = rails_config_cls.bind(py); - if let Some(config_path) = config.config_path.as_deref() { - return rails_config_cls - .call_method1("from_path", (config_path,)) - .map(Bound::unbind); +impl Drop for LocalGuardrailsWorker { + fn drop(&mut self) { + if let Ok(mut child) = self.child.lock() { + let _ = child.kill(); + let _ = child.wait(); + } } +} + +#[derive(Debug, Clone, Deserialize)] +struct WorkerEnvelope { + id: String, + ok: bool, + #[serde(default)] + result: Option, + #[serde(default)] + error: Option, + #[serde(default)] + event: Option, + #[serde(default)] + message: Option, +} + +#[derive(Deserialize)] +struct WorkerCheckResult { + status: String, + #[serde(default)] + content: Option, + #[serde(default)] + rail: Option, +} + +fn python_executable(config: &NeMoGuardrailsConfig) -> String { + config + .local + .as_ref() + .and_then(|local| local.python_executable.as_deref()) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(str::to_string) + .or_else(|| { + env::var(PYTHON_EXECUTABLE_ENV) + .ok() + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) + }) + .unwrap_or_else(|| DEFAULT_PYTHON_EXECUTABLE.to_string()) +} - let config_yaml = config.config_yaml.as_deref().ok_or_else(|| { - pyo3::exceptions::PyValueError::new_err( - "config_yaml is required when config_path is not provided", - ) +fn set_request_id(payload: &mut Json, id: &str) -> FlowResult<()> { + let object = payload.as_object_mut().ok_or_else(|| { + FlowError::Internal("worker command payload must be a JSON object".to_string()) })?; - let kwargs = PyDict::new(py); - kwargs.set_item("colang_content", config.colang_content.as_deref())?; - kwargs.set_item("yaml_content", config_yaml)?; - rails_config_cls - .call_method("from_content", (), Some(&kwargs)) - .map(Bound::unbind) + object.insert("id".to_string(), Json::String(id.to_string())); + Ok(()) } -fn py_status_value(status: &Bound<'_, PyAny>) -> PyResult { - let value = status.getattr("value").unwrap_or_else(|_| status.clone()); - Ok(value.str()?.extract::()?.to_lowercase()) +fn dispatch_worker_envelope( + waiters: &Arc>>>, + stream_events: &Arc>>>, + envelope: WorkerEnvelope, +) { + if envelope.event.is_some() { + let sender = stream_events + .lock() + .ok() + .and_then(|streams| streams.get(&envelope.id).cloned()); + if let Some(sender) = sender { + let _ = sender.send(envelope); + } + return; + } + + let sender = waiters + .lock() + .ok() + .and_then(|mut waiters| waiters.remove(&envelope.id)); + if let Some(sender) = sender { + let _ = sender.send(envelope); + } } -fn optional_string_attr(obj: &Bound<'_, PyAny>, attr: &str) -> PyResult> { - match obj.getattr(attr) { - Ok(value) if !value.is_none() => Ok(Some(value.str()?.extract::()?)), - Ok(_) | Err(_) => Ok(None), +fn notify_worker_closed( + waiters: &Arc>>>, + stream_events: &Arc>>>, + message: String, +) { + if let Ok(mut waiters) = waiters.lock() { + for (id, sender) in waiters.drain() { + let _ = sender.send(WorkerEnvelope { + id, + ok: false, + result: None, + error: Some(message.clone()), + event: None, + message: None, + }); + } + } + if let Ok(mut streams) = stream_events.lock() { + for (id, sender) in streams.drain() { + let _ = sender.send(WorkerEnvelope { + id, + ok: false, + result: None, + error: Some(message.clone()), + event: Some("error".to_string()), + message: None, + }); + } } } -fn string_attr_or_empty(obj: &Bound<'_, PyAny>, attr: &str) -> PyResult { - match optional_string_attr(obj, attr)? { - Some(value) => Ok(value), - None => Ok(String::new()), +fn worker_result(envelope: WorkerEnvelope) -> FlowResult { + if envelope.ok { + Ok(envelope.result.unwrap_or(Json::Null)) + } else { + Err(FlowError::Internal(envelope.error.unwrap_or_else(|| { + "NeMo Guardrails local Python worker failed".to_string() + }))) } } -fn json_to_py(py: Python<'_>, value: &Json) -> PyResult> { - let obj: Bound<'_, PyAny> = pythonize::pythonize(py, value).map_err(|e| { - PyErr::new::(format!("Failed to convert from JSON: {e}")) +fn parse_check_result(result: Json) -> FlowResult { + let result: WorkerCheckResult = serde_json::from_value(result).map_err(|err| { + FlowError::Internal(format!("worker returned invalid check result: {err}")) })?; - Ok(obj.unbind()) + match result.status.as_str() { + "blocked" => Ok(LocalCheckOutcome::Blocked { rail: result.rail }), + "modified" => Ok(LocalCheckOutcome::Modified { + content: result.content.unwrap_or_default(), + }), + _ => Ok(LocalCheckOutcome::Passed), + } } #[derive(Clone, Copy)] @@ -689,6 +758,15 @@ enum LocalRailKind { Output, } +impl LocalRailKind { + fn as_str(self) -> &'static str { + match self { + Self::Input => "input", + Self::Output => "output", + } + } +} + fn messages_from_annotated(annotated: &AnnotatedLlmRequest) -> FlowResult> { match serde_json::to_value(&annotated.messages) .map_err(|err| FlowError::Internal(format!("failed to serialize messages: {err}")))? @@ -917,104 +995,62 @@ fn extract_stream_text(codec: LocalGuardrailsCodec, chunk: &Json) -> Option>, + worker: Arc, + stream_id: String, + mut text_rx: mpsc::Receiver>, + mut event_rx: mpsc::UnboundedReceiver, blocked: Arc>>, ) -> FlowResult<()> { + let mut input_closed = false; loop { - let Some(coro) = next_async_iter_coro(&async_iter)? else { - break; - }; - let Some(value) = await_async_iter_value(coro).await? else { - break; - }; - Python::attach(|py| { - if let Ok(chunk) = value.bind(py).extract::() - && let Some(message) = guardrails_stream_error_message(&chunk) - { - let mut guard = blocked.lock().map_err(|err| { - FlowError::Internal(format!("stream block state lock poisoned: {err}")) - })?; - *guard = Some(message); + tokio::select! { + maybe_text = text_rx.recv(), if !input_closed => { + match maybe_text { + Some(Some(text)) => worker.send_stream_text(&stream_id, text)?, + Some(None) | None => { + worker.send_stream_end(&stream_id)?; + input_closed = true; + } + } } - Ok::<(), FlowError>(()) - })?; - if blocked_message(&blocked).is_some() { - break; - } - } - Ok(()) -} - -fn next_async_iter_coro(async_iter: &Arc>) -> FlowResult>> { - Python::attach(|py| { - let iter = async_iter.bind(py); - match iter.call_method0("__anext__") { - Ok(coro) => Ok(Some(coro.unbind())), - Err(error) => { - if error.is_instance_of::(py) { - Ok(None) - } else { - Err(FlowError::Internal(error.to_string())) + maybe_event = event_rx.recv() => { + let Some(event) = maybe_event else { + worker.forget_stream(&stream_id); + return Err(FlowError::Internal( + "NeMo Guardrails local Python worker stream closed unexpectedly".to_string(), + )); + }; + if !event.ok { + worker.forget_stream(&stream_id); + return Err(FlowError::Internal(event.error.unwrap_or_else(|| { + "NeMo Guardrails local Python worker stream failed".to_string() + }))); + } + match event.event.as_deref() { + Some("blocked") => { + if let Some(message) = event.message { + let mut guard = blocked.lock().map_err(|err| { + FlowError::Internal(format!("stream block state lock poisoned: {err}")) + })?; + *guard = Some(message); + } + worker.forget_stream(&stream_id); + return Ok(()); + } + Some("done") => { + worker.forget_stream(&stream_id); + return Ok(()); + } + Some(other) => { + worker.forget_stream(&stream_id); + return Err(FlowError::Internal(format!( + "NeMo Guardrails local Python worker returned unknown stream event '{other}'" + ))); + } + None => {} } } } - }) -} - -async fn await_async_iter_value(coro: Py) -> FlowResult>> { - let future = Python::attach(|py| { - pyo3_async_runtimes::tokio::into_future(coro.into_bound(py)) - .map_err(|err| FlowError::Internal(err.to_string())) - })?; - - match future.await { - Ok(result) => Ok(Some(result)), - Err(error) => Python::attach(|py| { - if error.is_instance_of::(py) { - Ok(None) - } else { - Err(FlowError::Internal(error.to_string())) - } - }), - } -} - -fn guardrails_stream_error_message(chunk: &str) -> Option { - let payload: Json = serde_json::from_str(chunk).ok()?; - let error = payload.get("error")?.as_object()?; - if error.get("type").and_then(Json::as_str) != Some("guardrails_violation") { - return None; - } - error - .get("message") - .and_then(Json::as_str) - .filter(|message| !message.is_empty()) - .map(str::to_string) - .or_else(|| Some("Blocked by output rails.".to_string())) -} - -#[pyclass(name = "StringStream")] -struct PyStringStream { - receiver: Arc>>>, -} - -#[pymethods] -impl PyStringStream { - fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { - slf - } - - fn __anext__<'py>(&self, py: Python<'py>) -> PyResult> { - let receiver = Arc::clone(&self.receiver); - pyo3_async_runtimes::tokio::future_into_py(py, async move { - let mut guard = receiver.lock().await; - match guard.recv().await { - Some(Some(value)) => Ok(value), - Some(None) | None => Err(PyErr::new::( - "stream exhausted", - )), - } - }) } } diff --git a/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs b/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs index 80ae0c2b..8f6b40ff 100644 --- a/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs +++ b/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs @@ -361,6 +361,7 @@ fn schema_contains_every_supported_nemo_guardrails_option() { "headers", "timeout_millis", "python_module", + "python_executable", "context", "thread_id", "state", @@ -761,7 +762,7 @@ fn invalid_shapes_and_values_are_reported() { "config_yaml": "", "colang_content": "", "codec": "openai_chat", - "local": {"python_module": ""} + "local": {"python_module": "", "python_executable": ""} }))); assert!(local_empty_fields.has_errors()); assert!( @@ -788,6 +789,12 @@ fn invalid_shapes_and_values_are_reported() { .iter() .any(|diag| diag.field.as_deref() == Some("local.python_module")) ); + assert!( + local_empty_fields + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("local.python_executable")) + ); let local_request_defaults = validate_plugin_config(&plugin_config(json!({ "mode": "local", @@ -975,22 +982,6 @@ fn unknown_fields_follow_policy() { assert!(ignored.diagnostics.is_empty()); } -#[cfg(not(feature = "python"))] -#[test] -fn local_mode_validation_reports_missing_python_feature() { - let diagnostics = validate_plugin_config(&plugin_config(json!({ - "mode": "local", - "codec": "openai_chat", - "config_path": "./rails" - }))) - .unwrap(); - - assert!(diagnostics.diagnostics.iter().any(|diag| { - diag.message - .contains("local mode requires a build with the 'python' feature enabled") - })); -} - #[test] fn enabled_unknown_mode_initialization_fails_fast_when_policy_ignores_validation() { let _guard = crate::plugins::nemo_guardrails::test_mutex() diff --git a/crates/core/tests/unit/plugins/nemo_guardrails/local_python_tests.rs b/crates/core/tests/unit/plugins/nemo_guardrails/local_python_tests.rs index cf48592c..859c297d 100644 --- a/crates/core/tests/unit/plugins/nemo_guardrails/local_python_tests.rs +++ b/crates/core/tests/unit/plugins/nemo_guardrails/local_python_tests.rs @@ -1,46 +1,111 @@ // SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use std::ffi::CString; +#[cfg(unix)] +use std::fs; +#[cfg(unix)] +use std::os::unix::fs::PermissionsExt; +#[cfg(unix)] +use std::path::{Path, PathBuf}; +#[cfg(unix)] +use std::process::Command; +#[cfg(unix)] +use std::sync::atomic::{AtomicUsize, Ordering}; +#[cfg(unix)] +use std::sync::{Arc, Mutex}; -use pyo3::prelude::*; -use pyo3::types::PyModule; use serde_json::json; use super::*; +#[cfg(unix)] use crate::plugins::nemo_guardrails::component::LocalBackendConfig; -fn local_config(module_name: &str) -> NeMoGuardrailsConfig { - NeMoGuardrailsConfig { - mode: "local".to_string(), - codec: Some("openai_chat".to_string()), - config_yaml: Some("models: []".to_string()), - colang_content: Some("define flow noop\n pass".to_string()), - local: Some(LocalBackendConfig { - python_module: Some(module_name.to_string()), - }), - ..NeMoGuardrailsConfig::default() - } +#[cfg(unix)] +static NEXT_FIXTURE_ID: AtomicUsize = AtomicUsize::new(1); + +#[cfg(unix)] +struct FakeGuardrails { + root: PathBuf, + module_name: String, + python: PathBuf, } -fn install_fake_guardrails(py: Python<'_>, module_name: &str, version: &str, llm_rails_init: &str) { - let code = format!( - r#" -import sys -import types +#[cfg(unix)] +impl FakeGuardrails { + fn new(version: &str) -> Self { + let id = NEXT_FIXTURE_ID.fetch_add(1, Ordering::Relaxed); + let module_name = format!("fake_guardrails_{id}"); + let root = std::env::temp_dir().join(format!( + "nemo_relay_fake_guardrails_{}_{}", + std::process::id(), + id + )); + let package = root.join(&module_name); + fs::create_dir_all(package.join("rails/llm")).unwrap(); + fs::write(package.join("rails/__init__.py"), "").unwrap(); + fs::write(package.join("rails/llm/__init__.py"), "").unwrap(); + fs::write(package.join("rails/llm/options.py"), fake_options_module()).unwrap(); + fs::write(package.join("__init__.py"), fake_root_module(version)).unwrap(); -MODULE_NAME = {module_name:?} + let python = root.join("python-wrapper"); + fs::write( + &python, + format!( + "#!/bin/sh\nPYTHONPATH='{}' exec python3 \"$@\"\n", + shell_single_quote(&root) + ), + ) + .unwrap(); + let mut permissions = fs::metadata(&python).unwrap().permissions(); + permissions.set_mode(0o755); + fs::set_permissions(&python, permissions).unwrap(); -fake_root = types.ModuleType(MODULE_NAME) -fake_root.__version__ = {version:?} -fake_options = types.ModuleType(MODULE_NAME + ".rails.llm.options") + Self { + root, + module_name, + python, + } + } -class Result: - def __init__(self, status, content=None, rail=None): - self.status = status - self.content = content - self.rail = rail + fn config(&self) -> NeMoGuardrailsConfig { + NeMoGuardrailsConfig { + mode: "local".to_string(), + codec: Some("openai_chat".to_string()), + config_yaml: Some("models: []".to_string()), + colang_content: Some("define flow noop\n pass".to_string()), + local: Some(LocalBackendConfig { + python_module: Some(self.module_name.clone()), + python_executable: Some(self.python.to_string_lossy().into_owned()), + }), + ..NeMoGuardrailsConfig::default() + } + } +} +#[cfg(unix)] +impl Drop for FakeGuardrails { + fn drop(&mut self) { + let _ = fs::remove_dir_all(&self.root); + } +} + +#[cfg(unix)] +fn python3_available() -> bool { + Command::new("python3") + .arg("--version") + .output() + .map(|output| output.status.success()) + .unwrap_or(false) +} + +#[cfg(unix)] +fn shell_single_quote(path: &Path) -> String { + path.to_string_lossy().replace('\'', "'\\''") +} + +#[cfg(unix)] +fn fake_options_module() -> &'static str { + r#" class RailType: INPUT = "input" OUTPUT = "output" @@ -49,101 +114,185 @@ class RailStatus: BLOCKED = "blocked" MODIFIED = "modified" PASSED = "passed" +"# +} + +#[cfg(unix)] +fn fake_root_module(version: &str) -> String { + format!( + r#" +import json +import types +from .rails.llm.options import RailStatus + +__version__ = {version:?} + +class Result: + def __init__(self, status, content=None, rail=None): + self.status = status + self.content = content + self.rail = rail class RailsConfig: @staticmethod def from_content(*, colang_content=None, yaml_content=None): - return {{"yaml": yaml_content, "colang": colang_content}} + stream_first = "stream_first_false" not in (yaml_content or "") + flows = [] if "no_stream" in (yaml_content or "") else ["self check output"] + return types.SimpleNamespace( + yaml=yaml_content, + colang=colang_content, + rails=types.SimpleNamespace( + output=types.SimpleNamespace( + flows=flows, + streaming=types.SimpleNamespace(enabled=True, stream_first=stream_first), + ) + ) + ) @staticmethod def from_path(path): - return {{"path": path}} + return types.SimpleNamespace( + path=path, + rails=types.SimpleNamespace( + output=types.SimpleNamespace( + flows=["self check output"], + streaming=types.SimpleNamespace(enabled=True, stream_first=True), + ) + ) + ) class LLMRails: - instances = [] - def __init__(self, config): - LLMRails.instances.append(self) -{llm_rails_init} - -fake_root.Result = Result -fake_root.RailStatus = RailStatus -fake_root.RailsConfig = RailsConfig -fake_root.LLMRails = LLMRails -fake_options.RailType = RailType -fake_options.RailStatus = RailStatus - -sys.modules[MODULE_NAME] = fake_root -sys.modules[MODULE_NAME + ".rails"] = types.ModuleType(MODULE_NAME + ".rails") -sys.modules[MODULE_NAME + ".rails.llm"] = types.ModuleType(MODULE_NAME + ".rails.llm") -sys.modules[MODULE_NAME + ".rails.llm.options"] = fake_options + self.config = config + + async def check_async(self, messages, rail_types=None): + content = " ".join(str(message.get("content", "")) for message in messages) + if "block" in content: + return Result(RailStatus.BLOCKED, "", "policy") + if "modify-tool" in content: + return Result(RailStatus.MODIFIED, '{{"arguments":{{"safe":true}},"result":{{"ok":true}}}}') + if "modify" in content: + return Result(RailStatus.MODIFIED, "rewritten") + return Result(RailStatus.PASSED, "") + + async def stream_async(self, *, messages=None, generator=None, include_metadata=False): + async for text in generator: + if "stream-block" in text: + yield json.dumps({{"error": {{"type": "guardrails_violation", "message": "blocked stream"}}}}) + return + yield json.dumps({{"ok": True}}) "# - ); - let code = CString::new(code).unwrap(); - let file_name = CString::new("fake_guardrails.py").unwrap(); - let module_name = CString::new(format!("{module_name}_installer")).unwrap(); - PyModule::from_code(py, &code, &file_name, &module_name).unwrap(); + ) } -fn py_to_json(obj: &Bound<'_, PyAny>) -> Json { - pythonize::depythonize(obj).unwrap() +#[cfg(unix)] +#[tokio::test(flavor = "current_thread")] +async fn bridge_checks_pass_block_and_modify_outcomes() { + if !python3_available() { + return; + } + + let fixture = FakeGuardrails::new("0.22.0"); + let bridge = LocalGuardrailsBridge::new(&fixture.config()).unwrap(); + + assert!(matches!( + bridge + .check( + vec![json!({"role": "user", "content": "hello"})], + LocalRailKind::Input, + ) + .await + .unwrap(), + LocalCheckOutcome::Passed + )); + + match bridge + .check( + vec![json!({"role": "user", "content": "block this"})], + LocalRailKind::Input, + ) + .await + .unwrap() + { + LocalCheckOutcome::Blocked { rail } => assert_eq!(rail.as_deref(), Some("policy")), + _ => panic!("expected blocked outcome"), + } + + match bridge + .check( + vec![json!({"role": "user", "content": "modify this"})], + LocalRailKind::Input, + ) + .await + .unwrap() + { + LocalCheckOutcome::Modified { content } => assert_eq!(content, "rewritten"), + _ => panic!("expected modified outcome"), + } } +#[cfg(unix)] #[test] -fn bridge_loads_inline_guardrails_config() { - let _guard = crate::plugins::nemo_guardrails::test_mutex() - .lock() - .unwrap_or_else(|err| err.into_inner()); - Python::attach(|py| { - let module_name = "fake_guardrails_bridge_config"; - install_fake_guardrails(py, module_name, "0.22.0", " self.config = config"); - - let bridge = LocalGuardrailsBridge::new(&local_config(module_name)).unwrap(); - let config = bridge.rails.bind(py).getattr("config").unwrap(); - assert_eq!( - py_to_json(&config), - json!({"yaml": "models: []", "colang": "define flow noop\n pass"}) - ); - }); +fn bridge_rejects_unsupported_guardrails_version() { + if !python3_available() { + return; + } + + let fixture = FakeGuardrails::new("0.21.0"); + let error = match LocalGuardrailsBridge::new(&fixture.config()) { + Ok(_) => panic!("expected unsupported version error"), + Err(error) => error, + }; + assert!(error.to_string().contains("nemoguardrails==0.22.0")); } -#[test] -fn bridge_parses_pass_block_and_modify_outcomes() { - let _guard = crate::plugins::nemo_guardrails::test_mutex() - .lock() - .unwrap_or_else(|err| err.into_inner()); - Python::attach(|py| { - let module_name = "fake_guardrails_bridge_outcomes"; - install_fake_guardrails(py, module_name, "0.22.0", " self.config = config"); - let bridge = LocalGuardrailsBridge::new(&local_config(module_name)).unwrap(); - let root = py.import(module_name).unwrap(); - let result_cls = root.getattr("Result").unwrap(); - let status = root.getattr("RailStatus").unwrap(); - - let passed = result_cls - .call1((status.getattr("PASSED").unwrap(),)) - .unwrap(); - assert!(matches!( - bridge.parse_check_result(&passed).unwrap(), - LocalCheckOutcome::Passed - )); +#[cfg(unix)] +#[tokio::test(flavor = "current_thread")] +async fn streaming_support_rejects_stream_first_false() { + if !python3_available() { + return; + } - let blocked = result_cls - .call1((status.getattr("BLOCKED").unwrap(), "stop", "policy")) - .unwrap(); - match bridge.parse_check_result(&blocked).unwrap() { - LocalCheckOutcome::Blocked { rail } => assert_eq!(rail.as_deref(), Some("policy")), - _ => panic!("expected blocked outcome"), - } + let fixture = FakeGuardrails::new("0.22.0"); + let mut config = fixture.config(); + config.config_yaml = Some("stream_first_false".to_string()); + let bridge = LocalGuardrailsBridge::new(&config).unwrap(); - let modified = result_cls - .call1((status.getattr("MODIFIED").unwrap(), "rewritten")) - .unwrap(); - match bridge.parse_check_result(&modified).unwrap() { - LocalCheckOutcome::Modified { content } => assert_eq!(content, "rewritten"), - _ => panic!("expected modified outcome"), - } - }); + assert!(bridge.has_streaming_output_rails().await.unwrap()); + let error = bridge + .ensure_streaming_output_supported() + .await + .unwrap_err(); + assert!(error.to_string().contains("stream_first = true")); +} + +#[cfg(unix)] +#[tokio::test(flavor = "current_thread")] +async fn stream_monitor_records_blocked_message() { + if !python3_available() { + return; + } + + let fixture = FakeGuardrails::new("0.22.0"); + let bridge = LocalGuardrailsBridge::new(&fixture.config()).unwrap(); + let (text_tx, text_rx) = mpsc::channel(8); + let blocked = Arc::new(Mutex::new(None)); + let monitor = bridge + .spawn_stream_monitor( + vec![json!({"role": "user", "content": "hello"})], + text_rx, + Arc::clone(&blocked), + ) + .unwrap(); + + text_tx + .send(Some("stream-block".to_string())) + .await + .unwrap(); + text_tx.send(None).await.unwrap(); + monitor.await.unwrap().unwrap(); + + assert_eq!(blocked.lock().unwrap().as_deref(), Some("blocked stream")); } #[test] @@ -163,34 +312,6 @@ fn modified_tool_payload_rejects_malformed_content() { ); } -#[test] -fn streaming_support_rejects_stream_first_false() { - let _guard = crate::plugins::nemo_guardrails::test_mutex() - .lock() - .unwrap_or_else(|err| err.into_inner()); - Python::attach(|py| { - let module_name = "fake_guardrails_bridge_streaming"; - install_fake_guardrails( - py, - module_name, - "0.22.0", - r#" self.config = types.SimpleNamespace( - rails=types.SimpleNamespace( - output=types.SimpleNamespace( - flows=["self check output"], - streaming=types.SimpleNamespace(enabled=True, stream_first=False), - ) - ) - )"#, - ); - - let bridge = LocalGuardrailsBridge::new(&local_config(module_name)).unwrap(); - assert!(bridge.has_streaming_output_rails().unwrap()); - let error = bridge.ensure_streaming_output_supported().unwrap_err(); - assert!(error.to_string().contains("stream_first = true")); - }); -} - #[test] fn stream_text_extraction_handles_supported_codecs() { assert_eq!( diff --git a/docs/about-nemo-relay/concepts/plugins.mdx b/docs/about-nemo-relay/concepts/plugins.mdx index 065b4b96..7034855a 100644 --- a/docs/about-nemo-relay/concepts/plugins.mdx +++ b/docs/about-nemo-relay/concepts/plugins.mdx @@ -174,7 +174,8 @@ shared plugin system. The current shipped user-facing lanes are: - the remote backend for Guardrails-service integration -- the Python-backed local backend for in-process `nemoguardrails` integration +- the Python-backed local backend for `nemoguardrails` integration through a + subprocess worker Detailed Guardrails plugin configuration belongs in [NeMo Guardrails Configuration](/nemo-guardrails-plugin/configuration). diff --git a/docs/nemo-guardrails-plugin/about.mdx b/docs/nemo-guardrails-plugin/about.mdx index d346fcb0..6e922710 100644 --- a/docs/nemo-guardrails-plugin/about.mdx +++ b/docs/nemo-guardrails-plugin/about.mdx @@ -20,8 +20,8 @@ The plugin is designed around backend modes: - Calls a Guardrails service over HTTP(S), including streaming over the same remote contract. - `local` - - Calls `nemoguardrails` in process through the Python runtime instead of a - separate Guardrails service. + - Calls `nemoguardrails` through a local `python3` worker subprocess instead + of a separate Guardrails service. ## Use This Plugin When @@ -40,7 +40,7 @@ Start here when you need to: The built-in plugin currently exposes two user-facing modes: - `remote` for Guardrails-service integration over HTTP(S) -- `local` for in-process `nemoguardrails` integration through the Python runtime +- `local` for `nemoguardrails` integration through a local Python worker Both modes support managed LLM `input` and `output`. The current mode-specific differences are: @@ -50,9 +50,9 @@ differences are: - `local` supports managed `tool_input` and broader LLM codec coverage, but it does not support `request_defaults` -The `local` backend is a Python-backed runtime feature, not a universal -cross-binding backend. Runtimes that do not install the local backend provider -report `local` mode as unavailable during plugin initialization. +The `local` backend requires a `python3 >= 3.11` executable that can import the +tested `nemoguardrails` dependency. It does not embed Python into the NeMo Relay +binary. ## Managed Surfaces Versus Request Defaults diff --git a/docs/nemo-guardrails-plugin/configuration.mdx b/docs/nemo-guardrails-plugin/configuration.mdx index 16245f24..c4eb3715 100644 --- a/docs/nemo-guardrails-plugin/configuration.mdx +++ b/docs/nemo-guardrails-plugin/configuration.mdx @@ -56,7 +56,7 @@ At least one managed Guardrails surface must be enabled. | Managed `tool_output` | Supported | Supported | | `request_defaults` pass-through | Supported | Not supported | | Codec support | `openai_chat` | `openai_chat`, `openai_responses`, `anthropic_messages` | -| Runtime availability | Any runtime that includes the remote backend | Python-enabled runtimes that can import `nemoguardrails` | +| Runtime availability | Any runtime that includes the remote backend | Runtimes that can start `python3 >= 3.11` with `nemoguardrails==0.22.0` installed | ## Remote Mode @@ -219,18 +219,18 @@ Guardrails activity: ## Local Mode -Use `local` mode when NeMo Relay should call `nemoguardrails` in process -through the Python runtime instead of a separate Guardrails service. +Use `local` mode when NeMo Relay should call `nemoguardrails` through a local +Python worker subprocess instead of a separate Guardrails service. ### Requirements -To use `mode = "local"`, the running Python environment must be able to import -`nemoguardrails==0.22.0`. +To use `mode = "local"`, NeMo Relay must be able to start a `python3 >= 3.11` +executable that can import `nemoguardrails==0.22.0`. -The built-in local backend is installed by the Python binding and runs -Guardrails in process. Use it when the runtime has direct access to the Python -Guardrails dependency and configuration files rather than a separate Guardrails -service. Install the tested local-mode Guardrails dependency with +The built-in local backend starts a Python worker process and sends Guardrails +checks over a JSON-lines protocol. Use it when the runtime has direct access to +the Python Guardrails dependency and configuration files rather than a separate +Guardrails service. Install the tested local-mode Guardrails dependency with `pip install nemoguardrails==0.22.0`. The same ownership boundary still applies: @@ -262,15 +262,18 @@ tool_input = true tool_output = true config_path = "./rails" +[components.config.local] +python_executable = "python3" + [components.config.policy] unknown_component = "warn" unknown_field = "warn" unsupported_value = "error" ``` -This example configures the built-in local mode for a Python-enabled runtime -that can import `nemoguardrails` and read a native Guardrails config directory -from `./rails`. +This example configures the built-in local mode for a runtime that can start +`python3`, import `nemoguardrails`, and read a native Guardrails config +directory from `./rails`. For example, the Guardrails-side policy can look like this: @@ -298,6 +301,8 @@ When `mode = "local"`: - `local.python_module` is optional and only needed when the runtime should import the Guardrails dependency from a custom Python module path instead of the default `nemoguardrails` package. +- `local.python_executable` is optional and defaults to the + `NEMO_RELAY_PYTHON` environment variable when set, otherwise `python3`. ### Codec Boundary From cdbd1a03f57478a3000f9763cb7fc88202e7b30c Mon Sep 17 00:00:00 2001 From: Will Killian Date: Fri, 5 Jun 2026 18:16:20 -0400 Subject: [PATCH 16/20] test: adapt guardrails coverage to subprocess backend Signed-off-by: Will Killian --- .../tests/unit/observability/atof_tests.rs | 2 +- .../nemo_guardrails_coverage_tests.rs | 403 ++++++++++-------- 2 files changed, 216 insertions(+), 189 deletions(-) diff --git a/crates/core/tests/unit/observability/atof_tests.rs b/crates/core/tests/unit/observability/atof_tests.rs index 47ba4dd8..e0bd4f96 100644 --- a/crates/core/tests/unit/observability/atof_tests.rs +++ b/crates/core/tests/unit/observability/atof_tests.rs @@ -300,7 +300,7 @@ fn start_http_capture_server(expected_requests: usize) -> (String, Arc(py: Python<'py>, code: &str) -> Bound<'py, PyModule> { let code = CString::new(code).unwrap(); let file_name = CString::new("nemo_guardrails_coverage_tests.py").unwrap(); @@ -26,26 +31,69 @@ fn python_package_dir() -> PathBuf { PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../python") } -fn fake_guardrails_module_prelude(module_name: &str, python_dir: &str) -> String { +struct FakeGuardrailsPackage { + root: PathBuf, + module_name: String, + python_executable: PathBuf, +} + +impl FakeGuardrailsPackage { + fn new(_py: Python<'_>, module_name: &str, version: &str, implementation: &str) -> Self { + let id = NEXT_FAKE_GUARDRAILS_ID.fetch_add(1, Ordering::Relaxed); + let root = std::env::temp_dir().join(format!( + "nemo_relay_python_fake_guardrails_{}_{}", + process::id(), + id + )); + let package = root.join(module_name); + fs::create_dir_all(package.join("rails/llm")).unwrap(); + fs::write(package.join("rails/__init__.py"), "").unwrap(); + fs::write(package.join("rails/llm/__init__.py"), "").unwrap(); + fs::write(package.join("rails/llm/options.py"), fake_options_module()).unwrap(); + fs::write( + package.join("__init__.py"), + fake_root_module(version, implementation), + ) + .unwrap(); + + let python_executable = write_python_wrapper(&root); + + Self { + root, + module_name: module_name.to_string(), + python_executable, + } + } +} + +impl Drop for FakeGuardrailsPackage { + fn drop(&mut self) { + let _ = fs::remove_dir_all(&self.root); + } +} + +fn fake_guardrails_module_prelude( + module_name: &str, + python_dir: &str, + python_executable: &str, +) -> String { format!( r#" import sys -import types sys.path.insert(0, {python_dir:?}) MODULE_NAME = {module_name:?} +PYTHON_EXECUTABLE = {python_executable:?} +"#, + python_dir = python_dir, + module_name = module_name, + python_executable = python_executable, + ) +} -fake_root = types.ModuleType(MODULE_NAME) -fake_root.__version__ = "0.22.0" -fake_options = types.ModuleType(MODULE_NAME + ".rails.llm.options") - -class Result: - def __init__(self, status, content=None, rail=None): - self.status = status - self.content = content - self.rail = rail - +fn fake_options_module() -> &'static str { + r#" class RailType: INPUT = "input" OUTPUT = "output" @@ -54,6 +102,22 @@ class RailStatus: BLOCKED = "blocked" MODIFIED = "modified" PASSED = "passed" +"# +} + +fn fake_root_module(version: &str, implementation: &str) -> String { + format!( + r#" +import types +from .rails.llm.options import RailStatus + +__version__ = {version:?} + +class Result: + def __init__(self, status, content=None, rail=None): + self.status = status + self.content = content + self.rail = rail class RailsConfig: @staticmethod @@ -63,26 +127,99 @@ class RailsConfig: @staticmethod def from_path(path): return {{"path": path}} -"#, - python_dir = python_dir, - module_name = module_name, + +{implementation} +"# ) } -fn register_fake_guardrails_module_epilogue() -> &'static str { +fn check_sequence_guardrails() -> &'static str { + r#" +class LLMRails: + def __init__(self, config): + self.config = config + self._check_results = [ + Result(RailStatus.MODIFIED, content="sanitized user"), + Result(RailStatus.PASSED), + Result(RailStatus.MODIFIED, content='{"arguments": {"city": "Boston"}}'), + Result(RailStatus.MODIFIED, content='{"result": {"ok": true}}'), + ] + + async def check_async(self, messages, rail_types): + return self._check_results.pop(0) +"# +} + +fn streaming_guardrails() -> &'static str { r#" -fake_root.RailsConfig = RailsConfig -fake_root.LLMRails = LLMRails -fake_options.RailType = RailType -fake_options.RailStatus = RailStatus - -sys.modules[MODULE_NAME] = fake_root -sys.modules[MODULE_NAME + ".rails"] = types.ModuleType(MODULE_NAME + ".rails") -sys.modules[MODULE_NAME + ".rails.llm"] = types.ModuleType(MODULE_NAME + ".rails.llm") -sys.modules[MODULE_NAME + ".rails.llm.options"] = fake_options +class LLMRails: + def __init__(self, config): + yaml = str(config.get("yaml", "")) + stream_first = "stream_first_false" not in yaml + self.config = types.SimpleNamespace( + rails=types.SimpleNamespace( + output=types.SimpleNamespace( + flows=["self check output"], + streaming=types.SimpleNamespace(enabled=True, stream_first=stream_first), + ) + ) + ) + self._stream_calls = 0 + + async def check_async(self, messages, rail_types): + return Result(RailStatus.PASSED) + + def stream_async(self, *, messages=None, generator=None, include_metadata=False): + async def _run(): + self._stream_calls += 1 + call_index = self._stream_calls + async for chunk in generator: + if call_index == 1: + yield chunk + if call_index > 1: + yield '{"error": {"message": "Blocked by output rails: output-policy", "type": "guardrails_violation"}}' + return _run() "# } +#[cfg(unix)] +fn write_python_wrapper(root: &Path) -> PathBuf { + use std::os::unix::fs::PermissionsExt; + + let wrapper = root.join("python-wrapper"); + fs::write( + &wrapper, + format!( + "#!/bin/sh\nPYTHONPATH='{}' exec python3 \"$@\"\n", + shell_single_quote(root) + ), + ) + .unwrap(); + let mut permissions = fs::metadata(&wrapper).unwrap().permissions(); + permissions.set_mode(0o755); + fs::set_permissions(&wrapper, permissions).unwrap(); + wrapper +} + +#[cfg(windows)] +fn write_python_wrapper(root: &Path) -> PathBuf { + let wrapper = root.join("python-wrapper.cmd"); + fs::write( + &wrapper, + format!( + "@echo off\r\nset \"PYTHONPATH={};%PYTHONPATH%\"\r\npython3 %*\r\n", + root.display() + ), + ) + .unwrap(); + wrapper +} + +#[cfg(unix)] +fn shell_single_quote(path: &Path) -> String { + path.to_string_lossy().replace('\'', "'\\''") +} + fn with_isolated_nemo_relay_modules( py: Python<'_>, native_module: &Bound<'_, PyModule>, @@ -218,31 +355,24 @@ fn test_guardrails_local_runtime_registers_and_enforces_llm_and_tool_checks() { crate::_native(&native_module).unwrap(); with_isolated_nemo_relay_modules(py, &native_module, || { + let fake = FakeGuardrailsPackage::new( + py, + "fake_guardrails_local_runtime", + "0.22.0", + check_sequence_guardrails(), + ); let python_dir = python_package_dir(); let prelude = fake_guardrails_module_prelude( - "fake_guardrails_local_runtime", + &fake.module_name, &python_dir.display().to_string(), + &fake.python_executable.display().to_string(), ); - let epilogue = register_fake_guardrails_module_epilogue(); let module = load_module( py, &format!( r#" {prelude} -check_results = [] -check_calls = [] - -class LLMRails: - def __init__(self, config): - self.config = config - - async def check_async(self, messages, rail_types): - check_calls.append((messages, rail_types)) - return check_results.pop(0) - -{epilogue} - import nemo_relay async def run_case(): @@ -263,7 +393,10 @@ async def run_case(): "output": True, "tool_input": True, "tool_output": True, - "local": {{"python_module": MODULE_NAME}}, + "local": {{ + "python_module": MODULE_NAME, + "python_executable": PYTHON_EXECUTABLE, + }}, }}, }} ], @@ -287,12 +420,6 @@ async def run_case(): "model": "gpt-4o-mini", }} - check_results.extend( - [ - Result(RailStatus.MODIFIED, content="sanitized user"), - Result(RailStatus.PASSED), - ] - ) llm_result = await nemo_relay.llm.execute( "demo", request, @@ -306,12 +433,6 @@ async def run_case(): seen_tool_args.append(args) return {{"raw": True}} - check_results.extend( - [ - Result(RailStatus.MODIFIED, content='{{"arguments": {{"city": "Boston"}}}}'), - Result(RailStatus.MODIFIED, content='{{"result": {{"ok": true}}}}'), - ] - ) tool_result = await nemo_relay.tools.execute("weather_lookup", {{"city": "Phoenix"}}, next_tool) return {{ @@ -319,11 +440,9 @@ async def run_case(): "tool_result": tool_result, "seen_request_messages": seen_request_messages, "seen_tool_args": seen_tool_args, - "check_calls": check_calls, }} "#, prelude = prelude, - epilogue = epilogue, ), ); @@ -348,33 +467,6 @@ async def run_case(): result_json["llm_result"]["choices"][0]["message"]["content"], json!("safe reply") ); - assert_eq!( - result_json["check_calls"], - json!([ - [ - [{"role": "user", "content": "unsafe"}], - ["input"] - ], - [ - [ - {"role": "user", "content": "sanitized user"}, - {"role": "assistant", "content": "safe reply"} - ], - ["output"] - ], - [ - [{"role": "user", "content": "{\"arguments\":{\"city\":\"Phoenix\"},\"tool_name\":\"weather_lookup\"}"}], - ["input"] - ], - [ - [ - {"role": "user", "content": "{\"arguments\":{\"city\":\"Boston\"},\"tool_name\":\"weather_lookup\"}"}, - {"role": "assistant", "content": "{\"arguments\":{\"city\":\"Boston\"},\"result\":{\"raw\":true},\"tool_name\":\"weather_lookup\"}"} - ], - ["output"] - ] - ]) - ); }); }); @@ -391,29 +483,24 @@ fn test_guardrails_local_runtime_rejects_unsupported_nemoguardrails_version() { crate::_native(&native_module).unwrap(); with_isolated_nemo_relay_modules(py, &native_module, || { + let fake = FakeGuardrailsPackage::new( + py, + "fake_guardrails_bad_version", + "0.21.0", + check_sequence_guardrails(), + ); let python_dir = python_package_dir(); let prelude = fake_guardrails_module_prelude( - "fake_guardrails_bad_version", + &fake.module_name, &python_dir.display().to_string(), + &fake.python_executable.display().to_string(), ); - let epilogue = register_fake_guardrails_module_epilogue(); let module = load_module( py, &format!( r#" {prelude} -fake_root.__version__ = "0.21.0" - -class LLMRails: - def __init__(self, config): - self.config = config - - async def check_async(self, messages, rail_types): - return Result(RailStatus.PASSED) - -{epilogue} - import nemo_relay async def run_case(): @@ -429,7 +516,10 @@ async def run_case(): "codec": "openai_chat", "config_yaml": "models: []", "input": True, - "local": {{"python_module": MODULE_NAME}}, + "local": {{ + "python_module": MODULE_NAME, + "python_executable": PYTHON_EXECUTABLE, + }}, }}, }} ], @@ -437,7 +527,6 @@ async def run_case(): ) "#, prelude = prelude, - epilogue = epilogue, ), ); @@ -470,51 +559,29 @@ fn test_guardrails_local_runtime_enforces_streamed_output_rails() { crate::_native(&native_module).unwrap(); with_isolated_nemo_relay_modules(py, &native_module, || { + let fake = FakeGuardrailsPackage::new( + py, + "fake_guardrails_streaming", + "0.22.0", + streaming_guardrails(), + ); let python_dir = python_package_dir(); let prelude = fake_guardrails_module_prelude( - "fake_guardrails_streaming", + &fake.module_name, &python_dir.display().to_string(), + &fake.python_executable.display().to_string(), ); - let epilogue = register_fake_guardrails_module_epilogue(); let module = load_module( py, &format!( r#" {prelude} -stream_results = [] event_log = [] -class LLMRails: - def __init__(self, config): - self.config = types.SimpleNamespace( - rails=types.SimpleNamespace( - output=types.SimpleNamespace( - flows=["self check output"], - streaming=types.SimpleNamespace(enabled=True, stream_first=True), - ) - ) - ) - - async def check_async(self, messages, rail_types): - return Result(RailStatus.PASSED) - - def stream_async(self, *, messages=None, generator=None, include_metadata=False): - async def _run(): - outcome = stream_results.pop(0) - async for chunk in generator: - event_log.append(f"guardrails-sees:{{chunk}}") - if outcome == "pass": - yield chunk - if outcome == "block": - yield '{{"error": {{"message": "Blocked by output rails: output-policy", "type": "guardrails_violation"}}}}' - return _run() - -{epilogue} - import nemo_relay -def plugin_config(): +def plugin_config(config_yaml="models: []"): return {{ "version": 1, "components": [ @@ -524,10 +591,13 @@ def plugin_config(): "config": {{ "mode": "local", "codec": "openai_chat", - "config_yaml": "models: []", + "config_yaml": config_yaml, "input": False, "output": True, - "local": {{"python_module": MODULE_NAME}}, + "local": {{ + "python_module": MODULE_NAME, + "python_executable": PYTHON_EXECUTABLE, + }}, }}, }} ], @@ -572,10 +642,8 @@ async def run_case(): }}, ) - stream_results.append("pass") allowed_chunks = await run_stream(request) - stream_results.append("block") try: await run_stream(request) except RuntimeError as error: @@ -584,19 +652,7 @@ async def run_case(): raise AssertionError("expected streamed output block") nemo_relay.plugin.clear() - fake_root.LLMRails = lambda config: types.SimpleNamespace( - config=types.SimpleNamespace( - rails=types.SimpleNamespace( - output=types.SimpleNamespace( - flows=["self check output"], - streaming=types.SimpleNamespace(enabled=True, stream_first=False), - ) - ) - ), - check_async=LLMRails(config).check_async, - stream_async=LLMRails(config).stream_async, - ) - await nemo_relay.plugin.initialize(plugin_config()) + await nemo_relay.plugin.initialize(plugin_config("stream_first_false")) try: await run_stream(request) except RuntimeError as error: @@ -612,7 +668,6 @@ async def run_case(): }} "#, prelude = prelude, - epilogue = epilogue, ), ); @@ -631,14 +686,7 @@ async def run_case(): ]) ); let event_log = result["event_log"].as_array().unwrap(); - for expected in [ - "source:hello", - "source:world", - "yield:hello", - "yield:world", - "guardrails-sees:hello", - "guardrails-sees:world", - ] { + for expected in ["source:hello", "source:world", "yield:hello", "yield:world"] { assert!( event_log.iter().any(|event| event == expected), "missing event {expected}: {event_log:?}" @@ -660,19 +708,10 @@ async def run_case(): .iter() .position(|event| event == "yield:world") .unwrap(); - let guardrails_hello = event_log - .iter() - .position(|event| event == "guardrails-sees:hello") - .unwrap(); - let guardrails_world = event_log - .iter() - .position(|event| event == "guardrails-sees:world") - .unwrap(); assert!(source_hello < source_world); assert!(source_hello < yield_hello); assert!(source_world < yield_world); assert!(yield_hello < yield_world); - assert!(guardrails_hello < guardrails_world); assert!( result["blocked"] .as_str() @@ -701,29 +740,24 @@ fn test_local_guardrails_provider_initializes_and_enforces_managed_core_calls() crate::_native(&native_module).unwrap(); with_isolated_nemo_relay_modules(py, &native_module, || { + let fake = FakeGuardrailsPackage::new( + py, + "fake_guardrails_local_e2e", + "0.22.0", + check_sequence_guardrails(), + ); let python_dir = python_package_dir(); let prelude = fake_guardrails_module_prelude( - "fake_guardrails_local_e2e", + &fake.module_name, &python_dir.display().to_string(), + &fake.python_executable.display().to_string(), ); - let epilogue = register_fake_guardrails_module_epilogue(); let module = load_module( py, &format!( r#" {prelude} -check_results = [] - -class LLMRails: - def __init__(self, config): - self.config = config - - async def check_async(self, messages, rail_types): - return check_results.pop(0) - -{epilogue} - import nemo_relay async def run_case(): @@ -745,22 +779,16 @@ async def run_case(): "output": True, "tool_input": True, "tool_output": True, - "local": {{"python_module": MODULE_NAME}}, + "local": {{ + "python_module": MODULE_NAME, + "python_executable": PYTHON_EXECUTABLE, + }}, }}, }} ], }} ) - check_results.extend( - [ - Result(RailStatus.MODIFIED, content="sanitized user"), - Result(RailStatus.PASSED), - Result(RailStatus.MODIFIED, content='{{"arguments": {{"city": "Boston"}}}}'), - Result(RailStatus.MODIFIED, content='{{"result": {{"ok": true}}}}'), - ] - ) - request = nemo_relay.LLMRequest( {{}}, {{ @@ -799,7 +827,6 @@ async def run_case(): }} "#, prelude = prelude, - epilogue = epilogue, ), ); let result_json = with_event_loop(py, |event_loop| { From dfee75bbe939e77da3ba44767f2297af7cd0ca55 Mon Sep 17 00:00:00 2001 From: Will Killian Date: Fri, 5 Jun 2026 18:44:55 -0400 Subject: [PATCH 17/20] chore: restore cargo manifests Signed-off-by: Will Killian --- Cargo.toml | 3 --- crates/core/Cargo.toml | 2 -- crates/python/Cargo.toml | 8 ++++---- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e45741b2..bf54af16 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,9 +26,6 @@ nemo-relay = { version = "0.4.0", path = "crates/core", default-features = false nemo-relay-adaptive = { version = "0.4.0", path = "crates/adaptive" } nemo-relay-ffi = { version = "0.4.0", path = "crates/ffi" } nemo-relay-cli = { version = "0.4.0", path = "crates/cli" } -pyo3 = "0.28.2" -pyo3-async-runtimes = "0.28.0" -pythonize = "0.28.0" uuid = "=1.18.1" [workspace.lints.rust] diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 7da6f8b3..72139869 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -19,7 +19,6 @@ default = [ "openinference", "guardrails-remote", "object-store", - "python", ] atof-streaming = [ "dep:futures-util", @@ -36,7 +35,6 @@ guardrails-remote = [ "dep:reqwest", "dep:rustls", ] -python = [] object-store = [ "dep:object_store", "dep:reqwest", diff --git a/crates/python/Cargo.toml b/crates/python/Cargo.toml index 5db1264d..4861d109 100644 --- a/crates/python/Cargo.toml +++ b/crates/python/Cargo.toml @@ -18,11 +18,11 @@ name = "_native" crate-type = ["cdylib", "rlib"] [dependencies] -nemo-relay = { workspace = true, features = ["atof-streaming", "otel", "openinference", "python"] } +nemo-relay = { workspace = true, features = ["atof-streaming", "otel", "openinference"] } nemo-relay-adaptive = { workspace = true, features = ["redis-backend"] } -pyo3 = { workspace = true, features = ["abi3", "abi3-py311", "experimental-inspect", "macros"] } -pyo3-async-runtimes = { workspace = true, features = ["tokio-runtime"] } -pythonize = { workspace = true } +pyo3 = { version = "0.28.2", features = ["abi3", "abi3-py311", "experimental-inspect", "macros"] } +pyo3-async-runtimes = { version = "0.28.0", features = ["tokio-runtime"] } +pythonize = "0.28.0" serde_json = "1" serde = "1" uuid = { workspace = true, features = ["v7"] } From 527b1f8771170f946872ef553374accc6105711e Mon Sep 17 00:00:00 2001 From: Will Killian Date: Fri, 5 Jun 2026 19:45:23 -0400 Subject: [PATCH 18/20] fix: address local guardrails review feedback Signed-off-by: Will Killian --- .../plugins/nemo_guardrails/local_worker.py | 3 +- .../src/plugins/nemo_guardrails/python.rs | 33 +++-- crates/python/src/lib.rs | 14 +- .../nemo_guardrails_coverage_tests.rs | 132 +++++++++++++----- docs/nemo-guardrails-plugin/about.mdx | 5 +- 5 files changed, 121 insertions(+), 66 deletions(-) diff --git a/crates/core/src/plugins/nemo_guardrails/local_worker.py b/crates/core/src/plugins/nemo_guardrails/local_worker.py index 40fac320..90d331aa 100644 --- a/crates/core/src/plugins/nemo_guardrails/local_worker.py +++ b/crates/core/src/plugins/nemo_guardrails/local_worker.py @@ -9,6 +9,7 @@ DEFAULT_MODULE_NAME = "nemoguardrails" SUPPORTED_NEMOGUARDRAILS_VERSION = "0.22.0" +STREAM_QUEUE_MAXSIZE = 32 _PROTOCOL_STDOUT = sys.stdout sys.stdout = sys.stderr @@ -232,7 +233,7 @@ async def handle_message(message): worker.ensure_streaming_output_supported() response(request_id) elif command == "stream_start": - queue = asyncio.Queue() + queue = asyncio.Queue(maxsize=STREAM_QUEUE_MAXSIZE) streams[request_id] = queue asyncio.create_task(worker.monitor_stream(request_id, message.get("messages") or [], queue, streams)) elif command == "stream_text": diff --git a/crates/core/src/plugins/nemo_guardrails/python.rs b/crates/core/src/plugins/nemo_guardrails/python.rs index 7d20c7eb..6b7d76e6 100644 --- a/crates/core/src/plugins/nemo_guardrails/python.rs +++ b/crates/core/src/plugins/nemo_guardrails/python.rs @@ -895,19 +895,22 @@ async fn forward_guarded_provider_stream( let text = extract_stream_text(codec, &chunk); - if chunk_tx.send(Ok(chunk)).await.is_err() { - let _ = text_tx.send(None).await; - let _ = monitor.await; - return; - } - tokio::task::yield_now().await; - if let Some(text) = text { - let _ = text_tx.send(Some(text)).await; + if text_tx.send(Some(text)).await.is_err() { + finish_stream_monitor(monitor, &chunk_tx, &blocked).await; + return; + } + tokio::task::yield_now().await; + + if let Some(message) = blocked_message(&blocked) { + let _ = chunk_tx.send(Err(streaming_output_blocked(message))).await; + let _ = text_tx.send(None).await; + let _ = monitor.await; + return; + } } - if let Some(message) = blocked_message(&blocked) { - let _ = chunk_tx.send(Err(streaming_output_blocked(message))).await; + if chunk_tx.send(Ok(chunk)).await.is_err() { let _ = text_tx.send(None).await; let _ = monitor.await; return; @@ -915,6 +918,14 @@ async fn forward_guarded_provider_stream( } let _ = text_tx.send(None).await; + finish_stream_monitor(monitor, &chunk_tx, &blocked).await; +} + +async fn finish_stream_monitor( + monitor: JoinHandle>, + chunk_tx: &mpsc::Sender>, + blocked: &Arc>>, +) { match monitor.await { Ok(Ok(())) => {} Ok(Err(err)) => { @@ -931,7 +942,7 @@ async fn forward_guarded_provider_stream( } } - if let Some(message) = blocked_message(&blocked) { + if let Some(message) = blocked_message(blocked) { let _ = chunk_tx.send(Err(streaming_output_blocked(message))).await; } } diff --git a/crates/python/src/lib.rs b/crates/python/src/lib.rs index b3377d28..0c7c1998 100644 --- a/crates/python/src/lib.rs +++ b/crates/python/src/lib.rs @@ -23,7 +23,7 @@ use nemo_relay::shared_runtime::initialize_shared_runtime_binding; use nemo_relay_adaptive::plugin_component::register_adaptive_component; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyModule}; +use pyo3::types::PyModule; mod convert; #[doc(hidden)] @@ -57,18 +57,6 @@ fn _native(m: &Bound<'_, PyModule>) -> PyResult<()> { py_api::register(m)?; py_plugin::register(m)?; py_adaptive::register(m)?; - install_native_module_alias(m)?; - Ok(()) -} - -fn install_native_module_alias(m: &Bound<'_, PyModule>) -> PyResult<()> { - let py = m.py(); - let sys = py.import("sys")?; - let modules = sys.getattr("modules")?.cast_into::()?; - modules.set_item("nemo_relay._native", m)?; - if let Ok(package) = py.import("nemo_relay") { - let _ = package.setattr("_native", m); - } Ok(()) } diff --git a/crates/python/tests/coverage/nemo_guardrails_coverage_tests.rs b/crates/python/tests/coverage/nemo_guardrails_coverage_tests.rs index 6a6af0e5..00805085 100644 --- a/crates/python/tests/coverage/nemo_guardrails_coverage_tests.rs +++ b/crates/python/tests/coverage/nemo_guardrails_coverage_tests.rs @@ -7,8 +7,11 @@ use std::ffi::CString; use std::fs; use std::panic::{AssertUnwindSafe, catch_unwind}; use std::path::{Path, PathBuf}; -use std::process; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::process::{self, Command, Stdio}; +use std::sync::{ + Mutex, + atomic::{AtomicUsize, Ordering}, +}; use nemo_relay::api::runtime::{NemoRelayContextState, global_context}; use nemo_relay::plugin::{ @@ -19,6 +22,7 @@ use pyo3::types::{PyDict, PyModule}; use serde_json::json; static NEXT_FAKE_GUARDRAILS_ID: AtomicUsize = AtomicUsize::new(1); +static SERIAL_TEST_MUTEX: Mutex<()> = Mutex::new(()); fn load_module<'py>(py: Python<'py>, code: &str) -> Bound<'py, PyModule> { let code = CString::new(code).unwrap(); @@ -38,7 +42,7 @@ struct FakeGuardrailsPackage { } impl FakeGuardrailsPackage { - fn new(_py: Python<'_>, module_name: &str, version: &str, implementation: &str) -> Self { + fn new(py: Python<'_>, module_name: &str, version: &str, implementation: &str) -> Self { let id = NEXT_FAKE_GUARDRAILS_ID.fetch_add(1, Ordering::Relaxed); let root = std::env::temp_dir().join(format!( "nemo_relay_python_fake_guardrails_{}_{}", @@ -56,7 +60,7 @@ impl FakeGuardrailsPackage { ) .unwrap(); - let python_executable = write_python_wrapper(&root); + let python_executable = write_python_wrapper(&root, &python_executable_for_worker(py)); Self { root, @@ -66,6 +70,27 @@ impl FakeGuardrailsPackage { } } +fn python_executable_for_worker(py: Python<'_>) -> String { + let executable = py + .import("sys") + .and_then(|sys| sys.getattr("executable")) + .and_then(|executable| executable.extract::()) + .unwrap_or_else(|_| "python3".to_string()); + if Command::new(&executable) + .arg("-c") + .arg("import sys") + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status() + .map(|status| status.success()) + .unwrap_or(false) + { + executable + } else { + "python3".to_string() + } +} + impl Drop for FakeGuardrailsPackage { fn drop(&mut self) { let _ = fs::remove_dir_all(&self.root); @@ -135,6 +160,23 @@ class RailsConfig: fn check_sequence_guardrails() -> &'static str { r#" +class LLMRails: + def __init__(self, config): + self.config = config + self._check_results = [ + Result(RailStatus.MODIFIED, content="sanitized user"), + Result(RailStatus.BLOCKED, rail="output-policy"), + Result(RailStatus.MODIFIED, content='{"arguments": {"city": "Boston"}}'), + Result(RailStatus.MODIFIED, content='{"result": {"ok": true}}'), + ] + + async def check_async(self, messages, rail_types): + return self._check_results.pop(0) +"# +} + +fn tool_sequence_guardrails() -> &'static str { + r#" class LLMRails: def __init__(self, config): self.config = config @@ -183,15 +225,16 @@ class LLMRails: } #[cfg(unix)] -fn write_python_wrapper(root: &Path) -> PathBuf { +fn write_python_wrapper(root: &Path, python_executable: &str) -> PathBuf { use std::os::unix::fs::PermissionsExt; let wrapper = root.join("python-wrapper"); fs::write( &wrapper, format!( - "#!/bin/sh\nPYTHONPATH='{}' exec python3 \"$@\"\n", - shell_single_quote(root) + "#!/bin/sh\nPYTHONPATH='{}' exec '{}' \"$@\"\n", + shell_single_quote(root), + shell_single_quote(Path::new(python_executable)) ), ) .unwrap(); @@ -202,13 +245,14 @@ fn write_python_wrapper(root: &Path) -> PathBuf { } #[cfg(windows)] -fn write_python_wrapper(root: &Path) -> PathBuf { +fn write_python_wrapper(root: &Path, python_executable: &str) -> PathBuf { let wrapper = root.join("python-wrapper.cmd"); fs::write( &wrapper, format!( - "@echo off\r\nset \"PYTHONPATH={};%PYTHONPATH%\"\r\npython3 %*\r\n", - root.display() + "@echo off\r\nset \"PYTHONPATH={};%PYTHONPATH%\"\r\n\"{}\" %*\r\n", + root.display(), + python_executable.replace('"', "\"\"") ), ) .unwrap(); @@ -225,6 +269,7 @@ fn with_isolated_nemo_relay_modules( native_module: &Bound<'_, PyModule>, f: impl FnOnce() -> T, ) -> T { + let _serial_guard = SERIAL_TEST_MUTEX.lock().unwrap(); let sys = py.import("sys").unwrap(); let modules = sys .getattr("modules") @@ -254,6 +299,7 @@ fn with_isolated_nemo_relay_modules( for (name, module) in saved_modules { modules.set_item(name, module).unwrap(); } + reset_runtime_state(); match result { Ok(value) => value, @@ -290,12 +336,19 @@ fn with_event_loop(py: Python<'_>, f: impl FnOnce(Bound<'_, PyAny>) -> T) -> asyncio .call_method1("set_event_loop", (&event_loop,)) .unwrap(); - let result = f(event_loop.clone().into_any()); + let result = catch_unwind(AssertUnwindSafe(|| f(event_loop.clone().into_any()))); asyncio .call_method1("set_event_loop", (py.None(),)) .unwrap(); event_loop.call_method0("close").unwrap(); - result + #[cfg(windows)] + asyncio + .call_method1("set_event_loop_policy", (py.None(),)) + .unwrap(); + match result { + Ok(value) => value, + Err(payload) => std::panic::resume_unwind(payload), + } } fn reset_runtime_state() { @@ -307,6 +360,7 @@ fn reset_runtime_state() { #[test] fn test_native_pymodule_entrypoint_registers_bindings_without_local_provider_install() { let _python = crate::test_support::init_python_test(); + let _serial_guard = SERIAL_TEST_MUTEX.lock().unwrap(); reset_runtime_state(); Python::attach(|py| { let module = PyModule::new(py, "_native_guardrails_provider").unwrap(); @@ -346,7 +400,7 @@ fn test_native_pymodule_entrypoint_registers_bindings_without_local_provider_ins } #[test] -fn test_guardrails_local_runtime_registers_and_enforces_llm_and_tool_checks() { +fn test_guardrails_local_runtime_enforces_llm_input_and_output_checks() { let _python = crate::test_support::init_python_test(); reset_runtime_state(); @@ -420,26 +474,21 @@ async def run_case(): "model": "gpt-4o-mini", }} - llm_result = await nemo_relay.llm.execute( - "demo", - request, - next_call, - response_codec=nemo_relay.codecs.OpenAIChatCodec(), - ) - - seen_tool_args = [] - - async def next_tool(args): - seen_tool_args.append(args) - return {{"raw": True}} - - tool_result = await nemo_relay.tools.execute("weather_lookup", {{"city": "Phoenix"}}, next_tool) + try: + await nemo_relay.llm.execute( + "demo", + request, + next_call, + response_codec=nemo_relay.codecs.OpenAIChatCodec(), + ) + except RuntimeError as error: + llm_error = str(error) + else: + raise AssertionError("expected output rail block") return {{ - "llm_result": llm_result, - "tool_result": tool_result, + "llm_error": llm_error, "seen_request_messages": seen_request_messages, - "seen_tool_args": seen_tool_args, }} "#, prelude = prelude, @@ -458,14 +507,21 @@ async def run_case(): result_json["seen_request_messages"][0], json!("sanitized user") ); - assert_eq!(result_json["tool_result"], json!({ "ok": true })); - assert_eq!( - result_json["seen_tool_args"][0], - json!({ "city": "Boston" }) + assert!( + result_json["llm_error"] + .as_str() + .unwrap() + .contains("output rail blocked the LLM call"), + "unexpected error: {}", + result_json["llm_error"] ); - assert_eq!( - result_json["llm_result"]["choices"][0]["message"]["content"], - json!("safe reply") + assert!( + result_json["llm_error"] + .as_str() + .unwrap() + .contains("output-policy"), + "unexpected error: {}", + result_json["llm_error"] ); }); }); @@ -744,7 +800,7 @@ fn test_local_guardrails_provider_initializes_and_enforces_managed_core_calls() py, "fake_guardrails_local_e2e", "0.22.0", - check_sequence_guardrails(), + tool_sequence_guardrails(), ); let python_dir = python_package_dir(); let prelude = fake_guardrails_module_prelude( diff --git a/docs/nemo-guardrails-plugin/about.mdx b/docs/nemo-guardrails-plugin/about.mdx index 6e922710..42c3c739 100644 --- a/docs/nemo-guardrails-plugin/about.mdx +++ b/docs/nemo-guardrails-plugin/about.mdx @@ -50,9 +50,8 @@ differences are: - `local` supports managed `tool_input` and broader LLM codec coverage, but it does not support `request_defaults` -The `local` backend requires a `python3 >= 3.11` executable that can import the -tested `nemoguardrails` dependency. It does not embed Python into the NeMo Relay -binary. +The `local` backend requires a `python3 >= 3.11` executable that can import +`nemoguardrails==0.22.0`. It does not embed Python into the NeMo Relay binary. ## Managed Surfaces Versus Request Defaults From 6d68272530f735660d74a70e170b258cf07a5316 Mon Sep 17 00:00:00 2001 From: Will Killian Date: Fri, 5 Jun 2026 20:12:59 -0400 Subject: [PATCH 19/20] fix: harden local guardrails worker flow Signed-off-by: Will Killian --- .../plugins/nemo_guardrails/local_worker.py | 5 +- .../src/plugins/nemo_guardrails/python.rs | 140 ++++++++++++++---- .../nemo_guardrails/local_python_tests.rs | 60 +++++++- docs/nemo-guardrails-plugin/configuration.mdx | 19 ++- 4 files changed, 185 insertions(+), 39 deletions(-) diff --git a/crates/core/src/plugins/nemo_guardrails/local_worker.py b/crates/core/src/plugins/nemo_guardrails/local_worker.py index 90d331aa..bd519655 100644 --- a/crates/core/src/plugins/nemo_guardrails/local_worker.py +++ b/crates/core/src/plugins/nemo_guardrails/local_worker.py @@ -263,7 +263,10 @@ async def main(): except Exception: traceback.print_exc(file=sys.stderr) continue - asyncio.create_task(handle_message(message)) + if str(message.get("command", "")).startswith("stream_"): + await handle_message(message) + else: + asyncio.create_task(handle_message(message)) asyncio.run(main()) diff --git a/crates/core/src/plugins/nemo_guardrails/python.rs b/crates/core/src/plugins/nemo_guardrails/python.rs index 6b7d76e6..86cd2773 100644 --- a/crates/core/src/plugins/nemo_guardrails/python.rs +++ b/crates/core/src/plugins/nemo_guardrails/python.rs @@ -33,6 +33,7 @@ use super::NeMoGuardrailsConfig; const DEFAULT_PYTHON_EXECUTABLE: &str = "python3"; const PYTHON_EXECUTABLE_ENV: &str = "NEMO_RELAY_PYTHON"; const WORKER_INIT_TIMEOUT: Duration = Duration::from_secs(30); +const WORKER_RPC_TIMEOUT: Duration = Duration::from_secs(30); const WORKER_SCRIPT: &str = include_str!("local_worker.py"); pub(super) fn register_local_backend( @@ -358,7 +359,7 @@ impl LocalGuardrailsBridge { } struct LocalGuardrailsWorker { - stdin: Mutex, + writer: Mutex>, child: Mutex, waiters: Arc>>>, stream_events: Arc>>>, @@ -394,7 +395,7 @@ impl LocalGuardrailsWorker { })?; let worker = Arc::new(Self { - stdin: Mutex::new(stdin), + writer: Mutex::new(Some(WorkerCommandWriter::spawn(stdin))), child: Mutex::new(child), waiters: Arc::new(Mutex::new(HashMap::new())), stream_events: Arc::new(Mutex::new(HashMap::new())), @@ -461,10 +462,21 @@ impl LocalGuardrailsWorker { async fn request(&self, mut payload: Json) -> FlowResult { let receiver = self.send_request(&mut payload)?; - let envelope = tokio::task::spawn_blocking(move || receiver.recv()) - .await - .map_err(|err| FlowError::Internal(format!("worker response task failed: {err}")))? - .map_err(|err| FlowError::Internal(format!("worker response channel closed: {err}")))?; + let response_task = tokio::task::spawn_blocking(move || receiver.recv()); + let envelope = match tokio::time::timeout(WORKER_RPC_TIMEOUT, response_task).await { + Ok(result) => result + .map_err(|err| FlowError::Internal(format!("worker response task failed: {err}")))? + .map_err(|err| { + FlowError::Internal(format!("worker response channel closed: {err}")) + })?, + Err(_) => { + self.shutdown(); + return Err(FlowError::Internal(format!( + "worker request timed out after {} seconds", + WORKER_RPC_TIMEOUT.as_secs() + ))); + } + }; worker_result(envelope) } @@ -542,22 +554,86 @@ impl LocalGuardrailsWorker { let line = serde_json::to_string(payload).map_err(|err| { FlowError::Internal(format!("failed to serialize worker command: {err}")) })?; - let mut stdin = self - .stdin + let writer = self + .writer .lock() - .map_err(|err| FlowError::Internal(format!("worker stdin lock poisoned: {err}")))?; - writeln!(stdin, "{line}") - .and_then(|_| stdin.flush()) - .map_err(|err| FlowError::Internal(format!("failed to write worker command: {err}"))) + .map_err(|err| FlowError::Internal(format!("worker writer lock poisoned: {err}")))?; + writer + .as_ref() + .ok_or_else(|| FlowError::Internal("worker command writer is closed".to_string()))? + .send(line) } -} -impl Drop for LocalGuardrailsWorker { - fn drop(&mut self) { + fn shutdown(&self) { + let writer = self.writer.lock().ok().and_then(|mut writer| writer.take()); if let Ok(mut child) = self.child.lock() { let _ = child.kill(); let _ = child.wait(); } + if let Some(writer) = writer { + writer.join(); + } + } +} + +impl Drop for LocalGuardrailsWorker { + fn drop(&mut self) { + self.shutdown(); + } +} + +struct WorkerCommandWriter { + sender: std_mpsc::Sender, + error: Arc>>, + handle: Option>, +} + +impl WorkerCommandWriter { + fn spawn(mut stdin: ChildStdin) -> Self { + let (sender, receiver) = std_mpsc::channel::(); + let error = Arc::new(Mutex::new(None)); + let writer_error = Arc::clone(&error); + let handle = thread::spawn(move || { + for line in receiver { + if let Err(err) = writeln!(stdin, "{line}").and_then(|_| stdin.flush()) { + if let Ok(mut stored_error) = writer_error.lock() { + *stored_error = Some(err.to_string()); + } + return; + } + } + let _ = stdin.flush(); + }); + Self { + sender, + error, + handle: Some(handle), + } + } + + fn send(&self, line: String) -> FlowResult<()> { + if let Some(error) = self + .error + .lock() + .map_err(|err| { + FlowError::Internal(format!("worker writer error lock poisoned: {err}")) + })? + .clone() + { + return Err(FlowError::Internal(format!( + "failed to write worker command: {error}" + ))); + } + self.sender.send(line).map_err(|err| { + FlowError::Internal(format!("worker command writer channel closed: {err}")) + }) + } + + fn join(mut self) { + drop(self.sender); + if let Some(handle) = self.handle.take() { + let _ = handle.join(); + } } } @@ -684,7 +760,10 @@ fn parse_check_result(result: Json) -> FlowResult { "modified" => Ok(LocalCheckOutcome::Modified { content: result.content.unwrap_or_default(), }), - _ => Ok(LocalCheckOutcome::Passed), + "passed" => Ok(LocalCheckOutcome::Passed), + unexpected => Err(FlowError::Internal(format!( + "unexpected worker check status: {unexpected}" + ))), } } @@ -875,6 +954,7 @@ async fn forward_guarded_provider_stream( monitor: JoinHandle>, blocked: Arc>>, ) { + let mut buffered_chunks = Vec::new(); while let Some(item) = provider_stream.next().await { let chunk = match item { Ok(chunk) => chunk, @@ -897,10 +977,9 @@ async fn forward_guarded_provider_stream( if let Some(text) = text { if text_tx.send(Some(text)).await.is_err() { - finish_stream_monitor(monitor, &chunk_tx, &blocked).await; + send_stream_monitor_error(monitor, &chunk_tx, &blocked).await; return; } - tokio::task::yield_now().await; if let Some(message) = blocked_message(&blocked) { let _ = chunk_tx.send(Err(streaming_output_blocked(message))).await; @@ -910,27 +989,31 @@ async fn forward_guarded_provider_stream( } } + buffered_chunks.push(chunk); + } + + let _ = text_tx.send(None).await; + if send_stream_monitor_error(monitor, &chunk_tx, &blocked).await { + return; + } + + for chunk in buffered_chunks { if chunk_tx.send(Ok(chunk)).await.is_err() { - let _ = text_tx.send(None).await; - let _ = monitor.await; return; } } - - let _ = text_tx.send(None).await; - finish_stream_monitor(monitor, &chunk_tx, &blocked).await; } -async fn finish_stream_monitor( +async fn send_stream_monitor_error( monitor: JoinHandle>, chunk_tx: &mpsc::Sender>, blocked: &Arc>>, -) { +) -> bool { match monitor.await { Ok(Ok(())) => {} Ok(Err(err)) => { let _ = chunk_tx.send(Err(err)).await; - return; + return true; } Err(err) => { let _ = chunk_tx @@ -938,13 +1021,16 @@ async fn finish_stream_monitor( "nemo_guardrails stream monitor task failed: {err}" )))) .await; - return; + return true; } } if let Some(message) = blocked_message(blocked) { let _ = chunk_tx.send(Err(streaming_output_blocked(message))).await; + return true; } + + false } fn blocked_message(blocked: &Arc>>) -> Option { diff --git a/crates/core/tests/unit/plugins/nemo_guardrails/local_python_tests.rs b/crates/core/tests/unit/plugins/nemo_guardrails/local_python_tests.rs index 859c297d..7ef5dbee 100644 --- a/crates/core/tests/unit/plugins/nemo_guardrails/local_python_tests.rs +++ b/crates/core/tests/unit/plugins/nemo_guardrails/local_python_tests.rs @@ -11,7 +11,6 @@ use std::path::{Path, PathBuf}; use std::process::Command; #[cfg(unix)] use std::sync::atomic::{AtomicUsize, Ordering}; -#[cfg(unix)] use std::sync::{Arc, Mutex}; use serde_json::json; @@ -295,6 +294,65 @@ async fn stream_monitor_records_blocked_message() { assert_eq!(blocked.lock().unwrap().as_deref(), Some("blocked stream")); } +#[tokio::test(flavor = "current_thread")] +async fn guarded_provider_stream_reports_block_before_buffered_chunks() { + let provider_stream: LlmJsonStream = Box::pin(tokio_stream::iter(vec![Ok(json!({ + "choices": [{"delta": {"content": "blocked"}}], + }))])); + let (text_tx, mut text_rx) = mpsc::channel::>(8); + let (chunk_tx, mut chunk_rx) = mpsc::channel(8); + let blocked = Arc::new(Mutex::new(None)); + let monitor_blocked = Arc::clone(&blocked); + let monitor = tokio::spawn(async move { + while let Some(item) = text_rx.recv().await { + match item { + Some(text) if text.contains("blocked") => { + *monitor_blocked.lock().unwrap() = Some("blocked stream".to_string()); + } + Some(_) => {} + None => break, + } + } + Ok(()) + }); + + forward_guarded_provider_stream( + provider_stream, + LocalGuardrailsCodec::OpenAIChat, + text_tx, + chunk_tx, + monitor, + blocked, + ) + .await; + + let error = chunk_rx.recv().await.unwrap().unwrap_err(); + assert!( + error.to_string().contains("blocked stream"), + "unexpected error: {error}" + ); + assert!(chunk_rx.recv().await.is_none()); +} + +#[test] +fn parse_check_result_rejects_unknown_status() { + assert!(matches!( + parse_check_result(json!({"status": "passed"})).unwrap(), + LocalCheckOutcome::Passed + )); + + let error = match parse_check_result(json!({"status": "surprising"})) { + Ok(_) => panic!("expected unknown status to fail"), + Err(error) => error, + }; + assert!( + error + .to_string() + .contains("unexpected worker check status: surprising"), + "unexpected error: {error}" + ); +} + #[test] fn modified_tool_payload_rejects_malformed_content() { let error = modified_tool_payload("not-json", "arguments").unwrap_err(); diff --git a/docs/nemo-guardrails-plugin/configuration.mdx b/docs/nemo-guardrails-plugin/configuration.mdx index c4eb3715..c7ad06fb 100644 --- a/docs/nemo-guardrails-plugin/configuration.mdx +++ b/docs/nemo-guardrails-plugin/configuration.mdx @@ -51,7 +51,7 @@ At least one managed Guardrails surface must be enabled. | Built-in component kind and config validation | Supported | Supported | | Managed LLM `input` | Supported | Supported | | Managed LLM `output` | Supported | Supported | -| Managed streaming LLM execution | Supported over the remote HTTP(S) contract | Supported for managed input checks and Guardrails-native output streaming when `rails.output.streaming.enabled = true`; with `stream_first = true`, output rails can stop the stream after some chunks have already been delivered; `stream_first = false` is not supported yet | +| Managed streaming LLM execution | Supported over the remote HTTP(S) contract | Supported for managed input checks and Guardrails-native output streaming when `rails.output.streaming.enabled = true`; provider chunks are released only after the local output rail monitor clears the stream; `stream_first = false` is not supported yet | | Managed `tool_input` | Not supported against the stock Guardrails remote contract | Supported | | Managed `tool_output` | Supported | Supported | | `request_defaults` pass-through | Supported | Not supported | @@ -325,22 +325,21 @@ The current local mode supports streaming LLM input checks before the stream callback runs. When output rails are configured, the current local mode uses Guardrails-native -streaming output rails instead of buffering the full provider stream. That -requires `rails.output.streaming.enabled = true` in the Guardrails config. +streaming output rails and buffers provider chunks until the local output rail +monitor clears the stream. That requires `rails.output.streaming.enabled = true` +in the Guardrails config. Guardrails calls the main streaming-output switch `rails.output.streaming.stream_first`. -When `stream_first = true`, the current local mode uses pass-through-first -streaming semantics: +When `stream_first = true`, the current local mode keeps provider-shaped chunks +buffered while Guardrails evaluates the streamed text: -- provider chunks can flow to the caller immediately -- Guardrails evaluates the streamed text in parallel -- if Guardrails later blocks the stream, the call fails at that point even - though some chunks may already have been delivered +- provider chunks are not released to the caller until the monitor finishes +- if Guardrails blocks the stream, the call fails without delivering those chunks The current local mode does not support `rails.output.streaming.stream_first = false` -yet. That mode would be Guardrails-first streaming semantics: +yet. That mode would require Guardrails-first chunk reconstruction: - Guardrails would need to evaluate streamed text before chunks are released to the caller From 33df99fa1c22023db8ac8a618f67080847c54574 Mon Sep 17 00:00:00 2001 From: Will Killian Date: Fri, 5 Jun 2026 20:56:37 -0400 Subject: [PATCH 20/20] fix: track local guardrails worker tasks Signed-off-by: Will Killian --- .../plugins/nemo_guardrails/local_worker.py | 49 +++++++++++++------ docs/nemo-guardrails-plugin/configuration.mdx | 2 +- 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/crates/core/src/plugins/nemo_guardrails/local_worker.py b/crates/core/src/plugins/nemo_guardrails/local_worker.py index bd519655..937fef88 100644 --- a/crates/core/src/plugins/nemo_guardrails/local_worker.py +++ b/crates/core/src/plugins/nemo_guardrails/local_worker.py @@ -205,7 +205,13 @@ async def monitor_stream(self, request_id, messages, queue, streams): streams = {} -async def handle_message(message): +def track_task(pending_tasks, task): + pending_tasks.add(task) + task.add_done_callback(pending_tasks.discard) + return task + + +async def handle_message(message, pending_tasks): global worker request_id = str(message.get("id", "")) @@ -235,7 +241,10 @@ async def handle_message(message): elif command == "stream_start": queue = asyncio.Queue(maxsize=STREAM_QUEUE_MAXSIZE) streams[request_id] = queue - asyncio.create_task(worker.monitor_stream(request_id, message.get("messages") or [], queue, streams)) + track_task( + pending_tasks, + asyncio.create_task(worker.monitor_stream(request_id, message.get("messages") or [], queue, streams)), + ) elif command == "stream_text": queue = streams.get(request_id) if queue is not None: @@ -254,19 +263,29 @@ async def handle_message(message): async def main(): - while True: - line = await asyncio.to_thread(sys.stdin.readline) - if not line: - return - try: - message = json.loads(line) - except Exception: - traceback.print_exc(file=sys.stderr) - continue - if str(message.get("command", "")).startswith("stream_"): - await handle_message(message) - else: - asyncio.create_task(handle_message(message)) + pending_tasks = set() + try: + while True: + line = await asyncio.to_thread(sys.stdin.readline) + if not line: + return + try: + message = json.loads(line) + except Exception: + traceback.print_exc(file=sys.stderr) + continue + if str(message.get("command", "")).startswith("stream_"): + await handle_message(message, pending_tasks) + else: + track_task( + pending_tasks, + asyncio.create_task(handle_message(message, pending_tasks)), + ) + finally: + for task in tuple(pending_tasks): + task.cancel() + if pending_tasks: + await asyncio.gather(*pending_tasks, return_exceptions=True) asyncio.run(main()) diff --git a/docs/nemo-guardrails-plugin/configuration.mdx b/docs/nemo-guardrails-plugin/configuration.mdx index c7ad06fb..dd604995 100644 --- a/docs/nemo-guardrails-plugin/configuration.mdx +++ b/docs/nemo-guardrails-plugin/configuration.mdx @@ -51,7 +51,7 @@ At least one managed Guardrails surface must be enabled. | Built-in component kind and config validation | Supported | Supported | | Managed LLM `input` | Supported | Supported | | Managed LLM `output` | Supported | Supported | -| Managed streaming LLM execution | Supported over the remote HTTP(S) contract | Supported for managed input checks and Guardrails-native output streaming when `rails.output.streaming.enabled = true`; provider chunks are released only after the local output rail monitor clears the stream; `stream_first = false` is not supported yet | +| Managed streaming LLM execution | Supported over the remote HTTP(S) contract | Supported; see [Streaming Boundary](#streaming-boundary) | | Managed `tool_input` | Not supported against the stock Guardrails remote contract | Supported | | Managed `tool_output` | Supported | Supported | | `request_defaults` pass-through | Supported | Not supported |