diff --git a/.jules/bolt.md b/.jules/bolt.md deleted file mode 100644 index e61cb3e..0000000 --- a/.jules/bolt.md +++ /dev/null @@ -1,3 +0,0 @@ -## 2024-05-18 - [Fix N+1 query] -**Learning:** Fixed N+1 queries in `rstate` and `onQuery` by bulk fetching data instead of using `rquery` inside a loop. -**Action:** Always check for repeated fetches in loops that can be optimized to a bulk fetch. diff --git a/action_devices.py b/action_devices.py index 2f660d0..09e8575 100644 --- a/action_devices.py +++ b/action_devices.py @@ -73,21 +73,16 @@ def rstate(user_id=None): if not devices_data: return {"devices": {"states": {}}} + devices = list(devices_data.keys()) payload = { "devices": { "states": {} } } - for device, raw_data in devices_data.items(): + for device in devices: device = str(device) logger.debug('Getting Device status from: %s', device) - - # Use pre-fetched state if available in snapshot, else fallback - if isinstance(raw_data, dict) and 'states' in raw_data: - state_data = raw_data['states'] - else: - state_data = rquery(device, user_id=user_id) - + state_data = rquery(device, user_id=user_id) if state_data: payload['devices']['states'][device] = state_data logger.debug('Device state: %s', state_data) @@ -204,24 +199,10 @@ def onQuery(body, user_id=None): payload = { "devices": {}, } - - # Optimize: bulk fetch device states to avoid N+1 queries - devices_data = _get_scoped_snapshot(user_id) or {} - for i in body['inputs']: for device in i['payload']['devices']: deviceId = device['id'] - - # Try getting from cache first - data = None - raw_data = devices_data.get(deviceId) - if isinstance(raw_data, dict) and 'states' in raw_data: - data = raw_data['states'] - - # Fallback to single fetch and validation logic in rquery - if data is None: - data = rquery(deviceId, user_id=user_id) - + data = rquery(deviceId, user_id=user_id) payload['devices'][deviceId] = data return payload except Exception as e: diff --git a/firebase_utils.py b/firebase_utils.py index b99c9cf..cb20d1a 100644 --- a/firebase_utils.py +++ b/firebase_utils.py @@ -93,7 +93,7 @@ def reference(user_id=None): if user_scope: return db.reference(f'/users/{user_scope}/devices') return db.reference('/devices') - except ValueError as e: + except Exception as e: # Firebase is installed but not initialized (e.g. missing credentials in dev) logger.warning("Firebase not initialized, falling back to mock data: %s", e) return MockRef() diff --git a/my_oauth.py b/my_oauth.py index 535b7d6..c69c9e5 100644 --- a/my_oauth.py +++ b/my_oauth.py @@ -8,7 +8,7 @@ from authlib.oauth2.rfc6750 import BearerTokenValidator from flask import session from flask_login import current_user -from sqlalchemy import select, delete +from sqlalchemy import select from models import db from models import Client, Token, Grant, User @@ -51,12 +51,14 @@ def load_client(client_id): def save_token(token_data, request): logger.debug("token setter") # make sure that every client has only one token connected to a user - db.session.execute( - delete(Token).filter_by( + existing_tokens = db.session.execute( + select(Token).filter_by( client_id=request.client.client_id, user_id=request.user.id, ) - ) + ).scalars() + for t in existing_tokens: + db.session.delete(t) raw_expires_in = token_data.get('expires_in') try: diff --git a/notifications.py b/notifications.py index 2bd9516..2d18957 100644 --- a/notifications.py +++ b/notifications.py @@ -86,29 +86,6 @@ def handle_disconnect(_client, _userdata, rc): _append_mqtt_log('system', f'Unexpected disconnect (rc={rc})', 'Disconnected') -def _handle_status_message(user_id, device_id, payload): - """Update Firebase device status based on MQTT payload.""" - # Guard JSON decoding to avoid noisy errors when payloads are already decoded or non-JSON - state_updates = None - - if isinstance(payload, dict): - state_updates = payload - elif isinstance(payload, str): - try: - state_updates = json.loads(payload) - except (ValueError, TypeError): - logger.debug("Non-JSON status payload for %s/%s; skipping Firebase update", user_id, device_id) - - if state_updates is not None: - try: - ref = _get_user_device_states_ref(user_id, device_id) - if ref: - ref.update(state_updates) - logger.debug("Updated Firebase status for %s/%s", user_id, device_id) - except Exception as e: - logger.error("Failed to update Firebase from MQTT: %s", e) - - @mqtt.on_message() def handle_messages(_client, _userdata, message): payload = _decode_payload(message.payload) @@ -121,16 +98,33 @@ def handle_messages(_client, _userdata, message): # {user_id}/{device_id}/status parts = topic.split('/') - if len(parts) < 3: - _append_mqtt_log(topic, payload, 'Received', user_id=None) - return - - user_id = parts[0] - device_id = parts[1] - msg_type = parts[2] - - if msg_type == 'status': - _handle_status_message(user_id, device_id, payload) + user_id = None + if len(parts) >= 3: + user_id = parts[0] + device_id = parts[1] + msg_type = parts[2] + + # Try to update Firebase if it's a status message + if msg_type == 'status': + # Guard JSON decoding to avoid noisy errors when payloads are already decoded or non-JSON + state_updates = None + + if isinstance(payload, dict): + state_updates = payload + elif isinstance(payload, str): + try: + state_updates = json.loads(payload) + except (ValueError, TypeError): + logger.debug("Non-JSON status payload for %s/%s; skipping Firebase update", user_id, device_id) + + if state_updates is not None: + try: + ref = _get_user_device_states_ref(user_id, device_id) + if ref: + ref.update(state_updates) + logger.debug("Updated Firebase status for %s/%s", user_id, device_id) + except Exception as e: + logger.error("Failed to update Firebase from MQTT: %s", e) _append_mqtt_log(topic, payload, 'Received', user_id=user_id) diff --git a/routes.py b/routes.py index c3c63fa..6d0efe5 100644 --- a/routes.py +++ b/routes.py @@ -19,39 +19,32 @@ def _oauth_error_response(exc): """Return OAuth errors in a client-friendly way when redirect URI is valid.""" - status_code = getattr(exc, 'status_code', 400) or 400 - error_response = jsonify(error=exc.error, error_description=exc.description), status_code - client_id = request.values.get('client_id') redirect_uri = request.values.get('redirect_uri') state = request.values.get('state') client = load_client(client_id) if client_id else None - if not client: - return error_response - - if not redirect_uri: - redirect_uri = client.get_default_redirect_uri() - - if not redirect_uri or not client.check_redirect_uri(redirect_uri): - return error_response + if client: + if not redirect_uri: + redirect_uri = client.get_default_redirect_uri() + if redirect_uri and client.check_redirect_uri(redirect_uri): + parsed = urlparse(redirect_uri) + if parsed.scheme and parsed.netloc: + params = {'error': exc.error} + if exc.description: + params['error_description'] = exc.description + if state: + params['state'] = state + + existing_query = dict(parse_qsl(parsed.query, keep_blank_values=True)) + existing_query.update(params) + safe_redirect_uri = urlunparse( + parsed._replace(query=urlencode(existing_query)) + ) + return redirect(safe_redirect_uri) - parsed = urlparse(redirect_uri) - if not (parsed.scheme and parsed.netloc): - return error_response - - params = {'error': exc.error} - if exc.description: - params['error_description'] = exc.description - if state: - params['state'] = state - - existing_query = dict(parse_qsl(parsed.query, keep_blank_values=True)) - existing_query.update(params) - safe_redirect_uri = urlunparse( - parsed._replace(query=urlencode(existing_query)) - ) - return redirect(safe_redirect_uri) + status_code = getattr(exc, 'status_code', 400) or 400 + return jsonify(error=exc.error, error_description=exc.description), status_code def _resolve_smarthome_user_scope(req): diff --git a/tests/test_firebase_utils.py b/tests/test_firebase_utils.py index 3045435..966e1ae 100644 --- a/tests/test_firebase_utils.py +++ b/tests/test_firebase_utils.py @@ -1,45 +1,109 @@ import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch +import firebase_utils + class TestFirebaseUtils(unittest.TestCase): - def test_reference_value_error_fallback(self): - # We must clear the module to reload it with our mocked FIREBASE_AVAILABLE state - import sys - if 'firebase_utils' in sys.modules: - del sys.modules['firebase_utils'] - # Mock firebase_admin and db to simulate an uninitialized state where db.reference raises ValueError + def test_reference_root_success(self): + # Mock db as it might not be imported if firebase_admin is missing + mock_db = MagicMock() + with patch('firebase_utils.FIREBASE_AVAILABLE', True), \ + patch('firebase_utils.db', mock_db, create=True): + ref = firebase_utils.reference() + mock_db.reference.assert_called_with('/devices') + self.assertNotIsInstance(ref, firebase_utils.MockRef) + + def test_reference_user_success(self): mock_db = MagicMock() - mock_db.reference.side_effect = ValueError("The default Firebase app does not exist.") + with patch('firebase_utils.FIREBASE_AVAILABLE', True), \ + patch('firebase_utils.db', mock_db, create=True): + ref = firebase_utils.reference(user_id="user123") + mock_db.reference.assert_called_with('/users/user123/devices') + self.assertNotIsInstance(ref, firebase_utils.MockRef) - with patch.dict(sys.modules, {'firebase_admin': MagicMock(db=mock_db)}): - import firebase_utils + def test_reference_exception_path(self): + # Ensure FIREBASE_AVAILABLE is True for this test + mock_db = MagicMock() + mock_db.reference.side_effect = Exception("Firebase initialization error") + with patch('firebase_utils.FIREBASE_AVAILABLE', True), \ + patch('firebase_utils.db', mock_db, create=True), \ + patch('firebase_utils.logger.warning') as mock_warning: + ref = firebase_utils.reference() - # FIREBASE_AVAILABLE should be True because the import succeeds (mocked) - self.assertTrue(firebase_utils.FIREBASE_AVAILABLE) + # Verify it returns a MockRef instance + self.assertIsInstance(ref, firebase_utils.MockRef) - with patch('firebase_utils.logger.warning') as mock_logger: - ref = firebase_utils.reference() + # Verify the warning was logged + mock_warning.assert_called() + self.assertIn("Firebase not initialized", mock_warning.call_args[0][0]) - # Should fallback to MockRef - self.assertIsInstance(ref, firebase_utils.MockRef) - mock_logger.assert_called_once() - self.assertIn("Firebase not initialized", mock_logger.call_args[0][0]) + def test_reference_firebase_not_available(self): + with patch('firebase_utils.FIREBASE_AVAILABLE', False): + ref = firebase_utils.reference() + self.assertIsInstance(ref, firebase_utils.MockRef) - def test_reference_other_exception_bubbles_up(self): - import sys - if 'firebase_utils' in sys.modules: - del sys.modules['firebase_utils'] + def test_normalize_user_scope(self): + self.assertEqual(firebase_utils._normalize_user_scope("user123"), "user123") + self.assertEqual(firebase_utils._normalize_user_scope(123), "123") + self.assertEqual(firebase_utils._normalize_user_scope(" user123 "), "user123") + self.assertIsNone(firebase_utils._normalize_user_scope(None)) + self.assertIsNone(firebase_utils._normalize_user_scope("")) + self.assertIsNone(firebase_utils._normalize_user_scope(" ")) + self.assertIsNone(firebase_utils._normalize_user_scope("user/123")) + self.assertIsNone(firebase_utils._normalize_user_scope("user\\123")) + self.assertIsNone(firebase_utils._normalize_user_scope("user..123")) + def test_get_user_device_states_ref_valid(self): mock_db = MagicMock() - # Some other error like network permission denied - mock_db.reference.side_effect = PermissionError("Permission denied.") + mock_ref = MagicMock() + mock_db.reference.return_value = mock_ref + mock_device_ref = MagicMock() + mock_ref.child.return_value = mock_device_ref + mock_states_ref = MagicMock() + mock_device_ref.child.return_value = mock_states_ref + + with patch('firebase_utils.FIREBASE_AVAILABLE', True), \ + patch('firebase_utils.db', mock_db, create=True): + + ref = firebase_utils._get_user_device_states_ref("user123", "device1") + + mock_db.reference.assert_called_with('/users/user123/devices') + mock_ref.child.assert_called_with('device1') + mock_device_ref.child.assert_called_with('states') + self.assertEqual(ref, mock_states_ref) + + def test_get_user_device_states_ref_invalid(self): + self.assertIsNone(firebase_utils._get_user_device_states_ref(None, "device1")) + self.assertIsNone(firebase_utils._get_user_device_states_ref("user/123", "device1")) + self.assertIsNone(firebase_utils._get_user_device_states_ref("user1", None)) + self.assertIsNone(firebase_utils._get_user_device_states_ref("user1", "device/1")) + + def test_mock_ref_and_child(self): + ref = firebase_utils.MockRef() + data = ref.get() + self.assertEqual(data, firebase_utils.MOCK_DEVICES) + + child = ref.child("test-light-1") + self.assertIsInstance(child, firebase_utils.MockChild) + self.assertEqual(child.get(), firebase_utils.MOCK_DEVICES["test-light-1"]) + + grandchild = child.child("name") + self.assertEqual(grandchild.get(), firebase_utils.MOCK_DEVICES["test-light-1"]["name"]) + + # Test non-existent path + self.assertIsNone(ref.child("non-existent").get()) - with patch.dict(sys.modules, {'firebase_admin': MagicMock(db=mock_db)}): - import firebase_utils + # Test update + # Careful as MOCK_DEVICES is shared. + original_states = firebase_utils.MOCK_DEVICES["test-light-1"]["states"].copy() + try: + update_data = {"on": not original_states["on"]} + child.child("states").update(update_data) + self.assertEqual(firebase_utils.MOCK_DEVICES["test-light-1"]["states"]["on"], not original_states["on"]) + finally: + firebase_utils.MOCK_DEVICES["test-light-1"]["states"] = original_states - with self.assertRaises(PermissionError): - firebase_utils.reference() if __name__ == '__main__': unittest.main() diff --git a/tests/test_multi_tenant.py b/tests/test_multi_tenant.py index de9e937..8adb362 100644 --- a/tests/test_multi_tenant.py +++ b/tests/test_multi_tenant.py @@ -19,12 +19,12 @@ def setUp(self): def tearDown(self): self.ctx.pop() - @patch('action_devices.reference') + @patch('firebase_utils.db') @patch('firebase_utils.FIREBASE_AVAILABLE', True) - def test_on_sync_scoped_path(self, mock_reference): + def test_on_sync_scoped_path(self, mock_db): # Mock Firebase response for user 123 mock_ref = MagicMock() - mock_reference.return_value = mock_ref + mock_db.reference.return_value = mock_ref mock_ref.get.return_value = { "device1": {"name": {"name": "Light 1"}, "type": "light"} } @@ -33,15 +33,15 @@ def test_on_sync_scoped_path(self, mock_reference): response = onSync(user_id="123") # Verify correct Firebase path was hit - mock_reference.assert_called_with('123') + mock_db.reference.assert_called_with('/users/123/devices') self.assertEqual(response['agentUserId'], "123") self.assertEqual(len(response['devices']), 1) self.assertEqual(response['devices'][0]['id'], "device1") @patch('action_devices.mqtt') - @patch('action_devices.reference') + @patch('firebase_utils.db') @patch('firebase_utils.FIREBASE_AVAILABLE', True) - def test_on_execute_mqtt_topic(self, mock_reference, mock_mqtt): + def test_on_execute_mqtt_topic(self, mock_db, mock_mqtt): # Mock request for user 456 req = { "requestId": "req1", @@ -61,7 +61,7 @@ def test_on_execute_mqtt_topic(self, mock_reference, mock_mqtt): # Mock Firebase update mock_ref = MagicMock() - mock_reference.return_value = mock_ref + mock_db.reference.return_value = mock_ref # Call actions actions(req, user_id="456") @@ -72,24 +72,31 @@ def test_on_execute_mqtt_topic(self, mock_reference, mock_mqtt): call_args = mock_mqtt.publish.call_args self.assertEqual(call_args.kwargs['topic'], expected_topic) - @patch('notifications._get_user_device_states_ref') + @patch('firebase_utils.db') @patch('firebase_utils.FIREBASE_AVAILABLE', True) - def test_mqtt_status_update_firebase_path(self, mock_reference): + def test_mqtt_status_update_firebase_path(self, mock_db): # Mock MQTT message: user 789 reports status for lamp1 mock_message = MagicMock() mock_message.topic = "789/lamp1/status" mock_message.payload = b'{"on": false, "online": true}' # Mock reference chain + mock_user_ref = MagicMock() + mock_device_ref = MagicMock() + mock_states_ref = MagicMock() + mock_db.reference.return_value = mock_user_ref + mock_user_ref.child.return_value = mock_device_ref + mock_device_ref.child.return_value = mock_states_ref # Call handle_messages handle_messages(None, None, mock_message) # Verify Firebase update path is scoped correctly - mock_reference.assert_called_with('789', 'lamp1') - - mock_reference.return_value.update.assert_called_with({"on": False, "online": True}) + mock_db.reference.assert_called_with('/users/789/devices') + mock_user_ref.child.assert_called_with('lamp1') + mock_device_ref.child.assert_called_with('states') + mock_states_ref.update.assert_called_with({"on": False, "online": True}) def test_mqtt_log_filtering(self): # Clear logs (since they are in-memory deque) diff --git a/tests/test_oauth_optimization.py b/tests/test_oauth_optimization.py deleted file mode 100644 index be99c4c..0000000 --- a/tests/test_oauth_optimization.py +++ /dev/null @@ -1,69 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch -import sys -import os - -# Add root directory to path -sys.path.insert(0, os.path.abspath(os.curdir)) - -class TestOAuthOptimization(unittest.TestCase): - def test_save_token_uses_bulk_delete(self): - """ - Verify that save_token uses a bulk delete instead of a loop. - """ - # Always use a fully mocked approach to avoid issues with missing dependencies - # or partial environments in both CI and local. - - mock_db = MagicMock() - mock_models = MagicMock() - mock_models.db = mock_db - mock_token_cls = MagicMock() - mock_models.Token = mock_token_cls - - mock_sqlalchemy = MagicMock() - mock_delete_query = MagicMock() - mock_sqlalchemy.delete.return_value = mock_delete_query - mock_delete_query.filter_by.return_value = mock_delete_query - - # Mock all possible dependencies to prevent ImportErrors during the test - mock_modules = { - 'flask': MagicMock(), - 'flask.debughelpers': MagicMock(), - 'flask_login': MagicMock(), - 'authlib.integrations.flask_oauth2': MagicMock(), - 'authlib.oauth2.rfc6749': MagicMock(), - 'authlib.oauth2.rfc6750': MagicMock(), - 'sqlalchemy': mock_sqlalchemy, - 'models': mock_models - } - - # Using patch.dict on sys.modules is safe and isolates the test - with patch.dict(sys.modules, mock_modules): - # Ensure my_oauth is reloaded within this mocked context - if 'my_oauth' in sys.modules: - del sys.modules['my_oauth'] - import my_oauth - - mock_request = MagicMock() - mock_request.client.client_id = 'test_client' - mock_request.user.id = 123 - - token_data = { - 'access_token': 'new_token', - 'expires_in': 3600, - 'token_type': 'Bearer', - 'scope': 'profile' - } - - # Call the function - my_oauth.save_token(token_data, mock_request) - - # Verify bulk delete was called - mock_db.session.execute.assert_any_call(mock_delete_query) - mock_sqlalchemy.delete.assert_called_once_with(mock_token_cls) - - # Verify the old iterative delete is NOT called - self.assertFalse(mock_db.session.delete.called) - -if __name__ == '__main__': - unittest.main()