diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..bf9bd99 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,10 @@ +[run] +branch = True +source = + cbor_rpc + +[report] +show_missing = True +skip_covered = True +omit = + tests/* diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..b2c12b2 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,53 @@ +name: Run Tests + +on: + push: + branches: + - main + - master + - dev + - develop + pull_request: + branches: + - main + - master + - dev + - develop + +jobs: + test: + runs-on: ubuntu-latest + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -e . + + - name: Run tests with coverage + run: | + pytest --cov=cbor_rpc --cov-report=xml --junitxml=junit.xml -o junit_family=legacy + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: coverage.xml + + - name: Upload test results to Codecov + if: ${{ !cancelled() }} + uses: codecov/test-results-action@v1 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: junit.xml diff --git a/.gitignore b/.gitignore index 72416f8..79cb3b1 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ venv/ __pycache__/ .pytest_cache/ cbor_rpc.egg-info/ +.coverage diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..1be6a53 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..8790aec --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +cbor-rpc +======== +[![codecov](https://codecov.io/github/mesudip/cbor-rpc-py/graph/badge.svg)](https://codecov.io/github/mesudip/cbor-rpc-py) \ No newline at end of file diff --git a/cbor_rpc/__init__.py b/cbor_rpc/__init__.py index 60c384f..2d29e51 100644 --- a/cbor_rpc/__init__.py +++ b/cbor_rpc/__init__.py @@ -2,45 +2,49 @@ CBOR-RPC: An async-compatible CBOR-based RPC system """ -from .emitter import AbstractEmitter -from .async_pipe import Pipe -from .transformer import Transformer -from .promise import DeferredPromise -from .client import RpcClient, RpcAuthorizedClient, RpcV1 -from .server import RpcServer, RpcV1Server -from .server_base import Server +from .event import AbstractEmitter +from .pipe import EventPipe, Pipe +from .timed_promise import TimedPromise from .tcp import TcpPipe, TcpServer -from .json_transformer import JsonTransformer -from .sync_pipe import SyncPipe +from .transformer import CborStreamTransformer, CborTransformer, JsonTransformer, Transformer +from .rpc import ( + RpcClient, + RpcAuthorizedClient, + RpcServer, + RpcV1, + RpcV1Server, + Server, + RpcCallContext, + RpcLogger, +) __all__ = [ - # Emitter - 'AbstractEmitter', - - # Pipe classes - 'Pipe', - 'SyncPipe', - 'Transformer', - # Promise - 'DeferredPromise', - - # Client classes - 'RpcClient', - 'RpcAuthorizedClient', - 'RpcV1', - - # Server classes - 'Server', - 'RpcServer', - 'RpcV1Server', - - # TCP classes - 'TcpPipe', - 'TcpServer', - + "TimedPromise", + # Emitter + "AbstractEmitter", + # Pipe abstract classes + "EventPipe", + "Pipe", + # Server abstract classes + "Server", + # Rpc abstract classes + "RpcClient", + "RpcAuthorizedClient", + "RpcServer", + # Rpc base implementation + "RpcV1", + "RpcV1Server", + # Rpc high level + "RpcCallContext", + "RpcLogger", # TCP classes + "TcpPipe", + "TcpServer", # Transformers - 'JsonTransformer', + "Transformer", + "JsonTransformer", + "CborTransformer", + "CborStreamTransformer", ] __version__ = "0.1.0" diff --git a/cbor_rpc/async_pipe.py b/cbor_rpc/async_pipe.py deleted file mode 100644 index b136337..0000000 --- a/cbor_rpc/async_pipe.py +++ /dev/null @@ -1,118 +0,0 @@ -from typing import Any, TypeVar, Generic, Callable, Tuple, Optional -from abc import ABC, abstractmethod -import asyncio -import inspect -from .emitter import AbstractEmitter -import queue -import threading -from typing import Union - -# Generic type variables -T1 = TypeVar('T1') -T2 = TypeVar('T2') - - -class Pipe(AbstractEmitter, Generic[T1, T2]): - """ - Async Pipe or simply Pipe are event based way for read/write. - You cannot directly read from a Pipe. You have to use a on("data") handler registration. - """ - @abstractmethod - async def write(self, chunk: T1) -> bool: - pass - - @abstractmethod - async def terminate(self, *args: Any) -> None: - pass - - @staticmethod - def attach(source: 'Pipe[Any, Any]', destination: 'Pipe[Any, Any]') -> None: - async def source_to_destination(chunk: Any): - await destination.write(chunk) - - async def destination_to_source(chunk: Any): - await source.write(chunk) - - async def close_handler(*args: Any): - await destination._emit("close", *args) - - source.on("data", source_to_destination) - destination.on("data", destination_to_source) - source.on("close", close_handler) - - @staticmethod - def create_pair() -> Tuple['Pipe[Any, Any]', 'Pipe[Any, Any]']: - """ - Create a pair of connected pipes for bidirectional communication. - - Returns: - A tuple of (pipe1, pipe2) where data written to pipe1 is emitted on pipe2 and vice versa. - """ - class ConnectedPipe(Pipe[Any, Any]): - def __init__(self): - super().__init__() - self.connected_pipe: Optional['ConnectedPipe'] = None - self._closed = False - - def connect_to(self, other: 'ConnectedPipe'): - self.connected_pipe = other - other.connected_pipe = self - - async def write(self, chunk: Any) -> bool: - if self._closed: - return False - - # Forward to connected pipe - if self.connected_pipe and not self.connected_pipe._closed: - await self.connected_pipe._emit("data", chunk) - - return True - - async def terminate(self, *args: Any) -> None: - if self._closed: - return - - self._closed = True - await self._emit("close", *args) - - # Notify connected pipe - if self.connected_pipe and not self.connected_pipe._closed: - await self.connected_pipe._emit("close", *args) - - async def read(self, timeout: Optional[float] = None) -> Optional[Any]: - """Read data from the pipe with a timeout. - - Args: - timeout: Maximum time to wait for data (None = no timeout) - - Returns: - Data from the pipe or None if timeout/closed - """ - if self._closed: - return None - - # Wait for data event or timeout - read_future = asyncio.Future() - - def handle_data(chunk: Any): - read_future.set_result(chunk) - self.off("data", handle_data) - - self.on("data", handle_data) - - try: - if timeout is not None and timeout > 0: - await asyncio.wait_for(read_future, timeout) - else: - # Wait indefinitely for data - await read_future - return read_future.result() - except asyncio.TimeoutError: - return None - - pipe1 = ConnectedPipe() - pipe2 = ConnectedPipe() - pipe1.connect_to(pipe2) - - return pipe1, pipe2 - diff --git a/cbor_rpc/client.py b/cbor_rpc/client.py deleted file mode 100644 index 9392980..0000000 --- a/cbor_rpc/client.py +++ /dev/null @@ -1,184 +0,0 @@ -from typing import Any, Dict, List, Optional, Callable -from abc import ABC, abstractmethod -import asyncio -import inspect -from .async_pipe import Pipe -from .promise import DeferredPromise - - -class RpcClient(ABC): - @abstractmethod - async def emit(self, topic: str, message: Any) -> None: - pass - - @abstractmethod - async def call_method(self, method: str, *args: Any) -> Any: - pass - - @abstractmethod - async def fire_method(self, method: str, *args: Any) -> None: - pass - - @abstractmethod - def set_timeout(self, milliseconds: int) -> None: - pass - - -class RpcAuthorizedClient(RpcClient): - @abstractmethod - def get_id(self) -> str: - pass - - -class RpcV1(RpcClient): - def __init__(self, pipe: Pipe[Any, Any]): - self.pipe = pipe - self._counter = 0 - self._promises: Dict[int, DeferredPromise] = {} - self._timeout = 30000 - self._waiters: Dict[str, DeferredPromise] = {} - - async def resolve_result(result: Any) -> Any: - """Recursively resolve coroutines or nested coroutines.""" - while asyncio.iscoroutine(result): - result = await result - return result - - async def on_data(data: List[Any]) -> None: - try: - if not isinstance(data, list) or len(data) != 5: - print(f"RpcV1: Invalid message format: {data}") - return - - version, direction, id_, method, params = data - if version != 1: - print(f"RpcV1: Unsupported version: {data}") - return - - if direction < 2: # Method call (0) or fire (1) - try: - # Call the method and get the result - result = self.handle_method_call(method, params) - - # Handle the response asynchronously - async def handle_response(): - try: - resolved_result = await resolve_result(result) - if direction == 0: # Only respond to method calls, not fire calls - await self.pipe.write([1, 2, id_, True, resolved_result]) - except Exception as e: - if direction == 0: - await self.pipe.write([1, 2, id_, False, str(e)]) - else: - print(f"Fired method error: {method}, params={params}, error={e}") - - # Create task to handle response - asyncio.create_task(handle_response()) - - except Exception as e: - if direction == 0: - asyncio.create_task(self.pipe.write([1, 2, id_, False, str(e)])) - else: - print(f"Fired method error: {method}, params={params}, error={e}") - - elif direction == 2: # Response - promise = self._promises.pop(id_, None) - if promise: - if method is True: # Success - await promise.resolve(params) - else: # Error - await promise.reject(params) - else: - print(f"Received rpc reply for expired request id: {id_}, success={method}, data={params}") - - elif direction == 3: # Event - await self._on_event(method, params) - else: - print(f"RpcV1: Invalid direction: {direction}") - - except Exception as e: - print(f"Error processing RPC message: {e}") - - self.pipe.on("data", on_data) - - async def call_method(self, method: str, *args: Any) -> Any: - counter = self._counter - self._counter += 1 - - def timeout_callback(): - self._promises.pop(counter, None) - - promise = DeferredPromise(self._timeout, timeout_callback) - self._promises[counter] = promise - await self.pipe.write([1, 0, counter, method, list(args)]) - return await promise.promise - - async def fire_method(self, method: str, *args: Any) -> None: - counter = self._counter - self._counter += 1 - await self.pipe.write([1, 1, counter, method, list(args)]) - - async def emit(self, topic: str, args: Any) -> None: - await self.pipe.write([1, 3, 0, topic, args]) - - def set_timeout(self, milliseconds: int) -> None: - self._timeout = milliseconds - - async def _on_event(self, method: str, message: Any) -> None: - waiter = self._waiters.pop(method, None) - if waiter: - await waiter.resolve(message) - else: - await self.on_event(method, message) - - @abstractmethod - def get_id(self) -> str: - pass - - async def wait_next_event(self, topic: str, timeout_ms: Optional[int] = None) -> Any: - if topic in self._waiters: - raise Exception("Already waiting for event") - - def timeout_callback(): - self._waiters.pop(topic, None) - - waiter = DeferredPromise( - timeout_ms or self._timeout, - timeout_callback, - f"Timeout Waiting for Event on: {topic}" - ) - self._waiters[topic] = waiter - return await waiter.promise - - @abstractmethod - def handle_method_call(self, method: str, args: List[Any]) -> Any: - pass - - @abstractmethod - async def on_event(self, topic: str, message: Any) -> None: - pass - - @staticmethod - def make_rpc_v1(pipe: Pipe[Any, Any], id_: str, method_handler: Callable, event_handler: Callable) -> 'RpcV1': - class ConcreteRpcV1(RpcV1): - def get_id(self) -> str: - return id_ - - def handle_method_call(self, method: str, args: List[Any]) -> Any: - return method_handler(method, args) - - async def on_event(self, topic: str, message: Any) -> None: - if inspect.iscoroutinefunction(event_handler): - await event_handler(topic, message) - else: - event_handler(topic, message) - return ConcreteRpcV1(pipe) - - @staticmethod - def read_only_client(pipe: Pipe[Any, Any]) -> 'RpcV1': - def method_handler(method: str, args: List[Any]) -> Any: - raise Exception("Client Only Implementation") - - async def event_handler(topic: str, message: Any) -> None: - print(f"Rpc Event dropped {topic} {message}") - return RpcV1.make_rpc_v1(pipe, '', method_handler, event_handler) diff --git a/cbor_rpc/event/__init__.py b/cbor_rpc/event/__init__.py new file mode 100644 index 0000000..e288aa2 --- /dev/null +++ b/cbor_rpc/event/__init__.py @@ -0,0 +1 @@ +from .emitter import AbstractEmitter diff --git a/cbor_rpc/emitter.py b/cbor_rpc/event/emitter.py similarity index 54% rename from cbor_rpc/emitter.py rename to cbor_rpc/event/emitter.py index 98a5506..1c2a389 100644 --- a/cbor_rpc/emitter.py +++ b/cbor_rpc/event/emitter.py @@ -2,6 +2,9 @@ from abc import ABC, abstractmethod import asyncio import inspect +import traceback +import warnings + class AbstractEmitter(ABC): def __init__(self): @@ -17,39 +20,48 @@ def unsubscribe(self, event: str, handler: Callable) -> None: def replace_on_handler(self, event_type: str, handler: Callable) -> None: self._subscribers[event_type] = [handler] - async def _emit(self, event_type: str, *args: Any) -> None: + def _run_background_task(self, coro: Callable[..., Any], *args: Any) -> None: + async def runner(): + try: + await coro(*args) + except Exception as e: + traceback.print_exc() + warnings.warn(f"Background task error in handler: {e}", RuntimeWarning) + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + warnings.warn("Background task skipped: no running event loop", RuntimeWarning) + return + + loop.create_task(runner()) + + def _emit(self, event_type: str, *args: Any) -> None: for sub in self._subscribers.get(event_type, []): try: if inspect.iscoroutinefunction(sub): - await sub(*args) + self._run_background_task(sub, *args) else: sub(*args) except Exception as e: - # We should log the exception but not propagate it - print(f"Error in event handler: {e}") + traceback.print_exc() + warnings.warn(f"Synchronous error in handler: {e}", RuntimeWarning) async def _notify(self, event_type: str, *args: Any) -> None: - tasks = [] - + """ + + """ for pipeline in self._pipelines.get(event_type, []): - if inspect.iscoroutinefunction(pipeline): - task = asyncio.create_task(pipeline(*args)) - tasks.append(task) - else: - try: + try: + if inspect.iscoroutinefunction(pipeline): + await pipeline(*args) + else: pipeline(*args) - except Exception as e: - await self._emit("error", e) - raise e - - if tasks: - results = await asyncio.gather(*tasks, return_exceptions=True) - for result in results: - if isinstance(result, Exception): - await self._emit("error", result) - raise result - - await self._emit(event_type, *args) + except Exception as e: + self._emit("error", e) + raise e + + self._emit(event_type, *args) def on(self, event: str, handler: Callable) -> None: self._subscribers.setdefault(event, []).append(handler) diff --git a/cbor_rpc/json_transformer.py b/cbor_rpc/json_transformer.py deleted file mode 100644 index f01ef42..0000000 --- a/cbor_rpc/json_transformer.py +++ /dev/null @@ -1,82 +0,0 @@ -import json -from typing import Any, Union -from .emitter import AbstractEmitter -from .transformer import Transformer -from .async_pipe import Pipe - -class JsonTransformer(AbstractEmitter, Transformer[Any, Any]): - """ - A transformer that encodes Python objects to JSON strings and decodes JSON strings back to Python objects. - """ - - def __init__(self, underlying_pipe: Pipe[Any, Any], encoding: str = 'utf-8'): - """ - Initialize the JSON transformer. - - Args: - underlying_pipe: The underlying pipe to transform - encoding: Text encoding to use (default: 'utf-8') - """ - AbstractEmitter.__init__(self) - Transformer.__init__(self, underlying_pipe) - self.encoding = encoding - - async def encode(self, data: Any) -> bytes: - """ - Encode Python object to JSON bytes. - - Args: - data: Python object to encode - - Returns: - JSON-encoded bytes - - Raises: - TypeError: If data is not JSON serializable - UnicodeEncodeError: If encoding fails - """ - json_str = json.dumps(data, ensure_ascii=False, separators=(',', ':')) - return json_str.encode(self.encoding) - - async def decode(self, data: Union[bytes, str, None]) -> Any: - """ - Decode JSON bytes/string to Python object. - - Args: - data: JSON bytes or string to decode - - Returns: - Decoded Python object - - Raises: - json.JSONDecodeError: If data is not valid JSON - UnicodeDecodeError: If bytes cannot be decoded - TypeError: If data is None or of invalid type - """ - if data is None: - raise TypeError("Expected bytes or str, got None") - - if isinstance(data, bytes): - json_str = data.decode(self.encoding) - elif isinstance(data, str): - json_str = data - else: - raise TypeError(f"Expected bytes or str, got {type(data)}") - - return json.loads(json_str) - - @classmethod - def create_pair(cls, encoding: str = 'utf-8') -> tuple['JsonTransformer', 'JsonTransformer']: - """ - Create a pair of connected JSON transformers. - - Args: - encoding: Text encoding to use - - Returns: - A tuple of (transformer1, transformer2) - """ - pipe1, pipe2 = Pipe.create_pair() - transformer1 = cls(pipe1, encoding) - transformer2 = cls(pipe2, encoding) - return transformer1, transformer2 diff --git a/cbor_rpc/pipe/__init__.py b/cbor_rpc/pipe/__init__.py new file mode 100644 index 0000000..a82c6d3 --- /dev/null +++ b/cbor_rpc/pipe/__init__.py @@ -0,0 +1,3 @@ +from .event_pipe import EventPipe +from .pipe import Pipe +from .aio_pipe import AioPipe diff --git a/cbor_rpc/pipe/aio_pipe.py b/cbor_rpc/pipe/aio_pipe.py new file mode 100644 index 0000000..45c8989 --- /dev/null +++ b/cbor_rpc/pipe/aio_pipe.py @@ -0,0 +1,185 @@ +from typing import Any, Optional, TypeVar, Generic, Union +import asyncio +from abc import ABC +from .event_pipe import EventPipe # Assuming EventPipe is in a separate module + +# Constrain T1 and T2 to bytes or bytearray for type safety with asyncio streams +T1 = TypeVar("T1", bound=Union[bytes, bytearray]) +T2 = TypeVar("T2", bound=Union[bytes, bytearray]) + + +class AioPipe(EventPipe[T1, T2], ABC): + """ + Abstract base class for asynchronous pipes that wrap asyncio.StreamReader and asyncio.StreamWriter. + Extends EventPipe to provide event-based communication for async I/O. + + Attributes: + DEFAULT_READ_CHUNK_SIZE (int): Default size of chunks to read from the stream (8192 bytes). + """ + + DEFAULT_READ_CHUNK_SIZE = 8192 + + def __init__( + self, + reader: Optional[asyncio.StreamReader] = None, + writer: Optional[asyncio.StreamWriter] = None, + chunk_size: int = DEFAULT_READ_CHUNK_SIZE, + ): + """ + Initialize the AioPipe with optional reader, writer, and chunk size. + + Args: + reader: Optional asyncio.StreamReader for reading data. + writer: Optional asyncio.StreamWriter for writing data. + chunk_size: Size of chunks to read from the stream (default: 8192 bytes). + + Raises: + ValueError: If only one of reader or writer is provided. + """ + super().__init__() + if (reader is None) != (writer is None): # Both must be None or both non-None + raise ValueError("Both reader and writer must be provided or neither") + self._reader = reader + self._writer = writer + self._chunk_size = chunk_size + self._connected = False + self._closed = False + self._read_task: Optional[asyncio.Task] = None + + async def _setup_connection(self) -> None: + """ + Set up the connection and start reading data. + + Raises: + RuntimeError: If reader or writer is not initialized. + """ + if not self._reader or not self._writer: + raise RuntimeError("Reader or writer not initialized") + + self._connected = True + self._closed = False + self._read_task = asyncio.create_task(self._read_loop()) + + try: + await self._notify("connect") + except Exception as e: + self._emit("error", e) # Synchronous _emit + await self._close_connection() + raise + + async def _read_loop(self) -> None: + """ + Continuously read data from the connection and emit data events. + + Stops on EOF, cancellation, or error. Closes the connection in case of errors. + """ + try: + while self._connected and not self._closed and self._reader: + try: + data = await self._reader.read(self._chunk_size) + if not data: # EOF reached + break + try: + await self._notify("data", data) + except Exception as e: + self._emit("error", e) # Synchronous _emit + break + except asyncio.CancelledError: + break + except BaseException as e: + if isinstance(e, GeneratorExit): + break + self._emit("error", e) # Synchronous _emit + break + except BaseException as e: + if not isinstance(e, (asyncio.CancelledError, GeneratorExit)): + self._emit("error", e) # Synchronous _emit + finally: + if not self._closed: + await self._close_connection() + + async def _close_connection(self, *args: Any) -> None: + """ + Close the connection and clean up resources. + + Args: + *args: Optional arguments to pass to the 'close' event. + """ + if self._closed: + return + + self._closed = True + self._connected = False + + # Cancel the read task + if self._read_task and not self._read_task.done(): + task, self._read_task = self._read_task, None + try: + task.cancel() + await task + except asyncio.CancelledError: + pass + except Exception as e: + self._emit("error", e) # Synchronous _emit + + # Close the writer + if self._writer: + writer, self._writer = self._writer, None + try: + writer.close() + await writer.wait_closed() + except Exception as e: + self._emit("error", e) # Synchronous _emit + + # Emit close event + try: + self._emit("close", *args) # Synchronous _emit + except Exception as e: + print(f"AioPipe: Failed to emit close event: {e}") + + def is_connected(self) -> bool: + """ + Check if the connection is active. + + Returns: + bool: True if connected and not closed, False otherwise. + """ + return self._connected and not self._closed + + async def write(self, chunk: T1) -> bool: + """ + Write data to the connection. + + Args: + chunk: The data to write (bytes or bytearray). + + Returns: + bool: True if the write was successful, False otherwise. + + Raises: + ConnectionError: If not connected or writer is unavailable. + TypeError: If chunk is not bytes or bytearray. + """ + if not self._connected or self._closed: + raise ConnectionError("Not connected") + if not self._writer: + raise ConnectionError("Writer not available") + if not isinstance(chunk, (bytes, bytearray)): + raise TypeError(f"Expected bytes or bytearray, got {type(chunk).__name__}") + + try: + self._writer.write(chunk) + await self._writer.drain() + return True + except Exception as e: + self._emit("error", e) # Synchronoous _emit + return False + + async def terminate(self, *args: Any) -> None: + """ + Terminate the connection. + + Args: + *args: Optional arguments (e.g., code, reason) to pass to the close event. + """ + await self._close_connection(*args) diff --git a/cbor_rpc/pipe/event_pipe.py b/cbor_rpc/pipe/event_pipe.py new file mode 100644 index 0000000..559e1ef --- /dev/null +++ b/cbor_rpc/pipe/event_pipe.py @@ -0,0 +1,74 @@ +from typing import Any, TypeVar, Generic, Callable, Tuple, Optional +from abc import ABC, abstractmethod +import asyncio +import inspect +from ..event.emitter import AbstractEmitter + +# Generic type variables +T1 = TypeVar("T1") +T2 = TypeVar("T2") + + +class EventPipe(AbstractEmitter, Generic[T1, T2]): + """ + Event Pipe or are event based way for read/write. + You cannot directly read from a Pipe. You have to use a pipeline("data") to register one or more functions to read data. + """ + + @abstractmethod + async def write(self, chunk: T1) -> bool: + pass + + @abstractmethod + async def terminate(self, *args: Any) -> None: + pass + + @staticmethod + def create_inmemory_pair() -> Tuple["EventPipe[Any, Any]", "EventPipe[Any, Any]"]: + """ + Create a pair of connected pipes for bidirectional communication. + + Returns: + A tuple of (pipe1, pipe2) where data written to pipe1 is emitted on pipe2 and vice versa. + """ + + class ConnectedPipe(EventPipe[Any, Any]): + def __init__(self): + super().__init__() + self.connected_pipe: Optional["ConnectedPipe"] = None + self._closed = False + + def connect_to(self, other: "ConnectedPipe"): + self.connected_pipe = other + other.connected_pipe = self + + async def write(self, chunk: Any) -> bool: + if self._closed or not self.connected_pipe or self.connected_pipe._closed: + return False + async def _deliver() -> None: + try: + await self.connected_pipe._notify("data", chunk) + except Exception: + # Receiver-side errors should not bubble to the writer. + # Mirror AioPipe behavior by closing both ends on pipeline errors. + await self.connected_pipe.terminate() + return + + loop = asyncio.get_running_loop() + loop.create_task(_deliver()) + return True + + async def terminate(self, *args: Any) -> None: + if self._closed: + return + self._closed = True + self._emit("close", *args) + if self.connected_pipe and not self.connected_pipe._closed: + self.connected_pipe._closed = True + self.connected_pipe._emit("close", *args) + + pipe1 = ConnectedPipe() + pipe2 = ConnectedPipe() + pipe1.connect_to(pipe2) + + return pipe1, pipe2 diff --git a/cbor_rpc/pipe/pipe.py b/cbor_rpc/pipe/pipe.py new file mode 100644 index 0000000..bcb77df --- /dev/null +++ b/cbor_rpc/pipe/pipe.py @@ -0,0 +1,120 @@ +from abc import ABC, abstractmethod +from typing import Any, TypeVar, Generic, Optional, Tuple +import asyncio + +from cbor_rpc.pipe.event_pipe import EventPipe +from ..event.emitter import AbstractEmitter + +T1 = TypeVar("T1") +T2 = TypeVar("T2") + + +class Pipe(AbstractEmitter, Generic[T1, T2], ABC): + """ + Abstract Pipe defining async event-based read/write/terminate interface. + We write T1 + We read T2. + Normally pipes will have T1=T2. if T1!=T2, then you need to use a transformer to convert between them. + """ + + @abstractmethod + async def write(self, chunk: T1) -> bool: + pass + + @abstractmethod + async def read(self, timeout: Optional[float] = None) -> Optional[T2]: + pass + + @abstractmethod + async def terminate(self, *args: Any) -> None: + pass + + @staticmethod + def create_pair() -> Tuple["Pipe[Any, Any]", "Pipe[Any, Any]"]: + class InMemoryPipe(Pipe[Any, Any]): + def __init__(self): + super().__init__() + self._closed = False + self._buffer: asyncio.Queue[Optional[Any]] = asyncio.Queue() + self.connected_pipe: Optional["InMemoryPipe"] = None + + async def write(self, chunk: Any) -> bool: + if self._closed or not self.connected_pipe or self.connected_pipe._closed: + return False + await self.connected_pipe._buffer.put(chunk) + return True + + async def read(self, timeout: Optional[float] = None) -> Optional[Any]: + if self._closed: + return None + try: + if timeout is not None and timeout > 0: + return await asyncio.wait_for(self._buffer.get(), timeout) + else: + return await self._buffer.get() + except asyncio.TimeoutError: + return None + except asyncio.CancelledError: + # If read is cancelled, put back the None if it was a termination signal + if not self._buffer.empty(): + item = self._buffer.get_nowait() + if item is None: + await self._buffer.put(None) + raise + + async def terminate(self, *args: Any) -> None: + if self._closed: + return + self._closed = True + # Signal termination to any pending reads + await self._buffer.put(None) + await self._notify("close", *args) # Notify external listeners + + if self.connected_pipe and not self.connected_pipe._closed: + await self.connected_pipe._buffer.put(None) # Signal termination to connected pipe + await self.connected_pipe.terminate(*args) # Recursively terminate connected pipe + + a = InMemoryPipe() + b = InMemoryPipe() + a.connected_pipe = b + b.connected_pipe = a + return a, b + + def make_event_based(self) -> "EventPipe[T1, T2]": + parent = self + + class PipeToEvent(EventPipe[T1, T2]): + def __init__(self): + super().__init__() + self._closed = False + + # Spawn a task to pump data from parent.read into events + self._pump_task = asyncio.create_task(self._pump()) + + async def _pump(self): + while not self._closed: + chunk = await parent.read() + if chunk is None: + print("PipeToEvent: Received termination signal, terminating event pipe.") + await self.terminate() + break + await self._notify("data", chunk) + + async def write(self, chunk: T1) -> bool: + return await parent.write(chunk) + + async def terminate(self, *args: Any) -> None: + if self._closed: + return + self._closed = True + print("PipeToEvent: Terminating event pipe.") + await parent.terminate(*args) + self._emit("close", *args) + if self._pump_task: + self._pump_task.cancel() + try: + await self._pump_task + except asyncio.CancelledError: + pass + + return PipeToEvent() diff --git a/cbor_rpc/rpc/__init__.py b/cbor_rpc/rpc/__init__.py new file mode 100644 index 0000000..e734bb8 --- /dev/null +++ b/cbor_rpc/rpc/__init__.py @@ -0,0 +1,8 @@ +from .server_base import Server + + +from .rpc_base import RpcClient, RpcServer, RpcAuthorizedClient +from .rpc_v1 import RpcV1 +from .context import RpcCallContext +from .logging import RpcLogger +from .rpc_server import RpcV1Server diff --git a/cbor_rpc/rpc/context.py b/cbor_rpc/rpc/context.py new file mode 100644 index 0000000..7cf0ee5 --- /dev/null +++ b/cbor_rpc/rpc/context.py @@ -0,0 +1,6 @@ +from .logging import RpcLogger + + +class RpcCallContext: + def __init__(self, logger: RpcLogger): + self.logger = logger diff --git a/cbor_rpc/rpc/logging.py b/cbor_rpc/rpc/logging.py new file mode 100644 index 0000000..87bb1c0 --- /dev/null +++ b/cbor_rpc/rpc/logging.py @@ -0,0 +1,28 @@ +from typing import Any, Callable + + +class RpcLogger: + def __init__( + self, + send_log: Callable[[int, int, Any, Any], None], + ref_proto: int, + ref_id: Callable[[], Any], + ): + self._send_log = send_log + self._ref_proto = ref_proto + self._ref_id = ref_id + + def log(self, content: Any) -> None: + self._send_log(3, self._ref_proto, self._ref_id(), content) + + def warn(self, content: Any) -> None: + self._send_log(2, self._ref_proto, self._ref_id(), content) + + def crit(self, content: Any) -> None: + self._send_log(1, self._ref_proto, self._ref_id(), content) + + def verbose(self, content: Any) -> None: + self._send_log(4, self._ref_proto, self._ref_id(), content) + + def debug(self, content: Any) -> None: + self._send_log(5, self._ref_proto, self._ref_id(), content) diff --git a/cbor_rpc/rpc/rpc_base.py b/cbor_rpc/rpc/rpc_base.py new file mode 100644 index 0000000..fb889f2 --- /dev/null +++ b/cbor_rpc/rpc/rpc_base.py @@ -0,0 +1,53 @@ +from typing import Any, Dict, List, Optional, Callable +from abc import ABC, abstractmethod +from ..pipe.event_pipe import EventPipe + + +class RpcClient(ABC): + @abstractmethod + async def call_method(self, method: str, *args: Any) -> Any: + pass + + @abstractmethod + async def fire_method(self, method: str, *args: Any) -> None: + pass + + @abstractmethod + def set_timeout(self, milliseconds: int) -> None: + pass + + +class RpcAuthorizedClient(RpcClient): + @abstractmethod + def get_id(self) -> str: + pass + + +class RpcServer(ABC): + @abstractmethod + async def call_method(self, connection_id: str, method: str, *args: Any) -> Any: + pass + + @abstractmethod + async def fire_method(self, connection_id: str, method: str, *args: Any) -> None: + pass + + @abstractmethod + async def disconnect(self, connection_id: str, reason: Optional[str] = None) -> None: + pass + + @abstractmethod + def get_client(self, connection_id: str) -> Optional[RpcAuthorizedClient]: + pass + + @abstractmethod + def with_client(self, connection_id: str, action: Callable) -> bool: + pass + + @abstractmethod + def set_timeout(self, milliseconds: int) -> None: + pass + + @abstractmethod + def is_active(self, connection_id: str) -> bool: + pass diff --git a/cbor_rpc/rpc/rpc_server.py b/cbor_rpc/rpc/rpc_server.py new file mode 100644 index 0000000..6d26c78 --- /dev/null +++ b/cbor_rpc/rpc/rpc_server.py @@ -0,0 +1,78 @@ +from typing import Any, Dict, List, Optional, Callable +from abc import ABC, abstractmethod + +from cbor_rpc.rpc.server_base import Server +from .rpc_base import RpcAuthorizedClient, RpcServer +from .rpc_v1 import RpcV1 +from .context import RpcCallContext +from cbor_rpc.pipe.event_pipe import EventPipe + + +class RpcV1Server(RpcServer): + def __init__(self): + self.active_connections: Dict[str, EventPipe[Any, Any]] = {} + self.rpc_clients: Dict[str, RpcV1] = {} + self.timeout = 30000 + + async def add_connection(self, conn_id: str, rpc_client: EventPipe[Any, Any]) -> None: + def method_handler(context: RpcCallContext, method: str, args: List[Any]) -> Any: + return self.handle_method_call(conn_id, context, method, args) + + client_rpc = RpcV1.make_rpc_v1(rpc_client, conn_id, method_handler) + client_rpc.set_timeout(self.timeout) + + self.active_connections[conn_id] = rpc_client + self.rpc_clients[conn_id] = client_rpc + + # Set up cleanup on close + async def cleanup(*args): + self.active_connections.pop(conn_id, None) + self.rpc_clients.pop(conn_id, None) + + rpc_client.on("close", cleanup) + + async def disconnect(self, connection_id: str, reason: Optional[str] = None) -> None: + base = self.active_connections.pop(connection_id, None) + self.rpc_clients.pop(connection_id, None) + if base: + print("RpcV1Server: Disconnecting client:", connection_id) + await base.terminate(1000, reason or "Server terminated connection") + + def set_timeout(self, milliseconds: int) -> None: + self.timeout = milliseconds + + def is_active(self, connection_id: str) -> bool: + return connection_id in self.rpc_clients + + def get_client(self, connection_id: str) -> Optional[RpcAuthorizedClient]: + return self.rpc_clients.get(connection_id) + + async def call_method(self, connection_id: str, method: str, *args: Any) -> Any: + client = self.rpc_clients.get(connection_id) + if client: + return await client.call_method(method, *args) + raise Exception("Client is not active") + + async def fire_method(self, connection_id: str, method: str, *args: Any) -> None: + client = self.rpc_clients.get(connection_id) + if client: + await client.fire_method(method, *args) + else: + raise Exception("Client is not active") + + @abstractmethod + async def handle_method_call( + self, + connection_id: str, + context: RpcCallContext, + method: str, + args: List[Any], + ) -> Any: + pass + + def with_client(self, connection_id: str, action: Callable) -> bool: + client = self.rpc_clients.get(connection_id) + if client: + action(client) + return True + return False diff --git a/cbor_rpc/rpc/rpc_v1.py b/cbor_rpc/rpc/rpc_v1.py new file mode 100644 index 0000000..596b85c --- /dev/null +++ b/cbor_rpc/rpc/rpc_v1.py @@ -0,0 +1,274 @@ +import sys +import logging +from typing import Any, Dict, List, Optional, Callable +from abc import abstractmethod +import asyncio + +from .rpc_base import RpcClient +from .context import RpcCallContext +from .logging import RpcLogger +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.timed_promise import TimedPromise + +logger = logging.getLogger(__name__) + + +class RpcCore(RpcClient): + protocol_id = 1 + def __init__(self, pipe: EventPipe[Any, Any]): + self.pipe = pipe + self._counter = 0 + self._promises: Dict[int, TimedPromise] = {} + self._timeout = 30000 + self._peer_log_level = 0 + self._desired_log_level = 0 + self.logger = RpcLogger(self._send_log, 1, self._get_default_ref_id) + + async def resolve_result(result: Any) -> Any: + """Recursively resolve coroutines or nested coroutines.""" + while asyncio.iscoroutine(result): + result = await result + return result + + self._resolve_result = resolve_result + + async def on_data(data: List[Any]) -> None: + try: + if not isinstance(data, list) or len(data) < 2: + logger.warning(f"RpcCore: Invalid message format: {data}") + return + + protocol_id = data[0] + if protocol_id == 1: + await self.handle_proto_1(data) + elif protocol_id == 2: + await self.handle_proto_2(data) + elif protocol_id == 3: + await self.handle_proto_3(data) + else: + logger.warning(f"RpcCore: Unsupported protocol: {data}") + + except Exception as e: + logger.error(f"Error processing RPC message: {e}") + + self.pipe.on("data", on_data) + + async def handle_proto_1(self, data: List[Any]) -> None: + """Handle Protocol 1 (RPC) messages.""" + if len(data) < 3: + logger.warning(f"RpcCore [Proto 1]: Invalid format: {data}") + return + + sub_proto_id = data[1] + + # Responses: [1, 0, id, result] (Success) or [1, 1, id, error] (Error) + if sub_proto_id <2: + if len(data) < 4: + logger.warning(f"RpcCore [Proto 1]: Invalid response format: {data}") + return + + id_ = data[2] + payload = data[3] + + promise = self._promises.pop(id_, None) + if promise: + if sub_proto_id == 0: # Success + await promise.resolve(payload) + else: # Error + await promise.reject(payload) + else: + logger.warning(f"Received rpc reply for expired request id: {id_}, success={sub_proto_id==0}, data={payload}") + + # Method Call (2) or Fire (3): [1, 2/3, id, method, params] + elif sub_proto_id == 2 or sub_proto_id == 3: + if len(data) < 5: + logger.warning(f"RpcCore [Proto 1]: Invalid call format: {data}") + return + + id_ = data[2] + method = data[3] + params = data[4] + + try: + # Call the method + context = RpcCallContext(self.logger) + result = self.handle_method_call(context, method, params) + + if sub_proto_id == 2: # Expect Response + async def handle_response() -> None: + try: + resolved_result = await self._resolve_result(result) + # Send Success: [1, 0, id, result] + await self.pipe.write([1, 0, id_, resolved_result]) + except Exception as e: + # Send Error: [1, 1, id, error] + await self.pipe.write([1, 1, id_, str(e)]) + + asyncio.create_task(handle_response()) + else: + # Fire (3) + async def handle_fire() -> None: + try: + await self._resolve_result(result) + except Exception as e: + logger.error(f"Fired method error: {method}, params={params}, error={e}") + asyncio.create_task(handle_fire()) + + except Exception as e: + if sub_proto_id == 2: + asyncio.create_task(self.pipe.write([1, 1, id_, str(e)])) + else: + logger.error(f"Fired method error: {method}, params={params}, error={e}") + else: + logger.warning(f"RpcCore [Proto 1]: Unknown sub-protocol: {sub_proto_id}") + + + async def handle_proto_2(self, data: List[Any]) -> None: + """Handle Protocol 2 (Logging) messages.""" + # Format: [2, log_level, ref_proto, ref_id, content] + if len(data) >= 3 and data[1] == 0: + self._peer_log_level = data[2] + return + + if len(data) < 5: + logger.warning(f"RpcCore [Proto 2]: Invalid format: {data}") + return + + log_level = data[1] + + ref_proto = data[2] + ref_id = data[3] + content = data[4] + + level_map = {1: "CRITICAL", 2: "WARN", 3: "INFO", 4: "VERBOSE", 5: "DEBUG"} + level_str = level_map.get(log_level, f"LEVEL-{log_level}") + + logger.info(f"[RemoteLog:{level_str}] p{ref_proto}:{ref_id} {content}") + + async def handle_proto_3(self, data: List[Any]) -> None: + logger.warning(f"RpcCore [Proto 3]: Unsupported event message: {data}") + + + async def call_method(self, method: str, *args: Any) -> Any: + counter = self._counter + self._counter += 1 + + def timeout_callback(): + self._promises.pop(counter, None) + + promise = TimedPromise(self._timeout, timeout_callback) + self._promises[counter] = promise + # New: [1, 2, counter, method, args] + await self.pipe.write([1, 2, counter, method, list(args)]) + return await promise.promise + + async def fire_method(self, method: str, *args: Any) -> None: + counter = self._counter + self._counter += 1 + # New: [1, 3, counter, method, args] + await self.pipe.write([1, 3, counter, method, list(args)]) + + def set_timeout(self, milliseconds: int) -> None: + self._timeout = milliseconds + + async def set_log_level(self, level: int) -> None: + """Send a request to set the remote log level.""" + # Protocol 2, Sub-protocol 0: [2, 0, level] + self._desired_log_level = level + await self.pipe.write([2, 0, level]) + + def _get_default_ref_id(self) -> int: + return 0 + + def _send_log(self, level: int, ref_proto: int, ref_id: Any, content: Any) -> None: + if self._peer_log_level <= 0: + return + if level > self._peer_log_level: + return + asyncio.create_task(self.pipe.write([2, level, ref_proto, ref_id, content])) + + @abstractmethod + def get_id(self) -> str: + pass + + @abstractmethod + def handle_method_call(self, context: RpcCallContext, method: str, args: List[Any]) -> Any: + pass + + @classmethod + def make_rpc_v1( + cls, + pipe: EventPipe[Any, Any], + id_: str, + method_handler: Callable, + ) -> "RpcCore": + class ConcreteRpcV1(cls): + def get_id(self) -> str: + return id_ + + def handle_method_call(self, context: RpcCallContext, method: str, args: List[Any]) -> Any: + return method_handler(context, method, args) + + async def on_event(self, context: RpcCallContext, topic: str, message: Any) -> None: + return None + + return ConcreteRpcV1(pipe) + + @classmethod + def read_only_client(cls, pipe: EventPipe[Any, Any]) -> "RpcCore": + def method_handler(context: RpcCallContext, method: str, args: List[Any]) -> Any: + raise Exception("Client Only Implementation") + + return cls.make_rpc_v1(pipe, "", method_handler) + + +class RpcV1(RpcCore): + def __init__(self, pipe: EventPipe[Any, Any]): + super().__init__(pipe) + self._waiters: Dict[str, TimedPromise] = {} + self._last_event_topic: Optional[str] = None + self.event_logger = RpcLogger(self._send_log, 3, self._get_last_event_topic) + + async def handle_proto_3(self, data: List[Any]) -> None: + if len(data) < 4: + logger.warning(f"RpcV1 [Proto 3]: Invalid event format: {data}") + return + sub_proto_id = data[1] + if sub_proto_id != 0: + logger.warning(f"RpcV1 [Proto 3]: Unknown sub-protocol: {sub_proto_id}") + return + await self._on_event(data[2], data[3]) + + async def emit(self, topic: str, args: Any) -> None: + await self.pipe.write([3, 0, topic, args]) + + async def _on_event(self, method: str, message: Any) -> None: + self._last_event_topic = method + waiter = self._waiters.pop(method, None) + if waiter: + await waiter.resolve(message) + else: + context = RpcCallContext(self.event_logger) + await self.on_event(context, method, message) + + async def wait_next_event(self, topic: str, timeout_ms: Optional[int] = None) -> Any: + if topic in self._waiters: + raise Exception("Already waiting for event") + + def timeout_callback(): + self._waiters.pop(topic, None) + + waiter = TimedPromise( + timeout_ms or self._timeout, + timeout_callback, + f"Timeout Waiting for Event on: {topic}", + ) + self._waiters[topic] = waiter + return await waiter.promise + + @abstractmethod + async def on_event(self, context: RpcCallContext, topic: str, message: Any) -> None: + pass + + def _get_last_event_topic(self) -> Any: + return self._last_event_topic diff --git a/cbor_rpc/server_base.py b/cbor_rpc/rpc/server_base.py similarity index 86% rename from cbor_rpc/server_base.py rename to cbor_rpc/rpc/server_base.py index cc208cb..9f214d2 100644 --- a/cbor_rpc/server_base.py +++ b/cbor_rpc/rpc/server_base.py @@ -1,20 +1,20 @@ from typing import Any, Callable, Optional, Set, TypeVar, Generic from abc import ABC, abstractmethod import asyncio -from .emitter import AbstractEmitter -from .async_pipe import Pipe +from ..event.emitter import AbstractEmitter +from ..pipe import EventPipe # Generic type variable for pipe types -P = TypeVar('P', bound=Pipe) +P = TypeVar("P", bound=EventPipe) class Server(AbstractEmitter, Generic[P]): """ Abstract server class that manages connections and emits connection events. - + Type parameter P specifies the type of Pipe that this server handles. """ - + def __init__(self): super().__init__() self._connections: Set[P] = set() @@ -24,7 +24,7 @@ def __init__(self): async def start(self, *args, **kwargs) -> Any: """ Start the server. - + Returns: Server-specific information (e.g., address, port) """ @@ -35,22 +35,30 @@ async def stop(self) -> None: """Stop the server and clean up resources.""" pass + @abstractmethod + async def accept(self, pipe: P) -> bool: + pass + async def _add_connection(self, pipe: P) -> None: """ Add a new connection and emit a connection event. - + Args: pipe: The pipe representing the connection """ + if not await self.accept(pipe): + await pipe.terminate() + return self._connections.add(pipe) - + # Set up cleanup when connection closes async def cleanup(*args): self._connections.discard(pipe) + pipe.on("close", cleanup) - + # Emit connection event - await self._emit("connection", pipe) + await self._notify("connection", pipe) def get_connections(self) -> Set[P]: """Get all active connections.""" @@ -69,7 +77,7 @@ async def close_all_connections(self) -> None: def on_connection(self, handler: Callable[[P], None]) -> None: """ Register a handler for new connections. - + Args: handler: Function that takes a Pipe of type P as argument """ diff --git a/cbor_rpc/server.py b/cbor_rpc/server.py deleted file mode 100644 index 22749a3..0000000 --- a/cbor_rpc/server.py +++ /dev/null @@ -1,135 +0,0 @@ -from typing import Any, Dict, List, Optional, Callable -from abc import ABC, abstractmethod -import asyncio -from .client import RpcClient, RpcAuthorizedClient, RpcV1 -from .async_pipe import Pipe - - -class RpcServer(ABC): - @abstractmethod - async def emit(self, connection_id: str, topic: str, message: Any) -> None: - pass - - @abstractmethod - async def broadcast(self, topic: str, message: Any) -> None: - pass - - @abstractmethod - async def call_method(self, connection_id: str, method: str, *args: Any) -> Any: - pass - - @abstractmethod - async def fire_method(self, connection_id: str, method: str, *args: Any) -> None: - pass - - @abstractmethod - async def disconnect(self, connection_id: str, reason: Optional[str] = None) -> None: - pass - - @abstractmethod - def get_client(self, connection_id: str) -> Optional[RpcAuthorizedClient]: - pass - - @abstractmethod - def with_client(self, connection_id: str, action: Callable) -> bool: - pass - - @abstractmethod - def set_timeout(self, milliseconds: int) -> None: - pass - - @abstractmethod - def is_active(self, connection_id: str) -> bool: - pass - - -class RpcV1Server(RpcServer): - def __init__(self): - self.active_connections: Dict[str, RpcV1] = {} - self.timeout = 30000 - - async def add_connection(self, conn_id: str, rpc_client: Pipe[Any, Any]) -> None: - def method_handler(method: str, args: List[Any]) -> Any: - return self.handle_method_call(conn_id, method, args) - - async def event_handler(topic: str, data: Any) -> None: - await self._handle_event(conn_id, topic, data) - - client_rpc = RpcV1.make_rpc_v1( - rpc_client, - conn_id, - method_handler, - event_handler - ) - client_rpc.set_timeout(self.timeout) - self.active_connections[conn_id] = client_rpc - - # Set up cleanup on close - async def cleanup(*args): - self.active_connections.pop(conn_id, None) - rpc_client.on("close", cleanup) - - async def disconnect(self, connection_id: str, reason: Optional[str] = None) -> None: - client = self.active_connections.pop(connection_id, None) - if client: - await client.pipe.terminate(1000, reason or "Server terminated connection") - - def set_timeout(self, milliseconds: int) -> None: - self.timeout = milliseconds - - def is_active(self, connection_id: str) -> bool: - return connection_id in self.active_connections - - async def broadcast(self, topic: str, message: Any) -> None: - tasks = [] - for client in self.active_connections.values(): - tasks.append(client.emit(topic, message)) - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) - - def get_client(self, connection_id: str) -> Optional[RpcAuthorizedClient]: - return self.active_connections.get(connection_id) - - async def call_method(self, connection_id: str, method: str, *args: Any) -> Any: - client = self.active_connections.get(connection_id) - if client: - return await client.call_method(method, *args) - raise Exception("Client is not active") - - async def emit(self, connection_id: str, topic: str, message: Any) -> None: - client = self.active_connections.get(connection_id) - if client: - await client.emit(topic, message) - else: - raise Exception("Client is not active") - - async def _handle_event(self, connection_id: str, topic: str, message: Any) -> None: - if await self.validate_event_broadcast(connection_id, topic, message): - tasks = [] - for key, client in self.active_connections.items(): - if key != connection_id: - tasks.append(client.emit(topic, message)) - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) - - async def fire_method(self, connection_id: str, method: str, *args: Any) -> None: - client = self.active_connections.get(connection_id) - if client: - await client.fire_method(method, *args) - else: - raise Exception("Client is not active") - - @abstractmethod - async def handle_method_call(self, connection_id: str, method: str, args: List[Any]) -> Any: - pass - - @abstractmethod - async def validate_event_broadcast(self, connection_id: str, topic: str, message: Any) -> bool: - pass - - def with_client(self, connection_id: str, action: Callable) -> bool: - client = self.active_connections.get(connection_id) - if client: - action(client) - return True - return False diff --git a/cbor_rpc/ssh/__init__.py b/cbor_rpc/ssh/__init__.py new file mode 100644 index 0000000..edfc2c3 --- /dev/null +++ b/cbor_rpc/ssh/__init__.py @@ -0,0 +1 @@ +from .ssh_pipe import SshPipe diff --git a/cbor_rpc/ssh/ssh_pipe.py b/cbor_rpc/ssh/ssh_pipe.py new file mode 100644 index 0000000..67cab1d --- /dev/null +++ b/cbor_rpc/ssh/ssh_pipe.py @@ -0,0 +1,121 @@ +import asyncio +import asyncssh +from typing import Optional, Tuple, Union + +from cbor_rpc.pipe.aio_pipe import AioPipe + + +class SshPipe(AioPipe[bytes, bytes]): + """ + A Pipe implementation that works over an SSH connection. + It uses asyncssh to establish and manage the SSH session and channels. + """ + + def __init__( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ssh_client: asyncssh.SSHClientConnection, + ssh_channel: asyncssh.SSHClientChannel, + ): + super().__init__(reader, writer) + self._ssh_client = ssh_client + self._ssh_channel = ssh_channel + + @classmethod + async def connect( + cls, + host: str, + port: int = 22, + username: str = "root", + password: Optional[str] = None, + ssh_key_content: Optional[str] = None, + ssh_key_passphrase: Optional[str] = None, + known_hosts: Optional[Union[str, list]] = None, + timeout: Optional[float] = None, + command: str = "sh -l", + ) -> "SshPipe": + """ + Establishes an SSH connection and opens a session channel, returning an SshPipe. + + Args: + host: The hostname or IP address of the SSH server. + port: The port number for the SSH connection (default: 22). + username: The username for SSH authentication. + password: The password for password-based authentication (optional). + ssh_key_content: The content of an SSH private key for key-based authentication (optional). + known_hosts: Path to a known_hosts file or a list of host keys (optional). + If None, host key checking is disabled (use with caution). + timeout: Optional timeout for the connection attempt. + command: The command to execute on the remote host to establish the pipe (default: 'sh -l'). + + Returns: + An SshPipe instance connected over SSH. + + Raises: + asyncssh.Error: If the SSH connection or channel establishment fails. + asyncio.TimeoutError: If the connection times out. + """ + client_keys = None + if ssh_key_content: + client_keys = [asyncssh.import_private_key(ssh_key_content, passphrase=ssh_key_passphrase)] + + try: + conn = await asyncio.wait_for( + asyncssh.connect( + host, + port=port, + options=asyncssh.SSHClientConnectionOptions( + username=username, + password=password, + client_keys=client_keys, + passphrase=ssh_key_passphrase, # Passphrase for encrypted client keys + ignore_encrypted=False, # Do not ignore encrypted keys + ), + known_hosts=known_hosts, + ), + timeout=timeout, + ) + + # Create a process on the SSH connection to get stdin/stdout streams. + # Use the provided 'command' argument. + # Explicitly set encoding=None to ensure raw bytes are handled. + # Set term_type=None and stdin=asyncssh.PIPE to ensure stdin is always a writable stream. + channel = await conn.create_process(command, term_type=None, encoding=None, stdin=asyncssh.PIPE) + + reader = channel.stdout + writer = channel.stdin + + pipe = cls(reader, writer, conn, channel) + await pipe._setup_connection() + return pipe + + except asyncssh.Error as e: + # asyncssh.Error expects a 'reason' argument + raise asyncssh.Error(str(e), reason=str(e)) + except asyncio.TimeoutError: + raise asyncio.TimeoutError(f"SSH connection to {host}:{port} timed out.") + + async def terminate(self) -> None: + """ + Closes the SSH channel and the underlying SSH connection. + """ + # Close the SSH channel and client connection + if self._ssh_channel and not self._ssh_channel.is_closing(): + self._ssh_channel.close() + await self._ssh_channel.wait_closed() + if self._ssh_client and not self._ssh_client.is_closed(): + self._ssh_client.close() + await self._ssh_client.wait_closed() + + # Terminate the AioPipe, which handles reader/writer tasks + await super().terminate() + + async def write_eof(self) -> None: + """ + Signals the end of the write stream to the remote process. + For SSH, this means closing the stdin channel. + """ + if self._writer: + self._writer.close() + await self._writer.wait_closed() diff --git a/cbor_rpc/stdio/stdio_pipe.py b/cbor_rpc/stdio/stdio_pipe.py new file mode 100644 index 0000000..97ab01d --- /dev/null +++ b/cbor_rpc/stdio/stdio_pipe.py @@ -0,0 +1,76 @@ +import asyncio +import sys +from typing import Any, List, Optional, Tuple, TypeVar, Generic + +from cbor_rpc.pipe.aio_pipe import AioPipe +from cbor_rpc.pipe.pipe import Pipe +from cbor_rpc.event.emitter import AbstractEmitter + +T1 = TypeVar("T1") +T2 = TypeVar("T2") + + +class StdioPipe(AioPipe[bytes, bytes]): + """ + A Pipe implementation that works with asyncio.StreamReader and asyncio.StreamWriter + typically obtained from a subprocess's stdin/stdout. + """ + + def __init__( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + process: Optional[asyncio.subprocess.Process] = None, + ): + super().__init__(reader, writer) + self._process = process + + async def _setup(self): + await self._setup_connection() + + @classmethod + async def open(cls) -> "StdioPipe": + """ + Creates a StdioPipe from the process's stdin and stdout. + """ + loop = asyncio.get_running_loop() + reader = asyncio.StreamReader() + protocol = asyncio.StreamReaderProtocol(reader) + await loop.connect_read_pipe(lambda: protocol, sys.stdin) + writer_transport, writer_protocol = await loop.connect_write_pipe(asyncio.streams.FlowControlMixin, sys.stdout) + writer = asyncio.StreamWriter(writer_transport, writer_protocol, reader, loop) + pipe = cls(reader, writer) + await pipe._setup() + return pipe + + @classmethod + async def start_process(cls, *args: str) -> "StdioPipe": + """ + Starts a process and returns a StdioPipe for it. + """ + process = await asyncio.create_subprocess_exec( + *args, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=sys.stderr, + ) + pipe = cls(process.stdout, process.stdin, process) + await pipe._setup() + return pipe + + async def wait_for_process_termination(self) -> int: + """ + Waits for the started subprocess to terminate and returns its exit code. + Raises RuntimeError if no process was started by this pipe. + """ + if not self._process: + raise RuntimeError("No subprocess associated with this StdioPipe instance.") + return await self._process.wait() + + async def terminate(self, *args: Any): + """ + Terminates the started subprocess if one exists. + """ + if self._process and self._process.returncode is None: + self._process.terminate() + await super().terminate(*args) diff --git a/cbor_rpc/sync_pipe.py b/cbor_rpc/sync_pipe.py deleted file mode 100644 index 55a2935..0000000 --- a/cbor_rpc/sync_pipe.py +++ /dev/null @@ -1,140 +0,0 @@ -from typing import Any, TypeVar, Generic, Callable, Tuple, Optional -from abc import ABC, abstractmethod -import asyncio -import inspect -from .emitter import AbstractEmitter -import queue -import threading -from typing import Union - -# Generic type variables -T1 = TypeVar('T1') -T2 = TypeVar('T2') - - - -class SyncPipe(Generic[T1, T2]): - """ - Synchronous pipe uses read/write methods instead of events. - Explicit read/write is used for writing protocols that have multiple steps. - """ - - def __init__(self): - self._closed = False - self._queue: queue.Queue = queue.Queue() - self._sent_data: queue.Queue = queue.Queue() # Track data sent to other pipe for size reporting - self._connected_pipe: Optional['SyncPipe'] = None - - def read(self, timeout: Optional[float] = None) -> Optional[T2]: - """ - Read data from the pipe. - - Args: - timeout: Maximum time to wait for data (None = block indefinitely) - - Returns: - Data from the pipe or None if timeout/closed - - Raises: - Exception: If pipe is closed - """ - if self._closed: - raise Exception("Pipe is closed") - - try: - return self._queue.get(timeout=timeout) - except queue.Empty: - return None - - def write(self, chunk: T1) -> bool: - """ - Write data to the pipe. - - Args: - chunk: Data to write - - Returns: - True if successful, False if pipe is closed - """ - if self._closed: - return False - - # Forward to connected pipe only (not to our own queue) - try: - if self._connected_pipe and not self._connected_pipe._closed: - self._sent_data.put(chunk) # Track for size reporting - self._connected_pipe._queue.put(chunk) - return True - except: - return False - - return True - - def close(self) -> None: - """Close the pipe.""" - if self._closed: - return - - self._closed = True - - # Notify connected pipe - if self._connected_pipe and not self._connected_pipe._closed: - self._connected_pipe.close() - - def is_closed(self) -> bool: - """Check if the pipe is closed.""" - return self._closed - - def available(self) -> int: - """Get number of items available to read.""" - if self._closed: - return 0 - - # Return combined size of local queue and sent data (data we wrote that hasn't been read by the other pipe) - local_size = self._queue.qsize() - - # If we have a connected pipe, check how much of our sent data has already been read - if self._connected_pipe: - # Get all items in _sent_data queue - sent_items = list(self._sent_data.queue) - - # Remove items that are still in the other pipe's queue - for item in list(sent_items): - try: - # Try to peek at what's in the connected pipe's queue - # We need a copy of the other pipe's queue to avoid modifying it during iteration - connected_queue = list(self._connected_pipe._queue.queue) - if item in connected_queue: - # Item hasn't been read yet, so count it - continue - else: - # Item has been read, remove from our sent tracking - self._sent_data.get() - except: - # If we can't access the other pipe's queue or there's an error, - # just use all items in _sent_data as a fallback - pass - - # Return total of local data and unread sent data - return local_size + len(list(self._sent_data.queue)) - - # No connected pipe, just return local size - return local_size - - @staticmethod - def create_pair() -> Tuple['SyncPipe[Any, Any]', 'SyncPipe[Any, Any]']: - """ - Create a pair of connected sync pipes for bidirectional communication. - - Returns: - A tuple of (pipe1, pipe2) where data written to pipe1 can be read from pipe2 and vice versa. - """ - pipe1 = SyncPipe[Any, Any]() - pipe2 = SyncPipe[Any, Any]() - - # Connect them bidirectionally - pipe1._connected_pipe = pipe2 - pipe2._connected_pipe = pipe1 - - return pipe1, pipe2 - diff --git a/cbor_rpc/tcp.py b/cbor_rpc/tcp.py deleted file mode 100644 index dbdefcc..0000000 --- a/cbor_rpc/tcp.py +++ /dev/null @@ -1,386 +0,0 @@ -import asyncio -import socket -from typing import Any, Callable, Optional, Tuple, Union -from .async_pipe import Pipe -from .server_base import Server - - -class TcpPipe(Pipe[bytes, bytes]): - """ - A TCP duplex pipe that implements Pipe for network communication. - Provides both client and server functionality for TCP connections. - """ - - def __init__(self, reader: Optional[asyncio.StreamReader] = None, - writer: Optional[asyncio.StreamWriter] = None): - super().__init__() - self._reader = reader - self._writer = writer - self._connected = False - self._closed = False - self._read_task: Optional[asyncio.Task] = None - - @classmethod - async def create_connection(cls, host: str, port: int, - timeout: Optional[float] = None) -> 'TcpPipe': - """ - Create a TCP client connection to the specified host and port. - - Args: - host: The hostname or IP address to connect to - port: The port number to connect to - timeout: Optional timeout for the connection attempt - - Returns: - A connected TcpPipe instance - - Raises: - ConnectionError: If the connection fails - asyncio.TimeoutError: If the connection times out - """ - try: - if timeout: - reader, writer = await asyncio.wait_for( - asyncio.open_connection(host, port), - timeout=timeout - ) - else: - reader, writer = await asyncio.open_connection(host, port) - - tcp_duplex = cls(reader, writer) - await tcp_duplex._setup_connection() - return tcp_duplex - - except Exception as e: - raise ConnectionError(f"Failed to connect to {host}:{port}: {e}") - - @classmethod - async def create_server(cls, host: str = '0.0.0.0', port: int = 0, - backlog: int = 100) -> 'TcpServer': - """ - Create a TCP server that listens for incoming connections. - - Args: - host: The hostname or IP address to bind to (default: '0.0.0.0') - port: The port number to bind to (default: 0 for auto-assignment) - backlog: The maximum number of queued connections - - Returns: - A TcpServer instance - """ - return await TcpServer.create(host, port, backlog) - - @staticmethod - async def create_pair() -> Tuple['TcpPipe', 'TcpPipe']: - """ - Create a pair of connected TCP pipes using a local server. - - Returns: - A tuple of (client_pipe, server_pipe) connected via TCP - """ - # Create a temporary server - server = await TcpServer.create('127.0.0.1', 0) - host, port = server.get_address() - - # Set up to capture the server-side connection - server_pipe = None - connection_ready = asyncio.Event() - - async def on_connection(pipe: TcpPipe): - nonlocal server_pipe - server_pipe = pipe - connection_ready.set() - - server.on_connection(on_connection) - - try: - # Create client connection - client_pipe = await TcpPipe.create_connection(host, port) - - # Wait for server connection - await connection_ready.wait() - - # Close the server but keep the connections - await server.stop() - - return client_pipe, server_pipe - - except Exception: - await server.stop() - raise - - @classmethod - def from_socket(cls, sock: socket.socket) -> 'TcpPipe': - """ - Create a TcpPipe from an existing socket. - - Args: - sock: An existing connected socket - - Returns: - A TcpPipe instance wrapping the socket - """ - # This will be set up when the connection is established - tcp_duplex = cls() - tcp_duplex._socket = sock - return tcp_duplex - - async def connect(self, host: str, port: int, timeout: Optional[float] = None) -> None: - """ - Connect to a remote TCP server. - - Args: - host: The hostname or IP address to connect to - port: The port number to connect to - timeout: Optional timeout for the connection attempt - - Raises: - ConnectionError: If already connected or connection fails - asyncio.TimeoutError: If the connection times out - """ - if self._connected: - raise ConnectionError("Already connected") - - try: - if timeout: - self._reader, self._writer = await asyncio.wait_for( - asyncio.open_connection(host, port), - timeout=timeout - ) - else: - self._reader, self._writer = await asyncio.open_connection(host, port) - - await self._setup_connection() - - except Exception as e: - raise ConnectionError(f"Failed to connect to {host}:{port}: {e}") - - async def _setup_connection(self) -> None: - """Set up the connection and start reading data.""" - if not self._reader or not self._writer: - raise RuntimeError("Reader or writer not initialized") - - self._connected = True - self._closed = False - - # Start the read loop - self._read_task = asyncio.create_task(self._read_loop()) - - # Emit connection event - await self._notify("connect") - - async def _read_loop(self) -> None: - """Continuously read data from the TCP connection and emit data events.""" - try: - while self._connected and not self._closed and self._reader: - try: - # Read data in chunks - data = await self._reader.read(8192) - if not data: - # Connection closed by remote - break - - # Emit data event - await self._emit("data", data) - - except asyncio.CancelledError: - break - except Exception as e: - await self._emit("error", e) - break - - except Exception as e: - await self._emit("error", e) - finally: - if not self._closed: - await self._close_connection() - - async def write(self, chunk: bytes) -> bool: - """ - Write data to the TCP connection. - - Args: - chunk: The bytes to write - - Returns: - True if the write was successful - - Raises: - ConnectionError: If not connected - TypeError: If chunk is not bytes - """ - if not self._connected or self._closed: - raise ConnectionError("Not connected") - - if not isinstance(chunk, (bytes, bytearray)): - raise TypeError("Chunk must be bytes or bytearray") - - if not self._writer: - raise ConnectionError("Writer not available") - - try: - self._writer.write(chunk) - await self._writer.drain() - return True - - except Exception as e: - await self._emit("error", e) - return False - - async def terminate(self, *args: Any) -> None: - """ - Terminate the TCP connection. - - Args: - *args: Optional arguments (code, reason) - """ - await self._close_connection(*args) - - async def _close_connection(self, *args: Any) -> None: - """Close the TCP connection and clean up resources.""" - if self._closed: - return - - self._closed = True - self._connected = False - - # Cancel the read task - if self._read_task and not self._read_task.done(): - task, self._read_task = self._read_task, None - if task.cancel(): - try: - await task - except asyncio.CancelledError: - pass - - # Close the writer - if self._writer: - writer, self._writer = self._writer, None - try: - writer.close() - await writer.wait_closed() - except Exception: - pass # Ignore errors during cleanup - - # Emit close event - await self._emit("close", *args) - - def is_connected(self) -> bool: - """Check if the TCP connection is active.""" - return self._connected and not self._closed - - def get_peer_info(self) -> Optional[Tuple[str, int]]: - """Get the remote peer's address and port.""" - if self._writer and self._connected: - try: - return self._writer.get_extra_info('peername') - except Exception: - pass - return None - - def get_local_info(self) -> Optional[Tuple[str, int]]: - """Get the local socket's address and port.""" - if self._writer and self._connected: - try: - return self._writer.get_extra_info('sockname') - except Exception: - pass - return None - - -class TcpServer(Server[TcpPipe]): - """ - A TCP server that creates TcpPipe instances for incoming connections. - Extends Server[TcpPipe] to provide type-safe TCP-specific functionality. - """ - - def __init__(self, server: asyncio.Server): - super().__init__() - self._server = server - - @classmethod - async def create(cls, host: str = '0.0.0.0', port: int = 0, - backlog: int = 100) -> 'TcpServer': - """ - Create and start a TCP server. - - Args: - host: The hostname or IP address to bind to - port: The port number to bind to (0 for auto-assignment) - backlog: The maximum number of queued connections - - Returns: - A started TcpServer instance - """ - tcp_server = cls.__new__(cls) - Server.__init__(tcp_server) - - async def client_connected_cb(reader: asyncio.StreamReader, - writer: asyncio.StreamWriter) -> None: - tcp_pipe = TcpPipe(reader, writer) - await tcp_pipe._setup_connection() - await tcp_server._add_connection(tcp_pipe) - - server = await asyncio.start_server( - client_connected_cb, host, port, backlog=backlog - ) - - tcp_server._server = server - tcp_server._running = True - return tcp_server - - async def start(self, host: str = '0.0.0.0', port: int = 0, backlog: int = 100) -> Tuple[str, int]: - """ - Start the TCP server (if not already started). - - Args: - host: The hostname or IP address to bind to - port: The port number to bind to - backlog: The maximum number of queued connections - - Returns: - A tuple of (host, port) where the server is listening - """ - if self._running: - return self.get_address() - - # This method is for compatibility; typically create() is used instead - new_server = await TcpServer.create(host, port, backlog) - self._server = new_server._server - self._running = True - return self.get_address() - - async def stop(self) -> None: - """Stop the TCP server and close all connections.""" - if not self._running: - return - - self._running = False - - # Close all connections - await self.close_all_connections() - - # Close the server - if self._server: - self._server.close() - await self._server.wait_closed() - - def get_address(self) -> Tuple[str, int]: - """Get the server's listening address and port.""" - if self._server and self._server.sockets: - return self._server.sockets[0].getsockname()[:2] - return ("", 0) - - # Legacy methods for backward compatibility - def on_connection(self, handler: Callable[[TcpPipe], None]) -> None: - """ - Set a handler for new connections. - - Args: - handler: A function that takes a TcpPipe as argument - """ - super().on_connection(handler) - - async def close(self) -> None: - """Legacy method - use stop() instead.""" - await self.stop() diff --git a/cbor_rpc/tcp/__init__.py b/cbor_rpc/tcp/__init__.py new file mode 100644 index 0000000..25f514f --- /dev/null +++ b/cbor_rpc/tcp/__init__.py @@ -0,0 +1,3 @@ +from .tcp import TcpPipe, TcpServer + +__all__ = ["TcpPipe", "TcpServer"] diff --git a/cbor_rpc/tcp/tcp.py b/cbor_rpc/tcp/tcp.py new file mode 100644 index 0000000..b8e87b2 --- /dev/null +++ b/cbor_rpc/tcp/tcp.py @@ -0,0 +1,263 @@ +from abc import abstractmethod +import asyncio +import socket +from typing import Any, Callable, Optional, Tuple, Union +from cbor_rpc.pipe.aio_pipe import AioPipe +from cbor_rpc.rpc.server_base import Server + + +class TcpPipe(AioPipe[bytes, bytes]): + """ + A TCP duplex pipe that implements Pipe for network communication. + Provides both client and server functionality for TCP connections. + """ + + def __init__( + self, + reader: Optional[asyncio.StreamReader] = None, + writer: Optional[asyncio.StreamWriter] = None, + ): + super().__init__(reader, writer) + + @classmethod + async def create_connection(cls, host: str, port: int, timeout: Optional[float] = None) -> "TcpPipe": + """ + Create a TCP client connection to the specified host and port. + + Args: + host: The hostname or IP address to connect to + port: The port number to connect to + timeout: Optional timeout for the connection attempt + + Returns: + A connected TcpPipe instance + + Raises: + ConnectionError: If the connection fails + asyncio.TimeoutError: If the connection times out + """ + try: + if timeout: + reader, writer = await asyncio.wait_for(asyncio.open_connection(host, port), timeout=timeout) + else: + reader, writer = await asyncio.open_connection(host, port) + + tcp_duplex = cls(reader, writer) + await tcp_duplex._setup_connection() + return tcp_duplex + + except Exception as e: + raise ConnectionError(f"Failed to connect to {host}:{port}: {e}") + + @classmethod + async def create_server(cls, host: str = "0.0.0.0", port: int = 0, backlog: int = 100) -> "TcpServer": + """ + Create a TCP server that listens for incoming connections. + + Args: + host: The hostname or IP address to bind to (default: '0.0.0.0') + port: The port number to bind to (default: 0 for auto-assignment) + backlog: The maximum number of queued connections + + Returns: + A TcpServer instance + """ + return await TcpServer.create(host, port, backlog) + + @staticmethod + async def create_inmemory_pair() -> Tuple["TcpPipe", "TcpPipe"]: + """ + Create a pair of connected TCP pipes using a local server. + + Returns: + A tuple of (client_pipe, server_pipe) connected via TCP + """ + + class SimpleTcpServer(TcpServer): + async def accept(self, pipe: TcpPipe) -> bool: + return True + + # Create a temporary server + server = await SimpleTcpServer.create("127.0.0.1", 0) + host, port = server.get_address() + + # Set up to capture the server-side connection + server_pipe = None + connection_ready = asyncio.Event() + + async def on_connection(pipe: TcpPipe): + nonlocal server_pipe + server_pipe = pipe + connection_ready.set() + + server.on_connection(on_connection) + + try: + # Create client connection + client_pipe = await TcpPipe.create_connection(host, port) + + # Wait for server connection + await connection_ready.wait() + + # Stop accepting new connections but keep the active pipes + await server.shutdown() + + return client_pipe, server_pipe + + except Exception: + await server.stop() + raise + + async def connect(self, host: str, port: int, timeout: Optional[float] = None) -> None: + """ + Connect to a remote TCP server. + + Args: + host: The hostname or IP address to connect to + port: The port number to connect to + timeout: Optional timeout for the connection attempt + + Raises: + ConnectionError: If already connected or connection fails + asyncio.TimeoutError: If the connection times out + """ + if self._connected: + raise ConnectionError("Already connected") + + try: + if timeout: + self._reader, self._writer = await asyncio.wait_for( + asyncio.open_connection(host, port), timeout=timeout + ) + else: + self._reader, self._writer = await asyncio.open_connection(host, port) + + await self._setup_connection() + + except Exception as e: + raise ConnectionError(f"Failed to connect to {host}:{port}: {e}") + + def get_peer_info(self) -> Optional[Tuple[str, int]]: + """Get the remote peer's address and port.""" + if self._writer and self._connected: + try: + return self._writer.get_extra_info("peername") + except Exception: + pass + return None + + def get_local_info(self) -> Optional[Tuple[str, int]]: + """Get the local socket's address and port.""" + if self._writer and self._connected: + try: + return self._writer.get_extra_info("sockname") + except Exception: + pass + return None + + def get_peer_info(self) -> Optional[Tuple[str, int]]: + """Get the remote peer's address and port.""" + if self._writer and self._connected: + try: + return self._writer.get_extra_info("peername") + except Exception: + pass + return None + + def get_local_info(self) -> Optional[Tuple[str, int]]: + """Get the local socket's address and port.""" + if self._writer and self._connected: + try: + return self._writer.get_extra_info("sockname") + except Exception: + pass + return None + + +class TcpServer(Server[TcpPipe]): + """ + A TCP server that creates TcpPipe instances for incoming connections. + Extends Server[TcpPipe] to provide type-safe TCP-specific functionality. + """ + + def __init__(self, server: asyncio.Server): + super().__init__() + self._server = server + + @classmethod + async def create(cls, host: str = "0.0.0.0", port: int = 0, backlog: int = 100) -> "TcpServer": + """ + Create and start a TCP server. + + Args: + host: The hostname or IP address to bind to + port: The port number to bind to (0 for auto-assignment) + backlog: The maximum number of queued connections + + Returns: + A started TcpServer instance + """ + tcp_server = cls.__new__(cls) + Server.__init__(tcp_server) + + async def client_connected_cb(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: + tcp_pipe = TcpPipe(reader, writer) + await tcp_pipe._setup_connection() + await tcp_server._add_connection(tcp_pipe) + + server = await asyncio.start_server(client_connected_cb, host, port, backlog=backlog) + + tcp_server._server = server + tcp_server._running = True + return tcp_server + + async def start(self, host: str = "0.0.0.0", port: int = 0, backlog: int = 100) -> Tuple[str, int]: + """ + Start the TCP server (if not already started). + + Args: + host: The hostname or IP address to bind to + port: The port number to bind to + backlog: The maximum number of queued connections + + Returns: + A tuple of (host, port) where the server is listening + """ + if self._running: + return self.get_address() + + # This method is for compatibility; typically create() is used instead + new_server = await TcpServer.create(host, port, backlog) + self._server = new_server._server + self._running = True + return self.get_address() + + async def stop(self) -> None: + """Stop the TCP server and close all connections.""" + if not self._running: + return + + self._running = False + + # Close all connections + await self.close_all_connections() + + # Close the server + if self._server: + self._server.close() + await self._server.wait_closed() + + async def shutdown(self) -> None: + """Stop accepting new connections while keeping active connections open.""" + if not self._server: + return + + self._running = False + self._server.close() + await self._server.wait_closed() + + def get_address(self) -> Tuple[str, int]: + """Get the server's listening address and port.""" + if self._server and self._server.sockets: + return self._server.sockets[0].getsockname()[:2] + return ("", 0) diff --git a/cbor_rpc/promise.py b/cbor_rpc/timed_promise.py similarity index 81% rename from cbor_rpc/promise.py rename to cbor_rpc/timed_promise.py index 66411dc..c96bd9a 100644 --- a/cbor_rpc/promise.py +++ b/cbor_rpc/timed_promise.py @@ -2,22 +2,23 @@ import asyncio -class DeferredPromise: - def __init__(self, timeout_ms: int, timeout_cb: Optional[Callable[[], None]] = None, - message: str = "Timeout on RPC call"): +class TimedPromise: + def __init__( + self, + timeout_ms: int, + timeout_cb: Optional[Callable[[], None]] = None, + message: str = "Timeout on RPC call", + ): self._timeout_ms = timeout_ms self._timeout_cb = timeout_cb self._message = message self._future = asyncio.get_event_loop().create_future() self._timeout_handle = None self._resolved = False - + # Set up timeout if timeout_ms > 0: - self._timeout_handle = asyncio.get_event_loop().call_later( - timeout_ms / 1000.0, - self._on_timeout - ) + self._timeout_handle = asyncio.get_event_loop().call_later(timeout_ms / 1000.0, self._on_timeout) @property def promise(self) -> asyncio.Future: @@ -46,8 +47,11 @@ def _on_timeout(self) -> None: error_data = { "timeout": True, "timeoutPeriod": self._timeout_ms, - "message": self._message + "message": self._message, } self._future.set_exception(Exception(error_data)) if self._timeout_cb: self._timeout_cb() + + def __await__(self): + return self._future.__await__() diff --git a/cbor_rpc/transformer.py b/cbor_rpc/transformer.py deleted file mode 100644 index 3e713c7..0000000 --- a/cbor_rpc/transformer.py +++ /dev/null @@ -1,3 +0,0 @@ -from .transformer.transformer import Transformer - -__all__ = ['Transformer'] diff --git a/cbor_rpc/transformer/__init__.py b/cbor_rpc/transformer/__init__.py index 6ff14d5..3ba6de9 100644 --- a/cbor_rpc/transformer/__init__.py +++ b/cbor_rpc/transformer/__init__.py @@ -1 +1,3 @@ -from .transformer import Transformer \ No newline at end of file +from .base import Transformer +from .json_transformer import JsonTransformer +from .cbor_transformer import CborTransformer, CborStreamTransformer diff --git a/cbor_rpc/transformer/base/__init__.py b/cbor_rpc/transformer/base/__init__.py new file mode 100644 index 0000000..09c712e --- /dev/null +++ b/cbor_rpc/transformer/base/__init__.py @@ -0,0 +1,2 @@ +from .transformer_base import Transformer, AsyncTransformer +from .base_exception import NeedsMoreDataException diff --git a/cbor_rpc/transformer/base/base_exception.py b/cbor_rpc/transformer/base/base_exception.py new file mode 100644 index 0000000..92a1fc7 --- /dev/null +++ b/cbor_rpc/transformer/base/base_exception.py @@ -0,0 +1,2 @@ +class NeedsMoreDataException(Exception): + pass diff --git a/cbor_rpc/transformer/base/event_transformer_pipe.py b/cbor_rpc/transformer/base/event_transformer_pipe.py new file mode 100644 index 0000000..da13748 --- /dev/null +++ b/cbor_rpc/transformer/base/event_transformer_pipe.py @@ -0,0 +1,77 @@ +import asyncio +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .transformer_base import Transformer +from typing import Any, Awaitable, Callable, TypeVar + +from cbor_rpc.pipe import EventPipe +from .base_exception import NeedsMoreDataException + +T1 = TypeVar("T1") # Output type after decoding +T2 = TypeVar("T2") # Input type before decoding (pipe input/output type) + + +class EventTransformerPipe(EventPipe[T1, T2]): + encode: Callable[[T1], Awaitable[T2]] + decode: Callable[[T2], Awaitable[T1]] + + def __init__(self, pipe: EventPipe[T2, T2], transformer: "Transformer"): + super().__init__() + self.pipe = pipe + self.pipe.pipeline("data", self._handle_data) + self.pipe.on("close", self._on_close) + self.pipe.on("error", self._on_error) + if transformer: + self.encode = transformer.encode + self.decode = transformer.decode + else: + raise ValueError("A transformer must be provided or encode/decode must be overridden") + + async def _handle_data(self, data: T2): + try: + decoded = await self.decode(data) + await self._notify("data", decoded) + except NeedsMoreDataException: + # If more data is needed, simply return and wait for the next chunk + return + except Exception as e: + # Let the exception propagate up to AbstractEmitter._notify, + # which will catch it and emit the "error" event. + raise e + + # Try to decode more from the buffer if the transformer supports it + while True: + try: + # We pass None to indicate "no new data, just decode from buffer" + # We catch TypeError in case the transformer does not support None/buffering + decoded = await self.decode(None) + await self._notify("data", decoded) + except NeedsMoreDataException: + break + except TypeError: + # Transformer likely doesn't support None (not a stream transformer) + break + except Exception as e: + # Other errors in subsequent decoding should probably be reported. + # However, since the *primary* data was processed, maybe we should just emit error? + # or raise? Raising here might be outside the context of the initial caller if we were async... + # But we are inside _handle_data which is awaited by _notify. + raise e + + def _on_close(self, *args: Any): + self._emit("close", *args) + + def _on_error(self, error: Exception): + self._emit("error", error) + + async def write(self, chunk: T1) -> bool: + try: + encoded = await self.encode(chunk) + return await self.pipe.write(encoded) + except Exception as e: + self._emit("error", e) + return False + + async def terminate(self, *args: Any) -> None: + await self.pipe.terminate(*args) diff --git a/cbor_rpc/transformer/base/transformer_base.py b/cbor_rpc/transformer/base/transformer_base.py new file mode 100644 index 0000000..957f22d --- /dev/null +++ b/cbor_rpc/transformer/base/transformer_base.py @@ -0,0 +1,106 @@ +from abc import abstractmethod +from typing import Any, Generic, TypeVar, Union, overload + +from cbor_rpc.pipe import EventPipe +from cbor_rpc.pipe import Pipe +from .event_transformer_pipe import EventTransformerPipe +from .base_exception import NeedsMoreDataException +from .transformer_pipe import TransformerPipe + +T1 = TypeVar("T1") +T2 = TypeVar("T2") + + +# Sync Transformer (no async methods) +class Transformer(Generic[T1, T2]): + """ + Transformer converts input of type T1 to output of type T2 and vice versa. + Write expects T1, read produces T1. The underlying pipe works with T2. + """ + def __init__(self): + super().__init__() + self._closed = False + + def is_closed(self) -> bool: + return self._closed + + @abstractmethod + def encode(self, data: T1) -> Any: + pass + + @abstractmethod + def decode(self, data: Any) -> T2: + pass + + def wait_next_data(self): + raise NeedsMoreDataException() + + @overload + def bind(self, pipe: Pipe) -> TransformerPipe: ... + @overload + def bind(self, pipe: EventPipe) -> EventTransformerPipe: ... + + @overload + def apply_transformer(self, pipe: Pipe) -> TransformerPipe: ... + @overload + def apply_transformer(self, pipe: EventPipe) -> EventTransformerPipe: ... + def apply_transformer(self, pipe: Union[Pipe, EventPipe]) -> Union[TransformerPipe, EventTransformerPipe]: + if isinstance(pipe, EventPipe): + return EventTransformerPipe(pipe, self.to_async()) + elif isinstance(pipe, Pipe): + return TransformerPipe(pipe, self.to_async()) + else: + raise TypeError("Invalid pipe type") + + def to_async(self) -> "AsyncTransformer[T1, T2]": + parent = self + + class WrappedAsyncTransformer(AsyncTransformer[T1, T2]): + async def encode(self, data: T1) -> Any: + return parent.encode(data) + + async def decode(self, data: Any) -> T2: + return parent.decode(data) + + def is_closed(self) -> bool: + return parent.is_closed() + + return WrappedAsyncTransformer() + + +# Async Transformer (async methods) +class AsyncTransformer(Generic[T1, T2]): + def __init__(self): + super().__init__() + self._closed = False + + def is_closed(self) -> bool: + return self._closed + + @abstractmethod + async def encode(self, data: T1) -> Any: + pass + + @abstractmethod + async def decode(self, data: Any) -> T2: + pass + + def wait_next_data(self): + raise NeedsMoreDataException() + + @overload + def bind(self, pipe: Pipe) -> TransformerPipe: ... + @overload + def bind(self, pipe: EventPipe) -> EventTransformerPipe: ... + + @overload + def apply_transformer(self, pipe: Pipe) -> TransformerPipe: ... + @overload + def apply_transformer(self, pipe: EventPipe) -> EventTransformerPipe: ... + def apply_transformer(self, pipe: Union[Pipe, EventPipe]) -> Union[TransformerPipe, EventTransformerPipe]: + if isinstance(pipe, EventPipe): + return EventTransformerPipe(pipe, self) + elif isinstance(pipe, Pipe): + return TransformerPipe(pipe, self) + else: + raise TypeError("Invalid pipe type") diff --git a/cbor_rpc/transformer/base/transformer_pipe.py b/cbor_rpc/transformer/base/transformer_pipe.py new file mode 100644 index 0000000..714c3a7 --- /dev/null +++ b/cbor_rpc/transformer/base/transformer_pipe.py @@ -0,0 +1,89 @@ +from typing import Any, Awaitable, Optional, TypeVar, Callable +from typing import TYPE_CHECKING +import asyncio +import time + +from .base_exception import NeedsMoreDataException + +if TYPE_CHECKING: + from .transformer_base import Transformer +from cbor_rpc.pipe.pipe import Pipe + +T1 = TypeVar("T1") +T2 = TypeVar("T2") + + +class TransformerPipe(Pipe[T1, T2]): + encode: Callable[[T1], Awaitable[T2]] + decode: Callable[[T2], Awaitable[T1]] + + def __init__(self, pipe: Pipe[Any, Any], transformer: "Optional[Transformer[T1, T2]]"): + super().__init__() + self.pipe = pipe + self._closed = False + self._terminated = False + + if transformer is None: + raise ValueError("A transformer must be provided") + + self.encode = transformer.encode + self.decode = transformer.decode + + async def _handle_error(*args): + if not self._closed: + self._closed = True + await self.terminate() + self._emit("error", *args) + + def _handle_close(*args): + if not self._closed: + self._closed = True + self._emit("close", *args) + + self.pipe.on("error", _handle_error) + self.pipe.on("close", _handle_close) + + async def write(self, chunk: T1) -> bool: + if self._closed: + return False + try: + encoded = await self.encode(chunk) + return await self.pipe.write(encoded) + except Exception as e: + self._emit("error", e) + return False + + async def read(self, timeout: Optional[float] = None) -> Optional[T1]: + if self._closed: + return None + + start = time.monotonic() + remaining = timeout + + try: + while True: + raw = await self.pipe.read(remaining) + if raw is None: + return None + + try: + return await self.decode(raw) + except NeedsMoreDataException: + if timeout is not None: + elapsed = time.monotonic() - start + remaining = max(0, timeout - elapsed) + if remaining == 0: + return None + except Exception as e: + self._emit("error", e) + return None + + async def terminate(self) -> None: + if self._terminated: + return + self._terminated = True + self._closed = True + await self.pipe.terminate() + + def _propagate_error(self, *args): + self.pipe._emit("error", *args) diff --git a/cbor_rpc/transformer/cbor_transformer.py b/cbor_rpc/transformer/cbor_transformer.py new file mode 100644 index 0000000..1a77dc4 --- /dev/null +++ b/cbor_rpc/transformer/cbor_transformer.py @@ -0,0 +1,95 @@ +import cbor2 +from io import BytesIO +from typing import Any, Union + +from .base import Transformer, AsyncTransformer +from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException + +# Use the pure-Python CBORDecoder to avoid buffering issues with the C extension +# when using stream-slicing logic. +try: + from cbor2._decoder import CBORDecoder as PythonCBORDecoder +except ImportError: + from cbor2 import CBORDecoder as PythonCBORDecoder + +# Import pure-Python break_marker / CBORDecodeError so we can handle both +# C-extension and pure-Python backends uniformly. +try: + from cbor2._types import break_marker as _py_break_marker +except ImportError: + _py_break_marker = cbor2.break_marker + + + +def _is_eof_error(exc: Exception) -> bool: + """Return True if *exc* signals incomplete CBOR data (C or Python backend).""" + return isinstance(exc, (cbor2.CBORDecodeEOF, IndexError)) or type(exc).__name__ == "CBORDecodeEOF" + + +class CborTransformer(Transformer[Any, bytes]): + """Encodes Python objects to CBOR bytes and decodes CBOR bytes back.""" + + def encode(self, data: Any) -> bytes: + return cbor2.dumps(data) + + def decode(self, data: Union[bytes, None]) -> Any: + if data is None: + raise TypeError("Expected bytes, got None") + if not isinstance(data, bytes): + raise TypeError(f"Expected bytes, got {type(data)}") + try: + return cbor2.loads(data) + except cbor2.CBORDecodeEOF as e: + raise cbor2.CBORDecodeError("Incomplete CBOR data for non-stream transformer") from e + + +class CborStreamTransformer(AsyncTransformer[Any, Any]): + """Async stream transformer that decodes concatenated CBOR objects.""" + + def __init__(self, max_buffer_bytes: int = 1024 * 1024*50): + super().__init__() + self._buffer = bytearray() + self._max_buffer_bytes = max_buffer_bytes + + async def encode(self, data: Any) -> bytes: + return cbor2.dumps(data) + + async def decode(self, data: Union[bytes, None]) -> Any: + if data is not None: + if not isinstance(data, bytes): + raise TypeError(f"Expected bytes or None, got {type(data)}") + self._buffer.extend(data) + + if len(self._buffer) > self._max_buffer_bytes: + self._buffer.clear() + raise OverflowError("CBOR stream buffer exceeded max size") + + if not self._buffer: + raise NeedsMoreDataException() + + try: + return self._decode_one() + except Exception as e: + if _is_eof_error(e): + raise NeedsMoreDataException() + raise + + # -- private helpers -------------------------------------------------- + + def _decode_one(self) -> Any: + """Decode exactly one CBOR object from the front of the buffer.""" + stream = BytesIO(self._buffer) + decoder = PythonCBORDecoder(stream) + try: + obj = decoder.decode() + except Exception as e: + # Normalize value errors to CBORDecodeError for consistent API behavior. + if isinstance(e, cbor2.CBORDecodeValueError) or type(e).__name__ == "CBORDecodeValueError": + raise cbor2.CBORDecodeError(str(e)) from e + raise + + if obj is cbor2.break_marker or obj is _py_break_marker: + raise cbor2.CBORDecodeError("Unexpected break marker") + + self._buffer = self._buffer[stream.tell():] + return obj diff --git a/cbor_rpc/transformer/json_transformer.py b/cbor_rpc/transformer/json_transformer.py new file mode 100644 index 0000000..ffb745d --- /dev/null +++ b/cbor_rpc/transformer/json_transformer.py @@ -0,0 +1,33 @@ +import json +from typing import Any, Union +from .base import Transformer + + +class JsonTransformer(Transformer[Any, Any]): + """ + A transformer that encodes Python objects to JSON strings and decodes JSON strings back to Python objects. + """ + + def __init__(self, encoding: str = "utf-8"): + super().__init__() + self.encoding = encoding + + def encode(self, data: Any) -> bytes: + json_str = json.dumps( + data, ensure_ascii=False + ) # Always allow non-ASCII characters to pass through json.dumps + return json_str.encode(self.encoding) + + + def decode(self, data: Union[bytes, str, None]) -> Any: + if data is None: + raise TypeError("Expected bytes or str, got None") + + if isinstance(data, bytes): + json_str = data.decode(self.encoding) + elif isinstance(data, str): + json_str = data + else: + raise TypeError(f"Expected bytes or str, got {type(data)}") + + return json.loads(json_str) diff --git a/cbor_rpc/transformer/transformer.py b/cbor_rpc/transformer/transformer.py deleted file mode 100644 index 9e1e691..0000000 --- a/cbor_rpc/transformer/transformer.py +++ /dev/null @@ -1,214 +0,0 @@ -from typing import Any, TypeVar, Generic, Callable, Tuple, Optional -from abc import ABC, abstractmethod -import asyncio -import inspect -from cbor_rpc.async_pipe import Pipe -from cbor_rpc.sync_pipe import SyncPipe -import queue -import threading -from typing import Union - -# Generic type variables -T1 = TypeVar('T1') -T2 = TypeVar('T2') - -class Transformer(Generic[T1, T2]): - """ - Abstract transformer that can wrap both sync and async pipes. - Encodes data when writing and decodes data when reading/emitting events. - """ - - def __init__(self, underlying_pipe: Union[Pipe[Any, Any], SyncPipe[Any, Any]]): - self.underlying_pipe = underlying_pipe - self._closed = False - self._is_sync_pipe = isinstance(underlying_pipe, SyncPipe) - - if not self._is_sync_pipe: - # For async pipes, set up event handlers - super().__init__() - - # Forward events from underlying pipe, but decode data events - async def on_underlying_data(data: Any): - try: - decoded_data = await self.decode(data) - await self._emit("data", decoded_data) - except Exception as e: - await self._emit("error", e) - - async def on_underlying_close(*args): - await self._emit("close", *args) - - async def on_underlying_error(error): - await self._emit("error", error) - - self.underlying_pipe.on("data", on_underlying_data) - self.underlying_pipe.on("close", on_underlying_close) - self.underlying_pipe.on("error", on_underlying_error) - - # Async methods for async pipes - async def write(self, chunk: T1) -> bool: - """Write data after encoding it (async version).""" - if self._closed: - return False - - if self._is_sync_pipe: - raise RuntimeError("Use write_sync() for SyncPipe") - - try: - encoded_chunk = await self.encode(chunk) - return await self.underlying_pipe.write(encoded_chunk) - except Exception as e: - await self._emit("error", e) - return False - - async def terminate(self, *args: Any) -> None: - """Terminate the underlying pipe (async version).""" - if self._closed: - return - self._closed = True - - if not self._is_sync_pipe: - await self.underlying_pipe.terminate(*args) - else: - self.underlying_pipe.close() - - # Sync methods for sync pipes - def write_sync(self, chunk: T1) -> bool: - """Write data after encoding it (sync version).""" - if self._closed: - return False - - if not self._is_sync_pipe: - raise RuntimeError("Use write() for async Pipe") - - try: - encoded_chunk = self.encode_sync(chunk) - return self.underlying_pipe.write(encoded_chunk) - except Exception as e: - return False - - def read_sync(self, timeout: Optional[float] = None) -> Optional[T2]: - """Read and decode data (sync version).""" - if self._closed: - return None - - if not self._is_sync_pipe: - raise RuntimeError("Use event handlers for async Pipe") - - try: - raw_data = self.underlying_pipe.read(timeout) - if raw_data is None: - return None - return self.decode_sync(raw_data) - except Exception as e: - return None - - def close_sync(self) -> None: - """Close the transformer (sync version).""" - if self._closed: - return - self._closed = True - - if self._is_sync_pipe: - self.underlying_pipe.close() - - def is_sync_pipe(self) -> bool: - """Check if this transformer wraps a sync pipe.""" - return self._is_sync_pipe - - # Abstract methods - async versions - @abstractmethod - async def encode(self, data: T1) -> Any: - """Encode data before writing to the underlying pipe (async version).""" - pass - - @abstractmethod - async def decode(self, data: Any) -> T2: - """Decode data received from the underlying pipe (async version).""" - pass - - # Abstract methods - sync versions (with default implementations that call async versions) - def encode_sync(self, data: T1) -> Any: - """Encode data before writing to the underlying pipe (sync version).""" - # Default implementation for backwards compatibility - # Subclasses should override this for true sync operation - if asyncio.iscoroutinefunction(self.encode): - raise NotImplementedError("Sync encoding not implemented for this transformer") - return asyncio.run(self.encode(data)) - - def decode_sync(self, data: Any) -> T2: - """Decode data received from the underlying pipe (sync version).""" - # Default implementation for backwards compatibility - # Subclasses should override this for true sync operation - if asyncio.iscoroutinefunction(self.decode): - raise NotImplementedError("Sync decoding not implemented for this transformer") - return asyncio.run(self.decode(data)) - - @staticmethod - def create_pair(encoder1: Callable, decoder1: Callable, - encoder2: Callable, decoder2: Callable, - use_sync: bool = False) -> Tuple['Transformer', 'Transformer']: - """ - Create a pair of connected transformer pipes. - - Args: - encoder1: Encoder function for the first transformer - decoder1: Decoder function for the first transformer - encoder2: Encoder function for the second transformer - decoder2: Decoder function for the second transformer - use_sync: If True, use SyncPipe; if False, use async Pipe - - Returns: - A tuple of (transformer1, transformer2) - """ - if use_sync: - pipe1, pipe2 = SyncPipe.create_pair() - else: - pipe1, pipe2 = Pipe.create_pair() - - class ConcreteTransformer1(Transformer): - async def encode(self, data): - if asyncio.iscoroutinefunction(encoder1): - return await encoder1(data) - return encoder1(data) - - async def decode(self, data): - if asyncio.iscoroutinefunction(decoder1): - return await decoder1(data) - return decoder1(data) - - def encode_sync(self, data): - if asyncio.iscoroutinefunction(encoder1): - raise NotImplementedError("Encoder is async, cannot use sync method") - return encoder1(data) - - def decode_sync(self, data): - if asyncio.iscoroutinefunction(decoder1): - raise NotImplementedError("Decoder is async, cannot use sync method") - return decoder1(data) - - class ConcreteTransformer2(Transformer): - async def encode(self, data): - if asyncio.iscoroutinefunction(encoder2): - return await encoder2(data) - return encoder2(data) - - async def decode(self, data): - if asyncio.iscoroutinefunction(decoder2): - return await decoder2(data) - return decoder2(data) - - def encode_sync(self, data): - if asyncio.iscoroutinefunction(encoder2): - raise NotImplementedError("Encoder is async, cannot use sync method") - return encoder2(data) - - def decode_sync(self, data): - if asyncio.iscoroutinefunction(decoder2): - raise NotImplementedError("Decoder is async, cannot use sync method") - return decoder2(data) - - transformer1 = ConcreteTransformer1(pipe1) - transformer2 = ConcreteTransformer2(pipe2) - - return transformer1, transformer2 diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..242f464 --- /dev/null +++ b/conftest.py @@ -0,0 +1,5 @@ +import sys +import os + +# Ensure the project root is in sys.path +sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) diff --git a/examples/rpc_backends_example.py b/examples/fs_rpc/__init__.py similarity index 100% rename from examples/rpc_backends_example.py rename to examples/fs_rpc/__init__.py diff --git a/examples/fs_rpc/filesystem_client.py b/examples/fs_rpc/filesystem_client.py new file mode 100644 index 0000000..4e717fc --- /dev/null +++ b/examples/fs_rpc/filesystem_client.py @@ -0,0 +1,53 @@ +import asyncio +from cbor_rpc import RpcV1 +from cbor_rpc.tcp import TcpPipe +from cbor_rpc.transformer.json_transformer import JsonTransformer +from cbor_rpc.pipe.event_pipe import ( + EventPipe, +) # Keep this import for clarity, though not directly instantiated + + +async def main(): + # Connect to the RPC server + tcp_pipe = await TcpPipe.create_connection("localhost", 8000) # Use port 8002 + + # Create a JSON transformer and attach the TCP pipe to it + json_transformer = JsonTransformer(tcp_pipe) + + # The RpcV1 client needs a pipe that it can write to and read from. + # The json_transformer now handles both incoming and outgoing data. + rpc_client = RpcV1.read_only_client(json_transformer) + + # Example usage of filesystem RPC methods + + # List files in current directory + files = await rpc_client.call_method("list_files", ".") + print("Files in current directory:", files) + + # Create a test file + create_success = await rpc_client.call_method("create_file", "test.txt") + print("File creation successful:", create_success) + + # Read the test file (should be empty) + content = await rpc_client.call_method("read_file", "test.txt") + print("File content:", content.decode()) + + # Write to the test file + write_success = await rpc_client.call_method("create_file", "test.txt", b"Hello, world!") + print("Write successful:", write_success) + + # Read the updated file + content = await rpc_client.call_method("read_file", "test.txt") + print("Updated file content:", content.decode()) + + # Rename the file + rename_success = await rpc_client.call_method("rename_file", "test.txt", "renamed_test.txt") + print("Rename successful:", rename_success) + + # Delete the renamed file + delete_success = await rpc_client.call_method("delete_file", "renamed_test.txt") + print("Delete successful:", delete_success) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/fs_rpc/filesystem_server.py b/examples/fs_rpc/filesystem_server.py new file mode 100644 index 0000000..0b458f6 --- /dev/null +++ b/examples/fs_rpc/filesystem_server.py @@ -0,0 +1,108 @@ +import os +from typing import List, Optional, Any +from cbor_rpc.rpc.context import RpcCallContext +from cbor_rpc import RpcV1Server + + +class FilesystemRpcServer(RpcV1Server): + async def validate_event_broadcast(self, connection_id, topic, message): + return False + + async def handle_method_call( + self, + connection_id: str, + context: RpcCallContext, + method: str, + args: List[Any], + ) -> Any: + if method == "list_files": + return self.list_files(*args) + elif method == "read_file": + return self.read_file(*args) + elif method == "create_file": + return self.create_file(*args) + elif method == "delete_file": + return self.delete_file(*args) + elif method == "rename_file": + return self.rename_file(*args) + else: + raise Exception(f"Unknown method: {method}") + + def list_files(self, directory: str) -> List[str]: + """Lists files and directories in the given path.""" + try: + return os.listdir(directory) + except Exception as e: + return f"Error listing files: {str(e)}" + + def read_file(self, path: str, chunk_size: int = 4096, offset: int = 0) -> bytes: + """Reads a file in chunks.""" + try: + with open(path, "rb") as f: + f.seek(offset) + return f.read(chunk_size) + except Exception as e: + return f"Error reading file: {str(e)}".encode() + + def create_file(self, path: str, content: Optional[bytes] = None) -> bool: + """Creates a file with optional initial content.""" + try: + with open(path, "wb") as f: + if content: + f.write(content) + return True + except Exception as e: + print(f"Error creating file: {str(e)}") + return False + + def delete_file(self, path: str) -> bool: + """Deletes a file.""" + try: + os.remove(path) + return True + except Exception as e: + print(f"Error deleting file: {str(e)}") + return False + + def rename_file(self, src: str, dest: str) -> bool: + """Renames/moves a file.""" + try: + os.rename(src, dest) + return True + except Exception as e: + print(f"Error renaming file: {str(e)}") + return False + + +if __name__ == "__main__": + import asyncio + from cbor_rpc.tcp import TcpPipe, TcpServer + from cbor_rpc.transformer.json_transformer import JsonTransformer + + class SimpleTcpServer(TcpServer): + async def accept(self, pipe: TcpPipe) -> bool: + return True + + async def main(): + rpc_id = 1 + # Create a TCP server that handles connections, using JsonTransformer for RPC messages + tcp_server = await SimpleTcpServer.create("localhost", 8000) + print("Server running on port 8000") + + # Set up event handlers for new connections + async def handle_connection(pipe): + json_transformer = JsonTransformer() + rpc_pipe = json_transformer.apply_transformer(pipe) + server = FilesystemRpcServer() + await server.add_connection(str(rpc_id), rpc_pipe) + + tcp_server.on_connection(handle_connection) + + # Just run until manually stopped + try: + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + await tcp_server.stop() + + asyncio.run(main()) diff --git a/examples/json_transformer_example.py b/examples/json_transformer_example.py deleted file mode 100644 index 10ca713..0000000 --- a/examples/json_transformer_example.py +++ /dev/null @@ -1,229 +0,0 @@ -""" -Example demonstrating JsonTransformer usage with different server types. -""" - -import asyncio -import json -from typing import Any, List -from cbor_rpc import ( - Server, TcpServer, TcpPipe, JsonTransformer, - RpcV1, RpcV1Server, Pipe -) - - -class JsonRpcServer(RpcV1Server): - """RPC server that works with JSON-transformed data.""" - - def __init__(self): - super().__init__() - self._message_log = [] - - async def handle_method_call(self, connection_id: str, method: str, args: List[Any]) -> Any: - """Handle RPC method calls.""" - self._message_log.append(f"JSON RPC call: {method} from {connection_id}") - - if method == "echo": - return {"echo": args[0] if args else None, "timestamp": "2024-01-01"} - elif method == "add_numbers": - return {"result": sum(args), "operation": "addition"} - elif method == "get_log": - return {"log": self._message_log.copy()} - elif method == "process_data": - # Process complex JSON data - data = args[0] if args else {} - return { - "processed": True, - "input_keys": list(data.keys()) if isinstance(data, dict) else [], - "input_type": type(data).__name__ - } - else: - raise Exception(f"Unknown method: {method}") - - async def validate_event_broadcast(self, connection_id: str, topic: str, message: Any) -> bool: - """Validate event broadcasts.""" - self._message_log.append(f"JSON event: {topic} from {connection_id}") - return True - - -async def demonstrate_json_over_tcp(): - """Demonstrate JSON RPC over TCP with JsonTransformer.""" - print("=== JSON RPC over TCP Example ===") - - # Create TCP server - tcp_server: Server[TcpPipe] = await TcpServer.create('127.0.0.1', 0) - rpc_server = JsonRpcServer() - - async def on_tcp_connection(tcp_pipe: TcpPipe): - # Wrap TCP pipe with JSON transformer - json_pipe = JsonTransformer(tcp_pipe) - - # Create connection ID - peer_info = tcp_pipe.get_peer_info() - conn_id = f"json-tcp-{peer_info[0]}:{peer_info[1]}" if peer_info else "unknown" - - # Add JSON-transformed connection to RPC server - await rpc_server.add_connection(conn_id, json_pipe) - print(f"New JSON RPC client connected: {conn_id}") - - tcp_server.on_connection(on_tcp_connection) - - try: - # Get server address - host, port = tcp_server.get_address() - print(f"JSON RPC server listening on {host}:{port}") - - # Create client with JSON transformer - tcp_client_pipe = await TcpPipe.create_connection(host, port) - json_client_pipe = JsonTransformer(tcp_client_pipe) - - # Create RPC client - rpc_client = RpcV1.make_rpc_v1( - json_client_pipe, - "json-client", - lambda m, a: {"client_response": f"Handled {m}"}, - lambda t, m: print(f"Client received event: {t} -> {m}") - ) - - await asyncio.sleep(0.1) # Let connection establish - - # Test JSON RPC calls - print("\n--- Testing JSON RPC calls ---") - - # Simple echo - result = await rpc_client.call_method("echo", "Hello JSON World!") - print(f"Echo result: {json.dumps(result, indent=2)}") - - # Math operation - result = await rpc_client.call_method("add_numbers", 10, 20, 30) - print(f"Add result: {json.dumps(result, indent=2)}") - - # Complex data processing - complex_data = { - "users": [ - {"name": "Alice", "age": 30}, - {"name": "Bob", "age": 25} - ], - "metadata": { - "version": "1.0", - "timestamp": "2024-01-01T00:00:00Z" - } - } - result = await rpc_client.call_method("process_data", complex_data) - print(f"Process data result: {json.dumps(result, indent=2)}") - - # Get server log - result = await rpc_client.call_method("get_log") - print(f"Server log: {json.dumps(result, indent=2)}") - - # Test event emission - await rpc_client.emit("user_action", { - "action": "login", - "user": "alice", - "timestamp": "2024-01-01T12:00:00Z" - }) - print("Emitted JSON event") - - await tcp_client_pipe.terminate() - - finally: - await tcp_server.stop() - print("JSON RPC server stopped") - - -async def demonstrate_json_transformer_pair(): - """Demonstrate direct JsonTransformer pair communication.""" - print("\n=== Direct JSON Transformer Pair Example ===") - - # Create JSON transformer pair - json_transformer1, json_transformer2 = JsonTransformer.create_pair() - - # Set up data handlers - transformer1_received = [] - transformer2_received = [] - - json_transformer1.on("data", lambda data: transformer1_received.append(data)) - json_transformer2.on("data", lambda data: transformer2_received.append(data)) - - # Test bidirectional JSON communication - test_data_1 = { - "message": "Hello from transformer 1", - "data": [1, 2, 3, {"nested": True}] - } - - test_data_2 = { - "response": "Hello from transformer 2", - "status": "success", - "metadata": {"processed_at": "2024-01-01"} - } - - await json_transformer1.write(test_data_1) - await json_transformer2.write(test_data_2) - await asyncio.sleep(0.1) - - print("Transformer 1 sent:", json.dumps(test_data_1, indent=2)) - print("Transformer 2 received:", json.dumps(transformer2_received[0], indent=2)) - print() - print("Transformer 2 sent:", json.dumps(test_data_2, indent=2)) - print("Transformer 1 received:", json.dumps(transformer1_received[0], indent=2)) - - # Verify data integrity - assert transformer2_received[0] == test_data_1 - assert transformer1_received[0] == test_data_2 - print("\nβœ… JSON data integrity verified!") - - -async def demonstrate_json_error_handling(): - """Demonstrate JSON transformer error handling.""" - print("\n=== JSON Error Handling Example ===") - - pipe1, pipe2 = Pipe.create_pair() - json_transformer = JsonTransformer(pipe1) - - errors = [] - json_transformer.on("error", lambda err: errors.append(str(err))) - - # Test encoding error - class NonSerializable: - def __repr__(self): - return "NonSerializable()" - - print("Testing encoding error...") - result = await json_transformer.write(NonSerializable()) - print(f"Write result: {result}") - await asyncio.sleep(0.01) - - if errors: - print(f"Encoding error caught: {errors[-1]}") - - # Test decoding error - print("\nTesting decoding error...") - await pipe2.write(b'{"invalid": json syntax}') - await asyncio.sleep(0.01) - - if len(errors) > 1: - print(f"Decoding error caught: {errors[-1]}") - - # Test recovery - print("\nTesting error recovery...") - valid_data = {"message": "Recovery successful"} - await pipe2.write(json.dumps(valid_data).encode('utf-8')) - - received_data = [] - json_transformer.on("data", lambda data: received_data.append(data)) - await asyncio.sleep(0.01) - - if received_data: - print(f"Recovery successful: {json.dumps(received_data[0], indent=2)}") - - print(f"\nTotal errors handled: {len(errors)}") - - -async def main(): - """Run all JSON transformer demonstrations.""" - await demonstrate_json_over_tcp() - await demonstrate_json_transformer_pair() - await demonstrate_json_error_handling() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/tcp_rpc_example.py b/examples/tcp_rpc_example.py deleted file mode 100644 index 216f489..0000000 --- a/examples/tcp_rpc_example.py +++ /dev/null @@ -1,168 +0,0 @@ -""" -Example demonstrating how to use TcpPipe with CBOR-RPC. -This example shows both client and server implementations using TCP transport. -""" - -import asyncio -import json -from typing import Any, List -from cbor_rpc import TcpPipe, TcpServer, RpcV1, RpcV1Server - - -class TcpRpcServer(RpcV1Server): - """An RPC server that accepts TCP connections.""" - - def __init__(self, host: str = '127.0.0.1', port: int = 0): - super().__init__() - self.host = host - self.port = port - self.tcp_server: TcpServer = None - - async def start(self): - """Start the TCP server and begin accepting connections.""" - self.tcp_server = await TcpServer.create(self.host, self.port) - - async def on_connection(tcp_duplex: TcpPipe): - # Create a unique connection ID - peer_info = tcp_duplex.get_peer_info() - conn_id = f"{peer_info[0]}:{peer_info[1]}" if peer_info else "unknown" - - # Add the TCP connection as an RPC client - await self.add_connection(conn_id, tcp_duplex) - print(f"New RPC client connected: {conn_id}") - - self.tcp_server.on_connection(on_connection) - - actual_host, actual_port = self.tcp_server.get_address() - print(f"RPC server listening on {actual_host}:{actual_port}") - return actual_host, actual_port - - async def stop(self): - """Stop the TCP server.""" - if self.tcp_server: - await self.tcp_server.close() - - async def handle_method_call(self, connection_id: str, method: str, args: List[Any]) -> Any: - """Handle RPC method calls.""" - print(f"Method call from {connection_id}: {method}({args})") - - if method == "echo": - return args[0] if args else None - elif method == "add": - return sum(args) if args else 0 - elif method == "get_server_info": - return { - "server": "TcpRpcServer", - "connections": len(self.active_connections), - "address": self.tcp_server.get_address() if self.tcp_server else None - } - else: - raise Exception(f"Unknown method: {method}") - - async def validate_event_broadcast(self, connection_id: str, topic: str, message: Any) -> bool: - """Validate whether an event should be broadcasted.""" - # Allow all events for this example - return True - - -async def run_server(): - """Run the RPC server example.""" - server = TcpRpcServer() - - try: - host, port = await server.start() - print(f"Server started on {host}:{port}") - - # Keep the server running - while True: - await asyncio.sleep(1) - - except KeyboardInterrupt: - print("Shutting down server...") - finally: - await server.stop() - - -async def run_client(host: str, port: int): - """Run the RPC client example.""" - try: - # Create TCP connection - tcp_duplex = await TcpPipe.create_connection(host, port) - print(f"Connected to server at {host}:{port}") - - # Create RPC client - def method_handler(method: str, args: List[Any]) -> Any: - print(f"Server called method: {method}({args})") - return f"Client response to {method}" - - async def event_handler(topic: str, message: Any) -> None: - print(f"Received event: {topic} -> {message}") - - rpc_client = RpcV1.make_rpc_v1(tcp_duplex, "client", method_handler, event_handler) - - # Test method calls - print("\n--- Testing RPC method calls ---") - - result = await rpc_client.call_method("echo", "Hello, Server!") - print(f"Echo result: {result}") - - result = await rpc_client.call_method("add", 1, 2, 3, 4, 5) - print(f"Add result: {result}") - - result = await rpc_client.call_method("get_server_info") - print(f"Server info: {json.dumps(result, indent=2)}") - - # Test fire method (no response expected) - await rpc_client.fire_method("echo", "Fire and forget message") - print("Fired method (no response)") - - # Test event emission - await rpc_client.emit("client_event", {"message": "Hello from client", "timestamp": "2024-01-01"}) - print("Emitted event") - - # Keep connection alive for a bit - await asyncio.sleep(2) - - except Exception as e: - print(f"Client error: {e}") - finally: - if 'tcp_duplex' in locals(): - await tcp_duplex.terminate() - print("Client disconnected") - - -async def main(): - """Main example function.""" - import sys - - if len(sys.argv) > 1 and sys.argv[1] == "server": - await run_server() - elif len(sys.argv) > 3 and sys.argv[1] == "client": - host = sys.argv[2] - port = int(sys.argv[3]) - await run_client(host, port) - else: - print("Usage:") - print(" python tcp_rpc_example.py server") - print(" python tcp_rpc_example.py client ") - print("\nRunning integrated example...") - - # Run integrated example - server = TcpRpcServer() - - try: - # Start server - host, port = await server.start() - - # Give server time to start - await asyncio.sleep(0.1) - - # Run client - await run_client(host, port) - - finally: - await server.stop() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 1a422e5..50fc97d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,12 @@ readme = "README.md" requires-python = ">=3.8" dependencies = [ "pytest>=8.3.2", - "pytest-asyncio>=0.24.0" + "pytest-asyncio>=0.24.0", + "pytest-cov>=5.0.0", + "asyncssh>=2.14.0", + "bcrypt", + "cbor2", + "docker" ] [build-system] @@ -24,3 +29,14 @@ addopts = "-v" testpaths = ["tests"] python_files = ["test_*.py"] pythonpath = ["cbor_rpc"] + +[tool.black] +line-length = 120 + +[tool.coverage.run] +branch = true +source = ["cbor_rpc"] + +[tool.coverage.report] +show_missing = true +skip_covered = true diff --git a/requirements.txt b/requirements.txt index 65b7223..5c32029 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,20 +1,63 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # -# pip-compile +# pip install pip-tools +# pip-compile --output-file=requirements.txt pyproject.toml # +asyncssh==2.22.0 + # via cbor-rpc (pyproject.toml) +bcrypt==5.0.0 + # via cbor-rpc (pyproject.toml) +cbor2==5.8.0 + # via cbor-rpc (pyproject.toml) +certifi==2026.1.4 + # via requests +cffi==2.0.0 + # via cryptography +charset-normalizer==3.4.4 + # via requests +coverage[toml]==7.6.4 + # via pytest-cov +cryptography==46.0.4 + # via asyncssh +docker==7.1.0 + # via cbor-rpc (pyproject.toml) +exceptiongroup==1.3.1 + # via pytest +idna==3.11 + # via requests iniconfig==2.1.0 # via pytest packaging==25.0 # via pytest pluggy==1.6.0 # via pytest +pycparser==3.0 + # via cffi pygments==2.19.1 # via pytest pytest==8.4.0 # via # cbor-rpc (pyproject.toml) # pytest-asyncio + # pytest-cov pytest-asyncio==1.0.0 # via cbor-rpc (pyproject.toml) +pytest-cov==5.0.0 + # via cbor-rpc (pyproject.toml) +requests==2.32.5 + # via docker +tomli==2.4.0 + # via + # coverage + # pytest +typing-extensions==4.15.0 + # via + # asyncssh + # cryptography + # exceptiongroup +urllib3==2.6.3 + # via + # docker + # requests diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..29a6b5c --- /dev/null +++ b/setup.py @@ -0,0 +1,23 @@ +from setuptools import setup, find_packages + +setup( + name="cbor-rpc", + version="0.1.0", + description="An async-compatible CBOR-based RPC system", + author="Sudip Bhattarai", + author_email="sudip@bhattarai.me", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + url="https://github.com/mesudip/cbor-rpc-py ", + packages=find_packages(include=["cbor_rpc", "cbor_rpc.*"]), + install_requires=["asyncssh>=2.14.0", "bcrypt", "cbor2"], + extras_require={ + "test": ["pytest>=8.3.2", "pytest-asyncio>=0.24.0", "pytest-cov>=5.0.0", "docker"], + }, + python_requires=">=3.8", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], +) diff --git a/tests/docker/sshd-python/Dockerfile b/tests/docker/sshd-python/Dockerfile new file mode 100644 index 0000000..28aed83 --- /dev/null +++ b/tests/docker/sshd-python/Dockerfile @@ -0,0 +1,13 @@ +# Use the linuxserver/openssh-server as base image +FROM linuxserver/openssh-server + +# Install python3 using apk +RUN apk add --no-cache python3 + +# Copy the binary_emitter.py script into the container +COPY binary_emitter.py /usr/local/bin/binary_emitter.py +RUN chmod +x /usr/local/bin/binary_emitter.py + +# Copy the echo_back.py script into the container +COPY echo_back.py /usr/local/bin/echo_back.py +RUN chmod +x /usr/local/bin/echo_back.py diff --git a/tests/docker/sshd-python/binary_emitter.py b/tests/docker/sshd-python/binary_emitter.py new file mode 100644 index 0000000..9b987c8 --- /dev/null +++ b/tests/docker/sshd-python/binary_emitter.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 +import os +import time +import binascii + +hex_data = "DEADBEEF0001020380FF7F" # Removed \x0A\x0D (LF and CR) +test_data = binascii.unhexlify(hex_data) + +while True: + os.write(1, test_data) + time.sleep(0.1) diff --git a/tests/docker/sshd-python/echo_back.py b/tests/docker/sshd-python/echo_back.py new file mode 100644 index 0000000..c47446a --- /dev/null +++ b/tests/docker/sshd-python/echo_back.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +import sys +import os + +# Read from stdin and write to stdout +# Ensure binary mode for consistent byte handling +stdin_fd = sys.stdin.buffer.fileno() +stdout_fd = sys.stdout.buffer.fileno() + +while True: + try: + # Read a chunk of data from stdin + data = os.read(stdin_fd, 4096) # Read up to 4096 bytes + if not data: + # EOF reached on stdin + break + # Write the data to stdout + os.write(stdout_fd, data) + # Explicitly flush stdout to ensure data is sent immediately + sys.stdout.buffer.flush() + except BrokenPipeError: + # stdout pipe was closed by the reader + break + except Exception as e: + # Log any other errors to stderr + sys.stderr.write(f"Error in echo_back.py: {e}\n".encode("utf-8")) + sys.stderr.flush() + break diff --git a/tests/event/test_emitter.py b/tests/event/test_emitter.py new file mode 100644 index 0000000..9fdc63f --- /dev/null +++ b/tests/event/test_emitter.py @@ -0,0 +1,306 @@ +import asyncio +import time +from typing import Any + +import pytest + +from cbor_rpc.event.emitter import AbstractEmitter + + +@pytest.mark.asyncio +async def test_on_and_emit(): + class TestEmitter(AbstractEmitter): + pass + + emitter = TestEmitter() + events = [] + + async def async_handler1(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_handler1_{data}") + + def sync_handler1(data: Any): + events.append(f"sync_handler1_{data}") + + async def async_handler2(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_handler2_{data}") + + emitter.on("test", async_handler1) + emitter.on("test", sync_handler1) + emitter.on("test", async_handler2) + + emitter._emit("test", "event1") + await asyncio.sleep(0.02) + + expected = [ + "async_handler1_event1", + "sync_handler1_event1", + "async_handler2_event1", + ] + assert sorted(events) == sorted(expected), f"Expected {expected}, got {events}" + + +@pytest.mark.asyncio +async def test_pipeline_and_notify(): + class TestEmitter(AbstractEmitter): + pass + + emitter = TestEmitter() + events = [] + + async def async_handler1(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_handler1_{data}") + + def sync_handler1(data: Any): + time.sleep(0.01) + events.append(f"sync_handler1_{data}") + + async def async_pipeline1(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_pipeline1_{data}") + + async def async_pipeline2(data: Any): + # Slow pipeline + await asyncio.sleep(0.05) + events.append(f"async_pipeline2_{data}") + + async def async_pipeline3(data: Any): + # Fast pipeline + await asyncio.sleep(0.01) + events.append(f"async_pipeline3_{data}") + + emitter.on("test", async_handler1) + emitter.on("test", sync_handler1) + emitter.pipeline("test", async_pipeline1) + emitter.pipeline("test", async_pipeline2) + emitter.pipeline("test", async_pipeline3) + + await emitter._notify("test", "event2") + await asyncio.sleep(0.02) + + # Strictly sequential execution order: + # 1. async_pipeline1 (waits 0.01) + # 2. async_pipeline2 (waits 0.05) -> If concurrent, this would finish LAST + # 3. async_pipeline3 (waits 0.01) -> If concurrent, this would finish BEFORE pipeline2 + + expected_pipelines = [ + "async_pipeline1_event2", + "async_pipeline2_event2", + "async_pipeline3_event2", + ] + + # Verify pipelines ran in strict order + assert events[:3] == expected_pipelines, f"Pipelines expected {expected_pipelines}, got {events[:3]}" + + expected_subscribers = ["async_handler1_event2", "sync_handler1_event2"] + + # Verify handlers ran after pipelines + assert set(events[3:]) == set(expected_subscribers), f"Subscribers expected {expected_subscribers}, got {events[3:]}" + + # Explicitly assert precedence + for p in expected_pipelines: + for s in expected_subscribers: + assert events.index(p) < events.index(s), f"Pipeline {p} executed after subscriber {s}" + + +@pytest.mark.asyncio +async def test_unsubscribe(): + class TestEmitter(AbstractEmitter): + pass + + emitter = TestEmitter() + events = [] + + async def async_handler1(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_handler1_{data}") + + def sync_handler1(data: Any): + events.append(f"sync_handler1_{data}") + + emitter.on("test", async_handler1) + emitter.on("test", sync_handler1) + emitter.unsubscribe("test", async_handler1) + + emitter._emit("test", "event3") + await asyncio.sleep(0.02) + + expected = ["sync_handler1_event3"] + assert events == expected, f"Expected {expected}, got {events}" + + +@pytest.mark.asyncio +async def test_replace_on_handler(): + class TestEmitter(AbstractEmitter): + pass + + emitter = TestEmitter() + events = [] + + async def async_handler1(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_handler1_{data}") + + def sync_handler1(data: Any): + events.append(f"sync_handler1_{data}") + + emitter.on("test", sync_handler1) + emitter.replace_on_handler("test", async_handler1) + + emitter._emit("test", "event4") + await asyncio.sleep(0.02) + + expected = ["async_handler1_event4"] + assert events == expected, f"Expected {expected}, got {events}" + + +@pytest.mark.asyncio +async def test_pipeline_failure(): + class TestEmitter(AbstractEmitter): + pass + + emitter = TestEmitter() + events = [] + + async def async_pipeline1(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_pipeline1_{data}") + raise ValueError("Pipeline failed") + + async def async_pipeline2(data: Any): + events.append(f"async_pipeline2_{data}") + + def sync_handler1(data: Any): + events.append(f"sync_handler1_{data}") + + emitter.pipeline("test", async_pipeline1) + emitter.pipeline("test", async_pipeline2) + emitter.on("test", sync_handler1) + + with pytest.raises(ValueError): + await emitter._notify("test", "event5") + + expected = ["async_pipeline1_event5"] + assert events == expected, f"Expected {expected}, got {events}" + + # Explicitly verify skipped execution + assert "async_pipeline2_event5" not in events, "Subsequent pipeline should be skipped" + assert "sync_handler1_event5" not in events, "Event subscribers should be skipped" + + +@pytest.mark.asyncio +async def test_multiple_event_types(): + class TestEmitter(AbstractEmitter): + pass + + emitter = TestEmitter() + events = [] + + async def async_handler_a(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_handler_a_{data}") + + def sync_handler_a(data: Any): + events.append(f"sync_handler_a_{data}") + + async def async_handler_b(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_handler_b_{data}") + + def sync_handler_b(data: Any): + events.append(f"sync_handler_b_{data}") + + async def async_pipeline_a(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_pipeline_a_{data}") + + def sync_pipeline_b(data: Any): + events.append(f"sync_pipeline_b_{data}") + + emitter.on("event_a", async_handler_a) + emitter.on("event_a", sync_handler_a) + emitter.pipeline("event_a", async_pipeline_a) + emitter.on("event_b", async_handler_b) + emitter.on("event_b", sync_handler_b) + emitter.pipeline("event_b", sync_pipeline_b) + + events.clear() + emitter._emit("event_a", "data_a") + await asyncio.sleep(0.02) + expected = ["async_handler_a_data_a", "sync_handler_a_data_a"] + assert sorted(events) == sorted(expected), f"Expected {expected}, got {events}" + + events.clear() + await emitter._notify("event_a", "data_a2") + await asyncio.sleep(0.02) + expected_pipelines = ["async_pipeline_a_data_a2"] + expected_subscribers = ["async_handler_a_data_a2", "sync_handler_a_data_a2"] + pipeline_indices = [events.index(e) for e in expected_pipelines if e in events] + subscriber_indices = [events.index(e) for e in expected_subscribers if e in events] + assert all( + p < s for p in pipeline_indices for s in subscriber_indices + ), f"Pipelines {expected_pipelines} should precede subscribers {expected_subscribers} in {events}" + assert sorted(events) == sorted( + expected_pipelines + expected_subscribers + ), f"Expected {expected_pipelines + expected_subscribers}, got {events}" + + +class DummyEmitter(AbstractEmitter): + pass + + +def test_emitter_no_running_loop_warning(): + emitter = DummyEmitter() + with pytest.warns(RuntimeWarning): + emitter._run_background_task(lambda: None) + + +def test_emitter_emit_sync_error_warning(): + emitter = DummyEmitter() + + def bad_handler(_data: Any) -> None: + raise ValueError("boom") + + emitter.on("evt", bad_handler) + with pytest.warns(RuntimeWarning): + emitter._emit("evt", "data") + + +@pytest.mark.asyncio +async def test_emitter_notify_pipeline_errors(): + emitter = DummyEmitter() + errors = [] + + def on_error(err: Exception) -> None: + errors.append(err) + + def bad_pipeline(_data: Any) -> None: + raise ValueError("boom") + + emitter.on("error", on_error) + emitter.pipeline("evt", bad_pipeline) + + with pytest.raises(ValueError): + await emitter._notify("evt", "data") + assert errors + + +@pytest.mark.asyncio +async def test_emitter_notify_async_pipeline_error(): + emitter = DummyEmitter() + errors = [] + + async def bad_pipeline(_data: Any) -> None: + raise ValueError("boom") + + def on_error(err: Exception) -> None: + errors.append(err) + + emitter.on("error", on_error) + emitter.pipeline("evt", bad_pipeline) + + with pytest.raises(ValueError): + await emitter._notify("evt", "data") + assert errors diff --git a/tests/helpers/simple_pipe.py b/tests/helpers/simple_pipe.py index a4b4a2c..eca1647 100644 --- a/tests/helpers/simple_pipe.py +++ b/tests/helpers/simple_pipe.py @@ -1,11 +1,11 @@ from typing import Any, Generic, TypeVar -from cbor_rpc import Pipe, RpcV1, DeferredPromise +from cbor_rpc import EventPipe, RpcV1, TimedPromise # Generic type variables -T1 = TypeVar('T1') +T1 = TypeVar("T1") -class SimplePipe(Pipe[T1, T1], Generic[T1]): +class SimplePipe(EventPipe[T1, T1], Generic[T1]): def __init__(self): super().__init__() self._closed = False @@ -13,11 +13,11 @@ def __init__(self): async def write(self, chunk: T1) -> bool: if self._closed: return False - await self._emit("data", chunk) + await self._notify("data", chunk) return True async def terminate(self, *args: Any) -> None: if self._closed: return self._closed = True - await self._emit("close", *args) + self._emit("close", *args) diff --git a/tests/helpers/simple_tcp_server.py b/tests/helpers/simple_tcp_server.py new file mode 100644 index 0000000..233064f --- /dev/null +++ b/tests/helpers/simple_tcp_server.py @@ -0,0 +1,13 @@ +from cbor_rpc.tcp.tcp import TcpServer, TcpPipe + + +class SimpleTcpServer(TcpServer): + """ + A simple TCP server implementation for testing purposes that accepts all connections. + """ + + async def accept(self, pipe: TcpPipe) -> bool: + """ + Accepts all incoming TCP connections. + """ + return True diff --git a/tests/helpers/stdio_test_script.py b/tests/helpers/stdio_test_script.py new file mode 100644 index 0000000..42a9f33 --- /dev/null +++ b/tests/helpers/stdio_test_script.py @@ -0,0 +1,21 @@ +import sys + + +def main(): + print("stdio_test_script: Started.", file=sys.stderr) + + while True: + print("stdio_test_script: Blocked on read", file=sys.stderr) + data = sys.stdin.read(1024) # Read up to 1024 bytes + print(f"stdio_test_script: Read {len(data)} bytes from stdin.", file=sys.stderr) + sys.stderr.flush() + if not data: # EOF reached + print("stdio_test_script: EOF reached, exiting.", file=sys.stderr) + sys.stderr.flush() + break + sys.stdout.write(data) + sys.stdout.flush() + + +if __name__ == "__main__": + main() diff --git a/tests/helpers/stream_pair.py b/tests/helpers/stream_pair.py new file mode 100644 index 0000000..7e3a7fc --- /dev/null +++ b/tests/helpers/stream_pair.py @@ -0,0 +1,12 @@ +import asyncio +from typing import Tuple + + +async def create_stream_pair() -> Tuple[asyncio.AbstractServer, asyncio.StreamReader, asyncio.StreamWriter]: + async def handler(_reader: asyncio.StreamReader, _writer: asyncio.StreamWriter) -> None: + await asyncio.sleep(0.2) + + server = await asyncio.start_server(handler, "127.0.0.1", 0) + host, port = server.sockets[0].getsockname()[:2] + reader, writer = await asyncio.open_connection(host, port) + return server, reader, writer diff --git a/tests/helpers/timeout_queue.py b/tests/helpers/timeout_queue.py new file mode 100644 index 0000000..7c35e44 --- /dev/null +++ b/tests/helpers/timeout_queue.py @@ -0,0 +1,13 @@ +import asyncio +from typing import Any, Optional + + +class TimeoutQueue(asyncio.Queue): + def __init__(self, *args: Any, default_timeout: float = 1.0, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._default_timeout = default_timeout + + async def get(self, timeout: Optional[float] = None) -> Any: + if timeout is None: + timeout = self._default_timeout + return await asyncio.wait_for(super().get(), timeout=timeout) diff --git a/tests/misc/test_timed_promise.py b/tests/misc/test_timed_promise.py new file mode 100644 index 0000000..c3cdbf8 --- /dev/null +++ b/tests/misc/test_timed_promise.py @@ -0,0 +1,26 @@ +import pytest + +from cbor_rpc.timed_promise import TimedPromise + + +@pytest.mark.asyncio +async def test_timed_promise_resolve_reject_and_timeout(): + resolved = TimedPromise(10) + await resolved.resolve("ok") + assert await resolved.promise == "ok" + + rejected = TimedPromise(10) + await rejected.reject("bad") + with pytest.raises(Exception): + await rejected.promise + + timeout_called = [] + + def on_timeout(): + timeout_called.append(True) + + timed = TimedPromise(1, on_timeout) + with pytest.raises(Exception) as exc_info: + await timed.promise + assert exc_info.value.args[0]["timeout"] is True + assert timeout_called == [True] diff --git a/tests/pipe/test_aio_pipe.py b/tests/pipe/test_aio_pipe.py new file mode 100644 index 0000000..42ed86f --- /dev/null +++ b/tests/pipe/test_aio_pipe.py @@ -0,0 +1,150 @@ +import asyncio +from typing import Any, List, Optional + +import pytest + +from cbor_rpc.pipe.aio_pipe import AioPipe + + +class FakeReader: + def __init__(self, responses: List[Any]): + self._responses = list(responses) + + async def read(self, _size: int) -> bytes: + if not self._responses: + return b"" + item = self._responses.pop(0) + if isinstance(item, Exception): + raise item + return item + + +class FakeWriter: + def __init__(self, drain_error: Optional[Exception] = None, close_error: Optional[Exception] = None): + self._drain_error = drain_error + self._close_error = close_error + self._closed = False + + def write(self, _chunk: bytes) -> None: + return None + + async def drain(self) -> None: + if self._drain_error: + raise self._drain_error + + def close(self) -> None: + if self._close_error: + raise self._close_error + self._closed = True + + async def wait_closed(self) -> None: + return None + + +class TestAioPipe(AioPipe[bytes, bytes]): + pass + + +class NotifyFailPipe(TestAioPipe): + async def _notify(self, event_type: str, *args: Any) -> None: + raise RuntimeError(f"notify:{event_type}") + + +class NotifyOncePipe(TestAioPipe): + def __init__(self, reader: FakeReader, writer: FakeWriter): + super().__init__(reader, writer) + self._notified = False + + async def _notify(self, event_type: str, *args: Any) -> None: + if event_type == "data" and not self._notified: + self._notified = True + raise RuntimeError("notify error") + return await super()._notify(event_type, *args) + + +@pytest.mark.asyncio +async def test_aio_pipe_init_requires_reader_and_writer(): + with pytest.raises(ValueError): + TestAioPipe(reader=FakeReader([]), writer=None) + + +@pytest.mark.asyncio +async def test_aio_pipe_setup_without_reader_writer(): + pipe = TestAioPipe() + with pytest.raises(RuntimeError): + await pipe._setup_connection() + + +@pytest.mark.asyncio +async def test_aio_pipe_setup_notify_error_closes(): + pipe = NotifyFailPipe(FakeReader([b"data", b""]), FakeWriter()) + + with pytest.raises(RuntimeError): + await pipe._setup_connection() + + assert pipe.is_connected() is False + + +@pytest.mark.asyncio +async def test_aio_pipe_read_loop_notify_error_emits_close(): + pipe = NotifyOncePipe(FakeReader([b"data", b""]), FakeWriter()) + closed = asyncio.Event() + + def on_close(*_args: Any) -> None: + closed.set() + + pipe.on("close", on_close) + await pipe._setup_connection() + await asyncio.wait_for(closed.wait(), timeout=1) + + +@pytest.mark.asyncio +async def test_aio_pipe_read_loop_reader_error_emits_error(): + pipe = TestAioPipe(FakeReader([ValueError("boom")]), FakeWriter()) + errors: List[Any] = [] + + def on_error(err: Exception) -> None: + errors.append(err) + + pipe.on("error", on_error) + await pipe._setup_connection() + await asyncio.sleep(0.05) + assert errors + + +@pytest.mark.asyncio +async def test_aio_pipe_write_errors(): + pipe = TestAioPipe(FakeReader([]), FakeWriter()) + + with pytest.raises(ConnectionError): + await pipe.write(b"data") + + pipe._connected = True + with pytest.raises(TypeError): + await pipe.write("data") + + errors: List[Any] = [] + + def on_error(err: Exception) -> None: + errors.append(err) + + pipe.on("error", on_error) + pipe._writer = FakeWriter(drain_error=RuntimeError("drain")) + ok = await pipe.write(b"data") + assert ok is False + assert errors + + +@pytest.mark.asyncio +async def test_aio_pipe_close_writer_exception_emits_error(): + pipe = TestAioPipe(FakeReader([]), FakeWriter()) + errors: List[Any] = [] + + def on_error(err: Exception) -> None: + errors.append(err) + + pipe.on("error", on_error) + pipe._connected = True + pipe._writer = FakeWriter(close_error=RuntimeError("close")) + await pipe._close_connection() + assert errors diff --git a/tests/pipe/test_event_pipe.py b/tests/pipe/test_event_pipe.py new file mode 100644 index 0000000..d5020d7 --- /dev/null +++ b/tests/pipe/test_event_pipe.py @@ -0,0 +1,179 @@ +import asyncio +from typing import Any, Tuple + +import pytest +import pytest_asyncio + +from cbor_rpc import EventPipe + + +@pytest_asyncio.fixture +async def event_pipe_pair(): + pipe1, pipe2 = EventPipe.create_inmemory_pair() + yield pipe1, pipe2 + await pipe1.terminate() + await pipe2.terminate() + + +@pytest.mark.asyncio +async def test_create_pair(event_pipe_pair: Tuple[EventPipe, EventPipe]): + pipe1, pipe2 = event_pipe_pair + assert isinstance(pipe1, EventPipe) + assert isinstance(pipe2, EventPipe) + + +@pytest.mark.asyncio +async def test_write_success(event_pipe_pair: Tuple[EventPipe, EventPipe]): + pipe1, _pipe2 = event_pipe_pair + result = await pipe1.write("test_chunk") + assert result is True + + +@pytest.mark.asyncio +async def test_terminate_success(event_pipe_pair: Tuple[EventPipe, EventPipe]): + pipe1, _pipe2 = event_pipe_pair + await pipe1.terminate() + + +@pytest.mark.asyncio +async def test_pipeline_execution(event_pipe_pair: Tuple[EventPipe, EventPipe]): + pipe1, _pipe2 = event_pipe_pair + received_chunk = None + event = asyncio.Event() + + async def pipeline_handler(chunk: Any) -> None: + nonlocal received_chunk + received_chunk = chunk + event.set() + + pipe1.pipeline("data", pipeline_handler) + await pipe1._notify("data", "test_chunk") + await asyncio.wait_for(event.wait(), timeout=1) + assert received_chunk == "test_chunk" + + +@pytest.mark.asyncio +async def test_pipe_pair(event_pipe_pair: Tuple[EventPipe, EventPipe]): + pipe1, pipe2 = event_pipe_pair + received_chunk = None + event = asyncio.Event() + + async def handler(chunk: Any) -> None: + nonlocal received_chunk + received_chunk = chunk + event.set() + + pipe2.pipeline("data", handler) + await pipe1.write("test_chunk") + await asyncio.wait_for(event.wait(), timeout=1) + assert received_chunk == "test_chunk" + + +@pytest.mark.asyncio +async def test_write_after_terminate(event_pipe_pair: Tuple[EventPipe, EventPipe]): + pipe1, _pipe2 = event_pipe_pair + await pipe1.terminate() + + result = await pipe1.write("test_chunk") + assert result is False + + +@pytest.mark.asyncio +async def test_parallel_event_writes(event_pipe_pair: Tuple[EventPipe, EventPipe]): + pipe1, pipe2 = event_pipe_pair + num_writes = 10 + test_chunks = [f"chunk_{i}" for i in range(num_writes)] + received_chunks = [] + + lock = asyncio.Lock() + received_queue = asyncio.Queue() + + async def pipeline_handler(chunk: Any) -> None: + async with lock: + received_chunks.append(chunk) + if len(received_chunks) == num_writes: + received_queue.put_nowait(True) + + pipe2.pipeline("data", pipeline_handler) + + async def writer(chunk): + return await pipe1.write(chunk) + + results = await asyncio.gather(*[writer(chunk) for chunk in test_chunks]) + assert all(results) + + await asyncio.wait_for(received_queue.get(), timeout=5) + + assert sorted(received_chunks) == sorted(test_chunks) + + +@pytest.mark.asyncio +async def test_parallel_event_processing(event_pipe_pair: Tuple[EventPipe, EventPipe]): + pipe1, pipe2 = event_pipe_pair + num_chunks = 10 + test_chunks = [f"data_{i}" for i in range(num_chunks)] + processed_chunks = [] + + lock = asyncio.Lock() + processed_queue = asyncio.Queue() + + async def processing_handler(chunk: Any) -> None: + await asyncio.sleep(0.01) + async with lock: + processed_chunks.append(chunk) + if len(processed_chunks) == num_chunks: + processed_queue.put_nowait(True) + + pipe2.pipeline("data", processing_handler) + + write_tasks = [pipe1.write(chunk) for chunk in test_chunks] + await asyncio.gather(*write_tasks) + + await asyncio.wait_for(processed_queue.get(), timeout=5) + + assert sorted(processed_chunks) == sorted(test_chunks) + + +@pytest.mark.asyncio +async def test_concurrent_bidirectional_event_communication( + event_pipe_pair: Tuple[EventPipe, EventPipe], +): + pipe1, pipe2 = event_pipe_pair + num_messages = 5 + + client_received_responses = [] + server_received_msgs = [] + server_sent_responses = [] + + client_done_event = asyncio.Event() + server_done_event = asyncio.Event() + + async def client_handler(response: Any) -> None: + client_received_responses.append(response) + if len(client_received_responses) == num_messages: + client_done_event.set() + + async def server_handler(msg: Any) -> None: + server_received_msgs.append(msg) + response = f"server_response_to_{msg}" + await pipe2.write(response) + server_sent_responses.append(response) + if len(server_sent_responses) == num_messages: + server_done_event.set() + + pipe1.pipeline("data", client_handler) + pipe2.pipeline("data", server_handler) + + async def client_writer_task(): + for i in range(num_messages): + msg = f"client_msg_{i}" + await pipe1.write(msg) + + asyncio.create_task(client_writer_task()) + + await asyncio.wait_for(asyncio.gather(client_done_event.wait(), server_done_event.wait()), timeout=5) + + assert sorted([f"client_msg_{i}" for i in range(num_messages)]) == sorted(server_received_msgs) + assert sorted([f"server_response_to_client_msg_{i}" for i in range(num_messages)]) == sorted( + client_received_responses + ) diff --git a/tests/pipe/test_pipe.py b/tests/pipe/test_pipe.py new file mode 100644 index 0000000..01242be --- /dev/null +++ b/tests/pipe/test_pipe.py @@ -0,0 +1,143 @@ +import asyncio +from typing import Tuple + +import pytest +import pytest_asyncio + +from cbor_rpc.pipe.pipe import Pipe + + +@pytest_asyncio.fixture +async def pipe_pair(): + pipe1, pipe2 = Pipe.create_pair() + yield pipe1, pipe2 + + +@pytest.mark.asyncio +async def test_create_pair(): + pipe1, pipe2 = Pipe.create_pair() + assert isinstance(pipe1, Pipe) + assert isinstance(pipe2, Pipe) + await pipe1.terminate() + await pipe2.terminate() + + +@pytest.mark.asyncio +async def test_write_read(pipe_pair: Tuple[Pipe, Pipe]): + pipe1, pipe2 = pipe_pair + + assert await pipe1.write("test_chunk") is True + await asyncio.sleep(0) + assert await pipe2.read() == "test_chunk" + + +@pytest.mark.asyncio +async def test_close_pipe(pipe_pair: Tuple[Pipe, Pipe]): + pipe1, pipe2 = pipe_pair + await pipe1.terminate() + + assert await pipe1.read() is None + assert pipe1._closed is True + assert pipe2._closed is True + + +@pytest.mark.asyncio +async def test_write_after_close(pipe_pair: Tuple[Pipe, Pipe]): + pipe1, _pipe2 = pipe_pair + await pipe1.terminate() + + assert await pipe1.write("test_chunk") is False + + +@pytest.mark.asyncio +async def test_read_timeout(pipe_pair: Tuple[Pipe, Pipe]): + pipe1, _pipe2 = pipe_pair + + assert await pipe1.read(timeout=0.1) is None + + +@pytest.mark.asyncio +async def test_bidirectional_communication(pipe_pair: Tuple[Pipe, Pipe]): + pipe1, pipe2 = pipe_pair + + assert await pipe1.write("test_chunk") is True + await asyncio.sleep(0) + assert await pipe2.read() == "test_chunk" + + assert await pipe2.write("response_chunk") is True + await asyncio.sleep(0) + assert await pipe1.read() == "response_chunk" + + +@pytest.mark.asyncio +async def test_parallel_writes(pipe_pair: Tuple[Pipe, Pipe]): + pipe1, pipe2 = pipe_pair + num_writes = 10 + test_chunks = [f"chunk_{i}" for i in range(num_writes)] + + async def writer(chunk): + return await pipe1.write(chunk) + + results = await asyncio.gather(*[writer(chunk) for chunk in test_chunks]) + assert all(results) + + received_chunks = [] + for _ in range(num_writes): + received_chunks.append(await pipe2.read()) + + assert sorted(received_chunks) == sorted(test_chunks) + + +@pytest.mark.asyncio +async def test_parallel_reads(pipe_pair: Tuple[Pipe, Pipe]): + pipe1, pipe2 = pipe_pair + num_reads = 20 + test_chunks = [f"data_{i}" for i in range(num_reads)] + + for chunk in test_chunks: + await pipe1.write(chunk) + await asyncio.sleep(0) + + async def reader(): + return await pipe2.read() + + received_chunks = await asyncio.gather(*[reader() for _ in range(num_reads)]) + + assert sorted(received_chunks) == sorted(test_chunks) + + +@pytest.mark.asyncio +async def test_concurrent_bidirectional_communication(pipe_pair: Tuple[Pipe, Pipe]): + pipe1, pipe2 = pipe_pair + num_messages = 20 + + async def client_task(): + sent = [] + received = [] + for i in range(num_messages): + msg = f"client_msg_{i}" + await pipe1.write(msg) + sent.append(msg) + response = await pipe1.read() + received.append(response) + return sent, received + + async def server_task(): + sent = [] + received = [] + for i in range(num_messages): + msg = await pipe2.read() + received.append(msg) + response = f"server_response_to_{msg}" + await pipe2.write(response) + sent.append(response) + return sent, received + + client_future = asyncio.create_task(client_task()) + server_future = asyncio.create_task(server_task()) + + _client_sent, client_received = await client_future + _server_sent, server_received = await server_future + + assert sorted([f"client_msg_{i}" for i in range(num_messages)]) == sorted(server_received) + assert sorted([f"server_response_to_client_msg_{i}" for i in range(num_messages)]) == sorted(client_received) diff --git a/tests/pipe/test_pipe_extra.py b/tests/pipe/test_pipe_extra.py new file mode 100644 index 0000000..94ce372 --- /dev/null +++ b/tests/pipe/test_pipe_extra.py @@ -0,0 +1,94 @@ +import asyncio +from typing import Any, List + +import pytest + +from cbor_rpc.pipe.pipe import Pipe + + +@pytest.mark.asyncio +async def test_pipe_read_timeout_returns_none(): + a, _b = Pipe.create_pair() + result = await a.read(timeout=0.01) + assert result is None + + +@pytest.mark.asyncio +async def test_pipe_cancelled_read_requeues_termination_signal(): + a, _b = Pipe.create_pair() + read_task = asyncio.create_task(a.read(timeout=1)) + await asyncio.sleep(0) + read_task.cancel() + with pytest.raises(asyncio.CancelledError): + await read_task + a._buffer.put_nowait(None) + result = await a.read(timeout=0.01) + assert result is None + + +@pytest.mark.asyncio +async def test_pipe_terminate_closes_both_ends(): + a, b = Pipe.create_pair() + await a.terminate("done") + result = await b.read(timeout=0.01) + assert result is None + + +@pytest.mark.asyncio +async def test_pipe_make_event_based_emits_data_and_close(): + a, b = Pipe.create_pair() + event_pipe = a.make_event_based() + received: List[Any] = [] + closed = asyncio.Event() + + async def on_data(data: Any) -> None: + received.append(data) + + def on_close(*_args: Any) -> None: + closed.set() + + event_pipe.pipeline("data", on_data) + event_pipe.on("close", on_close) + + await b.write("hello") + await asyncio.sleep(0.01) + assert received == ["hello"] + + await b.terminate("bye") + await asyncio.wait_for(closed.wait(), timeout=1) + + +@pytest.mark.asyncio +async def test_pipe_write_fails_when_peer_closed(): + a, b = Pipe.create_pair() + await b.terminate("done") + result = await a.write("data") + assert result is False + + +@pytest.mark.asyncio +async def test_pipe_read_with_zero_timeout_reads_buffered_item(): + a, b = Pipe.create_pair() + await b.write("data") + result = await a.read(timeout=0) + assert result == "data" + + +@pytest.mark.asyncio +async def test_pipe_make_event_based_write_roundtrip(): + a, b = Pipe.create_pair() + event_pipe = a.make_event_based() + await event_pipe.write("out") + result = await b.read(timeout=0.1) + assert result == "out" + await event_pipe.terminate("done") + + +@pytest.mark.asyncio +async def test_pipe_make_event_based_terminate_is_idempotent(capsys): + a, _b = Pipe.create_pair() + event_pipe = a.make_event_based() + await event_pipe.terminate("done") + await event_pipe.terminate("done") + captured = capsys.readouterr().out + assert "PipeToEvent: Terminating event pipe." in captured diff --git a/tests/rpc/test_rpc_base.py b/tests/rpc/test_rpc_base.py new file mode 100644 index 0000000..1ec5522 --- /dev/null +++ b/tests/rpc/test_rpc_base.py @@ -0,0 +1,56 @@ +from typing import Any, Optional + +import pytest + +from cbor_rpc.rpc.rpc_base import RpcClient, RpcServer + + +class DummyClient(RpcClient): + async def call_method(self, method: str, *args: Any) -> Any: + return await RpcClient.call_method(self, method, *args) + + async def fire_method(self, method: str, *args: Any) -> None: + return await RpcClient.fire_method(self, method, *args) + + def set_timeout(self, milliseconds: int) -> None: + return RpcClient.set_timeout(self, milliseconds) + + +class DummyServerBase(RpcServer): + async def call_method(self, connection_id: str, method: str, *args: Any) -> Any: + return await RpcServer.call_method(self, connection_id, method, *args) + + async def fire_method(self, connection_id: str, method: str, *args: Any) -> None: + return await RpcServer.fire_method(self, connection_id, method, *args) + + async def disconnect(self, connection_id: str, reason: Optional[str] = None) -> None: + return await RpcServer.disconnect(self, connection_id, reason) + + def get_client(self, connection_id: str): + return RpcServer.get_client(self, connection_id) + + def with_client(self, connection_id: str, action): + return RpcServer.with_client(self, connection_id, action) + + def set_timeout(self, milliseconds: int) -> None: + return RpcServer.set_timeout(self, milliseconds) + + def is_active(self, connection_id: str) -> bool: + return RpcServer.is_active(self, connection_id) + + +@pytest.mark.asyncio +async def test_rpc_base_methods_execute(): + client = DummyClient() + assert await client.call_method("m") is None + assert await client.fire_method("m") is None + assert client.set_timeout(1) is None + + server = DummyServerBase() + assert await server.call_method("id", "m") is None + assert await server.fire_method("id", "m") is None + assert await server.disconnect("id") is None + assert server.get_client("id") is None + assert server.with_client("id", lambda _c: None) is None + assert server.set_timeout(1) is None + assert server.is_active("id") is None diff --git a/tests/rpc/test_rpc_logging.py b/tests/rpc/test_rpc_logging.py new file mode 100644 index 0000000..2e923dc --- /dev/null +++ b/tests/rpc/test_rpc_logging.py @@ -0,0 +1,150 @@ +import asyncio +from typing import Any, List + +import pytest + +from cbor_rpc import EventPipe, RpcV1 +from cbor_rpc.rpc.context import RpcCallContext + + +LOG_LEVELS = [ + ("crit", 1), + ("warn", 2), + ("info", 3), + ("verbose", 4), + ("debug", 5), +] + + +class LoggingServerRpc(RpcV1): + def __init__(self, pipe: EventPipe[Any, Any], id_: str): + super().__init__(pipe) + self._id = id_ + + def get_id(self) -> str: + return self._id + + def handle_method_call(self, context: RpcCallContext, method: str, args: List[Any]) -> Any: + if method == "log_levels": + for level_name, _level_value in LOG_LEVELS: + self._log_for_level(context, f"method:{level_name}", level_name) + return "ok" + raise Exception(f"Unknown method: {method}") + + async def on_event(self, context: RpcCallContext, topic: str, message: Any) -> None: + if message == "levels": + for level_name, _level_value in LOG_LEVELS: + self._log_for_level(context, f"event:{topic}:{level_name}", level_name) + return + context.logger.warn(f"event:{topic}:{message}") + + def _log_for_level(self, context: RpcCallContext, content: str, level_name: str) -> None: + if level_name == "crit": + context.logger.crit(content) + elif level_name == "warn": + context.logger.warn(content) + elif level_name == "info": + context.logger.log(content) + elif level_name == "verbose": + context.logger.verbose(content) + elif level_name == "debug": + context.logger.debug(content) + else: + raise Exception(f"Unknown log level: {level_name}") + + +class LoggingClientRpc(RpcV1): + def __init__(self, pipe: EventPipe[Any, Any], id_: str): + super().__init__(pipe) + self._id = id_ + + def get_id(self) -> str: + return self._id + + def handle_method_call(self, context: RpcCallContext, method: str, args: List[Any]) -> Any: + raise Exception("Client Only Implementation") + + async def on_event(self, context: RpcCallContext, topic: str, message: Any) -> None: + return None + + +async def _collect_logs(queue: asyncio.Queue, expected_count: int) -> List[List[Any]]: + logs: List[List[Any]] = [] + for _ in range(expected_count): + log = await asyncio.wait_for(queue.get(), timeout=1) + logs.append(log) + return logs + + +def _attach_log_listener(pipe: EventPipe[Any, Any], queue: asyncio.Queue) -> None: + async def handler(chunk: Any) -> None: + if not isinstance(chunk, list) or len(chunk) < 2: + return + if chunk[0] != 2 or chunk[1] == 0: + return + await queue.put(chunk) + + pipe.pipeline("data", handler) + + +@pytest.mark.asyncio +async def test_rpc_logging_all_levels_method_and_event(): + pipe_a, pipe_b = EventPipe.create_inmemory_pair() + server = LoggingServerRpc(pipe_a, "server") + client = LoggingClientRpc(pipe_b, "client") + + log_queue: asyncio.Queue = asyncio.Queue() + _attach_log_listener(pipe_b, log_queue) + + await client.set_log_level(5) + await asyncio.sleep(0.01) + + result = await client.call_method("log_levels") + assert result == "ok" + await client.emit("topic1", "levels") + + logs = await _collect_logs(log_queue, expected_count=10) + + expected = set() + for level_name, level_value in LOG_LEVELS: + expected.add((level_value, 1, 0, f"method:{level_name}")) + expected.add((level_value, 3, "topic1", f"event:topic1:{level_name}")) + + actual = set((log[1], log[2], log[3], log[4]) for log in logs) + assert actual == expected + + await pipe_a.terminate() + await pipe_b.terminate() + + +@pytest.mark.asyncio +async def test_rpc_logging_respects_remote_level_setting(): + pipe_a, pipe_b = EventPipe.create_inmemory_pair() + server = LoggingServerRpc(pipe_a, "server") + client = LoggingClientRpc(pipe_b, "client") + + log_queue: asyncio.Queue = asyncio.Queue() + _attach_log_listener(pipe_b, log_queue) + + await client.set_log_level(2) + await asyncio.sleep(0.01) + + result = await client.call_method("log_levels") + assert result == "ok" + await client.emit("topic2", "levels") + + logs = await _collect_logs(log_queue, expected_count=4) + actual_levels = {log[1] for log in logs} + assert actual_levels == {1, 2} + + actual = set((log[1], log[2], log[3], log[4]) for log in logs) + expected = { + (1, 1, 0, "method:crit"), + (2, 1, 0, "method:warn"), + (1, 3, "topic2", "event:topic2:crit"), + (2, 3, "topic2", "event:topic2:warn"), + } + assert actual == expected + + await pipe_a.terminate() + await pipe_b.terminate() diff --git a/tests/test_rpc_v1.py b/tests/rpc/test_rpc_v1.py similarity index 62% rename from tests/test_rpc_v1.py rename to tests/rpc/test_rpc_v1.py index 427bc14..a27bd95 100644 --- a/tests/test_rpc_v1.py +++ b/tests/rpc/test_rpc_v1.py @@ -1,151 +1,178 @@ -import pytest import asyncio -from typing import Any, Generic, List -from unittest.mock import AsyncMock, MagicMock +from typing import Any, List -from cbor_rpc import Pipe, RpcV1, DeferredPromise -from tests.helpers import SimplePipe +import pytest +from cbor_rpc import EventPipe, RpcV1, TimedPromise +from cbor_rpc.rpc.context import RpcCallContext +from tests.helpers import SimplePipe async def sleep_method(seconds: float) -> None: await asyncio.sleep(seconds) return f"Slept for {seconds} seconds" + def add_method(a: int, b: int) -> int: return a + b + def multiply_method(a: int, b: int) -> int: return a * b + def throw_error_method(message: str) -> None: raise Exception(message) -# Method handler for RPC -def method_handler(method: str, args: List[Any]) -> Any: + +def method_handler(context: RpcCallContext, method: str, args: List[Any]) -> Any: if method == "sleep": return sleep_method(*args) - elif method == "add": + if method == "add": return add_method(*args) - elif method == "multiply": + if method == "multiply": return multiply_method(*args) - elif method == "throwError": + if method == "throwError": return throw_error_method(*args) raise Exception(f"Unknown method: {method}") -# Event handler for RPC -async def event_handler(topic: str, message: Any) -> None: - pass # No-op for testing + +class EventRpcHelper(RpcV1): + def get_id(self) -> str: + return "event_rpc" + + def handle_method_call(self, context: RpcCallContext, method: str, args: List[Any]) -> Any: + raise Exception("Event-only RPC") + + async def on_event(self, context: RpcCallContext, topic: str, payload: Any) -> None: + pass + @pytest.fixture def pipe(): return SimplePipe() + +@pytest.fixture +def event_rpc(pipe): + return EventRpcHelper(pipe) + + @pytest.fixture def rpc(pipe): - return RpcV1.make_rpc_v1(pipe, "test_id", method_handler, event_handler) + return RpcV1.make_rpc_v1(pipe, "test_id", method_handler) + @pytest.mark.asyncio async def test_get_id(rpc): assert rpc.get_id() == "test_id" + @pytest.mark.asyncio async def test_add_method_success(rpc): result = await rpc.call_method("add", 3, 4) assert result == 7 + @pytest.mark.asyncio async def test_multiply_method_success(rpc): result = await rpc.call_method("multiply", 5, 6) assert result == 30 + @pytest.mark.asyncio async def test_sleep_method_success(rpc): result = await rpc.call_method("sleep", 0.1) assert result == "Slept for 0.1 seconds" + @pytest.mark.asyncio async def test_throw_error_method(rpc): with pytest.raises(Exception) as exc_info: await rpc.call_method("throwError", "Test error") assert str(exc_info.value) == "Test error" + @pytest.mark.asyncio async def test_call_method_timeout(rpc, pipe): - rpc.set_timeout(100) # Set short timeout + rpc.set_timeout(100) with pytest.raises(Exception) as exc_info: - await rpc.call_method("sleep", 1) # Long sleep to trigger timeout + await rpc.call_method("sleep", 1) assert exc_info.value.args[0]["timeout"] is True assert exc_info.value.args[0]["timeoutPeriod"] == 100 + @pytest.mark.asyncio async def test_call_method_unknown_method(rpc): with pytest.raises(Exception) as exc_info: await rpc.call_method("unknown", 1) assert str(exc_info.value) == "Unknown method: unknown" + @pytest.mark.asyncio async def test_fire_method(rpc): await rpc.fire_method("add", 1, 2) - assert rpc._counter == 1 # Verify message was sent + assert rpc._counter == 1 + @pytest.mark.asyncio -async def test_emit_event(rpc): - await rpc.emit("test_topic", {"data": "test"}) - # No assertion needed; just verify no crash +async def test_emit_event(event_rpc): + await event_rpc.emit("test_topic", {"data": "test"}) + @pytest.mark.asyncio -async def test_wait_next_event_success(rpc, pipe): +async def test_wait_next_event_success(event_rpc, pipe): async def simulate_event(): await asyncio.sleep(0.1) - await pipe.write([1, 3, 0, "test_topic", {"data": "test"}]) + await pipe.write([3, 0, "test_topic", {"data": "test"}]) task = asyncio.create_task(simulate_event()) - result = await rpc.wait_next_event("test_topic", 1000) + result = await event_rpc.wait_next_event("test_topic", 1000) assert result == {"data": "test"} await task + @pytest.mark.asyncio -async def test_wait_next_event_timeout(rpc): +async def test_wait_next_event_timeout(event_rpc): with pytest.raises(Exception) as exc_info: - await rpc.wait_next_event("test_topic", 100) + await event_rpc.wait_next_event("test_topic", 100) assert exc_info.value.args[0]["timeout"] is True assert exc_info.value.args[0]["timeoutPeriod"] == 100 + @pytest.mark.asyncio -async def test_wait_next_event_already_waiting(rpc): - rpc._waiters["test_topic"] = DeferredPromise(1000) +async def test_wait_next_event_already_waiting(event_rpc): + event_rpc._waiters["test_topic"] = TimedPromise(1000) with pytest.raises(Exception) as exc_info: - await rpc.wait_next_event("test_topic") + await event_rpc.wait_next_event("test_topic") assert str(exc_info.value) == "Already waiting for event" + @pytest.mark.asyncio async def test_invalid_message_format(rpc, pipe): - await pipe.write([1, 2, 3]) # Invalid message - await asyncio.sleep(0.1) # Allow processing - # No crash means test passes + await pipe.write([1]) + await asyncio.sleep(0.1) + @pytest.mark.asyncio -async def test_unsupported_version(rpc, pipe): - await pipe.write([2, 0, 0, "add", [1, 2]]) # Wrong version - await asyncio.sleep(0.1) # Allow processing - # No crash means test passes +async def test_unsupported_protocol(rpc, pipe): + await pipe.write([99, 0, 0, "add", [1, 2]]) + await asyncio.sleep(0.1) + @pytest.mark.asyncio async def test_concurrent_method_calls(rpc): - tasks = [ - rpc.call_method("add", i, i) - for i in range(3) - ] + tasks = [rpc.call_method("add", i, i) for i in range(3)] results = await asyncio.gather(*tasks) assert results == [0, 2, 4] + @pytest.mark.asyncio async def test_read_only_client(pipe): read_only = RpcV1.read_only_client(SimplePipe()) - - # We need to directly call the handle_method_call method to test it + with pytest.raises(Exception) as exc_info: - read_only.handle_method_call("add", [1, 2]) - + context = RpcCallContext(read_only.logger) + read_only.handle_method_call(context, "add", [1, 2]) + assert str(exc_info.value) == "Client Only Implementation" diff --git a/tests/rpc/test_rpc_v1_extra.py b/tests/rpc/test_rpc_v1_extra.py new file mode 100644 index 0000000..6407d17 --- /dev/null +++ b/tests/rpc/test_rpc_v1_extra.py @@ -0,0 +1,224 @@ +import asyncio +from typing import Any, List + +import pytest + +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.rpc.context import RpcCallContext +from cbor_rpc.rpc.rpc_v1 import RpcV1, RpcCore +from cbor_rpc.rpc.rpc_server import RpcV1Server + + +class TestRpcServer(RpcV1Server): + async def handle_method_call( + self, + connection_id: str, + context: RpcCallContext, + method: str, + args: List[Any], + ) -> Any: + if method == "ping": + return f"pong:{connection_id}:{args[0]}" + raise Exception("Unknown method") + + +def _noop_handler(context: RpcCallContext, method: str, args: List[Any]) -> Any: + if method == "fire": + return None + return "ok" + + +class CoreOnlyRpc(RpcCore): + def get_id(self) -> str: + return "core" + + def handle_method_call(self, context: RpcCallContext, method: str, args: List[Any]) -> Any: + if method == "boom": + raise Exception("boom") + if method == "nested": + async def inner() -> str: + return "ok" + + async def outer(): + return inner() + + return outer() + return "ok" + + +@pytest.mark.asyncio +async def test_rpc_v1_proto_validation_and_logging(caplog): + import logging + caplog.set_level(logging.INFO) + pipe_a, pipe_b = EventPipe.create_inmemory_pair() + rpc = RpcV1.make_rpc_v1(pipe_a, "id", _noop_handler) + + await pipe_b.write([1, 0, 1]) + await pipe_b.write([1, 2, 1]) + await pipe_b.write([1, 9, 1, "x", []]) + await pipe_b.write([1, 0, 999, "ok"]) + await pipe_b.write([2, 1, 2]) + await pipe_b.write([2, 0, 3]) + await pipe_b.write([2, 99, 1, 2, "content"]) + await pipe_b.write([3, 0]) + await pipe_b.write([3, 1, "topic", "msg"]) + + await asyncio.sleep(0.05) + + assert rpc._peer_log_level == 3 + + logs = [r.message for r in caplog.records] + assert any("Invalid response format" in log for log in logs) + assert any("Invalid call format" in log for log in logs) + assert any("Unknown sub-protocol" in log for log in logs) + assert any("expired request id" in log for log in logs) + assert any("Invalid format" in log for log in logs) + assert any("[RemoteLog:LEVEL-99]" in log for log in logs) + assert any("Invalid event format" in log for log in logs) + + await pipe_a.terminate() + await pipe_b.terminate() + + +@pytest.mark.asyncio +async def test_rpc_v1_send_log_filters_by_peer_level(): + pipe_a, pipe_b = EventPipe.create_inmemory_pair() + rpc = RpcV1.make_rpc_v1(pipe_a, "id", _noop_handler) + + received: List[List[Any]] = [] + + async def on_data(data: Any) -> None: + if isinstance(data, list) and data and data[0] == 2: + received.append(data) + + pipe_b.pipeline("data", on_data) + + rpc.logger.log("skip") + await asyncio.sleep(0.01) + assert received == [] + + rpc._peer_log_level = 2 + rpc.logger.debug("skip") + await asyncio.sleep(0.01) + assert received == [] + + rpc.logger.warn("keep") + await asyncio.sleep(0.01) + assert received[-1][1] == 2 + + await pipe_a.terminate() + await pipe_b.terminate() + + +@pytest.mark.asyncio +async def test_rpc_v1_server_connection_lifecycle(): + server = TestRpcServer() + pipe_a, pipe_b = EventPipe.create_inmemory_pair() + + client_called = asyncio.Event() + + def client_handler(context: RpcCallContext, method: str, args: List[Any]) -> Any: + if method == "ping": + return f"client:{args[0]}" + if method == "fire": + client_called.set() + return None + raise Exception("Unknown method") + + client = RpcV1.make_rpc_v1(pipe_b, "client", client_handler) + + await server.add_connection("c1", pipe_a) + assert server.is_active("c1") + + result = await server.call_method("c1", "ping", "hi") + assert result == "client:hi" + + await server.fire_method("c1", "fire") + await asyncio.wait_for(client_called.wait(), timeout=1) + + called: List[bool] = [] + + def action(_client: Any) -> None: + called.append(True) + + assert server.with_client("c1", action) is True + assert called == [True] + + await server.disconnect("c1", "bye") + assert server.is_active("c1") is False + assert server.with_client("missing", action) is False + + await pipe_a.terminate() + await pipe_b.terminate() + + +@pytest.mark.asyncio +async def test_rpc_v1_server_inactive_client_errors(): + server = TestRpcServer() + + with pytest.raises(Exception) as exc_info: + await server.call_method("missing", "ping") + assert str(exc_info.value) == "Client is not active" + + with pytest.raises(Exception) as exc_info: + await server.fire_method("missing", "ping") + assert str(exc_info.value) == "Client is not active" + + +@pytest.mark.asyncio +async def test_rpc_core_fire_error_and_unsupported_event(caplog): + pipe_a, pipe_b = EventPipe.create_inmemory_pair() + _rpc = CoreOnlyRpc(pipe_a) + + await pipe_b.write([3, 0, "topic", "msg"]) + await pipe_b.write([1, 3, 1, "boom", []]) + await asyncio.sleep(0.05) + + logs = [r.message for r in caplog.records] + assert any("Unsupported event message" in log for log in logs) + assert any("Fired method error" in log for log in logs) + + await pipe_a.terminate() + await pipe_b.terminate() + + +@pytest.mark.asyncio +async def test_rpc_core_nested_result_and_error_response(): + pipe_a, pipe_b = EventPipe.create_inmemory_pair() + _rpc = CoreOnlyRpc(pipe_a) + + responses: List[List[Any]] = [] + + async def on_data(data: Any) -> None: + if isinstance(data, list) and data and data[0] == 1 and data[1] in (0, 1): + responses.append(data) + + pipe_b.pipeline("data", on_data) + + await pipe_b.write([1, 2, 1, "nested", []]) + await pipe_b.write([1, 2, 2, "boom", []]) + await asyncio.sleep(0.05) + + assert [1, 0, 1, "ok"] in responses + assert [1, 1, 2, "boom"] in responses + + await pipe_a.terminate() + await pipe_b.terminate() + + +@pytest.mark.asyncio +async def test_rpc_v1_server_close_cleanup_and_timeout_applied(): + server = TestRpcServer() + server.set_timeout(123) + pipe_a, _pipe_b = EventPipe.create_inmemory_pair() + await server.add_connection("c1", pipe_a) + assert server.rpc_clients["c1"]._timeout == 123 + + await pipe_a.terminate("done") + await asyncio.sleep(0.01) + assert server.is_active("c1") is False + + +def test_rpc_v1_server_get_client_and_disconnect_missing(): + server = TestRpcServer() + assert server.get_client("missing") is None diff --git a/tests/rpc/test_server_base.py b/tests/rpc/test_server_base.py new file mode 100644 index 0000000..e84ffa8 --- /dev/null +++ b/tests/rpc/test_server_base.py @@ -0,0 +1,69 @@ +import asyncio +from typing import Any, List + +import pytest + +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.rpc.server_base import Server + + +class DummyServer(Server[EventPipe]): + def __init__(self, accept_connections: bool = True): + super().__init__() + self._accept_connections = accept_connections + self.stopped = False + + async def start(self, *args, **kwargs) -> Any: + self._running = True + return "started" + + async def stop(self) -> None: + self._running = False + self.stopped = True + + async def accept(self, pipe: EventPipe) -> bool: + return self._accept_connections + + +@pytest.mark.asyncio +async def test_server_base_add_connection_accepts_and_cleans(): + server = DummyServer(accept_connections=True) + pipe_a, _pipe_b = EventPipe.create_inmemory_pair() + + connected: List[EventPipe] = [] + + def on_connection(pipe: EventPipe) -> None: + connected.append(pipe) + + server.on_connection(on_connection) + await server._add_connection(pipe_a) + assert pipe_a in server.get_connections() + assert connected == [pipe_a] + + await pipe_a.terminate("bye") + await asyncio.sleep(0) + assert pipe_a not in server.get_connections() + + +@pytest.mark.asyncio +async def test_server_base_rejects_connection(): + server = DummyServer(accept_connections=False) + pipe_a, _pipe_b = EventPipe.create_inmemory_pair() + + await server._add_connection(pipe_a) + assert pipe_a not in server.get_connections() + + +@pytest.mark.asyncio +async def test_server_base_close_all_connections_and_context(): + server = DummyServer(accept_connections=True) + pipe_a, _pipe_b = EventPipe.create_inmemory_pair() + await server._add_connection(pipe_a) + + await server.close_all_connections() + await asyncio.sleep(0) + assert server.get_connections() == set() + + async with server: + assert server.is_running() is False + assert server.stopped is True diff --git a/tests/ssh/test_ssh_docker_pipe.py b/tests/ssh/test_ssh_docker_pipe.py new file mode 100644 index 0000000..b6a7276 --- /dev/null +++ b/tests/ssh/test_ssh_docker_pipe.py @@ -0,0 +1,510 @@ +import asyncio +import os +import re +import time + +import asyncssh +import asyncssh.public_key +import docker +import pytest + +from cbor_rpc.ssh.ssh_pipe import SshPipe + +TEST_SSH_USER = "testuser" +TEST_SSH_PASSWORD = "testpassword" +SSHD_IMAGE_NAME = "cbor-rpc-py-sshd-python" +SSHD_CONTAINER_NAME = "test-sshd-container" +SSHD_DOCKERFILE_PATH = "./tests/docker/sshd-python" + + +@pytest.fixture(scope="session") +def ssh_keys(): + private_key_obj = asyncssh.generate_private_key("ssh-rsa") + passphrase = "test_passphrase" + + return { + "unencrypted_private": private_key_obj.export_private_key().decode(), + "unencrypted_public": private_key_obj.export_public_key().decode(), + "encrypted_private": private_key_obj.export_private_key(passphrase=passphrase).decode(), + "encrypted_public": private_key_obj.export_public_key().decode(), + "passphrase": passphrase, + } + + +@pytest.fixture(scope="session") +def docker_client(): + client = docker.from_env() + yield client + client.close() + + +@pytest.fixture(scope="session") +def test_network(docker_client: docker.DockerClient): + network_name = "test-ssh-network" + try: + network = docker_client.networks.get(network_name) + network.remove() + except docker.errors.NotFound: + pass + + network = docker_client.networks.create(network_name, driver="bridge") + yield network + network.remove() + + +@pytest.fixture(scope="module") +async def ssh_container_combined_auth(docker_client: docker.DockerClient, test_network, docker_host_ip, ssh_keys): + container_name = "ssh-test-container-combined-auth" + ssh_user = TEST_SSH_USER + ssh_password = TEST_SSH_PASSWORD + public_key = ssh_keys["unencrypted_public"] + + try: + existing_container = docker_client.containers.get(container_name) + print(f"Found existing container '{container_name}'. Stopping and removing...") + existing_container.stop() + existing_container.remove() + print(f"Removed existing container '{container_name}'.") + except docker.errors.NotFound: + print(f"No existing container '{container_name}' found. Proceeding.") + except Exception as e: + print(f"Error cleaning up existing container: {e}") + + print(f"\nBuilding Docker image '{SSHD_IMAGE_NAME}' from '{SSHD_DOCKERFILE_PATH}'...") + try: + docker_client.images.build( + path=SSHD_DOCKERFILE_PATH, + tag=SSHD_IMAGE_NAME, + rm=True, + ) + print(f"Docker image '{SSHD_IMAGE_NAME}' built successfully.") + except docker.errors.BuildError as e: + print(f"Failed to build Docker image: {e}") + raise RuntimeError(f"Failed to build Docker image '{SSHD_IMAGE_NAME}'.") + + print(f"\nStarting {SSHD_IMAGE_NAME} container with combined auth...") + container = None + try: + container = docker_client.containers.run( + SSHD_IMAGE_NAME, + detach=True, + ports={"2222/tcp": None}, + network=test_network.name, + name=container_name, + environment={ + "PUID": "1000", + "PGID": "1000", + "TZ": "Etc/UTC", + "PASSWORD_ACCESS": "true", + "USER_NAME": ssh_user, + "USER_PASSWORD": ssh_password, + "PUBLIC_KEY": public_key, + }, + restart_policy={"Name": "no"}, + ) + + container.reload() + host_port = None + for _ in range(30): + container.reload() + if "2222/tcp" in container.ports and container.ports["2222/tcp"]: + host_port = container.ports["2222/tcp"][0]["HostPort"] + break + time.sleep(1) + + if host_port is None: + raise RuntimeError("Failed to get host port for SSH container within 30 seconds.") + + print(f"SSH container with combined auth running on host port: {host_port}") + + ready = False + for i in range(60): + try: + + async def check_ssh_combined(): + try: + conn_pw = await asyncssh.connect( + docker_host_ip, + port=int(host_port), + username=ssh_user, + password=ssh_password, + known_hosts=None, + ) + conn_pw.close() + print("Password auth check successful.") + except (asyncssh.Error, OSError) as e: + print(f"Password auth check failed: {e}") + return False + + try: + conn_key = await asyncssh.connect( + docker_host_ip, + port=int(host_port), + username=ssh_user, + client_keys=[asyncssh.import_private_key(ssh_keys["unencrypted_private"])], + known_hosts=None, + ) + conn_key.close() + print("Public key auth check successful.") + except (asyncssh.Error, OSError) as e: + print(f"Public key auth check failed: {e}") + return False + + return True + + if await check_ssh_combined(): + print(f"SSH server with combined auth is ready after {i+1} seconds.") + ready = True + break + except Exception as e: + print(f"Error during SSH readiness check: {e}") + pass + time.sleep(1) + + if not ready: + print("\nSSH server with combined auth did not become ready in time. Container logs:") + if container: + print(container.logs().decode("utf-8")) + raise RuntimeError("SSH server with combined auth did not become ready in time.") + + yield container, docker_host_ip, host_port, ssh_user, ssh_password, ssh_keys["unencrypted_private"] + finally: + if container: + print("Stopping and removing ssh-test-container-combined-auth...") + print("=========================== Container Logs Start ===========================") + print(container.logs().decode("utf-8")) + print("=========================== Container Logs End ===========================") + container.stop() + container.remove() + + +@pytest.fixture(scope="session") +def docker_host_ip(): + docker_host = os.environ.get("DOCKER_HOST") + + regex = r"^(?:(tcp|unix)://)?([a-zA-Z0-9.-]+)(?::\d+)?$" + + if docker_host: + match = re.match(regex, docker_host) + if match: + protocol, host = match.groups() + if protocol == "unix": + return "localhost" + return host + return "localhost" + + +@pytest.mark.asyncio +async def test_ssh_pipe_with_hello_world_emitter(ssh_container_combined_auth): + container, host, port, username, password, _ = ssh_container_combined_auth + + print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} with password authentication...") + pipe = None + try: + emitter_command = "sh -c 'while true; do echo \"hello world\"; sleep 1; done'" + pipe = await SshPipe.connect( + host=host, + port=port, + username=username, + password=password, + known_hosts=None, + timeout=10, + command=emitter_command, + ) + + received_event = asyncio.Event() + received_data = [] + + def on_data_callback(data): + print(f"test_ssh_pipe_with_hello_world_emitter: Received data: {data!r}") + received_data.append(data) + if b"hello world" in data: + received_event.set() + + pipe.pipeline("data", on_data_callback) + + try: + await asyncio.wait_for(received_event.wait(), timeout=10) + except asyncio.TimeoutError: + pytest.fail("Did not receive 'hello world' within 10 seconds.") + + full_received_data = b"".join(received_data) + assert b"hello world" in full_received_data + + finally: + if pipe: + await pipe.terminate() + print("SshPipe closed.") + + +@pytest.mark.asyncio +async def test_ssh_pipe_with_password_authentication(ssh_container_combined_auth): + container, host, port, username, password, _ = ssh_container_combined_auth + + print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} with password authentication...") + pipe = None + try: + test_command = "echo 'Password auth successful!'" + pipe = await SshPipe.connect( + host=host, + port=port, + username=username, + password=password, + known_hosts=None, + timeout=10, + command=test_command, + ) + + received_event = asyncio.Event() + received_data = [] + + def on_data_callback(data): + print(f"test_ssh_pipe_with_password_authentication: Received data: {data!r}") + received_data.append(data) + if b"Password auth successful!" in data: + received_event.set() + + pipe.pipeline("data", on_data_callback) + + try: + await asyncio.wait_for(received_event.wait(), timeout=10) + except asyncio.TimeoutError: + pytest.fail("Did not receive expected output within 10 seconds for password authentication test.") + + full_received_data = b"".join(received_data) + assert b"Password auth successful!" in full_received_data + print("Password authentication successful.") + + except asyncssh.Error as e: + pytest.fail(f"SSH connection or command failed with password authentication: {e}") + except asyncio.TimeoutError: + pytest.fail("SSH connection timed out with password authentication.") + except Exception as e: + pytest.fail(f"An unexpected error occurred during password authentication test: {e}") + finally: + if pipe: + await pipe.terminate() + print("SshPipe closed for password authentication test.") + + +@pytest.mark.asyncio +async def test_ssh_pipe_with_plain_key_authentication(ssh_container_combined_auth, ssh_keys): + container, host, port, username, _, unencrypted_private_key_content = ssh_container_combined_auth + + print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} with plain key authentication...") + + pipe = None + try: + test_command = "echo 'Plain key auth successful!'" + pipe = await SshPipe.connect( + host=host, + port=port, + username=username, + ssh_key_content=unencrypted_private_key_content, + known_hosts=None, + timeout=10, + command=test_command, + ) + + received_event = asyncio.Event() + received_data = [] + + def on_data_callback(data): + print(f"test_ssh_pipe_with_plain_key_authentication: Received data: {data!r}") + received_data.append(data) + if b"Plain key auth successful!" in data: + received_event.set() + + pipe.pipeline("data", on_data_callback) + + try: + await asyncio.wait_for(received_event.wait(), timeout=10) + except asyncio.TimeoutError: + pytest.fail("Did not receive expected output within 10 seconds for plain key authentication test.") + + full_received_data = b"".join(received_data) + assert b"Plain key auth successful!" in full_received_data + print("Plain key authentication successful.") + + except asyncssh.Error as e: + pytest.fail(f"SSH connection or command failed with plain key authentication: {e}") + except asyncio.TimeoutError: + pytest.fail("SSH connection timed out with plain key authentication.") + except Exception as e: + pytest.fail(f"An unexpected error occurred during plain key authentication test: {e}") + finally: + if pipe: + await pipe.terminate() + print("SshPipe closed for plain key authentication test.") + + +@pytest.mark.asyncio +async def test_ssh_pipe_with_encrypted_key_authentication(ssh_container_combined_auth, ssh_keys): + container, host, port, username, _, _ = ssh_container_combined_auth + + private_key_content = ssh_keys["encrypted_private"] + passphrase = ssh_keys["passphrase"] + + print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} with encrypted key authentication...") + + pipe = None + try: + test_command = "echo 'Encrypted key auth successful!'" + pipe = await SshPipe.connect( + host=host, + port=port, + username=username, + ssh_key_content=private_key_content, + ssh_key_passphrase=passphrase, + known_hosts=None, + timeout=10, + command=test_command, + ) + + received_event = asyncio.Event() + received_data = [] + + def on_data_callback(data): + print(f"test_ssh_pipe_with_encrypted_key_authentication: Received data: {data!r}") + received_data.append(data) + if b"Encrypted key auth successful!" in data: + received_event.set() + + pipe.pipeline("data", on_data_callback) + + try: + await asyncio.wait_for(received_event.wait(), timeout=10) + except asyncio.TimeoutError: + pytest.fail("Did not receive expected output within 10 seconds for encrypted key authentication test.") + + full_received_data = b"".join(received_data) + assert b"Encrypted key auth successful!" in full_received_data + print("Encrypted key authentication successful.") + + except asyncssh.Error as e: + pytest.fail(f"SSH connection or command failed with encrypted key authentication: {e}") + except asyncio.TimeoutError: + pytest.fail("SSH connection timed out.") + except Exception as e: + pytest.fail(f"An unexpected error occurred during encrypted key authentication test: {e}") + finally: + if pipe: + await pipe.terminate() + print("SshPipe closed for encrypted key authentication test.") + + +@pytest.mark.asyncio +async def test_ssh_pipe_with_echo_back_command(ssh_container_combined_auth): + container, host, port, username, password, _ = ssh_container_combined_auth + + print( + f"\nAttempting SshPipe connection to {host}:{port} as user {username} for echo-back test (using echo_back.py) with password authentication..." + ) + pipe = None + try: + echo_back_command = "python3 /usr/local/bin/echo_back.py" + pipe = await SshPipe.connect( + host=host, + port=port, + username=username, + password=password, + known_hosts=None, + timeout=10, + command=echo_back_command, + ) + + received_data_chunks = [] + received_event = asyncio.Event() + + def on_data_callback(data): + print(f"test_ssh_pipe_with_echo_back_command: Received data chunk: {data!r}") + received_data_chunks.append(data) + received_event.set() + + pipe.pipeline("data", on_data_callback) + + test_message = b"This is a test message for echo back\n" + print(f"Writing data to pipe: {test_message!r}") + await pipe.write(test_message) + await asyncio.sleep(1) + + try: + await asyncio.wait_for(received_event.wait(), timeout=10) + except asyncio.TimeoutError: + pytest.fail("Did not receive any data within 10 seconds.") + + full_received_data = b"".join(received_data_chunks) + assert ( + full_received_data.strip() == test_message.strip() + ), f"Received data {full_received_data!r} should exactly match sent data {test_message!r}" + print("Verification successful: Data echoed correctly by 'echo_back.py' script.") + await pipe.write_eof() + + except asyncssh.Error as e: + pytest.fail(f"SSH connection or command failed: {e}") + except asyncio.TimeoutError: + pytest.fail("SSH connection timed out.") + except Exception as e: + pytest.fail(f"An unexpected error occurred: {e}") + finally: + if pipe: + await pipe.terminate() + print("SshPipe closed.") + + +@pytest.mark.asyncio +async def test_ssh_pipe_with_binary_data(ssh_container_combined_auth): + container, host, port, username, password, _ = ssh_container_combined_auth + + print( + f"\nAttempting SshPipe connection to {host}:{port} as user {username} for binary data emitter test with password authentication..." + ) + pipe = None + try: + emitter_binary_command = "python3 /usr/local/bin/binary_emitter.py" + pipe = await SshPipe.connect( + host=host, + port=port, + username=username, + password=password, + known_hosts=None, + timeout=10, + command=emitter_binary_command, + ) + + received_event = asyncio.Event() + received_data_chunks = [] + expected_binary_pattern = b"\xde\xad\xbe\xef\x00\x01\x02\x03\x80\xff\x7f" + + def on_data_callback(data): + print(f"test_ssh_pipe_with_binary_data: Received data chunk: {data!r}") + cleaned_data = data.replace(b"\r", b"").replace(b"\n", b"") + received_data_chunks.append(cleaned_data) + if expected_binary_pattern in cleaned_data: + received_event.set() + + pipe.pipeline("data", on_data_callback) + + try: + await asyncio.wait_for(received_event.wait(), timeout=10) + except asyncio.TimeoutError: + pytest.fail("Did not receive expected binary pattern within 10 seconds.") + + full_received_data = b"".join(received_data_chunks) + print(f"Received total {len(full_received_data)} bytes. First 50 bytes: {full_received_data[:50]!r}") + + assert ( + expected_binary_pattern in full_received_data + ), f"Expected binary pattern {expected_binary_pattern!r} not found in received data {full_received_data!r}" + print("Verification successful: Binary data emitted and received correctly.") + + except asyncssh.Error as e: + pytest.fail(f"SSH connection or command failed: {e}") + except asyncio.TimeoutError: + pytest.fail("SSH connection timed out.") + except Exception as e: + pytest.fail(f"An unexpected error occurred: {e}") + finally: + if pipe: + await pipe.terminate() + print("SshPipe closed.") diff --git a/tests/ssh/test_ssh_pipe.py b/tests/ssh/test_ssh_pipe.py new file mode 100644 index 0000000..4b11c13 --- /dev/null +++ b/tests/ssh/test_ssh_pipe.py @@ -0,0 +1,43 @@ +import pytest + +from cbor_rpc.ssh.ssh_pipe import SshPipe +from tests.helpers.stream_pair import create_stream_pair + + +@pytest.mark.asyncio +async def test_ssh_pipe_terminate_and_write_eof(): + server, reader, writer = await create_stream_pair() + + class DummyChannel: + def __init__(self): + self._closed = False + + def is_closing(self) -> bool: + return self._closed + + def close(self) -> None: + self._closed = True + + async def wait_closed(self) -> None: + return None + + class DummyClient: + def __init__(self): + self._closed = False + + def is_closed(self) -> bool: + return self._closed + + def close(self) -> None: + self._closed = True + + async def wait_closed(self) -> None: + return None + + pipe = SshPipe(reader, writer, DummyClient(), DummyChannel()) + + await pipe.write_eof() + await pipe.terminate() + + server.close() + await server.wait_closed() diff --git a/tests/stdio/test_stdio_pipe.py b/tests/stdio/test_stdio_pipe.py new file mode 100644 index 0000000..4da2b9e --- /dev/null +++ b/tests/stdio/test_stdio_pipe.py @@ -0,0 +1,60 @@ +import asyncio +import sys + +import pytest + +from cbor_rpc.stdio.stdio_pipe import StdioPipe +from tests.helpers.stream_pair import create_stream_pair + + +@pytest.mark.asyncio +async def test_stdio_pipe_errors_without_process(): + server, reader, writer = await create_stream_pair() + pipe = StdioPipe(reader, writer) + + with pytest.raises(RuntimeError): + await pipe.wait_for_process_termination() + + # Should not raise + await pipe.terminate() + + writer.close() + await writer.wait_closed() + server.close() + await server.wait_closed() + + +@pytest.mark.asyncio +async def test_stdio_pipe_start_process_and_terminate(): + pipe = await StdioPipe.start_process(sys.executable, "-c", "import time; time.sleep(0.2)") + await pipe.terminate() + code = await pipe.wait_for_process_termination() + assert isinstance(code, int) + + +@pytest.mark.asyncio +async def test_stdio_pipe_read_write(): + pipe = await StdioPipe.start_process("/bin/bash", "-c", "cat -") + + received_data = [] + future = asyncio.Future() + + def on_data(data): + received_data.append(data) + if len(received_data) == 10: + future.set_result(None) + + pipe.pipeline("data", on_data) + + test_data = [f"Test data {i}\n".encode("utf-8") for i in range(10)] + for data in test_data: + await pipe.write(data) + await asyncio.sleep(0.01) + + await future + + assert len(received_data) == 10 + for i, (sent, received) in enumerate(zip(test_data, received_data)): + assert received == sent, f"Mismatch at index {i}: expected {sent!r}, got {received!r}" + + await pipe.terminate() diff --git a/tests/test_tcp.py b/tests/tcp/test_tcp.py similarity index 63% rename from tests/test_tcp.py rename to tests/tcp/test_tcp.py index e90123d..8c537da 100644 --- a/tests/test_tcp.py +++ b/tests/tcp/test_tcp.py @@ -1,352 +1,378 @@ -import pytest import asyncio from typing import List -from cbor_rpc import TcpPipe, TcpServer + +import pytest + +from cbor_rpc import TcpPipe +from tests.helpers.simple_tcp_server import SimpleTcpServer + +DEFAULT_TIMEOUT = 1.0 @pytest.mark.asyncio async def test_tcp_client_server_connection(): - """Test basic TCP client-server connection.""" - # Start a server - server = await TcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() - + connections = [] - + def on_connection(tcp_pipe: TcpPipe): connections.append(tcp_pipe) - + server.on_connection(on_connection) - + try: - # Create a client connection client = await TcpPipe.create_connection(server_host, server_port) - - # Wait for server to register the connection await asyncio.sleep(0.1) - + assert len(connections) == 1 assert client.is_connected() assert connections[0].is_connected() - - # Test peer info + client_peer = client.get_peer_info() server_conn_peer = connections[0].get_peer_info() - + assert client_peer == (server_host, server_port) assert server_conn_peer is not None - + await client.terminate() - + finally: - await server.close() + await server.stop() @pytest.mark.asyncio async def test_tcp_data_exchange(): - """Test bidirectional data exchange over TCP.""" - server = await TcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() - + server_received = [] client_received = [] server_connection = None - + async def on_connection(tcp_pipe: TcpPipe): nonlocal server_connection server_connection = tcp_pipe - + async def on_server_data(data: bytes): server_received.append(data) - + tcp_pipe.on("data", on_server_data) - + server.on_connection(on_connection) - + try: - # Create client client = await TcpPipe.create_connection(server_host, server_port) - + async def on_client_data(data: bytes): client_received.append(data) - + client.on("data", on_client_data) - - # Wait for connection to be established + await asyncio.sleep(0.1) assert server_connection is not None - - # Send data from client to server + await client.write(b"Hello from client") await asyncio.sleep(0.1) assert server_received == [b"Hello from client"] - - # Send data from server to client + await server_connection.write(b"Hello from server") await asyncio.sleep(0.1) assert client_received == [b"Hello from server"] - - # Send multiple messages separately + server_received.clear() client_received.clear() - + await client.write(b"Message 1") - await asyncio.sleep(0.05) # Small delay between messages + await asyncio.sleep(0.05) await client.write(b"Message 2") await asyncio.sleep(0.1) - + await server_connection.write(b"Response 1") - await asyncio.sleep(0.05) # Small delay between messages + await asyncio.sleep(0.05) await server_connection.write(b"Response 2") await asyncio.sleep(0.1) - - # Check that messages were received (they might be combined due to TCP buffering) + server_data = b"".join(server_received) client_data = b"".join(client_received) - + assert b"Message 1" in server_data assert b"Message 2" in server_data assert b"Response 1" in client_data assert b"Response 2" in client_data - + await client.terminate() - + finally: - await server.close() + await server.stop() @pytest.mark.asyncio async def test_tcp_connection_errors(): - """Test TCP connection error handling.""" - # Test connection to non-existent server with pytest.raises(ConnectionError): - await TcpPipe.create_connection('127.0.0.1', 12345, timeout=0.1) - - # Test writing to disconnected client + await TcpPipe.create_connection("127.0.0.1", 12345, timeout=0.1) + client = TcpPipe() with pytest.raises(ConnectionError): await client.write(b"test") - - # Test double connection - server = await TcpServer.create('127.0.0.1', 0) + + server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() - + try: client = await TcpPipe.create_connection(server_host, server_port) - + with pytest.raises(ConnectionError): await client.connect(server_host, server_port) - + await client.terminate() - + finally: - await server.close() + await server.stop() @pytest.mark.asyncio async def test_tcp_connection_events(): - """Test TCP connection events (connect, close, error).""" - server = await TcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() - + events = [] - - server.on_connection(lambda conn: events.append("server_connect")) - + + server.on_connection(lambda conn: events.append("server_connect")) + try: - # Create client with event handlers client = TcpPipe() - + async def on_client_connect(): events.append("client_connect") - + async def on_client_close(*args): events.append(("client_close", args)) - + async def on_client_error(error): events.append(("client_error", str(error))) - + client.on("connect", on_client_connect) client.on("close", on_client_close) client.on("error", on_client_error) - - # Connect + await client.connect(server_host, server_port) - await asyncio.sleep(0.2) # Give more time for events to propagate - + await asyncio.sleep(0.2) + assert "client_connect" in events - print(events) assert "server_connect" in events - - # Close connection + events_before_close = len(events) await client.terminate("test_reason") - await asyncio.sleep(0.2) # Give time for close events - - # Check that close event was added + await asyncio.sleep(0.2) + assert len(events) > events_before_close close_events = [e for e in events if isinstance(e, tuple) and e[0] == "client_close"] assert len(close_events) > 0 - + finally: - await server.close() + await server.stop() @pytest.mark.asyncio async def test_tcp_client_connection_tracking(): - """Test handling multiple simultaneous TCP connections.""" - server = await TcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() - - - + try: - # Create multiple clients clients: List[TcpPipe] = [] server.on_connection(lambda conn: print(f"New connection: {conn}")) for i in range(5): client = await TcpPipe.create_connection(server_host, server_port) client.on("close", lambda: print(f"Connection[{i}] closed")) clients.append(client) - - - await asyncio.sleep(0.5) # Give time for all connections to be registered - - # Check that all connections are registered + + await asyncio.sleep(0.5) + assert len(server.get_connections()) == 5 - - # Close all clients + for client in clients: await client.terminate() - - await asyncio.sleep(0.2) + await asyncio.sleep(0.2) assert len(server.get_connections()) == 0, "Connections not clean uped" - - finally: - await server.close() + await server.stop() + @pytest.mark.asyncio async def test_tcp_client_connection_tracking_self(): - """Test handling multiple simultaneous TCP connections.""" - server = await TcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() - - + try: - # Create multiple clients clients: List[TcpPipe] = [] server.on_connection(lambda conn: print(f"New connection: {conn}")) for i in range(5): client = await TcpPipe.create_connection(server_host, server_port) client.on("close", lambda: print(f"Connection[{i}] closed")) clients.append(client) - - - await asyncio.sleep(0.5) # Give time for all connections to be registered - - # Check that all connections are registered + + await asyncio.sleep(0.5) + assert len(server.get_connections()) == 5 - - # Close all clients + for duplex in server.get_connections(): await duplex.terminate() - await asyncio.sleep(0.2) + await asyncio.sleep(0.2) assert len(server.get_connections()) == 0, "Connections not clean uped" - - finally: - await server.close() + await server.stop() + @pytest.mark.asyncio async def test_tcp_large_data_transfer(): - """Test transferring large amounts of data over TCP.""" - server = await TcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() - + received_data = bytearray() server_connection = None - + async def on_connection(tcp_pipe: TcpPipe): nonlocal server_connection server_connection = tcp_pipe - + async def on_data(data: bytes): received_data.extend(data) - + tcp_pipe.on("data", on_data) - + server.on_connection(on_connection) - + try: client = await TcpPipe.create_connection(server_host, server_port) await asyncio.sleep(0.1) - - # Send large data (100KB instead of 1MB for faster testing) + large_data = b"x" * (100 * 1024 * 1024) await client.write(large_data) - - # Wait for all data to be received - timeout = 5.0 # 5 second timeout + + timeout = 5.0 start_time = asyncio.get_event_loop().time() while len(received_data) < len(large_data): if asyncio.get_event_loop().time() - start_time > timeout: break await asyncio.sleep(0.1) - + assert bytes(received_data) == large_data - + await client.terminate() - + finally: - await server.close() + await server.stop() @pytest.mark.asyncio async def test_tcp_server_context_manager(): - """Test using TcpServer as a context manager.""" - async with await TcpServer.create('127.0.0.1', 0) as server: + async with await SimpleTcpServer.create("127.0.0.1", 0) as server: server_host, server_port = server.get_address() - + client = await TcpPipe.create_connection(server_host, server_port) assert client.is_connected() - + await client.terminate() - - # Server should be closed automatically @pytest.mark.asyncio async def test_tcp_invalid_data_types(): - """Test error handling for invalid data types.""" - server = await TcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() - + try: client = await TcpPipe.create_connection(server_host, server_port) - - # Test writing non-bytes data + with pytest.raises(TypeError): - await client.write("string data") # Should be bytes - + await client.write("string data") + with pytest.raises(TypeError): - await client.write(123) # Should be bytes - - # Test writing valid data types - await client.write(b"bytes data") # Should work - await client.write(bytearray(b"bytearray data")) # Should work - + await client.write(123) + + await client.write(b"bytes data") + await client.write(bytearray(b"bytearray data")) + await client.terminate() - + + finally: + await server.stop() + + +@pytest.mark.asyncio +async def test_tcp_inmemory_pair_bidirectional_exchange(): + client_pipe, server_pipe = await TcpPipe.create_inmemory_pair() + + try: + assert client_pipe.is_connected() + assert server_pipe.is_connected() + + client_received = asyncio.Queue() + server_received = asyncio.Queue() + + client_pipe.on("data", client_received.put_nowait) + server_pipe.on("data", server_received.put_nowait) + + await client_pipe.write(b"ping") + server_data = await asyncio.wait_for(server_received.get(), timeout=DEFAULT_TIMEOUT) + assert server_data == b"ping" + + await server_pipe.write(b"pong") + client_data = await asyncio.wait_for(client_received.get(), timeout=DEFAULT_TIMEOUT) + assert client_data == b"pong" + finally: - await server.close() + await client_pipe.terminate() + await server_pipe.terminate() + +@pytest.mark.asyncio +async def test_tcp_shutdown_keeps_active_connections(): + server = await SimpleTcpServer.create("127.0.0.1", 0) + server_host, server_port = server.get_address() -if __name__ == "__main__": - pytest.main(["-v", __file__]) + server_connection = None + + async def on_connection(tcp_pipe: TcpPipe): + nonlocal server_connection + server_connection = tcp_pipe + + server.on_connection(on_connection) + + try: + client = await TcpPipe.create_connection(server_host, server_port) + await asyncio.wait_for(asyncio.sleep(0.1), timeout=DEFAULT_TIMEOUT) + assert server_connection is not None + + await server.shutdown() + + client_received = asyncio.Queue() + server_received = asyncio.Queue() + + client.on("data", client_received.put_nowait) + server_connection.on("data", server_received.put_nowait) + + await client.write(b"still-alive") + server_data = await asyncio.wait_for(server_received.get(), timeout=DEFAULT_TIMEOUT) + assert server_data == b"still-alive" + + await server_connection.write(b"still-alive-2") + client_data = await asyncio.wait_for(client_received.get(), timeout=DEFAULT_TIMEOUT) + assert client_data == b"still-alive-2" + + with pytest.raises(ConnectionError) as exc_info: + await TcpPipe.create_connection(server_host, server_port, timeout=0.2) + error_text = str(exc_info.value).lower() + assert "refused" in error_text or "connect call failed" in error_text + + await client.terminate() + await server_connection.terminate() + + finally: + await server.stop() diff --git a/tests/tcp/test_tcp_pipe_errors.py b/tests/tcp/test_tcp_pipe_errors.py new file mode 100644 index 0000000..43d1940 --- /dev/null +++ b/tests/tcp/test_tcp_pipe_errors.py @@ -0,0 +1,33 @@ +import pytest + +from cbor_rpc.tcp.tcp import TcpPipe +from tests.helpers.simple_tcp_server import SimpleTcpServer + + +@pytest.mark.asyncio +async def test_tcp_pipe_error_paths(): + assert TcpPipe().get_peer_info() is None + assert TcpPipe().get_local_info() is None + + class BadWriter: + def get_extra_info(self, _name: str): + raise RuntimeError("boom") + + pipe = TcpPipe() + pipe._connected = True + pipe._writer = BadWriter() + assert pipe.get_peer_info() is None + assert pipe.get_local_info() is None + + server = await SimpleTcpServer.create("127.0.0.1", 0) + host, port = server.get_address() + await server.stop() + + with pytest.raises(ConnectionError): + await TcpPipe.create_connection(host, port) + + client, _server_pipe = await TcpPipe.create_inmemory_pair() + with pytest.raises(ConnectionError): + await client.connect("127.0.0.1", port) + + await client.terminate() diff --git a/tests/test_async_pipe.py b/tests/test_async_pipe.py deleted file mode 100644 index 03037e0..0000000 --- a/tests/test_async_pipe.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest -import asyncio -from typing import Any, Tuple -from cbor_rpc import Pipe - -class SimplePipe(Pipe): - def __init__(self): - super().__init__() - self._closed = False - - async def write(self, chunk: Any) -> bool: - if self._closed: - return False - # Simulate writing data asynchronously - await asyncio.sleep(0.1) - return True - - async def terminate(self, *args: Any) -> None: - if self._closed: - return - self._closed = True - await self._emit("close", *args) - -@pytest.fixture -def pipe(): - return SimplePipe() - -@pytest.mark.asyncio -async def test_create_pair(): - # Positive case: Creating a pair of async pipes - pipe1, pipe2 = Pipe.create_pair() - assert isinstance(pipe1, Pipe) - assert isinstance(pipe2, Pipe) - -@pytest.mark.asyncio -async def test_write_success(pipe): - # Positive case: Writing a chunk successfully - result = await pipe.write("test_chunk") - assert result is True - -@pytest.mark.asyncio -async def test_terminate_success(pipe): - # Positive case: Terminating the pipe - await pipe.terminate() - # No exception should be raised - -@pytest.mark.asyncio -async def test_pipeline_execution(pipe): - # Positive case: Adding and executing a pipeline - called = False - async def pipeline_handler(chunk: Any) -> None: - nonlocal called - called = True - - pipe.pipeline("data", pipeline_handler) - await pipe._notify("data", "test_chunk") - assert called is True - -@pytest.mark.asyncio -async def test_attach_pipes(): - # Positive case: Attaching two pipes - pipe1, pipe2 = Pipe.create_pair() - - called = False - async def handler(chunk: Any) -> None: - nonlocal called - called = True - - pipe2.on("data", handler) - await pipe1.write("test_chunk") - assert called is True - -@pytest.mark.asyncio -async def test_write_after_terminate(): - # Negative case: Writing to a terminated pipe - pipe1, _ = Pipe.create_pair() - await pipe1.terminate() - - result = await pipe1.write("test_chunk") - assert result is False - -if __name__ == "__main__": - pytest.main() diff --git a/tests/test_event_emitter.py b/tests/test_event_emitter.py deleted file mode 100644 index 1573bbd..0000000 --- a/tests/test_event_emitter.py +++ /dev/null @@ -1,368 +0,0 @@ -import pytest -import asyncio -from typing import Any, Callable -from cbor_rpc.emitter import AbstractEmitter - -@pytest.mark.asyncio -async def test_on_and_emit(): - """Test that subscribers registered with 'on' are called by '_emit' in registration order.""" - class TestEmitter(AbstractEmitter): - pass - - emitter = TestEmitter() - events = [] - - async def async_handler1(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler1_{data}") - - def sync_handler1(data: Any): - events.append(f"sync_handler1_{data}") - - async def async_handler2(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler2_{data}") - - # Register subscribers - emitter.on("test", async_handler1) - emitter.on("test", sync_handler1) - emitter.on("test", async_handler2) - - # Emit event - await emitter._emit("test", "event1") - await asyncio.sleep(0.02) # Allow async handlers to complete - - # Verify all subscribers ran (order may vary due to concurrency) - expected = [ - f"async_handler1_event1", - f"sync_handler1_event1", - f"async_handler2_event1" - ] - assert sorted(events) == sorted(expected), f"Expected {expected}, got {events}" - -@pytest.mark.asyncio -async def test_pipeline_and_notify(): - """Test that pipelines run before subscribers in '_notify', respecting registration order.""" - class TestEmitter(AbstractEmitter): - pass - - emitter = TestEmitter() - events = [] - - async def async_handler1(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler1_{data}") - - def sync_handler1(data: Any): - events.append(f"sync_handler1_{data}") - - async def async_pipeline1(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_pipeline1_{data}") - - def sync_pipeline1(data: Any): - events.append(f"sync_pipeline1_{data}") - - async def async_pipeline2(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_pipeline2_{data}") - - # Register handlers and pipelines - emitter.on("test", async_handler1) - emitter.on("test", sync_handler1) - emitter.pipeline("test", async_pipeline1) - emitter.pipeline("test", sync_pipeline1) - emitter.pipeline("test", async_pipeline2) - - # Notify event - await emitter._notify("test", "event2") - await asyncio.sleep(0.02) # Allow async pipelines and handlers to complete - - # Verify pipelines run before subscribers - expected_pipelines = [ - f"async_pipeline1_event2", - f"sync_pipeline1_event2", - f"async_pipeline2_event2" - ] - expected_subscribers = [ - f"async_handler1_event2", - f"sync_handler1_event2" - ] - pipeline_indices = [events.index(e) for e in expected_pipelines if e in events] - subscriber_indices = [events.index(e) for e in expected_subscribers if e in events] - assert all(p < s for p in pipeline_indices for s in subscriber_indices), \ - f"Pipelines {expected_pipelines} should precede subscribers {expected_subscribers} in {events}" - assert sorted(events) == sorted(expected_pipelines + expected_subscribers), \ - f"Expected {expected_pipelines + expected_subscribers}, got {events}" - -@pytest.mark.asyncio -async def test_unsubscribe(): - """Test that unsubscribing a handler removes it from the subscriber list.""" - class TestEmitter(AbstractEmitter): - pass - - emitter = TestEmitter() - events = [] - - async def async_handler1(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler1_{data}") - - def sync_handler1(data: Any): - events.append(f"sync_handler1_{data}") - - # Register and unsubscribe - emitter.on("test", async_handler1) - emitter.on("test", sync_handler1) - emitter.unsubscribe("test", async_handler1) - - # Emit event - await emitter._emit("test", "event3") - await asyncio.sleep(0.02) # Allow async handlers to complete (none in this case) - - # Verify only remaining subscriber ran - expected = [f"sync_handler1_event3"] - assert events == expected, f"Expected {expected}, got {events}" - -@pytest.mark.asyncio -async def test_replace_on_handler(): - """Test that replace_on_handler sets only the new handler for the event.""" - class TestEmitter(AbstractEmitter): - pass - - emitter = TestEmitter() - events = [] - - async def async_handler1(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler1_{data}") - - def sync_handler1(data: Any): - events.append(f"sync_handler1_{data}") - - # Register and replace - emitter.on("test", sync_handler1) - emitter.replace_on_handler("test", async_handler1) - - # Emit event - await emitter._emit("test", "event4") - await asyncio.sleep(0.02) # Allow async handler to complete - - # Verify only the replaced handler ran - expected = [f"async_handler1_event4"] - assert events == expected, f"Expected {expected}, got {events}" - -@pytest.mark.asyncio -async def test_pipeline_failure(): - """Test that '_notify' raises an exception if a pipeline fails and doesn't call subscribers.""" - class TestEmitter(AbstractEmitter): - pass - - emitter = TestEmitter() - events = [] - - async def async_pipeline1(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_pipeline1_{data}") - raise ValueError("Pipeline failed") - - def sync_handler1(data: Any): - events.append(f"sync_handler1_{data}") - - # Register pipeline and subscriber - emitter.pipeline("test", async_pipeline1) - emitter.on("test", sync_handler1) - - # Notify with failing pipeline - try: - await emitter._notify("test", "event5") - assert False, "Exception not thrown" - except ValueError: - pass - - # Verify only the pipeline ran, not the subscriber - expected = [f"async_pipeline1_event5"] - assert events == expected, f"Expected {expected}, got {events}" - -@pytest.mark.asyncio -async def test_multiple_event_types(): - """Test that only handlers for the triggered event type are called.""" - class TestEmitter(AbstractEmitter): - pass - - emitter = TestEmitter() - events = [] - - # Handlers for event_a - async def async_handler_a(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler_a_{data}") - - def sync_handler_a(data: Any): - events.append(f"sync_handler_a_{data}") - - # Handlers for event_b - async def async_handler_b(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler_b_{data}") - - def sync_handler_b(data: Any): - events.append(f"sync_handler_b_{data}") - - # Pipelines for event_a - async def async_pipeline_a(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_pipeline_a_{data}") - - # Pipelines for event_b - def sync_pipeline_b(data: Any): - events.append(f"sync_pipeline_b_{data}") - - # Register handlers and pipelines for different event types - emitter.on("event_a", async_handler_a) - emitter.on("event_a", sync_handler_a) - emitter.pipeline("event_a", async_pipeline_a) - emitter.on("event_b", async_handler_b) - emitter.on("event_b", sync_handler_b) - emitter.pipeline("event_b", sync_pipeline_b) - - # Test _emit for event_a - events.clear() - await emitter._emit("event_a", "data_a") - await asyncio.sleep(0.02) # Allow async handlers to complete - expected = [f"async_handler_a_data_a", f"sync_handler_a_data_a"] - assert sorted(events) == sorted(expected), f"Expected {expected}, got {events}" - - # Test _notify for event_a - events.clear() - await emitter._notify("event_a", "data_a2") - await asyncio.sleep(0.02) # Allow async pipelines and handlers to complete - expected_pipelines = [f"async_pipeline_a_data_a2"] - expected_subscribers = [f"async_handler_a_data_a2", f"sync_handler_a_data_a2"] - pipeline_indices = [events.index(e) for e in expected_pipelines if e in events] - subscriber_indices = [events.index(e) for e in expected_subscribers if e in events] - assert all(p < s for p in pipeline_indices for s in subscriber_indices), \ - f"Pipelines {expected_pipelines} should precede subscribers {expected_subscribers} in {events}" - assert sorted(events) == sorted(expected_pipelines + expected_subscribers), \ - f"Expected {expected_pipelines + expected_subscribers}, got {events}" - - # Test _emit for event_b - events.clear() - await emitter._emit("event_b", "data_b") - await asyncio.sleep(0.02) # Allow async handlers to complete - expected = [f"async_handler_b_data_b", f"sync_handler_b_data_b"] - assert sorted(events) == sorted(expected), f"Expected {expected}, got {events}" - - # Test _notify for event_b - events.clear() - await emitter._notify("event_b", "data_b2") - await asyncio.sleep(0.02) # Allow async handlers to complete - expected_pipelines = [f"sync_pipeline_b_data_b2"] - expected_subscribers = [f"async_handler_b_data_b2", f"sync_handler_b_data_b2"] - pipeline_indices = [events.index(e) for e in expected_pipelines if e in events] - subscriber_indices = [events.index(e) for e in expected_subscribers if e in events] - assert all(p < s for p in pipeline_indices for s in subscriber_indices), \ - f"Pipelines {expected_pipelines} should precede subscribers {expected_subscribers} in {events}" - assert sorted(events) == sorted(expected_pipelines + expected_subscribers), \ - f"Expected {expected_pipelines + expected_subscribers}, got {events}" - -@pytest.mark.asyncio -async def test_background_task_failure(): - """Test that background task failures in '_emit' don't affect other subscribers.""" - class TestEmitter(AbstractEmitter): - pass - - emitter = TestEmitter() - events = [] - - async def async_handler1(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler1_{data}") - - async def async_handler2(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler2_{data}") - raise ValueError("async_handler2 failed") - - def sync_handler(data: Any): - events.append(f"sync_handler_{data}") - - # Register subscribers, including one that fails - emitter.on("test", async_handler1) - emitter.on("test", async_handler2) - emitter.on("test", sync_handler) - - # Emit event - await emitter._emit("test", "event6") - await asyncio.sleep(0.02) # Allow async handlers to complete - - # Verify all subscribers ran despite the failure - expected = [ - f"async_handler1_event6", - f"async_handler2_event6", - f"sync_handler_event6" - ] - assert sorted(events) == sorted(expected), f"Expected {expected}, got {events}" - -@pytest.mark.asyncio -async def test_slow_emit_does_not_block_notify(): - """Test that a slow handler in _emit does not block a subsequent _notify call.""" - - class TestEmitter(AbstractEmitter): - pass - - emitter = TestEmitter() - events = [] - - # Slow async handler for _emit - async def slow_handler(data: Any): - await asyncio.sleep(0.8) - events.append(f"slow_handler_{data}") - - # Fast handler for _notify - def fast_notify_handler(data: Any): - events.append(f"fast_notify_handler_{data}") - - def fast_notify_pipeline(data: Any): - events.append(f"fast_notify_pipeline_{data}") - - # Register handlers - emitter.on("test_emit", slow_handler) - emitter.on("test_notify", fast_notify_handler) - emitter.pipeline("test_notify", fast_notify_pipeline) - - # Start _emit but don't wait for it to finish - asyncio.create_task(emitter._emit("test_emit", "data_emit")) - - # Wait briefly before triggering _notify - await asyncio.sleep(0.1) - - # Trigger and await _notify - await emitter._notify("test_notify", "data_notify") - - # Give time for _notify to complete before slow_handler finishes - await asyncio.sleep(0.4) - - # βœ… Confirm that notify events are already present - assert "fast_notify_pipeline_data_notify" in events - assert "fast_notify_handler_data_notify" in events - assert "slow_handler_data_emit" not in events # not yet complete - - # Wait for slow handler to complete - await asyncio.sleep(0.5) - - # Now validate the full event order - slow_index = events.index("slow_handler_data_emit") - pipeline_index = events.index("fast_notify_pipeline_data_notify") - handler_index = events.index("fast_notify_handler_data_notify") - - assert pipeline_index < slow_index and handler_index < slow_index, ( - f"_notify handlers [{pipeline_index}, {handler_index}] should run before slow _emit [{slow_index}]" - ) - - expected = { - "fast_notify_pipeline_data_notify", - "fast_notify_handler_data_notify", - "slow_handler_data_emit" - } - assert set(events) == expected, f"Expected {expected}, got {set(events)}" diff --git a/tests/test_json_transformer.py b/tests/test_json_transformer.py deleted file mode 100644 index d035158..0000000 --- a/tests/test_json_transformer.py +++ /dev/null @@ -1,238 +0,0 @@ -import pytest -import asyncio -import json -from typing import Any, Dict, List -from cbor_rpc import JsonTransformer -from cbor_rpc import Pipe - -@pytest.mark.asyncio -async def test_json_transformer_basic_encoding_decoding(): - """Test basic JSON encoding and decoding.""" - pipe1, pipe2 = Pipe.create_pair() - transformer = JsonTransformer(pipe1) - - received_data = [] - transformer.on("data", lambda chunk: received_data.append(chunk)) - - # Test encoding: write JSON data through pipe2 (will be received by transformer) - test_data = {"message": "hello", "number": 42, "array": [1, 2, 3]} - json_bytes = json.dumps(test_data).encode('utf-8') - await pipe2.write(json_bytes) - - # Wait for data to be processed - await asyncio.sleep(0.1) - assert test_data == received_data[0] if received_data else None - - # Test decoding: write data through transformer (will be encoded) - received_raw = [] - pipe2.on("data", lambda chunk: received_raw.append(chunk)) - await transformer.write({"response": "world", "success": True}) - - # Read the encoded response from pipe2 - raw_response = received_raw[0] if received_raw else None - decoded_response = json.loads(raw_response.decode('utf-8')) if raw_response else None - assert decoded_response == {"response": "world", "success": True} - -@pytest.mark.asyncio -async def test_json_transformer_create_pair(): - """Test creating a pair of JSON transformers.""" - transformer1, transformer2 = JsonTransformer.create_pair() - - # Test communication from transformer1 to transformer2 - received_data = [] - transformer2.on("data", lambda chunk: received_data.append(chunk)) - test_data = {"message": "Hello World", "timestamp": 1234567890} - await transformer1.write(test_data) - - assert received_data == [test_data] - - # Test communication from transformer2 to transformer1 - received_data.clear() - transformer1.on("data", lambda chunk: received_data.append(chunk)) - response_data = {"reply": "Hello Back", "status": "ok"} - await transformer2.write(response_data) - - assert received_data == [response_data] - -@pytest.mark.asyncio -async def test_json_transformer_different_data_types(): - """Test JSON transformer with different Python data types.""" - transformer1, transformer2 = JsonTransformer.create_pair() - received_data = [] - transformer2.on("data", lambda chunk: received_data.append(chunk)) - - # Test various data types - test_cases = [ - "simple string", - 42, - 3.14159, - True, - False, - None, - [1, 2, 3, "mixed", True], - {"nested": {"object": {"with": "values"}}}, - {"unicode": "Hello δΈ–η•Œ 🌍"}, - [], - {}, - ] - - for test_data in test_cases: - await transformer1.write(test_data) - - assert received_data == test_cases - -@pytest.mark.asyncio -async def test_json_transformer_encoding_errors(): - """Test JSON transformer encoding error handling.""" - pipe1, pipe2 = Pipe.create_pair() - transformer = JsonTransformer(pipe1) - errors = [] - transformer.on("error", lambda err: errors.append(str(err))) - - # Test encoding non-serializable object - class NonSerializable: - pass - - result = await transformer.write(NonSerializable()) - assert result is False # Write should return False on error - assert len(errors) == 1 - - # Test encoding circular reference - circular = {} - circular['self'] = circular - result = await transformer.write(circular) - assert result is False - assert len(errors) == 2 - -@pytest.mark.asyncio -async def test_json_transformer_decoding_errors(): - """Test JSON transformer decoding error handling.""" - pipe1, pipe2 = Pipe.create_pair() - transformer = JsonTransformer(pipe1) - errors = [] - transformer.on("error", lambda err: errors.append(str(err))) - - # Test invalid JSON - await pipe2.write(b'{"invalid": json}') - assert len(errors) == 1 - - # Test invalid UTF-8 bytes - await pipe2.write(b'\xff\xfe\xfd') - assert len(errors) == 2 - - # Test wrong data type - await pipe2.write(123) # Not bytes or string - assert len(errors) == 3 - -@pytest.mark.asyncio -async def test_json_transformer_string_input(): - """Test JSON transformer with string input (not just bytes).""" - pipe1, pipe2 = Pipe.create_pair() - transformer = JsonTransformer(pipe1) - received_data = [] - transformer.on("data", lambda chunk: received_data.append(chunk)) - - # Test with JSON string input - test_data = {"message": "from string", "value": 123} - json_string = json.dumps(test_data) - await pipe2.write(json_string) - - assert received_data == [test_data] - -@pytest.mark.asyncio -async def test_json_transformer_custom_encoding(): - """Test JSON transformer with custom text encoding.""" - pipe1, pipe2 = Pipe.create_pair() - transformer = JsonTransformer(pipe1, encoding='latin1') - received_data = [] - transformer.on("data", lambda chunk: received_data.append(chunk)) - - # Test with latin1 encoding - test_data = {"message": "cafΓ©"} # Contains non-ASCII character - json_bytes = json.dumps(test_data).encode('latin1') - await pipe2.write(json_bytes) - - assert received_data == [test_data] - -@pytest.mark.asyncio -async def test_json_transformer_termination(): - """Test JSON transformer termination.""" - pipe1, pipe2 = Pipe.create_pair() - transformer = JsonTransformer(pipe1) - close_events = [] - transformer.on("close", lambda *args: close_events.append(args)) - - await transformer.terminate("test_reason") - assert len(close_events) == 1 - assert close_events[0] == ("test_reason",) - -@pytest.mark.asyncio -async def test_json_transformer_large_data(): - """Test JSON transformer with large data structures.""" - transformer1, transformer2 = JsonTransformer.create_pair() - received_data = [] - transformer2.on("data", lambda chunk: received_data.append(chunk)) - - # Create a large nested data structure - large_data = { - "users": [ - { - "id": i, - "name": f"User {i}", - "email": f"user{i}@example.com", - "metadata": { - "created": f"2024-01-{i:02d}", - "tags": [f"tag{j}" for j in range(5)], - "settings": { - "theme": "dark" if i % 2 else "light", - "notifications": True, - "features": [f"feature{k}" for k in range(3)] - } - } - } - for i in range(10) - ] - } - - await transformer1.write(large_data) - assert len(received_data) == 1 - assert received_data[0] == large_data - -@pytest.mark.asyncio -async def test_json_transformer_concurrent_operations(): - """Test JSON transformer with concurrent read/write operations.""" - transformer1, transformer2 = JsonTransformer.create_pair() - received_data = [] - transformer2.on("data", lambda chunk: received_data.append(chunk)) - - # Send multiple messages concurrently - async def send_message(id: int): - await transformer1.write({"id": id, "message": f"Message {id}"}) - - tasks = [send_message(i) for i in range(5)] - await asyncio.gather(*tasks) - - assert len(received_data) == 5 - -@pytest.mark.asyncio -async def test_json_transformer_error_recovery(): - """Test that JSON transformer can recover from errors.""" - pipe1, pipe2 = Pipe.create_pair() - transformer = JsonTransformer(pipe1) - received_data = [] - errors = [] - transformer.on("data", lambda chunk: received_data.append(chunk)) - transformer.on("error", lambda err: errors.append(str(err))) - - # Send invalid JSON - await pipe2.write(b'invalid json') - - # Send valid JSON after error - valid_data = {"message": "recovery test"} - await pipe2.write(json.dumps(valid_data).encode('utf-8')) - - assert len(received_data) == 1 - assert received_data[0] == valid_data - -if __name__ == "__main__": - pytest.main(["-v", __file__]) diff --git a/tests/test_server_generics.py b/tests/test_server_generics.py deleted file mode 100644 index 6d86474..0000000 --- a/tests/test_server_generics.py +++ /dev/null @@ -1,195 +0,0 @@ -""" -Tests for the generic Server class and its type safety. -""" - -import pytest -import asyncio -from typing import Any, Set -from cbor_rpc import Server, Pipe, TcpServer, TcpPipe - - -class MockPipe(Pipe[str, str]): - """A mock pipe for testing.""" - - def __init__(self, name: str): - super().__init__() - self.name = name - self._closed = False - - async def write(self, chunk: str) -> bool: - if self._closed: - return False - await self._emit("data", chunk) - return True - - async def terminate(self, *args: Any) -> None: - if self._closed: - return - self._closed = True - await self._emit("close", *args) - - -class MockServer(Server[MockPipe]): - """A mock server for testing generic functionality.""" - - def __init__(self): - super().__init__() - self._started = False - - async def start(self) -> str: - self._started = True - self._running = True - return "mock-server-started" - - async def stop(self) -> None: - if not self._running: - return - self._started = False - self._running = False - await self.close_all_connections() - - async def add_mock_connection(self, pipe: MockPipe) -> None: - """Add a mock connection for testing.""" - await self._add_connection(pipe) - - -@pytest.mark.asyncio -async def test_generic_server_typing(): - """Test that Server generic typing works correctly.""" - server: Server[MockPipe] = MockServer() - - # Test server initialization - assert not server.is_running() - assert len(server.get_connections()) == 0 - - # Start server - result = await server.start() - assert result == "mock-server-started" - assert server.is_running() - - # Test connection handling - connection_events = [] - - def on_connection(pipe: MockPipe): - connection_events.append(pipe) - assert isinstance(pipe, MockPipe) - - server.on_connection(on_connection) - - # Add mock connections - pipe1 = MockPipe("test-pipe-1") - pipe2 = MockPipe("test-pipe-2") - - await server.add_mock_connection(pipe1) - await server.add_mock_connection(pipe2) - - # Verify connections were added - assert len(connection_events) == 2 - assert connection_events[0] == pipe1 - assert connection_events[1] == pipe2 - - # Verify get_connections returns correct type - connections: Set[MockPipe] = server.get_connections() - assert len(connections) == 2 - assert pipe1 in connections - assert pipe2 in connections - - # Test connection cleanup on close - await pipe1.terminate() - await asyncio.sleep(0.01) # Allow cleanup - - connections = server.get_connections() - assert len(connections) == 1 - assert pipe2 in connections - assert pipe1 not in connections - - # Stop server - await server.stop() - assert not server.is_running() - assert len(server.get_connections()) == 0 - - -@pytest.mark.asyncio -async def test_tcp_server_generic_typing(): - """Test that TcpServer properly implements Server[TcpPipe].""" - # Create TCP server - tcp_server: Server[TcpPipe] = await TcpServer.create('127.0.0.1', 0) - - # Verify it's the correct type - assert isinstance(tcp_server, Server) - assert isinstance(tcp_server, TcpServer) - - # Test connection event typing - connection_events = [] - - def on_tcp_connection(pipe: TcpPipe): - connection_events.append(pipe) - # Verify the pipe is the correct type - assert isinstance(pipe, TcpPipe) - assert isinstance(pipe, Pipe) - - tcp_server.on_connection(on_tcp_connection) - - try: - # Create client connection - host, port = tcp_server.get_address() - client = await TcpPipe.create_connection(host, port) - - await asyncio.sleep(0.1) - - # Verify connection event was emitted with correct type - assert len(connection_events) == 1 - assert isinstance(connection_events[0], TcpPipe) - - # Verify get_connections returns TcpPipe instances - connections: Set[TcpPipe] = tcp_server.get_connections() - assert len(connections) == 1 - for conn in connections: - assert isinstance(conn, TcpPipe) - - await client.terminate() - - finally: - await tcp_server.stop() - - -@pytest.mark.asyncio -async def test_server_context_manager(): - """Test server context manager functionality.""" - async with MockServer() as server: - await server.start() - assert server.is_running() - - pipe = MockPipe("context-test") - await server.add_mock_connection(pipe) - assert len(server.get_connections()) == 1 - - # Server should be stopped automatically - assert not server.is_running() - - -@pytest.mark.asyncio -async def test_server_close_all_connections(): - """Test that close_all_connections properly closes all connections.""" - server = MockServer() - await server.start() - - # Add multiple connections - pipes = [MockPipe(f"pipe-{i}") for i in range(3)] - for pipe in pipes: - await server.add_mock_connection(pipe) - - assert len(server.get_connections()) == 3 - - # Close all connections - await server.close_all_connections() - await asyncio.sleep(0.01) # Allow cleanup - - # All connections should be closed - assert len(server.get_connections()) == 0 - - await server.stop() - - -if __name__ == "__main__": - pytest.main(["-v", __file__]) diff --git a/tests/test_sync_pipe.py b/tests/test_sync_pipe.py deleted file mode 100644 index 6dbd15a..0000000 --- a/tests/test_sync_pipe.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest -from typing import Any, Tuple -from cbor_rpc.sync_pipe import SyncPipe - -def test_create_pair(): - # Positive case: Creating a pair of sync pipes - pipe1, pipe2 = SyncPipe.create_pair() - assert isinstance(pipe1, SyncPipe) - assert isinstance(pipe2, SyncPipe) - -def test_write_read(): - # Positive case: Writing and reading a chunk successfully - pipe1, pipe2 = SyncPipe.create_pair() - - assert pipe1.write("test_chunk") is True - assert pipe2.read() == "test_chunk" - -def test_close_pipe(): - # Positive case: Closing the pipe - pipe1, pipe2 = SyncPipe.create_pair() - pipe1.close() - - with pytest.raises(Exception): - pipe1.read() - - assert pipe2.is_closed() is True - -def test_write_after_close(): - # Negative case: Writing to a closed pipe - pipe1, pipe2 = SyncPipe.create_pair() - pipe1.close() - - assert pipe1.write("test_chunk") is False - -def test_read_timeout(): - # Positive case: Reading with timeout - pipe1, _ = SyncPipe.create_pair() - - assert pipe1.read(timeout=0.1) is None - -def test_bidirectional_communication(): - # Positive case: Bidirectional communication between pipes - pipe1, pipe2 = SyncPipe.create_pair() - - assert pipe1.write("test_chunk") is True - assert pipe2.read() == "test_chunk" - - assert pipe2.write("response_chunk") is True - assert pipe1.read() == "response_chunk" - -def test_queue_size(): - # Positive case: Checking queue size - pipe1, _ = SyncPipe.create_pair() - - pipe1.write("chunk1") - pipe1.write("chunk2") - - assert pipe1.available() == 2 - -if __name__ == "__main__": - pytest.main() diff --git a/tests/test_transformer.py b/tests/test_transformer.py deleted file mode 100644 index 37baffa..0000000 --- a/tests/test_transformer.py +++ /dev/null @@ -1,120 +0,0 @@ -import pytest -import asyncio -from typing import Any, Dict, List -from cbor_rpc.async_pipe import Pipe -from cbor_rpc import Transformer -from cbor_rpc import SyncPipe -from cbor_rpc import AbstractEmitter - -# Existing tests... - -@pytest.mark.asyncio -async def test_async_transformer_basic(): - """Test basic asynchronous transformer functionality.""" - pipe1, pipe2 = Pipe.create_pair() - - class MockTransformer(Transformer[str, str], AbstractEmitter): - async def encode(self, data: str) -> str: - return f"encoded_{data}" - - async def decode(self, data: Any) -> str: - if isinstance(data, str) and data.startswith("encoded_"): - return data[8:] # Remove "encoded_" prefix - raise ValueError("Invalid format") - - transformer = MockTransformer(pipe1) - - # Set up event handler on pipe2 to receive data - received_data = None - - def handler(data: str) -> None: - nonlocal received_data - received_data = data - - pipe2.on("data", handler) - - # Write through transformer, should be received by pipe2 - assert await transformer.write("test_data") is True - await asyncio.sleep(0.1) # Give time for events to propagate - assert received_data == "encoded_test_data" - -@pytest.mark.asyncio -async def test_sync_transformer_basic(): - """Test basic synchronous transformer functionality.""" - pipe1, pipe2 = SyncPipe.create_pair() - - class MockSyncTransformer(Transformer[str, str]): - def encode_sync(self, data: str) -> str: - return f"encoded_{data}" - - def decode_sync(self, data: Any) -> str: - if isinstance(data, str) and data.startswith("encoded_"): - return data[8:] # Remove "encoded_" prefix - raise ValueError("Invalid format") - - transformer = MockSyncTransformer(pipe1) - - # Write through transformer, should be readable from pipe2 - assert transformer.write_sync("test_data") is True - - # Directly read from the connected sync pipe to verify data transfer - encoded_data = pipe2.read(timeout=1.0) - decoded_data = transformer.decode_sync(encoded_data) if encoded_data else None - assert decoded_data == "test_data" - -@pytest.mark.asyncio -async def test_transformer_close_propagation(): - """Test close propagation in transformers.""" - pipe1, pipe2 = Pipe.create_pair() - - class MockTransformer(Transformer[str, str]): - async def encode(self, data: str) -> str: - return f"encoded_{data}" - - async def decode(self, data: Any) -> str: - if isinstance(data, str) and data.startswith("encoded_"): - return data[8:] # Remove "encoded_" prefix - raise ValueError("Invalid format") - - transformer = MockTransformer(pipe1) - - # Close transformer and verify pipe is also closed - await transformer.terminate() - - # Try to write after closing - should fail - assert await transformer.write("test_data") is False - - # Verify pipe is closed by trying to write directly (should fail) - assert await pipe1.write("raw_data") is False - -@pytest.mark.asyncio -async def test_transformer_exception_handling(): - """Test exception handling in transformers.""" - pipe1, pipe2 = Pipe.create_pair() - - class FaultyTransformer(Transformer[str, str], AbstractEmitter): - async def encode(self, data: str) -> str: - raise ValueError("Encoding error") - - async def decode(self, data: Any) -> str: - if isinstance(data, str) and data.startswith("encoded_"): - return data[8:] # Remove "encoded_" prefix - raise ValueError("Invalid format") - - transformer = FaultyTransformer(pipe1) - - # Set up error handler to verify exception is caught on the transformer itself - error_caught = False - - def on_error(err: Exception) -> None: - nonlocal error_caught - error_caught = True - - transformer.on("error", on_error) - - # Write through transformer - should trigger encoding error - assert await transformer.write("test_data") is False - assert error_caught is True - -if __name__ == "__main__": - pytest.main() diff --git a/tests/transformer/test_cbor_transformer.py b/tests/transformer/test_cbor_transformer.py new file mode 100644 index 0000000..5b57a3b --- /dev/null +++ b/tests/transformer/test_cbor_transformer.py @@ -0,0 +1,377 @@ +import asyncio + +import cbor2 +import pytest +import pytest_asyncio + +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.tcp.tcp import TcpPipe +from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException +from cbor_rpc.transformer.base.event_transformer_pipe import EventTransformerPipe +from cbor_rpc.transformer.cbor_transformer import CborTransformer, CborStreamTransformer +from tests.helpers.timeout_queue import TimeoutQueue + + +@pytest.fixture( + params=[ + (EventPipe.create_inmemory_pair, "InmemoryPipe"), + (TcpPipe.create_inmemory_pair, "TcpPipe"), + ], + ids=lambda param: param[1], +) +async def pipe_pair(request): + create_pair_func, _ = request.param + if asyncio.iscoroutinefunction(create_pair_func): + client_pipe, server_pipe = await create_pair_func() + else: + client_pipe, server_pipe = create_pair_func() + yield client_pipe, server_pipe + await client_pipe.terminate() + await server_pipe.terminate() + + +@pytest.fixture +def client_raw(pipe_pair): + client_raw_pipe, _ = pipe_pair + return client_raw_pipe + + +@pytest.fixture +def server_raw(pipe_pair): + _, server_raw_pipe = pipe_pair + return server_raw_pipe + + +@pytest.fixture +def client_cbor(client_raw): + cbor_transformer = CborTransformer() + return cbor_transformer.apply_transformer(client_raw) + + +@pytest.mark.asyncio +class TestCborTransformer: + async def test_cbor_transformer_end_to_end_simple_dict(self, client_raw, server_raw, client_cbor): + client_transformed_pipe = client_cbor + + received_data_queue = TimeoutQueue() + server_raw.on("data", received_data_queue.put_nowait) + + original_data = {"message": "Hello, CBOR!", "number": 456, "list": [1, 2, 3]} + await client_transformed_pipe.write(original_data) + encoded_data_received_by_server = await received_data_queue.get() + + decoded_by_server = cbor2.loads(encoded_data_received_by_server) + assert decoded_by_server == original_data + + client_received_data_queue = TimeoutQueue() + client_transformed_pipe.on("data", client_received_data_queue.put_nowait) + + response_data = {"status": "cbor_success", "code": 200} + await server_raw.write(cbor2.dumps(response_data)) + decoded_data_received_by_client = await client_received_data_queue.get() + assert decoded_data_received_by_client == response_data + + async def test_cbor_transformer_decoding_error_on_read(self, server_raw, client_cbor,client_raw): + client_transformed_pipe = client_cbor + + error_queue = TimeoutQueue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + incomplete_cbor_bytes = b"\x83\x01\x02" + assert await server_raw.write(incomplete_cbor_bytes) + + error = await asyncio.wait_for(error_queue.get(), timeout=1) + assert type(error).__name__.startswith("CBORDecode") + # sleep 2 + await asyncio.sleep(0.2) + assert server_raw._closed is True, "Pipe should be closed on decode error" + assert server_raw._closed is True, "Pipe should be closed on decode error" + + async def test_cbor_transformer_non_bytes_data(self, server_raw, client_cbor): + transformer = CborTransformer() + with pytest.raises(TypeError): + transformer.decode("not cbor") + + async def test_cbor_transformer_none_data(self, server_raw, client_cbor): + transformer = CborTransformer() + with pytest.raises(TypeError): + transformer.decode(None) + + async def test_cbor_transformer_multiple_separate_writes(self, server_raw, client_cbor, client_raw): + client_transformed_pipe = client_cbor + + received_data_queue = TimeoutQueue() + received_errors = [] + client_transformed_pipe.pipeline("data", received_data_queue.put_nowait) + client_transformed_pipe.on("error", lambda e: received_errors.append(e)) + + assert await server_raw.write(cbor2.dumps({"a": 1})) + decoded1 = await received_data_queue.get() + await asyncio.sleep(0.05) + assert await server_raw.write(cbor2.dumps({"b": 2})) + + assert [] == received_errors, f"Expected no errors, got: {received_errors}" + decoded2 = await received_data_queue.get() + + assert decoded1 == {"a": 1} + assert decoded2 == {"b": 2} + + async def test_cbor_transformer_single_concatenated_write(self, server_raw, client_cbor): + client_transformed_pipe = client_cbor + + received_data_queue = TimeoutQueue() + received_errors = [] + client_transformed_pipe.pipeline("data", received_data_queue.put_nowait) + client_transformed_pipe.on("error", lambda e: received_errors.append(e)) + + concatenated = cbor2.dumps({"a": 1}) + cbor2.dumps({"b": 2}) + assert await server_raw.write(concatenated) + + decoded1 = await received_data_queue.get() + assert decoded1 == {"a": 1} + assert [] == received_errors, f"Expected no errors, got: {received_errors}" + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(received_data_queue.get(), timeout=0.1) + + async def test_cbor_transformer_encode_error_on_write(self, server_raw, client_cbor): + client_transformed_pipe = client_cbor + + error_queue = TimeoutQueue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + received_data_queue = TimeoutQueue() + server_raw.on("data", received_data_queue.put_nowait) + + unserializable = {"func": lambda x: x} + result = await client_transformed_pipe.write(unserializable) + + assert result is False + error = await asyncio.wait_for(error_queue.get(), timeout=1) + assert isinstance(error, Exception) + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(received_data_queue.get(), timeout=0.1) + + async def test_cbor_transformer_close_propagation_and_write_after_close(self, client_raw, client_cbor): + client_transformed_pipe = client_cbor + + close_queue = TimeoutQueue() + client_transformed_pipe.on("close", lambda *args: close_queue.put_nowait(True)) + + await client_raw.terminate() + + await close_queue.get() + + result = await client_transformed_pipe.write({"after": "close"}) + assert result is False + + +@pytest.mark.asyncio +class TestCborStreamTransformer: + async def test_cbor_stream_transformer_single_object(self, client_raw, server_raw): + cbor_stream_transformer = CborStreamTransformer() + client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) + + received_data_queue = TimeoutQueue() + client_transformed_pipe.on("data", received_data_queue.put_nowait) + + original_data = {"key": "value", "id": 1} + await server_raw.write(cbor2.dumps(original_data)) + + decoded_data = await received_data_queue.get() + assert decoded_data == original_data + + async def test_cbor_stream_transformer_concatenated_objects(self, client_raw, server_raw): + cbor_stream_transformer = CborStreamTransformer() + client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) + + received_data_queue = TimeoutQueue() + client_transformed_pipe.on("data", received_data_queue.put_nowait) + + obj1 = {"a": 1} + obj2 = {"b": 2, "c": [3, 4]} + obj3 = "hello" + + concatenated_cbor = cbor2.dumps(obj1) + cbor2.dumps(obj2) + cbor2.dumps(obj3) + + await server_raw.write(concatenated_cbor) + + await server_raw.write(b"") + await server_raw.write(b"") + + decoded1 = await received_data_queue.get() + decoded2 = await received_data_queue.get() + decoded3 = await received_data_queue.get() + + assert decoded1 == obj1 + assert decoded2 == obj2 + assert decoded3 == obj3 + + async def test_cbor_stream_transformer_single_concatenated_write(self, client_raw, server_raw): + cbor_stream_transformer = CborStreamTransformer() + client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) + + received_data_queue = TimeoutQueue() + received_errors = [] + client_transformed_pipe.on("data", received_data_queue.put_nowait) + client_transformed_pipe.on("error", lambda e: received_errors.append(e)) + + concatenated = cbor2.dumps({"a": 1}) + cbor2.dumps({"b": 2}) + await server_raw.write(concatenated) + + decoded1 = await received_data_queue.get() + decoded2 = await received_data_queue.get() + assert decoded1 == {"a": 1} + assert decoded2 == {"b": 2} + assert [] == received_errors, f"Expected no errors, got: {received_errors}" + + async def test_cbor_stream_transformer_delayed_separate_writes(self, client_raw, server_raw): + cbor_stream_transformer = CborStreamTransformer() + client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) + + received_data_queue = TimeoutQueue() + received_errors = [] + client_transformed_pipe.on("data", received_data_queue.put_nowait) + client_transformed_pipe.on("error", lambda e: received_errors.append(e)) + + assert await server_raw.write(cbor2.dumps({"a": 1})) + decoded1 = await received_data_queue.get() + await asyncio.sleep(0.05) + assert await server_raw.write(cbor2.dumps({"b": 2})) + + decoded2 = await received_data_queue.get() + assert decoded1 == {"a": 1} + assert decoded2 == {"b": 2} + assert [] == received_errors, f"Expected no errors, got: {received_errors}" + + async def test_cbor_stream_transformer_fragmented_objects(self, client_raw, server_raw): + cbor_stream_transformer = CborStreamTransformer() + client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) + + received_data_queue = TimeoutQueue() + client_transformed_pipe.on("data", received_data_queue.put_nowait) + error_queue = TimeoutQueue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + obj = {"long_message": "a" * 100} + cbor_bytes = cbor2.dumps(obj) + + await server_raw.write(cbor_bytes[:10]) + await asyncio.sleep(0.05) + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(received_data_queue.get(), timeout=0.1) + assert error_queue.empty() + + await server_raw.write(cbor_bytes[10:50]) + await asyncio.sleep(0.05) + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(received_data_queue.get(), timeout=0.1) + assert error_queue.empty() + + await server_raw.write(cbor_bytes[50:]) + decoded = await received_data_queue.get() + assert decoded == obj + assert error_queue.empty() + + async def test_cbor_stream_transformer_mixed_fragmented_and_concatenated(self, client_raw, server_raw): + cbor_stream_transformer = CborStreamTransformer() + client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) + + received_data_queue = TimeoutQueue() + client_transformed_pipe.on("data", received_data_queue.put_nowait) + + obj1 = {"id": 1, "data": "first"} + obj2 = {"id": 2, "data": "second"} + obj3 = {"id": 3, "data": "third"} + + cbor_bytes1 = cbor2.dumps(obj1) + cbor_bytes2 = cbor2.dumps(obj2) + cbor_bytes3 = cbor2.dumps(obj3) + + await server_raw.write(cbor_bytes1[:5]) + await asyncio.sleep(0.05) + await server_raw.write(cbor_bytes1[5:]) + decoded1 = await received_data_queue.get() + assert decoded1 == obj1 + + await asyncio.sleep(0.05) + await server_raw.write(cbor_bytes2 + cbor_bytes3) + + await server_raw.write(b"") + + decoded2 = await received_data_queue.get() + decoded3 = await received_data_queue.get() + assert decoded2 == obj2 + assert decoded3 == obj3 + + async def test_cbor_stream_transformer_invalid_data_in_stream(self, client_raw, server_raw): + cbor_stream_transformer = CborStreamTransformer() + client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) + + received_data_queue = TimeoutQueue() + client_transformed_pipe.on("data", received_data_queue.put_nowait) + error_queue = TimeoutQueue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + obj1 = {"valid": True} + invalid_bytes = b"\x1f" + obj2 = {"another": "valid"} + + await server_raw.write(cbor2.dumps(obj1)) + await server_raw.write(invalid_bytes + cbor2.dumps(obj2)) + + decoded1 = await received_data_queue.get() + assert decoded1 == obj1 + + error = await asyncio.wait_for(error_queue.get(), timeout=1) + assert isinstance(error, cbor2.CBORDecodeError) + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(received_data_queue.get(), timeout=0.1) + + async def test_cbor_stream_transformer_non_bytes_data(self, client_raw, server_raw): + transformer = CborStreamTransformer() + with pytest.raises(TypeError): + await transformer.decode("not cbor") + + async def test_cbor_stream_transformer_close_propagation_and_write_after_close(self, client_raw): + cbor_stream_transformer = CborStreamTransformer() + client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) + + close_queue = TimeoutQueue() + client_transformed_pipe.on("close", lambda *args: close_queue.put_nowait(True)) + + await client_raw.terminate() + + await close_queue.get() + + result = await client_transformed_pipe.write({"after": "close"}) + assert result is False + + +@pytest.mark.asyncio +async def test_cbor_stream_transformer_paths(): + transformer = CborStreamTransformer() + + with pytest.raises(NeedsMoreDataException): + await transformer.decode(None) + + with pytest.raises(TypeError): + await transformer.decode("bad") + + good = cbor2.dumps({"a": 1}) + with pytest.raises(cbor2.CBORDecodeError): + await transformer.decode(b"\xff" + good) + + with pytest.raises(cbor2.CBORDecodeError): + await transformer.decode(b"\xff") + + +@pytest.mark.asyncio +async def test_cbor_stream_transformer_overflow(): + transformer = CborStreamTransformer(max_buffer_bytes=4) + + with pytest.raises(OverflowError): + await transformer.decode(b"12345") + diff --git a/tests/transformer/test_event_transformer_pipe.py b/tests/transformer/test_event_transformer_pipe.py new file mode 100644 index 0000000..add4d6a --- /dev/null +++ b/tests/transformer/test_event_transformer_pipe.py @@ -0,0 +1,86 @@ +import asyncio +from typing import Any, List + +import pytest + +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException +from cbor_rpc.transformer.base.event_transformer_pipe import EventTransformerPipe +from cbor_rpc.transformer.base.transformer_base import AsyncTransformer + + +class DummyAsyncTransformer(AsyncTransformer[Any, Any]): + async def encode(self, data: Any) -> Any: + if data == "encode_error": + raise ValueError("encode boom") + return f"enc:{data}" + + async def decode(self, data: Any) -> Any: + if data is None: + # Mimic standard transformer behavior which usually expects specific type input + # and rejects None if it doesn't support buffering/streaming check via None. + raise TypeError("Unexpected None input") + + if data == "need_more": + raise NeedsMoreDataException() + if data == "decode_error": + raise ValueError("decode boom") + return f"dec:{data}" + + +@pytest.mark.asyncio +async def test_event_transformer_pipe_data_and_errors(): + base_a, base_b = EventPipe.create_inmemory_pair() + + class EventTransformer(DummyAsyncTransformer): + async def decode(self, data: Any) -> Any: + if data == "need_more": + raise NeedsMoreDataException() + if data == "decode_error": + raise ValueError("boom") + return await super().decode(data) + + transformer = EventTransformer() + tpipe = EventTransformerPipe(base_a, transformer) + + received: List[Any] = [] + errors: List[Any] = [] + + def on_data(data: Any) -> None: + received.append(data) + + def on_error(err: Exception) -> None: + errors.append(err) + + tpipe.on("data", on_data) + tpipe.on("error", on_error) + + await base_b.write("need_more") + await asyncio.sleep(0.01) + assert received == [] + + await base_b.write("ok") + await asyncio.sleep(0.01) + assert received == ["dec:ok"] + + ok = await base_b.write("decode_error") + assert ok is True + await asyncio.sleep(0.01) + assert errors + + +@pytest.mark.asyncio +async def test_event_transformer_pipe_write_error(): + base_a, _base_b = EventPipe.create_inmemory_pair() + transformer = DummyAsyncTransformer() + tpipe = EventTransformerPipe(base_a, transformer) + + errors: List[Any] = [] + + def on_error(err: Exception) -> None: + errors.append(err) + + tpipe.on("error", on_error) + ok = await tpipe.write("encode_error") + assert ok is False + assert errors diff --git a/tests/transformer/test_json_transformer.py b/tests/transformer/test_json_transformer.py new file mode 100644 index 0000000..f6142aa --- /dev/null +++ b/tests/transformer/test_json_transformer.py @@ -0,0 +1,213 @@ +import asyncio +import json + +import pytest + +from cbor_rpc.pipe.aio_pipe import AioPipe +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.tcp.tcp import TcpPipe +from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException +from cbor_rpc.transformer.base.event_transformer_pipe import EventTransformerPipe +from cbor_rpc.transformer.json_transformer import JsonTransformer + +DEFAULT_TIMEOUT = 2.0 + + +@pytest.fixture( + params=[ + (EventPipe.create_inmemory_pair, "InmemoryPipe"), + (AioPipe.create_inmemory_pair, "AioPipe"), + (TcpPipe.create_inmemory_pair, "TcpPipe"), + ], + ids=lambda param: param[1], +) +async def pipe_pair(request): + create_pair_func, _ = request.param + if asyncio.iscoroutinefunction(create_pair_func): + client_pipe, server_pipe = await create_pair_func() + else: + client_pipe, server_pipe = create_pair_func() + yield client_pipe, server_pipe + await client_pipe.terminate() + await server_pipe.terminate() + + +@pytest.fixture +async def json_pipe(pipe_pair): + client_raw_pipe, server_raw_pipe = pipe_pair + json_transformer = JsonTransformer() + client_transformed_pipe = json_transformer.apply_transformer(client_raw_pipe) + return client_raw_pipe, server_raw_pipe, client_transformed_pipe, json_transformer + + +@pytest.fixture +async def json_pipe_ascii(pipe_pair): + client_raw_pipe, server_raw_pipe = pipe_pair + json_transformer = JsonTransformer(encoding="ascii") + client_transformed_pipe = json_transformer.apply_transformer(client_raw_pipe) + return client_raw_pipe, server_raw_pipe, client_transformed_pipe, json_transformer + + +@pytest.mark.asyncio +class TestJsonTransformerPipeInteraction: + async def test_json_transformer_end_to_end_simple_dict(self, json_pipe): + client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe + assert isinstance(client_transformed_pipe, EventTransformerPipe) + + received_data_queue = asyncio.Queue() + server_raw_pipe.on("data", received_data_queue.put_nowait) + + original_data = {"message": "Hello, world!", "number": 123} + + await client_transformed_pipe.write(original_data) + + encoded_data_received_by_server = await asyncio.wait_for(received_data_queue.get(), timeout=2.0) + + decoded_by_server = json.loads(encoded_data_received_by_server.decode("utf-8")) + assert decoded_by_server == original_data + + client_received_data_queue = asyncio.Queue() + client_transformed_pipe.on("data", client_received_data_queue.put_nowait) + + response_data = {"status": "success", "code": 200} + await server_raw_pipe.write(json.dumps(response_data).encode("utf-8")) + + decoded_data_received_by_client = await asyncio.wait_for( + client_received_data_queue.get(), + timeout=2.0, + ) + assert decoded_data_received_by_client == response_data + + await client_raw_pipe.terminate() + await server_raw_pipe.terminate() + + async def test_json_transformer_end_to_end_unicode_characters(self, json_pipe): + client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe + + received_data_queue = asyncio.Queue() + server_raw_pipe.on("data", received_data_queue.put_nowait) + + original_data = {"message": "δ½ ε₯½δΈ–η•Œ πŸ‘‹"} + await client_transformed_pipe.write(original_data) + encoded_data_received_by_server = await asyncio.wait_for( + received_data_queue.get(), + timeout=2.0, + ) + decoded_by_server = json.loads(encoded_data_received_by_server.decode("utf-8")) + assert decoded_by_server == original_data + + client_received_data_queue = asyncio.Queue() + client_transformed_pipe.on("data", client_received_data_queue.put_nowait) + response_data = {"greeting": "こんにけは"} + await server_raw_pipe.write(json.dumps(response_data, ensure_ascii=False).encode("utf-8")) + decoded_data_received_by_client = await asyncio.wait_for( + client_received_data_queue.get(), + timeout=2.0, + ) + assert decoded_data_received_by_client == response_data + + async def test_json_transformer_encoding_error_on_write(self, json_pipe_ascii): + client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe_ascii + + original_data = {"message": "Hello, world! πŸ‘‹"} + + error_queue = asyncio.Queue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + await client_transformed_pipe.write(original_data) + + error = await asyncio.wait_for(error_queue.get(), timeout=DEFAULT_TIMEOUT) + assert isinstance(error, UnicodeEncodeError) + + async def test_json_transformer_decoding_error_on_read(self, json_pipe): + client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe + + error_queue = asyncio.Queue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + invalid_json_bytes = b'{,"key": "value",}' + try: + await server_raw_pipe.write(invalid_json_bytes) + except json.JSONDecodeError: + pass + + error = await asyncio.wait_for(error_queue.get(), timeout=DEFAULT_TIMEOUT) + assert isinstance(error, json.JSONDecodeError) + + async def test_json_transformer_decoding_type_error_on_read(self, json_pipe): + client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe + + error_queue = asyncio.Queue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + non_string_data = 12345 + try: + await server_raw_pipe.write(non_string_data) + except TypeError as exc: + assert isinstance(exc, TypeError) + return + + error = await asyncio.wait_for(error_queue.get(), timeout=DEFAULT_TIMEOUT) + assert isinstance(error, TypeError) + assert "Expected bytes or str" in str(error) + + async def test_json_transformer_non_json_serializable_data(self, json_pipe): + client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe + + non_serializable_data = {"set_data": {1, 2, 3}} + + error_queue = asyncio.Queue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + await client_transformed_pipe.write(non_serializable_data) + + error = await asyncio.wait_for(error_queue.get(), timeout=DEFAULT_TIMEOUT) + assert isinstance(error, TypeError) + + async def test_json_transformer_pipe_termination(self, json_pipe): + client_raw_pipe, _server_raw_pipe, client_transformed_pipe, _ = json_pipe + + close_event_received = asyncio.Event() + client_transformed_pipe.on("close", lambda: close_event_received.set()) + + await client_raw_pipe.terminate() + + await asyncio.wait_for(close_event_received.wait(), timeout=DEFAULT_TIMEOUT) + + async def test_json_transformer_pipe_write_after_termination(self, json_pipe): + client_raw_pipe, _server_raw_pipe, client_transformed_pipe, _ = json_pipe + + await client_raw_pipe.terminate() + + result = await client_transformed_pipe.write({"test": "data"}) + assert result is False + + async def test_json_transformer_pipe_read_after_termination(self, json_pipe): + _client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe + + data_queue = asyncio.Queue() + client_transformed_pipe.pipeline("data", data_queue.put_nowait) + + close_event_received = asyncio.Event() + client_transformed_pipe.on("close", lambda: close_event_received.set()) + + await server_raw_pipe.terminate() + await asyncio.wait_for(close_event_received.wait(), timeout=DEFAULT_TIMEOUT) + + try: + result = await server_raw_pipe.write(b'{"should": "not_receive"}') + assert result is False + except ConnectionError: + pass + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(data_queue.get(), timeout=0.1) + + +def test_json_transformer_decode_variants(): + transformer = JsonTransformer() + assert transformer.decode(b"{\"x\": 1}") == {"x": 1} + assert transformer.decode("{\"y\": 2}") == {"y": 2} + + with pytest.raises(TypeError): + transformer.decode(123) diff --git a/tests/transformer/test_transformer_base.py b/tests/transformer/test_transformer_base.py new file mode 100644 index 0000000..d7500cf --- /dev/null +++ b/tests/transformer/test_transformer_base.py @@ -0,0 +1,62 @@ +import pytest + +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.pipe.pipe import Pipe +from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException +from cbor_rpc.transformer.base.event_transformer_pipe import EventTransformerPipe +from cbor_rpc.transformer.base.transformer_base import AsyncTransformer, Transformer +from cbor_rpc.transformer.base.transformer_pipe import TransformerPipe + + +class DummyAsyncTransformer(AsyncTransformer): + async def encode(self, data): + return f"enc:{data}" + + async def decode(self, data): + return f"dec:{data}" + + +class DummySyncTransformer(Transformer): + def encode(self, data): + return f"enc:{data}" + + def decode(self, data): + return f"dec:{data}" + + +class DummyPipe(Pipe): + async def write(self, _chunk): + return True + + async def read(self, _timeout=None): + return None + + async def terminate(self, *args): + return None + + +def test_transformer_base_apply_transformer_invalid_type(): + transformer = DummySyncTransformer() + with pytest.raises(TypeError): + transformer.apply_transformer(object()) + + +def test_transformer_base_apply_transformer_variants(): + transformer = DummySyncTransformer() + pipe = DummyPipe() + event_pipe_a, _event_pipe_b = EventPipe.create_inmemory_pair() + + bound_pipe = transformer.apply_transformer(pipe) + bound_event_pipe = transformer.apply_transformer(event_pipe_a) + + assert isinstance(bound_pipe, TransformerPipe) + assert isinstance(bound_event_pipe, EventTransformerPipe) + + +def test_transformer_base_wait_next_data_raises(): + transformer = DummySyncTransformer() + async_transformer = DummyAsyncTransformer() + with pytest.raises(NeedsMoreDataException): + transformer.wait_next_data() + with pytest.raises(NeedsMoreDataException): + async_transformer.wait_next_data() diff --git a/tests/transformer/test_transformer_pipe.py b/tests/transformer/test_transformer_pipe.py new file mode 100644 index 0000000..b670049 --- /dev/null +++ b/tests/transformer/test_transformer_pipe.py @@ -0,0 +1,216 @@ +import asyncio +from typing import Any, List, Optional + +import pytest + +from cbor_rpc.pipe.pipe import Pipe +from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException +from cbor_rpc.transformer.base.transformer_base import AsyncTransformer +from cbor_rpc.transformer.base.transformer_pipe import TransformerPipe + + +class DummyPipe(Pipe[Any, Any]): + def __init__(self): + super().__init__() + self._queue: asyncio.Queue = asyncio.Queue() + self._closed = False + self.writes: List[Any] = [] + self.terminated = False + + async def write(self, chunk: Any) -> bool: + if self._closed: + return False + self.writes.append(chunk) + return True + + async def read(self, timeout: float = None) -> Any: + if self._closed: + return None + try: + if timeout is None: + return await self._queue.get() + return await asyncio.wait_for(self._queue.get(), timeout) + except asyncio.TimeoutError: + return None + + async def terminate(self, *args: Any) -> None: + if self._closed: + return + self._closed = True + self.terminated = True + await self._queue.put(None) + self._emit("close", *args) + + async def push(self, item: Any) -> None: + await self._queue.put(item) + + +class DummyAsyncTransformer(AsyncTransformer[Any, Any]): + async def encode(self, data: Any) -> Any: + if data == "encode_error": + raise ValueError("encode boom") + return f"enc:{data}" + + async def decode(self, data: Any) -> Any: + if asyncio.iscoroutine(data): + data = await data + if data == "need_more": + raise NeedsMoreDataException() + if data == "decode_error": + raise ValueError("decode boom") + return f"dec:{data}" + + +@pytest.mark.asyncio +async def test_transformer_pipe_read_write_and_errors(): + pipe = DummyPipe() + transformer = DummyAsyncTransformer() + tpipe = TransformerPipe(pipe, transformer) + + errors: List[Any] = [] + + def on_error(err: Exception) -> None: + errors.append(err) + + tpipe.on("error", on_error) + + ok = await tpipe.write("value") + assert ok is True + assert pipe.writes == ["enc:value"] + + ok = await tpipe.write("encode_error") + assert ok is False + assert errors + + await pipe.push("need_more") + await pipe.push("data") + result = await tpipe.read(timeout=0.1) + assert result == "dec:data" + + await pipe.push("need_more") + result = await tpipe.read(timeout=0.01) + assert result is None + + await pipe.push("decode_error") + result = await tpipe.read(timeout=0.1) + assert result is None + + await tpipe.terminate() + assert pipe._closed is True + + +@pytest.mark.asyncio +async def test_transformer_pipe_error_event_closes_pipe(): + pipe = DummyPipe() + transformer = DummyAsyncTransformer() + tpipe = TransformerPipe(pipe, transformer) + + pipe._emit("error", RuntimeError("boom")) + + await asyncio.sleep(0) + assert tpipe._closed is True + + +@pytest.mark.asyncio +async def test_transformer_pipe_write_encode_error_emits_error(): + class BadEncodeTransformer: + async def encode(self, value: Any) -> Any: + raise ValueError("encode-fail") + + async def decode(self, value: Any) -> Any: + return value + + pipe = DummyPipe() + transformer = TransformerPipe(pipe, BadEncodeTransformer()) + errors: List[str] = [] + + def on_error(err: Exception) -> None: + errors.append(str(err)) + + transformer.on("error", on_error) + result = await transformer.write("data") + assert result is False + assert errors == ["encode-fail"] + + +@pytest.mark.asyncio +async def test_transformer_pipe_read_needs_more_data_timeout(): + class NeedMoreTransformer: + async def encode(self, value: Any) -> Any: + return value + + async def decode(self, value: Any) -> Any: + raise NeedsMoreDataException() + + pipe = DummyPipe() + transformer = TransformerPipe(pipe, NeedMoreTransformer()) + await pipe.push("chunk") + result = await transformer.read(timeout=0) + assert result is None + + +@pytest.mark.asyncio +async def test_transformer_pipe_read_decode_error_emits_error(): + class BadDecodeTransformer: + async def encode(self, value: Any) -> Any: + return value + + async def decode(self, value: Any) -> Any: + raise ValueError("decode-fail") + + pipe = DummyPipe() + transformer = TransformerPipe(pipe, BadDecodeTransformer()) + errors: List[str] = [] + + def on_error(err: Exception) -> None: + errors.append(str(err)) + + transformer.on("error", on_error) + await pipe.push("chunk") + result = await transformer.read(timeout=0.1) + assert result is None + assert errors == ["decode-fail"] + + +@pytest.mark.asyncio +async def test_transformer_pipe_close_error_propagation_and_terminate(): + class PassThroughTransformer: + async def encode(self, value: Any) -> Any: + return value + + async def decode(self, value: Any) -> Any: + return value + + pipe = DummyPipe() + transformer = TransformerPipe(pipe, PassThroughTransformer()) + closed = asyncio.Event() + errors: List[str] = [] + propagated: List[str] = [] + + def on_close(*_args: Any) -> None: + closed.set() + + def on_error(err: Exception) -> None: + errors.append(str(err)) + + def on_pipe_error(err: Exception) -> None: + propagated.append(str(err)) + + transformer.on("close", on_close) + transformer.on("error", on_error) + pipe.on("error", on_pipe_error) + + pipe._emit("error", Exception("pipe-error")) + await asyncio.sleep(0.01) + assert errors == ["pipe-error"] + assert transformer._closed is True + + transformer._propagate_error(Exception("propagate")) + assert propagated[-1:] == ["propagate"] + + await transformer.terminate() + await transformer.terminate() + assert pipe.terminated is True + + pipe._emit("close", "done") + await asyncio.wait_for(closed.wait(), timeout=1)