diff --git a/app/activity.py b/app/activity.py index 4cb3dce..324915f 100644 --- a/app/activity.py +++ b/app/activity.py @@ -2,7 +2,7 @@ from typing import Any -from fastapi import FastAPI, Query, Request +from fastapi import FastAPI, HTTPException, Query, Request from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates from sqlalchemy.orm import Session @@ -11,8 +11,16 @@ from app.serializers import activity_to_dict +def _normalized_activity_search_query(query: str | None) -> str | None: + if query is None: + return None + if any(ord(char) < 32 or ord(char) == 127 for char in query): + raise HTTPException(status_code=400, detail="q must not contain control characters") + return query.strip() + + def activity_context(session: Session, query: str | None = None) -> dict[str, Any]: - return activity_to_dict(session, query) + return activity_to_dict(session, _normalized_activity_search_query(query)) def register_activity_routes(app: FastAPI, *, db_url: str, templates: Jinja2Templates) -> None: diff --git a/app/bounty_api.py b/app/bounty_api.py index 04f2365..3db1643 100644 --- a/app/bounty_api.py +++ b/app/bounty_api.py @@ -81,6 +81,14 @@ def _ledger_http_error(exc: LedgerError) -> HTTPException: return HTTPException(status_code=400, detail=detail) +def _normalized_bounty_search_query(query_text: str | None) -> str | None: + if query_text is None: + return None + if CONTROL_CHAR_RE.search(query_text): + raise HTTPException(status_code=400, detail="q must not contain control characters") + return query_text.strip() + + def register_bounty_api_routes( app: FastAPI, *, @@ -114,25 +122,24 @@ def _list_bounties_by_status( status_code=400, detail="status must be one of: open, paid, closed" ) query = query.where(Bounty.status == normalized_status) - if query_text is not None: - normalized_query = query_text.strip() - if normalized_query: - escaped_query = ( - normalized_query.lower() - .replace("\\", "\\\\") - .replace("%", "\\%") - .replace("_", "\\_") - ) - like_query = f"%{escaped_query}%" - issue_number = issue_number_search_value(normalized_query) - text_filter = or_( - func.lower(Bounty.repo).like(like_query, escape="\\"), - func.lower(Bounty.title).like(like_query, escape="\\"), - func.lower(Bounty.acceptance).like(like_query, escape="\\"), - ) - if issue_number is not None: - text_filter = or_(text_filter, Bounty.issue_number == issue_number) - query = query.where(text_filter) + normalized_query = _normalized_bounty_search_query(query_text) + if normalized_query: + escaped_query = ( + normalized_query.lower() + .replace("\\", "\\\\") + .replace("%", "\\%") + .replace("_", "\\_") + ) + like_query = f"%{escaped_query}%" + issue_number = issue_number_search_value(normalized_query) + text_filter = or_( + func.lower(Bounty.repo).like(like_query, escape="\\"), + func.lower(Bounty.title).like(like_query, escape="\\"), + func.lower(Bounty.acceptance).like(like_query, escape="\\"), + ) + if issue_number is not None: + text_filter = or_(text_filter, Bounty.issue_number == issue_number) + query = query.where(text_filter) bounties = session.scalars(query.order_by(Bounty.id.desc())).all() sorted_bounties = sort_bounties( [bounty_to_dict(bounty) for bounty in bounties], normalized_sort diff --git a/app/mcp_tools.py b/app/mcp_tools.py index 0a0bd37..fc03aa6 100644 --- a/app/mcp_tools.py +++ b/app/mcp_tools.py @@ -27,6 +27,10 @@ def call_mcp_tool(database_url: str, name: str, args: dict[str, Any]) -> str | dict[str, Any]: + def reject_control_chars(field: str, value: str) -> None: + if any(ord(char) < 32 or ord(char) == 127 for char in value): + raise ValueError(f"{field} must not contain control characters") + def int_arg(field: str) -> int: value = args[field] if isinstance(value, bool): @@ -79,6 +83,16 @@ def optional_clean_str_arg(field: str) -> str | None: clean = value.strip() return clean or None + def optional_search_query_arg(field: str) -> str | None: + value = args.get(field) + if value is None: + return None + if not isinstance(value, str): + raise ValueError(f"{field} must be a string") + reject_control_chars(field, value) + clean = value.strip() + return clean or None + def output_format_arg() -> str: value = args.get("format", "text") if value is None: @@ -121,7 +135,7 @@ def optional_bool_arg(field: str, default: bool = False) -> bool: if normalized_status not in {"open", "paid", "closed"}: raise ValueError("status must be one of: open, paid, closed") query = select(Bounty).where(Bounty.status == normalized_status) - query_text = optional_clean_str_arg("q") + query_text = optional_search_query_arg("q") if query_text: escaped_query = ( query_text.lower().replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") diff --git a/tests/test_activity.py b/tests/test_activity.py index e383524..68bc12d 100644 --- a/tests/test_activity.py +++ b/tests/test_activity.py @@ -192,6 +192,28 @@ def test_activity_api_filters_accepted_work_by_query(sqlite_url: str) -> None: assert invalid_hash_query["recent"] == [] +def test_activity_rejects_control_character_search_query(sqlite_url: str) -> None: + create_schema(sqlite_url) + with session_scope(sqlite_url) as session: + ensure_genesis(session) + + client = TestClient(create_app(database_url=sqlite_url, webhook_secret="secret")) + + api_response = client.get("/api/v1/activity", params={"q": "\x00"}) + page_response = client.get("/activity", params={"q": "\x00"}) + leading_tab_response = client.get("/api/v1/activity", params={"q": "\talice"}) + trailing_newline_response = client.get("/activity", params={"q": "alice\n"}) + + assert api_response.status_code == 400 + assert api_response.json() == {"detail": "q must not contain control characters"} + assert page_response.status_code == 400 + assert page_response.json() == {"detail": "q must not contain control characters"} + assert leading_tab_response.status_code == 400 + assert leading_tab_response.json() == {"detail": "q must not contain control characters"} + assert trailing_newline_response.status_code == 400 + assert trailing_newline_response.json() == {"detail": "q must not contain control characters"} + + def test_activity_page_renders_empty_and_paid_states(sqlite_url: str) -> None: create_schema(sqlite_url) client = TestClient(create_app(database_url=sqlite_url, webhook_secret="secret")) diff --git a/tests/test_api_mcp.py b/tests/test_api_mcp.py index 8343864..6d21dde 100644 --- a/tests/test_api_mcp.py +++ b/tests/test_api_mcp.py @@ -540,6 +540,9 @@ def test_mcp_list_bounties_honors_sort_argument(sqlite_url: str) -> None: ({"limit": 0}, 34), ({"limit": 101}, 35), ({"sort": "invalid"}, 36), + ({"q": "\x00"}, 37), + ({"q": "\tDocs"}, 38), + ({"q": "Docs\n"}, 39), ], ) def test_mcp_list_bounties_rejects_invalid_filters( diff --git a/tests/test_bounty_api_routes.py b/tests/test_bounty_api_routes.py index 71a7441..47438f2 100644 --- a/tests/test_bounty_api_routes.py +++ b/tests/test_bounty_api_routes.py @@ -313,3 +313,40 @@ def test_bounty_api_limit_rejects_out_of_range_values(sqlite_url: str) -> None: assert client.get("/api/v1/bounties?limit=201").status_code == 422 assert client.get("/api/v1/bounties/summary?limit=0").status_code == 422 assert client.get("/api/v1/bounties/summary?limit=201").status_code == 422 + + +def test_bounty_api_rejects_control_character_search_queries(sqlite_url: str) -> None: + create_schema(sqlite_url) + with session_scope(sqlite_url) as session: + ensure_genesis(session) + create_bounty( + session, + repo="ramimbo/mergework", + issue_number=53, + issue_url="https://github.com/ramimbo/mergework/issues/53", + title="Control character bounty search", + reward_mrwk="5", + acceptance="Control characters should not widen bounty search results.", + ) + + client = TestClient(create_app(database_url=sqlite_url, webhook_secret="secret")) + + list_response = client.get("/api/v1/bounties?q=%00") + summary_response = client.get("/api/v1/bounties/summary?q=%00") + del_list_response = client.get("/api/v1/bounties", params={"q": "\x7f"}) + del_summary_response = client.get("/api/v1/bounties/summary", params={"q": "\x7f"}) + leading_tab_response = client.get("/api/v1/bounties", params={"q": "\tControl"}) + trailing_newline_response = client.get("/api/v1/bounties/summary", params={"q": "Control\n"}) + + assert list_response.status_code == 400 + assert list_response.json()["detail"] == "q must not contain control characters" + assert summary_response.status_code == 400 + assert summary_response.json()["detail"] == "q must not contain control characters" + assert del_list_response.status_code == 400 + assert del_list_response.json()["detail"] == "q must not contain control characters" + assert del_summary_response.status_code == 400 + assert del_summary_response.json()["detail"] == "q must not contain control characters" + assert leading_tab_response.status_code == 400 + assert leading_tab_response.json()["detail"] == "q must not contain control characters" + assert trailing_newline_response.status_code == 400 + assert trailing_newline_response.json()["detail"] == "q must not contain control characters" diff --git a/tests/test_bounty_pages.py b/tests/test_bounty_pages.py index 55b9f8b..eee8312 100644 --- a/tests/test_bounty_pages.py +++ b/tests/test_bounty_pages.py @@ -282,6 +282,14 @@ def test_bounties_page_and_api_search_by_text_and_issue_number(sqlite_url: str) assert backslash_search.status_code == 200 assert [row["issue_number"] for row in backslash_search.json()] == [66] + control_char_page_search = client.get("/bounties", params={"q": "\x00"}) + assert control_char_page_search.status_code == 400 + assert control_char_page_search.json()["detail"] == "q must not contain control characters" + + leading_tab_page_search = client.get("/bounties", params={"q": "\tDocs"}) + assert leading_tab_page_search.status_code == 400 + assert leading_tab_page_search.json()["detail"] == "q must not contain control characters" + def test_bounties_page_and_api_sort_public_rows(sqlite_url: str) -> None: create_schema(sqlite_url)