From eea100166e45b4f03257a235e7d00bb88c5dd7b0 Mon Sep 17 00:00:00 2001 From: Mike Grudic Date: Thu, 12 Feb 2026 15:54:55 -0500 Subject: [PATCH] walkthrough for converting between symbolic and numerical --- docs/index.rst | 1 + docs/source/jaco_to_numerical.rst | 136 ++++++++++++++++++++++++++++++ 2 files changed, 137 insertions(+) create mode 100644 docs/source/jaco_to_numerical.rst diff --git a/docs/index.rst b/docs/index.rst index fec4662..975e226 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -10,6 +10,7 @@ Welcome to jaco's documentation! source/installation source/Quickstart + source/jaco_to_numerical source/API source/community diff --git a/docs/source/jaco_to_numerical.rst b/docs/source/jaco_to_numerical.rst new file mode 100644 index 0000000..53f4040 --- /dev/null +++ b/docs/source/jaco_to_numerical.rst @@ -0,0 +1,136 @@ +Turning a jaco symbolic expression into a numerical function +============================================================ + +``jaco`` abstracts the rates of microphysical processes as symbolic +expressions, but to compute something we must interface this +representation with numerical functions. In this example, we will take +the rate of :math:`H_2` formation as implemented in the +``jaco.models.starforge`` model and create a JAX function to compute it +from array data containing the quantities in the expression. + +First we import stuff: + +.. code:: ipython3 + + from jaco.models.starforge.h2_chemistry.grain_formation import grain_formation + import sympy as sp + import numpy as np + +The rate per cm^-3 of the process is given by the ``rate`` attribute: + +.. code:: ipython3 + + grain_formation.rate + + + + +.. math:: + + \displaystyle \frac{3.0 \cdot 10^{-18} \sqrt{T} Z_{d} f_{d} n_{H}}{\left(1.0 + 10000.0 e^{- \frac{600.0}{Td}}\right) \left(8.0 \cdot 10^{-6} T^{2} + 0.002 T + 0.04 \sqrt{T + Td} + 1.0\right)} + + + +First we must inspect the symbols present in the symbolic object using +``free_symbols``: + +.. code:: ipython3 + + grain_formation.rate.free_symbols + + + + +.. parsed-literal:: + + {T, Td, Z_d, f_d, n_H} + + + +These are the quantities we must pass to a function to obtain numerical +values. + +We can then construct the numerical function using ``sympy.lambdify``, +specifying JAX as the backend. See the sympy docs for the complete list +of available backends, including e.g. numpy. + +.. code:: ipython3 + + rate = grain_formation.rate + symbols = list(sp.ordered(rate.free_symbols)) + rate_lambdified = sp.lambdify(symbols, rate, "jax") + ?rate_lambdified + + +.. parsed-literal:: + + Signature: rate_lambdified(T, Td, Z_d, f_d, n_H) + Docstring: + Created with lambdify. Signature: + + func(T, Td, Z_d, f_d, n_H) + + Expression: + + 3.0e-18*sqrt(T)*Z_d*f_d*n_H/((1.0 + 10000.0*exp(-600.0/Td))*(8.0e-6*T**2 +... + + Source code: + + def _lambdifygenerated(T, Td, Z_d, f_d, n_H): + return 3.0e-18*sqrt(T)*Z_d*f_d*n_H/((1.0 + 10000.0*exp(-600.0/Td))*(8.0e-6*T**2 + 0.002*T + 0.04*sqrt(T + Td) + 1.0)) + + + Imported modules: + + from jax.numpy import sqrt + from jax.numpy import exp + File: ~/code/starforge_tools/microphysics/ + Type: function + +We see that sympy has created a function whose arguments are the symbols +in the expression. This can be used directly, but we will wrap it in a +function with a ``**kwargs`` interface so that we don’t have to worry +about manually entering the arguments in the correct order: + +.. code:: ipython3 + + def H2_formation_rate(**kwargs): + """Returns the H2 formation rate in cm^-3 s^-1 + """ + return rate_lambdified(kwargs) + + + + + +.. parsed-literal:: + + Array([2.8339944e-18, 2.8339944e-18, 2.8339944e-18, 2.8339944e-18, + 2.8339944e-18, 2.8339944e-18, 2.8339944e-18, 2.8339944e-18, + 2.8339944e-18, 2.8339944e-18], dtype=float32) + + + +Finally, let’s create some example arrays and plug in the values: + +.. code:: ipython3 + + num = 10 + T=np.repeat(1.,num) + Td=np.repeat(1.,num) + Z_d = np.ones(num) + f_d = np.ones(num) + n_H = np.ones(num) + + H2_formation_rate(T=T,Td=Td,Z_d=Z_d,f_d=f_d,n_H=n_H) + + + + +.. parsed-literal:: + + Array([2.8339944e-18, 2.8339944e-18, 2.8339944e-18, 2.8339944e-18, + 2.8339944e-18, 2.8339944e-18, 2.8339944e-18, 2.8339944e-18, + 2.8339944e-18, 2.8339944e-18], dtype=float32) + +