Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion app/api/v2/handlers/payload_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pathlib
import re
from io import IOBase
from typing import Optional

import aiohttp_apispec
from aiohttp import web
Expand All @@ -27,14 +28,16 @@ def add_routes(self, app: web.Application):

@aiohttp_apispec.docs(tags=['payloads'],
summary='Retrieve payloads',
description='Retrieves all stored payloads.')
description='Retrieves all stored payloads. Supports optional filtering by name '
'(case-insensitive substring match via the `name` query parameter).')
@aiohttp_apispec.querystring_schema(PayloadQuerySchema)
@aiohttp_apispec.response_schema(PayloadSchema(),
description='Returns a list of all payloads in PayloadSchema format.')
Comment on lines +31 to 35
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The OpenAPI/apispec response schema for this endpoint is declared as PayloadSchema (an object with a payloads list), but the handler returns web.json_response(payloads) where payloads is a raw JSON array of strings. This makes the generated API spec inaccurate for clients. Update the response schema to match the actual array response (or wrap the response in the schema shape).

Copilot uses AI. Check for mistakes.
async def get_payloads(self, request: web.Request):
sort: bool = request['querystring'].get('sort')
exclude_plugins: bool = request['querystring'].get('exclude_plugins')
add_path: bool = request['querystring'].get('add_path')
name_filter: Optional[str] = request['querystring'].get('name')

cwd = pathlib.Path.cwd()
payload_dirs = [cwd / 'data' / 'payloads']
Expand All @@ -52,6 +55,11 @@ async def get_payloads(self, request: web.Request):
}

payloads = list(payloads)

if name_filter:
name_filter_lower = name_filter.lower()
payloads = [p for p in payloads if name_filter_lower in pathlib.PurePath(p).name.lower()]

Comment on lines +59 to +62
if sort:
payloads.sort()

Expand Down
1 change: 1 addition & 0 deletions app/api/v2/schemas/payload_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ class PayloadQuerySchema(schema.Schema):
sort = fields.Boolean(required=False, load_default=False)
exclude_plugins = fields.Boolean(required=False, load_default=False)
add_path = fields.Boolean(required=False, load_default=False)
name = fields.String(required=False, load_default=None, allow_none=True)


class PayloadSchema(schema.Schema):
Expand Down
30 changes: 30 additions & 0 deletions tests/api/v2/handlers/test_payloads_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pathlib
import tempfile
from http import HTTPStatus

Expand Down Expand Up @@ -49,6 +50,35 @@ async def test_get_payloads(self, api_v2_client, api_cookies, expected_payload_f

assert filtered_payload_file_names == expected_payload_file_names

@pytest.mark.parametrize('query_name', ['payload_', 'PAYLOAD_'])
async def test_get_payloads_name_filter(self, api_v2_client, api_cookies, expected_payload_file_names, query_name):
resp = await api_v2_client.get(f'/api/v2/payloads?name={query_name}', cookies=api_cookies)
assert resp.status == HTTPStatus.OK
payload_file_names = await resp.json()

# All expected payloads should be present
assert expected_payload_file_names <= set(payload_file_names)
# Every returned payload must match the filter (no false positives)
assert all('payload_' in pathlib.PurePath(p).name.lower() for p in payload_file_names)
Comment on lines +61 to +62
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test hard-codes 'payload_' in the match assertion instead of using query_name (or query_name.lower()), which makes the parametrization less meaningful and can hide issues if the query value is changed/extended later. Use the parametrized value in the assertion so the test validates the filter criteria being requested.

Suggested change
# Every returned payload must match the filter (no false positives)
assert all('payload_' in pathlib.PurePath(p).name.lower() for p in payload_file_names)
# Every returned payload must match the requested filter (no false positives)
assert all(query_name.lower() in pathlib.PurePath(p).name.lower() for p in payload_file_names)

Copilot uses AI. Check for mistakes.

async def test_get_payloads_name_filter_no_match(self, api_v2_client, api_cookies):
resp = await api_v2_client.get('/api/v2/payloads?name=__no_match_xyzzy__', cookies=api_cookies)
assert resp.status == HTTPStatus.OK
assert await resp.json() == []

async def test_get_payloads_name_filter_with_sort_and_add_path(
self, api_v2_client, api_cookies, expected_payload_file_names):
resp = await api_v2_client.get('/api/v2/payloads?name=payload_&sort=true&add_path=true', cookies=api_cookies)
assert resp.status == HTTPStatus.OK
payload_paths = await resp.json()

# Results should be sorted
assert payload_paths == sorted(payload_paths)
# Every returned path's filename must match the filter
assert all('payload_' in pathlib.PurePath(p).name.lower() for p in payload_paths)
# Results should contain paths (not bare filenames)
Comment on lines +76 to +79
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the previous test, this assertion hard-codes 'payload_' instead of checking against the actual query value used for the request. Using the request's filter value in the assertion would keep the test aligned with what it's exercising.

Copilot uses AI. Check for mistakes.
assert all(os.sep in p or '/' in p for p in payload_paths)

async def test_unauthorized_get_payloads(self, api_v2_client):
resp = await api_v2_client.get('/api/v2/payloads')
assert resp.status == HTTPStatus.UNAUTHORIZED
Loading