diff --git a/.env_local_example b/.env_local_example index 5126c95c9..626476ca1 100644 --- a/.env_local_example +++ b/.env_local_example @@ -13,11 +13,11 @@ OPENAI_CHAT_MODEL="gpt-4o" ############## # The below GLOBAL_MEMORY_LABELS will be applied to all prompts sent via attacks and can be altered whenever needed. -# Example recommended labels are shown below: `username`, `op_name`. Others that may be useful include: +# Example recommended labels are shown below: `operator`, `operation`. Others that may be useful include: # `language`, `harm_category`, `stage`, or `technique. For the above labels, please stick to the exact spelling, # spacing, and casing for better standardization throughout the database. ############## -GLOBAL_MEMORY_LABELS = {"username": "username"} +GLOBAL_MEMORY_LABELS = {"operator": "operator", "operation": "operation"} ############## # Set optional OPENAI_CHAT_ADDITIONAL_REQUEST_HEADERS to include additional HTTP headers in a dictionary format for API requests, e.g., {'key1': 'value1'}. diff --git a/build_scripts/env_local_integration_test b/build_scripts/env_local_integration_test index d9873e435..b61cfe7d2 100644 --- a/build_scripts/env_local_integration_test +++ b/build_scripts/env_local_integration_test @@ -22,7 +22,7 @@ DEFAULT_OPENAI_FRONTEND_ENDPOINT=${AZURE_OPENAI_INTEGRATION_TEST_ENDPOINT} DEFAULT_OPENAI_FRONTEND_KEY=${AZURE_OPENAI_INTEGRATION_TEST_KEY} DEFAULT_OPENAI_FRONTEND_MODEL=${AZURE_OPENAI_INTEGRATION_TEST_MODEL} -GLOBAL_MEMORY_LABELS={"username": "integration-test", "op_name": "integration-test"} +GLOBAL_MEMORY_LABELS={"operator": "integration-test", "operation": "integration-test"} ############## # Set optional OPENAI_CHAT_ADDITIONAL_REQUEST_HEADERS to include additional HTTP headers in a dictionary format for API requests, e.g., {'key1': 'value1'}. diff --git a/doc/code/memory/4_manually_working_with_memory.md b/doc/code/memory/4_manually_working_with_memory.md index 8beccb10f..f20af7a0e 100644 --- a/doc/code/memory/4_manually_working_with_memory.md +++ b/doc/code/memory/4_manually_working_with_memory.md @@ -32,7 +32,7 @@ This is especially nice with scoring. There are countless ways to do this, but t ![scoring_2.png](../../../assets/scoring_3_pivot.png) ## Using AzureSQL Query Editor to Query and Export Data -If you are using an AzureSQL Database, you can use the Query Editor to run SQL queries to retrieve desired data. Memory labels (`labels`) may be an especially useful column to query on for finding data pertaining to a specific operation, user, harm_category, etc. Memory labels are a free-from dictionary for tagging prompts with whatever information you'd like (e.g. `op_name`, `username`, `harm_category`). (For more information on memory labels, see the [Advanced Memory Guide](../memory/5_advanced_memory.ipynb).) An example is shown below: +If you are using an AzureSQL Database, you can use the Query Editor to run SQL queries to retrieve desired data. Memory labels (`labels`) may be an especially useful column to query on for finding data pertaining to a specific operation, user, harm_category, etc. Memory labels are a free-from dictionary for tagging prompts with whatever information you'd like (e.g. `operation`, `operator`, `harm_category`). (For more information on memory labels, see the [Advanced Memory Guide](../memory/5_advanced_memory.ipynb).) An example is shown below: 1. Write a SQL query in the Query Editor. You can either write these manually or use the "Open Query" option to load one in. The image below shows a query that gathers prompt entries with their corresponding scores for a specific operation (using the `labels` column) with a "float_scale" `score_type`. diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 5e2ce65b9..6f720b303 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -70,6 +70,8 @@ class ConfigurationLoader(YamlLoadable): env_files: List of environment file paths to load. None means "use defaults (.env, .env.local)", [] means "load nothing". silent: Whether to suppress initialization messages. + operator: Name for the current operator, e.g. a team or username. + operation: Name for the current operation. Example YAML configuration: memory_db_type: sqlite diff --git a/pyrit/setup/initializers/airt.py b/pyrit/setup/initializers/airt.py index 96740565d..c899797b9 100644 --- a/pyrit/setup/initializers/airt.py +++ b/pyrit/setup/initializers/airt.py @@ -8,11 +8,15 @@ AIRT configuration including converters, scorers, and targets using Azure OpenAI. """ +import json import os from collections.abc import Callable +import yaml + from pyrit.auth import get_azure_openai_auth, get_azure_token_provider from pyrit.common.apply_defaults import set_default_value, set_global_variable +from pyrit.common.path import DEFAULT_CONFIG_PATH from pyrit.executor.attack import ( AttackAdversarialConfig, AttackScoringConfig, @@ -43,12 +47,15 @@ class AIRTInitializer(PyRITInitializer): - Converter targets with Azure OpenAI configuration - Composite harm and objective scorers - Adversarial target configurations for attacks + - Use of an Azure SQL database Required Environment Variables: - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT: Azure OpenAI endpoint for converters and targets - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL: Azure OpenAI model name for converters and targets - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2: Azure OpenAI endpoint for scoring - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2: Azure OpenAI model name for scoring + - AZURE_SQL_DB_CONNECTION_STRING: Azure SQL database connection string + - AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: Azure SQL database location Optional Environment Variables: - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY: API key for converter endpoint. If not set, Entra ID auth is used. @@ -90,6 +97,8 @@ def required_env_vars(self) -> list[str]: "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2", "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2", "AZURE_CONTENT_SAFETY_API_ENDPOINT", + "AZURE_SQL_DB_CONNECTION_STRING", + "AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL", ] async def initialize_async(self) -> None: @@ -102,6 +111,9 @@ async def initialize_async(self) -> None: 3. Adversarial target configurations 4. Default values for all attack types """ + # Ensure operator, operation, and email are populated from GLOBAL_MEMORY_LABELS. + self._validate_operation_fields() + # Get environment variables (validated by validate() method) converter_endpoint = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") converter_model_name = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL") @@ -255,3 +267,35 @@ def _setup_adversarial_targets(self, *, endpoint: str, api_key: str, model_name: parameter_name="attack_adversarial_config", value=adversarial_config, ) + + def _validate_operation_fields(self) -> None: + """ + Check that mandatory global memory labels (operation, operator) + are populated. + + Raises: + ValueError: If mandatory global memory labels are missing. + """ + with open(DEFAULT_CONFIG_PATH) as f: + data = yaml.load(f, Loader=yaml.SafeLoader) + + if "operator" not in data: + raise ValueError( + "Error: `operator` was not set in .pyrit_conf. This is a required value for the AIRTInitializer." + ) + + if "operation" not in data: + raise ValueError( + "Error: `operation` was not set in .pyrit_conf. This is a required value for the AIRTInitializer." + ) + + raw_labels = os.environ.get("GLOBAL_MEMORY_LABELS") + labels = dict(json.loads(raw_labels)) if raw_labels else {} + + if "operator" not in labels: + labels["operator"] = data["operator"] + + if "operation" not in labels: + labels["operation"] = data["operation"] + + os.environ["GLOBAL_MEMORY_LABELS"] = json.dumps(labels) diff --git a/tests/unit/setup/test_airt_initializer.py b/tests/unit/setup/test_airt_initializer.py index 2a8606cde..61f74cbe5 100644 --- a/tests/unit/setup/test_airt_initializer.py +++ b/tests/unit/setup/test_airt_initializer.py @@ -6,11 +6,21 @@ from unittest.mock import patch import pytest +import yaml from pyrit.common.apply_defaults import reset_default_values from pyrit.setup.initializers import AIRTInitializer +@pytest.fixture +def patch_pyrit_conf(tmp_path): + """Create a temporary .pyrit_conf file and patch DEFAULT_CONFIG_PATH to point to it.""" + conf_file = tmp_path / ".pyrit_conf" + conf_file.write_text(yaml.dump({"operator": "test_user", "operation": "test_op"})) + with patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", conf_file): + yield + + class TestAIRTInitializer: """Tests for AIRTInitializer class - basic functionality.""" @@ -41,6 +51,11 @@ def setup_method(self) -> None: os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2"] = "https://test-scorer.openai.azure.com" os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2"] = "gpt-4" os.environ["AZURE_CONTENT_SAFETY_API_ENDPOINT"] = "https://test-safety.cognitiveservices.azure.com" + os.environ["AZURE_SQL_DB_CONNECTION_STRING"] = "Server=test.database.windows.net;Database=testdb" + os.environ["AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL"] = "https://teststorage.blob.core.windows.net/data" + os.environ["GLOBAL_MEMORY_LABELS"] = ( + '{"operation": "test_op", "operator": "test_user", "email": "test@test.com"}' + ) # Clean up globals for attr in [ "default_converter_target", @@ -61,6 +76,9 @@ def teardown_method(self) -> None: "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2", "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2", "AZURE_CONTENT_SAFETY_API_ENDPOINT", + "AZURE_SQL_DB_CONNECTION_STRING", + "AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL", + "GLOBAL_MEMORY_LABELS", ]: if var in os.environ: del os.environ[var] @@ -75,7 +93,7 @@ def teardown_method(self) -> None: delattr(sys.modules["__main__"], attr) @pytest.mark.asyncio - async def test_initialize_runs_without_error(self): + async def test_initialize_runs_without_error(self, patch_pyrit_conf): """Test that initialize runs without errors when no API keys are set (Entra auth fallback).""" init = AIRTInitializer() with ( @@ -85,7 +103,7 @@ async def test_initialize_runs_without_error(self): await init.initialize_async() @pytest.mark.asyncio - async def test_initialize_uses_api_keys_when_set(self): + async def test_initialize_uses_api_keys_when_set(self, patch_pyrit_conf): """Test that initialize uses API keys from env vars when they are set.""" os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY"] = "converter-key" os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2"] = "scorer-key" @@ -110,7 +128,7 @@ async def test_initialize_uses_api_keys_when_set(self): del os.environ[var] @pytest.mark.asyncio - async def test_get_info_after_initialize_has_populated_data(self): + async def test_get_info_after_initialize_has_populated_data(self, patch_pyrit_conf): """Test that get_info_async() returns populated data after initialization.""" init = AIRTInitializer() with ( @@ -174,6 +192,38 @@ def test_validate_missing_multiple_env_vars_raises_error(self): assert "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT" in error_message assert "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL" in error_message + def test_validate_missing_operator_raises_error(self, tmp_path): + """Test that _validate_operation_fields raises error when operator is missing from .pyrit_conf.""" + conf_file = tmp_path / ".pyrit_conf" + conf_file.write_text(yaml.dump({"operation": "test_op"})) + init = AIRTInitializer() + with ( + patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", conf_file), + pytest.raises(ValueError, match="operator"), + ): + init._validate_operation_fields() + + def test_validate_missing_operation_raises_error(self, tmp_path): + """Test that _validate_operation_fields raises error when operation is missing from .pyrit_conf.""" + conf_file = tmp_path / ".pyrit_conf" + conf_file.write_text(yaml.dump({"operator": "test_user"})) + init = AIRTInitializer() + with ( + patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", conf_file), + pytest.raises(ValueError, match="operation"), + ): + init._validate_operation_fields() + + def test_validate_db_connection_raises_error(self): + """Test that validate raises error when AZURE_SQL_DB_CONNECTION_STRING is missing.""" + del os.environ["AZURE_SQL_DB_CONNECTION_STRING"] + init = AIRTInitializer() + with pytest.raises(ValueError) as exc_info: + init.validate() + + error_message = str(exc_info.value) + assert "AZURE_SQL_DB_CONNECTION_STRING" in error_message + class TestAIRTInitializerGetInfo: """Tests for AIRTInitializer.get_info method - basic functionality."""