Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .jules/bolt.md

This file was deleted.

27 changes: 4 additions & 23 deletions action_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,16 @@ def rstate(user_id=None):
if not devices_data:
return {"devices": {"states": {}}}

devices = list(devices_data.keys())
payload = {
"devices": {
"states": {}
}
}
for device, raw_data in devices_data.items():
for device in devices:
device = str(device)
logger.debug('Getting Device status from: %s', device)

# Use pre-fetched state if available in snapshot, else fallback
if isinstance(raw_data, dict) and 'states' in raw_data:
state_data = raw_data['states']
else:
state_data = rquery(device, user_id=user_id)

state_data = rquery(device, user_id=user_id)
if state_data:
payload['devices']['states'][device] = state_data
logger.debug('Device state: %s', state_data)
Expand Down Expand Up @@ -204,24 +199,10 @@ def onQuery(body, user_id=None):
payload = {
"devices": {},
}

# Optimize: bulk fetch device states to avoid N+1 queries
devices_data = _get_scoped_snapshot(user_id) or {}

for i in body['inputs']:
for device in i['payload']['devices']:
deviceId = device['id']

# Try getting from cache first
data = None
raw_data = devices_data.get(deviceId)
if isinstance(raw_data, dict) and 'states' in raw_data:
data = raw_data['states']

# Fallback to single fetch and validation logic in rquery
if data is None:
data = rquery(deviceId, user_id=user_id)

data = rquery(deviceId, user_id=user_id)
payload['devices'][deviceId] = data
return payload
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion firebase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def reference(user_id=None):
if user_scope:
return db.reference(f'/users/{user_scope}/devices')
return db.reference('/devices')
except ValueError as e:
except Exception as e:
# Firebase is installed but not initialized (e.g. missing credentials in dev)
logger.warning("Firebase not initialized, falling back to mock data: %s", e)
return MockRef()
Expand Down
10 changes: 6 additions & 4 deletions my_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from authlib.oauth2.rfc6750 import BearerTokenValidator
from flask import session
from flask_login import current_user
from sqlalchemy import select, delete
from sqlalchemy import select
from models import db
from models import Client, Token, Grant, User

