Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
* None.

### Enhancements
* None.
* `QueryNetworkStateClient.reportBatchStatus` can be used to send status responses for batches returned from the service via
`QueryNetworkStateClient.getCurrentStates`.

### Fixes
* Specify typing_extensions as a dependency to fix support for Python 3.9 and 3.10
Expand Down
14 changes: 11 additions & 3 deletions docs/docs/query-network-state-client.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ The `QueryNetworkStateClient` will allow you to interact with a server running t
wrapper for the gRPC library, with the ability to retrieve information about the state of the network. This is done with the following 3 steps:

1. Create a gRPC connection to the server.
1. Create an instance of the `QueryNetworkStateClient` using your gRPC connection.
2. Use your `QueryNetworkStateClient` to retrieve the state of the network.
2. Create an instance of the `QueryNetworkStateClient` using your gRPC connection.
3. Use your `QueryNetworkStateClient` to retrieve the state of the network.
4. Use your `QueryNetworkStateClient` to report the status of applying the state of the network.

## Creating a gRPC channel

Expand Down Expand Up @@ -43,10 +44,17 @@ Now that you have a client, you can use it to query the state of the network on

The current state of the network between two date/times can be retrieved using the `get_current_states` function on the `QueryNetworkStateClient`.


```python
from datetime import datetime, timedelta

async for events in client.get_current_states(1, datetime.now() - timedelta(days=1), datetime.now()):
# process the list of events here.
```

### Sending current network state statuses

When applying the current state of the network, you should send a status response to report how the update went.

```python
client.report_batch_status(BatchSuccessful(1))
```
41 changes: 35 additions & 6 deletions docs/docs/query-network-state-service.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,44 @@ current state events that occurred between those date/times (inclusive)
from datetime import datetime
from typing import AsyncGenerator, Iterable

from zepben.evolve import CurrentStateEvent
from zepben.evolve import CurrentStateEventBatch


async def on_get_current_states(from_datetime: datetime, to_datetime: datetime) -> AsyncGenerator[Iterable[CurrentStateEvent], None]:
async def on_get_current_states(from_datetime: datetime, to_datetime: datetime) -> AsyncGenerator[CurrentStateEventBatch, None]:
events = []
# build the batch of events
yield events
```

### onCurrentStatesStatus

The `onCurrentStatesStatus` callback is triggered for each status response sent by the client. You should expect to receive one of these for every batch
returned from `onGetCurrentStates`.

```python
from zepben.evolve import SetCurrentStatesStatus


