From f39d8f5279625f1c087fba0c3287771e81cde741 Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Mon, 9 Jun 2025 20:20:12 +0545 Subject: [PATCH 01/25] Move modules, remove examples --- cbor_rpc/__init__.py | 2 +- cbor_rpc/{server.py => rpc_server.py} | 0 .../tcp/__init__.py | 0 cbor_rpc/{ => tcp}/tcp.py | 4 +- examples/json_transformer_example.py | 229 ------------------ examples/tcp_rpc_example.py | 168 ------------- 6 files changed, 3 insertions(+), 400 deletions(-) rename cbor_rpc/{server.py => rpc_server.py} (100%) rename examples/rpc_backends_example.py => cbor_rpc/tcp/__init__.py (100%) rename cbor_rpc/{ => tcp}/tcp.py (99%) delete mode 100644 examples/json_transformer_example.py delete mode 100644 examples/tcp_rpc_example.py diff --git a/cbor_rpc/__init__.py b/cbor_rpc/__init__.py index 60c384f..8ce4df7 100644 --- a/cbor_rpc/__init__.py +++ b/cbor_rpc/__init__.py @@ -7,7 +7,7 @@ from .transformer import Transformer from .promise import DeferredPromise from .client import RpcClient, RpcAuthorizedClient, RpcV1 -from .server import RpcServer, RpcV1Server +from .rpc_server import RpcServer, RpcV1Server from .server_base import Server from .tcp import TcpPipe, TcpServer from .json_transformer import JsonTransformer diff --git a/cbor_rpc/server.py b/cbor_rpc/rpc_server.py similarity index 100% rename from cbor_rpc/server.py rename to cbor_rpc/rpc_server.py diff --git a/examples/rpc_backends_example.py b/cbor_rpc/tcp/__init__.py similarity index 100% rename from examples/rpc_backends_example.py rename to cbor_rpc/tcp/__init__.py diff --git a/cbor_rpc/tcp.py b/cbor_rpc/tcp/tcp.py similarity index 99% rename from cbor_rpc/tcp.py rename to cbor_rpc/tcp/tcp.py index dbdefcc..d28c538 100644 --- a/cbor_rpc/tcp.py +++ b/cbor_rpc/tcp/tcp.py @@ -1,8 +1,8 @@ import asyncio import socket from typing import Any, Callable, Optional, Tuple, Union -from .async_pipe import Pipe -from .server_base import Server +from cbor_rpc.async_pipe import Pipe +from cbor_rpc.server_base import Server class TcpPipe(Pipe[bytes, bytes]): diff --git a/examples/json_transformer_example.py b/examples/json_transformer_example.py deleted file mode 100644 index 10ca713..0000000 --- a/examples/json_transformer_example.py +++ /dev/null @@ -1,229 +0,0 @@ -""" -Example demonstrating JsonTransformer usage with different server types. -""" - -import asyncio -import json -from typing import Any, List -from cbor_rpc import ( - Server, TcpServer, TcpPipe, JsonTransformer, - RpcV1, RpcV1Server, Pipe -) - - -class JsonRpcServer(RpcV1Server): - """RPC server that works with JSON-transformed data.""" - - def __init__(self): - super().__init__() - self._message_log = [] - - async def handle_method_call(self, connection_id: str, method: str, args: List[Any]) -> Any: - """Handle RPC method calls.""" - self._message_log.append(f"JSON RPC call: {method} from {connection_id}") - - if method == "echo": - return {"echo": args[0] if args else None, "timestamp": "2024-01-01"} - elif method == "add_numbers": - return {"result": sum(args), "operation": "addition"} - elif method == "get_log": - return {"log": self._message_log.copy()} - elif method == "process_data": - # Process complex JSON data - data = args[0] if args else {} - return { - "processed": True, - "input_keys": list(data.keys()) if isinstance(data, dict) else [], - "input_type": type(data).__name__ - } - else: - raise Exception(f"Unknown method: {method}") - - async def validate_event_broadcast(self, connection_id: str, topic: str, message: Any) -> bool: - """Validate event broadcasts.""" - self._message_log.append(f"JSON event: {topic} from {connection_id}") - return True - - -async def demonstrate_json_over_tcp(): - """Demonstrate JSON RPC over TCP with JsonTransformer.""" - print("=== JSON RPC over TCP Example ===") - - # Create TCP server - tcp_server: Server[TcpPipe] = await TcpServer.create('127.0.0.1', 0) - rpc_server = JsonRpcServer() - - async def on_tcp_connection(tcp_pipe: TcpPipe): - # Wrap TCP pipe with JSON transformer - json_pipe = JsonTransformer(tcp_pipe) - - # Create connection ID - peer_info = tcp_pipe.get_peer_info() - conn_id = f"json-tcp-{peer_info[0]}:{peer_info[1]}" if peer_info else "unknown" - - # Add JSON-transformed connection to RPC server - await rpc_server.add_connection(conn_id, json_pipe) - print(f"New JSON RPC client connected: {conn_id}") - - tcp_server.on_connection(on_tcp_connection) - - try: - # Get server address - host, port = tcp_server.get_address() - print(f"JSON RPC server listening on {host}:{port}") - - # Create client with JSON transformer - tcp_client_pipe = await TcpPipe.create_connection(host, port) - json_client_pipe = JsonTransformer(tcp_client_pipe) - - # Create RPC client - rpc_client = RpcV1.make_rpc_v1( - json_client_pipe, - "json-client", - lambda m, a: {"client_response": f"Handled {m}"}, - lambda t, m: print(f"Client received event: {t} -> {m}") - ) - - await asyncio.sleep(0.1) # Let connection establish - - # Test JSON RPC calls - print("\n--- Testing JSON RPC calls ---") - - # Simple echo - result = await rpc_client.call_method("echo", "Hello JSON World!") - print(f"Echo result: {json.dumps(result, indent=2)}") - - # Math operation - result = await rpc_client.call_method("add_numbers", 10, 20, 30) - print(f"Add result: {json.dumps(result, indent=2)}") - - # Complex data processing - complex_data = { - "users": [ - {"name": "Alice", "age": 30}, - {"name": "Bob", "age": 25} - ], - "metadata": { - "version": "1.0", - "timestamp": "2024-01-01T00:00:00Z" - } - } - result = await rpc_client.call_method("process_data", complex_data) - print(f"Process data result: {json.dumps(result, indent=2)}") - - # Get server log - result = await rpc_client.call_method("get_log") - print(f"Server log: {json.dumps(result, indent=2)}") - - # Test event emission - await rpc_client.emit("user_action", { - "action": "login", - "user": "alice", - "timestamp": "2024-01-01T12:00:00Z" - }) - print("Emitted JSON event") - - await tcp_client_pipe.terminate() - - finally: - await tcp_server.stop() - print("JSON RPC server stopped") - - -async def demonstrate_json_transformer_pair(): - """Demonstrate direct JsonTransformer pair communication.""" - print("\n=== Direct JSON Transformer Pair Example ===") - - # Create JSON transformer pair - json_transformer1, json_transformer2 = JsonTransformer.create_pair() - - # Set up data handlers - transformer1_received = [] - transformer2_received = [] - - json_transformer1.on("data", lambda data: transformer1_received.append(data)) - json_transformer2.on("data", lambda data: transformer2_received.append(data)) - - # Test bidirectional JSON communication - test_data_1 = { - "message": "Hello from transformer 1", - "data": [1, 2, 3, {"nested": True}] - } - - test_data_2 = { - "response": "Hello from transformer 2", - "status": "success", - "metadata": {"processed_at": "2024-01-01"} - } - - await json_transformer1.write(test_data_1) - await json_transformer2.write(test_data_2) - await asyncio.sleep(0.1) - - print("Transformer 1 sent:", json.dumps(test_data_1, indent=2)) - print("Transformer 2 received:", json.dumps(transformer2_received[0], indent=2)) - print() - print("Transformer 2 sent:", json.dumps(test_data_2, indent=2)) - print("Transformer 1 received:", json.dumps(transformer1_received[0], indent=2)) - - # Verify data integrity - assert transformer2_received[0] == test_data_1 - assert transformer1_received[0] == test_data_2 - print("\nāœ… JSON data integrity verified!") - - -async def demonstrate_json_error_handling(): - """Demonstrate JSON transformer error handling.""" - print("\n=== JSON Error Handling Example ===") - - pipe1, pipe2 = Pipe.create_pair() - json_transformer = JsonTransformer(pipe1) - - errors = [] - json_transformer.on("error", lambda err: errors.append(str(err))) - - # Test encoding error - class NonSerializable: - def __repr__(self): - return "NonSerializable()" - - print("Testing encoding error...") - result = await json_transformer.write(NonSerializable()) - print(f"Write result: {result}") - await asyncio.sleep(0.01) - - if errors: - print(f"Encoding error caught: {errors[-1]}") - - # Test decoding error - print("\nTesting decoding error...") - await pipe2.write(b'{"invalid": json syntax}') - await asyncio.sleep(0.01) - - if len(errors) > 1: - print(f"Decoding error caught: {errors[-1]}") - - # Test recovery - print("\nTesting error recovery...") - valid_data = {"message": "Recovery successful"} - await pipe2.write(json.dumps(valid_data).encode('utf-8')) - - received_data = [] - json_transformer.on("data", lambda data: received_data.append(data)) - await asyncio.sleep(0.01) - - if received_data: - print(f"Recovery successful: {json.dumps(received_data[0], indent=2)}") - - print(f"\nTotal errors handled: {len(errors)}") - - -async def main(): - """Run all JSON transformer demonstrations.""" - await demonstrate_json_over_tcp() - await demonstrate_json_transformer_pair() - await demonstrate_json_error_handling() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/tcp_rpc_example.py b/examples/tcp_rpc_example.py deleted file mode 100644 index 216f489..0000000 --- a/examples/tcp_rpc_example.py +++ /dev/null @@ -1,168 +0,0 @@ -""" -Example demonstrating how to use TcpPipe with CBOR-RPC. -This example shows both client and server implementations using TCP transport. -""" - -import asyncio -import json -from typing import Any, List -from cbor_rpc import TcpPipe, TcpServer, RpcV1, RpcV1Server - - -class TcpRpcServer(RpcV1Server): - """An RPC server that accepts TCP connections.""" - - def __init__(self, host: str = '127.0.0.1', port: int = 0): - super().__init__() - self.host = host - self.port = port - self.tcp_server: TcpServer = None - - async def start(self): - """Start the TCP server and begin accepting connections.""" - self.tcp_server = await TcpServer.create(self.host, self.port) - - async def on_connection(tcp_duplex: TcpPipe): - # Create a unique connection ID - peer_info = tcp_duplex.get_peer_info() - conn_id = f"{peer_info[0]}:{peer_info[1]}" if peer_info else "unknown" - - # Add the TCP connection as an RPC client - await self.add_connection(conn_id, tcp_duplex) - print(f"New RPC client connected: {conn_id}") - - self.tcp_server.on_connection(on_connection) - - actual_host, actual_port = self.tcp_server.get_address() - print(f"RPC server listening on {actual_host}:{actual_port}") - return actual_host, actual_port - - async def stop(self): - """Stop the TCP server.""" - if self.tcp_server: - await self.tcp_server.close() - - async def handle_method_call(self, connection_id: str, method: str, args: List[Any]) -> Any: - """Handle RPC method calls.""" - print(f"Method call from {connection_id}: {method}({args})") - - if method == "echo": - return args[0] if args else None - elif method == "add": - return sum(args) if args else 0 - elif method == "get_server_info": - return { - "server": "TcpRpcServer", - "connections": len(self.active_connections), - "address": self.tcp_server.get_address() if self.tcp_server else None - } - else: - raise Exception(f"Unknown method: {method}") - - async def validate_event_broadcast(self, connection_id: str, topic: str, message: Any) -> bool: - """Validate whether an event should be broadcasted.""" - # Allow all events for this example - return True - - -async def run_server(): - """Run the RPC server example.""" - server = TcpRpcServer() - - try: - host, port = await server.start() - print(f"Server started on {host}:{port}") - - # Keep the server running - while True: - await asyncio.sleep(1) - - except KeyboardInterrupt: - print("Shutting down server...") - finally: - await server.stop() - - -async def run_client(host: str, port: int): - """Run the RPC client example.""" - try: - # Create TCP connection - tcp_duplex = await TcpPipe.create_connection(host, port) - print(f"Connected to server at {host}:{port}") - - # Create RPC client - def method_handler(method: str, args: List[Any]) -> Any: - print(f"Server called method: {method}({args})") - return f"Client response to {method}" - - async def event_handler(topic: str, message: Any) -> None: - print(f"Received event: {topic} -> {message}") - - rpc_client = RpcV1.make_rpc_v1(tcp_duplex, "client", method_handler, event_handler) - - # Test method calls - print("\n--- Testing RPC method calls ---") - - result = await rpc_client.call_method("echo", "Hello, Server!") - print(f"Echo result: {result}") - - result = await rpc_client.call_method("add", 1, 2, 3, 4, 5) - print(f"Add result: {result}") - - result = await rpc_client.call_method("get_server_info") - print(f"Server info: {json.dumps(result, indent=2)}") - - # Test fire method (no response expected) - await rpc_client.fire_method("echo", "Fire and forget message") - print("Fired method (no response)") - - # Test event emission - await rpc_client.emit("client_event", {"message": "Hello from client", "timestamp": "2024-01-01"}) - print("Emitted event") - - # Keep connection alive for a bit - await asyncio.sleep(2) - - except Exception as e: - print(f"Client error: {e}") - finally: - if 'tcp_duplex' in locals(): - await tcp_duplex.terminate() - print("Client disconnected") - - -async def main(): - """Main example function.""" - import sys - - if len(sys.argv) > 1 and sys.argv[1] == "server": - await run_server() - elif len(sys.argv) > 3 and sys.argv[1] == "client": - host = sys.argv[2] - port = int(sys.argv[3]) - await run_client(host, port) - else: - print("Usage:") - print(" python tcp_rpc_example.py server") - print(" python tcp_rpc_example.py client ") - print("\nRunning integrated example...") - - # Run integrated example - server = TcpRpcServer() - - try: - # Start server - host, port = await server.start() - - # Give server time to start - await asyncio.sleep(0.1) - - # Run client - await run_client(host, port) - - finally: - await server.stop() - - -if __name__ == "__main__": - asyncio.run(main()) From cbcc9a8282190c5a1f0db92af8aaf3a11a33c91b Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Mon, 9 Jun 2025 22:36:35 +0545 Subject: [PATCH 02/25] WIP Refactor packages --- .vscode/settings.json | 7 ++ README.md | 2 + cbor_rpc/__init__.py | 20 ++-- cbor_rpc/event/__init__.py | 0 cbor_rpc/{ => event}/emitter.py | 26 ++++-- cbor_rpc/pipe/__init__.py | 0 .../{async_pipe.py => pipe/event_pipe.py} | 10 +- cbor_rpc/{sync_pipe.py => pipe/pipe.py} | 12 +-- cbor_rpc/{ => pipe}/server_base.py | 16 +++- cbor_rpc/promise.py | 2 +- cbor_rpc/rpc/__init__.py | 0 cbor_rpc/rpc/rpc_base.py | 68 ++++++++++++++ cbor_rpc/{ => rpc}/rpc_server.py | 47 ++-------- cbor_rpc/{client.py => rpc/rpc_v1.py} | 42 ++------- cbor_rpc/tcp/__init__.py | 3 + cbor_rpc/tcp/tcp.py | 37 ++------ cbor_rpc/transformer.py | 3 - .../{ => transformer}/json_transformer.py | 10 +- cbor_rpc/transformer/transformer.py | 12 +-- examples/fs_rpc/__init__.py | 0 examples/fs_rpc/filesystem_client.py | 57 ++++++++++++ examples/fs_rpc/filesystem_server.py | 91 +++++++++++++++++++ setup.py | 23 +++++ tests/helpers/simple_pipe.py | 4 +- tests/test_async_pipe.py | 14 +-- tests/test_event_emitter.py | 2 +- tests/test_json_transformer.py | 16 ++-- tests/test_rpc_v1.py | 4 +- tests/test_server_generics.py | 6 +- tests/test_sync_pipe.py | 20 ++-- tests/test_transformer.py | 12 +-- 31 files changed, 377 insertions(+), 189 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 README.md create mode 100644 cbor_rpc/event/__init__.py rename cbor_rpc/{ => event}/emitter.py (70%) create mode 100644 cbor_rpc/pipe/__init__.py rename cbor_rpc/{async_pipe.py => pipe/event_pipe.py} (91%) rename cbor_rpc/{sync_pipe.py => pipe/pipe.py} (93%) rename cbor_rpc/{ => pipe}/server_base.py (87%) create mode 100644 cbor_rpc/rpc/__init__.py create mode 100644 cbor_rpc/rpc/rpc_base.py rename cbor_rpc/{ => rpc}/rpc_server.py (74%) rename cbor_rpc/{client.py => rpc/rpc_v1.py} (85%) delete mode 100644 cbor_rpc/transformer.py rename cbor_rpc/{ => transformer}/json_transformer.py (90%) create mode 100644 examples/fs_rpc/__init__.py create mode 100644 examples/fs_rpc/filesystem_client.py create mode 100644 examples/fs_rpc/filesystem_server.py create mode 100644 setup.py diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..1be6a53 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..fde5239 --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +cbor-rpc +======== \ No newline at end of file diff --git a/cbor_rpc/__init__.py b/cbor_rpc/__init__.py index 8ce4df7..4bdc98e 100644 --- a/cbor_rpc/__init__.py +++ b/cbor_rpc/__init__.py @@ -2,28 +2,28 @@ CBOR-RPC: An async-compatible CBOR-based RPC system """ -from .emitter import AbstractEmitter -from .async_pipe import Pipe +from .event.emitter import AbstractEmitter +from .pipe.event_pipe import EventPipe from .transformer import Transformer -from .promise import DeferredPromise -from .client import RpcClient, RpcAuthorizedClient, RpcV1 -from .rpc_server import RpcServer, RpcV1Server -from .server_base import Server +from .promise import TimedPromise +from .rpc.rpc_base import RpcClient, RpcAuthorizedClient, RpcV1 +from .rpc.rpc_server import RpcServer, RpcV1Server +from .pipe.server_base import Server from .tcp import TcpPipe, TcpServer -from .json_transformer import JsonTransformer -from .sync_pipe import SyncPipe +from .transformer.json_transformer import JsonTransformer +from .pipe.pipe import Pipe __all__ = [ # Emitter 'AbstractEmitter', # Pipe classes + 'EventPipe', 'Pipe', - 'SyncPipe', 'Transformer', # Promise - 'DeferredPromise', + 'TimedPromise', # Client classes 'RpcClient', diff --git a/cbor_rpc/event/__init__.py b/cbor_rpc/event/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cbor_rpc/emitter.py b/cbor_rpc/event/emitter.py similarity index 70% rename from cbor_rpc/emitter.py rename to cbor_rpc/event/emitter.py index 98a5506..89c59b5 100644 --- a/cbor_rpc/emitter.py +++ b/cbor_rpc/event/emitter.py @@ -2,6 +2,8 @@ from abc import ABC, abstractmethod import asyncio import inspect +import traceback +import warnings class AbstractEmitter(ABC): def __init__(self): @@ -17,16 +19,26 @@ def unsubscribe(self, event: str, handler: Callable) -> None: def replace_on_handler(self, event_type: str, handler: Callable) -> None: self._subscribers[event_type] = [handler] - async def _emit(self, event_type: str, *args: Any) -> None: + def _run_background_task(self, coro: Callable[..., Any], *args: Any) -> None: + async def runner(): + try: + await coro(*args) + except Exception as e: + traceback.print_exc() + warnings.warn(f"Background task error in handler: {e}", RuntimeWarning) + + asyncio.create_task(runner()) + + def _emit(self, event_type: str, *args: Any) -> None: for sub in self._subscribers.get(event_type, []): try: if inspect.iscoroutinefunction(sub): - await sub(*args) + self._run_background_task(sub, *args) else: sub(*args) except Exception as e: - # We should log the exception but not propagate it - print(f"Error in event handler: {e}") + traceback.print_exc() + warnings.warn(f"Synchronous error in handler: {e}", RuntimeWarning) async def _notify(self, event_type: str, *args: Any) -> None: tasks = [] @@ -39,17 +51,17 @@ async def _notify(self, event_type: str, *args: Any) -> None: try: pipeline(*args) except Exception as e: - await self._emit("error", e) + self._emit("error", e) raise e if tasks: results = await asyncio.gather(*tasks, return_exceptions=True) for result in results: if isinstance(result, Exception): - await self._emit("error", result) + self._emit("error", result) raise result - await self._emit(event_type, *args) + self._emit(event_type, *args) def on(self, event: str, handler: Callable) -> None: self._subscribers.setdefault(event, []).append(handler) diff --git a/cbor_rpc/pipe/__init__.py b/cbor_rpc/pipe/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cbor_rpc/async_pipe.py b/cbor_rpc/pipe/event_pipe.py similarity index 91% rename from cbor_rpc/async_pipe.py rename to cbor_rpc/pipe/event_pipe.py index b136337..713e464 100644 --- a/cbor_rpc/async_pipe.py +++ b/cbor_rpc/pipe/event_pipe.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod import asyncio import inspect -from .emitter import AbstractEmitter +from ..event.emitter import AbstractEmitter import queue import threading from typing import Union @@ -12,7 +12,7 @@ T2 = TypeVar('T2') -class Pipe(AbstractEmitter, Generic[T1, T2]): +class EventPipe(AbstractEmitter, Generic[T1, T2]): """ Async Pipe or simply Pipe are event based way for read/write. You cannot directly read from a Pipe. You have to use a on("data") handler registration. @@ -26,7 +26,7 @@ async def terminate(self, *args: Any) -> None: pass @staticmethod - def attach(source: 'Pipe[Any, Any]', destination: 'Pipe[Any, Any]') -> None: + def attach(source: 'EventPipe[Any, Any]', destination: 'EventPipe[Any, Any]') -> None: async def source_to_destination(chunk: Any): await destination.write(chunk) @@ -41,14 +41,14 @@ async def close_handler(*args: Any): source.on("close", close_handler) @staticmethod - def create_pair() -> Tuple['Pipe[Any, Any]', 'Pipe[Any, Any]']: + def create_pair() -> Tuple['EventPipe[Any, Any]', 'EventPipe[Any, Any]']: """ Create a pair of connected pipes for bidirectional communication. Returns: A tuple of (pipe1, pipe2) where data written to pipe1 is emitted on pipe2 and vice versa. """ - class ConnectedPipe(Pipe[Any, Any]): + class ConnectedPipe(EventPipe[Any, Any]): def __init__(self): super().__init__() self.connected_pipe: Optional['ConnectedPipe'] = None diff --git a/cbor_rpc/sync_pipe.py b/cbor_rpc/pipe/pipe.py similarity index 93% rename from cbor_rpc/sync_pipe.py rename to cbor_rpc/pipe/pipe.py index 55a2935..b4cdd04 100644 --- a/cbor_rpc/sync_pipe.py +++ b/cbor_rpc/pipe/pipe.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod import asyncio import inspect -from .emitter import AbstractEmitter +from ..event.emitter import AbstractEmitter import queue import threading from typing import Union @@ -13,7 +13,7 @@ -class SyncPipe(Generic[T1, T2]): +class Pipe(Generic[T1, T2]): """ Synchronous pipe uses read/write methods instead of events. Explicit read/write is used for writing protocols that have multiple steps. @@ -23,7 +23,7 @@ def __init__(self): self._closed = False self._queue: queue.Queue = queue.Queue() self._sent_data: queue.Queue = queue.Queue() # Track data sent to other pipe for size reporting - self._connected_pipe: Optional['SyncPipe'] = None + self._connected_pipe: Optional['Pipe'] = None def read(self, timeout: Optional[float] = None) -> Optional[T2]: """ @@ -122,15 +122,15 @@ def available(self) -> int: return local_size @staticmethod - def create_pair() -> Tuple['SyncPipe[Any, Any]', 'SyncPipe[Any, Any]']: + def create_pair() -> Tuple['Pipe[Any, Any]', 'Pipe[Any, Any]']: """ Create a pair of connected sync pipes for bidirectional communication. Returns: A tuple of (pipe1, pipe2) where data written to pipe1 can be read from pipe2 and vice versa. """ - pipe1 = SyncPipe[Any, Any]() - pipe2 = SyncPipe[Any, Any]() + pipe1 = Pipe[Any, Any]() + pipe2 = Pipe[Any, Any]() # Connect them bidirectionally pipe1._connected_pipe = pipe2 diff --git a/cbor_rpc/server_base.py b/cbor_rpc/pipe/server_base.py similarity index 87% rename from cbor_rpc/server_base.py rename to cbor_rpc/pipe/server_base.py index cc208cb..a6570ed 100644 --- a/cbor_rpc/server_base.py +++ b/cbor_rpc/pipe/server_base.py @@ -1,11 +1,11 @@ from typing import Any, Callable, Optional, Set, TypeVar, Generic from abc import ABC, abstractmethod import asyncio -from .emitter import AbstractEmitter -from .async_pipe import Pipe +from ..event.emitter import AbstractEmitter +from .event_pipe import EventPipe # Generic type variable for pipe types -P = TypeVar('P', bound=Pipe) +P = TypeVar('P', bound=EventPipe) class Server(AbstractEmitter, Generic[P]): @@ -35,6 +35,10 @@ async def stop(self) -> None: """Stop the server and clean up resources.""" pass + @abstractmethod + async def accept(self,pipe:P) -> bool: + pass + async def _add_connection(self, pipe: P) -> None: """ Add a new connection and emit a connection event. @@ -42,15 +46,17 @@ async def _add_connection(self, pipe: P) -> None: Args: pipe: The pipe representing the connection """ + if not self.accept(pipe): + pipe.terminate() + return self._connections.add(pipe) - # Set up cleanup when connection closes async def cleanup(*args): self._connections.discard(pipe) pipe.on("close", cleanup) # Emit connection event - await self._emit("connection", pipe) + await self._notify("connection", pipe) def get_connections(self) -> Set[P]: """Get all active connections.""" diff --git a/cbor_rpc/promise.py b/cbor_rpc/promise.py index 66411dc..696f6e2 100644 --- a/cbor_rpc/promise.py +++ b/cbor_rpc/promise.py @@ -2,7 +2,7 @@ import asyncio -class DeferredPromise: +class TimedPromise: def __init__(self, timeout_ms: int, timeout_cb: Optional[Callable[[], None]] = None, message: str = "Timeout on RPC call"): self._timeout_ms = timeout_ms diff --git a/cbor_rpc/rpc/__init__.py b/cbor_rpc/rpc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cbor_rpc/rpc/rpc_base.py b/cbor_rpc/rpc/rpc_base.py new file mode 100644 index 0000000..71b6b19 --- /dev/null +++ b/cbor_rpc/rpc/rpc_base.py @@ -0,0 +1,68 @@ +from typing import Any, Dict, List, Optional, Callable +from abc import ABC, abstractmethod +import asyncio +import inspect +from ..pipe.event_pipe import EventPipe +from ..promise import TimedPromise + + +class RpcClient(ABC): + @abstractmethod + async def emit(self, topic: str, message: Any) -> None: + pass + + @abstractmethod + async def call_method(self, method: str, *args: Any) -> Any: + pass + + @abstractmethod + async def fire_method(self, method: str, *args: Any) -> None: + pass + + @abstractmethod + def set_timeout(self, milliseconds: int) -> None: + pass + + +class RpcAuthorizedClient(RpcClient): + @abstractmethod + def get_id(self) -> str: + pass + + +class RpcServer(ABC): + @abstractmethod + async def emit(self, connection_id: str, topic: str, message: Any) -> None: + pass + + @abstractmethod + async def broadcast(self, topic: str, message: Any) -> None: + pass + + @abstractmethod + async def call_method(self, connection_id: str, method: str, *args: Any) -> Any: + pass + + @abstractmethod + async def fire_method(self, connection_id: str, method: str, *args: Any) -> None: + pass + + @abstractmethod + async def disconnect(self, connection_id: str, reason: Optional[str] = None) -> None: + pass + + @abstractmethod + def get_client(self, connection_id: str) -> Optional[RpcAuthorizedClient]: + pass + + @abstractmethod + def with_client(self, connection_id: str, action: Callable) -> bool: + pass + + @abstractmethod + def set_timeout(self, milliseconds: int) -> None: + pass + + @abstractmethod + def is_active(self, connection_id: str) -> bool: + pass diff --git a/cbor_rpc/rpc_server.py b/cbor_rpc/rpc/rpc_server.py similarity index 74% rename from cbor_rpc/rpc_server.py rename to cbor_rpc/rpc/rpc_server.py index 22749a3..55456c2 100644 --- a/cbor_rpc/rpc_server.py +++ b/cbor_rpc/rpc/rpc_server.py @@ -1,54 +1,19 @@ from typing import Any, Dict, List, Optional, Callable from abc import ABC, abstractmethod import asyncio -from .client import RpcClient, RpcAuthorizedClient, RpcV1 -from .async_pipe import Pipe - -class RpcServer(ABC): - @abstractmethod - async def emit(self, connection_id: str, topic: str, message: Any) -> None: - pass - - @abstractmethod - async def broadcast(self, topic: str, message: Any) -> None: - pass - - @abstractmethod - async def call_method(self, connection_id: str, method: str, *args: Any) -> Any: - pass - - @abstractmethod - async def fire_method(self, connection_id: str, method: str, *args: Any) -> None: - pass - - @abstractmethod - async def disconnect(self, connection_id: str, reason: Optional[str] = None) -> None: - pass - - @abstractmethod - def get_client(self, connection_id: str) -> Optional[RpcAuthorizedClient]: - pass - - @abstractmethod - def with_client(self, connection_id: str, action: Callable) -> bool: - pass - - @abstractmethod - def set_timeout(self, milliseconds: int) -> None: - pass - - @abstractmethod - def is_active(self, connection_id: str) -> bool: - pass +from cbor_rpc.pipe.server_base import Server +from .rpc_base import RpcClient, RpcAuthorizedClient, RpcServer +from .rpc_v1 import RpcV1 +from cbor_rpc.pipe.event_pipe import EventPipe class RpcV1Server(RpcServer): - def __init__(self): + def __init__(self,server:Server): self.active_connections: Dict[str, RpcV1] = {} self.timeout = 30000 - async def add_connection(self, conn_id: str, rpc_client: Pipe[Any, Any]) -> None: + async def add_connection(self, conn_id: str, rpc_client: EventPipe[Any, Any]) -> None: def method_handler(method: str, args: List[Any]) -> Any: return self.handle_method_call(conn_id, method, args) diff --git a/cbor_rpc/client.py b/cbor_rpc/rpc/rpc_v1.py similarity index 85% rename from cbor_rpc/client.py rename to cbor_rpc/rpc/rpc_v1.py index 9392980..7608efe 100644 --- a/cbor_rpc/client.py +++ b/cbor_rpc/rpc/rpc_v1.py @@ -2,41 +2,19 @@ from abc import ABC, abstractmethod import asyncio import inspect -from .async_pipe import Pipe -from .promise import DeferredPromise - -class RpcClient(ABC): - @abstractmethod - async def emit(self, topic: str, message: Any) -> None: - pass - - @abstractmethod - async def call_method(self, method: str, *args: Any) -> Any: - pass - - @abstractmethod - async def fire_method(self, method: str, *args: Any) -> None: - pass - - @abstractmethod - def set_timeout(self, milliseconds: int) -> None: - pass - - -class RpcAuthorizedClient(RpcClient): - @abstractmethod - def get_id(self) -> str: - pass +from .rpc_base import RpcClient +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.promise import TimedPromise class RpcV1(RpcClient): - def __init__(self, pipe: Pipe[Any, Any]): + def __init__(self, pipe: EventPipe[Any, Any]): self.pipe = pipe self._counter = 0 - self._promises: Dict[int, DeferredPromise] = {} + self._promises: Dict[int, TimedPromise] = {} self._timeout = 30000 - self._waiters: Dict[str, DeferredPromise] = {} + self._waiters: Dict[str, TimedPromise] = {} async def resolve_result(result: Any) -> Any: """Recursively resolve coroutines or nested coroutines.""" @@ -108,7 +86,7 @@ async def call_method(self, method: str, *args: Any) -> Any: def timeout_callback(): self._promises.pop(counter, None) - promise = DeferredPromise(self._timeout, timeout_callback) + promise = TimedPromise(self._timeout, timeout_callback) self._promises[counter] = promise await self.pipe.write([1, 0, counter, method, list(args)]) return await promise.promise @@ -142,7 +120,7 @@ async def wait_next_event(self, topic: str, timeout_ms: Optional[int] = None) -> def timeout_callback(): self._waiters.pop(topic, None) - waiter = DeferredPromise( + waiter = TimedPromise( timeout_ms or self._timeout, timeout_callback, f"Timeout Waiting for Event on: {topic}" @@ -159,7 +137,7 @@ async def on_event(self, topic: str, message: Any) -> None: pass @staticmethod - def make_rpc_v1(pipe: Pipe[Any, Any], id_: str, method_handler: Callable, event_handler: Callable) -> 'RpcV1': + def make_rpc_v1(pipe: EventPipe[Any, Any], id_: str, method_handler: Callable, event_handler: Callable) -> 'RpcV1': class ConcreteRpcV1(RpcV1): def get_id(self) -> str: return id_ @@ -175,7 +153,7 @@ async def on_event(self, topic: str, message: Any) -> None: return ConcreteRpcV1(pipe) @staticmethod - def read_only_client(pipe: Pipe[Any, Any]) -> 'RpcV1': + def read_only_client(pipe: EventPipe[Any, Any]) -> 'RpcV1': def method_handler(method: str, args: List[Any]) -> Any: raise Exception("Client Only Implementation") diff --git a/cbor_rpc/tcp/__init__.py b/cbor_rpc/tcp/__init__.py index e69de29..fafafd4 100644 --- a/cbor_rpc/tcp/__init__.py +++ b/cbor_rpc/tcp/__init__.py @@ -0,0 +1,3 @@ +from .tcp import TcpPipe, TcpServer + +__all__ = ['TcpPipe', 'TcpServer'] diff --git a/cbor_rpc/tcp/tcp.py b/cbor_rpc/tcp/tcp.py index d28c538..7f31c1c 100644 --- a/cbor_rpc/tcp/tcp.py +++ b/cbor_rpc/tcp/tcp.py @@ -1,11 +1,12 @@ +from abc import abstractmethod import asyncio import socket from typing import Any, Callable, Optional, Tuple, Union -from cbor_rpc.async_pipe import Pipe -from cbor_rpc.server_base import Server +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.pipe.server_base import Server -class TcpPipe(Pipe[bytes, bytes]): +class TcpPipe(EventPipe[bytes, bytes]): """ A TCP duplex pipe that implements Pipe for network communication. Provides both client and server functionality for TCP connections. @@ -109,22 +110,6 @@ async def on_connection(pipe: TcpPipe): await server.stop() raise - @classmethod - def from_socket(cls, sock: socket.socket) -> 'TcpPipe': - """ - Create a TcpPipe from an existing socket. - - Args: - sock: An existing connected socket - - Returns: - A TcpPipe instance wrapping the socket - """ - # This will be set up when the connection is established - tcp_duplex = cls() - tcp_duplex._socket = sock - return tcp_duplex - async def connect(self, host: str, port: int, timeout: Optional[float] = None) -> None: """ Connect to a remote TCP server. @@ -370,17 +355,11 @@ def get_address(self) -> Tuple[str, int]: if self._server and self._server.sockets: return self._server.sockets[0].getsockname()[:2] return ("", 0) - - # Legacy methods for backward compatibility - def on_connection(self, handler: Callable[[TcpPipe], None]) -> None: - """ - Set a handler for new connections. - - Args: - handler: A function that takes a TcpPipe as argument - """ - super().on_connection(handler) + @abstractmethod + async def accept(self,pipe:TcpPipe) -> bool: + pass + async def close(self) -> None: """Legacy method - use stop() instead.""" await self.stop() diff --git a/cbor_rpc/transformer.py b/cbor_rpc/transformer.py deleted file mode 100644 index 3e713c7..0000000 --- a/cbor_rpc/transformer.py +++ /dev/null @@ -1,3 +0,0 @@ -from .transformer.transformer import Transformer - -__all__ = ['Transformer'] diff --git a/cbor_rpc/json_transformer.py b/cbor_rpc/transformer/json_transformer.py similarity index 90% rename from cbor_rpc/json_transformer.py rename to cbor_rpc/transformer/json_transformer.py index f01ef42..f0b73fe 100644 --- a/cbor_rpc/json_transformer.py +++ b/cbor_rpc/transformer/json_transformer.py @@ -1,15 +1,15 @@ import json from typing import Any, Union -from .emitter import AbstractEmitter -from .transformer import Transformer -from .async_pipe import Pipe +from ..event.emitter import AbstractEmitter +from . import Transformer +from ..pipe.event_pipe import EventPipe class JsonTransformer(AbstractEmitter, Transformer[Any, Any]): """ A transformer that encodes Python objects to JSON strings and decodes JSON strings back to Python objects. """ - def __init__(self, underlying_pipe: Pipe[Any, Any], encoding: str = 'utf-8'): + def __init__(self, underlying_pipe: EventPipe[Any, Any], encoding: str = 'utf-8'): """ Initialize the JSON transformer. @@ -76,7 +76,7 @@ def create_pair(cls, encoding: str = 'utf-8') -> tuple['JsonTransformer', 'JsonT Returns: A tuple of (transformer1, transformer2) """ - pipe1, pipe2 = Pipe.create_pair() + pipe1, pipe2 = EventPipe.create_pair() transformer1 = cls(pipe1, encoding) transformer2 = cls(pipe2, encoding) return transformer1, transformer2 diff --git a/cbor_rpc/transformer/transformer.py b/cbor_rpc/transformer/transformer.py index 9e1e691..1ab023b 100644 --- a/cbor_rpc/transformer/transformer.py +++ b/cbor_rpc/transformer/transformer.py @@ -2,8 +2,8 @@ from abc import ABC, abstractmethod import asyncio import inspect -from cbor_rpc.async_pipe import Pipe -from cbor_rpc.sync_pipe import SyncPipe +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.pipe.pipe import Pipe import queue import threading from typing import Union @@ -18,10 +18,10 @@ class Transformer(Generic[T1, T2]): Encodes data when writing and decodes data when reading/emitting events. """ - def __init__(self, underlying_pipe: Union[Pipe[Any, Any], SyncPipe[Any, Any]]): + def __init__(self, underlying_pipe: Union[EventPipe[Any, Any], Pipe[Any, Any]]): self.underlying_pipe = underlying_pipe self._closed = False - self._is_sync_pipe = isinstance(underlying_pipe, SyncPipe) + self._is_sync_pipe = isinstance(underlying_pipe, Pipe) if not self._is_sync_pipe: # For async pipes, set up event handlers @@ -162,9 +162,9 @@ def create_pair(encoder1: Callable, decoder1: Callable, A tuple of (transformer1, transformer2) """ if use_sync: - pipe1, pipe2 = SyncPipe.create_pair() - else: pipe1, pipe2 = Pipe.create_pair() + else: + pipe1, pipe2 = EventPipe.create_pair() class ConcreteTransformer1(Transformer): async def encode(self, data): diff --git a/examples/fs_rpc/__init__.py b/examples/fs_rpc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/fs_rpc/filesystem_client.py b/examples/fs_rpc/filesystem_client.py new file mode 100644 index 0000000..b318082 --- /dev/null +++ b/examples/fs_rpc/filesystem_client.py @@ -0,0 +1,57 @@ +import asyncio +from cbor_rpc import RpcV1 +from cbor_rpc.tcp import TcpPipe +from cbor_rpc.transformer.json_transformer import JsonTransformer +from cbor_rpc.pipe.event_pipe import EventPipe # Keep this import for clarity, though not directly instantiated + +async def main(): + # Connect to the RPC server + tcp_pipe = await TcpPipe.create_connection("localhost", 8000) # Use port 8002 + + # Create a JSON transformer and attach the TCP pipe to it + json_transformer = JsonTransformer(tcp_pipe) + + # The RpcV1 client needs a pipe that it can write to and read from. + # The json_transformer now handles both incoming and outgoing data. + rpc_client = RpcV1.read_only_client(json_transformer) + + # Example usage of filesystem RPC methods + + # List files in current directory + files = await rpc_client.call_method("list_files", ".") + print("Files in current directory:", files) + + # Create a test file + create_success = await rpc_client.call_method("create_file", "test.txt") + print("File creation successful:", create_success) + + # Read the test file (should be empty) + content = await rpc_client.call_method("read_file", "test.txt") + print("File content:", content.decode()) + + # Write to the test file + write_success = await rpc_client.call_method( + "create_file", + "test.txt", + b"Hello, world!" + ) + print("Write successful:", write_success) + + # Read the updated file + content = await rpc_client.call_method("read_file", "test.txt") + print("Updated file content:", content.decode()) + + # Rename the file + rename_success = await rpc_client.call_method( + "rename_file", + "test.txt", + "renamed_test.txt" + ) + print("Rename successful:", rename_success) + + # Delete the renamed file + delete_success = await rpc_client.call_method("delete_file", "renamed_test.txt") + print("Delete successful:", delete_success) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/fs_rpc/filesystem_server.py b/examples/fs_rpc/filesystem_server.py new file mode 100644 index 0000000..df97609 --- /dev/null +++ b/examples/fs_rpc/filesystem_server.py @@ -0,0 +1,91 @@ +import os +from typing import List, Optional, Any +from cbor_rpc import RpcV1Server + +class FilesystemRpcServer(RpcV1Server): + async def validate_event_broadcast(self, connection_id, topic, message): + return False + async def handle_method_call(self, connection_id: str, method: str, args: List[Any]) -> Any: + if method == "list_files": + return self.list_files(*args) + elif method == "read_file": + return self.read_file(*args) + elif method == "create_file": + return self.create_file(*args) + elif method == "delete_file": + return self.delete_file(*args) + elif method == "rename_file": + return self.rename_file(*args) + else: + raise Exception(f"Unknown method: {method}") + + def list_files(self, directory: str) -> List[str]: + """Lists files and directories in the given path.""" + try: + return os.listdir(directory) + except Exception as e: + return f"Error listing files: {str(e)}" + + def read_file(self, path: str, chunk_size: int = 4096, offset: int = 0) -> bytes: + """Reads a file in chunks.""" + try: + with open(path, 'rb') as f: + f.seek(offset) + return f.read(chunk_size) + except Exception as e: + return f"Error reading file: {str(e)}".encode() + + def create_file(self, path: str, content: Optional[bytes] = None) -> bool: + """Creates a file with optional initial content.""" + try: + with open(path, 'wb') as f: + if content: + f.write(content) + return True + except Exception as e: + print(f"Error creating file: {str(e)}") + return False + + def delete_file(self, path: str) -> bool: + """Deletes a file.""" + try: + os.remove(path) + return True + except Exception as e: + print(f"Error deleting file: {str(e)}") + return False + + def rename_file(self, src: str, dest: str) -> bool: + """Renames/moves a file.""" + try: + os.rename(src, dest) + return True + except Exception as e: + print(f"Error renaming file: {str(e)}") + return False + +if __name__ == "__main__": + import asyncio + from cbor_rpc.tcp import TcpPipe, TcpServer + from cbor_rpc.transformer.json_transformer import JsonTransformer + async def main(): + rpc_id=1 + # Create a TCP server that handles connections, using JsonTransformer for RPC messages + tcp_server = await TcpServer.create("localhost", 8000) + print("Server running on port 8000") + + # Set up event handlers for new connections + async def handle_connection( rpc_pipe): + server = FilesystemRpcServer() + await server.add_connection(str(rpc_id), rpc_pipe) + + tcp_server.on_connection(handle_connection) + + # Just run until manually stopped + try: + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + await tcp_server.stop() + + asyncio.run(main()) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..ce8d943 --- /dev/null +++ b/setup.py @@ -0,0 +1,23 @@ +from setuptools import setup, find_packages + +setup( + name="cbor-rpc", + version="0.1.0", + description="An async-compatible CBOR-based RPC system", + author="Sudip Bhattarai", + author_email="sudip@bhattarai.me", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + # url="https://github.com/your_username/cbor-rpc", # Replace with your project's URL + packages=find_packages(exclude=["cbor_rpc"]), + install_requires=[ + "pytest>=8.3.2", + "pytest-asyncio>=0.24.0" + ], + python_requires=">=3.8", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], +) diff --git a/tests/helpers/simple_pipe.py b/tests/helpers/simple_pipe.py index a4b4a2c..32a324d 100644 --- a/tests/helpers/simple_pipe.py +++ b/tests/helpers/simple_pipe.py @@ -1,11 +1,11 @@ from typing import Any, Generic, TypeVar -from cbor_rpc import Pipe, RpcV1, DeferredPromise +from cbor_rpc import EventPipe, RpcV1, TimedPromise # Generic type variables T1 = TypeVar('T1') -class SimplePipe(Pipe[T1, T1], Generic[T1]): +class SimplePipe(EventPipe[T1, T1], Generic[T1]): def __init__(self): super().__init__() self._closed = False diff --git a/tests/test_async_pipe.py b/tests/test_async_pipe.py index 03037e0..d64d0c2 100644 --- a/tests/test_async_pipe.py +++ b/tests/test_async_pipe.py @@ -1,9 +1,9 @@ import pytest import asyncio from typing import Any, Tuple -from cbor_rpc import Pipe +from cbor_rpc import EventPipe -class SimplePipe(Pipe): +class SimplePipe(EventPipe): def __init__(self): super().__init__() self._closed = False @@ -28,9 +28,9 @@ def pipe(): @pytest.mark.asyncio async def test_create_pair(): # Positive case: Creating a pair of async pipes - pipe1, pipe2 = Pipe.create_pair() - assert isinstance(pipe1, Pipe) - assert isinstance(pipe2, Pipe) + pipe1, pipe2 = EventPipe.create_pair() + assert isinstance(pipe1, EventPipe) + assert isinstance(pipe2, EventPipe) @pytest.mark.asyncio async def test_write_success(pipe): @@ -59,7 +59,7 @@ async def pipeline_handler(chunk: Any) -> None: @pytest.mark.asyncio async def test_attach_pipes(): # Positive case: Attaching two pipes - pipe1, pipe2 = Pipe.create_pair() + pipe1, pipe2 = EventPipe.create_pair() called = False async def handler(chunk: Any) -> None: @@ -73,7 +73,7 @@ async def handler(chunk: Any) -> None: @pytest.mark.asyncio async def test_write_after_terminate(): # Negative case: Writing to a terminated pipe - pipe1, _ = Pipe.create_pair() + pipe1, _ = EventPipe.create_pair() await pipe1.terminate() result = await pipe1.write("test_chunk") diff --git a/tests/test_event_emitter.py b/tests/test_event_emitter.py index 1573bbd..168ee12 100644 --- a/tests/test_event_emitter.py +++ b/tests/test_event_emitter.py @@ -1,7 +1,7 @@ import pytest import asyncio from typing import Any, Callable -from cbor_rpc.emitter import AbstractEmitter +from cbor_rpc.event.emitter import AbstractEmitter @pytest.mark.asyncio async def test_on_and_emit(): diff --git a/tests/test_json_transformer.py b/tests/test_json_transformer.py index d035158..acff4b1 100644 --- a/tests/test_json_transformer.py +++ b/tests/test_json_transformer.py @@ -3,12 +3,12 @@ import json from typing import Any, Dict, List from cbor_rpc import JsonTransformer -from cbor_rpc import Pipe +from cbor_rpc import EventPipe @pytest.mark.asyncio async def test_json_transformer_basic_encoding_decoding(): """Test basic JSON encoding and decoding.""" - pipe1, pipe2 = Pipe.create_pair() + pipe1, pipe2 = EventPipe.create_pair() transformer = JsonTransformer(pipe1) received_data = [] @@ -84,7 +84,7 @@ async def test_json_transformer_different_data_types(): @pytest.mark.asyncio async def test_json_transformer_encoding_errors(): """Test JSON transformer encoding error handling.""" - pipe1, pipe2 = Pipe.create_pair() + pipe1, pipe2 = EventPipe.create_pair() transformer = JsonTransformer(pipe1) errors = [] transformer.on("error", lambda err: errors.append(str(err))) @@ -107,7 +107,7 @@ class NonSerializable: @pytest.mark.asyncio async def test_json_transformer_decoding_errors(): """Test JSON transformer decoding error handling.""" - pipe1, pipe2 = Pipe.create_pair() + pipe1, pipe2 = EventPipe.create_pair() transformer = JsonTransformer(pipe1) errors = [] transformer.on("error", lambda err: errors.append(str(err))) @@ -127,7 +127,7 @@ async def test_json_transformer_decoding_errors(): @pytest.mark.asyncio async def test_json_transformer_string_input(): """Test JSON transformer with string input (not just bytes).""" - pipe1, pipe2 = Pipe.create_pair() + pipe1, pipe2 = EventPipe.create_pair() transformer = JsonTransformer(pipe1) received_data = [] transformer.on("data", lambda chunk: received_data.append(chunk)) @@ -142,7 +142,7 @@ async def test_json_transformer_string_input(): @pytest.mark.asyncio async def test_json_transformer_custom_encoding(): """Test JSON transformer with custom text encoding.""" - pipe1, pipe2 = Pipe.create_pair() + pipe1, pipe2 = EventPipe.create_pair() transformer = JsonTransformer(pipe1, encoding='latin1') received_data = [] transformer.on("data", lambda chunk: received_data.append(chunk)) @@ -157,7 +157,7 @@ async def test_json_transformer_custom_encoding(): @pytest.mark.asyncio async def test_json_transformer_termination(): """Test JSON transformer termination.""" - pipe1, pipe2 = Pipe.create_pair() + pipe1, pipe2 = EventPipe.create_pair() transformer = JsonTransformer(pipe1) close_events = [] transformer.on("close", lambda *args: close_events.append(args)) @@ -217,7 +217,7 @@ async def send_message(id: int): @pytest.mark.asyncio async def test_json_transformer_error_recovery(): """Test that JSON transformer can recover from errors.""" - pipe1, pipe2 = Pipe.create_pair() + pipe1, pipe2 = EventPipe.create_pair() transformer = JsonTransformer(pipe1) received_data = [] errors = [] diff --git a/tests/test_rpc_v1.py b/tests/test_rpc_v1.py index 427bc14..58ed565 100644 --- a/tests/test_rpc_v1.py +++ b/tests/test_rpc_v1.py @@ -3,7 +3,7 @@ from typing import Any, Generic, List from unittest.mock import AsyncMock, MagicMock -from cbor_rpc import Pipe, RpcV1, DeferredPromise +from cbor_rpc import EventPipe, RpcV1, TimedPromise from tests.helpers import SimplePipe @@ -114,7 +114,7 @@ async def test_wait_next_event_timeout(rpc): @pytest.mark.asyncio async def test_wait_next_event_already_waiting(rpc): - rpc._waiters["test_topic"] = DeferredPromise(1000) + rpc._waiters["test_topic"] = TimedPromise(1000) with pytest.raises(Exception) as exc_info: await rpc.wait_next_event("test_topic") assert str(exc_info.value) == "Already waiting for event" diff --git a/tests/test_server_generics.py b/tests/test_server_generics.py index 6d86474..df25e5b 100644 --- a/tests/test_server_generics.py +++ b/tests/test_server_generics.py @@ -5,10 +5,10 @@ import pytest import asyncio from typing import Any, Set -from cbor_rpc import Server, Pipe, TcpServer, TcpPipe +from cbor_rpc import Server, EventPipe, TcpServer, TcpPipe -class MockPipe(Pipe[str, str]): +class MockPipe(EventPipe[str, str]): """A mock pipe for testing.""" def __init__(self, name: str): @@ -126,7 +126,7 @@ def on_tcp_connection(pipe: TcpPipe): connection_events.append(pipe) # Verify the pipe is the correct type assert isinstance(pipe, TcpPipe) - assert isinstance(pipe, Pipe) + assert isinstance(pipe, EventPipe) tcp_server.on_connection(on_tcp_connection) diff --git a/tests/test_sync_pipe.py b/tests/test_sync_pipe.py index 6dbd15a..4ae825a 100644 --- a/tests/test_sync_pipe.py +++ b/tests/test_sync_pipe.py @@ -1,23 +1,23 @@ import pytest from typing import Any, Tuple -from cbor_rpc.sync_pipe import SyncPipe +from cbor_rpc.pipe.pipe import Pipe def test_create_pair(): # Positive case: Creating a pair of sync pipes - pipe1, pipe2 = SyncPipe.create_pair() - assert isinstance(pipe1, SyncPipe) - assert isinstance(pipe2, SyncPipe) + pipe1, pipe2 = Pipe.create_pair() + assert isinstance(pipe1, Pipe) + assert isinstance(pipe2, Pipe) def test_write_read(): # Positive case: Writing and reading a chunk successfully - pipe1, pipe2 = SyncPipe.create_pair() + pipe1, pipe2 = Pipe.create_pair() assert pipe1.write("test_chunk") is True assert pipe2.read() == "test_chunk" def test_close_pipe(): # Positive case: Closing the pipe - pipe1, pipe2 = SyncPipe.create_pair() + pipe1, pipe2 = Pipe.create_pair() pipe1.close() with pytest.raises(Exception): @@ -27,20 +27,20 @@ def test_close_pipe(): def test_write_after_close(): # Negative case: Writing to a closed pipe - pipe1, pipe2 = SyncPipe.create_pair() + pipe1, pipe2 = Pipe.create_pair() pipe1.close() assert pipe1.write("test_chunk") is False def test_read_timeout(): # Positive case: Reading with timeout - pipe1, _ = SyncPipe.create_pair() + pipe1, _ = Pipe.create_pair() assert pipe1.read(timeout=0.1) is None def test_bidirectional_communication(): # Positive case: Bidirectional communication between pipes - pipe1, pipe2 = SyncPipe.create_pair() + pipe1, pipe2 = Pipe.create_pair() assert pipe1.write("test_chunk") is True assert pipe2.read() == "test_chunk" @@ -50,7 +50,7 @@ def test_bidirectional_communication(): def test_queue_size(): # Positive case: Checking queue size - pipe1, _ = SyncPipe.create_pair() + pipe1, _ = Pipe.create_pair() pipe1.write("chunk1") pipe1.write("chunk2") diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 37baffa..4a91995 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -1,9 +1,9 @@ import pytest import asyncio from typing import Any, Dict, List -from cbor_rpc.async_pipe import Pipe +from cbor_rpc.pipe.event_pipe import EventPipe from cbor_rpc import Transformer -from cbor_rpc import SyncPipe +from cbor_rpc import Pipe from cbor_rpc import AbstractEmitter # Existing tests... @@ -11,7 +11,7 @@ @pytest.mark.asyncio async def test_async_transformer_basic(): """Test basic asynchronous transformer functionality.""" - pipe1, pipe2 = Pipe.create_pair() + pipe1, pipe2 = EventPipe.create_pair() class MockTransformer(Transformer[str, str], AbstractEmitter): async def encode(self, data: str) -> str: @@ -41,7 +41,7 @@ def handler(data: str) -> None: @pytest.mark.asyncio async def test_sync_transformer_basic(): """Test basic synchronous transformer functionality.""" - pipe1, pipe2 = SyncPipe.create_pair() + pipe1, pipe2 = Pipe.create_pair() class MockSyncTransformer(Transformer[str, str]): def encode_sync(self, data: str) -> str: @@ -65,7 +65,7 @@ def decode_sync(self, data: Any) -> str: @pytest.mark.asyncio async def test_transformer_close_propagation(): """Test close propagation in transformers.""" - pipe1, pipe2 = Pipe.create_pair() + pipe1, pipe2 = EventPipe.create_pair() class MockTransformer(Transformer[str, str]): async def encode(self, data: str) -> str: @@ -90,7 +90,7 @@ async def decode(self, data: Any) -> str: @pytest.mark.asyncio async def test_transformer_exception_handling(): """Test exception handling in transformers.""" - pipe1, pipe2 = Pipe.create_pair() + pipe1, pipe2 = EventPipe.create_pair() class FaultyTransformer(Transformer[str, str], AbstractEmitter): async def encode(self, data: str) -> str: From 45aabe82b2b431a56ab2da0b65bcb1053d37af56 Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Wed, 11 Jun 2025 22:38:33 +0545 Subject: [PATCH 03/25] Wip Transformers --- cbor_rpc/__init__.py | 5 +- cbor_rpc/pipe/event_pipe.py | 69 +----- cbor_rpc/pipe/pipe.py | 218 ++++++++---------- cbor_rpc/transformer/__init__.py | 2 +- cbor_rpc/transformer/base/__init__.py | 20 ++ .../transformer/base/async_transformer.py | 55 +++++ cbor_rpc/transformer/base/base_exception.py | 2 + cbor_rpc/transformer/base/sync_transformer.py | 80 +++++++ cbor_rpc/transformer/base/transformer_base.py | 78 +++++++ cbor_rpc/transformer/json_transformer.py | 2 +- cbor_rpc/transformer/transformer.py | 214 ----------------- tests/test_event_emitter.py | 14 +- ...{test_async_pipe.py => test_event_pipe.py} | 7 +- tests/{test_sync_pipe.py => test_pipe.py} | 0 tests/test_transformer.py | 4 +- 15 files changed, 356 insertions(+), 414 deletions(-) create mode 100644 cbor_rpc/transformer/base/__init__.py create mode 100644 cbor_rpc/transformer/base/async_transformer.py create mode 100644 cbor_rpc/transformer/base/base_exception.py create mode 100644 cbor_rpc/transformer/base/sync_transformer.py create mode 100644 cbor_rpc/transformer/base/transformer_base.py delete mode 100644 cbor_rpc/transformer/transformer.py rename tests/{test_async_pipe.py => test_event_pipe.py} (95%) rename tests/{test_sync_pipe.py => test_pipe.py} (100%) diff --git a/cbor_rpc/__init__.py b/cbor_rpc/__init__.py index 4bdc98e..08cc010 100644 --- a/cbor_rpc/__init__.py +++ b/cbor_rpc/__init__.py @@ -6,8 +6,9 @@ from .pipe.event_pipe import EventPipe from .transformer import Transformer from .promise import TimedPromise -from .rpc.rpc_base import RpcClient, RpcAuthorizedClient, RpcV1 -from .rpc.rpc_server import RpcServer, RpcV1Server +from .rpc.rpc_base import RpcClient, RpcAuthorizedClient,RpcServer +from .rpc.rpc_v1 import RpcV1 +from .rpc.rpc_server import RpcV1Server from .pipe.server_base import Server from .tcp import TcpPipe, TcpServer from .transformer.json_transformer import JsonTransformer diff --git a/cbor_rpc/pipe/event_pipe.py b/cbor_rpc/pipe/event_pipe.py index 713e464..f0f3609 100644 --- a/cbor_rpc/pipe/event_pipe.py +++ b/cbor_rpc/pipe/event_pipe.py @@ -3,9 +3,6 @@ import asyncio import inspect from ..event.emitter import AbstractEmitter -import queue -import threading -from typing import Union # Generic type variables T1 = TypeVar('T1') @@ -14,8 +11,8 @@ class EventPipe(AbstractEmitter, Generic[T1, T2]): """ - Async Pipe or simply Pipe are event based way for read/write. - You cannot directly read from a Pipe. You have to use a on("data") handler registration. + Event Pipe or are event based way for read/write. + You cannot directly read from a Pipe. You have to use a pipeline("data") to register a function to read data. """ @abstractmethod async def write(self, chunk: T1) -> bool: @@ -25,21 +22,6 @@ async def write(self, chunk: T1) -> bool: async def terminate(self, *args: Any) -> None: pass - @staticmethod - def attach(source: 'EventPipe[Any, Any]', destination: 'EventPipe[Any, Any]') -> None: - async def source_to_destination(chunk: Any): - await destination.write(chunk) - - async def destination_to_source(chunk: Any): - await source.write(chunk) - - async def close_handler(*args: Any): - await destination._emit("close", *args) - - source.on("data", source_to_destination) - destination.on("data", destination_to_source) - source.on("close", close_handler) - @staticmethod def create_pair() -> Tuple['EventPipe[Any, Any]', 'EventPipe[Any, Any]']: """ @@ -59,60 +41,21 @@ def connect_to(self, other: 'ConnectedPipe'): other.connected_pipe = self async def write(self, chunk: Any) -> bool: - if self._closed: + if self._closed or not self.connected_pipe or self.connected_pipe._closed: return False - - # Forward to connected pipe - if self.connected_pipe and not self.connected_pipe._closed: - await self.connected_pipe._emit("data", chunk) - + await self.connected_pipe._notify("data", chunk) return True async def terminate(self, *args: Any) -> None: if self._closed: return - self._closed = True - await self._emit("close", *args) - - # Notify connected pipe + self._emit("close", *args) if self.connected_pipe and not self.connected_pipe._closed: - await self.connected_pipe._emit("close", *args) - - async def read(self, timeout: Optional[float] = None) -> Optional[Any]: - """Read data from the pipe with a timeout. - - Args: - timeout: Maximum time to wait for data (None = no timeout) - - Returns: - Data from the pipe or None if timeout/closed - """ - if self._closed: - return None - - # Wait for data event or timeout - read_future = asyncio.Future() - - def handle_data(chunk: Any): - read_future.set_result(chunk) - self.off("data", handle_data) - - self.on("data", handle_data) - - try: - if timeout is not None and timeout > 0: - await asyncio.wait_for(read_future, timeout) - else: - # Wait indefinitely for data - await read_future - return read_future.result() - except asyncio.TimeoutError: - return None + self.connected_pipe._emit("close", *args) pipe1 = ConnectedPipe() pipe2 = ConnectedPipe() pipe1.connect_to(pipe2) return pipe1, pipe2 - diff --git a/cbor_rpc/pipe/pipe.py b/cbor_rpc/pipe/pipe.py index b4cdd04..8dcfb5d 100644 --- a/cbor_rpc/pipe/pipe.py +++ b/cbor_rpc/pipe/pipe.py @@ -1,140 +1,116 @@ -from typing import Any, TypeVar, Generic, Callable, Tuple, Optional from abc import ABC, abstractmethod +from typing import Any, TypeVar, Generic, Optional, Tuple import asyncio -import inspect + +from cbor_rpc.pipe.event_pipe import EventPipe from ..event.emitter import AbstractEmitter -import queue -import threading -from typing import Union -# Generic type variables T1 = TypeVar('T1') T2 = TypeVar('T2') - -class Pipe(Generic[T1, T2]): +class Pipe(AbstractEmitter, Generic[T1, T2], ABC): """ - Synchronous pipe uses read/write methods instead of events. - Explicit read/write is used for writing protocols that have multiple steps. + Abstract Pipe defining async event-based read/write/terminate interface. """ - def __init__(self): - self._closed = False - self._queue: queue.Queue = queue.Queue() - self._sent_data: queue.Queue = queue.Queue() # Track data sent to other pipe for size reporting - self._connected_pipe: Optional['Pipe'] = None - - def read(self, timeout: Optional[float] = None) -> Optional[T2]: - """ - Read data from the pipe. - - Args: - timeout: Maximum time to wait for data (None = block indefinitely) - - Returns: - Data from the pipe or None if timeout/closed - - Raises: - Exception: If pipe is closed - """ - if self._closed: - raise Exception("Pipe is closed") - - try: - return self._queue.get(timeout=timeout) - except queue.Empty: - return None - - def write(self, chunk: T1) -> bool: - """ - Write data to the pipe. - - Args: - chunk: Data to write - - Returns: - True if successful, False if pipe is closed - """ - if self._closed: - return False - - # Forward to connected pipe only (not to our own queue) - try: - if self._connected_pipe and not self._connected_pipe._closed: - self._sent_data.put(chunk) # Track for size reporting - self._connected_pipe._queue.put(chunk) + @abstractmethod + async def write(self, chunk: T1) -> bool: + pass + + @abstractmethod + async def read(self, timeout: Optional[float] = None) -> Optional[T2]: + pass + + @abstractmethod + async def terminate(self, *args: Any) -> None: + pass + + @staticmethod + def create_pair() -> Tuple['Pipe[Any, Any]', 'Pipe[Any, Any]']: + class InMemoryPipe(Pipe[Any, Any]): + def __init__(self): + super().__init__() + self._closed = False + self.connected_pipe: Optional['InMemoryPipe'] = None + + async def write(self, chunk: Any) -> bool: + if self._closed or not self.connected_pipe or self.connected_pipe._closed: + return False + await self.connected_pipe._notify("data", chunk) return True - except: - return False - return True - - def close(self) -> None: - """Close the pipe.""" - if self._closed: - return - - self._closed = True - - # Notify connected pipe - if self._connected_pipe and not self._connected_pipe._closed: - self._connected_pipe.close() - - def is_closed(self) -> bool: - """Check if the pipe is closed.""" - return self._closed - - def available(self) -> int: - """Get number of items available to read.""" - if self._closed: - return 0 + async def read(self, timeout: Optional[float] = None) -> Optional[Any]: + if self._closed: + return None + + loop = asyncio.get_running_loop() + future = loop.create_future() - # Return combined size of local queue and sent data (data we wrote that hasn't been read by the other pipe) - local_size = self._queue.qsize() + def handler(chunk: Any): + if not future.done(): + future.set_result(chunk) - # If we have a connected pipe, check how much of our sent data has already been read - if self._connected_pipe: - # Get all items in _sent_data queue - sent_items = list(self._sent_data.queue) + self.on("data", handler) - # Remove items that are still in the other pipe's queue - for item in list(sent_items): try: - # Try to peek at what's in the connected pipe's queue - # We need a copy of the other pipe's queue to avoid modifying it during iteration - connected_queue = list(self._connected_pipe._queue.queue) - if item in connected_queue: - # Item hasn't been read yet, so count it - continue + if timeout and timeout > 0: + await asyncio.wait_for(future, timeout) else: - # Item has been read, remove from our sent tracking - self._sent_data.get() - except: - # If we can't access the other pipe's queue or there's an error, - # just use all items in _sent_data as a fallback - pass - - # Return total of local data and unread sent data - return local_size + len(list(self._sent_data.queue)) - - # No connected pipe, just return local size - return local_size + await future + return future.result() + except asyncio.TimeoutError: + return None + finally: + self.unsubscribe("data", handler) + + async def terminate(self, *args: Any) -> None: + if self._closed: + return + self._closed = True + await self._notify("close", *args) + if self.connected_pipe and not self.connected_pipe._closed: + await self.connected_pipe._notify("close", *args) + + a = InMemoryPipe() + b = InMemoryPipe() + a.connected_pipe = b + b.connected_pipe = a + return a, b - @staticmethod - def create_pair() -> Tuple['Pipe[Any, Any]', 'Pipe[Any, Any]']: - """ - Create a pair of connected sync pipes for bidirectional communication. - - Returns: - A tuple of (pipe1, pipe2) where data written to pipe1 can be read from pipe2 and vice versa. - """ - pipe1 = Pipe[Any, Any]() - pipe2 = Pipe[Any, Any]() - - # Connect them bidirectionally - pipe1._connected_pipe = pipe2 - pipe2._connected_pipe = pipe1 - - return pipe1, pipe2 - + def make_event_based(self) -> 'EventPipe[T1, T2]': + parent = self + + class PipeToEvent(EventPipe[T1, T2]): + def __init__(self): + super().__init__() + self._closed = False + + # Spawn a task to pump data from parent.read into events + self._pump_task = asyncio.create_task(self._pump()) + + async def _pump(self): + while not self._closed: + chunk = await parent.read() + if chunk is None: + await self.terminate() + break + await self._notify("data", chunk) + + async def write(self, chunk: T1) -> bool: + return await parent.write(chunk) + + async def terminate(self, *args: Any) -> None: + if self._closed: + return + self._closed = True + await parent.terminate(*args) + self._emit("close", *args) + if self._pump_task: + self._pump_task.cancel() + try: + await self._pump_task + except asyncio.CancelledError: + pass + + return PipeToEvent() \ No newline at end of file diff --git a/cbor_rpc/transformer/__init__.py b/cbor_rpc/transformer/__init__.py index 6ff14d5..0fe31b2 100644 --- a/cbor_rpc/transformer/__init__.py +++ b/cbor_rpc/transformer/__init__.py @@ -1 +1 @@ -from .transformer import Transformer \ No newline at end of file +from .base.transformer_base import Transformer diff --git a/cbor_rpc/transformer/base/__init__.py b/cbor_rpc/transformer/base/__init__.py new file mode 100644 index 0000000..eead791 --- /dev/null +++ b/cbor_rpc/transformer/base/__init__.py @@ -0,0 +1,20 @@ +from typing import overload, Union +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.pipe.pipe import Pipe +from cbor_rpc.transformer.base.async_transformer import EventTransformerPipe +from cbor_rpc.transformer.base.sync_transformer import TransformerPipe +from cbor_rpc.transformer.base.transformer_base import Transformer + + +@overload +def applyTransformer(pipe: Pipe, transformer: Transformer) -> TransformerPipe: ... +@overload +def applyTransformer(pipe: EventPipe, transformer: Transformer) -> EventTransformerPipe: ... + +def applyTransformer(pipe: Union[Pipe, EventPipe], transformer: Transformer) -> Union[TransformerPipe, EventTransformerPipe]: + if isinstance(pipe, EventPipe): + return EventTransformerPipe(pipe, transformer) + elif isinstance(pipe, Pipe): + return TransformerPipe(pipe, transformer) + else: + raise TypeError("Invalid pipe type") \ No newline at end of file diff --git a/cbor_rpc/transformer/base/async_transformer.py b/cbor_rpc/transformer/base/async_transformer.py new file mode 100644 index 0000000..7a91bb3 --- /dev/null +++ b/cbor_rpc/transformer/base/async_transformer.py @@ -0,0 +1,55 @@ +import asyncio +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .transformer_base import Transformer +from typing import Any, Awaitable, Callable, TypeVar + +from cbor_rpc.pipe.event_pipe import EventPipe + +T1 = TypeVar("T1") # Output type after decoding +T2 = TypeVar("T2") # Input type before decoding (pipe input/output type) + +class EventTransformerPipe(EventPipe[T1, T2]): + encode: Callable[[T1], Awaitable[T2]] + decode: Callable[[T2], Awaitable[T1]] + + def __init__(self, pipe: EventPipe[T2, T2], transformer: 'Transformer'): + super().__init__() + self.pipe = pipe + self.pipe.pipeline("data", self._handle_data) + self.pipe.on("close", self._on_close) + self.pipe.on("error", self._on_error) + if transformer: + self.encode = transformer.encode + self.decode = transformer.decode + else: + raise ValueError("A transformer must be provided or encode/decode must be overridden") + + async def _handle_data(self, data: T2): + try: + decoded = await self.decode(data) + self._emit("data", decoded) + except Exception as e: + self._emit("error", e) + + def _on_close(self, *args: Any): + self._emit("close", *args) + + def _on_error(self, error: Exception): + self._emit("error", error) + + async def write(self, chunk: T1) -> bool: + if self._closed: + return False + try: + encoded = await self.encode(chunk) + return await self.pipe.write(encoded) + except Exception as e: + self._emit("error", e) + return False + + async def terminate(self, *args: Any) -> None: + if self._closed: + return + self._closed = True + await self.pipe.terminate(*args) diff --git a/cbor_rpc/transformer/base/base_exception.py b/cbor_rpc/transformer/base/base_exception.py new file mode 100644 index 0000000..92a1fc7 --- /dev/null +++ b/cbor_rpc/transformer/base/base_exception.py @@ -0,0 +1,2 @@ +class NeedsMoreDataException(Exception): + pass diff --git a/cbor_rpc/transformer/base/sync_transformer.py b/cbor_rpc/transformer/base/sync_transformer.py new file mode 100644 index 0000000..26b9cc6 --- /dev/null +++ b/cbor_rpc/transformer/base/sync_transformer.py @@ -0,0 +1,80 @@ +from typing import Any, Awaitable, Optional, TypeVar, Callable +from typing import TYPE_CHECKING +import time + +from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException +if TYPE_CHECKING: + from .transformer_base import Transformer +from cbor_rpc.pipe.pipe import Pipe + +T1 = TypeVar("T1") +T2 = TypeVar("T2") + +class TransformerPipe(Pipe[T1, T2]): + encode: Callable[[T1], Awaitable[T2]] + decode: Callable[[T2], Awaitable[T1]] + + def __init__(self, pipe: Pipe[Any, Any], transformer: 'Optional[Transformer[T1, T2]]' ): + super().__init__() + self.pipe = pipe + + self.encode = transformer.encode + self.decode = transformer.decode + + def _handle_error(*args): + if not self._closed: + self._closed = True + self.pipe.terminate() + self._emit('error', *args) + + def _handle_close(*args): + if not self._closed: + self._closed = True + self._emit('close', *args) + + self.pipe.on('error', _handle_error) + self.pipe.on('close', _handle_close) + + async def write(self, chunk: T1) -> bool: + if self._closed: + return False + try: + encoded = await self.encode(chunk) + return self.pipe.write(encoded) + except Exception: + self._emit('error', chunk) + return False + + async def read(self, timeout: Optional[float] = None) -> Optional[T1]: + if self._closed: + return None + + start = time.monotonic() + remaining = timeout + + try: + while True: + raw = self.pipe.read(remaining) + if raw is None: + return None + + try: + return await self.decode(raw) + except NeedsMoreDataException: + if timeout is not None: + elapsed = time.monotonic() - start + remaining = max(0, timeout - elapsed) + if remaining == 0: + return None + except Exception as e: + self._emit('error', e) + return None + + def terminate(self) -> None: + if self._closed: + return + self._closed = True + self.pipe.terminate() + + def _propagate_error(self, *args): + self.pipe._emit('error', *args) diff --git a/cbor_rpc/transformer/base/transformer_base.py b/cbor_rpc/transformer/base/transformer_base.py new file mode 100644 index 0000000..31f335d --- /dev/null +++ b/cbor_rpc/transformer/base/transformer_base.py @@ -0,0 +1,78 @@ +from abc import abstractmethod +from typing import Any, Generic, TypeVar, Union, overload + +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.pipe.pipe import Pipe +from cbor_rpc.transformer.base.async_transformer import EventTransformerPipe +from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException +from cbor_rpc.transformer.base.sync_transformer import TransformerPipe + +T1 = TypeVar("T1") +T2 = TypeVar("T2") + +# Sync Transformer (no async methods) +class Transformer(Generic[T1, T2]): + def __init__(self): + super().__init__() + self._closed = False + + def is_closed(self) -> bool: + return self._closed + + @abstractmethod + def encode(self, data: T1) -> Any: + pass + + @abstractmethod + def decode(self, data: Any) -> T2: + pass + + def wait_next_data(self): + raise NeedsMoreDataException() + + @overload + def bind(self, pipe: Pipe) -> TransformerPipe: ... + @overload + def bind(self, pipe: EventPipe) -> EventTransformerPipe: ... + + def applyTransformer(self, pipe: Union[Pipe, EventPipe], transformer: 'Transformer') -> Union[TransformerPipe, EventTransformerPipe]: + if isinstance(pipe, EventPipe): + return EventTransformerPipe(pipe, transformer) + elif isinstance(pipe, Pipe): + return TransformerPipe(pipe, transformer) + else: + raise TypeError("Invalid pipe type") + + +# Async Transformer (async methods) +class AsyncTransformer(Generic[T1, T2]): + def __init__(self): + super().__init__() + self._closed = False + + def is_closed(self) -> bool: + return self._closed + + @abstractmethod + async def encode(self, data: T1) -> Any: + pass + + @abstractmethod + async def decode(self, data: Any) -> T2: + pass + + def wait_next_data(self): + raise NeedsMoreDataException() + + @overload + def bind(self, pipe: Pipe) -> TransformerPipe: ... + @overload + def bind(self, pipe: EventPipe) -> EventTransformerPipe: ... + + def applyTransformer(self, pipe: Union[Pipe, EventPipe], transformer: 'AsyncTransformer') -> Union[TransformerPipe, EventTransformerPipe]: + if isinstance(pipe, EventPipe): + return EventTransformerPipe(pipe, transformer) + elif isinstance(pipe, Pipe): + return TransformerPipe(pipe, transformer) + else: + raise TypeError("Invalid pipe type") diff --git a/cbor_rpc/transformer/json_transformer.py b/cbor_rpc/transformer/json_transformer.py index f0b73fe..84f077b 100644 --- a/cbor_rpc/transformer/json_transformer.py +++ b/cbor_rpc/transformer/json_transformer.py @@ -4,7 +4,7 @@ from . import Transformer from ..pipe.event_pipe import EventPipe -class JsonTransformer(AbstractEmitter, Transformer[Any, Any]): +class JsonTransformer(Transformer[Any, Any]): """ A transformer that encodes Python objects to JSON strings and decodes JSON strings back to Python objects. """ diff --git a/cbor_rpc/transformer/transformer.py b/cbor_rpc/transformer/transformer.py deleted file mode 100644 index 1ab023b..0000000 --- a/cbor_rpc/transformer/transformer.py +++ /dev/null @@ -1,214 +0,0 @@ -from typing import Any, TypeVar, Generic, Callable, Tuple, Optional -from abc import ABC, abstractmethod -import asyncio -import inspect -from cbor_rpc.pipe.event_pipe import EventPipe -from cbor_rpc.pipe.pipe import Pipe -import queue -import threading -from typing import Union - -# Generic type variables -T1 = TypeVar('T1') -T2 = TypeVar('T2') - -class Transformer(Generic[T1, T2]): - """ - Abstract transformer that can wrap both sync and async pipes. - Encodes data when writing and decodes data when reading/emitting events. - """ - - def __init__(self, underlying_pipe: Union[EventPipe[Any, Any], Pipe[Any, Any]]): - self.underlying_pipe = underlying_pipe - self._closed = False - self._is_sync_pipe = isinstance(underlying_pipe, Pipe) - - if not self._is_sync_pipe: - # For async pipes, set up event handlers - super().__init__() - - # Forward events from underlying pipe, but decode data events - async def on_underlying_data(data: Any): - try: - decoded_data = await self.decode(data) - await self._emit("data", decoded_data) - except Exception as e: - await self._emit("error", e) - - async def on_underlying_close(*args): - await self._emit("close", *args) - - async def on_underlying_error(error): - await self._emit("error", error) - - self.underlying_pipe.on("data", on_underlying_data) - self.underlying_pipe.on("close", on_underlying_close) - self.underlying_pipe.on("error", on_underlying_error) - - # Async methods for async pipes - async def write(self, chunk: T1) -> bool: - """Write data after encoding it (async version).""" - if self._closed: - return False - - if self._is_sync_pipe: - raise RuntimeError("Use write_sync() for SyncPipe") - - try: - encoded_chunk = await self.encode(chunk) - return await self.underlying_pipe.write(encoded_chunk) - except Exception as e: - await self._emit("error", e) - return False - - async def terminate(self, *args: Any) -> None: - """Terminate the underlying pipe (async version).""" - if self._closed: - return - self._closed = True - - if not self._is_sync_pipe: - await self.underlying_pipe.terminate(*args) - else: - self.underlying_pipe.close() - - # Sync methods for sync pipes - def write_sync(self, chunk: T1) -> bool: - """Write data after encoding it (sync version).""" - if self._closed: - return False - - if not self._is_sync_pipe: - raise RuntimeError("Use write() for async Pipe") - - try: - encoded_chunk = self.encode_sync(chunk) - return self.underlying_pipe.write(encoded_chunk) - except Exception as e: - return False - - def read_sync(self, timeout: Optional[float] = None) -> Optional[T2]: - """Read and decode data (sync version).""" - if self._closed: - return None - - if not self._is_sync_pipe: - raise RuntimeError("Use event handlers for async Pipe") - - try: - raw_data = self.underlying_pipe.read(timeout) - if raw_data is None: - return None - return self.decode_sync(raw_data) - except Exception as e: - return None - - def close_sync(self) -> None: - """Close the transformer (sync version).""" - if self._closed: - return - self._closed = True - - if self._is_sync_pipe: - self.underlying_pipe.close() - - def is_sync_pipe(self) -> bool: - """Check if this transformer wraps a sync pipe.""" - return self._is_sync_pipe - - # Abstract methods - async versions - @abstractmethod - async def encode(self, data: T1) -> Any: - """Encode data before writing to the underlying pipe (async version).""" - pass - - @abstractmethod - async def decode(self, data: Any) -> T2: - """Decode data received from the underlying pipe (async version).""" - pass - - # Abstract methods - sync versions (with default implementations that call async versions) - def encode_sync(self, data: T1) -> Any: - """Encode data before writing to the underlying pipe (sync version).""" - # Default implementation for backwards compatibility - # Subclasses should override this for true sync operation - if asyncio.iscoroutinefunction(self.encode): - raise NotImplementedError("Sync encoding not implemented for this transformer") - return asyncio.run(self.encode(data)) - - def decode_sync(self, data: Any) -> T2: - """Decode data received from the underlying pipe (sync version).""" - # Default implementation for backwards compatibility - # Subclasses should override this for true sync operation - if asyncio.iscoroutinefunction(self.decode): - raise NotImplementedError("Sync decoding not implemented for this transformer") - return asyncio.run(self.decode(data)) - - @staticmethod - def create_pair(encoder1: Callable, decoder1: Callable, - encoder2: Callable, decoder2: Callable, - use_sync: bool = False) -> Tuple['Transformer', 'Transformer']: - """ - Create a pair of connected transformer pipes. - - Args: - encoder1: Encoder function for the first transformer - decoder1: Decoder function for the first transformer - encoder2: Encoder function for the second transformer - decoder2: Decoder function for the second transformer - use_sync: If True, use SyncPipe; if False, use async Pipe - - Returns: - A tuple of (transformer1, transformer2) - """ - if use_sync: - pipe1, pipe2 = Pipe.create_pair() - else: - pipe1, pipe2 = EventPipe.create_pair() - - class ConcreteTransformer1(Transformer): - async def encode(self, data): - if asyncio.iscoroutinefunction(encoder1): - return await encoder1(data) - return encoder1(data) - - async def decode(self, data): - if asyncio.iscoroutinefunction(decoder1): - return await decoder1(data) - return decoder1(data) - - def encode_sync(self, data): - if asyncio.iscoroutinefunction(encoder1): - raise NotImplementedError("Encoder is async, cannot use sync method") - return encoder1(data) - - def decode_sync(self, data): - if asyncio.iscoroutinefunction(decoder1): - raise NotImplementedError("Decoder is async, cannot use sync method") - return decoder1(data) - - class ConcreteTransformer2(Transformer): - async def encode(self, data): - if asyncio.iscoroutinefunction(encoder2): - return await encoder2(data) - return encoder2(data) - - async def decode(self, data): - if asyncio.iscoroutinefunction(decoder2): - return await decoder2(data) - return decoder2(data) - - def encode_sync(self, data): - if asyncio.iscoroutinefunction(encoder2): - raise NotImplementedError("Encoder is async, cannot use sync method") - return encoder2(data) - - def decode_sync(self, data): - if asyncio.iscoroutinefunction(decoder2): - raise NotImplementedError("Decoder is async, cannot use sync method") - return decoder2(data) - - transformer1 = ConcreteTransformer1(pipe1) - transformer2 = ConcreteTransformer2(pipe2) - - return transformer1, transformer2 diff --git a/tests/test_event_emitter.py b/tests/test_event_emitter.py index 168ee12..6a2b22b 100644 --- a/tests/test_event_emitter.py +++ b/tests/test_event_emitter.py @@ -29,7 +29,7 @@ async def async_handler2(data: Any): emitter.on("test", async_handler2) # Emit event - await emitter._emit("test", "event1") + emitter._emit("test", "event1") await asyncio.sleep(0.02) # Allow async handlers to complete # Verify all subscribers ran (order may vary due to concurrency) @@ -117,7 +117,7 @@ def sync_handler1(data: Any): emitter.unsubscribe("test", async_handler1) # Emit event - await emitter._emit("test", "event3") + emitter._emit("test", "event3") await asyncio.sleep(0.02) # Allow async handlers to complete (none in this case) # Verify only remaining subscriber ran @@ -145,7 +145,7 @@ def sync_handler1(data: Any): emitter.replace_on_handler("test", async_handler1) # Emit event - await emitter._emit("test", "event4") + emitter._emit("test", "event4") await asyncio.sleep(0.02) # Allow async handler to complete # Verify only the replaced handler ran @@ -228,7 +228,7 @@ def sync_pipeline_b(data: Any): # Test _emit for event_a events.clear() - await emitter._emit("event_a", "data_a") + emitter._emit("event_a", "data_a") await asyncio.sleep(0.02) # Allow async handlers to complete expected = [f"async_handler_a_data_a", f"sync_handler_a_data_a"] assert sorted(events) == sorted(expected), f"Expected {expected}, got {events}" @@ -248,7 +248,7 @@ def sync_pipeline_b(data: Any): # Test _emit for event_b events.clear() - await emitter._emit("event_b", "data_b") + emitter._emit("event_b", "data_b") await asyncio.sleep(0.02) # Allow async handlers to complete expected = [f"async_handler_b_data_b", f"sync_handler_b_data_b"] assert sorted(events) == sorted(expected), f"Expected {expected}, got {events}" @@ -293,7 +293,7 @@ def sync_handler(data: Any): emitter.on("test", sync_handler) # Emit event - await emitter._emit("test", "event6") + emitter._emit("test", "event6") await asyncio.sleep(0.02) # Allow async handlers to complete # Verify all subscribers ran despite the failure @@ -332,7 +332,7 @@ def fast_notify_pipeline(data: Any): emitter.pipeline("test_notify", fast_notify_pipeline) # Start _emit but don't wait for it to finish - asyncio.create_task(emitter._emit("test_emit", "data_emit")) + emitter._emit("test_emit", "data_emit") # Wait briefly before triggering _notify await asyncio.sleep(0.1) diff --git a/tests/test_async_pipe.py b/tests/test_event_pipe.py similarity index 95% rename from tests/test_async_pipe.py rename to tests/test_event_pipe.py index d64d0c2..a0a63a1 100644 --- a/tests/test_async_pipe.py +++ b/tests/test_event_pipe.py @@ -19,7 +19,7 @@ async def terminate(self, *args: Any) -> None: if self._closed: return self._closed = True - await self._emit("close", *args) + self._emit("close", *args) @pytest.fixture def pipe(): @@ -57,7 +57,7 @@ async def pipeline_handler(chunk: Any) -> None: assert called is True @pytest.mark.asyncio -async def test_attach_pipes(): +async def test_pipe_pair(): # Positive case: Attaching two pipes pipe1, pipe2 = EventPipe.create_pair() @@ -66,8 +66,9 @@ async def handler(chunk: Any) -> None: nonlocal called called = True - pipe2.on("data", handler) + pipe2.pipeline("data", handler) await pipe1.write("test_chunk") + assert called is True @pytest.mark.asyncio diff --git a/tests/test_sync_pipe.py b/tests/test_pipe.py similarity index 100% rename from tests/test_sync_pipe.py rename to tests/test_pipe.py diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 4a91995..6bed41d 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -44,10 +44,10 @@ async def test_sync_transformer_basic(): pipe1, pipe2 = Pipe.create_pair() class MockSyncTransformer(Transformer[str, str]): - def encode_sync(self, data: str) -> str: + def encode(self, data: str) -> str: return f"encoded_{data}" - def decode_sync(self, data: Any) -> str: + def decode(self, data: Any) -> str: if isinstance(data, str) and data.startswith("encoded_"): return data[8:] # Remove "encoded_" prefix raise ValueError("Invalid format") From b2fd6ac8b753bfab62bcf683ed025e40c1bcc30c Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Wed, 11 Jun 2025 23:14:13 +0545 Subject: [PATCH 04/25] Fix tests for pipe --- cbor_rpc/pipe/pipe.py | 39 +++-- cbor_rpc/transformer/json_transformer.py | 18 +-- tests/helpers/simple_pipe.py | 4 +- tests/test_event_pipe.py | 195 ++++++++++++++++++----- tests/test_pipe.py | 163 ++++++++++++++----- tests/test_server_generics.py | 195 ----------------------- tests/test_transformer.py | 120 -------------- 7 files changed, 301 insertions(+), 433 deletions(-) delete mode 100644 tests/test_server_generics.py delete mode 100644 tests/test_transformer.py diff --git a/cbor_rpc/pipe/pipe.py b/cbor_rpc/pipe/pipe.py index 8dcfb5d..4cca8d3 100644 --- a/cbor_rpc/pipe/pipe.py +++ b/cbor_rpc/pipe/pipe.py @@ -32,45 +32,44 @@ class InMemoryPipe(Pipe[Any, Any]): def __init__(self): super().__init__() self._closed = False + self._buffer: asyncio.Queue[Optional[Any]] = asyncio.Queue() self.connected_pipe: Optional['InMemoryPipe'] = None async def write(self, chunk: Any) -> bool: if self._closed or not self.connected_pipe or self.connected_pipe._closed: return False - await self.connected_pipe._notify("data", chunk) + await self.connected_pipe._buffer.put(chunk) return True async def read(self, timeout: Optional[float] = None) -> Optional[Any]: if self._closed: return None - - loop = asyncio.get_running_loop() - future = loop.create_future() - - def handler(chunk: Any): - if not future.done(): - future.set_result(chunk) - - self.on("data", handler) - try: - if timeout and timeout > 0: - await asyncio.wait_for(future, timeout) + if timeout is not None and timeout > 0: + return await asyncio.wait_for(self._buffer.get(), timeout) else: - await future - return future.result() + return await self._buffer.get() except asyncio.TimeoutError: return None - finally: - self.unsubscribe("data", handler) + except asyncio.CancelledError: + # If read is cancelled, put back the None if it was a termination signal + if not self._buffer.empty(): + item = self._buffer.get_nowait() + if item is None: + await self._buffer.put(None) + raise async def terminate(self, *args: Any) -> None: if self._closed: return self._closed = True - await self._notify("close", *args) + # Signal termination to any pending reads + await self._buffer.put(None) + await self._notify("close", *args) # Notify external listeners + if self.connected_pipe and not self.connected_pipe._closed: - await self.connected_pipe._notify("close", *args) + await self.connected_pipe._buffer.put(None) # Signal termination to connected pipe + await self.connected_pipe.terminate(*args) # Recursively terminate connected pipe a = InMemoryPipe() b = InMemoryPipe() @@ -113,4 +112,4 @@ async def terminate(self, *args: Any) -> None: except asyncio.CancelledError: pass - return PipeToEvent() \ No newline at end of file + return PipeToEvent() diff --git a/cbor_rpc/transformer/json_transformer.py b/cbor_rpc/transformer/json_transformer.py index 84f077b..797905f 100644 --- a/cbor_rpc/transformer/json_transformer.py +++ b/cbor_rpc/transformer/json_transformer.py @@ -63,20 +63,4 @@ async def decode(self, data: Union[bytes, str, None]) -> Any: else: raise TypeError(f"Expected bytes or str, got {type(data)}") - return json.loads(json_str) - - @classmethod - def create_pair(cls, encoding: str = 'utf-8') -> tuple['JsonTransformer', 'JsonTransformer']: - """ - Create a pair of connected JSON transformers. - - Args: - encoding: Text encoding to use - - Returns: - A tuple of (transformer1, transformer2) - """ - pipe1, pipe2 = EventPipe.create_pair() - transformer1 = cls(pipe1, encoding) - transformer2 = cls(pipe2, encoding) - return transformer1, transformer2 + return json.loads(json_str) \ No newline at end of file diff --git a/tests/helpers/simple_pipe.py b/tests/helpers/simple_pipe.py index 32a324d..10b196f 100644 --- a/tests/helpers/simple_pipe.py +++ b/tests/helpers/simple_pipe.py @@ -13,11 +13,11 @@ def __init__(self): async def write(self, chunk: T1) -> bool: if self._closed: return False - await self._emit("data", chunk) + await self._notify("data", chunk) return True async def terminate(self, *args: Any) -> None: if self._closed: return self._closed = True - await self._emit("close", *args) + self._emit("close", *args) diff --git a/tests/test_event_pipe.py b/tests/test_event_pipe.py index a0a63a1..4bdbe09 100644 --- a/tests/test_event_pipe.py +++ b/tests/test_event_pipe.py @@ -2,83 +2,192 @@ import asyncio from typing import Any, Tuple from cbor_rpc import EventPipe +import pytest_asyncio -class SimplePipe(EventPipe): - def __init__(self): - super().__init__() - self._closed = False - - async def write(self, chunk: Any) -> bool: - if self._closed: - return False - # Simulate writing data asynchronously - await asyncio.sleep(0.1) - return True - - async def terminate(self, *args: Any) -> None: - if self._closed: - return - self._closed = True - self._emit("close", *args) - -@pytest.fixture -def pipe(): - return SimplePipe() +@pytest_asyncio.fixture +async def event_pipe_pair(): + pipe1, pipe2 = EventPipe.create_pair() + yield pipe1, pipe2 + await pipe1.terminate() + await pipe2.terminate() @pytest.mark.asyncio -async def test_create_pair(): +async def test_create_pair(event_pipe_pair: Tuple[EventPipe, EventPipe]): # Positive case: Creating a pair of async pipes - pipe1, pipe2 = EventPipe.create_pair() + pipe1, pipe2 = event_pipe_pair assert isinstance(pipe1, EventPipe) assert isinstance(pipe2, EventPipe) @pytest.mark.asyncio -async def test_write_success(pipe): +async def test_write_success(event_pipe_pair: Tuple[EventPipe, EventPipe]): # Positive case: Writing a chunk successfully - result = await pipe.write("test_chunk") + pipe1, pipe2 = event_pipe_pair + result = await pipe1.write("test_chunk") assert result is True @pytest.mark.asyncio -async def test_terminate_success(pipe): +async def test_terminate_success(event_pipe_pair: Tuple[EventPipe, EventPipe]): # Positive case: Terminating the pipe - await pipe.terminate() + pipe1, pipe2 = event_pipe_pair + await pipe1.terminate() # No exception should be raised @pytest.mark.asyncio -async def test_pipeline_execution(pipe): +async def test_pipeline_execution(event_pipe_pair: Tuple[EventPipe, EventPipe]): # Positive case: Adding and executing a pipeline - called = False + pipe1, _ = event_pipe_pair + received_chunk = None + event = asyncio.Event() + async def pipeline_handler(chunk: Any) -> None: - nonlocal called - called = True + nonlocal received_chunk + received_chunk = chunk + event.set() - pipe.pipeline("data", pipeline_handler) - await pipe._notify("data", "test_chunk") - assert called is True + pipe1.pipeline("data", pipeline_handler) + await pipe1._notify("data", "test_chunk") + await asyncio.wait_for(event.wait(), timeout=1) # Wait for the handler to be called + assert received_chunk == "test_chunk" @pytest.mark.asyncio -async def test_pipe_pair(): +async def test_pipe_pair(event_pipe_pair: Tuple[EventPipe, EventPipe]): # Positive case: Attaching two pipes - pipe1, pipe2 = EventPipe.create_pair() + pipe1, pipe2 = event_pipe_pair + received_chunk = None + event = asyncio.Event() - called = False async def handler(chunk: Any) -> None: - nonlocal called - called = True + nonlocal received_chunk + received_chunk = chunk + event.set() pipe2.pipeline("data", handler) await pipe1.write("test_chunk") - - assert called is True + await asyncio.wait_for(event.wait(), timeout=1) # Wait for the handler to be called + assert received_chunk == "test_chunk" @pytest.mark.asyncio -async def test_write_after_terminate(): +async def test_write_after_terminate(event_pipe_pair: Tuple[EventPipe, EventPipe]): # Negative case: Writing to a terminated pipe - pipe1, _ = EventPipe.create_pair() + pipe1, _ = event_pipe_pair await pipe1.terminate() result = await pipe1.write("test_chunk") assert result is False +@pytest.mark.asyncio +async def test_parallel_event_writes(event_pipe_pair: Tuple[EventPipe, EventPipe]): + # Test case: Multiple coroutines writing to one end of the EventPipe + pipe1, pipe2 = event_pipe_pair + num_writes = 10 + test_chunks = [f"chunk_{i}" for i in range(num_writes)] + received_chunks = [] + + # Use a lock to protect shared list in concurrent access + lock = asyncio.Lock() + + # Use a queue to signal when all chunks are received + received_queue = asyncio.Queue() + + async def pipeline_handler(chunk: Any) -> None: + async with lock: + received_chunks.append(chunk) + if len(received_chunks) == num_writes: + received_queue.put_nowait(True) + + pipe2.pipeline("data", pipeline_handler) + + async def writer(chunk): + return await pipe1.write(chunk) + + # Concurrently write all chunks + results = await asyncio.gather(*[writer(chunk) for chunk in test_chunks]) + assert all(results) # All writes should be successful + + # Wait for all chunks to be received by the handler + await asyncio.wait_for(received_queue.get(), timeout=5) + + # Verify all chunks are received + assert sorted(received_chunks) == sorted(test_chunks) + +@pytest.mark.asyncio +async def test_parallel_event_processing(event_pipe_pair: Tuple[EventPipe, EventPipe]): + # Test case: One pipe writes multiple chunks, and the other pipe's handler processes them + pipe1, pipe2 = event_pipe_pair + num_chunks = 10 + test_chunks = [f"data_{i}" for i in range(num_chunks)] + processed_chunks = [] + + lock = asyncio.Lock() + processed_queue = asyncio.Queue() + + async def processing_handler(chunk: Any) -> None: + # Simulate some async processing + await asyncio.sleep(0.01) + async with lock: + processed_chunks.append(chunk) + if len(processed_chunks) == num_chunks: + processed_queue.put_nowait(True) + + pipe2.pipeline("data", processing_handler) + + # Write all chunks sequentially (or in parallel, EventPipe handles internal queueing) + write_tasks = [pipe1.write(chunk) for chunk in test_chunks] + await asyncio.gather(*write_tasks) + + # Wait for all chunks to be processed + await asyncio.wait_for(processed_queue.get(), timeout=5) + + # Verify all chunks are processed + assert sorted(processed_chunks) == sorted(test_chunks) + +@pytest.mark.asyncio +async def test_concurrent_bidirectional_event_communication(event_pipe_pair: Tuple[EventPipe, EventPipe]): + # Test case: Concurrent writes and event processing from both ends + pipe1, pipe2 = event_pipe_pair + num_messages = 5 + + client_sent_msgs = [] + client_received_responses = [] + server_received_msgs = [] + server_sent_responses = [] + + client_done_event = asyncio.Event() + server_done_event = asyncio.Event() + + async def client_handler(response: Any) -> None: + client_received_responses.append(response) + if len(client_received_responses) == num_messages: + client_done_event.set() + + async def server_handler(msg: Any) -> None: + server_received_msgs.append(msg) + response = f"server_response_to_{msg}" + await pipe2.write(response) + server_sent_responses.append(response) + if len(server_sent_responses) == num_messages: + server_done_event.set() + + pipe1.pipeline("data", client_handler) # Client listens for responses on the 'data' pipeline + pipe2.pipeline("data", server_handler) # Server listens for client messages + + async def client_writer_task(): + for i in range(num_messages): + msg = f"client_msg_{i}" + await pipe1.write(msg) + client_sent_msgs.append(msg) + # Client expects a response, but it's handled by client_handler + + asyncio.create_task(client_writer_task()) + + # Wait for both sides to complete their communication + await asyncio.wait_for(asyncio.gather(client_done_event.wait(), server_done_event.wait()), timeout=5) + + # Verify client sent messages are received by server + assert sorted([f"client_msg_{i}" for i in range(num_messages)]) == sorted(server_received_msgs) + + # Verify server sent messages are received by client + assert sorted([f"server_response_to_client_msg_{i}" for i in range(num_messages)]) == sorted(client_received_responses) + if __name__ == "__main__": pytest.main() diff --git a/tests/test_pipe.py b/tests/test_pipe.py index 4ae825a..97f6fba 100644 --- a/tests/test_pipe.py +++ b/tests/test_pipe.py @@ -1,61 +1,152 @@ import pytest +import asyncio from typing import Any, Tuple from cbor_rpc.pipe.pipe import Pipe +import pytest_asyncio -def test_create_pair(): +@pytest_asyncio.fixture +async def pipe_pair(): + pipe1, pipe2 = Pipe.create_pair() + yield pipe1, pipe2 + +@pytest.mark.asyncio +async def test_create_pair(): # Positive case: Creating a pair of sync pipes pipe1, pipe2 = Pipe.create_pair() assert isinstance(pipe1, Pipe) assert isinstance(pipe2, Pipe) + await pipe1.terminate() + await pipe2.terminate() -def test_write_read(): +@pytest.mark.asyncio +async def test_write_read(pipe_pair:Tuple[Pipe,Pipe]): # Positive case: Writing and reading a chunk successfully - pipe1, pipe2 = Pipe.create_pair() + pipe1, pipe2 = pipe_pair - assert pipe1.write("test_chunk") is True - assert pipe2.read() == "test_chunk" + assert await pipe1.write("test_chunk") is True + await asyncio.sleep(0) # Allow event loop to process the write + assert await pipe2.read() == "test_chunk" -def test_close_pipe(): +@pytest.mark.asyncio +async def test_close_pipe(pipe_pair:Tuple[Pipe,Pipe]): # Positive case: Closing the pipe - pipe1, pipe2 = Pipe.create_pair() - pipe1.close() - - with pytest.raises(Exception): - pipe1.read() + pipe1, pipe2 = pipe_pair + await pipe1.terminate() - assert pipe2.is_closed() is True + assert await pipe1.read() is None + assert pipe1._closed is True + assert pipe2._closed is True -def test_write_after_close(): +@pytest.mark.asyncio +async def test_write_after_close(pipe_pair:Tuple[Pipe,Pipe]): # Negative case: Writing to a closed pipe - pipe1, pipe2 = Pipe.create_pair() - pipe1.close() + pipe1, pipe2 = pipe_pair + await pipe1.terminate() - assert pipe1.write("test_chunk") is False + assert await pipe1.write("test_chunk") is False -def test_read_timeout(): +@pytest.mark.asyncio +async def test_read_timeout(pipe_pair): # Positive case: Reading with timeout - pipe1, _ = Pipe.create_pair() + pipe1, _ = pipe_pair - assert pipe1.read(timeout=0.1) is None + assert await pipe1.read(timeout=0.1) is None -def test_bidirectional_communication(): +@pytest.mark.asyncio +async def test_bidirectional_communication(pipe_pair:Tuple[Pipe,Pipe]): # Positive case: Bidirectional communication between pipes - pipe1, pipe2 = Pipe.create_pair() - - assert pipe1.write("test_chunk") is True - assert pipe2.read() == "test_chunk" - - assert pipe2.write("response_chunk") is True - assert pipe1.read() == "response_chunk" - -def test_queue_size(): - # Positive case: Checking queue size - pipe1, _ = Pipe.create_pair() - - pipe1.write("chunk1") - pipe1.write("chunk2") - - assert pipe1.available() == 2 + pipe1, pipe2 = pipe_pair + + assert await pipe1.write("test_chunk") is True + await asyncio.sleep(0) # Allow event loop to process the write + assert await pipe2.read() == "test_chunk" + + assert await pipe2.write("response_chunk") is True + await asyncio.sleep(0) # Allow event loop to process the write + assert await pipe1.read() == "response_chunk" + +@pytest.mark.asyncio +async def test_parallel_writes(pipe_pair: Tuple[Pipe, Pipe]): + # Test case: Multiple coroutines writing to one end of the pipe + pipe1, pipe2 = pipe_pair + num_writes = 10 + test_chunks = [f"chunk_{i}" for i in range(num_writes)] + + async def writer(chunk): + return await pipe1.write(chunk) + + # Concurrently write all chunks + results = await asyncio.gather(*[writer(chunk) for chunk in test_chunks]) + assert all(results) # All writes should be successful + + # Read all chunks from the other end + received_chunks = [] + for _ in range(num_writes): + received_chunks.append(await pipe2.read()) + + # Verify all chunks are received and in correct order (or at least all present) + assert sorted(received_chunks) == sorted(test_chunks) + +@pytest.mark.asyncio +async def test_parallel_reads(pipe_pair: Tuple[Pipe, Pipe]): + # Test case: Multiple coroutines reading from one end of the pipe + pipe1, pipe2 = pipe_pair + num_reads = 20 + test_chunks = [f"data_{i}" for i in range(num_reads)] + + # Write all chunks first + for chunk in test_chunks: + await pipe1.write(chunk) + await asyncio.sleep(0) # Allow event loop to process the write + + async def reader(): + return await pipe2.read() + + # Concurrently read all chunks + received_chunks = await asyncio.gather(*[reader() for _ in range(num_reads)]) + + # Verify all chunks are received + assert sorted(received_chunks) == sorted(test_chunks) + +@pytest.mark.asyncio +async def test_concurrent_bidirectional_communication(pipe_pair: Tuple[Pipe, Pipe]): + # Test case: Concurrent writes and reads from both ends + pipe1, pipe2 = pipe_pair + num_messages = 20 + + async def client_task(): + sent = [] + received = [] + for i in range(num_messages): + msg = f"client_msg_{i}" + await pipe1.write(msg) + sent.append(msg) + response = await pipe1.read() + received.append(response) + return sent, received + + async def server_task(): + sent = [] + received = [] + for i in range(num_messages): + msg = await pipe2.read() + received.append(msg) + response = f"server_response_to_{msg}" + await pipe2.write(response) + sent.append(response) + return sent, received + + client_future = asyncio.create_task(client_task()) + server_future = asyncio.create_task(server_task()) + + client_sent, client_received = await client_future + server_sent, server_received = await server_future + + # Verify client sent messages are received by server + assert sorted([f"client_msg_{i}" for i in range(num_messages)]) == sorted(server_received) + + # Verify server sent messages are received by client + assert sorted([f"server_response_to_client_msg_{i}" for i in range(num_messages)]) == sorted(client_received) if __name__ == "__main__": pytest.main() diff --git a/tests/test_server_generics.py b/tests/test_server_generics.py deleted file mode 100644 index df25e5b..0000000 --- a/tests/test_server_generics.py +++ /dev/null @@ -1,195 +0,0 @@ -""" -Tests for the generic Server class and its type safety. -""" - -import pytest -import asyncio -from typing import Any, Set -from cbor_rpc import Server, EventPipe, TcpServer, TcpPipe - - -class MockPipe(EventPipe[str, str]): - """A mock pipe for testing.""" - - def __init__(self, name: str): - super().__init__() - self.name = name - self._closed = False - - async def write(self, chunk: str) -> bool: - if self._closed: - return False - await self._emit("data", chunk) - return True - - async def terminate(self, *args: Any) -> None: - if self._closed: - return - self._closed = True - await self._emit("close", *args) - - -class MockServer(Server[MockPipe]): - """A mock server for testing generic functionality.""" - - def __init__(self): - super().__init__() - self._started = False - - async def start(self) -> str: - self._started = True - self._running = True - return "mock-server-started" - - async def stop(self) -> None: - if not self._running: - return - self._started = False - self._running = False - await self.close_all_connections() - - async def add_mock_connection(self, pipe: MockPipe) -> None: - """Add a mock connection for testing.""" - await self._add_connection(pipe) - - -@pytest.mark.asyncio -async def test_generic_server_typing(): - """Test that Server generic typing works correctly.""" - server: Server[MockPipe] = MockServer() - - # Test server initialization - assert not server.is_running() - assert len(server.get_connections()) == 0 - - # Start server - result = await server.start() - assert result == "mock-server-started" - assert server.is_running() - - # Test connection handling - connection_events = [] - - def on_connection(pipe: MockPipe): - connection_events.append(pipe) - assert isinstance(pipe, MockPipe) - - server.on_connection(on_connection) - - # Add mock connections - pipe1 = MockPipe("test-pipe-1") - pipe2 = MockPipe("test-pipe-2") - - await server.add_mock_connection(pipe1) - await server.add_mock_connection(pipe2) - - # Verify connections were added - assert len(connection_events) == 2 - assert connection_events[0] == pipe1 - assert connection_events[1] == pipe2 - - # Verify get_connections returns correct type - connections: Set[MockPipe] = server.get_connections() - assert len(connections) == 2 - assert pipe1 in connections - assert pipe2 in connections - - # Test connection cleanup on close - await pipe1.terminate() - await asyncio.sleep(0.01) # Allow cleanup - - connections = server.get_connections() - assert len(connections) == 1 - assert pipe2 in connections - assert pipe1 not in connections - - # Stop server - await server.stop() - assert not server.is_running() - assert len(server.get_connections()) == 0 - - -@pytest.mark.asyncio -async def test_tcp_server_generic_typing(): - """Test that TcpServer properly implements Server[TcpPipe].""" - # Create TCP server - tcp_server: Server[TcpPipe] = await TcpServer.create('127.0.0.1', 0) - - # Verify it's the correct type - assert isinstance(tcp_server, Server) - assert isinstance(tcp_server, TcpServer) - - # Test connection event typing - connection_events = [] - - def on_tcp_connection(pipe: TcpPipe): - connection_events.append(pipe) - # Verify the pipe is the correct type - assert isinstance(pipe, TcpPipe) - assert isinstance(pipe, EventPipe) - - tcp_server.on_connection(on_tcp_connection) - - try: - # Create client connection - host, port = tcp_server.get_address() - client = await TcpPipe.create_connection(host, port) - - await asyncio.sleep(0.1) - - # Verify connection event was emitted with correct type - assert len(connection_events) == 1 - assert isinstance(connection_events[0], TcpPipe) - - # Verify get_connections returns TcpPipe instances - connections: Set[TcpPipe] = tcp_server.get_connections() - assert len(connections) == 1 - for conn in connections: - assert isinstance(conn, TcpPipe) - - await client.terminate() - - finally: - await tcp_server.stop() - - -@pytest.mark.asyncio -async def test_server_context_manager(): - """Test server context manager functionality.""" - async with MockServer() as server: - await server.start() - assert server.is_running() - - pipe = MockPipe("context-test") - await server.add_mock_connection(pipe) - assert len(server.get_connections()) == 1 - - # Server should be stopped automatically - assert not server.is_running() - - -@pytest.mark.asyncio -async def test_server_close_all_connections(): - """Test that close_all_connections properly closes all connections.""" - server = MockServer() - await server.start() - - # Add multiple connections - pipes = [MockPipe(f"pipe-{i}") for i in range(3)] - for pipe in pipes: - await server.add_mock_connection(pipe) - - assert len(server.get_connections()) == 3 - - # Close all connections - await server.close_all_connections() - await asyncio.sleep(0.01) # Allow cleanup - - # All connections should be closed - assert len(server.get_connections()) == 0 - - await server.stop() - - -if __name__ == "__main__": - pytest.main(["-v", __file__]) diff --git a/tests/test_transformer.py b/tests/test_transformer.py deleted file mode 100644 index 6bed41d..0000000 --- a/tests/test_transformer.py +++ /dev/null @@ -1,120 +0,0 @@ -import pytest -import asyncio -from typing import Any, Dict, List -from cbor_rpc.pipe.event_pipe import EventPipe -from cbor_rpc import Transformer -from cbor_rpc import Pipe -from cbor_rpc import AbstractEmitter - -# Existing tests... - -@pytest.mark.asyncio -async def test_async_transformer_basic(): - """Test basic asynchronous transformer functionality.""" - pipe1, pipe2 = EventPipe.create_pair() - - class MockTransformer(Transformer[str, str], AbstractEmitter): - async def encode(self, data: str) -> str: - return f"encoded_{data}" - - async def decode(self, data: Any) -> str: - if isinstance(data, str) and data.startswith("encoded_"): - return data[8:] # Remove "encoded_" prefix - raise ValueError("Invalid format") - - transformer = MockTransformer(pipe1) - - # Set up event handler on pipe2 to receive data - received_data = None - - def handler(data: str) -> None: - nonlocal received_data - received_data = data - - pipe2.on("data", handler) - - # Write through transformer, should be received by pipe2 - assert await transformer.write("test_data") is True - await asyncio.sleep(0.1) # Give time for events to propagate - assert received_data == "encoded_test_data" - -@pytest.mark.asyncio -async def test_sync_transformer_basic(): - """Test basic synchronous transformer functionality.""" - pipe1, pipe2 = Pipe.create_pair() - - class MockSyncTransformer(Transformer[str, str]): - def encode(self, data: str) -> str: - return f"encoded_{data}" - - def decode(self, data: Any) -> str: - if isinstance(data, str) and data.startswith("encoded_"): - return data[8:] # Remove "encoded_" prefix - raise ValueError("Invalid format") - - transformer = MockSyncTransformer(pipe1) - - # Write through transformer, should be readable from pipe2 - assert transformer.write_sync("test_data") is True - - # Directly read from the connected sync pipe to verify data transfer - encoded_data = pipe2.read(timeout=1.0) - decoded_data = transformer.decode_sync(encoded_data) if encoded_data else None - assert decoded_data == "test_data" - -@pytest.mark.asyncio -async def test_transformer_close_propagation(): - """Test close propagation in transformers.""" - pipe1, pipe2 = EventPipe.create_pair() - - class MockTransformer(Transformer[str, str]): - async def encode(self, data: str) -> str: - return f"encoded_{data}" - - async def decode(self, data: Any) -> str: - if isinstance(data, str) and data.startswith("encoded_"): - return data[8:] # Remove "encoded_" prefix - raise ValueError("Invalid format") - - transformer = MockTransformer(pipe1) - - # Close transformer and verify pipe is also closed - await transformer.terminate() - - # Try to write after closing - should fail - assert await transformer.write("test_data") is False - - # Verify pipe is closed by trying to write directly (should fail) - assert await pipe1.write("raw_data") is False - -@pytest.mark.asyncio -async def test_transformer_exception_handling(): - """Test exception handling in transformers.""" - pipe1, pipe2 = EventPipe.create_pair() - - class FaultyTransformer(Transformer[str, str], AbstractEmitter): - async def encode(self, data: str) -> str: - raise ValueError("Encoding error") - - async def decode(self, data: Any) -> str: - if isinstance(data, str) and data.startswith("encoded_"): - return data[8:] # Remove "encoded_" prefix - raise ValueError("Invalid format") - - transformer = FaultyTransformer(pipe1) - - # Set up error handler to verify exception is caught on the transformer itself - error_caught = False - - def on_error(err: Exception) -> None: - nonlocal error_caught - error_caught = True - - transformer.on("error", on_error) - - # Write through transformer - should trigger encoding error - assert await transformer.write("test_data") is False - assert error_caught is True - -if __name__ == "__main__": - pytest.main() From 51b2d0a8a33e9a501b996271939814d3ff684125 Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Wed, 11 Jun 2025 23:42:21 +0545 Subject: [PATCH 05/25] Refactor and structure cbor_rpc --- cbor_rpc/__init__.py | 25 +++++---- cbor_rpc/event/__init__.py | 1 + cbor_rpc/pipe/__init__.py | 2 + cbor_rpc/rpc/rpc_base.py | 2 +- cbor_rpc/rpc/rpc_server.py | 2 +- cbor_rpc/rpc/rpc_v1.py | 2 +- cbor_rpc/{pipe => rpc}/server_base.py | 2 +- cbor_rpc/tcp/tcp.py | 2 +- cbor_rpc/{promise.py => timed_promise.py} | 0 cbor_rpc/transformer/__init__.py | 1 + cbor_rpc/transformer/base/__init__.py | 22 +------- ...ansformer.py => event_transformer_pipe.py} | 2 +- cbor_rpc/transformer/base/transformer_base.py | 36 +++++++++---- ...ync_transformer.py => transformer_pipe.py} | 2 +- cbor_rpc/transformer/json_transformer.py | 51 +++---------------- 15 files changed, 59 insertions(+), 93 deletions(-) rename cbor_rpc/{pipe => rpc}/server_base.py (98%) rename cbor_rpc/{promise.py => timed_promise.py} (100%) rename cbor_rpc/transformer/base/{async_transformer.py => event_transformer_pipe.py} (97%) rename cbor_rpc/transformer/base/{sync_transformer.py => transformer_pipe.py} (96%) diff --git a/cbor_rpc/__init__.py b/cbor_rpc/__init__.py index 08cc010..b68db41 100644 --- a/cbor_rpc/__init__.py +++ b/cbor_rpc/__init__.py @@ -5,35 +5,37 @@ from .event.emitter import AbstractEmitter from .pipe.event_pipe import EventPipe from .transformer import Transformer -from .promise import TimedPromise +from .timed_promise import TimedPromise from .rpc.rpc_base import RpcClient, RpcAuthorizedClient,RpcServer from .rpc.rpc_v1 import RpcV1 from .rpc.rpc_server import RpcV1Server -from .pipe.server_base import Server +from .rpc.server_base import Server from .tcp import TcpPipe, TcpServer from .transformer.json_transformer import JsonTransformer from .pipe.pipe import Pipe __all__ = [ + # Promise + 'TimedPromise', + # Emitter 'AbstractEmitter', - # Pipe classes + + # Pipe abstract classes 'EventPipe', 'Pipe', - 'Transformer', - # Promise - 'TimedPromise', + # Server abstract classes + 'Server', - # Client classes + # Rpc abstract classes 'RpcClient', 'RpcAuthorizedClient', - 'RpcV1', - - # Server classes - 'Server', 'RpcServer', + + # Rpc base implementation + 'RpcV1', 'RpcV1Server', # TCP classes @@ -41,6 +43,7 @@ 'TcpServer', # Transformers + 'Transformer', 'JsonTransformer', ] diff --git a/cbor_rpc/event/__init__.py b/cbor_rpc/event/__init__.py index e69de29..809afca 100644 --- a/cbor_rpc/event/__init__.py +++ b/cbor_rpc/event/__init__.py @@ -0,0 +1 @@ +from .emitter import AbstractEmitter \ No newline at end of file diff --git a/cbor_rpc/pipe/__init__.py b/cbor_rpc/pipe/__init__.py index e69de29..bb4d5ca 100644 --- a/cbor_rpc/pipe/__init__.py +++ b/cbor_rpc/pipe/__init__.py @@ -0,0 +1,2 @@ +from .event_pipe import EventPipe +from .pipe import Pipe diff --git a/cbor_rpc/rpc/rpc_base.py b/cbor_rpc/rpc/rpc_base.py index 71b6b19..efbb895 100644 --- a/cbor_rpc/rpc/rpc_base.py +++ b/cbor_rpc/rpc/rpc_base.py @@ -3,7 +3,7 @@ import asyncio import inspect from ..pipe.event_pipe import EventPipe -from ..promise import TimedPromise +from ..timed_promise import TimedPromise class RpcClient(ABC): diff --git a/cbor_rpc/rpc/rpc_server.py b/cbor_rpc/rpc/rpc_server.py index 55456c2..4786391 100644 --- a/cbor_rpc/rpc/rpc_server.py +++ b/cbor_rpc/rpc/rpc_server.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod import asyncio -from cbor_rpc.pipe.server_base import Server +from cbor_rpc.rpc.server_base import Server from .rpc_base import RpcClient, RpcAuthorizedClient, RpcServer from .rpc_v1 import RpcV1 from cbor_rpc.pipe.event_pipe import EventPipe diff --git a/cbor_rpc/rpc/rpc_v1.py b/cbor_rpc/rpc/rpc_v1.py index 7608efe..8010b27 100644 --- a/cbor_rpc/rpc/rpc_v1.py +++ b/cbor_rpc/rpc/rpc_v1.py @@ -5,7 +5,7 @@ from .rpc_base import RpcClient from cbor_rpc.pipe.event_pipe import EventPipe -from cbor_rpc.promise import TimedPromise +from cbor_rpc.timed_promise import TimedPromise class RpcV1(RpcClient): diff --git a/cbor_rpc/pipe/server_base.py b/cbor_rpc/rpc/server_base.py similarity index 98% rename from cbor_rpc/pipe/server_base.py rename to cbor_rpc/rpc/server_base.py index a6570ed..cfdd412 100644 --- a/cbor_rpc/pipe/server_base.py +++ b/cbor_rpc/rpc/server_base.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod import asyncio from ..event.emitter import AbstractEmitter -from .event_pipe import EventPipe +from ..pipe import EventPipe # Generic type variable for pipe types P = TypeVar('P', bound=EventPipe) diff --git a/cbor_rpc/tcp/tcp.py b/cbor_rpc/tcp/tcp.py index 7f31c1c..5058895 100644 --- a/cbor_rpc/tcp/tcp.py +++ b/cbor_rpc/tcp/tcp.py @@ -3,7 +3,7 @@ import socket from typing import Any, Callable, Optional, Tuple, Union from cbor_rpc.pipe.event_pipe import EventPipe -from cbor_rpc.pipe.server_base import Server +from cbor_rpc.rpc.server_base import Server class TcpPipe(EventPipe[bytes, bytes]): diff --git a/cbor_rpc/promise.py b/cbor_rpc/timed_promise.py similarity index 100% rename from cbor_rpc/promise.py rename to cbor_rpc/timed_promise.py diff --git a/cbor_rpc/transformer/__init__.py b/cbor_rpc/transformer/__init__.py index 0fe31b2..977529a 100644 --- a/cbor_rpc/transformer/__init__.py +++ b/cbor_rpc/transformer/__init__.py @@ -1 +1,2 @@ from .base.transformer_base import Transformer +from .json_transformer import JsonTransformer \ No newline at end of file diff --git a/cbor_rpc/transformer/base/__init__.py b/cbor_rpc/transformer/base/__init__.py index eead791..db7165b 100644 --- a/cbor_rpc/transformer/base/__init__.py +++ b/cbor_rpc/transformer/base/__init__.py @@ -1,20 +1,2 @@ -from typing import overload, Union -from cbor_rpc.pipe.event_pipe import EventPipe -from cbor_rpc.pipe.pipe import Pipe -from cbor_rpc.transformer.base.async_transformer import EventTransformerPipe -from cbor_rpc.transformer.base.sync_transformer import TransformerPipe -from cbor_rpc.transformer.base.transformer_base import Transformer - - -@overload -def applyTransformer(pipe: Pipe, transformer: Transformer) -> TransformerPipe: ... -@overload -def applyTransformer(pipe: EventPipe, transformer: Transformer) -> EventTransformerPipe: ... - -def applyTransformer(pipe: Union[Pipe, EventPipe], transformer: Transformer) -> Union[TransformerPipe, EventTransformerPipe]: - if isinstance(pipe, EventPipe): - return EventTransformerPipe(pipe, transformer) - elif isinstance(pipe, Pipe): - return TransformerPipe(pipe, transformer) - else: - raise TypeError("Invalid pipe type") \ No newline at end of file +from .transformer_base import Transformer,AsyncTransformer +from base_exception import NeedsMoreDataException \ No newline at end of file diff --git a/cbor_rpc/transformer/base/async_transformer.py b/cbor_rpc/transformer/base/event_transformer_pipe.py similarity index 97% rename from cbor_rpc/transformer/base/async_transformer.py rename to cbor_rpc/transformer/base/event_transformer_pipe.py index 7a91bb3..645ed44 100644 --- a/cbor_rpc/transformer/base/async_transformer.py +++ b/cbor_rpc/transformer/base/event_transformer_pipe.py @@ -4,7 +4,7 @@ from .transformer_base import Transformer from typing import Any, Awaitable, Callable, TypeVar -from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.pipe import EventPipe T1 = TypeVar("T1") # Output type after decoding T2 = TypeVar("T2") # Input type before decoding (pipe input/output type) diff --git a/cbor_rpc/transformer/base/transformer_base.py b/cbor_rpc/transformer/base/transformer_base.py index 31f335d..0529e3a 100644 --- a/cbor_rpc/transformer/base/transformer_base.py +++ b/cbor_rpc/transformer/base/transformer_base.py @@ -1,11 +1,11 @@ from abc import abstractmethod from typing import Any, Generic, TypeVar, Union, overload -from cbor_rpc.pipe.event_pipe import EventPipe -from cbor_rpc.pipe.pipe import Pipe -from cbor_rpc.transformer.base.async_transformer import EventTransformerPipe -from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException -from cbor_rpc.transformer.base.sync_transformer import TransformerPipe +from cbor_rpc.pipe import EventPipe +from cbor_rpc.pipe import Pipe +from .event_transformer_pipe import EventTransformerPipe +from .base_exception import NeedsMoreDataException +from .transformer_pipe import TransformerPipe T1 = TypeVar("T1") T2 = TypeVar("T2") @@ -35,14 +35,28 @@ def bind(self, pipe: Pipe) -> TransformerPipe: ... @overload def bind(self, pipe: EventPipe) -> EventTransformerPipe: ... - def applyTransformer(self, pipe: Union[Pipe, EventPipe], transformer: 'Transformer') -> Union[TransformerPipe, EventTransformerPipe]: + def applyTransformer(self, pipe: Union[Pipe, EventPipe]) -> Union[TransformerPipe, EventTransformerPipe]: if isinstance(pipe, EventPipe): - return EventTransformerPipe(pipe, transformer) + return EventTransformerPipe(pipe, self.to_async()) elif isinstance(pipe, Pipe): - return TransformerPipe(pipe, transformer) + return TransformerPipe(pipe, self.to_async()) else: raise TypeError("Invalid pipe type") + def to_async(self) -> 'AsyncTransformer[T1, T2]': + parent = self + + class WrappedAsyncTransformer(AsyncTransformer[T1, T2]): + async def encode(self, data: T1) -> Any: + return parent.encode(data) + + async def decode(self, data: Any) -> T2: + return parent.decode(data) + + def is_closed(self) -> bool: + return parent.is_closed() + + return WrappedAsyncTransformer() # Async Transformer (async methods) class AsyncTransformer(Generic[T1, T2]): @@ -69,10 +83,10 @@ def bind(self, pipe: Pipe) -> TransformerPipe: ... @overload def bind(self, pipe: EventPipe) -> EventTransformerPipe: ... - def applyTransformer(self, pipe: Union[Pipe, EventPipe], transformer: 'AsyncTransformer') -> Union[TransformerPipe, EventTransformerPipe]: + def applyTransformer(self, pipe: Union[Pipe, EventPipe]) -> Union[TransformerPipe, EventTransformerPipe]: if isinstance(pipe, EventPipe): - return EventTransformerPipe(pipe, transformer) + return EventTransformerPipe(pipe) elif isinstance(pipe, Pipe): - return TransformerPipe(pipe, transformer) + return TransformerPipe(pipe) else: raise TypeError("Invalid pipe type") diff --git a/cbor_rpc/transformer/base/sync_transformer.py b/cbor_rpc/transformer/base/transformer_pipe.py similarity index 96% rename from cbor_rpc/transformer/base/sync_transformer.py rename to cbor_rpc/transformer/base/transformer_pipe.py index 26b9cc6..0561956 100644 --- a/cbor_rpc/transformer/base/sync_transformer.py +++ b/cbor_rpc/transformer/base/transformer_pipe.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING import time -from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException +from .base_exception import NeedsMoreDataException if TYPE_CHECKING: from .transformer_base import Transformer from cbor_rpc.pipe.pipe import Pipe diff --git a/cbor_rpc/transformer/json_transformer.py b/cbor_rpc/transformer/json_transformer.py index 797905f..4931c04 100644 --- a/cbor_rpc/transformer/json_transformer.py +++ b/cbor_rpc/transformer/json_transformer.py @@ -1,58 +1,21 @@ import json from typing import Any, Union -from ..event.emitter import AbstractEmitter -from . import Transformer -from ..pipe.event_pipe import EventPipe +from .base import Transformer class JsonTransformer(Transformer[Any, Any]): """ A transformer that encodes Python objects to JSON strings and decodes JSON strings back to Python objects. """ - def __init__(self, underlying_pipe: EventPipe[Any, Any], encoding: str = 'utf-8'): - """ - Initialize the JSON transformer. - - Args: - underlying_pipe: The underlying pipe to transform - encoding: Text encoding to use (default: 'utf-8') - """ - AbstractEmitter.__init__(self) - Transformer.__init__(self, underlying_pipe) + def __init__(self, encoding: str = 'utf-8'): + super().__init__() self.encoding = encoding - async def encode(self, data: Any) -> bytes: - """ - Encode Python object to JSON bytes. - - Args: - data: Python object to encode - - Returns: - JSON-encoded bytes - - Raises: - TypeError: If data is not JSON serializable - UnicodeEncodeError: If encoding fails - """ - json_str = json.dumps(data, ensure_ascii=False, separators=(',', ':')) + def encode(self, data: Any) -> bytes: + json_str = json.dumps(data, ensure_ascii=False) return json_str.encode(self.encoding) - async def decode(self, data: Union[bytes, str, None]) -> Any: - """ - Decode JSON bytes/string to Python object. - - Args: - data: JSON bytes or string to decode - - Returns: - Decoded Python object - - Raises: - json.JSONDecodeError: If data is not valid JSON - UnicodeDecodeError: If bytes cannot be decoded - TypeError: If data is None or of invalid type - """ + def decode(self, data: Union[bytes, str, None]) -> Any: if data is None: raise TypeError("Expected bytes or str, got None") @@ -63,4 +26,4 @@ async def decode(self, data: Union[bytes, str, None]) -> Any: else: raise TypeError(f"Expected bytes or str, got {type(data)}") - return json.loads(json_str) \ No newline at end of file + return json.loads(json_str) From 05b3cf7c1bdfcad1cdad224533bfcfc0047ffceb Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Thu, 12 Jun 2025 00:01:51 +0545 Subject: [PATCH 06/25] Fix refactoring wip fix json transformer test --- cbor_rpc/__init__.py | 13 +- cbor_rpc/rpc/__init__.py | 7 + cbor_rpc/transformer/__init__.py | 2 +- cbor_rpc/transformer/base/__init__.py | 2 +- cbor_rpc/transformer/json_transformer.py | 2 +- tests/test_json_transformer.py | 441 +++++++++++------------ 6 files changed, 221 insertions(+), 246 deletions(-) diff --git a/cbor_rpc/__init__.py b/cbor_rpc/__init__.py index b68db41..c8e48b4 100644 --- a/cbor_rpc/__init__.py +++ b/cbor_rpc/__init__.py @@ -2,17 +2,12 @@ CBOR-RPC: An async-compatible CBOR-based RPC system """ -from .event.emitter import AbstractEmitter -from .pipe.event_pipe import EventPipe -from .transformer import Transformer +from .event import AbstractEmitter +from .pipe import EventPipe,Pipe from .timed_promise import TimedPromise -from .rpc.rpc_base import RpcClient, RpcAuthorizedClient,RpcServer -from .rpc.rpc_v1 import RpcV1 -from .rpc.rpc_server import RpcV1Server -from .rpc.server_base import Server +from .rpc import RpcClient, RpcAuthorizedClient,RpcServer,RpcV1,RpcV1Server,Server from .tcp import TcpPipe, TcpServer -from .transformer.json_transformer import JsonTransformer -from .pipe.pipe import Pipe +from .transformer import JsonTransformer,Transformer __all__ = [ # Promise diff --git a/cbor_rpc/rpc/__init__.py b/cbor_rpc/rpc/__init__.py index e69de29..b19f2a9 100644 --- a/cbor_rpc/rpc/__init__.py +++ b/cbor_rpc/rpc/__init__.py @@ -0,0 +1,7 @@ +from .server_base import Server + + +from .rpc_base import RpcClient,RpcServer,RpcAuthorizedClient +from .rpc_v1 import RpcV1 +from .rpc_server import RpcV1Server + diff --git a/cbor_rpc/transformer/__init__.py b/cbor_rpc/transformer/__init__.py index 977529a..e2e7e22 100644 --- a/cbor_rpc/transformer/__init__.py +++ b/cbor_rpc/transformer/__init__.py @@ -1,2 +1,2 @@ -from .base.transformer_base import Transformer +from .base import Transformer from .json_transformer import JsonTransformer \ No newline at end of file diff --git a/cbor_rpc/transformer/base/__init__.py b/cbor_rpc/transformer/base/__init__.py index db7165b..7c020f0 100644 --- a/cbor_rpc/transformer/base/__init__.py +++ b/cbor_rpc/transformer/base/__init__.py @@ -1,2 +1,2 @@ from .transformer_base import Transformer,AsyncTransformer -from base_exception import NeedsMoreDataException \ No newline at end of file +from .base_exception import NeedsMoreDataException \ No newline at end of file diff --git a/cbor_rpc/transformer/json_transformer.py b/cbor_rpc/transformer/json_transformer.py index 4931c04..dc6fc53 100644 --- a/cbor_rpc/transformer/json_transformer.py +++ b/cbor_rpc/transformer/json_transformer.py @@ -12,7 +12,7 @@ def __init__(self, encoding: str = 'utf-8'): self.encoding = encoding def encode(self, data: Any) -> bytes: - json_str = json.dumps(data, ensure_ascii=False) + json_str = json.dumps(data, ensure_ascii=(self.encoding == 'ascii')) return json_str.encode(self.encoding) def decode(self, data: Union[bytes, str, None]) -> Any: diff --git a/tests/test_json_transformer.py b/tests/test_json_transformer.py index acff4b1..c185937 100644 --- a/tests/test_json_transformer.py +++ b/tests/test_json_transformer.py @@ -1,238 +1,211 @@ import pytest -import asyncio import json -from typing import Any, Dict, List -from cbor_rpc import JsonTransformer -from cbor_rpc import EventPipe - -@pytest.mark.asyncio -async def test_json_transformer_basic_encoding_decoding(): - """Test basic JSON encoding and decoding.""" - pipe1, pipe2 = EventPipe.create_pair() - transformer = JsonTransformer(pipe1) - - received_data = [] - transformer.on("data", lambda chunk: received_data.append(chunk)) - - # Test encoding: write JSON data through pipe2 (will be received by transformer) - test_data = {"message": "hello", "number": 42, "array": [1, 2, 3]} - json_bytes = json.dumps(test_data).encode('utf-8') - await pipe2.write(json_bytes) - - # Wait for data to be processed - await asyncio.sleep(0.1) - assert test_data == received_data[0] if received_data else None - - # Test decoding: write data through transformer (will be encoded) - received_raw = [] - pipe2.on("data", lambda chunk: received_raw.append(chunk)) - await transformer.write({"response": "world", "success": True}) - - # Read the encoded response from pipe2 - raw_response = received_raw[0] if received_raw else None - decoded_response = json.loads(raw_response.decode('utf-8')) if raw_response else None - assert decoded_response == {"response": "world", "success": True} - -@pytest.mark.asyncio -async def test_json_transformer_create_pair(): - """Test creating a pair of JSON transformers.""" - transformer1, transformer2 = JsonTransformer.create_pair() - - # Test communication from transformer1 to transformer2 - received_data = [] - transformer2.on("data", lambda chunk: received_data.append(chunk)) - test_data = {"message": "Hello World", "timestamp": 1234567890} - await transformer1.write(test_data) - - assert received_data == [test_data] - - # Test communication from transformer2 to transformer1 - received_data.clear() - transformer1.on("data", lambda chunk: received_data.append(chunk)) - response_data = {"reply": "Hello Back", "status": "ok"} - await transformer2.write(response_data) - - assert received_data == [response_data] - -@pytest.mark.asyncio -async def test_json_transformer_different_data_types(): - """Test JSON transformer with different Python data types.""" - transformer1, transformer2 = JsonTransformer.create_pair() - received_data = [] - transformer2.on("data", lambda chunk: received_data.append(chunk)) - - # Test various data types - test_cases = [ - "simple string", - 42, - 3.14159, - True, - False, - None, - [1, 2, 3, "mixed", True], - {"nested": {"object": {"with": "values"}}}, - {"unicode": "Hello äø–ē•Œ šŸŒ"}, - [], - {}, - ] - - for test_data in test_cases: - await transformer1.write(test_data) - - assert received_data == test_cases - -@pytest.mark.asyncio -async def test_json_transformer_encoding_errors(): - """Test JSON transformer encoding error handling.""" - pipe1, pipe2 = EventPipe.create_pair() - transformer = JsonTransformer(pipe1) - errors = [] - transformer.on("error", lambda err: errors.append(str(err))) - - # Test encoding non-serializable object - class NonSerializable: - pass - - result = await transformer.write(NonSerializable()) - assert result is False # Write should return False on error - assert len(errors) == 1 - - # Test encoding circular reference - circular = {} - circular['self'] = circular - result = await transformer.write(circular) - assert result is False - assert len(errors) == 2 - -@pytest.mark.asyncio -async def test_json_transformer_decoding_errors(): - """Test JSON transformer decoding error handling.""" - pipe1, pipe2 = EventPipe.create_pair() - transformer = JsonTransformer(pipe1) - errors = [] - transformer.on("error", lambda err: errors.append(str(err))) - - # Test invalid JSON - await pipe2.write(b'{"invalid": json}') - assert len(errors) == 1 - - # Test invalid UTF-8 bytes - await pipe2.write(b'\xff\xfe\xfd') - assert len(errors) == 2 - - # Test wrong data type - await pipe2.write(123) # Not bytes or string - assert len(errors) == 3 - -@pytest.mark.asyncio -async def test_json_transformer_string_input(): - """Test JSON transformer with string input (not just bytes).""" - pipe1, pipe2 = EventPipe.create_pair() - transformer = JsonTransformer(pipe1) - received_data = [] - transformer.on("data", lambda chunk: received_data.append(chunk)) - - # Test with JSON string input - test_data = {"message": "from string", "value": 123} - json_string = json.dumps(test_data) - await pipe2.write(json_string) - - assert received_data == [test_data] - -@pytest.mark.asyncio -async def test_json_transformer_custom_encoding(): - """Test JSON transformer with custom text encoding.""" - pipe1, pipe2 = EventPipe.create_pair() - transformer = JsonTransformer(pipe1, encoding='latin1') - received_data = [] - transformer.on("data", lambda chunk: received_data.append(chunk)) - - # Test with latin1 encoding - test_data = {"message": "cafĆ©"} # Contains non-ASCII character - json_bytes = json.dumps(test_data).encode('latin1') - await pipe2.write(json_bytes) - - assert received_data == [test_data] - -@pytest.mark.asyncio -async def test_json_transformer_termination(): - """Test JSON transformer termination.""" - pipe1, pipe2 = EventPipe.create_pair() - transformer = JsonTransformer(pipe1) - close_events = [] - transformer.on("close", lambda *args: close_events.append(args)) - - await transformer.terminate("test_reason") - assert len(close_events) == 1 - assert close_events[0] == ("test_reason",) - -@pytest.mark.asyncio -async def test_json_transformer_large_data(): - """Test JSON transformer with large data structures.""" - transformer1, transformer2 = JsonTransformer.create_pair() - received_data = [] - transformer2.on("data", lambda chunk: received_data.append(chunk)) - - # Create a large nested data structure - large_data = { - "users": [ - { - "id": i, - "name": f"User {i}", - "email": f"user{i}@example.com", - "metadata": { - "created": f"2024-01-{i:02d}", - "tags": [f"tag{j}" for j in range(5)], - "settings": { - "theme": "dark" if i % 2 else "light", - "notifications": True, - "features": [f"feature{k}" for k in range(3)] - } - } - } - for i in range(10) - ] - } - - await transformer1.write(large_data) - assert len(received_data) == 1 - assert received_data[0] == large_data - -@pytest.mark.asyncio -async def test_json_transformer_concurrent_operations(): - """Test JSON transformer with concurrent read/write operations.""" - transformer1, transformer2 = JsonTransformer.create_pair() - received_data = [] - transformer2.on("data", lambda chunk: received_data.append(chunk)) - - # Send multiple messages concurrently - async def send_message(id: int): - await transformer1.write({"id": id, "message": f"Message {id}"}) - - tasks = [send_message(i) for i in range(5)] - await asyncio.gather(*tasks) - - assert len(received_data) == 5 +import asyncio +from cbor_rpc.transformer.json_transformer import JsonTransformer +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException +from cbor_rpc.transformer.base.event_transformer_pipe import EventTransformerPipe @pytest.mark.asyncio -async def test_json_transformer_error_recovery(): - """Test that JSON transformer can recover from errors.""" - pipe1, pipe2 = EventPipe.create_pair() - transformer = JsonTransformer(pipe1) - received_data = [] - errors = [] - transformer.on("data", lambda chunk: received_data.append(chunk)) - transformer.on("error", lambda err: errors.append(str(err))) - - # Send invalid JSON - await pipe2.write(b'invalid json') - - # Send valid JSON after error - valid_data = {"message": "recovery test"} - await pipe2.write(json.dumps(valid_data).encode('utf-8')) - - assert len(received_data) == 1 - assert received_data[0] == valid_data - -if __name__ == "__main__": - pytest.main(["-v", __file__]) +class TestJsonTransformerPipeInteraction: + + async def test_json_transformer_end_to_end_simple_dict(self): + # Create a pair of event pipes + client_raw_pipe, server_raw_pipe = EventPipe.create_pair() + + # Instantiate the JSON transformer + json_transformer = JsonTransformer() + + # Apply the transformer to the client side of the raw pipe + # This creates an EventTransformerPipe that encodes/decodes data + client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + assert isinstance(client_transformed_pipe, EventTransformerPipe) + + # Use a queue to capture data emitted by the server_raw_pipe + received_data_queue = asyncio.Queue() + server_raw_pipe.on("data", received_data_queue.put_nowait) + + # Data to send + original_data = {"message": "Hello, world!", "number": 123} + + # Write original_data to the transformed client pipe + # This should encode the data and send it through client_raw_pipe to server_raw_pipe + await client_transformed_pipe.write(original_data) + + # Wait for the encoded data to arrive at the server_raw_pipe + # The server_raw_pipe receives the *encoded* data (bytes) + encoded_data_received_by_server = await received_data_queue.get() + + # Manually decode the data received by the server_raw_pipe to verify it's JSON bytes + decoded_by_server = json.loads(encoded_data_received_by_server.decode('utf-8')) + assert decoded_by_server == original_data + + # Now, let's test the reverse: server sends data, client receives decoded data + # Use a queue to capture data emitted by the client_transformed_pipe + client_received_data_queue = asyncio.Queue() + client_transformed_pipe.on("data", client_received_data_queue.put_nowait) + + # Data to send from server + response_data = {"status": "success", "code": 200} + + # Server_raw_pipe writes the *encoded* data (as if it received it from a client and is sending a response) + # This data will go through client_raw_pipe and then be decoded by client_transformed_pipe + await server_raw_pipe.write(json.dumps(response_data).encode('utf-8')) + + # Wait for the decoded data to arrive at the client_transformed_pipe + decoded_data_received_by_client = await client_received_data_queue.get() + assert decoded_data_received_by_client == response_data + + # Clean up + await client_raw_pipe.terminate() + await server_raw_pipe.terminate() + + async def test_json_transformer_end_to_end_unicode_characters(self): + client_raw_pipe, server_raw_pipe = EventPipe.create_pair() + json_transformer = JsonTransformer() + client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + + received_data_queue = asyncio.Queue() + server_raw_pipe.on("data", received_data_queue.put_nowait) + + original_data = {"message": "ä½ å„½äø–ē•Œ šŸ‘‹"} + await client_transformed_pipe.write(original_data) + encoded_data_received_by_server = await received_data_queue.get() + decoded_by_server = json.loads(encoded_data_received_by_server.decode('utf-8')) + assert decoded_by_server == original_data + + client_received_data_queue = asyncio.Queue() + client_transformed_pipe.on("data", client_received_data_queue.put_nowait) + response_data = {"greeting": "こんにごは"} + await server_raw_pipe.write(json.dumps(response_data, ensure_ascii=False).encode('utf-8')) + decoded_data_received_by_client = await client_received_data_queue.get() + assert decoded_data_received_by_client == response_data + + await client_raw_pipe.terminate() + await server_raw_pipe.terminate() + + async def test_json_transformer_encoding_error_on_write(self): + client_raw_pipe, server_raw_pipe = EventPipe.create_pair() + # Use an encoding that cannot handle certain characters + json_transformer = JsonTransformer(encoding='ascii') + client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + + original_data = {"message": "Hello, world! šŸ‘‹"} # Contains non-ASCII character + + # Expect an encoding error when writing + with pytest.raises(UnicodeEncodeError): + await client_transformed_pipe.write(original_data) + + await client_raw_pipe.terminate() + await server_raw_pipe.terminate() + + async def test_json_transformer_decoding_error_on_read(self): + client_raw_pipe, server_raw_pipe = EventPipe.create_pair() + json_transformer = JsonTransformer() + client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + + # Use a queue to capture errors emitted by the transformed pipe + error_queue = asyncio.Queue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + # Simulate server sending invalid JSON bytes + invalid_json_bytes = b'{"key": "value",}' # Invalid JSON + await server_raw_pipe.write(invalid_json_bytes) + + # The transformed pipe should emit an error when trying to decode + error = await asyncio.wait_for(error_queue.get(), timeout=1) + assert isinstance(error, json.JSONDecodeError) + + await client_raw_pipe.terminate() + await server_raw_pipe.terminate() + + async def test_json_transformer_decoding_type_error_on_read(self): + client_raw_pipe, server_raw_pipe = EventPipe.create_pair() + json_transformer = JsonTransformer() + client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + + error_queue = asyncio.Queue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + # Simulate server sending non-bytes/str data (e.g., an int) + non_string_data = 12345 + await server_raw_pipe.write(non_string_data) # This will pass through raw pipe as is + + # The transformed pipe should emit a TypeError when trying to decode + error = await asyncio.wait_for(error_queue.get(), timeout=1) + assert isinstance(error, TypeError) + assert "Expected bytes or str" in str(error) + + await client_raw_pipe.terminate() + await server_raw_pipe.terminate() + + async def test_json_transformer_non_json_serializable_data(self): + client_raw_pipe, server_raw_pipe = EventPipe.create_pair() + json_transformer = JsonTransformer() + client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + + # Data that is not JSON serializable + non_serializable_data = {"set_data": {1, 2, 3}} + + # Writing this data should raise a TypeError + with pytest.raises(TypeError): + await client_transformed_pipe.write(non_serializable_data) + + await client_raw_pipe.terminate() + await server_raw_pipe.terminate() + + async def test_json_transformer_pipe_termination(self): + client_raw_pipe, server_raw_pipe = EventPipe.create_pair() + json_transformer = JsonTransformer() + client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + + # Listen for close event on the transformed pipe + close_event_received = asyncio.Event() + client_transformed_pipe.on("close", lambda: close_event_received.set()) + + # Terminate the underlying raw pipe + await client_raw_pipe.terminate() + + # The transformed pipe should also terminate and emit a close event + await asyncio.wait_for(close_event_received.wait(), timeout=1) + await server_raw_pipe.terminate() # Ensure the other end is also terminated + + async def test_json_transformer_pipe_write_after_termination(self): + client_raw_pipe, server_raw_pipe = EventPipe.create_pair() + json_transformer = JsonTransformer() + client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + + await client_raw_pipe.terminate() + + # Writing to a terminated transformed pipe should return False + result = await client_transformed_pipe.write({"test": "data"}) + assert result is False + + await server_raw_pipe.terminate() + + async def test_json_transformer_pipe_read_after_termination(self): + client_raw_pipe, server_raw_pipe = EventPipe.create_pair() + json_transformer = JsonTransformer() + client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + + # Listen for data on the transformed pipe + data_queue = asyncio.Queue() + client_transformed_pipe.on("data", data_queue.put_nowait) + + # Terminate the server_raw_pipe, which should cause the client_transformed_pipe to terminate + await server_raw_pipe.terminate() + + # The transformed pipe should eventually close and not emit new data + close_event_received = asyncio.Event() + client_transformed_pipe.on("close", lambda: close_event_received.set()) + await asyncio.wait_for(close_event_received.wait(), timeout=1) + + # Try to write to the raw pipe from the server side after termination + # This data should not be processed by the transformed pipe + await server_raw_pipe.write(b'{"should": "not_receive"}') + + # Ensure no data is received by the transformed pipe after termination + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(data_queue.get(), timeout=0.1) + + await client_raw_pipe.terminate() From ff162bba2322eb87e350e67751c3e014b3a9b76e Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Thu, 28 Aug 2025 19:53:21 +0000 Subject: [PATCH 07/25] Fix tcp server tests --- cbor_rpc/pipe/event_pipe.py | 2 +- cbor_rpc/tcp/tcp.py | 8 ++--- cbor_rpc/timed_promise.py | 3 ++ .../base/event_transformer_pipe.py | 5 --- cbor_rpc/transformer/base/transformer_pipe.py | 4 +-- cbor_rpc/transformer/json_transformer.py | 8 +++-- tests/helpers/simple_tcp_server.py | 11 +++++++ tests/test_json_transformer.py | 31 ++++++++++++++----- tests/test_tcp.py | 21 +++++++------ 9 files changed, 61 insertions(+), 32 deletions(-) create mode 100644 tests/helpers/simple_tcp_server.py diff --git a/cbor_rpc/pipe/event_pipe.py b/cbor_rpc/pipe/event_pipe.py index f0f3609..674135e 100644 --- a/cbor_rpc/pipe/event_pipe.py +++ b/cbor_rpc/pipe/event_pipe.py @@ -12,7 +12,7 @@ class EventPipe(AbstractEmitter, Generic[T1, T2]): """ Event Pipe or are event based way for read/write. - You cannot directly read from a Pipe. You have to use a pipeline("data") to register a function to read data. + You cannot directly read from a Pipe. You have to use a pipeline("data") to register one or more functions to read data. """ @abstractmethod async def write(self, chunk: T1) -> bool: diff --git a/cbor_rpc/tcp/tcp.py b/cbor_rpc/tcp/tcp.py index 5058895..2732eec 100644 --- a/cbor_rpc/tcp/tcp.py +++ b/cbor_rpc/tcp/tcp.py @@ -166,16 +166,16 @@ async def _read_loop(self) -> None: break # Emit data event - await self._emit("data", data) + self._emit("data", data) except asyncio.CancelledError: break except Exception as e: - await self._emit("error", e) + self._emit("error", e) break except Exception as e: - await self._emit("error", e) + self._emit("error", e) finally: if not self._closed: await self._close_connection() @@ -248,7 +248,7 @@ async def _close_connection(self, *args: Any) -> None: pass # Ignore errors during cleanup # Emit close event - await self._emit("close", *args) + self._emit("close", *args) def is_connected(self) -> bool: """Check if the TCP connection is active.""" diff --git a/cbor_rpc/timed_promise.py b/cbor_rpc/timed_promise.py index 696f6e2..da54d50 100644 --- a/cbor_rpc/timed_promise.py +++ b/cbor_rpc/timed_promise.py @@ -51,3 +51,6 @@ def _on_timeout(self) -> None: self._future.set_exception(Exception(error_data)) if self._timeout_cb: self._timeout_cb() + + def __await__(self): + return self._future.__await__() diff --git a/cbor_rpc/transformer/base/event_transformer_pipe.py b/cbor_rpc/transformer/base/event_transformer_pipe.py index 645ed44..7defd03 100644 --- a/cbor_rpc/transformer/base/event_transformer_pipe.py +++ b/cbor_rpc/transformer/base/event_transformer_pipe.py @@ -39,8 +39,6 @@ def _on_error(self, error: Exception): self._emit("error", error) async def write(self, chunk: T1) -> bool: - if self._closed: - return False try: encoded = await self.encode(chunk) return await self.pipe.write(encoded) @@ -49,7 +47,4 @@ async def write(self, chunk: T1) -> bool: return False async def terminate(self, *args: Any) -> None: - if self._closed: - return - self._closed = True await self.pipe.terminate(*args) diff --git a/cbor_rpc/transformer/base/transformer_pipe.py b/cbor_rpc/transformer/base/transformer_pipe.py index 0561956..1c35d6a 100644 --- a/cbor_rpc/transformer/base/transformer_pipe.py +++ b/cbor_rpc/transformer/base/transformer_pipe.py @@ -41,8 +41,8 @@ async def write(self, chunk: T1) -> bool: try: encoded = await self.encode(chunk) return self.pipe.write(encoded) - except Exception: - self._emit('error', chunk) + except Exception as e: + self._emit('error', e) return False async def read(self, timeout: Optional[float] = None) -> Optional[T1]: diff --git a/cbor_rpc/transformer/json_transformer.py b/cbor_rpc/transformer/json_transformer.py index dc6fc53..e79f251 100644 --- a/cbor_rpc/transformer/json_transformer.py +++ b/cbor_rpc/transformer/json_transformer.py @@ -12,8 +12,12 @@ def __init__(self, encoding: str = 'utf-8'): self.encoding = encoding def encode(self, data: Any) -> bytes: - json_str = json.dumps(data, ensure_ascii=(self.encoding == 'ascii')) - return json_str.encode(self.encoding) + try: + json_str = json.dumps(data, ensure_ascii=False) # Always allow non-ASCII characters to pass through json.dumps + return json_str.encode(self.encoding) + except Exception as e: + # Removed print statement as it was for debugging + raise # Re-raise to be caught by EventTransformerPipe def decode(self, data: Union[bytes, str, None]) -> Any: if data is None: diff --git a/tests/helpers/simple_tcp_server.py b/tests/helpers/simple_tcp_server.py new file mode 100644 index 0000000..bfbc081 --- /dev/null +++ b/tests/helpers/simple_tcp_server.py @@ -0,0 +1,11 @@ +from cbor_rpc.tcp.tcp import TcpServer, TcpPipe + +class SimpleTcpServer(TcpServer): + """ + A simple TCP server implementation for testing purposes that accepts all connections. + """ + async def accept(self, pipe: TcpPipe) -> bool: + """ + Accepts all incoming TCP connections. + """ + return True diff --git a/tests/test_json_transformer.py b/tests/test_json_transformer.py index c185937..3c885ed 100644 --- a/tests/test_json_transformer.py +++ b/tests/test_json_transformer.py @@ -92,10 +92,17 @@ async def test_json_transformer_encoding_error_on_write(self): original_data = {"message": "Hello, world! šŸ‘‹"} # Contains non-ASCII character - # Expect an encoding error when writing - with pytest.raises(UnicodeEncodeError): - await client_transformed_pipe.write(original_data) + # Use a queue to capture errors emitted by the transformed pipe + error_queue = asyncio.Queue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + # Writing this data should cause an encoding error to be emitted + await client_transformed_pipe.write(original_data) + # Assert that a UnicodeEncodeError is received + error = await asyncio.wait_for(error_queue.get(), timeout=1) + assert isinstance(error, UnicodeEncodeError) + await client_raw_pipe.terminate() await server_raw_pipe.terminate() @@ -147,10 +154,17 @@ async def test_json_transformer_non_json_serializable_data(self): # Data that is not JSON serializable non_serializable_data = {"set_data": {1, 2, 3}} - # Writing this data should raise a TypeError - with pytest.raises(TypeError): - await client_transformed_pipe.write(non_serializable_data) + # Use a queue to capture errors emitted by the transformed pipe + error_queue = asyncio.Queue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + # Writing this data should cause a TypeError to be emitted + await client_transformed_pipe.write(non_serializable_data) + # Assert that a TypeError is received + error = await asyncio.wait_for(error_queue.get(), timeout=1) + assert isinstance(error, TypeError) + await client_raw_pipe.terminate() await server_raw_pipe.terminate() @@ -190,14 +204,15 @@ async def test_json_transformer_pipe_read_after_termination(self): # Listen for data on the transformed pipe data_queue = asyncio.Queue() - client_transformed_pipe.on("data", data_queue.put_nowait) + client_transformed_pipe.pipeline("data", data_queue.put_nowait) # Terminate the server_raw_pipe, which should cause the client_transformed_pipe to terminate - await server_raw_pipe.terminate() # The transformed pipe should eventually close and not emit new data close_event_received = asyncio.Event() client_transformed_pipe.on("close", lambda: close_event_received.set()) + + await server_raw_pipe.terminate() await asyncio.wait_for(close_event_received.wait(), timeout=1) # Try to write to the raw pipe from the server side after termination diff --git a/tests/test_tcp.py b/tests/test_tcp.py index e90123d..255f1e1 100644 --- a/tests/test_tcp.py +++ b/tests/test_tcp.py @@ -1,14 +1,15 @@ import pytest import asyncio from typing import List -from cbor_rpc import TcpPipe, TcpServer +from cbor_rpc import TcpPipe +from tests.helpers.simple_tcp_server import SimpleTcpServer @pytest.mark.asyncio async def test_tcp_client_server_connection(): """Test basic TCP client-server connection.""" # Start a server - server = await TcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create('127.0.0.1', 0) server_host, server_port = server.get_address() connections = [] @@ -45,7 +46,7 @@ def on_connection(tcp_pipe: TcpPipe): @pytest.mark.asyncio async def test_tcp_data_exchange(): """Test bidirectional data exchange over TCP.""" - server = await TcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create('127.0.0.1', 0) server_host, server_port = server.get_address() server_received = [] @@ -128,7 +129,7 @@ async def test_tcp_connection_errors(): await client.write(b"test") # Test double connection - server = await TcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create('127.0.0.1', 0) server_host, server_port = server.get_address() try: @@ -146,7 +147,7 @@ async def test_tcp_connection_errors(): @pytest.mark.asyncio async def test_tcp_connection_events(): """Test TCP connection events (connect, close, error).""" - server = await TcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create('127.0.0.1', 0) server_host, server_port = server.get_address() events = [] @@ -195,7 +196,7 @@ async def on_client_error(error): @pytest.mark.asyncio async def test_tcp_client_connection_tracking(): """Test handling multiple simultaneous TCP connections.""" - server = await TcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create('127.0.0.1', 0) server_host, server_port = server.get_address() @@ -232,7 +233,7 @@ async def test_tcp_client_connection_tracking(): @pytest.mark.asyncio async def test_tcp_client_connection_tracking_self(): """Test handling multiple simultaneous TCP connections.""" - server = await TcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create('127.0.0.1', 0) server_host, server_port = server.get_address() @@ -267,7 +268,7 @@ async def test_tcp_client_connection_tracking_self(): @pytest.mark.asyncio async def test_tcp_large_data_transfer(): """Test transferring large amounts of data over TCP.""" - server = await TcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create('127.0.0.1', 0) server_host, server_port = server.get_address() received_data = bytearray() @@ -311,7 +312,7 @@ async def on_data(data: bytes): @pytest.mark.asyncio async def test_tcp_server_context_manager(): """Test using TcpServer as a context manager.""" - async with await TcpServer.create('127.0.0.1', 0) as server: + async with await SimpleTcpServer.create('127.0.0.1', 0) as server: server_host, server_port = server.get_address() client = await TcpPipe.create_connection(server_host, server_port) @@ -325,7 +326,7 @@ async def test_tcp_server_context_manager(): @pytest.mark.asyncio async def test_tcp_invalid_data_types(): """Test error handling for invalid data types.""" - server = await TcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create('127.0.0.1', 0) server_host, server_port = server.get_address() try: From f13627ebbfb18923da11193aa9ae02bfcd538844 Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Thu, 4 Sep 2025 23:31:30 +0545 Subject: [PATCH 08/25] Adds generic aio_pipe for writer/reader pair --- cbor_rpc/pipe/__init__.py | 1 + cbor_rpc/pipe/aio_pipe.py | 180 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 181 insertions(+) create mode 100644 cbor_rpc/pipe/aio_pipe.py diff --git a/cbor_rpc/pipe/__init__.py b/cbor_rpc/pipe/__init__.py index bb4d5ca..a82c6d3 100644 --- a/cbor_rpc/pipe/__init__.py +++ b/cbor_rpc/pipe/__init__.py @@ -1,2 +1,3 @@ from .event_pipe import EventPipe from .pipe import Pipe +from .aio_pipe import AioPipe diff --git a/cbor_rpc/pipe/aio_pipe.py b/cbor_rpc/pipe/aio_pipe.py new file mode 100644 index 0000000..cb9c051 --- /dev/null +++ b/cbor_rpc/pipe/aio_pipe.py @@ -0,0 +1,180 @@ +from typing import Any, Optional, TypeVar, Generic, Union +import asyncio +from abc import ABC +from .event_pipe import EventPipe # Assuming EventPipe is in a separate module + +# Constrain T1 and T2 to bytes or bytearray for type safety with asyncio streams +T1 = TypeVar('T1', bound=Union[bytes, bytearray]) +T2 = TypeVar('T2', bound=Union[bytes, bytearray]) + +class AioPipe(EventPipe[T1, T2], ABC): + """ + Abstract base class for asynchronous pipes that wrap asyncio.StreamReader and asyncio.StreamWriter. + Extends EventPipe to provide event-based communication for async I/O. + + Attributes: + DEFAULT_READ_CHUNK_SIZE (int): Default size of chunks to read from the stream (8192 bytes). + """ + DEFAULT_READ_CHUNK_SIZE = 8192 + + def __init__(self, + reader: Optional[asyncio.StreamReader] = None, + writer: Optional[asyncio.StreamWriter] = None, + chunk_size: int = DEFAULT_READ_CHUNK_SIZE): + """ + Initialize the AioPipe with optional reader, writer, and chunk size. + + Args: + reader: Optional asyncio.StreamReader for reading data. + writer: Optional asyncio.StreamWriter for writing data. + chunk_size: Size of chunks to read from the stream (default: 8192 bytes). + + Raises: + ValueError: If only one of reader or writer is provided. + """ + super().__init__() + if (reader is None) != (writer is None): # Both must be None or both non-None + raise ValueError("Both reader and writer must be provided or neither") + self._reader = reader + self._writer = writer + self._chunk_size = chunk_size + self._connected = False + self._closed = False + self._read_task: Optional[asyncio.Task] = None + + async def _setup_connection(self) -> None: + """ + Set up the connection and start reading data. + + Raises: + RuntimeError: If reader or writer is not initialized. + """ + if not self._reader or not self._writer: + raise RuntimeError("Reader or writer not initialized") + + self._connected = True + self._closed = False + self._read_task = asyncio.create_task(self._read_loop()) + + try: + await self._notify("connect") + except Exception as e: + self._emit("error", e) # Synchronous _emit + await self._close_connection() + raise + + async def _read_loop(self) -> None: + """ + Continuously read data from the connection and emit data events. + + Stops on EOF, cancellation, or error. Closes the connection in case of errors. + """ + try: + while self._connected and not self._closed and self._reader: + try: + print("AioPipe: Waiting to read data...") + data = await self._reader.read(self._chunk_size) + if not data: # EOF reached + print("AioPipe: EOF reached on read.") + break + try: + await self._notify("data", data) + except Exception as e: + self._emit("error", e) # Synchronous _emit + break + except asyncio.CancelledError: + break + except Exception as e: + self._emit("error", e) # Synchronous _emit + break + except Exception as e: + self._emit("error", e) # Synchronous _emit + finally: + if not self._closed: + await self._close_connection() + + async def _close_connection(self, *args: Any) -> None: + """ + Close the connection and clean up resources. + + Args: + *args: Optional arguments to pass to the 'close' event. + """ + if self._closed: + return + + self._closed = True + self._connected = False + + # Cancel the read task + if self._read_task and not self._read_task.done(): + task, self._read_task = self._read_task, None + try: + task.cancel() + await task + except asyncio.CancelledError: + pass + except Exception as e: + self._emit("error", e) # Synchronous _emit + + # Close the writer + if self._writer: + writer, self._writer = self._writer, None + try: + writer.close() + await writer.wait_closed() + except Exception as e: + self._emit("error", e) # Synchronous _emit + + # Emit close event + try: + self._emit("close", *args) # Synchronous _emit + except Exception as e: + print(f"AioPipe: Failed to emit close event: {e}") + + def is_connected(self) -> bool: + """ + Check if the connection is active. + + Returns: + bool: True if connected and not closed, False otherwise. + """ + return self._connected and not self._closed + + async def write(self, chunk: T1) -> bool: + """ + Write data to the connection. + + Args: + chunk: The data to write (bytes or bytearray). + + Returns: + bool: True if the write was successful, False otherwise. + + Raises: + ConnectionError: If not connected or writer is unavailable. + TypeError: If chunk is not bytes or bytearray. + """ + if not self._connected or self._closed: + raise ConnectionError("Not connected") + if not self._writer: + raise ConnectionError("Writer not available") + if not isinstance(chunk, (bytes, bytearray)): + raise TypeError(f"Expected bytes or bytearray, got {type(chunk).__name__}") + + try: + self._writer.write(chunk) + await self._writer.drain() + return True + except Exception as e: + self._emit("error", e) # Synchronoous _emit + return False + + async def terminate(self, *args: Any) -> None: + """ + Terminate the connection. + + Args: + *args: Optional arguments (e.g., code, reason) to pass to the close event. + """ + await self._close_connection(*args) \ No newline at end of file From a9f269128a98b062524c7606990b886cbbbb9376 Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Thu, 4 Sep 2025 23:33:03 +0545 Subject: [PATCH 09/25] Update tcp pipe to extend AioPipe --- cbor_rpc/tcp/tcp.py | 134 ++++++-------------------------------------- 1 file changed, 17 insertions(+), 117 deletions(-) diff --git a/cbor_rpc/tcp/tcp.py b/cbor_rpc/tcp/tcp.py index 2732eec..ec1954b 100644 --- a/cbor_rpc/tcp/tcp.py +++ b/cbor_rpc/tcp/tcp.py @@ -2,11 +2,11 @@ import asyncio import socket from typing import Any, Callable, Optional, Tuple, Union -from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.pipe.aio_pipe import AioPipe from cbor_rpc.rpc.server_base import Server -class TcpPipe(EventPipe[bytes, bytes]): +class TcpPipe(AioPipe[bytes, bytes]): """ A TCP duplex pipe that implements Pipe for network communication. Provides both client and server functionality for TCP connections. @@ -14,12 +14,7 @@ class TcpPipe(EventPipe[bytes, bytes]): def __init__(self, reader: Optional[asyncio.StreamReader] = None, writer: Optional[asyncio.StreamWriter] = None): - super().__init__() - self._reader = reader - self._writer = writer - self._connected = False - self._closed = False - self._read_task: Optional[asyncio.Task] = None + super().__init__(reader,writer) @classmethod async def create_connection(cls, host: str, port: int, @@ -140,119 +135,24 @@ async def connect(self, host: str, port: int, timeout: Optional[float] = None) - except Exception as e: raise ConnectionError(f"Failed to connect to {host}:{port}: {e}") - async def _setup_connection(self) -> None: - """Set up the connection and start reading data.""" - if not self._reader or not self._writer: - raise RuntimeError("Reader or writer not initialized") - - self._connected = True - self._closed = False - - # Start the read loop - self._read_task = asyncio.create_task(self._read_loop()) - - # Emit connection event - await self._notify("connect") - - async def _read_loop(self) -> None: - """Continuously read data from the TCP connection and emit data events.""" - try: - while self._connected and not self._closed and self._reader: - try: - # Read data in chunks - data = await self._reader.read(8192) - if not data: - # Connection closed by remote - break - - # Emit data event - self._emit("data", data) - - except asyncio.CancelledError: - break - except Exception as e: - self._emit("error", e) - break - - except Exception as e: - self._emit("error", e) - finally: - if not self._closed: - await self._close_connection() - - async def write(self, chunk: bytes) -> bool: - """ - Write data to the TCP connection. - - Args: - chunk: The bytes to write - - Returns: - True if the write was successful - - Raises: - ConnectionError: If not connected - TypeError: If chunk is not bytes - """ - if not self._connected or self._closed: - raise ConnectionError("Not connected") - - if not isinstance(chunk, (bytes, bytearray)): - raise TypeError("Chunk must be bytes or bytearray") - - if not self._writer: - raise ConnectionError("Writer not available") - - try: - self._writer.write(chunk) - await self._writer.drain() - return True - - except Exception as e: - await self._emit("error", e) - return False - - async def terminate(self, *args: Any) -> None: - """ - Terminate the TCP connection. - - Args: - *args: Optional arguments (code, reason) - """ - await self._close_connection(*args) - async def _close_connection(self, *args: Any) -> None: - """Close the TCP connection and clean up resources.""" - if self._closed: - return - - self._closed = True - self._connected = False - - # Cancel the read task - if self._read_task and not self._read_task.done(): - task, self._read_task = self._read_task, None - if task.cancel(): - try: - await task - except asyncio.CancelledError: - pass - - # Close the writer - if self._writer: - writer, self._writer = self._writer, None + def get_peer_info(self) -> Optional[Tuple[str, int]]: + """Get the remote peer's address and port.""" + if self._writer and self._connected: try: - writer.close() - await writer.wait_closed() + return self._writer.get_extra_info('peername') except Exception: - pass # Ignore errors during cleanup - - # Emit close event - self._emit("close", *args) + pass + return None - def is_connected(self) -> bool: - """Check if the TCP connection is active.""" - return self._connected and not self._closed + def get_local_info(self) -> Optional[Tuple[str, int]]: + """Get the local socket's address and port.""" + if self._writer and self._connected: + try: + return self._writer.get_extra_info('sockname') + except Exception: + pass + return None def get_peer_info(self) -> Optional[Tuple[str, int]]: """Get the remote peer's address and port.""" From e1db0951176defb84a21fbbf2305821f204d165f Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Thu, 4 Sep 2025 23:35:18 +0545 Subject: [PATCH 10/25] Add pipe for stdio communication --- cbor_rpc/stdio/stdio_pipe.py | 72 ++++++++++++++++++++++++++++++ tests/helpers/stdio_test_script.py | 19 ++++++++ tests/test_stdio_rpc.py | 41 +++++++++++++++++ 3 files changed, 132 insertions(+) create mode 100644 cbor_rpc/stdio/stdio_pipe.py create mode 100644 tests/helpers/stdio_test_script.py create mode 100644 tests/test_stdio_rpc.py diff --git a/cbor_rpc/stdio/stdio_pipe.py b/cbor_rpc/stdio/stdio_pipe.py new file mode 100644 index 0000000..5a1fd66 --- /dev/null +++ b/cbor_rpc/stdio/stdio_pipe.py @@ -0,0 +1,72 @@ +import asyncio +import sys +from typing import Any, List, Optional, Tuple, TypeVar, Generic + +from cbor_rpc.pipe.aio_pipe import AioPipe +from cbor_rpc.pipe.pipe import Pipe +from cbor_rpc.event.emitter import AbstractEmitter + +T1 = TypeVar('T1') +T2 = TypeVar('T2') + +class StdioPipe(AioPipe[bytes, bytes]): + """ + A Pipe implementation that works with asyncio.StreamReader and asyncio.StreamWriter + typically obtained from a subprocess's stdin/stdout. + """ + + def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, process: Optional[asyncio.subprocess.Process] = None): + super().__init__(reader,writer) + self._process = process + + async def _setup(self): + await self._setup_connection() + + @classmethod + async def open(cls) -> 'StdioPipe': + """ + Creates a StdioPipe from the process's stdin and stdout. + """ + loop = asyncio.get_running_loop() + reader = asyncio.StreamReader() + protocol = asyncio.StreamReaderProtocol(reader) + await loop.connect_read_pipe(lambda: protocol, sys.stdin) + writer_transport, writer_protocol = await loop.connect_write_pipe(asyncio.streams.FlowControlMixin, sys.stdout) + writer = asyncio.StreamWriter(writer_transport, writer_protocol, reader, loop) + pipe = cls(reader, writer) + await pipe._setup() + return pipe + + @classmethod + async def start_process(cls, *args: str) -> 'StdioPipe': + """ + Starts a process and returns a StdioPipe for it. + """ + process = await asyncio.create_subprocess_exec( + *args, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=sys.stderr + ) + pipe = cls(process.stdout, process.stdin, process) + await pipe._setup() + return pipe + + async def wait_for_process_termination(self) -> int: + """ + Waits for the started subprocess to terminate and returns its exit code. + Raises RuntimeError if no process was started by this pipe. + """ + if not self._process: + raise RuntimeError("No subprocess associated with this StdioPipe instance.") + return await self._process.wait() + + + def terminate(self): + """ + Terminates the started subprocess if one exists. + Raises RuntimeError if no process was started by this pipe. + """ + if not self._process: + raise RuntimeError("No subprocess associated with this StdioPipe instance.") + self._process.terminate() diff --git a/tests/helpers/stdio_test_script.py b/tests/helpers/stdio_test_script.py new file mode 100644 index 0000000..eb2453d --- /dev/null +++ b/tests/helpers/stdio_test_script.py @@ -0,0 +1,19 @@ +import sys + +def main(): + print("stdio_test_script: Started.", file=sys.stderr) + + while True: + print("stdio_test_script: Blocked on read", file=sys.stderr) + data = sys.stdin.read(1024) # Read up to 1024 bytes + print(f"stdio_test_script: Read {len(data)} bytes from stdin.", file=sys.stderr) + sys.stderr.flush() + if not data: # EOF reached + print("stdio_test_script: EOF reached, exiting.", file=sys.stderr) + sys.stderr.flush() + break + sys.stdout.write(data) + sys.stdout.flush() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/test_stdio_rpc.py b/tests/test_stdio_rpc.py new file mode 100644 index 0000000..24e23d7 --- /dev/null +++ b/tests/test_stdio_rpc.py @@ -0,0 +1,41 @@ +import asyncio +import pytest +import sys +from cbor_rpc.stdio.stdio_pipe import StdioPipe + +@pytest.mark.asyncio +async def test_stdtio_read_write(): + """ + Tests the StdioPipe.start_process method by writing and reading data 10 times. + """ + pipe = await StdioPipe.start_process("/bin/bash", "-c", "cat -") + + # List to collect received data + received_data = [] + future = asyncio.Future() + + def on_data(data): + received_data.append(data) + # Complete the future after receiving 10 data events + if len(received_data) == 10: + future.set_result(None) + + pipe.pipeline('data', on_data) + + # Write 10 unique data chunks + test_data = [f"Test data {i}\n".encode('utf-8') for i in range(10)] + for data in test_data: + await pipe.write(data) + # Brief sleep to allow the subprocess to process the data + await asyncio.sleep(0.01) + + # Wait for all 10 data events + await future + + # Assert that all received data matches the sent data + assert len(received_data) == 10, f"Expected 10 data events, got {len(received_data)}" + for i, (sent, received) in enumerate(zip(test_data, received_data)): + assert received == sent, f"Mismatch at index {i}: expected {sent!r}, got {received!r}" + + # Terminate the pipe + pipe.terminate() \ No newline at end of file From 0b61a7b6f0f8a53b8b2516ec113fd7461bc7370b Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Thu, 4 Sep 2025 23:36:15 +0545 Subject: [PATCH 11/25] Use stderr for warning logging --- cbor_rpc/rpc/rpc_v1.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/cbor_rpc/rpc/rpc_v1.py b/cbor_rpc/rpc/rpc_v1.py index 8010b27..29c0212 100644 --- a/cbor_rpc/rpc/rpc_v1.py +++ b/cbor_rpc/rpc/rpc_v1.py @@ -1,3 +1,4 @@ +import sys from typing import Any, Dict, List, Optional, Callable from abc import ABC, abstractmethod import asyncio @@ -25,13 +26,14 @@ async def resolve_result(result: Any) -> Any: async def on_data(data: List[Any]) -> None: try: if not isinstance(data, list) or len(data) != 5: - print(f"RpcV1: Invalid message format: {data}") + print(f"RpcV1: Invalid message format: {data}",file=sys.stderr) return version, direction, id_, method, params = data if version != 1: - print(f"RpcV1: Unsupported version: {data}") + print(f"RpcV1: Unsupported version: {data}",file=sys.stderr) return + print("RpvV1: Received", data,file=sys.stderr) if direction < 2: # Method call (0) or fire (1) try: @@ -48,7 +50,7 @@ async def handle_response(): if direction == 0: await self.pipe.write([1, 2, id_, False, str(e)]) else: - print(f"Fired method error: {method}, params={params}, error={e}") + print(f"Fired method error: {method}, params={params}, error={e}",file=sys.stderr) # Create task to handle response asyncio.create_task(handle_response()) @@ -57,7 +59,7 @@ async def handle_response(): if direction == 0: asyncio.create_task(self.pipe.write([1, 2, id_, False, str(e)])) else: - print(f"Fired method error: {method}, params={params}, error={e}") + print(f"Fired method error: {method}, params={params}, error={e}",file=sys.stderr) elif direction == 2: # Response promise = self._promises.pop(id_, None) @@ -67,15 +69,15 @@ async def handle_response(): else: # Error await promise.reject(params) else: - print(f"Received rpc reply for expired request id: {id_}, success={method}, data={params}") + print(f"Received rpc reply for expired request id: {id_}, success={method}, data={params}",file=sys.stderr) elif direction == 3: # Event await self._on_event(method, params) else: - print(f"RpcV1: Invalid direction: {direction}") + print(f"RpcV1: Invalid direction: {direction}",file=sys.stderr) except Exception as e: - print(f"Error processing RPC message: {e}") + print(f"Error processing RPC message: {e}",file=sys.stderr) self.pipe.on("data", on_data) @@ -158,5 +160,5 @@ def method_handler(method: str, args: List[Any]) -> Any: raise Exception("Client Only Implementation") async def event_handler(topic: str, message: Any) -> None: - print(f"Rpc Event dropped {topic} {message}") + print(f"Rpc Event dropped {topic} {message}",file=sys.stderr) return RpcV1.make_rpc_v1(pipe, '', method_handler, event_handler) From 81b150a1127b4562fb88e965c72f1e274b5ee915 Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Sat, 20 Sep 2025 00:49:09 +0545 Subject: [PATCH 12/25] Working implementation of ssh-pipe --- cbor_rpc/pipe/aio_pipe.py | 8 +- cbor_rpc/rpc/server_base.py | 4 +- cbor_rpc/ssh/__init__.py | 1 + cbor_rpc/ssh/ssh_pipe.py | 115 +++++++ pyproject.toml | 3 +- tests/docker/sshd-python/Dockerfile | 13 + tests/docker/sshd-python/binary_emitter.py | 11 + tests/docker/sshd-python/echo_back.py | 28 ++ tests/test_ssh_docker_pipe.py | 345 +++++++++++++++++++++ 9 files changed, 520 insertions(+), 8 deletions(-) create mode 100644 cbor_rpc/ssh/__init__.py create mode 100644 cbor_rpc/ssh/ssh_pipe.py create mode 100644 tests/docker/sshd-python/Dockerfile create mode 100644 tests/docker/sshd-python/binary_emitter.py create mode 100644 tests/docker/sshd-python/echo_back.py create mode 100644 tests/test_ssh_docker_pipe.py diff --git a/cbor_rpc/pipe/aio_pipe.py b/cbor_rpc/pipe/aio_pipe.py index cb9c051..16bb935 100644 --- a/cbor_rpc/pipe/aio_pipe.py +++ b/cbor_rpc/pipe/aio_pipe.py @@ -72,10 +72,8 @@ async def _read_loop(self) -> None: try: while self._connected and not self._closed and self._reader: try: - print("AioPipe: Waiting to read data...") data = await self._reader.read(self._chunk_size) if not data: # EOF reached - print("AioPipe: EOF reached on read.") break try: await self._notify("data", data) @@ -84,10 +82,10 @@ async def _read_loop(self) -> None: break except asyncio.CancelledError: break - except Exception as e: + except Exception as e: # Catch BaseException for GeneratorExit/other BaseExceptions self._emit("error", e) # Synchronous _emit break - except Exception as e: + except Exception as e: # Catch BaseException for GeneratorExit/other BaseExceptions self._emit("error", e) # Synchronous _emit finally: if not self._closed: @@ -177,4 +175,4 @@ async def terminate(self, *args: Any) -> None: Args: *args: Optional arguments (e.g., code, reason) to pass to the close event. """ - await self._close_connection(*args) \ No newline at end of file + await self._close_connection(*args) diff --git a/cbor_rpc/rpc/server_base.py b/cbor_rpc/rpc/server_base.py index cfdd412..e0c8eeb 100644 --- a/cbor_rpc/rpc/server_base.py +++ b/cbor_rpc/rpc/server_base.py @@ -46,8 +46,8 @@ async def _add_connection(self, pipe: P) -> None: Args: pipe: The pipe representing the connection """ - if not self.accept(pipe): - pipe.terminate() + if not await self.accept(pipe): + await pipe.terminate() return self._connections.add(pipe) # Set up cleanup when connection closes diff --git a/cbor_rpc/ssh/__init__.py b/cbor_rpc/ssh/__init__.py new file mode 100644 index 0000000..edfc2c3 --- /dev/null +++ b/cbor_rpc/ssh/__init__.py @@ -0,0 +1 @@ +from .ssh_pipe import SshPipe diff --git a/cbor_rpc/ssh/ssh_pipe.py b/cbor_rpc/ssh/ssh_pipe.py new file mode 100644 index 0000000..357a104 --- /dev/null +++ b/cbor_rpc/ssh/ssh_pipe.py @@ -0,0 +1,115 @@ +import asyncio +import asyncssh +from typing import Optional, Tuple, Union + +from cbor_rpc.pipe.aio_pipe import AioPipe + +class SshPipe(AioPipe[bytes, bytes]): + """ + A Pipe implementation that works over an SSH connection. + It uses asyncssh to establish and manage the SSH session and channels. + """ + + def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, + ssh_client: asyncssh.SSHClientConnection, + ssh_channel: asyncssh.SSHClientChannel): + super().__init__(reader, writer) + self._ssh_client = ssh_client + self._ssh_channel = ssh_channel + + + @classmethod + async def connect(cls, + host: str, + port: int = 22, + username: str = 'root', + password: Optional[str] = None, + ssh_key_content: Optional[str] = None, + ssh_key_passphrase: Optional[str] = None, + known_hosts: Optional[Union[str, list]] = None, + timeout: Optional[float] = None, + command: str = 'sh -l') -> 'SshPipe': + """ + Establishes an SSH connection and opens a session channel, returning an SshPipe. + + Args: + host: The hostname or IP address of the SSH server. + port: The port number for the SSH connection (default: 22). + username: The username for SSH authentication. + password: The password for password-based authentication (optional). + ssh_key_content: The content of an SSH private key for key-based authentication (optional). + known_hosts: Path to a known_hosts file or a list of host keys (optional). + If None, host key checking is disabled (use with caution). + timeout: Optional timeout for the connection attempt. + command: The command to execute on the remote host to establish the pipe (default: 'sh -l'). + + Returns: + An SshPipe instance connected over SSH. + + Raises: + asyncssh.Error: If the SSH connection or channel establishment fails. + asyncio.TimeoutError: If the connection times out. + """ + client_keys = None + if ssh_key_content: + client_keys = [asyncssh.read_private_key(ssh_key_content, passphrase=ssh_key_passphrase)] + + try: + conn = await asyncio.wait_for( + asyncssh.connect( + host, + port=port, + options=asyncssh.SSHClientConnectionOptions( + username=username, + password=password, + client_keys=client_keys, + passphrase=ssh_key_passphrase, # Passphrase for encrypted client keys + ignore_encrypted=False # Do not ignore encrypted keys + ), + known_hosts=known_hosts, + ), + timeout=timeout + ) + + # Create a process on the SSH connection to get stdin/stdout streams. + # Use the provided 'command' argument. + # Explicitly set encoding=None to ensure raw bytes are handled. + # Set term_type=None and stdin=asyncssh.PIPE to ensure stdin is always a writable stream. + channel = await conn.create_process(command, term_type=None, encoding=None, stdin=asyncssh.PIPE) + + reader = channel.stdout + writer = channel.stdin + + pipe = cls(reader, writer, conn, channel) + await pipe._setup_connection() + return pipe + + except asyncssh.Error as e: + # asyncssh.Error expects a 'reason' argument + raise asyncssh.Error(str(e), reason=str(e)) + except asyncio.TimeoutError: + raise asyncio.TimeoutError(f"SSH connection to {host}:{port} timed out.") + + async def terminate(self) -> None: + """ + Closes the SSH channel and the underlying SSH connection. + """ + # Close the SSH channel and client connection + if self._ssh_channel and not self._ssh_channel.is_closing(): + self._ssh_channel.close() + await self._ssh_channel.wait_closed() + if self._ssh_client and not self._ssh_client.is_closed(): + self._ssh_client.close() + await self._ssh_client.wait_closed() + + # Terminate the AioPipe, which handles reader/writer tasks + await super().terminate() + + async def write_eof(self) -> None: + """ + Signals the end of the write stream to the remote process. + For SSH, this means closing the stdin channel. + """ + if self._writer: + self._writer.close() + await self._writer.wait_closed() diff --git a/pyproject.toml b/pyproject.toml index 1a422e5..92a85dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,8 @@ readme = "README.md" requires-python = ">=3.8" dependencies = [ "pytest>=8.3.2", - "pytest-asyncio>=0.24.0" + "pytest-asyncio>=0.24.0", + "asyncssh>=2.14.0" ] [build-system] diff --git a/tests/docker/sshd-python/Dockerfile b/tests/docker/sshd-python/Dockerfile new file mode 100644 index 0000000..28aed83 --- /dev/null +++ b/tests/docker/sshd-python/Dockerfile @@ -0,0 +1,13 @@ +# Use the linuxserver/openssh-server as base image +FROM linuxserver/openssh-server + +# Install python3 using apk +RUN apk add --no-cache python3 + +# Copy the binary_emitter.py script into the container +COPY binary_emitter.py /usr/local/bin/binary_emitter.py +RUN chmod +x /usr/local/bin/binary_emitter.py + +# Copy the echo_back.py script into the container +COPY echo_back.py /usr/local/bin/echo_back.py +RUN chmod +x /usr/local/bin/echo_back.py diff --git a/tests/docker/sshd-python/binary_emitter.py b/tests/docker/sshd-python/binary_emitter.py new file mode 100644 index 0000000..6f4e6b7 --- /dev/null +++ b/tests/docker/sshd-python/binary_emitter.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 +import os +import time +import binascii + +hex_data = "DEADBEEF0001020380FF7F" # Removed \x0A\x0D (LF and CR) +test_data = binascii.unhexlify(hex_data) + +while True: + os.write(1, test_data) + time.sleep(0.1) diff --git a/tests/docker/sshd-python/echo_back.py b/tests/docker/sshd-python/echo_back.py new file mode 100644 index 0000000..890afd8 --- /dev/null +++ b/tests/docker/sshd-python/echo_back.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +import sys +import os + +# Read from stdin and write to stdout +# Ensure binary mode for consistent byte handling +stdin_fd = sys.stdin.buffer.fileno() +stdout_fd = sys.stdout.buffer.fileno() + +while True: + try: + # Read a chunk of data from stdin + data = os.read(stdin_fd, 4096) # Read up to 4096 bytes + if not data: + # EOF reached on stdin + break + # Write the data to stdout + os.write(stdout_fd, data) + # Explicitly flush stdout to ensure data is sent immediately + sys.stdout.buffer.flush() + except BrokenPipeError: + # stdout pipe was closed by the reader + break + except Exception as e: + # Log any other errors to stderr + sys.stderr.write(f"Error in echo_back.py: {e}\n".encode('utf-8')) + sys.stderr.flush() + break diff --git a/tests/test_ssh_docker_pipe.py b/tests/test_ssh_docker_pipe.py new file mode 100644 index 0000000..edded2a --- /dev/null +++ b/tests/test_ssh_docker_pipe.py @@ -0,0 +1,345 @@ +import asyncio +import pytest +import docker +import time +import asyncssh +import os +import re + +from cbor_rpc.ssh.ssh_pipe import SshPipe + +# Define a test user and password for the SSHD container +TEST_SSH_USER = "testuser" +TEST_SSH_PASSWORD = "testpassword" +SSHD_IMAGE_NAME = "cbor-rpc-py-sshd-python" # Custom image name +SSHD_CONTAINER_NAME = "test-sshd-container" +SSHD_DOCKERFILE_PATH = "./tests/docker/sshd-python" # Path to the Dockerfile + +@pytest.fixture(scope="session") +def docker_client(): + """Provides a Docker client instance.""" + client = docker.from_env() + yield client + client.close() + +@pytest.fixture(scope="session") +async def sshd_container(docker_client: docker.DockerClient, docker_host_ip): + """ + Starts an SSHD Docker container, configures it, and yields its connection details. + """ + container = None + host_port = None + + # Ensure previous container is stopped and removed + try: + existing_container = docker_client.containers.get(SSHD_CONTAINER_NAME) + print(f"Found existing container '{SSHD_CONTAINER_NAME}'. Stopping and removing...") + existing_container.stop() + existing_container.remove() + print(f"Removed existing container '{SSHD_CONTAINER_NAME}'.") + except docker.errors.NotFound: + print(f"No existing container '{SSHD_CONTAINER_NAME}' found. Proceeding.") + except Exception as e: + print(f"Error cleaning up existing container: {e}") + # Do not raise, try to proceed with run + + # Build the custom Docker image + print(f"\nBuilding custom Docker image '{SSHD_IMAGE_NAME}' from {SSHD_DOCKERFILE_PATH}...") + try: + # Use path to Dockerfile directory as context + image, build_logs = docker_client.images.build( + path=SSHD_DOCKERFILE_PATH, + tag=SSHD_IMAGE_NAME, + rm=True # Remove intermediate containers + ) + for chunk in build_logs: + if 'stream' in chunk: + print(chunk['stream'], end='') + print(f"Successfully built custom image '{SSHD_IMAGE_NAME}'.") + except docker.errors.BuildError as e: + print(f"Error building Docker image: {e}") + for line in e.build_log: + if 'stream' in line: + print(line['stream'], end='') + raise + except Exception as e: + print(f"An unexpected error occurred during Docker image build: {e}") + raise + + print(f"Starting {SSHD_CONTAINER_NAME} container...") + try: + container = docker_client.containers.run( + SSHD_IMAGE_NAME, # Use the custom image + detach=True, + ports={'2222/tcp': None}, # Map container's SSH port (2222) to a random host port + name=SSHD_CONTAINER_NAME, + environment={ + "PUID": os.getuid(), + "PGID": os.getgid(), + "USER_PASSWORD": TEST_SSH_PASSWORD, # Correct variable for password + "USER_NAME": TEST_SSH_USER, # Correct variable for username + "PUBLIC_KEY": "", # No public key for password auth + "TZ": "UTC", + "SUDO_ACCESS": "true", # Allow sudo for testuser if needed + "EXPOSE_SSH_PORT": "2222", # Explicitly expose port 2222 + "PASSWORD_ACCESS": "true" # Enable password authentication + }, + restart_policy={"Name": "no"} + ) + + # Wait for port mapping to be available + max_retries = 10 + for attempt in range(max_retries): + container.reload() + if '2222/tcp' in container.ports and container.ports['2222/tcp']: + host_port = container.ports['2222/tcp'][0]['HostPort'] + print(f"Port 2222/tcp mapped to host port {host_port} after {attempt+1} attempts.") + break + print(f"Attempt {attempt+1}/{max_retries}: Port 2222/tcp not yet mapped. Retrying...") + time.sleep(1) + else: + print(f"Error: Port 2222/tcp not mapped on host after {max_retries} attempts. Container ports: {container.ports}") + print("Container logs:") + print(container.logs().decode('utf-8')) + raise RuntimeError("Failed to map SSHD port.") + + print(f"SSHD container running on {docker_host_ip}:{host_port}") + print(f"SSHD container running on {docker_host_ip}:{host_port}") + + # Wait for SSHD to be ready + ready = False + for i in range(60): # Wait up to 60 seconds + try: + # Try to connect using asyncssh to check if the server is up + # Use a short timeout for the readiness check + conn = await asyncio.wait_for( + asyncssh.connect( + docker_host_ip, + port=int(host_port), + options=asyncssh.SSHClientConnectionOptions( + username=TEST_SSH_USER, + password=TEST_SSH_PASSWORD, + known_hosts=None # Disable host key checking for test container + ) + ), + timeout=5 + ) + conn.close() # Close the temporary connection used for readiness check + await conn.wait_closed() + print(f"SSHD is ready after {i+1} seconds.") + ready = True + break + except (asyncssh.Error, asyncio.TimeoutError, ConnectionRefusedError) as e: + # print(f"SSHD not ready yet: {e}") # Uncomment for verbose debugging + pass + time.sleep(1) + + if not ready: + print("\nSSHD container did not become ready in time. Container logs:") + if container: + print(container.logs().decode('utf-8')) + raise RuntimeError("SSHD container did not become ready in time.") + + # Python3 is now pre-installed in the custom Docker image, so no runtime installation needed. + print("Python3 is pre-installed in the custom Docker image.") + + yield docker_host_ip, int(host_port), TEST_SSH_USER, TEST_SSH_PASSWORD + + finally: + if container: + print(f"Stopping and removing {SSHD_CONTAINER_NAME}...") + container.stop() + container.remove() + print(f"Removed {SSHD_CONTAINER_NAME}.") + +@pytest.fixture(scope="session") +def docker_host_ip(): + """Determines the Docker host IP for connecting to containers.""" + docker_host = os.environ.get("DOCKER_HOST") + + # Regex to match tcp://hostname:port, unix://socket, or ip:port + regex = r'^(?:(tcp|unix)://)?([a-zA-Z0-9.-]+)(?::\d+)?$' + + if docker_host: + match = re.match(regex, docker_host) + if match: + protocol, host = match.groups() + if protocol == "unix": + # For unix sockets, connections are typically local, but asyncssh needs an IP + # In this case, 'localhost' is usually appropriate for host-to-container communication + return "localhost" + return host # Return the IP or hostname + return "localhost" # Default for local Docker setup + +@pytest.mark.asyncio +async def test_ssh_pipe_with_hello_world_emitter(sshd_container): + """ + Tests SshPipe by connecting to an SSHD container that emits "hello world" every second. + """ + host, port, username, password = sshd_container + + print(f"\nAttempting SshPipe connection to {host}:{port} as user {username}...") + pipe = None + try: + emitter_command = "sh -c 'while true; do echo \"hello world\"; sleep 1; done'" + pipe = await SshPipe.connect( + host=host, + port=port, + username=username, + password=password, + known_hosts=None, + timeout=10, + command=emitter_command + ) + + received_event = asyncio.Event() + received_data = [] + + def on_data_callback(data): + print(f"test_ssh_pipe_with_hello_world_emitter: Received data: {data!r}") + received_data.append(data) + # The data should be bytes, so compare directly with bytes literal + if b"hello world" in data: + received_event.set() + + pipe.pipeline('data', on_data_callback) + + try: + await asyncio.wait_for(received_event.wait(), timeout=10) + except asyncio.TimeoutError: + pytest.fail("Did not receive 'hello world' within 10 seconds.") + + full_received_data = b"".join(received_data) + assert b"hello world" in full_received_data + + finally: + if pipe: + await pipe.terminate() + print("SshPipe closed.") + + +@pytest.mark.asyncio +async def test_ssh_pipe_with_echo_back_command(sshd_container): + """ + Tests SshPipe by connecting to an SSHD container and running 'echo_back.py' to echo back input. + """ + host, port, username, password = sshd_container + + print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} for echo-back test (using echo_back.py)...") + pipe = None + try: + # Use the custom Python echo-back script + echo_back_command = "python3 /usr/local/bin/echo_back.py" + pipe = await SshPipe.connect( + host=host, + port=port, + username=username, + password=password, + known_hosts=None, # Disable host key checking for test container + timeout=10, + command=echo_back_command + ) + + received_data_chunks = [] + received_event = asyncio.Event() + + def on_data_callback(data): + print(f"test_ssh_pipe_with_echo_back_command: Received data chunk: {data!r}") + received_data_chunks.append(data) + # For echo-back, we expect the full message to be returned + # We'll set the event once we receive some data. + received_event.set() + + pipe.pipeline('data', on_data_callback) + + test_message = b"This is a test message for echo back\n" + print(f"Writing data to pipe: {test_message!r}") + await pipe.write(test_message) + await asyncio.sleep(1) + + try: + await asyncio.wait_for(received_event.wait(), timeout=10) + except asyncio.TimeoutError: + pytest.fail("Did not receive any data within 10 seconds.") + + full_received_data = b"".join(received_data_chunks) + print(f"Received data from pipe: {full_received_data!r}") + + # The Python echo_back.py script should echo back exactly what it receives. + # Strip any potential carriage returns or newlines added by the shell/terminal. + assert full_received_data.strip() == test_message.strip(), \ + f"Received data {full_received_data!r} should exactly match sent data {test_message!r}" + print("Verification successful: Data echoed correctly by 'echo_back.py' script.") + await pipe.write_eof() # Signal EOF to the remote process + + except asyncssh.Error as e: + pytest.fail(f"SSH connection or command failed: {e}") + except asyncio.TimeoutError: + pytest.fail(f"SSH connection timed out.") + except Exception as e: + pytest.fail(f"An unexpected error occurred: {e}") + finally: + if pipe: + await pipe.terminate() + print("SshPipe closed.") +@pytest.mark.asyncio +async def test_ssh_pipe_with_binary_data(sshd_container): + """ + Tests SshPipe by connecting to an SSHD container that emits raw binary data. + """ + host, port, username, password = sshd_container + + print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} for binary data emitter test...") + pipe = None + try: + # Python script to continuously emit binary data + # Using os.write(1, ...) to write raw bytes to stdout + emitter_binary_command = ( + "python3 /usr/local/bin/binary_emitter.py" + ) + pipe = await SshPipe.connect( + host=host, + port=port, + username=username, + password=password, + known_hosts=None, # Disable host key checking for test container + timeout=10, + command=emitter_binary_command + ) + + received_event = asyncio.Event() + received_data_chunks = [] + expected_binary_pattern = b'\xDE\xAD\xBE\xEF\x00\x01\x02\x03\x80\xFF\x7F' # Updated pattern without newlines + + def on_data_callback(data): + print(f"test_ssh_pipe_with_binary_data: Received data chunk: {data!r}") + # Strip any potential carriage returns or newlines added by the shell + cleaned_data = data.replace(b'\r', b'').replace(b'\n', b'') + received_data_chunks.append(cleaned_data) + if expected_binary_pattern in cleaned_data: + received_event.set() + + pipe.pipeline('data', on_data_callback) + + try: + await asyncio.wait_for(received_event.wait(), timeout=10) + except asyncio.TimeoutError: + pytest.fail("Did not receive expected binary pattern within 10 seconds.") + + full_received_data = b"".join(received_data_chunks) + print(f"Received total {len(full_received_data)} bytes. First 50 bytes: {full_received_data[:50]!r}") + + assert expected_binary_pattern in full_received_data, \ + f"Expected binary pattern {expected_binary_pattern!r} not found in received data {full_received_data!r}" + print("Verification successful: Binary data emitted and received correctly.") + + except asyncssh.Error as e: + pytest.fail(f"SSH connection or command failed: {e}") + except asyncio.TimeoutError: + pytest.fail(f"SSH connection timed out.") + except Exception as e: + pytest.fail(f"An unexpected error occurred: {e}") + finally: + if pipe: + await pipe.terminate() + print("SshPipe closed.") From b1dcb61c10d690874666e4f88c49c12c02f64099 Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Sat, 20 Sep 2025 21:04:30 +0545 Subject: [PATCH 13/25] Make ssh pipe work with all authentication methods. --- cbor_rpc/ssh/ssh_pipe.py | 2 +- pyproject.toml | 3 +- tests/test_ssh_docker_pipe.py | 384 +++++++++++++++++++++++++--------- 3 files changed, 287 insertions(+), 102 deletions(-) diff --git a/cbor_rpc/ssh/ssh_pipe.py b/cbor_rpc/ssh/ssh_pipe.py index 357a104..81911a9 100644 --- a/cbor_rpc/ssh/ssh_pipe.py +++ b/cbor_rpc/ssh/ssh_pipe.py @@ -52,7 +52,7 @@ async def connect(cls, """ client_keys = None if ssh_key_content: - client_keys = [asyncssh.read_private_key(ssh_key_content, passphrase=ssh_key_passphrase)] + client_keys = [asyncssh.import_private_key(ssh_key_content, passphrase=ssh_key_passphrase)] try: conn = await asyncio.wait_for( diff --git a/pyproject.toml b/pyproject.toml index 92a85dc..88db9b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,8 @@ requires-python = ">=3.8" dependencies = [ "pytest>=8.3.2", "pytest-asyncio>=0.24.0", - "asyncssh>=2.14.0" + "asyncssh>=2.14.0", + "bcrypt" ] [build-system] diff --git a/tests/test_ssh_docker_pipe.py b/tests/test_ssh_docker_pipe.py index edded2a..93d55f6 100644 --- a/tests/test_ssh_docker_pipe.py +++ b/tests/test_ssh_docker_pipe.py @@ -5,6 +5,7 @@ import asyncssh import os import re +import asyncssh.public_key from cbor_rpc.ssh.ssh_pipe import SshPipe @@ -15,6 +16,20 @@ SSHD_CONTAINER_NAME = "test-sshd-container" SSHD_DOCKERFILE_PATH = "./tests/docker/sshd-python" # Path to the Dockerfile +@pytest.fixture(scope="session") +def ssh_keys(): + """Generates SSH keys (plain and encrypted) and a passphrase for testing.""" + private_key_obj = asyncssh.generate_private_key('ssh-rsa') + passphrase = "test_passphrase" + + return { + "unencrypted_private": private_key_obj.export_private_key().decode(), + "unencrypted_public": private_key_obj.export_public_key().decode(), + "encrypted_private": private_key_obj.export_private_key(passphrase=passphrase).decode(), + "encrypted_public": private_key_obj.export_public_key().decode(), + "passphrase": passphrase + } + @pytest.fixture(scope="session") def docker_client(): """Provides a Docker client instance.""" @@ -23,134 +38,135 @@ def docker_client(): client.close() @pytest.fixture(scope="session") -async def sshd_container(docker_client: docker.DockerClient, docker_host_ip): - """ - Starts an SSHD Docker container, configures it, and yields its connection details. - """ - container = None - host_port = None +def test_network(docker_client: docker.DockerClient): + """Provides a Docker network for containers to communicate.""" + network_name = "test-ssh-network" + try: + network = docker_client.networks.get(network_name) + network.remove() # Clean up existing network if it exists + except docker.errors.NotFound: + pass + + network = docker_client.networks.create(network_name, driver="bridge") + yield network + network.remove() + +@pytest.fixture(scope="module") # Changed scope to module as requested +async def ssh_container_combined_auth(docker_client: docker.DockerClient, test_network, docker_host_ip, ssh_keys): + container_name = "ssh-test-container-combined-auth" + ssh_user = TEST_SSH_USER + ssh_password = TEST_SSH_PASSWORD + public_key = ssh_keys["unencrypted_public"] # Use the unencrypted public key # Ensure previous container is stopped and removed try: - existing_container = docker_client.containers.get(SSHD_CONTAINER_NAME) - print(f"Found existing container '{SSHD_CONTAINER_NAME}'. Stopping and removing...") + existing_container = docker_client.containers.get(container_name) + print(f"Found existing container '{container_name}'. Stopping and removing...") existing_container.stop() existing_container.remove() - print(f"Removed existing container '{SSHD_CONTAINER_NAME}'.") + print(f"Removed existing container '{container_name}'.") except docker.errors.NotFound: - print(f"No existing container '{SSHD_CONTAINER_NAME}' found. Proceeding.") + print(f"No existing container '{container_name}' found. Proceeding.") except Exception as e: print(f"Error cleaning up existing container: {e}") - # Do not raise, try to proceed with run # Build the custom Docker image - print(f"\nBuilding custom Docker image '{SSHD_IMAGE_NAME}' from {SSHD_DOCKERFILE_PATH}...") + print(f"\nBuilding Docker image '{SSHD_IMAGE_NAME}' from '{SSHD_DOCKERFILE_PATH}'...") try: - # Use path to Dockerfile directory as context - image, build_logs = docker_client.images.build( + docker_client.images.build( path=SSHD_DOCKERFILE_PATH, tag=SSHD_IMAGE_NAME, rm=True # Remove intermediate containers ) - for chunk in build_logs: - if 'stream' in chunk: - print(chunk['stream'], end='') - print(f"Successfully built custom image '{SSHD_IMAGE_NAME}'.") + print(f"Docker image '{SSHD_IMAGE_NAME}' built successfully.") except docker.errors.BuildError as e: - print(f"Error building Docker image: {e}") - for line in e.build_log: - if 'stream' in line: - print(line['stream'], end='') - raise - except Exception as e: - print(f"An unexpected error occurred during Docker image build: {e}") - raise + print(f"Failed to build Docker image: {e}") + raise RuntimeError(f"Failed to build Docker image '{SSHD_IMAGE_NAME}'.") - print(f"Starting {SSHD_CONTAINER_NAME} container...") + print(f"\nStarting {SSHD_IMAGE_NAME} container with combined auth...") + container = None try: container = docker_client.containers.run( - SSHD_IMAGE_NAME, # Use the custom image + SSHD_IMAGE_NAME, # Use the custom image name detach=True, - ports={'2222/tcp': None}, # Map container's SSH port (2222) to a random host port - name=SSHD_CONTAINER_NAME, + ports={'2222/tcp': None}, # Map container SSH port to a random host port + network=test_network.name, + name=container_name, environment={ - "PUID": os.getuid(), - "PGID": os.getgid(), - "USER_PASSWORD": TEST_SSH_PASSWORD, # Correct variable for password - "USER_NAME": TEST_SSH_USER, # Correct variable for username - "PUBLIC_KEY": "", # No public key for password auth - "TZ": "UTC", - "SUDO_ACCESS": "true", # Allow sudo for testuser if needed - "EXPOSE_SSH_PORT": "2222", # Explicitly expose port 2222 - "PASSWORD_ACCESS": "true" # Enable password authentication + "PUID": "1000", + "PGID": "1000", + "TZ": "Etc/UTC", + "PASSWORD_ACCESS": "true", # Password access enabled + "USER_NAME": ssh_user, + "USER_PASSWORD": ssh_password, + "PUBLIC_KEY": public_key # Add one public key }, restart_policy={"Name": "no"} ) - - # Wait for port mapping to be available - max_retries = 10 - for attempt in range(max_retries): + + container.reload() + host_port = None + for _ in range(30): # Wait up to 30 seconds for port mapping container.reload() if '2222/tcp' in container.ports and container.ports['2222/tcp']: host_port = container.ports['2222/tcp'][0]['HostPort'] - print(f"Port 2222/tcp mapped to host port {host_port} after {attempt+1} attempts.") break - print(f"Attempt {attempt+1}/{max_retries}: Port 2222/tcp not yet mapped. Retrying...") time.sleep(1) - else: - print(f"Error: Port 2222/tcp not mapped on host after {max_retries} attempts. Container ports: {container.ports}") - print("Container logs:") - print(container.logs().decode('utf-8')) - raise RuntimeError("Failed to map SSHD port.") - - print(f"SSHD container running on {docker_host_ip}:{host_port}") - print(f"SSHD container running on {docker_host_ip}:{host_port}") - # Wait for SSHD to be ready + if host_port is None: + raise RuntimeError("Failed to get host port for SSH container within 30 seconds.") + + print(f"SSH container with combined auth running on host port: {host_port}") + + # Wait for SSH server to be ready ready = False - for i in range(60): # Wait up to 60 seconds + for i in range(60): # wait up to 60 seconds try: - # Try to connect using asyncssh to check if the server is up - # Use a short timeout for the readiness check - conn = await asyncio.wait_for( - asyncssh.connect( - docker_host_ip, - port=int(host_port), - options=asyncssh.SSHClientConnectionOptions( - username=TEST_SSH_USER, - password=TEST_SSH_PASSWORD, - known_hosts=None # Disable host key checking for test container - ) - ), - timeout=5 - ) - conn.close() # Close the temporary connection used for readiness check - await conn.wait_closed() - print(f"SSHD is ready after {i+1} seconds.") - ready = True - break - except (asyncssh.Error, asyncio.TimeoutError, ConnectionRefusedError) as e: - # print(f"SSHD not ready yet: {e}") # Uncomment for verbose debugging + async def check_ssh_combined(): + # Check password authentication + try: + conn_pw = await asyncssh.connect(docker_host_ip, port=int(host_port), username=ssh_user, password=ssh_password, known_hosts=None) + conn_pw.close() + print("Password auth check successful.") + except (asyncssh.Error, OSError) as e: + print(f"Password auth check failed: {e}") + return False + + # Check public key authentication (unencrypted) + try: + conn_key = await asyncssh.connect(docker_host_ip, port=int(host_port), username=ssh_user, client_keys=[asyncssh.import_private_key(ssh_keys["unencrypted_private"])], known_hosts=None) + conn_key.close() + print("Public key auth check successful.") + except (asyncssh.Error, OSError) as e: + print(f"Public key auth check failed: {e}") + return False + + return True + + if await check_ssh_combined(): + print(f"SSH server with combined auth is ready after {i+1} seconds.") + ready = True + break + except Exception as e: + print(f"Error during SSH readiness check: {e}") pass time.sleep(1) - + if not ready: - print("\nSSHD container did not become ready in time. Container logs:") + print("\nSSH server with combined auth did not become ready in time. Container logs:") if container: print(container.logs().decode('utf-8')) - raise RuntimeError("SSHD container did not become ready in time.") - - # Python3 is now pre-installed in the custom Docker image, so no runtime installation needed. - print("Python3 is pre-installed in the custom Docker image.") - - yield docker_host_ip, int(host_port), TEST_SSH_USER, TEST_SSH_PASSWORD - + raise RuntimeError("SSH server with combined auth did not become ready in time.") + + yield container, docker_host_ip, host_port, ssh_user, ssh_password, ssh_keys["unencrypted_private"] finally: if container: - print(f"Stopping and removing {SSHD_CONTAINER_NAME}...") + print("Stopping and removing ssh-test-container-combined-auth...") + print("=========================== Container Logs Start ===========================") + print(container.logs().decode('utf-8')) + print("=========================== Container Logs End ===========================") container.stop() container.remove() - print(f"Removed {SSHD_CONTAINER_NAME}.") @pytest.fixture(scope="session") def docker_host_ip(): @@ -172,13 +188,14 @@ def docker_host_ip(): return "localhost" # Default for local Docker setup @pytest.mark.asyncio -async def test_ssh_pipe_with_hello_world_emitter(sshd_container): +async def test_ssh_pipe_with_hello_world_emitter(ssh_container_combined_auth): """ Tests SshPipe by connecting to an SSHD container that emits "hello world" every second. + This test uses password authentication. """ - host, port, username, password = sshd_container + container, host, port, username, password, _ = ssh_container_combined_auth - print(f"\nAttempting SshPipe connection to {host}:{port} as user {username}...") + print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} with password authentication...") pipe = None try: emitter_command = "sh -c 'while true; do echo \"hello world\"; sleep 1; done'" @@ -219,13 +236,183 @@ def on_data_callback(data): @pytest.mark.asyncio -async def test_ssh_pipe_with_echo_back_command(sshd_container): +async def test_ssh_pipe_with_password_authentication(ssh_container_combined_auth): + """ + Tests SshPipe using username and password authentication. + This is a dedicated test for password authentication, ensuring it works as expected. + """ + container, host, port, username, password, _ = ssh_container_combined_auth + + print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} with password authentication...") + pipe = None + try: + # Use a simple command to verify connection, e.g., 'echo' + test_command = "echo 'Password auth successful!'" + pipe = await SshPipe.connect( + host=host, + port=port, + username=username, + password=password, + known_hosts=None, + timeout=10, + command=test_command + ) + + received_event = asyncio.Event() + received_data = [] + + def on_data_callback(data): + print(f"test_ssh_pipe_with_password_authentication: Received data: {data!r}") + received_data.append(data) + if b"Password auth successful!" in data: + received_event.set() + + pipe.pipeline('data', on_data_callback) + + try: + await asyncio.wait_for(received_event.wait(), timeout=10) + except asyncio.TimeoutError: + pytest.fail("Did not receive expected output within 10 seconds for password authentication test.") + + full_received_data = b"".join(received_data) + assert b"Password auth successful!" in full_received_data + print("Password authentication successful.") + + except asyncssh.Error as e: + pytest.fail(f"SSH connection or command failed with password authentication: {e}") + except asyncio.TimeoutError: + pytest.fail(f"SSH connection timed out with password authentication.") + except Exception as e: + pytest.fail(f"An unexpected error occurred during password authentication test: {e}") + finally: + if pipe: + await pipe.terminate() + print("SshPipe closed for password authentication test.") + + +@pytest.mark.asyncio +async def test_ssh_pipe_with_plain_key_authentication(ssh_container_combined_auth, ssh_keys): + """ + Tests SshPipe using a plain (unencrypted) SSH key for authentication. + This test is designed to run when sshd_container is configured for plain key auth. + """ + container, host, port, username, _, unencrypted_private_key_content = ssh_container_combined_auth + + print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} with plain key authentication...") + + pipe = None + try: + test_command = "echo 'Plain key auth successful!'" + pipe = await SshPipe.connect( + host=host, + port=port, + username=username, + ssh_key_content=unencrypted_private_key_content, # Use the private key content for authentication + known_hosts=None, + timeout=10, + command=test_command + ) + + received_event = asyncio.Event() + received_data = [] + + def on_data_callback(data): + print(f"test_ssh_pipe_with_plain_key_authentication: Received data: {data!r}") + received_data.append(data) + if b"Plain key auth successful!" in data: + received_event.set() + + pipe.pipeline('data', on_data_callback) + + try: + await asyncio.wait_for(received_event.wait(), timeout=10) + except asyncio.TimeoutError: + pytest.fail("Did not receive expected output within 10 seconds for plain key authentication test.") + + full_received_data = b"".join(received_data) + assert b"Plain key auth successful!" in full_received_data + print("Plain key authentication successful.") + + except asyncssh.Error as e: + pytest.fail(f"SSH connection or command failed with plain key authentication: {e}") + except asyncio.TimeoutError: + pytest.fail(f"SSH connection timed out with plain key authentication.") + except Exception as e: + pytest.fail(f"An unexpected error occurred during plain key authentication test: {e}") + finally: + if pipe: + await pipe.terminate() + print("SshPipe closed for plain key authentication test.") + + +@pytest.mark.asyncio +async def test_ssh_pipe_with_encrypted_key_authentication(ssh_container_combined_auth, ssh_keys): + """ + Tests SshPipe using an encrypted SSH key with a passphrase for authentication. + This test is designed to run when sshd_container is configured for encrypted key auth. + """ + container, host, port, username, _, _ = ssh_container_combined_auth + + private_key_content = ssh_keys['encrypted_private'] + passphrase = ssh_keys['passphrase'] + + print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} with encrypted key authentication...") + + pipe = None + try: + test_command = "echo 'Encrypted key auth successful!'" + pipe = await SshPipe.connect( + host=host, + port=port, + username=username, + ssh_key_content=private_key_content, # Use the private key content for authentication + ssh_key_passphrase=passphrase, # Pass the passphrase + known_hosts=None, + timeout=10, + command=test_command + ) + + received_event = asyncio.Event() + received_data = [] + + def on_data_callback(data): + print(f"test_ssh_pipe_with_encrypted_key_authentication: Received data: {data!r}") + received_data.append(data) + if b"Encrypted key auth successful!" in data: + received_event.set() + + pipe.pipeline('data', on_data_callback) + + try: + await asyncio.wait_for(received_event.wait(), timeout=10) + except asyncio.TimeoutError: + pytest.fail("Did not receive expected output within 10 seconds for encrypted key authentication test.") + + full_received_data = b"".join(received_data) + assert b"Encrypted key auth successful!" in full_received_data + print("Encrypted key authentication successful.") + + except asyncssh.Error as e: + pytest.fail(f"SSH connection or command failed with encrypted key authentication: {e}") + except asyncio.TimeoutError: + pytest.fail(f"SSH connection timed out.") + except Exception as e: + pytest.fail(f"An unexpected error occurred during encrypted key authentication test: {e}") + finally: + if pipe: + await pipe.terminate() + print("SshPipe closed for encrypted key authentication test.") + + +@pytest.mark.asyncio +async def test_ssh_pipe_with_echo_back_command(ssh_container_combined_auth): """ Tests SshPipe by connecting to an SSHD container and running 'echo_back.py' to echo back input. + This test uses password authentication. """ - host, port, username, password = sshd_container + container, host, port, username, password, _ = ssh_container_combined_auth - print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} for echo-back test (using echo_back.py)...") + print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} for echo-back test (using echo_back.py) with password authentication...") pipe = None try: # Use the custom Python echo-back script @@ -263,10 +450,6 @@ def on_data_callback(data): pytest.fail("Did not receive any data within 10 seconds.") full_received_data = b"".join(received_data_chunks) - print(f"Received data from pipe: {full_received_data!r}") - - # The Python echo_back.py script should echo back exactly what it receives. - # Strip any potential carriage returns or newlines added by the shell/terminal. assert full_received_data.strip() == test_message.strip(), \ f"Received data {full_received_data!r} should exactly match sent data {test_message!r}" print("Verification successful: Data echoed correctly by 'echo_back.py' script.") @@ -283,13 +466,14 @@ def on_data_callback(data): await pipe.terminate() print("SshPipe closed.") @pytest.mark.asyncio -async def test_ssh_pipe_with_binary_data(sshd_container): +async def test_ssh_pipe_with_binary_data(ssh_container_combined_auth): """ Tests SshPipe by connecting to an SSHD container that emits raw binary data. + This test uses password authentication. """ - host, port, username, password = sshd_container + container, host, port, username, password, _ = ssh_container_combined_auth - print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} for binary data emitter test...") + print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} for binary data emitter test with password authentication...") pipe = None try: # Python script to continuously emit binary data From 33b081302472b89735e5252f264040644d804d29 Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Fri, 6 Feb 2026 22:47:26 +0545 Subject: [PATCH 14/25] Add cbor transformer --- cbor_rpc/event/emitter.py | 1 + cbor_rpc/pipe/event_pipe.py | 2 +- cbor_rpc/pipe/pipe.py | 2 + cbor_rpc/rpc/rpc_server.py | 1 + cbor_rpc/tcp/tcp.py | 28 +- cbor_rpc/transformer/__init__.py | 3 +- .../base/event_transformer_pipe.py | 8 +- cbor_rpc/transformer/base/transformer_base.py | 12 +- cbor_rpc/transformer/cbor_transformer.py | 104 ++++++ pyproject.toml | 3 +- tests/helpers/timeout_queue.py | 13 + tests/test_cbor_transformer.py | 322 ++++++++++++++++++ tests/test_event_pipe.py | 3 +- tests/test_json_transformer.py | 175 ++++++---- tests/test_tcp.py | 79 +++++ 15 files changed, 667 insertions(+), 89 deletions(-) create mode 100644 cbor_rpc/transformer/cbor_transformer.py create mode 100644 tests/helpers/timeout_queue.py create mode 100644 tests/test_cbor_transformer.py diff --git a/cbor_rpc/event/emitter.py b/cbor_rpc/event/emitter.py index 89c59b5..37538cb 100644 --- a/cbor_rpc/event/emitter.py +++ b/cbor_rpc/event/emitter.py @@ -59,6 +59,7 @@ async def _notify(self, event_type: str, *args: Any) -> None: for result in results: if isinstance(result, Exception): self._emit("error", result) + # Must raise error to notify caller and block normal execution raise result self._emit(event_type, *args) diff --git a/cbor_rpc/pipe/event_pipe.py b/cbor_rpc/pipe/event_pipe.py index 674135e..11577e5 100644 --- a/cbor_rpc/pipe/event_pipe.py +++ b/cbor_rpc/pipe/event_pipe.py @@ -23,7 +23,7 @@ async def terminate(self, *args: Any) -> None: pass @staticmethod - def create_pair() -> Tuple['EventPipe[Any, Any]', 'EventPipe[Any, Any]']: + def create_inmemory_pair() -> Tuple['EventPipe[Any, Any]', 'EventPipe[Any, Any]']: """ Create a pair of connected pipes for bidirectional communication. diff --git a/cbor_rpc/pipe/pipe.py b/cbor_rpc/pipe/pipe.py index 4cca8d3..659fe67 100644 --- a/cbor_rpc/pipe/pipe.py +++ b/cbor_rpc/pipe/pipe.py @@ -92,6 +92,7 @@ async def _pump(self): while not self._closed: chunk = await parent.read() if chunk is None: + print("PipeToEvent: Received termination signal, terminating event pipe.") await self.terminate() break await self._notify("data", chunk) @@ -103,6 +104,7 @@ async def terminate(self, *args: Any) -> None: if self._closed: return self._closed = True + print("PipeToEvent: Terminating event pipe.") await parent.terminate(*args) self._emit("close", *args) if self._pump_task: diff --git a/cbor_rpc/rpc/rpc_server.py b/cbor_rpc/rpc/rpc_server.py index 4786391..f778ec5 100644 --- a/cbor_rpc/rpc/rpc_server.py +++ b/cbor_rpc/rpc/rpc_server.py @@ -37,6 +37,7 @@ async def cleanup(*args): async def disconnect(self, connection_id: str, reason: Optional[str] = None) -> None: client = self.active_connections.pop(connection_id, None) if client: + print("RpcV1Server: Disconnecting client:", connection_id) await client.pipe.terminate(1000, reason or "Server terminated connection") def set_timeout(self, milliseconds: int) -> None: diff --git a/cbor_rpc/tcp/tcp.py b/cbor_rpc/tcp/tcp.py index ec1954b..3174ad4 100644 --- a/cbor_rpc/tcp/tcp.py +++ b/cbor_rpc/tcp/tcp.py @@ -67,15 +67,18 @@ async def create_server(cls, host: str = '0.0.0.0', port: int = 0, return await TcpServer.create(host, port, backlog) @staticmethod - async def create_pair() -> Tuple['TcpPipe', 'TcpPipe']: + async def create_inmemory_pair() -> Tuple['TcpPipe', 'TcpPipe']: """ Create a pair of connected TCP pipes using a local server. Returns: A tuple of (client_pipe, server_pipe) connected via TCP """ + class SimpleTcpServer(TcpServer): + async def accept(self, pipe: TcpPipe) -> bool: + return True # Create a temporary server - server = await TcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create('127.0.0.1', 0) host, port = server.get_address() # Set up to capture the server-side connection @@ -96,8 +99,8 @@ async def on_connection(pipe: TcpPipe): # Wait for server connection await connection_ready.wait() - # Close the server but keep the connections - await server.stop() + # Stop accepting new connections but keep the active pipes + await server.shutdown() return client_pipe, server_pipe @@ -199,17 +202,17 @@ async def create(cls, host: str = '0.0.0.0', port: int = 0, """ tcp_server = cls.__new__(cls) Server.__init__(tcp_server) - - async def client_connected_cb(reader: asyncio.StreamReader, + + async def client_connected_cb(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: tcp_pipe = TcpPipe(reader, writer) await tcp_pipe._setup_connection() await tcp_server._add_connection(tcp_pipe) - + server = await asyncio.start_server( client_connected_cb, host, port, backlog=backlog ) - + tcp_server._server = server tcp_server._running = True return tcp_server @@ -249,6 +252,15 @@ async def stop(self) -> None: if self._server: self._server.close() await self._server.wait_closed() + + async def shutdown(self) -> None: + """Stop accepting new connections while keeping active connections open.""" + if not self._server: + return + + self._running = False + self._server.close() + await self._server.wait_closed() def get_address(self) -> Tuple[str, int]: """Get the server's listening address and port.""" diff --git a/cbor_rpc/transformer/__init__.py b/cbor_rpc/transformer/__init__.py index e2e7e22..3ba6de9 100644 --- a/cbor_rpc/transformer/__init__.py +++ b/cbor_rpc/transformer/__init__.py @@ -1,2 +1,3 @@ from .base import Transformer -from .json_transformer import JsonTransformer \ No newline at end of file +from .json_transformer import JsonTransformer +from .cbor_transformer import CborTransformer, CborStreamTransformer diff --git a/cbor_rpc/transformer/base/event_transformer_pipe.py b/cbor_rpc/transformer/base/event_transformer_pipe.py index 7defd03..1700a57 100644 --- a/cbor_rpc/transformer/base/event_transformer_pipe.py +++ b/cbor_rpc/transformer/base/event_transformer_pipe.py @@ -5,6 +5,7 @@ from typing import Any, Awaitable, Callable, TypeVar from cbor_rpc.pipe import EventPipe +from .base_exception import NeedsMoreDataException T1 = TypeVar("T1") # Output type after decoding T2 = TypeVar("T2") # Input type before decoding (pipe input/output type) @@ -29,8 +30,13 @@ async def _handle_data(self, data: T2): try: decoded = await self.decode(data) self._emit("data", decoded) + except NeedsMoreDataException: + # If more data is needed, simply return and wait for the next chunk + return except Exception as e: - self._emit("error", e) + # Let the exception propagate up to AbstractEmitter._notify, + # which will catch it and emit the "error" event. + raise e def _on_close(self, *args: Any): self._emit("close", *args) diff --git a/cbor_rpc/transformer/base/transformer_base.py b/cbor_rpc/transformer/base/transformer_base.py index 0529e3a..c84562d 100644 --- a/cbor_rpc/transformer/base/transformer_base.py +++ b/cbor_rpc/transformer/base/transformer_base.py @@ -35,6 +35,10 @@ def bind(self, pipe: Pipe) -> TransformerPipe: ... @overload def bind(self, pipe: EventPipe) -> EventTransformerPipe: ... + @overload + def applyTransformer(self, pipe: Pipe) -> TransformerPipe: ... + @overload + def applyTransformer(self, pipe: EventPipe) -> EventTransformerPipe: ... def applyTransformer(self, pipe: Union[Pipe, EventPipe]) -> Union[TransformerPipe, EventTransformerPipe]: if isinstance(pipe, EventPipe): return EventTransformerPipe(pipe, self.to_async()) @@ -83,10 +87,14 @@ def bind(self, pipe: Pipe) -> TransformerPipe: ... @overload def bind(self, pipe: EventPipe) -> EventTransformerPipe: ... + @overload + def applyTransformer(self, pipe: Pipe) -> TransformerPipe: ... + @overload + def applyTransformer(self, pipe: EventPipe) -> EventTransformerPipe: ... def applyTransformer(self, pipe: Union[Pipe, EventPipe]) -> Union[TransformerPipe, EventTransformerPipe]: if isinstance(pipe, EventPipe): - return EventTransformerPipe(pipe) + return EventTransformerPipe(pipe, self) elif isinstance(pipe, Pipe): - return TransformerPipe(pipe) + return TransformerPipe(pipe, self) else: raise TypeError("Invalid pipe type") diff --git a/cbor_rpc/transformer/cbor_transformer.py b/cbor_rpc/transformer/cbor_transformer.py new file mode 100644 index 0000000..b8a4235 --- /dev/null +++ b/cbor_rpc/transformer/cbor_transformer.py @@ -0,0 +1,104 @@ +import cbor2 +from io import BytesIO +from typing import Any, Union +from .base import Transformer, AsyncTransformer +from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException + +class CborTransformer(Transformer[Any, Any]): + """ + A transformer that encodes Python objects to CBOR bytes and decodes CBOR bytes back to Python objects. + """ + + def __init__(self): + super().__init__() + + def encode(self, data: Any) -> bytes: + try: + return cbor2.dumps(data) + except Exception: + raise + + def decode(self, data: Union[bytes, None]) -> Any: + if data is None: + raise TypeError("Expected bytes, got None") + + if not isinstance(data, bytes): + raise TypeError(f"Expected bytes, got {type(data)}") + + try: + # cbor2.loads can decode a single CBOR object from bytes + return cbor2.loads(data) + except cbor2.CBORDecodeEOF as e: + # For a non-stream transformer, incomplete data is a decoding error + raise cbor2.CBORDecodeError("Incomplete CBOR data for non-stream transformer") from e + except cbor2.CBORDecodeError: + # Re-raise other decoding errors + raise + +class CborStreamTransformer(AsyncTransformer[Any, Any]): + """ + An async transformer that decodes a stream of concatenated CBOR objects. + This is similar to how a JSON stream decoder would work, reading one object at a time. + """ + def __init__(self): + super().__init__() + self._buffer = bytearray() + + async def encode(self, data: Any) -> bytes: + try: + return cbor2.dumps(data) + except Exception: + raise + + async def decode(self, data: Union[bytes, None]) -> Any: + if data is not None: + if not isinstance(data, bytes): + raise TypeError(f"Expected bytes or None, got {type(data)}") + self._buffer.extend(data) + + if not self._buffer: + raise NeedsMoreDataException() + + try: + stream = BytesIO(self._buffer) + decoder = cbor2.CBORDecoder(stream) + decoded_data = decoder.decode() + + bytes_consumed = stream.tell() + self._buffer = self._buffer[bytes_consumed:] + + return decoded_data + except cbor2.CBORDecodeEOF: + raise NeedsMoreDataException() + except cbor2.CBORDecodeError as e: + original_exception = e + # Discard bytes from the buffer until a valid CBOR object can be decoded or the buffer is exhausted. + # This loop attempts to find the start of the next valid CBOR object. + while self._buffer: + # Discard one byte and try again + self._buffer = self._buffer[1:] + if not self._buffer: + break # Buffer is empty, cannot recover further + + try: + stream = BytesIO(self._buffer) + decoder = cbor2.CBORDecoder(stream) + decoded_data = decoder.decode() + + bytes_consumed = stream.tell() + self._buffer = self._buffer[bytes_consumed:] + + return decoded_data # Successfully decoded a new object + except cbor2.CBORDecodeEOF: + # If we need more data after discarding some, it means we might be in the middle of a valid object + raise NeedsMoreDataException() + except cbor2.CBORDecodeError: + # Still an error after discarding one byte, continue the loop to discard another + continue + # If we reach here, the buffer is exhausted or we couldn't recover. + # Re-raise the original exception as no valid CBOR object could be found. + raise original_exception + except Exception as e: + # For other unexpected errors, clear the buffer and re-raise + self._buffer = bytearray() + raise e diff --git a/pyproject.toml b/pyproject.toml index 88db9b4..911dfc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,8 @@ dependencies = [ "pytest>=8.3.2", "pytest-asyncio>=0.24.0", "asyncssh>=2.14.0", - "bcrypt" + "bcrypt", + "cbor2" ] [build-system] diff --git a/tests/helpers/timeout_queue.py b/tests/helpers/timeout_queue.py new file mode 100644 index 0000000..7c35e44 --- /dev/null +++ b/tests/helpers/timeout_queue.py @@ -0,0 +1,13 @@ +import asyncio +from typing import Any, Optional + + +class TimeoutQueue(asyncio.Queue): + def __init__(self, *args: Any, default_timeout: float = 1.0, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._default_timeout = default_timeout + + async def get(self, timeout: Optional[float] = None) -> Any: + if timeout is None: + timeout = self._default_timeout + return await asyncio.wait_for(super().get(), timeout=timeout) diff --git a/tests/test_cbor_transformer.py b/tests/test_cbor_transformer.py new file mode 100644 index 0000000..e5991c0 --- /dev/null +++ b/tests/test_cbor_transformer.py @@ -0,0 +1,322 @@ +import pytest +import pytest_asyncio +import asyncio +import cbor2 +from cbor_rpc.transformer.cbor_transformer import CborTransformer, CborStreamTransformer +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException +from cbor_rpc.transformer.base.event_transformer_pipe import EventTransformerPipe +from tests.helpers.timeout_queue import TimeoutQueue + + +@pytest_asyncio.fixture +async def _raw_pipe_pair(): + client_raw_pipe, server_raw_pipe = EventPipe.create_inmemory_pair() + yield client_raw_pipe, server_raw_pipe + await client_raw_pipe.terminate() + await server_raw_pipe.terminate() + + +@pytest.fixture +def client_raw(_raw_pipe_pair): + client_raw_pipe, _ = _raw_pipe_pair + return client_raw_pipe + + +@pytest.fixture +def server_raw(_raw_pipe_pair): + _, server_raw_pipe = _raw_pipe_pair + return server_raw_pipe + + +@pytest.fixture +def client_cbor(client_raw): + cbor_transformer = CborTransformer() + return cbor_transformer.applyTransformer(client_raw) + +@pytest.mark.asyncio +class TestCborTransformer: + + async def test_cbor_transformer_end_to_end_simple_dict(self, client_raw, server_raw, client_cbor): + client_transformed_pipe = client_cbor + + received_data_queue = TimeoutQueue() + server_raw.on("data", received_data_queue.put_nowait) + + original_data = {"message": "Hello, CBOR!", "number": 456, "list": [1, 2, 3]} + await client_transformed_pipe.write(original_data) + encoded_data_received_by_server = await received_data_queue.get() + + # Verify the raw bytes received by the server are valid CBOR + decoded_by_server = cbor2.loads(encoded_data_received_by_server) + assert decoded_by_server == original_data + + client_received_data_queue = TimeoutQueue() + client_transformed_pipe.on("data", client_received_data_queue.put_nowait) + + response_data = {"status": "cbor_success", "code": 200} + await server_raw.write(cbor2.dumps(response_data)) + decoded_data_received_by_client = await client_received_data_queue.get() + assert decoded_data_received_by_client == response_data + + async def test_cbor_transformer_decoding_error_on_read(self, server_raw, client_cbor): + client_transformed_pipe = client_cbor + + error_queue = TimeoutQueue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + # Simulate server sending incomplete CBOR bytes + incomplete_cbor_bytes = b'\x83\x01\x02' # Incomplete array, missing one element + with pytest.raises(cbor2.CBORDecodeError): + await server_raw.write(incomplete_cbor_bytes) + + # The CborTransformer should now raise CBORDecodeError for incomplete data + error = await asyncio.wait_for(error_queue.get(), timeout=1) + assert isinstance(error, cbor2.CBORDecodeError) + assert "Incomplete CBOR data for non-stream transformer" in str(error) + + # Send truly invalid data + truly_invalid_cbor = b'\x1f' # Unknown unsigned integer subtype + with pytest.raises(cbor2.CBORDecodeError): + await server_raw.write(truly_invalid_cbor) + error = await asyncio.wait_for(error_queue.get(), timeout=1) + assert isinstance(error, cbor2.CBORDecodeError) + assert "unknown unsigned integer subtype" in str(error) + + async def test_cbor_transformer_non_bytes_data(self, server_raw, client_cbor): + client_transformed_pipe = client_cbor + + error_queue = TimeoutQueue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + # Simulate server sending non-bytes data + non_bytes_data = "not cbor" + with pytest.raises(TypeError): + await server_raw.write(non_bytes_data) + + error = await asyncio.wait_for(error_queue.get(), timeout=1) + assert isinstance(error, TypeError) + assert "Expected bytes" in str(error) + + async def test_cbor_transformer_none_data(self, server_raw, client_cbor): + client_transformed_pipe = client_cbor + + error_queue = TimeoutQueue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + with pytest.raises(TypeError): + await server_raw.write(None) + + error = await asyncio.wait_for(error_queue.get(), timeout=1) + assert isinstance(error, TypeError) + assert "Expected bytes" in str(error) + + async def test_cbor_transformer_multiple_separate_writes(self, server_raw, client_cbor): + client_transformed_pipe = client_cbor + + received_data_queue = TimeoutQueue() + client_transformed_pipe.on("data", received_data_queue.put_nowait) + + await server_raw.write(cbor2.dumps({"a": 1})) + await server_raw.write(cbor2.dumps({"b": 2})) + + decoded1 = await received_data_queue.get() + decoded2 = await received_data_queue.get() + + assert decoded1 == {"a": 1} + assert decoded2 == {"b": 2} + + async def test_cbor_transformer_encode_error_on_write(self, server_raw, client_cbor): + client_transformed_pipe = client_cbor + + error_queue = TimeoutQueue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + received_data_queue = TimeoutQueue() + server_raw.on("data", received_data_queue.put_nowait) + + unserializable = {"func": lambda x: x} + result = await client_transformed_pipe.write(unserializable) + + assert result is False + error = await asyncio.wait_for(error_queue.get(), timeout=1) + assert isinstance(error, Exception) + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(received_data_queue.get(), timeout=0.1) + + async def test_cbor_transformer_close_propagation_and_write_after_close(self, client_raw, client_cbor): + client_transformed_pipe = client_cbor + + close_queue = TimeoutQueue() + client_transformed_pipe.on("close", lambda *args: close_queue.put_nowait(True)) + + await client_raw.terminate() + + await close_queue.get() + + result = await client_transformed_pipe.write({"after": "close"}) + assert result is False + +@pytest.mark.asyncio +class TestCborStreamTransformer: + + async def test_cbor_stream_transformer_single_object(self, client_raw, server_raw): + cbor_stream_transformer = CborStreamTransformer() + client_transformed_pipe = cbor_stream_transformer.applyTransformer(client_raw) + + received_data_queue = TimeoutQueue() + client_transformed_pipe.on("data", received_data_queue.put_nowait) + + original_data = {"key": "value", "id": 1} + await server_raw.write(cbor2.dumps(original_data)) + + decoded_data = await received_data_queue.get() + assert decoded_data == original_data + + async def test_cbor_stream_transformer_concatenated_objects(self, client_raw, server_raw): + cbor_stream_transformer = CborStreamTransformer() + client_transformed_pipe = cbor_stream_transformer.applyTransformer(client_raw) + + received_data_queue = TimeoutQueue() + client_transformed_pipe.on("data", received_data_queue.put_nowait) + + obj1 = {"a": 1} + obj2 = {"b": 2, "c": [3, 4]} + obj3 = "hello" + + concatenated_cbor = cbor2.dumps(obj1) + cbor2.dumps(obj2) + cbor2.dumps(obj3) + + # Send all concatenated data at once + await server_raw.write(concatenated_cbor) + + # Trigger additional decode passes to drain buffered objects + await server_raw.write(b"") + await server_raw.write(b"") + + # Expect to receive objects one by one + decoded1 = await received_data_queue.get() + decoded2 = await received_data_queue.get() + decoded3 = await received_data_queue.get() + + assert decoded1 == obj1 + assert decoded2 == obj2 + assert decoded3 == obj3 + + async def test_cbor_stream_transformer_fragmented_objects(self, client_raw, server_raw): + cbor_stream_transformer = CborStreamTransformer() + client_transformed_pipe = cbor_stream_transformer.applyTransformer(client_raw) + + received_data_queue = TimeoutQueue() + client_transformed_pipe.on("data", received_data_queue.put_nowait) + error_queue = TimeoutQueue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + obj = {"long_message": "a" * 100} + cbor_bytes = cbor2.dumps(obj) + + # Send in fragments + await server_raw.write(cbor_bytes[:10]) + # Should raise NeedsMoreDataException internally, but not emit an error + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(received_data_queue.get(), timeout=0.1) + assert error_queue.empty() + + await server_raw.write(cbor_bytes[10:50]) + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(received_data_queue.get(), timeout=0.1) + assert error_queue.empty() + + await server_raw.write(cbor_bytes[50:]) + decoded = await received_data_queue.get() + assert decoded == obj + assert error_queue.empty() + + async def test_cbor_stream_transformer_mixed_fragmented_and_concatenated(self, client_raw, server_raw): + cbor_stream_transformer = CborStreamTransformer() + client_transformed_pipe = cbor_stream_transformer.applyTransformer(client_raw) + + received_data_queue = TimeoutQueue() + client_transformed_pipe.on("data", received_data_queue.put_nowait) + + obj1 = {"id": 1, "data": "first"} + obj2 = {"id": 2, "data": "second"} + obj3 = {"id": 3, "data": "third"} + + cbor_bytes1 = cbor2.dumps(obj1) + cbor_bytes2 = cbor2.dumps(obj2) + cbor_bytes3 = cbor2.dumps(obj3) + + # Send first object fragmented + await server_raw.write(cbor_bytes1[:5]) + await server_raw.write(cbor_bytes1[5:]) + decoded1 = await received_data_queue.get() + assert decoded1 == obj1 + + # Send second and third concatenated + await server_raw.write(cbor_bytes2 + cbor_bytes3) + + # Trigger an extra decode pass to drain buffered third object + await server_raw.write(b"") + + decoded2 = await received_data_queue.get() + decoded3 = await received_data_queue.get() + assert decoded2 == obj2 + assert decoded3 == obj3 + + async def test_cbor_stream_transformer_invalid_data_in_stream(self, client_raw, server_raw): + cbor_stream_transformer = CborStreamTransformer() + client_transformed_pipe = cbor_stream_transformer.applyTransformer(client_raw) + + received_data_queue = TimeoutQueue() + client_transformed_pipe.on("data", received_data_queue.put_nowait) + error_queue = TimeoutQueue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + obj1 = {"valid": True} + invalid_bytes = b'\x1f' # Unknown unsigned integer subtype + obj2 = {"another": "valid"} + + await server_raw.write(cbor2.dumps(obj1)) + await server_raw.write(invalid_bytes + cbor2.dumps(obj2)) + + # First valid object should be decoded + decoded1 = await received_data_queue.get() + assert decoded1 == obj1 + + # The transformer should recover and decode the next valid object without emitting an error + decoded2 = await received_data_queue.get() + assert decoded2 == obj2 + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(error_queue.get(), timeout=0.1) + + async def test_cbor_stream_transformer_non_bytes_data(self, client_raw, server_raw): + cbor_stream_transformer = CborStreamTransformer() + client_transformed_pipe = cbor_stream_transformer.applyTransformer(client_raw) + + error_queue = TimeoutQueue() + client_transformed_pipe.on("error", error_queue.put_nowait) + + # Simulate server sending non-bytes data + non_bytes_data = "not cbor" + with pytest.raises(TypeError): + await server_raw.write(non_bytes_data) + + error = await asyncio.wait_for(error_queue.get(), timeout=1) + assert isinstance(error, TypeError) + assert "Expected bytes" in str(error) + + async def test_cbor_stream_transformer_close_propagation_and_write_after_close(self, client_raw): + cbor_stream_transformer = CborStreamTransformer() + client_transformed_pipe = cbor_stream_transformer.applyTransformer(client_raw) + + close_queue = TimeoutQueue() + client_transformed_pipe.on("close", lambda *args: close_queue.put_nowait(True)) + + await client_raw.terminate() + + await close_queue.get() + + result = await client_transformed_pipe.write({"after": "close"}) + assert result is False diff --git a/tests/test_event_pipe.py b/tests/test_event_pipe.py index 4bdbe09..2d84a22 100644 --- a/tests/test_event_pipe.py +++ b/tests/test_event_pipe.py @@ -6,8 +6,9 @@ @pytest_asyncio.fixture async def event_pipe_pair(): - pipe1, pipe2 = EventPipe.create_pair() + pipe1, pipe2 = EventPipe.create_inmemory_pair() yield pipe1, pipe2 + # Terminate is an async method, so it needs to be awaited await pipe1.terminate() await pipe2.terminate() diff --git a/tests/test_json_transformer.py b/tests/test_json_transformer.py index 3c885ed..8f12ab0 100644 --- a/tests/test_json_transformer.py +++ b/tests/test_json_transformer.py @@ -1,28 +1,57 @@ import pytest import json import asyncio +from cbor_rpc.tcp.tcp import TcpPipe from cbor_rpc.transformer.json_transformer import JsonTransformer from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.pipe.aio_pipe import AioPipe +from tests.helpers.simple_pipe import SimplePipe from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException from cbor_rpc.transformer.base.event_transformer_pipe import EventTransformerPipe +DEFAULT_TIMEOUT = 2.0 + +@pytest.fixture(params=[ + (EventPipe.create_inmemory_pair, "InmemoryPipe"), + (AioPipe.create_inmemory_pair, "AioPipe"), + (TcpPipe.create_inmemory_pair, "TcpPipe"), + +], ids=lambda param: param[1]) +async def pipe_pair(request): + create_pair_func, _ = request.param + if asyncio.iscoroutinefunction(create_pair_func): + client_pipe, server_pipe = await create_pair_func() + else: + client_pipe, server_pipe = create_pair_func() + yield client_pipe, server_pipe + await client_pipe.terminate() + await server_pipe.terminate() + +@pytest.fixture +async def json_pipe(pipe_pair): + client_raw_pipe, server_raw_pipe = pipe_pair + json_transformer = JsonTransformer() + client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + return client_raw_pipe, server_raw_pipe, client_transformed_pipe, json_transformer + +@pytest.fixture +async def json_pipe_ascii(pipe_pair): + client_raw_pipe, server_raw_pipe = pipe_pair + json_transformer = JsonTransformer(encoding='ascii') + client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + return client_raw_pipe, server_raw_pipe, client_transformed_pipe, json_transformer + @pytest.mark.asyncio class TestJsonTransformerPipeInteraction: - async def test_json_transformer_end_to_end_simple_dict(self): - # Create a pair of event pipes - client_raw_pipe, server_raw_pipe = EventPipe.create_pair() - - # Instantiate the JSON transformer - json_transformer = JsonTransformer() - - # Apply the transformer to the client side of the raw pipe - # This creates an EventTransformerPipe that encodes/decodes data - client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + async def test_json_transformer_end_to_end_simple_dict(self, json_pipe): + client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe assert isinstance(client_transformed_pipe, EventTransformerPipe) # Use a queue to capture data emitted by the server_raw_pipe received_data_queue = asyncio.Queue() + server_raw_pipe.on("error", lambda e: print("Server raw pipe error:", e)) + server_raw_pipe.on("data", received_data_queue.put_nowait) # Data to send @@ -34,7 +63,12 @@ async def test_json_transformer_end_to_end_simple_dict(self): # Wait for the encoded data to arrive at the server_raw_pipe # The server_raw_pipe receives the *encoded* data (bytes) - encoded_data_received_by_server = await received_data_queue.get() + try: + # Wait for data for up to 5 seconds + encoded_data_received_by_server = await asyncio.wait_for(received_data_queue.get(), timeout=2.0) + except asyncio.TimeoutError: + print("No data received within 5 seconds") + assert False, "Test failed due to timeout waiting for data" # Manually decode the data received by the server_raw_pipe to verify it's JSON bytes decoded_by_server = json.loads(encoded_data_received_by_server.decode('utf-8')) @@ -53,24 +87,29 @@ async def test_json_transformer_end_to_end_simple_dict(self): await server_raw_pipe.write(json.dumps(response_data).encode('utf-8')) # Wait for the decoded data to arrive at the client_transformed_pipe - decoded_data_received_by_client = await client_received_data_queue.get() + decoded_data_received_by_client = await asyncio.wait_for( + client_received_data_queue.get(), + timeout=2.0, + ) assert decoded_data_received_by_client == response_data # Clean up await client_raw_pipe.terminate() await server_raw_pipe.terminate() - async def test_json_transformer_end_to_end_unicode_characters(self): - client_raw_pipe, server_raw_pipe = EventPipe.create_pair() - json_transformer = JsonTransformer() - client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + async def test_json_transformer_end_to_end_unicode_characters(self, json_pipe): + ## TODO: This is taking forever when using TCP pipes, investigate why. It works fine with EventPipe and AioPipe. + client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe received_data_queue = asyncio.Queue() server_raw_pipe.on("data", received_data_queue.put_nowait) original_data = {"message": "ä½ å„½äø–ē•Œ šŸ‘‹"} await client_transformed_pipe.write(original_data) - encoded_data_received_by_server = await received_data_queue.get() + encoded_data_received_by_server = await asyncio.wait_for( + received_data_queue.get(), + timeout=2.0, + ) decoded_by_server = json.loads(encoded_data_received_by_server.decode('utf-8')) assert decoded_by_server == original_data @@ -78,17 +117,15 @@ async def test_json_transformer_end_to_end_unicode_characters(self): client_transformed_pipe.on("data", client_received_data_queue.put_nowait) response_data = {"greeting": "こんにごは"} await server_raw_pipe.write(json.dumps(response_data, ensure_ascii=False).encode('utf-8')) - decoded_data_received_by_client = await client_received_data_queue.get() + decoded_data_received_by_client = await asyncio.wait_for( + client_received_data_queue.get(), + timeout=2.0, + ) assert decoded_data_received_by_client == response_data - await client_raw_pipe.terminate() - await server_raw_pipe.terminate() - - async def test_json_transformer_encoding_error_on_write(self): - client_raw_pipe, server_raw_pipe = EventPipe.create_pair() + async def test_json_transformer_encoding_error_on_write(self, json_pipe_ascii): # Use an encoding that cannot handle certain characters - json_transformer = JsonTransformer(encoding='ascii') - client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe_ascii original_data = {"message": "Hello, world! šŸ‘‹"} # Contains non-ASCII character @@ -100,56 +137,50 @@ async def test_json_transformer_encoding_error_on_write(self): await client_transformed_pipe.write(original_data) # Assert that a UnicodeEncodeError is received - error = await asyncio.wait_for(error_queue.get(), timeout=1) + error = await asyncio.wait_for(error_queue.get(), timeout=DEFAULT_TIMEOUT) assert isinstance(error, UnicodeEncodeError) - await client_raw_pipe.terminate() - await server_raw_pipe.terminate() - - async def test_json_transformer_decoding_error_on_read(self): - client_raw_pipe, server_raw_pipe = EventPipe.create_pair() - json_transformer = JsonTransformer() - client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + async def test_json_transformer_decoding_error_on_read(self, json_pipe): + client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe # Use a queue to capture errors emitted by the transformed pipe error_queue = asyncio.Queue() client_transformed_pipe.on("error", error_queue.put_nowait) # Simulate server sending invalid JSON bytes - invalid_json_bytes = b'{"key": "value",}' # Invalid JSON - await server_raw_pipe.write(invalid_json_bytes) + invalid_json_bytes = b'{,"key": "value",}' # Invalid JSON + try: + await server_raw_pipe.write(invalid_json_bytes) + except json.JSONDecodeError: + # EventPipe/AioPipe may raise from the pipeline while still emitting the error. + pass # The transformed pipe should emit an error when trying to decode - error = await asyncio.wait_for(error_queue.get(), timeout=1) + error = await asyncio.wait_for(error_queue.get(), timeout=DEFAULT_TIMEOUT) assert isinstance(error, json.JSONDecodeError) - await client_raw_pipe.terminate() - await server_raw_pipe.terminate() - - async def test_json_transformer_decoding_type_error_on_read(self): - client_raw_pipe, server_raw_pipe = EventPipe.create_pair() - json_transformer = JsonTransformer() - client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + async def test_json_transformer_decoding_type_error_on_read(self, json_pipe): + client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe error_queue = asyncio.Queue() client_transformed_pipe.on("error", error_queue.put_nowait) # Simulate server sending non-bytes/str data (e.g., an int) non_string_data = 12345 - await server_raw_pipe.write(non_string_data) # This will pass through raw pipe as is + try: + await server_raw_pipe.write(non_string_data) # This will pass through raw pipe as is + except TypeError as exc: + # TcpPipe enforces bytes-only writes; no error will be emitted by the transformer. + assert isinstance(exc, TypeError) + return # The transformed pipe should emit a TypeError when trying to decode - error = await asyncio.wait_for(error_queue.get(), timeout=1) + error = await asyncio.wait_for(error_queue.get(), timeout=DEFAULT_TIMEOUT) assert isinstance(error, TypeError) assert "Expected bytes or str" in str(error) - await client_raw_pipe.terminate() - await server_raw_pipe.terminate() - - async def test_json_transformer_non_json_serializable_data(self): - client_raw_pipe, server_raw_pipe = EventPipe.create_pair() - json_transformer = JsonTransformer() - client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + async def test_json_transformer_non_json_serializable_data(self, json_pipe): + client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe # Data that is not JSON serializable non_serializable_data = {"set_data": {1, 2, 3}} @@ -162,16 +193,11 @@ async def test_json_transformer_non_json_serializable_data(self): await client_transformed_pipe.write(non_serializable_data) # Assert that a TypeError is received - error = await asyncio.wait_for(error_queue.get(), timeout=1) + error = await asyncio.wait_for(error_queue.get(), timeout=DEFAULT_TIMEOUT) assert isinstance(error, TypeError) - await client_raw_pipe.terminate() - await server_raw_pipe.terminate() - - async def test_json_transformer_pipe_termination(self): - client_raw_pipe, server_raw_pipe = EventPipe.create_pair() - json_transformer = JsonTransformer() - client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + async def test_json_transformer_pipe_termination(self, json_pipe): + client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe # Listen for close event on the transformed pipe close_event_received = asyncio.Event() @@ -181,13 +207,11 @@ async def test_json_transformer_pipe_termination(self): await client_raw_pipe.terminate() # The transformed pipe should also terminate and emit a close event - await asyncio.wait_for(close_event_received.wait(), timeout=1) - await server_raw_pipe.terminate() # Ensure the other end is also terminated + await asyncio.wait_for(close_event_received.wait(), timeout=DEFAULT_TIMEOUT) + # server_raw_pipe is terminated by the fixture - async def test_json_transformer_pipe_write_after_termination(self): - client_raw_pipe, server_raw_pipe = EventPipe.create_pair() - json_transformer = JsonTransformer() - client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + async def test_json_transformer_pipe_write_after_termination(self, json_pipe): + client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe await client_raw_pipe.terminate() @@ -195,12 +219,10 @@ async def test_json_transformer_pipe_write_after_termination(self): result = await client_transformed_pipe.write({"test": "data"}) assert result is False - await server_raw_pipe.terminate() + # server_raw_pipe is terminated by the fixture - async def test_json_transformer_pipe_read_after_termination(self): - client_raw_pipe, server_raw_pipe = EventPipe.create_pair() - json_transformer = JsonTransformer() - client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + async def test_json_transformer_pipe_read_after_termination(self, json_pipe): + client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe # Listen for data on the transformed pipe data_queue = asyncio.Queue() @@ -213,14 +235,19 @@ async def test_json_transformer_pipe_read_after_termination(self): client_transformed_pipe.on("close", lambda: close_event_received.set()) await server_raw_pipe.terminate() - await asyncio.wait_for(close_event_received.wait(), timeout=1) + await asyncio.wait_for(close_event_received.wait(), timeout=DEFAULT_TIMEOUT) # Try to write to the raw pipe from the server side after termination # This data should not be processed by the transformed pipe - await server_raw_pipe.write(b'{"should": "not_receive"}') + try: + result = await server_raw_pipe.write(b'{"should": "not_receive"}') + assert result is False + except ConnectionError: + # TcpPipe raises if not connected after termination. + pass # Ensure no data is received by the transformed pipe after termination with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(data_queue.get(), timeout=0.1) - await client_raw_pipe.terminate() + # client_raw_pipe is terminated by the fixture diff --git a/tests/test_tcp.py b/tests/test_tcp.py index 255f1e1..a372797 100644 --- a/tests/test_tcp.py +++ b/tests/test_tcp.py @@ -4,6 +4,8 @@ from cbor_rpc import TcpPipe from tests.helpers.simple_tcp_server import SimpleTcpServer +DEFAULT_TIMEOUT = 1.0 # we are doing everything on same machine. everything should be fast + @pytest.mark.asyncio async def test_tcp_client_server_connection(): @@ -349,5 +351,82 @@ async def test_tcp_invalid_data_types(): await server.close() +@pytest.mark.asyncio +async def test_tcp_inmemory_pair_bidirectional_exchange(): + """Test TcpPipe.create_inmemory_pair produces connected pipes that can exchange data.""" + client_pipe, server_pipe = await TcpPipe.create_inmemory_pair() + + try: + assert client_pipe.is_connected() + assert server_pipe.is_connected() + + client_received = asyncio.Queue() + server_received = asyncio.Queue() + + client_pipe.on("data", client_received.put_nowait) + server_pipe.on("data", server_received.put_nowait) + + await client_pipe.write(b"ping") + server_data = await asyncio.wait_for(server_received.get(), timeout=DEFAULT_TIMEOUT) + assert server_data == b"ping" + + await server_pipe.write(b"pong") + client_data = await asyncio.wait_for(client_received.get(), timeout=DEFAULT_TIMEOUT) + assert client_data == b"pong" + + finally: + await client_pipe.terminate() + await server_pipe.terminate() + + +@pytest.mark.asyncio +async def test_tcp_shutdown_keeps_active_connections(): + """Test shutting down the listener doesn't drop existing connections.""" + server = await SimpleTcpServer.create('127.0.0.1', 0) + server_host, server_port = server.get_address() + + server_connection = None + + async def on_connection(tcp_pipe: TcpPipe): + nonlocal server_connection + server_connection = tcp_pipe + + server.on_connection(on_connection) + + try: + client = await TcpPipe.create_connection(server_host, server_port) + await asyncio.wait_for(asyncio.sleep(0.1), timeout=DEFAULT_TIMEOUT) + assert server_connection is not None + + await server.shutdown() + + # Existing connection should still be usable + client_received = asyncio.Queue() + server_received = asyncio.Queue() + + client.on("data", client_received.put_nowait) + server_connection.on("data", server_received.put_nowait) + + await client.write(b"still-alive") + server_data = await asyncio.wait_for(server_received.get(), timeout=DEFAULT_TIMEOUT) + assert server_data == b"still-alive" + + await server_connection.write(b"still-alive-2") + client_data = await asyncio.wait_for(client_received.get(), timeout=DEFAULT_TIMEOUT) + assert client_data == b"still-alive-2" + + # New connections should fail while listener is shut down + with pytest.raises(ConnectionError) as exc_info: + await TcpPipe.create_connection(server_host, server_port, timeout=0.2) + error_text = str(exc_info.value).lower() + assert "refused" in error_text or "connect call failed" in error_text + + await client.terminate() + await server_connection.terminate() + + finally: + await server.stop() + + if __name__ == "__main__": pytest.main(["-v", __file__]) From acfe7f65c9be186cd9ff0a951914a0ff78374c7a Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Fri, 6 Feb 2026 22:51:37 +0545 Subject: [PATCH 15/25] Setup and apply black formatter --- cbor_rpc/__init__.py | 44 ++-- cbor_rpc/event/__init__.py | 2 +- cbor_rpc/event/emitter.py | 1 + cbor_rpc/pipe/aio_pipe.py | 20 +- cbor_rpc/pipe/event_pipe.py | 12 +- cbor_rpc/pipe/pipe.py | 18 +- cbor_rpc/rpc/__init__.py | 5 +- cbor_rpc/rpc/rpc_server.py | 16 +- cbor_rpc/rpc/rpc_v1.py | 62 +++--- cbor_rpc/rpc/server_base.py | 20 +- cbor_rpc/ssh/ssh_pipe.py | 40 ++-- cbor_rpc/stdio/stdio_pipe.py | 21 +- cbor_rpc/tcp/__init__.py | 2 +- cbor_rpc/tcp/tcp.py | 134 ++++++------ cbor_rpc/timed_promise.py | 17 +- cbor_rpc/transformer/base/__init__.py | 4 +- .../base/event_transformer_pipe.py | 4 +- cbor_rpc/transformer/base/transformer_base.py | 14 +- cbor_rpc/transformer/base/transformer_pipe.py | 18 +- cbor_rpc/transformer/cbor_transformer.py | 15 +- cbor_rpc/transformer/json_transformer.py | 9 +- examples/fs_rpc/filesystem_client.py | 20 +- examples/fs_rpc/filesystem_server.py | 12 +- pyproject.toml | 3 + setup.py | 5 +- tests/docker/sshd-python/binary_emitter.py | 2 +- tests/docker/sshd-python/echo_back.py | 4 +- tests/helpers/simple_pipe.py | 2 +- tests/helpers/simple_tcp_server.py | 2 + tests/helpers/stdio_test_script.py | 4 +- tests/test_cbor_transformer.py | 10 +- tests/test_event_emitter.py | 64 +++--- tests/test_event_pipe.py | 41 ++-- tests/test_json_transformer.py | 44 ++-- tests/test_pipe.py | 31 ++- tests/test_rpc_v1.py | 33 ++- tests/test_ssh_docker_pipe.py | 171 +++++++++------- tests/test_stdio_rpc.py | 7 +- tests/test_tcp.py | 190 +++++++++--------- 39 files changed, 623 insertions(+), 500 deletions(-) diff --git a/cbor_rpc/__init__.py b/cbor_rpc/__init__.py index c8e48b4..e4cf8a2 100644 --- a/cbor_rpc/__init__.py +++ b/cbor_rpc/__init__.py @@ -3,43 +3,35 @@ """ from .event import AbstractEmitter -from .pipe import EventPipe,Pipe +from .pipe import EventPipe, Pipe from .timed_promise import TimedPromise -from .rpc import RpcClient, RpcAuthorizedClient,RpcServer,RpcV1,RpcV1Server,Server +from .rpc import RpcClient, RpcAuthorizedClient, RpcServer, RpcV1, RpcV1Server, Server from .tcp import TcpPipe, TcpServer -from .transformer import JsonTransformer,Transformer +from .transformer import JsonTransformer, Transformer __all__ = [ # Promise - 'TimedPromise', - + "TimedPromise", # Emitter - 'AbstractEmitter', - - + "AbstractEmitter", # Pipe abstract classes - 'EventPipe', - 'Pipe', - + "EventPipe", + "Pipe", # Server abstract classes - 'Server', - - # Rpc abstract classes - 'RpcClient', - 'RpcAuthorizedClient', - 'RpcServer', - + "Server", + # Rpc abstract classes + "RpcClient", + "RpcAuthorizedClient", + "RpcServer", # Rpc base implementation - 'RpcV1', - 'RpcV1Server', - + "RpcV1", + "RpcV1Server", # TCP classes - 'TcpPipe', - 'TcpServer', - + "TcpPipe", + "TcpServer", # Transformers - 'Transformer', - 'JsonTransformer', + "Transformer", + "JsonTransformer", ] __version__ = "0.1.0" diff --git a/cbor_rpc/event/__init__.py b/cbor_rpc/event/__init__.py index 809afca..e288aa2 100644 --- a/cbor_rpc/event/__init__.py +++ b/cbor_rpc/event/__init__.py @@ -1 +1 @@ -from .emitter import AbstractEmitter \ No newline at end of file +from .emitter import AbstractEmitter diff --git a/cbor_rpc/event/emitter.py b/cbor_rpc/event/emitter.py index 37538cb..e5727c1 100644 --- a/cbor_rpc/event/emitter.py +++ b/cbor_rpc/event/emitter.py @@ -5,6 +5,7 @@ import traceback import warnings + class AbstractEmitter(ABC): def __init__(self): self._pipelines: Dict[str, List[Callable]] = {} diff --git a/cbor_rpc/pipe/aio_pipe.py b/cbor_rpc/pipe/aio_pipe.py index 16bb935..c866c15 100644 --- a/cbor_rpc/pipe/aio_pipe.py +++ b/cbor_rpc/pipe/aio_pipe.py @@ -4,8 +4,9 @@ from .event_pipe import EventPipe # Assuming EventPipe is in a separate module # Constrain T1 and T2 to bytes or bytearray for type safety with asyncio streams -T1 = TypeVar('T1', bound=Union[bytes, bytearray]) -T2 = TypeVar('T2', bound=Union[bytes, bytearray]) +T1 = TypeVar("T1", bound=Union[bytes, bytearray]) +T2 = TypeVar("T2", bound=Union[bytes, bytearray]) + class AioPipe(EventPipe[T1, T2], ABC): """ @@ -15,12 +16,15 @@ class AioPipe(EventPipe[T1, T2], ABC): Attributes: DEFAULT_READ_CHUNK_SIZE (int): Default size of chunks to read from the stream (8192 bytes). """ + DEFAULT_READ_CHUNK_SIZE = 8192 - def __init__(self, - reader: Optional[asyncio.StreamReader] = None, - writer: Optional[asyncio.StreamWriter] = None, - chunk_size: int = DEFAULT_READ_CHUNK_SIZE): + def __init__( + self, + reader: Optional[asyncio.StreamReader] = None, + writer: Optional[asyncio.StreamWriter] = None, + chunk_size: int = DEFAULT_READ_CHUNK_SIZE, + ): """ Initialize the AioPipe with optional reader, writer, and chunk size. @@ -82,10 +86,10 @@ async def _read_loop(self) -> None: break except asyncio.CancelledError: break - except Exception as e: # Catch BaseException for GeneratorExit/other BaseExceptions + except Exception as e: # Catch BaseException for GeneratorExit/other BaseExceptions self._emit("error", e) # Synchronous _emit break - except Exception as e: # Catch BaseException for GeneratorExit/other BaseExceptions + except Exception as e: # Catch BaseException for GeneratorExit/other BaseExceptions self._emit("error", e) # Synchronous _emit finally: if not self._closed: diff --git a/cbor_rpc/pipe/event_pipe.py b/cbor_rpc/pipe/event_pipe.py index 11577e5..2b7d66b 100644 --- a/cbor_rpc/pipe/event_pipe.py +++ b/cbor_rpc/pipe/event_pipe.py @@ -5,8 +5,8 @@ from ..event.emitter import AbstractEmitter # Generic type variables -T1 = TypeVar('T1') -T2 = TypeVar('T2') +T1 = TypeVar("T1") +T2 = TypeVar("T2") class EventPipe(AbstractEmitter, Generic[T1, T2]): @@ -14,6 +14,7 @@ class EventPipe(AbstractEmitter, Generic[T1, T2]): Event Pipe or are event based way for read/write. You cannot directly read from a Pipe. You have to use a pipeline("data") to register one or more functions to read data. """ + @abstractmethod async def write(self, chunk: T1) -> bool: pass @@ -23,20 +24,21 @@ async def terminate(self, *args: Any) -> None: pass @staticmethod - def create_inmemory_pair() -> Tuple['EventPipe[Any, Any]', 'EventPipe[Any, Any]']: + def create_inmemory_pair() -> Tuple["EventPipe[Any, Any]", "EventPipe[Any, Any]"]: """ Create a pair of connected pipes for bidirectional communication. Returns: A tuple of (pipe1, pipe2) where data written to pipe1 is emitted on pipe2 and vice versa. """ + class ConnectedPipe(EventPipe[Any, Any]): def __init__(self): super().__init__() - self.connected_pipe: Optional['ConnectedPipe'] = None + self.connected_pipe: Optional["ConnectedPipe"] = None self._closed = False - def connect_to(self, other: 'ConnectedPipe'): + def connect_to(self, other: "ConnectedPipe"): self.connected_pipe = other other.connected_pipe = self diff --git a/cbor_rpc/pipe/pipe.py b/cbor_rpc/pipe/pipe.py index 659fe67..7cc3150 100644 --- a/cbor_rpc/pipe/pipe.py +++ b/cbor_rpc/pipe/pipe.py @@ -5,8 +5,8 @@ from cbor_rpc.pipe.event_pipe import EventPipe from ..event.emitter import AbstractEmitter -T1 = TypeVar('T1') -T2 = TypeVar('T2') +T1 = TypeVar("T1") +T2 = TypeVar("T2") class Pipe(AbstractEmitter, Generic[T1, T2], ABC): @@ -27,13 +27,13 @@ async def terminate(self, *args: Any) -> None: pass @staticmethod - def create_pair() -> Tuple['Pipe[Any, Any]', 'Pipe[Any, Any]']: + def create_pair() -> Tuple["Pipe[Any, Any]", "Pipe[Any, Any]"]: class InMemoryPipe(Pipe[Any, Any]): def __init__(self): super().__init__() self._closed = False self._buffer: asyncio.Queue[Optional[Any]] = asyncio.Queue() - self.connected_pipe: Optional['InMemoryPipe'] = None + self.connected_pipe: Optional["InMemoryPipe"] = None async def write(self, chunk: Any) -> bool: if self._closed or not self.connected_pipe or self.connected_pipe._closed: @@ -65,19 +65,19 @@ async def terminate(self, *args: Any) -> None: self._closed = True # Signal termination to any pending reads await self._buffer.put(None) - await self._notify("close", *args) # Notify external listeners + await self._notify("close", *args) # Notify external listeners if self.connected_pipe and not self.connected_pipe._closed: - await self.connected_pipe._buffer.put(None) # Signal termination to connected pipe - await self.connected_pipe.terminate(*args) # Recursively terminate connected pipe + await self.connected_pipe._buffer.put(None) # Signal termination to connected pipe + await self.connected_pipe.terminate(*args) # Recursively terminate connected pipe a = InMemoryPipe() b = InMemoryPipe() a.connected_pipe = b b.connected_pipe = a return a, b - - def make_event_based(self) -> 'EventPipe[T1, T2]': + + def make_event_based(self) -> "EventPipe[T1, T2]": parent = self class PipeToEvent(EventPipe[T1, T2]): diff --git a/cbor_rpc/rpc/__init__.py b/cbor_rpc/rpc/__init__.py index b19f2a9..431c063 100644 --- a/cbor_rpc/rpc/__init__.py +++ b/cbor_rpc/rpc/__init__.py @@ -1,7 +1,6 @@ from .server_base import Server -from .rpc_base import RpcClient,RpcServer,RpcAuthorizedClient +from .rpc_base import RpcClient, RpcServer, RpcAuthorizedClient from .rpc_v1 import RpcV1 -from .rpc_server import RpcV1Server - +from .rpc_server import RpcV1Server diff --git a/cbor_rpc/rpc/rpc_server.py b/cbor_rpc/rpc/rpc_server.py index f778ec5..69f8bcd 100644 --- a/cbor_rpc/rpc/rpc_server.py +++ b/cbor_rpc/rpc/rpc_server.py @@ -9,29 +9,25 @@ class RpcV1Server(RpcServer): - def __init__(self,server:Server): + def __init__(self, server: Server): self.active_connections: Dict[str, RpcV1] = {} self.timeout = 30000 async def add_connection(self, conn_id: str, rpc_client: EventPipe[Any, Any]) -> None: def method_handler(method: str, args: List[Any]) -> Any: return self.handle_method_call(conn_id, method, args) - + async def event_handler(topic: str, data: Any) -> None: await self._handle_event(conn_id, topic, data) - - client_rpc = RpcV1.make_rpc_v1( - rpc_client, - conn_id, - method_handler, - event_handler - ) + + client_rpc = RpcV1.make_rpc_v1(rpc_client, conn_id, method_handler, event_handler) client_rpc.set_timeout(self.timeout) self.active_connections[conn_id] = client_rpc - + # Set up cleanup on close async def cleanup(*args): self.active_connections.pop(conn_id, None) + rpc_client.on("close", cleanup) async def disconnect(self, connection_id: str, reason: Optional[str] = None) -> None: diff --git a/cbor_rpc/rpc/rpc_v1.py b/cbor_rpc/rpc/rpc_v1.py index 29c0212..fc78d1f 100644 --- a/cbor_rpc/rpc/rpc_v1.py +++ b/cbor_rpc/rpc/rpc_v1.py @@ -26,20 +26,20 @@ async def resolve_result(result: Any) -> Any: async def on_data(data: List[Any]) -> None: try: if not isinstance(data, list) or len(data) != 5: - print(f"RpcV1: Invalid message format: {data}",file=sys.stderr) + print(f"RpcV1: Invalid message format: {data}", file=sys.stderr) return version, direction, id_, method, params = data if version != 1: - print(f"RpcV1: Unsupported version: {data}",file=sys.stderr) + print(f"RpcV1: Unsupported version: {data}", file=sys.stderr) return - print("RpvV1: Received", data,file=sys.stderr) + print("RpvV1: Received", data, file=sys.stderr) if direction < 2: # Method call (0) or fire (1) try: # Call the method and get the result result = self.handle_method_call(method, params) - + # Handle the response asynchronously async def handle_response(): try: @@ -50,17 +50,23 @@ async def handle_response(): if direction == 0: await self.pipe.write([1, 2, id_, False, str(e)]) else: - print(f"Fired method error: {method}, params={params}, error={e}",file=sys.stderr) - + print( + f"Fired method error: {method}, params={params}, error={e}", + file=sys.stderr, + ) + # Create task to handle response asyncio.create_task(handle_response()) - + except Exception as e: if direction == 0: asyncio.create_task(self.pipe.write([1, 2, id_, False, str(e)])) else: - print(f"Fired method error: {method}, params={params}, error={e}",file=sys.stderr) - + print( + f"Fired method error: {method}, params={params}, error={e}", + file=sys.stderr, + ) + elif direction == 2: # Response promise = self._promises.pop(id_, None) if promise: @@ -69,25 +75,28 @@ async def handle_response(): else: # Error await promise.reject(params) else: - print(f"Received rpc reply for expired request id: {id_}, success={method}, data={params}",file=sys.stderr) - + print( + f"Received rpc reply for expired request id: {id_}, success={method}, data={params}", + file=sys.stderr, + ) + elif direction == 3: # Event await self._on_event(method, params) else: - print(f"RpcV1: Invalid direction: {direction}",file=sys.stderr) - + print(f"RpcV1: Invalid direction: {direction}", file=sys.stderr) + except Exception as e: - print(f"Error processing RPC message: {e}",file=sys.stderr) + print(f"Error processing RPC message: {e}", file=sys.stderr) self.pipe.on("data", on_data) async def call_method(self, method: str, *args: Any) -> Any: counter = self._counter self._counter += 1 - + def timeout_callback(): self._promises.pop(counter, None) - + promise = TimedPromise(self._timeout, timeout_callback) self._promises[counter] = promise await self.pipe.write([1, 0, counter, method, list(args)]) @@ -118,14 +127,14 @@ def get_id(self) -> str: async def wait_next_event(self, topic: str, timeout_ms: Optional[int] = None) -> Any: if topic in self._waiters: raise Exception("Already waiting for event") - + def timeout_callback(): self._waiters.pop(topic, None) - + waiter = TimedPromise( timeout_ms or self._timeout, timeout_callback, - f"Timeout Waiting for Event on: {topic}" + f"Timeout Waiting for Event on: {topic}", ) self._waiters[topic] = waiter return await waiter.promise @@ -139,7 +148,12 @@ async def on_event(self, topic: str, message: Any) -> None: pass @staticmethod - def make_rpc_v1(pipe: EventPipe[Any, Any], id_: str, method_handler: Callable, event_handler: Callable) -> 'RpcV1': + def make_rpc_v1( + pipe: EventPipe[Any, Any], + id_: str, + method_handler: Callable, + event_handler: Callable, + ) -> "RpcV1": class ConcreteRpcV1(RpcV1): def get_id(self) -> str: return id_ @@ -152,13 +166,15 @@ async def on_event(self, topic: str, message: Any) -> None: await event_handler(topic, message) else: event_handler(topic, message) + return ConcreteRpcV1(pipe) @staticmethod - def read_only_client(pipe: EventPipe[Any, Any]) -> 'RpcV1': + def read_only_client(pipe: EventPipe[Any, Any]) -> "RpcV1": def method_handler(method: str, args: List[Any]) -> Any: raise Exception("Client Only Implementation") async def event_handler(topic: str, message: Any) -> None: - print(f"Rpc Event dropped {topic} {message}",file=sys.stderr) - return RpcV1.make_rpc_v1(pipe, '', method_handler, event_handler) + print(f"Rpc Event dropped {topic} {message}", file=sys.stderr) + + return RpcV1.make_rpc_v1(pipe, "", method_handler, event_handler) diff --git a/cbor_rpc/rpc/server_base.py b/cbor_rpc/rpc/server_base.py index e0c8eeb..9f214d2 100644 --- a/cbor_rpc/rpc/server_base.py +++ b/cbor_rpc/rpc/server_base.py @@ -5,16 +5,16 @@ from ..pipe import EventPipe # Generic type variable for pipe types -P = TypeVar('P', bound=EventPipe) +P = TypeVar("P", bound=EventPipe) class Server(AbstractEmitter, Generic[P]): """ Abstract server class that manages connections and emits connection events. - + Type parameter P specifies the type of Pipe that this server handles. """ - + def __init__(self): super().__init__() self._connections: Set[P] = set() @@ -24,7 +24,7 @@ def __init__(self): async def start(self, *args, **kwargs) -> Any: """ Start the server. - + Returns: Server-specific information (e.g., address, port) """ @@ -36,25 +36,27 @@ async def stop(self) -> None: pass @abstractmethod - async def accept(self,pipe:P) -> bool: + async def accept(self, pipe: P) -> bool: pass async def _add_connection(self, pipe: P) -> None: """ Add a new connection and emit a connection event. - + Args: pipe: The pipe representing the connection """ if not await self.accept(pipe): await pipe.terminate() - return + return self._connections.add(pipe) + # Set up cleanup when connection closes async def cleanup(*args): self._connections.discard(pipe) + pipe.on("close", cleanup) - + # Emit connection event await self._notify("connection", pipe) @@ -75,7 +77,7 @@ async def close_all_connections(self) -> None: def on_connection(self, handler: Callable[[P], None]) -> None: """ Register a handler for new connections. - + Args: handler: Function that takes a Pipe of type P as argument """ diff --git a/cbor_rpc/ssh/ssh_pipe.py b/cbor_rpc/ssh/ssh_pipe.py index 81911a9..67cab1d 100644 --- a/cbor_rpc/ssh/ssh_pipe.py +++ b/cbor_rpc/ssh/ssh_pipe.py @@ -4,31 +4,37 @@ from cbor_rpc.pipe.aio_pipe import AioPipe + class SshPipe(AioPipe[bytes, bytes]): """ A Pipe implementation that works over an SSH connection. It uses asyncssh to establish and manage the SSH session and channels. """ - def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, - ssh_client: asyncssh.SSHClientConnection, - ssh_channel: asyncssh.SSHClientChannel): + def __init__( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ssh_client: asyncssh.SSHClientConnection, + ssh_channel: asyncssh.SSHClientChannel, + ): super().__init__(reader, writer) self._ssh_client = ssh_client self._ssh_channel = ssh_channel - @classmethod - async def connect(cls, - host: str, - port: int = 22, - username: str = 'root', - password: Optional[str] = None, - ssh_key_content: Optional[str] = None, - ssh_key_passphrase: Optional[str] = None, - known_hosts: Optional[Union[str, list]] = None, - timeout: Optional[float] = None, - command: str = 'sh -l') -> 'SshPipe': + async def connect( + cls, + host: str, + port: int = 22, + username: str = "root", + password: Optional[str] = None, + ssh_key_content: Optional[str] = None, + ssh_key_passphrase: Optional[str] = None, + known_hosts: Optional[Union[str, list]] = None, + timeout: Optional[float] = None, + command: str = "sh -l", + ) -> "SshPipe": """ Establishes an SSH connection and opens a session channel, returning an SshPipe. @@ -63,12 +69,12 @@ async def connect(cls, username=username, password=password, client_keys=client_keys, - passphrase=ssh_key_passphrase, # Passphrase for encrypted client keys - ignore_encrypted=False # Do not ignore encrypted keys + passphrase=ssh_key_passphrase, # Passphrase for encrypted client keys + ignore_encrypted=False, # Do not ignore encrypted keys ), known_hosts=known_hosts, ), - timeout=timeout + timeout=timeout, ) # Create a process on the SSH connection to get stdin/stdout streams. diff --git a/cbor_rpc/stdio/stdio_pipe.py b/cbor_rpc/stdio/stdio_pipe.py index 5a1fd66..c1d1d56 100644 --- a/cbor_rpc/stdio/stdio_pipe.py +++ b/cbor_rpc/stdio/stdio_pipe.py @@ -6,8 +6,9 @@ from cbor_rpc.pipe.pipe import Pipe from cbor_rpc.event.emitter import AbstractEmitter -T1 = TypeVar('T1') -T2 = TypeVar('T2') +T1 = TypeVar("T1") +T2 = TypeVar("T2") + class StdioPipe(AioPipe[bytes, bytes]): """ @@ -15,15 +16,20 @@ class StdioPipe(AioPipe[bytes, bytes]): typically obtained from a subprocess's stdin/stdout. """ - def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, process: Optional[asyncio.subprocess.Process] = None): - super().__init__(reader,writer) + def __init__( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + process: Optional[asyncio.subprocess.Process] = None, + ): + super().__init__(reader, writer) self._process = process async def _setup(self): await self._setup_connection() @classmethod - async def open(cls) -> 'StdioPipe': + async def open(cls) -> "StdioPipe": """ Creates a StdioPipe from the process's stdin and stdout. """ @@ -38,7 +44,7 @@ async def open(cls) -> 'StdioPipe': return pipe @classmethod - async def start_process(cls, *args: str) -> 'StdioPipe': + async def start_process(cls, *args: str) -> "StdioPipe": """ Starts a process and returns a StdioPipe for it. """ @@ -46,7 +52,7 @@ async def start_process(cls, *args: str) -> 'StdioPipe': *args, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, - stderr=sys.stderr + stderr=sys.stderr, ) pipe = cls(process.stdout, process.stdin, process) await pipe._setup() @@ -61,7 +67,6 @@ async def wait_for_process_termination(self) -> int: raise RuntimeError("No subprocess associated with this StdioPipe instance.") return await self._process.wait() - def terminate(self): """ Terminates the started subprocess if one exists. diff --git a/cbor_rpc/tcp/__init__.py b/cbor_rpc/tcp/__init__.py index fafafd4..25f514f 100644 --- a/cbor_rpc/tcp/__init__.py +++ b/cbor_rpc/tcp/__init__.py @@ -1,3 +1,3 @@ from .tcp import TcpPipe, TcpServer -__all__ = ['TcpPipe', 'TcpServer'] +__all__ = ["TcpPipe", "TcpServer"] diff --git a/cbor_rpc/tcp/tcp.py b/cbor_rpc/tcp/tcp.py index 3174ad4..9a61803 100644 --- a/cbor_rpc/tcp/tcp.py +++ b/cbor_rpc/tcp/tcp.py @@ -11,166 +11,164 @@ class TcpPipe(AioPipe[bytes, bytes]): A TCP duplex pipe that implements Pipe for network communication. Provides both client and server functionality for TCP connections. """ - - def __init__(self, reader: Optional[asyncio.StreamReader] = None, - writer: Optional[asyncio.StreamWriter] = None): - super().__init__(reader,writer) - + + def __init__( + self, + reader: Optional[asyncio.StreamReader] = None, + writer: Optional[asyncio.StreamWriter] = None, + ): + super().__init__(reader, writer) + @classmethod - async def create_connection(cls, host: str, port: int, - timeout: Optional[float] = None) -> 'TcpPipe': + async def create_connection(cls, host: str, port: int, timeout: Optional[float] = None) -> "TcpPipe": """ Create a TCP client connection to the specified host and port. - + Args: host: The hostname or IP address to connect to port: The port number to connect to timeout: Optional timeout for the connection attempt - + Returns: A connected TcpPipe instance - + Raises: ConnectionError: If the connection fails asyncio.TimeoutError: If the connection times out """ try: if timeout: - reader, writer = await asyncio.wait_for( - asyncio.open_connection(host, port), - timeout=timeout - ) + reader, writer = await asyncio.wait_for(asyncio.open_connection(host, port), timeout=timeout) else: reader, writer = await asyncio.open_connection(host, port) - + tcp_duplex = cls(reader, writer) await tcp_duplex._setup_connection() return tcp_duplex - + except Exception as e: raise ConnectionError(f"Failed to connect to {host}:{port}: {e}") - + @classmethod - async def create_server(cls, host: str = '0.0.0.0', port: int = 0, - backlog: int = 100) -> 'TcpServer': + async def create_server(cls, host: str = "0.0.0.0", port: int = 0, backlog: int = 100) -> "TcpServer": """ Create a TCP server that listens for incoming connections. - + Args: host: The hostname or IP address to bind to (default: '0.0.0.0') port: The port number to bind to (default: 0 for auto-assignment) backlog: The maximum number of queued connections - + Returns: A TcpServer instance """ return await TcpServer.create(host, port, backlog) @staticmethod - async def create_inmemory_pair() -> Tuple['TcpPipe', 'TcpPipe']: + async def create_inmemory_pair() -> Tuple["TcpPipe", "TcpPipe"]: """ Create a pair of connected TCP pipes using a local server. - + Returns: A tuple of (client_pipe, server_pipe) connected via TCP """ + class SimpleTcpServer(TcpServer): async def accept(self, pipe: TcpPipe) -> bool: return True + # Create a temporary server - server = await SimpleTcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create("127.0.0.1", 0) host, port = server.get_address() - + # Set up to capture the server-side connection server_pipe = None connection_ready = asyncio.Event() - + async def on_connection(pipe: TcpPipe): nonlocal server_pipe server_pipe = pipe connection_ready.set() - + server.on_connection(on_connection) - + try: # Create client connection client_pipe = await TcpPipe.create_connection(host, port) - + # Wait for server connection await connection_ready.wait() - + # Stop accepting new connections but keep the active pipes await server.shutdown() - + return client_pipe, server_pipe - + except Exception: await server.stop() raise - + async def connect(self, host: str, port: int, timeout: Optional[float] = None) -> None: """ Connect to a remote TCP server. - + Args: host: The hostname or IP address to connect to port: The port number to connect to timeout: Optional timeout for the connection attempt - + Raises: ConnectionError: If already connected or connection fails asyncio.TimeoutError: If the connection times out """ if self._connected: raise ConnectionError("Already connected") - + try: if timeout: self._reader, self._writer = await asyncio.wait_for( - asyncio.open_connection(host, port), - timeout=timeout + asyncio.open_connection(host, port), timeout=timeout ) else: self._reader, self._writer = await asyncio.open_connection(host, port) - + await self._setup_connection() - + except Exception as e: raise ConnectionError(f"Failed to connect to {host}:{port}: {e}") - - + def get_peer_info(self) -> Optional[Tuple[str, int]]: """Get the remote peer's address and port.""" if self._writer and self._connected: try: - return self._writer.get_extra_info('peername') + return self._writer.get_extra_info("peername") except Exception: pass return None - + def get_local_info(self) -> Optional[Tuple[str, int]]: """Get the local socket's address and port.""" if self._writer and self._connected: try: - return self._writer.get_extra_info('sockname') + return self._writer.get_extra_info("sockname") except Exception: pass return None - + def get_peer_info(self) -> Optional[Tuple[str, int]]: """Get the remote peer's address and port.""" if self._writer and self._connected: try: - return self._writer.get_extra_info('peername') + return self._writer.get_extra_info("peername") except Exception: pass return None - + def get_local_info(self) -> Optional[Tuple[str, int]]: """Get the local socket's address and port.""" if self._writer and self._connected: try: - return self._writer.get_extra_info('sockname') + return self._writer.get_extra_info("sockname") except Exception: pass return None @@ -181,57 +179,53 @@ class TcpServer(Server[TcpPipe]): A TCP server that creates TcpPipe instances for incoming connections. Extends Server[TcpPipe] to provide type-safe TCP-specific functionality. """ - + def __init__(self, server: asyncio.Server): super().__init__() self._server = server - + @classmethod - async def create(cls, host: str = '0.0.0.0', port: int = 0, - backlog: int = 100) -> 'TcpServer': + async def create(cls, host: str = "0.0.0.0", port: int = 0, backlog: int = 100) -> "TcpServer": """ Create and start a TCP server. - + Args: host: The hostname or IP address to bind to port: The port number to bind to (0 for auto-assignment) backlog: The maximum number of queued connections - + Returns: A started TcpServer instance """ tcp_server = cls.__new__(cls) Server.__init__(tcp_server) - async def client_connected_cb(reader: asyncio.StreamReader, - writer: asyncio.StreamWriter) -> None: + async def client_connected_cb(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: tcp_pipe = TcpPipe(reader, writer) await tcp_pipe._setup_connection() await tcp_server._add_connection(tcp_pipe) - server = await asyncio.start_server( - client_connected_cb, host, port, backlog=backlog - ) + server = await asyncio.start_server(client_connected_cb, host, port, backlog=backlog) tcp_server._server = server tcp_server._running = True return tcp_server - async def start(self, host: str = '0.0.0.0', port: int = 0, backlog: int = 100) -> Tuple[str, int]: + async def start(self, host: str = "0.0.0.0", port: int = 0, backlog: int = 100) -> Tuple[str, int]: """ Start the TCP server (if not already started). - + Args: host: The hostname or IP address to bind to port: The port number to bind to backlog: The maximum number of queued connections - + Returns: A tuple of (host, port) where the server is listening """ if self._running: return self.get_address() - + # This method is for compatibility; typically create() is used instead new_server = await TcpServer.create(host, port, backlog) self._server = new_server._server @@ -242,12 +236,12 @@ async def stop(self) -> None: """Stop the TCP server and close all connections.""" if not self._running: return - + self._running = False - + # Close all connections await self.close_all_connections() - + # Close the server if self._server: self._server.close() @@ -261,15 +255,15 @@ async def shutdown(self) -> None: self._running = False self._server.close() await self._server.wait_closed() - + def get_address(self) -> Tuple[str, int]: """Get the server's listening address and port.""" if self._server and self._server.sockets: return self._server.sockets[0].getsockname()[:2] return ("", 0) - + @abstractmethod - async def accept(self,pipe:TcpPipe) -> bool: + async def accept(self, pipe: TcpPipe) -> bool: pass async def close(self) -> None: diff --git a/cbor_rpc/timed_promise.py b/cbor_rpc/timed_promise.py index da54d50..c96bd9a 100644 --- a/cbor_rpc/timed_promise.py +++ b/cbor_rpc/timed_promise.py @@ -3,21 +3,22 @@ class TimedPromise: - def __init__(self, timeout_ms: int, timeout_cb: Optional[Callable[[], None]] = None, - message: str = "Timeout on RPC call"): + def __init__( + self, + timeout_ms: int, + timeout_cb: Optional[Callable[[], None]] = None, + message: str = "Timeout on RPC call", + ): self._timeout_ms = timeout_ms self._timeout_cb = timeout_cb self._message = message self._future = asyncio.get_event_loop().create_future() self._timeout_handle = None self._resolved = False - + # Set up timeout if timeout_ms > 0: - self._timeout_handle = asyncio.get_event_loop().call_later( - timeout_ms / 1000.0, - self._on_timeout - ) + self._timeout_handle = asyncio.get_event_loop().call_later(timeout_ms / 1000.0, self._on_timeout) @property def promise(self) -> asyncio.Future: @@ -46,7 +47,7 @@ def _on_timeout(self) -> None: error_data = { "timeout": True, "timeoutPeriod": self._timeout_ms, - "message": self._message + "message": self._message, } self._future.set_exception(Exception(error_data)) if self._timeout_cb: diff --git a/cbor_rpc/transformer/base/__init__.py b/cbor_rpc/transformer/base/__init__.py index 7c020f0..09c712e 100644 --- a/cbor_rpc/transformer/base/__init__.py +++ b/cbor_rpc/transformer/base/__init__.py @@ -1,2 +1,2 @@ -from .transformer_base import Transformer,AsyncTransformer -from .base_exception import NeedsMoreDataException \ No newline at end of file +from .transformer_base import Transformer, AsyncTransformer +from .base_exception import NeedsMoreDataException diff --git a/cbor_rpc/transformer/base/event_transformer_pipe.py b/cbor_rpc/transformer/base/event_transformer_pipe.py index 1700a57..55a3da0 100644 --- a/cbor_rpc/transformer/base/event_transformer_pipe.py +++ b/cbor_rpc/transformer/base/event_transformer_pipe.py @@ -1,5 +1,6 @@ import asyncio from typing import TYPE_CHECKING + if TYPE_CHECKING: from .transformer_base import Transformer from typing import Any, Awaitable, Callable, TypeVar @@ -10,11 +11,12 @@ T1 = TypeVar("T1") # Output type after decoding T2 = TypeVar("T2") # Input type before decoding (pipe input/output type) + class EventTransformerPipe(EventPipe[T1, T2]): encode: Callable[[T1], Awaitable[T2]] decode: Callable[[T2], Awaitable[T1]] - def __init__(self, pipe: EventPipe[T2, T2], transformer: 'Transformer'): + def __init__(self, pipe: EventPipe[T2, T2], transformer: "Transformer"): super().__init__() self.pipe = pipe self.pipe.pipeline("data", self._handle_data) diff --git a/cbor_rpc/transformer/base/transformer_base.py b/cbor_rpc/transformer/base/transformer_base.py index c84562d..4fd67cb 100644 --- a/cbor_rpc/transformer/base/transformer_base.py +++ b/cbor_rpc/transformer/base/transformer_base.py @@ -10,6 +10,7 @@ T1 = TypeVar("T1") T2 = TypeVar("T2") + # Sync Transformer (no async methods) class Transformer(Generic[T1, T2]): def __init__(self): @@ -18,7 +19,7 @@ def __init__(self): def is_closed(self) -> bool: return self._closed - + @abstractmethod def encode(self, data: T1) -> Any: pass @@ -28,8 +29,8 @@ def decode(self, data: Any) -> T2: pass def wait_next_data(self): - raise NeedsMoreDataException() - + raise NeedsMoreDataException() + @overload def bind(self, pipe: Pipe) -> TransformerPipe: ... @overload @@ -47,7 +48,7 @@ def applyTransformer(self, pipe: Union[Pipe, EventPipe]) -> Union[TransformerPip else: raise TypeError("Invalid pipe type") - def to_async(self) -> 'AsyncTransformer[T1, T2]': + def to_async(self) -> "AsyncTransformer[T1, T2]": parent = self class WrappedAsyncTransformer(AsyncTransformer[T1, T2]): @@ -62,6 +63,7 @@ def is_closed(self) -> bool: return WrappedAsyncTransformer() + # Async Transformer (async methods) class AsyncTransformer(Generic[T1, T2]): def __init__(self): @@ -70,7 +72,7 @@ def __init__(self): def is_closed(self) -> bool: return self._closed - + @abstractmethod async def encode(self, data: T1) -> Any: pass @@ -80,7 +82,7 @@ async def decode(self, data: Any) -> T2: pass def wait_next_data(self): - raise NeedsMoreDataException() + raise NeedsMoreDataException() @overload def bind(self, pipe: Pipe) -> TransformerPipe: ... diff --git a/cbor_rpc/transformer/base/transformer_pipe.py b/cbor_rpc/transformer/base/transformer_pipe.py index 1c35d6a..21ac1e1 100644 --- a/cbor_rpc/transformer/base/transformer_pipe.py +++ b/cbor_rpc/transformer/base/transformer_pipe.py @@ -3,6 +3,7 @@ import time from .base_exception import NeedsMoreDataException + if TYPE_CHECKING: from .transformer_base import Transformer from cbor_rpc.pipe.pipe import Pipe @@ -10,11 +11,12 @@ T1 = TypeVar("T1") T2 = TypeVar("T2") + class TransformerPipe(Pipe[T1, T2]): encode: Callable[[T1], Awaitable[T2]] decode: Callable[[T2], Awaitable[T1]] - def __init__(self, pipe: Pipe[Any, Any], transformer: 'Optional[Transformer[T1, T2]]' ): + def __init__(self, pipe: Pipe[Any, Any], transformer: "Optional[Transformer[T1, T2]]"): super().__init__() self.pipe = pipe @@ -25,15 +27,15 @@ def _handle_error(*args): if not self._closed: self._closed = True self.pipe.terminate() - self._emit('error', *args) + self._emit("error", *args) def _handle_close(*args): if not self._closed: self._closed = True - self._emit('close', *args) + self._emit("close", *args) - self.pipe.on('error', _handle_error) - self.pipe.on('close', _handle_close) + self.pipe.on("error", _handle_error) + self.pipe.on("close", _handle_close) async def write(self, chunk: T1) -> bool: if self._closed: @@ -42,7 +44,7 @@ async def write(self, chunk: T1) -> bool: encoded = await self.encode(chunk) return self.pipe.write(encoded) except Exception as e: - self._emit('error', e) + self._emit("error", e) return False async def read(self, timeout: Optional[float] = None) -> Optional[T1]: @@ -67,7 +69,7 @@ async def read(self, timeout: Optional[float] = None) -> Optional[T1]: if remaining == 0: return None except Exception as e: - self._emit('error', e) + self._emit("error", e) return None def terminate(self) -> None: @@ -77,4 +79,4 @@ def terminate(self) -> None: self.pipe.terminate() def _propagate_error(self, *args): - self.pipe._emit('error', *args) + self.pipe._emit("error", *args) diff --git a/cbor_rpc/transformer/cbor_transformer.py b/cbor_rpc/transformer/cbor_transformer.py index b8a4235..dee4db0 100644 --- a/cbor_rpc/transformer/cbor_transformer.py +++ b/cbor_rpc/transformer/cbor_transformer.py @@ -4,6 +4,7 @@ from .base import Transformer, AsyncTransformer from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException + class CborTransformer(Transformer[Any, Any]): """ A transformer that encodes Python objects to CBOR bytes and decodes CBOR bytes back to Python objects. @@ -35,11 +36,13 @@ def decode(self, data: Union[bytes, None]) -> Any: # Re-raise other decoding errors raise + class CborStreamTransformer(AsyncTransformer[Any, Any]): """ An async transformer that decodes a stream of concatenated CBOR objects. This is similar to how a JSON stream decoder would work, reading one object at a time. """ + def __init__(self): super().__init__() self._buffer = bytearray() @@ -63,10 +66,10 @@ async def decode(self, data: Union[bytes, None]) -> Any: stream = BytesIO(self._buffer) decoder = cbor2.CBORDecoder(stream) decoded_data = decoder.decode() - + bytes_consumed = stream.tell() self._buffer = self._buffer[bytes_consumed:] - + return decoded_data except cbor2.CBORDecodeEOF: raise NeedsMoreDataException() @@ -78,17 +81,17 @@ async def decode(self, data: Union[bytes, None]) -> Any: # Discard one byte and try again self._buffer = self._buffer[1:] if not self._buffer: - break # Buffer is empty, cannot recover further + break # Buffer is empty, cannot recover further try: stream = BytesIO(self._buffer) decoder = cbor2.CBORDecoder(stream) decoded_data = decoder.decode() - + bytes_consumed = stream.tell() self._buffer = self._buffer[bytes_consumed:] - - return decoded_data # Successfully decoded a new object + + return decoded_data # Successfully decoded a new object except cbor2.CBORDecodeEOF: # If we need more data after discarding some, it means we might be in the middle of a valid object raise NeedsMoreDataException() diff --git a/cbor_rpc/transformer/json_transformer.py b/cbor_rpc/transformer/json_transformer.py index e79f251..0626c6e 100644 --- a/cbor_rpc/transformer/json_transformer.py +++ b/cbor_rpc/transformer/json_transformer.py @@ -2,22 +2,25 @@ from typing import Any, Union from .base import Transformer + class JsonTransformer(Transformer[Any, Any]): """ A transformer that encodes Python objects to JSON strings and decodes JSON strings back to Python objects. """ - def __init__(self, encoding: str = 'utf-8'): + def __init__(self, encoding: str = "utf-8"): super().__init__() self.encoding = encoding def encode(self, data: Any) -> bytes: try: - json_str = json.dumps(data, ensure_ascii=False) # Always allow non-ASCII characters to pass through json.dumps + json_str = json.dumps( + data, ensure_ascii=False + ) # Always allow non-ASCII characters to pass through json.dumps return json_str.encode(self.encoding) except Exception as e: # Removed print statement as it was for debugging - raise # Re-raise to be caught by EventTransformerPipe + raise # Re-raise to be caught by EventTransformerPipe def decode(self, data: Union[bytes, str, None]) -> Any: if data is None: diff --git a/examples/fs_rpc/filesystem_client.py b/examples/fs_rpc/filesystem_client.py index b318082..4e717fc 100644 --- a/examples/fs_rpc/filesystem_client.py +++ b/examples/fs_rpc/filesystem_client.py @@ -1,8 +1,11 @@ import asyncio -from cbor_rpc import RpcV1 +from cbor_rpc import RpcV1 from cbor_rpc.tcp import TcpPipe from cbor_rpc.transformer.json_transformer import JsonTransformer -from cbor_rpc.pipe.event_pipe import EventPipe # Keep this import for clarity, though not directly instantiated +from cbor_rpc.pipe.event_pipe import ( + EventPipe, +) # Keep this import for clarity, though not directly instantiated + async def main(): # Connect to the RPC server @@ -30,11 +33,7 @@ async def main(): print("File content:", content.decode()) # Write to the test file - write_success = await rpc_client.call_method( - "create_file", - "test.txt", - b"Hello, world!" - ) + write_success = await rpc_client.call_method("create_file", "test.txt", b"Hello, world!") print("Write successful:", write_success) # Read the updated file @@ -42,16 +41,13 @@ async def main(): print("Updated file content:", content.decode()) # Rename the file - rename_success = await rpc_client.call_method( - "rename_file", - "test.txt", - "renamed_test.txt" - ) + rename_success = await rpc_client.call_method("rename_file", "test.txt", "renamed_test.txt") print("Rename successful:", rename_success) # Delete the renamed file delete_success = await rpc_client.call_method("delete_file", "renamed_test.txt") print("Delete successful:", delete_success) + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/fs_rpc/filesystem_server.py b/examples/fs_rpc/filesystem_server.py index df97609..97982aa 100644 --- a/examples/fs_rpc/filesystem_server.py +++ b/examples/fs_rpc/filesystem_server.py @@ -2,9 +2,11 @@ from typing import List, Optional, Any from cbor_rpc import RpcV1Server + class FilesystemRpcServer(RpcV1Server): async def validate_event_broadcast(self, connection_id, topic, message): return False + async def handle_method_call(self, connection_id: str, method: str, args: List[Any]) -> Any: if method == "list_files": return self.list_files(*args) @@ -29,7 +31,7 @@ def list_files(self, directory: str) -> List[str]: def read_file(self, path: str, chunk_size: int = 4096, offset: int = 0) -> bytes: """Reads a file in chunks.""" try: - with open(path, 'rb') as f: + with open(path, "rb") as f: f.seek(offset) return f.read(chunk_size) except Exception as e: @@ -38,7 +40,7 @@ def read_file(self, path: str, chunk_size: int = 4096, offset: int = 0) -> bytes def create_file(self, path: str, content: Optional[bytes] = None) -> bool: """Creates a file with optional initial content.""" try: - with open(path, 'wb') as f: + with open(path, "wb") as f: if content: f.write(content) return True @@ -64,18 +66,20 @@ def rename_file(self, src: str, dest: str) -> bool: print(f"Error renaming file: {str(e)}") return False + if __name__ == "__main__": import asyncio from cbor_rpc.tcp import TcpPipe, TcpServer from cbor_rpc.transformer.json_transformer import JsonTransformer + async def main(): - rpc_id=1 + rpc_id = 1 # Create a TCP server that handles connections, using JsonTransformer for RPC messages tcp_server = await TcpServer.create("localhost", 8000) print("Server running on port 8000") # Set up event handlers for new connections - async def handle_connection( rpc_pipe): + async def handle_connection(rpc_pipe): server = FilesystemRpcServer() await server.add_connection(str(rpc_id), rpc_pipe) diff --git a/pyproject.toml b/pyproject.toml index 911dfc8..3044564 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,3 +27,6 @@ addopts = "-v" testpaths = ["tests"] python_files = ["test_*.py"] pythonpath = ["cbor_rpc"] + +[tool.black] +line-length = 120 diff --git a/setup.py b/setup.py index ce8d943..9785ff0 100644 --- a/setup.py +++ b/setup.py @@ -10,10 +10,7 @@ long_description_content_type="text/markdown", # url="https://github.com/your_username/cbor-rpc", # Replace with your project's URL packages=find_packages(exclude=["cbor_rpc"]), - install_requires=[ - "pytest>=8.3.2", - "pytest-asyncio>=0.24.0" - ], + install_requires=["pytest>=8.3.2", "pytest-asyncio>=0.24.0"], python_requires=">=3.8", classifiers=[ "Programming Language :: Python :: 3", diff --git a/tests/docker/sshd-python/binary_emitter.py b/tests/docker/sshd-python/binary_emitter.py index 6f4e6b7..9b987c8 100644 --- a/tests/docker/sshd-python/binary_emitter.py +++ b/tests/docker/sshd-python/binary_emitter.py @@ -3,7 +3,7 @@ import time import binascii -hex_data = "DEADBEEF0001020380FF7F" # Removed \x0A\x0D (LF and CR) +hex_data = "DEADBEEF0001020380FF7F" # Removed \x0A\x0D (LF and CR) test_data = binascii.unhexlify(hex_data) while True: diff --git a/tests/docker/sshd-python/echo_back.py b/tests/docker/sshd-python/echo_back.py index 890afd8..c47446a 100644 --- a/tests/docker/sshd-python/echo_back.py +++ b/tests/docker/sshd-python/echo_back.py @@ -10,7 +10,7 @@ while True: try: # Read a chunk of data from stdin - data = os.read(stdin_fd, 4096) # Read up to 4096 bytes + data = os.read(stdin_fd, 4096) # Read up to 4096 bytes if not data: # EOF reached on stdin break @@ -23,6 +23,6 @@ break except Exception as e: # Log any other errors to stderr - sys.stderr.write(f"Error in echo_back.py: {e}\n".encode('utf-8')) + sys.stderr.write(f"Error in echo_back.py: {e}\n".encode("utf-8")) sys.stderr.flush() break diff --git a/tests/helpers/simple_pipe.py b/tests/helpers/simple_pipe.py index 10b196f..eca1647 100644 --- a/tests/helpers/simple_pipe.py +++ b/tests/helpers/simple_pipe.py @@ -2,7 +2,7 @@ from cbor_rpc import EventPipe, RpcV1, TimedPromise # Generic type variables -T1 = TypeVar('T1') +T1 = TypeVar("T1") class SimplePipe(EventPipe[T1, T1], Generic[T1]): diff --git a/tests/helpers/simple_tcp_server.py b/tests/helpers/simple_tcp_server.py index bfbc081..233064f 100644 --- a/tests/helpers/simple_tcp_server.py +++ b/tests/helpers/simple_tcp_server.py @@ -1,9 +1,11 @@ from cbor_rpc.tcp.tcp import TcpServer, TcpPipe + class SimpleTcpServer(TcpServer): """ A simple TCP server implementation for testing purposes that accepts all connections. """ + async def accept(self, pipe: TcpPipe) -> bool: """ Accepts all incoming TCP connections. diff --git a/tests/helpers/stdio_test_script.py b/tests/helpers/stdio_test_script.py index eb2453d..42a9f33 100644 --- a/tests/helpers/stdio_test_script.py +++ b/tests/helpers/stdio_test_script.py @@ -1,5 +1,6 @@ import sys + def main(): print("stdio_test_script: Started.", file=sys.stderr) @@ -15,5 +16,6 @@ def main(): sys.stdout.write(data) sys.stdout.flush() + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/test_cbor_transformer.py b/tests/test_cbor_transformer.py index e5991c0..dbb4ce6 100644 --- a/tests/test_cbor_transformer.py +++ b/tests/test_cbor_transformer.py @@ -34,6 +34,7 @@ def client_cbor(client_raw): cbor_transformer = CborTransformer() return cbor_transformer.applyTransformer(client_raw) + @pytest.mark.asyncio class TestCborTransformer: @@ -46,7 +47,7 @@ async def test_cbor_transformer_end_to_end_simple_dict(self, client_raw, server_ original_data = {"message": "Hello, CBOR!", "number": 456, "list": [1, 2, 3]} await client_transformed_pipe.write(original_data) encoded_data_received_by_server = await received_data_queue.get() - + # Verify the raw bytes received by the server are valid CBOR decoded_by_server = cbor2.loads(encoded_data_received_by_server) assert decoded_by_server == original_data @@ -66,7 +67,7 @@ async def test_cbor_transformer_decoding_error_on_read(self, server_raw, client_ client_transformed_pipe.on("error", error_queue.put_nowait) # Simulate server sending incomplete CBOR bytes - incomplete_cbor_bytes = b'\x83\x01\x02' # Incomplete array, missing one element + incomplete_cbor_bytes = b"\x83\x01\x02" # Incomplete array, missing one element with pytest.raises(cbor2.CBORDecodeError): await server_raw.write(incomplete_cbor_bytes) @@ -76,7 +77,7 @@ async def test_cbor_transformer_decoding_error_on_read(self, server_raw, client_ assert "Incomplete CBOR data for non-stream transformer" in str(error) # Send truly invalid data - truly_invalid_cbor = b'\x1f' # Unknown unsigned integer subtype + truly_invalid_cbor = b"\x1f" # Unknown unsigned integer subtype with pytest.raises(cbor2.CBORDecodeError): await server_raw.write(truly_invalid_cbor) error = await asyncio.wait_for(error_queue.get(), timeout=1) @@ -158,6 +159,7 @@ async def test_cbor_transformer_close_propagation_and_write_after_close(self, cl result = await client_transformed_pipe.write({"after": "close"}) assert result is False + @pytest.mark.asyncio class TestCborStreamTransformer: @@ -274,7 +276,7 @@ async def test_cbor_stream_transformer_invalid_data_in_stream(self, client_raw, client_transformed_pipe.on("error", error_queue.put_nowait) obj1 = {"valid": True} - invalid_bytes = b'\x1f' # Unknown unsigned integer subtype + invalid_bytes = b"\x1f" # Unknown unsigned integer subtype obj2 = {"another": "valid"} await server_raw.write(cbor2.dumps(obj1)) diff --git a/tests/test_event_emitter.py b/tests/test_event_emitter.py index 6a2b22b..e1a270b 100644 --- a/tests/test_event_emitter.py +++ b/tests/test_event_emitter.py @@ -3,9 +3,11 @@ from typing import Any, Callable from cbor_rpc.event.emitter import AbstractEmitter + @pytest.mark.asyncio async def test_on_and_emit(): """Test that subscribers registered with 'on' are called by '_emit' in registration order.""" + class TestEmitter(AbstractEmitter): pass @@ -36,13 +38,15 @@ async def async_handler2(data: Any): expected = [ f"async_handler1_event1", f"sync_handler1_event1", - f"async_handler2_event1" + f"async_handler2_event1", ] assert sorted(events) == sorted(expected), f"Expected {expected}, got {events}" + @pytest.mark.asyncio async def test_pipeline_and_notify(): """Test that pipelines run before subscribers in '_notify', respecting registration order.""" + class TestEmitter(AbstractEmitter): pass @@ -82,22 +86,23 @@ async def async_pipeline2(data: Any): expected_pipelines = [ f"async_pipeline1_event2", f"sync_pipeline1_event2", - f"async_pipeline2_event2" - ] - expected_subscribers = [ - f"async_handler1_event2", - f"sync_handler1_event2" + f"async_pipeline2_event2", ] + expected_subscribers = [f"async_handler1_event2", f"sync_handler1_event2"] pipeline_indices = [events.index(e) for e in expected_pipelines if e in events] subscriber_indices = [events.index(e) for e in expected_subscribers if e in events] - assert all(p < s for p in pipeline_indices for s in subscriber_indices), \ - f"Pipelines {expected_pipelines} should precede subscribers {expected_subscribers} in {events}" - assert sorted(events) == sorted(expected_pipelines + expected_subscribers), \ - f"Expected {expected_pipelines + expected_subscribers}, got {events}" + assert all( + p < s for p in pipeline_indices for s in subscriber_indices + ), f"Pipelines {expected_pipelines} should precede subscribers {expected_subscribers} in {events}" + assert sorted(events) == sorted( + expected_pipelines + expected_subscribers + ), f"Expected {expected_pipelines + expected_subscribers}, got {events}" + @pytest.mark.asyncio async def test_unsubscribe(): """Test that unsubscribing a handler removes it from the subscriber list.""" + class TestEmitter(AbstractEmitter): pass @@ -124,9 +129,11 @@ def sync_handler1(data: Any): expected = [f"sync_handler1_event3"] assert events == expected, f"Expected {expected}, got {events}" + @pytest.mark.asyncio async def test_replace_on_handler(): """Test that replace_on_handler sets only the new handler for the event.""" + class TestEmitter(AbstractEmitter): pass @@ -152,9 +159,11 @@ def sync_handler1(data: Any): expected = [f"async_handler1_event4"] assert events == expected, f"Expected {expected}, got {events}" + @pytest.mark.asyncio async def test_pipeline_failure(): """Test that '_notify' raises an exception if a pipeline fails and doesn't call subscribers.""" + class TestEmitter(AbstractEmitter): pass @@ -184,9 +193,11 @@ def sync_handler1(data: Any): expected = [f"async_pipeline1_event5"] assert events == expected, f"Expected {expected}, got {events}" + @pytest.mark.asyncio async def test_multiple_event_types(): """Test that only handlers for the triggered event type are called.""" + class TestEmitter(AbstractEmitter): pass @@ -241,10 +252,12 @@ def sync_pipeline_b(data: Any): expected_subscribers = [f"async_handler_a_data_a2", f"sync_handler_a_data_a2"] pipeline_indices = [events.index(e) for e in expected_pipelines if e in events] subscriber_indices = [events.index(e) for e in expected_subscribers if e in events] - assert all(p < s for p in pipeline_indices for s in subscriber_indices), \ - f"Pipelines {expected_pipelines} should precede subscribers {expected_subscribers} in {events}" - assert sorted(events) == sorted(expected_pipelines + expected_subscribers), \ - f"Expected {expected_pipelines + expected_subscribers}, got {events}" + assert all( + p < s for p in pipeline_indices for s in subscriber_indices + ), f"Pipelines {expected_pipelines} should precede subscribers {expected_subscribers} in {events}" + assert sorted(events) == sorted( + expected_pipelines + expected_subscribers + ), f"Expected {expected_pipelines + expected_subscribers}, got {events}" # Test _emit for event_b events.clear() @@ -261,14 +274,18 @@ def sync_pipeline_b(data: Any): expected_subscribers = [f"async_handler_b_data_b2", f"sync_handler_b_data_b2"] pipeline_indices = [events.index(e) for e in expected_pipelines if e in events] subscriber_indices = [events.index(e) for e in expected_subscribers if e in events] - assert all(p < s for p in pipeline_indices for s in subscriber_indices), \ - f"Pipelines {expected_pipelines} should precede subscribers {expected_subscribers} in {events}" - assert sorted(events) == sorted(expected_pipelines + expected_subscribers), \ - f"Expected {expected_pipelines + expected_subscribers}, got {events}" + assert all( + p < s for p in pipeline_indices for s in subscriber_indices + ), f"Pipelines {expected_pipelines} should precede subscribers {expected_subscribers} in {events}" + assert sorted(events) == sorted( + expected_pipelines + expected_subscribers + ), f"Expected {expected_pipelines + expected_subscribers}, got {events}" + @pytest.mark.asyncio async def test_background_task_failure(): """Test that background task failures in '_emit' don't affect other subscribers.""" + class TestEmitter(AbstractEmitter): pass @@ -300,10 +317,11 @@ def sync_handler(data: Any): expected = [ f"async_handler1_event6", f"async_handler2_event6", - f"sync_handler_event6" + f"sync_handler_event6", ] assert sorted(events) == sorted(expected), f"Expected {expected}, got {events}" + @pytest.mark.asyncio async def test_slow_emit_does_not_block_notify(): """Test that a slow handler in _emit does not block a subsequent _notify call.""" @@ -356,13 +374,13 @@ def fast_notify_pipeline(data: Any): pipeline_index = events.index("fast_notify_pipeline_data_notify") handler_index = events.index("fast_notify_handler_data_notify") - assert pipeline_index < slow_index and handler_index < slow_index, ( - f"_notify handlers [{pipeline_index}, {handler_index}] should run before slow _emit [{slow_index}]" - ) + assert ( + pipeline_index < slow_index and handler_index < slow_index + ), f"_notify handlers [{pipeline_index}, {handler_index}] should run before slow _emit [{slow_index}]" expected = { "fast_notify_pipeline_data_notify", "fast_notify_handler_data_notify", - "slow_handler_data_emit" + "slow_handler_data_emit", } assert set(events) == expected, f"Expected {expected}, got {set(events)}" diff --git a/tests/test_event_pipe.py b/tests/test_event_pipe.py index 2d84a22..a456211 100644 --- a/tests/test_event_pipe.py +++ b/tests/test_event_pipe.py @@ -4,6 +4,7 @@ from cbor_rpc import EventPipe import pytest_asyncio + @pytest_asyncio.fixture async def event_pipe_pair(): pipe1, pipe2 = EventPipe.create_inmemory_pair() @@ -12,6 +13,7 @@ async def event_pipe_pair(): await pipe1.terminate() await pipe2.terminate() + @pytest.mark.asyncio async def test_create_pair(event_pipe_pair: Tuple[EventPipe, EventPipe]): # Positive case: Creating a pair of async pipes @@ -19,6 +21,7 @@ async def test_create_pair(event_pipe_pair: Tuple[EventPipe, EventPipe]): assert isinstance(pipe1, EventPipe) assert isinstance(pipe2, EventPipe) + @pytest.mark.asyncio async def test_write_success(event_pipe_pair: Tuple[EventPipe, EventPipe]): # Positive case: Writing a chunk successfully @@ -26,6 +29,7 @@ async def test_write_success(event_pipe_pair: Tuple[EventPipe, EventPipe]): result = await pipe1.write("test_chunk") assert result is True + @pytest.mark.asyncio async def test_terminate_success(event_pipe_pair: Tuple[EventPipe, EventPipe]): # Positive case: Terminating the pipe @@ -33,6 +37,7 @@ async def test_terminate_success(event_pipe_pair: Tuple[EventPipe, EventPipe]): await pipe1.terminate() # No exception should be raised + @pytest.mark.asyncio async def test_pipeline_execution(event_pipe_pair: Tuple[EventPipe, EventPipe]): # Positive case: Adding and executing a pipeline @@ -47,9 +52,10 @@ async def pipeline_handler(chunk: Any) -> None: pipe1.pipeline("data", pipeline_handler) await pipe1._notify("data", "test_chunk") - await asyncio.wait_for(event.wait(), timeout=1) # Wait for the handler to be called + await asyncio.wait_for(event.wait(), timeout=1) # Wait for the handler to be called assert received_chunk == "test_chunk" + @pytest.mark.asyncio async def test_pipe_pair(event_pipe_pair: Tuple[EventPipe, EventPipe]): # Positive case: Attaching two pipes @@ -64,9 +70,10 @@ async def handler(chunk: Any) -> None: pipe2.pipeline("data", handler) await pipe1.write("test_chunk") - await asyncio.wait_for(event.wait(), timeout=1) # Wait for the handler to be called + await asyncio.wait_for(event.wait(), timeout=1) # Wait for the handler to be called assert received_chunk == "test_chunk" + @pytest.mark.asyncio async def test_write_after_terminate(event_pipe_pair: Tuple[EventPipe, EventPipe]): # Negative case: Writing to a terminated pipe @@ -76,6 +83,7 @@ async def test_write_after_terminate(event_pipe_pair: Tuple[EventPipe, EventPipe result = await pipe1.write("test_chunk") assert result is False + @pytest.mark.asyncio async def test_parallel_event_writes(event_pipe_pair: Tuple[EventPipe, EventPipe]): # Test case: Multiple coroutines writing to one end of the EventPipe @@ -83,10 +91,10 @@ async def test_parallel_event_writes(event_pipe_pair: Tuple[EventPipe, EventPipe num_writes = 10 test_chunks = [f"chunk_{i}" for i in range(num_writes)] received_chunks = [] - + # Use a lock to protect shared list in concurrent access - lock = asyncio.Lock() - + lock = asyncio.Lock() + # Use a queue to signal when all chunks are received received_queue = asyncio.Queue() @@ -111,6 +119,7 @@ async def writer(chunk): # Verify all chunks are received assert sorted(received_chunks) == sorted(test_chunks) + @pytest.mark.asyncio async def test_parallel_event_processing(event_pipe_pair: Tuple[EventPipe, EventPipe]): # Test case: One pipe writes multiple chunks, and the other pipe's handler processes them @@ -118,13 +127,13 @@ async def test_parallel_event_processing(event_pipe_pair: Tuple[EventPipe, Event num_chunks = 10 test_chunks = [f"data_{i}" for i in range(num_chunks)] processed_chunks = [] - + lock = asyncio.Lock() processed_queue = asyncio.Queue() async def processing_handler(chunk: Any) -> None: # Simulate some async processing - await asyncio.sleep(0.01) + await asyncio.sleep(0.01) async with lock: processed_chunks.append(chunk) if len(processed_chunks) == num_chunks: @@ -142,12 +151,15 @@ async def processing_handler(chunk: Any) -> None: # Verify all chunks are processed assert sorted(processed_chunks) == sorted(test_chunks) + @pytest.mark.asyncio -async def test_concurrent_bidirectional_event_communication(event_pipe_pair: Tuple[EventPipe, EventPipe]): +async def test_concurrent_bidirectional_event_communication( + event_pipe_pair: Tuple[EventPipe, EventPipe], +): # Test case: Concurrent writes and event processing from both ends pipe1, pipe2 = event_pipe_pair num_messages = 5 - + client_sent_msgs = [] client_received_responses = [] server_received_msgs = [] @@ -169,8 +181,8 @@ async def server_handler(msg: Any) -> None: if len(server_sent_responses) == num_messages: server_done_event.set() - pipe1.pipeline("data", client_handler) # Client listens for responses on the 'data' pipeline - pipe2.pipeline("data", server_handler) # Server listens for client messages + pipe1.pipeline("data", client_handler) # Client listens for responses on the 'data' pipeline + pipe2.pipeline("data", server_handler) # Server listens for client messages async def client_writer_task(): for i in range(num_messages): @@ -186,9 +198,12 @@ async def client_writer_task(): # Verify client sent messages are received by server assert sorted([f"client_msg_{i}" for i in range(num_messages)]) == sorted(server_received_msgs) - + # Verify server sent messages are received by client - assert sorted([f"server_response_to_client_msg_{i}" for i in range(num_messages)]) == sorted(client_received_responses) + assert sorted([f"server_response_to_client_msg_{i}" for i in range(num_messages)]) == sorted( + client_received_responses + ) + if __name__ == "__main__": pytest.main() diff --git a/tests/test_json_transformer.py b/tests/test_json_transformer.py index 8f12ab0..7ae6b6f 100644 --- a/tests/test_json_transformer.py +++ b/tests/test_json_transformer.py @@ -11,12 +11,15 @@ DEFAULT_TIMEOUT = 2.0 -@pytest.fixture(params=[ - (EventPipe.create_inmemory_pair, "InmemoryPipe"), - (AioPipe.create_inmemory_pair, "AioPipe"), - (TcpPipe.create_inmemory_pair, "TcpPipe"), - -], ids=lambda param: param[1]) + +@pytest.fixture( + params=[ + (EventPipe.create_inmemory_pair, "InmemoryPipe"), + (AioPipe.create_inmemory_pair, "AioPipe"), + (TcpPipe.create_inmemory_pair, "TcpPipe"), + ], + ids=lambda param: param[1], +) async def pipe_pair(request): create_pair_func, _ = request.param if asyncio.iscoroutinefunction(create_pair_func): @@ -27,6 +30,7 @@ async def pipe_pair(request): await client_pipe.terminate() await server_pipe.terminate() + @pytest.fixture async def json_pipe(pipe_pair): client_raw_pipe, server_raw_pipe = pipe_pair @@ -34,13 +38,15 @@ async def json_pipe(pipe_pair): client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) return client_raw_pipe, server_raw_pipe, client_transformed_pipe, json_transformer + @pytest.fixture async def json_pipe_ascii(pipe_pair): client_raw_pipe, server_raw_pipe = pipe_pair - json_transformer = JsonTransformer(encoding='ascii') + json_transformer = JsonTransformer(encoding="ascii") client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) return client_raw_pipe, server_raw_pipe, client_transformed_pipe, json_transformer + @pytest.mark.asyncio class TestJsonTransformerPipeInteraction: @@ -71,7 +77,7 @@ async def test_json_transformer_end_to_end_simple_dict(self, json_pipe): assert False, "Test failed due to timeout waiting for data" # Manually decode the data received by the server_raw_pipe to verify it's JSON bytes - decoded_by_server = json.loads(encoded_data_received_by_server.decode('utf-8')) + decoded_by_server = json.loads(encoded_data_received_by_server.decode("utf-8")) assert decoded_by_server == original_data # Now, let's test the reverse: server sends data, client receives decoded data @@ -81,10 +87,10 @@ async def test_json_transformer_end_to_end_simple_dict(self, json_pipe): # Data to send from server response_data = {"status": "success", "code": 200} - + # Server_raw_pipe writes the *encoded* data (as if it received it from a client and is sending a response) # This data will go through client_raw_pipe and then be decoded by client_transformed_pipe - await server_raw_pipe.write(json.dumps(response_data).encode('utf-8')) + await server_raw_pipe.write(json.dumps(response_data).encode("utf-8")) # Wait for the decoded data to arrive at the client_transformed_pipe decoded_data_received_by_client = await asyncio.wait_for( @@ -110,13 +116,13 @@ async def test_json_transformer_end_to_end_unicode_characters(self, json_pipe): received_data_queue.get(), timeout=2.0, ) - decoded_by_server = json.loads(encoded_data_received_by_server.decode('utf-8')) + decoded_by_server = json.loads(encoded_data_received_by_server.decode("utf-8")) assert decoded_by_server == original_data client_received_data_queue = asyncio.Queue() client_transformed_pipe.on("data", client_received_data_queue.put_nowait) response_data = {"greeting": "こんにごは"} - await server_raw_pipe.write(json.dumps(response_data, ensure_ascii=False).encode('utf-8')) + await server_raw_pipe.write(json.dumps(response_data, ensure_ascii=False).encode("utf-8")) decoded_data_received_by_client = await asyncio.wait_for( client_received_data_queue.get(), timeout=2.0, @@ -127,7 +133,7 @@ async def test_json_transformer_encoding_error_on_write(self, json_pipe_ascii): # Use an encoding that cannot handle certain characters client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe_ascii - original_data = {"message": "Hello, world! šŸ‘‹"} # Contains non-ASCII character + original_data = {"message": "Hello, world! šŸ‘‹"} # Contains non-ASCII character # Use a queue to capture errors emitted by the transformed pipe error_queue = asyncio.Queue() @@ -135,7 +141,7 @@ async def test_json_transformer_encoding_error_on_write(self, json_pipe_ascii): # Writing this data should cause an encoding error to be emitted await client_transformed_pipe.write(original_data) - + # Assert that a UnicodeEncodeError is received error = await asyncio.wait_for(error_queue.get(), timeout=DEFAULT_TIMEOUT) assert isinstance(error, UnicodeEncodeError) @@ -148,7 +154,7 @@ async def test_json_transformer_decoding_error_on_read(self, json_pipe): client_transformed_pipe.on("error", error_queue.put_nowait) # Simulate server sending invalid JSON bytes - invalid_json_bytes = b'{,"key": "value",}' # Invalid JSON + invalid_json_bytes = b'{,"key": "value",}' # Invalid JSON try: await server_raw_pipe.write(invalid_json_bytes) except json.JSONDecodeError: @@ -168,7 +174,7 @@ async def test_json_transformer_decoding_type_error_on_read(self, json_pipe): # Simulate server sending non-bytes/str data (e.g., an int) non_string_data = 12345 try: - await server_raw_pipe.write(non_string_data) # This will pass through raw pipe as is + await server_raw_pipe.write(non_string_data) # This will pass through raw pipe as is except TypeError as exc: # TcpPipe enforces bytes-only writes; no error will be emitted by the transformer. assert isinstance(exc, TypeError) @@ -191,7 +197,7 @@ async def test_json_transformer_non_json_serializable_data(self, json_pipe): # Writing this data should cause a TypeError to be emitted await client_transformed_pipe.write(non_serializable_data) - + # Assert that a TypeError is received error = await asyncio.wait_for(error_queue.get(), timeout=DEFAULT_TIMEOUT) assert isinstance(error, TypeError) @@ -214,7 +220,7 @@ async def test_json_transformer_pipe_write_after_termination(self, json_pipe): client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe await client_raw_pipe.terminate() - + # Writing to a terminated transformed pipe should return False result = await client_transformed_pipe.write({"test": "data"}) assert result is False @@ -245,7 +251,7 @@ async def test_json_transformer_pipe_read_after_termination(self, json_pipe): except ConnectionError: # TcpPipe raises if not connected after termination. pass - + # Ensure no data is received by the transformed pipe after termination with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(data_queue.get(), timeout=0.1) diff --git a/tests/test_pipe.py b/tests/test_pipe.py index 97f6fba..ae4f6f9 100644 --- a/tests/test_pipe.py +++ b/tests/test_pipe.py @@ -4,11 +4,13 @@ from cbor_rpc.pipe.pipe import Pipe import pytest_asyncio + @pytest_asyncio.fixture async def pipe_pair(): pipe1, pipe2 = Pipe.create_pair() yield pipe1, pipe2 + @pytest.mark.asyncio async def test_create_pair(): # Positive case: Creating a pair of sync pipes @@ -18,17 +20,19 @@ async def test_create_pair(): await pipe1.terminate() await pipe2.terminate() + @pytest.mark.asyncio -async def test_write_read(pipe_pair:Tuple[Pipe,Pipe]): +async def test_write_read(pipe_pair: Tuple[Pipe, Pipe]): # Positive case: Writing and reading a chunk successfully pipe1, pipe2 = pipe_pair assert await pipe1.write("test_chunk") is True - await asyncio.sleep(0) # Allow event loop to process the write + await asyncio.sleep(0) # Allow event loop to process the write assert await pipe2.read() == "test_chunk" + @pytest.mark.asyncio -async def test_close_pipe(pipe_pair:Tuple[Pipe,Pipe]): +async def test_close_pipe(pipe_pair: Tuple[Pipe, Pipe]): # Positive case: Closing the pipe pipe1, pipe2 = pipe_pair await pipe1.terminate() @@ -37,14 +41,16 @@ async def test_close_pipe(pipe_pair:Tuple[Pipe,Pipe]): assert pipe1._closed is True assert pipe2._closed is True + @pytest.mark.asyncio -async def test_write_after_close(pipe_pair:Tuple[Pipe,Pipe]): +async def test_write_after_close(pipe_pair: Tuple[Pipe, Pipe]): # Negative case: Writing to a closed pipe pipe1, pipe2 = pipe_pair await pipe1.terminate() assert await pipe1.write("test_chunk") is False + @pytest.mark.asyncio async def test_read_timeout(pipe_pair): # Positive case: Reading with timeout @@ -52,19 +58,21 @@ async def test_read_timeout(pipe_pair): assert await pipe1.read(timeout=0.1) is None + @pytest.mark.asyncio -async def test_bidirectional_communication(pipe_pair:Tuple[Pipe,Pipe]): +async def test_bidirectional_communication(pipe_pair: Tuple[Pipe, Pipe]): # Positive case: Bidirectional communication between pipes pipe1, pipe2 = pipe_pair assert await pipe1.write("test_chunk") is True - await asyncio.sleep(0) # Allow event loop to process the write + await asyncio.sleep(0) # Allow event loop to process the write assert await pipe2.read() == "test_chunk" assert await pipe2.write("response_chunk") is True - await asyncio.sleep(0) # Allow event loop to process the write + await asyncio.sleep(0) # Allow event loop to process the write assert await pipe1.read() == "response_chunk" + @pytest.mark.asyncio async def test_parallel_writes(pipe_pair: Tuple[Pipe, Pipe]): # Test case: Multiple coroutines writing to one end of the pipe @@ -87,6 +95,7 @@ async def writer(chunk): # Verify all chunks are received and in correct order (or at least all present) assert sorted(received_chunks) == sorted(test_chunks) + @pytest.mark.asyncio async def test_parallel_reads(pipe_pair: Tuple[Pipe, Pipe]): # Test case: Multiple coroutines reading from one end of the pipe @@ -97,7 +106,7 @@ async def test_parallel_reads(pipe_pair: Tuple[Pipe, Pipe]): # Write all chunks first for chunk in test_chunks: await pipe1.write(chunk) - await asyncio.sleep(0) # Allow event loop to process the write + await asyncio.sleep(0) # Allow event loop to process the write async def reader(): return await pipe2.read() @@ -108,12 +117,13 @@ async def reader(): # Verify all chunks are received assert sorted(received_chunks) == sorted(test_chunks) + @pytest.mark.asyncio async def test_concurrent_bidirectional_communication(pipe_pair: Tuple[Pipe, Pipe]): # Test case: Concurrent writes and reads from both ends pipe1, pipe2 = pipe_pair num_messages = 20 - + async def client_task(): sent = [] received = [] @@ -144,9 +154,10 @@ async def server_task(): # Verify client sent messages are received by server assert sorted([f"client_msg_{i}" for i in range(num_messages)]) == sorted(server_received) - + # Verify server sent messages are received by client assert sorted([f"server_response_to_client_msg_{i}" for i in range(num_messages)]) == sorted(client_received) + if __name__ == "__main__": pytest.main() diff --git a/tests/test_rpc_v1.py b/tests/test_rpc_v1.py index 58ed565..1323e75 100644 --- a/tests/test_rpc_v1.py +++ b/tests/test_rpc_v1.py @@ -7,20 +7,23 @@ from tests.helpers import SimplePipe - async def sleep_method(seconds: float) -> None: await asyncio.sleep(seconds) return f"Slept for {seconds} seconds" + def add_method(a: int, b: int) -> int: return a + b + def multiply_method(a: int, b: int) -> int: return a * b + def throw_error_method(message: str) -> None: raise Exception(message) + # Method handler for RPC def method_handler(method: str, args: List[Any]) -> Any: if method == "sleep": @@ -33,43 +36,52 @@ def method_handler(method: str, args: List[Any]) -> Any: return throw_error_method(*args) raise Exception(f"Unknown method: {method}") + # Event handler for RPC async def event_handler(topic: str, message: Any) -> None: pass # No-op for testing + @pytest.fixture def pipe(): return SimplePipe() + @pytest.fixture def rpc(pipe): return RpcV1.make_rpc_v1(pipe, "test_id", method_handler, event_handler) + @pytest.mark.asyncio async def test_get_id(rpc): assert rpc.get_id() == "test_id" + @pytest.mark.asyncio async def test_add_method_success(rpc): result = await rpc.call_method("add", 3, 4) assert result == 7 + @pytest.mark.asyncio async def test_multiply_method_success(rpc): result = await rpc.call_method("multiply", 5, 6) assert result == 30 + @pytest.mark.asyncio async def test_sleep_method_success(rpc): result = await rpc.call_method("sleep", 0.1) assert result == "Slept for 0.1 seconds" + @pytest.mark.asyncio async def test_throw_error_method(rpc): with pytest.raises(Exception) as exc_info: await rpc.call_method("throwError", "Test error") assert str(exc_info.value) == "Test error" + @pytest.mark.asyncio async def test_call_method_timeout(rpc, pipe): rpc.set_timeout(100) # Set short timeout @@ -78,22 +90,26 @@ async def test_call_method_timeout(rpc, pipe): assert exc_info.value.args[0]["timeout"] is True assert exc_info.value.args[0]["timeoutPeriod"] == 100 + @pytest.mark.asyncio async def test_call_method_unknown_method(rpc): with pytest.raises(Exception) as exc_info: await rpc.call_method("unknown", 1) assert str(exc_info.value) == "Unknown method: unknown" + @pytest.mark.asyncio async def test_fire_method(rpc): await rpc.fire_method("add", 1, 2) assert rpc._counter == 1 # Verify message was sent + @pytest.mark.asyncio async def test_emit_event(rpc): await rpc.emit("test_topic", {"data": "test"}) # No assertion needed; just verify no crash + @pytest.mark.asyncio async def test_wait_next_event_success(rpc, pipe): async def simulate_event(): @@ -105,6 +121,7 @@ async def simulate_event(): assert result == {"data": "test"} await task + @pytest.mark.asyncio async def test_wait_next_event_timeout(rpc): with pytest.raises(Exception) as exc_info: @@ -112,6 +129,7 @@ async def test_wait_next_event_timeout(rpc): assert exc_info.value.args[0]["timeout"] is True assert exc_info.value.args[0]["timeoutPeriod"] == 100 + @pytest.mark.asyncio async def test_wait_next_event_already_waiting(rpc): rpc._waiters["test_topic"] = TimedPromise(1000) @@ -119,33 +137,34 @@ async def test_wait_next_event_already_waiting(rpc): await rpc.wait_next_event("test_topic") assert str(exc_info.value) == "Already waiting for event" + @pytest.mark.asyncio async def test_invalid_message_format(rpc, pipe): await pipe.write([1, 2, 3]) # Invalid message await asyncio.sleep(0.1) # Allow processing # No crash means test passes + @pytest.mark.asyncio async def test_unsupported_version(rpc, pipe): await pipe.write([2, 0, 0, "add", [1, 2]]) # Wrong version await asyncio.sleep(0.1) # Allow processing # No crash means test passes + @pytest.mark.asyncio async def test_concurrent_method_calls(rpc): - tasks = [ - rpc.call_method("add", i, i) - for i in range(3) - ] + tasks = [rpc.call_method("add", i, i) for i in range(3)] results = await asyncio.gather(*tasks) assert results == [0, 2, 4] + @pytest.mark.asyncio async def test_read_only_client(pipe): read_only = RpcV1.read_only_client(SimplePipe()) - + # We need to directly call the handle_method_call method to test it with pytest.raises(Exception) as exc_info: read_only.handle_method_call("add", [1, 2]) - + assert str(exc_info.value) == "Client Only Implementation" diff --git a/tests/test_ssh_docker_pipe.py b/tests/test_ssh_docker_pipe.py index 93d55f6..4fb22aa 100644 --- a/tests/test_ssh_docker_pipe.py +++ b/tests/test_ssh_docker_pipe.py @@ -12,14 +12,15 @@ # Define a test user and password for the SSHD container TEST_SSH_USER = "testuser" TEST_SSH_PASSWORD = "testpassword" -SSHD_IMAGE_NAME = "cbor-rpc-py-sshd-python" # Custom image name +SSHD_IMAGE_NAME = "cbor-rpc-py-sshd-python" # Custom image name SSHD_CONTAINER_NAME = "test-sshd-container" -SSHD_DOCKERFILE_PATH = "./tests/docker/sshd-python" # Path to the Dockerfile +SSHD_DOCKERFILE_PATH = "./tests/docker/sshd-python" # Path to the Dockerfile + @pytest.fixture(scope="session") def ssh_keys(): """Generates SSH keys (plain and encrypted) and a passphrase for testing.""" - private_key_obj = asyncssh.generate_private_key('ssh-rsa') + private_key_obj = asyncssh.generate_private_key("ssh-rsa") passphrase = "test_passphrase" return { @@ -27,9 +28,10 @@ def ssh_keys(): "unencrypted_public": private_key_obj.export_public_key().decode(), "encrypted_private": private_key_obj.export_private_key(passphrase=passphrase).decode(), "encrypted_public": private_key_obj.export_public_key().decode(), - "passphrase": passphrase + "passphrase": passphrase, } + @pytest.fixture(scope="session") def docker_client(): """Provides a Docker client instance.""" @@ -37,26 +39,28 @@ def docker_client(): yield client client.close() + @pytest.fixture(scope="session") def test_network(docker_client: docker.DockerClient): """Provides a Docker network for containers to communicate.""" network_name = "test-ssh-network" try: network = docker_client.networks.get(network_name) - network.remove() # Clean up existing network if it exists + network.remove() # Clean up existing network if it exists except docker.errors.NotFound: pass - + network = docker_client.networks.create(network_name, driver="bridge") yield network network.remove() -@pytest.fixture(scope="module") # Changed scope to module as requested + +@pytest.fixture(scope="module") # Changed scope to module as requested async def ssh_container_combined_auth(docker_client: docker.DockerClient, test_network, docker_host_ip, ssh_keys): container_name = "ssh-test-container-combined-auth" ssh_user = TEST_SSH_USER ssh_password = TEST_SSH_PASSWORD - public_key = ssh_keys["unencrypted_public"] # Use the unencrypted public key + public_key = ssh_keys["unencrypted_public"] # Use the unencrypted public key # Ensure previous container is stopped and removed try: @@ -76,7 +80,7 @@ async def ssh_container_combined_auth(docker_client: docker.DockerClient, test_n docker_client.images.build( path=SSHD_DOCKERFILE_PATH, tag=SSHD_IMAGE_NAME, - rm=True # Remove intermediate containers + rm=True, # Remove intermediate containers ) print(f"Docker image '{SSHD_IMAGE_NAME}' built successfully.") except docker.errors.BuildError as e: @@ -87,62 +91,75 @@ async def ssh_container_combined_auth(docker_client: docker.DockerClient, test_n container = None try: container = docker_client.containers.run( - SSHD_IMAGE_NAME, # Use the custom image name + SSHD_IMAGE_NAME, # Use the custom image name detach=True, - ports={'2222/tcp': None}, # Map container SSH port to a random host port + ports={"2222/tcp": None}, # Map container SSH port to a random host port network=test_network.name, name=container_name, environment={ "PUID": "1000", "PGID": "1000", "TZ": "Etc/UTC", - "PASSWORD_ACCESS": "true", # Password access enabled + "PASSWORD_ACCESS": "true", # Password access enabled "USER_NAME": ssh_user, "USER_PASSWORD": ssh_password, - "PUBLIC_KEY": public_key # Add one public key + "PUBLIC_KEY": public_key, # Add one public key }, - restart_policy={"Name": "no"} + restart_policy={"Name": "no"}, ) - + container.reload() host_port = None - for _ in range(30): # Wait up to 30 seconds for port mapping + for _ in range(30): # Wait up to 30 seconds for port mapping container.reload() - if '2222/tcp' in container.ports and container.ports['2222/tcp']: - host_port = container.ports['2222/tcp'][0]['HostPort'] + if "2222/tcp" in container.ports and container.ports["2222/tcp"]: + host_port = container.ports["2222/tcp"][0]["HostPort"] break time.sleep(1) if host_port is None: raise RuntimeError("Failed to get host port for SSH container within 30 seconds.") - + print(f"SSH container with combined auth running on host port: {host_port}") # Wait for SSH server to be ready ready = False - for i in range(60): # wait up to 60 seconds + for i in range(60): # wait up to 60 seconds try: + async def check_ssh_combined(): # Check password authentication try: - conn_pw = await asyncssh.connect(docker_host_ip, port=int(host_port), username=ssh_user, password=ssh_password, known_hosts=None) + conn_pw = await asyncssh.connect( + docker_host_ip, + port=int(host_port), + username=ssh_user, + password=ssh_password, + known_hosts=None, + ) conn_pw.close() print("Password auth check successful.") except (asyncssh.Error, OSError) as e: print(f"Password auth check failed: {e}") return False - + # Check public key authentication (unencrypted) try: - conn_key = await asyncssh.connect(docker_host_ip, port=int(host_port), username=ssh_user, client_keys=[asyncssh.import_private_key(ssh_keys["unencrypted_private"])], known_hosts=None) + conn_key = await asyncssh.connect( + docker_host_ip, + port=int(host_port), + username=ssh_user, + client_keys=[asyncssh.import_private_key(ssh_keys["unencrypted_private"])], + known_hosts=None, + ) conn_key.close() print("Public key auth check successful.") except (asyncssh.Error, OSError) as e: print(f"Public key auth check failed: {e}") return False - + return True - + if await check_ssh_combined(): print(f"SSH server with combined auth is ready after {i+1} seconds.") ready = True @@ -151,31 +168,32 @@ async def check_ssh_combined(): print(f"Error during SSH readiness check: {e}") pass time.sleep(1) - + if not ready: print("\nSSH server with combined auth did not become ready in time. Container logs:") if container: - print(container.logs().decode('utf-8')) + print(container.logs().decode("utf-8")) raise RuntimeError("SSH server with combined auth did not become ready in time.") - + yield container, docker_host_ip, host_port, ssh_user, ssh_password, ssh_keys["unencrypted_private"] finally: if container: print("Stopping and removing ssh-test-container-combined-auth...") print("=========================== Container Logs Start ===========================") - print(container.logs().decode('utf-8')) + print(container.logs().decode("utf-8")) print("=========================== Container Logs End ===========================") container.stop() container.remove() + @pytest.fixture(scope="session") def docker_host_ip(): """Determines the Docker host IP for connecting to containers.""" docker_host = os.environ.get("DOCKER_HOST") - + # Regex to match tcp://hostname:port, unix://socket, or ip:port - regex = r'^(?:(tcp|unix)://)?([a-zA-Z0-9.-]+)(?::\d+)?$' - + regex = r"^(?:(tcp|unix)://)?([a-zA-Z0-9.-]+)(?::\d+)?$" + if docker_host: match = re.match(regex, docker_host) if match: @@ -183,9 +201,10 @@ def docker_host_ip(): if protocol == "unix": # For unix sockets, connections are typically local, but asyncssh needs an IP # In this case, 'localhost' is usually appropriate for host-to-container communication - return "localhost" + return "localhost" return host # Return the IP or hostname - return "localhost" # Default for local Docker setup + return "localhost" # Default for local Docker setup + @pytest.mark.asyncio async def test_ssh_pipe_with_hello_world_emitter(ssh_container_combined_auth): @@ -194,7 +213,7 @@ async def test_ssh_pipe_with_hello_world_emitter(ssh_container_combined_auth): This test uses password authentication. """ container, host, port, username, password, _ = ssh_container_combined_auth - + print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} with password authentication...") pipe = None try: @@ -206,7 +225,7 @@ async def test_ssh_pipe_with_hello_world_emitter(ssh_container_combined_auth): password=password, known_hosts=None, timeout=10, - command=emitter_command + command=emitter_command, ) received_event = asyncio.Event() @@ -219,7 +238,7 @@ def on_data_callback(data): if b"hello world" in data: received_event.set() - pipe.pipeline('data', on_data_callback) + pipe.pipeline("data", on_data_callback) try: await asyncio.wait_for(received_event.wait(), timeout=10) @@ -242,7 +261,7 @@ async def test_ssh_pipe_with_password_authentication(ssh_container_combined_auth This is a dedicated test for password authentication, ensuring it works as expected. """ container, host, port, username, password, _ = ssh_container_combined_auth - + print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} with password authentication...") pipe = None try: @@ -255,7 +274,7 @@ async def test_ssh_pipe_with_password_authentication(ssh_container_combined_auth password=password, known_hosts=None, timeout=10, - command=test_command + command=test_command, ) received_event = asyncio.Event() @@ -267,7 +286,7 @@ def on_data_callback(data): if b"Password auth successful!" in data: received_event.set() - pipe.pipeline('data', on_data_callback) + pipe.pipeline("data", on_data_callback) try: await asyncio.wait_for(received_event.wait(), timeout=10) @@ -297,7 +316,7 @@ async def test_ssh_pipe_with_plain_key_authentication(ssh_container_combined_aut This test is designed to run when sshd_container is configured for plain key auth. """ container, host, port, username, _, unencrypted_private_key_content = ssh_container_combined_auth - + print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} with plain key authentication...") pipe = None @@ -307,10 +326,10 @@ async def test_ssh_pipe_with_plain_key_authentication(ssh_container_combined_aut host=host, port=port, username=username, - ssh_key_content=unencrypted_private_key_content, # Use the private key content for authentication + ssh_key_content=unencrypted_private_key_content, # Use the private key content for authentication known_hosts=None, timeout=10, - command=test_command + command=test_command, ) received_event = asyncio.Event() @@ -322,7 +341,7 @@ def on_data_callback(data): if b"Plain key auth successful!" in data: received_event.set() - pipe.pipeline('data', on_data_callback) + pipe.pipeline("data", on_data_callback) try: await asyncio.wait_for(received_event.wait(), timeout=10) @@ -352,10 +371,10 @@ async def test_ssh_pipe_with_encrypted_key_authentication(ssh_container_combined This test is designed to run when sshd_container is configured for encrypted key auth. """ container, host, port, username, _, _ = ssh_container_combined_auth - - private_key_content = ssh_keys['encrypted_private'] - passphrase = ssh_keys['passphrase'] - + + private_key_content = ssh_keys["encrypted_private"] + passphrase = ssh_keys["passphrase"] + print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} with encrypted key authentication...") pipe = None @@ -365,11 +384,11 @@ async def test_ssh_pipe_with_encrypted_key_authentication(ssh_container_combined host=host, port=port, username=username, - ssh_key_content=private_key_content, # Use the private key content for authentication - ssh_key_passphrase=passphrase, # Pass the passphrase + ssh_key_content=private_key_content, # Use the private key content for authentication + ssh_key_passphrase=passphrase, # Pass the passphrase known_hosts=None, timeout=10, - command=test_command + command=test_command, ) received_event = asyncio.Event() @@ -381,7 +400,7 @@ def on_data_callback(data): if b"Encrypted key auth successful!" in data: received_event.set() - pipe.pipeline('data', on_data_callback) + pipe.pipeline("data", on_data_callback) try: await asyncio.wait_for(received_event.wait(), timeout=10) @@ -411,8 +430,10 @@ async def test_ssh_pipe_with_echo_back_command(ssh_container_combined_auth): This test uses password authentication. """ container, host, port, username, password, _ = ssh_container_combined_auth - - print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} for echo-back test (using echo_back.py) with password authentication...") + + print( + f"\nAttempting SshPipe connection to {host}:{port} as user {username} for echo-back test (using echo_back.py) with password authentication..." + ) pipe = None try: # Use the custom Python echo-back script @@ -422,9 +443,9 @@ async def test_ssh_pipe_with_echo_back_command(ssh_container_combined_auth): port=port, username=username, password=password, - known_hosts=None, # Disable host key checking for test container + known_hosts=None, # Disable host key checking for test container timeout=10, - command=echo_back_command + command=echo_back_command, ) received_data_chunks = [] @@ -437,7 +458,7 @@ def on_data_callback(data): # We'll set the event once we receive some data. received_event.set() - pipe.pipeline('data', on_data_callback) + pipe.pipeline("data", on_data_callback) test_message = b"This is a test message for echo back\n" print(f"Writing data to pipe: {test_message!r}") @@ -450,10 +471,11 @@ def on_data_callback(data): pytest.fail("Did not receive any data within 10 seconds.") full_received_data = b"".join(received_data_chunks) - assert full_received_data.strip() == test_message.strip(), \ - f"Received data {full_received_data!r} should exactly match sent data {test_message!r}" + assert ( + full_received_data.strip() == test_message.strip() + ), f"Received data {full_received_data!r} should exactly match sent data {test_message!r}" print("Verification successful: Data echoed correctly by 'echo_back.py' script.") - await pipe.write_eof() # Signal EOF to the remote process + await pipe.write_eof() # Signal EOF to the remote process except asyncssh.Error as e: pytest.fail(f"SSH connection or command failed: {e}") @@ -465,6 +487,8 @@ def on_data_callback(data): if pipe: await pipe.terminate() print("SshPipe closed.") + + @pytest.mark.asyncio async def test_ssh_pipe_with_binary_data(ssh_container_combined_auth): """ @@ -472,38 +496,38 @@ async def test_ssh_pipe_with_binary_data(ssh_container_combined_auth): This test uses password authentication. """ container, host, port, username, password, _ = ssh_container_combined_auth - - print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} for binary data emitter test with password authentication...") + + print( + f"\nAttempting SshPipe connection to {host}:{port} as user {username} for binary data emitter test with password authentication..." + ) pipe = None try: # Python script to continuously emit binary data # Using os.write(1, ...) to write raw bytes to stdout - emitter_binary_command = ( - "python3 /usr/local/bin/binary_emitter.py" - ) + emitter_binary_command = "python3 /usr/local/bin/binary_emitter.py" pipe = await SshPipe.connect( host=host, port=port, username=username, password=password, - known_hosts=None, # Disable host key checking for test container + known_hosts=None, # Disable host key checking for test container timeout=10, - command=emitter_binary_command + command=emitter_binary_command, ) received_event = asyncio.Event() received_data_chunks = [] - expected_binary_pattern = b'\xDE\xAD\xBE\xEF\x00\x01\x02\x03\x80\xFF\x7F' # Updated pattern without newlines + expected_binary_pattern = b"\xde\xad\xbe\xef\x00\x01\x02\x03\x80\xff\x7f" # Updated pattern without newlines def on_data_callback(data): print(f"test_ssh_pipe_with_binary_data: Received data chunk: {data!r}") # Strip any potential carriage returns or newlines added by the shell - cleaned_data = data.replace(b'\r', b'').replace(b'\n', b'') + cleaned_data = data.replace(b"\r", b"").replace(b"\n", b"") received_data_chunks.append(cleaned_data) if expected_binary_pattern in cleaned_data: received_event.set() - pipe.pipeline('data', on_data_callback) + pipe.pipeline("data", on_data_callback) try: await asyncio.wait_for(received_event.wait(), timeout=10) @@ -512,9 +536,10 @@ def on_data_callback(data): full_received_data = b"".join(received_data_chunks) print(f"Received total {len(full_received_data)} bytes. First 50 bytes: {full_received_data[:50]!r}") - - assert expected_binary_pattern in full_received_data, \ - f"Expected binary pattern {expected_binary_pattern!r} not found in received data {full_received_data!r}" + + assert ( + expected_binary_pattern in full_received_data + ), f"Expected binary pattern {expected_binary_pattern!r} not found in received data {full_received_data!r}" print("Verification successful: Binary data emitted and received correctly.") except asyncssh.Error as e: diff --git a/tests/test_stdio_rpc.py b/tests/test_stdio_rpc.py index 24e23d7..46b599d 100644 --- a/tests/test_stdio_rpc.py +++ b/tests/test_stdio_rpc.py @@ -3,6 +3,7 @@ import sys from cbor_rpc.stdio.stdio_pipe import StdioPipe + @pytest.mark.asyncio async def test_stdtio_read_write(): """ @@ -20,10 +21,10 @@ def on_data(data): if len(received_data) == 10: future.set_result(None) - pipe.pipeline('data', on_data) + pipe.pipeline("data", on_data) # Write 10 unique data chunks - test_data = [f"Test data {i}\n".encode('utf-8') for i in range(10)] + test_data = [f"Test data {i}\n".encode("utf-8") for i in range(10)] for data in test_data: await pipe.write(data) # Brief sleep to allow the subprocess to process the data @@ -38,4 +39,4 @@ def on_data(data): assert received == sent, f"Mismatch at index {i}: expected {sent!r}, got {received!r}" # Terminate the pipe - pipe.terminate() \ No newline at end of file + pipe.terminate() diff --git a/tests/test_tcp.py b/tests/test_tcp.py index a372797..321fe45 100644 --- a/tests/test_tcp.py +++ b/tests/test_tcp.py @@ -4,43 +4,43 @@ from cbor_rpc import TcpPipe from tests.helpers.simple_tcp_server import SimpleTcpServer -DEFAULT_TIMEOUT = 1.0 # we are doing everything on same machine. everything should be fast +DEFAULT_TIMEOUT = 1.0 # we are doing everything on same machine. everything should be fast @pytest.mark.asyncio async def test_tcp_client_server_connection(): """Test basic TCP client-server connection.""" # Start a server - server = await SimpleTcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() - + connections = [] - + def on_connection(tcp_pipe: TcpPipe): connections.append(tcp_pipe) - + server.on_connection(on_connection) - + try: # Create a client connection client = await TcpPipe.create_connection(server_host, server_port) - + # Wait for server to register the connection await asyncio.sleep(0.1) - + assert len(connections) == 1 assert client.is_connected() assert connections[0].is_connected() - + # Test peer info client_peer = client.get_peer_info() server_conn_peer = connections[0].get_peer_info() - + assert client_peer == (server_host, server_port) assert server_conn_peer is not None - + await client.terminate() - + finally: await server.close() @@ -48,72 +48,72 @@ def on_connection(tcp_pipe: TcpPipe): @pytest.mark.asyncio async def test_tcp_data_exchange(): """Test bidirectional data exchange over TCP.""" - server = await SimpleTcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() - + server_received = [] client_received = [] server_connection = None - + async def on_connection(tcp_pipe: TcpPipe): nonlocal server_connection server_connection = tcp_pipe - + async def on_server_data(data: bytes): server_received.append(data) - + tcp_pipe.on("data", on_server_data) - + server.on_connection(on_connection) - + try: # Create client client = await TcpPipe.create_connection(server_host, server_port) - + async def on_client_data(data: bytes): client_received.append(data) - + client.on("data", on_client_data) - + # Wait for connection to be established await asyncio.sleep(0.1) assert server_connection is not None - + # Send data from client to server await client.write(b"Hello from client") await asyncio.sleep(0.1) assert server_received == [b"Hello from client"] - + # Send data from server to client await server_connection.write(b"Hello from server") await asyncio.sleep(0.1) assert client_received == [b"Hello from server"] - + # Send multiple messages separately server_received.clear() client_received.clear() - + await client.write(b"Message 1") await asyncio.sleep(0.05) # Small delay between messages await client.write(b"Message 2") await asyncio.sleep(0.1) - + await server_connection.write(b"Response 1") await asyncio.sleep(0.05) # Small delay between messages await server_connection.write(b"Response 2") await asyncio.sleep(0.1) - + # Check that messages were received (they might be combined due to TCP buffering) server_data = b"".join(server_received) client_data = b"".join(client_received) - + assert b"Message 1" in server_data assert b"Message 2" in server_data assert b"Response 1" in client_data assert b"Response 2" in client_data - + await client.terminate() - + finally: await server.close() @@ -123,25 +123,25 @@ async def test_tcp_connection_errors(): """Test TCP connection error handling.""" # Test connection to non-existent server with pytest.raises(ConnectionError): - await TcpPipe.create_connection('127.0.0.1', 12345, timeout=0.1) - + await TcpPipe.create_connection("127.0.0.1", 12345, timeout=0.1) + # Test writing to disconnected client client = TcpPipe() with pytest.raises(ConnectionError): await client.write(b"test") - + # Test double connection - server = await SimpleTcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() - + try: client = await TcpPipe.create_connection(server_host, server_port) - + with pytest.raises(ConnectionError): await client.connect(server_host, server_port) - + await client.terminate() - + finally: await server.close() @@ -149,48 +149,48 @@ async def test_tcp_connection_errors(): @pytest.mark.asyncio async def test_tcp_connection_events(): """Test TCP connection events (connect, close, error).""" - server = await SimpleTcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() - + events = [] - - server.on_connection(lambda conn: events.append("server_connect")) - + + server.on_connection(lambda conn: events.append("server_connect")) + try: # Create client with event handlers client = TcpPipe() - + async def on_client_connect(): events.append("client_connect") - + async def on_client_close(*args): events.append(("client_close", args)) - + async def on_client_error(error): events.append(("client_error", str(error))) - + client.on("connect", on_client_connect) client.on("close", on_client_close) client.on("error", on_client_error) - + # Connect await client.connect(server_host, server_port) await asyncio.sleep(0.2) # Give more time for events to propagate - + assert "client_connect" in events print(events) assert "server_connect" in events - + # Close connection events_before_close = len(events) await client.terminate("test_reason") await asyncio.sleep(0.2) # Give time for close events - + # Check that close event was added assert len(events) > events_before_close close_events = [e for e in events if isinstance(e, tuple) and e[0] == "client_close"] assert len(close_events) > 0 - + finally: await server.close() @@ -198,11 +198,9 @@ async def on_client_error(error): @pytest.mark.asyncio async def test_tcp_client_connection_tracking(): """Test handling multiple simultaneous TCP connections.""" - server = await SimpleTcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() - - - + try: # Create multiple clients clients: List[TcpPipe] = [] @@ -211,34 +209,30 @@ async def test_tcp_client_connection_tracking(): client = await TcpPipe.create_connection(server_host, server_port) client.on("close", lambda: print(f"Connection[{i}] closed")) clients.append(client) - - + await asyncio.sleep(0.5) # Give time for all connections to be registered - + # Check that all connections are registered assert len(server.get_connections()) == 5 - + # Close all clients for client in clients: await client.terminate() - - await asyncio.sleep(0.2) + await asyncio.sleep(0.2) assert len(server.get_connections()) == 0, "Connections not clean uped" - - finally: await server.close() + @pytest.mark.asyncio async def test_tcp_client_connection_tracking_self(): """Test handling multiple simultaneous TCP connections.""" - server = await SimpleTcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() - - + try: # Create multiple clients clients: List[TcpPipe] = [] @@ -247,54 +241,52 @@ async def test_tcp_client_connection_tracking_self(): client = await TcpPipe.create_connection(server_host, server_port) client.on("close", lambda: print(f"Connection[{i}] closed")) clients.append(client) - - + await asyncio.sleep(0.5) # Give time for all connections to be registered - + # Check that all connections are registered assert len(server.get_connections()) == 5 - + # Close all clients for duplex in server.get_connections(): await duplex.terminate() - await asyncio.sleep(0.2) + await asyncio.sleep(0.2) assert len(server.get_connections()) == 0, "Connections not clean uped" - - finally: await server.close() + @pytest.mark.asyncio async def test_tcp_large_data_transfer(): """Test transferring large amounts of data over TCP.""" - server = await SimpleTcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() - + received_data = bytearray() server_connection = None - + async def on_connection(tcp_pipe: TcpPipe): nonlocal server_connection server_connection = tcp_pipe - + async def on_data(data: bytes): received_data.extend(data) - + tcp_pipe.on("data", on_data) - + server.on_connection(on_connection) - + try: client = await TcpPipe.create_connection(server_host, server_port) await asyncio.sleep(0.1) - + # Send large data (100KB instead of 1MB for faster testing) large_data = b"x" * (100 * 1024 * 1024) await client.write(large_data) - + # Wait for all data to be received timeout = 5.0 # 5 second timeout start_time = asyncio.get_event_loop().time() @@ -302,11 +294,11 @@ async def on_data(data: bytes): if asyncio.get_event_loop().time() - start_time > timeout: break await asyncio.sleep(0.1) - + assert bytes(received_data) == large_data - + await client.terminate() - + finally: await server.close() @@ -314,39 +306,39 @@ async def on_data(data: bytes): @pytest.mark.asyncio async def test_tcp_server_context_manager(): """Test using TcpServer as a context manager.""" - async with await SimpleTcpServer.create('127.0.0.1', 0) as server: + async with await SimpleTcpServer.create("127.0.0.1", 0) as server: server_host, server_port = server.get_address() - + client = await TcpPipe.create_connection(server_host, server_port) assert client.is_connected() - + await client.terminate() - + # Server should be closed automatically @pytest.mark.asyncio async def test_tcp_invalid_data_types(): """Test error handling for invalid data types.""" - server = await SimpleTcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() - + try: client = await TcpPipe.create_connection(server_host, server_port) - + # Test writing non-bytes data with pytest.raises(TypeError): await client.write("string data") # Should be bytes - + with pytest.raises(TypeError): await client.write(123) # Should be bytes - + # Test writing valid data types await client.write(b"bytes data") # Should work await client.write(bytearray(b"bytearray data")) # Should work - + await client.terminate() - + finally: await server.close() @@ -382,7 +374,7 @@ async def test_tcp_inmemory_pair_bidirectional_exchange(): @pytest.mark.asyncio async def test_tcp_shutdown_keeps_active_connections(): """Test shutting down the listener doesn't drop existing connections.""" - server = await SimpleTcpServer.create('127.0.0.1', 0) + server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() server_connection = None From aeb9f4f007b1c4357f7f32e666040cbbdbd706a8 Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Fri, 6 Feb 2026 23:11:08 +0545 Subject: [PATCH 16/25] Update dependencies in setup.py --- cbor_rpc/__init__.py | 4 +++- requirements.txt | 26 ++++++++++++++++++++++++-- setup.py | 7 +++++-- 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/cbor_rpc/__init__.py b/cbor_rpc/__init__.py index e4cf8a2..d3cb71a 100644 --- a/cbor_rpc/__init__.py +++ b/cbor_rpc/__init__.py @@ -7,7 +7,7 @@ from .timed_promise import TimedPromise from .rpc import RpcClient, RpcAuthorizedClient, RpcServer, RpcV1, RpcV1Server, Server from .tcp import TcpPipe, TcpServer -from .transformer import JsonTransformer, Transformer +from .transformer import CborStreamTransformer, CborTransformer, JsonTransformer, Transformer __all__ = [ # Promise @@ -32,6 +32,8 @@ # Transformers "Transformer", "JsonTransformer", + "CborTransformer", + "CborStreamTransformer", ] __version__ = "0.1.0" diff --git a/requirements.txt b/requirements.txt index 65b7223..b3e5e48 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,30 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # -# pip-compile +# pip install pip-tools +# pip-compile --output-file=requirements.txt pyproject.toml # +asyncssh==2.22.0 + # via cbor-rpc (pyproject.toml) +bcrypt==5.0.0 + # via cbor-rpc (pyproject.toml) +cbor2==5.8.0 + # via cbor-rpc (pyproject.toml) +cffi==2.0.0 + # via cryptography +cryptography==46.0.4 + # via asyncssh +exceptiongroup==1.3.1 + # via pytest iniconfig==2.1.0 # via pytest packaging==25.0 # via pytest pluggy==1.6.0 # via pytest +pycparser==3.0 + # via cffi pygments==2.19.1 # via pytest pytest==8.4.0 @@ -18,3 +33,10 @@ pytest==8.4.0 # pytest-asyncio pytest-asyncio==1.0.0 # via cbor-rpc (pyproject.toml) +tomli==2.4.0 + # via pytest +typing-extensions==4.15.0 + # via + # asyncssh + # cryptography + # exceptiongroup diff --git a/setup.py b/setup.py index 9785ff0..ed868ab 100644 --- a/setup.py +++ b/setup.py @@ -8,9 +8,12 @@ author_email="sudip@bhattarai.me", long_description=open("README.md").read(), long_description_content_type="text/markdown", - # url="https://github.com/your_username/cbor-rpc", # Replace with your project's URL + url="https://github.com/mesudip/cbor-rpc-py ", packages=find_packages(exclude=["cbor_rpc"]), - install_requires=["pytest>=8.3.2", "pytest-asyncio>=0.24.0"], + install_requires=["asyncssh>=2.14.0", "bcrypt", "cbor2"], + extras_require={ + "test": ["pytest>=8.3.2", "pytest-asyncio>=0.24.0"], + }, python_requires=">=3.8", classifiers=[ "Programming Language :: Python :: 3", From db38f374087273b26496833da194fb045a913d34 Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Sun, 8 Feb 2026 01:31:47 +0545 Subject: [PATCH 17/25] Add logger to the rpc system --- .github/workflows/tests.yml | 53 ++++ README.md | 3 +- cbor_rpc/__init__.py | 15 +- cbor_rpc/event/emitter.py | 8 +- cbor_rpc/pipe/aio_pipe.py | 9 +- cbor_rpc/rpc/__init__.py | 2 + cbor_rpc/rpc/context.py | 6 + cbor_rpc/rpc/logging.py | 28 +++ cbor_rpc/rpc/rpc_base.py | 15 -- cbor_rpc/rpc/rpc_server.py | 73 ++---- cbor_rpc/rpc/rpc_v1.py | 308 +++++++++++++++-------- cbor_rpc/transformer/json_transformer.py | 13 +- examples/fs_rpc/filesystem_server.py | 9 +- pyproject.toml | 9 + requirements.txt | 4 + setup.py | 2 +- tests/test_rpc_v1.py | 53 ++-- 17 files changed, 408 insertions(+), 202 deletions(-) create mode 100644 .github/workflows/tests.yml create mode 100644 cbor_rpc/rpc/context.py create mode 100644 cbor_rpc/rpc/logging.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..b2c12b2 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,53 @@ +name: Run Tests + +on: + push: + branches: + - main + - master + - dev + - develop + pull_request: + branches: + - main + - master + - dev + - develop + +jobs: + test: + runs-on: ubuntu-latest + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -e . + + - name: Run tests with coverage + run: | + pytest --cov=cbor_rpc --cov-report=xml --junitxml=junit.xml -o junit_family=legacy + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: coverage.xml + + - name: Upload test results to Codecov + if: ${{ !cancelled() }} + uses: codecov/test-results-action@v1 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: junit.xml diff --git a/README.md b/README.md index fde5239..8790aec 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,3 @@ cbor-rpc -======== \ No newline at end of file +======== +[![codecov](https://codecov.io/github/mesudip/cbor-rpc-py/graph/badge.svg)](https://codecov.io/github/mesudip/cbor-rpc-py) \ No newline at end of file diff --git a/cbor_rpc/__init__.py b/cbor_rpc/__init__.py index d3cb71a..2d29e51 100644 --- a/cbor_rpc/__init__.py +++ b/cbor_rpc/__init__.py @@ -5,9 +5,18 @@ from .event import AbstractEmitter from .pipe import EventPipe, Pipe from .timed_promise import TimedPromise -from .rpc import RpcClient, RpcAuthorizedClient, RpcServer, RpcV1, RpcV1Server, Server from .tcp import TcpPipe, TcpServer from .transformer import CborStreamTransformer, CborTransformer, JsonTransformer, Transformer +from .rpc import ( + RpcClient, + RpcAuthorizedClient, + RpcServer, + RpcV1, + RpcV1Server, + Server, + RpcCallContext, + RpcLogger, +) __all__ = [ # Promise @@ -26,7 +35,9 @@ # Rpc base implementation "RpcV1", "RpcV1Server", - # TCP classes + # Rpc high level + "RpcCallContext", + "RpcLogger", # TCP classes "TcpPipe", "TcpServer", # Transformers diff --git a/cbor_rpc/event/emitter.py b/cbor_rpc/event/emitter.py index e5727c1..970ceaa 100644 --- a/cbor_rpc/event/emitter.py +++ b/cbor_rpc/event/emitter.py @@ -28,7 +28,13 @@ async def runner(): traceback.print_exc() warnings.warn(f"Background task error in handler: {e}", RuntimeWarning) - asyncio.create_task(runner()) + try: + loop = asyncio.get_running_loop() + except RuntimeError: + warnings.warn("Background task skipped: no running event loop", RuntimeWarning) + return + + loop.create_task(runner()) def _emit(self, event_type: str, *args: Any) -> None: for sub in self._subscribers.get(event_type, []): diff --git a/cbor_rpc/pipe/aio_pipe.py b/cbor_rpc/pipe/aio_pipe.py index c866c15..45c8989 100644 --- a/cbor_rpc/pipe/aio_pipe.py +++ b/cbor_rpc/pipe/aio_pipe.py @@ -86,11 +86,14 @@ async def _read_loop(self) -> None: break except asyncio.CancelledError: break - except Exception as e: # Catch BaseException for GeneratorExit/other BaseExceptions + except BaseException as e: + if isinstance(e, GeneratorExit): + break self._emit("error", e) # Synchronous _emit break - except Exception as e: # Catch BaseException for GeneratorExit/other BaseExceptions - self._emit("error", e) # Synchronous _emit + except BaseException as e: + if not isinstance(e, (asyncio.CancelledError, GeneratorExit)): + self._emit("error", e) # Synchronous _emit finally: if not self._closed: await self._close_connection() diff --git a/cbor_rpc/rpc/__init__.py b/cbor_rpc/rpc/__init__.py index 431c063..e734bb8 100644 --- a/cbor_rpc/rpc/__init__.py +++ b/cbor_rpc/rpc/__init__.py @@ -3,4 +3,6 @@ from .rpc_base import RpcClient, RpcServer, RpcAuthorizedClient from .rpc_v1 import RpcV1 +from .context import RpcCallContext +from .logging import RpcLogger from .rpc_server import RpcV1Server diff --git a/cbor_rpc/rpc/context.py b/cbor_rpc/rpc/context.py new file mode 100644 index 0000000..7cf0ee5 --- /dev/null +++ b/cbor_rpc/rpc/context.py @@ -0,0 +1,6 @@ +from .logging import RpcLogger + + +class RpcCallContext: + def __init__(self, logger: RpcLogger): + self.logger = logger diff --git a/cbor_rpc/rpc/logging.py b/cbor_rpc/rpc/logging.py new file mode 100644 index 0000000..87bb1c0 --- /dev/null +++ b/cbor_rpc/rpc/logging.py @@ -0,0 +1,28 @@ +from typing import Any, Callable + + +class RpcLogger: + def __init__( + self, + send_log: Callable[[int, int, Any, Any], None], + ref_proto: int, + ref_id: Callable[[], Any], + ): + self._send_log = send_log + self._ref_proto = ref_proto + self._ref_id = ref_id + + def log(self, content: Any) -> None: + self._send_log(3, self._ref_proto, self._ref_id(), content) + + def warn(self, content: Any) -> None: + self._send_log(2, self._ref_proto, self._ref_id(), content) + + def crit(self, content: Any) -> None: + self._send_log(1, self._ref_proto, self._ref_id(), content) + + def verbose(self, content: Any) -> None: + self._send_log(4, self._ref_proto, self._ref_id(), content) + + def debug(self, content: Any) -> None: + self._send_log(5, self._ref_proto, self._ref_id(), content) diff --git a/cbor_rpc/rpc/rpc_base.py b/cbor_rpc/rpc/rpc_base.py index efbb895..fb889f2 100644 --- a/cbor_rpc/rpc/rpc_base.py +++ b/cbor_rpc/rpc/rpc_base.py @@ -1,16 +1,9 @@ from typing import Any, Dict, List, Optional, Callable from abc import ABC, abstractmethod -import asyncio -import inspect from ..pipe.event_pipe import EventPipe -from ..timed_promise import TimedPromise class RpcClient(ABC): - @abstractmethod - async def emit(self, topic: str, message: Any) -> None: - pass - @abstractmethod async def call_method(self, method: str, *args: Any) -> Any: pass @@ -31,14 +24,6 @@ def get_id(self) -> str: class RpcServer(ABC): - @abstractmethod - async def emit(self, connection_id: str, topic: str, message: Any) -> None: - pass - - @abstractmethod - async def broadcast(self, topic: str, message: Any) -> None: - pass - @abstractmethod async def call_method(self, connection_id: str, method: str, *args: Any) -> Any: pass diff --git a/cbor_rpc/rpc/rpc_server.py b/cbor_rpc/rpc/rpc_server.py index 69f8bcd..e7eca90 100644 --- a/cbor_rpc/rpc/rpc_server.py +++ b/cbor_rpc/rpc/rpc_server.py @@ -1,96 +1,77 @@ from typing import Any, Dict, List, Optional, Callable from abc import ABC, abstractmethod -import asyncio from cbor_rpc.rpc.server_base import Server -from .rpc_base import RpcClient, RpcAuthorizedClient, RpcServer +from .rpc_base import RpcAuthorizedClient, RpcServer from .rpc_v1 import RpcV1 +from .context import RpcCallContext from cbor_rpc.pipe.event_pipe import EventPipe class RpcV1Server(RpcServer): def __init__(self, server: Server): - self.active_connections: Dict[str, RpcV1] = {} + self.active_connections: Dict[str, EventPipe[Any, Any]] = {} + self.rpc_clients: Dict[str, RpcV1] = {} self.timeout = 30000 async def add_connection(self, conn_id: str, rpc_client: EventPipe[Any, Any]) -> None: - def method_handler(method: str, args: List[Any]) -> Any: - return self.handle_method_call(conn_id, method, args) + def method_handler(context: RpcCallContext, method: str, args: List[Any]) -> Any: + return self.handle_method_call(conn_id, context, method, args) - async def event_handler(topic: str, data: Any) -> None: - await self._handle_event(conn_id, topic, data) - - client_rpc = RpcV1.make_rpc_v1(rpc_client, conn_id, method_handler, event_handler) + client_rpc = RpcV1.make_rpc_v1(rpc_client, conn_id, method_handler) client_rpc.set_timeout(self.timeout) - self.active_connections[conn_id] = client_rpc + + self.active_connections[conn_id] = rpc_client + self.rpc_clients[conn_id] = client_rpc # Set up cleanup on close async def cleanup(*args): self.active_connections.pop(conn_id, None) + self.rpc_clients.pop(conn_id, None) rpc_client.on("close", cleanup) async def disconnect(self, connection_id: str, reason: Optional[str] = None) -> None: - client = self.active_connections.pop(connection_id, None) - if client: + base = self.active_connections.pop(connection_id, None) + self.rpc_clients.pop(connection_id, None) + if base: print("RpcV1Server: Disconnecting client:", connection_id) - await client.pipe.terminate(1000, reason or "Server terminated connection") + await base.terminate(1000, reason or "Server terminated connection") def set_timeout(self, milliseconds: int) -> None: self.timeout = milliseconds def is_active(self, connection_id: str) -> bool: - return connection_id in self.active_connections - - async def broadcast(self, topic: str, message: Any) -> None: - tasks = [] - for client in self.active_connections.values(): - tasks.append(client.emit(topic, message)) - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) + return connection_id in self.rpc_clients def get_client(self, connection_id: str) -> Optional[RpcAuthorizedClient]: - return self.active_connections.get(connection_id) + return self.rpc_clients.get(connection_id) async def call_method(self, connection_id: str, method: str, *args: Any) -> Any: - client = self.active_connections.get(connection_id) + client = self.rpc_clients.get(connection_id) if client: return await client.call_method(method, *args) raise Exception("Client is not active") - async def emit(self, connection_id: str, topic: str, message: Any) -> None: - client = self.active_connections.get(connection_id) - if client: - await client.emit(topic, message) - else: - raise Exception("Client is not active") - - async def _handle_event(self, connection_id: str, topic: str, message: Any) -> None: - if await self.validate_event_broadcast(connection_id, topic, message): - tasks = [] - for key, client in self.active_connections.items(): - if key != connection_id: - tasks.append(client.emit(topic, message)) - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) - async def fire_method(self, connection_id: str, method: str, *args: Any) -> None: - client = self.active_connections.get(connection_id) + client = self.rpc_clients.get(connection_id) if client: await client.fire_method(method, *args) else: raise Exception("Client is not active") @abstractmethod - async def handle_method_call(self, connection_id: str, method: str, args: List[Any]) -> Any: - pass - - @abstractmethod - async def validate_event_broadcast(self, connection_id: str, topic: str, message: Any) -> bool: + async def handle_method_call( + self, + connection_id: str, + context: RpcCallContext, + method: str, + args: List[Any], + ) -> Any: pass def with_client(self, connection_id: str, action: Callable) -> bool: - client = self.active_connections.get(connection_id) + client = self.rpc_clients.get(connection_id) if client: action(client) return True diff --git a/cbor_rpc/rpc/rpc_v1.py b/cbor_rpc/rpc/rpc_v1.py index fc78d1f..1c77e7c 100644 --- a/cbor_rpc/rpc/rpc_v1.py +++ b/cbor_rpc/rpc/rpc_v1.py @@ -1,21 +1,25 @@ import sys from typing import Any, Dict, List, Optional, Callable -from abc import ABC, abstractmethod +from abc import abstractmethod import asyncio -import inspect from .rpc_base import RpcClient +from .context import RpcCallContext +from .logging import RpcLogger from cbor_rpc.pipe.event_pipe import EventPipe from cbor_rpc.timed_promise import TimedPromise -class RpcV1(RpcClient): +class RpcCore(RpcClient): + protocol_id = 1 def __init__(self, pipe: EventPipe[Any, Any]): self.pipe = pipe self._counter = 0 self._promises: Dict[int, TimedPromise] = {} self._timeout = 30000 - self._waiters: Dict[str, TimedPromise] = {} + self._peer_log_level = 0 + self._desired_log_level = 0 + self.logger = RpcLogger(self._send_log, 1, self._get_default_ref_id) async def resolve_result(result: Any) -> Any: """Recursively resolve coroutines or nested coroutines.""" @@ -23,72 +27,133 @@ async def resolve_result(result: Any) -> Any: result = await result return result + self._resolve_result = resolve_result + async def on_data(data: List[Any]) -> None: try: - if not isinstance(data, list) or len(data) != 5: - print(f"RpcV1: Invalid message format: {data}", file=sys.stderr) + if not isinstance(data, list) or len(data) < 2: + print(f"RpcCore: Invalid message format: {data}", file=sys.stderr) return - version, direction, id_, method, params = data - if version != 1: - print(f"RpcV1: Unsupported version: {data}", file=sys.stderr) - return - print("RpvV1: Received", data, file=sys.stderr) - - if direction < 2: # Method call (0) or fire (1) - try: - # Call the method and get the result - result = self.handle_method_call(method, params) - - # Handle the response asynchronously - async def handle_response(): - try: - resolved_result = await resolve_result(result) - if direction == 0: # Only respond to method calls, not fire calls - await self.pipe.write([1, 2, id_, True, resolved_result]) - except Exception as e: - if direction == 0: - await self.pipe.write([1, 2, id_, False, str(e)]) - else: - print( - f"Fired method error: {method}, params={params}, error={e}", - file=sys.stderr, - ) - - # Create task to handle response - asyncio.create_task(handle_response()) - - except Exception as e: - if direction == 0: - asyncio.create_task(self.pipe.write([1, 2, id_, False, str(e)])) - else: + protocol_id = data[0] + if protocol_id == 1: + await self.handle_proto_1(data) + elif protocol_id == 2: + await self.handle_proto_2(data) + elif protocol_id == 3: + await self.handle_proto_3(data) + else: + print(f"RpcCore: Unsupported protocol: {data}", file=sys.stderr) + + except Exception as e: + print(f"Error processing RPC message: {e}", file=sys.stderr) + + self.pipe.on("data", on_data) + + async def handle_proto_1(self, data: List[Any]) -> None: + """Handle Protocol 1 (RPC) messages.""" + if len(data) < 3: + print(f"RpcCore [Proto 1]: Invalid format: {data}", file=sys.stderr) + return + + sub_proto_id = data[1] + + # Responses: [1, 0, id, result] (Success) or [1, 1, id, error] (Error) + if sub_proto_id <2: + if len(data) < 4: + print(f"RpcCore [Proto 1]: Invalid response format: {data}", file=sys.stderr) + return + + id_ = data[2] + payload = data[3] + + promise = self._promises.pop(id_, None) + if promise: + if sub_proto_id == 0: # Success + await promise.resolve(payload) + else: # Error + await promise.reject(payload) + else: + print( + f"Received rpc reply for expired request id: {id_}, success={sub_proto_id==0}, data={payload}", + file=sys.stderr, + ) + + # Method Call (2) or Fire (3): [1, 2/3, id, method, params] + elif sub_proto_id == 2 or sub_proto_id == 3: + if len(data) < 5: + print(f"RpcCore [Proto 1]: Invalid call format: {data}", file=sys.stderr) + return + + id_ = data[2] + method = data[3] + params = data[4] + + try: + # Call the method + context = RpcCallContext(self.logger) + result = self.handle_method_call(context, method, params) + + if sub_proto_id == 2: # Expect Response + async def handle_response() -> None: + try: + resolved_result = await self._resolve_result(result) + # Send Success: [1, 0, id, result] + await self.pipe.write([1, 0, id_, resolved_result]) + except Exception as e: + # Send Error: [1, 1, id, error] + await self.pipe.write([1, 1, id_, str(e)]) + + asyncio.create_task(handle_response()) + else: + # Fire (3) + async def handle_fire() -> None: + try: + await self._resolve_result(result) + except Exception as e: print( f"Fired method error: {method}, params={params}, error={e}", file=sys.stderr, ) + asyncio.create_task(handle_fire()) - elif direction == 2: # Response - promise = self._promises.pop(id_, None) - if promise: - if method is True: # Success - await promise.resolve(params) - else: # Error - await promise.reject(params) - else: - print( - f"Received rpc reply for expired request id: {id_}, success={method}, data={params}", - file=sys.stderr, - ) - - elif direction == 3: # Event - await self._on_event(method, params) + except Exception as e: + if sub_proto_id == 2: + asyncio.create_task(self.pipe.write([1, 1, id_, str(e)])) else: - print(f"RpcV1: Invalid direction: {direction}", file=sys.stderr) + print( + f"Fired method error: {method}, params={params}, error={e}", + file=sys.stderr, + ) + else: + print(f"RpcCore [Proto 1]: Unknown sub-protocol: {sub_proto_id}", file=sys.stderr) - except Exception as e: - print(f"Error processing RPC message: {e}", file=sys.stderr) - self.pipe.on("data", on_data) + async def handle_proto_2(self, data: List[Any]) -> None: + """Handle Protocol 2 (Logging) messages.""" + # Format: [2, log_level, ref_proto, ref_id, content] + if len(data) >= 3 and data[1] == 0: + self._peer_log_level = data[2] + return + + if len(data) < 5: + print(f"RpcCore [Proto 2]: Invalid format: {data}", file=sys.stderr) + return + + log_level = data[1] + + ref_proto = data[2] + ref_id = data[3] + content = data[4] + + level_map = {1: "CRITICAL", 2: "WARN", 3: "INFO", 4: "VERBOSE", 5: "DEBUG"} + level_str = level_map.get(log_level, f"LEVEL-{log_level}") + + print(f"[RemoteLog:{level_str}] p{ref_proto}:{ref_id} {content}", file=sys.stderr) + + async def handle_proto_3(self, data: List[Any]) -> None: + print(f"RpcCore [Proto 3]: Unsupported event message: {data}", file=sys.stderr) + async def call_method(self, method: str, *args: Any) -> Any: counter = self._counter @@ -99,30 +164,98 @@ def timeout_callback(): promise = TimedPromise(self._timeout, timeout_callback) self._promises[counter] = promise - await self.pipe.write([1, 0, counter, method, list(args)]) + # New: [1, 2, counter, method, args] + await self.pipe.write([1, 2, counter, method, list(args)]) return await promise.promise async def fire_method(self, method: str, *args: Any) -> None: counter = self._counter self._counter += 1 - await self.pipe.write([1, 1, counter, method, list(args)]) - - async def emit(self, topic: str, args: Any) -> None: - await self.pipe.write([1, 3, 0, topic, args]) + # New: [1, 3, counter, method, args] + await self.pipe.write([1, 3, counter, method, list(args)]) def set_timeout(self, milliseconds: int) -> None: self._timeout = milliseconds + async def set_log_level(self, level: int) -> None: + """Send a request to set the remote log level.""" + # Protocol 2, Sub-protocol 0: [2, 0, level] + self._desired_log_level = level + await self.pipe.write([2, 0, level]) + + def _get_default_ref_id(self) -> int: + return 0 + + def _send_log(self, level: int, ref_proto: int, ref_id: Any, content: Any) -> None: + if self._peer_log_level <= 0: + return + if level > self._peer_log_level: + return + asyncio.create_task(self.pipe.write([2, level, ref_proto, ref_id, content])) + + @abstractmethod + def get_id(self) -> str: + pass + + @abstractmethod + def handle_method_call(self, context: RpcCallContext, method: str, args: List[Any]) -> Any: + pass + + @classmethod + def make_rpc_v1( + cls, + pipe: EventPipe[Any, Any], + id_: str, + method_handler: Callable, + ) -> "RpcCore": + class ConcreteRpcV1(cls): + def get_id(self) -> str: + return id_ + + def handle_method_call(self, context: RpcCallContext, method: str, args: List[Any]) -> Any: + return method_handler(context, method, args) + + async def on_event(self, context: RpcCallContext, topic: str, message: Any) -> None: + return None + + return ConcreteRpcV1(pipe) + + @classmethod + def read_only_client(cls, pipe: EventPipe[Any, Any]) -> "RpcCore": + def method_handler(context: RpcCallContext, method: str, args: List[Any]) -> Any: + raise Exception("Client Only Implementation") + + return cls.make_rpc_v1(pipe, "", method_handler) + + +class RpcV1(RpcCore): + def __init__(self, pipe: EventPipe[Any, Any]): + super().__init__(pipe) + self._waiters: Dict[str, TimedPromise] = {} + self._last_event_topic: Optional[str] = None + self.event_logger = RpcLogger(self._send_log, 3, self._get_last_event_topic) + + async def handle_proto_3(self, data: List[Any]) -> None: + if len(data) < 4: + print(f"RpcV1 [Proto 3]: Invalid event format: {data}", file=sys.stderr) + return + sub_proto_id = data[1] + if sub_proto_id != 0: + print(f"RpcV1 [Proto 3]: Unknown sub-protocol: {sub_proto_id}", file=sys.stderr) + return + await self._on_event(data[2], data[3]) + + async def emit(self, topic: str, args: Any) -> None: + await self.pipe.write([3, 0, topic, args]) + async def _on_event(self, method: str, message: Any) -> None: + self._last_event_topic = method waiter = self._waiters.pop(method, None) if waiter: await waiter.resolve(message) else: - await self.on_event(method, message) - - @abstractmethod - def get_id(self) -> str: - pass + context = RpcCallContext(self.event_logger) + await self.on_event(context, method, message) async def wait_next_event(self, topic: str, timeout_ms: Optional[int] = None) -> Any: if topic in self._waiters: @@ -140,41 +273,8 @@ def timeout_callback(): return await waiter.promise @abstractmethod - def handle_method_call(self, method: str, args: List[Any]) -> Any: + async def on_event(self, context: RpcCallContext, topic: str, message: Any) -> None: pass - @abstractmethod - async def on_event(self, topic: str, message: Any) -> None: - pass - - @staticmethod - def make_rpc_v1( - pipe: EventPipe[Any, Any], - id_: str, - method_handler: Callable, - event_handler: Callable, - ) -> "RpcV1": - class ConcreteRpcV1(RpcV1): - def get_id(self) -> str: - return id_ - - def handle_method_call(self, method: str, args: List[Any]) -> Any: - return method_handler(method, args) - - async def on_event(self, topic: str, message: Any) -> None: - if inspect.iscoroutinefunction(event_handler): - await event_handler(topic, message) - else: - event_handler(topic, message) - - return ConcreteRpcV1(pipe) - - @staticmethod - def read_only_client(pipe: EventPipe[Any, Any]) -> "RpcV1": - def method_handler(method: str, args: List[Any]) -> Any: - raise Exception("Client Only Implementation") - - async def event_handler(topic: str, message: Any) -> None: - print(f"Rpc Event dropped {topic} {message}", file=sys.stderr) - - return RpcV1.make_rpc_v1(pipe, "", method_handler, event_handler) + def _get_last_event_topic(self) -> Any: + return self._last_event_topic diff --git a/cbor_rpc/transformer/json_transformer.py b/cbor_rpc/transformer/json_transformer.py index 0626c6e..ffb745d 100644 --- a/cbor_rpc/transformer/json_transformer.py +++ b/cbor_rpc/transformer/json_transformer.py @@ -13,14 +13,11 @@ def __init__(self, encoding: str = "utf-8"): self.encoding = encoding def encode(self, data: Any) -> bytes: - try: - json_str = json.dumps( - data, ensure_ascii=False - ) # Always allow non-ASCII characters to pass through json.dumps - return json_str.encode(self.encoding) - except Exception as e: - # Removed print statement as it was for debugging - raise # Re-raise to be caught by EventTransformerPipe + json_str = json.dumps( + data, ensure_ascii=False + ) # Always allow non-ASCII characters to pass through json.dumps + return json_str.encode(self.encoding) + def decode(self, data: Union[bytes, str, None]) -> Any: if data is None: diff --git a/examples/fs_rpc/filesystem_server.py b/examples/fs_rpc/filesystem_server.py index 97982aa..06a76fe 100644 --- a/examples/fs_rpc/filesystem_server.py +++ b/examples/fs_rpc/filesystem_server.py @@ -1,5 +1,6 @@ import os from typing import List, Optional, Any +from cbor_rpc.rpc.context import RpcCallContext from cbor_rpc import RpcV1Server @@ -7,7 +8,13 @@ class FilesystemRpcServer(RpcV1Server): async def validate_event_broadcast(self, connection_id, topic, message): return False - async def handle_method_call(self, connection_id: str, method: str, args: List[Any]) -> Any: + async def handle_method_call( + self, + connection_id: str, + context: RpcCallContext, + method: str, + args: List[Any], + ) -> Any: if method == "list_files": return self.list_files(*args) elif method == "read_file": diff --git a/pyproject.toml b/pyproject.toml index 3044564..71ab03d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ requires-python = ">=3.8" dependencies = [ "pytest>=8.3.2", "pytest-asyncio>=0.24.0", + "pytest-cov>=5.0.0", "asyncssh>=2.14.0", "bcrypt", "cbor2" @@ -30,3 +31,11 @@ pythonpath = ["cbor_rpc"] [tool.black] line-length = 120 + +[tool.coverage.run] +branch = true +source = ["cbor_rpc"] + +[tool.coverage.report] +show_missing = true +skip_covered = true diff --git a/requirements.txt b/requirements.txt index b3e5e48..8908293 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,8 @@ bcrypt==5.0.0 # via cbor-rpc (pyproject.toml) cbor2==5.8.0 # via cbor-rpc (pyproject.toml) +coverage==7.6.4 + # via pytest-cov cffi==2.0.0 # via cryptography cryptography==46.0.4 @@ -31,6 +33,8 @@ pytest==8.4.0 # via # cbor-rpc (pyproject.toml) # pytest-asyncio +pytest-cov==5.0.0 + # via cbor-rpc (pyproject.toml) pytest-asyncio==1.0.0 # via cbor-rpc (pyproject.toml) tomli==2.4.0 diff --git a/setup.py b/setup.py index ed868ab..7ada36a 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ packages=find_packages(exclude=["cbor_rpc"]), install_requires=["asyncssh>=2.14.0", "bcrypt", "cbor2"], extras_require={ - "test": ["pytest>=8.3.2", "pytest-asyncio>=0.24.0"], + "test": ["pytest>=8.3.2", "pytest-asyncio>=0.24.0", "pytest-cov>=5.0.0"], }, python_requires=">=3.8", classifiers=[ diff --git a/tests/test_rpc_v1.py b/tests/test_rpc_v1.py index 1323e75..8f3495f 100644 --- a/tests/test_rpc_v1.py +++ b/tests/test_rpc_v1.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock from cbor_rpc import EventPipe, RpcV1, TimedPromise +from cbor_rpc.rpc.context import RpcCallContext from tests.helpers import SimplePipe @@ -25,7 +26,7 @@ def throw_error_method(message: str) -> None: # Method handler for RPC -def method_handler(method: str, args: List[Any]) -> Any: +def method_handler(context: RpcCallContext, method: str, args: List[Any]) -> Any: if method == "sleep": return sleep_method(*args) elif method == "add": @@ -37,9 +38,16 @@ def method_handler(method: str, args: List[Any]) -> Any: raise Exception(f"Unknown method: {method}") -# Event handler for RPC -async def event_handler(topic: str, message: Any) -> None: - pass # No-op for testing + +class EventRpcHelper(RpcV1): + def get_id(self) -> str: + return "event_rpc" + + def handle_method_call(self, context: RpcCallContext, method: str, args: List[Any]) -> Any: + raise Exception("Event-only RPC") + + async def on_event(self, context: RpcCallContext, topic: str, payload: Any) -> None: + pass @pytest.fixture @@ -47,9 +55,14 @@ def pipe(): return SimplePipe() +@pytest.fixture +def event_rpc(pipe): + return EventRpcHelper(pipe) + + @pytest.fixture def rpc(pipe): - return RpcV1.make_rpc_v1(pipe, "test_id", method_handler, event_handler) + return RpcV1.make_rpc_v1(pipe, "test_id", method_handler) @pytest.mark.asyncio @@ -105,49 +118,48 @@ async def test_fire_method(rpc): @pytest.mark.asyncio -async def test_emit_event(rpc): - await rpc.emit("test_topic", {"data": "test"}) - # No assertion needed; just verify no crash +async def test_emit_event(event_rpc): + await event_rpc.emit("test_topic", {"data": "test"}) @pytest.mark.asyncio -async def test_wait_next_event_success(rpc, pipe): +async def test_wait_next_event_success(event_rpc, pipe): async def simulate_event(): await asyncio.sleep(0.1) - await pipe.write([1, 3, 0, "test_topic", {"data": "test"}]) + await pipe.write([3, 0, "test_topic", {"data": "test"}]) task = asyncio.create_task(simulate_event()) - result = await rpc.wait_next_event("test_topic", 1000) + result = await event_rpc.wait_next_event("test_topic", 1000) assert result == {"data": "test"} await task @pytest.mark.asyncio -async def test_wait_next_event_timeout(rpc): +async def test_wait_next_event_timeout(event_rpc): with pytest.raises(Exception) as exc_info: - await rpc.wait_next_event("test_topic", 100) + await event_rpc.wait_next_event("test_topic", 100) assert exc_info.value.args[0]["timeout"] is True assert exc_info.value.args[0]["timeoutPeriod"] == 100 @pytest.mark.asyncio -async def test_wait_next_event_already_waiting(rpc): - rpc._waiters["test_topic"] = TimedPromise(1000) +async def test_wait_next_event_already_waiting(event_rpc): + event_rpc._waiters["test_topic"] = TimedPromise(1000) with pytest.raises(Exception) as exc_info: - await rpc.wait_next_event("test_topic") + await event_rpc.wait_next_event("test_topic") assert str(exc_info.value) == "Already waiting for event" @pytest.mark.asyncio async def test_invalid_message_format(rpc, pipe): - await pipe.write([1, 2, 3]) # Invalid message + await pipe.write([1]) # Invalid message await asyncio.sleep(0.1) # Allow processing # No crash means test passes @pytest.mark.asyncio -async def test_unsupported_version(rpc, pipe): - await pipe.write([2, 0, 0, "add", [1, 2]]) # Wrong version +async def test_unsupported_protocol(rpc, pipe): + await pipe.write([99, 0, 0, "add", [1, 2]]) # Unsupported protocol await asyncio.sleep(0.1) # Allow processing # No crash means test passes @@ -165,6 +177,7 @@ async def test_read_only_client(pipe): # We need to directly call the handle_method_call method to test it with pytest.raises(Exception) as exc_info: - read_only.handle_method_call("add", [1, 2]) + context = RpcCallContext(read_only.logger) + read_only.handle_method_call(context, "add", [1, 2]) assert str(exc_info.value) == "Client Only Implementation" From 8fc81f01dc61436c04d01fc6f7461114ceaff299 Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Sun, 8 Feb 2026 15:55:36 +0545 Subject: [PATCH 18/25] Refactor tests --- .coveragerc | 10 + .gitignore | 1 + cbor_rpc/rpc/rpc_server.py | 2 +- cbor_rpc/rpc/rpc_v1.py | 42 +- cbor_rpc/stdio/stdio_pipe.py | 9 +- cbor_rpc/tcp/tcp.py | 8 - cbor_rpc/transformer/base/transformer_base.py | 12 +- cbor_rpc/transformer/base/transformer_pipe.py | 21 +- cbor_rpc/transformer/cbor_transformer.py | 2 + examples/fs_rpc/filesystem_server.py | 10 +- setup.py | 2 +- tests/event/test_emitter.py | 284 +++++++++++++ tests/helpers/stream_pair.py | 12 + tests/misc/test_timed_promise.py | 26 ++ tests/pipe/test_aio_pipe.py | 150 +++++++ tests/{ => pipe}/test_event_pipe.py | 56 +-- tests/{ => pipe}/test_pipe.py | 50 +-- tests/pipe/test_pipe_extra.py | 94 +++++ tests/rpc/test_rpc_base.py | 56 +++ tests/rpc/test_rpc_logging.py | 150 +++++++ tests/{ => rpc}/test_rpc_v1.py | 31 +- tests/rpc/test_rpc_v1_extra.py | 224 ++++++++++ tests/rpc/test_server_base.py | 69 ++++ tests/{ => ssh}/test_ssh_docker_pipe.py | 106 ++--- tests/ssh/test_ssh_pipe.py | 43 ++ tests/stdio/test_stdio_pipe.py | 60 +++ tests/{ => tcp}/test_tcp.py | 92 ++--- tests/tcp/test_tcp_pipe_errors.py | 33 ++ tests/test_event_emitter.py | 386 ------------------ tests/test_stdio_rpc.py | 42 -- .../test_cbor_transformer.py | 77 ++-- .../test_event_transformer_pipe.py | 83 ++++ .../test_json_transformer.py | 94 ++--- tests/transformer/test_transformer_base.py | 62 +++ tests/transformer/test_transformer_pipe.py | 216 ++++++++++ 35 files changed, 1791 insertions(+), 824 deletions(-) create mode 100644 .coveragerc create mode 100644 tests/event/test_emitter.py create mode 100644 tests/helpers/stream_pair.py create mode 100644 tests/misc/test_timed_promise.py create mode 100644 tests/pipe/test_aio_pipe.py rename tests/{ => pipe}/test_event_pipe.py (71%) rename tests/{ => pipe}/test_pipe.py (70%) create mode 100644 tests/pipe/test_pipe_extra.py create mode 100644 tests/rpc/test_rpc_base.py create mode 100644 tests/rpc/test_rpc_logging.py rename tests/{ => rpc}/test_rpc_v1.py (85%) create mode 100644 tests/rpc/test_rpc_v1_extra.py create mode 100644 tests/rpc/test_server_base.py rename tests/{ => ssh}/test_ssh_docker_pipe.py (81%) create mode 100644 tests/ssh/test_ssh_pipe.py create mode 100644 tests/stdio/test_stdio_pipe.py rename tests/{ => tcp}/test_tcp.py (77%) create mode 100644 tests/tcp/test_tcp_pipe_errors.py delete mode 100644 tests/test_event_emitter.py delete mode 100644 tests/test_stdio_rpc.py rename tests/{ => transformer}/test_cbor_transformer.py (85%) create mode 100644 tests/transformer/test_event_transformer_pipe.py rename tests/{ => transformer}/test_json_transformer.py (63%) create mode 100644 tests/transformer/test_transformer_base.py create mode 100644 tests/transformer/test_transformer_pipe.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..bf9bd99 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,10 @@ +[run] +branch = True +source = + cbor_rpc + +[report] +show_missing = True +skip_covered = True +omit = + tests/* diff --git a/.gitignore b/.gitignore index 72416f8..79cb3b1 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ venv/ __pycache__/ .pytest_cache/ cbor_rpc.egg-info/ +.coverage diff --git a/cbor_rpc/rpc/rpc_server.py b/cbor_rpc/rpc/rpc_server.py index e7eca90..6d26c78 100644 --- a/cbor_rpc/rpc/rpc_server.py +++ b/cbor_rpc/rpc/rpc_server.py @@ -9,7 +9,7 @@ class RpcV1Server(RpcServer): - def __init__(self, server: Server): + def __init__(self): self.active_connections: Dict[str, EventPipe[Any, Any]] = {} self.rpc_clients: Dict[str, RpcV1] = {} self.timeout = 30000 diff --git a/cbor_rpc/rpc/rpc_v1.py b/cbor_rpc/rpc/rpc_v1.py index 1c77e7c..596b85c 100644 --- a/cbor_rpc/rpc/rpc_v1.py +++ b/cbor_rpc/rpc/rpc_v1.py @@ -1,4 +1,5 @@ import sys +import logging from typing import Any, Dict, List, Optional, Callable from abc import abstractmethod import asyncio @@ -9,6 +10,8 @@ from cbor_rpc.pipe.event_pipe import EventPipe from cbor_rpc.timed_promise import TimedPromise +logger = logging.getLogger(__name__) + class RpcCore(RpcClient): protocol_id = 1 @@ -32,7 +35,7 @@ async def resolve_result(result: Any) -> Any: async def on_data(data: List[Any]) -> None: try: if not isinstance(data, list) or len(data) < 2: - print(f"RpcCore: Invalid message format: {data}", file=sys.stderr) + logger.warning(f"RpcCore: Invalid message format: {data}") return protocol_id = data[0] @@ -43,17 +46,17 @@ async def on_data(data: List[Any]) -> None: elif protocol_id == 3: await self.handle_proto_3(data) else: - print(f"RpcCore: Unsupported protocol: {data}", file=sys.stderr) + logger.warning(f"RpcCore: Unsupported protocol: {data}") except Exception as e: - print(f"Error processing RPC message: {e}", file=sys.stderr) + logger.error(f"Error processing RPC message: {e}") self.pipe.on("data", on_data) async def handle_proto_1(self, data: List[Any]) -> None: """Handle Protocol 1 (RPC) messages.""" if len(data) < 3: - print(f"RpcCore [Proto 1]: Invalid format: {data}", file=sys.stderr) + logger.warning(f"RpcCore [Proto 1]: Invalid format: {data}") return sub_proto_id = data[1] @@ -61,7 +64,7 @@ async def handle_proto_1(self, data: List[Any]) -> None: # Responses: [1, 0, id, result] (Success) or [1, 1, id, error] (Error) if sub_proto_id <2: if len(data) < 4: - print(f"RpcCore [Proto 1]: Invalid response format: {data}", file=sys.stderr) + logger.warning(f"RpcCore [Proto 1]: Invalid response format: {data}") return id_ = data[2] @@ -74,15 +77,12 @@ async def handle_proto_1(self, data: List[Any]) -> None: else: # Error await promise.reject(payload) else: - print( - f"Received rpc reply for expired request id: {id_}, success={sub_proto_id==0}, data={payload}", - file=sys.stderr, - ) + logger.warning(f"Received rpc reply for expired request id: {id_}, success={sub_proto_id==0}, data={payload}") # Method Call (2) or Fire (3): [1, 2/3, id, method, params] elif sub_proto_id == 2 or sub_proto_id == 3: if len(data) < 5: - print(f"RpcCore [Proto 1]: Invalid call format: {data}", file=sys.stderr) + logger.warning(f"RpcCore [Proto 1]: Invalid call format: {data}") return id_ = data[2] @@ -111,22 +111,16 @@ async def handle_fire() -> None: try: await self._resolve_result(result) except Exception as e: - print( - f"Fired method error: {method}, params={params}, error={e}", - file=sys.stderr, - ) + logger.error(f"Fired method error: {method}, params={params}, error={e}") asyncio.create_task(handle_fire()) except Exception as e: if sub_proto_id == 2: asyncio.create_task(self.pipe.write([1, 1, id_, str(e)])) else: - print( - f"Fired method error: {method}, params={params}, error={e}", - file=sys.stderr, - ) + logger.error(f"Fired method error: {method}, params={params}, error={e}") else: - print(f"RpcCore [Proto 1]: Unknown sub-protocol: {sub_proto_id}", file=sys.stderr) + logger.warning(f"RpcCore [Proto 1]: Unknown sub-protocol: {sub_proto_id}") async def handle_proto_2(self, data: List[Any]) -> None: @@ -137,7 +131,7 @@ async def handle_proto_2(self, data: List[Any]) -> None: return if len(data) < 5: - print(f"RpcCore [Proto 2]: Invalid format: {data}", file=sys.stderr) + logger.warning(f"RpcCore [Proto 2]: Invalid format: {data}") return log_level = data[1] @@ -149,10 +143,10 @@ async def handle_proto_2(self, data: List[Any]) -> None: level_map = {1: "CRITICAL", 2: "WARN", 3: "INFO", 4: "VERBOSE", 5: "DEBUG"} level_str = level_map.get(log_level, f"LEVEL-{log_level}") - print(f"[RemoteLog:{level_str}] p{ref_proto}:{ref_id} {content}", file=sys.stderr) + logger.info(f"[RemoteLog:{level_str}] p{ref_proto}:{ref_id} {content}") async def handle_proto_3(self, data: List[Any]) -> None: - print(f"RpcCore [Proto 3]: Unsupported event message: {data}", file=sys.stderr) + logger.warning(f"RpcCore [Proto 3]: Unsupported event message: {data}") async def call_method(self, method: str, *args: Any) -> Any: @@ -237,11 +231,11 @@ def __init__(self, pipe: EventPipe[Any, Any]): async def handle_proto_3(self, data: List[Any]) -> None: if len(data) < 4: - print(f"RpcV1 [Proto 3]: Invalid event format: {data}", file=sys.stderr) + logger.warning(f"RpcV1 [Proto 3]: Invalid event format: {data}") return sub_proto_id = data[1] if sub_proto_id != 0: - print(f"RpcV1 [Proto 3]: Unknown sub-protocol: {sub_proto_id}", file=sys.stderr) + logger.warning(f"RpcV1 [Proto 3]: Unknown sub-protocol: {sub_proto_id}") return await self._on_event(data[2], data[3]) diff --git a/cbor_rpc/stdio/stdio_pipe.py b/cbor_rpc/stdio/stdio_pipe.py index c1d1d56..97ab01d 100644 --- a/cbor_rpc/stdio/stdio_pipe.py +++ b/cbor_rpc/stdio/stdio_pipe.py @@ -67,11 +67,10 @@ async def wait_for_process_termination(self) -> int: raise RuntimeError("No subprocess associated with this StdioPipe instance.") return await self._process.wait() - def terminate(self): + async def terminate(self, *args: Any): """ Terminates the started subprocess if one exists. - Raises RuntimeError if no process was started by this pipe. """ - if not self._process: - raise RuntimeError("No subprocess associated with this StdioPipe instance.") - self._process.terminate() + if self._process and self._process.returncode is None: + self._process.terminate() + await super().terminate(*args) diff --git a/cbor_rpc/tcp/tcp.py b/cbor_rpc/tcp/tcp.py index 9a61803..b8e87b2 100644 --- a/cbor_rpc/tcp/tcp.py +++ b/cbor_rpc/tcp/tcp.py @@ -261,11 +261,3 @@ def get_address(self) -> Tuple[str, int]: if self._server and self._server.sockets: return self._server.sockets[0].getsockname()[:2] return ("", 0) - - @abstractmethod - async def accept(self, pipe: TcpPipe) -> bool: - pass - - async def close(self) -> None: - """Legacy method - use stop() instead.""" - await self.stop() diff --git a/cbor_rpc/transformer/base/transformer_base.py b/cbor_rpc/transformer/base/transformer_base.py index 4fd67cb..7799c00 100644 --- a/cbor_rpc/transformer/base/transformer_base.py +++ b/cbor_rpc/transformer/base/transformer_base.py @@ -37,10 +37,10 @@ def bind(self, pipe: Pipe) -> TransformerPipe: ... def bind(self, pipe: EventPipe) -> EventTransformerPipe: ... @overload - def applyTransformer(self, pipe: Pipe) -> TransformerPipe: ... + def apply_transformer(self, pipe: Pipe) -> TransformerPipe: ... @overload - def applyTransformer(self, pipe: EventPipe) -> EventTransformerPipe: ... - def applyTransformer(self, pipe: Union[Pipe, EventPipe]) -> Union[TransformerPipe, EventTransformerPipe]: + def apply_transformer(self, pipe: EventPipe) -> EventTransformerPipe: ... + def apply_transformer(self, pipe: Union[Pipe, EventPipe]) -> Union[TransformerPipe, EventTransformerPipe]: if isinstance(pipe, EventPipe): return EventTransformerPipe(pipe, self.to_async()) elif isinstance(pipe, Pipe): @@ -90,10 +90,10 @@ def bind(self, pipe: Pipe) -> TransformerPipe: ... def bind(self, pipe: EventPipe) -> EventTransformerPipe: ... @overload - def applyTransformer(self, pipe: Pipe) -> TransformerPipe: ... + def apply_transformer(self, pipe: Pipe) -> TransformerPipe: ... @overload - def applyTransformer(self, pipe: EventPipe) -> EventTransformerPipe: ... - def applyTransformer(self, pipe: Union[Pipe, EventPipe]) -> Union[TransformerPipe, EventTransformerPipe]: + def apply_transformer(self, pipe: EventPipe) -> EventTransformerPipe: ... + def apply_transformer(self, pipe: Union[Pipe, EventPipe]) -> Union[TransformerPipe, EventTransformerPipe]: if isinstance(pipe, EventPipe): return EventTransformerPipe(pipe, self) elif isinstance(pipe, Pipe): diff --git a/cbor_rpc/transformer/base/transformer_pipe.py b/cbor_rpc/transformer/base/transformer_pipe.py index 21ac1e1..714c3a7 100644 --- a/cbor_rpc/transformer/base/transformer_pipe.py +++ b/cbor_rpc/transformer/base/transformer_pipe.py @@ -1,5 +1,6 @@ from typing import Any, Awaitable, Optional, TypeVar, Callable from typing import TYPE_CHECKING +import asyncio import time from .base_exception import NeedsMoreDataException @@ -19,14 +20,19 @@ class TransformerPipe(Pipe[T1, T2]): def __init__(self, pipe: Pipe[Any, Any], transformer: "Optional[Transformer[T1, T2]]"): super().__init__() self.pipe = pipe + self._closed = False + self._terminated = False + + if transformer is None: + raise ValueError("A transformer must be provided") self.encode = transformer.encode self.decode = transformer.decode - def _handle_error(*args): + async def _handle_error(*args): if not self._closed: self._closed = True - self.pipe.terminate() + await self.terminate() self._emit("error", *args) def _handle_close(*args): @@ -42,7 +48,7 @@ async def write(self, chunk: T1) -> bool: return False try: encoded = await self.encode(chunk) - return self.pipe.write(encoded) + return await self.pipe.write(encoded) except Exception as e: self._emit("error", e) return False @@ -56,7 +62,7 @@ async def read(self, timeout: Optional[float] = None) -> Optional[T1]: try: while True: - raw = self.pipe.read(remaining) + raw = await self.pipe.read(remaining) if raw is None: return None @@ -72,11 +78,12 @@ async def read(self, timeout: Optional[float] = None) -> Optional[T1]: self._emit("error", e) return None - def terminate(self) -> None: - if self._closed: + async def terminate(self) -> None: + if self._terminated: return + self._terminated = True self._closed = True - self.pipe.terminate() + await self.pipe.terminate() def _propagate_error(self, *args): self.pipe._emit("error", *args) diff --git a/cbor_rpc/transformer/cbor_transformer.py b/cbor_rpc/transformer/cbor_transformer.py index dee4db0..2db27bf 100644 --- a/cbor_rpc/transformer/cbor_transformer.py +++ b/cbor_rpc/transformer/cbor_transformer.py @@ -66,6 +66,8 @@ async def decode(self, data: Union[bytes, None]) -> Any: stream = BytesIO(self._buffer) decoder = cbor2.CBORDecoder(stream) decoded_data = decoder.decode() + if decoded_data is cbor2.break_marker: + raise cbor2.CBORDecodeError("Unexpected break marker") bytes_consumed = stream.tell() self._buffer = self._buffer[bytes_consumed:] diff --git a/examples/fs_rpc/filesystem_server.py b/examples/fs_rpc/filesystem_server.py index 06a76fe..0b458f6 100644 --- a/examples/fs_rpc/filesystem_server.py +++ b/examples/fs_rpc/filesystem_server.py @@ -79,14 +79,20 @@ def rename_file(self, src: str, dest: str) -> bool: from cbor_rpc.tcp import TcpPipe, TcpServer from cbor_rpc.transformer.json_transformer import JsonTransformer + class SimpleTcpServer(TcpServer): + async def accept(self, pipe: TcpPipe) -> bool: + return True + async def main(): rpc_id = 1 # Create a TCP server that handles connections, using JsonTransformer for RPC messages - tcp_server = await TcpServer.create("localhost", 8000) + tcp_server = await SimpleTcpServer.create("localhost", 8000) print("Server running on port 8000") # Set up event handlers for new connections - async def handle_connection(rpc_pipe): + async def handle_connection(pipe): + json_transformer = JsonTransformer() + rpc_pipe = json_transformer.apply_transformer(pipe) server = FilesystemRpcServer() await server.add_connection(str(rpc_id), rpc_pipe) diff --git a/setup.py b/setup.py index 7ada36a..e290e9e 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ long_description=open("README.md").read(), long_description_content_type="text/markdown", url="https://github.com/mesudip/cbor-rpc-py ", - packages=find_packages(exclude=["cbor_rpc"]), + packages=find_packages(include=["cbor_rpc", "cbor_rpc.*"]), install_requires=["asyncssh>=2.14.0", "bcrypt", "cbor2"], extras_require={ "test": ["pytest>=8.3.2", "pytest-asyncio>=0.24.0", "pytest-cov>=5.0.0"], diff --git a/tests/event/test_emitter.py b/tests/event/test_emitter.py new file mode 100644 index 0000000..4eb3331 --- /dev/null +++ b/tests/event/test_emitter.py @@ -0,0 +1,284 @@ +import asyncio +from typing import Any + +import pytest + +from cbor_rpc.event.emitter import AbstractEmitter + + +@pytest.mark.asyncio +async def test_on_and_emit(): + class TestEmitter(AbstractEmitter): + pass + + emitter = TestEmitter() + events = [] + + async def async_handler1(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_handler1_{data}") + + def sync_handler1(data: Any): + events.append(f"sync_handler1_{data}") + + async def async_handler2(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_handler2_{data}") + + emitter.on("test", async_handler1) + emitter.on("test", sync_handler1) + emitter.on("test", async_handler2) + + emitter._emit("test", "event1") + await asyncio.sleep(0.02) + + expected = [ + "async_handler1_event1", + "sync_handler1_event1", + "async_handler2_event1", + ] + assert sorted(events) == sorted(expected), f"Expected {expected}, got {events}" + + +@pytest.mark.asyncio +async def test_pipeline_and_notify(): + class TestEmitter(AbstractEmitter): + pass + + emitter = TestEmitter() + events = [] + + async def async_handler1(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_handler1_{data}") + + def sync_handler1(data: Any): + events.append(f"sync_handler1_{data}") + + async def async_pipeline1(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_pipeline1_{data}") + + def sync_pipeline1(data: Any): + events.append(f"sync_pipeline1_{data}") + + async def async_pipeline2(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_pipeline2_{data}") + + emitter.on("test", async_handler1) + emitter.on("test", sync_handler1) + emitter.pipeline("test", async_pipeline1) + emitter.pipeline("test", sync_pipeline1) + emitter.pipeline("test", async_pipeline2) + + await emitter._notify("test", "event2") + await asyncio.sleep(0.02) + + expected_pipelines = [ + "async_pipeline1_event2", + "sync_pipeline1_event2", + "async_pipeline2_event2", + ] + expected_subscribers = ["async_handler1_event2", "sync_handler1_event2"] + pipeline_indices = [events.index(e) for e in expected_pipelines if e in events] + subscriber_indices = [events.index(e) for e in expected_subscribers if e in events] + assert all( + p < s for p in pipeline_indices for s in subscriber_indices + ), f"Pipelines {expected_pipelines} should precede subscribers {expected_subscribers} in {events}" + assert sorted(events) == sorted( + expected_pipelines + expected_subscribers + ), f"Expected {expected_pipelines + expected_subscribers}, got {events}" + + +@pytest.mark.asyncio +async def test_unsubscribe(): + class TestEmitter(AbstractEmitter): + pass + + emitter = TestEmitter() + events = [] + + async def async_handler1(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_handler1_{data}") + + def sync_handler1(data: Any): + events.append(f"sync_handler1_{data}") + + emitter.on("test", async_handler1) + emitter.on("test", sync_handler1) + emitter.unsubscribe("test", async_handler1) + + emitter._emit("test", "event3") + await asyncio.sleep(0.02) + + expected = ["sync_handler1_event3"] + assert events == expected, f"Expected {expected}, got {events}" + + +@pytest.mark.asyncio +async def test_replace_on_handler(): + class TestEmitter(AbstractEmitter): + pass + + emitter = TestEmitter() + events = [] + + async def async_handler1(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_handler1_{data}") + + def sync_handler1(data: Any): + events.append(f"sync_handler1_{data}") + + emitter.on("test", sync_handler1) + emitter.replace_on_handler("test", async_handler1) + + emitter._emit("test", "event4") + await asyncio.sleep(0.02) + + expected = ["async_handler1_event4"] + assert events == expected, f"Expected {expected}, got {events}" + + +@pytest.mark.asyncio +async def test_pipeline_failure(): + class TestEmitter(AbstractEmitter): + pass + + emitter = TestEmitter() + events = [] + + async def async_pipeline1(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_pipeline1_{data}") + raise ValueError("Pipeline failed") + + def sync_handler1(data: Any): + events.append(f"sync_handler1_{data}") + + emitter.pipeline("test", async_pipeline1) + emitter.on("test", sync_handler1) + + with pytest.raises(ValueError): + await emitter._notify("test", "event5") + + expected = ["async_pipeline1_event5"] + assert events == expected, f"Expected {expected}, got {events}" + + +@pytest.mark.asyncio +async def test_multiple_event_types(): + class TestEmitter(AbstractEmitter): + pass + + emitter = TestEmitter() + events = [] + + async def async_handler_a(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_handler_a_{data}") + + def sync_handler_a(data: Any): + events.append(f"sync_handler_a_{data}") + + async def async_handler_b(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_handler_b_{data}") + + def sync_handler_b(data: Any): + events.append(f"sync_handler_b_{data}") + + async def async_pipeline_a(data: Any): + await asyncio.sleep(0.01) + events.append(f"async_pipeline_a_{data}") + + def sync_pipeline_b(data: Any): + events.append(f"sync_pipeline_b_{data}") + + emitter.on("event_a", async_handler_a) + emitter.on("event_a", sync_handler_a) + emitter.pipeline("event_a", async_pipeline_a) + emitter.on("event_b", async_handler_b) + emitter.on("event_b", sync_handler_b) + emitter.pipeline("event_b", sync_pipeline_b) + + events.clear() + emitter._emit("event_a", "data_a") + await asyncio.sleep(0.02) + expected = ["async_handler_a_data_a", "sync_handler_a_data_a"] + assert sorted(events) == sorted(expected), f"Expected {expected}, got {events}" + + events.clear() + await emitter._notify("event_a", "data_a2") + await asyncio.sleep(0.02) + expected_pipelines = ["async_pipeline_a_data_a2"] + expected_subscribers = ["async_handler_a_data_a2", "sync_handler_a_data_a2"] + pipeline_indices = [events.index(e) for e in expected_pipelines if e in events] + subscriber_indices = [events.index(e) for e in expected_subscribers if e in events] + assert all( + p < s for p in pipeline_indices for s in subscriber_indices + ), f"Pipelines {expected_pipelines} should precede subscribers {expected_subscribers} in {events}" + assert sorted(events) == sorted( + expected_pipelines + expected_subscribers + ), f"Expected {expected_pipelines + expected_subscribers}, got {events}" + + +class DummyEmitter(AbstractEmitter): + pass + + +def test_emitter_no_running_loop_warning(): + emitter = DummyEmitter() + with pytest.warns(RuntimeWarning): + emitter._run_background_task(lambda: None) + + +def test_emitter_emit_sync_error_warning(): + emitter = DummyEmitter() + + def bad_handler(_data: Any) -> None: + raise ValueError("boom") + + emitter.on("evt", bad_handler) + with pytest.warns(RuntimeWarning): + emitter._emit("evt", "data") + + +@pytest.mark.asyncio +async def test_emitter_notify_pipeline_errors(): + emitter = DummyEmitter() + errors = [] + + def on_error(err: Exception) -> None: + errors.append(err) + + def bad_pipeline(_data: Any) -> None: + raise ValueError("boom") + + emitter.on("error", on_error) + emitter.pipeline("evt", bad_pipeline) + + with pytest.raises(ValueError): + await emitter._notify("evt", "data") + assert errors + + +@pytest.mark.asyncio +async def test_emitter_notify_async_pipeline_error(): + emitter = DummyEmitter() + errors = [] + + async def bad_pipeline(_data: Any) -> None: + raise ValueError("boom") + + def on_error(err: Exception) -> None: + errors.append(err) + + emitter.on("error", on_error) + emitter.pipeline("evt", bad_pipeline) + + with pytest.raises(ValueError): + await emitter._notify("evt", "data") + assert errors diff --git a/tests/helpers/stream_pair.py b/tests/helpers/stream_pair.py new file mode 100644 index 0000000..7e3a7fc --- /dev/null +++ b/tests/helpers/stream_pair.py @@ -0,0 +1,12 @@ +import asyncio +from typing import Tuple + + +async def create_stream_pair() -> Tuple[asyncio.AbstractServer, asyncio.StreamReader, asyncio.StreamWriter]: + async def handler(_reader: asyncio.StreamReader, _writer: asyncio.StreamWriter) -> None: + await asyncio.sleep(0.2) + + server = await asyncio.start_server(handler, "127.0.0.1", 0) + host, port = server.sockets[0].getsockname()[:2] + reader, writer = await asyncio.open_connection(host, port) + return server, reader, writer diff --git a/tests/misc/test_timed_promise.py b/tests/misc/test_timed_promise.py new file mode 100644 index 0000000..c3cdbf8 --- /dev/null +++ b/tests/misc/test_timed_promise.py @@ -0,0 +1,26 @@ +import pytest + +from cbor_rpc.timed_promise import TimedPromise + + +@pytest.mark.asyncio +async def test_timed_promise_resolve_reject_and_timeout(): + resolved = TimedPromise(10) + await resolved.resolve("ok") + assert await resolved.promise == "ok" + + rejected = TimedPromise(10) + await rejected.reject("bad") + with pytest.raises(Exception): + await rejected.promise + + timeout_called = [] + + def on_timeout(): + timeout_called.append(True) + + timed = TimedPromise(1, on_timeout) + with pytest.raises(Exception) as exc_info: + await timed.promise + assert exc_info.value.args[0]["timeout"] is True + assert timeout_called == [True] diff --git a/tests/pipe/test_aio_pipe.py b/tests/pipe/test_aio_pipe.py new file mode 100644 index 0000000..42ed86f --- /dev/null +++ b/tests/pipe/test_aio_pipe.py @@ -0,0 +1,150 @@ +import asyncio +from typing import Any, List, Optional + +import pytest + +from cbor_rpc.pipe.aio_pipe import AioPipe + + +class FakeReader: + def __init__(self, responses: List[Any]): + self._responses = list(responses) + + async def read(self, _size: int) -> bytes: + if not self._responses: + return b"" + item = self._responses.pop(0) + if isinstance(item, Exception): + raise item + return item + + +class FakeWriter: + def __init__(self, drain_error: Optional[Exception] = None, close_error: Optional[Exception] = None): + self._drain_error = drain_error + self._close_error = close_error + self._closed = False + + def write(self, _chunk: bytes) -> None: + return None + + async def drain(self) -> None: + if self._drain_error: + raise self._drain_error + + def close(self) -> None: + if self._close_error: + raise self._close_error + self._closed = True + + async def wait_closed(self) -> None: + return None + + +class TestAioPipe(AioPipe[bytes, bytes]): + pass + + +class NotifyFailPipe(TestAioPipe): + async def _notify(self, event_type: str, *args: Any) -> None: + raise RuntimeError(f"notify:{event_type}") + + +class NotifyOncePipe(TestAioPipe): + def __init__(self, reader: FakeReader, writer: FakeWriter): + super().__init__(reader, writer) + self._notified = False + + async def _notify(self, event_type: str, *args: Any) -> None: + if event_type == "data" and not self._notified: + self._notified = True + raise RuntimeError("notify error") + return await super()._notify(event_type, *args) + + +@pytest.mark.asyncio +async def test_aio_pipe_init_requires_reader_and_writer(): + with pytest.raises(ValueError): + TestAioPipe(reader=FakeReader([]), writer=None) + + +@pytest.mark.asyncio +async def test_aio_pipe_setup_without_reader_writer(): + pipe = TestAioPipe() + with pytest.raises(RuntimeError): + await pipe._setup_connection() + + +@pytest.mark.asyncio +async def test_aio_pipe_setup_notify_error_closes(): + pipe = NotifyFailPipe(FakeReader([b"data", b""]), FakeWriter()) + + with pytest.raises(RuntimeError): + await pipe._setup_connection() + + assert pipe.is_connected() is False + + +@pytest.mark.asyncio +async def test_aio_pipe_read_loop_notify_error_emits_close(): + pipe = NotifyOncePipe(FakeReader([b"data", b""]), FakeWriter()) + closed = asyncio.Event() + + def on_close(*_args: Any) -> None: + closed.set() + + pipe.on("close", on_close) + await pipe._setup_connection() + await asyncio.wait_for(closed.wait(), timeout=1) + + +@pytest.mark.asyncio +async def test_aio_pipe_read_loop_reader_error_emits_error(): + pipe = TestAioPipe(FakeReader([ValueError("boom")]), FakeWriter()) + errors: List[Any] = [] + + def on_error(err: Exception) -> None: + errors.append(err) + + pipe.on("error", on_error) + await pipe._setup_connection() + await asyncio.sleep(0.05) + assert errors + + +@pytest.mark.asyncio +async def test_aio_pipe_write_errors(): + pipe = TestAioPipe(FakeReader([]), FakeWriter()) + + with pytest.raises(ConnectionError): + await pipe.write(b"data") + + pipe._connected = True + with pytest.raises(TypeError): + await pipe.write("data") + + errors: List[Any] = [] + + def on_error(err: Exception) -> None: + errors.append(err) + + pipe.on("error", on_error) + pipe._writer = FakeWriter(drain_error=RuntimeError("drain")) + ok = await pipe.write(b"data") + assert ok is False + assert errors + + +@pytest.mark.asyncio +async def test_aio_pipe_close_writer_exception_emits_error(): + pipe = TestAioPipe(FakeReader([]), FakeWriter()) + errors: List[Any] = [] + + def on_error(err: Exception) -> None: + errors.append(err) + + pipe.on("error", on_error) + pipe._connected = True + pipe._writer = FakeWriter(close_error=RuntimeError("close")) + await pipe._close_connection() + assert errors diff --git a/tests/test_event_pipe.py b/tests/pipe/test_event_pipe.py similarity index 71% rename from tests/test_event_pipe.py rename to tests/pipe/test_event_pipe.py index a456211..d5020d7 100644 --- a/tests/test_event_pipe.py +++ b/tests/pipe/test_event_pipe.py @@ -1,22 +1,22 @@ -import pytest import asyncio from typing import Any, Tuple -from cbor_rpc import EventPipe + +import pytest import pytest_asyncio +from cbor_rpc import EventPipe + @pytest_asyncio.fixture async def event_pipe_pair(): pipe1, pipe2 = EventPipe.create_inmemory_pair() yield pipe1, pipe2 - # Terminate is an async method, so it needs to be awaited await pipe1.terminate() await pipe2.terminate() @pytest.mark.asyncio async def test_create_pair(event_pipe_pair: Tuple[EventPipe, EventPipe]): - # Positive case: Creating a pair of async pipes pipe1, pipe2 = event_pipe_pair assert isinstance(pipe1, EventPipe) assert isinstance(pipe2, EventPipe) @@ -24,24 +24,20 @@ async def test_create_pair(event_pipe_pair: Tuple[EventPipe, EventPipe]): @pytest.mark.asyncio async def test_write_success(event_pipe_pair: Tuple[EventPipe, EventPipe]): - # Positive case: Writing a chunk successfully - pipe1, pipe2 = event_pipe_pair + pipe1, _pipe2 = event_pipe_pair result = await pipe1.write("test_chunk") assert result is True @pytest.mark.asyncio async def test_terminate_success(event_pipe_pair: Tuple[EventPipe, EventPipe]): - # Positive case: Terminating the pipe - pipe1, pipe2 = event_pipe_pair + pipe1, _pipe2 = event_pipe_pair await pipe1.terminate() - # No exception should be raised @pytest.mark.asyncio async def test_pipeline_execution(event_pipe_pair: Tuple[EventPipe, EventPipe]): - # Positive case: Adding and executing a pipeline - pipe1, _ = event_pipe_pair + pipe1, _pipe2 = event_pipe_pair received_chunk = None event = asyncio.Event() @@ -52,13 +48,12 @@ async def pipeline_handler(chunk: Any) -> None: pipe1.pipeline("data", pipeline_handler) await pipe1._notify("data", "test_chunk") - await asyncio.wait_for(event.wait(), timeout=1) # Wait for the handler to be called + await asyncio.wait_for(event.wait(), timeout=1) assert received_chunk == "test_chunk" @pytest.mark.asyncio async def test_pipe_pair(event_pipe_pair: Tuple[EventPipe, EventPipe]): - # Positive case: Attaching two pipes pipe1, pipe2 = event_pipe_pair received_chunk = None event = asyncio.Event() @@ -70,14 +65,13 @@ async def handler(chunk: Any) -> None: pipe2.pipeline("data", handler) await pipe1.write("test_chunk") - await asyncio.wait_for(event.wait(), timeout=1) # Wait for the handler to be called + await asyncio.wait_for(event.wait(), timeout=1) assert received_chunk == "test_chunk" @pytest.mark.asyncio async def test_write_after_terminate(event_pipe_pair: Tuple[EventPipe, EventPipe]): - # Negative case: Writing to a terminated pipe - pipe1, _ = event_pipe_pair + pipe1, _pipe2 = event_pipe_pair await pipe1.terminate() result = await pipe1.write("test_chunk") @@ -86,16 +80,12 @@ async def test_write_after_terminate(event_pipe_pair: Tuple[EventPipe, EventPipe @pytest.mark.asyncio async def test_parallel_event_writes(event_pipe_pair: Tuple[EventPipe, EventPipe]): - # Test case: Multiple coroutines writing to one end of the EventPipe pipe1, pipe2 = event_pipe_pair num_writes = 10 test_chunks = [f"chunk_{i}" for i in range(num_writes)] received_chunks = [] - # Use a lock to protect shared list in concurrent access lock = asyncio.Lock() - - # Use a queue to signal when all chunks are received received_queue = asyncio.Queue() async def pipeline_handler(chunk: Any) -> None: @@ -109,20 +99,16 @@ async def pipeline_handler(chunk: Any) -> None: async def writer(chunk): return await pipe1.write(chunk) - # Concurrently write all chunks results = await asyncio.gather(*[writer(chunk) for chunk in test_chunks]) - assert all(results) # All writes should be successful + assert all(results) - # Wait for all chunks to be received by the handler await asyncio.wait_for(received_queue.get(), timeout=5) - # Verify all chunks are received assert sorted(received_chunks) == sorted(test_chunks) @pytest.mark.asyncio async def test_parallel_event_processing(event_pipe_pair: Tuple[EventPipe, EventPipe]): - # Test case: One pipe writes multiple chunks, and the other pipe's handler processes them pipe1, pipe2 = event_pipe_pair num_chunks = 10 test_chunks = [f"data_{i}" for i in range(num_chunks)] @@ -132,7 +118,6 @@ async def test_parallel_event_processing(event_pipe_pair: Tuple[EventPipe, Event processed_queue = asyncio.Queue() async def processing_handler(chunk: Any) -> None: - # Simulate some async processing await asyncio.sleep(0.01) async with lock: processed_chunks.append(chunk) @@ -141,14 +126,11 @@ async def processing_handler(chunk: Any) -> None: pipe2.pipeline("data", processing_handler) - # Write all chunks sequentially (or in parallel, EventPipe handles internal queueing) write_tasks = [pipe1.write(chunk) for chunk in test_chunks] await asyncio.gather(*write_tasks) - # Wait for all chunks to be processed await asyncio.wait_for(processed_queue.get(), timeout=5) - # Verify all chunks are processed assert sorted(processed_chunks) == sorted(test_chunks) @@ -156,11 +138,9 @@ async def processing_handler(chunk: Any) -> None: async def test_concurrent_bidirectional_event_communication( event_pipe_pair: Tuple[EventPipe, EventPipe], ): - # Test case: Concurrent writes and event processing from both ends pipe1, pipe2 = event_pipe_pair num_messages = 5 - client_sent_msgs = [] client_received_responses = [] server_received_msgs = [] server_sent_responses = [] @@ -181,29 +161,19 @@ async def server_handler(msg: Any) -> None: if len(server_sent_responses) == num_messages: server_done_event.set() - pipe1.pipeline("data", client_handler) # Client listens for responses on the 'data' pipeline - pipe2.pipeline("data", server_handler) # Server listens for client messages + pipe1.pipeline("data", client_handler) + pipe2.pipeline("data", server_handler) async def client_writer_task(): for i in range(num_messages): msg = f"client_msg_{i}" await pipe1.write(msg) - client_sent_msgs.append(msg) - # Client expects a response, but it's handled by client_handler asyncio.create_task(client_writer_task()) - # Wait for both sides to complete their communication await asyncio.wait_for(asyncio.gather(client_done_event.wait(), server_done_event.wait()), timeout=5) - # Verify client sent messages are received by server assert sorted([f"client_msg_{i}" for i in range(num_messages)]) == sorted(server_received_msgs) - - # Verify server sent messages are received by client assert sorted([f"server_response_to_client_msg_{i}" for i in range(num_messages)]) == sorted( client_received_responses ) - - -if __name__ == "__main__": - pytest.main() diff --git a/tests/test_pipe.py b/tests/pipe/test_pipe.py similarity index 70% rename from tests/test_pipe.py rename to tests/pipe/test_pipe.py index ae4f6f9..01242be 100644 --- a/tests/test_pipe.py +++ b/tests/pipe/test_pipe.py @@ -1,9 +1,11 @@ -import pytest import asyncio -from typing import Any, Tuple -from cbor_rpc.pipe.pipe import Pipe +from typing import Tuple + +import pytest import pytest_asyncio +from cbor_rpc.pipe.pipe import Pipe + @pytest_asyncio.fixture async def pipe_pair(): @@ -13,7 +15,6 @@ async def pipe_pair(): @pytest.mark.asyncio async def test_create_pair(): - # Positive case: Creating a pair of sync pipes pipe1, pipe2 = Pipe.create_pair() assert isinstance(pipe1, Pipe) assert isinstance(pipe2, Pipe) @@ -23,17 +24,15 @@ async def test_create_pair(): @pytest.mark.asyncio async def test_write_read(pipe_pair: Tuple[Pipe, Pipe]): - # Positive case: Writing and reading a chunk successfully pipe1, pipe2 = pipe_pair assert await pipe1.write("test_chunk") is True - await asyncio.sleep(0) # Allow event loop to process the write + await asyncio.sleep(0) assert await pipe2.read() == "test_chunk" @pytest.mark.asyncio async def test_close_pipe(pipe_pair: Tuple[Pipe, Pipe]): - # Positive case: Closing the pipe pipe1, pipe2 = pipe_pair await pipe1.terminate() @@ -44,38 +43,34 @@ async def test_close_pipe(pipe_pair: Tuple[Pipe, Pipe]): @pytest.mark.asyncio async def test_write_after_close(pipe_pair: Tuple[Pipe, Pipe]): - # Negative case: Writing to a closed pipe - pipe1, pipe2 = pipe_pair + pipe1, _pipe2 = pipe_pair await pipe1.terminate() assert await pipe1.write("test_chunk") is False @pytest.mark.asyncio -async def test_read_timeout(pipe_pair): - # Positive case: Reading with timeout - pipe1, _ = pipe_pair +async def test_read_timeout(pipe_pair: Tuple[Pipe, Pipe]): + pipe1, _pipe2 = pipe_pair assert await pipe1.read(timeout=0.1) is None @pytest.mark.asyncio async def test_bidirectional_communication(pipe_pair: Tuple[Pipe, Pipe]): - # Positive case: Bidirectional communication between pipes pipe1, pipe2 = pipe_pair assert await pipe1.write("test_chunk") is True - await asyncio.sleep(0) # Allow event loop to process the write + await asyncio.sleep(0) assert await pipe2.read() == "test_chunk" assert await pipe2.write("response_chunk") is True - await asyncio.sleep(0) # Allow event loop to process the write + await asyncio.sleep(0) assert await pipe1.read() == "response_chunk" @pytest.mark.asyncio async def test_parallel_writes(pipe_pair: Tuple[Pipe, Pipe]): - # Test case: Multiple coroutines writing to one end of the pipe pipe1, pipe2 = pipe_pair num_writes = 10 test_chunks = [f"chunk_{i}" for i in range(num_writes)] @@ -83,44 +78,36 @@ async def test_parallel_writes(pipe_pair: Tuple[Pipe, Pipe]): async def writer(chunk): return await pipe1.write(chunk) - # Concurrently write all chunks results = await asyncio.gather(*[writer(chunk) for chunk in test_chunks]) - assert all(results) # All writes should be successful + assert all(results) - # Read all chunks from the other end received_chunks = [] for _ in range(num_writes): received_chunks.append(await pipe2.read()) - # Verify all chunks are received and in correct order (or at least all present) assert sorted(received_chunks) == sorted(test_chunks) @pytest.mark.asyncio async def test_parallel_reads(pipe_pair: Tuple[Pipe, Pipe]): - # Test case: Multiple coroutines reading from one end of the pipe pipe1, pipe2 = pipe_pair num_reads = 20 test_chunks = [f"data_{i}" for i in range(num_reads)] - # Write all chunks first for chunk in test_chunks: await pipe1.write(chunk) - await asyncio.sleep(0) # Allow event loop to process the write + await asyncio.sleep(0) async def reader(): return await pipe2.read() - # Concurrently read all chunks received_chunks = await asyncio.gather(*[reader() for _ in range(num_reads)]) - # Verify all chunks are received assert sorted(received_chunks) == sorted(test_chunks) @pytest.mark.asyncio async def test_concurrent_bidirectional_communication(pipe_pair: Tuple[Pipe, Pipe]): - # Test case: Concurrent writes and reads from both ends pipe1, pipe2 = pipe_pair num_messages = 20 @@ -149,15 +136,8 @@ async def server_task(): client_future = asyncio.create_task(client_task()) server_future = asyncio.create_task(server_task()) - client_sent, client_received = await client_future - server_sent, server_received = await server_future + _client_sent, client_received = await client_future + _server_sent, server_received = await server_future - # Verify client sent messages are received by server assert sorted([f"client_msg_{i}" for i in range(num_messages)]) == sorted(server_received) - - # Verify server sent messages are received by client assert sorted([f"server_response_to_client_msg_{i}" for i in range(num_messages)]) == sorted(client_received) - - -if __name__ == "__main__": - pytest.main() diff --git a/tests/pipe/test_pipe_extra.py b/tests/pipe/test_pipe_extra.py new file mode 100644 index 0000000..94ce372 --- /dev/null +++ b/tests/pipe/test_pipe_extra.py @@ -0,0 +1,94 @@ +import asyncio +from typing import Any, List + +import pytest + +from cbor_rpc.pipe.pipe import Pipe + + +@pytest.mark.asyncio +async def test_pipe_read_timeout_returns_none(): + a, _b = Pipe.create_pair() + result = await a.read(timeout=0.01) + assert result is None + + +@pytest.mark.asyncio +async def test_pipe_cancelled_read_requeues_termination_signal(): + a, _b = Pipe.create_pair() + read_task = asyncio.create_task(a.read(timeout=1)) + await asyncio.sleep(0) + read_task.cancel() + with pytest.raises(asyncio.CancelledError): + await read_task + a._buffer.put_nowait(None) + result = await a.read(timeout=0.01) + assert result is None + + +@pytest.mark.asyncio +async def test_pipe_terminate_closes_both_ends(): + a, b = Pipe.create_pair() + await a.terminate("done") + result = await b.read(timeout=0.01) + assert result is None + + +@pytest.mark.asyncio +async def test_pipe_make_event_based_emits_data_and_close(): + a, b = Pipe.create_pair() + event_pipe = a.make_event_based() + received: List[Any] = [] + closed = asyncio.Event() + + async def on_data(data: Any) -> None: + received.append(data) + + def on_close(*_args: Any) -> None: + closed.set() + + event_pipe.pipeline("data", on_data) + event_pipe.on("close", on_close) + + await b.write("hello") + await asyncio.sleep(0.01) + assert received == ["hello"] + + await b.terminate("bye") + await asyncio.wait_for(closed.wait(), timeout=1) + + +@pytest.mark.asyncio +async def test_pipe_write_fails_when_peer_closed(): + a, b = Pipe.create_pair() + await b.terminate("done") + result = await a.write("data") + assert result is False + + +@pytest.mark.asyncio +async def test_pipe_read_with_zero_timeout_reads_buffered_item(): + a, b = Pipe.create_pair() + await b.write("data") + result = await a.read(timeout=0) + assert result == "data" + + +@pytest.mark.asyncio +async def test_pipe_make_event_based_write_roundtrip(): + a, b = Pipe.create_pair() + event_pipe = a.make_event_based() + await event_pipe.write("out") + result = await b.read(timeout=0.1) + assert result == "out" + await event_pipe.terminate("done") + + +@pytest.mark.asyncio +async def test_pipe_make_event_based_terminate_is_idempotent(capsys): + a, _b = Pipe.create_pair() + event_pipe = a.make_event_based() + await event_pipe.terminate("done") + await event_pipe.terminate("done") + captured = capsys.readouterr().out + assert "PipeToEvent: Terminating event pipe." in captured diff --git a/tests/rpc/test_rpc_base.py b/tests/rpc/test_rpc_base.py new file mode 100644 index 0000000..1ec5522 --- /dev/null +++ b/tests/rpc/test_rpc_base.py @@ -0,0 +1,56 @@ +from typing import Any, Optional + +import pytest + +from cbor_rpc.rpc.rpc_base import RpcClient, RpcServer + + +class DummyClient(RpcClient): + async def call_method(self, method: str, *args: Any) -> Any: + return await RpcClient.call_method(self, method, *args) + + async def fire_method(self, method: str, *args: Any) -> None: + return await RpcClient.fire_method(self, method, *args) + + def set_timeout(self, milliseconds: int) -> None: + return RpcClient.set_timeout(self, milliseconds) + + +class DummyServerBase(RpcServer): + async def call_method(self, connection_id: str, method: str, *args: Any) -> Any: + return await RpcServer.call_method(self, connection_id, method, *args) + + async def fire_method(self, connection_id: str, method: str, *args: Any) -> None: + return await RpcServer.fire_method(self, connection_id, method, *args) + + async def disconnect(self, connection_id: str, reason: Optional[str] = None) -> None: + return await RpcServer.disconnect(self, connection_id, reason) + + def get_client(self, connection_id: str): + return RpcServer.get_client(self, connection_id) + + def with_client(self, connection_id: str, action): + return RpcServer.with_client(self, connection_id, action) + + def set_timeout(self, milliseconds: int) -> None: + return RpcServer.set_timeout(self, milliseconds) + + def is_active(self, connection_id: str) -> bool: + return RpcServer.is_active(self, connection_id) + + +@pytest.mark.asyncio +async def test_rpc_base_methods_execute(): + client = DummyClient() + assert await client.call_method("m") is None + assert await client.fire_method("m") is None + assert client.set_timeout(1) is None + + server = DummyServerBase() + assert await server.call_method("id", "m") is None + assert await server.fire_method("id", "m") is None + assert await server.disconnect("id") is None + assert server.get_client("id") is None + assert server.with_client("id", lambda _c: None) is None + assert server.set_timeout(1) is None + assert server.is_active("id") is None diff --git a/tests/rpc/test_rpc_logging.py b/tests/rpc/test_rpc_logging.py new file mode 100644 index 0000000..2e923dc --- /dev/null +++ b/tests/rpc/test_rpc_logging.py @@ -0,0 +1,150 @@ +import asyncio +from typing import Any, List + +import pytest + +from cbor_rpc import EventPipe, RpcV1 +from cbor_rpc.rpc.context import RpcCallContext + + +LOG_LEVELS = [ + ("crit", 1), + ("warn", 2), + ("info", 3), + ("verbose", 4), + ("debug", 5), +] + + +class LoggingServerRpc(RpcV1): + def __init__(self, pipe: EventPipe[Any, Any], id_: str): + super().__init__(pipe) + self._id = id_ + + def get_id(self) -> str: + return self._id + + def handle_method_call(self, context: RpcCallContext, method: str, args: List[Any]) -> Any: + if method == "log_levels": + for level_name, _level_value in LOG_LEVELS: + self._log_for_level(context, f"method:{level_name}", level_name) + return "ok" + raise Exception(f"Unknown method: {method}") + + async def on_event(self, context: RpcCallContext, topic: str, message: Any) -> None: + if message == "levels": + for level_name, _level_value in LOG_LEVELS: + self._log_for_level(context, f"event:{topic}:{level_name}", level_name) + return + context.logger.warn(f"event:{topic}:{message}") + + def _log_for_level(self, context: RpcCallContext, content: str, level_name: str) -> None: + if level_name == "crit": + context.logger.crit(content) + elif level_name == "warn": + context.logger.warn(content) + elif level_name == "info": + context.logger.log(content) + elif level_name == "verbose": + context.logger.verbose(content) + elif level_name == "debug": + context.logger.debug(content) + else: + raise Exception(f"Unknown log level: {level_name}") + + +class LoggingClientRpc(RpcV1): + def __init__(self, pipe: EventPipe[Any, Any], id_: str): + super().__init__(pipe) + self._id = id_ + + def get_id(self) -> str: + return self._id + + def handle_method_call(self, context: RpcCallContext, method: str, args: List[Any]) -> Any: + raise Exception("Client Only Implementation") + + async def on_event(self, context: RpcCallContext, topic: str, message: Any) -> None: + return None + + +async def _collect_logs(queue: asyncio.Queue, expected_count: int) -> List[List[Any]]: + logs: List[List[Any]] = [] + for _ in range(expected_count): + log = await asyncio.wait_for(queue.get(), timeout=1) + logs.append(log) + return logs + + +def _attach_log_listener(pipe: EventPipe[Any, Any], queue: asyncio.Queue) -> None: + async def handler(chunk: Any) -> None: + if not isinstance(chunk, list) or len(chunk) < 2: + return + if chunk[0] != 2 or chunk[1] == 0: + return + await queue.put(chunk) + + pipe.pipeline("data", handler) + + +@pytest.mark.asyncio +async def test_rpc_logging_all_levels_method_and_event(): + pipe_a, pipe_b = EventPipe.create_inmemory_pair() + server = LoggingServerRpc(pipe_a, "server") + client = LoggingClientRpc(pipe_b, "client") + + log_queue: asyncio.Queue = asyncio.Queue() + _attach_log_listener(pipe_b, log_queue) + + await client.set_log_level(5) + await asyncio.sleep(0.01) + + result = await client.call_method("log_levels") + assert result == "ok" + await client.emit("topic1", "levels") + + logs = await _collect_logs(log_queue, expected_count=10) + + expected = set() + for level_name, level_value in LOG_LEVELS: + expected.add((level_value, 1, 0, f"method:{level_name}")) + expected.add((level_value, 3, "topic1", f"event:topic1:{level_name}")) + + actual = set((log[1], log[2], log[3], log[4]) for log in logs) + assert actual == expected + + await pipe_a.terminate() + await pipe_b.terminate() + + +@pytest.mark.asyncio +async def test_rpc_logging_respects_remote_level_setting(): + pipe_a, pipe_b = EventPipe.create_inmemory_pair() + server = LoggingServerRpc(pipe_a, "server") + client = LoggingClientRpc(pipe_b, "client") + + log_queue: asyncio.Queue = asyncio.Queue() + _attach_log_listener(pipe_b, log_queue) + + await client.set_log_level(2) + await asyncio.sleep(0.01) + + result = await client.call_method("log_levels") + assert result == "ok" + await client.emit("topic2", "levels") + + logs = await _collect_logs(log_queue, expected_count=4) + actual_levels = {log[1] for log in logs} + assert actual_levels == {1, 2} + + actual = set((log[1], log[2], log[3], log[4]) for log in logs) + expected = { + (1, 1, 0, "method:crit"), + (2, 1, 0, "method:warn"), + (1, 3, "topic2", "event:topic2:crit"), + (2, 3, "topic2", "event:topic2:warn"), + } + assert actual == expected + + await pipe_a.terminate() + await pipe_b.terminate() diff --git a/tests/test_rpc_v1.py b/tests/rpc/test_rpc_v1.py similarity index 85% rename from tests/test_rpc_v1.py rename to tests/rpc/test_rpc_v1.py index 8f3495f..a27bd95 100644 --- a/tests/test_rpc_v1.py +++ b/tests/rpc/test_rpc_v1.py @@ -1,7 +1,7 @@ -import pytest import asyncio -from typing import Any, Generic, List -from unittest.mock import AsyncMock, MagicMock +from typing import Any, List + +import pytest from cbor_rpc import EventPipe, RpcV1, TimedPromise from cbor_rpc.rpc.context import RpcCallContext @@ -25,20 +25,18 @@ def throw_error_method(message: str) -> None: raise Exception(message) -# Method handler for RPC def method_handler(context: RpcCallContext, method: str, args: List[Any]) -> Any: if method == "sleep": return sleep_method(*args) - elif method == "add": + if method == "add": return add_method(*args) - elif method == "multiply": + if method == "multiply": return multiply_method(*args) - elif method == "throwError": + if method == "throwError": return throw_error_method(*args) raise Exception(f"Unknown method: {method}") - class EventRpcHelper(RpcV1): def get_id(self) -> str: return "event_rpc" @@ -97,9 +95,9 @@ async def test_throw_error_method(rpc): @pytest.mark.asyncio async def test_call_method_timeout(rpc, pipe): - rpc.set_timeout(100) # Set short timeout + rpc.set_timeout(100) with pytest.raises(Exception) as exc_info: - await rpc.call_method("sleep", 1) # Long sleep to trigger timeout + await rpc.call_method("sleep", 1) assert exc_info.value.args[0]["timeout"] is True assert exc_info.value.args[0]["timeoutPeriod"] == 100 @@ -114,7 +112,7 @@ async def test_call_method_unknown_method(rpc): @pytest.mark.asyncio async def test_fire_method(rpc): await rpc.fire_method("add", 1, 2) - assert rpc._counter == 1 # Verify message was sent + assert rpc._counter == 1 @pytest.mark.asyncio @@ -152,16 +150,14 @@ async def test_wait_next_event_already_waiting(event_rpc): @pytest.mark.asyncio async def test_invalid_message_format(rpc, pipe): - await pipe.write([1]) # Invalid message - await asyncio.sleep(0.1) # Allow processing - # No crash means test passes + await pipe.write([1]) + await asyncio.sleep(0.1) @pytest.mark.asyncio async def test_unsupported_protocol(rpc, pipe): - await pipe.write([99, 0, 0, "add", [1, 2]]) # Unsupported protocol - await asyncio.sleep(0.1) # Allow processing - # No crash means test passes + await pipe.write([99, 0, 0, "add", [1, 2]]) + await asyncio.sleep(0.1) @pytest.mark.asyncio @@ -175,7 +171,6 @@ async def test_concurrent_method_calls(rpc): async def test_read_only_client(pipe): read_only = RpcV1.read_only_client(SimplePipe()) - # We need to directly call the handle_method_call method to test it with pytest.raises(Exception) as exc_info: context = RpcCallContext(read_only.logger) read_only.handle_method_call(context, "add", [1, 2]) diff --git a/tests/rpc/test_rpc_v1_extra.py b/tests/rpc/test_rpc_v1_extra.py new file mode 100644 index 0000000..6407d17 --- /dev/null +++ b/tests/rpc/test_rpc_v1_extra.py @@ -0,0 +1,224 @@ +import asyncio +from typing import Any, List + +import pytest + +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.rpc.context import RpcCallContext +from cbor_rpc.rpc.rpc_v1 import RpcV1, RpcCore +from cbor_rpc.rpc.rpc_server import RpcV1Server + + +class TestRpcServer(RpcV1Server): + async def handle_method_call( + self, + connection_id: str, + context: RpcCallContext, + method: str, + args: List[Any], + ) -> Any: + if method == "ping": + return f"pong:{connection_id}:{args[0]}" + raise Exception("Unknown method") + + +def _noop_handler(context: RpcCallContext, method: str, args: List[Any]) -> Any: + if method == "fire": + return None + return "ok" + + +class CoreOnlyRpc(RpcCore): + def get_id(self) -> str: + return "core" + + def handle_method_call(self, context: RpcCallContext, method: str, args: List[Any]) -> Any: + if method == "boom": + raise Exception("boom") + if method == "nested": + async def inner() -> str: + return "ok" + + async def outer(): + return inner() + + return outer() + return "ok" + + +@pytest.mark.asyncio +async def test_rpc_v1_proto_validation_and_logging(caplog): + import logging + caplog.set_level(logging.INFO) + pipe_a, pipe_b = EventPipe.create_inmemory_pair() + rpc = RpcV1.make_rpc_v1(pipe_a, "id", _noop_handler) + + await pipe_b.write([1, 0, 1]) + await pipe_b.write([1, 2, 1]) + await pipe_b.write([1, 9, 1, "x", []]) + await pipe_b.write([1, 0, 999, "ok"]) + await pipe_b.write([2, 1, 2]) + await pipe_b.write([2, 0, 3]) + await pipe_b.write([2, 99, 1, 2, "content"]) + await pipe_b.write([3, 0]) + await pipe_b.write([3, 1, "topic", "msg"]) + + await asyncio.sleep(0.05) + + assert rpc._peer_log_level == 3 + + logs = [r.message for r in caplog.records] + assert any("Invalid response format" in log for log in logs) + assert any("Invalid call format" in log for log in logs) + assert any("Unknown sub-protocol" in log for log in logs) + assert any("expired request id" in log for log in logs) + assert any("Invalid format" in log for log in logs) + assert any("[RemoteLog:LEVEL-99]" in log for log in logs) + assert any("Invalid event format" in log for log in logs) + + await pipe_a.terminate() + await pipe_b.terminate() + + +@pytest.mark.asyncio +async def test_rpc_v1_send_log_filters_by_peer_level(): + pipe_a, pipe_b = EventPipe.create_inmemory_pair() + rpc = RpcV1.make_rpc_v1(pipe_a, "id", _noop_handler) + + received: List[List[Any]] = [] + + async def on_data(data: Any) -> None: + if isinstance(data, list) and data and data[0] == 2: + received.append(data) + + pipe_b.pipeline("data", on_data) + + rpc.logger.log("skip") + await asyncio.sleep(0.01) + assert received == [] + + rpc._peer_log_level = 2 + rpc.logger.debug("skip") + await asyncio.sleep(0.01) + assert received == [] + + rpc.logger.warn("keep") + await asyncio.sleep(0.01) + assert received[-1][1] == 2 + + await pipe_a.terminate() + await pipe_b.terminate() + + +@pytest.mark.asyncio +async def test_rpc_v1_server_connection_lifecycle(): + server = TestRpcServer() + pipe_a, pipe_b = EventPipe.create_inmemory_pair() + + client_called = asyncio.Event() + + def client_handler(context: RpcCallContext, method: str, args: List[Any]) -> Any: + if method == "ping": + return f"client:{args[0]}" + if method == "fire": + client_called.set() + return None + raise Exception("Unknown method") + + client = RpcV1.make_rpc_v1(pipe_b, "client", client_handler) + + await server.add_connection("c1", pipe_a) + assert server.is_active("c1") + + result = await server.call_method("c1", "ping", "hi") + assert result == "client:hi" + + await server.fire_method("c1", "fire") + await asyncio.wait_for(client_called.wait(), timeout=1) + + called: List[bool] = [] + + def action(_client: Any) -> None: + called.append(True) + + assert server.with_client("c1", action) is True + assert called == [True] + + await server.disconnect("c1", "bye") + assert server.is_active("c1") is False + assert server.with_client("missing", action) is False + + await pipe_a.terminate() + await pipe_b.terminate() + + +@pytest.mark.asyncio +async def test_rpc_v1_server_inactive_client_errors(): + server = TestRpcServer() + + with pytest.raises(Exception) as exc_info: + await server.call_method("missing", "ping") + assert str(exc_info.value) == "Client is not active" + + with pytest.raises(Exception) as exc_info: + await server.fire_method("missing", "ping") + assert str(exc_info.value) == "Client is not active" + + +@pytest.mark.asyncio +async def test_rpc_core_fire_error_and_unsupported_event(caplog): + pipe_a, pipe_b = EventPipe.create_inmemory_pair() + _rpc = CoreOnlyRpc(pipe_a) + + await pipe_b.write([3, 0, "topic", "msg"]) + await pipe_b.write([1, 3, 1, "boom", []]) + await asyncio.sleep(0.05) + + logs = [r.message for r in caplog.records] + assert any("Unsupported event message" in log for log in logs) + assert any("Fired method error" in log for log in logs) + + await pipe_a.terminate() + await pipe_b.terminate() + + +@pytest.mark.asyncio +async def test_rpc_core_nested_result_and_error_response(): + pipe_a, pipe_b = EventPipe.create_inmemory_pair() + _rpc = CoreOnlyRpc(pipe_a) + + responses: List[List[Any]] = [] + + async def on_data(data: Any) -> None: + if isinstance(data, list) and data and data[0] == 1 and data[1] in (0, 1): + responses.append(data) + + pipe_b.pipeline("data", on_data) + + await pipe_b.write([1, 2, 1, "nested", []]) + await pipe_b.write([1, 2, 2, "boom", []]) + await asyncio.sleep(0.05) + + assert [1, 0, 1, "ok"] in responses + assert [1, 1, 2, "boom"] in responses + + await pipe_a.terminate() + await pipe_b.terminate() + + +@pytest.mark.asyncio +async def test_rpc_v1_server_close_cleanup_and_timeout_applied(): + server = TestRpcServer() + server.set_timeout(123) + pipe_a, _pipe_b = EventPipe.create_inmemory_pair() + await server.add_connection("c1", pipe_a) + assert server.rpc_clients["c1"]._timeout == 123 + + await pipe_a.terminate("done") + await asyncio.sleep(0.01) + assert server.is_active("c1") is False + + +def test_rpc_v1_server_get_client_and_disconnect_missing(): + server = TestRpcServer() + assert server.get_client("missing") is None diff --git a/tests/rpc/test_server_base.py b/tests/rpc/test_server_base.py new file mode 100644 index 0000000..e84ffa8 --- /dev/null +++ b/tests/rpc/test_server_base.py @@ -0,0 +1,69 @@ +import asyncio +from typing import Any, List + +import pytest + +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.rpc.server_base import Server + + +class DummyServer(Server[EventPipe]): + def __init__(self, accept_connections: bool = True): + super().__init__() + self._accept_connections = accept_connections + self.stopped = False + + async def start(self, *args, **kwargs) -> Any: + self._running = True + return "started" + + async def stop(self) -> None: + self._running = False + self.stopped = True + + async def accept(self, pipe: EventPipe) -> bool: + return self._accept_connections + + +@pytest.mark.asyncio +async def test_server_base_add_connection_accepts_and_cleans(): + server = DummyServer(accept_connections=True) + pipe_a, _pipe_b = EventPipe.create_inmemory_pair() + + connected: List[EventPipe] = [] + + def on_connection(pipe: EventPipe) -> None: + connected.append(pipe) + + server.on_connection(on_connection) + await server._add_connection(pipe_a) + assert pipe_a in server.get_connections() + assert connected == [pipe_a] + + await pipe_a.terminate("bye") + await asyncio.sleep(0) + assert pipe_a not in server.get_connections() + + +@pytest.mark.asyncio +async def test_server_base_rejects_connection(): + server = DummyServer(accept_connections=False) + pipe_a, _pipe_b = EventPipe.create_inmemory_pair() + + await server._add_connection(pipe_a) + assert pipe_a not in server.get_connections() + + +@pytest.mark.asyncio +async def test_server_base_close_all_connections_and_context(): + server = DummyServer(accept_connections=True) + pipe_a, _pipe_b = EventPipe.create_inmemory_pair() + await server._add_connection(pipe_a) + + await server.close_all_connections() + await asyncio.sleep(0) + assert server.get_connections() == set() + + async with server: + assert server.is_running() is False + assert server.stopped is True diff --git a/tests/test_ssh_docker_pipe.py b/tests/ssh/test_ssh_docker_pipe.py similarity index 81% rename from tests/test_ssh_docker_pipe.py rename to tests/ssh/test_ssh_docker_pipe.py index 4fb22aa..b6a7276 100644 --- a/tests/test_ssh_docker_pipe.py +++ b/tests/ssh/test_ssh_docker_pipe.py @@ -1,25 +1,24 @@ import asyncio -import pytest -import docker -import time -import asyncssh import os import re +import time + +import asyncssh import asyncssh.public_key +import docker +import pytest from cbor_rpc.ssh.ssh_pipe import SshPipe -# Define a test user and password for the SSHD container TEST_SSH_USER = "testuser" TEST_SSH_PASSWORD = "testpassword" -SSHD_IMAGE_NAME = "cbor-rpc-py-sshd-python" # Custom image name +SSHD_IMAGE_NAME = "cbor-rpc-py-sshd-python" SSHD_CONTAINER_NAME = "test-sshd-container" -SSHD_DOCKERFILE_PATH = "./tests/docker/sshd-python" # Path to the Dockerfile +SSHD_DOCKERFILE_PATH = "./tests/docker/sshd-python" @pytest.fixture(scope="session") def ssh_keys(): - """Generates SSH keys (plain and encrypted) and a passphrase for testing.""" private_key_obj = asyncssh.generate_private_key("ssh-rsa") passphrase = "test_passphrase" @@ -34,7 +33,6 @@ def ssh_keys(): @pytest.fixture(scope="session") def docker_client(): - """Provides a Docker client instance.""" client = docker.from_env() yield client client.close() @@ -42,11 +40,10 @@ def docker_client(): @pytest.fixture(scope="session") def test_network(docker_client: docker.DockerClient): - """Provides a Docker network for containers to communicate.""" network_name = "test-ssh-network" try: network = docker_client.networks.get(network_name) - network.remove() # Clean up existing network if it exists + network.remove() except docker.errors.NotFound: pass @@ -55,14 +52,13 @@ def test_network(docker_client: docker.DockerClient): network.remove() -@pytest.fixture(scope="module") # Changed scope to module as requested +@pytest.fixture(scope="module") async def ssh_container_combined_auth(docker_client: docker.DockerClient, test_network, docker_host_ip, ssh_keys): container_name = "ssh-test-container-combined-auth" ssh_user = TEST_SSH_USER ssh_password = TEST_SSH_PASSWORD - public_key = ssh_keys["unencrypted_public"] # Use the unencrypted public key + public_key = ssh_keys["unencrypted_public"] - # Ensure previous container is stopped and removed try: existing_container = docker_client.containers.get(container_name) print(f"Found existing container '{container_name}'. Stopping and removing...") @@ -74,13 +70,12 @@ async def ssh_container_combined_auth(docker_client: docker.DockerClient, test_n except Exception as e: print(f"Error cleaning up existing container: {e}") - # Build the custom Docker image print(f"\nBuilding Docker image '{SSHD_IMAGE_NAME}' from '{SSHD_DOCKERFILE_PATH}'...") try: docker_client.images.build( path=SSHD_DOCKERFILE_PATH, tag=SSHD_IMAGE_NAME, - rm=True, # Remove intermediate containers + rm=True, ) print(f"Docker image '{SSHD_IMAGE_NAME}' built successfully.") except docker.errors.BuildError as e: @@ -91,26 +86,26 @@ async def ssh_container_combined_auth(docker_client: docker.DockerClient, test_n container = None try: container = docker_client.containers.run( - SSHD_IMAGE_NAME, # Use the custom image name + SSHD_IMAGE_NAME, detach=True, - ports={"2222/tcp": None}, # Map container SSH port to a random host port + ports={"2222/tcp": None}, network=test_network.name, name=container_name, environment={ "PUID": "1000", "PGID": "1000", "TZ": "Etc/UTC", - "PASSWORD_ACCESS": "true", # Password access enabled + "PASSWORD_ACCESS": "true", "USER_NAME": ssh_user, "USER_PASSWORD": ssh_password, - "PUBLIC_KEY": public_key, # Add one public key + "PUBLIC_KEY": public_key, }, restart_policy={"Name": "no"}, ) container.reload() host_port = None - for _ in range(30): # Wait up to 30 seconds for port mapping + for _ in range(30): container.reload() if "2222/tcp" in container.ports and container.ports["2222/tcp"]: host_port = container.ports["2222/tcp"][0]["HostPort"] @@ -122,13 +117,11 @@ async def ssh_container_combined_auth(docker_client: docker.DockerClient, test_n print(f"SSH container with combined auth running on host port: {host_port}") - # Wait for SSH server to be ready ready = False - for i in range(60): # wait up to 60 seconds + for i in range(60): try: async def check_ssh_combined(): - # Check password authentication try: conn_pw = await asyncssh.connect( docker_host_ip, @@ -143,7 +136,6 @@ async def check_ssh_combined(): print(f"Password auth check failed: {e}") return False - # Check public key authentication (unencrypted) try: conn_key = await asyncssh.connect( docker_host_ip, @@ -188,10 +180,8 @@ async def check_ssh_combined(): @pytest.fixture(scope="session") def docker_host_ip(): - """Determines the Docker host IP for connecting to containers.""" docker_host = os.environ.get("DOCKER_HOST") - # Regex to match tcp://hostname:port, unix://socket, or ip:port regex = r"^(?:(tcp|unix)://)?([a-zA-Z0-9.-]+)(?::\d+)?$" if docker_host: @@ -199,19 +189,13 @@ def docker_host_ip(): if match: protocol, host = match.groups() if protocol == "unix": - # For unix sockets, connections are typically local, but asyncssh needs an IP - # In this case, 'localhost' is usually appropriate for host-to-container communication return "localhost" - return host # Return the IP or hostname - return "localhost" # Default for local Docker setup + return host + return "localhost" @pytest.mark.asyncio async def test_ssh_pipe_with_hello_world_emitter(ssh_container_combined_auth): - """ - Tests SshPipe by connecting to an SSHD container that emits "hello world" every second. - This test uses password authentication. - """ container, host, port, username, password, _ = ssh_container_combined_auth print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} with password authentication...") @@ -234,7 +218,6 @@ async def test_ssh_pipe_with_hello_world_emitter(ssh_container_combined_auth): def on_data_callback(data): print(f"test_ssh_pipe_with_hello_world_emitter: Received data: {data!r}") received_data.append(data) - # The data should be bytes, so compare directly with bytes literal if b"hello world" in data: received_event.set() @@ -256,16 +239,11 @@ def on_data_callback(data): @pytest.mark.asyncio async def test_ssh_pipe_with_password_authentication(ssh_container_combined_auth): - """ - Tests SshPipe using username and password authentication. - This is a dedicated test for password authentication, ensuring it works as expected. - """ container, host, port, username, password, _ = ssh_container_combined_auth print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} with password authentication...") pipe = None try: - # Use a simple command to verify connection, e.g., 'echo' test_command = "echo 'Password auth successful!'" pipe = await SshPipe.connect( host=host, @@ -300,7 +278,7 @@ def on_data_callback(data): except asyncssh.Error as e: pytest.fail(f"SSH connection or command failed with password authentication: {e}") except asyncio.TimeoutError: - pytest.fail(f"SSH connection timed out with password authentication.") + pytest.fail("SSH connection timed out with password authentication.") except Exception as e: pytest.fail(f"An unexpected error occurred during password authentication test: {e}") finally: @@ -311,10 +289,6 @@ def on_data_callback(data): @pytest.mark.asyncio async def test_ssh_pipe_with_plain_key_authentication(ssh_container_combined_auth, ssh_keys): - """ - Tests SshPipe using a plain (unencrypted) SSH key for authentication. - This test is designed to run when sshd_container is configured for plain key auth. - """ container, host, port, username, _, unencrypted_private_key_content = ssh_container_combined_auth print(f"\nAttempting SshPipe connection to {host}:{port} as user {username} with plain key authentication...") @@ -326,7 +300,7 @@ async def test_ssh_pipe_with_plain_key_authentication(ssh_container_combined_aut host=host, port=port, username=username, - ssh_key_content=unencrypted_private_key_content, # Use the private key content for authentication + ssh_key_content=unencrypted_private_key_content, known_hosts=None, timeout=10, command=test_command, @@ -355,7 +329,7 @@ def on_data_callback(data): except asyncssh.Error as e: pytest.fail(f"SSH connection or command failed with plain key authentication: {e}") except asyncio.TimeoutError: - pytest.fail(f"SSH connection timed out with plain key authentication.") + pytest.fail("SSH connection timed out with plain key authentication.") except Exception as e: pytest.fail(f"An unexpected error occurred during plain key authentication test: {e}") finally: @@ -366,10 +340,6 @@ def on_data_callback(data): @pytest.mark.asyncio async def test_ssh_pipe_with_encrypted_key_authentication(ssh_container_combined_auth, ssh_keys): - """ - Tests SshPipe using an encrypted SSH key with a passphrase for authentication. - This test is designed to run when sshd_container is configured for encrypted key auth. - """ container, host, port, username, _, _ = ssh_container_combined_auth private_key_content = ssh_keys["encrypted_private"] @@ -384,8 +354,8 @@ async def test_ssh_pipe_with_encrypted_key_authentication(ssh_container_combined host=host, port=port, username=username, - ssh_key_content=private_key_content, # Use the private key content for authentication - ssh_key_passphrase=passphrase, # Pass the passphrase + ssh_key_content=private_key_content, + ssh_key_passphrase=passphrase, known_hosts=None, timeout=10, command=test_command, @@ -414,7 +384,7 @@ def on_data_callback(data): except asyncssh.Error as e: pytest.fail(f"SSH connection or command failed with encrypted key authentication: {e}") except asyncio.TimeoutError: - pytest.fail(f"SSH connection timed out.") + pytest.fail("SSH connection timed out.") except Exception as e: pytest.fail(f"An unexpected error occurred during encrypted key authentication test: {e}") finally: @@ -425,10 +395,6 @@ def on_data_callback(data): @pytest.mark.asyncio async def test_ssh_pipe_with_echo_back_command(ssh_container_combined_auth): - """ - Tests SshPipe by connecting to an SSHD container and running 'echo_back.py' to echo back input. - This test uses password authentication. - """ container, host, port, username, password, _ = ssh_container_combined_auth print( @@ -436,14 +402,13 @@ async def test_ssh_pipe_with_echo_back_command(ssh_container_combined_auth): ) pipe = None try: - # Use the custom Python echo-back script echo_back_command = "python3 /usr/local/bin/echo_back.py" pipe = await SshPipe.connect( host=host, port=port, username=username, password=password, - known_hosts=None, # Disable host key checking for test container + known_hosts=None, timeout=10, command=echo_back_command, ) @@ -454,8 +419,6 @@ async def test_ssh_pipe_with_echo_back_command(ssh_container_combined_auth): def on_data_callback(data): print(f"test_ssh_pipe_with_echo_back_command: Received data chunk: {data!r}") received_data_chunks.append(data) - # For echo-back, we expect the full message to be returned - # We'll set the event once we receive some data. received_event.set() pipe.pipeline("data", on_data_callback) @@ -475,12 +438,12 @@ def on_data_callback(data): full_received_data.strip() == test_message.strip() ), f"Received data {full_received_data!r} should exactly match sent data {test_message!r}" print("Verification successful: Data echoed correctly by 'echo_back.py' script.") - await pipe.write_eof() # Signal EOF to the remote process + await pipe.write_eof() except asyncssh.Error as e: pytest.fail(f"SSH connection or command failed: {e}") except asyncio.TimeoutError: - pytest.fail(f"SSH connection timed out.") + pytest.fail("SSH connection timed out.") except Exception as e: pytest.fail(f"An unexpected error occurred: {e}") finally: @@ -491,10 +454,6 @@ def on_data_callback(data): @pytest.mark.asyncio async def test_ssh_pipe_with_binary_data(ssh_container_combined_auth): - """ - Tests SshPipe by connecting to an SSHD container that emits raw binary data. - This test uses password authentication. - """ container, host, port, username, password, _ = ssh_container_combined_auth print( @@ -502,26 +461,23 @@ async def test_ssh_pipe_with_binary_data(ssh_container_combined_auth): ) pipe = None try: - # Python script to continuously emit binary data - # Using os.write(1, ...) to write raw bytes to stdout emitter_binary_command = "python3 /usr/local/bin/binary_emitter.py" pipe = await SshPipe.connect( host=host, port=port, username=username, password=password, - known_hosts=None, # Disable host key checking for test container + known_hosts=None, timeout=10, command=emitter_binary_command, ) received_event = asyncio.Event() received_data_chunks = [] - expected_binary_pattern = b"\xde\xad\xbe\xef\x00\x01\x02\x03\x80\xff\x7f" # Updated pattern without newlines + expected_binary_pattern = b"\xde\xad\xbe\xef\x00\x01\x02\x03\x80\xff\x7f" def on_data_callback(data): print(f"test_ssh_pipe_with_binary_data: Received data chunk: {data!r}") - # Strip any potential carriage returns or newlines added by the shell cleaned_data = data.replace(b"\r", b"").replace(b"\n", b"") received_data_chunks.append(cleaned_data) if expected_binary_pattern in cleaned_data: @@ -545,7 +501,7 @@ def on_data_callback(data): except asyncssh.Error as e: pytest.fail(f"SSH connection or command failed: {e}") except asyncio.TimeoutError: - pytest.fail(f"SSH connection timed out.") + pytest.fail("SSH connection timed out.") except Exception as e: pytest.fail(f"An unexpected error occurred: {e}") finally: diff --git a/tests/ssh/test_ssh_pipe.py b/tests/ssh/test_ssh_pipe.py new file mode 100644 index 0000000..4b11c13 --- /dev/null +++ b/tests/ssh/test_ssh_pipe.py @@ -0,0 +1,43 @@ +import pytest + +from cbor_rpc.ssh.ssh_pipe import SshPipe +from tests.helpers.stream_pair import create_stream_pair + + +@pytest.mark.asyncio +async def test_ssh_pipe_terminate_and_write_eof(): + server, reader, writer = await create_stream_pair() + + class DummyChannel: + def __init__(self): + self._closed = False + + def is_closing(self) -> bool: + return self._closed + + def close(self) -> None: + self._closed = True + + async def wait_closed(self) -> None: + return None + + class DummyClient: + def __init__(self): + self._closed = False + + def is_closed(self) -> bool: + return self._closed + + def close(self) -> None: + self._closed = True + + async def wait_closed(self) -> None: + return None + + pipe = SshPipe(reader, writer, DummyClient(), DummyChannel()) + + await pipe.write_eof() + await pipe.terminate() + + server.close() + await server.wait_closed() diff --git a/tests/stdio/test_stdio_pipe.py b/tests/stdio/test_stdio_pipe.py new file mode 100644 index 0000000..b2c9cdd --- /dev/null +++ b/tests/stdio/test_stdio_pipe.py @@ -0,0 +1,60 @@ +import asyncio +import sys + +import pytest + +from cbor_rpc.stdio.stdio_pipe import StdioPipe +from tests.helpers.stream_pair import create_stream_pair + + +@pytest.mark.asyncio +async def test_stdio_pipe_errors_without_process(): + server, reader, writer = await create_stream_pair() + pipe = StdioPipe(reader, writer) + + with pytest.raises(RuntimeError): + await pipe.wait_for_process_termination() + + # Should not raise + await pipe.terminate() + + writer.close() + await writer.wait_closed() + server.close() + await server.wait_closed() + + +@pytest.mark.asyncio +async def test_stdio_pipe_start_process_and_terminate(): + pipe = await StdioPipe.start_process(sys.executable, "-c", "import time; time.sleep(0.2)") + pipe.terminate() + code = await pipe.wait_for_process_termination() + assert isinstance(code, int) + + +@pytest.mark.asyncio +async def test_stdio_pipe_read_write(): + pipe = await StdioPipe.start_process("/bin/bash", "-c", "cat -") + + received_data = [] + future = asyncio.Future() + + def on_data(data): + received_data.append(data) + if len(received_data) == 10: + future.set_result(None) + + pipe.pipeline("data", on_data) + + test_data = [f"Test data {i}\n".encode("utf-8") for i in range(10)] + for data in test_data: + await pipe.write(data) + await asyncio.sleep(0.01) + + await future + + assert len(received_data) == 10 + for i, (sent, received) in enumerate(zip(test_data, received_data)): + assert received == sent, f"Mismatch at index {i}: expected {sent!r}, got {received!r}" + + pipe.terminate() diff --git a/tests/test_tcp.py b/tests/tcp/test_tcp.py similarity index 77% rename from tests/test_tcp.py rename to tests/tcp/test_tcp.py index 321fe45..8c537da 100644 --- a/tests/test_tcp.py +++ b/tests/tcp/test_tcp.py @@ -1,16 +1,16 @@ -import pytest import asyncio from typing import List + +import pytest + from cbor_rpc import TcpPipe from tests.helpers.simple_tcp_server import SimpleTcpServer -DEFAULT_TIMEOUT = 1.0 # we are doing everything on same machine. everything should be fast +DEFAULT_TIMEOUT = 1.0 @pytest.mark.asyncio async def test_tcp_client_server_connection(): - """Test basic TCP client-server connection.""" - # Start a server server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() @@ -22,17 +22,13 @@ def on_connection(tcp_pipe: TcpPipe): server.on_connection(on_connection) try: - # Create a client connection client = await TcpPipe.create_connection(server_host, server_port) - - # Wait for server to register the connection await asyncio.sleep(0.1) assert len(connections) == 1 assert client.is_connected() assert connections[0].is_connected() - # Test peer info client_peer = client.get_peer_info() server_conn_peer = connections[0].get_peer_info() @@ -42,12 +38,11 @@ def on_connection(tcp_pipe: TcpPipe): await client.terminate() finally: - await server.close() + await server.stop() @pytest.mark.asyncio async def test_tcp_data_exchange(): - """Test bidirectional data exchange over TCP.""" server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() @@ -67,7 +62,6 @@ async def on_server_data(data: bytes): server.on_connection(on_connection) try: - # Create client client = await TcpPipe.create_connection(server_host, server_port) async def on_client_data(data: bytes): @@ -75,35 +69,30 @@ async def on_client_data(data: bytes): client.on("data", on_client_data) - # Wait for connection to be established await asyncio.sleep(0.1) assert server_connection is not None - # Send data from client to server await client.write(b"Hello from client") await asyncio.sleep(0.1) assert server_received == [b"Hello from client"] - # Send data from server to client await server_connection.write(b"Hello from server") await asyncio.sleep(0.1) assert client_received == [b"Hello from server"] - # Send multiple messages separately server_received.clear() client_received.clear() await client.write(b"Message 1") - await asyncio.sleep(0.05) # Small delay between messages + await asyncio.sleep(0.05) await client.write(b"Message 2") await asyncio.sleep(0.1) await server_connection.write(b"Response 1") - await asyncio.sleep(0.05) # Small delay between messages + await asyncio.sleep(0.05) await server_connection.write(b"Response 2") await asyncio.sleep(0.1) - # Check that messages were received (they might be combined due to TCP buffering) server_data = b"".join(server_received) client_data = b"".join(client_received) @@ -115,22 +104,18 @@ async def on_client_data(data: bytes): await client.terminate() finally: - await server.close() + await server.stop() @pytest.mark.asyncio async def test_tcp_connection_errors(): - """Test TCP connection error handling.""" - # Test connection to non-existent server with pytest.raises(ConnectionError): await TcpPipe.create_connection("127.0.0.1", 12345, timeout=0.1) - # Test writing to disconnected client client = TcpPipe() with pytest.raises(ConnectionError): await client.write(b"test") - # Test double connection server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() @@ -143,12 +128,11 @@ async def test_tcp_connection_errors(): await client.terminate() finally: - await server.close() + await server.stop() @pytest.mark.asyncio async def test_tcp_connection_events(): - """Test TCP connection events (connect, close, error).""" server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() @@ -157,7 +141,6 @@ async def test_tcp_connection_events(): server.on_connection(lambda conn: events.append("server_connect")) try: - # Create client with event handlers client = TcpPipe() async def on_client_connect(): @@ -173,36 +156,30 @@ async def on_client_error(error): client.on("close", on_client_close) client.on("error", on_client_error) - # Connect await client.connect(server_host, server_port) - await asyncio.sleep(0.2) # Give more time for events to propagate + await asyncio.sleep(0.2) assert "client_connect" in events - print(events) assert "server_connect" in events - # Close connection events_before_close = len(events) await client.terminate("test_reason") - await asyncio.sleep(0.2) # Give time for close events + await asyncio.sleep(0.2) - # Check that close event was added assert len(events) > events_before_close close_events = [e for e in events if isinstance(e, tuple) and e[0] == "client_close"] assert len(close_events) > 0 finally: - await server.close() + await server.stop() @pytest.mark.asyncio async def test_tcp_client_connection_tracking(): - """Test handling multiple simultaneous TCP connections.""" server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() try: - # Create multiple clients clients: List[TcpPipe] = [] server.on_connection(lambda conn: print(f"New connection: {conn}")) for i in range(5): @@ -210,12 +187,10 @@ async def test_tcp_client_connection_tracking(): client.on("close", lambda: print(f"Connection[{i}] closed")) clients.append(client) - await asyncio.sleep(0.5) # Give time for all connections to be registered + await asyncio.sleep(0.5) - # Check that all connections are registered assert len(server.get_connections()) == 5 - # Close all clients for client in clients: await client.terminate() @@ -224,17 +199,15 @@ async def test_tcp_client_connection_tracking(): assert len(server.get_connections()) == 0, "Connections not clean uped" finally: - await server.close() + await server.stop() @pytest.mark.asyncio async def test_tcp_client_connection_tracking_self(): - """Test handling multiple simultaneous TCP connections.""" server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() try: - # Create multiple clients clients: List[TcpPipe] = [] server.on_connection(lambda conn: print(f"New connection: {conn}")) for i in range(5): @@ -242,12 +215,10 @@ async def test_tcp_client_connection_tracking_self(): client.on("close", lambda: print(f"Connection[{i}] closed")) clients.append(client) - await asyncio.sleep(0.5) # Give time for all connections to be registered + await asyncio.sleep(0.5) - # Check that all connections are registered assert len(server.get_connections()) == 5 - # Close all clients for duplex in server.get_connections(): await duplex.terminate() @@ -256,12 +227,11 @@ async def test_tcp_client_connection_tracking_self(): assert len(server.get_connections()) == 0, "Connections not clean uped" finally: - await server.close() + await server.stop() @pytest.mark.asyncio async def test_tcp_large_data_transfer(): - """Test transferring large amounts of data over TCP.""" server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() @@ -283,12 +253,10 @@ async def on_data(data: bytes): client = await TcpPipe.create_connection(server_host, server_port) await asyncio.sleep(0.1) - # Send large data (100KB instead of 1MB for faster testing) large_data = b"x" * (100 * 1024 * 1024) await client.write(large_data) - # Wait for all data to be received - timeout = 5.0 # 5 second timeout + timeout = 5.0 start_time = asyncio.get_event_loop().time() while len(received_data) < len(large_data): if asyncio.get_event_loop().time() - start_time > timeout: @@ -300,12 +268,11 @@ async def on_data(data: bytes): await client.terminate() finally: - await server.close() + await server.stop() @pytest.mark.asyncio async def test_tcp_server_context_manager(): - """Test using TcpServer as a context manager.""" async with await SimpleTcpServer.create("127.0.0.1", 0) as server: server_host, server_port = server.get_address() @@ -314,38 +281,32 @@ async def test_tcp_server_context_manager(): await client.terminate() - # Server should be closed automatically - @pytest.mark.asyncio async def test_tcp_invalid_data_types(): - """Test error handling for invalid data types.""" server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() try: client = await TcpPipe.create_connection(server_host, server_port) - # Test writing non-bytes data with pytest.raises(TypeError): - await client.write("string data") # Should be bytes + await client.write("string data") with pytest.raises(TypeError): - await client.write(123) # Should be bytes + await client.write(123) - # Test writing valid data types - await client.write(b"bytes data") # Should work - await client.write(bytearray(b"bytearray data")) # Should work + await client.write(b"bytes data") + await client.write(bytearray(b"bytearray data")) await client.terminate() finally: - await server.close() + await server.stop() @pytest.mark.asyncio async def test_tcp_inmemory_pair_bidirectional_exchange(): - """Test TcpPipe.create_inmemory_pair produces connected pipes that can exchange data.""" client_pipe, server_pipe = await TcpPipe.create_inmemory_pair() try: @@ -373,7 +334,6 @@ async def test_tcp_inmemory_pair_bidirectional_exchange(): @pytest.mark.asyncio async def test_tcp_shutdown_keeps_active_connections(): - """Test shutting down the listener doesn't drop existing connections.""" server = await SimpleTcpServer.create("127.0.0.1", 0) server_host, server_port = server.get_address() @@ -392,7 +352,6 @@ async def on_connection(tcp_pipe: TcpPipe): await server.shutdown() - # Existing connection should still be usable client_received = asyncio.Queue() server_received = asyncio.Queue() @@ -407,7 +366,6 @@ async def on_connection(tcp_pipe: TcpPipe): client_data = await asyncio.wait_for(client_received.get(), timeout=DEFAULT_TIMEOUT) assert client_data == b"still-alive-2" - # New connections should fail while listener is shut down with pytest.raises(ConnectionError) as exc_info: await TcpPipe.create_connection(server_host, server_port, timeout=0.2) error_text = str(exc_info.value).lower() @@ -418,7 +376,3 @@ async def on_connection(tcp_pipe: TcpPipe): finally: await server.stop() - - -if __name__ == "__main__": - pytest.main(["-v", __file__]) diff --git a/tests/tcp/test_tcp_pipe_errors.py b/tests/tcp/test_tcp_pipe_errors.py new file mode 100644 index 0000000..43d1940 --- /dev/null +++ b/tests/tcp/test_tcp_pipe_errors.py @@ -0,0 +1,33 @@ +import pytest + +from cbor_rpc.tcp.tcp import TcpPipe +from tests.helpers.simple_tcp_server import SimpleTcpServer + + +@pytest.mark.asyncio +async def test_tcp_pipe_error_paths(): + assert TcpPipe().get_peer_info() is None + assert TcpPipe().get_local_info() is None + + class BadWriter: + def get_extra_info(self, _name: str): + raise RuntimeError("boom") + + pipe = TcpPipe() + pipe._connected = True + pipe._writer = BadWriter() + assert pipe.get_peer_info() is None + assert pipe.get_local_info() is None + + server = await SimpleTcpServer.create("127.0.0.1", 0) + host, port = server.get_address() + await server.stop() + + with pytest.raises(ConnectionError): + await TcpPipe.create_connection(host, port) + + client, _server_pipe = await TcpPipe.create_inmemory_pair() + with pytest.raises(ConnectionError): + await client.connect("127.0.0.1", port) + + await client.terminate() diff --git a/tests/test_event_emitter.py b/tests/test_event_emitter.py deleted file mode 100644 index e1a270b..0000000 --- a/tests/test_event_emitter.py +++ /dev/null @@ -1,386 +0,0 @@ -import pytest -import asyncio -from typing import Any, Callable -from cbor_rpc.event.emitter import AbstractEmitter - - -@pytest.mark.asyncio -async def test_on_and_emit(): - """Test that subscribers registered with 'on' are called by '_emit' in registration order.""" - - class TestEmitter(AbstractEmitter): - pass - - emitter = TestEmitter() - events = [] - - async def async_handler1(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler1_{data}") - - def sync_handler1(data: Any): - events.append(f"sync_handler1_{data}") - - async def async_handler2(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler2_{data}") - - # Register subscribers - emitter.on("test", async_handler1) - emitter.on("test", sync_handler1) - emitter.on("test", async_handler2) - - # Emit event - emitter._emit("test", "event1") - await asyncio.sleep(0.02) # Allow async handlers to complete - - # Verify all subscribers ran (order may vary due to concurrency) - expected = [ - f"async_handler1_event1", - f"sync_handler1_event1", - f"async_handler2_event1", - ] - assert sorted(events) == sorted(expected), f"Expected {expected}, got {events}" - - -@pytest.mark.asyncio -async def test_pipeline_and_notify(): - """Test that pipelines run before subscribers in '_notify', respecting registration order.""" - - class TestEmitter(AbstractEmitter): - pass - - emitter = TestEmitter() - events = [] - - async def async_handler1(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler1_{data}") - - def sync_handler1(data: Any): - events.append(f"sync_handler1_{data}") - - async def async_pipeline1(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_pipeline1_{data}") - - def sync_pipeline1(data: Any): - events.append(f"sync_pipeline1_{data}") - - async def async_pipeline2(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_pipeline2_{data}") - - # Register handlers and pipelines - emitter.on("test", async_handler1) - emitter.on("test", sync_handler1) - emitter.pipeline("test", async_pipeline1) - emitter.pipeline("test", sync_pipeline1) - emitter.pipeline("test", async_pipeline2) - - # Notify event - await emitter._notify("test", "event2") - await asyncio.sleep(0.02) # Allow async pipelines and handlers to complete - - # Verify pipelines run before subscribers - expected_pipelines = [ - f"async_pipeline1_event2", - f"sync_pipeline1_event2", - f"async_pipeline2_event2", - ] - expected_subscribers = [f"async_handler1_event2", f"sync_handler1_event2"] - pipeline_indices = [events.index(e) for e in expected_pipelines if e in events] - subscriber_indices = [events.index(e) for e in expected_subscribers if e in events] - assert all( - p < s for p in pipeline_indices for s in subscriber_indices - ), f"Pipelines {expected_pipelines} should precede subscribers {expected_subscribers} in {events}" - assert sorted(events) == sorted( - expected_pipelines + expected_subscribers - ), f"Expected {expected_pipelines + expected_subscribers}, got {events}" - - -@pytest.mark.asyncio -async def test_unsubscribe(): - """Test that unsubscribing a handler removes it from the subscriber list.""" - - class TestEmitter(AbstractEmitter): - pass - - emitter = TestEmitter() - events = [] - - async def async_handler1(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler1_{data}") - - def sync_handler1(data: Any): - events.append(f"sync_handler1_{data}") - - # Register and unsubscribe - emitter.on("test", async_handler1) - emitter.on("test", sync_handler1) - emitter.unsubscribe("test", async_handler1) - - # Emit event - emitter._emit("test", "event3") - await asyncio.sleep(0.02) # Allow async handlers to complete (none in this case) - - # Verify only remaining subscriber ran - expected = [f"sync_handler1_event3"] - assert events == expected, f"Expected {expected}, got {events}" - - -@pytest.mark.asyncio -async def test_replace_on_handler(): - """Test that replace_on_handler sets only the new handler for the event.""" - - class TestEmitter(AbstractEmitter): - pass - - emitter = TestEmitter() - events = [] - - async def async_handler1(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler1_{data}") - - def sync_handler1(data: Any): - events.append(f"sync_handler1_{data}") - - # Register and replace - emitter.on("test", sync_handler1) - emitter.replace_on_handler("test", async_handler1) - - # Emit event - emitter._emit("test", "event4") - await asyncio.sleep(0.02) # Allow async handler to complete - - # Verify only the replaced handler ran - expected = [f"async_handler1_event4"] - assert events == expected, f"Expected {expected}, got {events}" - - -@pytest.mark.asyncio -async def test_pipeline_failure(): - """Test that '_notify' raises an exception if a pipeline fails and doesn't call subscribers.""" - - class TestEmitter(AbstractEmitter): - pass - - emitter = TestEmitter() - events = [] - - async def async_pipeline1(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_pipeline1_{data}") - raise ValueError("Pipeline failed") - - def sync_handler1(data: Any): - events.append(f"sync_handler1_{data}") - - # Register pipeline and subscriber - emitter.pipeline("test", async_pipeline1) - emitter.on("test", sync_handler1) - - # Notify with failing pipeline - try: - await emitter._notify("test", "event5") - assert False, "Exception not thrown" - except ValueError: - pass - - # Verify only the pipeline ran, not the subscriber - expected = [f"async_pipeline1_event5"] - assert events == expected, f"Expected {expected}, got {events}" - - -@pytest.mark.asyncio -async def test_multiple_event_types(): - """Test that only handlers for the triggered event type are called.""" - - class TestEmitter(AbstractEmitter): - pass - - emitter = TestEmitter() - events = [] - - # Handlers for event_a - async def async_handler_a(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler_a_{data}") - - def sync_handler_a(data: Any): - events.append(f"sync_handler_a_{data}") - - # Handlers for event_b - async def async_handler_b(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler_b_{data}") - - def sync_handler_b(data: Any): - events.append(f"sync_handler_b_{data}") - - # Pipelines for event_a - async def async_pipeline_a(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_pipeline_a_{data}") - - # Pipelines for event_b - def sync_pipeline_b(data: Any): - events.append(f"sync_pipeline_b_{data}") - - # Register handlers and pipelines for different event types - emitter.on("event_a", async_handler_a) - emitter.on("event_a", sync_handler_a) - emitter.pipeline("event_a", async_pipeline_a) - emitter.on("event_b", async_handler_b) - emitter.on("event_b", sync_handler_b) - emitter.pipeline("event_b", sync_pipeline_b) - - # Test _emit for event_a - events.clear() - emitter._emit("event_a", "data_a") - await asyncio.sleep(0.02) # Allow async handlers to complete - expected = [f"async_handler_a_data_a", f"sync_handler_a_data_a"] - assert sorted(events) == sorted(expected), f"Expected {expected}, got {events}" - - # Test _notify for event_a - events.clear() - await emitter._notify("event_a", "data_a2") - await asyncio.sleep(0.02) # Allow async pipelines and handlers to complete - expected_pipelines = [f"async_pipeline_a_data_a2"] - expected_subscribers = [f"async_handler_a_data_a2", f"sync_handler_a_data_a2"] - pipeline_indices = [events.index(e) for e in expected_pipelines if e in events] - subscriber_indices = [events.index(e) for e in expected_subscribers if e in events] - assert all( - p < s for p in pipeline_indices for s in subscriber_indices - ), f"Pipelines {expected_pipelines} should precede subscribers {expected_subscribers} in {events}" - assert sorted(events) == sorted( - expected_pipelines + expected_subscribers - ), f"Expected {expected_pipelines + expected_subscribers}, got {events}" - - # Test _emit for event_b - events.clear() - emitter._emit("event_b", "data_b") - await asyncio.sleep(0.02) # Allow async handlers to complete - expected = [f"async_handler_b_data_b", f"sync_handler_b_data_b"] - assert sorted(events) == sorted(expected), f"Expected {expected}, got {events}" - - # Test _notify for event_b - events.clear() - await emitter._notify("event_b", "data_b2") - await asyncio.sleep(0.02) # Allow async handlers to complete - expected_pipelines = [f"sync_pipeline_b_data_b2"] - expected_subscribers = [f"async_handler_b_data_b2", f"sync_handler_b_data_b2"] - pipeline_indices = [events.index(e) for e in expected_pipelines if e in events] - subscriber_indices = [events.index(e) for e in expected_subscribers if e in events] - assert all( - p < s for p in pipeline_indices for s in subscriber_indices - ), f"Pipelines {expected_pipelines} should precede subscribers {expected_subscribers} in {events}" - assert sorted(events) == sorted( - expected_pipelines + expected_subscribers - ), f"Expected {expected_pipelines + expected_subscribers}, got {events}" - - -@pytest.mark.asyncio -async def test_background_task_failure(): - """Test that background task failures in '_emit' don't affect other subscribers.""" - - class TestEmitter(AbstractEmitter): - pass - - emitter = TestEmitter() - events = [] - - async def async_handler1(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler1_{data}") - - async def async_handler2(data: Any): - await asyncio.sleep(0.01) - events.append(f"async_handler2_{data}") - raise ValueError("async_handler2 failed") - - def sync_handler(data: Any): - events.append(f"sync_handler_{data}") - - # Register subscribers, including one that fails - emitter.on("test", async_handler1) - emitter.on("test", async_handler2) - emitter.on("test", sync_handler) - - # Emit event - emitter._emit("test", "event6") - await asyncio.sleep(0.02) # Allow async handlers to complete - - # Verify all subscribers ran despite the failure - expected = [ - f"async_handler1_event6", - f"async_handler2_event6", - f"sync_handler_event6", - ] - assert sorted(events) == sorted(expected), f"Expected {expected}, got {events}" - - -@pytest.mark.asyncio -async def test_slow_emit_does_not_block_notify(): - """Test that a slow handler in _emit does not block a subsequent _notify call.""" - - class TestEmitter(AbstractEmitter): - pass - - emitter = TestEmitter() - events = [] - - # Slow async handler for _emit - async def slow_handler(data: Any): - await asyncio.sleep(0.8) - events.append(f"slow_handler_{data}") - - # Fast handler for _notify - def fast_notify_handler(data: Any): - events.append(f"fast_notify_handler_{data}") - - def fast_notify_pipeline(data: Any): - events.append(f"fast_notify_pipeline_{data}") - - # Register handlers - emitter.on("test_emit", slow_handler) - emitter.on("test_notify", fast_notify_handler) - emitter.pipeline("test_notify", fast_notify_pipeline) - - # Start _emit but don't wait for it to finish - emitter._emit("test_emit", "data_emit") - - # Wait briefly before triggering _notify - await asyncio.sleep(0.1) - - # Trigger and await _notify - await emitter._notify("test_notify", "data_notify") - - # Give time for _notify to complete before slow_handler finishes - await asyncio.sleep(0.4) - - # āœ… Confirm that notify events are already present - assert "fast_notify_pipeline_data_notify" in events - assert "fast_notify_handler_data_notify" in events - assert "slow_handler_data_emit" not in events # not yet complete - - # Wait for slow handler to complete - await asyncio.sleep(0.5) - - # Now validate the full event order - slow_index = events.index("slow_handler_data_emit") - pipeline_index = events.index("fast_notify_pipeline_data_notify") - handler_index = events.index("fast_notify_handler_data_notify") - - assert ( - pipeline_index < slow_index and handler_index < slow_index - ), f"_notify handlers [{pipeline_index}, {handler_index}] should run before slow _emit [{slow_index}]" - - expected = { - "fast_notify_pipeline_data_notify", - "fast_notify_handler_data_notify", - "slow_handler_data_emit", - } - assert set(events) == expected, f"Expected {expected}, got {set(events)}" diff --git a/tests/test_stdio_rpc.py b/tests/test_stdio_rpc.py deleted file mode 100644 index 46b599d..0000000 --- a/tests/test_stdio_rpc.py +++ /dev/null @@ -1,42 +0,0 @@ -import asyncio -import pytest -import sys -from cbor_rpc.stdio.stdio_pipe import StdioPipe - - -@pytest.mark.asyncio -async def test_stdtio_read_write(): - """ - Tests the StdioPipe.start_process method by writing and reading data 10 times. - """ - pipe = await StdioPipe.start_process("/bin/bash", "-c", "cat -") - - # List to collect received data - received_data = [] - future = asyncio.Future() - - def on_data(data): - received_data.append(data) - # Complete the future after receiving 10 data events - if len(received_data) == 10: - future.set_result(None) - - pipe.pipeline("data", on_data) - - # Write 10 unique data chunks - test_data = [f"Test data {i}\n".encode("utf-8") for i in range(10)] - for data in test_data: - await pipe.write(data) - # Brief sleep to allow the subprocess to process the data - await asyncio.sleep(0.01) - - # Wait for all 10 data events - await future - - # Assert that all received data matches the sent data - assert len(received_data) == 10, f"Expected 10 data events, got {len(received_data)}" - for i, (sent, received) in enumerate(zip(test_data, received_data)): - assert received == sent, f"Mismatch at index {i}: expected {sent!r}, got {received!r}" - - # Terminate the pipe - pipe.terminate() diff --git a/tests/test_cbor_transformer.py b/tests/transformer/test_cbor_transformer.py similarity index 85% rename from tests/test_cbor_transformer.py rename to tests/transformer/test_cbor_transformer.py index dbb4ce6..77f23bb 100644 --- a/tests/test_cbor_transformer.py +++ b/tests/transformer/test_cbor_transformer.py @@ -1,11 +1,13 @@ -import pytest -import pytest_asyncio import asyncio + import cbor2 -from cbor_rpc.transformer.cbor_transformer import CborTransformer, CborStreamTransformer +import pytest +import pytest_asyncio + from cbor_rpc.pipe.event_pipe import EventPipe from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException from cbor_rpc.transformer.base.event_transformer_pipe import EventTransformerPipe +from cbor_rpc.transformer.cbor_transformer import CborTransformer, CborStreamTransformer from tests.helpers.timeout_queue import TimeoutQueue @@ -32,12 +34,11 @@ def server_raw(_raw_pipe_pair): @pytest.fixture def client_cbor(client_raw): cbor_transformer = CborTransformer() - return cbor_transformer.applyTransformer(client_raw) + return cbor_transformer.apply_transformer(client_raw) @pytest.mark.asyncio class TestCborTransformer: - async def test_cbor_transformer_end_to_end_simple_dict(self, client_raw, server_raw, client_cbor): client_transformed_pipe = client_cbor @@ -48,7 +49,6 @@ async def test_cbor_transformer_end_to_end_simple_dict(self, client_raw, server_ await client_transformed_pipe.write(original_data) encoded_data_received_by_server = await received_data_queue.get() - # Verify the raw bytes received by the server are valid CBOR decoded_by_server = cbor2.loads(encoded_data_received_by_server) assert decoded_by_server == original_data @@ -66,18 +66,15 @@ async def test_cbor_transformer_decoding_error_on_read(self, server_raw, client_ error_queue = TimeoutQueue() client_transformed_pipe.on("error", error_queue.put_nowait) - # Simulate server sending incomplete CBOR bytes - incomplete_cbor_bytes = b"\x83\x01\x02" # Incomplete array, missing one element + incomplete_cbor_bytes = b"\x83\x01\x02" with pytest.raises(cbor2.CBORDecodeError): await server_raw.write(incomplete_cbor_bytes) - # The CborTransformer should now raise CBORDecodeError for incomplete data error = await asyncio.wait_for(error_queue.get(), timeout=1) assert isinstance(error, cbor2.CBORDecodeError) assert "Incomplete CBOR data for non-stream transformer" in str(error) - # Send truly invalid data - truly_invalid_cbor = b"\x1f" # Unknown unsigned integer subtype + truly_invalid_cbor = b"\x1f" with pytest.raises(cbor2.CBORDecodeError): await server_raw.write(truly_invalid_cbor) error = await asyncio.wait_for(error_queue.get(), timeout=1) @@ -90,7 +87,6 @@ async def test_cbor_transformer_non_bytes_data(self, server_raw, client_cbor): error_queue = TimeoutQueue() client_transformed_pipe.on("error", error_queue.put_nowait) - # Simulate server sending non-bytes data non_bytes_data = "not cbor" with pytest.raises(TypeError): await server_raw.write(non_bytes_data) @@ -162,10 +158,9 @@ async def test_cbor_transformer_close_propagation_and_write_after_close(self, cl @pytest.mark.asyncio class TestCborStreamTransformer: - async def test_cbor_stream_transformer_single_object(self, client_raw, server_raw): cbor_stream_transformer = CborStreamTransformer() - client_transformed_pipe = cbor_stream_transformer.applyTransformer(client_raw) + client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) received_data_queue = TimeoutQueue() client_transformed_pipe.on("data", received_data_queue.put_nowait) @@ -178,7 +173,7 @@ async def test_cbor_stream_transformer_single_object(self, client_raw, server_ra async def test_cbor_stream_transformer_concatenated_objects(self, client_raw, server_raw): cbor_stream_transformer = CborStreamTransformer() - client_transformed_pipe = cbor_stream_transformer.applyTransformer(client_raw) + client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) received_data_queue = TimeoutQueue() client_transformed_pipe.on("data", received_data_queue.put_nowait) @@ -189,14 +184,11 @@ async def test_cbor_stream_transformer_concatenated_objects(self, client_raw, se concatenated_cbor = cbor2.dumps(obj1) + cbor2.dumps(obj2) + cbor2.dumps(obj3) - # Send all concatenated data at once await server_raw.write(concatenated_cbor) - # Trigger additional decode passes to drain buffered objects await server_raw.write(b"") await server_raw.write(b"") - # Expect to receive objects one by one decoded1 = await received_data_queue.get() decoded2 = await received_data_queue.get() decoded3 = await received_data_queue.get() @@ -207,7 +199,7 @@ async def test_cbor_stream_transformer_concatenated_objects(self, client_raw, se async def test_cbor_stream_transformer_fragmented_objects(self, client_raw, server_raw): cbor_stream_transformer = CborStreamTransformer() - client_transformed_pipe = cbor_stream_transformer.applyTransformer(client_raw) + client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) received_data_queue = TimeoutQueue() client_transformed_pipe.on("data", received_data_queue.put_nowait) @@ -217,9 +209,7 @@ async def test_cbor_stream_transformer_fragmented_objects(self, client_raw, serv obj = {"long_message": "a" * 100} cbor_bytes = cbor2.dumps(obj) - # Send in fragments await server_raw.write(cbor_bytes[:10]) - # Should raise NeedsMoreDataException internally, but not emit an error with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(received_data_queue.get(), timeout=0.1) assert error_queue.empty() @@ -236,7 +226,7 @@ async def test_cbor_stream_transformer_fragmented_objects(self, client_raw, serv async def test_cbor_stream_transformer_mixed_fragmented_and_concatenated(self, client_raw, server_raw): cbor_stream_transformer = CborStreamTransformer() - client_transformed_pipe = cbor_stream_transformer.applyTransformer(client_raw) + client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) received_data_queue = TimeoutQueue() client_transformed_pipe.on("data", received_data_queue.put_nowait) @@ -249,16 +239,13 @@ async def test_cbor_stream_transformer_mixed_fragmented_and_concatenated(self, c cbor_bytes2 = cbor2.dumps(obj2) cbor_bytes3 = cbor2.dumps(obj3) - # Send first object fragmented await server_raw.write(cbor_bytes1[:5]) await server_raw.write(cbor_bytes1[5:]) decoded1 = await received_data_queue.get() assert decoded1 == obj1 - # Send second and third concatenated await server_raw.write(cbor_bytes2 + cbor_bytes3) - # Trigger an extra decode pass to drain buffered third object await server_raw.write(b"") decoded2 = await received_data_queue.get() @@ -268,7 +255,7 @@ async def test_cbor_stream_transformer_mixed_fragmented_and_concatenated(self, c async def test_cbor_stream_transformer_invalid_data_in_stream(self, client_raw, server_raw): cbor_stream_transformer = CborStreamTransformer() - client_transformed_pipe = cbor_stream_transformer.applyTransformer(client_raw) + client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) received_data_queue = TimeoutQueue() client_transformed_pipe.on("data", received_data_queue.put_nowait) @@ -276,17 +263,15 @@ async def test_cbor_stream_transformer_invalid_data_in_stream(self, client_raw, client_transformed_pipe.on("error", error_queue.put_nowait) obj1 = {"valid": True} - invalid_bytes = b"\x1f" # Unknown unsigned integer subtype + invalid_bytes = b"\x1f" obj2 = {"another": "valid"} await server_raw.write(cbor2.dumps(obj1)) await server_raw.write(invalid_bytes + cbor2.dumps(obj2)) - # First valid object should be decoded decoded1 = await received_data_queue.get() assert decoded1 == obj1 - # The transformer should recover and decode the next valid object without emitting an error decoded2 = await received_data_queue.get() assert decoded2 == obj2 @@ -295,12 +280,11 @@ async def test_cbor_stream_transformer_invalid_data_in_stream(self, client_raw, async def test_cbor_stream_transformer_non_bytes_data(self, client_raw, server_raw): cbor_stream_transformer = CborStreamTransformer() - client_transformed_pipe = cbor_stream_transformer.applyTransformer(client_raw) + client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) error_queue = TimeoutQueue() client_transformed_pipe.on("error", error_queue.put_nowait) - # Simulate server sending non-bytes data non_bytes_data = "not cbor" with pytest.raises(TypeError): await server_raw.write(non_bytes_data) @@ -311,7 +295,7 @@ async def test_cbor_stream_transformer_non_bytes_data(self, client_raw, server_r async def test_cbor_stream_transformer_close_propagation_and_write_after_close(self, client_raw): cbor_stream_transformer = CborStreamTransformer() - client_transformed_pipe = cbor_stream_transformer.applyTransformer(client_raw) + client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) close_queue = TimeoutQueue() client_transformed_pipe.on("close", lambda *args: close_queue.put_nowait(True)) @@ -322,3 +306,32 @@ async def test_cbor_stream_transformer_close_propagation_and_write_after_close(s result = await client_transformed_pipe.write({"after": "close"}) assert result is False + + +@pytest.mark.asyncio +async def test_cbor_stream_transformer_paths(monkeypatch): + transformer = CborStreamTransformer() + + with pytest.raises(NeedsMoreDataException): + await transformer.decode(None) + + with pytest.raises(TypeError): + await transformer.decode("bad") + + good = cbor2.dumps({"a": 1}) + recovered = await transformer.decode(b"\xff" + good) + assert recovered == {"a": 1} + + with pytest.raises(cbor2.CBORDecodeError): + await transformer.decode(b"\xff") + + class BadDecoder: + def __init__(self, _stream): + pass + + def decode(self): + raise ValueError("boom") + + monkeypatch.setattr(cbor2, "CBORDecoder", BadDecoder) + with pytest.raises(ValueError): + await transformer.decode(good) diff --git a/tests/transformer/test_event_transformer_pipe.py b/tests/transformer/test_event_transformer_pipe.py new file mode 100644 index 0000000..3376c2f --- /dev/null +++ b/tests/transformer/test_event_transformer_pipe.py @@ -0,0 +1,83 @@ +import asyncio +from typing import Any, List + +import pytest + +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException +from cbor_rpc.transformer.base.event_transformer_pipe import EventTransformerPipe +from cbor_rpc.transformer.base.transformer_base import AsyncTransformer + + +class DummyAsyncTransformer(AsyncTransformer[Any, Any]): + async def encode(self, data: Any) -> Any: + if data == "encode_error": + raise ValueError("encode boom") + return f"enc:{data}" + + async def decode(self, data: Any) -> Any: + if asyncio.iscoroutine(data): + data = await data + if data == "need_more": + raise NeedsMoreDataException() + if data == "decode_error": + raise ValueError("decode boom") + return f"dec:{data}" + + +@pytest.mark.asyncio +async def test_event_transformer_pipe_data_and_errors(): + base_a, base_b = EventPipe.create_inmemory_pair() + + class EventTransformer(DummyAsyncTransformer): + async def decode(self, data: Any) -> Any: + if data == "need_more": + raise NeedsMoreDataException() + if data == "decode_error": + raise ValueError("boom") + return await super().decode(data) + + transformer = EventTransformer() + tpipe = EventTransformerPipe(base_a, transformer) + + received: List[Any] = [] + errors: List[Any] = [] + + def on_data(data: Any) -> None: + received.append(data) + + def on_error(err: Exception) -> None: + errors.append(err) + + tpipe.on("data", on_data) + tpipe.on("error", on_error) + + await base_b.write("need_more") + await asyncio.sleep(0.01) + assert received == [] + + await base_b.write("ok") + await asyncio.sleep(0.01) + assert received == ["dec:ok"] + + with pytest.raises(ValueError): + await base_b.write("decode_error") + await asyncio.sleep(0.01) + assert errors + + +@pytest.mark.asyncio +async def test_event_transformer_pipe_write_error(): + base_a, _base_b = EventPipe.create_inmemory_pair() + transformer = DummyAsyncTransformer() + tpipe = EventTransformerPipe(base_a, transformer) + + errors: List[Any] = [] + + def on_error(err: Exception) -> None: + errors.append(err) + + tpipe.on("error", on_error) + ok = await tpipe.write("encode_error") + assert ok is False + assert errors diff --git a/tests/test_json_transformer.py b/tests/transformer/test_json_transformer.py similarity index 63% rename from tests/test_json_transformer.py rename to tests/transformer/test_json_transformer.py index 7ae6b6f..f6142aa 100644 --- a/tests/test_json_transformer.py +++ b/tests/transformer/test_json_transformer.py @@ -1,13 +1,14 @@ -import pytest -import json import asyncio -from cbor_rpc.tcp.tcp import TcpPipe -from cbor_rpc.transformer.json_transformer import JsonTransformer -from cbor_rpc.pipe.event_pipe import EventPipe +import json + +import pytest + from cbor_rpc.pipe.aio_pipe import AioPipe -from tests.helpers.simple_pipe import SimplePipe +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.tcp.tcp import TcpPipe from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException from cbor_rpc.transformer.base.event_transformer_pipe import EventTransformerPipe +from cbor_rpc.transformer.json_transformer import JsonTransformer DEFAULT_TIMEOUT = 2.0 @@ -35,7 +36,7 @@ async def pipe_pair(request): async def json_pipe(pipe_pair): client_raw_pipe, server_raw_pipe = pipe_pair json_transformer = JsonTransformer() - client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + client_transformed_pipe = json_transformer.apply_transformer(client_raw_pipe) return client_raw_pipe, server_raw_pipe, client_transformed_pipe, json_transformer @@ -43,68 +44,44 @@ async def json_pipe(pipe_pair): async def json_pipe_ascii(pipe_pair): client_raw_pipe, server_raw_pipe = pipe_pair json_transformer = JsonTransformer(encoding="ascii") - client_transformed_pipe = json_transformer.applyTransformer(client_raw_pipe) + client_transformed_pipe = json_transformer.apply_transformer(client_raw_pipe) return client_raw_pipe, server_raw_pipe, client_transformed_pipe, json_transformer @pytest.mark.asyncio class TestJsonTransformerPipeInteraction: - async def test_json_transformer_end_to_end_simple_dict(self, json_pipe): client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe assert isinstance(client_transformed_pipe, EventTransformerPipe) - # Use a queue to capture data emitted by the server_raw_pipe received_data_queue = asyncio.Queue() - server_raw_pipe.on("error", lambda e: print("Server raw pipe error:", e)) - server_raw_pipe.on("data", received_data_queue.put_nowait) - # Data to send original_data = {"message": "Hello, world!", "number": 123} - # Write original_data to the transformed client pipe - # This should encode the data and send it through client_raw_pipe to server_raw_pipe await client_transformed_pipe.write(original_data) - # Wait for the encoded data to arrive at the server_raw_pipe - # The server_raw_pipe receives the *encoded* data (bytes) - try: - # Wait for data for up to 5 seconds - encoded_data_received_by_server = await asyncio.wait_for(received_data_queue.get(), timeout=2.0) - except asyncio.TimeoutError: - print("No data received within 5 seconds") - assert False, "Test failed due to timeout waiting for data" + encoded_data_received_by_server = await asyncio.wait_for(received_data_queue.get(), timeout=2.0) - # Manually decode the data received by the server_raw_pipe to verify it's JSON bytes decoded_by_server = json.loads(encoded_data_received_by_server.decode("utf-8")) assert decoded_by_server == original_data - # Now, let's test the reverse: server sends data, client receives decoded data - # Use a queue to capture data emitted by the client_transformed_pipe client_received_data_queue = asyncio.Queue() client_transformed_pipe.on("data", client_received_data_queue.put_nowait) - # Data to send from server response_data = {"status": "success", "code": 200} - - # Server_raw_pipe writes the *encoded* data (as if it received it from a client and is sending a response) - # This data will go through client_raw_pipe and then be decoded by client_transformed_pipe await server_raw_pipe.write(json.dumps(response_data).encode("utf-8")) - # Wait for the decoded data to arrive at the client_transformed_pipe decoded_data_received_by_client = await asyncio.wait_for( client_received_data_queue.get(), timeout=2.0, ) assert decoded_data_received_by_client == response_data - # Clean up await client_raw_pipe.terminate() await server_raw_pipe.terminate() async def test_json_transformer_end_to_end_unicode_characters(self, json_pipe): - ## TODO: This is taking forever when using TCP pipes, investigate why. It works fine with EventPipe and AioPipe. client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe received_data_queue = asyncio.Queue() @@ -130,38 +107,30 @@ async def test_json_transformer_end_to_end_unicode_characters(self, json_pipe): assert decoded_data_received_by_client == response_data async def test_json_transformer_encoding_error_on_write(self, json_pipe_ascii): - # Use an encoding that cannot handle certain characters client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe_ascii - original_data = {"message": "Hello, world! šŸ‘‹"} # Contains non-ASCII character + original_data = {"message": "Hello, world! šŸ‘‹"} - # Use a queue to capture errors emitted by the transformed pipe error_queue = asyncio.Queue() client_transformed_pipe.on("error", error_queue.put_nowait) - # Writing this data should cause an encoding error to be emitted await client_transformed_pipe.write(original_data) - # Assert that a UnicodeEncodeError is received error = await asyncio.wait_for(error_queue.get(), timeout=DEFAULT_TIMEOUT) assert isinstance(error, UnicodeEncodeError) async def test_json_transformer_decoding_error_on_read(self, json_pipe): client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe - # Use a queue to capture errors emitted by the transformed pipe error_queue = asyncio.Queue() client_transformed_pipe.on("error", error_queue.put_nowait) - # Simulate server sending invalid JSON bytes - invalid_json_bytes = b'{,"key": "value",}' # Invalid JSON + invalid_json_bytes = b'{,"key": "value",}' try: await server_raw_pipe.write(invalid_json_bytes) except json.JSONDecodeError: - # EventPipe/AioPipe may raise from the pipeline while still emitting the error. pass - # The transformed pipe should emit an error when trying to decode error = await asyncio.wait_for(error_queue.get(), timeout=DEFAULT_TIMEOUT) assert isinstance(error, json.JSONDecodeError) @@ -171,16 +140,13 @@ async def test_json_transformer_decoding_type_error_on_read(self, json_pipe): error_queue = asyncio.Queue() client_transformed_pipe.on("error", error_queue.put_nowait) - # Simulate server sending non-bytes/str data (e.g., an int) non_string_data = 12345 try: - await server_raw_pipe.write(non_string_data) # This will pass through raw pipe as is + await server_raw_pipe.write(non_string_data) except TypeError as exc: - # TcpPipe enforces bytes-only writes; no error will be emitted by the transformer. assert isinstance(exc, TypeError) return - # The transformed pipe should emit a TypeError when trying to decode error = await asyncio.wait_for(error_queue.get(), timeout=DEFAULT_TIMEOUT) assert isinstance(error, TypeError) assert "Expected bytes or str" in str(error) @@ -188,72 +154,60 @@ async def test_json_transformer_decoding_type_error_on_read(self, json_pipe): async def test_json_transformer_non_json_serializable_data(self, json_pipe): client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe - # Data that is not JSON serializable non_serializable_data = {"set_data": {1, 2, 3}} - # Use a queue to capture errors emitted by the transformed pipe error_queue = asyncio.Queue() client_transformed_pipe.on("error", error_queue.put_nowait) - # Writing this data should cause a TypeError to be emitted await client_transformed_pipe.write(non_serializable_data) - # Assert that a TypeError is received error = await asyncio.wait_for(error_queue.get(), timeout=DEFAULT_TIMEOUT) assert isinstance(error, TypeError) async def test_json_transformer_pipe_termination(self, json_pipe): - client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe + client_raw_pipe, _server_raw_pipe, client_transformed_pipe, _ = json_pipe - # Listen for close event on the transformed pipe close_event_received = asyncio.Event() client_transformed_pipe.on("close", lambda: close_event_received.set()) - # Terminate the underlying raw pipe await client_raw_pipe.terminate() - # The transformed pipe should also terminate and emit a close event await asyncio.wait_for(close_event_received.wait(), timeout=DEFAULT_TIMEOUT) - # server_raw_pipe is terminated by the fixture async def test_json_transformer_pipe_write_after_termination(self, json_pipe): - client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe + client_raw_pipe, _server_raw_pipe, client_transformed_pipe, _ = json_pipe await client_raw_pipe.terminate() - # Writing to a terminated transformed pipe should return False result = await client_transformed_pipe.write({"test": "data"}) assert result is False - # server_raw_pipe is terminated by the fixture - async def test_json_transformer_pipe_read_after_termination(self, json_pipe): - client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe + _client_raw_pipe, server_raw_pipe, client_transformed_pipe, _ = json_pipe - # Listen for data on the transformed pipe data_queue = asyncio.Queue() client_transformed_pipe.pipeline("data", data_queue.put_nowait) - # Terminate the server_raw_pipe, which should cause the client_transformed_pipe to terminate - - # The transformed pipe should eventually close and not emit new data close_event_received = asyncio.Event() client_transformed_pipe.on("close", lambda: close_event_received.set()) await server_raw_pipe.terminate() await asyncio.wait_for(close_event_received.wait(), timeout=DEFAULT_TIMEOUT) - # Try to write to the raw pipe from the server side after termination - # This data should not be processed by the transformed pipe try: result = await server_raw_pipe.write(b'{"should": "not_receive"}') assert result is False except ConnectionError: - # TcpPipe raises if not connected after termination. pass - # Ensure no data is received by the transformed pipe after termination with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(data_queue.get(), timeout=0.1) - # client_raw_pipe is terminated by the fixture + +def test_json_transformer_decode_variants(): + transformer = JsonTransformer() + assert transformer.decode(b"{\"x\": 1}") == {"x": 1} + assert transformer.decode("{\"y\": 2}") == {"y": 2} + + with pytest.raises(TypeError): + transformer.decode(123) diff --git a/tests/transformer/test_transformer_base.py b/tests/transformer/test_transformer_base.py new file mode 100644 index 0000000..d7500cf --- /dev/null +++ b/tests/transformer/test_transformer_base.py @@ -0,0 +1,62 @@ +import pytest + +from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.pipe.pipe import Pipe +from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException +from cbor_rpc.transformer.base.event_transformer_pipe import EventTransformerPipe +from cbor_rpc.transformer.base.transformer_base import AsyncTransformer, Transformer +from cbor_rpc.transformer.base.transformer_pipe import TransformerPipe + + +class DummyAsyncTransformer(AsyncTransformer): + async def encode(self, data): + return f"enc:{data}" + + async def decode(self, data): + return f"dec:{data}" + + +class DummySyncTransformer(Transformer): + def encode(self, data): + return f"enc:{data}" + + def decode(self, data): + return f"dec:{data}" + + +class DummyPipe(Pipe): + async def write(self, _chunk): + return True + + async def read(self, _timeout=None): + return None + + async def terminate(self, *args): + return None + + +def test_transformer_base_apply_transformer_invalid_type(): + transformer = DummySyncTransformer() + with pytest.raises(TypeError): + transformer.apply_transformer(object()) + + +def test_transformer_base_apply_transformer_variants(): + transformer = DummySyncTransformer() + pipe = DummyPipe() + event_pipe_a, _event_pipe_b = EventPipe.create_inmemory_pair() + + bound_pipe = transformer.apply_transformer(pipe) + bound_event_pipe = transformer.apply_transformer(event_pipe_a) + + assert isinstance(bound_pipe, TransformerPipe) + assert isinstance(bound_event_pipe, EventTransformerPipe) + + +def test_transformer_base_wait_next_data_raises(): + transformer = DummySyncTransformer() + async_transformer = DummyAsyncTransformer() + with pytest.raises(NeedsMoreDataException): + transformer.wait_next_data() + with pytest.raises(NeedsMoreDataException): + async_transformer.wait_next_data() diff --git a/tests/transformer/test_transformer_pipe.py b/tests/transformer/test_transformer_pipe.py new file mode 100644 index 0000000..b670049 --- /dev/null +++ b/tests/transformer/test_transformer_pipe.py @@ -0,0 +1,216 @@ +import asyncio +from typing import Any, List, Optional + +import pytest + +from cbor_rpc.pipe.pipe import Pipe +from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException +from cbor_rpc.transformer.base.transformer_base import AsyncTransformer +from cbor_rpc.transformer.base.transformer_pipe import TransformerPipe + + +class DummyPipe(Pipe[Any, Any]): + def __init__(self): + super().__init__() + self._queue: asyncio.Queue = asyncio.Queue() + self._closed = False + self.writes: List[Any] = [] + self.terminated = False + + async def write(self, chunk: Any) -> bool: + if self._closed: + return False + self.writes.append(chunk) + return True + + async def read(self, timeout: float = None) -> Any: + if self._closed: + return None + try: + if timeout is None: + return await self._queue.get() + return await asyncio.wait_for(self._queue.get(), timeout) + except asyncio.TimeoutError: + return None + + async def terminate(self, *args: Any) -> None: + if self._closed: + return + self._closed = True + self.terminated = True + await self._queue.put(None) + self._emit("close", *args) + + async def push(self, item: Any) -> None: + await self._queue.put(item) + + +class DummyAsyncTransformer(AsyncTransformer[Any, Any]): + async def encode(self, data: Any) -> Any: + if data == "encode_error": + raise ValueError("encode boom") + return f"enc:{data}" + + async def decode(self, data: Any) -> Any: + if asyncio.iscoroutine(data): + data = await data + if data == "need_more": + raise NeedsMoreDataException() + if data == "decode_error": + raise ValueError("decode boom") + return f"dec:{data}" + + +@pytest.mark.asyncio +async def test_transformer_pipe_read_write_and_errors(): + pipe = DummyPipe() + transformer = DummyAsyncTransformer() + tpipe = TransformerPipe(pipe, transformer) + + errors: List[Any] = [] + + def on_error(err: Exception) -> None: + errors.append(err) + + tpipe.on("error", on_error) + + ok = await tpipe.write("value") + assert ok is True + assert pipe.writes == ["enc:value"] + + ok = await tpipe.write("encode_error") + assert ok is False + assert errors + + await pipe.push("need_more") + await pipe.push("data") + result = await tpipe.read(timeout=0.1) + assert result == "dec:data" + + await pipe.push("need_more") + result = await tpipe.read(timeout=0.01) + assert result is None + + await pipe.push("decode_error") + result = await tpipe.read(timeout=0.1) + assert result is None + + await tpipe.terminate() + assert pipe._closed is True + + +@pytest.mark.asyncio +async def test_transformer_pipe_error_event_closes_pipe(): + pipe = DummyPipe() + transformer = DummyAsyncTransformer() + tpipe = TransformerPipe(pipe, transformer) + + pipe._emit("error", RuntimeError("boom")) + + await asyncio.sleep(0) + assert tpipe._closed is True + + +@pytest.mark.asyncio +async def test_transformer_pipe_write_encode_error_emits_error(): + class BadEncodeTransformer: + async def encode(self, value: Any) -> Any: + raise ValueError("encode-fail") + + async def decode(self, value: Any) -> Any: + return value + + pipe = DummyPipe() + transformer = TransformerPipe(pipe, BadEncodeTransformer()) + errors: List[str] = [] + + def on_error(err: Exception) -> None: + errors.append(str(err)) + + transformer.on("error", on_error) + result = await transformer.write("data") + assert result is False + assert errors == ["encode-fail"] + + +@pytest.mark.asyncio +async def test_transformer_pipe_read_needs_more_data_timeout(): + class NeedMoreTransformer: + async def encode(self, value: Any) -> Any: + return value + + async def decode(self, value: Any) -> Any: + raise NeedsMoreDataException() + + pipe = DummyPipe() + transformer = TransformerPipe(pipe, NeedMoreTransformer()) + await pipe.push("chunk") + result = await transformer.read(timeout=0) + assert result is None + + +@pytest.mark.asyncio +async def test_transformer_pipe_read_decode_error_emits_error(): + class BadDecodeTransformer: + async def encode(self, value: Any) -> Any: + return value + + async def decode(self, value: Any) -> Any: + raise ValueError("decode-fail") + + pipe = DummyPipe() + transformer = TransformerPipe(pipe, BadDecodeTransformer()) + errors: List[str] = [] + + def on_error(err: Exception) -> None: + errors.append(str(err)) + + transformer.on("error", on_error) + await pipe.push("chunk") + result = await transformer.read(timeout=0.1) + assert result is None + assert errors == ["decode-fail"] + + +@pytest.mark.asyncio +async def test_transformer_pipe_close_error_propagation_and_terminate(): + class PassThroughTransformer: + async def encode(self, value: Any) -> Any: + return value + + async def decode(self, value: Any) -> Any: + return value + + pipe = DummyPipe() + transformer = TransformerPipe(pipe, PassThroughTransformer()) + closed = asyncio.Event() + errors: List[str] = [] + propagated: List[str] = [] + + def on_close(*_args: Any) -> None: + closed.set() + + def on_error(err: Exception) -> None: + errors.append(str(err)) + + def on_pipe_error(err: Exception) -> None: + propagated.append(str(err)) + + transformer.on("close", on_close) + transformer.on("error", on_error) + pipe.on("error", on_pipe_error) + + pipe._emit("error", Exception("pipe-error")) + await asyncio.sleep(0.01) + assert errors == ["pipe-error"] + assert transformer._closed is True + + transformer._propagate_error(Exception("propagate")) + assert propagated[-1:] == ["propagate"] + + await transformer.terminate() + await transformer.terminate() + assert pipe.terminated is True + + pipe._emit("close", "done") + await asyncio.wait_for(closed.wait(), timeout=1) From 48c430d50a2b9e82b526ac6eb49e890b3a7ba76b Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Sun, 8 Feb 2026 16:13:52 +0545 Subject: [PATCH 19/25] Fix dependencies for test --- conftest.py | 5 +++++ pyproject.toml | 3 ++- requirements.txt | 27 ++++++++++++++++++++++----- setup.py | 2 +- tests/stdio/test_stdio_pipe.py | 4 ++-- 5 files changed, 32 insertions(+), 9 deletions(-) create mode 100644 conftest.py diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..242f464 --- /dev/null +++ b/conftest.py @@ -0,0 +1,5 @@ +import sys +import os + +# Ensure the project root is in sys.path +sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) diff --git a/pyproject.toml b/pyproject.toml index 71ab03d..50fc97d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,8 @@ dependencies = [ "pytest-cov>=5.0.0", "asyncssh>=2.14.0", "bcrypt", - "cbor2" + "cbor2", + "docker" ] [build-system] diff --git a/requirements.txt b/requirements.txt index 8908293..5c32029 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,14 +11,22 @@ bcrypt==5.0.0 # via cbor-rpc (pyproject.toml) cbor2==5.8.0 # via cbor-rpc (pyproject.toml) -coverage==7.6.4 - # via pytest-cov +certifi==2026.1.4 + # via requests cffi==2.0.0 # via cryptography +charset-normalizer==3.4.4 + # via requests +coverage[toml]==7.6.4 + # via pytest-cov cryptography==46.0.4 # via asyncssh +docker==7.1.0 + # via cbor-rpc (pyproject.toml) exceptiongroup==1.3.1 # via pytest +idna==3.11 + # via requests iniconfig==2.1.0 # via pytest packaging==25.0 @@ -33,14 +41,23 @@ pytest==8.4.0 # via # cbor-rpc (pyproject.toml) # pytest-asyncio -pytest-cov==5.0.0 - # via cbor-rpc (pyproject.toml) + # pytest-cov pytest-asyncio==1.0.0 # via cbor-rpc (pyproject.toml) +pytest-cov==5.0.0 + # via cbor-rpc (pyproject.toml) +requests==2.32.5 + # via docker tomli==2.4.0 - # via pytest + # via + # coverage + # pytest typing-extensions==4.15.0 # via # asyncssh # cryptography # exceptiongroup +urllib3==2.6.3 + # via + # docker + # requests diff --git a/setup.py b/setup.py index e290e9e..29a6b5c 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ packages=find_packages(include=["cbor_rpc", "cbor_rpc.*"]), install_requires=["asyncssh>=2.14.0", "bcrypt", "cbor2"], extras_require={ - "test": ["pytest>=8.3.2", "pytest-asyncio>=0.24.0", "pytest-cov>=5.0.0"], + "test": ["pytest>=8.3.2", "pytest-asyncio>=0.24.0", "pytest-cov>=5.0.0", "docker"], }, python_requires=">=3.8", classifiers=[ diff --git a/tests/stdio/test_stdio_pipe.py b/tests/stdio/test_stdio_pipe.py index b2c9cdd..4da2b9e 100644 --- a/tests/stdio/test_stdio_pipe.py +++ b/tests/stdio/test_stdio_pipe.py @@ -27,7 +27,7 @@ async def test_stdio_pipe_errors_without_process(): @pytest.mark.asyncio async def test_stdio_pipe_start_process_and_terminate(): pipe = await StdioPipe.start_process(sys.executable, "-c", "import time; time.sleep(0.2)") - pipe.terminate() + await pipe.terminate() code = await pipe.wait_for_process_termination() assert isinstance(code, int) @@ -57,4 +57,4 @@ def on_data(data): for i, (sent, received) in enumerate(zip(test_data, received_data)): assert received == sent, f"Mismatch at index {i}: expected {sent!r}, got {received!r}" - pipe.terminate() + await pipe.terminate() From 22542af150d9df16d0d95c50c00cdf4e56aef550 Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Sun, 8 Feb 2026 16:52:15 +0545 Subject: [PATCH 20/25] Fix notify in emitter --- cbor_rpc/event/emitter.py | 28 ++++------- .../base/event_transformer_pipe.py | 2 +- tests/event/test_emitter.py | 50 +++++++++++++------ 3 files changed, 47 insertions(+), 33 deletions(-) diff --git a/cbor_rpc/event/emitter.py b/cbor_rpc/event/emitter.py index 970ceaa..1c2a389 100644 --- a/cbor_rpc/event/emitter.py +++ b/cbor_rpc/event/emitter.py @@ -48,26 +48,18 @@ def _emit(self, event_type: str, *args: Any) -> None: warnings.warn(f"Synchronous error in handler: {e}", RuntimeWarning) async def _notify(self, event_type: str, *args: Any) -> None: - tasks = [] - + """ + + """ for pipeline in self._pipelines.get(event_type, []): - if inspect.iscoroutinefunction(pipeline): - task = asyncio.create_task(pipeline(*args)) - tasks.append(task) - else: - try: + try: + if inspect.iscoroutinefunction(pipeline): + await pipeline(*args) + else: pipeline(*args) - except Exception as e: - self._emit("error", e) - raise e - - if tasks: - results = await asyncio.gather(*tasks, return_exceptions=True) - for result in results: - if isinstance(result, Exception): - self._emit("error", result) - # Must raise error to notify caller and block normal execution - raise result + except Exception as e: + self._emit("error", e) + raise e self._emit(event_type, *args) diff --git a/cbor_rpc/transformer/base/event_transformer_pipe.py b/cbor_rpc/transformer/base/event_transformer_pipe.py index 55a3da0..7180c5a 100644 --- a/cbor_rpc/transformer/base/event_transformer_pipe.py +++ b/cbor_rpc/transformer/base/event_transformer_pipe.py @@ -31,7 +31,7 @@ def __init__(self, pipe: EventPipe[T2, T2], transformer: "Transformer"): async def _handle_data(self, data: T2): try: decoded = await self.decode(data) - self._emit("data", decoded) + await self._notify("data", decoded) except NeedsMoreDataException: # If more data is needed, simply return and wait for the next chunk return diff --git a/tests/event/test_emitter.py b/tests/event/test_emitter.py index 4eb3331..9fdc63f 100644 --- a/tests/event/test_emitter.py +++ b/tests/event/test_emitter.py @@ -1,4 +1,5 @@ import asyncio +import time from typing import Any import pytest @@ -53,42 +54,55 @@ async def async_handler1(data: Any): events.append(f"async_handler1_{data}") def sync_handler1(data: Any): + time.sleep(0.01) events.append(f"sync_handler1_{data}") async def async_pipeline1(data: Any): await asyncio.sleep(0.01) events.append(f"async_pipeline1_{data}") - def sync_pipeline1(data: Any): - events.append(f"sync_pipeline1_{data}") - async def async_pipeline2(data: Any): - await asyncio.sleep(0.01) + # Slow pipeline + await asyncio.sleep(0.05) events.append(f"async_pipeline2_{data}") + async def async_pipeline3(data: Any): + # Fast pipeline + await asyncio.sleep(0.01) + events.append(f"async_pipeline3_{data}") + emitter.on("test", async_handler1) emitter.on("test", sync_handler1) emitter.pipeline("test", async_pipeline1) - emitter.pipeline("test", sync_pipeline1) emitter.pipeline("test", async_pipeline2) + emitter.pipeline("test", async_pipeline3) await emitter._notify("test", "event2") await asyncio.sleep(0.02) + # Strictly sequential execution order: + # 1. async_pipeline1 (waits 0.01) + # 2. async_pipeline2 (waits 0.05) -> If concurrent, this would finish LAST + # 3. async_pipeline3 (waits 0.01) -> If concurrent, this would finish BEFORE pipeline2 + expected_pipelines = [ "async_pipeline1_event2", - "sync_pipeline1_event2", "async_pipeline2_event2", + "async_pipeline3_event2", ] + + # Verify pipelines ran in strict order + assert events[:3] == expected_pipelines, f"Pipelines expected {expected_pipelines}, got {events[:3]}" + expected_subscribers = ["async_handler1_event2", "sync_handler1_event2"] - pipeline_indices = [events.index(e) for e in expected_pipelines if e in events] - subscriber_indices = [events.index(e) for e in expected_subscribers if e in events] - assert all( - p < s for p in pipeline_indices for s in subscriber_indices - ), f"Pipelines {expected_pipelines} should precede subscribers {expected_subscribers} in {events}" - assert sorted(events) == sorted( - expected_pipelines + expected_subscribers - ), f"Expected {expected_pipelines + expected_subscribers}, got {events}" + + # Verify handlers ran after pipelines + assert set(events[3:]) == set(expected_subscribers), f"Subscribers expected {expected_subscribers}, got {events[3:]}" + + # Explicitly assert precedence + for p in expected_pipelines: + for s in expected_subscribers: + assert events.index(p) < events.index(s), f"Pipeline {p} executed after subscriber {s}" @pytest.mark.asyncio @@ -154,11 +168,15 @@ async def async_pipeline1(data: Any): await asyncio.sleep(0.01) events.append(f"async_pipeline1_{data}") raise ValueError("Pipeline failed") + + async def async_pipeline2(data: Any): + events.append(f"async_pipeline2_{data}") def sync_handler1(data: Any): events.append(f"sync_handler1_{data}") emitter.pipeline("test", async_pipeline1) + emitter.pipeline("test", async_pipeline2) emitter.on("test", sync_handler1) with pytest.raises(ValueError): @@ -167,6 +185,10 @@ def sync_handler1(data: Any): expected = ["async_pipeline1_event5"] assert events == expected, f"Expected {expected}, got {events}" + # Explicitly verify skipped execution + assert "async_pipeline2_event5" not in events, "Subsequent pipeline should be skipped" + assert "sync_handler1_event5" not in events, "Event subscribers should be skipped" + @pytest.mark.asyncio async def test_multiple_event_types(): From 0c115d5363e945f324b121a41214fe4e920ad3e7 Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Sun, 8 Feb 2026 17:02:56 +0545 Subject: [PATCH 21/25] EventTransformer how reads in loop --- .../base/event_transformer_pipe.py | 19 +++++++++++++++++++ tests/transformer/test_cbor_transformer.py | 4 ++++ .../test_event_transformer_pipe.py | 7 +++++-- 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/cbor_rpc/transformer/base/event_transformer_pipe.py b/cbor_rpc/transformer/base/event_transformer_pipe.py index 7180c5a..da13748 100644 --- a/cbor_rpc/transformer/base/event_transformer_pipe.py +++ b/cbor_rpc/transformer/base/event_transformer_pipe.py @@ -40,6 +40,25 @@ async def _handle_data(self, data: T2): # which will catch it and emit the "error" event. raise e + # Try to decode more from the buffer if the transformer supports it + while True: + try: + # We pass None to indicate "no new data, just decode from buffer" + # We catch TypeError in case the transformer does not support None/buffering + decoded = await self.decode(None) + await self._notify("data", decoded) + except NeedsMoreDataException: + break + except TypeError: + # Transformer likely doesn't support None (not a stream transformer) + break + except Exception as e: + # Other errors in subsequent decoding should probably be reported. + # However, since the *primary* data was processed, maybe we should just emit error? + # or raise? Raising here might be outside the context of the initial caller if we were async... + # But we are inside _handle_data which is awaited by _notify. + raise e + def _on_close(self, *args: Any): self._emit("close", *args) diff --git a/tests/transformer/test_cbor_transformer.py b/tests/transformer/test_cbor_transformer.py index 77f23bb..05beb55 100644 --- a/tests/transformer/test_cbor_transformer.py +++ b/tests/transformer/test_cbor_transformer.py @@ -210,11 +210,13 @@ async def test_cbor_stream_transformer_fragmented_objects(self, client_raw, serv cbor_bytes = cbor2.dumps(obj) await server_raw.write(cbor_bytes[:10]) + await asyncio.sleep(0.05) with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(received_data_queue.get(), timeout=0.1) assert error_queue.empty() await server_raw.write(cbor_bytes[10:50]) + await asyncio.sleep(0.05) with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(received_data_queue.get(), timeout=0.1) assert error_queue.empty() @@ -240,10 +242,12 @@ async def test_cbor_stream_transformer_mixed_fragmented_and_concatenated(self, c cbor_bytes3 = cbor2.dumps(obj3) await server_raw.write(cbor_bytes1[:5]) + await asyncio.sleep(0.05) await server_raw.write(cbor_bytes1[5:]) decoded1 = await received_data_queue.get() assert decoded1 == obj1 + await asyncio.sleep(0.05) await server_raw.write(cbor_bytes2 + cbor_bytes3) await server_raw.write(b"") diff --git a/tests/transformer/test_event_transformer_pipe.py b/tests/transformer/test_event_transformer_pipe.py index 3376c2f..5a56b17 100644 --- a/tests/transformer/test_event_transformer_pipe.py +++ b/tests/transformer/test_event_transformer_pipe.py @@ -16,8 +16,11 @@ async def encode(self, data: Any) -> Any: return f"enc:{data}" async def decode(self, data: Any) -> Any: - if asyncio.iscoroutine(data): - data = await data + if data is None: + # Mimic standard transformer behavior which usually expects specific type input + # and rejects None if it doesn't support buffering/streaming check via None. + raise TypeError("Unexpected None input") + if data == "need_more": raise NeedsMoreDataException() if data == "decode_error": From c36ce9fd2cda5a97b474ca5f76057c83f4d0024f Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Sun, 8 Feb 2026 17:14:12 +0545 Subject: [PATCH 22/25] Use Python native decoder of cbor2 --- cbor_rpc/transformer/cbor_transformer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/cbor_rpc/transformer/cbor_transformer.py b/cbor_rpc/transformer/cbor_transformer.py index 2db27bf..b97e509 100644 --- a/cbor_rpc/transformer/cbor_transformer.py +++ b/cbor_rpc/transformer/cbor_transformer.py @@ -4,6 +4,13 @@ from .base import Transformer, AsyncTransformer from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException +# We import the Python implementation of CBORDecoder ensuring we avoid buffering issues +# that can occur with the C extension when using stream slicing logic. +try: + from cbor2._decoder import CBORDecoder as PythonCBORDecoder +except ImportError: + from cbor2 import CBORDecoder as PythonCBORDecoder + class CborTransformer(Transformer[Any, Any]): """ @@ -64,7 +71,7 @@ async def decode(self, data: Union[bytes, None]) -> Any: try: stream = BytesIO(self._buffer) - decoder = cbor2.CBORDecoder(stream) + decoder = PythonCBORDecoder(stream) decoded_data = decoder.decode() if decoded_data is cbor2.break_marker: raise cbor2.CBORDecodeError("Unexpected break marker") @@ -87,7 +94,7 @@ async def decode(self, data: Union[bytes, None]) -> Any: try: stream = BytesIO(self._buffer) - decoder = cbor2.CBORDecoder(stream) + decoder = PythonCBORDecoder(stream) decoded_data = decoder.decode() bytes_consumed = stream.tell() From 2af63effd38ced32a627d1a04cd424bbb8f6596a Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Sun, 8 Feb 2026 18:16:44 +0545 Subject: [PATCH 23/25] Refactor cbor transformer --- cbor_rpc/pipe/event_pipe.py | 10 +- cbor_rpc/transformer/cbor_transformer.py | 117 +++++++----------- tests/transformer/test_cbor_transformer.py | 110 +++++++--------- .../test_event_transformer_pipe.py | 4 +- 4 files changed, 104 insertions(+), 137 deletions(-) diff --git a/cbor_rpc/pipe/event_pipe.py b/cbor_rpc/pipe/event_pipe.py index 2b7d66b..a6de80c 100644 --- a/cbor_rpc/pipe/event_pipe.py +++ b/cbor_rpc/pipe/event_pipe.py @@ -45,7 +45,15 @@ def connect_to(self, other: "ConnectedPipe"): async def write(self, chunk: Any) -> bool: if self._closed or not self.connected_pipe or self.connected_pipe._closed: return False - await self.connected_pipe._notify("data", chunk) + async def _deliver() -> None: + try: + await self.connected_pipe._notify("data", chunk) + except Exception: + # Receiver-side errors should not bubble to the writer. + return True + + loop = asyncio.get_running_loop() + loop.create_task(_deliver()) return True async def terminate(self, *args: Any) -> None: diff --git a/cbor_rpc/transformer/cbor_transformer.py b/cbor_rpc/transformer/cbor_transformer.py index b97e509..87ce79a 100644 --- a/cbor_rpc/transformer/cbor_transformer.py +++ b/cbor_rpc/transformer/cbor_transformer.py @@ -1,64 +1,58 @@ import cbor2 from io import BytesIO from typing import Any, Union + from .base import Transformer, AsyncTransformer from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException -# We import the Python implementation of CBORDecoder ensuring we avoid buffering issues -# that can occur with the C extension when using stream slicing logic. +# Use the pure-Python CBORDecoder to avoid buffering issues with the C extension +# when using stream-slicing logic. try: from cbor2._decoder import CBORDecoder as PythonCBORDecoder except ImportError: from cbor2 import CBORDecoder as PythonCBORDecoder +# Import pure-Python break_marker / CBORDecodeError so we can handle both +# C-extension and pure-Python backends uniformly. +try: + from cbor2._types import break_marker as _py_break_marker +except ImportError: + _py_break_marker = cbor2.break_marker + -class CborTransformer(Transformer[Any, Any]): - """ - A transformer that encodes Python objects to CBOR bytes and decodes CBOR bytes back to Python objects. - """ - def __init__(self): - super().__init__() +def _is_eof_error(exc: Exception) -> bool: + """Return True if *exc* signals incomplete CBOR data (C or Python backend).""" + return isinstance(exc, (cbor2.CBORDecodeEOF, IndexError)) or type(exc).__name__ == "CBORDecodeEOF" + + +class CborTransformer(Transformer[Any, Any]): + """Encodes Python objects to CBOR bytes and decodes CBOR bytes back.""" def encode(self, data: Any) -> bytes: - try: - return cbor2.dumps(data) - except Exception: - raise + return cbor2.dumps(data) def decode(self, data: Union[bytes, None]) -> Any: if data is None: raise TypeError("Expected bytes, got None") - if not isinstance(data, bytes): raise TypeError(f"Expected bytes, got {type(data)}") - try: - # cbor2.loads can decode a single CBOR object from bytes return cbor2.loads(data) except cbor2.CBORDecodeEOF as e: - # For a non-stream transformer, incomplete data is a decoding error raise cbor2.CBORDecodeError("Incomplete CBOR data for non-stream transformer") from e - except cbor2.CBORDecodeError: - # Re-raise other decoding errors - raise class CborStreamTransformer(AsyncTransformer[Any, Any]): - """ - An async transformer that decodes a stream of concatenated CBOR objects. - This is similar to how a JSON stream decoder would work, reading one object at a time. - """ + """Async stream transformer that decodes concatenated CBOR objects.""" - def __init__(self): + def __init__(self, max_buffer_bytes: int = 1024 * 1024*50): super().__init__() self._buffer = bytearray() + self._max_buffer_bytes = max_buffer_bytes async def encode(self, data: Any) -> bytes: - try: - return cbor2.dumps(data) - except Exception: - raise + return cbor2.dumps(data) async def decode(self, data: Union[bytes, None]) -> Any: if data is not None: @@ -66,51 +60,36 @@ async def decode(self, data: Union[bytes, None]) -> Any: raise TypeError(f"Expected bytes or None, got {type(data)}") self._buffer.extend(data) + if len(self._buffer) > self._max_buffer_bytes: + self._buffer.clear() + raise OverflowError("CBOR stream buffer exceeded max size") + if not self._buffer: raise NeedsMoreDataException() try: - stream = BytesIO(self._buffer) - decoder = PythonCBORDecoder(stream) - decoded_data = decoder.decode() - if decoded_data is cbor2.break_marker: - raise cbor2.CBORDecodeError("Unexpected break marker") + return self._decode_one() + except Exception as e: + if _is_eof_error(e): + raise NeedsMoreDataException() + raise - bytes_consumed = stream.tell() - self._buffer = self._buffer[bytes_consumed:] + # -- private helpers -------------------------------------------------- - return decoded_data - except cbor2.CBORDecodeEOF: - raise NeedsMoreDataException() - except cbor2.CBORDecodeError as e: - original_exception = e - # Discard bytes from the buffer until a valid CBOR object can be decoded or the buffer is exhausted. - # This loop attempts to find the start of the next valid CBOR object. - while self._buffer: - # Discard one byte and try again - self._buffer = self._buffer[1:] - if not self._buffer: - break # Buffer is empty, cannot recover further - - try: - stream = BytesIO(self._buffer) - decoder = PythonCBORDecoder(stream) - decoded_data = decoder.decode() - - bytes_consumed = stream.tell() - self._buffer = self._buffer[bytes_consumed:] - - return decoded_data # Successfully decoded a new object - except cbor2.CBORDecodeEOF: - # If we need more data after discarding some, it means we might be in the middle of a valid object - raise NeedsMoreDataException() - except cbor2.CBORDecodeError: - # Still an error after discarding one byte, continue the loop to discard another - continue - # If we reach here, the buffer is exhausted or we couldn't recover. - # Re-raise the original exception as no valid CBOR object could be found. - raise original_exception + def _decode_one(self) -> Any: + """Decode exactly one CBOR object from the front of the buffer.""" + stream = BytesIO(self._buffer) + decoder = PythonCBORDecoder(stream) + try: + obj = decoder.decode() except Exception as e: - # For other unexpected errors, clear the buffer and re-raise - self._buffer = bytearray() - raise e + # Normalize value errors to CBORDecodeError for consistent API behavior. + if isinstance(e, cbor2.CBORDecodeValueError) or type(e).__name__ == "CBORDecodeValueError": + raise cbor2.CBORDecodeError(str(e)) from e + raise + + if obj is cbor2.break_marker or obj is _py_break_marker: + raise cbor2.CBORDecodeError("Unexpected break marker") + + self._buffer = self._buffer[stream.tell():] + return obj diff --git a/tests/transformer/test_cbor_transformer.py b/tests/transformer/test_cbor_transformer.py index 05beb55..bdcea66 100644 --- a/tests/transformer/test_cbor_transformer.py +++ b/tests/transformer/test_cbor_transformer.py @@ -5,29 +5,40 @@ import pytest_asyncio from cbor_rpc.pipe.event_pipe import EventPipe +from cbor_rpc.tcp.tcp import TcpPipe from cbor_rpc.transformer.base.base_exception import NeedsMoreDataException from cbor_rpc.transformer.base.event_transformer_pipe import EventTransformerPipe from cbor_rpc.transformer.cbor_transformer import CborTransformer, CborStreamTransformer from tests.helpers.timeout_queue import TimeoutQueue -@pytest_asyncio.fixture -async def _raw_pipe_pair(): - client_raw_pipe, server_raw_pipe = EventPipe.create_inmemory_pair() - yield client_raw_pipe, server_raw_pipe - await client_raw_pipe.terminate() - await server_raw_pipe.terminate() +@pytest.fixture( + params=[ + (EventPipe.create_inmemory_pair, "InmemoryPipe"), + (TcpPipe.create_inmemory_pair, "TcpPipe"), + ], + ids=lambda param: param[1], +) +async def pipe_pair(request): + create_pair_func, _ = request.param + if asyncio.iscoroutinefunction(create_pair_func): + client_pipe, server_pipe = await create_pair_func() + else: + client_pipe, server_pipe = create_pair_func() + yield client_pipe, server_pipe + await client_pipe.terminate() + await server_pipe.terminate() @pytest.fixture -def client_raw(_raw_pipe_pair): - client_raw_pipe, _ = _raw_pipe_pair +def client_raw(pipe_pair): + client_raw_pipe, _ = pipe_pair return client_raw_pipe @pytest.fixture -def server_raw(_raw_pipe_pair): - _, server_raw_pipe = _raw_pipe_pair +def server_raw(pipe_pair): + _, server_raw_pipe = pipe_pair return server_raw_pipe @@ -67,46 +78,27 @@ async def test_cbor_transformer_decoding_error_on_read(self, server_raw, client_ client_transformed_pipe.on("error", error_queue.put_nowait) incomplete_cbor_bytes = b"\x83\x01\x02" - with pytest.raises(cbor2.CBORDecodeError): - await server_raw.write(incomplete_cbor_bytes) + ok = await server_raw.write(incomplete_cbor_bytes) + assert ok is True error = await asyncio.wait_for(error_queue.get(), timeout=1) - assert isinstance(error, cbor2.CBORDecodeError) - assert "Incomplete CBOR data for non-stream transformer" in str(error) + assert type(error).__name__.startswith("CBORDecode") truly_invalid_cbor = b"\x1f" - with pytest.raises(cbor2.CBORDecodeError): - await server_raw.write(truly_invalid_cbor) + ok = await server_raw.write(truly_invalid_cbor) + assert ok is True error = await asyncio.wait_for(error_queue.get(), timeout=1) - assert isinstance(error, cbor2.CBORDecodeError) - assert "unknown unsigned integer subtype" in str(error) + assert type(error).__name__.startswith("CBORDecode") async def test_cbor_transformer_non_bytes_data(self, server_raw, client_cbor): - client_transformed_pipe = client_cbor - - error_queue = TimeoutQueue() - client_transformed_pipe.on("error", error_queue.put_nowait) - - non_bytes_data = "not cbor" + transformer = CborTransformer() with pytest.raises(TypeError): - await server_raw.write(non_bytes_data) - - error = await asyncio.wait_for(error_queue.get(), timeout=1) - assert isinstance(error, TypeError) - assert "Expected bytes" in str(error) + transformer.decode("not cbor") async def test_cbor_transformer_none_data(self, server_raw, client_cbor): - client_transformed_pipe = client_cbor - - error_queue = TimeoutQueue() - client_transformed_pipe.on("error", error_queue.put_nowait) - + transformer = CborTransformer() with pytest.raises(TypeError): - await server_raw.write(None) - - error = await asyncio.wait_for(error_queue.get(), timeout=1) - assert isinstance(error, TypeError) - assert "Expected bytes" in str(error) + transformer.decode(None) async def test_cbor_transformer_multiple_separate_writes(self, server_raw, client_cbor): client_transformed_pipe = client_cbor @@ -276,26 +268,16 @@ async def test_cbor_stream_transformer_invalid_data_in_stream(self, client_raw, decoded1 = await received_data_queue.get() assert decoded1 == obj1 - decoded2 = await received_data_queue.get() - assert decoded2 == obj2 + error = await asyncio.wait_for(error_queue.get(), timeout=1) + assert isinstance(error, cbor2.CBORDecodeError) with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(error_queue.get(), timeout=0.1) + await asyncio.wait_for(received_data_queue.get(), timeout=0.1) async def test_cbor_stream_transformer_non_bytes_data(self, client_raw, server_raw): - cbor_stream_transformer = CborStreamTransformer() - client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) - - error_queue = TimeoutQueue() - client_transformed_pipe.on("error", error_queue.put_nowait) - - non_bytes_data = "not cbor" + transformer = CborStreamTransformer() with pytest.raises(TypeError): - await server_raw.write(non_bytes_data) - - error = await asyncio.wait_for(error_queue.get(), timeout=1) - assert isinstance(error, TypeError) - assert "Expected bytes" in str(error) + await transformer.decode("not cbor") async def test_cbor_stream_transformer_close_propagation_and_write_after_close(self, client_raw): cbor_stream_transformer = CborStreamTransformer() @@ -313,7 +295,7 @@ async def test_cbor_stream_transformer_close_propagation_and_write_after_close(s @pytest.mark.asyncio -async def test_cbor_stream_transformer_paths(monkeypatch): +async def test_cbor_stream_transformer_paths(): transformer = CborStreamTransformer() with pytest.raises(NeedsMoreDataException): @@ -323,19 +305,17 @@ async def test_cbor_stream_transformer_paths(monkeypatch): await transformer.decode("bad") good = cbor2.dumps({"a": 1}) - recovered = await transformer.decode(b"\xff" + good) - assert recovered == {"a": 1} + with pytest.raises(cbor2.CBORDecodeError): + await transformer.decode(b"\xff" + good) with pytest.raises(cbor2.CBORDecodeError): await transformer.decode(b"\xff") - class BadDecoder: - def __init__(self, _stream): - pass - def decode(self): - raise ValueError("boom") +@pytest.mark.asyncio +async def test_cbor_stream_transformer_overflow(): + transformer = CborStreamTransformer(max_buffer_bytes=4) + + with pytest.raises(OverflowError): + await transformer.decode(b"12345") - monkeypatch.setattr(cbor2, "CBORDecoder", BadDecoder) - with pytest.raises(ValueError): - await transformer.decode(good) diff --git a/tests/transformer/test_event_transformer_pipe.py b/tests/transformer/test_event_transformer_pipe.py index 5a56b17..add4d6a 100644 --- a/tests/transformer/test_event_transformer_pipe.py +++ b/tests/transformer/test_event_transformer_pipe.py @@ -63,8 +63,8 @@ def on_error(err: Exception) -> None: await asyncio.sleep(0.01) assert received == ["dec:ok"] - with pytest.raises(ValueError): - await base_b.write("decode_error") + ok = await base_b.write("decode_error") + assert ok is True await asyncio.sleep(0.01) assert errors From 3cfb54eddd6ba46622a1b470122629ebcf8208ef Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Sun, 8 Feb 2026 18:34:15 +0545 Subject: [PATCH 24/25] Fix cbor transformer test --- tests/transformer/test_cbor_transformer.py | 69 ++++++++++++++++++++-- 1 file changed, 64 insertions(+), 5 deletions(-) diff --git a/tests/transformer/test_cbor_transformer.py b/tests/transformer/test_cbor_transformer.py index bdcea66..df5e3cd 100644 --- a/tests/transformer/test_cbor_transformer.py +++ b/tests/transformer/test_cbor_transformer.py @@ -100,21 +100,43 @@ async def test_cbor_transformer_none_data(self, server_raw, client_cbor): with pytest.raises(TypeError): transformer.decode(None) - async def test_cbor_transformer_multiple_separate_writes(self, server_raw, client_cbor): + async def test_cbor_transformer_multiple_separate_writes(self, server_raw, client_cbor, client_raw): client_transformed_pipe = client_cbor received_data_queue = TimeoutQueue() - client_transformed_pipe.on("data", received_data_queue.put_nowait) - - await server_raw.write(cbor2.dumps({"a": 1})) - await server_raw.write(cbor2.dumps({"b": 2})) + received_errors = [] + client_transformed_pipe.pipeline("data", received_data_queue.put_nowait) + client_transformed_pipe.on("error", lambda e: received_errors.append(e)) + assert await server_raw.write(cbor2.dumps({"a": 1})) decoded1 = await received_data_queue.get() + await asyncio.sleep(0.05) + assert await server_raw.write(cbor2.dumps({"b": 2})) + + assert [] == received_errors, f"Expected no errors, got: {received_errors}" decoded2 = await received_data_queue.get() assert decoded1 == {"a": 1} assert decoded2 == {"b": 2} + async def test_cbor_transformer_single_concatenated_write(self, server_raw, client_cbor): + client_transformed_pipe = client_cbor + + received_data_queue = TimeoutQueue() + received_errors = [] + client_transformed_pipe.pipeline("data", received_data_queue.put_nowait) + client_transformed_pipe.on("error", lambda e: received_errors.append(e)) + + concatenated = cbor2.dumps({"a": 1}) + cbor2.dumps({"b": 2}) + assert await server_raw.write(concatenated) + + decoded1 = await received_data_queue.get() + assert decoded1 == {"a": 1} + assert [] == received_errors, f"Expected no errors, got: {received_errors}" + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(received_data_queue.get(), timeout=0.1) + async def test_cbor_transformer_encode_error_on_write(self, server_raw, client_cbor): client_transformed_pipe = client_cbor @@ -189,6 +211,43 @@ async def test_cbor_stream_transformer_concatenated_objects(self, client_raw, se assert decoded2 == obj2 assert decoded3 == obj3 + async def test_cbor_stream_transformer_single_concatenated_write(self, client_raw, server_raw): + cbor_stream_transformer = CborStreamTransformer() + client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) + + received_data_queue = TimeoutQueue() + received_errors = [] + client_transformed_pipe.on("data", received_data_queue.put_nowait) + client_transformed_pipe.on("error", lambda e: received_errors.append(e)) + + concatenated = cbor2.dumps({"a": 1}) + cbor2.dumps({"b": 2}) + await server_raw.write(concatenated) + + decoded1 = await received_data_queue.get() + decoded2 = await received_data_queue.get() + assert decoded1 == {"a": 1} + assert decoded2 == {"b": 2} + assert [] == received_errors, f"Expected no errors, got: {received_errors}" + + async def test_cbor_stream_transformer_delayed_separate_writes(self, client_raw, server_raw): + cbor_stream_transformer = CborStreamTransformer() + client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) + + received_data_queue = TimeoutQueue() + received_errors = [] + client_transformed_pipe.on("data", received_data_queue.put_nowait) + client_transformed_pipe.on("error", lambda e: received_errors.append(e)) + + assert await server_raw.write(cbor2.dumps({"a": 1})) + decoded1 = await received_data_queue.get() + await asyncio.sleep(0.05) + assert await server_raw.write(cbor2.dumps({"b": 2})) + + decoded2 = await received_data_queue.get() + assert decoded1 == {"a": 1} + assert decoded2 == {"b": 2} + assert [] == received_errors, f"Expected no errors, got: {received_errors}" + async def test_cbor_stream_transformer_fragmented_objects(self, client_raw, server_raw): cbor_stream_transformer = CborStreamTransformer() client_transformed_pipe = cbor_stream_transformer.apply_transformer(client_raw) From 3f33e15584f7e65f8f33f9715bea1da4cd39b723 Mon Sep 17 00:00:00 2001 From: Sudip Bhattarai Date: Sun, 8 Feb 2026 18:55:32 +0545 Subject: [PATCH 25/25] Fix error propagation in inmemory pipe --- cbor_rpc/pipe/event_pipe.py | 5 ++++- cbor_rpc/pipe/pipe.py | 3 +++ cbor_rpc/transformer/base/transformer_base.py | 4 ++++ cbor_rpc/transformer/cbor_transformer.py | 2 +- tests/transformer/test_cbor_transformer.py | 15 ++++++--------- 5 files changed, 18 insertions(+), 11 deletions(-) diff --git a/cbor_rpc/pipe/event_pipe.py b/cbor_rpc/pipe/event_pipe.py index a6de80c..559e1ef 100644 --- a/cbor_rpc/pipe/event_pipe.py +++ b/cbor_rpc/pipe/event_pipe.py @@ -50,7 +50,9 @@ async def _deliver() -> None: await self.connected_pipe._notify("data", chunk) except Exception: # Receiver-side errors should not bubble to the writer. - return True + # Mirror AioPipe behavior by closing both ends on pipeline errors. + await self.connected_pipe.terminate() + return loop = asyncio.get_running_loop() loop.create_task(_deliver()) @@ -62,6 +64,7 @@ async def terminate(self, *args: Any) -> None: self._closed = True self._emit("close", *args) if self.connected_pipe and not self.connected_pipe._closed: + self.connected_pipe._closed = True self.connected_pipe._emit("close", *args) pipe1 = ConnectedPipe() diff --git a/cbor_rpc/pipe/pipe.py b/cbor_rpc/pipe/pipe.py index 7cc3150..bcb77df 100644 --- a/cbor_rpc/pipe/pipe.py +++ b/cbor_rpc/pipe/pipe.py @@ -12,6 +12,9 @@ class Pipe(AbstractEmitter, Generic[T1, T2], ABC): """ Abstract Pipe defining async event-based read/write/terminate interface. + We write T1 + We read T2. + Normally pipes will have T1=T2. if T1!=T2, then you need to use a transformer to convert between them. """ @abstractmethod diff --git a/cbor_rpc/transformer/base/transformer_base.py b/cbor_rpc/transformer/base/transformer_base.py index 7799c00..957f22d 100644 --- a/cbor_rpc/transformer/base/transformer_base.py +++ b/cbor_rpc/transformer/base/transformer_base.py @@ -13,6 +13,10 @@ # Sync Transformer (no async methods) class Transformer(Generic[T1, T2]): + """ + Transformer converts input of type T1 to output of type T2 and vice versa. + Write expects T1, read produces T1. The underlying pipe works with T2. + """ def __init__(self): super().__init__() self._closed = False diff --git a/cbor_rpc/transformer/cbor_transformer.py b/cbor_rpc/transformer/cbor_transformer.py index 87ce79a..1a77dc4 100644 --- a/cbor_rpc/transformer/cbor_transformer.py +++ b/cbor_rpc/transformer/cbor_transformer.py @@ -26,7 +26,7 @@ def _is_eof_error(exc: Exception) -> bool: return isinstance(exc, (cbor2.CBORDecodeEOF, IndexError)) or type(exc).__name__ == "CBORDecodeEOF" -class CborTransformer(Transformer[Any, Any]): +class CborTransformer(Transformer[Any, bytes]): """Encodes Python objects to CBOR bytes and decodes CBOR bytes back.""" def encode(self, data: Any) -> bytes: diff --git a/tests/transformer/test_cbor_transformer.py b/tests/transformer/test_cbor_transformer.py index df5e3cd..5b57a3b 100644 --- a/tests/transformer/test_cbor_transformer.py +++ b/tests/transformer/test_cbor_transformer.py @@ -71,24 +71,21 @@ async def test_cbor_transformer_end_to_end_simple_dict(self, client_raw, server_ decoded_data_received_by_client = await client_received_data_queue.get() assert decoded_data_received_by_client == response_data - async def test_cbor_transformer_decoding_error_on_read(self, server_raw, client_cbor): + async def test_cbor_transformer_decoding_error_on_read(self, server_raw, client_cbor,client_raw): client_transformed_pipe = client_cbor error_queue = TimeoutQueue() client_transformed_pipe.on("error", error_queue.put_nowait) incomplete_cbor_bytes = b"\x83\x01\x02" - ok = await server_raw.write(incomplete_cbor_bytes) - assert ok is True + assert await server_raw.write(incomplete_cbor_bytes) error = await asyncio.wait_for(error_queue.get(), timeout=1) assert type(error).__name__.startswith("CBORDecode") - - truly_invalid_cbor = b"\x1f" - ok = await server_raw.write(truly_invalid_cbor) - assert ok is True - error = await asyncio.wait_for(error_queue.get(), timeout=1) - assert type(error).__name__.startswith("CBORDecode") + # sleep 2 + await asyncio.sleep(0.2) + assert server_raw._closed is True, "Pipe should be closed on decode error" + assert server_raw._closed is True, "Pipe should be closed on decode error" async def test_cbor_transformer_non_bytes_data(self, server_raw, client_cbor): transformer = CborTransformer()