Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions scripts/tests/test_harvest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import importlib.util
import os
import sys
import pytest
from pathlib import Path
from unittest.mock import patch

# Load harvest.py
script_path = Path(__file__).parent.parent / "harvest.py"
spec = importlib.util.spec_from_file_location("harvest", script_path)
harvest = importlib.util.module_from_spec(spec)
sys.modules["harvest"] = harvest
spec.loader.exec_module(harvest)

def test_extract_session_id_from_event_session_id():
event = {"session_id": "test-session"}
assert harvest.extract_session_id(event) == "test-session"

def test_extract_session_id_from_event_sessionId_fallback():
event = {"sessionId": "test-session-camel"}
assert harvest.extract_session_id(event) == "test-session-camel"

def test_extract_session_id_from_env_claude_session_id():
with patch.dict(os.environ, {"CLAUDE_SESSION_ID": "env-session"}):
# Ensure other env vars don't interfere
if "CLAUDE_CODE_SESSION_ID" in os.environ:
del os.environ["CLAUDE_CODE_SESSION_ID"]
assert harvest.extract_session_id({}) == "env-session"

def test_extract_session_id_from_env_claude_code_session_id_fallback():
with patch.dict(os.environ, {"CLAUDE_CODE_SESSION_ID": "code-env-session"}):
# Ensure CLAUDE_SESSION_ID is not present
with patch.dict(os.environ, {}):
if "CLAUDE_SESSION_ID" in os.environ:
del os.environ["CLAUDE_SESSION_ID"]
assert harvest.extract_session_id({}) == "code-env-session"

def test_extract_session_id_default_unknown():
with patch.dict(os.environ, {}, clear=True):
assert harvest.extract_session_id({}) == "unknown-session"

def test_extract_session_id_normalization():
event = {"session_id": "session/with/slashes and spaces!!!"}
# Expected: "session-with-slashes-and-spaces"
# Logic: re.sub(r"[^A-Za-z0-9._-]", "-", str(raw)).strip("-")
assert harvest.extract_session_id(event) == "session-with-slashes-and-spaces"

def test_extract_session_id_priority():
event = {"session_id": "event-session"}
with patch.dict(os.environ, {"CLAUDE_SESSION_ID": "env-session"}):
assert harvest.extract_session_id(event) == "event-session"

def test_extract_session_id_empty_after_normalization():
event = {"session_id": "!!!"}
assert harvest.extract_session_id(event) == "unknown-session"

def test_extract_session_id_non_string_input():
event = {"session_id": 12345}
assert harvest.extract_session_id(event) == "12345"
33 changes: 14 additions & 19 deletions scripts/tests/test_index_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from unittest.mock import MagicMock, patch
import time
import sqlite3
import json

# Load index-vault.py
script_path = Path(__file__).parent.parent / "index-vault.py"
Expand Down Expand Up @@ -218,19 +219,19 @@ def test_query_index_single_keyword():
args, _ = mock_conn.execute.call_args
sql, params = args[0], args[1]

# Verify specific parts of the SQL
assert "WHERE (title LIKE ? OR summary LIKE ? OR tags LIKE ? OR rel_path LIKE ?)" in sql
assert "ORDER BY score DESC, body_chars DESC" in sql
assert "LIMIT ?" in sql
# Verify that json_each is used
assert "json_each(?)" in sql
assert "WHERE score > 0" in sql
assert "ORDER BY score DESC" in sql

# Verify the parameter binding
expected_params = ["%test%"] * 8 + [10] # 4 for score, 4 for where, 1 for limit
assert params == expected_params
assert params[0] == json.dumps(["test"])
assert params[1] == 10 # default limit
mock_conn.close.assert_called_once()


def test_query_index_multiple_keywords():
"""Test that multiple keywords use OR between the where clauses and sum the score parts."""
"""Test that multiple keywords are passed as a JSON list to the SQL query."""
vault = pathlib.Path("dummy")
with patch.object(index_vault, "get_db") as mock_get_db:
mock_conn = MagicMock()
Expand All @@ -244,15 +245,10 @@ def test_query_index_multiple_keywords():

mock_conn.execute.assert_called_once()
args, _ = mock_conn.execute.call_args
sql, params = args[0], args[1]

# Two score parts joined by '+'
assert " + " in sql.split("AS score")[0]
# Two where parts joined by 'OR'
assert "WHERE (title LIKE ? OR summary LIKE ? OR tags LIKE ? OR rel_path LIKE ?) OR (title LIKE ? OR summary LIKE ? OR tags LIKE ? OR rel_path LIKE ?)" in sql
params = args[1]

expected_params = ["%foo%"] * 8 + ["%bar%"] * 8 + [5]
assert params == expected_params
assert params[0] == json.dumps(["foo", "bar"])
assert params[1] == 5


def test_query_index_respects_limit():
Expand All @@ -274,7 +270,7 @@ def test_query_index_respects_limit():


def test_query_index_special_characters():
"""Test that special characters in keywords are properly bound using parameters to prevent SQL injection."""
"""Test that special characters in keywords are properly bound via JSON to prevent SQL injection."""
vault = pathlib.Path("dummy")
with patch.object(index_vault, "get_db") as mock_get_db:
mock_conn = MagicMock()
Expand All @@ -294,9 +290,8 @@ def test_query_index_special_characters():
# Check that the keyword itself does not appear raw in the SQL query
assert keyword not in sql

# Check that it appears in parameters with % wrapped
expected_bound_value = f"%{keyword}%"
assert all(p == expected_bound_value for p in params[:-1])
# Check that it is correctly encoded in the JSON parameter
assert keyword in json.loads(params[0])


def test_query_index_row_to_dict():
Expand Down
Loading