Expand Down Expand Up @@ -51,12 +51,14 @@ def load_client(client_id):
def save_token(token_data, request):
logger.debug("token setter")
# make sure that every client has only one token connected to a user
db.session.execute(
delete(Token).filter_by(
existing_tokens = db.session.execute(
select(Token).filter_by(
client_id=request.client.client_id,
user_id=request.user.id,
)
)
).scalars()
for t in existing_tokens:
db.session.delete(t)

raw_expires_in = token_data.get('expires_in')
try:
Expand Down
60 changes: 27 additions & 33 deletions notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,29 +86,6 @@ def handle_disconnect(_client, _userdata, rc):
_append_mqtt_log('system', f'Unexpected disconnect (rc={rc})', 'Disconnected')


def _handle_status_message(user_id, device_id, payload):
"""Update Firebase device status based on MQTT payload."""
# Guard JSON decoding to avoid noisy errors when payloads are already decoded or non-JSON
state_updates = None

if isinstance(payload, dict):
state_updates = payload
elif isinstance(payload, str):
try:
state_updates = json.loads(payload)
except (ValueError, TypeError):
logger.debug("Non-JSON status payload for %s/%s; skipping Firebase update", user_id, device_id)

if state_updates is not None:
try:
ref = _get_user_device_states_ref(user_id, device_id)
if ref:
ref.update(state_updates)
logger.debug("Updated Firebase status for %s/%s", user_id, device_id)
except Exception as e:
logger.error("Failed to update Firebase from MQTT: %s", e)


@mqtt.on_message()
def handle_messages(_client, _userdata, message):
payload = _decode_payload(message.payload)
Expand All @@ -121,16 +98,33 @@ def handle_messages(_client, _userdata, message):
# {user_id}/{device_id}/status

parts = topic.split('/')
if len(parts) < 3:
_append_mqtt_log(topic, payload, 'Received', user_id=None)
return

user_id = parts[0]
device_id = parts[1]
msg_type = parts[2]

if msg_type == 'status':
_handle_status_message(user_id, device_id, payload)
user_id = None
if len(parts) >= 3:
user_id = parts[0]
device_id = parts[1]
msg_type = parts[2]

# Try to update Firebase if it's a status message
if msg_type == 'status':
# Guard JSON decoding to avoid noisy errors when payloads are already decoded or non-JSON
state_updates = None

if isinstance(payload, dict):
state_updates = payload
elif isinstance(payload, str):
try:
state_updates = json.loads(payload)
except (ValueError, TypeError):
logger.debug("Non-JSON status payload for %s/%s; skipping Firebase update", user_id, device_id)

if state_updates is not None:
try:
ref = _get_user_device_states_ref(user_id, device_id)
if ref:
ref.update(state_updates)
logger.debug("Updated Firebase status for %s/%s", user_id, device_id)
except Exception as e:
logger.error("Failed to update Firebase from MQTT: %s", e)

_append_mqtt_log(topic, payload, 'Received', user_id=user_id)

Expand Down
47 changes: 20 additions & 27 deletions routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,39 +19,32 @@

def _oauth_error_response(exc):
"""Return OAuth errors in a client-friendly way when redirect URI is valid."""
status_code = getattr(exc, 'status_code', 400) or 400
error_response = jsonify(error=exc.error, error_description=exc.description), status_code

client_id = request.values.get('client_id')
redirect_uri = request.values.get('redirect_uri')
state = request.values.get('state')

client = load_client(client_id) if client_id else None
if not client:
return error_response

if not redirect_uri:
redirect_uri = client.get_default_redirect_uri()

if not redirect_uri or not client.check_redirect_uri(redirect_uri):
return error_response
if client:
if not redirect_uri:
redirect_uri = client.get_default_redirect_uri()
if redirect_uri and client.check_redirect_uri(redirect_uri):
parsed = urlparse(redirect_uri)
if parsed.scheme and parsed.netloc:
params = {'error': exc.error}
if exc.description:
params['error_description'] = exc.description
if state:
params['state'] = state

existing_query = dict(parse_qsl(parsed.query, keep_blank_values=True))
existing_query.update(params)
safe_redirect_uri = urlunparse(
parsed._replace(query=urlencode(existing_query))
)
return redirect(safe_redirect_uri)

parsed = urlparse(redirect_uri)
if not (parsed.scheme and parsed.netloc):
return error_response

params = {'error': exc.error}
if exc.description:
params['error_description'] = exc.description
if state:
params['state'] = state

existing_query = dict(parse_qsl(parsed.query, keep_blank_values=True))
existing_query.update(params)
safe_redirect_uri = urlunparse(
parsed._replace(query=urlencode(existing_query))
)
return redirect(safe_redirect_uri)
status_code = getattr(exc, 'status_code', 400) or 400
return jsonify(error=exc.error, error_description=exc.description), status_code


def _resolve_smarthome_user_scope(req):
Expand Down
120 changes: 92 additions & 28 deletions tests/test_firebase_utils.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,109 @@
import unittest
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch
import firebase_utils


class TestFirebaseUtils(unittest.TestCase):
def test_reference_value_error_fallback(self):
# We must clear the module to reload it with our mocked FIREBASE_AVAILABLE state
import sys
if 'firebase_utils' in sys.modules:
del sys.modules['firebase_utils']

# Mock firebase_admin and db to simulate an uninitialized state where db.reference raises ValueError
def test_reference_root_success(self):
# Mock db as it might not be imported if firebase_admin is missing
mock_db = MagicMock()
with patch('firebase_utils.FIREBASE_AVAILABLE', True), \
patch('firebase_utils.db', mock_db, create=True):
ref = firebase_utils.reference()
mock_db.reference.assert_called_with('/devices')
self.assertNotIsInstance(ref, firebase_utils.MockRef)

def test_reference_user_success(self):
mock_db = MagicMock()
mock_db.reference.side_effect = ValueError("The default Firebase app does not exist.")
with patch('firebase_utils.FIREBASE_AVAILABLE', True), \
patch('firebase_utils.db', mock_db, create=True):
ref = firebase_utils.reference(user_id="user123")
mock_db.reference.assert_called_with('/users/user123/devices')
self.assertNotIsInstance(ref, firebase_utils.MockRef)

with patch.dict(sys.modules, {'firebase_admin': MagicMock(db=mock_db)}):
import firebase_utils
def test_reference_exception_path(self):
# Ensure FIREBASE_AVAILABLE is True for this test
mock_db = MagicMock()
mock_db.reference.side_effect = Exception("Firebase initialization error")
with patch('firebase_utils.FIREBASE_AVAILABLE', True), \
patch('firebase_utils.db', mock_db, create=True), \
patch('firebase_utils.logger.warning') as mock_warning:
ref = firebase_utils.reference()

# FIREBASE_AVAILABLE should be True because the import succeeds (mocked)
self.assertTrue(firebase_utils.FIREBASE_AVAILABLE)
# Verify it returns a MockRef instance
self.assertIsInstance(ref, firebase_utils.MockRef)

with patch('firebase_utils.logger.warning') as mock_logger:
ref = firebase_utils.reference()
# Verify the warning was logged
mock_warning.assert_called()
self.assertIn("Firebase not initialized", mock_warning.call_args[0][0])

# Should fallback to MockRef
self.assertIsInstance(ref, firebase_utils.MockRef)
mock_logger.assert_called_once()
self.assertIn("Firebase not initialized", mock_logger.call_args[0][0])
def test_reference_firebase_not_available(self):
with patch('firebase_utils.FIREBASE_AVAILABLE', False):
ref = firebase_utils.reference()
self.assertIsInstance(ref, firebase_utils.MockRef)

def test_reference_other_exception_bubbles_up(self):
import sys
if 'firebase_utils' in sys.modules:
del sys.modules['firebase_utils']
def test_normalize_user_scope(self):
self.assertEqual(firebase_utils._normalize_user_scope("user123"), "user123")
self.assertEqual(firebase_utils._normalize_user_scope(123), "123")
self.assertEqual(firebase_utils._normalize_user_scope(" user123 "), "user123")
self.assertIsNone(firebase_utils._normalize_user_scope(None))
self.assertIsNone(firebase_utils._normalize_user_scope(""))
self.assertIsNone(firebase_utils._normalize_user_scope(" "))
self.assertIsNone(firebase_utils._normalize_user_scope("user/123"))
self.assertIsNone(firebase_utils._normalize_user_scope("user\\123"))
self.assertIsNone(firebase_utils._normalize_user_scope("user..123"))

def test_get_user_device_states_ref_valid(self):
Comment on lines +46 to +57

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Add integration tests that exercise _normalize_user_scope via reference() to prove the normalized value is actually used in paths.

Currently _normalize_user_scope is only tested in isolation. Please add an integration-style test for reference(user_id=...) that asserts the Firebase path uses the normalized value, e.g.:

  • reference(user_id=" user123 ") calls mock_db.reference("/users/user123/devices")
  • reference(user_id=123) calls mock_db.reference("/users/123/devices")
    This guards against regressions where reference stops using _normalize_user_scope correctly.
Suggested change
def test_normalize_user_scope(self):
self.assertEqual(firebase_utils._normalize_user_scope("user123"), "user123")
self.assertEqual(firebase_utils._normalize_user_scope(123), "123")
self.assertEqual(firebase_utils._normalize_user_scope(" user123 "), "user123")
self.assertIsNone(firebase_utils._normalize_user_scope(None))
self.assertIsNone(firebase_utils._normalize_user_scope(""))
self.assertIsNone(firebase_utils._normalize_user_scope(" "))
self.assertIsNone(firebase_utils._normalize_user_scope("user/123"))
self.assertIsNone(firebase_utils._normalize_user_scope("user\\123"))
self.assertIsNone(firebase_utils._normalize_user_scope("user..123"))
def test_get_user_device_states_ref_valid(self):
def test_normalize_user_scope(self):
self.assertEqual(firebase_utils._normalize_user_scope("user123"), "user123")
self.assertEqual(firebase_utils._normalize_user_scope(123), "123")
self.assertEqual(firebase_utils._normalize_user_scope(" user123 "), "user123")
self.assertIsNone(firebase_utils._normalize_user_scope(None))
self.assertIsNone(firebase_utils._normalize_user_scope(""))
self.assertIsNone(firebase_utils._normalize_user_scope(" "))
self.assertIsNone(firebase_utils._normalize_user_scope("user/123"))
self.assertIsNone(firebase_utils._normalize_user_scope("user\\123"))
self.assertIsNone(firebase_utils._normalize_user_scope("user..123"))
def test_reference_uses_normalized_user_scope_str_input(self):
with patch('firebase_utils.FIREBASE_AVAILABLE', True):
# Ensure reference() routes to the real db mock and uses the normalized user_id in the path
firebase_utils.db = mock_db
mock_db.reference.reset_mock()
firebase_utils.reference(user_id=" user123 ")
mock_db.reference.assert_called_once_with("/users/user123/devices")
def test_reference_uses_normalized_user_scope_int_input(self):
with patch('firebase_utils.FIREBASE_AVAILABLE', True):
# Ensure integer user_id is normalized to its string representation in the path
firebase_utils.db = mock_db
mock_db.reference.reset_mock()
firebase_utils.reference(user_id=123)
mock_db.reference.assert_called_once_with("/users/123/devices")
def test_get_user_device_states_ref_valid(self):

mock_db = MagicMock()
# Some other error like network permission denied
mock_db.reference.side_effect = PermissionError("Permission denied.")
mock_ref = MagicMock()
mock_db.reference.return_value = mock_ref
mock_device_ref = MagicMock()
mock_ref.child.return_value = mock_device_ref
mock_states_ref = MagicMock()
mock_device_ref.child.return_value = mock_states_ref

with patch('firebase_utils.FIREBASE_AVAILABLE', True), \
patch('firebase_utils.db', mock_db, create=True):

ref = firebase_utils._get_user_device_states_ref("user123", "device1")

mock_db.reference.assert_called_with('/users/user123/devices')
mock_ref.child.assert_called_with('device1')
mock_device_ref.child.assert_called_with('states')
self.assertEqual(ref, mock_states_ref)

def test_get_user_device_states_ref_invalid(self):
self.assertIsNone(firebase_utils._get_user_device_states_ref(None, "device1"))
self.assertIsNone(firebase_utils._get_user_device_states_ref("user/123", "device1"))
self.assertIsNone(firebase_utils._get_user_device_states_ref("user1", None))
self.assertIsNone(firebase_utils._get_user_device_states_ref("user1", "device/1"))

def test_mock_ref_and_child(self):
ref = firebase_utils.MockRef()
data = ref.get()
self.assertEqual(data, firebase_utils.MOCK_DEVICES)

child = ref.child("test-light-1")
self.assertIsInstance(child, firebase_utils.MockChild)
self.assertEqual(child.get(), firebase_utils.MOCK_DEVICES["test-light-1"])

grandchild = child.child("name")
self.assertEqual(grandchild.get(), firebase_utils.MOCK_DEVICES["test-light-1"]["name"])

# Test non-existent path
self.assertIsNone(ref.child("non-existent").get())

with patch.dict(sys.modules, {'firebase_admin': MagicMock(db=mock_db)}):
import firebase_utils
# Test update
# Careful as MOCK_DEVICES is shared.
original_states = firebase_utils.MOCK_DEVICES["test-light-1"]["states"].copy()
try:
update_data = {"on": not original_states["on"]}
child.child("states").update(update_data)
self.assertEqual(firebase_utils.MOCK_DEVICES["test-light-1"]["states"]["on"], not original_states["on"])
finally:
firebase_utils.MOCK_DEVICES["test-light-1"]["states"] = original_states

with self.assertRaises(PermissionError):
firebase_utils.reference()

if __name__ == '__main__':
unittest.main()
Loading