Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,4 @@ fabric.properties

/mlruns
.vscode/launch.json
myenv/
15 changes: 2 additions & 13 deletions tests/test_forward/test_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from tsadar.utils import misc
from tsadar.core.thomson_diagnostic import ThomsonScatteringDiagnostic
from tsadar.core.modules import ThomsonParams
from tsadar.utils.data_handling.calibration import get_scattering_angles


def test_1d_forward_pass():
Expand Down Expand Up @@ -39,17 +38,6 @@ def test_1d_forward_pass():
defaults.update(flatten(inputs))
config = unflatten(defaults)

# get scattering angles and weights
config["other"]["lamrangE"] = [
config["data"]["fit_rng"]["forward_epw_start"],
config["data"]["fit_rng"]["forward_epw_end"],
]
config["other"]["lamrangI"] = [
config["data"]["fit_rng"]["forward_iaw_start"],
config["data"]["fit_rng"]["forward_iaw_end"],
]
config["other"]["npts"] = int(config["other"]["CCDsize"][1] * config["other"]["points_per_pixel"])
sas = get_scattering_angles(config)

dummy_batch = {
"i_data": np.array([1]),
Expand All @@ -60,7 +48,8 @@ def test_1d_forward_pass():
"i_amps": np.array([1]),
}

ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas)
ts_diag = ThomsonScatteringDiagnostic(config, angular=False, cumulative=False)
config = ts_diag.get_cfg()
ts_params = ThomsonParams(config["parameters"], num_params=1, batch=True, activate=True)
ThryE, ThryI, lamAxisE, lamAxisI = ts_diag(ts_params, dummy_batch)

Expand Down
24 changes: 2 additions & 22 deletions tests/test_forward/test_angular_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@

from tsadar.core.thomson_diagnostic import ThomsonScatteringDiagnostic
from tsadar.core.modules import ThomsonParams
from tsadar.utils.data_handling.calibration import get_scattering_angles, get_calibrations


def test_arts1d_forward_pass():
"""
Expand All @@ -39,25 +37,6 @@ def test_arts1d_forward_pass():
defaults.update(flatten(inputs))
config = unflatten(defaults)

# get scattering angles and weights
config["other"]["lamrangE"] = [
config["data"]["fit_rng"]["forward_epw_start"],
config["data"]["fit_rng"]["forward_epw_end"],
]
config["other"]["lamrangI"] = [
config["data"]["fit_rng"]["forward_iaw_start"],
config["data"]["fit_rng"]["forward_iaw_end"],
]
config["other"]["npts"] = int(config["other"]["CCDsize"][1] * config["other"]["points_per_pixel"])
sas = get_scattering_angles(config)

[axisxE, _, _, _, _, _] = get_calibrations(
104000, config["other"]["extraoptions"]["spectype"], 0.0, config["other"]["CCDsize"]
) # shot number hardcoded to get calibration
config["other"]["extraoptions"]["spectype"] = "angular_full"

sas["angAxis"] = axisxE

dummy_batch = {
"i_data": np.ones((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1])),
"e_data": np.ones((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1])),
Expand All @@ -67,7 +46,8 @@ def test_arts1d_forward_pass():
"i_amps": np.array([1]),
}

ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas)
ts_diag = ThomsonScatteringDiagnostic(config, angular=True, cumulative=False)
config = ts_diag.get_cfg()
ts_params = ThomsonParams(config["parameters"], num_params=1, batch=False, activate=True)
ThryE, ThryI, lamAxisE, lamAxisI = ts_diag(ts_params, dummy_batch)
# np.save("tests/test_forward/ThryE-arts1d.npy", ThryE)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_forward/test_angular_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_arts2d_forward_pass():
"i_amps": np.array([1]),
}

ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas)
ts_diag = ThomsonScatteringDiagnostic(config, angular=True, cumulative=False)
ts_params = ThomsonParams(config["parameters"], num_params=1, batch=False)
ThryE, ThryI, lamAxisE, lamAxisI = ts_diag(ts_params, dummy_batch)
# np.save("tests/test_forward/ThryE-arts2d.npy", ThryE)
Expand Down
22 changes: 7 additions & 15 deletions tests/test_inverse/test_1d_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from tsadar.utils import misc
from tsadar.core.thomson_diagnostic import ThomsonScatteringDiagnostic
from tsadar.core.modules import ThomsonParams, get_filter_spec
from tsadar.utils.data_handling.calibration import get_scattering_angles


def _perturb_params_(rng, params):
Expand Down Expand Up @@ -79,18 +78,6 @@ def test_1d_inverse():
defaults.update(flatten(inputs))
config = unflatten(defaults)

# get scattering angles and weights
config["other"]["lamrangE"] = [
config["data"]["fit_rng"]["forward_epw_start"],
config["data"]["fit_rng"]["forward_epw_end"],
]
config["other"]["lamrangI"] = [
config["data"]["fit_rng"]["forward_iaw_start"],
config["data"]["fit_rng"]["forward_iaw_end"],
]
config["other"]["npts"] = int(config["other"]["CCDsize"][1] * config["other"]["points_per_pixel"])
sas = get_scattering_angles(config)

dummy_batch = {
"i_data": np.array([1]),
"e_data": np.array([1]),
Expand All @@ -100,7 +87,8 @@ def test_1d_inverse():
"i_amps": np.array([1]),
}
rng = np.random.default_rng()
ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas)
ts_diag = ThomsonScatteringDiagnostic(config, angular=False, cumulative=False)
config = ts_diag.get_cfg()
config["parameters"] = _perturb_params_(rng, config["parameters"])
misc.log_mlflow(config)
ts_params_gt = ThomsonParams(config["parameters"], num_params=1, batch=True, activate=True)
Expand All @@ -109,7 +97,7 @@ def test_1d_inverse():

loss = 1
while np.nan_to_num(loss, nan=1) > 1e-3:
ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas)
ts_diag = ThomsonScatteringDiagnostic(config, angular=False, cumulative=False)
config["parameters"] = _perturb_params_(rng, config["parameters"])
ts_params_fit = ThomsonParams(config["parameters"], num_params=1, batch=True, activate=True)
diff_params, static_params = eqx.partition(
Expand Down Expand Up @@ -162,6 +150,10 @@ def scipy_vg_fn(diff_params_flat):
ax.set_ylabel("Intensity (arb. units)")
ax.set_title("Electron Spectrum")
fig.savefig(os.path.join(td, "ThryE.png"), bbox_inches="tight")

# Save ThryE and ground_truth as text files
np.savetxt(os.path.join(td, "ThryE.txt"), ThryE)
np.savetxt(os.path.join(td, "ground_truth_ThryE.txt"), ground_truth["ThryE"])
mlflow.log_artifacts(td)

# np.testing.assert_allclose(ThryE, ground_truth["ThryE"], atol=0, rtol=0.2)
Expand Down
73 changes: 69 additions & 4 deletions tsadar/core/thomson_diagnostic.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import os

from jax import numpy as jnp, vmap

from tsadar.utils.data_handling.calibration import get_scattering_angles, get_calibrations
from tsadar.utils.data_handling.load_ts_data import loadData

from .modules import ThomsonParams
from .physics import irf
from .physics.generate_spectra import FitModel





class ThomsonScatteringDiagnostic:
"""
The SpectrumCalculator class wraps the FitModel class adding instrumental effects to the calculated spectrum so it
Expand All @@ -20,11 +27,13 @@ class ThomsonScatteringDiagnostic:
weights of each of the scattering angles in the final spectrum
"""

def __init__(self, cfg, scattering_angles):
def __init__(self, cfg, angular=False, cumulative=True):

super().__init__()
self.cfg = cfg
self.scattering_angles = scattering_angles
self.model = FitModel(cfg, scattering_angles)

self.cfg, self.scattering_angles = self.initialize_scattering_angles(cfg, angular, cumulative)

self.model = FitModel(cfg, self.scattering_angles)

if (
"temporal" in cfg["other"]["extraoptions"]["spectype"]
Expand Down Expand Up @@ -125,3 +134,59 @@ def __call__(self, ts_params: ThomsonParams, batch):
ThryI = ThryI + batch["noise_i"]

return ThryE, ThryI, lamAxisE, lamAxisI

def initialize_scattering_angles(self, config, angular, cumulative):
"""
Initializes scattering angles and weights.

