Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions configs/1d/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,6 @@ parameters:
active: False
lb: -100.0
ub: 100.0
Va:
val: 0.0
angle: 0.0
active: False
lb: -20.5
ub: 20.5

other:
expandedions: False
Expand Down
69 changes: 63 additions & 6 deletions configs/1d/inputs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,6 @@ parameters:
num_grad_points: 1
ub: 15.0
val: 0.0
Va:
active: false
angle: 0.0
lb: -20.5
ub: 20.5
val: 0.0
amp1:
active: true
lb: 0.01
Expand Down Expand Up @@ -96,21 +90,84 @@ parameters:
lb: -20.0
ub: 20.0
val: 0.0

ion-1:
A:
active: false
val: 40.0
Ti:
active: True
lb: 0.01
same: false
ub: 1.0
val: 0.1
Z:
active: True
lb: 1.0
ub: 18.0
val: 8.0
Va:
active: True
angle: 0.0
lb: 0.0
ub: 40.5
val: 15.0
fract:
active: false
val: 1.0
type:
active: false
ion: null

ion-2:
A:
active: false
val: 40.0
Ti:
active: True
lb: 0.01
same: false
ub: 3.0
val: 0.2
Z:
active: True
lb: 1.0
ub: 18.0
val: 8.0
Va:
active: True
angle: 0.0
lb: 0.0
ub: 40.5
val: 20.0
fract:
active: false
val: 1.0
type:
active: false
ion: null

ion-3:
A:
active: false
val: 40.0
Ti:
active: True
lb: 0.01
same: false
ub: 1.0
val: 0.1
Z:
active: True
lb: 1.0
ub: 18.0
val: 8.0
Va:
active: True
angle: 0.0
lb: 0.0
ub: 40.5
val: 10.0
fract:
active: false
val: 1.0
12 changes: 6 additions & 6 deletions tests/configs/1d-defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ parameters:
fract:
val: 0.1
active: False
Va:
val: 0.0
angle: 0.0
active: False
lb: -20.5
ub: 20.5

general:
amp1:
Expand Down Expand Up @@ -69,12 +75,6 @@ parameters:
active: False
lb: -100.0
ub: 100.0
Va:
val: 0.0
angle: 0.0
active: False
lb: -20.5
ub: 20.5

other:
expandedions: False
Expand Down
13 changes: 7 additions & 6 deletions tests/configs/1d-inputs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ parameters:
fract:
val: 1.
active: False
Va:
val: 0.0
angle: 0.0
active: False
lb: -20.5
ub: 20.5

general:
amp1:
val: 1.
Expand Down Expand Up @@ -77,12 +84,6 @@ parameters:
active: False
lb: -10.0
ub: 10.0
Va:
val: 0.0
angle: 0.0
active: False
lb: -20.5
ub: 20.5


other:
Expand Down
12 changes: 6 additions & 6 deletions tests/configs/epw_defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ parameters:
fract:
val: 0.1
active: False
Va:
val: 0.0
angle: 0.0
active: False
lb: -20.5
ub: 20.5

general:
amp1:
Expand Down Expand Up @@ -77,12 +83,6 @@ parameters:
active: False
lb: -100.0
ub: 100.0
Va:
val: 0.0
angle: 0.0
active: False
lb: -20.5
ub: 20.5

other:
expandedions: False
Expand Down
13 changes: 7 additions & 6 deletions tests/configs/epw_inputs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ parameters:
val: 1.0
active: False

Va:
val: 0.0
angle: 0.0
active: True
lb: -20.5
ub: 20.5

general:
amp1:
val: 1.
Expand Down Expand Up @@ -78,12 +85,6 @@ parameters:
active: False
lb: -10.0
ub: 10.0
Va:
val: 0.0
angle: 0.0
active: True
lb: -20.5
ub: 20.5

data:
shotnum: 101675
Expand Down
12 changes: 6 additions & 6 deletions tests/configs/time_test_defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ parameters:
fract:
val: 0.1
active: False
Va:
val: 0.0
angle: 0.0
active: False
lb: -20.5
ub: 20.5

general:
amp1:
Expand Down Expand Up @@ -75,12 +81,6 @@ parameters:
active: False
lb: -100.0
ub: 100.0
Va:
val: 0.0
angle: 0.0
active: False
lb: -20.5
ub: 20.5

other:
expandedions: False
Expand Down
12 changes: 6 additions & 6 deletions tests/configs/time_test_inputs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ parameters:
fract:
val: 1.0
active: False
Va:
val: 0.0
angle: 0.0
active: False
lb: -20.5
ub: 20.5
general:
amp1:
val: 1.
Expand Down Expand Up @@ -76,12 +82,6 @@ parameters:
active: False
lb: -10.0
ub: 10.0
Va:
val: 0.0
angle: 0.0
active: False
lb: -20.5
ub: 20.5


other:
Expand Down
28 changes: 18 additions & 10 deletions tsadar/core/modules/ts_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,19 +133,22 @@ def __call__(self):
class IonParams(eqx.Module):
normed_Ti: Array
normed_Z: Array
normed_Va: Array #SB
fract: Array
Ti_scale: float
Ti_shift: float
Z_scale: float
Z_shift: float
Va_scale: float # SB
Va_shift: float # SB
A: int
act_funs: Dict[str, Callable]
inv_act_funs: Dict[str, Callable]

