Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions __templates__/driver/jumpstarter_driver/driver.py.tmpl
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import logging
from dataclasses import dataclass

from jumpstarter.driver import Driver, export

logger = logging.getLogger(__name__)

@dataclass(kw_only=True)
class ${DRIVER_CLASS}(Driver):
"""${DRIVE_NAME} driver for Jumpstarter"""
Expand All @@ -22,10 +19,10 @@ class ${DRIVER_CLASS}(Driver):

@export
def method1(self):
logger.info("Method1 called")
self.logger.info("Method1 called")
return "method1 response"

@export
def method2(self):
logger.info("Method2 called")
self.logger.info("Method2 called")
return "method2 response"
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import logging
import os
import time
from collections.abc import AsyncGenerator
Expand All @@ -19,8 +18,6 @@

from jumpstarter.driver import Driver, export

log = logging.getLogger(__name__)


@dataclass(kw_only=True)
class DutlinkConfig:
Expand All @@ -35,7 +32,7 @@ def __post_init__(self):
for dev in usb.core.find(idVendor=0x2B23, idProduct=0x1012, find_all=True):
serial = usb.util.get_string(dev, dev.iSerialNumber)
if serial == self.serial or self.serial is None:
log.debug(f"found dutlink board with serial {serial}")
self.logger.debug(f"found dutlink board with serial {serial}")

self.serial = serial
self.dev = dev
Expand Down Expand Up @@ -78,9 +75,7 @@ def control(self, direction, ty, actions, action, value):

if direction == usb.ENDPOINT_IN:
str_value = bytes(res).decode("utf-8")
log.debug(
"ctrl_transfer result: %s",
)
self.logger.debug("ctrl_transfer result: %s", str_value)
return str_value


Expand All @@ -101,7 +96,7 @@ class DutlinkPower(DutlinkConfig, PowerInterface, Driver):
last_action: str | None = field(default=None)

def control(self, action):
log.debug(f"power control: {action}")
self.logger.debug(f"power control: {action}")
if self.last_action == action:
return

Expand Down Expand Up @@ -160,7 +155,7 @@ class DutlinkStorageMux(DutlinkConfig, StorageMuxInterface, Driver):
storage_device: str

def control(self, action):
log.debug(f"storage control: {action}")
self.logger.debug(f"storage control: {action}")
return super().control(
usb.ENDPOINT_OUT,
0x02,
Expand Down Expand Up @@ -190,9 +185,9 @@ def off(self):
async def wait_for_storage_device(self):
with fail_after(20):
while True:
log.debug(f"waiting for storage device {self.storage_device}")
self.logger.debug(f"waiting for storage device {self.storage_device}")
if os.path.exists(self.storage_device):
log.debug(f"storage device {self.storage_device} is ready")
self.logger.debug(f"storage device {self.storage_device} is ready")
# https://stackoverflow.com/a/2774125
fd = os.open(self.storage_device, os.O_WRONLY)
try:
Expand All @@ -213,7 +208,7 @@ async def write(self, src: str):
async for chunk in res:
await stream.send(chunk)
if total_bytes > next_print:
log.debug(f"{self.storage_device} written {total_bytes / (1024 * 1024)} MB")
self.logger.debug(f"{self.storage_device} written {total_bytes / (1024 * 1024)} MB")
next_print += 50 * 1024 * 1024
total_bytes += len(chunk)

Expand Down Expand Up @@ -255,15 +250,16 @@ def __post_init__(self):
super().__post_init__()

self.children["power"] = DutlinkPower(serial=self.serial, timeout_s=self.timeout_s)
self.children["storage"] = DutlinkStorageMux(serial=self.serial, storage_device=self.storage_device,
timeout_s=self.timeout_s)
self.children["storage"] = DutlinkStorageMux(
serial=self.serial, storage_device=self.storage_device, timeout_s=self.timeout_s
)

# if an alternate serial port has been requested, use it
if self.alternate_console is not None:
try:
self.children["console"] = PySerial(url=self.alternate_console, baudrate=self.baudrate)
except SerialException:
log.info(
self.logger.info(
f"failed to open alternate console {self.alternate_console} but trying to power on the target once"
)
self.children["power"].on()
Expand Down
60 changes: 28 additions & 32 deletions packages/jumpstarter-driver-http/jumpstarter_driver_http/driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
Expand All @@ -10,8 +9,6 @@

from jumpstarter.driver import Driver, export

logger = logging.getLogger(__name__)


class HttpServerError(Exception):
"""Base exception for HTTP server errors"""
Expand All @@ -21,26 +18,12 @@ class FileWriteError(HttpServerError):
"""Exception raised when file writing fails"""


def get_default_ip():
try:
import socket

s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
ip = s.getsockname()[0]
s.close()
return ip
except Exception:
logger.warning("Could not determine default IP address, falling back to 0.0.0.0")
return "0.0.0.0"


@dataclass(kw_only=True)
class HttpServer(Driver):
"""HTTP Server driver for Jumpstarter"""

root_dir: str = "/var/www"
host: str = field(default_factory=get_default_ip)
host: str = field(default=None)
port: int = 8080
app: web.Application = field(init=False, default_factory=web.Application)
runner: Optional[web.AppRunner] = field(init=False, default=None)
Expand All @@ -53,6 +36,19 @@ def __post_init__(self):
web.get("/{filename}", self.get_file),
]
)
if self.host is None:
self.host = self.get_default_ip()

