Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 34 additions & 30 deletions src/pipelines/arterial_waveform_shape_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ class ArterialSegExample(ProcessPipeline):
v_raw_segment_input = (
"/Artery/VelocityPerBeat/Segments/VelocitySignalPerBeatPerSegment/value"
)
v_band_segment_input = "/Artery/VelocityPerBeat/Segments/VelocitySignalPerBeatPerSegmentBandLimited/value"
v_band_segment_input = (
"/Artery/VelocityPerBeat/Segments/VelocitySignalPerBeatPerSegmentBandLimited/value"
)

v_raw_global_input = "/Artery/VelocityPerBeat/VelocitySignalPerBeat/value"
v_band_global_input = (
Expand Down Expand Up @@ -305,14 +307,12 @@ def _rho_h_90_support_from_harmonics(self, V: np.ndarray) -> dict:
return out

w = power / s
C = np.cumsum(w) # C(1), ..., C(H)
C = np.cumsum(w)
h = np.arange(1, H + 1, dtype=float)

# stockage discret
out["harmonic_energy_cumsum"][:H] = C
out["harmonic_energy_cumsum_h"][:H] = h

# interpolation continue avec convention C(0)=0
C_full = np.concatenate(([0.0], C))
h_full = np.arange(0, H + 1, dtype=float)

Expand Down Expand Up @@ -788,7 +788,6 @@ def _compute_graphics_support_1d(self, v: np.ndarray, Tbeat: float) -> dict:
vmin = float(np.nanmin(vv))
vmean = float(np.nanmean(vv))

# Cumulative displacement geometry sampled on normalized phase
d_full = np.concatenate(
([0.0], np.cumsum(np.where(np.isfinite(vv), vv, 0.0)) / m0_sum)
)
Expand Down Expand Up @@ -820,8 +819,8 @@ def _compute_graphics_support_1d(self, v: np.ndarray, Tbeat: float) -> dict:
E_high = np.nan

if V is not None and H >= 0:
mags = np.abs(V[: H + 1]) # indices 0..H
power = mags**2 # |V_n|^2
mags = np.abs(V[: H + 1])
power = mags**2
harmonic_energies[: H + 1] = power
harmonic_magnitudes[: H + 1] = mags

Expand All @@ -839,11 +838,9 @@ def _compute_graphics_support_1d(self, v: np.ndarray, Tbeat: float) -> dict:
E_low = float(np.nansum(power[1 : self.H_LOW_MAX + 1]))
E_high = float(np.nansum(power[self.H_HIGH_MIN : self.H_HIGH_MAX + 1]))

# poids énergie : définis seulement sur n>=1
if np.isfinite(power_sum) and power_sum > 0:
harmonic_energy_weights[0:H] = power_h / (power_sum + self.eps)

# poids amplitude : définis seulement sur n>=1
if np.isfinite(mag_sum) and mag_sum > 0:
harmonic_weights[0:H] = mags_h / (mag_sum + self.eps)

Expand Down Expand Up @@ -1296,21 +1293,21 @@ def _metric_keys() -> list[list]:
def _compute_block_segment(self, v_block: np.ndarray, T: np.ndarray):
"""
v_block: (n_t, n_beats, n_branches, n_radii)

Returns:
per-segment arrays: (n_beats, n_segments)
per-branch arrays: (n_beats, n_branches) (median over radii)
global arrays: (n_beats,) (mean over all branches & radii)
per-segment arrays: (n_beats, n_branches, n_radii)
per-branch arrays: (n_beats, n_branches) (median over radii)
global arrays: (n_beats,) (median over all branch-radius values)
"""
if v_block.ndim != 4:
raise ValueError(
f"Expected (n_t,n_beats,n_branches,n_radii), got {v_block.shape}"
)

n_t, n_beats, n_branches, n_radii = v_block.shape
n_segments = n_branches * n_radii

seg = {
k[0]: np.full((n_beats, n_segments), np.nan, dtype=float)
k[0]: np.full((n_beats, n_branches, n_radii), np.nan, dtype=float)
for k in self._metric_keys()
}
br = {
Expand All @@ -1333,25 +1330,25 @@ def _compute_block_segment(self, v_block: np.ndarray, T: np.ndarray):
v = v_block[:, beat_idx, branch_idx, radius_idx]
m = self._compute_metrics_1d(v, Tbeat)

seg_idx = branch_idx * n_radii + radius_idx
for k in self._metric_keys():
seg[k[0]][beat_idx, seg_idx] = m[k[0]]
br_vals[k[0]].append(m[k[0]])
gl_vals[k[0]].append(m[k[0]])
key = k[0]
seg[key][beat_idx, branch_idx, radius_idx] = m[key]
br_vals[key].append(m[key])
gl_vals[key].append(m[key])

for k in self._metric_keys():
br[k[0]][beat_idx, branch_idx] = self._safe_nanmedian(
np.asarray(br_vals[k[0]], dtype=float)
key = k[0]
br[key][beat_idx, branch_idx] = self._safe_nanmedian(
np.asarray(br_vals[key], dtype=float)
)

for k in self._metric_keys():
gl[k[0]][beat_idx] = self._safe_nanmean(
np.asarray(gl_vals[k[0]], dtype=float)
key = k[0]
gl[key][beat_idx] = self._safe_nanmedian(
np.asarray(gl_vals[key], dtype=float)
)

seg_order_note = (
"seg_idx = branch_idx * n_radii + radius_idx (branch-major flattening)"
)
seg_order_note = "segment arrays are stored as (beat, branch, radius)"
return seg, br, gl, n_branches, n_radii, seg_order_note

def _compute_block_global(self, v_global: np.ndarray, T: np.ndarray):
Expand Down Expand Up @@ -1461,12 +1458,18 @@ def pack(prefix: str, d: dict, attrs_common: dict):
pack(
"by_segment/bandlimited_segment",
seg_b,
{"segment_indexing": [seg_note]},
{
"definition": ["per-segment metrics stored as (beat, branch, radius)"],
"segment_indexing": [seg_note],
},
)
pack(
"by_segment/raw_segment",
seg_r,
{"segment_indexing": [seg_note]},
{
"definition": ["per-segment metrics stored as (beat, branch, radius)"],
"segment_indexing": [seg_note],
},
)

pack(
Expand All @@ -1483,12 +1486,12 @@ def pack(prefix: str, d: dict, attrs_common: dict):
pack(
"by_segment/bandlimited_global",
gl_b,
{"definition": ["mean over branches and radii"]},
{"definition": ["median over all branch-radius segment values per beat"]},
)
pack(
"by_segment/raw_global",
gl_r,
{"definition": ["mean over branches and radii"]},
{"definition": ["median over all branch-radius segment values per beat"]},
)

metrics["by_segment/params/ratio_R_VTI"] = np.asarray(
Expand Down Expand Up @@ -1570,6 +1573,7 @@ def pack(prefix: str, d: dict, attrs_common: dict):
metrics["global/params/H_PHASE_RESIDUAL"] = np.asarray(
self.H_PHASE_RESIDUAL, dtype=int
)

graphics_raw = self._compute_graphics_support_block(v_raw_gl, T)
graphics_band = self._compute_graphics_support_block(v_band_gl, T)
for name, arr in graphics_raw.items():
Expand All @@ -1578,4 +1582,4 @@ def pack(prefix: str, d: dict, attrs_common: dict):
for name, arr in graphics_band.items():
metrics[f"global/bandlimited/{name}"] = arr

return ProcessResult(metrics=metrics)
return ProcessResult(metrics=metrics)
Loading