Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions app/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any

from fastapi import FastAPI, Query, Request
from fastapi import FastAPI, HTTPException, Query, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session
Expand All @@ -11,8 +11,16 @@
from app.serializers import activity_to_dict


def _normalized_activity_search_query(query: str | None) -> str | None:
if query is None:
return None
if any(ord(char) < 32 or ord(char) == 127 for char in query):
raise HTTPException(status_code=400, detail="q must not contain control characters")
return query.strip()


def activity_context(session: Session, query: str | None = None) -> dict[str, Any]:
return activity_to_dict(session, query)
return activity_to_dict(session, _normalized_activity_search_query(query))


def register_activity_routes(app: FastAPI, *, db_url: str, templates: Jinja2Templates) -> None:
Expand Down
45 changes: 26 additions & 19 deletions app/bounty_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ def _ledger_http_error(exc: LedgerError) -> HTTPException:
return HTTPException(status_code=400, detail=detail)


def _normalized_bounty_search_query(query_text: str | None) -> str | None:
if query_text is None:
return None
if CONTROL_CHAR_RE.search(query_text):
raise HTTPException(status_code=400, detail="q must not contain control characters")
return query_text.strip()


def register_bounty_api_routes(
app: FastAPI,
*,
Expand Down Expand Up @@ -114,25 +122,24 @@ def _list_bounties_by_status(
status_code=400, detail="status must be one of: open, paid, closed"
)
query = query.where(Bounty.status == normalized_status)
if query_text is not None:
normalized_query = query_text.strip()
if normalized_query:
escaped_query = (
normalized_query.lower()
.replace("\\", "\\\\")
.replace("%", "\\%")
.replace("_", "\\_")
)
like_query = f"%{escaped_query}%"
issue_number = issue_number_search_value(normalized_query)
text_filter = or_(
func.lower(Bounty.repo).like(like_query, escape="\\"),
func.lower(Bounty.title).like(like_query, escape="\\"),
func.lower(Bounty.acceptance).like(like_query, escape="\\"),
)
if issue_number is not None:
text_filter = or_(text_filter, Bounty.issue_number == issue_number)
query = query.where(text_filter)
normalized_query = _normalized_bounty_search_query(query_text)
if normalized_query:
escaped_query = (
normalized_query.lower()
.replace("\\", "\\\\")
.replace("%", "\\%")
.replace("_", "\\_")
)
like_query = f"%{escaped_query}%"
issue_number = issue_number_search_value(normalized_query)
text_filter = or_(
func.lower(Bounty.repo).like(like_query, escape="\\"),
func.lower(Bounty.title).like(like_query, escape="\\"),
func.lower(Bounty.acceptance).like(like_query, escape="\\"),
)
if issue_number is not None:
text_filter = or_(text_filter, Bounty.issue_number == issue_number)
query = query.where(text_filter)
bounties = session.scalars(query.order_by(Bounty.id.desc())).all()
sorted_bounties = sort_bounties(
[bounty_to_dict(bounty) for bounty in bounties], normalized_sort
Expand Down
16 changes: 15 additions & 1 deletion app/mcp_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@


def call_mcp_tool(database_url: str, name: str, args: dict[str, Any]) -> str | dict[str, Any]:
def reject_control_chars(field: str, value: str) -> None:
if any(ord(char) < 32 or ord(char) == 127 for char in value):
raise ValueError(f"{field} must not contain control characters")

def int_arg(field: str) -> int:
value = args[field]
if isinstance(value, bool):
Expand Down Expand Up @@ -79,6 +83,16 @@ def optional_clean_str_arg(field: str) -> str | None:
clean = value.strip()
return clean or None

def optional_search_query_arg(field: str) -> str | None:
value = args.get(field)
if value is None:
return None
if not isinstance(value, str):
raise ValueError(f"{field} must be a string")
reject_control_chars(field, value)
clean = value.strip()
return clean or None

def output_format_arg() -> str:
value = args.get("format", "text")
if value is None:
Expand Down Expand Up @@ -121,7 +135,7 @@ def optional_bool_arg(field: str, default: bool = False) -> bool:
if normalized_status not in {"open", "paid", "closed"}:
raise ValueError("status must be one of: open, paid, closed")
query = select(Bounty).where(Bounty.status == normalized_status)
query_text = optional_clean_str_arg("q")
query_text = optional_search_query_arg("q")
if query_text:
escaped_query = (
query_text.lower().replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
Expand Down
22 changes: 22 additions & 0 deletions tests/test_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,28 @@ def test_activity_api_filters_accepted_work_by_query(sqlite_url: str) -> None:
assert invalid_hash_query["recent"] == []


def test_activity_rejects_control_character_search_query(sqlite_url: str) -> None:
create_schema(sqlite_url)
with session_scope(sqlite_url) as session:
ensure_genesis(session)

client = TestClient(create_app(database_url=sqlite_url, webhook_secret="secret"))

api_response = client.get("/api/v1/activity", params={"q": "\x00"})
page_response = client.get("/activity", params={"q": "\x00"})
leading_tab_response = client.get("/api/v1/activity", params={"q": "\talice"})
trailing_newline_response = client.get("/activity", params={"q": "alice\n"})

assert api_response.status_code == 400
assert api_response.json() == {"detail": "q must not contain control characters"}
assert page_response.status_code == 400
assert page_response.json() == {"detail": "q must not contain control characters"}
assert leading_tab_response.status_code == 400
assert leading_tab_response.json() == {"detail": "q must not contain control characters"}
assert trailing_newline_response.status_code == 400
assert trailing_newline_response.json() == {"detail": "q must not contain control characters"}


def test_activity_page_renders_empty_and_paid_states(sqlite_url: str) -> None:
create_schema(sqlite_url)
client = TestClient(create_app(database_url=sqlite_url, webhook_secret="secret"))
Expand Down
3 changes: 3 additions & 0 deletions tests/test_api_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,9 @@ def test_mcp_list_bounties_honors_sort_argument(sqlite_url: str) -> None:
({"limit": 0}, 34),
({"limit": 101}, 35),
({"sort": "invalid"}, 36),
({"q": "\x00"}, 37),
({"q": "\tDocs"}, 38),
({"q": "Docs\n"}, 39),
],
)
def test_mcp_list_bounties_rejects_invalid_filters(
Expand Down
37 changes: 37 additions & 0 deletions tests/test_bounty_api_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,40 @@ def test_bounty_api_limit_rejects_out_of_range_values(sqlite_url: str) -> None:
assert client.get("/api/v1/bounties?limit=201").status_code == 422
assert client.get("/api/v1/bounties/summary?limit=0").status_code == 422
assert client.get("/api/v1/bounties/summary?limit=201").status_code == 422


def test_bounty_api_rejects_control_character_search_queries(sqlite_url: str) -> None:
create_schema(sqlite_url)
with session_scope(sqlite_url) as session:
ensure_genesis(session)
create_bounty(
session,
repo="ramimbo/mergework",
issue_number=53,
issue_url="https://github.com/ramimbo/mergework/issues/53",
title="Control character bounty search",
reward_mrwk="5",
acceptance="Control characters should not widen bounty search results.",
)

client = TestClient(create_app(database_url=sqlite_url, webhook_secret="secret"))

list_response = client.get("/api/v1/bounties?q=%00")
summary_response = client.get("/api/v1/bounties/summary?q=%00")
del_list_response = client.get("/api/v1/bounties", params={"q": "\x7f"})
del_summary_response = client.get("/api/v1/bounties/summary", params={"q": "\x7f"})
leading_tab_response = client.get("/api/v1/bounties", params={"q": "\tControl"})
trailing_newline_response = client.get("/api/v1/bounties/summary", params={"q": "Control\n"})

assert list_response.status_code == 400
assert list_response.json()["detail"] == "q must not contain control characters"
assert summary_response.status_code == 400
assert summary_response.json()["detail"] == "q must not contain control characters"
assert del_list_response.status_code == 400
assert del_list_response.json()["detail"] == "q must not contain control characters"
assert del_summary_response.status_code == 400
assert del_summary_response.json()["detail"] == "q must not contain control characters"
assert leading_tab_response.status_code == 400
assert leading_tab_response.json()["detail"] == "q must not contain control characters"
assert trailing_newline_response.status_code == 400
assert trailing_newline_response.json()["detail"] == "q must not contain control characters"
Comment thread
coderabbitai[bot] marked this conversation as resolved.
8 changes: 8 additions & 0 deletions tests/test_bounty_pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,14 @@ def test_bounties_page_and_api_search_by_text_and_issue_number(sqlite_url: str)
assert backslash_search.status_code == 200
assert [row["issue_number"] for row in backslash_search.json()] == [66]

control_char_page_search = client.get("/bounties", params={"q": "\x00"})
assert control_char_page_search.status_code == 400
assert control_char_page_search.json()["detail"] == "q must not contain control characters"

leading_tab_page_search = client.get("/bounties", params={"q": "\tDocs"})
assert leading_tab_page_search.status_code == 400
assert leading_tab_page_search.json()["detail"] == "q must not contain control characters"


def test_bounties_page_and_api_sort_public_rows(sqlite_url: str) -> None:
create_schema(sqlite_url)
Expand Down
Loading