Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/user-guides/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,23 +283,23 @@ MuJoCo does not currently ship first-party Python typing stubs. To enable proper

> Thanks to the work by [@kevinzakka](https://github.com/google-deepmind/mujoco/issues/1292#issuecomment-1874138201) and [@mluogh-xdof](https://github.com/google-deepmind/mujoco/issues/1292#issuecomment-3208219200) for figuring all this out!

1. In your terminal:
1. From your project root, generate stubs into a local `typings/` directory:

```bash linenums="0"
pybind11-stubgen mujoco -o ~/typings/ --numpy-array-wrap-with-annotated
pybind11-stubgen mujoco -o typings/ --numpy-array-wrap-with-annotated
```

2. Recent MuJoCo builds compiled with newer pybind11 versions correctly expose enums as `SupportsInt`. If you encounter Pyright enum type errors, apply the compatibility patch:

```bash linenums="0"
python ~/typings/patch_mujoco_enums.py ~/typings/mujoco/_enums.pyi
python typings/patch_mujoco_enums.py typings/mujoco/_enums.pyi
```

3. Then in `pyproject.toml` (for me, VSCode already type hints correctly, but this should fix things if you use `pyright` with your pre-commit hooks):

```toml title="pyproject.toml"
[tool.pyright]
stubPath = "~/typings"
stubPath = "typings"
venvPath = "."
venv = ".venv"
```
Expand Down
24 changes: 24 additions & 0 deletions src/mujoco_mojo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,30 @@
mujoco_mojo is a collection of Python objects built to make working with MuJoCo via Python easier.

It provides vast bindings for all MJCF XML schema objects, tools to convert to XML, run MuJoCo simulations, and more.

MuJoCo Type Stubs:
------------------

MuJoCo does not ship Python type stubs. Generate them once with pybind11-stubgen:

.. code-block:: bash
pip install pybind11-stubgen
pybind11-stubgen mujoco -o typings/ --numpy-array-wrap-with-annotated

Run this from the project root. The ``typings/`` directory is gitignored.
Then add to pyproject.toml:

.. code-block:: toml
[tool.pyright]
stubPath = "typings"
venvPath = "."
venv = ".venv"

If you encounter Pyright enum errors on recent MuJoCo builds, apply the
compatibility patch that ships alongside the stubs:

.. code-block:: bash
python typings/patch_mujoco_enums.py typings/mujoco/_enums.pyi
"""

from mujoco_mojo.__about__ import __version__ # noqa: F401
Expand Down
43 changes: 31 additions & 12 deletions src/mujoco_mojo/utils/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast
from typing import TYPE_CHECKING, Any, TypedDict, cast

if TYPE_CHECKING:
from typing import Self

import polars as pl
from scipy.spatial.transform import Rotation as R
Expand All @@ -28,13 +31,6 @@
logger = get_logger(__name__)


class MojoFrameProtocol(Protocol):
"""Protocol to tell type checkers that .mojo is available on the DataFrame."""

@property
def mojo(self) -> MojoNamespace: ...


class _MojoFrame(pl.DataFrame):
"""
Internal implementation of MojoFrame to house static loaders.
Expand Down Expand Up @@ -78,10 +74,30 @@ def from_dict(

if TYPE_CHECKING:
# MojoFrame is seen by the IDE as the combination of:
# 1. Our loaders (_MojoFrame)
# 1. loaders (_MojoFrame)
# 2. Polars methods (pl.DataFrame)
# 3. Our namespace (MojoFrameProtocol)
class MojoDataFrame(_MojoFrame, MojoFrameProtocol): ...
# 3. namespace (MojoFrameProtocol)
class MojoDataFrame(_MojoFrame):
@property
def mojo(self) -> MojoNamespace: ...

# override polars methods that return DataFrame so the type propagates
def select(self, *args: Any, **kwargs: Any) -> Self: ...
def filter(self, *args: Any, **kwargs: Any) -> Self: ...
def with_columns(self, *args: Any, **kwargs: Any) -> Self: ...
def sort(self, *args: Any, **kwargs: Any) -> Self: ...
def head(self, *args: Any, **kwargs: Any) -> Self: ...
def tail(self, *args: Any, **kwargs: Any) -> Self: ...
def limit(self, *args: Any, **kwargs: Any) -> Self: ...
def slice(self, *args: Any, **kwargs: Any) -> Self: ...
def rename(self, *args: Any, **kwargs: Any) -> Self: ...
def drop(self, *args: Any, **kwargs: Any) -> Self: ...
def drop_nulls(self, *args: Any, **kwargs: Any) -> Self: ...
def unique(self, *args: Any, **kwargs: Any) -> Self: ...
def sample(self, *args: Any, **kwargs: Any) -> Self: ...
def join(self, *args: Any, **kwargs: Any) -> Self: ...
def hstack(self, *args: Any, **kwargs: Any) -> Self: ...
def vstack(self, *args: Any, **kwargs: Any) -> Self: ...
else:
# At runtime, it's just our internal class
MojoDataFrame = _MojoFrame
Expand All @@ -93,7 +109,6 @@ class ColumnManifest(TypedDict):
available_quats: list[str]


@pl.api.register_dataframe_namespace("mojo")
class MojoNamespace:
"""
Enhanced Polars DataFrame for MuJoCo Mojo telemetry.
Expand Down Expand Up @@ -370,3 +385,7 @@ def with_filters(
target_cols = columns or self._df.columns
filter_map = {col: filters for col in target_cols}
return self.with_filter_map(filter_map, omit_time=omit_time)


if not TYPE_CHECKING:
pl.api.register_dataframe_namespace("mojo")(MojoNamespace)
41 changes: 38 additions & 3 deletions src/mujoco_mojo/utils/filters/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ class BaseFilter(ABC, BaseModel):
def apply(self, expr: pl.Expr) -> pl.Expr:
"""Applies the transformation to a Polars expression."""

def apply_with_context(
self, series: pl.Series, df: pl.DataFrame
) -> pl.Series | None:
"""
Override for filters that need access to other columns.
Receives the current (already-transformed) series and the original dataframe.
Return None to fall back to apply(expr).
"""
return None


class ScaleFilter(BaseFilter):
"""Applies a linear transformation: (value * factor) + offset."""
Expand Down Expand Up @@ -97,12 +107,25 @@ class DerivativeFilter(BaseFilter):
"""The discriminator type for Pydantic."""

dt: float = Field(default=0.001, gt=0)
"""The time step between samples in seconds."""
"""The time step between samples in seconds. Ignored when wrt_col is set."""

wrt_col: str | None = Field(default=None, json_schema_extra={"ui_type": "col"})
"""Optional column to differentiate with respect to instead of a fixed dt."""

def apply(self, expr: pl.Expr) -> pl.Expr:
# Backward difference: (x[n] - x[n-1]) / dt
return expr.diff().fill_null(0) / self.dt

def apply_with_context(
self, series: pl.Series, df: pl.DataFrame
) -> pl.Series | None:
if not self.wrt_col or self.wrt_col not in df.columns:
return None
wrt = df[self.wrt_col].cast(pl.Float64)
# Avoid divide-by-zero at the first sample
dx = wrt.diff().fill_null(strategy="forward").fill_null(1)
return series.cast(pl.Float64).diff().fill_null(0) / dx


class IntegralFilter(BaseFilter):
"""
Expand All @@ -114,12 +137,24 @@ class IntegralFilter(BaseFilter):
"""The discriminator type for Pydantic."""

dt: float = Field(default=0.001, gt=0)
"""The time step between samples in seconds."""
"""The time step between samples in seconds. Ignored when wrt_col is set."""

wrt_col: str | None = Field(default=None, json_schema_extra={"ui_type": "col"})
"""Optional column to integrate with respect to instead of a fixed dt."""

def apply(self, expr: pl.Expr) -> pl.Expr:
# Simple cumulative trapezoidal or rectangular integration
# Simple rectangular integration with fixed step
return expr.cum_sum() * self.dt

def apply_with_context(
self, series: pl.Series, df: pl.DataFrame
) -> pl.Series | None:
if not self.wrt_col or self.wrt_col not in df.columns:
return None
wrt = df[self.wrt_col].cast(pl.Float64)
dx = wrt.diff().fill_null(0)
return (series.cast(pl.Float64) * dx).cum_sum()


class LowPassFilter(BaseFilter):
"""
Expand Down
19 changes: 19 additions & 0 deletions src/mujoco_mojo/utils/layers/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,25 @@ def init_project(
border_style="cyan",
)
)
if not Path("typings", "mujoco").exists():
console.print("\n[bold yellow]Type Hints Setup[/bold yellow]")
console.print(
"[bold white]MuJoCo does not ship Python type stubs.[/bold white] "
"Generate them once for Pylance/Pyright autocomplete:\n\n"
"[bold yellow]"
"pip install pybind11-stubgen\n"
"pybind11-stubgen mujoco -o typings/ --numpy-array-wrap-with-annotated"
"[/bold yellow]\n\n"
"Run from the project root. Then add to [bold cyan]pyproject.toml[/bold cyan]:\n\n"
"[dim]"
"\\[tool.pyright]\n"
'stubPath = "typings"\n'
'venvPath = "."\n'
'venv = ".venv"'
"[/dim]\n\n"
"[dim]Enum errors? Run: "
"python typings/patch_mujoco_enums.py typings/mujoco/_enums.pyi[/dim]"
)


@cli_app.command(name="reloaded")
Expand Down
13 changes: 12 additions & 1 deletion src/mujoco_mojo/utils/layers/dojo/plot_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ class ShapeType(StrEnum):
rect = "rect"


class PlotType(StrEnum):
cartesian = "cartesian"
polar = "polar"


# ---------------------------------------------------------------------------
# Composite models
# ---------------------------------------------------------------------------
Expand All @@ -98,6 +103,11 @@ class FilterEntry(BaseModel):
enabled: bool = True


class XAxisConfig(BaseModel):
col: str = "time"
filters: list[FilterEntry] = []


class YAxisConfig(BaseModel):
label: str
color: str
Expand Down Expand Up @@ -128,7 +138,7 @@ class Shape(BaseModel):
class PlotConfig(BaseModel):
"""Complete serialisable state of a trial-viewer plot."""

xAxis: str
xAxis: XAxisConfig = Field(default_factory=XAxisConfig)
yAxes: dict[str, YAxisConfig]
refFrame: str | None
grid: GridMode
Expand All @@ -146,6 +156,7 @@ class PlotConfig(BaseModel):
yScale: ScaleType
xLogBase: float | None = None
yLogBase: float | None = None
plotType: PlotType = PlotType.cartesian
vsEnabled: bool
vsRange: Annotated[tuple[float, float], Field()]
annotations: list[Annotation]
Expand Down
15 changes: 10 additions & 5 deletions src/mujoco_mojo/utils/layers/dojo/routers/mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ async def get_filter_schema():
from pydantic_core import PydanticUndefined

def _infer_type(prop: dict) -> str:
if prop.get("ui_type") == "col":
return "col"
if "anyOf" in prop:
non_null = [s for s in prop["anyOf"] if s.get("type") != "null"]
prop = non_null[0] if non_null else {}
Expand Down Expand Up @@ -506,12 +508,15 @@ async def get_trial_data(
if filter_list:
if series.dtype != pl.Float64:
series = series.cast(pl.Float64)
tmp = pl.DataFrame({col: series})
expr = pl.col(col)
for f in filter_list:
expr = f.apply(expr)
tmp = tmp.with_columns(expr.alias(col))
series = tmp[col]
# context-aware filters (e.g. derivative/integral wrt another col)
ctx = f.apply_with_context(series, df)
if ctx is not None:
series = ctx
else:
tmp = pl.DataFrame({col: series})
tmp = tmp.with_columns(f.apply(pl.col(col)).alias(col))
series = tmp[col]
data[col] = series.to_list()

return {
Expand Down
Loading
Loading