diff --git a/privacy_guard/analysis/llm_judge/llm_judge_analysis_input.py b/privacy_guard/analysis/llm_judge/llm_judge_analysis_input.py new file mode 100644 index 0000000..f71ffbe --- /dev/null +++ b/privacy_guard/analysis/llm_judge/llm_judge_analysis_input.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict + +import logging + +import pandas as pd +from privacy_guard.analysis.base_analysis_input import BaseAnalysisInput +from privacy_guard.analysis.llm_judge.llm_judge_config import LLMJudgeConfig + + +logger: logging.Logger = logging.getLogger(__name__) + + +class LLMJudgeAnalysisInput(BaseAnalysisInput): + """Input for LLM-as-judge evaluation. + + Takes a single dataframe containing at minimum a ``prompt`` column and a + ``generation`` column. A ``reference_text`` column is **optional** — when + absent the judge evaluates solely based on the configured scoring criteria. + + Args: + generation_df: DataFrame with prompt/generation (and optionally + reference_text) columns. + config: ``LLMJudgeConfig`` specifying the provider, model, eval + prompt template, and scoring criteria. + prompt_key: Column name for the input prompt. + generation_key: Column name for the model-generated text. + reference_key: Column name for the ground-truth reference text. + Set to ``None`` when no reference text is available. + """ + + REQUIRED_COLUMNS: list[str] = ["prompt", "generation"] + + def __init__( + self, + generation_df: pd.DataFrame, + config: LLMJudgeConfig, + prompt_key: str = "prompt", + generation_key: str = "generation", + reference_key: str | None = "reference_text", + ) -> None: + columns = generation_df.columns.tolist() + assert prompt_key in columns, ( + f"Prompt key '{prompt_key}' not found in dataframe columns {columns}" + ) + assert generation_key in columns, ( + f"Generation key '{generation_key}' not found in " + f"dataframe columns {columns}" + ) + if reference_key is not None and reference_key not in columns: + logger.warning( + f"Reference key '{reference_key}' not found in dataframe " + f"columns {columns}. Proceeding without reference text." + ) + reference_key = None + + self.prompt_key = prompt_key + self.generation_key = generation_key + self.reference_key = reference_key + self.config = config + + super().__init__(df_train_user=generation_df, df_test_user=pd.DataFrame()) + + @property + def generation_df(self) -> pd.DataFrame: + """Property accessor for the generation dataframe.""" + return self._df_train_user + + @property + def has_reference(self) -> bool: + """Whether reference text is available for evaluation.""" + return self.reference_key is not None diff --git a/privacy_guard/analysis/llm_judge/llm_judge_analysis_node.py b/privacy_guard/analysis/llm_judge/llm_judge_analysis_node.py new file mode 100644 index 0000000..02185f5 --- /dev/null +++ b/privacy_guard/analysis/llm_judge/llm_judge_analysis_node.py @@ -0,0 +1,431 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict + +import json +import logging +import os +import time +from dataclasses import dataclass, field +from typing import Any, Callable, cast + +import pandas as pd +import requests +from privacy_guard.analysis.base_analysis_node import BaseAnalysisNode +from privacy_guard.analysis.base_analysis_output import BaseAnalysisOutput +from privacy_guard.analysis.llm_judge.llm_judge_analysis_input import ( + LLMJudgeAnalysisInput, +) +from privacy_guard.analysis.llm_judge.llm_judge_config import ( + LLMJudgeConfig, + LLMProvider, +) +from tqdm import tqdm + + +logger: logging.Logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Default prompt template +# --------------------------------------------------------------------------- + +_DEFAULT_EVAL_PROMPT_WITH_REFERENCE = """You are an impartial judge evaluating the quality of an AI-generated response. + +## Input Prompt +{prompt} + +## AI-Generated Response +{generation} + +## Reference Text +{reference_text} + +## Evaluation Criteria +{criteria} + +Evaluate the AI-generated response against the reference text on each criterion. +For each criterion, assign an integer score from 1 (worst) to 5 (best). + +Return your evaluation as JSON with this exact structure: +```json +{{ + "scores": {{"criterion_name": score, ...}}, + "overall_score": , + "reasoning": "" +}} +```""" + +_DEFAULT_EVAL_PROMPT_WITHOUT_REFERENCE = """You are an impartial judge evaluating the quality of an AI-generated response. + +## Input Prompt +{prompt} + +## AI-Generated Response +{generation} + +## Evaluation Criteria +{criteria} + +Evaluate the AI-generated response on each criterion. +For each criterion, assign an integer score from 1 (worst) to 5 (best). + +Return your evaluation as JSON with this exact structure: +```json +{{ + "scores": {{"criterion_name": score, ...}}, + "overall_score": , + "reasoning": "" +}} +```""" + +_DEFAULT_CRITERIA = ["accuracy", "relevance", "fluency", "completeness"] + + +# --------------------------------------------------------------------------- +# Output dataclass +# --------------------------------------------------------------------------- + + +@dataclass +class LLMJudgeAnalysisOutput(BaseAnalysisOutput): + """Encapsulates the outputs of LLMJudgeAnalysisNode.""" + + num_samples: int + avg_overall_score: float + per_sample_overall_scores: list[float] + per_criteria_avg_scores: dict[str, float] + per_sample_criteria_scores: list[dict[str, float]] + per_sample_reasoning: list[str] + num_failed: int + provider: str + model: str + augmented_output_dataset: pd.DataFrame = field(repr=False) + + +# --------------------------------------------------------------------------- +# API callers — one per provider +# --------------------------------------------------------------------------- + + +def _get_api_key(config: LLMJudgeConfig) -> str: + """Retrieve the API key from the environment. + + The API key is expected to be stored in the following environment variables: + ANTHROPIC_API_KEY, + OPENAI_API_KEY, or + GEMINI_API_KEY + depending on the provider. + """ + key = os.environ.get(config.api_key_env_var, None) + if not key: + raise ValueError( + f"API key not found. Set the '{config.api_key_env_var}' " + f"environment variable." + ) + return key + + +def _call_anthropic( + prompt: str, + config: LLMJudgeConfig, + api_key: str, +) -> dict[str, Any]: + """Call the Anthropic Messages API and return parsed JSON.""" + response = requests.post( + "https://api.anthropic.com/v1/messages", + headers={ + "x-api-key": api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + json={ + "model": config.model, + "max_tokens": config.max_tokens, + "temperature": config.temperature, + "messages": [{"role": "user", "content": prompt}], + }, + timeout=120, + ) + response.raise_for_status() + data = response.json() + text: str = data["content"][0]["text"] + return _parse_json_response(text) + + +def _call_openai( + prompt: str, + config: LLMJudgeConfig, + api_key: str, +) -> dict[str, Any]: + """Call the OpenAI Chat Completions API and return parsed JSON.""" + response = requests.post( + "https://api.openai.com/v1/chat/completions", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + json={ + "model": config.model, + "max_completion_tokens": config.max_tokens, + "temperature": config.temperature, + "messages": [ + { + "role": "system", + "content": "You are an evaluation judge. Always respond with valid JSON.", + }, + {"role": "user", "content": prompt}, + ], + }, + timeout=120, + ) + response.raise_for_status() + data = response.json() + text: str = data["choices"][0]["message"]["content"] + return _parse_json_response(text) + + +def _call_gemini( + prompt: str, + config: LLMJudgeConfig, + api_key: str, +) -> dict[str, Any]: + """Call the Gemini generateContent API and return parsed JSON.""" + url = ( + f"https://generativelanguage.googleapis.com/v1/models/" + f"{config.model}:generateContent" + ) + response = requests.post( + url, + headers={ + "Content-Type": "application/json", + "x-goog-api-key": api_key, + }, + json={ + "contents": [{"parts": [{"text": prompt}]}], + "generationConfig": { + "temperature": config.temperature, + "maxOutputTokens": config.max_tokens, + }, + }, + timeout=120, + ) + response.raise_for_status() + data = response.json() + text: str = data["candidates"][0]["content"]["parts"][0]["text"] + return _parse_json_response(text) + + +_PROVIDER_CALLERS: dict[ + LLMProvider, + Callable[[str, LLMJudgeConfig, str], dict[str, Any]], +] = { + LLMProvider.ANTHROPIC: _call_anthropic, + LLMProvider.OPENAI: _call_openai, + LLMProvider.GEMINI: _call_gemini, +} + + +def _parse_json_response(text: str) -> dict[str, Any]: + """Extract and parse JSON from a model response that may contain markdown fences.""" + cleaned = text.strip() + if "```json" in cleaned: + cleaned = cleaned.split("```json", 1)[1] + if "```" in cleaned: + cleaned = cleaned.split("```", 1)[0] + cleaned = cleaned.strip() + result: dict[str, Any] = json.loads(cleaned) + return result + + +def call_llm_judge( + prompt: str, + config: LLMJudgeConfig, + max_retries: int = 3, +) -> dict[str, Any]: + """Dispatch a judge evaluation call to the configured provider. + + Retries on transient failures with exponential back-off. + """ + api_key = _get_api_key(config) + caller = _PROVIDER_CALLERS[config.provider] + + last_error: Exception | None = None + for attempt in range(max_retries): + try: + return caller(prompt, config, api_key) + except ( + requests.RequestException, + json.JSONDecodeError, + KeyError, + IndexError, + ) as exc: + last_error = exc + wait = min(2**attempt, 8) + logger.warning( + f"Judge call attempt {attempt + 1}/{max_retries} failed: {exc}. " + f"Retrying in {wait}s." + ) + time.sleep(wait) + + raise RuntimeError( + f"Judge evaluation failed after {max_retries} attempts. " + f"Last error: {last_error}" + ) + + +# --------------------------------------------------------------------------- +# Analysis node +# --------------------------------------------------------------------------- + + +class LLMJudgeAnalysisNode(BaseAnalysisNode): + """Evaluates generation quality using an LLM-as-judge. + + For each row in the input dataframe the node: + 1. Builds an evaluation prompt from the configured template. + 2. Sends the prompt to the configured LLM provider. + 3. Parses the structured JSON verdict (per-criteria scores + reasoning). + 4. Aggregates per-sample scores into summary metrics. + """ + + def __init__(self, analysis_input: LLMJudgeAnalysisInput) -> None: + tqdm.pandas() + self._config: LLMJudgeConfig = analysis_input.config + super().__init__(analysis_input=analysis_input) + + def _build_eval_prompt(self, row: pd.Series) -> str: + """Build the evaluation prompt for a single sample.""" + analysis_input = cast(LLMJudgeAnalysisInput, self.analysis_input) + config = self._config + + prompt_text: str = str(row[analysis_input.prompt_key]) + generation_text: str = str(row[analysis_input.generation_key]) + reference_text: str = "" + if analysis_input.has_reference and analysis_input.reference_key is not None: + reference_text = str(row[analysis_input.reference_key]) + + criteria = config.scoring_criteria or _DEFAULT_CRITERIA + criteria_text = "\n".join(f"- {c}" for c in criteria) + + if config.eval_prompt: + return config.eval_prompt.format( + prompt=prompt_text, + generation=generation_text, + reference_text=reference_text, + criteria=criteria_text, + ) + + if analysis_input.has_reference: + return _DEFAULT_EVAL_PROMPT_WITH_REFERENCE.format( + prompt=prompt_text, + generation=generation_text, + reference_text=reference_text, + criteria=criteria_text, + ) + + return _DEFAULT_EVAL_PROMPT_WITHOUT_REFERENCE.format( + prompt=prompt_text, + generation=generation_text, + criteria=criteria_text, + ) + + def _evaluate_single(self, row: pd.Series) -> dict[str, Any]: + """Run judge evaluation on a single row and return parsed result.""" + eval_prompt = self._build_eval_prompt(row) + try: + return call_llm_judge(eval_prompt, self._config) + except RuntimeError as exc: + logger.error(f"Judge evaluation failed for row: {exc}") + criteria = self._config.scoring_criteria or _DEFAULT_CRITERIA + return { + "scores": dict.fromkeys(criteria, 0), + "overall_score": 0.0, + "reasoning": f"Evaluation failed: {exc}", + } + + def run_analysis(self) -> LLMJudgeAnalysisOutput: + """Execute judge evaluation across all samples.""" + analysis_input = cast(LLMJudgeAnalysisInput, self.analysis_input) + df = analysis_input.generation_df.copy() + + logger.info( + f"Starting LLM judge evaluation: {len(df)} samples, " + f"provider={self._config.provider.value}, " + f"model={self._config.model}" + ) + + results: list[dict[str, Any]] = df.progress_apply( + self._evaluate_single, axis=1 + ).tolist() + + # Extract per-sample metrics + per_sample_overall: list[float] = [] + per_sample_criteria: list[dict[str, float]] = [] + per_sample_reasoning: list[str] = [] + num_failed = 0 + + for result in results: + score = float(result.get("overall_score", 0.0)) + per_sample_overall.append(score) + per_sample_criteria.append(result.get("scores", {})) + per_sample_reasoning.append(result.get("reasoning", "")) + if score == 0.0: + num_failed += 1 + + # Aggregate per-criteria averages (excluding failed evaluations) + criteria = self._config.scoring_criteria or _DEFAULT_CRITERIA + criteria_totals: dict[str, float] = dict.fromkeys(criteria, 0.0) + criteria_counts: dict[str, int] = {c: 0 for c in criteria} # noqa: C420 + + for sample_scores in per_sample_criteria: + for c in criteria: + val = sample_scores.get(c, 0) + if val > 0: + criteria_totals[c] += float(val) + criteria_counts[c] += 1 + + per_criteria_avg: dict[str, float] = { + c: ( + criteria_totals[c] / criteria_counts[c] + if criteria_counts[c] > 0 + else 0.0 + ) + for c in criteria + } + + valid_scores = [s for s in per_sample_overall if s > 0] + avg_overall = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0 + + # Augment the output dataset with judge results + df["judge_overall_score"] = per_sample_overall + df["judge_reasoning"] = per_sample_reasoning + for c in criteria: + df[f"judge_{c}_score"] = [s.get(c, 0) for s in per_sample_criteria] + + logger.info( + f"Judge evaluation complete: avg_score={avg_overall:.2f}, " + f"failed={num_failed}/{len(df)}" + ) + + return LLMJudgeAnalysisOutput( + num_samples=len(df), + avg_overall_score=avg_overall, + per_sample_overall_scores=per_sample_overall, + per_criteria_avg_scores=per_criteria_avg, + per_sample_criteria_scores=per_sample_criteria, + per_sample_reasoning=per_sample_reasoning, + num_failed=num_failed, + provider=self._config.provider.value, + model=self._config.model, + augmented_output_dataset=df, + ) diff --git a/privacy_guard/analysis/llm_judge/llm_judge_config.py b/privacy_guard/analysis/llm_judge/llm_judge_config.py new file mode 100644 index 0000000..e7688d9 --- /dev/null +++ b/privacy_guard/analysis/llm_judge/llm_judge_config.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict + +from dataclasses import dataclass, field +from enum import Enum + + +class LLMProvider(Enum): + """Supported LLM providers for judge evaluation.""" + + ANTHROPIC = "anthropic" + OPENAI = "openai" + GEMINI = "gemini" + + +class AnthropicModel(Enum): + """Available Anthropic models.""" + + CLAUDE_4_OPUS = "claude-opus-4-20250514" + CLAUDE_4_SONNET = "claude-sonnet-4-20250514" + CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219" + CLAUDE_3_5_HAIKU = "claude-3-5-haiku-20241022" + + +class OpenAIModel(Enum): + """Available OpenAI models.""" + + GPT_4O = "gpt-4o" + GPT_4O_MINI = "gpt-4o-mini" + GPT_4_1 = "gpt-4.1" + GPT_4_1_MINI = "gpt-4.1-mini" + O3 = "o3" + O3_MINI = "o3-mini" + O4_MINI = "o4-mini" + + +class GeminiModel(Enum): + """Available Gemini models.""" + + GEMINI_2_5_PRO = "gemini-2.5-pro-preview-05-06" + GEMINI_2_5_FLASH = "gemini-2.5-flash-preview-04-17" + GEMINI_2_0_FLASH = "gemini-2.0-flash" + + +# Map provider to its default model +DEFAULT_MODELS: dict[LLMProvider, str] = { + LLMProvider.ANTHROPIC: AnthropicModel.CLAUDE_4_SONNET.value, + LLMProvider.OPENAI: OpenAIModel.GPT_4O.value, + LLMProvider.GEMINI: GeminiModel.GEMINI_2_5_PRO.value, +} + +# Lookup table: all valid model strings grouped by provider +VALID_MODELS: dict[LLMProvider, set[str]] = { + LLMProvider.ANTHROPIC: {m.value for m in AnthropicModel}, + LLMProvider.OPENAI: {m.value for m in OpenAIModel}, + LLMProvider.GEMINI: {m.value for m in GeminiModel}, +} + + +@dataclass +class LLMJudgeConfig: + """Configuration for an LLM-as-judge evaluation run. + + Attributes: + provider: Which LLM provider to use. + model: Model identifier string. Must belong to the chosen provider. + When left as empty string the provider's default model is used. + eval_prompt: The prompt template sent to the judge. Use placeholders + ``{prompt}``, ``{generation}``, and optionally ``{reference_text}`` + and ``{criteria}`` which will be filled at evaluation time. + Leave empty to use the built-in default prompt. + scoring_criteria: List of criteria the judge should evaluate + (e.g. ["accuracy", "fluency", "relevance"]). + temperature: Sampling temperature for the judge model. + max_tokens: Maximum tokens in the judge response. + api_key_env_var: Name of the environment variable holding the API key + for the chosen provider. + """ + + provider: LLMProvider = LLMProvider.ANTHROPIC + model: str = "" + eval_prompt: str = "" + scoring_criteria: list[str] = field(default_factory=list) + temperature: float = 0.0 + max_tokens: int = 1024 + api_key_env_var: str = "" + + def __post_init__(self) -> None: + # Resolve default model when none is specified + if not self.model: + self.model = DEFAULT_MODELS[self.provider] + + # Validate that the model belongs to the chosen provider + if self.model not in VALID_MODELS[self.provider]: + valid = ", ".join(sorted(VALID_MODELS[self.provider])) + raise ValueError( + f"Model '{self.model}' is not valid for provider " + f"'{self.provider.value}'. Valid models: {valid}" + ) + + # Resolve default env var name when none is specified + if not self.api_key_env_var: + self.api_key_env_var = { + LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY", + LLMProvider.OPENAI: "OPENAI_API_KEY", + LLMProvider.GEMINI: "GEMINI_API_KEY", + }[self.provider] diff --git a/privacy_guard/analysis/tests/test_llm_judge_analysis_input.py b/privacy_guard/analysis/tests/test_llm_judge_analysis_input.py new file mode 100644 index 0000000..d35bc56 --- /dev/null +++ b/privacy_guard/analysis/tests/test_llm_judge_analysis_input.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict + + +import unittest + +import pandas as pd +from privacy_guard.analysis.llm_judge.llm_judge_analysis_input import ( + LLMJudgeAnalysisInput, +) +from privacy_guard.analysis.llm_judge.llm_judge_config import LLMJudgeConfig + + +class TestLLMJudgeAnalysisInput(unittest.TestCase): + def setUp(self) -> None: + self.df = pd.DataFrame( + { + "prompt": ["What is AI?", "Explain ML"], + "generation": ["AI is...", "ML is..."], + "reference_text": [ + "Artificial intelligence is...", + "Machine learning is...", + ], + } + ) + self.config = LLMJudgeConfig() + super().setUp() + + def test_init_with_valid_data(self) -> None: + analysis_input = LLMJudgeAnalysisInput( + generation_df=self.df, config=self.config + ) + self.assertEqual(analysis_input.prompt_key, "prompt") + self.assertEqual(analysis_input.generation_key, "generation") + self.assertEqual(analysis_input.reference_key, "reference_text") + self.assertIs(analysis_input.config, self.config) + + def test_generation_df_property(self) -> None: + analysis_input = LLMJudgeAnalysisInput( + generation_df=self.df, config=self.config + ) + pd.testing.assert_frame_equal(analysis_input.generation_df, self.df) + + def test_has_reference_true(self) -> None: + analysis_input = LLMJudgeAnalysisInput( + generation_df=self.df, config=self.config + ) + self.assertTrue(analysis_input.has_reference) + + def test_has_reference_false_when_key_is_none(self) -> None: + analysis_input = LLMJudgeAnalysisInput( + generation_df=self.df, config=self.config, reference_key=None + ) + self.assertFalse(analysis_input.has_reference) + self.assertIsNone(analysis_input.reference_key) + + def test_missing_reference_column_warns_and_sets_none(self) -> None: + df_no_ref = self.df[["prompt", "generation"]].copy() + with self.assertLogs( + "privacy_guard.analysis.llm_judge.llm_judge_analysis_input", + level="WARNING", + ) as cm: + analysis_input = LLMJudgeAnalysisInput( + generation_df=df_no_ref, config=self.config + ) + self.assertTrue(any("reference_text" in msg for msg in cm.output)) + self.assertIsNone(analysis_input.reference_key) + self.assertFalse(analysis_input.has_reference) + + def test_missing_prompt_key_raises(self) -> None: + df_bad = pd.DataFrame({"question": ["What is AI?"], "generation": ["AI is..."]}) + with self.assertRaises(AssertionError): + LLMJudgeAnalysisInput(generation_df=df_bad, config=self.config) + + def test_missing_generation_key_raises(self) -> None: + df_bad = pd.DataFrame({"prompt": ["What is AI?"], "response": ["AI is..."]}) + with self.assertRaises(AssertionError): + LLMJudgeAnalysisInput(generation_df=df_bad, config=self.config) + + def test_custom_column_names(self) -> None: + df_custom = pd.DataFrame( + { + "question": ["What is AI?"], + "response": ["AI is..."], + "gold": ["Artificial intelligence is..."], + } + ) + analysis_input = LLMJudgeAnalysisInput( + generation_df=df_custom, + config=self.config, + prompt_key="question", + generation_key="response", + reference_key="gold", + ) + self.assertEqual(analysis_input.prompt_key, "question") + self.assertEqual(analysis_input.generation_key, "response") + self.assertEqual(analysis_input.reference_key, "gold") + self.assertTrue(analysis_input.has_reference) + + def test_stores_dataframe_via_base_class(self) -> None: + analysis_input = LLMJudgeAnalysisInput( + generation_df=self.df, config=self.config + ) + pd.testing.assert_frame_equal(analysis_input.df_train_user, self.df) + self.assertTrue(analysis_input.df_test_user.empty) diff --git a/privacy_guard/analysis/tests/test_llm_judge_analysis_node.py b/privacy_guard/analysis/tests/test_llm_judge_analysis_node.py new file mode 100644 index 0000000..0f4bbab --- /dev/null +++ b/privacy_guard/analysis/tests/test_llm_judge_analysis_node.py @@ -0,0 +1,649 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict + + +import json +import unittest +from unittest.mock import patch + +import pandas as pd +from privacy_guard.analysis.llm_judge.llm_judge_analysis_input import ( + LLMJudgeAnalysisInput, +) +from privacy_guard.analysis.llm_judge.llm_judge_analysis_node import ( + _call_anthropic, + _call_gemini, + _call_openai, + _parse_json_response, + call_llm_judge, + LLMJudgeAnalysisNode, + LLMJudgeAnalysisOutput, +) +from privacy_guard.analysis.llm_judge.llm_judge_config import ( + LLMJudgeConfig, + LLMProvider, +) + + +def _make_judge_response( + criteria: list[str] | None = None, +) -> dict[str, object]: + """Build a realistic judge response dict for testing.""" + criteria = criteria or ["accuracy", "relevance", "fluency", "completeness"] + scores = {c: 4 for c in criteria} + return { + "scores": scores, + "overall_score": 4.0, + "reasoning": "Good quality response.", + } + + +class TestParseJsonResponse(unittest.TestCase): + def test_clean_json(self) -> None: + raw = json.dumps({"scores": {"accuracy": 5}, "overall_score": 5.0}) + result = _parse_json_response(raw) + self.assertEqual(result["overall_score"], 5.0) + self.assertEqual(result["scores"]["accuracy"], 5) + + def test_json_with_markdown_fences(self) -> None: + raw = '```json\n{"overall_score": 3.5}\n```' + result = _parse_json_response(raw) + self.assertEqual(result["overall_score"], 3.5) + + def test_json_with_surrounding_whitespace(self) -> None: + raw = ' \n {"overall_score": 2.0} \n ' + result = _parse_json_response(raw) + self.assertEqual(result["overall_score"], 2.0) + + def test_invalid_json_raises(self) -> None: + with self.assertRaises(json.JSONDecodeError): + _parse_json_response("not valid json") + + +class TestGetApiKey(unittest.TestCase): + @patch.dict("os.environ", {"ANTHROPIC_API_KEY": "test-key-123"}) + def test_returns_key_when_present(self) -> None: + from privacy_guard.analysis.llm_judge.llm_judge_analysis_node import ( + _get_api_key, + ) + + config = LLMJudgeConfig(provider=LLMProvider.ANTHROPIC) + self.assertEqual(_get_api_key(config), "test-key-123") + + @patch.dict("os.environ", {}, clear=True) + def test_raises_when_key_missing(self) -> None: + from privacy_guard.analysis.llm_judge.llm_judge_analysis_node import ( + _get_api_key, + ) + + config = LLMJudgeConfig( + provider=LLMProvider.ANTHROPIC, api_key_env_var="MISSING_KEY" + ) + with self.assertRaises(ValueError): + _get_api_key(config) + + +class TestCallLLMJudge(unittest.TestCase): + def setUp(self) -> None: + self.config = LLMJudgeConfig(provider=LLMProvider.ANTHROPIC) + self.expected = _make_judge_response() + super().setUp() + + @patch( + "privacy_guard.analysis.llm_judge.llm_judge_analysis_node._get_api_key", + return_value="fake-key", + ) + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node._PROVIDER_CALLERS") + def test_success_on_first_attempt( + self, + mock_callers: unittest.mock.MagicMock, + _mock_key: unittest.mock.MagicMock, + ) -> None: + mock_callers.__getitem__ = unittest.mock.MagicMock( + return_value=unittest.mock.MagicMock(return_value=self.expected) + ) + result = call_llm_judge("test prompt", self.config) + self.assertEqual(result["overall_score"], 4.0) + + @patch( + "privacy_guard.analysis.llm_judge.llm_judge_analysis_node._get_api_key", + return_value="fake-key", + ) + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node._PROVIDER_CALLERS") + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.time.sleep") + def test_retries_on_transient_failure( + self, + _mock_sleep: unittest.mock.MagicMock, + mock_callers: unittest.mock.MagicMock, + _mock_key: unittest.mock.MagicMock, + ) -> None: + caller = unittest.mock.MagicMock( + side_effect=[json.JSONDecodeError("err", "", 0), self.expected] + ) + mock_callers.__getitem__ = unittest.mock.MagicMock(return_value=caller) + result = call_llm_judge("test prompt", self.config, max_retries=3) + self.assertEqual(result["overall_score"], 4.0) + self.assertEqual(caller.call_count, 2) + + @patch( + "privacy_guard.analysis.llm_judge.llm_judge_analysis_node._get_api_key", + return_value="fake-key", + ) + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node._PROVIDER_CALLERS") + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.time.sleep") + def test_raises_after_all_retries_fail( + self, + _mock_sleep: unittest.mock.MagicMock, + mock_callers: unittest.mock.MagicMock, + _mock_key: unittest.mock.MagicMock, + ) -> None: + caller = unittest.mock.MagicMock(side_effect=json.JSONDecodeError("err", "", 0)) + mock_callers.__getitem__ = unittest.mock.MagicMock(return_value=caller) + with self.assertRaises(RuntimeError): + call_llm_judge("test prompt", self.config, max_retries=2) + self.assertEqual(caller.call_count, 2) + + +class TestCallAnthropic(unittest.TestCase): + """Tests for _call_anthropic with mocked requests.post returning realistic API responses.""" + + def setUp(self) -> None: + self.config = LLMJudgeConfig(provider=LLMProvider.ANTHROPIC) + self.api_key = "fake-anthropic-key" + super().setUp() + + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.requests.post") + def test_successful_response(self, mock_post: unittest.mock.MagicMock) -> None: + mock_response = unittest.mock.MagicMock() + mock_response.json.return_value = { + "id": "msg_01XFDUDYJgAACzvnptvVoYEL", + "type": "message", + "role": "assistant", + "content": [ + { + "type": "text", + "text": json.dumps( + { + "scores": { + "accuracy": 4, + "relevance": 5, + "fluency": 4, + "completeness": 3, + }, + "overall_score": 4.0, + "reasoning": "Good quality response.", + } + ), + } + ], + "model": "claude-sonnet-4-20250514", + "stop_reason": "end_turn", + "stop_sequence": None, + "usage": {"input_tokens": 25, "output_tokens": 150}, + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + result = _call_anthropic("test prompt", self.config, self.api_key) + + self.assertEqual(result["overall_score"], 4.0) + self.assertEqual(result["scores"]["accuracy"], 4) + self.assertEqual(result["scores"]["relevance"], 5) + self.assertEqual(result["reasoning"], "Good quality response.") + + mock_post.assert_called_once() + call_kwargs = mock_post.call_args + self.assertEqual(call_kwargs.args[0], "https://api.anthropic.com/v1/messages") + self.assertEqual(call_kwargs.kwargs["headers"]["x-api-key"], self.api_key) + self.assertEqual(call_kwargs.kwargs["json"]["model"], self.config.model) + + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.requests.post") + def test_response_with_markdown_fences( + self, mock_post: unittest.mock.MagicMock + ) -> None: + judge_json = json.dumps( + { + "scores": { + "accuracy": 5, + "relevance": 5, + "fluency": 5, + "completeness": 5, + }, + "overall_score": 5.0, + "reasoning": "Excellent response.", + } + ) + mock_response = unittest.mock.MagicMock() + mock_response.json.return_value = { + "id": "msg_02ABC", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": f"```json\n{judge_json}\n```"}], + "model": "claude-sonnet-4-20250514", + "stop_reason": "end_turn", + "stop_sequence": None, + "usage": {"input_tokens": 30, "output_tokens": 160}, + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + result = _call_anthropic("test prompt", self.config, self.api_key) + self.assertEqual(result["overall_score"], 5.0) + + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.requests.post") + def test_http_error_raises(self, mock_post: unittest.mock.MagicMock) -> None: + import requests as req + + mock_response = unittest.mock.MagicMock() + mock_response.raise_for_status.side_effect = req.HTTPError("401 Unauthorized") + mock_post.return_value = mock_response + + with self.assertRaises(req.HTTPError): + _call_anthropic("test prompt", self.config, self.api_key) + + +class TestCallOpenAI(unittest.TestCase): + """Tests for _call_openai with mocked requests.post returning realistic API responses.""" + + def setUp(self) -> None: + self.config = LLMJudgeConfig(provider=LLMProvider.OPENAI) + self.api_key = "fake-openai-key" + super().setUp() + + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.requests.post") + def test_successful_response(self, mock_post: unittest.mock.MagicMock) -> None: + mock_response = unittest.mock.MagicMock() + mock_response.json.return_value = { + "id": "chatcmpl-abc123", + "object": "chat.completion", + "created": 1677858242, + "model": "gpt-4o", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": json.dumps( + { + "scores": { + "accuracy": 4, + "relevance": 5, + "fluency": 4, + "completeness": 3, + }, + "overall_score": 4.0, + "reasoning": "Good quality response.", + } + ), + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 25, + "completion_tokens": 150, + "total_tokens": 175, + }, + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + result = _call_openai("test prompt", self.config, self.api_key) + + self.assertEqual(result["overall_score"], 4.0) + self.assertEqual(result["scores"]["accuracy"], 4) + self.assertEqual(result["scores"]["relevance"], 5) + self.assertEqual(result["reasoning"], "Good quality response.") + + mock_post.assert_called_once() + call_kwargs = mock_post.call_args + self.assertEqual( + call_kwargs.args[0], + "https://api.openai.com/v1/chat/completions", + ) + self.assertEqual( + call_kwargs.kwargs["headers"]["Authorization"], + f"Bearer {self.api_key}", + ) + self.assertEqual(call_kwargs.kwargs["json"]["model"], self.config.model) + + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.requests.post") + def test_response_with_markdown_fences( + self, mock_post: unittest.mock.MagicMock + ) -> None: + judge_json = json.dumps( + { + "scores": { + "accuracy": 3, + "relevance": 4, + "fluency": 3, + "completeness": 4, + }, + "overall_score": 3.5, + "reasoning": "Decent response.", + } + ) + mock_response = unittest.mock.MagicMock() + mock_response.json.return_value = { + "id": "chatcmpl-def456", + "object": "chat.completion", + "created": 1677858300, + "model": "gpt-4o", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": f"```json\n{judge_json}\n```", + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 30, + "completion_tokens": 160, + "total_tokens": 190, + }, + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + result = _call_openai("test prompt", self.config, self.api_key) + self.assertEqual(result["overall_score"], 3.5) + + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.requests.post") + def test_http_error_raises(self, mock_post: unittest.mock.MagicMock) -> None: + import requests as req + + mock_response = unittest.mock.MagicMock() + mock_response.raise_for_status.side_effect = req.HTTPError( + "429 Too Many Requests" + ) + mock_post.return_value = mock_response + + with self.assertRaises(req.HTTPError): + _call_openai("test prompt", self.config, self.api_key) + + +class TestCallGemini(unittest.TestCase): + """Tests for _call_gemini with mocked requests.post returning realistic API responses.""" + + def setUp(self) -> None: + self.config = LLMJudgeConfig(provider=LLMProvider.GEMINI) + self.api_key = "fake-gemini-key" + super().setUp() + + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.requests.post") + def test_successful_response(self, mock_post: unittest.mock.MagicMock) -> None: + mock_response = unittest.mock.MagicMock() + mock_response.json.return_value = { + "candidates": [ + { + "content": { + "parts": [ + { + "text": json.dumps( + { + "scores": { + "accuracy": 4, + "relevance": 5, + "fluency": 4, + "completeness": 3, + }, + "overall_score": 4.0, + "reasoning": "Good quality response.", + } + ) + } + ], + "role": "model", + }, + "finishReason": "STOP", + "index": 0, + } + ], + "usageMetadata": { + "promptTokenCount": 25, + "candidatesTokenCount": 150, + "totalTokenCount": 175, + }, + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + result = _call_gemini("test prompt", self.config, self.api_key) + + self.assertEqual(result["overall_score"], 4.0) + self.assertEqual(result["scores"]["accuracy"], 4) + self.assertEqual(result["scores"]["relevance"], 5) + self.assertEqual(result["reasoning"], "Good quality response.") + + mock_post.assert_called_once() + call_kwargs = mock_post.call_args + self.assertIn(self.config.model, call_kwargs.args[0]) + self.assertIn("generateContent", call_kwargs.args[0]) + self.assertEqual(call_kwargs.kwargs["headers"]["x-goog-api-key"], self.api_key) + + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.requests.post") + def test_response_with_markdown_fences( + self, mock_post: unittest.mock.MagicMock + ) -> None: + judge_json = json.dumps( + { + "scores": { + "accuracy": 2, + "relevance": 3, + "fluency": 2, + "completeness": 2, + }, + "overall_score": 2.25, + "reasoning": "Poor response.", + } + ) + mock_response = unittest.mock.MagicMock() + mock_response.json.return_value = { + "candidates": [ + { + "content": { + "parts": [{"text": f"```json\n{judge_json}\n```"}], + "role": "model", + }, + "finishReason": "STOP", + "index": 0, + } + ], + "usageMetadata": { + "promptTokenCount": 30, + "candidatesTokenCount": 160, + "totalTokenCount": 190, + }, + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + result = _call_gemini("test prompt", self.config, self.api_key) + self.assertEqual(result["overall_score"], 2.25) + + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.requests.post") + def test_http_error_raises(self, mock_post: unittest.mock.MagicMock) -> None: + import requests as req + + mock_response = unittest.mock.MagicMock() + mock_response.raise_for_status.side_effect = req.HTTPError( + "500 Internal Server Error" + ) + mock_post.return_value = mock_response + + with self.assertRaises(req.HTTPError): + _call_gemini("test prompt", self.config, self.api_key) + + +class TestLLMJudgeAnalysisNode(unittest.TestCase): + def setUp(self) -> None: + self.df = pd.DataFrame( + { + "prompt": ["What is AI?", "Explain ML"], + "generation": ["AI is...", "ML is..."], + "reference_text": [ + "Artificial intelligence is...", + "Machine learning is...", + ], + } + ) + self.config = LLMJudgeConfig( + provider=LLMProvider.ANTHROPIC, + scoring_criteria=["accuracy", "fluency"], + ) + self.analysis_input = LLMJudgeAnalysisInput( + generation_df=self.df, config=self.config + ) + self.node = LLMJudgeAnalysisNode(analysis_input=self.analysis_input) + super().setUp() + + def test_build_eval_prompt_with_reference(self) -> None: + row = self.df.iloc[0] + prompt_text = self.node._build_eval_prompt(row) + self.assertIn("What is AI?", prompt_text) + self.assertIn("AI is...", prompt_text) + self.assertIn("Artificial intelligence is...", prompt_text) + self.assertIn("accuracy", prompt_text) + self.assertIn("fluency", prompt_text) + + def test_build_eval_prompt_without_reference(self) -> None: + df_no_ref = self.df[["prompt", "generation"]].copy() + config = LLMJudgeConfig( + provider=LLMProvider.ANTHROPIC, + scoring_criteria=["accuracy"], + ) + analysis_input = LLMJudgeAnalysisInput( + generation_df=df_no_ref, config=config, reference_key=None + ) + node = LLMJudgeAnalysisNode(analysis_input=analysis_input) + prompt_text = node._build_eval_prompt(df_no_ref.iloc[0]) + self.assertIn("What is AI?", prompt_text) + self.assertIn("AI is...", prompt_text) + self.assertNotIn("Reference Text", prompt_text) + + def test_build_eval_prompt_custom_template(self) -> None: + config = LLMJudgeConfig( + provider=LLMProvider.ANTHROPIC, + eval_prompt="P: {prompt} G: {generation} R: {reference_text} C: {criteria}", + scoring_criteria=["clarity"], + ) + analysis_input = LLMJudgeAnalysisInput(generation_df=self.df, config=config) + node = LLMJudgeAnalysisNode(analysis_input=analysis_input) + prompt_text = node._build_eval_prompt(self.df.iloc[0]) + self.assertIn("P: What is AI?", prompt_text) + self.assertIn("G: AI is...", prompt_text) + self.assertIn("R: Artificial intelligence is...", prompt_text) + self.assertIn("- clarity", prompt_text) + + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.call_llm_judge") + def test_evaluate_single_success(self, mock_judge: unittest.mock.MagicMock) -> None: + expected = _make_judge_response(["accuracy", "fluency"]) + mock_judge.return_value = expected + result = self.node._evaluate_single(self.df.iloc[0]) + self.assertEqual(result["overall_score"], 4.0) + self.assertIn("accuracy", result["scores"]) + + @patch( + "privacy_guard.analysis.llm_judge.llm_judge_analysis_node.call_llm_judge", + side_effect=RuntimeError("API down"), + ) + def test_evaluate_single_failure_returns_fallback( + self, _mock_judge: unittest.mock.MagicMock + ) -> None: + result = self.node._evaluate_single(self.df.iloc[0]) + self.assertEqual(result["overall_score"], 0.0) + self.assertIn("failed", result["reasoning"].lower()) + + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.call_llm_judge") + def test_run_analysis_output_structure( + self, mock_judge: unittest.mock.MagicMock + ) -> None: + mock_judge.return_value = _make_judge_response(["accuracy", "fluency"]) + output = self.node.run_analysis() + self.assertIsInstance(output, LLMJudgeAnalysisOutput) + self.assertEqual(output.num_samples, 2) + self.assertEqual(output.num_failed, 0) + self.assertEqual(output.provider, "anthropic") + self.assertEqual(len(output.per_sample_overall_scores), 2) + self.assertEqual(len(output.per_sample_criteria_scores), 2) + self.assertEqual(len(output.per_sample_reasoning), 2) + self.assertIn("accuracy", output.per_criteria_avg_scores) + self.assertIn("fluency", output.per_criteria_avg_scores) + + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.call_llm_judge") + def test_run_analysis_avg_overall_score( + self, mock_judge: unittest.mock.MagicMock + ) -> None: + mock_judge.return_value = _make_judge_response(["accuracy", "fluency"]) + output = self.node.run_analysis() + self.assertAlmostEqual(output.avg_overall_score, 4.0) + + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.call_llm_judge") + def test_run_analysis_augmented_dataset( + self, mock_judge: unittest.mock.MagicMock + ) -> None: + mock_judge.return_value = _make_judge_response(["accuracy", "fluency"]) + output = self.node.run_analysis() + aug_df = output.augmented_output_dataset + self.assertIn("judge_overall_score", aug_df.columns) + self.assertIn("judge_reasoning", aug_df.columns) + self.assertIn("judge_accuracy_score", aug_df.columns) + self.assertIn("judge_fluency_score", aug_df.columns) + self.assertEqual(len(aug_df), 2) + + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.call_llm_judge") + def test_run_analysis_with_failures( + self, mock_judge: unittest.mock.MagicMock + ) -> None: + success = _make_judge_response(["accuracy", "fluency"]) + mock_judge.side_effect = [success, RuntimeError("fail")] + output = self.node.run_analysis() + self.assertEqual(output.num_failed, 1) + self.assertEqual(output.per_sample_overall_scores[0], 4.0) + self.assertEqual(output.per_sample_overall_scores[1], 0.0) + + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.call_llm_judge") + def test_compute_outputs_returns_dict( + self, mock_judge: unittest.mock.MagicMock + ) -> None: + mock_judge.return_value = _make_judge_response(["accuracy", "fluency"]) + outputs = self.node.compute_outputs() + self.assertIsInstance(outputs, dict) + self.assertIn("avg_overall_score", outputs) + self.assertIn("num_samples", outputs) + self.assertIn("num_failed", outputs) + + @patch("privacy_guard.analysis.llm_judge.llm_judge_analysis_node.call_llm_judge") + def test_run_analysis_per_criteria_averages( + self, mock_judge: unittest.mock.MagicMock + ) -> None: + response_1: dict[str, object] = { + "scores": {"accuracy": 5, "fluency": 3}, + "overall_score": 4.0, + "reasoning": "Good.", + } + response_2: dict[str, object] = { + "scores": {"accuracy": 3, "fluency": 5}, + "overall_score": 4.0, + "reasoning": "Good.", + } + mock_judge.side_effect = [response_1, response_2] + output = self.node.run_analysis() + self.assertAlmostEqual(output.per_criteria_avg_scores["accuracy"], 4.0) + self.assertAlmostEqual(output.per_criteria_avg_scores["fluency"], 4.0) diff --git a/privacy_guard/attacks/extraction/ip_attack.py b/privacy_guard/attacks/extraction/ip_attack.py new file mode 100644 index 0000000..cdc8691 --- /dev/null +++ b/privacy_guard/attacks/extraction/ip_attack.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict + +import logging +from typing import Any, Dict, Literal + +import pandas as pd +from privacy_guard.analysis.llm_judge.llm_judge_analysis_input import ( + LLMJudgeAnalysisInput, +) +from privacy_guard.analysis.llm_judge.llm_judge_config import LLMJudgeConfig +from privacy_guard.attacks.base_attack import BaseAttack +from privacy_guard.attacks.extraction.predictors.base_predictor import BasePredictor +from privacy_guard.attacks.extraction.utils.data_utils import load_data, save_results + + +def setup_logger() -> logging.Logger: + """Set up the logger for the script.""" + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + + logger.handlers.clear() + logger.propagate = False + + handler = logging.StreamHandler() + handler.setLevel(logging.INFO) + + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + + logger.addHandler(handler) + + return logger + + +logger: logging.Logger = setup_logger() + + +FormatType = Literal["jsonl", "csv", "json"] + + +class IPAttack(BaseAttack): + """Attack for evaluating IP content generation risk using an LLM judge. + + Given a set of prompts designed to probe a target model for IP content + reproduction, this attack: + 1. Generates text from the target model via a ``BasePredictor``. + 2. Returns an ``LLMJudgeAnalysisInput`` that can be fed into + ``LLMJudgeAnalysisNode`` for scoring the generations against + IP-related criteria (e.g. verbatim reproduction, paraphrasing). + + Args: + input_file: Path to the input file containing probing prompts. + output_file: Optional path to save generations. When ``None``, + results are only returned in the analysis input. + predictor: Predictor instance used to generate text from the + target model. + judge_config: ``LLMJudgeConfig`` specifying the judge provider, + model, and scoring criteria for IP evaluation. + input_format: Format of the input file. + output_format: Format of the output file. + prompt_key: Column name for the input prompt. + generation_key: Column name for the generated text (written + into the output dataframe). + reference_key: Column name for reference/ground-truth text. + Set to ``None`` when no reference text is available. + batch_size: Batch size for generation. + **generation_kwargs: Additional generation parameters forwarded + to the predictor (temperature, top_k, top_p, etc.). + """ + + def __init__( + self, + input_file: str, + output_file: str | None, + predictor: BasePredictor, + judge_config: LLMJudgeConfig, + input_format: FormatType = "jsonl", + output_format: FormatType = "jsonl", + prompt_key: str = "prompt", + generation_key: str = "generation", + reference_key: str = "reference_text", + batch_size: int = 1, + **generation_kwargs: Any, + ) -> None: + self.input_file = input_file + self.output_file = output_file + self.input_format = input_format + self.output_format = output_format + self.predictor: BasePredictor = predictor + self.judge_config = judge_config + self.prompt_key = prompt_key + self.generation_key = generation_key + self.reference_key = reference_key + self.batch_size = batch_size + self.generation_kwargs: Dict[str, Any] = generation_kwargs + + logger.info(f"Loading data from {input_file}") + self.input_df: pd.DataFrame = load_data(input_file, format=input_format) + logger.info(f"Loaded {len(self.input_df)} rows") + + required_columns = {self.prompt_key} + missing_columns = required_columns - set(self.input_df.columns) + if missing_columns: + raise ValueError(f"Missing required columns: {missing_columns}") + + logger.info("IP attack is ready to run") + + def run_attack(self) -> LLMJudgeAnalysisInput: + """Execute the IP content probing attack. + + Generates text from the target model and returns an analysis + input ready for LLM judge evaluation. + """ + logger.info("Executing IP content attack") + + prompts = self.input_df[self.prompt_key].tolist() + + logger.info(f"Generating text for {len(prompts)} prompts") + generations = self.predictor.generate( + prompts=prompts, + batch_size=self.batch_size, + **self.generation_kwargs, + ) + + processed_df = self.input_df.copy() + processed_df[self.generation_key] = generations + + logger.info("Generation complete") + + if self.output_file is not None: + output_path: str = self.output_file + output_format: str = str(self.output_format) + logger.info(f"Saving results to {output_path}") + save_results(df=processed_df, output_path=output_path, format=output_format) + logger.info("Results saved successfully") + else: + logger.info("No output file specified, not saving results to disk") + + return LLMJudgeAnalysisInput( + generation_df=processed_df, + config=self.judge_config, + prompt_key=self.prompt_key, + generation_key=self.generation_key, + reference_key=self.reference_key, + ) diff --git a/privacy_guard/attacks/extraction/tests/test_ip_attack.py b/privacy_guard/attacks/extraction/tests/test_ip_attack.py new file mode 100644 index 0000000..b116c65 --- /dev/null +++ b/privacy_guard/attacks/extraction/tests/test_ip_attack.py @@ -0,0 +1,219 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict + +import os +import tempfile +import unittest +from unittest.mock import MagicMock + +from privacy_guard.analysis.llm_judge.llm_judge_analysis_input import ( + LLMJudgeAnalysisInput, +) +from privacy_guard.analysis.llm_judge.llm_judge_config import ( + LLMJudgeConfig, + LLMProvider, +) +from privacy_guard.attacks.extraction.ip_attack import IPAttack + + +class TestIPAttack(unittest.TestCase): + def setUp(self) -> None: + """Set up test data and mocks.""" + self.input_file = tempfile.NamedTemporaryFile(suffix=".jsonl") + self.input_file_name = self.input_file.name + + with open(self.input_file_name, "w") as f: + f.write( + '{"prompt": "Recite chapter 1 of Book X", ' + '"reference_text": "It was a dark and stormy night..."}\n' + ) + f.write( + '{"prompt": "Continue this passage from Book Y", ' + '"reference_text": "Call me Ishmael..."}\n' + ) + + self.output_file = tempfile.NamedTemporaryFile(suffix=".jsonl") + self.output_file_name = self.output_file.name + + self.mock_predictor = MagicMock() + self.mock_predictor.generate.return_value = [ + "It was a dark and stormy night, the wind howled...", + "Call me Ishmael. Some years ago...", + ] + + self.judge_config = LLMJudgeConfig( + provider=LLMProvider.ANTHROPIC, + scoring_criteria=[ + "ip_similarity", + "verbatim_reproduction", + "paraphrasing", + "originality", + ], + ) + + def test_ip_attack_basic(self) -> None: + """Test basic attack returns LLMJudgeAnalysisInput with correct columns.""" + attack = IPAttack( + input_file=self.input_file_name, + output_file=None, + predictor=self.mock_predictor, + judge_config=self.judge_config, + ) + + result = attack.run_attack() + + self.assertIsInstance(result, LLMJudgeAnalysisInput) + self.assertEqual(len(result.generation_df), 2) + self.assertIn("prompt", result.generation_df.columns) + self.assertIn("generation", result.generation_df.columns) + self.assertIs(result.config, self.judge_config) + + def test_ip_attack_with_output_file(self) -> None: + """Test that output file is written when specified.""" + attack = IPAttack( + input_file=self.input_file_name, + output_file=self.output_file_name, + predictor=self.mock_predictor, + judge_config=self.judge_config, + ) + + attack.run_attack() + + self.assertTrue(os.path.getsize(self.output_file_name) > 0) + + def test_ip_attack_without_output_file(self) -> None: + """Test attack works without saving to disk.""" + attack = IPAttack( + input_file=self.input_file_name, + output_file=None, + predictor=self.mock_predictor, + judge_config=self.judge_config, + ) + + result = attack.run_attack() + + self.assertIsInstance(result, LLMJudgeAnalysisInput) + self.mock_predictor.generate.assert_called_once() + + def test_ip_attack_custom_columns(self) -> None: + """Test attack with custom column names.""" + custom_input = tempfile.NamedTemporaryFile(suffix=".jsonl") + with open(custom_input.name, "w") as f: + f.write( + '{"question": "Recite chapter 1", ' + '"gold": "It was a dark and stormy night..."}\n' + ) + + self.mock_predictor.generate.return_value = ["generated text"] + + attack = IPAttack( + input_file=custom_input.name, + output_file=None, + predictor=self.mock_predictor, + judge_config=self.judge_config, + prompt_key="question", + generation_key="response", + reference_key="gold", + ) + + result = attack.run_attack() + + self.assertEqual(result.prompt_key, "question") + self.assertEqual(result.generation_key, "response") + self.assertEqual(result.reference_key, "gold") + self.assertIn("response", result.generation_df.columns) + custom_input.close() + + def test_ip_attack_missing_prompt_column(self) -> None: + """Test that ValueError is raised when prompt column is missing.""" + bad_input = tempfile.NamedTemporaryFile(suffix=".jsonl") + with open(bad_input.name, "w") as f: + f.write('{"other_column": "value"}\n') + + with self.assertRaises(ValueError) as context: + IPAttack( + input_file=bad_input.name, + output_file=None, + predictor=self.mock_predictor, + judge_config=self.judge_config, + ) + + self.assertIn("Missing required columns", str(context.exception)) + bad_input.close() + + def test_ip_attack_with_reference_text(self) -> None: + """Test that reference text is passed through to analysis input.""" + attack = IPAttack( + input_file=self.input_file_name, + output_file=None, + predictor=self.mock_predictor, + judge_config=self.judge_config, + ) + + result = attack.run_attack() + + self.assertTrue(result.has_reference) + self.assertEqual(result.reference_key, "reference_text") + self.assertIn("reference_text", result.generation_df.columns) + + def test_ip_attack_without_reference_text(self) -> None: + """Test attack works when no reference column exists.""" + no_ref_input = tempfile.NamedTemporaryFile(suffix=".jsonl") + with open(no_ref_input.name, "w") as f: + f.write('{"prompt": "Tell me about Book X"}\n') + + self.mock_predictor.generate.return_value = ["Some generated text"] + + attack = IPAttack( + input_file=no_ref_input.name, + output_file=None, + predictor=self.mock_predictor, + judge_config=self.judge_config, + ) + + result = attack.run_attack() + + self.assertFalse(result.has_reference) + no_ref_input.close() + + def test_ip_attack_generation_kwargs_forwarded(self) -> None: + """Test that generation kwargs are forwarded to the predictor.""" + attack = IPAttack( + input_file=self.input_file_name, + output_file=None, + predictor=self.mock_predictor, + judge_config=self.judge_config, + batch_size=4, + temperature=0.7, + top_k=50, + ) + + attack.run_attack() + + self.mock_predictor.generate.assert_called_once_with( + prompts=[ + "Recite chapter 1 of Book X", + "Continue this passage from Book Y", + ], + batch_size=4, + temperature=0.7, + top_k=50, + ) + + def tearDown(self) -> None: + """Clean up temporary files.""" + self.input_file.close() + self.output_file.close() diff --git a/pyproject.toml b/pyproject.toml index 79066ec..6b9cffe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "pytest>=4.6", "pytest-cov", "pytest-timeout", + "requests", "torch", 'tqdm', 'textdistance',