diff --git a/ci/Dockerfile b/ci/Dockerfile index 55a55e5e1..3f7115c70 100644 --- a/ci/Dockerfile +++ b/ci/Dockerfile @@ -19,7 +19,7 @@ RUN python -m pip install --upgrade pip setuptools wheel uv WORKDIR /workspace COPY ci/requirements.ci.txt /tmp/requirements.ci.txt -RUN python -m uv pip install --system -r /tmp/requirements.ci.txt +RUN python -m uv pip install -r /tmp/requirements.ci.txt # Working directory for the job diff --git a/ci/requirements.ci.txt b/ci/requirements.ci.txt index badd0d145..8d3c46816 100644 --- a/ci/requirements.ci.txt +++ b/ci/requirements.ci.txt @@ -10,4 +10,5 @@ attrs scipy cupy-cuda12x pandas -sympy \ No newline at end of file +sympy +cellmlmanip \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 65b3ad524..8c25c07aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,8 +41,8 @@ dev = [ "cupy-cuda12x", "pandas", "matplotlib", - "scipy" - + "scipy", + "cellmlmanip", ] cupy = [ "cupy-cuda12x", diff --git a/src/cubie/odesystems/symbolic/parsing/cellml.py b/src/cubie/odesystems/symbolic/parsing/cellml.py index 40779cc62..271e5c97f 100644 --- a/src/cubie/odesystems/symbolic/parsing/cellml.py +++ b/src/cubie/odesystems/symbolic/parsing/cellml.py @@ -1,12 +1,42 @@ """Minimal CellML parsing helpers using ``cellmlmanip``. -This wrapper is heavily inspired by -:mod:`chaste_codegen.model_with_conversions` from the chaste-codegen project -(MIT licence). Only a tiny subset required for basic model loading is -implemented here. -""" +This module provides functionality to import CellML models into CuBIE's +symbolic ODE framework. It wraps the cellmlmanip library to load +CellML files and convert them directly into SymbolicODE objects. + +The implementation is inspired by +:mod:`chaste_codegen.model_with_conversions` from the chaste-codegen +project (MIT licence). Only a minimal subset required for basic model +loading is implemented here. + +Examples +-------- +Basic CellML model loading workflow: + +>>> from cubie.odesystems.symbolic.parsing.cellml import ( +... load_cellml_model +... ) +>>> +>>> # Load a CellML model file - returns initialized SymbolicODE +>>> ode_system = load_cellml_model("cardiac_model.cellml") +>>> +>>> # The model is ready to use with solve_ivp +>>> print(f"Model has {ode_system.num_states} states") +>>> print(f"Model has {len(ode_system.indices.observables)} observables") + +Notes +----- +The cellmlmanip dependency is optional. Install with: -from __future__ import annotations + pip install cellmlmanip + +CellML models can be obtained from the Physiome Model Repository: +https://models.physiomeproject.org/ + +See Also +-------- +load_cellml_model : Main function for loading CellML files +""" try: # pragma: no cover - optional dependency import cellmlmanip # type: ignore @@ -14,25 +44,212 @@ cellmlmanip = None # type: ignore import sympy as sp +from pathlib import Path +import numpy as np +from typing import Optional, List +import re +from cubie._utils import PrecisionDType -def load_cellml_model(path: str) -> tuple[list[sp.Symbol], list[sp.Eq]]: - """Load a CellML model and extract states and derivatives. + +def _sanitize_symbol_name(name: str) -> str: + """Sanitize CellML symbol names for Python identifiers. + + CellML uses $ for namespacing and allows names starting with _ + followed by numbers. We need to convert these to valid Python + identifiers. + """ + # Replace $ with _ + name = name.replace('$', '_') + + # Replace . with _ + name = name.replace('.', '_') + + # If name starts with _, check if next char is a digit + # If so, prepend with 'var_' to make it valid + if name.startswith('_') and len(name) > 1 and name[1].isdigit(): + name = 'var' + name + + # Ensure name doesn't start with a digit + if name and name[0].isdigit(): + name = 'var_' + name + + # Replace any remaining invalid characters with _ + name = re.sub(r'[^a-zA-Z0-9_]', '_', name) + + return name + + +def load_cellml_model( + path: str, + precision: PrecisionDType = np.float32, + name: Optional[str] = None, + parameters: Optional[List[str]] = None, + observables: Optional[List[str]] = None, +): + """Load a CellML model and return an initialized SymbolicODE system. + + This function uses the cellmlmanip library to parse CellML files + and converts them into a ready-to-use SymbolicODE system with all + differential equations and algebraic constraints properly configured. Parameters ---------- - path - Filesystem path to the CellML source file. + path : str + Filesystem path to the CellML source file. Must have .cellml + extension and be a valid CellML 1.0 or 1.1 model file. + precision : numpy dtype, optional + Target floating-point precision for compiled kernels. + Default is np.float32. + name : str, optional + Identifier for the generated system. If None, uses the + filename without extension. + parameters : list of str, optional + List of symbol names to assign as parameters. Otherwise, + these symbols become constants or anonymous auxiliaries. + observables : list of str, optional + List of symbol names to assign as observables. Otherwise, + these symbols become anonymous auxiliaries. Returns ------- - tuple[list[sympy.Symbol], list[sympy.Eq]] - States and differential equations defined by the model. + SymbolicODE + Fully initialized ODE system ready for use with solve_ivp. + State variables are configured with initial values from the + CellML model, and algebraic equations are set up according + to the parameters and observables specifications. + + Raises + ------ + ImportError + If cellmlmanip is not installed. Install with: + pip install cellmlmanip + TypeError + If path is not a string. + FileNotFoundError + If the specified CellML file does not exist. + ValueError + If the file does not have .cellml extension. + + Examples + -------- + Load a CellML model and run a simulation: + + >>> from cubie import load_cellml_model, solve_ivp + >>> import numpy as np + >>> + >>> # Load the model + >>> ode_system = load_cellml_model("beeler_reuter_model_1977.cellml") + >>> + >>> # Set up simulation + >>> t_span = (0.0, 100.0) + >>> initial_states = np.ones(ode_system.num_states, dtype=np.float32) + >>> + >>> # Run simulation + >>> result = solve_ivp(ode_system, t_span, initial_states) + + Notes + ----- + - Differential equations become state equations in the ODE system + - Algebraic equations become observables or anonymous auxiliaries + - State variables are converted from sympy.Dummy to sympy.Symbol + - Initial values from CellML are preserved in the ODE system + - Supports CellML 1.0 and 1.1 formats + - CellML models from Physiome repository are compatible + - The cellmlmanip library handles the complex CellML XML parsing """ if cellmlmanip is None: # pragma: no cover raise ImportError("cellmlmanip is required for CellML parsing") + + # Validate input type + if not isinstance(path, str): + raise TypeError( + f"path must be a string, got {type(path).__name__}" + ) + + # Validate file existence + path_obj = Path(path) + if not path_obj.exists(): + raise FileNotFoundError(f"CellML file not found: {path}") + + # Validate file extension + if not path.endswith('.cellml'): + raise ValueError( + f"File must have .cellml extension, got: {path}" + ) + + # Use filename as default name if not provided + if name is None: + name = path_obj.stem + model = cellmlmanip.load_model(path) - states = list(model.get_state_variables()) - derivatives = list(model.get_derivatives()) - equations = [eq for eq in model.equations if eq.lhs in derivatives] - return states, equations + raw_states = list(model.get_state_variables()) + raw_derivatives = list(model.get_derivatives()) + + # Extract initial values from CellML model + initial_values = {} + + # Convert Dummy symbols to regular Symbols with sanitized names + # cellmlmanip returns Dummy symbols but we need regular Symbols + states = [] + dummy_to_symbol = {} + for raw_state in raw_states: + clean_name = _sanitize_symbol_name(raw_state.name) + symbol = sp.Symbol(clean_name) + dummy_to_symbol[raw_state] = symbol + states.append(symbol) + + # Get initial value if available + if hasattr(raw_state, 'initial_value') and raw_state.initial_value is not None: + initial_values[clean_name] = float(raw_state.initial_value) + + # Also convert any other Dummy symbols in the model equations + for eq in model.equations: + for atom in eq.atoms(sp.Dummy): + if atom not in dummy_to_symbol: + clean_name = _sanitize_symbol_name(atom.name) + dummy_to_symbol[atom] = sp.Symbol(clean_name) + + # Filter differential equations and algebraic equations separately + differential_equations = [] + algebraic_equations = [] + + for eq in model.equations: + eq_substituted = eq.subs(dummy_to_symbol) + if eq.lhs in raw_derivatives: + differential_equations.append(eq_substituted) + else: + algebraic_equations.append(eq_substituted) + + # Convert equations to string format for SymbolicODE.create() + dxdt_strings = [] + for eq in differential_equations: + # Get the state variable from the derivative + state_var = eq.lhs.args[0] + # Format as "dstate_name = rhs" (no slash) + dxdt_str = f"d{state_var.name} = {eq.rhs}" + dxdt_strings.append(dxdt_str) + + # Convert algebraic equations to strings + # These will be included in the equations list + all_equations = dxdt_strings.copy() + for eq in algebraic_equations: + # Format as "lhs = rhs" + obs_str = f"{eq.lhs} = {eq.rhs}" + all_equations.append(obs_str) + + # Import here to avoid circular import with codegen modules + # cellml is imported by parsing/__init__.py which is imported + # during SymbolicODE initialization, creating a circular dependency + from cubie.odesystems.symbolic.symbolicODE import SymbolicODE + + # Create and return the SymbolicODE system + return SymbolicODE.create( + dxdt=all_equations, + states=initial_values if initial_values else None, + parameters=parameters, + observables=observables, + name=name, + precision=precision, + strict=False, + ) diff --git a/tests/fixtures/cellml/basic_ode.cellml b/tests/fixtures/cellml/basic_ode.cellml new file mode 100644 index 000000000..ff43dd15e --- /dev/null +++ b/tests/fixtures/cellml/basic_ode.cellml @@ -0,0 +1,23 @@ + + + + + + + + + + + + time + x + + + + a + x + + + + + diff --git a/tests/fixtures/cellml/beeler_reuter_model_1977.cellml b/tests/fixtures/cellml/beeler_reuter_model_1977.cellml new file mode 100644 index 000000000..f79c8a5ea --- /dev/null +++ b/tests/fixtures/cellml/beeler_reuter_model_1977.cellml @@ -0,0 +1,1246 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + time + + V + + + + + + Istim + + + i_Na + i_s + i_x1 + i_K1 + + + C + + + + + + + + + + + + + + + + + + i_Na + + + + + + + g_Na + + + m + 3 + + h + j + + g_Nac + + + + V + E_Na + + + + + + + + + + + + + + + alpha_m + + + + + + + 1 + + + + V + 47 + + + + + + + + + + + 0.1 + + + + V + 47 + + + + 1 + + + + + + beta_m + + + 40 + + + + + + + 0.056 + + + + V + 72 + + + + + + + + + + + time + + m + + + + + + alpha_m + + + 1 + m + + + + + beta_m + m + + + + + + + + + + + + + + + alpha_h + + + 0.126 + + + + + + + 0.25 + + + + V + 77 + + + + + + + + beta_h + + + 1.7 + + + + + + + + + 0.082 + + + + V + 22.5 + + + + 1 + + + + + + + + + time + + h + + + + + + alpha_h + + + 1 + h + + + + + beta_h + h + + + + + + + + + + + + + + + alpha_j + + + + + 0.055 + + + + + + + 0.25 + + + + V + 78 + + + + + + + + + + + + + 0.2 + + + + V + 78 + + + + 1 + + + + + + beta_j + + + 0.3 + + + + + + + + + 0.1 + + + + V + 32 + + + + 1 + + + + + + + + + time + + j + + + + + + alpha_j + + + 1 + j + + + + + beta_j + j + + + + + + + + + + + + + + + + + + + + + + + + E_s + + + + + 82.3 + + + + 13.0287 + + + + + Cai + 0.001 + + + + + + + + i_s + + + g_s + d + f + + + V + E_s + + + + + + + + + time + + Cai + + + + + + + + + + 0.01 + + i_s + + 1 + + + + 0.07 + + + 0.0001 + Cai + + + + + + + + + + + + + + + + alpha_d + + + + + 0.095 + + + + + + + + + V + 5 + + + 100 + + + + + + 1 + + + + + + + + + V + 5 + + + 13.89 + + + + + + + + beta_d + + + + + 0.07 + + + + + + + + + V + 44 + + + 59 + + + + + + 1 + + + + + + + V + 44 + + 20 + + + + + + + + + + + time + + d + + + + + + alpha_d + + + 1 + d + + + + + beta_d + d + + + + + + + + + + + + + + + alpha_f + + + + + 0.012 + + + + + + + + + V + 28 + + + 125 + + + + + + 1 + + + + + + + V + 28 + + 6.67 + + + + + + + + beta_f + + + + + 0.0065 + + + + + + + + + V + 30 + + + 50 + + + + + + 1 + + + + + + + + + V + 30 + + + 5 + + + + + + + + + + + time + + f + + + + + + alpha_f + + + 1 + f + + + + + beta_f + f + + + + + + + + + + + + + + i_x1 + + + + + x1 + 8-3 + + + + + + + 0.04 + + + V + 77 + + + + 1 + + + + + + + 0.04 + + + V + 35 + + + + + + + + + + + + + + + + + alpha_x1 + + + + + 5-4 + + + + + + + V + 50 + + 12.1 + + + + + + 1 + + + + + + + V + 50 + + 17.5 + + + + + + + + beta_x1 + + + + + 0.0013 + + + + + + + + + V + 20 + + + 16.67 + + + + + + 1 + + + + + + + + + V + 20 + + + 25 + + + + + + + + + + + time + + x1 + + + + + + alpha_x1 + + + 1 + x1 + + + + + beta_x1 + x1 + + + + + + + + + + + + + i_K1 + + + 0.0035 + + + + + + + 4 + + + + + + + 0.04 + + + V + 85 + + + + 1 + + + + + + + + + 0.08 + + + V + 53 + + + + + + + + 0.04 + + + V + 53 + + + + + + + + + + 0.2 + + + V + 23 + + + + + 1 + + + + + + + 0.04 + + + + V + 23 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Istim + + + IstimAmplitude + + + + + time + IstimStart + + + + time + IstimEnd + + + + + + + + time + IstimStart + + + + + + + + + + time + IstimStart + + IstimPeriod + + + IstimPeriod + + + IstimPulseDuration + + + + + 0 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/odesystems/symbolic/test_cellml.py b/tests/odesystems/symbolic/test_cellml.py index fcee9652d..fe739f905 100644 --- a/tests/odesystems/symbolic/test_cellml.py +++ b/tests/odesystems/symbolic/test_cellml.py @@ -1,11 +1,155 @@ import pytest +from pathlib import Path +import numpy as np from cubie.odesystems.symbolic.parsing.cellml import load_cellml_model +from cubie._utils import is_devfunc -# -# def test_cellml_import_error(): -# """Missing dependency raises ImportError.""" -# -# with pytest.raises(ImportError): -# load_cellml_model("dummy.cellml") +# Note: cellmlmanip import removed - tests should fail if dependency missing +# This ensures critical information about missing dependencies is visible + + +@pytest.fixture +def cellml_fixtures_dir(): + """Return path to cellml test fixtures directory.""" + return Path(__file__).parent.parent.parent / "fixtures" / "cellml" + + +@pytest.fixture +def basic_model_path(cellml_fixtures_dir): + """Return path to basic ODE CellML model.""" + return cellml_fixtures_dir / "basic_ode.cellml" + + +@pytest.fixture +def beeler_reuter_model_path(cellml_fixtures_dir): + """Return path to Beeler-Reuter CellML model.""" + return cellml_fixtures_dir / "beeler_reuter_model_1977.cellml" + + +def test_load_simple_cellml_model(basic_model_path): + """Load a simple CellML model successfully.""" + ode_system = load_cellml_model(str(basic_model_path)) + + assert ode_system.num_states == 1 + assert is_devfunc(ode_system.dxdt_function) + + +def test_load_complex_cellml_model(beeler_reuter_model_path): + """Load Beeler-Reuter cardiac model successfully.""" + ode_system = load_cellml_model(str(beeler_reuter_model_path)) + + # Beeler-Reuter has 8 state variables + assert ode_system.num_states == 8 + assert is_devfunc(ode_system.dxdt_function) + + +def test_ode_system_has_correct_attributes(basic_model_path): + """Verify ODE system has expected attributes.""" + ode_system = load_cellml_model(str(basic_model_path)) + + # Should have SymbolicODE attributes + assert hasattr(ode_system, 'num_states') + assert hasattr(ode_system, 'equations') + assert hasattr(ode_system, 'indices') + + +def test_algebraic_equations_as_observables(beeler_reuter_model_path): + """Verify algebraic equations can be assigned as observables.""" + # Load with specific observables (sanitized names from the model) + observable_names = ["sodium_current_i_Na", "sodium_current_m_gate_alpha_m"] + ode_system = load_cellml_model( + str(beeler_reuter_model_path), + observables=observable_names + ) + + # Verify the observables were assigned + obs_map = ode_system.indices.observables.index_map + assert len(obs_map) > 0 + + # Check that the requested observables are present + # Keys are symbols, so we need to compare names + obs_symbol_names = [str(k) for k in obs_map.keys()] + assert len(obs_map) == 2 + for obs_name in observable_names: + assert obs_name in obs_symbol_names + + +def test_invalid_path_type(): + """Verify TypeError raised for non-string path.""" + with pytest.raises(TypeError, match="path must be a string"): + load_cellml_model(123) + + +def test_nonexistent_file(): + """Verify FileNotFoundError raised for missing file.""" + with pytest.raises(FileNotFoundError, match="CellML file not found"): + load_cellml_model("/nonexistent/path/model.cellml") + + +def test_invalid_extension(): + """Verify ValueError raised for non-.cellml extension.""" + import tempfile + import os + + # Create a temporary file with wrong extension + with tempfile.NamedTemporaryFile(mode='w', suffix='.xml', + delete=False) as f: + temp_path = f.name + + try: + with pytest.raises(ValueError, match="must have .cellml extension"): + load_cellml_model(temp_path) + finally: + os.unlink(temp_path) + + +def test_custom_precision(basic_model_path): + """Verify custom precision can be specified.""" + ode_system = load_cellml_model( + str(basic_model_path), + precision=np.float64 + ) + + assert ode_system.precision == np.float64 + + +def test_custom_name(basic_model_path): + """Verify custom name can be specified.""" + ode_system = load_cellml_model( + str(basic_model_path), + name="custom_model" + ) + + assert ode_system.name == "custom_model" + + +def test_integration_with_solve_ivp(basic_model_path): + """Test that loaded model builds and is ready for solve_ivp.""" + # Use float64 to avoid dtype mismatch in cuda simulator + ode_system = load_cellml_model(str(basic_model_path), precision=np.float64) + + # Build the system - this is the critical step that verifies + # the model is properly structured for integration + ode_system.build() + + # Verify the model has the necessary components + assert is_devfunc(ode_system.dxdt_function) + assert ode_system.num_states == 1 + + # Verify initial values are accessible + assert ode_system.indices.states.defaults is not None + assert len(ode_system.indices.states.defaults) == 1 + + +def test_initial_values_from_cellml(beeler_reuter_model_path): + """Verify initial values from CellML model are preserved.""" + ode_system = load_cellml_model(str(beeler_reuter_model_path)) + + # Check that initial values were set using defaults dict + assert ode_system.indices.states.defaults is not None + assert len(ode_system.indices.states.defaults) == 8 + + # Initial values should be non-zero (from the model) + assert any(v != 0 for v in ode_system.indices.states.defaults.values())