diff --git a/frontend/e2e/config.spec.ts b/frontend/e2e/config.spec.ts index 51f3edb135..4178f792b2 100644 --- a/frontend/e2e/config.spec.ts +++ b/frontend/e2e/config.spec.ts @@ -52,11 +52,14 @@ test.describe("Target Configuration Page", () => { await goToConfig(page); - // Table should appear with both targets - await expect(page.getByText("gpt-4o")).toBeVisible({ timeout: 10000 }); - await expect(page.getByText("dall-e-3")).toBeVisible(); - await expect(page.getByText("OpenAIChatTarget")).toBeVisible(); + // Section headings should appear (collapsed by default) + await expect(page.getByText("OpenAIChatTarget")).toBeVisible({ timeout: 10000 }); await expect(page.getByText("OpenAIImageTarget")).toBeVisible(); + + // Expand all to see target data + await page.getByRole("button", { name: /expand all/i }).click(); + await expect(page.getByText("gpt-4o")).toBeVisible(); + await expect(page.getByText("dall-e-3")).toBeVisible(); }); test("should show empty state when no targets exist", async ({ page }) => { @@ -86,9 +89,11 @@ test.describe("Target Configuration Page", () => { }); await goToConfig(page); - await expect(page.getByText("gpt-4o")).toBeVisible({ timeout: 10000 }); + await expect(page.getByText("OpenAIChatTarget")).toBeVisible({ timeout: 10000 }); + + // Expand section to reveal Set Active buttons + await page.getByRole("button", { name: /OpenAIChatTarget/i }).click(); - // Both rows should have a "Set Active" button initially const setActiveBtns = page.getByRole("button", { name: /set active/i }); await expect(setActiveBtns.first()).toBeVisible(); await setActiveBtns.first().click(); @@ -123,16 +128,16 @@ test.describe("Target Configuration Page", () => { }); await goToConfig(page); - // First load shows one target - await expect(page.getByText("gpt-4o")).toBeVisible({ timeout: 10000 }); - await expect(page.getByText("dall-e-3")).not.toBeVisible(); + // First load shows one section heading + await expect(page.getByText("OpenAIChatTarget")).toBeVisible({ timeout: 10000 }); + await expect(page.getByText("OpenAIImageTarget")).not.toBeVisible(); // Flip the flag and click refresh showExtra = true; await page.getByRole("button", { name: /refresh/i }).click(); - // Second target should now appear - await expect(page.getByText("dall-e-3")).toBeVisible({ timeout: 10000 }); + // Second target type section should now appear + await expect(page.getByText("OpenAIImageTarget")).toBeVisible({ timeout: 10000 }); }); }); @@ -177,7 +182,7 @@ test.describe("Create Target Dialog", () => { await dialog.getByPlaceholder("https://your-resource.openai.azure.com/").fill("https://my-endpoint.openai.azure.com/"); // Fill model name - await dialog.getByPlaceholder("e.g. gpt-4o, dall-e-3").fill("gpt-4o-test"); + await dialog.getByPlaceholder("e.g. gpt-4o, my-deployment").fill("gpt-4o-test"); // Click Create Target await dialog.getByRole("button", { name: "Create Target" }).click(); @@ -225,9 +230,10 @@ test.describe("Target Config ↔ Chat Navigation", () => { }); await goToConfig(page); - await expect(page.getByText("gpt-4o")).toBeVisible({ timeout: 10000 }); + await expect(page.getByText("OpenAIChatTarget")).toBeVisible({ timeout: 10000 }); - // Set first target active + // Expand section and set first target active + await page.getByRole("button", { name: /OpenAIChatTarget/i }).click(); await page.getByRole("button", { name: /set active/i }).first().click(); // Navigate back to chat @@ -248,9 +254,10 @@ test.describe("Target Config ↔ Chat Navigation", () => { await page.goto("/"); await expect(page.getByTestId("no-target-banner")).toBeVisible(); - // Go to config, set a target + // Go to config, expand section, set a target await page.getByTitle("Configuration").click(); - await expect(page.getByText("gpt-4o")).toBeVisible({ timeout: 10000 }); + await expect(page.getByText("OpenAIChatTarget")).toBeVisible({ timeout: 10000 }); + await page.getByRole("button", { name: /OpenAIChatTarget/i }).click(); await page.getByRole("button", { name: /set active/i }).first().click(); // Return to chat — send should be enabled when there's text diff --git a/frontend/src/components/Chat/ChatWindow.tsx b/frontend/src/components/Chat/ChatWindow.tsx index 89d32be4b1..a63283b366 100644 --- a/frontend/src/components/Chat/ChatWindow.tsx +++ b/frontend/src/components/Chat/ChatWindow.tsx @@ -520,7 +520,7 @@ export default function ChatWindow({ {activeTarget.target_type} - {activeTarget.model_name ? ` (${activeTarget.model_name})` : ''} + {activeTarget.deployment_name ? ` (${activeTarget.deployment_name})` : activeTarget.model_name ? ` (${activeTarget.model_name})` : ''} diff --git a/frontend/src/components/Config/CreateTargetDialog.test.tsx b/frontend/src/components/Config/CreateTargetDialog.test.tsx index 5c8abe6b9b..017e7025b1 100644 --- a/frontend/src/components/Config/CreateTargetDialog.test.tsx +++ b/frontend/src/components/Config/CreateTargetDialog.test.tsx @@ -112,7 +112,7 @@ describe("CreateTargetDialog", () => { fireEvent.change(endpointInput, { target: { value: "https://api.openai.com" } }); // Fill model name — use fireEvent.change for consistency (same reason as endpoint) - const modelInput = screen.getByPlaceholderText("e.g. gpt-4o, dall-e-3"); + const modelInput = screen.getByPlaceholderText("e.g. gpt-4o, my-deployment"); fireEvent.change(modelInput, { target: { value: "gpt-4" } }); // Submit diff --git a/frontend/src/components/Config/CreateTargetDialog.tsx b/frontend/src/components/Config/CreateTargetDialog.tsx index b52317161a..5607b7e238 100644 --- a/frontend/src/components/Config/CreateTargetDialog.tsx +++ b/frontend/src/components/Config/CreateTargetDialog.tsx @@ -10,6 +10,8 @@ import { Input, Label, Select, + Switch, + Text, tokens, Field, MessageBar, @@ -38,6 +40,8 @@ export default function CreateTargetDialog({ open, onClose, onCreated }: CreateT const [targetType, setTargetType] = useState('') const [endpoint, setEndpoint] = useState('') const [modelName, setModelName] = useState('') + const [hasDifferentUnderlying, setHasDifferentUnderlying] = useState(false) + const [underlyingModel, setUnderlyingModel] = useState('') const [apiKey, setApiKey] = useState('') const [submitting, setSubmitting] = useState(false) const [error, setError] = useState(null) @@ -47,6 +51,8 @@ export default function CreateTargetDialog({ open, onClose, onCreated }: CreateT setTargetType('') setEndpoint('') setModelName('') + setHasDifferentUnderlying(false) + setUnderlyingModel('') setApiKey('') setError(null) setFieldErrors({}) @@ -75,6 +81,7 @@ export default function CreateTargetDialog({ open, onClose, onCreated }: CreateT endpoint, } if (modelName) params.model_name = modelName + if (hasDifferentUnderlying && underlyingModel) params.underlying_model = underlyingModel if (apiKey) params.api_key = apiKey await targetsApi.createTarget({ @@ -139,12 +146,36 @@ export default function CreateTargetDialog({ open, onClose, onCreated }: CreateT setModelName(data.value)} /> +
+ { + setHasDifferentUnderlying(data.checked) + if (!data.checked) setUnderlyingModel('') + }} + label="Underlying model differs from deployment name" + /> + + On Azure, the deployment name (e.g. my-gpt4-deployment) may differ from the actual model (e.g. gpt-4o). + +
+ + {hasDifferentUnderlying && ( + + setUnderlyingModel(data.value)} + /> + + )} + { expect(screen.getByText("OpenAIChatTarget")).toBeInTheDocument(); }); + // Expand section to reveal Set Active buttons + await userEvent.click(screen.getByRole("button", { name: /OpenAIChatTarget/i })); + const setActiveButtons = screen.getAllByText("Set Active"); await userEvent.click(setActiveButtons[0]); @@ -272,6 +275,12 @@ describe("TargetConfig", () => { await waitFor(() => { expect(screen.getByText("OpenAIChatTarget")).toBeInTheDocument(); + }); + + // Expand section to reveal target data + await userEvent.click(screen.getByRole("button", { name: /OpenAIChatTarget/i })); + + await waitFor(() => { expect(screen.getByText("gpt-4")).toBeInTheDocument(); expect( screen.getAllByText("https://api.openai.com").length @@ -307,7 +316,12 @@ describe("TargetConfig", () => { await waitFor(() => { expect(screen.getByText("OpenAIResponseTarget")).toBeInTheDocument(); - // formatParams renders as "key: value, key: value" + }); + + // Expand section to reveal params + await userEvent.click(screen.getByRole("button", { name: /OpenAIResponseTarget/i })); + + await waitFor(() => { expect(screen.getByText(/reasoning_effort: high/)).toBeInTheDocument(); expect(screen.getByText(/reasoning_summary: auto/)).toBeInTheDocument(); expect(screen.getByText(/max_output_tokens: 4096/)).toBeInTheDocument(); @@ -339,6 +353,9 @@ describe("TargetConfig", () => { expect(screen.getByText("TextTarget")).toBeInTheDocument(); }); + // Expand section + await userEvent.click(screen.getByRole("button", { name: /TextTarget/i })); + // No reasoning or other special params should be displayed expect(screen.queryByText(/reasoning_effort/)).not.toBeInTheDocument(); }); diff --git a/frontend/src/components/Config/TargetTable.styles.ts b/frontend/src/components/Config/TargetTable.styles.ts index 76634e0e47..c60fab8b00 100644 --- a/frontend/src/components/Config/TargetTable.styles.ts +++ b/frontend/src/components/Config/TargetTable.styles.ts @@ -17,4 +17,8 @@ export const useTargetTableStyles = makeStyles({ textOverflow: 'ellipsis', whiteSpace: 'nowrap', }, + paramsCell: { + whiteSpace: 'pre-line', + wordBreak: 'break-word', + }, }) diff --git a/frontend/src/components/Config/TargetTable.test.tsx b/frontend/src/components/Config/TargetTable.test.tsx index 4e6dcff821..59149d5fe4 100644 --- a/frontend/src/components/Config/TargetTable.test.tsx +++ b/frontend/src/components/Config/TargetTable.test.tsx @@ -17,18 +17,21 @@ const sampleTargets: TargetInstance[] = [ target_type: 'OpenAIChatTarget', endpoint: 'https://api.openai.com', model_name: 'gpt-4', + deployment_name: 'gpt-4', }, { target_registry_name: 'azure_image_dalle', target_type: 'AzureImageTarget', endpoint: 'https://azure.openai.com', model_name: 'dall-e-3', + deployment_name: 'dall-e-3', }, { target_registry_name: 'text_target_basic', target_type: 'TextTarget', endpoint: null, model_name: null, + deployment_name: null, }, ] @@ -43,47 +46,105 @@ describe('TargetTable', () => { jest.clearAllMocks() }) - it('should render table with target rows', () => { + it('should render section headings alphabetically by target type', () => { render( ) - expect(screen.getByRole('table')).toBeInTheDocument() - expect(screen.getByText('OpenAIChatTarget')).toBeInTheDocument() - expect(screen.getByText('AzureImageTarget')).toBeInTheDocument() - expect(screen.getByText('TextTarget')).toBeInTheDocument() + const buttons = screen.getAllByRole('button', { expanded: false }) + const headings = buttons.map(b => b.textContent).filter(t => t && !t.includes('Set Active')) + // AzureImageTarget, OpenAIChatTarget, TextTarget (alphabetical) + expect(headings[0]).toContain('AzureImageTarget') + expect(headings[1]).toContain('OpenAIChatTarget') + expect(headings[2]).toContain('TextTarget') }) - it('should display target type, endpoint, and model name columns', () => { + it('should show item count in section headings', () => { + const targets = [ + ...sampleTargets, + { target_registry_name: 'chat2', target_type: 'OpenAIChatTarget', endpoint: 'https://x.com', model_name: 'gpt-5' }, + ] + + render( + + + + ) + + // OpenAIChatTarget has 2 items + expect(screen.getByText('(2)')).toBeInTheDocument() + }) + + it('should start with all sections collapsed when no active target', () => { render( ) - // Header cells - expect(screen.getByText('Type')).toBeInTheDocument() - expect(screen.getByText('Model')).toBeInTheDocument() - expect(screen.getByText('Endpoint')).toBeInTheDocument() + // No tables visible when all collapsed + expect(screen.queryByRole('table')).not.toBeInTheDocument() + expect(screen.queryByText('Set Active')).not.toBeInTheDocument() + }) - // Data cells + it('should expand the section containing the active target', () => { + render( + + + + ) + + // OpenAIChatTarget section should be expanded expect(screen.getByText('gpt-4')).toBeInTheDocument() + expect(screen.getByText('Active')).toBeInTheDocument() + // Other sections should still be collapsed + expect(screen.queryByText('dall-e-3')).not.toBeInTheDocument() + }) + + it('should toggle section expand/collapse on click', () => { + render( + + + + ) + + // Click on AzureImageTarget heading to expand + const sectionButton = screen.getByRole('button', { name: /AzureImageTarget/i }) + fireEvent.click(sectionButton) + + // Table should now be visible with target data expect(screen.getByText('dall-e-3')).toBeInTheDocument() - expect(screen.getByText('https://api.openai.com')).toBeInTheDocument() expect(screen.getByText('https://azure.openai.com')).toBeInTheDocument() + + // Click again to collapse + fireEvent.click(sectionButton) + expect(screen.queryByText('dall-e-3')).not.toBeInTheDocument() + }) + + it('should display endpoint and model name when section is expanded', () => { + render( + + + + ) + + // OpenAIChatTarget section is expanded (has active target) + expect(screen.getByText('gpt-4')).toBeInTheDocument() + expect(screen.getByText('https://api.openai.com')).toBeInTheDocument() }) it('should show "Set Active" button for non-active targets', () => { render( - + ) - const setActiveButtons = screen.getAllByText('Set Active') - expect(setActiveButtons).toHaveLength(3) + // Only OpenAIChatTarget section is expanded, and it has the active target + // So no "Set Active" buttons visible (only 1 target in that group, and it's active) + expect(screen.queryByText('Set Active')).not.toBeInTheDocument() }) it('should show "Active" badge for the active target', () => { @@ -94,9 +155,6 @@ describe('TargetTable', () => { ) expect(screen.getByText('Active')).toBeInTheDocument() - // The other two should still have "Set Active" - const setActiveButtons = screen.getAllByText('Set Active') - expect(setActiveButtons).toHaveLength(2) }) it('should call onSetActiveTarget when "Set Active" is clicked', () => { @@ -108,8 +166,11 @@ describe('TargetTable', () => { ) - const setActiveButtons = screen.getAllByText('Set Active') - fireEvent.click(setActiveButtons[1]) + // Expand AzureImageTarget section + fireEvent.click(screen.getByRole('button', { name: /AzureImageTarget/i })) + + const setActiveButton = screen.getByText('Set Active') + fireEvent.click(setActiveButton) expect(onSetActiveTarget).toHaveBeenCalledTimes(1) expect(onSetActiveTarget).toHaveBeenCalledWith(sampleTargets[1]) @@ -122,7 +183,7 @@ describe('TargetTable', () => { ) - expect(screen.getByRole('table')).toBeInTheDocument() + expect(screen.queryByRole('table')).not.toBeInTheDocument() expect(screen.queryByText('Set Active')).not.toBeInTheDocument() }) @@ -133,19 +194,26 @@ describe('TargetTable', () => { ) + // Expand TextTarget section + fireEvent.click(screen.getByRole('button', { name: /TextTarget/i })) + // TextTarget has null model_name and endpoint; should render "—" const dashes = screen.getAllByText('—') expect(dashes.length).toBeGreaterThanOrEqual(2) }) - it('should display Parameters column header', () => { + it('should display Parameters column header when expanded', () => { render( ) - expect(screen.getByText('Parameters')).toBeInTheDocument() + // Expand a section first + fireEvent.click(screen.getByRole('button', { name: /OpenAIChatTarget/i })) + + const paramHeaders = screen.getAllByText('Parameters') + expect(paramHeaders.length).toBeGreaterThanOrEqual(1) }) it('should display target_specific_params when present', () => { @@ -168,7 +236,102 @@ describe('TargetTable', () => { ) + // Expand the section + fireEvent.click(screen.getByRole('button', { name: /OpenAIResponseTarget/i })) + expect(screen.getByText(/reasoning_effort: high/)).toBeInTheDocument() expect(screen.getByText(/max_output_tokens: 4096/)).toBeInTheDocument() }) + + it('should show tooltip for model with different underlying model', () => { + const targetWithUnderlying: TargetInstance[] = [ + { + target_registry_name: 'azure_deployment', + target_type: 'OpenAIChatTarget', + endpoint: 'https://azure.openai.com', + model_name: 'gpt-4o', + deployment_name: 'my-gpt4o-deployment', + }, + ] + + render( + + + + ) + + // Expand the section + fireEvent.click(screen.getByRole('button', { name: /OpenAIChatTarget/i })) + + // Deployment name should be displayed with dotted underline + const modelText = screen.getByText('my-gpt4o-deployment') + expect(modelText).toHaveStyle({ textDecoration: 'underline dotted' }) + }) + + it('should keep multiple sections expanded independently', () => { + render( + + + + ) + + // Expand two sections + fireEvent.click(screen.getByRole('button', { name: /OpenAIChatTarget/i })) + fireEvent.click(screen.getByRole('button', { name: /AzureImageTarget/i })) + + // Both should be visible + expect(screen.getByText('gpt-4')).toBeInTheDocument() + expect(screen.getByText('dall-e-3')).toBeInTheDocument() + + // Collapse one — the other should remain + fireEvent.click(screen.getByRole('button', { name: /OpenAIChatTarget/i })) + expect(screen.queryByText('gpt-4')).not.toBeInTheDocument() + expect(screen.getByText('dall-e-3')).toBeInTheDocument() + }) + + it('should expand all sections when "Expand All" is clicked', () => { + render( + + + + ) + + // All collapsed initially + expect(screen.queryByRole('table')).not.toBeInTheDocument() + + fireEvent.click(screen.getByRole('button', { name: /expand all/i })) + + // All sections expanded — all data visible + expect(screen.getByText('gpt-4')).toBeInTheDocument() + expect(screen.getByText('dall-e-3')).toBeInTheDocument() + const tables = screen.getAllByRole('table') + expect(tables.length).toBe(3) + }) + + it('should collapse all sections when "Collapse All" is clicked', () => { + render( + + + + ) + + // Expand all first + fireEvent.click(screen.getByRole('button', { name: /expand all/i })) + expect(screen.getAllByRole('table').length).toBe(3) + + // Now collapse all + fireEvent.click(screen.getByRole('button', { name: /collapse all/i })) + expect(screen.queryByRole('table')).not.toBeInTheDocument() + }) + + it('should not show Expand All button when targets list is empty', () => { + render( + + + + ) + + expect(screen.queryByRole('button', { name: /expand all/i })).not.toBeInTheDocument() + expect(screen.queryByRole('button', { name: /collapse all/i })).not.toBeInTheDocument() + }) }) diff --git a/frontend/src/components/Config/TargetTable.tsx b/frontend/src/components/Config/TargetTable.tsx index 389347a62b..41e53d9652 100644 --- a/frontend/src/components/Config/TargetTable.tsx +++ b/frontend/src/components/Config/TargetTable.tsx @@ -1,3 +1,4 @@ +import { useState, useMemo, useEffect } from 'react' import { Table, TableHeader, @@ -8,8 +9,9 @@ import { Badge, Button, Text, + Tooltip, } from '@fluentui/react-components' -import { CheckmarkRegular } from '@fluentui/react-icons' +import { CheckmarkRegular, ChevronDownRegular, ChevronRightRegular } from '@fluentui/react-icons' import type { TargetInstance } from '../../types' import { useTargetTableStyles } from './TargetTable.styles' @@ -34,68 +36,184 @@ function formatParams(params?: Record | null): string { parts.push(`${key}: ${typeof val === 'object' ? JSON.stringify(val) : String(val)}`) } } - return parts.join(', ') + return parts.join('\n') +} + +/** Group targets by target_type, sorted alphabetically by type name. */ +function groupByType(targets: TargetInstance[]): Array<[string, TargetInstance[]]> { + const groups = new Map() + for (const target of targets) { + const list = groups.get(target.target_type) ?? [] + list.push(target) + groups.set(target.target_type, list) + } + return Array.from(groups.entries()).sort(([a], [b]) => a.localeCompare(b)) +} + +/** Render the model cell with a tooltip when underlying model differs. */ +function ModelCell({ target }: { target: TargetInstance }) { + const displayName = target.deployment_name || target.model_name || '—' + const hasUnderlying = target.model_name + && target.deployment_name + && target.model_name !== target.deployment_name + + if (hasUnderlying) { + return ( + + + {displayName} + + + ) + } + + return {displayName} +} + +/** Find which type group contains the active target (if any). */ +function findActiveGroup( + groups: Array<[string, TargetInstance[]]>, + activeTarget: TargetInstance | null, +): string | null { + if (!activeTarget) return null + for (const [typeName, targets] of groups) { + if (targets.some(t => t.target_registry_name === activeTarget.target_registry_name)) { + return typeName + } + } + return null } export default function TargetTable({ targets, activeTarget, onSetActiveTarget }: TargetTableProps) { const styles = useTargetTableStyles() + const grouped = useMemo(() => groupByType(targets), [targets]) + const activeGroup = useMemo(() => findActiveGroup(grouped, activeTarget), [grouped, activeTarget]) + + const [expandedSections, setExpandedSections] = useState>(() => { + // Start with only the active target's section expanded + return activeGroup ? new Set([activeGroup]) : new Set() + }) + + // When active target changes, ensure its section is expanded + useEffect(() => { + if (activeGroup) { + setExpandedSections(prev => { + if (prev.has(activeGroup)) return prev + return new Set([...prev, activeGroup]) + }) + } + }, [activeGroup]) + + const toggleSection = (typeName: string) => { + setExpandedSections(prev => { + const next = new Set(prev) + if (next.has(typeName)) { + next.delete(typeName) + } else { + next.add(typeName) + } + return next + }) + } + + const allTypeNames = useMemo(() => grouped.map(([name]) => name), [grouped]) + const allExpanded = allTypeNames.length > 0 && allTypeNames.every(n => expandedSections.has(n)) + + const toggleAll = () => { + if (allExpanded) { + setExpandedSections(new Set()) + } else { + setExpandedSections(new Set(allTypeNames)) + } + } + const isActive = (target: TargetInstance): boolean => activeTarget?.target_registry_name === target.target_registry_name return (
- - - - - Type - Model - Endpoint - Parameters - - - - {targets.map((target) => ( - 0 && ( +
+ +
+ )} + {grouped.map(([typeName, groupTargets]) => { + const isExpanded = expandedSections.has(typeName) + return ( +
+ - )} - - - {target.target_type} - - - {target.model_name || '—'} - - - - {target.endpoint || '—'} - - - - - {formatParams(target.target_specific_params) || '—'} - - - - ))} - -
+ + {typeName} + + + ({groupTargets.length}) + + + {isExpanded && ( + + + + + Model + Endpoint + Parameters + + + + {groupTargets.map((target) => ( + + + {isActive(target) ? ( + }> + Active + + ) : ( + + )} + + + + + + + {target.endpoint || '—'} + + + + + {formatParams(target.target_specific_params) || '—'} + + + + ))} + +
+ )} +
+ ) + })} ) } diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index c66d327cb8..09dc2abb42 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -58,6 +58,7 @@ export interface TargetInstance { target_type: string endpoint?: string | null model_name?: string | null + deployment_name?: string | null temperature?: number | null top_p?: number | null max_requests_per_minute?: number | null diff --git a/pyrit/backend/mappers/target_mappers.py b/pyrit/backend/mappers/target_mappers.py index 84fcf1c68c..4c14f7e582 100644 --- a/pyrit/backend/mappers/target_mappers.py +++ b/pyrit/backend/mappers/target_mappers.py @@ -30,6 +30,7 @@ def target_object_to_instance(target_registry_name: str, target_obj: PromptTarge extracted_keys = { "endpoint", "model_name", + "deployment_name", "temperature", "top_p", "max_requests_per_minute", @@ -47,6 +48,7 @@ def target_object_to_instance(target_registry_name: str, target_obj: PromptTarge target_type=identifier.class_name, endpoint=params.get("endpoint") or None, model_name=params.get("model_name") or None, + deployment_name=params.get("deployment_name") or None, temperature=params.get("temperature"), top_p=params.get("top_p"), max_requests_per_minute=params.get("max_requests_per_minute"), diff --git a/pyrit/backend/models/targets.py b/pyrit/backend/models/targets.py index 43c4bb8190..543c6639a9 100644 --- a/pyrit/backend/models/targets.py +++ b/pyrit/backend/models/targets.py @@ -29,7 +29,8 @@ class TargetInstance(BaseModel): target_registry_name: str = Field(..., description="Target registry key (e.g., 'azure_openai_chat')") target_type: str = Field(..., description="Target class name (e.g., 'OpenAIChatTarget')") endpoint: Optional[str] = Field(None, description="Target endpoint URL") - model_name: Optional[str] = Field(None, description="Model or deployment name") + model_name: Optional[str] = Field(None, description="Underlying model name (e.g., 'gpt-4o')") + deployment_name: Optional[str] = Field(None, description="Deployment or model name used in API calls") temperature: Optional[float] = Field(None, description="Temperature parameter for generation") top_p: Optional[float] = Field(None, description="Top-p parameter for generation") max_requests_per_minute: Optional[int] = Field(None, description="Maximum requests per minute") diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index ecaf5e1c39..6dbdad7077 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -179,6 +179,7 @@ def _create_identifier( all_params: dict[str, Any] = { "endpoint": self._endpoint, "model_name": model_name, + "deployment_name": self._model_name or "", "max_requests_per_minute": self._max_requests_per_minute, "supports_multi_turn": self.capabilities.supports_multi_turn, } diff --git a/pyrit/prompt_target/openai/openai_target.py b/pyrit/prompt_target/openai/openai_target.py index 2c9a48fa21..e338362052 100644 --- a/pyrit/prompt_target/openai/openai_target.py +++ b/pyrit/prompt_target/openai/openai_target.py @@ -124,10 +124,19 @@ def __init__( env_var_name=self.endpoint_environment_variable, passed_value=endpoint ) - # Get underlying_model from passed value or environment variable - underlying_model_value = default_values.get_non_required_value( - env_var_name=self.underlying_model_environment_variable, passed_value=underlying_model - ) + # Get underlying_model from passed value or environment variable. + # Only fall back to the env var when model_name was also resolved from env vars + # (i.e., no explicit model_name was passed). This prevents the generic + # OPENAI_CHAT_UNDERLYING_MODEL env var from overriding an explicitly provided + # model_name for unrelated targets (e.g., deepseek, gemini, mistral). + if underlying_model is not None: + underlying_model_value = underlying_model + elif model_name is None: + underlying_model_value = default_values.get_non_required_value( + env_var_name=self.underlying_model_environment_variable, passed_value=None + ) + else: + underlying_model_value = None # Initialize parent with endpoint and model_name PromptTarget.__init__( diff --git a/tests/unit/backend/test_target_service.py b/tests/unit/backend/test_target_service.py index ee12045129..13a6fc32c4 100644 --- a/tests/unit/backend/test_target_service.py +++ b/tests/unit/backend/test_target_service.py @@ -5,7 +5,8 @@ Tests for backend target service. """ -from unittest.mock import MagicMock +import os +from unittest.mock import MagicMock, patch import pytest @@ -277,6 +278,27 @@ async def test_create_target_registers_in_registry(self, sqlite_instance) -> Non target_obj = service.get_target_object(target_registry_name=result.target_registry_name) assert target_obj is not None + @pytest.mark.asyncio + async def test_create_target_model_name_not_overridden_by_env_var(self, sqlite_instance) -> None: + """Test that explicit model_name is not overridden by underlying_model env var.""" + with patch.dict(os.environ, {"OPENAI_CHAT_UNDERLYING_MODEL": "gpt-4o"}): + service = TargetService() + + request = CreateTargetRequest( + type="OpenAIChatTarget", + params={ + "model_name": "claude-sonnet-4-6", + "endpoint": "https://test.openai.azure.com/", + "api_key": "test-key", + }, + ) + + result = await service.create_target_async(request=request) + + assert result.deployment_name == "claude-sonnet-4-6" + # model_name in identifier should also be claude-sonnet-4-6, not gpt-4o + assert result.model_name == "claude-sonnet-4-6" + class TestTargetServiceSingleton: """Tests for get_target_service singleton function.""" diff --git a/tests/unit/target/test_openai_chat_target.py b/tests/unit/target/test_openai_chat_target.py index 56f61b8733..98e55f2c78 100644 --- a/tests/unit/target/test_openai_chat_target.py +++ b/tests/unit/target/test_openai_chat_target.py @@ -1115,7 +1115,22 @@ def test_get_identifier_uses_underlying_model_when_provided_as_param(patch_centr def test_get_identifier_uses_underlying_model_from_env_var(patch_central_database): - """Test that get_identifier uses underlying_model from environment variable.""" + """Test that get_identifier uses underlying_model env var only when model_name is NOT explicitly passed.""" + with patch.dict(os.environ, {"OPENAI_CHAT_UNDERLYING_MODEL": "gpt-4o", "OPENAI_CHAT_MODEL": "my-deployment"}): + # When model_name is resolved from env var (not passed explicitly), + # underlying_model env var should be used + target = OpenAIChatTarget( + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + identifier = target.get_identifier() + + assert identifier.params["model_name"] == "gpt-4o" + + +def test_get_identifier_ignores_underlying_model_env_var_when_model_name_explicit(patch_central_database): + """Test that underlying_model env var is NOT used when model_name is explicitly passed.""" with patch.dict(os.environ, {"OPENAI_CHAT_UNDERLYING_MODEL": "gpt-4o"}): target = OpenAIChatTarget( model_name="my-deployment", @@ -1125,7 +1140,9 @@ def test_get_identifier_uses_underlying_model_from_env_var(patch_central_databas identifier = target.get_identifier() - assert identifier.params["model_name"] == "gpt-4o" + # model_name was explicit, so underlying_model env var should be ignored + assert identifier.params["model_name"] == "my-deployment" + assert identifier.params["deployment_name"] == "my-deployment" def test_underlying_model_param_takes_precedence_over_env_var(patch_central_database): @@ -1141,6 +1158,7 @@ def test_underlying_model_param_takes_precedence_over_env_var(patch_central_data identifier = target.get_identifier() assert identifier.params["model_name"] == "gpt-4o-from-param" + assert identifier.params["deployment_name"] == "my-deployment" def test_get_identifier_includes_endpoint(patch_central_database): diff --git a/tests/unit/target/test_openai_response_target.py b/tests/unit/target/test_openai_response_target.py index 507f8f0935..e7cadde9ac 100644 --- a/tests/unit/target/test_openai_response_target.py +++ b/tests/unit/target/test_openai_response_target.py @@ -1287,3 +1287,19 @@ def test_build_identifier_includes_reasoning_params(patch_central_database): identifier = target._build_identifier() assert identifier.params["reasoning_effort"] == "low" assert identifier.params["reasoning_summary"] == "concise" + + +def test_get_identifier_ignores_underlying_model_env_var_when_model_name_explicit(patch_central_database): + """Test that underlying_model env var is NOT used when model_name is explicitly passed.""" + with patch.dict(os.environ, {"OPENAI_RESPONSES_UNDERLYING_MODEL": "gpt-4o"}): + target = OpenAIResponseTarget( + model_name="gpt-4.1", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + identifier = target.get_identifier() + + # model_name was explicit, so underlying_model env var should be ignored + assert identifier.params["model_name"] == "gpt-4.1" + assert identifier.params["deployment_name"] == "gpt-4.1"