def on_current_states_status(event_status: SetCurrentStatesStatus):
# Do something with the `eventStatus`.
```

### onProcessingError

The `onProcessingError` callback is triggered for any errors in your `onCurrentStatesStatus` callback, or if any [SetCurrentStatesResponse] is for an unknown
event status.

```python
def on_processing_error(error: Exception):
# Do something with the `error`.
```

## Registering callbacks

Registering the callbacks with the service is as simple as passing them into the `QueryNetworkStateService` constructor.

```python
service = QueryNetworkStateService(on_get_current_states)
service = QueryNetworkStateService(on_get_current_states, on_current_states_status, on_processing_error)
```

## Registering the service
Expand Down Expand Up @@ -73,18 +96,24 @@ from typing import AsyncGenerator, Iterable
import grpc
from zepben.protobuf.ns.network_state_pb2_grpc import add_QueryNetworkStateServiceServicer_to_server

from zepben.evolve import CurrentStateEvent, QueryNetworkStateService
from zepben.evolve import CurrentStateEventBatch, QueryNetworkStateService, SetCurrentStatesStatus

class QueryNetworkStateServiceImpl:

def __init__(self):
self.service = QueryNetworkStateService(self.on_get_current_states)
self.service = QueryNetworkStateService(self.on_get_current_states, self.on_current_states_status, self.on_processing_error)

async def on_get_current_states(self, from_datetime: datetime, to_datetime: datetime) -> AsyncGenerator[Iterable[CurrentStateEvent], None]:
async def on_get_current_states(self, from_datetime: datetime, to_datetime: datetime) -> AsyncGenerator[CurrentStateEventBatch, None]:
events = []
# build the batch of events
yield events

def on_current_states_status(event_status: SetCurrentStatesStatus):
# Do something with the `eventStatus`.

def on_processing_error(error: Exception):
# Do something with the `error`.

async def main():
server = grpc.aio.server()
host = 'localhost:50051'
Expand Down
8 changes: 8 additions & 0 deletions src/zepben/evolve/streaming/get/query_network_state_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from zepben.evolve.streaming.data.current_state_event import CurrentStateEvent
from zepben.evolve.streaming.data.current_state_event_batch import CurrentStateEventBatch
from zepben.evolve.streaming.data.set_current_states_status import SetCurrentStatesStatus
from zepben.evolve.streaming.grpc.grpc import GrpcClient
from zepben.evolve.util import datetime_to_timestamp

Expand Down Expand Up @@ -53,3 +54,10 @@ async def get_current_states(self, query_id: int, from_datetime: datetime, to_da
)
async for response in self._stub.getCurrentStates(request):
yield CurrentStateEventBatch(response.messageId, [CurrentStateEvent.from_pb(event) for event in response.event])

def report_batch_status(self, status: SetCurrentStatesStatus):
"""
Send a response to a previous `getCurrentStates` request to let the server know how we handled its response.
:param status: The batch status to report.
"""
self._stub.reportBatchStatus(iter([status.to_pb()]))
122 changes: 110 additions & 12 deletions test/streaming/get/mock_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ class StreamGrpc:
function: str
processors: List[Callable[[GrpcRequest], Generator[GrpcResponse, None, None]]]
"""
The processors to run in order for this function StreamGrpc.
For example if you expect getIdentifiedObjects to be called twice in a row, you could provide two processors for getIdentifiedObjects here, rather
than two separate StreamGrpcs
Stream of requests and matching stream of responses.

The processors to run in order for this function StreamGrpc. There should be an entry in this list for every request sent
via the same stream, with the expected responses to that request.

If you expect multiple requests to be made by subsequent calls to a new stream, they should be specified by two separate
StreamGrpc instances, rather than providing two processors here.
"""

force_timeout: bool = False
Expand All @@ -34,6 +38,37 @@ class StreamGrpc:
class UnaryGrpc:
function: str
processor: Callable[[GrpcRequest], Generator[GrpcResponse, None, None]]
"""
Unary request and matching unary response.

Will cause errors in the processing if the generator produces more than one response.
"""
force_timeout: bool = False


@dataclass
class StreamUnaryGrpc:
function: str
request_validators: List[Callable[[GrpcRequest], None]]
"""
Stream of requests.

Requires one per request sent via the same stream. For multiple requests on different streams use multiple StreamUnaryGrpc instances.
"""
response: GrpcResponse
"""
Unary response to send after request stream is closed.
"""
force_timeout: bool = False


@dataclass
class UnaryStreamGrpc:
function: str
processor: Callable[[GrpcRequest], Generator[GrpcResponse, None, None]]
"""
Unary request and matching stream of responses.
"""
force_timeout: bool = False


Expand Down Expand Up @@ -64,17 +99,23 @@ def __init__(self, channel: grpc_testing.Channel, grpc_service: ServiceDescripto
self.channel: grpc_testing.Channel = channel
self.grpc_service: ServiceDescriptor = grpc_service

async def validate(self, client_test: Callable[[], Awaitable[None]], interactions: List[Union[StreamGrpc, UnaryGrpc]]):
async def validate(self, client_test: Callable[[], Awaitable[None]], interactions: List[Union[StreamGrpc, UnaryGrpc, StreamUnaryGrpc, UnaryStreamGrpc]]):
"""
Run a server that mocks RPC requests by invoking the provided `interactions` in order.

:param client_test: The test code to call.
:param interactions: An ordered list of interactions expected for this server.
"""
# Run the server logic in another thread.
# noinspection PyTypeChecker
server = CatchingThread(target=self._run_server_logic, args=[interactions])
server.start()

# Send the client requests.
await client_test()

# Wait for the server to finish. If this times out your test, it indicates that not all expected requests were received, or the request stream
# wasn't closed/completed.
server.join()

if server.exception:
Expand All @@ -86,37 +127,94 @@ def _run_server_logic(self, interactions: List[Union[StreamGrpc, UnaryGrpc]]):
self._run_stream_server_logic(i)
elif isinstance(i, UnaryGrpc):
self._run_unary_server_logic(i)
elif isinstance(i, StreamUnaryGrpc):
self._run_stream_unary_server_logic(i)
elif isinstance(i, UnaryStreamGrpc):
self._run_unary_stream_server_logic(i)
else:
raise NotImplementedError(f"No server logic has been configured for {type(i)}")

def _run_stream_server_logic(self, interaction: StreamGrpc):
for processor in interaction.processors:
_, rpc = self.channel.take_stream_stream(self.grpc_service.methods_by_name[interaction.function])
rpc.send_initial_metadata(())
# Take the expected StreamStream message.
_, rpc = self.channel.take_stream_stream(self.grpc_service.methods_by_name[interaction.function])
rpc.send_initial_metadata(())

try:
try:
for processor in interaction.processors:
# Take a request from the stream.
request = rpc.take_request()
if interaction.force_timeout:
rpc.terminate(None, (), grpc.StatusCode.DEADLINE_EXCEEDED, '')
rpc.terminate((), grpc.StatusCode.DEADLINE_EXCEEDED, '')
return

# Send each yielded item to the response stream.
for response in processor(request):
rpc.send_response(response)

rpc.requests_closed()
finally:
rpc.terminate((), grpc.StatusCode.OK, '')
# Ensure the requests stream is closed.
rpc.requests_closed()
finally:
# Terminate the message regardless of exceptions or completion.
rpc.terminate((), grpc.StatusCode.OK, '')

def _run_unary_server_logic(self, interaction: UnaryGrpc):
# Take the expected UnaryUnary message, which includes the request.
_, request, rpc = self.channel.take_unary_unary(self.grpc_service.methods_by_name[interaction.function])
rpc.send_initial_metadata(())

if interaction.force_timeout:
rpc.terminate(None, (), grpc.StatusCode.DEADLINE_EXCEEDED, '')
return

try:
#
# NOTE: Although this is looping, the processor should only yield a single entry (it is Unary after all). Yielding multiple items will
# cause multiple calls to `terminate` which will cause errors.
#
for response in interaction.processor(request):
rpc.terminate(response, (), grpc.StatusCode.OK, '')
except Exception as e:
# If there were errors in the processor, send a blank response.
rpc.terminate((), (), grpc.StatusCode.OK, '')
raise e

def _run_stream_unary_server_logic(self, interaction: StreamUnaryGrpc):
_, rpc = self.channel.take_stream_unary(self.grpc_service.methods_by_name[interaction.function])
rpc.send_initial_metadata(())

try:
for validator in interaction.request_validators:
request = rpc.take_request()
if interaction.force_timeout:
rpc.terminate(None, (), grpc.StatusCode.DEADLINE_EXCEEDED, '')
return

validator(request)

rpc.requests_closed()

# NOTE: We don't want to move this to a finally block, otherwise it will send a second terminate if it times out above.
rpc.terminate(interaction.response, (), grpc.StatusCode.OK, '')
except Exception as e:
rpc.terminate(interaction.response, (), grpc.StatusCode.OK, '')
raise e

def _run_unary_stream_server_logic(self, interaction: UnaryStreamGrpc):
# Take the expected UnaryStream message, which includes the request.
_, request, rpc = self.channel.take_unary_stream(self.grpc_service.methods_by_name[interaction.function])
rpc.send_initial_metadata(())

try:
if interaction.force_timeout:
rpc.terminate((), grpc.StatusCode.DEADLINE_EXCEEDED, '')
return

# Send each yielded item to the response stream.
for response in interaction.processor(request):
rpc.send_response(response)

# Ensure the requests stream is closed.
rpc.requests_closed()
finally:
# Terminate the message regardless of exceptions or completion.
rpc.terminate((), grpc.StatusCode.OK, '')
Loading