diff --git a/homeassistant_api/models/domains.py b/homeassistant_api/models/domains.py index 9c8c950..467defe 100644 --- a/homeassistant_api/models/domains.py +++ b/homeassistant_api/models/domains.py @@ -383,9 +383,9 @@ class ServiceFieldSelectorObject(BaseModel): class ServiceFieldSelectorQRCode(BaseModel): data: str scale: Optional[Union[int, float]] = None - error_correction_level: Optional[ - ServiceFieldSelectorQRCodeErrorCorrectionLevel - ] = None + error_correction_level: Optional[ServiceFieldSelectorQRCodeErrorCorrectionLevel] = ( + None + ) center_image: Optional[str] = None @@ -614,14 +614,13 @@ def trigger( async def async_trigger( self, **service_data - ) -> Union[Tuple[State, ...], Tuple[Tuple[State, ...], dict[str, JSONType]]]: + ) -> Union[ + Tuple[State, ...], + None, + dict[str, JSONType], + tuple[tuple[State, ...], dict[str, JSONType]], + ]: """Triggers the service associated with this object.""" - from homeassistant_api import WebsocketClient # prevent circular import - - if isinstance(self.domain._client, WebsocketClient): - raise NotImplementedError( - "WebsocketClient does not support async/await syntax." - ) try: return await self.domain._client.async_trigger_service_with_response( self.domain.domain_id, @@ -647,7 +646,12 @@ def __call__( Coroutine[ Any, Any, - Union[Tuple[State, ...], Tuple[Tuple[State, ...], dict[str, JSONType]]], + Union[ + Tuple[State, ...], + Tuple[Tuple[State, ...], dict[str, JSONType]], + dict[str, JSONType], + None, + ], ], ]: """ diff --git a/homeassistant_api/rawasyncwebsocket.py b/homeassistant_api/rawasyncwebsocket.py new file mode 100644 index 0000000..6c99ff6 --- /dev/null +++ b/homeassistant_api/rawasyncwebsocket.py @@ -0,0 +1,667 @@ +import contextlib +import json +import logging +import time +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Dict, + Optional, + Tuple, + Union, + cast, +) + +import websockets.asyncio.client as ws +from pydantic import ValidationError + +from homeassistant_api.errors import ( + ReceivingError, + ResponseError, + UnauthorizedError, +) +from homeassistant_api.models import ( + ConfigEntry, + ConfigEntryEvent, + ConfigSubEntry, + Domain, + Entity, + Group, + State, +) +from homeassistant_api.models.config_entries import DisableEnableResult, FlowResult +from homeassistant_api.models.states import Context +from homeassistant_api.models.websocket import ( + AuthInvalid, + AuthOk, + AuthRequired, + EventResponse, + FiredEvent, + FiredTrigger, + PingResponse, + ResultResponse, + TemplateEvent, +) +from homeassistant_api.rawbasewebsocket import RawBaseWebsocketClient +from homeassistant_api.utils import JSONType, prepare_entity_id + +if TYPE_CHECKING: + from homeassistant_api import WebsocketClient +else: + WebsocketClient = None # pylint: disable=invalid-name + +logger = logging.getLogger(__name__) + + +class RawAsyncWebsocketClient(RawBaseWebsocketClient): + _async_conn: Optional[ws.ClientConnection] + + def __init__(self, api_url: str, token: str) -> None: + super().__init__(api_url, token) + self._async_conn = None + + async def __aenter__(self): + self._async_conn = await ws.connect(self.api_url) + await self._async_conn.__aenter__() + okay = await self.async_authentication_phase() + logging.info("Authenticated with Home Assistant (%s)", okay.ha_version) + await self.async_supported_features_phase() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + if not self._async_conn: + raise ReceivingError("Connection is not open!") + await self._async_conn.__aexit__(exc_type, exc_value, traceback) + self._async_conn = None + + async def _async_send(self, data: dict[str, JSONType]) -> None: + """Send a message to the websocket server.""" + logger.debug(f"Sending message: {data}") + if self._async_conn is None: + raise ReceivingError("Connection is not open!") + await self._async_conn.send(json.dumps(data)) + + async def _async_recv(self) -> dict[str, JSONType]: + """Receive a message from the websocket server.""" + if self._async_conn is None: + raise ReceivingError("Connection is not open!") + _bytes = await self._async_conn.recv() + logger.debug("Received message: %s", _bytes) + return cast(dict[str, JSONType], json.loads(_bytes)) + + async def async_send(self, type: str, include_id: bool = True, **data: Any) -> int: + """ + Send a command message to the websocket server and wait for a "result" response. + + Returns the id of the message sent. + """ + if include_id: # auth messages don't have an id + data["id"] = self._request_id() + + data["type"] = type + await self._async_send(data) + + if "id" in data: + assert isinstance(data["id"], int) + if data["type"] == "ping": + self._ping_responses[data["id"]] = PingResponse( + start=time.perf_counter_ns(), + id=data["id"], + type="pong", + ) + else: + self._event_responses[data["id"]] = [] + self._result_responses[data["id"]] = None + return data["id"] + return -1 # non-command messages don't have an id + + async def async_recv( + self, id: int + ) -> Union[EventResponse, ResultResponse, PingResponse]: + """Receive a response to a message from the websocket server.""" + while True: + ## have we received a message with the id we're looking for? + if self._result_responses.get(id) is not None: + return cast(dict[int, ResultResponse], self._result_responses).pop( + id + ) # ughhh why can't mypy figure this out + if self._event_responses.get(id, []): + return self._event_responses[id].pop(0) + if self._ping_responses.get(id) is not None: + if self._ping_responses[id].end is not None: + return self._ping_responses.pop(id) + + ## if not, keep receiving messages until we do + self.handle_recv(await self._async_recv()) + + async def async_authentication_phase(self) -> AuthOk: + """Authenticate with the websocket server.""" + # Capture the first message from the server saying we need to authenticate + try: + welcome = AuthRequired.model_validate(await self._async_recv()) + logger.debug(f"Received welcome message: {welcome}") + except ValidationError as e: + raise ResponseError("Unexpected response during authentication") from e + + # Send our authentication token + await self.async_send("auth", access_token=self.token, include_id=False) + logger.debug("Sent auth message") + + # Check the response + resp = await self._async_recv() + try: + return AuthOk.model_validate(resp) + except ValidationError as e: + error_resp = AuthInvalid.model_validate(resp) + raise UnauthorizedError(error_resp.message) from e + except Exception as e: + raise ResponseError( + "Unexpected response during authentication", resp["message"] + ) from e + + async def async_supported_features_phase(self) -> None: + """Get the supported features from the websocket server.""" + resp = await self.async_recv( + await self.async_send( + "supported_features", + features={ + # "coalesce_messages": 42, # including this key sets it to True + }, + ) + ) + assert cast(ResultResponse, resp).result is None + + async def async_ping_latency(self) -> float: + """Get the latency (in milliseconds) of the connection by sending a ping message.""" + pong = cast(PingResponse, await self.async_recv(await self.async_send("ping"))) + assert pong.end is not None + return (pong.end - pong.start) / 1_000_000 + + async def async_get_rendered_template(self, template: str) -> str: + """ + Renders a Jinja2 template with Home Assistant context data. + See https://www.home-assistant.io/docs/configuration/templating. + + Sends command :code:`{"type": "render_template", ...}`. + """ + id = await self.async_send( + "render_template", template=template, report_errors=True + ) + first = await self.async_recv(id) + assert cast(ResultResponse, first).result is None + second = await self.async_recv(id) + await self._async_unsubscribe(id) + return cast(TemplateEvent, cast(EventResponse, second).event).result + + async def async_get_config(self) -> dict[str, JSONType]: + """ + Get the Home Assistant configuration. + + Sends command :code:`{"type": "get_config", ...}`. + """ + return cast( + dict[str, JSONType], + cast( + ResultResponse, + await self.async_recv(await self.async_send("get_config")), + ).result, + ) + + async def async_get_states(self) -> Tuple[State, ...]: + """ + Get a list of states. + + Sends command :code:`{"type": "get_states", ...}`. + """ + return tuple( + State.from_json(state) + for state in cast( + list[dict[str, JSONType]], + cast( + ResultResponse, + await self.async_recv(await self.async_send("get_states")), + ).result, + ) + ) + + async def async_get_state( # pylint: disable=duplicate-code + self, + *, + entity_id: Optional[str] = None, + group_id: Optional[str] = None, + slug: Optional[str] = None, + ) -> State: + """ + Just calls the :py:meth:`get_states` method and filters the result. + + Please tell home-assistant/core to add a :code:`{"type": "get_state", ...}` command to the WS API! + There is a lot of disappointment and frustration in the community because this is not available. + """ + entity_id = prepare_entity_id( + group_id=group_id, + slug=slug, + entity_id=entity_id, + ) + + for state in await self.async_get_states(): + if state.entity_id == entity_id: + return state + raise ValueError(f"Entity {entity_id} not found!") + + async def async_get_entities(self) -> Dict[str, Group]: + """ + Fetches all entities from the Websocket API and returns them as a dictionary of :py:class:`Group`'s. + For example :code:`light.living_room` would be in the group :code:`light` (i.e. :code:`get_entities()["light"].living_room`). + """ + entities: Dict[str, Group] = {} + for state in await self.async_get_states(): + group_id, entity_slug = state.entity_id.split(".") + if group_id not in entities: + entities[group_id] = Group( + group_id=group_id, + _client=self, # type: ignore[arg-type] + ) + entities[group_id]._add_entity(entity_slug, state) + return entities + + async def async_get_entity( + self, + group_id: Optional[str] = None, + slug: Optional[str] = None, + entity_id: Optional[str] = None, + ) -> Optional[Entity]: + """ + Returns an :py:class:`Entity` model for an :code:`entity_id`. + + Calls :py:meth:`get_states` under the hood. + + Please tell home-assistant/core to add a :code:`{"type": "get_state", ...}` command to the WS API! + There is a lot of disappointment and frustration in the community because this is not available. + """ + if group_id is not None and slug is not None: + state = await self.async_get_state(group_id=group_id, slug=slug) + elif entity_id is not None: + state = await self.async_get_state(entity_id=entity_id) + else: + help_msg = ( + "Use keyword arguments to pass entity_id. " + "Or you can pass the group_id and slug instead" + ) + raise ValueError( + f"Neither group_id and slug or entity_id provided. {help_msg}" + ) + split_group_id, split_slug = state.entity_id.split(".") + group = Group( + group_id=split_group_id, + _client=self, # type: ignore[arg-type] + ) + group._add_entity(split_slug, state) + return group.get_entity(split_slug) + + async def async_get_domains(self) -> dict[str, Domain]: + """ + Get a list of services that Home Assistant offers (organized into a dictionary of service domains). + + For example, the service :code:`light.turn_on` would be in the domain :code:`light`. + + Sends command :code:`{"type": "get_services", ...}`. + """ + resp = await self.async_recv(await self.async_send("get_services")) + domains = map( + lambda item: Domain.from_json_with_client( + {"domain": item[0], "services": item[1]}, + client=cast(WebsocketClient, self), + ), + cast(dict[str, JSONType], cast(ResultResponse, resp).result).items(), + ) + return {domain.domain_id: domain for domain in domains} + + async def async_get_domain(self, domain: str) -> Domain: + """Get a domain. + + Note: This is not a method in the WS API client... yet. + + Please tell home-assistant/core to add a `get_domain` command to the WS API! + + For now, just call the :py:meth":`get_domains` method and parsing the result. + """ + return (await self.async_get_domains())[domain] + + async def async_trigger_service( + self, + domain: str, + service: str, + entity_id: Optional[str] = None, + **service_data, + ) -> None: + """ + Trigger a service (that doesn't return a response). + + Sends command :code:`{"type": "call_service", ...}`. + """ + params = { + "domain": domain, + "service": service, + "service_data": service_data, + "return_response": False, + } + if entity_id is not None: + params["target"] = {"entity_id": entity_id} + + data = await self.async_recv( + await self.async_send("call_service", include_id=True, **params) + ) + + # TODO: handle data["result"]["context"] ? + + assert ( + cast( + dict[str, JSONType], + cast(ResultResponse, data).result, + ).get("response") + is None + ) # should always be None for services without a response + + async def async_trigger_service_with_response( + self, + domain: str, + service: str, + entity_id: Optional[str] = None, + **service_data, + ) -> dict[str, JSONType]: + """ + Trigger a service (that returns a response) and return the response. + + Sends command :code:`{"type": "call_service", ...}`. + """ + params = { + "domain": domain, + "service": service, + "service_data": service_data, + "return_response": True, + } + if entity_id is not None: + params["target"] = {"entity_id": entity_id} + + data = await self.async_recv( + await self.async_send("call_service", include_id=True, **params) + ) + + return cast(dict[str, dict[str, JSONType]], cast(ResultResponse, data).result)[ + "response" + ] + + @contextlib.asynccontextmanager + async def async_listen_events( + self, + event_type: Optional[str] = None, + ) -> AsyncGenerator[AsyncGenerator[FiredEvent, None], None]: + """ + Listen for all events of a certain type. + + For example, to listen for all events of type `test_event`: + + .. code-block:: python + + async with ws_client.listen_events("test_event") as events: + async for i, event in zip(range(2), events): # to only wait for two events to be received + print(event) + """ + subscription = await self._async_subscribe_events(event_type) + yield cast(AsyncGenerator[FiredEvent, None], self._async_wait_for(subscription)) + await self._async_unsubscribe(subscription) + + async def _async_subscribe_events(self, event_type: Optional[str]) -> int: + """ + Subscribe to all events of a certain type. + + + Sends command :code:`{"type": "subscribe_events", ...}`. + """ + params = {"event_type": event_type} if event_type else {} + return ( + await self.async_recv( + await self.async_send("subscribe_events", include_id=True, **params) + ) + ).id + + @contextlib.asynccontextmanager + async def async_listen_trigger( + self, trigger: str, **trigger_fields + ) -> AsyncGenerator[AsyncGenerator[dict[str, JSONType], None], None]: + """ + Listen to a Home Assistant trigger. + Allows additional trigger keyword parameters with :code:`**kwargs` (i.e. passing :code:`tag_id=...` for NFC tag triggers). + + For example, in Home Assistant Automations we can subscribe to a state trigger for a light entity with YAML: + + .. code-block:: yaml + + triggers: + # ... + - trigger: state + entity_id: light.kitchen + + To subscribe to that same state trigger with :py:class:`AsyncWebsocketClient` instead + + .. code-block:: python + + async with ws_client.listen_trigger("state", entity_id="light.kitchen") as trigger: + async for event in trigger: # will iterate until we manually break out of the loop + print(event) + if : + break + # exiting the context manager unsubscribes from the trigger + + Woohoo! We can now listen to triggers in Python code! + """ + subscription = await self._async_subscribe_trigger(trigger, **trigger_fields) + yield ( + fired_trigger.variables + async for fired_trigger in cast( + AsyncGenerator[FiredTrigger, None], + self._async_wait_for(subscription), + ) + ) + await self._async_unsubscribe(subscription) + + async def _async_subscribe_trigger(self, trigger: str, **trigger_fields) -> int: + """ + Return the subscription id of the trigger we subscribe to. + + Sends command :code:`{"type": "subscribe_trigger", ...}`. + """ + return ( + await self.async_recv( + await self.async_send( + "subscribe_trigger", trigger={"platform": trigger, **trigger_fields} + ) + ) + ).id + + async def _async_wait_for( + self, subscription_id: int + ) -> AsyncGenerator[Union[FiredEvent, FiredTrigger], None]: + """ + An iterator that waits for events of a certain type. + """ + while True: + yield cast( + Union[ + FiredEvent, FiredTrigger + ], # we can cast this because TemplateEvent is only used for rendering templates + cast(EventResponse, await self.async_recv(subscription_id)).event, + ) + + async def _async_unsubscribe(self, subcription_id: int) -> None: + """ + Unsubscribe from all events of a certain type. + + Sends command :code:`{"type": "unsubscribe_events", ...}`. + """ + resp = await self.async_recv( + await self.async_send("unsubscribe_events", subscription=subcription_id) + ) + assert cast(ResultResponse, resp).result is None + self._event_responses.pop(subcription_id) + + async def async_get_config_entries(self) -> Tuple[ConfigEntry, ...]: + """ + Get all config entries. + + Sends command :code:`{"type": "config_entries/get", ...}`. + """ + resp = await self.async_recv(await self.async_send("config_entries/get")) + return tuple( + ConfigEntry.from_json(entry) + for entry in cast( + list[dict[str, JSONType]], + cast(ResultResponse, resp).result, + ) + ) + + async def async_disable_config_entry(self, entry_id: str) -> DisableEnableResult: + """ + Disable a config entry. + + Sends command :code:`{"type": "config_entries/disable", ...}`. + """ + resp = await self.async_recv( + await self.async_send( + "config_entries/disable", + entry_id=entry_id, + disabled_by="user", + ) + ) + return DisableEnableResult.from_json( + cast(dict[str, JSONType], cast(ResultResponse, resp).result) + ) + + async def async_enable_config_entry(self, entry_id: str) -> DisableEnableResult: + """ + Enable a config entry. + + Sends command :code:`{"type": "config_entries/disable", ...}`. + """ + resp = await self.async_recv( + await self.async_send( + "config_entries/disable", + entry_id=entry_id, + disabled_by=None, + ) + ) + return DisableEnableResult.from_json( + cast(dict[str, JSONType], cast(ResultResponse, resp).result) + ) + + async def async_ignore_config_flow(self, flow_id: str, title: str) -> None: + """ + Ignore a config flow. + + Sends command :code:`{"type": "config_entries/ignore_flow", ...}`. + """ + await self.async_recv( + await self.async_send( + "config_entries/ignore_flow", + flow_id=flow_id, + title=title, + ) + ) + + async def async_get_nonuser_flows_in_progress(self) -> Tuple[FlowResult, ...]: + """ + Get non-user config flows in progress. + + Sends command :code:`{"type": "config_entries/flow/progress", ...}`. + """ + resp = await self.async_recv( + await self.async_send("config_entries/flow/progress") + ) + return tuple( + FlowResult.from_json(flow) + for flow in cast( + list[dict[str, JSONType]], + cast(ResultResponse, resp).result, + ) + ) + + async def async_get_entry_subentries( + self, entry_id: str + ) -> Tuple[ConfigSubEntry, ...]: + """ + Get subentries for a config entry. + + Sends command :code:`{"type": "config_entries/subentries/list", ...}`. + """ + resp = await self.async_recv( + await self.async_send("config_entries/subentries/list", entry_id=entry_id) + ) + return tuple( + ConfigSubEntry.from_json(subentry) + for subentry in cast( + list[dict[str, JSONType]], + cast(ResultResponse, resp).result, + ) + ) + + async def async_delete_entry_subentry( + self, entry_id: str, subentry_id: str + ) -> None: + """ + Delete a subentry from a config entry. + + Sends command :code:`{"type": "config_entries/subentries/delete", ...}`. + """ + await self.async_recv( + await self.async_send( + "config_entries/subentries/delete", + entry_id=entry_id, + subentry_id=subentry_id, + ) + ) + + @contextlib.asynccontextmanager + async def async_listen_config_entries( + self, + ) -> AsyncGenerator[AsyncGenerator[list[ConfigEntryEvent], None], None]: + """ + Listen for config entry changes. + + Sends command :code:`{"type": "config_entries/subscribe", ...}`. + """ + subscription = ( + await self.async_recv(await self.async_send("config_entries/subscribe")) + ).id + yield self._async_wait_for_config_entries(subscription) + await self._async_unsubscribe(subscription) + + async def _async_wait_for_config_entries( + self, subscription_id: int + ) -> AsyncGenerator[list[ConfigEntryEvent], None]: + """An async iterator that waits for config entry events.""" + while True: + event_resp = cast(EventResponse, await self.async_recv(subscription_id)) + entries = cast(list[dict[str, JSONType]], event_resp.event) + yield [ConfigEntryEvent.from_json(entry) for entry in entries] + + async def async_fire_event(self, event_type: str, **event_data) -> Context: + """ + Fire an event. + + Sends command :code:`{"type": "fire_event", ...}`. + """ + params: dict[str, JSONType] = {"event_type": event_type} + if event_data: + params["event_data"] = event_data + return Context.from_json( + cast( + dict[str, dict[str, JSONType]], + cast( + ResultResponse, + await self.async_recv( + await self.async_send("fire_event", include_id=True, **params) + ), + ).result, + )["context"] + ) diff --git a/homeassistant_api/rawbasewebsocket.py b/homeassistant_api/rawbasewebsocket.py new file mode 100644 index 0000000..f959d50 --- /dev/null +++ b/homeassistant_api/rawbasewebsocket.py @@ -0,0 +1,82 @@ +import logging +import time +from typing import Optional, cast + +from pydantic import ValidationError + +from homeassistant_api.errors import ( + ReceivingError, + RequestError, +) +from homeassistant_api.models.websocket import ( + ErrorResponse, + EventResponse, + PingResponse, + ResultResponse, +) +from homeassistant_api.utils import JSONType + +logger = logging.getLogger(__name__) + + +class RawBaseWebsocketClient: + """Shared methods for Websocket clients.""" + + api_url: str + token: str + _id_counter: int + _result_responses: dict[int, Optional[ResultResponse]] + _event_responses: dict[int, list[EventResponse]] + _ping_responses: dict[int, PingResponse] + + def __init__(self, api_url: str, token: str) -> None: + self.api_url = api_url + self.token = token.strip() + + self._id_counter = 0 + self._result_responses = {} # id -> response + self._event_responses = {} # id -> [response, ...] + self._ping_responses = {} # id -> (sent, received) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.api_url!r})" + + def _request_id(self) -> int: + """Get a unique id for a message.""" + self._id_counter += 1 + return self._id_counter + + def check_success(self, data: dict[str, JSONType]) -> None: + """Check if a command message was successful.""" + try: + error_resp = ErrorResponse.model_validate(data) + raise RequestError(error_resp.error.code, error_resp.error.message) + except ValidationError: + pass + + def handle_recv(self, data: dict[str, JSONType]) -> None: + """Handle a received message.""" + if "id" not in data: + raise ReceivingError( + "Received a message without an id outside the auth phase." + ) + self.check_success(data) + self.parse_response(data) + + def parse_response(self, data: dict[str, JSONType]) -> None: + data_id = cast(int, data["id"]) + if data.get("type") == "pong": + logger.info("Received pong message") + self._ping_responses[data_id].end = time.perf_counter_ns() + elif data.get("type") == "result": + logger.info("Received result message") + if data.get("success"): + self._result_responses[data_id] = ResultResponse.model_validate(data) + else: + error_resp = ErrorResponse.model_validate(data) + raise RequestError(error_resp.error.code, error_resp.error.message) + elif data.get("type") == "event": + logger.info("Received event message %s", data["event"]) + self._event_responses[data_id].append(EventResponse.model_validate(data)) + else: + raise ReceivingError(f"Received unexpected message type: {data}") diff --git a/homeassistant_api/rawwebsocket.py b/homeassistant_api/rawwebsocket.py index 09e702f..3e71d85 100644 --- a/homeassistant_api/rawwebsocket.py +++ b/homeassistant_api/rawwebsocket.py @@ -1,43 +1,55 @@ +import contextlib import json import logging import time -from typing import Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple, Union, cast import websockets.sync.client as ws from pydantic import ValidationError from homeassistant_api.errors import ( ReceivingError, - RequestError, ResponseError, UnauthorizedError, ) +from homeassistant_api.models import ( + ConfigEntry, + ConfigEntryEvent, + ConfigSubEntry, + Domain, + Entity, + Group, + State, +) +from homeassistant_api.models.config_entries import DisableEnableResult, FlowResult +from homeassistant_api.models.states import Context from homeassistant_api.models.websocket import ( AuthInvalid, AuthOk, AuthRequired, - ErrorResponse, EventResponse, + FiredEvent, + FiredTrigger, PingResponse, ResultResponse, + TemplateEvent, ) -from homeassistant_api.utils import JSONType +from homeassistant_api.rawbasewebsocket import RawBaseWebsocketClient +from homeassistant_api.utils import JSONType, prepare_entity_id + +if TYPE_CHECKING: + from homeassistant_api import WebsocketClient +else: + WebsocketClient = None # pylint: disable=invalid-name logger = logging.getLogger(__name__) -class RawWebsocketClient: - api_url: str - token: str +class RawWebsocketClient(RawBaseWebsocketClient): _conn: Optional[ws.ClientConnection] - def __init__( - self, - api_url: str, - token: str, - ) -> None: - self.api_url = api_url - self.token = token.strip() + def __init__(self, api_url: str, token: str) -> None: + super().__init__(api_url, token) self._conn = None self._id_counter = 0 @@ -66,11 +78,6 @@ def __exit__(self, exc_type, exc_value, traceback): self._conn.__exit__(exc_type, exc_value, traceback) self._conn = None - def _request_id(self) -> int: - """Get a unique id for a message.""" - self._id_counter += 1 - return self._id_counter - def _send(self, data: dict[str, JSONType]) -> None: """Send a message to the websocket server.""" logger.debug(f"Sending message: {data}") @@ -112,41 +119,6 @@ def send(self, type: str, include_id: bool = True, **data: Any) -> int: return data["id"] return -1 # non-command messages don't have an id - def check_success(self, data: dict[str, JSONType]) -> None: - """Check if a command message was successful.""" - try: - error_resp = ErrorResponse.model_validate(data) - raise RequestError(error_resp.error.code, error_resp.error.message) - except ValidationError: - pass - - def handle_recv(self, data: dict[str, JSONType]) -> None: - """Handle a received message.""" - if "id" not in data: - raise ReceivingError( - "Received a message without an id outside the auth phase." - ) - self.check_success(data) - self.parse_response(data) - - def parse_response(self, data: dict[str, JSONType]) -> None: - data_id = cast(int, data["id"]) - if data.get("type") == "pong": - logger.info("Received pong message") - self._ping_responses[data_id].end = time.perf_counter_ns() - elif data.get("type") == "result": - logger.info("Received result message") - if data.get("success"): - self._result_responses[data_id] = ResultResponse.model_validate(data) - else: - error_resp = ErrorResponse.model_validate(data) - raise RequestError(error_resp.error.code, error_resp.error.message) - elif data.get("type") == "event": - logger.info("Received event message %s", data["event"]) - self._event_responses[data_id].append(EventResponse.model_validate(data)) - else: - raise ReceivingError(f"Received unexpected message type: {data}") - def recv(self, id: int) -> Union[EventResponse, ResultResponse, PingResponse]: """Receive a response to a message from the websocket server.""" while True: @@ -206,3 +178,462 @@ def ping_latency(self) -> float: pong = cast(PingResponse, self.recv(self.send("ping"))) assert pong.end is not None return (pong.end - pong.start) / 1_000_000 + + def get_rendered_template(self, template: str) -> str: + """ + Renders a Jinja2 template with Home Assistant context data. + See https://www.home-assistant.io/docs/configuration/templating. + + Sends command :code:`{"type": "render_template", ...}`. + """ + id = self.send("render_template", template=template, report_errors=True) + first = self.recv(id) + assert cast(ResultResponse, first).result is None + second = self.recv(id) + self._unsubscribe(id) + return cast(TemplateEvent, cast(EventResponse, second).event).result + + def get_config(self) -> dict[str, JSONType]: + """ + Get the Home Assistant configuration. + + Sends command :code:`{"type": "get_config", ...}`. + """ + return cast( + dict[str, JSONType], + cast( + ResultResponse, + self.recv(self.send("get_config")), + ).result, + ) + + def get_states(self) -> Tuple[State, ...]: + """ + Get a list of states. + + Sends command :code:`{"type": "get_states", ...}`. + """ + return tuple( + State.from_json(state) + for state in cast( + list[dict[str, JSONType]], + cast(ResultResponse, self.recv(self.send("get_states"))).result, + ) + ) + + def get_state( # pylint: disable=duplicate-code + self, + *, + entity_id: Optional[str] = None, + group_id: Optional[str] = None, + slug: Optional[str] = None, + ) -> State: + """ + Just calls the :py:meth:`get_states` method and filters the result. + + Please tell home-assistant/core to add a :code:`{"type": "get_state", ...}` command to the WS API! + There is a lot of disappointment and frustration in the community because this is not available. + """ + entity_id = prepare_entity_id( + group_id=group_id, + slug=slug, + entity_id=entity_id, + ) + + for state in self.get_states(): + if state.entity_id == entity_id: + return state + raise ValueError(f"Entity {entity_id} not found!") + + def get_entities(self) -> Dict[str, Group]: + """ + Fetches all entities from the Websocket API and returns them as a dictionary of :py:class:`Group`'s. + For example :code:`light.living_room` would be in the group :code:`light` (i.e. :code:`get_entities()["light"].living_room`). + """ + entities: Dict[str, Group] = {} + for state in self.get_states(): + group_id, entity_slug = state.entity_id.split(".") + if group_id not in entities: + entities[group_id] = Group( + group_id=group_id, + _client=self, # type: ignore[arg-type] + ) + entities[group_id]._add_entity(entity_slug, state) + return entities + + def get_entity( + self, + group_id: Optional[str] = None, + slug: Optional[str] = None, + entity_id: Optional[str] = None, + ) -> Optional[Entity]: + """ + Returns an :py:class:`Entity` model for an :code:`entity_id`. + + Calls :py:meth:`get_states` under the hood. + + Please tell home-assistant/core to add a :code:`{"type": "get_state", ...}` command to the WS API! + There is a lot of disappointment and frustration in the community because this is not available. + """ + if group_id is not None and slug is not None: + state = self.get_state(group_id=group_id, slug=slug) + elif entity_id is not None: + state = self.get_state(entity_id=entity_id) + else: + help_msg = ( + "Use keyword arguments to pass entity_id. " + "Or you can pass the group_id and slug instead" + ) + raise ValueError( + f"Neither group_id and slug or entity_id provided. {help_msg}" + ) + split_group_id, split_slug = state.entity_id.split(".") + group = Group( + group_id=split_group_id, + _client=self, # type: ignore[arg-type] + ) + group._add_entity(split_slug, state) + return group.get_entity(split_slug) + + def get_domains(self) -> dict[str, Domain]: + """ + Get a list of services that Home Assistant offers (organized into a dictionary of service domains). + + For example, the service :code:`light.turn_on` would be in the domain :code:`light`. + + Sends command :code:`{"type": "get_services", ...}`. + """ + resp = self.recv(self.send("get_services")) + domains = map( + lambda item: Domain.from_json_with_client( + {"domain": item[0], "services": item[1]}, + client=cast(WebsocketClient, self), + ), + cast(dict[str, JSONType], cast(ResultResponse, resp).result).items(), + ) + return {domain.domain_id: domain for domain in domains} + + def get_domain(self, domain: str) -> Domain: + """Get a domain. + + Note: This is not a method in the WS API client... yet. + + Please tell home-assistant/core to add a `get_domain` command to the WS API! + + For now, just call the :py:meth":`get_domains` method and parsing the result. + """ + return self.get_domains()[domain] + + def trigger_service( + self, + domain: str, + service: str, + entity_id: Optional[str] = None, + **service_data, + ) -> None: + """ + Trigger a service (that doesn't return a response). + + Sends command :code:`{"type": "call_service", ...}`. + """ + params = { + "domain": domain, + "service": service, + "service_data": service_data, + "return_response": False, + } + if entity_id is not None: + params["target"] = {"entity_id": entity_id} + + data = self.recv(self.send("call_service", include_id=True, **params)) + + # TODO: handle data["result"]["context"] ? + + assert ( + cast( + dict[str, JSONType], + cast(ResultResponse, data).result, + ).get("response") + is None + ) # should always be None for services without a response + + def trigger_service_with_response( + self, + domain: str, + service: str, + entity_id: Optional[str] = None, + **service_data, + ) -> dict[str, JSONType]: + """ + Trigger a service (that returns a response) and return the response. + + Sends command :code:`{"type": "call_service", ...}`. + """ + params = { + "domain": domain, + "service": service, + "service_data": service_data, + "return_response": True, + } + if entity_id is not None: + params["target"] = {"entity_id": entity_id} + + data = self.recv(self.send("call_service", include_id=True, **params)) + + return cast(dict[str, dict[str, JSONType]], cast(ResultResponse, data).result)[ + "response" + ] + + @contextlib.contextmanager + def listen_events( + self, + event_type: Optional[str] = None, + ) -> Generator[Generator[FiredEvent, None, None], None, None]: + """ + Listen for all events of a certain type. + + For example, to listen for all events of type `test_event`: + + .. code-block:: python + + with ws_client.listen_events("test_event") as events: + for i, event in zip(range(2), events): # to only wait for two events to be received + print(event) + """ + subscription = self._subscribe_events(event_type) + yield cast(Generator[FiredEvent, None, None], self._wait_for(subscription)) + self._unsubscribe(subscription) + + def _subscribe_events(self, event_type: Optional[str]) -> int: + """ + Subscribe to all events of a certain type. + + + Sends command :code:`{"type": "subscribe_events", ...}`. + """ + params = {"event_type": event_type} if event_type else {} + return self.recv(self.send("subscribe_events", include_id=True, **params)).id + + @contextlib.contextmanager + def listen_trigger( + self, trigger: str, **trigger_fields + ) -> Generator[Generator[dict[str, JSONType], None, None], None, None]: + """ + Listen to a Home Assistant trigger. + Allows additional trigger keyword parameters with :code:`**kwargs` (i.e. passing :code:`tag_id=...` for NFC tag triggers). + + For example, in Home Assistant Automations we can subscribe to a state trigger for a light entity with YAML: + + .. code-block:: yaml + + triggers: + # ... + - trigger: state + entity_id: light.kitchen + + To subscribe to that same state trigger with :py:class:`WebsocketClient` instead + + .. code-block:: python + + with ws_client.listen_trigger("state", entity_id="light.kitchen") as trigger: + for event in trigger: # will iterate until we manually break out of the loop + print(event) + if : + break + # exiting the context manager unsubscribes from the trigger + + Woohoo! We can now listen to triggers in Python code! + """ + subscription = self._subscribe_trigger(trigger, **trigger_fields) + yield ( + fired_trigger.variables + for fired_trigger in cast( + Generator[FiredTrigger, None, None], + self._wait_for(subscription), + ) + ) + self._unsubscribe(subscription) + + def _subscribe_trigger(self, trigger: str, **trigger_fields) -> int: + """ + Return the subscription id of the trigger we subscribe to. + + Sends command :code:`{"type": "subscribe_trigger", ...}`. + """ + return self.recv( + self.send( + "subscribe_trigger", trigger={"platform": trigger, **trigger_fields} + ) + ).id + + def _wait_for( + self, subscription_id: int + ) -> Generator[Union[FiredEvent, FiredTrigger], None, None]: + """ + An iterator that waits for events of a certain type. + """ + while True: + yield cast( + Union[ + FiredEvent, FiredTrigger + ], # we can cast this because TemplateEvent is only used for rendering templates + cast(EventResponse, self.recv(subscription_id)).event, + ) + + def _unsubscribe(self, subcription_id: int) -> None: + """ + Unsubscribe from all events of a certain type. + + Sends command :code:`{"type": "unsubscribe_events", ...}`. + """ + resp = self.recv(self.send("unsubscribe_events", subscription=subcription_id)) + assert cast(ResultResponse, resp).result is None + self._event_responses.pop(subcription_id) + + def get_config_entries(self) -> Tuple[ConfigEntry, ...]: + """ + Get all config entries. + + Sends command :code:`{"type": "config_entries/get", ...}`. + """ + resp = self.recv(self.send("config_entries/get")) + return tuple( + ConfigEntry.from_json(entry) + for entry in cast( + list[dict[str, JSONType]], + cast(ResultResponse, resp).result, + ) + ) + + def disable_config_entry(self, entry_id: str) -> DisableEnableResult: + """ + Disable a config entry. + + Sends command :code:`{"type": "config_entries/disable", ...}`. + """ + resp = self.recv( + self.send( + "config_entries/disable", + entry_id=entry_id, + disabled_by="user", + ) + ) + return DisableEnableResult.from_json( + cast(dict[str, JSONType], cast(ResultResponse, resp).result) + ) + + def enable_config_entry(self, entry_id: str) -> DisableEnableResult: + """ + Enable a config entry. + + Sends command :code:`{"type": "config_entries/disable", ...}`. + """ + resp = self.recv( + self.send( + "config_entries/disable", + entry_id=entry_id, + disabled_by=None, + ) + ) + return DisableEnableResult.from_json( + cast(dict[str, JSONType], cast(ResultResponse, resp).result) + ) + + def ignore_config_flow(self, flow_id: str, title: str) -> None: + """ + Ignore a config flow. + + Sends command :code:`{"type": "config_entries/ignore_flow", ...}`. + """ + self.recv( + self.send( + "config_entries/ignore_flow", + flow_id=flow_id, + title=title, + ) + ) + + def get_nonuser_flows_in_progress(self) -> Tuple[FlowResult, ...]: + """ + Get non-user config flows in progress. + + Sends command :code:`{"type": "config_entries/flow/progress", ...}`. + """ + resp = self.recv(self.send("config_entries/flow/progress")) + return tuple( + FlowResult.from_json(flow) + for flow in cast( + list[dict[str, JSONType]], + cast(ResultResponse, resp).result, + ) + ) + + def get_entry_subentries(self, entry_id: str) -> Tuple[ConfigSubEntry, ...]: + """ + Get subentries for a config entry. + + Sends command :code:`{"type": "config_entries/subentries/list", ...}`. + """ + resp = self.recv(self.send("config_entries/subentries/list", entry_id=entry_id)) + return tuple( + ConfigSubEntry.from_json(subentry) + for subentry in cast( + list[dict[str, JSONType]], + cast(ResultResponse, resp).result, + ) + ) + + def delete_entry_subentry(self, entry_id: str, subentry_id: str) -> None: + """ + Delete a subentry from a config entry. + + Sends command :code:`{"type": "config_entries/subentries/delete", ...}`. + """ + self.recv( + self.send( + "config_entries/subentries/delete", + entry_id=entry_id, + subentry_id=subentry_id, + ) + ) + + @contextlib.contextmanager + def listen_config_entries( + self, + ) -> Generator[Generator[list[ConfigEntryEvent], None, None], None, None]: + """ + Listen for config entry changes. + + Sends command :code:`{"type": "config_entries/subscribe", ...}`. + """ + subscription = self.recv(self.send("config_entries/subscribe")).id + yield self._wait_for_config_entries(subscription) + self._unsubscribe(subscription) + + def _wait_for_config_entries( + self, subscription_id: int + ) -> Generator[list[ConfigEntryEvent], None, None]: + """An iterator that waits for config entry events.""" + while True: + event_resp = cast(EventResponse, self.recv(subscription_id)) + entries = cast(list[dict[str, JSONType]], event_resp.event) + yield [ConfigEntryEvent.from_json(entry) for entry in entries] + + def fire_event(self, event_type: str, **event_data) -> Context: + """ + Fire an event. + + Sends command :code:`{"type": "fire_event", ...}`. + """ + params: dict[str, JSONType] = {"event_type": event_type} + if event_data: + params["event_data"] = event_data + return Context.from_json( + cast( + dict[str, dict[str, JSONType]], + cast( + ResultResponse, + self.recv(self.send("fire_event", include_id=True, **params)), + ).result, + )["context"] + ) diff --git a/homeassistant_api/websocket.py b/homeassistant_api/websocket.py index eb9d2c7..5a4d0d0 100644 --- a/homeassistant_api/websocket.py +++ b/homeassistant_api/websocket.py @@ -1,39 +1,18 @@ -import contextlib +"""Module containing the primary Client class.""" + import logging import urllib.parse as urlparse -from typing import Dict, Generator, List, Optional, Tuple, Union, cast - -from homeassistant_api.models import ( - ConfigEntry, - ConfigEntryEvent, - ConfigSubEntry, - DisableEnableResult, - Domain, - Entity, - FlowResult, - Group, - IntegrationTypes, - State, -) -from homeassistant_api.models.states import Context -from homeassistant_api.models.websocket import ( - EventResponse, - FiredEvent, - FiredTrigger, - ResultResponse, - TemplateEvent, -) -from homeassistant_api.utils import JSONType, prepare_entity_id +from .rawasyncwebsocket import RawAsyncWebsocketClient from .rawwebsocket import RawWebsocketClient logger = logging.getLogger(__name__) -class WebsocketClient(RawWebsocketClient): +class WebsocketClient(RawWebsocketClient, RawAsyncWebsocketClient): """ - The main class for interactign with the Home Assistant WebSocket API client. + The main class for interacting with the Home Assistant WebSocket API client. Here's a quick example of how to use the :py:class:`WebsocketClient` class: @@ -48,496 +27,19 @@ class WebsocketClient(RawWebsocketClient): light = ws_client.trigger_service('light', 'turn_on', entity_id="light.living_room") """ - def __init__( - self, - api_url: str, - token: str, - ) -> None: + def __init__(self, api_url: str, token: str, use_async: bool = False) -> None: parsed = urlparse.urlparse(api_url) - if parsed.scheme not in {"ws", "wss"}: - raise ValueError(f"Unknown scheme {parsed.scheme} in {api_url}") - super().__init__(api_url, token) - logger.debug(f"WebSocketClient initialized with api_url: {api_url}") - - def get_rendered_template(self, template: str) -> str: - """ - Renders a Jinja2 template with Home Assistant context data. - See https://www.home-assistant.io/docs/configuration/templating. - - Sends command :code:`{"type": "render_template", ...}`. - """ - id = self.send("render_template", template=template, report_errors=True) - first = self.recv(id) - assert cast(ResultResponse, first).result is None - second = self.recv(id) - self._unsubscribe(id) - return cast(TemplateEvent, cast(EventResponse, second).event).result - - def get_config(self) -> dict[str, JSONType]: - """ - Get the Home Assistant configuration. - - Sends command :code:`{"type": "get_config", ...}`. - """ - return cast( - dict[str, JSONType], - cast( - ResultResponse, - self.recv(self.send("get_config")), - ).result, - ) - - def get_states(self) -> Tuple[State, ...]: - """ - Get a list of states. - - Sends command :code:`{"type": "get_states", ...}`. - """ - return tuple( - State.from_json(state) - for state in cast( - list[dict[str, JSONType]], - cast(ResultResponse, self.recv(self.send("get_states"))).result, - ) - ) - - def get_state( # pylint: disable=duplicate-code - self, - *, - entity_id: Optional[str] = None, - group_id: Optional[str] = None, - slug: Optional[str] = None, - ) -> State: - """ - Just calls the :py:meth:`get_states` method and filters the result. - - Please tell home-assistant/core to add a :code:`{"type": "get_state", ...}` command to the WS API! - There is a lot of disappointment and frustration in the community because this is not available. - """ - entity_id = prepare_entity_id( - group_id=group_id, - slug=slug, - entity_id=entity_id, - ) - - for state in self.get_states(): - if state.entity_id == entity_id: - return state - raise ValueError(f"Entity {entity_id} not found!") - - def get_entities(self) -> Dict[str, Group]: - """ - Fetches all entities from the Websocket API and returns them as a dictionary of :py:class:`Group`'s. - For example :code:`light.living_room` would be in the group :code:`light` (i.e. :code:`get_entities()["light"].living_room`). - """ - entities: Dict[str, Group] = {} - for state in self.get_states(): - group_id, entity_slug = state.entity_id.split(".") - if group_id not in entities: - entities[group_id] = Group( - group_id=group_id, - _client=self, # type: ignore[arg-type] - ) - entities[group_id]._add_entity(entity_slug, state) - return entities - - def get_entity( - self, - group_id: Optional[str] = None, - slug: Optional[str] = None, - entity_id: Optional[str] = None, - ) -> Optional[Entity]: - """ - Returns an :py:class:`Entity` model for an :code:`entity_id`. - - Calls :py:meth:`get_states` under the hood. - - Please tell home-assistant/core to add a :code:`{"type": "get_state", ...}` command to the WS API! - There is a lot of disappointment and frustration in the community because this is not available. - """ - if group_id is not None and slug is not None: - state = self.get_state(group_id=group_id, slug=slug) - elif entity_id is not None: - state = self.get_state(entity_id=entity_id) + if parsed.scheme in {"ws", "wss"}: + if use_async: + RawAsyncWebsocketClient.__init__(self, api_url, token) + client_type = "Async" + else: + RawWebsocketClient.__init__(self, api_url, token) + client_type = "" else: - help_msg = ( - "Use keyword arguments to pass entity_id. " - "Or you can pass the group_id and slug instead" - ) - raise ValueError( - f"Neither group_id and slug or entity_id provided. {help_msg}" - ) - split_group_id, split_slug = state.entity_id.split(".") - group = Group( - group_id=split_group_id, - _client=self, # type: ignore[arg-type] - ) - group._add_entity(split_slug, state) - return group.get_entity(split_slug) - - def get_domains(self) -> dict[str, Domain]: - """ - Get a list of services that Home Assistant offers (organized into a dictionary of service domains). - - For example, the service :code:`light.turn_on` would be in the domain :code:`light`. - - Sends command :code:`{"type": "get_services", ...}`. - """ - resp = self.recv(self.send("get_services")) - domains = map( - lambda item: Domain.from_json_with_client( - {"domain": item[0], "services": item[1]}, - client=self, - ), - cast(dict[str, JSONType], cast(ResultResponse, resp).result).items(), - ) - return {domain.domain_id: domain for domain in domains} - - def get_domain(self, domain: str) -> Domain: - """Get a domain. - - Note: This is not a method in the WS API client... yet. - - Please tell home-assistant/core to add a `get_domain` command to the WS API! - - For now, just call the :py:meth":`get_domains` method and parsing the result. - """ - return self.get_domains()[domain] - - # config_entries.py - - def get_nonuser_flows_in_progress(self) -> Tuple[FlowResult, ...]: - """ - Get config entries that are in progress but not initiated by a user. - - Sends command :code:`{"type": "config_entries/flow/progress"}`. - """ - return tuple( - FlowResult.from_json(flow_result) - for flow_result in cast( - list[dict[str, JSONType]], - cast( - ResultResponse, self.recv(self.send("config_entries/flow/progress")) - ).result, - ) - ) - - def disable_config_entry(self, entry_id: str) -> DisableEnableResult: - """ - Disable a config entry. - - Sends command :code:`{"type": "config_entries/disable", disabled_by="user", ...}`. - """ - return DisableEnableResult.from_json( - cast( - ResultResponse, - self.recv( - self.send( - "config_entries/disable", entry_id=entry_id, disabled_by="user" - ) - ), - ).result, - ) - - def enable_config_entry(self, entry_id: str) -> DisableEnableResult: - """Enable a config entry. - - Sends command :code:`{"type": "config_entries/disable", disabled_by=None, ...}`. - - """ - return DisableEnableResult.from_json( - cast( - ResultResponse, - self.recv( - self.send( - "config_entries/disable", entry_id=entry_id, disabled_by=None - ) - ), - ).result, - ) - - def ignore_config_flow(self, flow_id: str, title: str) -> None: - """ - Ignore an active config flow. - - Sends command :code:`{"type": "config_entries/ignore_flow", ...}`. - """ - self.recv(self.send("config_entries/ignore_flow", flow_id=flow_id, title=title)) - - def get_config_entries( - self, type_filter: List[IntegrationTypes] = [], domain: str = "" - ) -> Tuple[ConfigEntry, ...]: - """ - Get filtered config entries. - - Sends command :code:`{"type": "config_entries/get", ...}`. - """ - return tuple( - ConfigEntry.from_json(config_entry) - for config_entry in cast( - list[dict[str, JSONType]], - cast( - ResultResponse, - self.recv( - self.send( - "config_entries/get", type_filter=type_filter, domain=domain - ) - ), - ).result, - ) - ) - - def _subscribe_config_entries(self) -> int: - """ - Subscribe to config entry flows. - - Sends command :code:`{"type": "config_entries/subscribe"}`. - """ - - return self.recv(self.send("config_entries/subscribe")).id - - @contextlib.contextmanager - def listen_config_entries( - self, disconnect_client: bool = True - ) -> Generator[Generator[List[ConfigEntryEvent], None, None], None, None]: - """ - Listen to all config entry flow events. - - For example: - - .. code-block:: python - - with ws_client.listen_config_entries() as flows: - for i, flow in zip(range(2), flows): # to only wait for two flows to be received - print(flow) - """ - subscription = self._subscribe_config_entries() - yield cast( - Generator[List[ConfigEntryEvent], None, None], self._wait_for(subscription) - ) - # There is no "unsubscribe" method available for these events. - # Provide the ability to "unsubscribe" by disconnecting and reconnecting the Websocket client. - if disconnect_client: - logger.info("Reloading Websocket Client. Undefined behavior may occur.") - self.__exit__(None, None, None) - self.__enter__() - - def get_entry_subentries(self, entry_id: str) -> Tuple[ConfigSubEntry, ...]: - """ - Get an entry's sub-entries. - - Sends command :code:`{"type": "config_entries/subentries/list", ...}`. - """ - return tuple( - ConfigSubEntry.from_json(sub_entry) - for sub_entry in cast( - list[dict[str, JSONType]], - cast( - ResultResponse, - self.recv( - self.send("config_entries/subentries/list", entry_id=entry_id) - ), - ).result, - ) - ) - - # UNTESTED - def delete_entry_subentry(self, entry_id: str, subentry_id: str) -> None: - """ - Delete an entry's sub-entry. - - Sends command :code:`{"type": "config_entries/subentries/delete", ...}`. - """ - self.recv( - self.send( - "config_entries/subentries/delete", - entry_id=entry_id, - subentry_id=subentry_id, - ) - ) - - def trigger_service( - self, - domain: str, - service: str, - entity_id: Optional[str] = None, - **service_data, - ) -> None: - """ - Trigger a service (that doesn't return a response). - - Sends command :code:`{"type": "call_service", ...}`. - """ - params = { - "domain": domain, - "service": service, - "service_data": service_data, - "return_response": False, - } - if entity_id is not None: - params["target"] = {"entity_id": entity_id} - - data = self.recv(self.send("call_service", include_id=True, **params)) - - # TODO: handle data["result"]["context"] ? - - assert ( - cast( - dict[str, JSONType], - cast(ResultResponse, data).result, - ).get("response") - is None - ) # should always be None for services without a response - - def trigger_service_with_response( - self, - domain: str, - service: str, - entity_id: Optional[str] = None, - **service_data, - ) -> dict[str, JSONType]: - """ - Trigger a service (that returns a response) and return the response. - - Sends command :code:`{"type": "call_service", ...}`. - """ - params = { - "domain": domain, - "service": service, - "service_data": service_data, - "return_response": True, - } - if entity_id is not None: - params["target"] = {"entity_id": entity_id} - - data = self.recv(self.send("call_service", include_id=True, **params)) - - return cast(dict[str, dict[str, JSONType]], cast(ResultResponse, data).result)[ - "response" - ] - - @contextlib.contextmanager - def listen_events( - self, - event_type: Optional[str] = None, - ) -> Generator[Generator[FiredEvent, None, None], None, None]: - """ - Listen for all events of a certain type. - - For example, to listen for all events of type `test_event`: - - .. code-block:: python - - with ws_client.listen_events("test_event") as events: - for i, event in zip(range(2), events): # to only wait for two events to be received - print(event) - """ - subscription = self._subscribe_events(event_type) - yield cast(Generator[FiredEvent, None, None], self._wait_for(subscription)) - self._unsubscribe(subscription) - - def _subscribe_events(self, event_type: Optional[str]) -> int: - """ - Subscribe to all events of a certain type. - - - Sends command :code:`{"type": "subscribe_events", ...}`. - """ - params = {"event_type": event_type} if event_type else {} - return self.recv(self.send("subscribe_events", include_id=True, **params)).id - - @contextlib.contextmanager - def listen_trigger( - self, trigger: str, **trigger_fields - ) -> Generator[Generator[dict[str, JSONType], None, None], None, None]: - """ - Listen to a Home Assistant trigger. - Allows additional trigger keyword parameters with :code:`**kwargs` (i.e. passing :code:`tag_id=...` for NFC tag triggers). - - For example, in Home Assistant Automations we can subscribe to a state trigger for a light entity with YAML: - - .. code-block:: yaml - - triggers: - # ... - - trigger: state - entity_id: light.kitchen - - To subscribe to that same state trigger with :py:class:`WebsocketClient` instead - - .. code-block:: python - - with ws_client.listen_trigger("state", entity_id="light.kitchen") as trigger: - for event in trigger: # will iterate until we manually break out of the loop - print(event) - if : - break - # exiting the context manager unsubscribes from the trigger - - Woohoo! We can now listen to triggers in Python code! - """ - subscription = self._subscribe_trigger(trigger, **trigger_fields) - yield ( - fired_trigger.variables - for fired_trigger in cast( - Generator[FiredTrigger, None, None], - self._wait_for(subscription), - ) - ) - self._unsubscribe(subscription) - - def _subscribe_trigger(self, trigger: str, **trigger_fields) -> int: - """ - Return the subscription id of the trigger we subscribe to. - - Sends command :code:`{"type": "subscribe_trigger", ...}`. - """ - return self.recv( - self.send( - "subscribe_trigger", trigger={"platform": trigger, **trigger_fields} - ) - ).id - - def _wait_for( - self, subscription_id: int - ) -> Generator[Union[FiredEvent, FiredTrigger, List[ConfigEntryEvent]], None, None]: - """ - An iterator that waits for events of a certain type. - """ - while True: - yield cast( - Union[ - FiredEvent, FiredTrigger, List[ConfigEntryEvent] - ], # we can cast this because TemplateEvent is only used for rendering templates - cast(EventResponse, self.recv(subscription_id)).event, - ) - - def _unsubscribe(self, subcription_id: int) -> None: - """ - Unsubscribe from all events of a certain type. - - Sends command :code:`{"type": "unsubscribe_events", ...}`. - """ - resp = self.recv(self.send("unsubscribe_events", subscription=subcription_id)) - assert cast(ResultResponse, resp).result is None - self._event_responses.pop(subcription_id) - - def fire_event(self, event_type: str, **event_data) -> Context: - """ - Fire an event. + raise ValueError(f"Unknown scheme {parsed.scheme} in {api_url}") - Sends command :code:`{"type": "fire_event", ...}`. - """ - params: dict[str, JSONType] = {"event_type": event_type} - if event_data: - params["event_data"] = event_data - return Context.from_json( - cast( - dict[str, dict[str, JSONType]], - cast( - ResultResponse, - self.recv(self.send("fire_event", include_id=True, **params)), - ).result, - )["context"] + logger.debug( + f"{client_type}WebSocketClient initialized with api_url: {api_url}" ) diff --git a/tests/conftest.py b/tests/conftest.py index 16c26a7..78ef237 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,3 +57,16 @@ def setup_websocket_client( os.environ["HOMEASSISTANTAPI_TOKEN"], ) as client: yield client + + +@pytest.fixture(name="async_websocket_client", scope="session") +async def setup_async_websocket_client( + wait_for_server: Literal[None], +) -> AsyncGenerator[WebsocketClient, None]: + """Initializes the Client and enters an async WebSocket session.""" + async with WebsocketClient( + os.environ["HOMEASSISTANTAPI_WS_URL"], + os.environ["HOMEASSISTANTAPI_TOKEN"], + use_async=True, + ) as client: + yield client diff --git a/tests/test_client.py b/tests/test_client.py index 22c3a6d..6a11939 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -55,3 +55,12 @@ def test_websocket_client_ping() -> None: os.environ["HOMEASSISTANTAPI_TOKEN"], ) as client: assert client.ping_latency() > 0 + + +async def test_async_websocket_client_ping() -> None: + async with WebsocketClient( + os.environ["HOMEASSISTANTAPI_WS_URL"], + os.environ["HOMEASSISTANTAPI_TOKEN"], + use_async=True, + ) as client: + assert (await client.async_ping_latency()) > 0 diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index f712ac9..0745600 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -129,6 +129,19 @@ def test_websocket_get_rendered_template(websocket_client: WebsocketClient) -> N } +async def test_async_websocket_get_rendered_template( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "render_template"` websocket command.""" + rendered_template = await async_websocket_client.async_get_rendered_template( + 'The sun is {{ states("sun.sun").replace("_", " the ") }}.' + ) + assert rendered_template in { + "The sun is above the horizon.", + "The sun is below the horizon.", + } + + def test_check_api_config(cached_client: Client) -> None: """Tests the `POST /api/config/core/check_config` endpoint.""" assert cached_client.check_api_config() @@ -158,6 +171,15 @@ def test_websocket_get_config(websocket_client: WebsocketClient) -> None: assert config.get("state") in {"RUNNING", "NOT_RUNNING"} +async def test_async_websocket_get_config( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "get_config"` websocket command.""" + config = await async_websocket_client.async_get_config() + assert isinstance(config, dict) + assert config.get("state") in {"RUNNING", "NOT_RUNNING"} + + def test_websocket_get_state(websocket_client: WebsocketClient) -> None: """Tests WebsocketClient.get_state with entity_id.""" state = websocket_client.get_state(entity_id="sun.sun") @@ -187,6 +209,53 @@ def test_websocket_get_entity_no_args(websocket_client: WebsocketClient) -> None websocket_client.get_entity() +async def test_async_websocket_get_state( + async_websocket_client: WebsocketClient, +) -> None: + """Tests async WebsocketClient.async_get_state with entity_id.""" + state = await async_websocket_client.async_get_state(entity_id="sun.sun") + assert state.entity_id == "sun.sun" + assert state.state in {"above_horizon", "below_horizon"} + + +async def test_async_websocket_get_entity_by_group_slug( + async_websocket_client: WebsocketClient, +) -> None: + """Tests async WebsocketClient.async_get_entity with group_id and slug.""" + entity = await async_websocket_client.async_get_entity(group_id="sun", slug="sun") + assert entity is not None + assert entity.entity_id == "sun.sun" + + +async def test_async_websocket_get_entity_by_entity_id( + async_websocket_client: WebsocketClient, +) -> None: + """Tests async WebsocketClient.async_get_entity with entity_id.""" + entity = await async_websocket_client.async_get_entity(entity_id="sun.sun") + assert entity is not None + assert entity.entity_id == "sun.sun" + + +async def test_async_websocket_get_entity_no_args( + async_websocket_client: WebsocketClient, +) -> None: + """Tests async WebsocketClient.async_get_entity raises ValueError with no arguments.""" + with pytest.raises( + ValueError, match="Neither group_id and slug or entity_id provided" + ): + await async_websocket_client.async_get_entity() + + +async def test_async_websocket_get_state_not_found( + async_websocket_client: WebsocketClient, +) -> None: + """Tests async WebsocketClient.async_get_state raises ValueError for nonexistent entity.""" + with pytest.raises(ValueError, match="not found"): + await async_websocket_client.async_get_state( + entity_id="fake.nonexistent_entity_12345" + ) + + def test_websocket_get_state_not_found(websocket_client: WebsocketClient) -> None: """Tests WebsocketClient.get_state raises ValueError for nonexistent entity.""" with pytest.raises(ValueError, match="not found"): @@ -199,6 +268,14 @@ def test_websocket_get_entities(websocket_client: WebsocketClient) -> None: assert "sun" in entities +async def test_async_websocket_get_entities( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "get_entities"` websocket command.""" + entities = await async_websocket_client.async_get_entities() + assert "sun" in entities + + def test_get_domains(cached_client: Client) -> None: """Tests the `GET /api/services` endpoint.""" domains = cached_client.get_domains() @@ -217,6 +294,14 @@ def test_websocket_get_domains(websocket_client: WebsocketClient) -> None: assert "homeassistant" in domains +async def test_async_websocket_get_domains( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "get_domains"` websocket command.""" + domains = await async_websocket_client.async_get_domains() + assert "homeassistant" in domains + + def test_get_domain(cached_client: Client) -> None: """Tests the `GET /api/services` endpoint.""" domain = cached_client.get_domain("homeassistant") @@ -238,6 +323,15 @@ def test_websocket_get_domain(websocket_client: WebsocketClient) -> None: assert domain.services +async def test_async_websocket_get_domain( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "get_domain"` websocket command.""" + domain = await async_websocket_client.async_get_domain("homeassistant") + assert domain is not None + assert domain.services + + def test_get_nonuser_flows_in_progress(websocket_client: WebsocketClient) -> None: """Tests the `"type": "config_entries/flow/progress"` websocket command.""" # No flows in progress @@ -245,6 +339,14 @@ def test_get_nonuser_flows_in_progress(websocket_client: WebsocketClient) -> Non assert not flows +async def test_async_get_nonuser_flows_in_progress( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "config_entries/flow/progress"` websocket command.""" + flows = await async_websocket_client.async_get_nonuser_flows_in_progress() + assert not flows + + def test_disable_enable_config_entry(websocket_client: WebsocketClient) -> None: """Tests the `"type": "config_entries/disable"` websocket command.""" # Get sun entry @@ -261,21 +363,42 @@ def test_disable_enable_config_entry(websocket_client: WebsocketClient) -> None: # Re-enable websocket_client.enable_config_entry(entry.entry_id) - # Check that it was enable + # Check that it was enabled enabled_entry = websocket_client.get_config_entries()[0] assert enabled_entry.disabled_by is None +async def test_async_disable_enable_config_entry( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "config_entries/disable"` websocket command.""" + entry = (await async_websocket_client.async_get_config_entries())[0] + assert entry.disabled_by is None + + await async_websocket_client.async_disable_config_entry(entry.entry_id) + + disabled_entry = (await async_websocket_client.async_get_config_entries())[0] + assert disabled_entry.disabled_by is ConfigEntryDisabler.USER + + await async_websocket_client.async_enable_config_entry(entry.entry_id) + + enabled_entry = (await async_websocket_client.async_get_config_entries())[0] + assert enabled_entry.disabled_by is None + + def test_ignore_config_flow(websocket_client: WebsocketClient) -> None: """Tests the `"type": "config_entries/ignore_flow"` websocket command.""" # Currently not able to test as no flows are in progress. Send invalid parameters and handle that error - try: + with pytest.raises(RequestError, match="Config entry not found"): websocket_client.ignore_config_flow("", "") - except RequestError as error: - assert ( - error.__str__() - == "An error occurred while making the request to 'Config entry not found' with data: 'not_found'" - ) + + +async def test_async_ignore_config_flow( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "config_entries/ignore_flow"` websocket command.""" + with pytest.raises(RequestError, match="Config entry not found"): + await async_websocket_client.async_ignore_config_flow("", "") def test_get_config_entries(websocket_client: WebsocketClient) -> None: @@ -307,6 +430,20 @@ def test_get_config_entries(websocket_client: WebsocketClient) -> None: assert sun.num_subentries == 0 +async def test_async_get_config_entries( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "config_entries/get"` websocket command.""" + entries = await async_websocket_client.async_get_config_entries() + assert len(entries) == 4 + + sun = entries[0] + assert sun.entry_id == "5f8426fa502435857743f302651753c9" + assert sun.domain == "sun" + assert sun.title == "Sun" + assert sun.disabled_by is None + + def test_get_entry_subentries(websocket_client: WebsocketClient) -> None: """Tests the `"type": "config_entries/subentries/list"` websocket command.""" # Currently not able to test as no entries with subentries available @@ -318,16 +455,28 @@ def test_get_entry_subentries(websocket_client: WebsocketClient) -> None: assert not websocket_client.get_entry_subentries(sun.entry_id) +async def test_async_get_entry_subentries( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "config_entries/subentries/list"` websocket command.""" + sun = (await async_websocket_client.async_get_config_entries())[0] + assert sun + assert not await async_websocket_client.async_get_entry_subentries(sun.entry_id) + + def test_delete_entry_subentry(websocket_client: WebsocketClient) -> None: """Tests the `"type": "config_entries/subentries/delete"` websocket command.""" # Currently not able to test as no entries with subentries available. Send invalid parameters and handle that error - try: + with pytest.raises(RequestError, match="Config entry not found"): websocket_client.delete_entry_subentry("", "") - except RequestError as error: - assert ( - error.__str__() - == "An error occurred while making the request to 'Config entry not found' with data: 'not_found'" - ) + + +async def test_async_delete_entry_subentry( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "config_entries/subentries/delete"` websocket command.""" + with pytest.raises(RequestError, match="Config entry not found"): + await async_websocket_client.async_delete_entry_subentry("", "") def test_trigger_service(cached_client: Client) -> None: @@ -364,6 +513,19 @@ def test_websocket_trigger_service(websocket_client: WebsocketClient) -> None: assert resp is None +async def test_async_websocket_trigger_service( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "trigger_service"` websocket command.""" + notify = await async_websocket_client.async_get_domain("notify") + assert notify is not None + resp = await notify.persistent_notification( + message="Your API Test Suite just said hello!", title="Test Suite Notifcation" + ) + # Websocket API doesnt return changed states so we check for None + assert resp is None + + def test_websocket_trigger_service_with_entity_id( websocket_client: WebsocketClient, ) -> None: @@ -416,6 +578,20 @@ def test_websocket_trigger_service_with_response( assert data is not None +async def test_async_websocket_trigger_service_with_response( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "trigger_service_with_response"` websocket command.""" + weather = await async_websocket_client.async_get_domain("weather") + assert weather is not None + data = await weather.get_forecasts( + entity_id="weather.forecast_home", + type="hourly", + ) + # Websocket API doesnt return changed states so we check data is not None because we expect a response + assert data is not None + + def test_get_states(cached_client: Client) -> None: """Tests the `GET /api/states` endpoint.""" states = cached_client.get_states() @@ -437,6 +613,15 @@ def test_websocket_get_states(websocket_client: WebsocketClient) -> None: assert isinstance(state, State) +async def test_async_websocket_get_states( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "get_states"` websocket command.""" + states = await async_websocket_client.async_get_states() + for state in states: + assert isinstance(state, State) + + def test_get_state(cached_client: Client) -> None: """Tests the `GET /api/states/` endpoint.""" state = cached_client.get_state(entity_id="sun.sun") diff --git a/tests/test_errors.py b/tests/test_errors.py index f892a01..471716d 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -45,6 +45,16 @@ def test_websocket_unauthorized() -> None: pass +async def test_async_websocket_unauthorized() -> None: + with pytest.raises(UnauthorizedError): + async with WebsocketClient( + os.environ["HOMEASSISTANTAPI_WS_URL"], + "lolthisisawrongtokenforsure", + use_async=True, + ): + pass + + async def test_async_unauthorized() -> None: with pytest.raises(UnauthorizedError): async with Client( diff --git a/tests/test_events.py b/tests/test_events.py index 458e334..a8bf6aa 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -15,10 +15,24 @@ def test_listen_events(websocket_client: WebsocketClient) -> None: websocket_client.fire_event( "test_event", message="Triggered by websocket client" ) - for _, event in zip(range(1), events): + for event in events: assert event.origin == "LOCAL" assert event.event_type == "test_event" assert event.data["message"] == "Triggered by websocket client" + break + + +async def test_async_listen_events(async_websocket_client: WebsocketClient) -> None: + async with async_websocket_client.async_listen_events("async_test_event") as events: + await async_websocket_client.async_fire_event( + "async_test_event", message="Triggered by async websocket client" + ) + # Typing breaks when using zip in an async context, so break instead + async for event in events: + assert event.origin == "LOCAL" + assert event.event_type == "async_test_event" + assert event.data["message"] == "Triggered by async websocket client" + break def test_listen_trigger(websocket_client: WebsocketClient) -> None: @@ -28,11 +42,12 @@ def test_listen_trigger(websocket_client: WebsocketClient) -> None: with websocket_client.listen_trigger( "time", at=future.strftime("%H:%M:%S") ) as triggers: - for _, trigger in zip(range(1), triggers): + for trigger in triggers: assert trigger["trigger"]["platform"] == "time" assert datetime.fromisoformat( trigger["trigger"]["now"] ).timestamp() == pytest.approx(future.timestamp(), abs=1) + break def test_listen_config_entries(websocket_client: WebsocketClient) -> None: @@ -70,3 +85,67 @@ def test_listen_config_entries(websocket_client: WebsocketClient) -> None: assert flow[0].type == ConfigEntryChange.UPDATED assert flow[0].entry.disabled_by is None assert flow[0].entry.state == ConfigEntryState.LOADED + + +async def test_async_listen_config_entries( + async_websocket_client: WebsocketClient, +) -> None: + async with async_websocket_client.async_listen_config_entries() as flows: + i = 0 + async for flow in flows: + if i == 0: + # The first "events" are currently available entries + assert flow[0].type is None + assert flow[0].entry.disabled_by is None + assert flow[0].entry.state == ConfigEntryState.LOADED + + # Trigger an "updated" event + await async_websocket_client.async_disable_config_entry( + flow[0].entry.entry_id + ) + + if i == 1: + assert flow[0].type == ConfigEntryChange.UPDATED + assert flow[0].entry.disabled_by == ConfigEntryDisabler.USER + assert flow[0].entry.state == ConfigEntryState.UNLOAD_IN_PROGRESS + + if i == 2: + assert flow[0].type == ConfigEntryChange.UPDATED + assert flow[0].entry.disabled_by == ConfigEntryDisabler.USER + assert flow[0].entry.state == ConfigEntryState.NOT_LOADED + + # Restore original state + await async_websocket_client.async_enable_config_entry( + flow[0].entry.entry_id + ) + + if i == 3: + assert flow[0].type == ConfigEntryChange.UPDATED + assert flow[0].entry.disabled_by is None + assert flow[0].entry.state == ConfigEntryState.SETUP_IN_PROGRESS + + if i == 4: + assert flow[0].type == ConfigEntryChange.UPDATED + assert flow[0].entry.disabled_by is None + assert flow[0].entry.state == ConfigEntryState.LOADED + break + + i += 1 + + +async def test_async_listen_trigger(async_websocket_client: WebsocketClient) -> None: + future = datetime.fromisoformat( + await async_websocket_client.async_get_rendered_template( + "{{ (now() + timedelta(seconds=1)) }}" + ) + ) + async with async_websocket_client.async_listen_trigger( + "time", at=future.strftime("%H:%M:%S") + ) as triggers: + # Typing breaks when using zip in an async context, so break instead + async for trigger in triggers: + assert trigger["trigger"]["platform"] == "time" + assert datetime.fromisoformat( + trigger["trigger"]["now"] + ).timestamp() == pytest.approx(future.timestamp(), abs=1) + break diff --git a/tests/test_websocket.py b/tests/test_websocket.py index ccd0757..e86f4d0 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -1,8 +1,9 @@ -"""Unit tests for RawWebsocketClient and WebsocketClient error paths.""" +"""Unit tests for RawWebsocketClient, RawAsyncWebsocketClient, and WebsocketClient error paths.""" import pytest from homeassistant_api.errors import ReceivingError, RequestError, ResponseError +from homeassistant_api.rawasyncwebsocket import RawAsyncWebsocketClient from homeassistant_api.rawwebsocket import RawWebsocketClient from homeassistant_api.models import websocket as ws_models @@ -12,6 +13,11 @@ def make_raw_client() -> RawWebsocketClient: return RawWebsocketClient("ws://localhost:8123/api/websocket", "fake_token") +def make_raw_async_client() -> RawAsyncWebsocketClient: + """Create a RawAsyncWebsocketClient without connecting.""" + return RawAsyncWebsocketClient("ws://localhost:8123/api/websocket", "fake_token") + + def test_exit_without_connection() -> None: """Tests __exit__ raises ReceivingError when connection is not open.""" client = make_raw_client() @@ -98,3 +104,70 @@ def raise_runtime_error(*args, **kwargs): ResponseError, match="Unexpected response during authentication" ): client.authentication_phase() + + +async def test_async_aexit_without_connection() -> None: + """Tests __aexit__ raises ReceivingError when connection is not open.""" + client = make_raw_async_client() + with pytest.raises(ReceivingError, match="Connection is not open"): + await client.__aexit__(None, None, None) + + +async def test_async_send_without_connection() -> None: + """Tests _async_send raises ReceivingError when connection is not open.""" + client = make_raw_async_client() + with pytest.raises(ReceivingError, match="Connection is not open"): + await client._async_send({"type": "test"}) + + +async def test_async_recv_without_connection() -> None: + """Tests _async_recv raises ReceivingError when connection is not open.""" + client = make_raw_async_client() + with pytest.raises(ReceivingError, match="Connection is not open"): + await client._async_recv() + + +async def test_async_authentication_phase_invalid_welcome(monkeypatch) -> None: + """Tests async_authentication_phase raises ResponseError on invalid welcome message.""" + client = make_raw_async_client() + + async def fake_recv(): + return {"type": "not_auth_required"} + + monkeypatch.setattr(client, "_async_recv", fake_recv) + with pytest.raises( + ResponseError, match="Unexpected response during authentication" + ): + await client.async_authentication_phase() + + +async def test_async_authentication_phase_unexpected_auth_response( + monkeypatch, +) -> None: + """Tests async_authentication_phase raises ResponseError when AuthOk.model_validate raises a non-ValidationError.""" + call_count = 0 + + async def fake_recv(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return {"type": "auth_required", "ha_version": "2024.1.0"} + return {"type": "auth_ok", "ha_version": "2024.1.0", "message": "unexpected"} + + client = make_raw_async_client() + monkeypatch.setattr(client, "_async_recv", fake_recv) + + async def fake_send(data): + pass + + monkeypatch.setattr(client, "_async_send", fake_send) + + def raise_runtime_error(*args, **kwargs): + raise RuntimeError("something went wrong") + + monkeypatch.setattr(ws_models.AuthOk, "model_validate", raise_runtime_error) + + with pytest.raises( + ResponseError, match="Unexpected response during authentication" + ): + await client.async_authentication_phase()