diff --git a/packages/jumpstarter-driver-corellium/examples/exporter.yml b/packages/jumpstarter-driver-corellium/examples/exporter.yml index e87370263..e5e2a1972 100644 --- a/packages/jumpstarter-driver-corellium/examples/exporter.yml +++ b/packages/jumpstarter-driver-corellium/examples/exporter.yml @@ -16,3 +16,4 @@ export: # optional device_os: "1.1.1" device_build: "Critical Application Monitor (Baremetal)" + console_name: "Primary Compute Non-Secure" diff --git a/packages/jumpstarter-driver-corellium/fixtures/http/get-instance-200.json b/packages/jumpstarter-driver-corellium/fixtures/http/get-instance-200.json new file mode 100644 index 000000000..4e39d2177 --- /dev/null +++ b/packages/jumpstarter-driver-corellium/fixtures/http/get-instance-200.json @@ -0,0 +1,84 @@ +{ + "id": "7f4f241c-821f-4219-905f-c3b50b0db5dd", + "name": "my-instance", + "key": "key", + "flavor": "kronos", + "flavorName": "RD-1AE", + "type": "iot", + "project": "778f00af-5e9b-40e6-8e7f-c4f14b632e9c", + "domainName": "my-domain", + "state": "on", + "stateChanged": "9999-12-31T00:00:00.000Z", + "startedAt": "9999-12-31T00:00:00.000Z", + "userTask": null, + "taskState": "none", + "restoreStatus": {}, + "error": null, + "bootOptions": { + "bootArgs": "", + "uuid": "uuid", + "ecid": "ecid", + "noSnapshotMount": true + }, + "serviceIp": "0.0.0.0", + "wifiIp": "0.0.0.0", + "wifiMac": "mac", + "secondaryIp": null, + "port-gdb": "1234", + "port-usbmuxd": "27015", + "port-1-cons": "4001", + "port-2-cons": "4002", + "consoles": [ + { + "name": "Console 1", + "id": "port-1-cons", + "port": "4001", + "hash": "post-1-cons-hash", + "info": "port-1-cons-info" + }, + { + "name": "Console 2", + "id": "port-2-cons", + "port": "4002", + "hash": "port-2-cons-hash", + "info": "port-2-cons-info" + } + ], + "services": { + "vpn": { + "proxy": [], + "listeners": [] + } + }, + "panicked": false, + "created": "9999-12-31T00:00:00.000Z", + "os": "1.1", + "agent": null, + "netmon": { + "hash": "netmon-hash", + "info": "netmon-info", + "enabled": false + }, + "netdump": { + "hash": "netdump-hash", + "info": "netdump-info", + "enabled": false + }, + "coreTrace": { + "enabled": null + }, + "hyperTrace": { + "enabled": null + }, + "exposePort": null, + "fault": null, + "patches": [ + "jailbroken" + ], + "createdBy": { + "id": "2eadc0a4-bd23-44dd-9a2c-0aef784d0c43", + "username": "some@user.email", + "label": "full name" + }, + "mast": false +} diff --git a/packages/jumpstarter-driver-corellium/fixtures/http/get-instance-state-404.json b/packages/jumpstarter-driver-corellium/fixtures/http/get-instance-404.json similarity index 100% rename from packages/jumpstarter-driver-corellium/fixtures/http/get-instance-state-404.json rename to packages/jumpstarter-driver-corellium/fixtures/http/get-instance-404.json diff --git a/packages/jumpstarter-driver-corellium/fixtures/http/get-instance-console-url-200.json b/packages/jumpstarter-driver-corellium/fixtures/http/get-instance-console-url-200.json new file mode 100644 index 000000000..15c5cc724 --- /dev/null +++ b/packages/jumpstarter-driver-corellium/fixtures/http/get-instance-console-url-200.json @@ -0,0 +1,3 @@ +{ + "url": "wss://api-host/port-cons-1" +} diff --git a/packages/jumpstarter-driver-corellium/fixtures/http/get-instance-console-url-400.json b/packages/jumpstarter-driver-corellium/fixtures/http/get-instance-console-url-400.json new file mode 100644 index 000000000..b8265f291 --- /dev/null +++ b/packages/jumpstarter-driver-corellium/fixtures/http/get-instance-console-url-400.json @@ -0,0 +1,5 @@ +{ + "error":"not a valid type", + "errorID":"UserError", + "field":"type" +} diff --git a/packages/jumpstarter-driver-corellium/jumpstarter_driver_corellium/corellium/api.py b/packages/jumpstarter-driver-corellium/jumpstarter_driver_corellium/corellium/api.py index ec21fc99a..44d85ae27 100644 --- a/packages/jumpstarter-driver-corellium/jumpstarter_driver_corellium/corellium/api.py +++ b/packages/jumpstarter-driver-corellium/jumpstarter_driver_corellium/corellium/api.py @@ -38,16 +38,19 @@ def login(self) -> None: It uses the global requests objects so a new session can be generated. """ - data = { + data = None + req_data = { 'apiToken': self.token } try: - res = requests.post(f'{self.baseurl}/v1/auth/login', json=data) + res = requests.post(f'{self.baseurl}/v1/auth/login', json=req_data) data = res.json() res.raise_for_status() except (requests.exceptions.RequestException, requests.exceptions.HTTPError) as e: - raise CorelliumApiException(data.get('error', str(e))) from e + msgerr = data.get('error') if data is not None else str(e) + + raise CorelliumApiException(msgerr) from e self.session = Session(**data) self.req.headers.update(self.session.as_header()) @@ -56,12 +59,16 @@ def get_project(self, project_ref: str = 'Default Project') -> Optional[Project] """ Retrieve a project based on project_ref, which is either its id or name. """ + data = None + try: res = self.req.get(f'{self.baseurl}/v1/projects') data = res.json() res.raise_for_status() except requests.exceptions.RequestException as e: - raise CorelliumApiException(data.get('error', str(e))) from e + msgerr = data.get('error') if data is not None else str(e) + + raise CorelliumApiException(msgerr) from e for project in data: if project['name'] == project_ref or project['id'] == project_ref: @@ -75,12 +82,16 @@ def get_device(self, model: str) -> Optional[Device]: A device object is used to create a new virtual instance. """ + data = None + try: res = self.req.get(f'{self.baseurl}/v1/models') data = res.json() res.raise_for_status() except requests.exceptions.RequestException as e: - raise CorelliumApiException(data.get('error', str(e))) from e + msgerr = data.get('error') if data is not None else str(e) + + raise CorelliumApiException(msgerr) from e for device in data: if device['model'] == model: @@ -92,7 +103,8 @@ def create_instance(self, name: str, project: Project, device: Device, os_versio """ Create a new virtual instance from a device spec. """ - data = { + data = None + req_data = { 'name': name, 'project': project.id, 'flavor': device.flavor, @@ -101,11 +113,13 @@ def create_instance(self, name: str, project: Project, device: Device, os_versio } try: - res = self.req.post(f'{self.baseurl}/v1/instances', json=data) + res = self.req.post(f'{self.baseurl}/v1/instances', json=req_data) data = res.json() res.raise_for_status() except requests.exceptions.RequestException as e: - raise CorelliumApiException(data.get('error', str(e))) from e + msgerr = data.get('error') if data is not None else str(e) + + raise CorelliumApiException(msgerr) from e return Instance(**data) @@ -115,12 +129,16 @@ def get_instance(self, instance_ref: str) -> Optional[Instance]: Return None if it does not exist. """ + data = None + try: res = self.req.get(f'{self.baseurl}/v1/instances') data = res.json() res.raise_for_status() except requests.exceptions.RequestException as e: - raise CorelliumApiException(data.get('error', str(e))) from e + msgerr = data.get('error') if data is not None else str(e) + + raise CorelliumApiException(msgerr) from e for instance in data: if instance['name'] == instance_ref or instance['id'] == instance_ref: @@ -144,16 +162,17 @@ def set_instance_state(self, instance: Instance, instance_state: str) -> None: - rebooting - error """ - data = { + data = None + req_data = { 'state': instance_state } try: - res = self.req.put(f'{self.baseurl}/v1/instances/{instance.id}/state', json=data) + res = self.req.put(f'{self.baseurl}/v1/instances/{instance.id}/state', json=req_data) data = res.json() if res.status_code != 204 else None res.raise_for_status() except requests.exceptions.RequestException as e: - msgerr = data if data is not None else str(e) + msgerr = data.get('error') if data is not None else str(e) raise CorelliumApiException(msgerr) from e @@ -168,6 +187,47 @@ def destroy_instance(self, instance: Instance) -> None: data = res.json() if res.status_code != 204 else None res.raise_for_status() except requests.exceptions.RequestException as e: - msgerr = data if data is not None else str(e) + msgerr = data.get('error') if data is not None else str(e) raise CorelliumApiException(msgerr) from e + + def get_instance_console_id(self, instance: Instance, console_name: str) -> Optional[str]: + """ + Retrieve an instance's console id by its name. + + Return None if it does not exist. + """ + data = None + + try: + res = self.req.get(f'{self.baseurl}/v1/instances/{instance.id}') + data = res.json() + res.raise_for_status() + except requests.exceptions.RequestException as e: + msgerr = data.get('error') if data is not None else str(e) + + raise CorelliumApiException(msgerr) from e + + for console in data.get('consoles', []): + if console['name'] == console_name: + return console['id'] + + return None + + def get_instance_console_url(self, instance: Instance, console_id: str) -> Optional[str]: + """ + Get a a console URL (websocket) to stream logs from. + """ + data = None + + try: + res = self.req.get(f'{self.baseurl}/v1/instances/{instance.id}/console', + params={'type': console_id.replace('port-', '')}) + data = res.json() + res.raise_for_status() + except requests.exceptions.RequestException as e: + msgerr = data.get('error') if data is not None else str(e) + + raise CorelliumApiException(msgerr) from e + + return data['url'] diff --git a/packages/jumpstarter-driver-corellium/jumpstarter_driver_corellium/corellium/api_test.py b/packages/jumpstarter-driver-corellium/jumpstarter_driver_corellium/corellium/api_test.py index b342b2ee9..7825b2460 100644 --- a/packages/jumpstarter-driver-corellium/jumpstarter_driver_corellium/corellium/api_test.py +++ b/packages/jumpstarter-driver-corellium/jumpstarter_driver_corellium/corellium/api_test.py @@ -2,6 +2,7 @@ import pytest +# import websockets from .api import ApiClient from .exceptions import CorelliumApiException from .types import Device, Instance, Project, Session @@ -167,7 +168,7 @@ def test_destroy_instance_state_ok(requests_mock): 'status_code,data,msg', [ (403, fixture('http/403.json'), 'Invalid or missing authorization token'), - (404, fixture('http/get-instance-state-404.json'), 'No instance associated with this value'), + (404, fixture('http/get-instance-404.json'), 'No instance associated with this value'), ]) def test_destroy_instance_error(requests_mock, status_code, data, msg): instance = Instance(id='d59db33d-27bd-4b22-878d-49e4758a648e') @@ -180,3 +181,77 @@ def test_destroy_instance_error(requests_mock, status_code, data, msg): api.destroy_instance(instance) assert msg in str(e.value) + + +@pytest.mark.parametrize( + 'console_name,console_id', + [ + ('Console 1', 'port-1-cons',), + ('Console 10', None,), + ('Console 2', 'port-2-cons',), + ]) +def test_get_instance_console_id_ok(requests_mock, console_name, console_id): + data = fixture('http/get-instance-200.json') + instance = Instance(id='d59db33d-27bd-4b22-878d-49e4758a648e') + requests_mock.get(f'https://api-host/api/v1/instances/{instance.id}', status_code=200, text=data) + api = ApiClient('api-host', 'api-token') + api.session = Session('session-token', '2022-03-20T01:50:10.000Z') + + current = api.get_instance_console_id(instance, console_name) + + assert console_id == current + + +@pytest.mark.parametrize( + 'status_code,data,msg', + [ + (403, fixture('http/403.json'), 'Invalid or missing authorization token'), + (404, fixture('http/get-instance-404.json'), 'No instance associated with this value'), + ]) +def test_get_instance_console_id_error(requests_mock, status_code, data, msg): + instance = Instance(id='d59db33d-27bd-4b22-878d-49e4758a648e') + requests_mock.get(f'https://api-host/api/v1/instances/{instance.id}', + status_code=status_code, text=data) + api = ApiClient('api-host', 'api-token') + api.session = Session('session-token', '2022-03-20T01:50:10.000Z') + + with pytest.raises(CorelliumApiException) as e: + api.get_instance_console_id(instance, 'Console 1') + + assert msg in str(e.value) + + +def test_get_instance_console_url_ok(requests_mock): + console_id = 'port-1-cons' + data = fixture('http/get-instance-console-url-200.json') + instance = Instance(id='d59db33d-27bd-4b22-878d-49e4758a648e') + requests_mock.get(f'https://api-host/api/v1/instances/{instance.id}/console?type=1-cons', + status_code=200, text=data) + api = ApiClient('api-host', 'api-token') + api.session = Session('session-token', '2022-03-20T01:50:10.000Z') + + current = api.get_instance_console_url(instance, console_id) + expected = 'wss://api-host/port-cons-1' + + assert expected == current + + +@pytest.mark.parametrize( + 'status_code,data,msg', + [ + (400, fixture('http/get-instance-console-url-400.json'), 'not a valid type'), + (403, fixture('http/403.json'), 'Invalid or missing authorization token'), + (404, fixture('http/get-instance-404.json'), 'No instance associated with this value'), + ]) +def test_get_instance_console_url_error(requests_mock, status_code, data, msg): + console_id = 'port-1-cons' + instance = Instance(id='d59db33d-27bd-4b22-878d-49e4758a648e') + requests_mock.get(f'https://api-host/api/v1/instances/{instance.id}/console?type=1-cons', + status_code=status_code, text=data) + api = ApiClient('api-host', 'api-token') + api.session = Session('session-token', '2022-03-20T01:50:10.000Z') + + with pytest.raises(CorelliumApiException) as e: + api.get_instance_console_url(instance, console_id) + + assert msg in str(e.value) diff --git a/packages/jumpstarter-driver-corellium/jumpstarter_driver_corellium/driver.py b/packages/jumpstarter-driver-corellium/jumpstarter_driver_corellium/driver.py index fbb76332e..8058f2001 100644 --- a/packages/jumpstarter-driver-corellium/jumpstarter_driver_corellium/driver.py +++ b/packages/jumpstarter-driver-corellium/jumpstarter_driver_corellium/driver.py @@ -8,6 +8,7 @@ from datetime import datetime, timedelta from typing import Dict, Optional +from jumpstarter_driver_network.driver import WebsocketNetwork from jumpstarter_driver_power.driver import PowerReading, VirtualPowerInterface from .corellium.api import ApiClient @@ -27,6 +28,7 @@ class Corellium(Driver): device_flavor: str device_os: str = field(default='1.1.1') device_build: str = field(default='Critical Application Monitor (Baremetal)') + console_name: str = field(default='Primary Compute Non-Secure') @classmethod def client(cls) -> str: @@ -59,6 +61,7 @@ def __post_init__(self) -> None: self._api = ApiClient(api_host, api_token) self.children['power'] = CorelliumPower(parent=self) + self.children['serial'] = CorelliumConsole(parent=self, url='') def get_env_var(self, name: str) -> str: """ @@ -201,3 +204,49 @@ def off(self, destroy: bool = False) -> None: @export def read(self) -> AsyncGenerator[PowerReading, None]: pass + + +@dataclass(kw_only=True) +class CorelliumConsole(WebsocketNetwork): + """ + A serial console driver that uses a network websocket to connect + to the remote virtual instance console. + """ + parent: Corellium + baudrate: int = field(default=115200) + + @classmethod + def client(cls) -> str: + """ + Use pyserial client to re-use its console implementation. + """ + return "jumpstarter_driver_pyserial.client.PySerialClient" + + @property + def url(self) -> str: + """ + Retrieve console url from Corellium's API to be used by + other drivers, such as the PySerial one. + + Overwrites the parent's url property so the required + API calls are only invoked when the property is used. + """ + project = self.parent.api.get_project(self.parent.project_id) + if project is None: + raise ValueError(f"Unable to fetch project: {self.parent.project_id}") + + # get instance and fail if instance does not exist + instance = self.parent.api.get_instance(self.parent.device_name) + if instance is None: + raise ValueError("Instance does not exist or is powered off") + + console_id = self.parent.api.get_instance_console_id(instance, self.parent.console_name) + if console_id is None: + raise ValueError("Console ID not found for \"{self.patent.console_name}\"") + console_url = self.parent.api.get_instance_console_url(instance, console_id) + + return console_url + + @url.setter + def url(self, value: str) -> None: + pass diff --git a/packages/jumpstarter-driver-corellium/jumpstarter_driver_corellium/driver_test.py b/packages/jumpstarter-driver-corellium/jumpstarter_driver_corellium/driver_test.py index 2feb97c57..acc55c4a9 100644 --- a/packages/jumpstarter-driver-corellium/jumpstarter_driver_corellium/driver_test.py +++ b/packages/jumpstarter-driver-corellium/jumpstarter_driver_corellium/driver_test.py @@ -4,7 +4,7 @@ from .corellium.exceptions import CorelliumApiException from .corellium.types import Device, Instance, Project, Session -from .driver import Corellium, CorelliumPower +from .driver import Corellium, CorelliumConsole, CorelliumPower from jumpstarter.common import exceptions as jmp_exceptions @@ -154,3 +154,50 @@ def test_driver_power_off_error(monkeypatch, mock_data): patch.object(root._api, 'destroy_instance', **mock_data.get('destroy_instance', {'return_value': instance}))): power.off() + + +def test_driver_console_get_url_ok(monkeypatch): + monkeypatch.setenv('CORELLIUM_API_HOST', 'api-host') + monkeypatch.setenv('CORELLIUM_API_TOKEN', 'api-token') + + project = Project('1', 'Default Project') + instance = Instance(id='7f4f241c-821f-4219-905f-c3b50b0db5dd', state='on') + root = Corellium(project_id='1', device_name='jmp', device_flavor='kronos', device_os='1.0') + console = CorelliumConsole(parent=root, url='') + + with (patch.object(root._api, 'login', return_value=None), + patch.object(root._api, 'get_project', return_value=project), + patch.object(root._api, 'get_instance', return_value=instance), + patch.object(root._api, 'get_instance_console_id', return_value='uart7-cons'), + patch.object(root._api, 'get_instance_console_url', return_value='wss://mock')): + assert 'wss://mock' == console.url + + +@pytest.mark.parametrize('mock_data',[ + ({'login': {'side_effect': CorelliumApiException('login error')}}), + ({'get_project': {'return_value': None}}), + ({'get_instance': {'return_value': None}}), + ({'get_instance_console_id': {'side_effect': ValueError('x')}}), + ({'get_instance_console_url': {'side_effect': ValueError('x')}}) +]) +def test_driver_console_get_url_error(monkeypatch, mock_data): + monkeypatch.setenv('CORELLIUM_API_HOST', 'api-host') + monkeypatch.setenv('CORELLIUM_API_TOKEN', 'api-token') + + project = Project('1', 'Default Project') + instance = Instance(id='7f4f241c-821f-4219-905f-c3b50b0db5dd', state='on') + root = Corellium(project_id='1', device_name='jmp', device_flavor='kronos', device_os='1.0') + console = CorelliumConsole(parent=root, url='') + + with pytest.raises((CorelliumApiException, ValueError)): + with (patch.object(root._api, 'login', + **mock_data.get('login', {'return_value': None})), + patch.object(root._api, 'get_project', + **mock_data.get('get_project', {'return_value': project})), + patch.object(root._api, 'get_instance', + **mock_data.get('get_instance', {'side_effect': [instance, None]})), + patch.object(root._api, 'get_instance_console_id', + **mock_data.get('get_instance_console_id', {'return_value': 'uart7-cons'})), + patch.object(root._api, 'get_instance_console_id', + **mock_data.get('get_instance_console_url', {'return_value': 'ws://mock'}))): + assert console.url diff --git a/packages/jumpstarter-driver-corellium/pyproject.toml b/packages/jumpstarter-driver-corellium/pyproject.toml index 65b13a09f..c702a84e1 100644 --- a/packages/jumpstarter-driver-corellium/pyproject.toml +++ b/packages/jumpstarter-driver-corellium/pyproject.toml @@ -10,7 +10,8 @@ dependencies = [ "jumpstarter", "jumpstarter-driver-composite", "jumpstarter-driver-power", - "pyserial>=3.5", + "jumpstarter-driver-network", + "jumpstarter-driver-pyserial", "asyncclick>=8.1.7.2", ] @@ -18,7 +19,7 @@ dependencies = [ Corellium = "jumpstarter_driver_corellium.driver:Corellium" [dependency-groups] -dev = ["pytest>=8.3.2", "pytest-cov>=5.0.0", "trio>=0.28.0", "requests_mock"] +dev = ["pytest>=8.3.2", "pytest-cov>=5.0.0", "trio>=0.28.0", "requests_mock", "pytest-asyncio>=0.25.3"] [tool.hatch.metadata.hooks.vcs.urls] Homepage = "https://jumpstarter.dev" diff --git a/packages/jumpstarter-driver-network/jumpstarter_driver_network/driver.py b/packages/jumpstarter-driver-network/jumpstarter_driver_network/driver.py index 2db866463..f49825249 100644 --- a/packages/jumpstarter-driver-network/jumpstarter_driver_network/driver.py +++ b/packages/jumpstarter-driver-network/jumpstarter_driver_network/driver.py @@ -7,6 +7,7 @@ from os import getenv, getuid from typing import ClassVar, Literal +import websockets from anyio import ( connect_tcp, connect_unix, @@ -16,6 +17,7 @@ from anyio._backends._asyncio import SocketStream, StreamProtocol from anyio.streams.stapled import StapledObjectStream +from .streams.websocket import WebsocketClientStream from jumpstarter.driver import Driver, exportstream @@ -235,3 +237,25 @@ async def connect(self): self.logger.debug("Connecting Echo") async with StapledObjectStream(tx, rx) as stream: yield stream + + +@dataclass(kw_only=True) +class WebsocketNetwork(NetworkInterface, Driver): + ''' + Handles websocket connections from a given url. + ''' + url: str + + @exportstream + @asynccontextmanager + async def connect(self): + ''' + Create a websocket connection to `self.url` and srreams its output. + ''' + self.logger.info("Connecting to %s", self.url) + + async with websockets.connect(self.url) as websocket: + async with WebsocketClientStream(conn=websocket) as stream: + yield stream + + self.logger.info("Disconnected from %s", self.url) diff --git a/packages/jumpstarter-driver-network/jumpstarter_driver_network/driver_test.py b/packages/jumpstarter-driver-network/jumpstarter_driver_network/driver_test.py index d1ca48226..8bbd6fd65 100644 --- a/packages/jumpstarter-driver-network/jumpstarter_driver_network/driver_test.py +++ b/packages/jumpstarter-driver-network/jumpstarter_driver_network/driver_test.py @@ -3,12 +3,13 @@ import subprocess import sys from shutil import which +from unittest.mock import AsyncMock, patch import pytest from anyio.from_thread import start_blocking_portal from .adapters import TcpPortforwardAdapter, UnixPortforwardAdapter -from .driver import DbusNetwork, TcpNetwork, UdpNetwork, UnixNetwork +from .driver import DbusNetwork, TcpNetwork, UdpNetwork, UnixNetwork, WebsocketNetwork from jumpstarter.common import TemporaryUnixListener from jumpstarter.common.utils import serve @@ -141,3 +142,14 @@ def test_dbus_network_session(monkeypatch): stderr=subprocess.PIPE, ) assert oldvar == os.getenv("DBUS_SESSION_BUS_ADDRESS") + + +@pytest.mark.asyncio +async def test_websocket_network_connect(): + ws = AsyncMock() + ws.__aenter__.return_value = ws + + with patch("websockets.connect", return_value=ws) as m: + client = WebsocketNetwork(url="ws://localhost/something") + async with client.connect(): + m.assert_called_once_with("ws://localhost/something") diff --git a/packages/jumpstarter-driver-network/jumpstarter_driver_network/streams/websocket.py b/packages/jumpstarter-driver-network/jumpstarter_driver_network/streams/websocket.py index e79ed07c8..48085fd80 100644 --- a/packages/jumpstarter-driver-network/jumpstarter_driver_network/streams/websocket.py +++ b/packages/jumpstarter-driver-network/jumpstarter_driver_network/streams/websocket.py @@ -5,6 +5,7 @@ from anyio import BrokenResourceError, WouldBlock, create_memory_object_stream from anyio.abc import AnyByteStream, ObjectStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from websockets.asyncio.client import ClientConnection as WSSClientConnection from wsproto import ConnectionType, WSConnection from wsproto.connection import ConnectionState from wsproto.events import ( @@ -67,3 +68,24 @@ async def aclose(self): with suppress(LocalProtocolError): await self.stream.send(self.ws.send(CloseConnection(code=CloseReason.NORMAL_CLOSURE))) await self.stream.aclose() + + +@dataclass(kw_only=True) +class WebsocketClientStream(ObjectStream[bytes]): + ''' + Websocket client streaming. + ''' + conn: WSSClientConnection + + async def send(self, data: bytes) -> None: + await self.conn.send(data) + + async def receive(self) -> bytes: + return await self.conn.recv() + + async def send_eof(self): + pass + + async def aclose(self): + await self.conn.close() + diff --git a/packages/jumpstarter-driver-network/pyproject.toml b/packages/jumpstarter-driver-network/pyproject.toml index e5fcb3600..5853a8372 100644 --- a/packages/jumpstarter-driver-network/pyproject.toml +++ b/packages/jumpstarter-driver-network/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "fabric>=3.2.2", "wsproto>=1.2.0", "asyncclick>=8.1.8", + "websockets>=15.0.1" ] [project.entry-points."jumpstarter.drivers"] @@ -32,6 +33,7 @@ Novnc = "jumpstarter_driver_network.adapters:NovncAdapter" [dependency-groups] dev = [ "pytest>=8.3.2", + "pytest-asyncio>=0.26.0", "pytest-cov>=5.0.0", "types-paramiko>=3.5.0.20240928", "types-pexpect>=4.9.0.20241208", diff --git a/uv.lock b/uv.lock index ed9d2fc7b..426439835 100644 --- a/uv.lock +++ b/uv.lock @@ -1227,13 +1227,15 @@ dependencies = [ { name = "asyncclick" }, { name = "jumpstarter" }, { name = "jumpstarter-driver-composite" }, + { name = "jumpstarter-driver-network" }, { name = "jumpstarter-driver-power" }, - { name = "pyserial" }, + { name = "jumpstarter-driver-pyserial" }, ] [package.dev-dependencies] dev = [ { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "requests-mock" }, { name = "trio" }, @@ -1244,13 +1246,15 @@ requires-dist = [ { name = "asyncclick", specifier = ">=8.1.7.2" }, { name = "jumpstarter", editable = "packages/jumpstarter" }, { name = "jumpstarter-driver-composite", editable = "packages/jumpstarter-driver-composite" }, + { name = "jumpstarter-driver-network", editable = "packages/jumpstarter-driver-network" }, { name = "jumpstarter-driver-power", editable = "packages/jumpstarter-driver-power" }, - { name = "pyserial", specifier = ">=3.5" }, + { name = "jumpstarter-driver-pyserial", editable = "packages/jumpstarter-driver-pyserial" }, ] [package.metadata.requires-dev] dev = [ { name = "pytest", specifier = ">=8.3.2" }, + { name = "pytest-asyncio", specifier = ">=0.25.3" }, { name = "pytest-cov", specifier = ">=5.0.0" }, { name = "requests-mock" }, { name = "trio", specifier = ">=0.28.0" }, @@ -1379,12 +1383,14 @@ dependencies = [ { name = "fabric" }, { name = "jumpstarter" }, { name = "pexpect" }, + { name = "websockets" }, { name = "wsproto" }, ] [package.dev-dependencies] dev = [ { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "types-paramiko" }, { name = "types-pexpect" }, @@ -1397,12 +1403,14 @@ requires-dist = [ { name = "fabric", specifier = ">=3.2.2" }, { name = "jumpstarter", editable = "packages/jumpstarter" }, { name = "pexpect", specifier = ">=4.9.0" }, + { name = "websockets", specifier = ">=15.0.1" }, { name = "wsproto", specifier = ">=1.2.0" }, ] [package.metadata.requires-dev] dev = [ { name = "pytest", specifier = ">=8.3.2" }, + { name = "pytest-asyncio", specifier = ">=0.26.0" }, { name = "pytest-cov", specifier = ">=5.0.0" }, { name = "types-paramiko", specifier = ">=3.5.0.20240928" }, { name = "types-pexpect", specifier = ">=4.9.0.20241208" },