Skip to content
Merged
4 changes: 2 additions & 2 deletions .env_local_example
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ OPENAI_CHAT_MODEL="gpt-4o"

##############
# The below GLOBAL_MEMORY_LABELS will be applied to all prompts sent via attacks and can be altered whenever needed.
# Example recommended labels are shown below: `username`, `op_name`. Others that may be useful include:
# Example recommended labels are shown below: `operator`, `operation`. Others that may be useful include:
# `language`, `harm_category`, `stage`, or `technique. For the above labels, please stick to the exact spelling,
# spacing, and casing for better standardization throughout the database.
##############
GLOBAL_MEMORY_LABELS = {"username": "username"}
GLOBAL_MEMORY_LABELS = {"operator": "operator", "operation": "operation"}

##############
# Set optional OPENAI_CHAT_ADDITIONAL_REQUEST_HEADERS to include additional HTTP headers in a dictionary format for API requests, e.g., {'key1': 'value1'}.
Expand Down
2 changes: 1 addition & 1 deletion build_scripts/env_local_integration_test
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ DEFAULT_OPENAI_FRONTEND_ENDPOINT=${AZURE_OPENAI_INTEGRATION_TEST_ENDPOINT}
DEFAULT_OPENAI_FRONTEND_KEY=${AZURE_OPENAI_INTEGRATION_TEST_KEY}
DEFAULT_OPENAI_FRONTEND_MODEL=${AZURE_OPENAI_INTEGRATION_TEST_MODEL}

GLOBAL_MEMORY_LABELS={"username": "integration-test", "op_name": "integration-test"}
GLOBAL_MEMORY_LABELS={"operator": "integration-test", "operation": "integration-test"}

##############
# Set optional OPENAI_CHAT_ADDITIONAL_REQUEST_HEADERS to include additional HTTP headers in a dictionary format for API requests, e.g., {'key1': 'value1'}.
Expand Down
2 changes: 1 addition & 1 deletion doc/code/memory/4_manually_working_with_memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ This is especially nice with scoring. There are countless ways to do this, but t
![scoring_2.png](../../../assets/scoring_3_pivot.png)

## Using AzureSQL Query Editor to Query and Export Data
If you are using an AzureSQL Database, you can use the Query Editor to run SQL queries to retrieve desired data. Memory labels (`labels`) may be an especially useful column to query on for finding data pertaining to a specific operation, user, harm_category, etc. Memory labels are a free-from dictionary for tagging prompts with whatever information you'd like (e.g. `op_name`, `username`, `harm_category`). (For more information on memory labels, see the [Advanced Memory Guide](../memory/5_advanced_memory.ipynb).) An example is shown below:
If you are using an AzureSQL Database, you can use the Query Editor to run SQL queries to retrieve desired data. Memory labels (`labels`) may be an especially useful column to query on for finding data pertaining to a specific operation, user, harm_category, etc. Memory labels are a free-from dictionary for tagging prompts with whatever information you'd like (e.g. `operation`, `operator`, `harm_category`). (For more information on memory labels, see the [Advanced Memory Guide](../memory/5_advanced_memory.ipynb).) An example is shown below:

1. Write a SQL query in the Query Editor. You can either write these manually or use the "Open Query" option to load one in. The image below shows a query that gathers prompt entries with their corresponding scores for a specific operation (using the `labels` column) with a "float_scale" `score_type`.