def __init__(self, cfg, batch_size, batch=True, activate=False):
super().__init__()
self.act_funs, self.inv_act_funs = {}, {}
for param in ["Ti", "Z"]:
for param in ["Ti", "Z", "Va"]: #SB
setattr(self, param + "_scale", cfg[param]["ub"] - cfg[param]["lb"])
setattr(self, param + "_shift", cfg[param]["lb"])
self.act_funs[param], self.inv_act_funs[param] = get_act_and_inv_act(cfg[param], activate)
Expand All @@ -159,11 +162,15 @@ def __init__(self, cfg, batch_size, batch=True, activate=False):
self.normed_Z = self.inv_act_funs["Z"](
jnp.full(batch_size, (cfg["Z"]["val"] - self.Z_shift) / self.Z_scale)
)
self.normed_Va = self.inv_act_funs["Va"](
jnp.full(batch_size, (cfg["Va"]["val"] - self.Va_shift) / self.Va_scale) # SB
)
self.A = jnp.full(batch_size, cfg["A"]["val"])
self.fract = self.inv_act_funs["fract"](jnp.full(batch_size, cfg["fract"]["val"]))
else:
self.normed_Ti = self.inv_act_funs["Ti"]((cfg["Ti"]["val"] - self.Ti_shift) / self.Ti_scale)
self.normed_Z = self.inv_act_funs["Z"]((cfg["Z"]["val"] - self.Z_shift) / self.Z_scale)
self.normed_Va = self.inv_act_funs["Va"]((cfg["Va"]["val"] - self.Va_shift) / self.Va_scale) # SB
self.A = cfg["A"]["val"]
self.fract = float(self.inv_act_funs["fract"](cfg["fract"]["val"]))

Expand All @@ -177,6 +184,7 @@ def __call__(self):
"fract": self.act_funs["fract"](self.fract),
"Ti": self.act_funs["Ti"](self.normed_Ti) * self.Ti_scale + self.Ti_shift,
"Z": self.act_funs["Z"](self.normed_Z) * self.Z_scale + self.Z_shift,
"Va": self.act_funs["Va"](self.normed_Va) * self.Va_scale + self.Va_shift,
}


Expand Down Expand Up @@ -213,7 +221,7 @@ class GeneralParams(eqx.Module):
normed_ne_gradient: Array
normed_Te_gradient: Array
normed_ud: Array
normed_Va: Array
#normed_Va: Array # SB
lam_scale: float
lam_shift: float
amp1_scale: float
Expand All @@ -228,24 +236,24 @@ class GeneralParams(eqx.Module):
Te_gradient_shift: float
ud_scale: float
ud_shift: float
Va_scale: float
Va_shift: float
#Va_scale: float # SB
#Va_shift: float # SB
act_funs: Dict[str, Callable]

def __init__(self, cfg, batch_size: int, batch=True, activate=False):
super().__init__()

# this is all a bit ugly but we use setattr instead of = to be able to use the for loop
self.act_funs, inv_act_funs = {}, {}
for param in ["lam", "amp1", "amp2", "amp3", "ne_gradient", "Te_gradient", "ud", "Va"]:
for param in ["lam", "amp1", "amp2", "amp3", "ne_gradient", "Te_gradient", "ud"]: # SB removed Va
self.act_funs[param], inv_act_funs[param] = get_act_and_inv_act(cfg[param], activate)
setattr(self, param + "_scale", cfg[param]["ub"] - cfg[param]["lb"])
setattr(self, param + "_shift", cfg[param]["lb"])

# this is where the linear and nonlinear transformations are applied i.e.
# the rescaling and the activation function
if batch:
for param in ["lam", "amp1", "amp2", "amp3", "ne_gradient", "Te_gradient", "ud", "Va"]:
for param in ["lam", "amp1", "amp2", "amp3", "ne_gradient", "Te_gradient", "ud"]: # SB removed Va
setattr(
self,
"normed_" + param,
Expand All @@ -257,7 +265,7 @@ def __init__(self, cfg, batch_size: int, batch=True, activate=False):
),
)
else:
for param in ["lam", "amp1", "amp2", "amp3", "ne_gradient", "Te_gradient", "ud", "Va"]:
for param in ["lam", "amp1", "amp2", "amp3", "ne_gradient", "Te_gradient", "ud"]: # SB removed Va
setattr(
self,
"normed_" + param,
Expand All @@ -281,7 +289,7 @@ def __call__(self):
self.act_funs["Te_gradient"](self.normed_Te_gradient) * self.Te_gradient_scale + self.Te_gradient_shift
)
unnormed_ud = self.act_funs["ud"](self.normed_ud) * self.ud_scale + self.ud_shift
unnormed_Va = self.act_funs["Va"](self.normed_Va) * self.Va_scale + self.Va_shift
# unnormed_Va = self.act_funs["Va"](self.normed_Va) * self.Va_scale + self.Va_shift # SB

return {
"lam": unnormed_lam,
Expand All @@ -291,7 +299,7 @@ def __call__(self):
"ne_gradient": unnormed_ne_gradient,
"Te_gradient": unnormed_Te_gradient,
"ud": unnormed_ud,
"Va": unnormed_Va,
#"Va": unnormed_Va, # SB
}


Expand Down Expand Up @@ -375,7 +383,7 @@ def get_filter_spec(cfg_params: Dict, ts_params: ThomsonParams) -> Dict:
if key == "fe":
filter_spec = get_distribution_filter_spec(filter_spec, dist_params=_params)
else:
nkey = f"normed_{key}"
nkey = f"normed_{key}"# if key!="fract" else f"{key}" # SB treat fractions differently
if "ion" in species:
filter_spec = eqx.tree_at(
lambda tree: getattr(getattr(tree, "ions")[ion_num - 1], nkey),
Expand Down
Loading
Loading