Args:
config: Configuration dictionary

Returns:
Updated configuration dictionary with scattering angles and weights
"""
if cumulative:
custom_path = None
if "filenames" in config["data"].keys():
if config["data"]["filenames"]["epw"] is not None:
custom_path = os.path.dirname(config["data"]["filenames"]["epw-local"])

if config["data"]["filenames"]["iaw"] is not None:
custom_path = os.path.dirname(config["data"]["filenames"]["iaw-local"])

[elecData, ionData, xlab, t0, config["other"]["extraoptions"]["spectype"]] = loadData(
config["data"]["shotnum"], config["data"]["shotDay"], config["other"]["extraoptions"], custom_path=custom_path
)
scattering_angles = get_scattering_angles(config)

else:
config["other"]["lamrangE"] = [
config["data"]["fit_rng"]["forward_epw_start"],
config["data"]["fit_rng"]["forward_epw_end"],
]
config["other"]["lamrangI"] = [
config["data"]["fit_rng"]["forward_iaw_start"],
config["data"]["fit_rng"]["forward_iaw_end"],
]
config["other"]["npts"] = int(config["other"]["CCDsize"][1] * config["other"]["points_per_pixel"])
scattering_angles = get_scattering_angles(config)

if angular:
[axisxE, _, _, _, _, _] = get_calibrations(
104000, config["other"]["extraoptions"]["spectype"], 0.0, config["other"]["CCDsize"]
) # shot number hardcoded to get calibration
config["other"]["extraoptions"]["spectype"] = "angular_full"

scattering_angles["angAxis"] = axisxE

return config, scattering_angles

def get_cfg(self):
"""
Getter method for the cfg attribute

Returns:
The configuration dictionary
"""

return self.cfg
2 changes: 1 addition & 1 deletion tsadar/inverse/loss_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, cfg: Dict, scattering_angles, dummy_batch):

############

self.ts_diag = ThomsonScatteringDiagnostic(cfg, scattering_angles=scattering_angles)
self.ts_diag = ThomsonScatteringDiagnostic(cfg, angular=False)

self._loss_ = filter_jit(self.__loss__)
self._vg_func_ = filter_jit(filter_value_and_grad(self.__loss__, has_aux=True))
Expand Down
2 changes: 1 addition & 1 deletion tsadar/utils/data_handling/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def sa_lookup(beam):
),
)
else:
raise NotImplmentedError("Other probe geometrries are not yet supported")
raise NotImplementedError("Other probe geometrries are not yet supported")

return sa

Expand Down
Loading