Expand Down
2 changes: 2 additions & 0 deletions pyrit/setup/configuration_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class ConfigurationLoader(YamlLoadable):
env_files: List of environment file paths to load.
None means "use defaults (.env, .env.local)", [] means "load nothing".
silent: Whether to suppress initialization messages.
operator: Name for the current operator, e.g. a team or username.
operation: Name for the current operation.
Example YAML configuration:
memory_db_type: sqlite
Expand Down
44 changes: 44 additions & 0 deletions pyrit/setup/initializers/airt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
AIRT configuration including converters, scorers, and targets using Azure OpenAI.
"""

import json
import os
from collections.abc import Callable

import yaml

from pyrit.auth import get_azure_openai_auth, get_azure_token_provider
from pyrit.common.apply_defaults import set_default_value, set_global_variable
from pyrit.common.path import DEFAULT_CONFIG_PATH
from pyrit.executor.attack import (
AttackAdversarialConfig,
AttackScoringConfig,
Expand Down Expand Up @@ -43,12 +47,15 @@ class AIRTInitializer(PyRITInitializer):
- Converter targets with Azure OpenAI configuration
- Composite harm and objective scorers
- Adversarial target configurations for attacks
- Use of an Azure SQL database

Required Environment Variables:
- AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT: Azure OpenAI endpoint for converters and targets
- AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL: Azure OpenAI model name for converters and targets
- AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2: Azure OpenAI endpoint for scoring
- AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2: Azure OpenAI model name for scoring
- AZURE_SQL_DB_CONNECTION_STRING: Azure SQL database connection string
- AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: Azure SQL database location

Optional Environment Variables:
- AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY: API key for converter endpoint. If not set, Entra ID auth is used.
Expand Down Expand Up @@ -90,6 +97,8 @@ def required_env_vars(self) -> list[str]:
"AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2",
"AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2",
"AZURE_CONTENT_SAFETY_API_ENDPOINT",
"AZURE_SQL_DB_CONNECTION_STRING",
"AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL",
]

async def initialize_async(self) -> None:
Expand All @@ -102,6 +111,9 @@ async def initialize_async(self) -> None:
3. Adversarial target configurations
4. Default values for all attack types
"""
# Ensure operator, operation, and email are populated from GLOBAL_MEMORY_LABELS.
self._validate_operation_fields()

# Get environment variables (validated by validate() method)
converter_endpoint = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT")
converter_model_name = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL")
Expand Down Expand Up @@ -255,3 +267,35 @@ def _setup_adversarial_targets(self, *, endpoint: str, api_key: str, model_name:
parameter_name="attack_adversarial_config",
value=adversarial_config,
)

def _validate_operation_fields(self) -> None:
"""
Check that mandatory global memory labels (operation, operator)
are populated.

Raises:
ValueError: If mandatory global memory labels are missing.
"""
with open(DEFAULT_CONFIG_PATH) as f:
data = yaml.load(f, Loader=yaml.SafeLoader)

if "operator" not in data:
raise ValueError(
"Error: `operator` was not set in .pyrit_conf. This is a required value for the AIRTInitializer."
)

if "operation" not in data:
raise ValueError(
"Error: `operation` was not set in .pyrit_conf. This is a required value for the AIRTInitializer."
)

raw_labels = os.environ.get("GLOBAL_MEMORY_LABELS")
labels = dict(json.loads(raw_labels)) if raw_labels else {}

if "operator" not in labels:
labels["operator"] = data["operator"]

if "operation" not in labels:
labels["operation"] = data["operation"]

os.environ["GLOBAL_MEMORY_LABELS"] = json.dumps(labels)
56 changes: 53 additions & 3 deletions tests/unit/setup/test_airt_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,21 @@
from unittest.mock import patch

import pytest
import yaml

from pyrit.common.apply_defaults import reset_default_values
from pyrit.setup.initializers import AIRTInitializer


@pytest.fixture
def patch_pyrit_conf(tmp_path):
"""Create a temporary .pyrit_conf file and patch DEFAULT_CONFIG_PATH to point to it."""
conf_file = tmp_path / ".pyrit_conf"
conf_file.write_text(yaml.dump({"operator": "test_user", "operation": "test_op"}))
with patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", conf_file):
yield


class TestAIRTInitializer:
"""Tests for AIRTInitializer class - basic functionality."""

