Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 29 additions & 23 deletions src/agent_memory/storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,30 +111,34 @@ def initialize_database(db_path: Path | str) -> None:
connection.execute(
"CREATE INDEX IF NOT EXISTS idx_episodes_status_scope_importance ON episodes(status, scope, importance_score)"
)
connection.execute(
"""
CREATE TABLE IF NOT EXISTS retrieval_observations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
surface TEXT NOT NULL,
query_sha256 TEXT NOT NULL,
query_preview TEXT,
preferred_scope TEXT,
limit_value INTEGER NOT NULL,
statuses_json TEXT NOT NULL DEFAULT '["approved"]',
retrieved_memory_refs_json TEXT NOT NULL DEFAULT '[]',
top_memory_ref TEXT,
response_mode TEXT CHECK (response_mode IN ('direct', 'cautious', 'verify_first')),
metadata_json TEXT NOT NULL DEFAULT '{}'
)
"""
)
connection.execute(
"CREATE INDEX IF NOT EXISTS idx_retrieval_observations_created_at ON retrieval_observations(created_at, id)"
)
connection.execute(
"CREATE INDEX IF NOT EXISTS idx_retrieval_observations_surface ON retrieval_observations(surface, created_at)"
_ensure_retrieval_observations_schema(connection)


def _ensure_retrieval_observations_schema(connection: sqlite3.Connection) -> None:
connection.execute(
"""
CREATE TABLE IF NOT EXISTS retrieval_observations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
surface TEXT NOT NULL,
query_sha256 TEXT NOT NULL,
query_preview TEXT,
preferred_scope TEXT,
limit_value INTEGER NOT NULL,
statuses_json TEXT NOT NULL DEFAULT '["approved"]',
retrieved_memory_refs_json TEXT NOT NULL DEFAULT '[]',
top_memory_ref TEXT,
response_mode TEXT CHECK (response_mode IN ('direct', 'cautious', 'verify_first')),
metadata_json TEXT NOT NULL DEFAULT '{}'
)
"""
)
connection.execute(
"CREATE INDEX IF NOT EXISTS idx_retrieval_observations_created_at ON retrieval_observations(created_at, id)"
)
connection.execute(
"CREATE INDEX IF NOT EXISTS idx_retrieval_observations_surface ON retrieval_observations(surface, created_at)"
)


def _ensure_memory_table_columns(
Expand Down Expand Up @@ -829,6 +833,7 @@ def record_retrieval_observation(
top_memory_ref = retrieved_memory_refs[0] if retrieved_memory_refs else None
query_sha256 = hashlib.sha256(query.encode("utf-8")).hexdigest()
with connect(db_path) as connection:
_ensure_retrieval_observations_schema(connection)
cursor = connection.execute(
"""
INSERT INTO retrieval_observations (
Expand Down Expand Up @@ -864,6 +869,7 @@ def record_retrieval_observation(

def list_retrieval_observations(db_path: Path | str, *, limit: int = 50) -> list[RetrievalObservation]:
with connect(db_path) as connection:
_ensure_retrieval_observations_schema(connection)
rows = connection.execute(
"""
SELECT *
Expand Down
26 changes: 26 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,32 @@ def test_python_module_cli_retrieve_observe_records_secret_safe_local_observatio
assert "abc123" not in list_result.stdout


def test_python_module_cli_observations_list_migrates_existing_database_without_observation_table(tmp_path: Path) -> None:
db_path = tmp_path / "legacy-observation.db"
initialize_database(db_path)
import sqlite3

with sqlite3.connect(db_path) as connection:
connection.execute("DROP TABLE retrieval_observations")

result = subprocess.run(
[
sys.executable,
"-m",
"agent_memory.api.cli",
"observations",
"list",
str(db_path),
],
text=True,
capture_output=True,
check=True,
)

payload = json.loads(result.stdout)
assert payload["observations"] == []


def test_python_module_cli_retrieve_defaults_to_approved_and_hides_disputed_content(tmp_path: Path) -> None:
db_path = tmp_path / "retrieve-approved-only.db"
initialize_database(db_path)
Expand Down