Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Zen for Python 3 is compatible with:
* βœ… [`asyncpg`](https://pypi.org/project/asyncpg) ^0.27
* βœ… [`motor`](https://pypi.org/project/motor/) (See `pymongo` version)
* βœ… [`clickhouse-driver`](https://pypi.org/project/clickhouse-driver)
* βœ… [`sqlite3`](https://docs.python.org/3/library/sqlite3.html)

### AI SDKs
Zen instruments the following AI SDKs to track which models are used and how many tokens are consumed, allowing you to monitor your AI usage and costs:
Expand Down
1 change: 1 addition & 0 deletions aikido_zen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def protect(mode="daemon", token=""):
import aikido_zen.sinks.psycopg
import aikido_zen.sinks.asyncpg
import aikido_zen.sinks.clickhouse_driver
import aikido_zen.sinks.sqlite3

import aikido_zen.sinks.builtins
import aikido_zen.sinks.os
Expand Down
17 changes: 17 additions & 0 deletions aikido_zen/helpers/modify_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Exports modify_arguments"""


def modify_arguments(args, kwargs, pos, name, value):
"""
Returns (new_args, new_kwargs) with `value` injected as keyword argument
`name`. If a positional argument exists at index `pos` or beyond, it is
removed from args so the call is not duplicated.
"""
if len(args) > pos:
new_args = args[:pos] + (value,) + args[pos + 1 :]
new_kwargs = dict(kwargs)
else:
new_args = args
new_kwargs = dict(kwargs)
new_kwargs[name] = value
return new_args, new_kwargs
58 changes: 58 additions & 0 deletions aikido_zen/helpers/modify_arguments_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
from .modify_arguments import modify_arguments


def test_injects_value_as_kwarg():
args, kwargs = modify_arguments((), {}, 0, "key", "val")
assert kwargs["key"] == "val"
assert args == ()


def test_overwrites_positional_arg_at_pos():
args, kwargs = modify_arguments(("a", "b", "c"), {}, 2, "key", "new")
assert args == ("a", "b", "new")
assert "key" not in kwargs


def test_overwrites_positional_arg_keeps_trailing_args():
args, kwargs = modify_arguments(("a", "b", "c", "d"), {}, 2, "key", "new")
assert args == ("a", "b", "new", "d")
assert "key" not in kwargs


def test_injects_as_kwarg_when_pos_not_in_args():
args, kwargs = modify_arguments(("a", "b"), {}, 5, "key", "new")
assert args == ("a", "b")
assert kwargs["key"] == "new"


def test_overwrites_existing_kwarg():
args, kwargs = modify_arguments((), {"key": "old"}, 0, "key", "new")
assert kwargs["key"] == "new"


def test_does_not_mutate_original_kwargs():
original = {"other": 1}
_, kwargs = modify_arguments((), original, 0, "key", "val")
assert "key" not in original
assert kwargs["other"] == 1


def test_does_not_mutate_original_args():
original = ("a", "b", "c")
new_args, _ = modify_arguments(original, {}, 1, "key", "val")
assert original == ("a", "b", "c")
assert new_args == ("a", "val", "c")


def test_empty_args_and_kwargs():
args, kwargs = modify_arguments((), {}, 3, "key", 42)
assert args == ()
assert kwargs == {"key": 42}


def test_preserves_other_kwargs():
args, kwargs = modify_arguments((), {"a": 1, "b": 2}, 5, "c", 3)
assert kwargs["a"] == 1
assert kwargs["b"] == 2
assert kwargs["c"] == 3
19 changes: 19 additions & 0 deletions aikido_zen/sinks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,22 @@ async def decorator(func, instance, args, kwargs):
return return_value

return decorator


def patch_immutable_class(base_cls, method_patches):
modifiable_attributes = {}
for name in method_patches:
modifiable_attributes[name] = getattr(base_cls, name)

cls = type(
base_cls.__name__,
(base_cls,),
# this modifiable_attributes object contains a python (not c) map of functions, so we can apply the
# patch_function to these attributes of our new class.
modifiable_attributes,
)

for name, wrapper in method_patches.items():
patch_function(cls, name, wrapper)

return cls
71 changes: 71 additions & 0 deletions aikido_zen/sinks/sqlite3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import sqlite3 as _sqlite3
import sys
from aikido_zen.helpers.get_argument import get_argument
from aikido_zen.helpers.modify_arguments import modify_arguments
import aikido_zen.vulnerabilities as vulns
from aikido_zen.helpers.register_call import register_call
from aikido_zen.sinks import (
patch_function,
on_import,
before,
patch_immutable_class,
)


@before
def _execute(func, instance, args, kwargs):
op = f"sqlite3.{type(instance).__name__}.{func.__name__}"
query = get_argument(args, kwargs, 0, "sql")
register_call(op, "sql_op")
vulns.run_vulnerability_scan(kind="sql_injection", op=op, args=(query, "sqlite"))


@before
def _executescript(func, instance, args, kwargs):
op = f"sqlite3.{type(instance).__name__}.{func.__name__}"
query = get_argument(args, kwargs, 0, "sql_script")
register_call(op, "sql_op")
vulns.run_vulnerability_scan(kind="sql_injection", op=op, args=(query, "sqlite"))


def _cursor_patch(func, instance, args, kwargs):
factory = get_argument(args, kwargs, 0, "factory") or _sqlite3.Cursor
patched_factory = patch_immutable_class(
factory,
{
"execute": _execute,
"executemany": _execute,
"executescript": _executescript,
},
)

new_args, new_kwargs = modify_arguments(args, kwargs, 0, "factory", patched_factory)
return func(*new_args, **new_kwargs)


def _connect(func, instance, args, kwargs):
factory = get_argument(args, kwargs, 5, "factory") or _sqlite3.Connection
connection_patches = {"cursor": _cursor_patch}

# In Python 3.11, the sqlite3 module was fully moved to C. Hence the extra patches
if sys.version_info >= (3, 11):
connection_patches.update(
{
"execute": _execute,
"executemany": _execute,
"executescript": _executescript,
}
)

patched_factory = patch_immutable_class(factory, connection_patches)
new_args, new_kwargs = modify_arguments(args, kwargs, 5, "factory", patched_factory)
return func(*new_args, **new_kwargs)


@on_import("sqlite3")
def patch(m):
"""
patches sqlite3, a c library; the "connect" function is not c, after that we use patch_immutable_class to
patch the factory parameter of the connect function. In this factory we patch the cursor function.
"""
patch_function(m, "connect", _connect)
Loading