Problem
Issue #30 demonstrated near real-time capability (50 Hz) but cannot reach 100+ Hz due to physics solver bottleneck.
Current performance:
- Physics step: 18 ms/step → max 55 Hz
- Bottleneck: MHD solver RHS computation
- JAX JIT not enabled (dataclass incompatibility)
Target: 100+ Hz sustained throughput for strict real-time control
Proposed Solutions
1. JAX Pytree Registration (highest impact)
Convert ElsasserState to JAX-friendly structure:
# Current (non-JIT):
@dataclass
class ElsasserState:
z_plus: jnp.ndarray
z_minus: jnp.ndarray
P: jnp.ndarray
# Proposed (JIT-compatible):
@jax.tree_util.register_pytree_node_class
class ElsasserState:
# ... with flatten/unflatten methods
Expected speedup: 2-5× → 100-150 Hz ✅
Effort: 4-6 hours (refactor dataclass + tests)
2. GPU Acceleration
Enable JAX GPU backend:
# Automatic GPU placement
import jax
jax.config.update('jax_platform_name', 'gpu')
Expected speedup: 5-10× (on top of JIT) → 250-500 Hz ✅
Requirements:
- GPU-enabled machine
- JAX GPU build
- Batch processing for efficiency
Effort: 2-3 hours (if JAX GPU available)
3. Adaptive Resolution
High-res only where needed:
- Fine grid near resonant surfaces
- Coarse grid elsewhere
- Dynamic refinement
Expected speedup: 2-3× with accuracy preservation
Effort: 1-2 weeks (complex implementation)
4. Fast Observation Approximation
Avoid Poisson solve for every observation:
- Approximate m-mode extraction from z± directly
- Skip full ψ/φ reconstruction
- Only compute when needed
Expected speedup: Remove observation bottleneck (currently 521 ms)
Effort: 1 week
Recommended Approach (v3.1)
Phase 1: JAX Pytree (quick win) ⚡
- Refactor ElsasserState → pytree
- Enable @jax.jit on RHS computation
- Target: 100-150 Hz
- Timeline: 1-2 days
Phase 2: GPU (if available)
- Enable GPU backend
- Optimize batch size
- Target: 250-500 Hz
- Timeline: 1 day
Phase 3: Fast diagnostics (polish)
- Approximate observation mode
- Reduce overhead further
- Timeline: 1 week
Success Criteria
- ✅ Sustained throughput >100 Hz
- ✅ Physics accuracy maintained (within 5%)
- ✅ GPU support (if hardware available)
- ✅ Backward compatibility with v3.0 API
Priority
P2-medium - Not critical for v3.0 release (50 Hz sufficient). Good v3.1 enhancement.
Created by: 小A 🤖
Date: 2026-03-24
Context: Issue #30 completion, v3.1 planning
Problem
Issue #30 demonstrated near real-time capability (50 Hz) but cannot reach 100+ Hz due to physics solver bottleneck.
Current performance:
Target: 100+ Hz sustained throughput for strict real-time control
Proposed Solutions
1. JAX Pytree Registration (highest impact)
Convert ElsasserState to JAX-friendly structure:
Expected speedup: 2-5× → 100-150 Hz ✅
Effort: 4-6 hours (refactor dataclass + tests)
2. GPU Acceleration
Enable JAX GPU backend:
Expected speedup: 5-10× (on top of JIT) → 250-500 Hz ✅
Requirements:
Effort: 2-3 hours (if JAX GPU available)
3. Adaptive Resolution
High-res only where needed:
Expected speedup: 2-3× with accuracy preservation
Effort: 1-2 weeks (complex implementation)
4. Fast Observation Approximation
Avoid Poisson solve for every observation:
Expected speedup: Remove observation bottleneck (currently 521 ms)
Effort: 1 week
Recommended Approach (v3.1)
Phase 1: JAX Pytree (quick win) ⚡
Phase 2: GPU (if available)
Phase 3: Fast diagnostics (polish)
Success Criteria
Priority
P2-medium - Not critical for v3.0 release (50 Hz sufficient). Good v3.1 enhancement.
Created by: 小A 🤖
Date: 2026-03-24
Context: Issue #30 completion, v3.1 planning