diff --git a/changelog.md b/changelog.md index 3891291fa..272dd07bd 100644 --- a/changelog.md +++ b/changelog.md @@ -7,7 +7,8 @@ * None. ### Enhancements -* None. +* `QueryNetworkStateClient.reportBatchStatus` can be used to send status responses for batches returned from the service via + `QueryNetworkStateClient.getCurrentStates`. ### Fixes * Specify typing_extensions as a dependency to fix support for Python 3.9 and 3.10 diff --git a/docs/docs/query-network-state-client.mdx b/docs/docs/query-network-state-client.mdx index 1c329e661..676ef50f4 100644 --- a/docs/docs/query-network-state-client.mdx +++ b/docs/docs/query-network-state-client.mdx @@ -10,8 +10,9 @@ The `QueryNetworkStateClient` will allow you to interact with a server running t wrapper for the gRPC library, with the ability to retrieve information about the state of the network. This is done with the following 3 steps: 1. Create a gRPC connection to the server. -1. Create an instance of the `QueryNetworkStateClient` using your gRPC connection. -2. Use your `QueryNetworkStateClient` to retrieve the state of the network. +2. Create an instance of the `QueryNetworkStateClient` using your gRPC connection. +3. Use your `QueryNetworkStateClient` to retrieve the state of the network. +4. Use your `QueryNetworkStateClient` to report the status of applying the state of the network. ## Creating a gRPC channel @@ -43,10 +44,17 @@ Now that you have a client, you can use it to query the state of the network on The current state of the network between two date/times can be retrieved using the `get_current_states` function on the `QueryNetworkStateClient`. - ```python from datetime import datetime, timedelta async for events in client.get_current_states(1, datetime.now() - timedelta(days=1), datetime.now()): # process the list of events here. ``` + +### Sending current network state statuses + +When applying the current state of the network, you should send a status response to report how the update went. + +```python +client.report_batch_status(BatchSuccessful(1)) +``` diff --git a/docs/docs/query-network-state-service.mdx b/docs/docs/query-network-state-service.mdx index a08099d69..88cc811bf 100644 --- a/docs/docs/query-network-state-service.mdx +++ b/docs/docs/query-network-state-service.mdx @@ -26,21 +26,44 @@ current state events that occurred between those date/times (inclusive) from datetime import datetime from typing import AsyncGenerator, Iterable -from zepben.evolve import CurrentStateEvent +from zepben.evolve import CurrentStateEventBatch -async def on_get_current_states(from_datetime: datetime, to_datetime: datetime) -> AsyncGenerator[Iterable[CurrentStateEvent], None]: +async def on_get_current_states(from_datetime: datetime, to_datetime: datetime) -> AsyncGenerator[CurrentStateEventBatch, None]: events = [] # build the batch of events yield events ``` +### onCurrentStatesStatus + +The `onCurrentStatesStatus` callback is triggered for each status response sent by the client. You should expect to receive one of these for every batch +returned from `onGetCurrentStates`. + +```python +from zepben.evolve import SetCurrentStatesStatus + + +def on_current_states_status(event_status: SetCurrentStatesStatus): + # Do something with the `eventStatus`. +``` + +### onProcessingError + +The `onProcessingError` callback is triggered for any errors in your `onCurrentStatesStatus` callback, or if any [SetCurrentStatesResponse] is for an unknown +event status. + +```python +def on_processing_error(error: Exception): + # Do something with the `error`. +``` + ## Registering callbacks Registering the callbacks with the service is as simple as passing them into the `QueryNetworkStateService` constructor. ```python -service = QueryNetworkStateService(on_get_current_states) +service = QueryNetworkStateService(on_get_current_states, on_current_states_status, on_processing_error) ``` ## Registering the service @@ -73,18 +96,24 @@ from typing import AsyncGenerator, Iterable import grpc from zepben.protobuf.ns.network_state_pb2_grpc import add_QueryNetworkStateServiceServicer_to_server -from zepben.evolve import CurrentStateEvent, QueryNetworkStateService +from zepben.evolve import CurrentStateEventBatch, QueryNetworkStateService, SetCurrentStatesStatus class QueryNetworkStateServiceImpl: def __init__(self): - self.service = QueryNetworkStateService(self.on_get_current_states) + self.service = QueryNetworkStateService(self.on_get_current_states, self.on_current_states_status, self.on_processing_error) - async def on_get_current_states(self, from_datetime: datetime, to_datetime: datetime) -> AsyncGenerator[Iterable[CurrentStateEvent], None]: + async def on_get_current_states(self, from_datetime: datetime, to_datetime: datetime) -> AsyncGenerator[CurrentStateEventBatch, None]: events = [] # build the batch of events yield events + def on_current_states_status(event_status: SetCurrentStatesStatus): + # Do something with the `eventStatus`. + + def on_processing_error(error: Exception): + # Do something with the `error`. + async def main(): server = grpc.aio.server() host = 'localhost:50051' diff --git a/src/zepben/evolve/streaming/get/query_network_state_client.py b/src/zepben/evolve/streaming/get/query_network_state_client.py index 9dad8ae3e..0f29fbeee 100644 --- a/src/zepben/evolve/streaming/get/query_network_state_client.py +++ b/src/zepben/evolve/streaming/get/query_network_state_client.py @@ -12,6 +12,7 @@ from zepben.evolve.streaming.data.current_state_event import CurrentStateEvent from zepben.evolve.streaming.data.current_state_event_batch import CurrentStateEventBatch +from zepben.evolve.streaming.data.set_current_states_status import SetCurrentStatesStatus from zepben.evolve.streaming.grpc.grpc import GrpcClient from zepben.evolve.util import datetime_to_timestamp @@ -53,3 +54,10 @@ async def get_current_states(self, query_id: int, from_datetime: datetime, to_da ) async for response in self._stub.getCurrentStates(request): yield CurrentStateEventBatch(response.messageId, [CurrentStateEvent.from_pb(event) for event in response.event]) + + def report_batch_status(self, status: SetCurrentStatesStatus): + """ + Send a response to a previous `getCurrentStates` request to let the server know how we handled its response. + :param status: The batch status to report. + """ + self._stub.reportBatchStatus(iter([status.to_pb()])) diff --git a/test/streaming/get/mock_server.py b/test/streaming/get/mock_server.py index 554d74b28..7230b7ae7 100644 --- a/test/streaming/get/mock_server.py +++ b/test/streaming/get/mock_server.py @@ -22,9 +22,13 @@ class StreamGrpc: function: str processors: List[Callable[[GrpcRequest], Generator[GrpcResponse, None, None]]] """ - The processors to run in order for this function StreamGrpc. - For example if you expect getIdentifiedObjects to be called twice in a row, you could provide two processors for getIdentifiedObjects here, rather - than two separate StreamGrpcs + Stream of requests and matching stream of responses. + + The processors to run in order for this function StreamGrpc. There should be an entry in this list for every request sent + via the same stream, with the expected responses to that request. + + If you expect multiple requests to be made by subsequent calls to a new stream, they should be specified by two separate + StreamGrpc instances, rather than providing two processors here. """ force_timeout: bool = False @@ -34,6 +38,37 @@ class StreamGrpc: class UnaryGrpc: function: str processor: Callable[[GrpcRequest], Generator[GrpcResponse, None, None]] + """ + Unary request and matching unary response. + + Will cause errors in the processing if the generator produces more than one response. + """ + force_timeout: bool = False + + +@dataclass +class StreamUnaryGrpc: + function: str + request_validators: List[Callable[[GrpcRequest], None]] + """ + Stream of requests. + + Requires one per request sent via the same stream. For multiple requests on different streams use multiple StreamUnaryGrpc instances. + """ + response: GrpcResponse + """ + Unary response to send after request stream is closed. + """ + force_timeout: bool = False + + +@dataclass +class UnaryStreamGrpc: + function: str + processor: Callable[[GrpcRequest], Generator[GrpcResponse, None, None]] + """ + Unary request and matching stream of responses. + """ force_timeout: bool = False @@ -64,17 +99,23 @@ def __init__(self, channel: grpc_testing.Channel, grpc_service: ServiceDescripto self.channel: grpc_testing.Channel = channel self.grpc_service: ServiceDescriptor = grpc_service - async def validate(self, client_test: Callable[[], Awaitable[None]], interactions: List[Union[StreamGrpc, UnaryGrpc]]): + async def validate(self, client_test: Callable[[], Awaitable[None]], interactions: List[Union[StreamGrpc, UnaryGrpc, StreamUnaryGrpc, UnaryStreamGrpc]]): """ Run a server that mocks RPC requests by invoking the provided `interactions` in order. :param client_test: The test code to call. :param interactions: An ordered list of interactions expected for this server. """ + # Run the server logic in another thread. + # noinspection PyTypeChecker server = CatchingThread(target=self._run_server_logic, args=[interactions]) server.start() + # Send the client requests. await client_test() + + # Wait for the server to finish. If this times out your test, it indicates that not all expected requests were received, or the request stream + # wasn't closed/completed. server.join() if server.exception: @@ -86,37 +127,94 @@ def _run_server_logic(self, interactions: List[Union[StreamGrpc, UnaryGrpc]]): self._run_stream_server_logic(i) elif isinstance(i, UnaryGrpc): self._run_unary_server_logic(i) + elif isinstance(i, StreamUnaryGrpc): + self._run_stream_unary_server_logic(i) + elif isinstance(i, UnaryStreamGrpc): + self._run_unary_stream_server_logic(i) else: raise NotImplementedError(f"No server logic has been configured for {type(i)}") def _run_stream_server_logic(self, interaction: StreamGrpc): - for processor in interaction.processors: - _, rpc = self.channel.take_stream_stream(self.grpc_service.methods_by_name[interaction.function]) - rpc.send_initial_metadata(()) + # Take the expected StreamStream message. + _, rpc = self.channel.take_stream_stream(self.grpc_service.methods_by_name[interaction.function]) + rpc.send_initial_metadata(()) - try: + try: + for processor in interaction.processors: + # Take a request from the stream. request = rpc.take_request() if interaction.force_timeout: - rpc.terminate(None, (), grpc.StatusCode.DEADLINE_EXCEEDED, '') + rpc.terminate((), grpc.StatusCode.DEADLINE_EXCEEDED, '') return + # Send each yielded item to the response stream. for response in processor(request): rpc.send_response(response) - rpc.requests_closed() - finally: - rpc.terminate((), grpc.StatusCode.OK, '') + # Ensure the requests stream is closed. + rpc.requests_closed() + finally: + # Terminate the message regardless of exceptions or completion. + rpc.terminate((), grpc.StatusCode.OK, '') def _run_unary_server_logic(self, interaction: UnaryGrpc): + # Take the expected UnaryUnary message, which includes the request. _, request, rpc = self.channel.take_unary_unary(self.grpc_service.methods_by_name[interaction.function]) rpc.send_initial_metadata(()) + if interaction.force_timeout: rpc.terminate(None, (), grpc.StatusCode.DEADLINE_EXCEEDED, '') return try: + # + # NOTE: Although this is looping, the processor should only yield a single entry (it is Unary after all). Yielding multiple items will + # cause multiple calls to `terminate` which will cause errors. + # for response in interaction.processor(request): rpc.terminate(response, (), grpc.StatusCode.OK, '') except Exception as e: + # If there were errors in the processor, send a blank response. rpc.terminate((), (), grpc.StatusCode.OK, '') raise e + + def _run_stream_unary_server_logic(self, interaction: StreamUnaryGrpc): + _, rpc = self.channel.take_stream_unary(self.grpc_service.methods_by_name[interaction.function]) + rpc.send_initial_metadata(()) + + try: + for validator in interaction.request_validators: + request = rpc.take_request() + if interaction.force_timeout: + rpc.terminate(None, (), grpc.StatusCode.DEADLINE_EXCEEDED, '') + return + + validator(request) + + rpc.requests_closed() + + # NOTE: We don't want to move this to a finally block, otherwise it will send a second terminate if it times out above. + rpc.terminate(interaction.response, (), grpc.StatusCode.OK, '') + except Exception as e: + rpc.terminate(interaction.response, (), grpc.StatusCode.OK, '') + raise e + + def _run_unary_stream_server_logic(self, interaction: UnaryStreamGrpc): + # Take the expected UnaryStream message, which includes the request. + _, request, rpc = self.channel.take_unary_stream(self.grpc_service.methods_by_name[interaction.function]) + rpc.send_initial_metadata(()) + + try: + if interaction.force_timeout: + rpc.terminate((), grpc.StatusCode.DEADLINE_EXCEEDED, '') + return + + # Send each yielded item to the response stream. + for response in interaction.processor(request): + rpc.send_response(response) + + # Ensure the requests stream is closed. + rpc.requests_closed() + finally: + # Terminate the message regardless of exceptions or completion. + rpc.terminate((), grpc.StatusCode.OK, '') diff --git a/test/streaming/get/test_network_consumer.py b/test/streaming/get/test_network_consumer.py index efa5bb968..4a831f22b 100644 --- a/test/streaming/get/test_network_consumer.py +++ b/test/streaming/get/test_network_consumer.py @@ -2,20 +2,15 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at https://mozilla.org/MPL/2.0/. +from time import sleep from typing import Dict, Iterable, TypeVar, Generator, Callable, Optional from unittest.mock import MagicMock import grpc_testing import pytest -# noinspection PyPackageRequirements +# noinspection PyPackageRequirements,PyUnresolvedReferences from google.protobuf.any_pb2 import Any from hypothesis import given, settings, Phase - -from zepben.evolve import NetworkConsumerClient, NetworkService, IdentifiedObject, CableInfo, AcLineSegment, Breaker, EnergySource, \ - EnergySourcePhase, Junction, PowerTransformer, PowerTransformerEnd, ConnectivityNode, Feeder, Location, OverheadWireInfo, PerLengthSequenceImpedance, \ - Substation, Terminal, EquipmentContainer, Equipment, BaseService, OperationalRestriction, TransformerStarImpedance, GeographicalRegion, \ - SubGeographicalRegion, Circuit, Loop, Diagram, UnsupportedOperationException, LvFeeder, TestNetworkBuilder, PerLengthPhaseImpedance, BatteryControl, \ - PanDemandResponseFunction, BatteryUnit, StaticVarCompensator from zepben.protobuf.nc import nc_pb2 from zepben.protobuf.nc.nc_data_pb2 import NetworkIdentifiedObject from zepben.protobuf.nc.nc_requests_pb2 import GetIdentifiedObjectsRequest, GetEquipmentForContainersRequest, \ @@ -24,13 +19,17 @@ from zepben.protobuf.nc.nc_responses_pb2 import GetIdentifiedObjectsResponse, GetEquipmentForContainersResponse, \ GetEquipmentForRestrictionResponse, GetTerminalsForNodeResponse, GetNetworkHierarchyResponse -from time import sleep -from streaming.get.data.metadata import create_metadata, create_metadata_response -from streaming.get.grpcio_aio_testing.mock_async_channel import async_testing_channel -from streaming.get.pb_creators import network_identified_objects, ac_line_segment from streaming.get.data.hierarchy import create_hierarchy_network from streaming.get.data.loops import create_loops_network +from streaming.get.data.metadata import create_metadata, create_metadata_response +from streaming.get.grpcio_aio_testing.mock_async_channel import async_testing_channel from streaming.get.mock_server import MockServer, StreamGrpc, UnaryGrpc, stream_from_fixed, unary_from_fixed +from streaming.get.pb_creators import network_identified_objects, ac_line_segment +from zepben.evolve import NetworkConsumerClient, NetworkService, IdentifiedObject, CableInfo, AcLineSegment, Breaker, EnergySource, \ + EnergySourcePhase, Junction, PowerTransformer, PowerTransformerEnd, ConnectivityNode, Feeder, Location, OverheadWireInfo, PerLengthSequenceImpedance, \ + Substation, Terminal, EquipmentContainer, Equipment, BaseService, OperationalRestriction, TransformerStarImpedance, GeographicalRegion, \ + SubGeographicalRegion, Circuit, Loop, Diagram, UnsupportedOperationException, LvFeeder, TestNetworkBuilder, PerLengthPhaseImpedance, BatteryControl, \ + PanDemandResponseFunction, BatteryUnit, StaticVarCompensator PBRequest = TypeVar('PBRequest') GrpcResponse = TypeVar('GrpcResponse') @@ -237,14 +236,16 @@ async def client_test(): [ UnaryGrpc('getNetworkHierarchy', unary_from_fixed(None, _create_hierarchy_response(feeder_network))), StreamGrpc('getEquipmentForContainers', [_create_container_responses(feeder_network)]), - StreamGrpc('getIdentifiedObjects', [object_responses, object_responses]) + StreamGrpc('getIdentifiedObjects', [object_responses]), + StreamGrpc('getIdentifiedObjects', [object_responses]) ]) @pytest.mark.asyncio async def test_resolve_references_skips_resolvers_referencing_equipment_containers(self): """ - lvf5:[ c1 { ] c3 }:lvf6 - lvf5:[tx0------{b2]------tx4}:lvf6 + lvf5:[------------] + tx0 --c1-- b2 --c3-- tx4 + lvf6: [------------] """ lv_feeders_with_open_point = (await TestNetworkBuilder() .from_power_transformer() # tx0 @@ -277,7 +278,8 @@ async def client_test(): UnaryGrpc('getNetworkHierarchy', unary_from_fixed(None, _create_hierarchy_response(lv_feeders_with_open_point))), StreamGrpc('getIdentifiedObjects', [object_responses]), StreamGrpc('getEquipmentForContainers', [_create_container_equipment_responses(lv_feeders_with_open_point)]), - StreamGrpc('getIdentifiedObjects', [object_responses, object_responses]) + StreamGrpc('getIdentifiedObjects', [object_responses]), + StreamGrpc('getIdentifiedObjects', [object_responses]) ]) @pytest.mark.asyncio @@ -319,7 +321,8 @@ async def client_test(): network_state=NetworkState.ALL_NETWORK_STATE ) ]), - StreamGrpc('getIdentifiedObjects', [object_responses, object_responses]) + StreamGrpc('getIdentifiedObjects', [object_responses]), + StreamGrpc('getIdentifiedObjects', [object_responses]) ]) @pytest.mark.asyncio @@ -357,7 +360,12 @@ async def client_test(): expected_include_energized_containers=include_energized_containers, network_state=network_state) - await self.mock_server.validate(client_test, [StreamGrpc('getEquipmentForContainers', [response, response])]) + await self.mock_server.validate( + client_test, + [ + StreamGrpc('getEquipmentForContainers', [response]), + StreamGrpc('getEquipmentForContainers', [response]) + ]) @pytest.mark.asyncio async def test_get_equipment_for_containers(self, feeder_network: NetworkService): @@ -414,7 +422,8 @@ async def client_test(): await self.mock_server.validate(client_test, [ UnaryGrpc('getNetworkHierarchy', unary_from_fixed(None, _create_hierarchy_response(ns))), - StreamGrpc('getEquipmentForContainers', [_create_container_equipment_responses(ns, loop_containers, network_state=network_state)]), + StreamGrpc('getEquipmentForContainers', + [_create_container_equipment_responses(ns, loop_containers, network_state=network_state)]), StreamGrpc('getIdentifiedObjects', [_create_object_responses(ns, assoc_objs)]) ]) @@ -436,7 +445,8 @@ async def client_test(): await self.mock_server.validate(client_test, [ UnaryGrpc('getNetworkHierarchy', unary_from_fixed(None, _create_hierarchy_response(ns))), - StreamGrpc('getEquipmentForContainers', [_create_container_equipment_responses(ns, loop_containers, network_state=network_state)]), + StreamGrpc('getEquipmentForContainers', + [_create_container_equipment_responses(ns, loop_containers, network_state=network_state)]), StreamGrpc('getIdentifiedObjects', [_create_object_responses(ns, assoc_objs)]) ]) @@ -613,6 +623,7 @@ def responses(request: GetTerminalsForNodeRequest) -> Generator[GetTerminalsForN return responses + # noinspection PyUnresolvedReferences def _create_hierarchy_response_with_sleep(service: NetworkService, sleep_time: int) -> GetNetworkHierarchyResponse: sleep(sleep_time) diff --git a/test/streaming/get/test_query_network_state_client.py b/test/streaming/get/test_query_network_state_client.py index 60140c0ae..fbcdb3b02 100644 --- a/test/streaming/get/test_query_network_state_client.py +++ b/test/streaming/get/test_query_network_state_client.py @@ -7,12 +7,13 @@ import grpc_testing import pytest +from google.protobuf import empty_pb2 from zepben.protobuf.ns import network_state_pb2 from zepben.protobuf.ns.network_state_responses_pb2 import GetCurrentStatesResponse from streaming.get.grpcio_aio_testing.mock_async_channel import async_testing_channel -from streaming.get.mock_server import MockServer, GrpcRequest, GrpcResponse, StreamGrpc -from zepben.evolve import PhaseCode, datetime_to_timestamp, SwitchStateEvent, SwitchAction, CurrentStateEventBatch, QueryNetworkStateClient +from streaming.get.mock_server import MockServer, GrpcRequest, GrpcResponse, StreamGrpc, StreamUnaryGrpc +from zepben.evolve import PhaseCode, datetime_to_timestamp, SwitchStateEvent, SwitchAction, CurrentStateEventBatch, QueryNetworkStateClient, BatchSuccessful def _current_state_batch_to_pb(batch: CurrentStateEventBatch) -> GetCurrentStatesResponse: @@ -33,7 +34,7 @@ def setup(self): ] @pytest.mark.asyncio - async def test_get_current_states(self): + async def test_get_current_states(self, caplog): query_id = 1 from_datetime = datetime.now() to_datetime = datetime.now() + timedelta(days=1) @@ -55,3 +56,19 @@ async def client_test(): assert results == self.batches await self.mock_server.validate(client_test, [StreamGrpc('getCurrentStates', mock_service())]) + + @pytest.mark.asyncio + async def test_can_report_batch_status(self): + status = BatchSuccessful(1234) + + def mock_service() -> List[Callable[[GrpcRequest], None]]: + def validate_request(request: GrpcRequest): + assert request.messageId == status.batch_id + + return [validate_request] + + async def client_test(): + self.client.report_batch_status(status) + + # noinspection PyUnresolvedReferences + await self.mock_server.validate(client_test, [StreamUnaryGrpc('reportBatchStatus', mock_service(), empty_pb2.Empty())])