Expand Down Expand Up @@ -41,6 +51,11 @@ def setup_method(self) -> None:
os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2"] = "https://test-scorer.openai.azure.com"
os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2"] = "gpt-4"
os.environ["AZURE_CONTENT_SAFETY_API_ENDPOINT"] = "https://test-safety.cognitiveservices.azure.com"
os.environ["AZURE_SQL_DB_CONNECTION_STRING"] = "Server=test.database.windows.net;Database=testdb"
os.environ["AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL"] = "https://teststorage.blob.core.windows.net/data"
os.environ["GLOBAL_MEMORY_LABELS"] = (
'{"operation": "test_op", "operator": "test_user", "email": "test@test.com"}'
)
# Clean up globals
for attr in [
"default_converter_target",
Expand All @@ -61,6 +76,9 @@ def teardown_method(self) -> None:
"AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2",
"AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2",
"AZURE_CONTENT_SAFETY_API_ENDPOINT",
"AZURE_SQL_DB_CONNECTION_STRING",
"AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL",
"GLOBAL_MEMORY_LABELS",
]:
if var in os.environ:
del os.environ[var]
Expand All @@ -75,7 +93,7 @@ def teardown_method(self) -> None:
delattr(sys.modules["__main__"], attr)

@pytest.mark.asyncio
async def test_initialize_runs_without_error(self):
async def test_initialize_runs_without_error(self, patch_pyrit_conf):
"""Test that initialize runs without errors when no API keys are set (Entra auth fallback)."""
init = AIRTInitializer()
with (
Expand All @@ -85,7 +103,7 @@ async def test_initialize_runs_without_error(self):
await init.initialize_async()

@pytest.mark.asyncio
async def test_initialize_uses_api_keys_when_set(self):
async def test_initialize_uses_api_keys_when_set(self, patch_pyrit_conf):
"""Test that initialize uses API keys from env vars when they are set."""
os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY"] = "converter-key"
os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2"] = "scorer-key"
Expand All @@ -110,7 +128,7 @@ async def test_initialize_uses_api_keys_when_set(self):
del os.environ[var]

@pytest.mark.asyncio
async def test_get_info_after_initialize_has_populated_data(self):
async def test_get_info_after_initialize_has_populated_data(self, patch_pyrit_conf):
"""Test that get_info_async() returns populated data after initialization."""
init = AIRTInitializer()
with (
Expand Down Expand Up @@ -174,6 +192,38 @@ def test_validate_missing_multiple_env_vars_raises_error(self):
assert "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT" in error_message
assert "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL" in error_message

def test_validate_missing_operator_raises_error(self, tmp_path):
"""Test that _validate_operation_fields raises error when operator is missing from .pyrit_conf."""
conf_file = tmp_path / ".pyrit_conf"
conf_file.write_text(yaml.dump({"operation": "test_op"}))
init = AIRTInitializer()
with (
patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", conf_file),
pytest.raises(ValueError, match="operator"),
):
init._validate_operation_fields()

def test_validate_missing_operation_raises_error(self, tmp_path):
"""Test that _validate_operation_fields raises error when operation is missing from .pyrit_conf."""
conf_file = tmp_path / ".pyrit_conf"
conf_file.write_text(yaml.dump({"operator": "test_user"}))
init = AIRTInitializer()
with (
patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", conf_file),
pytest.raises(ValueError, match="operation"),
):
init._validate_operation_fields()

def test_validate_db_connection_raises_error(self):
"""Test that validate raises error when AZURE_SQL_DB_CONNECTION_STRING is missing."""
del os.environ["AZURE_SQL_DB_CONNECTION_STRING"]
init = AIRTInitializer()
with pytest.raises(ValueError) as exc_info:
init.validate()

error_message = str(exc_info.value)
assert "AZURE_SQL_DB_CONNECTION_STRING" in error_message


class TestAIRTInitializerGetInfo:
"""Tests for AIRTInitializer.get_info method - basic functionality."""
Expand Down
Loading