diff --git a/.gitignore b/.gitignore index 39f40812..f36798b7 100644 --- a/.gitignore +++ b/.gitignore @@ -237,3 +237,4 @@ fabric.properties /mlruns .vscode/launch.json +myenv/ diff --git a/tests/test_forward/test_1d.py b/tests/test_forward/test_1d.py index 72faef65..6de3c85a 100644 --- a/tests/test_forward/test_1d.py +++ b/tests/test_forward/test_1d.py @@ -11,7 +11,6 @@ from tsadar.utils import misc from tsadar.core.thomson_diagnostic import ThomsonScatteringDiagnostic from tsadar.core.modules import ThomsonParams -from tsadar.utils.data_handling.calibration import get_scattering_angles def test_1d_forward_pass(): @@ -39,17 +38,6 @@ def test_1d_forward_pass(): defaults.update(flatten(inputs)) config = unflatten(defaults) - # get scattering angles and weights - config["other"]["lamrangE"] = [ - config["data"]["fit_rng"]["forward_epw_start"], - config["data"]["fit_rng"]["forward_epw_end"], - ] - config["other"]["lamrangI"] = [ - config["data"]["fit_rng"]["forward_iaw_start"], - config["data"]["fit_rng"]["forward_iaw_end"], - ] - config["other"]["npts"] = int(config["other"]["CCDsize"][1] * config["other"]["points_per_pixel"]) - sas = get_scattering_angles(config) dummy_batch = { "i_data": np.array([1]), @@ -60,7 +48,8 @@ def test_1d_forward_pass(): "i_amps": np.array([1]), } - ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas) + ts_diag = ThomsonScatteringDiagnostic(config, angular=False, cumulative=False) + config = ts_diag.get_cfg() ts_params = ThomsonParams(config["parameters"], num_params=1, batch=True, activate=True) ThryE, ThryI, lamAxisE, lamAxisI = ts_diag(ts_params, dummy_batch) diff --git a/tests/test_forward/test_angular_1d.py b/tests/test_forward/test_angular_1d.py index a9ab9370..3696ee07 100644 --- a/tests/test_forward/test_angular_1d.py +++ b/tests/test_forward/test_angular_1d.py @@ -11,8 +11,6 @@ from tsadar.core.thomson_diagnostic import ThomsonScatteringDiagnostic from tsadar.core.modules import ThomsonParams -from tsadar.utils.data_handling.calibration import get_scattering_angles, get_calibrations - def test_arts1d_forward_pass(): """ @@ -39,25 +37,6 @@ def test_arts1d_forward_pass(): defaults.update(flatten(inputs)) config = unflatten(defaults) - # get scattering angles and weights - config["other"]["lamrangE"] = [ - config["data"]["fit_rng"]["forward_epw_start"], - config["data"]["fit_rng"]["forward_epw_end"], - ] - config["other"]["lamrangI"] = [ - config["data"]["fit_rng"]["forward_iaw_start"], - config["data"]["fit_rng"]["forward_iaw_end"], - ] - config["other"]["npts"] = int(config["other"]["CCDsize"][1] * config["other"]["points_per_pixel"]) - sas = get_scattering_angles(config) - - [axisxE, _, _, _, _, _] = get_calibrations( - 104000, config["other"]["extraoptions"]["spectype"], 0.0, config["other"]["CCDsize"] - ) # shot number hardcoded to get calibration - config["other"]["extraoptions"]["spectype"] = "angular_full" - - sas["angAxis"] = axisxE - dummy_batch = { "i_data": np.ones((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1])), "e_data": np.ones((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1])), @@ -67,7 +46,8 @@ def test_arts1d_forward_pass(): "i_amps": np.array([1]), } - ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas) + ts_diag = ThomsonScatteringDiagnostic(config, angular=True, cumulative=False) + config = ts_diag.get_cfg() ts_params = ThomsonParams(config["parameters"], num_params=1, batch=False, activate=True) ThryE, ThryI, lamAxisE, lamAxisI = ts_diag(ts_params, dummy_batch) # np.save("tests/test_forward/ThryE-arts1d.npy", ThryE) diff --git a/tests/test_forward/test_angular_2d.py b/tests/test_forward/test_angular_2d.py index 47bda087..6fa65072 100644 --- a/tests/test_forward/test_angular_2d.py +++ b/tests/test_forward/test_angular_2d.py @@ -77,7 +77,7 @@ def test_arts2d_forward_pass(): "i_amps": np.array([1]), } - ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas) + ts_diag = ThomsonScatteringDiagnostic(config, angular=True, cumulative=False) ts_params = ThomsonParams(config["parameters"], num_params=1, batch=False) ThryE, ThryI, lamAxisE, lamAxisI = ts_diag(ts_params, dummy_batch) # np.save("tests/test_forward/ThryE-arts2d.npy", ThryE) diff --git a/tests/test_inverse/test_1d_random.py b/tests/test_inverse/test_1d_random.py index 2675f9a1..fe2d647d 100644 --- a/tests/test_inverse/test_1d_random.py +++ b/tests/test_inverse/test_1d_random.py @@ -14,7 +14,6 @@ from tsadar.utils import misc from tsadar.core.thomson_diagnostic import ThomsonScatteringDiagnostic from tsadar.core.modules import ThomsonParams, get_filter_spec -from tsadar.utils.data_handling.calibration import get_scattering_angles def _perturb_params_(rng, params): @@ -79,18 +78,6 @@ def test_1d_inverse(): defaults.update(flatten(inputs)) config = unflatten(defaults) - # get scattering angles and weights - config["other"]["lamrangE"] = [ - config["data"]["fit_rng"]["forward_epw_start"], - config["data"]["fit_rng"]["forward_epw_end"], - ] - config["other"]["lamrangI"] = [ - config["data"]["fit_rng"]["forward_iaw_start"], - config["data"]["fit_rng"]["forward_iaw_end"], - ] - config["other"]["npts"] = int(config["other"]["CCDsize"][1] * config["other"]["points_per_pixel"]) - sas = get_scattering_angles(config) - dummy_batch = { "i_data": np.array([1]), "e_data": np.array([1]), @@ -100,7 +87,8 @@ def test_1d_inverse(): "i_amps": np.array([1]), } rng = np.random.default_rng() - ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas) + ts_diag = ThomsonScatteringDiagnostic(config, angular=False, cumulative=False) + config = ts_diag.get_cfg() config["parameters"] = _perturb_params_(rng, config["parameters"]) misc.log_mlflow(config) ts_params_gt = ThomsonParams(config["parameters"], num_params=1, batch=True, activate=True) @@ -109,7 +97,7 @@ def test_1d_inverse(): loss = 1 while np.nan_to_num(loss, nan=1) > 1e-3: - ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas) + ts_diag = ThomsonScatteringDiagnostic(config, angular=False, cumulative=False) config["parameters"] = _perturb_params_(rng, config["parameters"]) ts_params_fit = ThomsonParams(config["parameters"], num_params=1, batch=True, activate=True) diff_params, static_params = eqx.partition( @@ -162,6 +150,10 @@ def scipy_vg_fn(diff_params_flat): ax.set_ylabel("Intensity (arb. units)") ax.set_title("Electron Spectrum") fig.savefig(os.path.join(td, "ThryE.png"), bbox_inches="tight") + + # Save ThryE and ground_truth as text files + np.savetxt(os.path.join(td, "ThryE.txt"), ThryE) + np.savetxt(os.path.join(td, "ground_truth_ThryE.txt"), ground_truth["ThryE"]) mlflow.log_artifacts(td) # np.testing.assert_allclose(ThryE, ground_truth["ThryE"], atol=0, rtol=0.2) diff --git a/tsadar/core/thomson_diagnostic.py b/tsadar/core/thomson_diagnostic.py index 83b9fcb0..7f32aa29 100644 --- a/tsadar/core/thomson_diagnostic.py +++ b/tsadar/core/thomson_diagnostic.py @@ -1,11 +1,18 @@ +import os + from jax import numpy as jnp, vmap +from tsadar.utils.data_handling.calibration import get_scattering_angles, get_calibrations +from tsadar.utils.data_handling.load_ts_data import loadData from .modules import ThomsonParams from .physics import irf from .physics.generate_spectra import FitModel + + + class ThomsonScatteringDiagnostic: """ The SpectrumCalculator class wraps the FitModel class adding instrumental effects to the calculated spectrum so it @@ -20,11 +27,13 @@ class ThomsonScatteringDiagnostic: weights of each of the scattering angles in the final spectrum """ - def __init__(self, cfg, scattering_angles): + def __init__(self, cfg, angular=False, cumulative=True): + super().__init__() - self.cfg = cfg - self.scattering_angles = scattering_angles - self.model = FitModel(cfg, scattering_angles) + + self.cfg, self.scattering_angles = self.initialize_scattering_angles(cfg, angular, cumulative) + + self.model = FitModel(cfg, self.scattering_angles) if ( "temporal" in cfg["other"]["extraoptions"]["spectype"] @@ -125,3 +134,59 @@ def __call__(self, ts_params: ThomsonParams, batch): ThryI = ThryI + batch["noise_i"] return ThryE, ThryI, lamAxisE, lamAxisI + + def initialize_scattering_angles(self, config, angular, cumulative): + """ + Initializes scattering angles and weights. + + Args: + config: Configuration dictionary + + Returns: + Updated configuration dictionary with scattering angles and weights + """ + if cumulative: + custom_path = None + if "filenames" in config["data"].keys(): + if config["data"]["filenames"]["epw"] is not None: + custom_path = os.path.dirname(config["data"]["filenames"]["epw-local"]) + + if config["data"]["filenames"]["iaw"] is not None: + custom_path = os.path.dirname(config["data"]["filenames"]["iaw-local"]) + + [elecData, ionData, xlab, t0, config["other"]["extraoptions"]["spectype"]] = loadData( + config["data"]["shotnum"], config["data"]["shotDay"], config["other"]["extraoptions"], custom_path=custom_path + ) + scattering_angles = get_scattering_angles(config) + + else: + config["other"]["lamrangE"] = [ + config["data"]["fit_rng"]["forward_epw_start"], + config["data"]["fit_rng"]["forward_epw_end"], + ] + config["other"]["lamrangI"] = [ + config["data"]["fit_rng"]["forward_iaw_start"], + config["data"]["fit_rng"]["forward_iaw_end"], + ] + config["other"]["npts"] = int(config["other"]["CCDsize"][1] * config["other"]["points_per_pixel"]) + scattering_angles = get_scattering_angles(config) + + if angular: + [axisxE, _, _, _, _, _] = get_calibrations( + 104000, config["other"]["extraoptions"]["spectype"], 0.0, config["other"]["CCDsize"] + ) # shot number hardcoded to get calibration + config["other"]["extraoptions"]["spectype"] = "angular_full" + + scattering_angles["angAxis"] = axisxE + + return config, scattering_angles + + def get_cfg(self): + """ + Getter method for the cfg attribute + + Returns: + The configuration dictionary + """ + + return self.cfg \ No newline at end of file diff --git a/tsadar/inverse/loss_function.py b/tsadar/inverse/loss_function.py index 50b962a8..46ece644 100644 --- a/tsadar/inverse/loss_function.py +++ b/tsadar/inverse/loss_function.py @@ -45,7 +45,7 @@ def __init__(self, cfg: Dict, scattering_angles, dummy_batch): ############ - self.ts_diag = ThomsonScatteringDiagnostic(cfg, scattering_angles=scattering_angles) + self.ts_diag = ThomsonScatteringDiagnostic(cfg, angular=False) self._loss_ = filter_jit(self.__loss__) self._vg_func_ = filter_jit(filter_value_and_grad(self.__loss__, has_aux=True)) diff --git a/tsadar/utils/data_handling/calibration.py b/tsadar/utils/data_handling/calibration.py index b46f502a..1ea8ee67 100644 --- a/tsadar/utils/data_handling/calibration.py +++ b/tsadar/utils/data_handling/calibration.py @@ -189,7 +189,7 @@ def sa_lookup(beam): ), ) else: - raise NotImplmentedError("Other probe geometrries are not yet supported") + raise NotImplementedError("Other probe geometrries are not yet supported") return sa