def get_default_ip(self):
try:
import socket

with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
except Exception:
self.logger.warning("Could not determine default IP address, falling back to 0.0.0.0")
return "0.0.0.0"

@classmethod
def client(cls) -> str:
Expand Down Expand Up @@ -86,11 +82,11 @@ async def put_file(self, filename: str, src_stream) -> str:
async for chunk in src:
await dst.send(chunk)

logger.info(f"File '{filename}' written to '{file_path}'")
self.logger.info(f"File '{filename}' written to '{file_path}'")
return f"{self.get_url()}/{filename}"

except Exception as e:
logger.error(f"Failed to upload file '{filename}': {e}")
self.logger.error(f"Failed to upload file '{filename}': {e}")
raise FileWriteError(f"Failed to upload file '{filename}': {e}") from e

@export
Expand All @@ -112,10 +108,10 @@ async def delete_file(self, filename: str) -> str:
raise HttpServerError(f"File '{filename}' does not exist.")
try:
file_path.unlink()
logger.info(f"File '{filename}' has been deleted.")
self.logger.info(f"File '{filename}' has been deleted.")
return filename
except Exception as e:
logger.error(f"Failed to delete file '{filename}': {e}")
self.logger.error(f"Failed to delete file '{filename}': {e}")
raise HttpServerError(f"Failed to delete file '{filename}': {e}") from e

async def get_file(self, request) -> web.FileResponse:
Expand All @@ -134,9 +130,9 @@ async def get_file(self, request) -> web.FileResponse:
filename = request.match_info["filename"]
file_path = os.path.join(self.root_dir, filename)
if not os.path.isfile(file_path):
logger.warning(f"File not found: {file_path}")
self.logger.warning(f"File not found: {file_path}")
raise web.HTTPNotFound(text=f"File '{filename}' not found.")
logger.info(f"Serving file: {file_path}")
self.logger.info(f"Serving file: {file_path}")
return web.FileResponse(file_path)

@export
Expand All @@ -155,7 +151,7 @@ def list_files(self) -> list[str]:
files = [f for f in files if os.path.isfile(os.path.join(self.root_dir, f))]
return files
except Exception as e:
logger.error(f"Failed to list files: {e}")
self.logger.error(f"Failed to list files: {e}")
raise HttpServerError(f"Failed to list files: {e}") from e

@export
Expand All @@ -167,7 +163,7 @@ async def start(self):
HttpServerError: If the server fails to start.
"""
if self.runner is not None:
logger.warning("HTTP server is already running.")
self.logger.warning("HTTP server is already running.")
return

self.runner = web.AppRunner(self.app)
Expand All @@ -176,7 +172,7 @@ async def start(self):

site = web.TCPSite(self.runner, self.host, self.port)
await site.start()
logger.info(f"HTTP server started at http://{self.host}:{self.port}")
self.logger.info(f"HTTP server started at http://{self.host}:{self.port}")

@export
async def stop(self):
Expand All @@ -187,11 +183,11 @@ async def stop(self):
HttpServerError: If the server fails to stop.
"""
if self.runner is None:
logger.warning("HTTP server is not running.")
self.logger.warning("HTTP server is not running.")
return

