diff --git a/app/api/v2/handlers/payload_api.py b/app/api/v2/handlers/payload_api.py index 1b034a332..6c6350348 100644 --- a/app/api/v2/handlers/payload_api.py +++ b/app/api/v2/handlers/payload_api.py @@ -4,6 +4,7 @@ import pathlib import re from io import IOBase +from typing import Optional import aiohttp_apispec from aiohttp import web @@ -27,7 +28,8 @@ def add_routes(self, app: web.Application): @aiohttp_apispec.docs(tags=['payloads'], summary='Retrieve payloads', - description='Retrieves all stored payloads.') + description='Retrieves all stored payloads. Supports optional filtering by name ' + '(case-insensitive substring match via the `name` query parameter).') @aiohttp_apispec.querystring_schema(PayloadQuerySchema) @aiohttp_apispec.response_schema(PayloadSchema(), description='Returns a list of all payloads in PayloadSchema format.') @@ -35,6 +37,7 @@ async def get_payloads(self, request: web.Request): sort: bool = request['querystring'].get('sort') exclude_plugins: bool = request['querystring'].get('exclude_plugins') add_path: bool = request['querystring'].get('add_path') + name_filter: Optional[str] = request['querystring'].get('name') cwd = pathlib.Path.cwd() payload_dirs = [cwd / 'data' / 'payloads'] @@ -52,6 +55,11 @@ async def get_payloads(self, request: web.Request): } payloads = list(payloads) + + if name_filter: + name_filter_lower = name_filter.lower() + payloads = [p for p in payloads if name_filter_lower in pathlib.PurePath(p).name.lower()] + if sort: payloads.sort() diff --git a/app/api/v2/schemas/payload_schemas.py b/app/api/v2/schemas/payload_schemas.py index 22d8701f4..a03af81a6 100644 --- a/app/api/v2/schemas/payload_schemas.py +++ b/app/api/v2/schemas/payload_schemas.py @@ -5,6 +5,7 @@ class PayloadQuerySchema(schema.Schema): sort = fields.Boolean(required=False, load_default=False) exclude_plugins = fields.Boolean(required=False, load_default=False) add_path = fields.Boolean(required=False, load_default=False) + name = fields.String(required=False, load_default=None, allow_none=True) class PayloadSchema(schema.Schema): diff --git a/tests/api/v2/handlers/test_payloads_api.py b/tests/api/v2/handlers/test_payloads_api.py index 8570f35ba..4ffa34eb7 100644 --- a/tests/api/v2/handlers/test_payloads_api.py +++ b/tests/api/v2/handlers/test_payloads_api.py @@ -1,4 +1,5 @@ import os +import pathlib import tempfile from http import HTTPStatus @@ -49,6 +50,35 @@ async def test_get_payloads(self, api_v2_client, api_cookies, expected_payload_f assert filtered_payload_file_names == expected_payload_file_names + @pytest.mark.parametrize('query_name', ['payload_', 'PAYLOAD_']) + async def test_get_payloads_name_filter(self, api_v2_client, api_cookies, expected_payload_file_names, query_name): + resp = await api_v2_client.get(f'/api/v2/payloads?name={query_name}', cookies=api_cookies) + assert resp.status == HTTPStatus.OK + payload_file_names = await resp.json() + + # All expected payloads should be present + assert expected_payload_file_names <= set(payload_file_names) + # Every returned payload must match the filter (no false positives) + assert all('payload_' in pathlib.PurePath(p).name.lower() for p in payload_file_names) + + async def test_get_payloads_name_filter_no_match(self, api_v2_client, api_cookies): + resp = await api_v2_client.get('/api/v2/payloads?name=__no_match_xyzzy__', cookies=api_cookies) + assert resp.status == HTTPStatus.OK + assert await resp.json() == [] + + async def test_get_payloads_name_filter_with_sort_and_add_path( + self, api_v2_client, api_cookies, expected_payload_file_names): + resp = await api_v2_client.get('/api/v2/payloads?name=payload_&sort=true&add_path=true', cookies=api_cookies) + assert resp.status == HTTPStatus.OK + payload_paths = await resp.json() + + # Results should be sorted + assert payload_paths == sorted(payload_paths) + # Every returned path's filename must match the filter + assert all('payload_' in pathlib.PurePath(p).name.lower() for p in payload_paths) + # Results should contain paths (not bare filenames) + assert all(os.sep in p or '/' in p for p in payload_paths) + async def test_unauthorized_get_payloads(self, api_v2_client): resp = await api_v2_client.get('/api/v2/payloads') assert resp.status == HTTPStatus.UNAUTHORIZED