Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions pyrit/datasets/seed_datasets/remote/beaver_tails_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

import logging

from jinja2 import TemplateSyntaxError

from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
_RemoteDatasetLoader,
)
Expand Down Expand Up @@ -101,21 +99,18 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset:

harm_categories = [k for k, v in item["category"].items() if v]

try:
seed_prompts.append(
SeedPrompt(
value=item["prompt"],
data_type="text",
dataset_name=self.dataset_name,
harm_categories=harm_categories,
description=description,
source=source_url,
authors=authors,
groups=groups,
)
seed_prompts.append(
SeedPrompt(
value=item["prompt"],
data_type="text",
dataset_name=self.dataset_name,
harm_categories=harm_categories,
description=description,
source=source_url,
authors=authors,
groups=groups,
)
except TemplateSyntaxError:
logger.warning("Skipping BeaverTails prompt due to Jinja2 template syntax error in prompt text")
)

logger.info(f"Successfully loaded {len(seed_prompts)} prompts from BeaverTails dataset")

Expand Down
50 changes: 16 additions & 34 deletions pyrit/datasets/seed_datasets/remote/toxic_chat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import logging
from typing import Any

from jinja2 import TemplateSyntaxError

from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
_RemoteDatasetLoader,
)
Expand Down Expand Up @@ -122,42 +120,26 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset:
source_url = f"https://huggingface.co/datasets/{self.HF_DATASET_NAME}"
groups = ["UC San Diego"]

raw_prefix = "{% raw %}"
raw_suffix = "{% endraw %}"

seed_prompts: list[SeedPrompt] = []
for item in data:
user_input = item["user_input"]
harm_categories = self._extract_harm_categories(item)
try:
prompt = SeedPrompt(
value=user_input,
data_type="text",
dataset_name=self.dataset_name,
description=description,
source=source_url,
authors=authors,
groups=groups,
harm_categories=harm_categories,
metadata={
"toxicity": str(item.get("toxicity", "")),
"jailbreaking": str(item.get("jailbreaking", "")),
"human_annotation": str(item.get("human_annotation", "")),
},
)

# If user_input contains Jinja2 control structures (e.g., {% for %}),
# render_template_value_silent may skip rendering and leave the raw wrapper.
if prompt.value.startswith(raw_prefix) and prompt.value.endswith(raw_suffix):
prompt.value = prompt.value[len(raw_prefix) : -len(raw_suffix)]

seed_prompts.append(prompt)
except TemplateSyntaxError:
conv_id = item.get("conv_id", "unknown")
logger.debug(
f"Skipping entry with conv_id={conv_id}: failed to parse as Jinja2 template",
exc_info=True,
)
prompt = SeedPrompt(
value=user_input,
data_type="text",
dataset_name=self.dataset_name,
description=description,
source=source_url,
authors=authors,
groups=groups,
harm_categories=harm_categories,
metadata={
"toxicity": str(item.get("toxicity", "")),
"jailbreaking": str(item.get("jailbreaking", "")),
"human_annotation": str(item.get("human_annotation", "")),
},
)
seed_prompts.append(prompt)

logger.info(f"Successfully loaded {len(seed_prompts)} prompts from ToxicChat dataset")

Expand Down
6 changes: 3 additions & 3 deletions pyrit/models/seeds/seed_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def __post_init__(self) -> None:
"""
if self.is_general_technique:
raise ValueError("SeedObjective cannot be a general technique.")
if not self.is_jinja_template:
self.value = self.escape_for_jinja(self.value)
self.value = super().render_template_value_silent(**PATHS_DICT)
# Only trusted templates are rendered through Jinja — see seed_prompt.py for details.
if self.is_jinja_template:
self.value = super().render_template_value_silent(**PATHS_DICT)

@classmethod
def from_yaml_with_required_parameters(
Expand Down
9 changes: 6 additions & 3 deletions pyrit/models/seeds/seed_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@ def __post_init__(self) -> None:
ValueError: If file-based data type cannot be inferred from extension.

"""
if not self.is_jinja_template:
self.value = self.escape_for_jinja(self.value)
self.value = self.render_template_value_silent(**PATHS_DICT)
# Only trusted templates (is_jinja_template=True, e.g. from YAML files) are rendered
# through Jinja. Untrusted text (e.g. from remote datasets) must NOT be rendered — a
# crafted payload containing "{% endraw %}" can escape the raw wrapper and execute
# arbitrary Jinja expressions. See seed_objective.py for the same pattern.
if self.is_jinja_template:
self.value = self.render_template_value_silent(**PATHS_DICT)

if not self.data_type:
# If data_type is not provided, infer it from the value
Expand Down
13 changes: 7 additions & 6 deletions tests/unit/datasets/test_beaver_tails_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,14 @@ def test_dataset_name(self):
assert loader.dataset_name == "beaver_tails"

@pytest.mark.asyncio
async def test_fetch_dataset_skips_prompt_with_template_syntax_error(self):
"""Test that prompts causing TemplateSyntaxError are skipped gracefully."""
async def test_fetch_dataset_preserves_prompt_with_jinja_syntax(self):
"""Test that prompts containing Jinja2 syntax are preserved as literal text."""

class MockDataset:
def __init__(self):
self._data = [
{
"prompt": "This contains {% endraw %} which breaks Jinja2",
"prompt": "This contains {% endraw %} which is Jinja2 syntax",
"response": "response",
"category": {"animal_abuse": True},
"is_safe": False,
Expand All @@ -124,6 +124,7 @@ def __iter__(self):

with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=MockDataset())):
dataset = await loader.fetch_dataset()
# The broken prompt should be skipped, only the normal one remains
assert len(dataset.seeds) == 1
assert dataset.seeds[0].value == "Normal unsafe prompt"
# Both prompts should be preserved — untrusted text is never passed through Jinja
assert len(dataset.seeds) == 2
assert dataset.seeds[0].value == "This contains {% endraw %} which is Jinja2 syntax"
assert dataset.seeds[1].value == "Normal unsafe prompt"
12 changes: 7 additions & 5 deletions tests/unit/datasets/test_toxic_chat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ async def test_fetch_dataset_preserves_jinja2_content(self):
assert dataset.seeds[1].value == "<!DOCTYPE html>{%block%}broken"

@pytest.mark.asyncio
async def test_fetch_dataset_skips_jinja2_incompatible_entries(self):
"""Test that entries with Jinja2-incompatible content are skipped."""
async def test_fetch_dataset_preserves_jinja2_syntax_in_entries(self):
"""Test that entries with Jinja2 syntax are preserved as literal text."""
data_with_endraw = [
{
"conv_id": "good1",
Expand All @@ -105,7 +105,7 @@ async def test_fetch_dataset_skips_jinja2_incompatible_entries(self):
"openai_moderation": "[]",
},
{
"conv_id": "bad1",
"conv_id": "jinja1",
"user_input": "This has {% endraw %} in it",
"model_output": "N/A",
"human_annotation": "False",
Expand All @@ -128,9 +128,11 @@ async def test_fetch_dataset_skips_jinja2_incompatible_entries(self):
with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=data_with_endraw)):
dataset = await loader.fetch_dataset()

assert len(dataset.seeds) == 2
# All entries are preserved — untrusted text is never passed through Jinja
assert len(dataset.seeds) == 3
assert dataset.seeds[0].value == "Normal question"
assert dataset.seeds[1].value == "Another normal question"
assert dataset.seeds[1].value == "This has {% endraw %} in it"
assert dataset.seeds[2].value == "Another normal question"

@pytest.mark.asyncio
async def test_fetch_dataset_preserves_for_loop_content(self):
Expand Down
Loading