await self.runner.cleanup()
logger.info("HTTP server stopped.")
self.logger.info("HTTP server stopped.")
self.runner = None

@export
Expand Down Expand Up @@ -230,7 +226,7 @@ def close(self):
if anyio.get_current_task():
anyio.from_thread.run(self._async_cleanup)
except Exception as e:
logger.warning(f"HTTP server cleanup failed synchronously: {e}")
self.logger.warning(f"HTTP server cleanup failed synchronously: {e}")
self.runner = None
super().close()

Expand All @@ -239,6 +235,6 @@ async def _async_cleanup(self):
if self.runner:
await self.runner.shutdown()
await self.runner.cleanup()
logger.info("HTTP server cleanup completed asynchronously.")
self.logger.info("HTTP server cleanup completed asynchronously.")
except Exception as e:
logger.error(f"HTTP server cleanup failed asynchronously: {e}")
self.logger.error(f"HTTP server cleanup failed asynchronously: {e}")
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from abc import ABCMeta, abstractmethod
from contextlib import asynccontextmanager
from dataclasses import dataclass
Expand All @@ -13,8 +12,6 @@

from jumpstarter.driver import Driver, exportstream

logger = logging.getLogger(__name__)


class NetworkInterface(metaclass=ABCMeta):
@classmethod
Expand All @@ -34,7 +31,7 @@ class TcpNetwork(NetworkInterface, Driver):
@exportstream
@asynccontextmanager
async def connect(self):
logger.debug("Connecting TCP host=%s port=%d", self.host, self.port)
self.logger.debug("Connecting TCP host=%s port=%d", self.host, self.port)
async with await connect_tcp(remote_host=self.host, remote_port=self.port) as stream:
yield stream

Expand All @@ -47,7 +44,7 @@ class UdpNetwork(NetworkInterface, Driver):
@exportstream
@asynccontextmanager
async def connect(self):
logger.debug("Connecting UDP host=%s port=%d", self.host, self.port)
self.logger.debug("Connecting UDP host=%s port=%d", self.host, self.port)
async with await create_connected_udp_socket(remote_host=self.host, remote_port=self.port) as stream:
yield stream

Expand All @@ -59,7 +56,7 @@ class UnixNetwork(NetworkInterface, Driver):
@exportstream
@asynccontextmanager
async def connect(self):
logger.debug("Connecting UDS path=%s", self.path)
self.logger.debug("Connecting UDS path=%s", self.path)
async with await connect_unix(path=self.path) as stream:
yield stream

Expand All @@ -69,6 +66,6 @@ class EchoNetwork(NetworkInterface, Driver):
@asynccontextmanager
async def connect(self):
tx, rx = create_memory_object_stream[bytes](32)
logger.debug("Connecting Echo")
self.logger.debug("Connecting Echo")
async with StapledObjectStream(tx, rx) as stream:
yield stream
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from contextlib import asynccontextmanager
from dataclasses import dataclass, field

Expand All @@ -8,8 +7,6 @@

from jumpstarter.driver import Driver, exportstream

log = logging.getLogger(__name__)


@dataclass(kw_only=True)
class AsyncSerial(ObjectStream):
Expand Down Expand Up @@ -46,8 +43,8 @@ def client(cls) -> str:
@exportstream
@asynccontextmanager
async def connect(self):
log.info("Connecting to %s, baudrate: %d", self.url, self.baudrate)
self.logger.info("Connecting to %s, baudrate: %d", self.url, self.baudrate)
device = await run_sync(serial_for_url, self.url, self.baudrate)
async with AsyncSerial(device=device) as stream:
yield stream
log.info("Disconnected from %s", self.url)
self.logger.info("Disconnected from %s", self.url)
Loading