diff --git a/src/ipsdk/connection.py b/src/ipsdk/connection.py index c71dcdb..fa71804 100644 --- a/src/ipsdk/connection.py +++ b/src/ipsdk/connection.py @@ -146,6 +146,20 @@ async def fetch_devices(): class ConnectionBase: + __slots__ = ( + "_auth_lock", + "_auth_timestamp", + "_ttl_enabled", + "authenticated", + "client", + "client_id", + "client_secret", + "password", + "token", + "ttl", + "user", + ) + client: httpx.Client | httpx.AsyncClient @logging.trace @@ -203,6 +217,7 @@ def __init__( self._auth_lock: Any | None = None self._auth_timestamp: float | None = None self.ttl = ttl + self._ttl_enabled = ttl > 0 # Cache this check for performance self.client = self.__init_client__( base_url=self._make_base_url(host, port, base_path, use_tls), @@ -238,13 +253,13 @@ def _make_base_url( None """ if port == 0: - port = 443 if use_tls is True else 80 + port = 443 if use_tls else 80 if port not in (None, 80, 443): host = f"{host}:{port}" base_path = "" if base_path is None else base_path - proto = "https" if use_tls is True else "http" + proto = "https" if use_tls else "http" return urllib.parse.urlunsplit((proto, host, base_path, None, None)) @@ -313,7 +328,7 @@ def _validate_request_args( method: HTTPMethod, path: str, params: dict[str, Any | None] | None = None, - json: str | bytes | dict | (list | None) = None, + json: str | bytes | dict | list | None = None, ) -> None: """ Validate request arguments to ensure they have correct types. @@ -336,19 +351,19 @@ def _validate_request_args( IpsdkError: If method is not HTTPMethod type, params is not dict, json is not dict/list, or path is not string """ - if isinstance(method, HTTPMethod) is False: + if not isinstance(method, HTTPMethod): msg = "method must be of type `HTTPMethod`" raise exceptions.IpsdkError(msg) - if all((params is not None, isinstance(params, dict) is False)): + if params is not None and not isinstance(params, dict): msg = "params must be of type `dict`" raise exceptions.IpsdkError(msg) - if all((json is not None, isinstance(json, (list, dict)) is False)): + if json is not None and not isinstance(json, (list, dict)): msg = "json must be of type `dict` or `list`" raise exceptions.IpsdkError(msg) - if isinstance(path, str) is False: + if not isinstance(path, str): msg = "path must be of type `str`" raise exceptions.IpsdkError(msg) @@ -475,19 +490,25 @@ def _send_request( RequestError: Network or connection errors occurred. HTTPStatusError: Server returned an HTTP error status (4xx, 5xx). """ - # Check if reauthentication is needed due to timeout - if self._needs_reauthentication(): - logging.info("Forcing reauthentication due to timeout") - self.authenticated = False - self.token = None - - if self.authenticated is False: - assert self._auth_lock is not None + # Check authentication status and handle TTL-based reauthentication + if self.authenticated is False or self._ttl_enabled: + if self._auth_lock is None: + msg = "Authentication lock not initialized" + raise exceptions.IpsdkError(msg) + with self._auth_lock: - if self.authenticated is False: - self.authenticate() - self.authenticated = True - self._auth_timestamp = time.time() + # Double-check pattern with TTL check inside lock + # to prevent race conditions + if self.authenticated is False or self._needs_reauthentication(): + if self._needs_reauthentication(): + logging.info("Forcing reauthentication due to timeout") + self.authenticated = False + self.token = None + + if self.authenticated is False: + self.authenticate() + self.authenticated = True + self._auth_timestamp = time.time() request = self._build_request( method=method, @@ -693,19 +714,25 @@ async def _send_request( RequestError: Network or connection errors occurred. HTTPStatusError: Server returned an HTTP error status (4xx, 5xx). """ - # Check if reauthentication is needed due to timeout - if self._needs_reauthentication(): - logging.info("Forcing reauthentication due to timeout") - self.authenticated = False - self.token = None - - if self.authenticated is False: - assert self._auth_lock is not None + # Check authentication status and handle TTL-based reauthentication + if self.authenticated is False or self._ttl_enabled: + if self._auth_lock is None: + msg = "Authentication lock not initialized" + raise exceptions.IpsdkError(msg) + async with self._auth_lock: - if self.authenticated is False: - await self.authenticate() - self.authenticated = True - self._auth_timestamp = time.time() + # Double-check pattern with TTL check inside lock + # to prevent race conditions + if self.authenticated is False or self._needs_reauthentication(): + if self._needs_reauthentication(): + logging.info("Forcing reauthentication due to timeout") + self.authenticated = False + self.token = None + + if self.authenticated is False: + await self.authenticate() + self.authenticated = True + self._auth_timestamp = time.time() request = self._build_request( method=method, diff --git a/src/ipsdk/exceptions.py b/src/ipsdk/exceptions.py index 5f07e20..76aa290 100644 --- a/src/ipsdk/exceptions.py +++ b/src/ipsdk/exceptions.py @@ -128,6 +128,8 @@ def request(self) -> Any: The httpx.Request object associated with this error, or None if no httpx exception was provided during initialization. """ + if self._exc is None: + return None return self._exc.request @property @@ -140,7 +142,9 @@ def response(self) -> Any: no httpx exception was provided or if the error occurred before receiving a response. """ - return self._exc.response + if self._exc is None: + return None + return getattr(self._exc, "response", None) class RequestError(IpsdkError): diff --git a/src/ipsdk/gateway.py b/src/ipsdk/gateway.py index 77d9f00..c3f7fbe 100644 --- a/src/ipsdk/gateway.py +++ b/src/ipsdk/gateway.py @@ -152,6 +152,7 @@ async def get_devices(): print(f"Network error: {e}") """ +from typing import TYPE_CHECKING from typing import Any import httpx @@ -275,12 +276,21 @@ async def authenticate(self) -> None: raise exceptions.RequestError(exc) -Gateway = type("Gateway", (AuthMixin, connection.Connection), {}) -AsyncGateway = type("AsyncGateway", (AsyncAuthMixin, connection.AsyncConnection), {}) +# Define dynamically created classes for runtime and type checking +if TYPE_CHECKING: + # For type checkers: provide explicit class definitions + class Gateway(AuthMixin, connection.Connection): + """Synchronous Gateway client with authentication.""" -# Type aliases for mypy -GatewayType = Gateway -AsyncGatewayType = AsyncGateway + class AsyncGateway(AsyncAuthMixin, connection.AsyncConnection): + """Asynchronous Gateway client with authentication.""" + +else: + # For runtime: use dynamic type creation for flexibility + Gateway = type("Gateway", (AuthMixin, connection.Connection), {}) + AsyncGateway = type( + "AsyncGateway", (AsyncAuthMixin, connection.AsyncConnection), {} + ) @logging.trace @@ -335,7 +345,7 @@ def gateway_factory( Returns: An initialized connection instance """ - factory = AsyncGateway if want_async is True else Gateway + factory = AsyncGateway if want_async else Gateway return factory( host=host, port=port, diff --git a/src/ipsdk/heuristics.py b/src/ipsdk/heuristics.py index 50ee546..ccc8d11 100644 --- a/src/ipsdk/heuristics.py +++ b/src/ipsdk/heuristics.py @@ -35,6 +35,8 @@ class Scanner: _instance: Scanner | None = None _initialized: bool = False + _default_patterns: dict[str, Pattern] | None = None + _default_redactions: dict[str, Callable[[str], str]] | None = None def __new__(cls, _custom_patterns: dict[str, str | None] | None = None) -> Scanner: """Create or return the singleton instance. @@ -73,10 +75,7 @@ def __init__(self, custom_patterns: dict[str, str | None] | None = None) -> None """ # Only initialize once due to Singleton pattern if not self._initialized: - self._patterns: dict[str, Pattern] = {} - self._redaction_functions: dict[str, Callable[[str], str]] = {} - - # Initialize default patterns + # Initialize default patterns (copies from class-level cache) self._init_default_patterns() # Add custom patterns if provided @@ -91,7 +90,8 @@ def _init_default_patterns(self) -> None: """Initialize default sensitive data patterns. Sets up regex patterns for common sensitive data types including API keys, - passwords, tokens, credit card numbers, and other PII. + passwords, tokens, credit card numbers, and other PII. Patterns are compiled + once at class level and reused across all instances for performance. Returns: None @@ -99,52 +99,61 @@ def _init_default_patterns(self) -> None: Raises: None """ - # API Keys and tokens (various formats) - self.add_pattern( - "api_key", - r"(?i)\b(?:api[_-]?key|apikey)\s*[=:]\s*[\"']?([a-zA-Z0-9_\-]{16,})[\"']?", - ) - self.add_pattern("bearer_token", r"(?i)\bbearer\s+([a-zA-Z0-9_\-\.]{20,})") - self.add_pattern( - "jwt_token", - r"\b(eyJ[a-zA-Z0-9_\-]+\.eyJ[a-zA-Z0-9_\-]+\.[a-zA-Z0-9_\-]+)\b", - ) - self.add_pattern( - "access_token", - r"(?i)\b(?:access[_-]?token|accesstoken)\s*[=:]\s*[\"']?([a-zA-Z0-9_\-]{20,})[\"']?", - ) - - # Password patterns - self.add_pattern( - "password", - r"(?i)\b(?:password|passwd|pwd)\s*[=:]\s*[\"']?([^\s\"']{6,})[\"']?", - ) - self.add_pattern( - "secret", - r"(?i)\b(?:secret|client_secret)\s*[=:]\s*[\"']?([a-zA-Z0-9_\-]{16,})[\"']?", - ) - - # URLs with authentication (check before email patterns) - self.add_pattern("auth_url", r"https?://[a-zA-Z0-9_\-]+:[a-zA-Z0-9_\-]+@[^\s]+") - - # Basic email pattern (when used in sensitive contexts) - self.add_pattern( - "email_in_auth", - r"(?i)(?:username|user|email)\s*[=:]\s*[\"']?([a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,})[\"']?", - ) - - # Database connection strings - self.add_pattern( - "db_connection", - r"(?i)\b(?:mongodb|mysql|postgresql|postgres)://[^\s]+:[^\s]+@[^\s]+", - ) - - # Private keys (basic detection) - self.add_pattern( - "private_key", - r"-----BEGIN (?:RSA )?PRIVATE KEY-----[\s\S]*?" - r"-----END (?:RSA )?PRIVATE KEY-----", - ) + # Only compile patterns once at class level + if Scanner._default_patterns is None: + Scanner._default_patterns = {} + Scanner._default_redactions = {} + + # Compile all default patterns once + patterns_to_compile = { + "api_key": ( + r"(?i)\b(?:api[_-]?key|apikey)\s*[=:]\s*[\"']?" + r"([a-zA-Z0-9_\-]{16,})[\"']?" + ), + "bearer_token": r"(?i)\bbearer\s+([a-zA-Z0-9_\-\.]{20,})", + "jwt_token": ( + r"\b(eyJ[a-zA-Z0-9_\-]+\.eyJ[a-zA-Z0-9_\-]+" + r"\.[a-zA-Z0-9_\-]+)\b" + ), + "access_token": ( + r"(?i)\b(?:access[_-]?token|accesstoken)\s*[=:]\s*[\"']?" + r"([a-zA-Z0-9_\-]{20,})[\"']?" + ), + "password": ( + r"(?i)\b(?:password|passwd|pwd)\s*[=:]\s*[\"']?" + r"([^\s\"']{6,})[\"']?" + ), + "secret": ( + r"(?i)\b(?:secret|client_secret)\s*[=:]\s*[\"']?" + r"([a-zA-Z0-9_\-]{16,})[\"']?" + ), + "auth_url": r"https?://[a-zA-Z0-9_\-]+:[a-zA-Z0-9_\-]+@[^\s]+", + "email_in_auth": ( + r"(?i)(?:username|user|email)\s*[=:]\s*[\"']?" + r"([a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,})[\"']?" + ), + "db_connection": ( + r"(?i)\b(?:mongodb|mysql|postgresql|postgres)://" + r"[^\s]+:[^\s]+@[^\s]+" + ), + "private_key": ( + r"-----BEGIN (?:RSA )?PRIVATE KEY-----[\s\S]*?" + r"-----END (?:RSA )?PRIVATE KEY-----" + ), + } + + for name, pattern_str in patterns_to_compile.items(): + compiled_pattern = re.compile(pattern_str) + Scanner._default_patterns[name] = compiled_pattern + # Use default argument to capture current value of name + # (avoid late-binding closure issue) + Scanner._default_redactions[name] = ( + lambda _, n=name: f"[REDACTED_{n.upper()}]" + ) + + # Copy pre-compiled patterns to instance + self._patterns = Scanner._default_patterns.copy() + self._redaction_functions = Scanner._default_redactions.copy() def add_pattern( self, diff --git a/src/ipsdk/http.py b/src/ipsdk/http.py index 140af1a..62bedc0 100644 --- a/src/ipsdk/http.py +++ b/src/ipsdk/http.py @@ -77,6 +77,8 @@ class Request: ValueError: If required parameters are missing or invalid """ + __slots__ = ("headers", "json", "method", "params", "path") + @logging.trace def __init__( self, @@ -84,7 +86,7 @@ def __init__( path: str, params: dict[str, Any | None] | None = None, headers: dict[str, str | None] | None = None, - json: str | bytes | dict | (list | None) = None, + json: str | bytes | dict | list | None = None, ) -> None: self.method = method self.path = path @@ -130,6 +132,8 @@ class Response: ValueError: If the httpx_response is None or invalid """ + __slots__ = ("_response",) + @logging.trace def __init__(self, httpx_response: httpx.Response) -> None: if httpx_response is None: diff --git a/src/ipsdk/logging.py b/src/ipsdk/logging.py index 4e4c488..d99dc7a 100644 --- a/src/ipsdk/logging.py +++ b/src/ipsdk/logging.py @@ -104,6 +104,7 @@ def process_data(data): logging.set_level(logging.NONE) """ +import gc import inspect import logging import sys @@ -148,9 +149,12 @@ def process_data(data): logging.getLogger(metadata.name).setLevel(NONE) # Thread-safe configuration for sensitive data filtering -_filtering_lock = threading.RLock() +_filtering_lock = threading.Lock() _sensitive_data_filtering_enabled = False +# Thread-safe logger cache access +_logger_cache_lock = threading.Lock() + def log(lvl: int, msg: str) -> None: """Send the log message with the specified level. @@ -313,9 +317,8 @@ def _get_loggers() -> set[logging.Logger]: dependencies (ipsdk, FastMCP). Results are cached to improve performance on subsequent calls. - This function is thread-safe. It creates a snapshot of logger names before - iteration to prevent issues if the logger dictionary is modified by other - threads during iteration. + This function is thread-safe. It uses a lock to protect logger dictionary + access and creates a snapshot to prevent race conditions during iteration. Note: The cached result may not immediately reflect loggers created after @@ -325,13 +328,19 @@ def _get_loggers() -> set[logging.Logger]: Returns: set[logging.Logger]: Set of logger instances for the application and dependencies. """ - loggers = set() - # Create a snapshot of logger names to prevent race conditions during iteration - logger_names = list(logging.Logger.manager.loggerDict.keys()) - for name in logger_names: - if name.startswith((metadata.name, "httpx")): - loggers.add(logging.getLogger(name)) - return loggers + with _logger_cache_lock: + loggers = set() + # Create a copy of the logger dictionary for thread safety + logger_dict_copy = logging.Logger.manager.loggerDict.copy() + + for name in logger_dict_copy: + # Verify logger still exists and matches our namespace + if ( + name.startswith((metadata.name, "httpx")) + and name in logging.Logger.manager.loggerDict + ): + loggers.add(logging.getLogger(name)) + return loggers def get_logger() -> logging.Logger: @@ -381,7 +390,7 @@ def set_level(lvl: int | str, *, propagate: bool = False) -> None: logger.log(logging.INFO, f"{metadata.name} version {metadata.version}") logger.log(logging.INFO, f"Logging level set to {lvl}") - if propagate is True: + if propagate: # Clear cache to ensure we get all current loggers including httpx _get_loggers.cache_clear() for logger in _get_loggers(): @@ -543,6 +552,10 @@ def initialize() -> None: function while other threads are actively logging may result in lost log messages or exceptions. + This function should only be called once during application startup. + Repeated calls may cause memory leaks if loggers are created between + invocations. + Returns: None @@ -565,3 +578,6 @@ def initialize() -> None: logger.addHandler(stream_handler) logger.setLevel(NONE) logger.propagate = False + + # Force garbage collection of closed handlers to prevent memory leaks + gc.collect() diff --git a/src/ipsdk/platform.py b/src/ipsdk/platform.py index a02f0bd..41880d0 100644 --- a/src/ipsdk/platform.py +++ b/src/ipsdk/platform.py @@ -222,6 +222,8 @@ async def create_workflow(name): print(f"Request failed with status {response.status_code}") """ +from typing import TYPE_CHECKING + import httpx from . import connection @@ -592,13 +594,21 @@ async def authenticate_oauth(self) -> None: raise exceptions.RequestError(exc) -# Define type aliases for the dynamically created classes -Platform = type("Platform", (AuthMixin, connection.Connection), {}) -AsyncPlatform = type("AsyncPlatform", (AsyncAuthMixin, connection.AsyncConnection), {}) +# Define dynamically created classes for runtime and type checking +if TYPE_CHECKING: + # For type checkers: provide explicit class definitions + class Platform(AuthMixin, connection.Connection): + """Synchronous Platform client with authentication.""" + + class AsyncPlatform(AsyncAuthMixin, connection.AsyncConnection): + """Asynchronous Platform client with authentication.""" -# Type aliases for mypy -PlatformType = Platform -AsyncPlatformType = AsyncPlatform +else: + # For runtime: use dynamic type creation for flexibility + Platform = type("Platform", (AuthMixin, connection.Connection), {}) + AsyncPlatform = type( + "AsyncPlatform", (AsyncAuthMixin, connection.AsyncConnection), {} + ) @logging.trace diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index c6d7c1c..d96970e 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -60,16 +60,16 @@ def test_inheritance(self): def test_request_property_without_exception(self): """Test request property when no httpx exception was provided.""" exc = exceptions.IpsdkError("Test error") - # Accessing request without an httpx exception will raise AttributeError - with pytest.raises(AttributeError): - _ = exc.request + # Accessing request without an httpx exception now returns None + # (safer than AttributeError) + assert exc.request is None def test_response_property_without_exception(self): """Test response property when no httpx exception was provided.""" exc = exceptions.IpsdkError("Test error") - # Accessing response without an httpx exception will raise AttributeError - with pytest.raises(AttributeError): - _ = exc.response + # Accessing response without an httpx exception now returns None + # (safer than AttributeError) + assert exc.response is None class TestRequestError: diff --git a/tests/test_gateway.py b/tests/test_gateway.py index def1828..05eb402 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -12,10 +12,9 @@ from ipsdk import exceptions from ipsdk.connection import AsyncConnection from ipsdk.gateway import AsyncAuthMixin -from ipsdk.gateway import AsyncGatewayType +from ipsdk.gateway import AsyncGateway from ipsdk.gateway import AuthMixin from ipsdk.gateway import Gateway -from ipsdk.gateway import GatewayType from ipsdk.gateway import _make_body from ipsdk.gateway import _make_headers from ipsdk.gateway import _make_path @@ -219,11 +218,11 @@ def test_gateway_factory_async_with_all_parameters(): def test_gateway_type_aliases(): - """Test that gateway type aliases are correctly defined.""" + """Test that gateway classes are correctly defined.""" - # Verify type aliases exist and are not None - assert GatewayType is not None - assert AsyncGatewayType is not None + # Verify gateway classes exist and are not None + assert Gateway is not None + assert AsyncGateway is not None def test_make_body_empty_strings(): diff --git a/tests/test_platform.py b/tests/test_platform.py index ef0f5ab..7bb7486 100644 --- a/tests/test_platform.py +++ b/tests/test_platform.py @@ -18,10 +18,9 @@ from ipsdk.platform import _OAUTH_HEADERS from ipsdk.platform import _OAUTH_PATH from ipsdk.platform import AsyncAuthMixin -from ipsdk.platform import AsyncPlatformType +from ipsdk.platform import AsyncPlatform from ipsdk.platform import AuthMixin from ipsdk.platform import Platform -from ipsdk.platform import PlatformType from ipsdk.platform import _make_basicauth_body from ipsdk.platform import _make_oauth_body from ipsdk.platform import platform_factory @@ -328,11 +327,11 @@ def test_platform_oauth_token_handling(): def test_platform_type_aliases(): - """Test that platform type aliases are correctly defined.""" + """Test that platform classes are correctly defined.""" - # Verify type aliases exist - assert PlatformType is not None - assert AsyncPlatformType is not None + # Verify platform classes exist + assert Platform is not None + assert AsyncPlatform is not None def test_platform_factory_with_all_parameters():