diff --git a/scripts/tests/test_harvest.py b/scripts/tests/test_harvest.py new file mode 100644 index 0000000..b98c35d --- /dev/null +++ b/scripts/tests/test_harvest.py @@ -0,0 +1,59 @@ +import importlib.util +import os +import sys +import pytest +from pathlib import Path +from unittest.mock import patch + +# Load harvest.py +script_path = Path(__file__).parent.parent / "harvest.py" +spec = importlib.util.spec_from_file_location("harvest", script_path) +harvest = importlib.util.module_from_spec(spec) +sys.modules["harvest"] = harvest +spec.loader.exec_module(harvest) + +def test_extract_session_id_from_event_session_id(): + event = {"session_id": "test-session"} + assert harvest.extract_session_id(event) == "test-session" + +def test_extract_session_id_from_event_sessionId_fallback(): + event = {"sessionId": "test-session-camel"} + assert harvest.extract_session_id(event) == "test-session-camel" + +def test_extract_session_id_from_env_claude_session_id(): + with patch.dict(os.environ, {"CLAUDE_SESSION_ID": "env-session"}): + # Ensure other env vars don't interfere + if "CLAUDE_CODE_SESSION_ID" in os.environ: + del os.environ["CLAUDE_CODE_SESSION_ID"] + assert harvest.extract_session_id({}) == "env-session" + +def test_extract_session_id_from_env_claude_code_session_id_fallback(): + with patch.dict(os.environ, {"CLAUDE_CODE_SESSION_ID": "code-env-session"}): + # Ensure CLAUDE_SESSION_ID is not present + with patch.dict(os.environ, {}): + if "CLAUDE_SESSION_ID" in os.environ: + del os.environ["CLAUDE_SESSION_ID"] + assert harvest.extract_session_id({}) == "code-env-session" + +def test_extract_session_id_default_unknown(): + with patch.dict(os.environ, {}, clear=True): + assert harvest.extract_session_id({}) == "unknown-session" + +def test_extract_session_id_normalization(): + event = {"session_id": "session/with/slashes and spaces!!!"} + # Expected: "session-with-slashes-and-spaces" + # Logic: re.sub(r"[^A-Za-z0-9._-]", "-", str(raw)).strip("-") + assert harvest.extract_session_id(event) == "session-with-slashes-and-spaces" + +def test_extract_session_id_priority(): + event = {"session_id": "event-session"} + with patch.dict(os.environ, {"CLAUDE_SESSION_ID": "env-session"}): + assert harvest.extract_session_id(event) == "event-session" + +def test_extract_session_id_empty_after_normalization(): + event = {"session_id": "!!!"} + assert harvest.extract_session_id(event) == "unknown-session" + +def test_extract_session_id_non_string_input(): + event = {"session_id": 12345} + assert harvest.extract_session_id(event) == "12345" diff --git a/scripts/tests/test_index_vault.py b/scripts/tests/test_index_vault.py index ea249ae..3a78f63 100644 --- a/scripts/tests/test_index_vault.py +++ b/scripts/tests/test_index_vault.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch import time import sqlite3 +import json # Load index-vault.py script_path = Path(__file__).parent.parent / "index-vault.py" @@ -218,19 +219,19 @@ def test_query_index_single_keyword(): args, _ = mock_conn.execute.call_args sql, params = args[0], args[1] - # Verify specific parts of the SQL - assert "WHERE (title LIKE ? OR summary LIKE ? OR tags LIKE ? OR rel_path LIKE ?)" in sql - assert "ORDER BY score DESC, body_chars DESC" in sql - assert "LIMIT ?" in sql + # Verify that json_each is used + assert "json_each(?)" in sql + assert "WHERE score > 0" in sql + assert "ORDER BY score DESC" in sql # Verify the parameter binding - expected_params = ["%test%"] * 8 + [10] # 4 for score, 4 for where, 1 for limit - assert params == expected_params + assert params[0] == json.dumps(["test"]) + assert params[1] == 10 # default limit mock_conn.close.assert_called_once() def test_query_index_multiple_keywords(): - """Test that multiple keywords use OR between the where clauses and sum the score parts.""" + """Test that multiple keywords are passed as a JSON list to the SQL query.""" vault = pathlib.Path("dummy") with patch.object(index_vault, "get_db") as mock_get_db: mock_conn = MagicMock() @@ -244,15 +245,10 @@ def test_query_index_multiple_keywords(): mock_conn.execute.assert_called_once() args, _ = mock_conn.execute.call_args - sql, params = args[0], args[1] - - # Two score parts joined by '+' - assert " + " in sql.split("AS score")[0] - # Two where parts joined by 'OR' - assert "WHERE (title LIKE ? OR summary LIKE ? OR tags LIKE ? OR rel_path LIKE ?) OR (title LIKE ? OR summary LIKE ? OR tags LIKE ? OR rel_path LIKE ?)" in sql + params = args[1] - expected_params = ["%foo%"] * 8 + ["%bar%"] * 8 + [5] - assert params == expected_params + assert params[0] == json.dumps(["foo", "bar"]) + assert params[1] == 5 def test_query_index_respects_limit(): @@ -274,7 +270,7 @@ def test_query_index_respects_limit(): def test_query_index_special_characters(): - """Test that special characters in keywords are properly bound using parameters to prevent SQL injection.""" + """Test that special characters in keywords are properly bound via JSON to prevent SQL injection.""" vault = pathlib.Path("dummy") with patch.object(index_vault, "get_db") as mock_get_db: mock_conn = MagicMock() @@ -294,9 +290,8 @@ def test_query_index_special_characters(): # Check that the keyword itself does not appear raw in the SQL query assert keyword not in sql - # Check that it appears in parameters with % wrapped - expected_bound_value = f"%{keyword}%" - assert all(p == expected_bound_value for p in params[:-1]) + # Check that it is correctly encoded in the JSON parameter + assert keyword in json.loads(params[0]) def test_query_index_row_to_dict():