diff --git a/README.md b/README.md index 1b65820a..cb2802cc 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ Zen for Python 3 is compatible with: * ✅ [`asyncpg`](https://pypi.org/project/asyncpg) ^0.27 * ✅ [`motor`](https://pypi.org/project/motor/) (See `pymongo` version) * ✅ [`clickhouse-driver`](https://pypi.org/project/clickhouse-driver) +* ✅ [`sqlite3`](https://docs.python.org/3/library/sqlite3.html) ### AI SDKs Zen instruments the following AI SDKs to track which models are used and how many tokens are consumed, allowing you to monitor your AI usage and costs: diff --git a/aikido_zen/__init__.py b/aikido_zen/__init__.py index c65ba350..1739c4ad 100644 --- a/aikido_zen/__init__.py +++ b/aikido_zen/__init__.py @@ -72,6 +72,7 @@ def protect(mode="daemon", token=""): import aikido_zen.sinks.psycopg import aikido_zen.sinks.asyncpg import aikido_zen.sinks.clickhouse_driver + import aikido_zen.sinks.sqlite3 import aikido_zen.sinks.builtins import aikido_zen.sinks.os diff --git a/aikido_zen/helpers/modify_arguments.py b/aikido_zen/helpers/modify_arguments.py new file mode 100644 index 00000000..159b47c4 --- /dev/null +++ b/aikido_zen/helpers/modify_arguments.py @@ -0,0 +1,17 @@ +"""Exports modify_arguments""" + + +def modify_arguments(args, kwargs, pos, name, value): + """ + Returns (new_args, new_kwargs) with `value` injected as keyword argument + `name`. If a positional argument exists at index `pos` or beyond, it is + removed from args so the call is not duplicated. + """ + if len(args) > pos: + new_args = args[:pos] + (value,) + args[pos + 1 :] + new_kwargs = dict(kwargs) + else: + new_args = args + new_kwargs = dict(kwargs) + new_kwargs[name] = value + return new_args, new_kwargs diff --git a/aikido_zen/helpers/modify_arguments_test.py b/aikido_zen/helpers/modify_arguments_test.py new file mode 100644 index 00000000..ccea3483 --- /dev/null +++ b/aikido_zen/helpers/modify_arguments_test.py @@ -0,0 +1,58 @@ +import pytest +from .modify_arguments import modify_arguments + + +def test_injects_value_as_kwarg(): + args, kwargs = modify_arguments((), {}, 0, "key", "val") + assert kwargs["key"] == "val" + assert args == () + + +def test_overwrites_positional_arg_at_pos(): + args, kwargs = modify_arguments(("a", "b", "c"), {}, 2, "key", "new") + assert args == ("a", "b", "new") + assert "key" not in kwargs + + +def test_overwrites_positional_arg_keeps_trailing_args(): + args, kwargs = modify_arguments(("a", "b", "c", "d"), {}, 2, "key", "new") + assert args == ("a", "b", "new", "d") + assert "key" not in kwargs + + +def test_injects_as_kwarg_when_pos_not_in_args(): + args, kwargs = modify_arguments(("a", "b"), {}, 5, "key", "new") + assert args == ("a", "b") + assert kwargs["key"] == "new" + + +def test_overwrites_existing_kwarg(): + args, kwargs = modify_arguments((), {"key": "old"}, 0, "key", "new") + assert kwargs["key"] == "new" + + +def test_does_not_mutate_original_kwargs(): + original = {"other": 1} + _, kwargs = modify_arguments((), original, 0, "key", "val") + assert "key" not in original + assert kwargs["other"] == 1 + + +def test_does_not_mutate_original_args(): + original = ("a", "b", "c") + new_args, _ = modify_arguments(original, {}, 1, "key", "val") + assert original == ("a", "b", "c") + assert new_args == ("a", "val", "c") + + +def test_empty_args_and_kwargs(): + args, kwargs = modify_arguments((), {}, 3, "key", 42) + assert args == () + assert kwargs == {"key": 42} + + +def test_preserves_other_kwargs(): + args, kwargs = modify_arguments((), {"a": 1, "b": 2}, 5, "c", 3) + assert kwargs["a"] == 1 + assert kwargs["b"] == 2 + assert kwargs["c"] == 3 diff --git a/aikido_zen/sinks/__init__.py b/aikido_zen/sinks/__init__.py index e962d781..0ad392e5 100644 --- a/aikido_zen/sinks/__init__.py +++ b/aikido_zen/sinks/__init__.py @@ -135,3 +135,22 @@ async def decorator(func, instance, args, kwargs): return return_value return decorator + + +def patch_immutable_class(base_cls, method_patches): + modifiable_attributes = {} + for name in method_patches: + modifiable_attributes[name] = getattr(base_cls, name) + + cls = type( + base_cls.__name__, + (base_cls,), + # this modifiable_attributes object contains a python (not c) map of functions, so we can apply the + # patch_function to these attributes of our new class. + modifiable_attributes, + ) + + for name, wrapper in method_patches.items(): + patch_function(cls, name, wrapper) + + return cls diff --git a/aikido_zen/sinks/sqlite3.py b/aikido_zen/sinks/sqlite3.py new file mode 100644 index 00000000..830caf4b --- /dev/null +++ b/aikido_zen/sinks/sqlite3.py @@ -0,0 +1,71 @@ +import sqlite3 as _sqlite3 +import sys +from aikido_zen.helpers.get_argument import get_argument +from aikido_zen.helpers.modify_arguments import modify_arguments +import aikido_zen.vulnerabilities as vulns +from aikido_zen.helpers.register_call import register_call +from aikido_zen.sinks import ( + patch_function, + on_import, + before, + patch_immutable_class, +) + + +@before +def _execute(func, instance, args, kwargs): + op = f"sqlite3.{type(instance).__name__}.{func.__name__}" + query = get_argument(args, kwargs, 0, "sql") + register_call(op, "sql_op") + vulns.run_vulnerability_scan(kind="sql_injection", op=op, args=(query, "sqlite")) + + +@before +def _executescript(func, instance, args, kwargs): + op = f"sqlite3.{type(instance).__name__}.{func.__name__}" + query = get_argument(args, kwargs, 0, "sql_script") + register_call(op, "sql_op") + vulns.run_vulnerability_scan(kind="sql_injection", op=op, args=(query, "sqlite")) + + +def _cursor_patch(func, instance, args, kwargs): + factory = get_argument(args, kwargs, 0, "factory") or _sqlite3.Cursor + patched_factory = patch_immutable_class( + factory, + { + "execute": _execute, + "executemany": _execute, + "executescript": _executescript, + }, + ) + + new_args, new_kwargs = modify_arguments(args, kwargs, 0, "factory", patched_factory) + return func(*new_args, **new_kwargs) + + +def _connect(func, instance, args, kwargs): + factory = get_argument(args, kwargs, 5, "factory") or _sqlite3.Connection + connection_patches = {"cursor": _cursor_patch} + + # In Python 3.11, the sqlite3 module was fully moved to C. Hence the extra patches + if sys.version_info >= (3, 11): + connection_patches.update( + { + "execute": _execute, + "executemany": _execute, + "executescript": _executescript, + } + ) + + patched_factory = patch_immutable_class(factory, connection_patches) + new_args, new_kwargs = modify_arguments(args, kwargs, 5, "factory", patched_factory) + return func(*new_args, **new_kwargs) + + +@on_import("sqlite3") +def patch(m): + """ + patches sqlite3, a c library; the "connect" function is not c, after that we use patch_immutable_class to + patch the factory parameter of the connect function. In this factory we patch the cursor function. + """ + patch_function(m, "connect", _connect) diff --git a/aikido_zen/sinks/tests/sqlite3_test.py b/aikido_zen/sinks/tests/sqlite3_test.py new file mode 100644 index 00000000..9ce689ef --- /dev/null +++ b/aikido_zen/sinks/tests/sqlite3_test.py @@ -0,0 +1,416 @@ +import pytest +from unittest.mock import patch +import aikido_zen.sinks.sqlite3 +from aikido_zen.background_process.comms import reset_comms + +kind = "sql_injection" + + +@pytest.fixture +def database_conn(): + import sqlite3 + + conn = sqlite3.connect(":memory:") + conn.execute( + "CREATE TABLE dogs (id INTEGER PRIMARY KEY, dog_name TEXT, isAdmin INTEGER)" + ) + conn.commit() + return conn + + +def test_cursor_execute(database_conn): + reset_comms() + with patch( + "aikido_zen.vulnerabilities.run_vulnerability_scan" + ) as mock_run_vulnerability_scan: + cursor = database_conn.cursor() + query = "SELECT * FROM dogs" + cursor.execute(query) + + called_with_args = mock_run_vulnerability_scan.call_args[1]["args"] + assert called_with_args[0] == query + assert called_with_args[1] == "sqlite" + mock_run_vulnerability_scan.assert_called_once() + + cursor.fetchall() + mock_run_vulnerability_scan.assert_called_once() + + cursor.close() + database_conn.close() + mock_run_vulnerability_scan.assert_called_once() + + +def test_cursor_execute_parameterized(database_conn): + reset_comms() + with patch( + "aikido_zen.vulnerabilities.run_vulnerability_scan" + ) as mock_run_vulnerability_scan: + cursor = database_conn.cursor() + query = "INSERT INTO dogs (dog_name, isAdmin) VALUES (?, ?)" + cursor.execute(query, ("doggo", 0)) + + called_with_args = mock_run_vulnerability_scan.call_args[1]["args"] + assert called_with_args[0] == query + assert called_with_args[1] == "sqlite" + mock_run_vulnerability_scan.assert_called_once() + + database_conn.commit() + cursor.close() + database_conn.close() + mock_run_vulnerability_scan.assert_called_once() + + +def test_cursor_execute_no_args(database_conn): + reset_comms() + with patch( + "aikido_zen.vulnerabilities.run_vulnerability_scan" + ) as mock_run_vulnerability_scan: + cursor = database_conn.cursor() + dogname = "Doggo" + isadmin = 1 + query = f"INSERT INTO dogs (dog_name, isAdmin) VALUES ('{dogname}', {isadmin})" + cursor.execute(query) + + called_with_args = mock_run_vulnerability_scan.call_args[1]["args"] + assert ( + called_with_args[0] + == "INSERT INTO dogs (dog_name, isAdmin) VALUES ('Doggo', 1)" + ) + assert called_with_args[1] == "sqlite" + mock_run_vulnerability_scan.assert_called_once() + + cursor.close() + database_conn.close() + mock_run_vulnerability_scan.assert_called_once() + + +def test_cursor_execute_with_fstring(database_conn): + reset_comms() + with patch( + "aikido_zen.vulnerabilities.run_vulnerability_scan" + ) as mock_run_vulnerability_scan: + cursor = database_conn.cursor() + table_name = "dogs" + value_2 = "1" + cursor.execute( + f"INSERT INTO {table_name} (dog_name, isAdmin) VALUES (?, {value_2})", + ("doggy",), + ) + + called_with_args = mock_run_vulnerability_scan.call_args[1]["args"] + assert ( + called_with_args[0] == "INSERT INTO dogs (dog_name, isAdmin) VALUES (?, 1)" + ) + assert called_with_args[1] == "sqlite" + mock_run_vulnerability_scan.assert_called_once() + + database_conn.commit() + cursor.close() + database_conn.close() + mock_run_vulnerability_scan.assert_called_once() + + +def test_cursor_executemany(database_conn): + reset_comms() + with patch( + "aikido_zen.vulnerabilities.run_vulnerability_scan" + ) as mock_run_vulnerability_scan: + cursor = database_conn.cursor() + data = [ + ("Doggy", 0), + ("Doggy 2", 1), + ("Dogski", 1), + ] + cursor.executemany("INSERT INTO dogs (dog_name, isAdmin) VALUES (?, ?)", data) + + called_with_args = mock_run_vulnerability_scan.call_args[1]["args"] + assert ( + called_with_args[0] == "INSERT INTO dogs (dog_name, isAdmin) VALUES (?, ?)" + ) + assert called_with_args[1] == "sqlite" + mock_run_vulnerability_scan.assert_called_once() + + database_conn.commit() + cursor.close() + database_conn.close() + mock_run_vulnerability_scan.assert_called_once() + + +def test_cursor_executescript(database_conn): + reset_comms() + with patch( + "aikido_zen.vulnerabilities.run_vulnerability_scan" + ) as mock_run_vulnerability_scan: + cursor = database_conn.cursor() + script = """ + INSERT INTO dogs (dog_name, isAdmin) VALUES ('Fido', 0); + INSERT INTO dogs (dog_name, isAdmin) VALUES ('Rex', 1); + """ + cursor.executescript(script) + + called_with_args = mock_run_vulnerability_scan.call_args[1]["args"] + assert called_with_args[0] == script + assert called_with_args[1] == "sqlite" + mock_run_vulnerability_scan.assert_called_once() + + cursor.close() + database_conn.close() + mock_run_vulnerability_scan.assert_called_once() + + +def test_connection_execute(database_conn): + reset_comms() + with patch( + "aikido_zen.vulnerabilities.run_vulnerability_scan" + ) as mock_run_vulnerability_scan: + query = "SELECT * FROM dogs" + database_conn.execute(query) + + called_with_args = mock_run_vulnerability_scan.call_args[1]["args"] + assert called_with_args[0] == query + assert called_with_args[1] == "sqlite" + mock_run_vulnerability_scan.assert_called_once() + + database_conn.close() + mock_run_vulnerability_scan.assert_called_once() + + +def test_connection_execute_parameterized(database_conn): + reset_comms() + with patch( + "aikido_zen.vulnerabilities.run_vulnerability_scan" + ) as mock_run_vulnerability_scan: + query = "INSERT INTO dogs (dog_name, isAdmin) VALUES (?, ?)" + database_conn.execute(query, ("doggo", 0)) + + called_with_args = mock_run_vulnerability_scan.call_args[1]["args"] + assert called_with_args[0] == query + assert called_with_args[1] == "sqlite" + mock_run_vulnerability_scan.assert_called_once() + + database_conn.commit() + database_conn.close() + mock_run_vulnerability_scan.assert_called_once() + + +def test_connection_executemany(database_conn): + reset_comms() + with patch( + "aikido_zen.vulnerabilities.run_vulnerability_scan" + ) as mock_run_vulnerability_scan: + data = [ + ("Doggy", 0), + ("Doggy 2", 1), + ("Dogski", 1), + ] + database_conn.executemany( + "INSERT INTO dogs (dog_name, isAdmin) VALUES (?, ?)", data + ) + + called_with_args = mock_run_vulnerability_scan.call_args[1]["args"] + assert ( + called_with_args[0] == "INSERT INTO dogs (dog_name, isAdmin) VALUES (?, ?)" + ) + assert called_with_args[1] == "sqlite" + mock_run_vulnerability_scan.assert_called_once() + + database_conn.commit() + database_conn.close() + mock_run_vulnerability_scan.assert_called_once() + + +def test_connection_executescript(database_conn): + reset_comms() + with patch( + "aikido_zen.vulnerabilities.run_vulnerability_scan" + ) as mock_run_vulnerability_scan: + script = """ + INSERT INTO dogs (dog_name, isAdmin) VALUES ('Fido', 0); + INSERT INTO dogs (dog_name, isAdmin) VALUES ('Rex', 1); + """ + database_conn.executescript(script) + + called_with_args = mock_run_vulnerability_scan.call_args[1]["args"] + assert called_with_args[0] == script + assert called_with_args[1] == "sqlite" + mock_run_vulnerability_scan.assert_called_once() + + database_conn.close() + mock_run_vulnerability_scan.assert_called_once() + + +# Functional tests — verify sqlite3 behavior is not broken by patching + + +def test_cursor_execute_returns_results(database_conn): + cursor = database_conn.cursor() + cursor.execute("INSERT INTO dogs (dog_name, isAdmin) VALUES (?, ?)", ("Fido", 1)) + database_conn.commit() + + cursor.execute("SELECT * FROM dogs WHERE dog_name = ?", ("Fido",)) + rows = cursor.fetchall() + assert len(rows) == 1 + assert rows[0][1] == "Fido" + assert rows[0][2] == 1 + cursor.close() + + +def test_cursor_fetchone(database_conn): + cursor = database_conn.cursor() + cursor.execute("INSERT INTO dogs (dog_name, isAdmin) VALUES (?, ?)", ("Rex", 0)) + database_conn.commit() + + cursor.execute("SELECT * FROM dogs") + row = cursor.fetchone() + assert row is not None + assert row[1] == "Rex" + cursor.close() + + +def test_cursor_fetchmany(database_conn): + cursor = database_conn.cursor() + data = [("Dog1", 0), ("Dog2", 1), ("Dog3", 0)] + cursor.executemany("INSERT INTO dogs (dog_name, isAdmin) VALUES (?, ?)", data) + database_conn.commit() + + cursor.execute("SELECT * FROM dogs") + rows = cursor.fetchmany(2) + assert len(rows) == 2 + cursor.close() + + +def test_cursor_rowcount(database_conn): + cursor = database_conn.cursor() + cursor.execute("INSERT INTO dogs (dog_name, isAdmin) VALUES (?, ?)", ("Buddy", 0)) + assert cursor.rowcount == 1 + cursor.close() + + +def test_cursor_lastrowid(database_conn): + cursor = database_conn.cursor() + cursor.execute("INSERT INTO dogs (dog_name, isAdmin) VALUES (?, ?)", ("Max", 0)) + assert cursor.lastrowid is not None + assert cursor.lastrowid > 0 + cursor.close() + + +def test_cursor_description(database_conn): + cursor = database_conn.cursor() + cursor.execute("SELECT * FROM dogs") + assert cursor.description is not None + col_names = [col[0] for col in cursor.description] + assert col_names == ["id", "dog_name", "isAdmin"] + cursor.close() + + +def test_executemany_inserts_all_rows(database_conn): + cursor = database_conn.cursor() + data = [("Dog1", 0), ("Dog2", 1), ("Dog3", 0)] + cursor.executemany("INSERT INTO dogs (dog_name, isAdmin) VALUES (?, ?)", data) + database_conn.commit() + + cursor.execute("SELECT COUNT(*) FROM dogs") + count = cursor.fetchone()[0] + assert count == 3 + cursor.close() + + +def test_executescript_runs_all_statements(database_conn): + cursor = database_conn.cursor() + script = """ + INSERT INTO dogs (dog_name, isAdmin) VALUES ('Script1', 0); + INSERT INTO dogs (dog_name, isAdmin) VALUES ('Script2', 1); + """ + cursor.executescript(script) + + cursor.execute("SELECT COUNT(*) FROM dogs") + count = cursor.fetchone()[0] + assert count == 2 + cursor.close() + + +def test_connection_as_context_manager(): + import sqlite3 + + with sqlite3.connect(":memory:") as conn: + conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, val TEXT)") + conn.execute("INSERT INTO test (val) VALUES (?)", ("hello",)) + row = conn.execute("SELECT val FROM test").fetchone() + assert row[0] == "hello" + + +def test_connect_with_custom_factory(): + import sqlite3 + + class CustomConnection(sqlite3.Connection): + def custom_method(self): + return "custom" + + conn = sqlite3.connect(":memory:", factory=CustomConnection) + assert isinstance(conn, CustomConnection) + assert conn.custom_method() == "custom" + + # Verify cursor operations still work through a custom factory + conn.execute("CREATE TABLE t (v TEXT)") + conn.execute("INSERT INTO t VALUES (?)", ("x",)) + row = conn.execute("SELECT v FROM t").fetchone() + assert row[0] == "x" + conn.close() + + +def test_connect_with_keyword_database(): + import sqlite3 + + conn = sqlite3.connect(database=":memory:") + conn.execute("CREATE TABLE kw_test (id INTEGER PRIMARY KEY, val TEXT)") + conn.execute("INSERT INTO kw_test (val) VALUES (?)", ("kw",)) + row = conn.execute("SELECT val FROM kw_test").fetchone() + assert row[0] == "kw" + conn.close() + + +def test_row_factory(database_conn): + import sqlite3 + + database_conn.row_factory = sqlite3.Row + cursor = database_conn.cursor() + cursor.execute("INSERT INTO dogs (dog_name, isAdmin) VALUES (?, ?)", ("RowFido", 1)) + database_conn.commit() + + cursor.execute("SELECT * FROM dogs WHERE dog_name = ?", ("RowFido",)) + row = cursor.fetchone() + assert row["dog_name"] == "RowFido" + assert row["isAdmin"] == 1 + cursor.close() + + +def test_cursor_with_custom_factory(database_conn): + import sqlite3 + + class CustomCursor(sqlite3.Cursor): + pass + + reset_comms() + with patch( + "aikido_zen.vulnerabilities.run_vulnerability_scan" + ) as mock_run_vulnerability_scan: + cursor = database_conn.cursor(factory=CustomCursor) + assert isinstance(cursor, CustomCursor) + + cursor.execute( + "INSERT INTO dogs (dog_name, isAdmin) VALUES (?, ?)", ("FactoryFido", 1) + ) + database_conn.commit() + + called_with_args = mock_run_vulnerability_scan.call_args[1]["args"] + assert ( + called_with_args[0] == "INSERT INTO dogs (dog_name, isAdmin) VALUES (?, ?)" + ) + assert called_with_args[1] == "sqlite" + mock_run_vulnerability_scan.assert_called_once() + + cursor.execute("SELECT * FROM dogs WHERE dog_name = ?", ("FactoryFido",)) + rows = cursor.fetchall() + assert len(rows) == 1 + assert rows[0][1] == "FactoryFido" + cursor.close()