From cb1b6889d5a9a6eb084ac44b0f73dbce2829a39b Mon Sep 17 00:00:00 2001 From: Ashwin U Sokke Date: Mon, 13 Apr 2026 08:12:36 -0400 Subject: [PATCH] fix(mcp): use POST search and validate store size --- tests/test_mcp.py | 31 +++++++++++++++++++++++-------- tools/mcp/server.py | 13 ++++++++----- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/tests/test_mcp.py b/tests/test_mcp.py index a096974..9036dc4 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -107,6 +107,13 @@ def test_store_error(mcp_server): assert "Embedding failed" in result +def test_store_rejects_oversized_content_without_http(mcp_server): + with patch.object(mcp_server.urllib.request, "urlopen") as mocked: + result = mcp_server.memory_store("x" * (mcp_server._MAX_COMMIT_TEXT_CHARS + 1)) + assert "content too long" in result + mocked.assert_not_called() + + # -- memory_search ---------------------------------------------------------- @@ -149,20 +156,24 @@ def test_search_empty(mcp_server): assert result == "No matching memories found." -def test_search_limit_clamped(mcp_server): - captured_url = {} +def test_search_limit_clamped_and_uses_post_body(mcp_server): + captured = {} def fake_urlopen(req, **kwargs): - captured_url["url"] = req.full_url + captured["url"] = req.full_url + captured["method"] = req.get_method() + captured["data"] = json.loads(req.data.decode()) return _make_response({"results": [], "elapsed_ms": 0}) with patch.object(mcp_server.urllib.request, "urlopen", side_effect=fake_urlopen): mcp_server.memory_search("q", limit=999) - assert "limit=30" in captured_url["url"] + assert captured["url"] == "http://test:7777/search" + assert captured["method"] == "POST" + assert captured["data"]["limit"] == 30 with patch.object(mcp_server.urllib.request, "urlopen", side_effect=fake_urlopen): mcp_server.memory_search("q", limit=-5) - assert "limit=1" in captured_url["url"] + assert captured["data"]["limit"] == 1 def test_search_score_fallback_chain(mcp_server): @@ -375,12 +386,16 @@ def test_collection_params_custom_bank(mcp_server, monkeypatch): def test_search_includes_collection_param(mcp_server, monkeypatch): monkeypatch.setattr(mcp_server, "BANK_ID", "custom-bank") - captured_url = {} + captured = {} def fake_urlopen(req, **kwargs): - captured_url["url"] = req.full_url + captured["url"] = req.full_url + captured["method"] = req.get_method() + captured["data"] = json.loads(req.data.decode()) return _make_response({"results": [], "elapsed_ms": 0}) with patch.object(mcp_server.urllib.request, "urlopen", side_effect=fake_urlopen): mcp_server.memory_search("test") - assert "collection=custom-bank" in captured_url["url"] + assert captured["url"] == "http://test:7777/search" + assert captured["method"] == "POST" + assert captured["data"]["collection"] == "custom-bank" diff --git a/tools/mcp/server.py b/tools/mcp/server.py index 8400b53..c57aada 100644 --- a/tools/mcp/server.py +++ b/tools/mcp/server.py @@ -15,7 +15,6 @@ import json import logging import os -import urllib.parse import urllib.request from typing import Any @@ -23,6 +22,8 @@ logger = logging.getLogger("rasputin.mcp") +_MAX_COMMIT_TEXT_CHARS = 8000 + RASPUTIN_URL = os.environ.get("RASPUTIN_URL", "http://127.0.0.1:7777") RASPUTIN_TOKEN = os.environ.get("RASPUTIN_TOKEN", "") BANK_ID = os.environ.get("RASPUTIN_BANK_ID", "") @@ -91,6 +92,9 @@ def memory_store( importance: Priority 0-100. Default 60. Use 80+ for critical decisions, 40 for background context. """ + if len(content) > _MAX_COMMIT_TEXT_CHARS: + return f"Error: content too long (max {_MAX_COMMIT_TEXT_CHARS} characters)." + payload: dict[str, Any] = { "text": content, "source": source, @@ -133,10 +137,9 @@ def memory_search( limit: Maximum results to return (1-30, default 10). """ limit = max(1, min(30, limit)) - qs: dict[str, Any] = {"q": query, "limit": limit} - qs.update(_collection_params()) - params = urllib.parse.urlencode(qs) - result = _api(f"/search?{params}") + payload: dict[str, Any] = {"q": query, "limit": limit} + payload.update(_collection_params()) + result = _api("/search", method="POST", data=payload) results = result.get("results", []) if not results: return "No matching memories found."