diff --git a/ci/Dockerfile b/ci/Dockerfile
index 55a55e5e1..3f7115c70 100644
--- a/ci/Dockerfile
+++ b/ci/Dockerfile
@@ -19,7 +19,7 @@ RUN python -m pip install --upgrade pip setuptools wheel uv
WORKDIR /workspace
COPY ci/requirements.ci.txt /tmp/requirements.ci.txt
-RUN python -m uv pip install --system -r /tmp/requirements.ci.txt
+RUN python -m uv pip install -r /tmp/requirements.ci.txt
# Working directory for the job
diff --git a/ci/requirements.ci.txt b/ci/requirements.ci.txt
index badd0d145..8d3c46816 100644
--- a/ci/requirements.ci.txt
+++ b/ci/requirements.ci.txt
@@ -10,4 +10,5 @@ attrs
scipy
cupy-cuda12x
pandas
-sympy
\ No newline at end of file
+sympy
+cellmlmanip
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 65b3ad524..8c25c07aa 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -41,8 +41,8 @@ dev = [
"cupy-cuda12x",
"pandas",
"matplotlib",
- "scipy"
-
+ "scipy",
+ "cellmlmanip",
]
cupy = [
"cupy-cuda12x",
diff --git a/src/cubie/odesystems/symbolic/parsing/cellml.py b/src/cubie/odesystems/symbolic/parsing/cellml.py
index 40779cc62..271e5c97f 100644
--- a/src/cubie/odesystems/symbolic/parsing/cellml.py
+++ b/src/cubie/odesystems/symbolic/parsing/cellml.py
@@ -1,12 +1,42 @@
"""Minimal CellML parsing helpers using ``cellmlmanip``.
-This wrapper is heavily inspired by
-:mod:`chaste_codegen.model_with_conversions` from the chaste-codegen project
-(MIT licence). Only a tiny subset required for basic model loading is
-implemented here.
-"""
+This module provides functionality to import CellML models into CuBIE's
+symbolic ODE framework. It wraps the cellmlmanip library to load
+CellML files and convert them directly into SymbolicODE objects.
+
+The implementation is inspired by
+:mod:`chaste_codegen.model_with_conversions` from the chaste-codegen
+project (MIT licence). Only a minimal subset required for basic model
+loading is implemented here.
+
+Examples
+--------
+Basic CellML model loading workflow:
+
+>>> from cubie.odesystems.symbolic.parsing.cellml import (
+... load_cellml_model
+... )
+>>>
+>>> # Load a CellML model file - returns initialized SymbolicODE
+>>> ode_system = load_cellml_model("cardiac_model.cellml")
+>>>
+>>> # The model is ready to use with solve_ivp
+>>> print(f"Model has {ode_system.num_states} states")
+>>> print(f"Model has {len(ode_system.indices.observables)} observables")
+
+Notes
+-----
+The cellmlmanip dependency is optional. Install with:
-from __future__ import annotations
+ pip install cellmlmanip
+
+CellML models can be obtained from the Physiome Model Repository:
+https://models.physiomeproject.org/
+
+See Also
+--------
+load_cellml_model : Main function for loading CellML files
+"""
try: # pragma: no cover - optional dependency
import cellmlmanip # type: ignore
@@ -14,25 +44,212 @@
cellmlmanip = None # type: ignore
import sympy as sp
+from pathlib import Path
+import numpy as np
+from typing import Optional, List
+import re
+from cubie._utils import PrecisionDType
-def load_cellml_model(path: str) -> tuple[list[sp.Symbol], list[sp.Eq]]:
- """Load a CellML model and extract states and derivatives.
+
+def _sanitize_symbol_name(name: str) -> str:
+ """Sanitize CellML symbol names for Python identifiers.
+
+ CellML uses $ for namespacing and allows names starting with _
+ followed by numbers. We need to convert these to valid Python
+ identifiers.
+ """
+ # Replace $ with _
+ name = name.replace('$', '_')
+
+ # Replace . with _
+ name = name.replace('.', '_')
+
+ # If name starts with _, check if next char is a digit
+ # If so, prepend with 'var_' to make it valid
+ if name.startswith('_') and len(name) > 1 and name[1].isdigit():
+ name = 'var' + name
+
+ # Ensure name doesn't start with a digit
+ if name and name[0].isdigit():
+ name = 'var_' + name
+
+ # Replace any remaining invalid characters with _
+ name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
+
+ return name
+
+
+def load_cellml_model(
+ path: str,
+ precision: PrecisionDType = np.float32,
+ name: Optional[str] = None,
+ parameters: Optional[List[str]] = None,
+ observables: Optional[List[str]] = None,
+):
+ """Load a CellML model and return an initialized SymbolicODE system.
+
+ This function uses the cellmlmanip library to parse CellML files
+ and converts them into a ready-to-use SymbolicODE system with all
+ differential equations and algebraic constraints properly configured.
Parameters
----------
- path
- Filesystem path to the CellML source file.
+ path : str
+ Filesystem path to the CellML source file. Must have .cellml
+ extension and be a valid CellML 1.0 or 1.1 model file.
+ precision : numpy dtype, optional
+ Target floating-point precision for compiled kernels.
+ Default is np.float32.
+ name : str, optional
+ Identifier for the generated system. If None, uses the
+ filename without extension.
+ parameters : list of str, optional
+ List of symbol names to assign as parameters. Otherwise,
+ these symbols become constants or anonymous auxiliaries.
+ observables : list of str, optional
+ List of symbol names to assign as observables. Otherwise,
+ these symbols become anonymous auxiliaries.
Returns
-------
- tuple[list[sympy.Symbol], list[sympy.Eq]]
- States and differential equations defined by the model.
+ SymbolicODE
+ Fully initialized ODE system ready for use with solve_ivp.
+ State variables are configured with initial values from the
+ CellML model, and algebraic equations are set up according
+ to the parameters and observables specifications.
+
+ Raises
+ ------
+ ImportError
+ If cellmlmanip is not installed. Install with:
+ pip install cellmlmanip
+ TypeError
+ If path is not a string.
+ FileNotFoundError
+ If the specified CellML file does not exist.
+ ValueError
+ If the file does not have .cellml extension.
+
+ Examples
+ --------
+ Load a CellML model and run a simulation:
+
+ >>> from cubie import load_cellml_model, solve_ivp
+ >>> import numpy as np
+ >>>
+ >>> # Load the model
+ >>> ode_system = load_cellml_model("beeler_reuter_model_1977.cellml")
+ >>>
+ >>> # Set up simulation
+ >>> t_span = (0.0, 100.0)
+ >>> initial_states = np.ones(ode_system.num_states, dtype=np.float32)
+ >>>
+ >>> # Run simulation
+ >>> result = solve_ivp(ode_system, t_span, initial_states)
+
+ Notes
+ -----
+ - Differential equations become state equations in the ODE system
+ - Algebraic equations become observables or anonymous auxiliaries
+ - State variables are converted from sympy.Dummy to sympy.Symbol
+ - Initial values from CellML are preserved in the ODE system
+ - Supports CellML 1.0 and 1.1 formats
+ - CellML models from Physiome repository are compatible
+ - The cellmlmanip library handles the complex CellML XML parsing
"""
if cellmlmanip is None: # pragma: no cover
raise ImportError("cellmlmanip is required for CellML parsing")
+
+ # Validate input type
+ if not isinstance(path, str):
+ raise TypeError(
+ f"path must be a string, got {type(path).__name__}"
+ )
+
+ # Validate file existence
+ path_obj = Path(path)
+ if not path_obj.exists():
+ raise FileNotFoundError(f"CellML file not found: {path}")
+
+ # Validate file extension
+ if not path.endswith('.cellml'):
+ raise ValueError(
+ f"File must have .cellml extension, got: {path}"
+ )
+
+ # Use filename as default name if not provided
+ if name is None:
+ name = path_obj.stem
+
model = cellmlmanip.load_model(path)
- states = list(model.get_state_variables())
- derivatives = list(model.get_derivatives())
- equations = [eq for eq in model.equations if eq.lhs in derivatives]
- return states, equations
+ raw_states = list(model.get_state_variables())
+ raw_derivatives = list(model.get_derivatives())
+
+ # Extract initial values from CellML model
+ initial_values = {}
+
+ # Convert Dummy symbols to regular Symbols with sanitized names
+ # cellmlmanip returns Dummy symbols but we need regular Symbols
+ states = []
+ dummy_to_symbol = {}
+ for raw_state in raw_states:
+ clean_name = _sanitize_symbol_name(raw_state.name)
+ symbol = sp.Symbol(clean_name)
+ dummy_to_symbol[raw_state] = symbol
+ states.append(symbol)
+
+ # Get initial value if available
+ if hasattr(raw_state, 'initial_value') and raw_state.initial_value is not None:
+ initial_values[clean_name] = float(raw_state.initial_value)
+
+ # Also convert any other Dummy symbols in the model equations
+ for eq in model.equations:
+ for atom in eq.atoms(sp.Dummy):
+ if atom not in dummy_to_symbol:
+ clean_name = _sanitize_symbol_name(atom.name)
+ dummy_to_symbol[atom] = sp.Symbol(clean_name)
+
+ # Filter differential equations and algebraic equations separately
+ differential_equations = []
+ algebraic_equations = []
+
+ for eq in model.equations:
+ eq_substituted = eq.subs(dummy_to_symbol)
+ if eq.lhs in raw_derivatives:
+ differential_equations.append(eq_substituted)
+ else:
+ algebraic_equations.append(eq_substituted)
+
+ # Convert equations to string format for SymbolicODE.create()
+ dxdt_strings = []
+ for eq in differential_equations:
+ # Get the state variable from the derivative
+ state_var = eq.lhs.args[0]
+ # Format as "dstate_name = rhs" (no slash)
+ dxdt_str = f"d{state_var.name} = {eq.rhs}"
+ dxdt_strings.append(dxdt_str)
+
+ # Convert algebraic equations to strings
+ # These will be included in the equations list
+ all_equations = dxdt_strings.copy()
+ for eq in algebraic_equations:
+ # Format as "lhs = rhs"
+ obs_str = f"{eq.lhs} = {eq.rhs}"
+ all_equations.append(obs_str)
+
+ # Import here to avoid circular import with codegen modules
+ # cellml is imported by parsing/__init__.py which is imported
+ # during SymbolicODE initialization, creating a circular dependency
+ from cubie.odesystems.symbolic.symbolicODE import SymbolicODE
+
+ # Create and return the SymbolicODE system
+ return SymbolicODE.create(
+ dxdt=all_equations,
+ states=initial_values if initial_values else None,
+ parameters=parameters,
+ observables=observables,
+ name=name,
+ precision=precision,
+ strict=False,
+ )
diff --git a/tests/fixtures/cellml/basic_ode.cellml b/tests/fixtures/cellml/basic_ode.cellml
new file mode 100644
index 000000000..ff43dd15e
--- /dev/null
+++ b/tests/fixtures/cellml/basic_ode.cellml
@@ -0,0 +1,23 @@
+
+
+
+
+
+
+
+
+
diff --git a/tests/fixtures/cellml/beeler_reuter_model_1977.cellml b/tests/fixtures/cellml/beeler_reuter_model_1977.cellml
new file mode 100644
index 000000000..f79c8a5ea
--- /dev/null
+++ b/tests/fixtures/cellml/beeler_reuter_model_1977.cellml
@@ -0,0 +1,1246 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/tests/odesystems/symbolic/test_cellml.py b/tests/odesystems/symbolic/test_cellml.py
index fcee9652d..fe739f905 100644
--- a/tests/odesystems/symbolic/test_cellml.py
+++ b/tests/odesystems/symbolic/test_cellml.py
@@ -1,11 +1,155 @@
import pytest
+from pathlib import Path
+import numpy as np
from cubie.odesystems.symbolic.parsing.cellml import load_cellml_model
+from cubie._utils import is_devfunc
-#
-# def test_cellml_import_error():
-# """Missing dependency raises ImportError."""
-#
-# with pytest.raises(ImportError):
-# load_cellml_model("dummy.cellml")
+# Note: cellmlmanip import removed - tests should fail if dependency missing
+# This ensures critical information about missing dependencies is visible
+
+
+@pytest.fixture
+def cellml_fixtures_dir():
+ """Return path to cellml test fixtures directory."""
+ return Path(__file__).parent.parent.parent / "fixtures" / "cellml"
+
+
+@pytest.fixture
+def basic_model_path(cellml_fixtures_dir):
+ """Return path to basic ODE CellML model."""
+ return cellml_fixtures_dir / "basic_ode.cellml"
+
+
+@pytest.fixture
+def beeler_reuter_model_path(cellml_fixtures_dir):
+ """Return path to Beeler-Reuter CellML model."""
+ return cellml_fixtures_dir / "beeler_reuter_model_1977.cellml"
+
+
+def test_load_simple_cellml_model(basic_model_path):
+ """Load a simple CellML model successfully."""
+ ode_system = load_cellml_model(str(basic_model_path))
+
+ assert ode_system.num_states == 1
+ assert is_devfunc(ode_system.dxdt_function)
+
+
+def test_load_complex_cellml_model(beeler_reuter_model_path):
+ """Load Beeler-Reuter cardiac model successfully."""
+ ode_system = load_cellml_model(str(beeler_reuter_model_path))
+
+ # Beeler-Reuter has 8 state variables
+ assert ode_system.num_states == 8
+ assert is_devfunc(ode_system.dxdt_function)
+
+
+def test_ode_system_has_correct_attributes(basic_model_path):
+ """Verify ODE system has expected attributes."""
+ ode_system = load_cellml_model(str(basic_model_path))
+
+ # Should have SymbolicODE attributes
+ assert hasattr(ode_system, 'num_states')
+ assert hasattr(ode_system, 'equations')
+ assert hasattr(ode_system, 'indices')
+
+
+def test_algebraic_equations_as_observables(beeler_reuter_model_path):
+ """Verify algebraic equations can be assigned as observables."""
+ # Load with specific observables (sanitized names from the model)
+ observable_names = ["sodium_current_i_Na", "sodium_current_m_gate_alpha_m"]
+ ode_system = load_cellml_model(
+ str(beeler_reuter_model_path),
+ observables=observable_names
+ )
+
+ # Verify the observables were assigned
+ obs_map = ode_system.indices.observables.index_map
+ assert len(obs_map) > 0
+
+ # Check that the requested observables are present
+ # Keys are symbols, so we need to compare names
+ obs_symbol_names = [str(k) for k in obs_map.keys()]
+ assert len(obs_map) == 2
+ for obs_name in observable_names:
+ assert obs_name in obs_symbol_names
+
+
+def test_invalid_path_type():
+ """Verify TypeError raised for non-string path."""
+ with pytest.raises(TypeError, match="path must be a string"):
+ load_cellml_model(123)
+
+
+def test_nonexistent_file():
+ """Verify FileNotFoundError raised for missing file."""
+ with pytest.raises(FileNotFoundError, match="CellML file not found"):
+ load_cellml_model("/nonexistent/path/model.cellml")
+
+
+def test_invalid_extension():
+ """Verify ValueError raised for non-.cellml extension."""
+ import tempfile
+ import os
+
+ # Create a temporary file with wrong extension
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.xml',
+ delete=False) as f:
+ temp_path = f.name
+
+ try:
+ with pytest.raises(ValueError, match="must have .cellml extension"):
+ load_cellml_model(temp_path)
+ finally:
+ os.unlink(temp_path)
+
+
+def test_custom_precision(basic_model_path):
+ """Verify custom precision can be specified."""
+ ode_system = load_cellml_model(
+ str(basic_model_path),
+ precision=np.float64
+ )
+
+ assert ode_system.precision == np.float64
+
+
+def test_custom_name(basic_model_path):
+ """Verify custom name can be specified."""
+ ode_system = load_cellml_model(
+ str(basic_model_path),
+ name="custom_model"
+ )
+
+ assert ode_system.name == "custom_model"
+
+
+def test_integration_with_solve_ivp(basic_model_path):
+ """Test that loaded model builds and is ready for solve_ivp."""
+ # Use float64 to avoid dtype mismatch in cuda simulator
+ ode_system = load_cellml_model(str(basic_model_path), precision=np.float64)
+
+ # Build the system - this is the critical step that verifies
+ # the model is properly structured for integration
+ ode_system.build()
+
+ # Verify the model has the necessary components
+ assert is_devfunc(ode_system.dxdt_function)
+ assert ode_system.num_states == 1
+
+ # Verify initial values are accessible
+ assert ode_system.indices.states.defaults is not None
+ assert len(ode_system.indices.states.defaults) == 1
+
+
+def test_initial_values_from_cellml(beeler_reuter_model_path):
+ """Verify initial values from CellML model are preserved."""
+ ode_system = load_cellml_model(str(beeler_reuter_model_path))
+
+ # Check that initial values were set using defaults dict
+ assert ode_system.indices.states.defaults is not None
+ assert len(ode_system.indices.states.defaults) == 8
+
+ # Initial values should be non-zero (from the model)
+ assert any(v != 0 for v in ode_system.indices.states.defaults.values())