diff --git a/.gitignore b/.gitignore index 4458690..ce487ef 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ # pixi environments .pixi *.egg-info +site \ No newline at end of file diff --git a/docs/api/core.md b/docs/api/core.md index a7bbc07..dccc60d 100644 --- a/docs/api/core.md +++ b/docs/api/core.md @@ -2,13 +2,13 @@ ::: drone_controllers.core -The core module provides the foundational functionality for controller parametrization and registration. +The core module provides the foundational functionality for controller parametrization. ## Key Concepts ### Controller Parametrization -The `parametrize` function allows you to automatically configure controllers with parameters for specific drone models: +The `parametrize` function automatically configures a controller with parameters for a specific drone model by inspecting the function's keyword-only arguments and filling them from the corresponding TOML file: ```python from drone_controllers import parametrize @@ -18,38 +18,29 @@ from drone_controllers.mellinger import state2attitude controller = parametrize(state2attitude, "cf2x_L250") # Use the controller (all parameters are automatically filled in) -rpyt, pos_err = controller(pos, quat, vel, ang_vel, cmd) +rpyt, pos_err = controller(pos, quat, vel, cmd) ``` -### Parameter Registry +### Manual Parameter Loading -Controllers register their parameter types using the `@register_controller_parameters` decorator: +Use `load_params` to inspect or override parameters directly: ```python -@register_controller_parameters(MyControllerParams) -def my_controller(pos, vel, *, param1, param2, param3): - # Controller implementation - pass -``` - -### ControllerParams Protocol +from drone_controllers.core import load_params -All controller parameter classes must implement the `ControllerParams` protocol: +params = load_params("mellinger", "state2attitude", "cf2x_L250") +print(params["mass"]) # 0.029 +print(params["kp"]) # position gain array +``` -- `load(drone_model: str)` - Load parameters for a specific drone model -- `_asdict()` - Convert parameters to a dictionary +### Array Namespace Support -## Example Usage +Both `parametrize` and `load_params` accept an `xp` argument so that static parameters are placed in the correct array namespace before being bound to the function: ```python -from functools import partial -from drone_controllers.mellinger.params import StateParams - -# Manual parameter loading -params = StateParams.load("cf2x_L250") -controller = partial(state2attitude, **params._asdict()) - -# Equivalent to using parametrize +import jax.numpy as jnp from drone_controllers import parametrize -controller = parametrize(state2attitude, "cf2x_L250") +from drone_controllers.mellinger import state2attitude + +controller = parametrize(state2attitude, "cf2x_L250", xp=jnp) ``` diff --git a/docs/api/drones.md b/docs/api/drones.md new file mode 100644 index 0000000..ad77003 --- /dev/null +++ b/docs/api/drones.md @@ -0,0 +1,3 @@ +# Drones + +::: drone_controllers.drones diff --git a/docs/api/mellinger.md b/docs/api/mellinger.md index 8e4a744..68de824 100644 --- a/docs/api/mellinger.md +++ b/docs/api/mellinger.md @@ -17,7 +17,7 @@ from drone_controllers.mellinger import state2attitude controller = parametrize(state2attitude, "cf2x_L250") -rpyt, pos_err_i = controller(pos, quat, vel, ang_vel, cmd) +rpyt, pos_err_i = controller(pos, quat, vel, cmd) ``` ### attitude2force_torque @@ -33,7 +33,7 @@ from drone_controllers.mellinger import attitude2force_torque controller = parametrize(attitude2force_torque, "cf2x_L250") -force, torque, att_err_i = controller(pos, quat, vel, ang_vel, rpyt_cmd) +force, torque, att_err_i = controller(quat, ang_vel, rpyt_cmd) ``` ### force_torque2rotor_vel @@ -52,26 +52,6 @@ controller = parametrize(force_torque2rotor_vel, "cf2x_L250") rotor_speeds = controller(force, torque) ``` -## Parameter Classes - -### StateParams - -::: drone_controllers.mellinger.params.StateParams - -Parameters for the position control loop. - -### AttitudeParams - -::: drone_controllers.mellinger.params.AttitudeParams - -Parameters for the attitude control loop. - -### ForceTorqueParams - -::: drone_controllers.mellinger.params.ForceTorqueParams - -Parameters for the force/torque to rotor speed conversion. - ## Complete Controller Pipeline Here's how to use all three components together: @@ -100,8 +80,8 @@ ang_vel = np.array([0.0, 0.0, 0.0]) cmd = np.array([1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) # Run the complete pipeline -rpyt, pos_err_i = state_ctrl(pos, quat, vel, ang_vel, cmd) -force, torque, att_err_i = attitude_ctrl(pos, quat, vel, ang_vel, rpyt) +rpyt, pos_err_i = state_ctrl(pos, quat, vel, cmd) +force, torque, att_err_i = attitude_ctrl(quat, ang_vel, rpyt) rotor_speeds = rotor_ctrl(force, torque) print(f"Final rotor speeds: {rotor_speeds} rad/s") @@ -121,10 +101,10 @@ for step in range(100): # Pass previous integral errors ctrl_errors = (pos_err_i,) if pos_err_i is not None else None - rpyt, pos_err_i = state_ctrl(pos, quat, vel, ang_vel, cmd, ctrl_errors=ctrl_errors) + rpyt, pos_err_i = state_ctrl(pos, quat, vel, cmd, ctrl_errors=ctrl_errors) ctrl_errors = (att_err_i,) if att_err_i is not None else None - force, torque, att_err_i = attitude_ctrl(pos, quat, vel, ang_vel, rpyt, ctrl_errors=ctrl_errors) + force, torque, att_err_i = attitude_ctrl(quat, ang_vel, rpyt, ctrl_errors=ctrl_errors) rotor_speeds = rotor_ctrl(force, torque) ``` @@ -145,7 +125,7 @@ quat_jax = jnp.array([0.0, 0.0, 0.0, 1.0]) # JIT compile the controller jit_controller = jit(parametrize(state2attitude, "cf2x_L250")) -rpyt, pos_err_i = jit_controller(pos_jax, quat_jax, vel_jax, ang_vel_jax, cmd_jax) +rpyt, pos_err_i = jit_controller(pos_jax, quat_jax, vel_jax, cmd_jax) ``` # References diff --git a/docs/getting-started/quick-start.md b/docs/getting-started/quick-start.md index 264d230..02eb7c7 100644 --- a/docs/getting-started/quick-start.md +++ b/docs/getting-started/quick-start.md @@ -27,11 +27,11 @@ ang_vel = np.array([0.0, 0.0, 0.0]) # Current angular velocity [wx, wy, wz] cmd = np.array([1.0, 0.0, 1.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) # Step 1: State to attitude control -rpyt_cmd, pos_error_integral = state_ctrl(pos, quat, vel, ang_vel, cmd) +rpyt_cmd, pos_error_integral = state_ctrl(pos, quat, vel, cmd) print(f"Attitude command (R,P,Y,T): {rpyt_cmd}") # Step 2: Attitude to force/torque -force, torque, att_error_integral = attitude_ctrl(pos, quat, vel, ang_vel, rpyt_cmd) +force, torque, att_error_integral = attitude_ctrl(quat, ang_vel, rpyt_cmd) print(f"Desired force: {force[0]:.3f} N") print(f"Desired torque: {torque}") @@ -66,7 +66,7 @@ cmd_batch = np.zeros((*batch_shape, 13)) cmd_batch[..., :3] = pos_batch + np.random.randn(*batch_shape, 3) * 0.5 # Target positions # Process entire batch at once -rpyt_batch, pos_err_batch = controller(pos_batch, quat_batch, vel_batch, ang_vel_batch, cmd_batch) +rpyt_batch, pos_err_batch = controller(pos_batch, quat_batch, vel_batch, cmd_batch) print(f"Batch output shape: {rpyt_batch.shape}") # Should be (3, 5, 4) print(f"Per-drone commands: {rpyt_batch[0, 0, :]}") # First drone, first timestep @@ -74,31 +74,15 @@ print(f"Per-drone commands: {rpyt_batch[0, 0, :]}") # First drone, first timest ## Manual Parameter Loading -You can also load parameters manually without using the `parametrize` decorator: +You can inspect or override parameters using `load_params`: ```python -import numpy as np -from functools import partial -from drone_controllers.mellinger import state2attitude -from drone_controllers.mellinger.params import StateParams - -# Load parameters manually -params = StateParams.load("cf2x_L250") -print(f"Position gains: {params.kp}") -print(f"Velocity gains: {params.kd}") -print(f"Drone mass: {params.mass} kg") - -# Create controller with custom parameters -controller = partial(state2attitude, **params._asdict()) - -# Use as before -pos = np.array([0.0, 0.0, 1.0]) -quat = np.array([0.0, 0.0, 0.0, 1.0]) -vel = np.array([0.0, 0.0, 0.0]) -ang_vel = np.array([0.0, 0.0, 0.0]) -cmd = np.ones(13) +from drone_controllers.core import load_params -rpyt, pos_err = controller(pos, quat, vel, ang_vel, cmd, ctrl_freq=100) +params = load_params("mellinger", "state2attitude", "cf2x_L250") +print(f"Position gains: {params['kp']}") +print(f"Velocity gains: {params['kd']}") +print(f"Drone mass: {params['mass']} kg") ``` ## Array API Compatibility @@ -124,7 +108,7 @@ controller = parametrize(state2attitude, "cf2x_L250") from jax import jit jit_controller = jit(controller) -rpyt, pos_err = jit_controller(pos, quat, vel, ang_vel, cmd) +rpyt, pos_err = jit_controller(pos, quat, vel, cmd) print(f"Output type: {type(rpyt)}") # JAX array ``` @@ -143,16 +127,15 @@ controller = parametrize(state2attitude, "cf2x_L250") pos = np.array([0.0, 0.0, 0.5]) quat = np.array([0.0, 0.0, 0.0, 1.0]) vel = np.array([0.0, 0.0, 0.0]) -ang_vel = np.array([0.0, 0.0, 0.0]) # Target hover at 1m altitude cmd = np.array([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) # First call - no integral error history -rpyt1, pos_err_i1 = controller(pos, quat, vel, ang_vel, cmd, ctrl_errors=None) +rpyt1, pos_err_i1 = controller(pos, quat, vel, cmd, ctrl_errors=None) # Subsequent calls - pass integral error from previous step -rpyt2, pos_err_i2 = controller(pos, quat, vel, ang_vel, cmd, ctrl_errors=(pos_err_i1,)) +rpyt2, pos_err_i2 = controller(pos, quat, vel, cmd, ctrl_errors=(pos_err_i1,)) print(f"Integral error evolution: {np.linalg.norm(pos_err_i1)} -> {np.linalg.norm(pos_err_i2)}") ``` @@ -176,7 +159,6 @@ for drone in Drones: Now that you've seen the basics, explore: -- **[Concepts](../concepts/overview.md)** - Understand the theory behind the controllers - **[API Reference](../api/core.md)** - Complete API documentation ## Common Issues diff --git a/docs/index.md b/docs/index.md index d7122af..4e553ab 100644 --- a/docs/index.md +++ b/docs/index.md @@ -40,13 +40,12 @@ controller = parametrize(state2attitude, "cf2x_L250") pos = np.array([0.0, 0.0, 1.0]) # position [x, y, z] quat = np.array([0.0, 0.0, 0.0, 1.0]) # quaternion [x, y, z, w] vel = np.array([0.0, 0.0, 0.0]) # velocity [vx, vy, vz] -ang_vel = np.array([0.0, 0.0, 0.0]) # angular velocity [wx, wy, wz] # Command: [x, y, z, vx, vy, vz, ax, ay, az, yaw, r_rate, p_rate, y_rate] cmd = np.array([1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) # Compute control output -rpyt, pos_err_i = controller(pos, quat, vel, ang_vel, cmd) +rpyt, pos_err_i = controller(pos, quat, vel, cmd) print(f"Roll-Pitch-Yaw-Thrust command: {rpyt}") ``` @@ -54,7 +53,7 @@ print(f"Roll-Pitch-Yaw-Thrust command: {rpyt}") ### Implemented Controllers -- **[Mellinger Controller](api/drone_controllers/mellinger/control.md)** — Geometric tracking controller based on the original Crazyflie implementation +- **[Mellinger Controller](api/mellinger.md)** — Geometric tracking controller based on the original Crazyflie implementation ### Supported Drone Models @@ -89,6 +88,5 @@ All controllers support the Python Array API standard, meaning you can use them ## Getting Help - Read the [Getting Started](getting-started/installation.md) guide -- Browse the [API Reference](api/core.md) -- Check out [Concepts](concepts/overview.md) for theory +- Browse the [API Reference](api/core.md) - Report issues on [GitHub](https://github.com/learnsyslab/drone-controllers/issues) diff --git a/drone_controllers/core.py b/drone_controllers/core.py index 3793154..2823ac7 100644 --- a/drone_controllers/core.py +++ b/drone_controllers/core.py @@ -1,27 +1,36 @@ -"""Core functionalities for controller parametrization and registration.""" +"""Core functionalities for controller parametrization.""" from __future__ import annotations +import inspect +import tomllib from functools import partial -from typing import Any, Callable, ParamSpec, Protocol, TypeVar, runtime_checkable +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TypeVar -P = ParamSpec("P") -R = TypeVar("R") +import numpy as np +if TYPE_CHECKING: + from types import ModuleType -controller_parameter_registry: dict[str, type[ControllerParams]] = {} +P = ParamSpec("P") +R = TypeVar("R") -def parametrize(fn: Callable[P, R], drone_model: str) -> Callable[P, R]: +def parametrize( + fn: Callable[P, R], drone_model: str, xp: ModuleType | None = None, device: str | None = None +) -> Callable[P, R]: """Parametrize a controller function with the default controller parameters for a drone model. Args: fn: The controller function to parametrize. drone_model: The drone model to use. + xp: The array API module to use. If not provided, numpy is used. + device: The device to use. If None, the device is inferred from the xp module. Example: - >>> from drone_models.controller import parametrize - >>> from drone_models.controller.mellinger import state2attitude + >>> from drone_controllers.core import parametrize + >>> from drone_controllers.mellinger import state2attitude >>> controller_fn = parametrize(state2attitude, drone_model="cf2x_L250") >>> command_rpyt, int_pos_err = controller_fn( ... pos=pos, @@ -36,54 +45,50 @@ def parametrize(fn: Callable[P, R], drone_model: str) -> Callable[P, R]: Returns: The parametrized controller function with all keyword argument only parameters filled in. """ - controller_id = fn.__module__ + "." + fn.__name__ + xp = np if xp is None else xp + controller = fn.__module__.split(".")[-2] + sig = inspect.signature(fn) + kwonly_params = { + name + for name, param in sig.parameters.items() + if param.kind == inspect.Parameter.KEYWORD_ONLY + } try: - params = controller_parameter_registry[controller_id].load(drone_model) + params = load_params(controller, fn.__name__, drone_model, xp=xp) except KeyError as e: raise KeyError( - f"Controller `{controller_id}` does not exist in the parameter registry" + f"Controller `{controller}.{fn.__name__}` not found for drone `{drone_model}`" ) from e - except ValueError as e: - raise ValueError(f"Drone model `{drone_model}` not supported for `{fn.__name__}`") from e - return partial(fn, **params._asdict()) - - -@runtime_checkable -class ControllerParams(Protocol): - """Protocol for controller parameters.""" + params = {k: xp.asarray(v, device=device) for k, v in params.items() if k in kwonly_params} + return partial(fn, **params) - @staticmethod - def load(drone_model: str) -> ControllerParams: - """Load the parameters from the config file.""" - def _asdict(self) -> dict[str, Any]: - """Convert the parameters to a dictionary.""" +def load_params( + controller: str, fn_name: str, drone_model: str, xp: ModuleType | None = None +) -> dict[str, Any]: + """Load and merge controller parameters for a specific function. - -def register_controller_parameters( - params: ControllerParams | type[ControllerParams], -) -> Callable[[Callable[P, R]], Callable[P, R]]: - """Register the default controller parameters for this controller. - - Warning: - The controller parameters **must** be a named tuple with a function `load` that takes in the - drone model name and returns an instance of itself, or a class that implements the - ControllerParams protocol. + Reads ``drone_controllers//params.toml`` and merges the + ``[drone_model.core]`` section with the ``[drone_model.]`` section, + with function-specific values taking precedence over core values. Args: - params: The controller parameter type. + controller: Name of the controller sub-package, e.g. ``"mellinger"``. + fn_name: Name of the controller function, e.g. ``"state2attitude"``. + drone_model: Name of the drone configuration, e.g. ``"cf2x_L250"``. + xp: The array API module to use. If not provided, numpy is used. Returns: - A decorator function that registers the parameters and returns the function unchanged. - """ - if not isinstance(params, ControllerParams): - raise ValueError(f"{params} does not implement the ControllerParams protocol") + A flat dict mapping parameter names to arrays in the requested array namespace. - def decorator(fn: Callable[P, R]) -> Callable[P, R]: - controller_id = fn.__module__ + "." + fn.__name__ - if controller_id in controller_parameter_registry: - raise ValueError(f"Controller `{controller_id}` already registered") - controller_parameter_registry[controller_id] = params - return fn - - return decorator + Raises: + KeyError: If ``drone_model`` is not found in the params.toml file. + """ + xp = np if xp is None else xp + with open(Path(__file__).parent / f"{controller}/params.toml", "rb") as f: + params = tomllib.load(f) + if drone_model not in params: + raise KeyError(f"Drone model `{drone_model}` not found in {controller}/params.toml") + model_params = params[drone_model] + merged = model_params.get("core", {}) | model_params.get(fn_name, {}) + return {k: xp.asarray(v) for k, v in merged.items()} diff --git a/drone_controllers/mellinger/control.py b/drone_controllers/mellinger/control.py index 9745ba3..af4a026 100644 --- a/drone_controllers/mellinger/control.py +++ b/drone_controllers/mellinger/control.py @@ -8,23 +8,18 @@ from array_api_compat import array_namespace from scipy.spatial.transform import Rotation as R -from drone_controllers.core import register_controller_parameters -from drone_controllers.mellinger.params import AttitudeParams, ForceTorqueParams, StateParams from drone_controllers.transform import force2pwm, motor_force2rotor_vel, pwm2force if TYPE_CHECKING: from drone_controllers._typing import Array # To be changed to array_api_typing later -@register_controller_parameters(StateParams) def state2attitude( pos: Array, quat: Array, vel: Array, - ang_vel: Array, cmd: Array, ctrl_errors: tuple[Array, ...] | None = None, - ctrl_info: tuple[Array, ...] | None = None, ctrl_freq: float = 100, *, mass: float, @@ -46,12 +41,10 @@ def state2attitude( pos: Drone position with shape (..., 3). quat: Drone orientation as xyzw quaternion with shape (..., 4). vel: Drone velocity with shape (..., 3). - ang_vel: Drone angular drone velocity in rad/s with shape (..., 3). cmd: Full state command in SI units and rad with shape (..., 13). The entries are [x, y, z, vx, vy, vz, ax, ay, az, yaw, roll_rate, pitch_rate, yaw_rate]. ctrl_errors: Tuple of integral errors. For state2attitude, the tuple contains a single array (..., 3) for the position integral error or is None. - ctrl_info: Tuple of arrays with additional data. Not used in state2attitude. ctrl_freq: Control frequency in Hz mass: Drone mass used for calculations in the controller in kg. kp: Proportional gain for the position controller with shape (3,). @@ -132,16 +125,12 @@ def state2attitude( return command_rpyt, int_pos_err -@register_controller_parameters(AttitudeParams) def attitude2force_torque( - pos: Array, quat: Array, - vel: Array, ang_vel: Array, cmd: Array, prev_ang_vel: Array | None = None, ctrl_errors: tuple[Array, ...] | None = None, - ctrl_info: tuple[Array, ...] | None = None, ctrl_freq: int = 500, *, kR: Array, @@ -164,14 +153,11 @@ def attitude2force_torque( compatible with the new frame of the Crazyflie 2.1. Args: - pos: Drone position with shape (..., 3). quat: Drone orientation as xyzw quaternion with shape (..., 4). - vel: Drone velocity with shape (..., 3). ang_vel: Drone angular drone velocity in rad/s with shape (..., 3). cmd: Commanded attitude (roll, pitch, yaw) and total thrust [rad, rad, rad, N]. ctrl_errors: Tuple of integral errors. For attitude2force_torque, the tuple contains a single array (..., 3) for the angular velocity integral error or is None. - ctrl_info: Tuple of arrays with additional data. Not used in attitude2force_torque. ctrl_freq: Control frequency in Hz kR: Proportional gain for the rotation error with shape (3,). kw: Proportional gain for the angular velocity error with shape (3,). @@ -182,9 +168,7 @@ def attitude2force_torque( thrust_max: Maximum thrust in N. pwm_min: Minimum PWM value. pwm_max: Maximum PWM value. - ang_vel_des: Desired angular velocity in rad/s. prev_ang_vel: Previous angular velocity in rad/s. - prev_ang_vel_des: Previous angular velocity command in rad/s. L: Distance from the center of the quadrotor to the center of the rotor in m. thrust2torque: Conversion factor (m). mixing_matrix: Mixing matrix for the motor forces with shape (4, 3). @@ -258,7 +242,6 @@ def force_torque_pwms2pwms(force_pwm: Array, torque_pwm: Array, mixing_matrix: A return force_pwm[..., None] + (torque_pwm @ mixing_matrix) -@register_controller_parameters(ForceTorqueParams) def force_torque2rotor_vel( force: Array, torque: Array, diff --git a/drone_controllers/mellinger/params.py b/drone_controllers/mellinger/params.py deleted file mode 100644 index 1ef155f..0000000 --- a/drone_controllers/mellinger/params.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Parameters for the Mellinger controller.""" - -from __future__ import annotations - -import tomllib -from pathlib import Path -from typing import TYPE_CHECKING, NamedTuple - -import numpy as np - -if TYPE_CHECKING: - from drone_controllers._typing import Array # To be changed to array_api_typing later - - -class StateParams(NamedTuple): - """Parameters for the Mellinger state controller.""" - - kp: Array - kd: Array - ki: Array - int_err_max: Array - mass: float - gravity_vec: Array - mass_thrust: float - thrust_max: float - pwm_max: float - - @staticmethod - def load(drone_model: str) -> StateParams: - """Load the parameters from the config file.""" - with open(Path(__file__).parent / "params.toml", "rb") as f: - params = tomllib.load(f) - if drone_model not in params: - raise KeyError(f"Drone model `{drone_model}` not found in params.toml") - params = params[drone_model]["state2attitude"] | params[drone_model]["core"] - params = {k: np.asarray(v) for k, v in params.items() if k in StateParams._fields} - return StateParams(**params) - - -class AttitudeParams(NamedTuple): - """Parameters for the Mellinger attitude controller.""" - - kR: Array - kw: Array - ki_m: Array - kd_omega: Array - int_err_max: Array - torque_pwm_max: Array - thrust_max: float - pwm_min: float - pwm_max: float - L: float - thrust2torque: float - mixing_matrix: Array - - @staticmethod - def load(drone_model: str) -> AttitudeParams: - """Load the parameters from the config file.""" - with open(Path(__file__).parent / "params.toml", "rb") as f: - params = tomllib.load(f) - if drone_model not in params: - raise KeyError(f"Drone model `{drone_model}` not found in params.toml") - params = params[drone_model]["attitude2force_torque"] | params[drone_model]["core"] - params = {k: np.asarray(v) for k, v in params.items() if k in AttitudeParams._fields} - return AttitudeParams(**params) - - -class ForceTorqueParams(NamedTuple): - """Parameters for the Mellinger force torque controller.""" - - thrust_min: float - thrust_max: float - L: float - rpm2thrust: Array - thrust2torque: float - mixing_matrix: Array - - @staticmethod - def load(drone_model: str) -> ForceTorqueParams: - """Load the parameters from the config file.""" - with open(Path(__file__).parent / "params.toml", "rb") as f: - params = tomllib.load(f) - if drone_model not in params: - raise KeyError(f"Drone model `{drone_model}` not found in params.toml") - params = params[drone_model]["core"] - params = {k: np.asarray(v) for k, v in params.items() if k in ForceTorqueParams._fields} - return ForceTorqueParams(**params) diff --git a/mkdocs.yml b/mkdocs.yml index ba8b308..f1a4e3b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -36,11 +36,6 @@ nav: - Getting Started: - Installation: getting-started/installation.md - Quick Start: getting-started/quick-start.md - - Concepts: - - Overview: concepts/overview.md - - Control Theory: concepts/control-theory.md - - Drone Dynamics: concepts/drone-dynamics.md - - Controller Design: concepts/controller-design.md - API Reference: - Core: api/core.md - Drones: api/drones.md @@ -49,11 +44,6 @@ nav: plugins: - search - - gen-files: - scripts: - - docs/gen_ref_pages.py - - literate-nav: - nav_file: SUMMARY.md - mkdocstrings: handlers: python: diff --git a/pixi.lock b/pixi.lock index ec331df..275b54c 100644 --- a/pixi.lock +++ b/pixi.lock @@ -137,7 +137,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zstandard-0.23.0-py313h07c4f96_3.conda - pypi: https://files.pythonhosted.org/packages/a0/d3/54cd560804a8c2b898824778e86c13c2a14600bc83532a9c4f69f2f469c3/array_api_compat-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/46/f7/9e14be985fd77ae26fee9136c9735e8987772e0ecf5f1f4e6e2b84cadc46/array_api_extra-0.10.1-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/fa/88/6764e7a109dd84294850741501145da90d13cdeac9d4e614929464a37420/build-1.4.4-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/0d/fe/6bea5c9162869c5beba5d9c8abbed835ec85bf1ec1fba05a3822325c45f3/build-1.5.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4c/ab/1e73cfc181afc3054a09e5e8f7753a8fba254592ff50b735d7456d197353/cryptography-46.0.0-cp311-abi3-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/02/10/5da547df7a391dcde17f59520a231527b8571e6f46fc8efb02ccb370ab12/docutils-0.22.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/42/77/de194443bf38daed9452139e960c632b0ef9f9a5dd9ce605fdf18ca9f1b1/id-1.6.1-py3-none-any.whl @@ -146,7 +146,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/fd/c4/813bb09f0985cb21e959f21f2464169eca882656849adf727ac7bb7e1767/jaraco_functools-4.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/a3/e137168c9c44d18eff0376253da9f1e9234d0239e0ee230d2fee6cea8e55/jeepney-0.9.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/81/db/e655086b7f3a705df045bf0933bdd9c2f79bb3c97bfef1384598bb79a217/keyring-25.7.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/b3/81/4da04ced5a082363ecfa159c010d200ecbd959ae410c10c0264a38cac0f5/markdown_it_py-4.2.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/cb/98/6af411189d9413534c3eb691182bff1f5c6d44ed2f93f2edfe52a1bbceb8/more_itertools-11.0.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/1b/0e/bf298920729f216adcb002acf7ea01b90842603d2e4e2ce9b900d9ee8fab/nh3-0.3.5-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl @@ -204,14 +204,14 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: https://files.pythonhosted.org/packages/a0/d3/54cd560804a8c2b898824778e86c13c2a14600bc83532a9c4f69f2f469c3/array_api_compat-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/46/f7/9e14be985fd77ae26fee9136c9735e8987772e0ecf5f1f4e6e2b84cadc46/array_api_extra-0.10.1-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/fa/88/6764e7a109dd84294850741501145da90d13cdeac9d4e614929464a37420/build-1.4.4-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/0d/fe/6bea5c9162869c5beba5d9c8abbed835ec85bf1ec1fba05a3822325c45f3/build-1.5.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/02/10/5da547df7a391dcde17f59520a231527b8571e6f46fc8efb02ccb370ab12/docutils-0.22.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/42/77/de194443bf38daed9452139e960c632b0ef9f9a5dd9ce605fdf18ca9f1b1/id-1.6.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/7f/66/b15ce62552d84bbfcec9a4873ab79d993a1dd4edb922cbfccae192bd5b5f/jaraco.classes-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f2/58/bc8954bda5fcda97bd7c19be11b85f91973d67a706ed4a3aec33e7de22db/jaraco_context-6.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fd/c4/813bb09f0985cb21e959f21f2464169eca882656849adf727ac7bb7e1767/jaraco_functools-4.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/81/db/e655086b7f3a705df045bf0933bdd9c2f79bb3c97bfef1384598bb79a217/keyring-25.7.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/b3/81/4da04ced5a082363ecfa159c010d200ecbd959ae410c10c0264a38cac0f5/markdown_it_py-4.2.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/cb/98/6af411189d9413534c3eb691182bff1f5c6d44ed2f93f2edfe52a1bbceb8/more_itertools-11.0.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/85/30/d162e99746a2fb1d98bb0ef23af3e201b156cf09f7de867c7390c8fe1c06/nh3-0.3.5-cp38-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl @@ -670,10 +670,10 @@ packages: - pkg:pypi/brotli?source=hash-mapping size: 359854 timestamp: 1764018178608 -- pypi: https://files.pythonhosted.org/packages/fa/88/6764e7a109dd84294850741501145da90d13cdeac9d4e614929464a37420/build-1.4.4-py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/0d/fe/6bea5c9162869c5beba5d9c8abbed835ec85bf1ec1fba05a3822325c45f3/build-1.5.0-py3-none-any.whl name: build - version: 1.4.4 - sha256: 8c3f48a6090b39edec1a273d2d57949aaf13723b01e02f9d518396887519f64d + version: 1.5.0 + sha256: 13f3eecb844759ab66efec90ca17639bbf14dc06cb2fdf37a9010322d9c50a6f requires_dist: - packaging>=24.0 - pyproject-hooks @@ -682,10 +682,9 @@ packages: - tomli>=1.1.0 ; python_full_version < '3.11' - keyring ; extra == 'keyring' - uv>=0.1.18 ; extra == 'uv' - - virtualenv>=20.11 ; python_full_version < '3.10' and extra == 'virtualenv' - virtualenv>=20.17 ; python_full_version >= '3.10' and python_full_version < '3.14' and extra == 'virtualenv' - virtualenv>=20.31 ; python_full_version >= '3.14' and extra == 'virtualenv' - requires_python: '>=3.9' + requires_python: '>=3.10' - conda: https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda sha256: 5ced96500d945fb286c9c838e54fa759aa04a7129c59800f0846b4335cee770d md5: 62ee74e96c5ebb0af99386de58cf9553 @@ -843,13 +842,14 @@ packages: requires_python: '>=3.9' - pypi: ./ name: drone-controllers - version: 0.1.0 - sha256: 89f6ceceaeed050513bac3515ea219ed5ba8f70854b7d0f4f411330cff144dca + version: 0.2.0 + sha256: f97b3f74b3f3163b18ac4e2d4ac832ff714b92b3a7b9bbbd2144c695125ed1c5 requires_dist: - numpy>=2.0.0 - scipy>=1.17.0 - array-api-compat - array-api-extra + editable: true - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.3.0-pyhd8ed1ab_0.conda sha256: ce61f4f99401a4bd455b89909153b40b9c823276aefcbb06f2044618696009ca md5: 72e42d28960d875c7654614f8b50939a @@ -1638,10 +1638,10 @@ packages: - pkg:pypi/markdown?source=hash-mapping size: 80353 timestamp: 1750360406187 -- pypi: https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/b3/81/4da04ced5a082363ecfa159c010d200ecbd959ae410c10c0264a38cac0f5/markdown_it_py-4.2.0-py3-none-any.whl name: markdown-it-py - version: 4.0.0 - sha256: 87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147 + version: 4.2.0 + sha256: 9f7ebbcd14fe59494226453aed97c1070d83f8d24b6fc3a3bcf9a38092641c4a requires_dist: - mdurl~=0.1 - psutil ; extra == 'benchmarking' @@ -1669,6 +1669,7 @@ packages: - pytest ; extra == 'testing' - pytest-cov ; extra == 'testing' - pytest-regressions ; extra == 'testing' + - pytest-timeout ; extra == 'testing' - requests ; extra == 'testing' requires_python: '>=3.10' - conda: https://conda.anaconda.org/conda-forge/linux-64/markupsafe-3.0.2-py313h8060acc_1.conda diff --git a/pyproject.toml b/pyproject.toml index d9eaff5..b4fdd45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "drone_controllers" -version = "0.1.0" +version = "0.2.0" description = "Controllers for quadrotor drones." authors = [{ name = "Marcel Rath" }, { name = "Martin Schuck" }] readme = "README.md" diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py new file mode 100644 index 0000000..b668063 --- /dev/null +++ b/tests/unit/test_core.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import inspect +import tomllib +from pathlib import Path +from typing import Any, Callable + +import array_api_strict +import pytest + +from drone_controllers.core import load_params, parametrize +from drone_controllers.drones import Drones +from drone_controllers.mellinger import ( + attitude2force_torque, + force_torque2rotor_vel, + state2attitude, +) + +_MELLINGER_FNS = [state2attitude, attitude2force_torque, force_torque2rotor_vel] + + +@pytest.mark.unit +@pytest.mark.parametrize("fn", _MELLINGER_FNS, ids=lambda fn: fn.__name__) +@pytest.mark.parametrize("drone_model", Drones) +def test_load_params_keys(fn: Callable[..., Any], drone_model: Drones) -> None: + params = load_params("mellinger", fn.__name__, drone_model) + kwonly = { + name + for name, p in inspect.signature(fn).parameters.items() + if p.kind == inspect.Parameter.KEYWORD_ONLY + } + assert kwonly <= set(params.keys()), f"Missing keys: {kwonly - set(params.keys())}" + + +@pytest.mark.unit +@pytest.mark.parametrize("drone_model", Drones) +def test_load_params_values(drone_model: Drones) -> None: + params = load_params("mellinger", "state2attitude", drone_model) + toml_path = Path(__file__).parents[2] / "drone_controllers/mellinger/params.toml" + with open(toml_path, "rb") as f: + raw = tomllib.load(f) + expected_mass = raw[drone_model.value]["core"]["mass"] + assert float(params["mass"]) == pytest.approx(expected_mass) + + +@pytest.mark.unit +def test_load_params_unknown_drone() -> None: + with pytest.raises(KeyError, match="nonexistent_drone"): + load_params("mellinger", "state2attitude", "nonexistent_drone") + + +@pytest.mark.unit +def test_load_params_unknown_controller() -> None: + with pytest.raises(OSError): + load_params("nonexistent", "state2attitude", "cf2x_L250") + + +@pytest.mark.unit +def test_parametrize_unknown_drone() -> None: + with pytest.raises(KeyError): + parametrize(state2attitude, "nonexistent_drone") + + +@pytest.mark.unit +@pytest.mark.parametrize("drone_model", Drones) +def test_parametrize_xp_namespace(drone_model: Drones) -> None: + controller = parametrize(state2attitude, drone_model, xp=array_api_strict) + xp_array_type = type(array_api_strict.asarray(0.0)) + assert all(isinstance(v, xp_array_type) for v in controller.keywords.values()) diff --git a/tests/unit/test_mellinger.py b/tests/unit/test_mellinger.py index e1d930c..b041b41 100644 --- a/tests/unit/test_mellinger.py +++ b/tests/unit/test_mellinger.py @@ -1,19 +1,18 @@ from __future__ import annotations -from functools import partial from typing import TYPE_CHECKING import numpy as np import pytest from drone_controllers import parametrize +from drone_controllers.core import load_params from drone_controllers.drones import Drones from drone_controllers.mellinger import ( attitude2force_torque, force_torque2rotor_vel, state2attitude, ) -from drone_controllers.mellinger.params import AttitudeParams, ForceTorqueParams, StateParams if TYPE_CHECKING: from drone_controllers._typing import Array # To be changed to array_api_typing later @@ -26,32 +25,28 @@ def create_rnd_states(shape: tuple[int, ...] = ()) -> tuple[Array, Array, Array, @pytest.mark.unit @pytest.mark.parametrize("drone_model", Drones) -def test_state2attitude(drone_model: Drones): - # Manually parametrize the controller - params = StateParams.load(drone_model) - controller = partial(state2attitude, ctrl_freq=100, **params._asdict()) +def test_state2attitude(drone_model: Drones) -> None: + controller = parametrize(state2attitude, drone_model) # Single input pos, quat, vel, ang_vel = create_rnd_states() - rpyt, pos_err_i = controller(pos, quat, vel, ang_vel, np.ones(13), ctrl_freq=100) + rpyt, pos_err_i = controller(pos, quat, vel, np.ones(13), ctrl_freq=100) assert rpyt.shape == (4,) assert pos_err_i.shape == (3,) # Batch input pos, quat, vel, ang_vel = create_rnd_states((5, 4)) - rpyt, pos_err_i = controller(pos, quat, vel, ang_vel, np.ones((5, 4, 13)), ctrl_freq=100) + rpyt, pos_err_i = controller(pos, quat, vel, np.ones((5, 4, 13)), ctrl_freq=100) assert rpyt.shape == (5, 4, 4) assert pos_err_i.shape == (5, 4, 3) @pytest.mark.unit @pytest.mark.parametrize("drone_model", Drones) -def test_attitude2force_torque(drone_model: Drones): - # Manually parametrize the controller - params = AttitudeParams.load(drone_model) - controller = partial(attitude2force_torque, ctrl_freq=500, **params._asdict()) +def test_attitude2force_torque(drone_model: Drones) -> None: + controller = parametrize(attitude2force_torque, drone_model) # Single input pos, quat, vel, ang_vel = create_rnd_states() rpyt_cmd = np.array([0.1, 0.1, 0.1, 1.0]) # roll, pitch, yaw, thrust command - force_des, torque_des, r_int_error = controller(pos, quat, vel, ang_vel, rpyt_cmd) + force_des, torque_des, r_int_error = controller(quat, ang_vel, rpyt_cmd) assert force_des.shape == (1,) assert torque_des.shape == (3,) assert r_int_error.shape == (3,) @@ -59,7 +54,7 @@ def test_attitude2force_torque(drone_model: Drones): pos, quat, vel, ang_vel = create_rnd_states((5, 4)) rpyt_cmd = np.random.randn(5, 4, 4) rpyt_cmd[..., 3] = np.abs(rpyt_cmd[..., 3]) # Ensure positive thrust - force_des, torque_des, r_int_error = controller(pos, quat, vel, ang_vel, rpyt_cmd) + force_des, torque_des, r_int_error = controller(quat, ang_vel, rpyt_cmd) assert force_des.shape == (5, 4, 1) assert torque_des.shape == (5, 4, 3) assert r_int_error.shape == (5, 4, 3) @@ -67,10 +62,8 @@ def test_attitude2force_torque(drone_model: Drones): @pytest.mark.unit @pytest.mark.parametrize("drone_model", Drones) -def test_force_torque2rotor_vel(drone_model: Drones): - # Manually parametrize the controller - params = ForceTorqueParams.load(drone_model) - controller = partial(force_torque2rotor_vel, **params._asdict()) +def test_force_torque2rotor_vel(drone_model: Drones) -> None: + controller = parametrize(force_torque2rotor_vel, drone_model) # Single input force = np.array([1.0]) torque = np.array([0.1, 0.1, 0.1]) @@ -83,57 +76,137 @@ def test_force_torque2rotor_vel(drone_model: Drones): assert rotor_vel.shape == (5, 4, 4) +# Correctness / physics + + @pytest.mark.unit @pytest.mark.parametrize("drone_model", Drones) -def test_state2attitude_parametrize(drone_model: Drones): - # Test the parametrize function with all available drones +def test_state2attitude_at_setpoint(drone_model: Drones) -> None: + # At setpoint with identity orientation and zero acc, RPY command should be + # [0, 0, 0] and thrust must be positive (hovering against gravity). controller = parametrize(state2attitude, drone_model) - # Single input test - pos, quat, vel, ang_vel = create_rnd_states() - rpyt, pos_err_i = controller(pos, quat, vel, ang_vel, np.ones(13)) - assert rpyt.shape == (4,) - assert pos_err_i.shape == (3,) - # Batch input test - pos, quat, vel, ang_vel = create_rnd_states((3, 2)) - rpyt, pos_err_i = controller(pos, quat, vel, ang_vel, np.ones((3, 2, 13))) - assert rpyt.shape == (3, 2, 4) - assert pos_err_i.shape == (3, 2, 3) + pos = np.zeros(3) + quat = np.array([0.0, 0.0, 0.0, 1.0]) + vel = np.zeros(3) + cmd = np.zeros(13) # setpoint at origin, zero vel/acc, yaw=0 + rpyt, _ = controller(pos, quat, vel, cmd) + assert np.allclose(rpyt[:3], 0.0, atol=1e-6), f"RPY at setpoint should be ~0, got {rpyt[:3]}" + assert rpyt[3] > 0.0, "Hovering thrust must be positive" + + +@pytest.mark.unit +@pytest.mark.parametrize("drone_model", Drones) +def test_state2attitude_integral_error_accumulation(drone_model: Drones) -> None: + # A constant position error must cause the integral error to accumulate + # linearly until it would exceed int_err_max (clipped by the controller). + controller = parametrize(state2attitude, drone_model) + params = load_params("mellinger", "state2attitude", drone_model) + pos = np.zeros(3) + quat = np.array([0.0, 0.0, 0.0, 1.0]) + vel = np.zeros(3) + cmd = np.zeros(13) + cmd[0] = 1.0 # 1 m setpoint error in x + ctrl_freq = 100.0 + dt = 1.0 / ctrl_freq + steps = 5 + + err = None + for _ in range(steps): + _, err_i = controller(pos, quat, vel, cmd, ctrl_errors=err, ctrl_freq=ctrl_freq) + err = (err_i,) + + expected = np.clip( + np.array([steps * dt, 0.0, 0.0]), -params["int_err_max"], params["int_err_max"] + ) + assert np.allclose(err[0], expected, atol=1e-6) @pytest.mark.unit @pytest.mark.parametrize("drone_model", Drones) -def test_attitude2force_torque_parametrize(drone_model: Drones): - # Test the parametrize function with all available drones +def test_attitude2force_torque_at_setpoint(drone_model: Drones) -> None: + # Identity orientation commanded → zero attitude error → zero corrective torque. controller = parametrize(attitude2force_torque, drone_model) - # Single input test - pos, quat, vel, ang_vel = create_rnd_states() - rpyt_cmd = np.array([0.1, 0.1, 0.1, 1.0]) # roll, pitch, yaw, thrust command - force_des, torque_des, r_int_error = controller(pos, quat, vel, ang_vel, rpyt_cmd) - assert force_des.shape == (1,) - assert torque_des.shape == (3,) - assert r_int_error.shape == (3,) - # Batch input test - pos, quat, vel, ang_vel = create_rnd_states((3, 2)) - rpyt_cmd = np.random.randn(3, 2, 4) - rpyt_cmd[..., 3] = np.abs(rpyt_cmd[..., 3]) # Ensure positive thrust - force_des, torque_des, r_int_error = controller(pos, quat, vel, ang_vel, rpyt_cmd) - assert force_des.shape == (3, 2, 1) - assert torque_des.shape == (3, 2, 3) - assert r_int_error.shape == (3, 2, 3) + quat = np.array([0.0, 0.0, 0.0, 1.0]) + ang_vel = np.zeros(3) + cmd = np.array([0.0, 0.0, 0.0, 0.5]) # RPY=0, positive thrust + force_des, torque_des, _ = controller(quat, ang_vel, cmd) + assert np.allclose(torque_des, 0.0, atol=1e-6), ( + f"Torque at setpoint should be ~0, got {torque_des}" + ) + assert force_des[0] > 0.0, "Force must be positive for positive thrust command" @pytest.mark.unit @pytest.mark.parametrize("drone_model", Drones) -def test_force_torque2rotor_vel_parametrize(drone_model: Drones): - # Test the parametrize function with all available drones +def test_attitude2force_torque_zero_thrust(drone_model: Drones): + # Zero thrust command → firmware zeros torque; outputs are all zero. + controller = parametrize(attitude2force_torque, drone_model) + quat = np.array([0.0, 0.0, 0.0, 1.0]) + ang_vel = np.zeros(3) + cmd = np.array([0.1, 0.1, 0.1, 0.0]) # non-zero RPY but zero thrust + force_des, torque_des, _ = controller(quat, ang_vel, cmd) + assert np.allclose(force_des, 0.0, atol=1e-6) + assert np.allclose(torque_des, 0.0, atol=1e-6) + + +# Batch consistency (batch result == sequential result) + + +@pytest.mark.unit +@pytest.mark.parametrize("drone_model", Drones) +def test_state2attitude_batch_consistency(drone_model: Drones): + controller = parametrize(state2attitude, drone_model) + batch = (3, 2) + pos, quat, vel, _ = create_rnd_states(batch) + cmd = np.random.randn(*batch, 13) + rpyt_batch, err_batch = controller(pos, quat, vel, cmd) + for i in range(batch[0]): + for j in range(batch[1]): + rpyt_s, err_s = controller(pos[i, j], quat[i, j], vel[i, j], cmd[i, j]) + assert np.allclose(rpyt_batch[i, j], rpyt_s, atol=1e-5) + assert np.allclose(err_batch[i, j], err_s, atol=1e-5) + + +@pytest.mark.unit +@pytest.mark.parametrize("drone_model", Drones) +def test_attitude2force_torque_batch_consistency(drone_model: Drones): + controller = parametrize(attitude2force_torque, drone_model) + batch = (3, 2) + _, quat, _, ang_vel = create_rnd_states(batch) + cmd = np.random.randn(*batch, 4) + cmd[..., 3] = np.abs(cmd[..., 3]) + force_batch, torque_batch, err_batch = controller(quat, ang_vel, cmd) + for i in range(batch[0]): + for j in range(batch[1]): + force_s, torque_s, err_s = controller(quat[i, j], ang_vel[i, j], cmd[i, j]) + assert np.allclose(force_batch[i, j], force_s, atol=1e-5) + assert np.allclose(torque_batch[i, j], torque_s, atol=1e-5) + assert np.allclose(err_batch[i, j], err_s, atol=1e-5) + + +@pytest.mark.unit +@pytest.mark.parametrize("drone_model", Drones) +def test_force_torque2rotor_vel_batch_consistency(drone_model: Drones): controller = parametrize(force_torque2rotor_vel, drone_model) - # Single input test - force = np.array([1.0]) - torque = np.array([0.1, 0.1, 0.1]) - rotor_vel = controller(force, torque) - assert rotor_vel.shape == (4,) - # Batch input test - force = np.ones((3, 2, 1)) - torque = np.random.randn(3, 2, 3) * 0.1 + batch = (3, 2) + force = np.abs(np.random.randn(*batch, 1)) * 0.05 + 0.05 + torque = np.random.randn(*batch, 3) * 0.001 + rpm_batch = controller(force, torque) + for i in range(batch[0]): + for j in range(batch[1]): + rpm_s = controller(force[i, j], torque[i, j]) + assert np.allclose(rpm_batch[i, j], rpm_s, atol=1e-5) + + +# Symmetric force check + + +@pytest.mark.unit +@pytest.mark.parametrize("drone_model", Drones) +def test_force_torque2rotor_vel_symmetric(drone_model: Drones): + # Pure vertical force with zero torque → X-frame symmetry → all 4 RPMs equal. + controller = parametrize(force_torque2rotor_vel, drone_model) + force = np.array([0.2]) # total thrust, split equally across 4 motors + torque = np.zeros(3) rotor_vel = controller(force, torque) - assert rotor_vel.shape == (3, 2, 4) + assert np.allclose(rotor_vel, rotor_vel[0], rtol=1e-5), f"RPMs not equal: {rotor_vel}" diff --git a/tests/unit/test_transform.py b/tests/unit/test_transform.py new file mode 100644 index 0000000..f8ef109 --- /dev/null +++ b/tests/unit/test_transform.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest + +from drone_controllers.core import load_params +from drone_controllers.drones import Drones +from drone_controllers.transform import ( + force2pwm, + motor_force2rotor_vel, + pwm2force, + rotor_vel2body_force, +) + + +@pytest.fixture(scope="module") +def core_params() -> dict[str, Any]: + return load_params("mellinger", "force_torque2rotor_vel", Drones.cf2x_L250) + + +@pytest.mark.unit +def test_force2pwm_pwm2force_roundtrip(core_params: dict[str, Any]) -> None: + thrust_max = float(core_params["thrust_max"]) + pwm_max = float(core_params["pwm_max"]) + forces = np.array([0.0, thrust_max * 0.25, thrust_max * 0.5, thrust_max]) + assert np.allclose( + pwm2force(force2pwm(forces, thrust_max, pwm_max), thrust_max, pwm_max), forces + ) + + +@pytest.mark.unit +def test_force2pwm_boundary(core_params: dict[str, Any]) -> None: + thrust_max = float(core_params["thrust_max"]) + pwm_max = float(core_params["pwm_max"]) + assert force2pwm(0.0, thrust_max, pwm_max) == pytest.approx(0.0) + assert force2pwm(thrust_max, thrust_max, pwm_max) == pytest.approx(pwm_max) + + +@pytest.mark.unit +def test_motor_force2rotor_vel_shape(core_params: dict[str, Any]) -> None: + rpm2thrust = core_params["rpm2thrust"] + assert motor_force2rotor_vel(np.full(4, 0.05), rpm2thrust).shape == (4,) + assert motor_force2rotor_vel(np.full((3, 2, 4), 0.05), rpm2thrust).shape == (3, 2, 4) + + +@pytest.mark.unit +def test_motor_force2rotor_vel_positive(core_params: dict[str, Any]) -> None: + rpm2thrust = core_params["rpm2thrust"] + forces = np.linspace(0.02, 0.12, 10) + assert np.all(motor_force2rotor_vel(forces, rpm2thrust) > 0) + + +@pytest.mark.unit +def test_rotor_vel2body_force_shape(core_params: dict[str, Any]) -> None: + rpm2thrust = core_params["rpm2thrust"] + assert rotor_vel2body_force(np.full(4, 10_000.0), rpm2thrust).shape == (3,) + assert rotor_vel2body_force(np.full((3, 2, 4), 10_000.0), rpm2thrust).shape == (3, 2, 3) + + +@pytest.mark.unit +def test_rotor_vel2body_force_z_axis_only(core_params: dict[str, Any]) -> None: + rpm2thrust = core_params["rpm2thrust"] + body_force = rotor_vel2body_force(np.full(4, 10_000.0), rpm2thrust) + assert np.allclose(body_force[:2], 0.0) + assert body_force[2] > 0.0