diff --git a/docs/user-guides/getting-started.md b/docs/user-guides/getting-started.md index 1fb42c7a..4d7f42cd 100644 --- a/docs/user-guides/getting-started.md +++ b/docs/user-guides/getting-started.md @@ -283,23 +283,23 @@ MuJoCo does not currently ship first-party Python typing stubs. To enable proper > Thanks to the work by [@kevinzakka](https://github.com/google-deepmind/mujoco/issues/1292#issuecomment-1874138201) and [@mluogh-xdof](https://github.com/google-deepmind/mujoco/issues/1292#issuecomment-3208219200) for figuring all this out! -1. In your terminal: +1. From your project root, generate stubs into a local `typings/` directory: ```bash linenums="0" - pybind11-stubgen mujoco -o ~/typings/ --numpy-array-wrap-with-annotated + pybind11-stubgen mujoco -o typings/ --numpy-array-wrap-with-annotated ``` 2. Recent MuJoCo builds compiled with newer pybind11 versions correctly expose enums as `SupportsInt`. If you encounter Pyright enum type errors, apply the compatibility patch: ```bash linenums="0" - python ~/typings/patch_mujoco_enums.py ~/typings/mujoco/_enums.pyi + python typings/patch_mujoco_enums.py typings/mujoco/_enums.pyi ``` 3. Then in `pyproject.toml` (for me, VSCode already type hints correctly, but this should fix things if you use `pyright` with your pre-commit hooks): ```toml title="pyproject.toml" [tool.pyright] - stubPath = "~/typings" + stubPath = "typings" venvPath = "." venv = ".venv" ``` diff --git a/src/mujoco_mojo/__init__.py b/src/mujoco_mojo/__init__.py index c8cf36a6..2d029fd6 100644 --- a/src/mujoco_mojo/__init__.py +++ b/src/mujoco_mojo/__init__.py @@ -2,6 +2,30 @@ mujoco_mojo is a collection of Python objects built to make working with MuJoCo via Python easier. It provides vast bindings for all MJCF XML schema objects, tools to convert to XML, run MuJoCo simulations, and more. + +MuJoCo Type Stubs: +------------------ + +MuJoCo does not ship Python type stubs. Generate them once with pybind11-stubgen: + +.. code-block:: bash + pip install pybind11-stubgen + pybind11-stubgen mujoco -o typings/ --numpy-array-wrap-with-annotated + +Run this from the project root. The ``typings/`` directory is gitignored. +Then add to pyproject.toml: + +.. code-block:: toml + [tool.pyright] + stubPath = "typings" + venvPath = "." + venv = ".venv" + +If you encounter Pyright enum errors on recent MuJoCo builds, apply the +compatibility patch that ships alongside the stubs: + +.. code-block:: bash + python typings/patch_mujoco_enums.py typings/mujoco/_enums.pyi """ from mujoco_mojo.__about__ import __version__ # noqa: F401 diff --git a/src/mujoco_mojo/utils/dataframe.py b/src/mujoco_mojo/utils/dataframe.py index 8f605133..8c790a8a 100644 --- a/src/mujoco_mojo/utils/dataframe.py +++ b/src/mujoco_mojo/utils/dataframe.py @@ -1,7 +1,10 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast +from typing import TYPE_CHECKING, Any, TypedDict, cast + +if TYPE_CHECKING: + from typing import Self import polars as pl from scipy.spatial.transform import Rotation as R @@ -28,13 +31,6 @@ logger = get_logger(__name__) -class MojoFrameProtocol(Protocol): - """Protocol to tell type checkers that .mojo is available on the DataFrame.""" - - @property - def mojo(self) -> MojoNamespace: ... - - class _MojoFrame(pl.DataFrame): """ Internal implementation of MojoFrame to house static loaders. @@ -78,10 +74,30 @@ def from_dict( if TYPE_CHECKING: # MojoFrame is seen by the IDE as the combination of: - # 1. Our loaders (_MojoFrame) + # 1. loaders (_MojoFrame) # 2. Polars methods (pl.DataFrame) - # 3. Our namespace (MojoFrameProtocol) - class MojoDataFrame(_MojoFrame, MojoFrameProtocol): ... + # 3. namespace (MojoFrameProtocol) + class MojoDataFrame(_MojoFrame): + @property + def mojo(self) -> MojoNamespace: ... + + # override polars methods that return DataFrame so the type propagates + def select(self, *args: Any, **kwargs: Any) -> Self: ... + def filter(self, *args: Any, **kwargs: Any) -> Self: ... + def with_columns(self, *args: Any, **kwargs: Any) -> Self: ... + def sort(self, *args: Any, **kwargs: Any) -> Self: ... + def head(self, *args: Any, **kwargs: Any) -> Self: ... + def tail(self, *args: Any, **kwargs: Any) -> Self: ... + def limit(self, *args: Any, **kwargs: Any) -> Self: ... + def slice(self, *args: Any, **kwargs: Any) -> Self: ... + def rename(self, *args: Any, **kwargs: Any) -> Self: ... + def drop(self, *args: Any, **kwargs: Any) -> Self: ... + def drop_nulls(self, *args: Any, **kwargs: Any) -> Self: ... + def unique(self, *args: Any, **kwargs: Any) -> Self: ... + def sample(self, *args: Any, **kwargs: Any) -> Self: ... + def join(self, *args: Any, **kwargs: Any) -> Self: ... + def hstack(self, *args: Any, **kwargs: Any) -> Self: ... + def vstack(self, *args: Any, **kwargs: Any) -> Self: ... else: # At runtime, it's just our internal class MojoDataFrame = _MojoFrame @@ -93,7 +109,6 @@ class ColumnManifest(TypedDict): available_quats: list[str] -@pl.api.register_dataframe_namespace("mojo") class MojoNamespace: """ Enhanced Polars DataFrame for MuJoCo Mojo telemetry. @@ -370,3 +385,7 @@ def with_filters( target_cols = columns or self._df.columns filter_map = {col: filters for col in target_cols} return self.with_filter_map(filter_map, omit_time=omit_time) + + +if not TYPE_CHECKING: + pl.api.register_dataframe_namespace("mojo")(MojoNamespace) diff --git a/src/mujoco_mojo/utils/filters/filters.py b/src/mujoco_mojo/utils/filters/filters.py index 685a03d6..4a880bb7 100644 --- a/src/mujoco_mojo/utils/filters/filters.py +++ b/src/mujoco_mojo/utils/filters/filters.py @@ -60,6 +60,16 @@ class BaseFilter(ABC, BaseModel): def apply(self, expr: pl.Expr) -> pl.Expr: """Applies the transformation to a Polars expression.""" + def apply_with_context( + self, series: pl.Series, df: pl.DataFrame + ) -> pl.Series | None: + """ + Override for filters that need access to other columns. + Receives the current (already-transformed) series and the original dataframe. + Return None to fall back to apply(expr). + """ + return None + class ScaleFilter(BaseFilter): """Applies a linear transformation: (value * factor) + offset.""" @@ -97,12 +107,25 @@ class DerivativeFilter(BaseFilter): """The discriminator type for Pydantic.""" dt: float = Field(default=0.001, gt=0) - """The time step between samples in seconds.""" + """The time step between samples in seconds. Ignored when wrt_col is set.""" + + wrt_col: str | None = Field(default=None, json_schema_extra={"ui_type": "col"}) + """Optional column to differentiate with respect to instead of a fixed dt.""" def apply(self, expr: pl.Expr) -> pl.Expr: # Backward difference: (x[n] - x[n-1]) / dt return expr.diff().fill_null(0) / self.dt + def apply_with_context( + self, series: pl.Series, df: pl.DataFrame + ) -> pl.Series | None: + if not self.wrt_col or self.wrt_col not in df.columns: + return None + wrt = df[self.wrt_col].cast(pl.Float64) + # Avoid divide-by-zero at the first sample + dx = wrt.diff().fill_null(strategy="forward").fill_null(1) + return series.cast(pl.Float64).diff().fill_null(0) / dx + class IntegralFilter(BaseFilter): """ @@ -114,12 +137,24 @@ class IntegralFilter(BaseFilter): """The discriminator type for Pydantic.""" dt: float = Field(default=0.001, gt=0) - """The time step between samples in seconds.""" + """The time step between samples in seconds. Ignored when wrt_col is set.""" + + wrt_col: str | None = Field(default=None, json_schema_extra={"ui_type": "col"}) + """Optional column to integrate with respect to instead of a fixed dt.""" def apply(self, expr: pl.Expr) -> pl.Expr: - # Simple cumulative trapezoidal or rectangular integration + # Simple rectangular integration with fixed step return expr.cum_sum() * self.dt + def apply_with_context( + self, series: pl.Series, df: pl.DataFrame + ) -> pl.Series | None: + if not self.wrt_col or self.wrt_col not in df.columns: + return None + wrt = df[self.wrt_col].cast(pl.Float64) + dx = wrt.diff().fill_null(0) + return (series.cast(pl.Float64) * dx).cum_sum() + class LowPassFilter(BaseFilter): """ diff --git a/src/mujoco_mojo/utils/layers/cli.py b/src/mujoco_mojo/utils/layers/cli.py index 48a049a0..988b1ef3 100644 --- a/src/mujoco_mojo/utils/layers/cli.py +++ b/src/mujoco_mojo/utils/layers/cli.py @@ -921,6 +921,25 @@ def init_project( border_style="cyan", ) ) + if not Path("typings", "mujoco").exists(): + console.print("\n[bold yellow]Type Hints Setup[/bold yellow]") + console.print( + "[bold white]MuJoCo does not ship Python type stubs.[/bold white] " + "Generate them once for Pylance/Pyright autocomplete:\n\n" + "[bold yellow]" + "pip install pybind11-stubgen\n" + "pybind11-stubgen mujoco -o typings/ --numpy-array-wrap-with-annotated" + "[/bold yellow]\n\n" + "Run from the project root. Then add to [bold cyan]pyproject.toml[/bold cyan]:\n\n" + "[dim]" + "\\[tool.pyright]\n" + 'stubPath = "typings"\n' + 'venvPath = "."\n' + 'venv = ".venv"' + "[/dim]\n\n" + "[dim]Enum errors? Run: " + "python typings/patch_mujoco_enums.py typings/mujoco/_enums.pyi[/dim]" + ) @cli_app.command(name="reloaded") diff --git a/src/mujoco_mojo/utils/layers/dojo/plot_config.py b/src/mujoco_mojo/utils/layers/dojo/plot_config.py index 8efa02b3..47f58a2c 100644 --- a/src/mujoco_mojo/utils/layers/dojo/plot_config.py +++ b/src/mujoco_mojo/utils/layers/dojo/plot_config.py @@ -84,6 +84,11 @@ class ShapeType(StrEnum): rect = "rect" +class PlotType(StrEnum): + cartesian = "cartesian" + polar = "polar" + + # --------------------------------------------------------------------------- # Composite models # --------------------------------------------------------------------------- @@ -98,6 +103,11 @@ class FilterEntry(BaseModel): enabled: bool = True +class XAxisConfig(BaseModel): + col: str = "time" + filters: list[FilterEntry] = [] + + class YAxisConfig(BaseModel): label: str color: str @@ -128,7 +138,7 @@ class Shape(BaseModel): class PlotConfig(BaseModel): """Complete serialisable state of a trial-viewer plot.""" - xAxis: str + xAxis: XAxisConfig = Field(default_factory=XAxisConfig) yAxes: dict[str, YAxisConfig] refFrame: str | None grid: GridMode @@ -146,6 +156,7 @@ class PlotConfig(BaseModel): yScale: ScaleType xLogBase: float | None = None yLogBase: float | None = None + plotType: PlotType = PlotType.cartesian vsEnabled: bool vsRange: Annotated[tuple[float, float], Field()] annotations: list[Annotation] diff --git a/src/mujoco_mojo/utils/layers/dojo/routers/mosaic.py b/src/mujoco_mojo/utils/layers/dojo/routers/mosaic.py index 53642a5c..03105e6f 100644 --- a/src/mujoco_mojo/utils/layers/dojo/routers/mosaic.py +++ b/src/mujoco_mojo/utils/layers/dojo/routers/mosaic.py @@ -216,6 +216,8 @@ async def get_filter_schema(): from pydantic_core import PydanticUndefined def _infer_type(prop: dict) -> str: + if prop.get("ui_type") == "col": + return "col" if "anyOf" in prop: non_null = [s for s in prop["anyOf"] if s.get("type") != "null"] prop = non_null[0] if non_null else {} @@ -506,12 +508,15 @@ async def get_trial_data( if filter_list: if series.dtype != pl.Float64: series = series.cast(pl.Float64) - tmp = pl.DataFrame({col: series}) - expr = pl.col(col) for f in filter_list: - expr = f.apply(expr) - tmp = tmp.with_columns(expr.alias(col)) - series = tmp[col] + # context-aware filters (e.g. derivative/integral wrt another col) + ctx = f.apply_with_context(series, df) + if ctx is not None: + series = ctx + else: + tmp = pl.DataFrame({col: series}) + tmp = tmp.with_columns(f.apply(pl.col(col)).alias(col)) + series = tmp[col] data[col] = series.to_list() return { diff --git a/src/mujoco_mojo/utils/layers/dojo/templates/base.html b/src/mujoco_mojo/utils/layers/dojo/templates/base.html index d9917000..0f8e3094 100644 --- a/src/mujoco_mojo/utils/layers/dojo/templates/base.html +++ b/src/mujoco_mojo/utils/layers/dojo/templates/base.html @@ -101,6 +101,11 @@ } } + @@ -50,31 +39,6 @@
- -