diff --git a/.gitignore b/.gitignore index 5bf1196..162185b 100644 --- a/.gitignore +++ b/.gitignore @@ -7,5 +7,5 @@ fixed_wheels/ .vscode/ .DS_Store *.so -/**/*_ui.py Visualization_Results/ +/**/*_ui.py diff --git a/engines/ceus b/engines/ceus index 97d3d3a..fb0b427 160000 --- a/engines/ceus +++ b/engines/ceus @@ -1 +1 @@ -Subproject commit 97d3d3a8b03ee02bc5bda8dca2f1f316151d8210 +Subproject commit fb0b427fc6725d9bda6c485ff91b2a4055b2ffcb diff --git a/src/ceus/analysis_loading/analysis_loading_controller.py b/src/ceus/analysis_loading/analysis_loading_controller.py index 8ddbc27..98b0558 100644 --- a/src/ceus/analysis_loading/analysis_loading_controller.py +++ b/src/ceus/analysis_loading/analysis_loading_controller.py @@ -10,7 +10,8 @@ from ..mvc.base_controller import BaseController from .analysis_loading_view_coordinator import AnalysisLoadingViewCoordinator -from engines.ceus.src.data_objs import UltrasoundImage, CeusSeg +from engines.ceus.src.data_objs.image import UltrasoundImage +from engines.ceus.src.data_objs.seg import CeusSeg from engines.ceus.src.time_series_analysis.curves.framework import CurvesAnalysis @@ -64,20 +65,33 @@ def _connect_signals(self) -> None: def _setup_analysis_options(self) -> None: """Setup available analysis types and functions in the view.""" analysis_types, analysis_functions = self._model.get_analysis_types() + print(f"DEBUG: Available analysis types: {list(analysis_types.keys())}") - # Automatically select "Paramap" as the analysis type - paramap_type = "paramap" - if paramap_type in analysis_types: - self._selected_analysis_type = paramap_type - if self._model.set_analysis_type(paramap_type): - # Get available functions for Paramap analysis - available_functions = self._model.get_analysis_functions(paramap_type) + # Automatically select the best available analysis type + # Prefer curves_paramap, then curves, or just the first available one + selected_type = None + for preferred in ["curves_paramap", "curves", "paramap"]: + if preferred in analysis_types: + selected_type = preferred + break + if preferred in analysis_types: + selected_type = preferred + break + + if not selected_type and analysis_types: + selected_type = list(analysis_types.keys())[0] + + if selected_type: + self._selected_analysis_type = selected_type + if self._model.set_analysis_type(selected_type): + # Get available functions for selected analysis type + available_functions = self._model.get_analysis_functions(selected_type) # Skip analysis type selection and go directly to function selection self._view_coordinator.show_function_selection(available_functions) else: - self._view_coordinator.show_error("Failed to set Paramap analysis type") + self._view_coordinator.show_error(f"Failed to set {selected_type} analysis type") else: - self._view_coordinator.show_error("Paramap analysis type not available") + self._view_coordinator.show_error("No analysis types available") def _on_user_action(self, action_name: str, action_data: Any) -> None: """ @@ -98,7 +112,7 @@ def _on_user_action(self, action_name: str, action_data: Any) -> None: print(f"DEBUG: Controller received analysis_execution_started action") print(f"DEBUG: action_data = {action_data}") self._handle_analysis_execution(action_data) - elif action_name == "analysis_completed": + elif action_name == "analysis_loading_completed": self._handle_analysis_completion(action_data) else: # Forward unknown actions to application controller @@ -303,5 +317,13 @@ def analysis_data(self) -> Optional[CurvesAnalysis]: return self._analysis_data def cleanup(self) -> None: - """Clean up resources.""" - self.model.cleanup() + """Clean up resources and disconnect signals.""" + try: + self._model.analysis_completed.disconnect(self._on_analysis_completed) + except (TypeError, RuntimeError): + pass + try: + self._model.error_occurred.disconnect(self._on_analysis_error) + except (TypeError, RuntimeError): + pass + self._model.cleanup() diff --git a/src/ceus/analysis_loading/analysis_loading_view_coordinator.py b/src/ceus/analysis_loading/analysis_loading_view_coordinator.py index a87ade9..7e969ff 100644 --- a/src/ceus/analysis_loading/analysis_loading_view_coordinator.py +++ b/src/ceus/analysis_loading/analysis_loading_view_coordinator.py @@ -10,12 +10,13 @@ from PyQt6.QtWidgets import QWidget, QStackedWidget from PyQt6.QtCore import pyqtSignal -from quantus.gui.mvc.base_view import BaseViewMixin +from ..mvc.base_view import BaseViewMixin from .views.analysis_function_selection_widget import AnalysisFunctionSelectionWidget -from quantus.gui.config_loading.views.analysis_params_widget import AnalysisParamsWidget +from .views.analysis_params_widget import AnalysisParamsWidget from .views.analysis_execution_widget import AnalysisExecutionWidget -from quantus.data_objs import UltrasoundRfImage, BmodeSeg, RfAnalysisConfig -from quantus.analysis.paramap.framework import ParamapAnalysis +from engines.ceus.src.data_objs.image import UltrasoundImage +from engines.ceus.src.data_objs.seg import CeusSeg +from engines.ceus.src.time_series_analysis.curves.framework import CurvesAnalysis class AnalysisLoadingViewCoordinator(QStackedWidget, BaseViewMixin): @@ -40,7 +41,7 @@ class AnalysisLoadingViewCoordinator(QStackedWidget, BaseViewMixin): # ============================================================================ - def __init__(self, image_data: UltrasoundRfImage, seg_data: BmodeSeg, config_data: RfAnalysisConfig, parent: Optional[QWidget] = None): + def __init__(self, image_data: UltrasoundImage, seg_data: CeusSeg, config_data, parent: Optional[QWidget] = None): super().__init__(parent) self.__init_base_view__(parent) self._image_data = image_data @@ -48,11 +49,6 @@ def __init__(self, image_data: UltrasoundRfImage, seg_data: BmodeSeg, config_dat self._config_data = config_data print(f"DEBUG: AnalysisLoadingViewCoordinator - image_data = {image_data is not None}") - if image_data is not None: - print(f"DEBUG: AnalysisLoadingViewCoordinator - scan_name = {image_data.scan_name}") - print(f"DEBUG: AnalysisLoadingViewCoordinator - phantom_name = {image_data.phantom_name}") - else: - print(f"DEBUG: AnalysisLoadingViewCoordinator - image_data is None!") # Widget instances self._function_selection_widget: Optional[AnalysisFunctionSelectionWidget] = None @@ -63,7 +59,7 @@ def __init__(self, image_data: UltrasoundRfImage, seg_data: BmodeSeg, config_dat self._selected_analysis_type: Optional[str] = None self._selected_functions: List[str] = [] self._analysis_params: dict = {} - self._analysis_data: Optional[ParamapAnalysis] = None + self._analysis_data: Optional[CurvesAnalysis] = None # Note: Analysis type selection is now skipped - Paramap is automatically selected # The controller will call show_function_selection directly @@ -106,12 +102,16 @@ def show_error(self, error_message: str) -> None: error_message: Error message to display """ current_widget: BaseViewMixin = self.currentWidget() - current_widget.show_error(error_message) + if current_widget: + current_widget.show_error(error_message) + else: + print(f"ERROR (no active widget): {error_message}") def clear_error(self) -> None: """Clear error message in the current widget.""" current_widget: BaseViewMixin = self.currentWidget() - current_widget.clear_error() + if current_widget: + current_widget.clear_error() # ============================================================================ # NAVIGATION METHODS - Methods to show different widgets @@ -160,13 +160,19 @@ def show_params_configuration(self, required_params: List[str], selected_functio print(f"DEBUG: Creating AnalysisParamsWidget with image_data = {self._image_data is not None}") if self._image_data is not None: print(f"DEBUG: Passing scan_name = {self._image_data.scan_name}") - print(f"DEBUG: Passing phantom_name = {self._image_data.phantom_name}") + if hasattr(self._image_data, 'phantom_name'): + print(f"DEBUG: Passing phantom_name = {self._image_data.phantom_name}") self._params_widget = AnalysisParamsWidget(self._image_data, self._seg_data, self._config_data) self._params_widget.setup_ui() self._params_widget.connect_signals() self._params_widget.params_configured.connect(self._on_params_configured) self._params_widget.back_requested.connect(self._on_params_back) self.addWidget(self._params_widget) + else: + # Update data in existing widget + self._params_widget._image_data = self._image_data + self._params_widget._seg_data = self._seg_data + self._params_widget._config_data = self._config_data print(f"DEBUG: Calling set_required_params...") self._params_widget.set_required_params(required_params, selected_functions) @@ -198,6 +204,10 @@ def show_analysis_execution(self, execution_summary: Dict) -> None: print(f"DEBUG: AnalysisExecutionWidget created and added to stack") else: print(f"DEBUG: Using existing AnalysisExecutionWidget") + # Update data in existing widget + self._execution_widget._image_data = self._image_data + self._execution_widget._seg_data = self._seg_data + self._execution_widget._config_data = self._config_data print(f"DEBUG: Setting execution summary...") self._execution_widget.set_execution_summary(execution_summary) @@ -208,7 +218,7 @@ def show_analysis_execution(self, execution_summary: Dict) -> None: self._execution_widget.clear_error() print(f"DEBUG: show_analysis_execution completed - execution screen should be visible") - def show_analysis_results(self, analysis_data: ParamapAnalysis) -> None: + def show_analysis_results(self, analysis_data: CurvesAnalysis) -> None: """ Show analysis results in the execution widget. @@ -260,7 +270,7 @@ def _on_execution_started(self, execution_data: dict) -> None: self._emit_user_action("analysis_execution_started", execution_data) print(f"DEBUG: user_action signal emitted") - def _on_analysis_confirmed(self, analysis_data: ParamapAnalysis) -> None: + def _on_analysis_confirmed(self, analysis_data: CurvesAnalysis) -> None: """ Handle analysis completion confirmation. diff --git a/src/ceus/analysis_loading/ui/analysis_params.ui b/src/ceus/analysis_loading/ui/analysis_params.ui new file mode 100644 index 0000000..3aae63c --- /dev/null +++ b/src/ceus/analysis_loading/ui/analysis_params.ui @@ -0,0 +1,676 @@ + + + analysisParams + + + + 0 + 0 + 1284 + 803 + + + + + 0 + 0 + + + + Analysis Parameters Configuration + + + QWidget { + background: rgb(42, 42, 42); +} + + + + + 60 + 20 + 951 + 731 + + + + + + + 0 + + + QLayout::SetMaximumSize + + + + + + 341 + 601 + + + + + 241 + 601 + + + + <html><head/><body><p><br/></p></body></html> + + + QWidget { + background-color: rgb(28, 0, 101); +} + + + + + 0 + 0 + 341 + 121 + + + + + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(99, 0, 174); + border: 1px solid black; +} + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + 70 + 0 + 191 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Image Selection: + + + Qt::AlignCenter + + + + + + -60 + 40 + 191 + 51 + + + + QLabel { + font-size: 16px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Image: + + + Qt::AlignCenter + + + + + + -50 + 70 + 191 + 51 + + + + QLabel { + font-size: 16px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold +} + + + Phantom: + + + Qt::AlignCenter + + + + + + 100 + 40 + 241 + 51 + + + + QLabel { + font-size: 14px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; +} + + + Sample filename + + + Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + + + + + + 100 + 70 + 241 + 51 + + + + QLabel { + font-size: 14px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; +} + + + Sample filename + + + Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + + + + + + + 0 + 120 + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(99, 0, 174); + border: 1px solid black; +} + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + 0 + 40 + 341 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Segmentation Selection + + + Qt::AlignCenter + + + + + + + 0 + 240 + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(99, 0, 174); + border: 1px solid black; +} + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + 0 + 30 + 341 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight:bold; +} + + + Analysis Parameter Selection + + + Qt::AlignCenter + + + + + + + 0 + 360 + 341 + 121 + + + + + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(99, 0, 174); + border: 1px solid black; +} + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + 0 + 30 + 341 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Radio Frequency Data Analysis + + + Qt::AlignCenter + + + + + + + 0 + 480 + 341 + 121 + + + + + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(49, 0, 124); + border: 1px solid black; +} + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + 20 + 30 + 301 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Visualization / Export + + + Qt::AlignCenter + + + + + + + + + + 341 + 16777215 + + + + QFrame { + background-color: rgb(28, 0, 101); +} + + + + QLayout::SetMinAndMaxSize + + + 10 + + + 10 + + + 10 + + + 10 + + + + + Qt::Vertical + + + + 20 + 40 + + + + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + + + 131 + 41 + + + + + 131 + 41 + + + + QPushButton { + color: white; + font-size: 16px; + background: rgb(90, 37, 255); + border-radius: 15px; +} + + + Back + + + + + + + + + + + + 50 + + + 30 + + + 10 + + + 30 + + + 10 + + + + + QLabel { + font-size: 29px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + Analysis in Progress... + + + Qt::PlainText + + + false + + + true + + + + + + + QLabel { + font-size: 29px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + Configure Analysis Parameters: + + + Qt::PlainText + + + false + + + true + + + + + + + true + + + + + 0 + 0 + 407 + 294 + + + + + + + + + + QLabel { + color: rgb(0, 255, 0); + font-size: 20px; + background-color: rgba(255, 255, 255, 0); +} + + + Running Analysis.... + + + + + + + + 131 + 41 + + + + + 131 + 41 + + + + QPushButton { + color: white; + font-size: 16px; + background: rgb(90, 37, 255); + border-radius: 15px; +} + + + Run Analysis + + + + + + + + 20 + 40 + + + + + + + + + + + + diff --git a/src/ceus/analysis_loading/views/analysis_execution_widget.py b/src/ceus/analysis_loading/views/analysis_execution_widget.py index 08d5b76..c5a408c 100644 --- a/src/ceus/analysis_loading/views/analysis_execution_widget.py +++ b/src/ceus/analysis_loading/views/analysis_execution_widget.py @@ -10,10 +10,11 @@ from PyQt6.QtCore import pyqtSignal, Qt, QTimer from PyQt6.QtGui import QFont -from quantus.gui.mvc.base_view import BaseViewMixin -from quantus.gui.analysis_loading.ui.analysis_execution_ui import Ui_analysisExecution -from quantus.data_objs import UltrasoundRfImage, BmodeSeg, RfAnalysisConfig -from quantus.analysis.paramap.framework import ParamapAnalysis +from ...mvc.base_view import BaseViewMixin +from ..ui.analysis_execution_ui import Ui_analysisExecution +from engines.ceus.src.data_objs.image import UltrasoundImage +from engines.ceus.src.data_objs.seg import CeusSeg +from engines.ceus.src.time_series_analysis.curves.framework import CurvesAnalysis class AnalysisExecutionWidget(QWidget, BaseViewMixin): @@ -26,11 +27,11 @@ class AnalysisExecutionWidget(QWidget, BaseViewMixin): # Signals for communicating with controller execution_started = pyqtSignal(dict) # execution_data - analysis_confirmed = pyqtSignal(object) # analysis_data (ParamapAnalysis) + analysis_confirmed = pyqtSignal(object) # analysis_data (CurvesAnalysis) close_requested = pyqtSignal() back_requested = pyqtSignal() - def __init__(self, image_data: UltrasoundRfImage, seg_data: BmodeSeg, config_data: RfAnalysisConfig, parent: Optional[QWidget] = None): + def __init__(self, image_data: UltrasoundImage, seg_data: CeusSeg, config_data, parent: Optional[QWidget] = None): QWidget.__init__(self, parent) self.__init_base_view__(parent) self._ui = Ui_analysisExecution() @@ -40,7 +41,7 @@ def __init__(self, image_data: UltrasoundRfImage, seg_data: BmodeSeg, config_dat # Current state self._execution_summary: Dict = {} - self._analysis_data: Optional[ParamapAnalysis] = None + self._analysis_data: Optional[CurvesAnalysis] = None self._is_executing = False self._results_shown = False # Track if results have been shown @@ -78,8 +79,8 @@ def setup_ui(self) -> None: # Update labels to reflect inputted image and phantom if self._image_data is not None: - self._ui.image_path_input.setText(self._image_data.scan_name or "No image loaded") - self._ui.phantom_path_input.setText(self._image_data.phantom_name or "No phantom loaded") + self._ui.image_path_input.setText(getattr(self._image_data, 'scan_name', "No image loaded")) + self._ui.phantom_path_input.setText(getattr(self._image_data, 'phantom_name', "No phantom loaded")) else: self._ui.image_path_input.setText("No image loaded") self._ui.phantom_path_input.setText("No phantom loaded") @@ -219,7 +220,7 @@ def _clear_summary_layout(self) -> None: if child.widget(): child.widget().deleteLater() - def show_results(self, analysis_data: ParamapAnalysis) -> None: + def show_results(self, analysis_data: CurvesAnalysis) -> None: """ Show analysis results. @@ -233,6 +234,12 @@ def show_results(self, analysis_data: ParamapAnalysis) -> None: self._ui.progress_bar.setValue(100) self._ui.progress_label.setText("Analysis completed successfully!") + # Add a message that the rest of the pipeline needs to be finished + info_label = QLabel("Note: The rest of the pipeline still needs to be finished.") + info_label.setStyleSheet("color: #FFD700; font-style: italic; font-size: 10px; margin-top: 5px;") + info_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._ui.analysis_execution_layout.addWidget(info_label) + # Show finish button, hide execute button self._ui.execute_button.setVisible(False) self._ui.finish_button.setVisible(True) diff --git a/src/ceus/analysis_loading/views/analysis_function_selection_widget.py b/src/ceus/analysis_loading/views/analysis_function_selection_widget.py index f36643d..318544a 100644 --- a/src/ceus/analysis_loading/views/analysis_function_selection_widget.py +++ b/src/ceus/analysis_loading/views/analysis_function_selection_widget.py @@ -9,9 +9,10 @@ from PyQt6.QtWidgets import QWidget, QComboBox, QVBoxLayout, QHBoxLayout, QLabel, QSizePolicy from PyQt6.QtCore import pyqtSignal, Qt -from quantus.gui.mvc.base_view import BaseViewMixin -from quantus.gui.analysis_loading.ui.analysis_function_selection_ui import Ui_analysisFunctionSelection -from quantus.data_objs import UltrasoundRfImage, BmodeSeg, RfAnalysisConfig +from ...mvc.base_view import BaseViewMixin +from ..ui.analysis_function_selection_ui import Ui_analysisFunctionSelection +from engines.ceus.src.data_objs.image import UltrasoundImage +from engines.ceus.src.data_objs.seg import CeusSeg class AnalysisFunctionSelectionWidget(QWidget, BaseViewMixin): @@ -27,7 +28,7 @@ class AnalysisFunctionSelectionWidget(QWidget, BaseViewMixin): close_requested = pyqtSignal() back_requested = pyqtSignal() - def __init__(self, image_data: UltrasoundRfImage, seg_data: BmodeSeg, config_data: RfAnalysisConfig, parent: Optional[QWidget] = None): + def __init__(self, image_data: UltrasoundImage, seg_data: CeusSeg, config_data, parent: Optional[QWidget] = None): QWidget.__init__(self, parent) self.__init_base_view__(parent) self._ui = Ui_analysisFunctionSelection() @@ -56,8 +57,8 @@ def setup_ui(self) -> None: # Update labels to reflect inputted image and phantom if self._image_data is not None: - self._ui.image_path_input.setText(self._image_data.scan_name or "No image loaded") - self._ui.phantom_path_input.setText(self._image_data.phantom_name or "No phantom loaded") + self._ui.image_path_input.setText(getattr(self._image_data, 'scan_name', "No image loaded")) + self._ui.phantom_path_input.setText(getattr(self._image_data, 'phantom_name', "No phantom loaded")) else: self._ui.image_path_input.setText("No image loaded") self._ui.phantom_path_input.setText("No phantom loaded") diff --git a/src/ceus/analysis_loading/views/analysis_params_widget.py b/src/ceus/analysis_loading/views/analysis_params_widget.py new file mode 100644 index 0000000..4d25d2b --- /dev/null +++ b/src/ceus/analysis_loading/views/analysis_params_widget.py @@ -0,0 +1,152 @@ +""" +Analysis Parameters Widget for Analysis Loading + +This widget allows users to configure parameters required for the selected analysis functions. +It dynamically creates input fields based on the required parameters. +""" + +from typing import List, Optional, Dict, Any +from PyQt6.QtWidgets import (QWidget, QLabel, QLineEdit, QDoubleSpinBox, QSpinBox, + QCheckBox, QComboBox, QFormLayout, + QGroupBox, QTextEdit) +from PyQt6.QtCore import pyqtSignal, Qt, QTimer + +from ...mvc.base_view import BaseViewMixin +from ..ui.analysis_params_ui import Ui_analysisParams +from engines.ceus.src.data_objs.image import UltrasoundImage +from engines.ceus.src.data_objs.seg import CeusSeg + + +class AnalysisParamsWidget(QWidget, BaseViewMixin): + """ + Widget for configuring analysis parameters. + + This widget dynamically creates input fields based on the required parameters + for the selected analysis functions. + """ + + # Signals for communicating with controller + params_configured = pyqtSignal(dict) # analysis_params + close_requested = pyqtSignal() + back_requested = pyqtSignal() + + def __init__(self, image_data: UltrasoundImage, seg_data: CeusSeg, config_data, parent: Optional[QWidget] = None): + QWidget.__init__(self, parent) + self.__init_base_view__(parent) + self._ui = Ui_analysisParams() + self._image_data = image_data + self._seg_data = seg_data + self._config_data = config_data + + # Track parameter inputs + self._param_inputs: Dict[str, QWidget] = {} + self._required_params: List[str] = [] + self._selected_functions: List[str] = [] + + def setup_ui(self) -> None: + """Setup the user interface.""" + self._ui.setupUi(self) + + # Configure layout for parameters configuration (assuming similar structure to QUS) + self.setLayout(self._ui.full_screen_layout) + + # Configure stretch factors + self._ui.full_screen_layout.setStretchFactor(self._ui.side_bar_layout, 1) + self._ui.full_screen_layout.setStretchFactor(self._ui.analysis_params_layout, 10) + + # Update labels to reflect inputted image and phantom + self._ui.image_path_input.setText(getattr(self._image_data, 'scan_name', "No image loaded")) + self._ui.phantom_path_input.setText(getattr(self._image_data, 'phantom_name', "No phantom loaded")) + + def connect_signals(self) -> None: + """Connect UI signals to internal handlers.""" + self._ui.run_analysis_button.clicked.connect(self._on_run_analysis_clicked) + self._ui.back_button.clicked.connect(self._on_back_clicked) + + def set_required_params(self, required_params: List[str], selected_functions: List[str]) -> None: + """ + Set required parameters and create input fields. + + Args: + required_params: List of required parameter names + selected_functions: List of selected function names + """ + print(f"DEBUG: AnalysisParamsWidget.set_required_params called") + print(f"DEBUG: required_params = {required_params}") + print(f"DEBUG: selected_functions = {selected_functions}") + self._required_params = required_params + self._selected_functions = selected_functions + self._create_parameter_inputs() + + def _create_parameter_inputs(self) -> None: + """Create input fields for each required parameter.""" + print(f"DEBUG: AnalysisParamsWidget._create_parameter_inputs called") + + # Clear existing inputs + self._clear_params_layout() + self._param_inputs = {} + + # If no params required, provide a small delay and auto-transition? + # Or just show the screen with a "Continue" button + if not self._required_params: + print(f"DEBUG: No required params found") + if hasattr(self._ui, 'run_analysis_button'): + self._ui.run_analysis_button.setText("Continue to Execution") + self._ui.run_analysis_button.setVisible(True) + self._ui.run_analysis_button.setEnabled(True) + self._ui.analysis_running_label.hide() + self._ui.analysis_execution_label.hide() + return + + # Show normal parameter labels + self._ui.analysis_params_label.show() + self._ui.run_analysis_button.setText("Run Analysis") + self._ui.run_analysis_button.setVisible(True) + self._ui.run_analysis_button.setEnabled(True) + self._ui.analysis_running_label.hide() + self._ui.analysis_execution_label.hide() + + # Ideally would dynamically create inputs based on CEUS requirements + form_layout = QFormLayout() + for param_name in self._required_params: + label = QLabel(param_name.replace("_", " ").title() + ":") + label.setStyleSheet("color: white; font-size: 14px;") + + # Simple line edit for now + line_edit = QLineEdit() + line_edit.setStyleSheet("color: white; background-color: rgb(60, 60, 60); border: 1px solid gray; padding: 5px;") + + form_layout.addRow(label, line_edit) + self._param_inputs[param_name] = line_edit + + self._ui.params_layout.addLayout(form_layout) + + def _clear_params_layout(self) -> None: + """Clear all widgets from the params container.""" + if hasattr(self._ui, 'params_layout') and self._ui.params_layout is not None: + while self._ui.params_layout.count(): + child = self._ui.params_layout.takeAt(0) + if child.widget(): + child.widget().deleteLater() + elif child.layout(): + # Recursively clear sub-layouts + def clear_sub_layout(l): + while l.count(): + c = l.takeAt(0) + if c.widget(): c.widget().deleteLater() + elif c.layout(): clear_sub_layout(c.layout()) + clear_sub_layout(child.layout()) + + def _on_run_analysis_clicked(self) -> None: + """Handle run analysis button click.""" + print(f"DEBUG: AnalysisParamsWidget._on_run_analysis_clicked called") + # Collect parameters (simplified) + params = {} + # TODO: Collect actual values from dynamically created widgets + + print(f"DEBUG: Emitting params_configured with {params}") + self.params_configured.emit(params) + + def _on_back_clicked(self) -> None: + """Handle back button click.""" + self.back_requested.emit() diff --git a/src/ceus/application_controller.py b/src/ceus/application_controller.py index 86f6cbf..9326a6f 100644 --- a/src/ceus/application_controller.py +++ b/src/ceus/application_controller.py @@ -5,14 +5,16 @@ import sys import qdarktheme from typing import Optional -from PyQt6.QtWidgets import QApplication, QStackedWidget +from PyQt6.QtWidgets import QMessageBox from PyQt6.QtCore import QObject, pyqtSignal +from PyQt6.QtWidgets import QApplication, QStackedWidget from .application_model import ApplicationModel -from .image_loading.image_loading_view_coordinator import ImageLoadingViewCoordinator from .image_loading.image_loading_controller import ImageLoadingController from .seg_loading.seg_loading_controller import SegmentationLoadingController -from engines.ceus.src.data_objs import UltrasoundImage, CeusSeg +from .analysis_loading.analysis_loading_controller import AnalysisLoadingController +from engines.ceus.src.data_objs.image import UltrasoundImage +from engines.ceus.src.data_objs.seg import CeusSeg class ApplicationController(QObject): @@ -37,9 +39,14 @@ def __init__(self, app: QApplication): # Unified application model self._model = ApplicationModel() + # Current data + self._image_data: Optional[UltrasoundImage] = None + self._seg_data: Optional[CeusSeg] = None + # Controllers for different screens (using the same model) self._image_loading_controller: Optional[ImageLoadingController] = None self._segmentation_controller: Optional[SegmentationLoadingController] = None + self._analysis_loading_controller: Optional[AnalysisLoadingController] = None # Setup main widget self._setup_main_widget() @@ -59,7 +66,7 @@ def _connect_model_signals(self) -> None: """Connect unified model signals to application controller.""" self._model.image_loaded.connect(self._initialize_segmentation_loading) self._model.error_occurred.connect(self._on_model_error) - + def _initialize_image_loading(self) -> None: """Initialize the image loading screen.""" if self._image_loading_controller: @@ -82,6 +89,8 @@ def _initialize_segmentation_loading(self, image_data: UltrasoundImage) -> None: Args: image_data: Loaded image data from previous screen """ + self._image_data = image_data + if self._segmentation_controller: self._cleanup_segmentation_loading() @@ -104,7 +113,8 @@ def _on_model_error(self, error_message: str) -> None: error_message: Error message from model """ print(f"DEBUG: Application model error: {error_message}") - # The individual view controllers will handle displaying the error to the user + # Show error message to user + QMessageBox.critical(self._widget_stack, "Error", error_message) def _on_image_action(self, action_name: str, action_data) -> None: """ @@ -128,10 +138,58 @@ def _on_segmentation_action(self, action_name: str, action_data) -> None: """ if action_name == 'segmentation_confirmed': self._seg_data = self._segmentation_controller.get_loaded_segmentation() - # TODO: Navigate to analysis screen when implemented - print("Analysis screen coming soon...") - self._app.quit() + + # Use model data as source of truth + image_data = self._model.image_data if self._model.image_data else self._image_data + + self._initialize_analysis_loading(image_data, self._seg_data) + def _initialize_analysis_loading(self, image_data: UltrasoundImage, seg_data: CeusSeg) -> None: + """ + Initialize the analysis loading screen. + + Args: + image_data: Loaded image data + seg_data: Loaded segmentation data + """ + if self._analysis_loading_controller: + self._cleanup_analysis_loading() + + # Create controller with unified model + # Note: CEUS might need a config object, passing None for now if not available + self._analysis_loading_controller = AnalysisLoadingController(self._model, image_data, seg_data, None) + + # Connect signals + self._analysis_loading_controller.view.user_action.connect(self._on_analysis_action) + self._analysis_loading_controller.view.back_requested.connect(self._navigate_to_segmentation_loading) + + # Add to stack and show + self._widget_stack.addWidget(self._analysis_loading_controller.view) + self._widget_stack.setCurrentWidget(self._analysis_loading_controller.view) + + def _on_analysis_action(self, action_name: str, action_data) -> None: + """ + Handle actions from the analysis loading screen. + + Args: + action_name: Name of the action + action_data: Data associated with the action + """ + if action_name == 'analysis_loading_completed': + print("Analysis completed successfully!") + # Future: Navigate to visualization screen + self._app.quit() + + def _navigate_to_segmentation_loading(self) -> None: + """Navigate back to segmentation loading.""" + if self._analysis_loading_controller: + self._cleanup_analysis_loading() + + if self._segmentation_controller: + self._widget_stack.setCurrentWidget(self._segmentation_controller.view) + else: + self._initialize_segmentation_loading(self._image_data) + def _navigate_to_image_loading(self) -> None: """Navigate to image loading screen.""" # Reset image loading controller to initial state @@ -192,6 +250,15 @@ def _cleanup(self) -> None: """Clean up all resources before application exit.""" self._cleanup_image_loading() self._cleanup_segmentation_loading() + self._cleanup_analysis_loading() + + def _cleanup_analysis_loading(self) -> None: + """Clean up analysis loading controller resources.""" + if self._analysis_loading_controller: + self._widget_stack.removeWidget(self._analysis_loading_controller.view) + self._analysis_loading_controller.cleanup() + self._analysis_loading_controller.view.deleteLater() + self._analysis_loading_controller = None @property def image_data(self) -> Optional[UltrasoundImage]: diff --git a/src/ceus/application_model.py b/src/ceus/application_model.py index 2ae47fc..f6e6767 100644 --- a/src/ceus/application_model.py +++ b/src/ceus/application_model.py @@ -12,8 +12,13 @@ from .mvc.base_model import BaseModel from engines.ceus.src.image_loading.options import get_scan_loaders from engines.ceus.src.seg_loading.options import get_seg_loaders +from engines.ceus.src.time_series_analysis.options import get_analysis_types from engines.ceus.src.entrypoints import scan_loading_step, seg_loading_step -from engines.ceus.src.data_objs import UltrasoundImage, CeusSeg +from engines.ceus.src.data_objs.image import UltrasoundImage +from engines.ceus.src.data_objs.seg import CeusSeg +from engines.ceus.src.time_series_analysis.curves.framework import CurvesAnalysis +from engines.ceus.src.time_series_analysis.options import get_required_kwargs +from engines.ceus.src.image_preprocessing.options import get_required_im_preproc_kwargs class ScanLoadingWorker(QThread): @@ -35,6 +40,11 @@ def run(self): self.image_path, **self.scan_loader_kwargs ) + + if isinstance(image_data, int): + self.error_msg.emit(f"Error loading scan: Loader error code {image_data}") + return + self.finished.emit(image_data) except Exception as e: @@ -66,6 +76,10 @@ def run(self): **self.seg_loader_kwargs ) + if isinstance(seg_data, int): + self.error_msg.emit(f"Error loading segmentation: Loader error code {seg_data}") + return + self.finished.emit(seg_data) except Exception as e: @@ -75,6 +89,56 @@ def run(self): self.error_msg.emit(f"Error loading segmentation: {e}") +class AnalysisWorker(QThread): + """Worker thread for time-consuming analysis operations.""" + finished = pyqtSignal(object) + error_msg = pyqtSignal(str) + + def __init__(self, analysis_type: str, image_data: UltrasoundImage, + config_data: Any, seg_data: CeusSeg, + selected_functions: List[str], analysis_kwargs: Dict[str, Any]): + super().__init__() + self.analysis_type = analysis_type + self.image_data = image_data + self.config_data = config_data + self.seg_data = seg_data + self.selected_functions = selected_functions + self.analysis_kwargs = analysis_kwargs + + def run(self): + """Execute the analysis in background thread.""" + try: + from engines.ceus.src.time_series_analysis.options import get_analysis_types + all_types, _ = get_analysis_types() + + if self.analysis_type not in all_types: + self.error_msg.emit(f"Invalid analysis type: {self.analysis_type}") + return + + analysis_cls = all_types[self.analysis_type] + + # Initialize analysis + analysis_obj = analysis_cls( + self.image_data, + self.seg_data, + self.selected_functions, + **self.analysis_kwargs + ) + + # Execute analysis + if hasattr(analysis_obj, 'compute_curves'): + analysis_obj.compute_curves() + elif hasattr(analysis_obj, 'run'): + analysis_obj.run() + + self.finished.emit(analysis_obj) + + except Exception as e: + import traceback + traceback.print_exc() + self.error_msg.emit(f"Error during analysis: {e}") + + class ApplicationModel(BaseModel): """ Unified application model that manages all data and business logic for the QuantUS GUI. @@ -88,7 +152,9 @@ class ApplicationModel(BaseModel): # Additional signals for application-specific events image_loaded = pyqtSignal(UltrasoundImage) + preprocessing_complete = pyqtSignal(UltrasoundImage) segmentation_loaded = pyqtSignal(CeusSeg) + analysis_completed = pyqtSignal(object) # Emits CurvesAnalysis def __init__(self): super().__init__() @@ -105,9 +171,17 @@ def __init__(self): self._seg_data: Optional[CeusSeg] = None self._seg_worker: Optional[SegLoadingWorker] = None + # Analysis state + self._analysis_data: Optional[CurvesAnalysis] = None + self._analysis_types: Dict[str, Any] = {} + self._analysis_functions: Dict[str, Any] = {} + self._selected_analysis_type: Optional[str] = None + self._analysis_worker: Optional[AnalysisWorker] = None + # Initialize loaders self._load_scan_loaders() self._load_seg_loaders() + self._load_analysis_types() def _load_scan_loaders(self) -> None: """Load available scan loaders from backend.""" @@ -122,6 +196,15 @@ def _load_seg_loaders(self) -> None: self._seg_loaders = get_seg_loaders() except Exception as e: self._emit_error(f"Failed to load seg loaders: {e}") + + def _load_analysis_types(self) -> None: + """Load available analysis types from backend.""" + try: + self._analysis_types, self._analysis_functions = get_analysis_types() + except Exception as e: + print(f"Error loading analysis types: {e}") + self._analysis_types = {} + self._analysis_functions = {} # Image Loading Properties and Methods @property @@ -276,26 +359,23 @@ def get_preprocessing_kwargs_requirements(self, func_names: list) -> list: Returns: list: List of required keyword arguments """ - from engines.ceus.src.image_preprocessing.options import get_required_im_preproc_kwargs return get_required_im_preproc_kwargs(func_names) - def apply_preprocessing_preview(self, func_configs: List[Dict[str, Any]], image_data: Optional[UltrasoundImage] = None) -> UltrasoundImage: + def apply_preprocessing(self, func_configs: List[Dict[str, Any]]) -> None: """ - Apply preprocessing to the given UltrasoundImage. - This does not modify the image data in the model. + Apply preprocessing to the model's current image. + This modifies the image data in the model. Args: func_configs: List of dicts with 'name' and 'kwargs' for each function - image_data: Optional UltrasoundImage to preprocess (if None, uses current image) """ - if not image_data and not self._image_data: + if not self._image_data: self._emit_error("No image loaded to preprocess") return - - processed_image = image_data if image_data else self._image_data try: funcs = self.get_preprocessing_options() + processed_image = self._image_data for config in func_configs: name = config['name'] @@ -304,12 +384,12 @@ def apply_preprocessing_preview(self, func_configs: List[Dict[str, Any]], image_ processed_image = funcs[name](processed_image, **kwargs) else: print(f"WARNING: Preprocessing function {name} not found") - + + self._image_data = processed_image + self.preprocessing_complete.emit(self._image_data) except Exception as e: self._emit_error(f"Error during preprocessing: {e}") - return processed_image - def enhance_image(self, image: UltrasoundImage, func_configs: List[Dict[str, Any]]) -> UltrasoundImage: """ Enhance a given UltrasoundImage and return the result. @@ -337,6 +417,19 @@ def enhance_image(self, image: UltrasoundImage, func_configs: List[Dict[str, Any print(f"DEBUG: enhance_image error: {e}") return image + def apply_preprocessing_preview(self, func_configs: List[Dict[str, Any]], image: UltrasoundImage) -> UltrasoundImage: + """ + Alias for enhance_image to support legacy controller calls. + + Args: + func_configs: List of dicts with 'name' and 'kwargs' + image: UltrasoundImage object to enhance + + Returns: + UltrasoundImage: Enhanced image + """ + return self.enhance_image(image, func_configs) + def _validate_image_input(self, input_data: Dict[str, Any]) -> bool: """ Validate input data for scan loading. @@ -438,7 +531,7 @@ def set_seg_type(self, seg_type_display_name: str) -> bool: """ try: if seg_type_display_name == "Manual Segmentation": - self._selected_seg_type = "pkl_roi" + self._selected_seg_type = "nifti" return True # Convert display name back to internal key @@ -541,9 +634,24 @@ def _on_segmentation_loading_complete(self, seg_data: CeusSeg) -> None: print(f"-----------------------------------------------\n") self.segmentation_loaded.emit(seg_data) + # Automatically confirm if this was loaded (either from file or manual save) + # This allows the app controller to catch the completion else: print(f"DEBUG: Segmentation loading failed - invalid seg data") self._emit_error("Failed to load segmentation data") + + def set_manual_segmentation(self, seg_data: CeusSeg) -> None: + """ + Set manually drawn segmentation data. + + Args: + seg_data: Manually drawn segmentation data + """ + if seg_data and hasattr(seg_data, 'seg_mask') and seg_data.seg_mask is not None: + self._seg_data = seg_data + self.segmentation_loaded.emit(seg_data) + else: + self._emit_error("Invalid manual segmentation data") def cleanup(self) -> None: """Clean up resources.""" @@ -556,3 +664,105 @@ def cleanup(self) -> None: self._seg_worker.quit() self._seg_worker.wait() self._seg_worker = None + + if hasattr(self, '_analysis_worker') and self._analysis_worker and self._analysis_worker.isRunning(): + self._analysis_worker.quit() + self._analysis_worker.wait() + self._analysis_worker = None + + # ============================================================================ + # ANALYSIS METHODS + # ============================================================================ + + def get_analysis_types(self) -> tuple: + """Get available analysis types and functions.""" + return self._analysis_types, self._analysis_functions + + def set_analysis_type(self, analysis_type: str) -> bool: + """ + Set the selected analysis type. + + Args: + analysis_type: Analysis type to select + + Returns: + bool: True if successful + """ + if analysis_type in self._analysis_types: + self._selected_analysis_type = analysis_type + return True + else: + print(f"DEBUG: Invalid analysis type: {analysis_type}") + return False + + def get_analysis_functions(self, analysis_type: str) -> dict: + """ + Get available functions for an analysis type. + + Args: + analysis_type: Analysis type + + Returns: + dict: Available functions for the analysis type + """ + # In CEUS engine, analysis_functions is a flat dict of all available curve functions + # that are applicable to both 'curves' and 'curves_paramap' analysis types. + if analysis_type in self._analysis_functions and isinstance(self._analysis_functions[analysis_type], dict): + return self._analysis_functions[analysis_type] + + return self._analysis_functions + + def get_required_params(self, analysis_type: str, selected_functions: list) -> list: + """ + Get required parameters for the selected analysis. + + Args: + analysis_type: Key for the analysis type + selected_functions: List of selected function names + + Returns: + list: List of parameter names required + """ + try: + return get_required_kwargs(analysis_type, selected_functions) + except Exception as e: + print(f"Error getting required params: {e}") + return [] + + def set_analysis_data(self, analysis_data: CurvesAnalysis) -> None: + """ + Store completed analysis data. + + Args: + analysis_data: Completed analysis data + """ + self._analysis_data = analysis_data + # Signal that analysis is complete + self.analysis_completed.emit(analysis_data) + + def run_analysis(self, analysis_type: str, image_data: UltrasoundImage, + config_data: Any, seg_data: CeusSeg, + selected_functions: List[str], **kwargs) -> None: + """ + Run the analysis in a background thread. + """ + # Stop existing worker if running + if self._analysis_worker and self._analysis_worker.isRunning(): + self._analysis_worker.quit() + self._analysis_worker.wait() + + self._analysis_worker = AnalysisWorker( + analysis_type, image_data, config_data, seg_data, selected_functions, kwargs + ) + + self._analysis_worker.finished.connect(self._on_analysis_worker_finished) + self._analysis_worker.error_msg.connect(self._emit_error) + + self._set_loading(True) + self._analysis_worker.start() + + def _on_analysis_worker_finished(self, analysis_obj: Any) -> None: + """Handle analysis completion.""" + self._set_loading(False) + self._analysis_data = analysis_obj + self.analysis_completed.emit(analysis_obj) diff --git a/src/ceus/image_loading/views/file_selection_widget.py b/src/ceus/image_loading/views/file_selection_widget.py index 5b095c6..23f8d09 100644 --- a/src/ceus/image_loading/views/file_selection_widget.py +++ b/src/ceus/image_loading/views/file_selection_widget.py @@ -116,7 +116,8 @@ def _show_loading_message(self) -> None: def _on_choose_image_path(self) -> None: """Handle image file selection.""" - if self._file_extensions == ["FOLDER"]: + is_folder = any(ext.upper() == "FOLDER" for ext in self._file_extensions) + if is_folder: dir_name = QFileDialog.getExistingDirectory(self, "Select Directory") if dir_name: self._ui.image_path_input.setText(dir_name) @@ -133,12 +134,16 @@ def _on_generate_image(self) -> None: if not os.path.exists(image_path): self.show_error(f"Image file does not exist: {os.path.basename(image_path)}") return - if not image_path.endswith(tuple(self._file_extensions)) and self._file_extensions != ['FOLDER']: - self.show_error(f"Image file must have one of the following extensions: {', '.join(self._file_extensions)}") - return - if self._file_extensions == ["FOLDER"] and not os.path.isdir(image_path): - self.show_error("Input path must be a folder!") - return + + is_folder = any(ext.upper() == "FOLDER" for ext in self._file_extensions) + if not is_folder: + if not image_path.endswith(tuple(self._file_extensions)): + self.show_error(f"Image file must have one of the following extensions: {', '.join(self._file_extensions)}") + return + else: + if not os.path.isdir(image_path): + self.show_error("Input path must be a folder!") + return self.clear_error() diff --git a/src/ceus/seg_loading/seg_loading_controller.py b/src/ceus/seg_loading/seg_loading_controller.py index eab66f6..e756958 100644 --- a/src/ceus/seg_loading/seg_loading_controller.py +++ b/src/ceus/seg_loading/seg_loading_controller.py @@ -35,15 +35,15 @@ def __init__(self, model: Optional[ApplicationModel] = None, custom_view=None): super().__init__(model, view) - # # Connect to model signals for automatic view updates - # self._connect_model_signals() + # Connect to model signals for automatic view updates + self._connect_model_signals() # Initialize view with segmentation loaders self._initialize_view() - # def _connect_model_signals(self) -> None: - # """Connect to model signals for automatic view updates.""" - # self.model.segmentation_loaded.connect(self.view.show_segmentation_preview) + def _connect_model_signals(self) -> None: + """Connect to model signals for automatic view updates.""" + self.model.segmentation_loaded.connect(self.view.show_segmentation_preview) def _initialize_view(self) -> None: """Initialize the view with data from the model.""" @@ -67,7 +67,10 @@ def handle_user_action(self, action_name: str, action_data: Any) -> None: elif action_name == 'apply_preprocs_preview': self._handle_preprocs_preview(action_data) elif action_name == 'segmentation_confirmed': - pass # Handle confirmation action in the application controller + # Ensure the model has the confirmed segmentation data + # This is especially important for manually drawn segmentations + if action_data: + self.model.set_manual_segmentation(action_data) else: raise ValueError(f"Unknown action: {action_name}") @@ -132,5 +135,9 @@ def get_loaded_segmentation(self) -> CeusSeg: return self.model.seg_data def cleanup(self) -> None: - """Clean up resources.""" + """Clean up resources and disconnect signals.""" + try: + self.model.segmentation_loaded.disconnect(self.view.show_segmentation_preview) + except (TypeError, RuntimeError): + pass self.model.cleanup() diff --git a/src/ceus/seg_loading/seg_loading_view_coordinator.py b/src/ceus/seg_loading/seg_loading_view_coordinator.py index 4f6d9d5..709da9f 100644 --- a/src/ceus/seg_loading/seg_loading_view_coordinator.py +++ b/src/ceus/seg_loading/seg_loading_view_coordinator.py @@ -15,6 +15,8 @@ from .views.seg_file_selection_widget import SegFileSelectionWidget from .views.draw_roi_widget import DrawROIWidget from .views.draw_voi_widget import DrawVOIWidget +from .views.roi_preview_widget import ROIPreviewWidget +from .views.voi_preview_widget import VOIPreviewWidget from engines.ceus.src.data_objs import UltrasoundImage, CeusSeg @@ -48,10 +50,12 @@ def __init__(self, image_data: UltrasoundImage, parent: Optional[QWidget] = None self._seg_type_widget: Optional[SegTypeSelectionWidget] = None self._seg_file_widget: Optional[SegFileSelectionWidget] = None self._voi_drawing_widget: Optional[DrawVOIWidget] = None + self._roi_drawing_widget: Optional[DrawROIWidget] = None + self._roi_preview_widget: Optional[ROIPreviewWidget] = None + self._voi_preview_widget: Optional[VOIPreviewWidget] = None # Current state self._selected_seg_type: Optional[str] = None - self._seg_data: Optional[CeusSeg] = None # Start with segmentation type selection self.show_seg_type_selection() @@ -109,6 +113,9 @@ def reset_to_seg_type_selection(self) -> None: widgets_to_remove = [ self._seg_file_widget, self._voi_drawing_widget, + self._roi_drawing_widget, + self._roi_preview_widget, + self._voi_preview_widget ] for widget in widgets_to_remove: @@ -120,11 +127,11 @@ def reset_to_seg_type_selection(self) -> None: self._seg_file_widget = None self._roi_drawing_widget = None self._voi_drawing_widget = None - self._seg_preview_widget = None + self._roi_preview_widget = None + self._voi_preview_widget = None # Clear state self._selected_seg_type = None - self._seg_data = None # Return to seg type widget if self._seg_type_widget: @@ -181,6 +188,7 @@ def show_voi_drawing(self) -> None: self._voi_drawing_widget = DrawVOIWidget(self._image_data) # Connect signals to handle user actions + self._voi_drawing_widget.segmentation_saved.connect(self._on_segmentation_saved) self._voi_drawing_widget.back_requested.connect(self.reset_to_seg_type_selection) self._voi_drawing_widget.close_requested.connect(self.close_requested.emit) self._voi_drawing_widget.apply_preprocs_preview.connect(self._on_preprocs_preview_requested) @@ -191,8 +199,14 @@ def show_voi_drawing(self) -> None: def preview_modified_image(self, modified_image: UltrasoundImage, frame: int) -> None: """Show the preprocessed data in the VOI drawing widget.""" - if self._voi_drawing_widget: + if self._roi_preview_widget: + self._roi_preview_widget.update_enhancement_cache(modified_image.pixel_data, frame) + elif self._voi_preview_widget: + self._voi_preview_widget.update_enhancement_cache(modified_image.pixel_data, frame) + elif self._voi_drawing_widget: self._voi_drawing_widget.update_enhancement_cache(modified_image.pixel_data, frame) + elif self._roi_drawing_widget: + self._roi_drawing_widget.update_enhancement_cache(modified_image.pixel_data, frame) else: raise RuntimeError("VOI drawing widget not initialized") @@ -201,16 +215,77 @@ def show_roi_drawing(self) -> None: self._roi_drawing_widget = DrawROIWidget(self._image_data) # Connect signals to handle user actions + self._roi_drawing_widget.segmentation_saved.connect(self._on_segmentation_saved) self._roi_drawing_widget.back_requested.connect(self.reset_to_seg_type_selection) self._roi_drawing_widget.close_requested.connect(self.close_requested.emit) + self._roi_drawing_widget.apply_preprocs_preview.connect(self._on_preprocs_preview_requested) # Add to stack and show self.addWidget(self._roi_drawing_widget) self.setCurrentWidget(self._roi_drawing_widget) + def show_segmentation_preview(self, seg_data: CeusSeg) -> None: + """ + Show the segmentation preview widget. + + Args: + seg_data: Loaded segmentation data + """ + # Create and setup segmentation preview widget + preview_widget = None + if seg_data.seg_mask.ndim == 2: + self._roi_preview_widget = ROIPreviewWidget(self._image_data, seg_data) + preview_widget = self._roi_preview_widget + elif seg_data.seg_mask.ndim == 3: + self._voi_preview_widget = VOIPreviewWidget(self._image_data, seg_data) + preview_widget = self._voi_preview_widget + else: + raise NotImplementedError("Only 2D and 3D frames are supported") + + # Connect signals to handle user actions + preview_widget.segmentation_confirmed.connect( + lambda: self.user_action.emit('segmentation_confirmed', seg_data) + ) + preview_widget.back_requested.connect(self.back_from_preview) + preview_widget.close_requested.connect(self.close_requested.emit) + preview_widget.apply_preprocs_preview.connect(self._on_preprocs_preview_requested) + + # Add to stack and show + self.addWidget(preview_widget) + self.setCurrentWidget(preview_widget) + # ============================================================================ # USER ACTION HANDLING - Process user interactions and communicate with controller # ============================================================================ + + def back_from_preview(self): + if not self._roi_drawing_widget and not self._voi_drawing_widget: + self.reset_to_seg_type_selection() + elif self._roi_drawing_widget: + self.removeWidget(self._roi_preview_widget) + self._roi_preview_widget.deleteLater() + self._roi_preview_widget = None + self.setCurrentWidget(self._roi_drawing_widget) + elif self._voi_drawing_widget: + self.removeWidget(self._voi_preview_widget) + self._voi_preview_widget.deleteLater() + self._voi_preview_widget = None + self.setCurrentWidget(self._voi_drawing_widget) + else: + raise ValueError("Undefine state. Should never get here") + + def _on_segmentation_saved(self, file_path: str) -> None: + """ + Handle segmentation saved from the manual drawing widget. + + Args: + file_path: Path to the saved segmentation file + """ + file_data = { + 'seg_path': file_path, + 'seg_type': self._selected_seg_type + } + self._emit_user_action('load_segmentation', file_data) def _on_seg_type_selected(self, seg_type_name: str) -> None: """ diff --git a/src/ceus/seg_loading/ui/draw_voi.ui b/src/ceus/seg_loading/ui/draw_voi.ui index 15ec3cb..5416e96 100644 --- a/src/ceus/seg_loading/ui/draw_voi.ui +++ b/src/ceus/seg_loading/ui/draw_voi.ui @@ -1103,113 +1103,6 @@ - - - - - - QLabel { - font-size: 20px; - color: rgb(255, 255, 255); - background-color: rgba(255, 255, 255, 0); -} - - - VOI Alpha: - - - Qt::AutoText - - - false - - - Qt::AlignCenter - - - true - - - - - - - - - - 285 - 0 - - - - - 285 - 16777215 - - - - 24 - - - - - - - - 13 - - - - QSpinBox{ - background-color: white, -} - - - - - - - QLabel { - font-size: 17px; - color: rgb(255, 255, 255); - background-color: rgba(255, 255, 255, 0); -} - - - of - - - Qt::AlignCenter - - - Qt::NoTextInteraction - - - - - - - QLabel { - font-size: 17px; - color: rgb(255, 255, 255); - background-color: rgba(255, 255, 255, 0); -} - - - 255 - - - Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter - - - Qt::NoTextInteraction - - - - - - - diff --git a/src/ceus/seg_loading/ui/roi_preview.ui b/src/ceus/seg_loading/ui/roi_preview.ui new file mode 100644 index 0000000..a41d1c1 --- /dev/null +++ b/src/ceus/seg_loading/ui/roi_preview.ui @@ -0,0 +1,740 @@ + + + confirmRoi + + + + 0 + 0 + 1422 + 725 + + + + + 1400 + 662 + + + + Select Region of Interest + + + QWidget { + background: rgb(42, 42, 42); +} + + + + 0 + + + 0 + + + 0 + + + 0 + + + + + + + 0 + + + QLayout::SetMaximumSize + + + + + + 341 + 601 + + + + + 241 + 601 + + + + <html><head/><body><p><br/></p></body></html> + + + QWidget { + background-color: rgb(28, 0, 101); +} + + + + + 0 + 0 + 341 + 121 + + + + + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(99, 0, 174); + border: 1px solid black; +} + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + 70 + 0 + 191 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Scan Selection: + + + Qt::AlignCenter + + + + + + -60 + 50 + 191 + 51 + + + + QLabel { + font-size: 16px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Scan: + + + Qt::AlignCenter + + + + + + 70 + 50 + 261 + 51 + + + + QLabel { + font-size: 14px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; +} + + + Sample filename + + + Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + + + + + + + 0 + 120 + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(99, 0, 174); + border: 1px solid black; +} + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + 0 + 40 + 341 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Segmentation Selection + + + Qt::AlignCenter + + + + + + + 0 + 360 + 341 + 121 + + + + + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(49, 0, 124); + border: 1px solid black; +} + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + 0 + 30 + 341 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Results + + + Qt::AlignCenter + + + + + + + 0 + 480 + 341 + 121 + + + + + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(49, 0, 124); + border: 1px solid black; +} + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + 20 + 30 + 301 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Visualizations + + + Qt::AlignCenter + + + + + + + 0 + 240 + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(49, 0, 124); + border: 1px solid black; +} + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + 0 + 30 + 341 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight:bold; +} + + + Analysis Configuration + + + Qt::AlignCenter + + + + + + + + + + 341 + 0 + + + + + 341 + 16777215 + + + + QFrame { + background-color: rgb(28, 0, 101); +} + + + + QLayout::SetMinAndMaxSize + + + 10 + + + 10 + + + 10 + + + 10 + + + + + Qt::Vertical + + + + 20 + 40 + + + + + + + + + + + + + 10 + + + 30 + + + 10 + + + 30 + + + 10 + + + + + QLabel { + font-size: 29px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + Confirm Segmentation + + + Qt::AutoText + + + false + + + Qt::AlignCenter + + + true + + + + + + + + + + 241 + 41 + + + + + 241 + 41 + + + + QPushButton { + color: white; + font-size: 16px; + background: rgb(90, 37, 255); + border-radius: 15px; +} + + + + Back + + + true + + + false + + + + + + + + 241 + 41 + + + + + 241 + 41 + + + + QPushButton { + color: white; + font-size: 16px; + background: rgb(90, 37, 255); + border-radius: 15px; +} + + + + Confirm + + + true + + + false + + + + + + + + + + 10 + 10 + + + + + 501 + 321 + + + + + 16777215 + 16777215 + + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + + + 10 + + + + + + 251 + 41 + + + + + 251 + 41 + + + + Qt::Horizontal + + + + + + + QLabel { + font-size: 15px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + 0 + + + Qt::AutoText + + + false + + + Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter + + + true + + + + + + + QLabel { + font-size: 15px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + of + + + Qt::AutoText + + + false + + + Qt::AlignCenter + + + true + + + + + + + QLabel { + font-size: 15px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + 0 + + + Qt::AutoText + + + false + + + Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + + + true + + + + + + + QLabel { + font-size: 12px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + seconds + + + Qt::AutoText + + + false + + + Qt::AlignCenter + + + true + + + + + + + + + + + + + + diff --git a/src/ceus/seg_loading/ui/voi_preview.ui b/src/ceus/seg_loading/ui/voi_preview.ui new file mode 100644 index 0000000..c892fa2 --- /dev/null +++ b/src/ceus/seg_loading/ui/voi_preview.ui @@ -0,0 +1,1149 @@ + + + confirmVoi + + + + 0 + 0 + 1512 + 834 + + + + + 0 + 0 + + + + Confirm Segmentation + + + QWidget { + background: rgb(42, 42, 42); +} + + + + + 10 + 10 + 1351 + 951 + + + + + + + 0 + + + QLayout::SetMaximumSize + + + + + + 341 + 601 + + + + + 241 + 601 + + + + <html><head/><body><p><br/></p></body></html> + + + QWidget { + background-color: rgb(28, 0, 101); +} + + + + + 0 + 0 + 341 + 121 + + + + + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(99, 0, 174); + border: 1px solid black; +} + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + 70 + 0 + 191 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Scan Selection: + + + Qt::AlignCenter + + + + + + -60 + 50 + 191 + 51 + + + + QLabel { + font-size: 16px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Scan: + + + Qt::AlignCenter + + + + + + 70 + 50 + 261 + 51 + + + + QLabel { + font-size: 14px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; +} + + + Sample filename + + + Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + + + + + + + 0 + 120 + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(99, 0, 174); + border: 1px solid black; +} + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + 0 + 40 + 341 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Segmentation Selection + + + Qt::AlignCenter + + + + + + + 0 + 360 + 341 + 121 + + + + + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(49, 0, 124); + border: 1px solid black; +} + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + 0 + 30 + 341 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Results + + + Qt::AlignCenter + + + + + + + 0 + 480 + 341 + 121 + + + + + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(49, 0, 124); + border: 1px solid black; +} + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + 20 + 30 + 301 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Visualizations + + + Qt::AlignCenter + + + + + + + 0 + 240 + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(49, 0, 124); + border: 1px solid black; +} + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + 0 + 30 + 341 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight:bold; +} + + + Analysis Configuration + + + Qt::AlignCenter + + + + + + + + + + 341 + 0 + + + + + 341 + 16777215 + + + + QFrame { + background-color: rgb(28, 0, 101); +} + + + + QLayout::SetMinAndMaxSize + + + 10 + + + 10 + + + 10 + + + 10 + + + + + QLabel { background-color : rgb(42, 42, 42); color : red; } + + + Observing! + + + Qt::AlignCenter + + + + + + + QLabel { background-color : rgb(42, 42, 42); color : green; } + + + Navigating! + + + Qt::AlignCenter + + + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + + Qt::Vertical + + + + 20 + 40 + + + + + + + + + 131 + 41 + + + + + 131 + 41 + + + + QPushButton { + color: white; + font-size: 16px; + background: rgb(90, 37, 255); + border-radius: 15px; +} + + + Show/Hide Cross + + + + + + + + 131 + 41 + + + + + 131 + 41 + + + + QPushButton { + color: white; + font-size: 16px; + background: rgb(90, 37, 255); + border-radius: 15px; +} + + + Back + + + + + + + + + + + + + + + + + 0 + 0 + + + + QLabel { + font-size: 18px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + Sagittal Plane + + + Qt::AlignCenter + + + Qt::NoTextInteraction + + + + + + + + 1 + 1 + + + + + 321 + 301 + + + + ArrowCursor + + + true + + + QFrame::Box + + + + + + + + + + 5 + + + + + + 0 + 0 + + + + QLabel { + font-size: 15px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + 0 + + + Qt::AlignCenter + + + Qt::NoTextInteraction + + + + + + + + 0 + 0 + + + + QLabel { + font-size: 15px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + of + + + Qt::AlignCenter + + + Qt::NoTextInteraction + + + + + + + + 0 + 0 + + + + QLabel { + font-size: 15px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + 0 + + + Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + + + Qt::NoTextInteraction + + + + + + + + + + + + + + 0 + 0 + + + + QLabel { + font-size: 18px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + Axial Plane + + + Qt::AlignCenter + + + Qt::NoTextInteraction + + + + + + + + 1 + 1 + + + + + 321 + 301 + + + + ArrowCursor + + + true + + + QFrame::Box + + + + + + + + + + 5 + + + QLayout::SetDefaultConstraint + + + + + + 0 + 0 + + + + QLabel { + font-size: 15px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + 0 + + + Qt::AlignCenter + + + Qt::NoTextInteraction + + + + + + + + 0 + 0 + + + + QLabel { + font-size: 15px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + of + + + Qt::AlignCenter + + + Qt::NoTextInteraction + + + + + + + + 0 + 0 + + + + QLabel { + font-size: 15px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + 0 + + + Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + + + Qt::NoTextInteraction + + + + + + + + + + + + + + 0 + 0 + + + + QLabel { + font-size: 18px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + Coronal Plane + + + Qt::AlignCenter + + + Qt::NoTextInteraction + + + + + + + + 1 + 1 + + + + + 321 + 301 + + + + ArrowCursor + + + true + + + QFrame::Box + + + + + + + + + + 5 + + + + + + 0 + 0 + + + + QLabel { + font-size: 15px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + 0 + + + Qt::AlignCenter + + + Qt::NoTextInteraction + + + + + + + + 0 + 0 + + + + QLabel { + font-size: 15px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + of + + + Qt::AlignCenter + + + Qt::NoTextInteraction + + + + + + + + 0 + 0 + + + + QLabel { + font-size: 15px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + 0 + + + Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + + + Qt::NoTextInteraction + + + + + + + + + + + 10 + + + 10 + + + 20 + + + + + + + QLabel { + font-size: 20px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + Current Slice (in seconds): + + + Qt::AutoText + + + false + + + Qt::AlignCenter + + + true + + + + + + + + + + 285 + 0 + + + + + 285 + 16777215 + + + + QSlider { + color: white; +} + + + Qt::Horizontal + + + + + + + + 70 + 16777215 + + + + QDoubleSpinBox { + background: white; + color: black; +} + + + + + + + QLabel { + font-size: 17px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + of + + + Qt::AlignCenter + + + Qt::NoTextInteraction + + + + + + + QLabel { + font-size: 17px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + 0 + + + Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + + + Qt::NoTextInteraction + + + + + + + + + + + 0 + + + 0 + + + + + + 0 + 36 + + + + QPushButton { + color: white; + font-size: 16px; + background: rgb(90, 37, 255); + border-radius: 15px; +} + + + Confirm Segmentation + + + false + + + + + + + + + + + + + + + diff --git a/src/ceus/seg_loading/views/draw_roi_widget.py b/src/ceus/seg_loading/views/draw_roi_widget.py index f72738f..95ee444 100644 --- a/src/ceus/seg_loading/views/draw_roi_widget.py +++ b/src/ceus/seg_loading/views/draw_roi_widget.py @@ -11,14 +11,24 @@ from scipy import interpolate from PIL import Image, ImageDraw from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.colors import LinearSegmentedColormap -from PyQt6.QtWidgets import QWidget, QHBoxLayout, QFileDialog +from PyQt6.QtWidgets import QWidget, QHBoxLayout, QFileDialog, QSlider, QVBoxLayout, QFrame, QCheckBox, QLabel from PyQt6.QtCore import pyqtSignal, Qt from ...mvc.base_view import BaseViewMixin from ..ui.draw_roi_ui import Ui_constructRoi from engines.ceus.src.data_objs import UltrasoundImage +# Philips CEUS Colormap: Grayscale -> Red -> Yellow +philips_colors = [ + (0.0, 0.0, 0.0), # 0% - Black + (0.4, 0.4, 0.4), # 40% - Gray + (0.8, 0.0, 0.0), # 80% - Red + (1.0, 1.0, 0.0) # 100% - Yellow +] +philips_cmap = LinearSegmentedColormap.from_list("philips_ceus", philips_colors) + class DrawROIWidget(QWidget, BaseViewMixin): """ @@ -33,6 +43,7 @@ class DrawROIWidget(QWidget, BaseViewMixin): segmentation_saved = pyqtSignal(str) # emit with saved file path back_requested = pyqtSignal() close_requested = pyqtSignal() + apply_preprocs_preview = pyqtSignal(list) # List of dicts with 'name' and 'kwargs' keys def __init__(self, image_data: UltrasoundImage, parent: Optional[QWidget] = None): QWidget.__init__(self, parent) @@ -56,11 +67,20 @@ def __init__(self, image_data: UltrasoundImage, parent: Optional[QWidget] = None self._im_artist = None # The image artist for fast updates self._roi_plot_artist = None # The ROI artist for fast updates self._roi_scatter_artist = None # The ROI scatter artist for fast updates - self._target_frame = 0 # Target frame for smooth transitions self._frame_update_pending = False + # Enhancement parameters + self._clahe_clip_limit = 1.2 + self._gamma = 1.5 self._width_scale = 1.0 + # Enhancement parameters + self._clahe_clip_limit = 1.2 + self._gamma = 1.5 + self._use_philips_ceus = False + self._enhanced_cache = None # Cache for enhanced current frame + self._enhanced_cache_idx = -1 + self._setup_ui() self._connect_signals() self._show_draw_type_selection() @@ -206,10 +226,7 @@ def _update_frame_animated(self, frame_num) -> list: if not self._frame_update_pending: return [self._im_artist, self._roi_plot_artist[0], self._roi_scatter_artist] - # Update to target frame - if self._frame != self._target_frame: - self._frame = self._target_frame - self._update_frame_display(self._frame) + self._update_frame_display(self._frame) self._update_roi_plot() self._update_roi_scatter() @@ -256,70 +273,156 @@ def _update_aspect_ratio(self) -> None: self._matplotlib_canvas.draw_idle() def _setup_enhancement_controls(self) -> None: - """Add enhancement sliders to the sidebar.""" - from PyQt6.QtWidgets import QVBoxLayout, QLabel, QSlider, QFrame - + """Add enhancement sliders beside the frame slider in a single horizontal line.""" + # Container frame for enhancement controls enh_group = QFrame() enh_group.setStyleSheet("background-color: rgba(255, 255, 255, 0); border: none;") - container_layout = QVBoxLayout(enh_group) - container_layout.setContentsMargins(0, 10, 0, 10) + + # Main horizontal layout for the enhancement section + container_layout = QHBoxLayout(enh_group) + container_layout.setContentsMargins(0, 0, 15, 0) container_layout.setSpacing(15) - def create_enh_column(label_text, min_val, max_val, current_val, callback): - col_widget = QWidget() - col_layout = QVBoxLayout(col_widget) - col_layout.setContentsMargins(0, 0, 0, 0) - col_layout.setSpacing(5) + def create_compact_control(label_text, min_val, max_val, current_val, callback): + # Widget to hold label, slider, and value in ONE line + ctrl_widget = QWidget() + ctrl_layout = QHBoxLayout(ctrl_widget) + ctrl_layout.setContentsMargins(0, 0, 0, 0) + ctrl_layout.setSpacing(5) lbl = QLabel(label_text) - lbl.setStyleSheet("font-size: 14px; color: white; font-weight: bold;") - lbl.setAlignment(Qt.AlignmentFlag.AlignCenter) - col_layout.addWidget(lbl) + lbl.setStyleSheet("font-size: 10px; color: white; font-weight: bold;") + ctrl_layout.addWidget(lbl) - row_layout = QHBoxLayout() + # Slider slider = QSlider(Qt.Orientation.Horizontal) slider.setRange(min_val, max_val) slider.setValue(current_val) - slider.setMinimumWidth(100) - slider.setMaximumWidth(120) + slider.setStyleSheet(self._ui.frame_slider.styleSheet()) + slider.setFixedWidth(70) + slider.setFixedHeight(12) slider.valueChanged.connect(callback) - + ctrl_layout.addWidget(slider) + val_lbl = QLabel(f"{current_val/10.0:.1f}") - val_lbl.setMinimumWidth(40) - val_lbl.setStyleSheet("color: #3498db; font-weight: bold; font-size: 14px;") + val_lbl.setStyleSheet("color: #3498db; font-weight: bold; font-size: 10px;") + val_lbl.setMinimumWidth(22) + val_lbl.setAlignment(Qt.AlignmentFlag.AlignLeft) + ctrl_layout.addWidget(val_lbl) - row_layout.addWidget(slider) - row_layout.addWidget(val_lbl) - col_layout.addLayout(row_layout) - - return col_widget, slider, val_lbl + return ctrl_widget, slider, val_lbl - width_col, self.width_slider, self.width_val_lbl = create_enh_column( + # Create controls + clahe_w, self.clahe_slider, self.clahe_val_lbl = create_compact_control( + "CLAHE", 1, 100, int(self._clahe_clip_limit * 10), self._on_clahe_changed + ) + gamma_w, self.gamma_slider, self.gamma_val_lbl = create_compact_control( + "GAMMA", 1, 40, int(self._gamma * 10), self._on_gamma_changed + ) + width_w, self.width_slider, self.width_val_lbl = create_compact_control( "WIDTH", 1, 50, int(self._width_scale * 10), self._on_width_changed ) - container_layout.addWidget(width_col) - - # Add to the layout below the frame slider - self._ui.side_bar_layout.addWidget(enh_group) + # Add to horizontal layout + container_layout.addWidget(clahe_w) + container_layout.addWidget(gamma_w) + container_layout.addWidget(width_w) + + # Pseudo coloring toggle nicely aligned + if not (self._image_data.pixel_data.ndim == 4 and self._image_data.pixel_data.shape[3] > 1): + # For RGB images, disable the Philips colormap option since it doesn't apply + self.philips_check = QCheckBox("Pseudo coloring") + self.philips_check.setStyleSheet("color: white; font-weight: bold; font-size: 11px;") + self.philips_check.stateChanged.connect(self._on_philips_toggled) + container_layout.addWidget(self.philips_check) + + # Add to the layout beside the frame slider (below the image) + self._ui.frameControlsLayout.insertWidget(0, enh_group) + + def _on_clahe_changed(self, value: int) -> None: + """Handle CLAHE clip limit change.""" + self._clahe_clip_limit = value / 10.0 + if hasattr(self, 'clahe_val_lbl'): + self.clahe_val_lbl.setText(f"{self._clahe_clip_limit:.1f}") + self._invalidate_enhancement_cache() + + def _on_gamma_changed(self, value: int) -> None: + """Handle gamma change.""" + self._gamma = value / 10.0 + if hasattr(self, 'gamma_val_lbl'): + self.gamma_val_lbl.setText(f"{self._gamma:.1f}") + self._invalidate_enhancement_cache() + + def _invalidate_enhancement_cache(self) -> None: + """Invalidate the enhancement cache (e.g. when parameters change).""" + self._enhanced_cache = None + self._enhanced_cache_idx = -1 + self._frame_update_pending = True # Trigger update to request new enhanced frame + + def _on_philips_toggled(self, state: int) -> None: + self._use_philips_ceus = state == Qt.CheckState.Checked.value + if self._im_artist: + new_cmap = philips_cmap if self._use_philips_ceus else 'gray' + self._im_artist.set_cmap(new_cmap) + + # # Force a call to set_array() to dirty the artist for the blitter + # self._update_frame_display(self._frame) + + # Flag the animation loop to blit the newly dirtied image on its next tick + self._frame_update_pending = True + + def _request_enhanced_frame(self, frame_2d: np.ndarray) -> np.ndarray: + """Enhance a 2D image frame using backend engine functions.""" + # Create a temporary UltrasoundImage for the current frame + temp_im = UltrasoundImage(self._image_data.scan_path) + temp_im.pixel_data = frame_2d.T[None].T.copy() # Add back time dimension for processing + temp_im.pixdim = self._image_data.pixdim + temp_im.frame_rate = self._image_data.frame_rate + + clahe_preproc_dict = { + 'name': 'enhance_clahe', + 'image_data': temp_im, + 'frame_ix': self._frame, + 'kwargs': { + 'clip_limit': self._clahe_clip_limit, + 'tile_grid_size': (8, 8), + } + } + + gamma_preproc_dict = { + 'name': 'enhance_gamma', + 'image_data': None, # signal to reuse the already CLAHE-enhanced image (all preprocs in the same batch share the same image input) + 'frame_ix': self._frame, + 'kwargs': { + 'gamma': self._gamma, + } + } + + preproc_dicts = [clahe_preproc_dict, gamma_preproc_dict] + self.apply_preprocs_preview.emit(preproc_dicts) # synchronous call to apply the enhancements and update the cache via the connected slot def _on_frame_changed(self, value: int) -> None: """Handle frame slider change with optimized performance.""" - self._target_frame = value + self._frame = value self._frame_update_pending = True - # Animation will handle the actual update efficiently + + def update_enhancement_cache(self, enhanced_frame: np.ndarray, frame: int) -> None: + """Receives enhanced frame from controller and stores it for display.""" + self._enhanced_cache = enhanced_frame.T[0].T # shape is (1, H, W) from the temp_im — take the single frame + self._enhanced_cache_idx = frame + self._frame_update_pending = True # Flag to update display on next animation tick def _update_frame_display(self, frame_index: int) -> None: - """Update the frame display with consistent parameters.""" if self._im_artist: - self._displayed_im = self._all_frames[frame_index] - self._im_artist.set_array(self._displayed_im) - self._ui.cur_frame_label.setText(str(np.round(frame_index*self._image_data.frame_rate, decimals=2))) - - def _force_frame_update(self) -> None: - """Force immediate frame update without animation (for initialization).""" - self._update_frame_display(self._frame) - self._matplotlib_canvas.draw_idle() + if self._enhanced_cache is None or self._enhanced_cache_idx != frame_index: + # synchronously update self._enhanced_cache with the new enhanced frame + # for the current index + self._request_enhanced_frame(self._all_frames[frame_index]) + self._im_artist.set_array(self._enhanced_cache) + + self._ui.cur_frame_label.setText( + str(np.round(frame_index * self._image_data.frame_rate, decimals=2)) + ) def _cleanup_animation(self): """Stop and clean up animation safely.""" @@ -351,13 +454,6 @@ def __del__(self): self._cleanup_animation() except: pass # Ignore errors during cleanup - - def _on_frame_selected(self) -> None: - """Handle frame selection confirmation.""" - # Make sure we're on the correct frame before confirming - if self._frame != self._target_frame: - self._frame = self._target_frame - self._force_frame_update() def _on_back_clicked(self) -> None: """Handle back button click.""" diff --git a/src/ceus/seg_loading/views/draw_voi_widget.py b/src/ceus/seg_loading/views/draw_voi_widget.py index 858a869..8e2b1c9 100644 --- a/src/ceus/seg_loading/views/draw_voi_widget.py +++ b/src/ceus/seg_loading/views/draw_voi_widget.py @@ -128,7 +128,7 @@ class DrawVOIWidget(QWidget, BaseViewMixin): """ # Signals for communicating with controller - file_selected = pyqtSignal(dict) # {'seg_path': str, 'seg_type': str} + segmentation_saved = pyqtSignal(str) # emit with saved file path back_requested = pyqtSignal() close_requested = pyqtSignal() apply_preprocs_preview = pyqtSignal(list) # List of dicts with 'name' and 'kwargs' keys @@ -146,6 +146,7 @@ def __init__(self, image_data: UltrasoundImage, parent: Optional[QWidget] = None self._width_scale_axial = 1.0 self._width_scale_sagittal = 1.0 self._width_scale_coronal = 1.0 + self._mask_alpha = 125 # Default alpha for mask overlay (0-255) self._use_philips_ceus = False # Cache for enhanced volume @@ -156,7 +157,6 @@ def __init__(self, image_data: UltrasoundImage, parent: Optional[QWidget] = None self._drawing_widgets = [] self._voi_decision_widgets = [] self._save_voi_widgets = [] - self._voi_alpha_widgets = [] # Crosshair / navigation state self._crosshair_active = False @@ -173,11 +173,12 @@ def __init__(self, image_data: UltrasoundImage, parent: Optional[QWidget] = None self._current_drawing_plane = None self._drawn_rois: List[Tuple[int, List[float], np.ndarray]] = [] # (plane_index, [roi_coords_xyz], roi_mask) self._roi_masks_overlap = np.zeros((self._x_len, self._y_len, self._z_len, 4), dtype=np.uint8) + self._roi_masks_overlap_1d = np.zeros((self._x_len, self._y_len, self._z_len, 4), dtype=np.uint8) # Per-plane resources (axial, sagittal, coronal) self._ax_sag_cor_matplotlib_canvases = [None, None, None] self._ax_sag_cor_planes = (None, None, None) - self._ax_sag_cor_index_maps = ((0, 1), (2, 1), (2, 0)) # dims that vary per plane + self._ax_sag_cor_index_maps = ((0, 1), (2, 1), (0, 2)) # (horiz_dim, vert_dim) self._ax_sag_cor_animations = [None, None, None] self._ax_sag_cor_plane_artists = [None, None, None] self._ax_sag_cor_crosshair_lines = [(None, None), (None, None), (None, None)] @@ -192,16 +193,17 @@ def __init__(self, image_data: UltrasoundImage, parent: Optional[QWidget] = None self._setup_ui() self._setup_matplotlib_canvases() self._initialize_plane_displays() + self._update_aspect_ratios() self._setup_all_plane_animations() self._connect_signals() self._connect_matplotlib_events() self.setFocusPolicy(Qt.FocusPolicy.StrongFocus) - self._update_scan_display() # Initial UI update - self._refresh_frames() # Mark all planes for first update + self._update_scan_display() # Initial UI update + self._refresh_frames() # Mark all planes for first update def update_enhancement_cache(self, enhanced_frame: np.ndarray, frame: int) -> None: """Update the displayed image data, e.g. after preprocessing.""" - assert enhanced_frame.shape[:-1] == self._pix_data.shape[:-1], "Enhanced pixel data must have the same shape as original" + assert enhanced_frame.shape[:-1] == self._pix_data.shape[:-1], f"Enhanced pixel data must have the same shape as original" self._enhanced_cache = enhanced_frame[:, :, :, 0] # Store only the current time frame in cache self._enhanced_cache_frame = frame self._refresh_frames() @@ -274,15 +276,11 @@ def _on_canvas_motion(self, event, plane_ix: int): # type: ignore params = {} if self._current_drawing_plane == None or self._current_drawing_plane == plane_ix+1: if self._crosshair_xyzt[dim_x] != new_xval: - if dim_x == 0: params['x'] = new_xval - elif dim_x == 1: params['y'] = new_xval - elif dim_x == 2: params['z'] = new_xval - elif dim_x == 3: params['t'] = new_xval + key = ['x', 'y', 'z', 't'][dim_x] + params[key] = new_xval if self._crosshair_xyzt[dim_y] != new_yval: - if dim_y == 0: params['x'] = new_yval - elif dim_y == 1: params['y'] = new_yval - elif dim_y == 2: params['z'] = new_yval - elif dim_y == 3: params['t'] = new_yval + key = ['x', 'y', 'z', 't'][dim_y] + params[key] = new_yval if params: self.set_crosshair(**params) @@ -324,13 +322,6 @@ def _setup_ui(self) -> None: self._ui.clear_save_folder_button, self._ui.export_voi_button, ] - self._voi_alpha_widgets = [ - self._ui.alpha_label, - self._ui.alpha_of_label, - self._ui.alpha_spin_box, - self._ui.alpha_status, - self._ui.alpha_total - ] self._ui.scan_name_input.setText(self._image_data.scan_name) self._ui.toggle_crosshair_visibility_button.setText('Hide Crosshair') @@ -339,7 +330,7 @@ def _setup_ui(self) -> None: self._ui.interp_loading_label.hide(); self._ui.saving_voi_label.hide() self._ui.navigating_label.hide(); self._ui.undo_last_roi_button.hide() self._hide_widget_lists([self._voi_decision_widgets, - self._save_voi_widgets, self._voi_alpha_widgets]) + self._save_voi_widgets]) # Setup enhancement controls in sidebar self._setup_enhancement_controls() @@ -412,10 +403,19 @@ def create_enh_column(label_text, min_val, max_val, current_val, callback): width_cor_col, self.width_cor_slider, self.width_cor_val_lbl = create_enh_column( "WIDTH (COR)", 1, 50, int(self._width_scale_coronal * 10), self._on_width_coronal_changed ) + # Add VOI Alpha slider + alpha_col, self.alpha_slider, self.alpha_val_lbl = create_enh_column( + "VOI ALPHA", 0, 2550, int(self._mask_alpha * 10), lambda v: self._on_alpha_changed(v // 10) + ) + # Fix label to show integer for alpha + self.alpha_val_lbl.setText(str(self._mask_alpha)) + self.alpha_slider.valueChanged.disconnect() + self.alpha_slider.valueChanged.connect(lambda v: (self._on_alpha_changed(v // 10), self.alpha_val_lbl.setText(str(v // 10)))) row1_layout.addWidget(clahe_col) row1_layout.addWidget(gamma_col) + # Philips CEUS Toggle (Pseudocoloring) - now in row 1 self.philips_check = QCheckBox("Pseudocoloring") self.philips_check.setStyleSheet("color: white; font-weight: bold; font-size: 14px;") @@ -426,6 +426,7 @@ def create_enh_column(label_text, min_val, max_val, current_val, callback): "Yellow → Peak Enhancement") self.philips_check.stateChanged.connect(self._on_philips_toggled) row1_layout.addWidget(self.philips_check) + row1_layout.addWidget(alpha_col) row2_layout.addWidget(width_ax_col) row2_layout.addWidget(width_sag_col) @@ -484,6 +485,16 @@ def _on_width_coronal_changed(self, value: int) -> None: self.width_cor_val_lbl.setText(f"{self._width_scale_coronal:.1f}") self._update_aspect_ratios() self._refresh_frames() + + def _on_alpha_changed(self, value: int) -> None: + """Handle alpha transparency change for the VOI mask.""" + self._mask_alpha = int(value) + + # Update the master overlap mask by blending all stored ROIs + self._roi_masks_overlap[self._roi_masks_overlap_1d, 3] = self._mask_alpha + + self._current_drawing_plane = None + self._refresh_frames() def _update_aspect_ratios(self) -> None: """Update the aspect ratios of the axes based on the plane-specific width scales.""" @@ -511,7 +522,7 @@ def _update_aspect_ratios(self) -> None: # Index 2: Coronal (Plane 2) if self._ax_sag_cor_matplotlib_canvases[2]: dx, dz = pix[0], pix[2] - aspect = (dx / dz if dz != 0 else 1.0) * self._width_scale_coronal + aspect = (dz / dx if dx != 0 else 1.0) * self._width_scale_coronal fig2 = self._ax_sag_cor_matplotlib_canvases[2].figure if fig2.axes: fig2.axes[0].set_aspect(aspect) @@ -537,21 +548,18 @@ def _invalidate_enhancement_cache(self) -> None: def _setup_matplotlib_canvases(self): """Setup matplotlib canvases for high-performance plane display.""" - for i in range(3): - fig = plt.figure() - fig.patch.set_facecolor((0, 0, 0, 0)) + for i, parent_label in enumerate(self._ax_sag_cor_planes): + fig, ax = plt.subplots(facecolor='black') fig.subplots_adjust(left=0, right=1, top=1, bottom=0) + ax.axis('off') canvas = FigureCanvas(fig) - canvas.figure.patch.set_facecolor((0, 0, 0, 0)) + canvas.setParent(parent_label) canvas.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + # Clear label text but keep the background/frame + parent_label.setText("") self._ax_sag_cor_matplotlib_canvases[i] = canvas - layout = QHBoxLayout(self._ax_sag_cor_planes[i]) - layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(canvas, stretch=1) - self._ax_sag_cor_planes[i].setLayout(layout) - # Make canvas expand to fill its QLabel container # Install event filter on parent label for resize handling - self._ax_sag_cor_planes[i].installEventFilter(self) + parent_label.installEventFilter(self) # Initial sizing pass self._resize_all_canvases() @@ -561,39 +569,23 @@ def _initialize_plane_displays(self) -> None: if not canvas: continue try: - fig = canvas.figure - if plane_ix == 0: # Axial - dy = self._image_data.pixdim[1] - dx = self._image_data.pixdim[0] - base_aspect = dy / dx if dx != 0 else 1 - aspect = base_aspect * self._width_scale_axial - elif plane_ix == 1: # Sagittal - dy = self._image_data.pixdim[1] - dz = self._image_data.pixdim[2] - base_aspect = dy / dz if dz != 0 else 1 - aspect = base_aspect * self._width_scale_sagittal - elif plane_ix == 2: # Coronal - dx = self._image_data.pixdim[0] - dz = self._image_data.pixdim[2] - base_aspect = dx / dz if dz != 0 else 1 - aspect = base_aspect * self._width_scale_coronal - else: - self.show_error(f"Invalid plane index: {plane_ix}") - fig.clear() - ax = fig.add_subplot(111) - ax.axis('off') + # Reuse the axes already created by _setup_matplotlib_canvases + # (do NOT call fig.clear() — that destroys the padding-free subplots_adjust) + ax = canvas.figure.axes[0] + # Get initial slice slice_arr = self._get_plane_slice(plane_ix, initializing=True) mask_arr = self._get_mask_slice(plane_ix) current_cmap = philips_cmap if self._use_philips_ceus else 'gray' - artist = ax.imshow(slice_arr, cmap=current_cmap, aspect=float(aspect), zorder=1, animated=True) # add vmin and vmax for the 0 - 255 show + artist = ax.imshow(slice_arr, cmap=current_cmap, interpolation='nearest', + zorder=1, vmin=0, vmax=255) v_line = ax.axvline(x=0, color='yellow', lw=0.8, animated=True, zorder=11) h_line = ax.axhline(y=0, color='yellow', lw=0.8, animated=True, zorder=11) - seg_mask = ax.imshow(mask_arr, zorder=8, aspect=float(aspect), animated=True) + seg_mask = ax.imshow(mask_arr, interpolation='nearest', zorder=8) roi_plot = ax.plot([], [], c='cyan', lw=1, zorder=9, animated=True) point_scatter = ax.scatter([], [], c='red', s=5, marker='o', zorder=10, animated=True) - + self._ax_sag_cor_plane_artists[plane_ix] = artist self._ax_sag_cor_crosshair_lines[plane_ix] = (v_line, h_line) self._ax_sag_cor_point_scatters[plane_ix] = point_scatter @@ -606,31 +598,33 @@ def _initialize_plane_displays(self) -> None: self.show_error(f"Error initializing plane display {plane_ix}: {e}") def _get_plane_slice(self, plane_ix: int, initializing=False): - """Return 2D numpy slice for given plane index based on current crosshair.""" - idx = self._get_plane_indices(plane_ix) - current_t = self._crosshair_xyzt[3] + """Extract a 2D image slice for the specified plane at current crosshair indices.""" + x, y, z, t = self._crosshair_xyzt # Check if we need to enhance a new frame - if not initializing and (self._enhanced_cache is None or self._enhanced_cache_frame != current_t): + if not initializing and (self._enhanced_cache is None or self._enhanced_cache_frame != t): # Get the 3D volume for current time frame - current_frame_3d = self._pix_data[:, :, :, current_t] + current_frame_3d = self._pix_data[:, :, :, t] # Enhance the entire 3D volume ONCE per frame self._enhance_volume(current_frame_3d) # performs enhancement SYNCHRONOUSLY - self._enhanced_cache_frame = current_t + self._enhanced_cache_frame = t elif initializing: - self._enhanced_cache = self._image_data.pixel_data[:, :, :, current_t] # Cache the initial frame without enhancement for faster startup - - # Extract the 2D slice from cached enhanced volume - slice_idx = list(idx[:3]) # Remove time dimension - arr = self._enhanced_cache[tuple(slice_idx)] + self._enhanced_cache = self._image_data.pixel_data[:, :, :, t] # Cache the initial frame without enhancement for faster startup + + vol = self._enhanced_cache - if arr.ndim != 2: - arr = arr.squeeze() - # Axial plane (index 0) needs transpose for correct orientation - if plane_ix == 0: - arr = arr.T - return arr + if plane_ix == 0: # Axial (XY) at Z -> show (Y, X) + return vol[:, :, z].T + elif plane_ix == 1: # Sagittal (YZ) at X -> show (Z, Y) then rotate 90 CW -> (Y, Z) + # Match DrawVOIWidget approach: arr.T then rot90(k=-1) + arr = vol[x, :, :] + arr_t = arr + return arr_t + elif plane_ix == 2: # Coronal (XZ) at Y -> show (Y, X) + # Mirror Axial for Coronal to match DrawVOI behavior + return vol[:, y, :].T + return np.zeros((10, 10)) def _enhance_volume(self, volume_3d: np.ndarray) -> None: """Enhance a 3D image volume using predefined enhancement methods in the backend engine.""" @@ -664,12 +658,18 @@ def _enhance_volume(self, volume_3d: np.ndarray) -> None: def _get_mask_slice(self, plane_ix: int): """Return RGBA numpy slice for the mask of the given plane index.""" - idx = self._get_plane_indices(plane_ix)[:-1] # no time dimension - arr = self._roi_masks_overlap[idx] - # Mask needs transpose for correct orientation to match the image slice - if plane_ix == 0: - arr = np.transpose(arr, (1, 0, 2)) # Transpose for axial plane - return arr + x, y, z, _ = self._crosshair_xyzt + if plane_ix == 0: # Axial (XY) at Z -> show (Y, X) + arr = self._roi_masks_overlap[:, :, z, :] + return np.transpose(arr, (1, 0, 2)) + elif plane_ix == 1: # Sagittal (YZ) at X -> show (Z, Y) then rotate 90 CW -> (Y, Z) + arr = self._roi_masks_overlap[x, :, :, :] + # arr_t = np.transpose(arr, (1, 0, 2)) + return arr + elif plane_ix == 2: # Coronal (XZ) at Y -> show (Y, X) + arr = self._roi_masks_overlap[:, y, :, :] + return np.transpose(arr, (1, 0, 2)) + return np.zeros((10, 10, 4), dtype=np.uint8) def _get_plane_indices(self, plane_ix: int) -> Tuple[int]: """Return a list of indices for the given plane.""" @@ -950,11 +950,11 @@ def _on_roi_close(self): elif plane_ix == 1: # Sagittal target_slice_mask[fixed_val, :, :] = mask_2d elif plane_ix == 2: # Coronal - target_slice_mask[:, fixed_val, :] = mask_2d + target_slice_mask[:, fixed_val, :] = mask_2d.T # Apply colors to the RGBA mask where the 3D mask is true current_roi_mask_rgba[target_slice_mask, 0] = 255 # Red - current_roi_mask_rgba[target_slice_mask, 3] = 128 # Alpha + current_roi_mask_rgba[target_slice_mask, 3] = self._mask_alpha # Alpha # Store the original points and the generated mask self._drawn_rois.append((self._current_drawing_plane, current_roi_pts, current_roi_mask_rgba)) @@ -964,8 +964,8 @@ def _on_roi_close(self): for _, _, roi_mask in self._drawn_rois: # Add color channels, clipping at 255 self._roi_masks_overlap[:,:,:,:3] = np.clip(self._roi_masks_overlap[:,:,:,:3].astype(np.uint16) + roi_mask[:,:,:,:3].astype(np.uint16), 0, 255).astype(np.uint8) - # Add alpha, clipping at a reasonable max to avoid full opacity - self._roi_masks_overlap[:,:,:,3] = np.clip(self._roi_masks_overlap[:,:,:,3].astype(np.uint16) + roi_mask[:,:,:,3].astype(np.uint16), 0, 128).astype(np.uint8) + self._roi_masks_overlap[:,:,:,3] = np.maximum(self._roi_masks_overlap[:,:,:,3].astype(np.uint16), roi_mask[:,:,:,3].astype(np.uint16)).astype(np.uint8) + self._roi_masks_overlap_1d = self._roi_masks_overlap[..., 0].astype(bool) # Clear points and hide the ROI plot for the next ROI self._plotted_pts.clear() @@ -995,6 +995,7 @@ def _on_undo_last_roi(self): self._roi_masks_overlap[:,:,:,:3] = np.clip(self._roi_masks_overlap[:,:,:,:3].astype(np.uint16) + roi_mask[:,:,:,:3].astype(np.uint16), 0, 255).astype(np.uint8) # Add alpha, clipping at a reasonable max to avoid full opacity self._roi_masks_overlap[:,:,:,3] = np.clip(self._roi_masks_overlap[:,:,:,3].astype(np.uint16) + roi_mask[:,:,:,3].astype(np.uint16), 0, 128).astype(np.uint8) + self._roi_masks_overlap_1d = self._roi_masks_overlap[..., 0].astype(bool) # Hide the button if no ROIs are left to undo if not self._drawn_rois: @@ -1016,6 +1017,7 @@ def _on_restart_voi(self): # Reset the drawing state self._drawn_rois.clear() self._roi_masks_overlap.fill(0) + self._roi_masks_overlap_1d.fill(0) self._plotted_pts.clear() self._current_drawing_plane = None @@ -1027,7 +1029,7 @@ def _on_restart_voi(self): def _on_save_voi_clicked(self): self._hide_widget_lists([self._voi_decision_widgets]) - self._show_widget_lists([self._save_voi_widgets, self._voi_alpha_widgets]) + self._show_widget_lists([self._save_voi_widgets]) self._refresh_frames() def _on_export_voi_clicked(self): @@ -1043,7 +1045,9 @@ def _on_export_voi_clicked(self): def _on_save_voi_finished(self, msg): self._ui.saving_voi_label.hide() self._show_widget_lists([self._save_voi_widgets]) - print(msg) + # print(msg) + if hasattr(self, '_last_saved_path'): + self.segmentation_saved.emit(str(self._last_saved_path)) def _on_save_voi_error(self, err): self._ui.saving_voi_label.hide() @@ -1181,7 +1185,11 @@ def _resize_canvas_for(self, label_widget: QLabel): if not canvas: return - canvas.figure.tight_layout(pad=0) + # Match canvas size to the QLabel/placeholder size so it fills edge-to-edge + canvas_width = label_widget.width() + canvas_height = label_widget.height() + canvas.setFixedSize(canvas_width, canvas_height) + canvas.move(0, 0) canvas.draw_idle() def _resize_all_canvases(self): @@ -1272,11 +1280,12 @@ def _save_voi(self): out_name = out_name + '.nii.gz' if not out_name.endswith('.nii.gz') else out_name out_path = Path(self._ui.save_folder_input.text()) / out_name + self._last_saved_path = out_path affine = np.eye(4) for i, res in enumerate(self._image_data.pixdim[:3]): affine[i, i] = res - voi_mask = np.array(self._roi_masks_overlap[:, :, :, 0] / 255.0).astype(np.uint8) + voi_mask = self._roi_masks_overlap_1d.astype(np.uint8) niiarray = nib.Nifti1Image(voi_mask, affine) niiarray.header["descrip"] = self._image_data.scan_name nib.save(niiarray, out_path) @@ -1296,7 +1305,8 @@ def _on_interpolation_finished(self, voi_mask: np.ndarray): # Update the master overlap mask with the new 3D VOI self._roi_masks_overlap.fill(0) self._roi_masks_overlap[voi_mask, 0] = 255 # Red - self._roi_masks_overlap[voi_mask, 3] = 128 # Alpha + self._roi_masks_overlap[voi_mask, 3] = self._mask_alpha # Alpha + self._roi_masks_overlap_1d = voi_mask self._set_interp_loading(False) self._refresh_frames() diff --git a/src/ceus/seg_loading/views/roi_preview_widget.py b/src/ceus/seg_loading/views/roi_preview_widget.py new file mode 100644 index 0000000..3d87a9b --- /dev/null +++ b/src/ceus/seg_loading/views/roi_preview_widget.py @@ -0,0 +1,418 @@ +""" +Segmentation Preview Widget for Segmentation Loading +""" + +from typing import Optional +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.animation as anim +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.colors import LinearSegmentedColormap + +from PyQt6.QtWidgets import QWidget, QHBoxLayout, QSlider, QFrame, QCheckBox, QLabel +from PyQt6.QtCore import pyqtSignal, Qt + +from ...mvc.base_view import BaseViewMixin +from ..ui.roi_preview_ui import Ui_confirmRoi +from engines.ceus.src.data_objs import UltrasoundImage, CeusSeg + +# Philips CEUS Colormap: Grayscale -> Red -> Yellow +philips_colors = [ + (0.0, 0.0, 0.0), # 0% - Black + (0.4, 0.4, 0.4), # 40% - Gray + (0.8, 0.0, 0.0), # 80% - Red + (1.0, 1.0, 0.0) # 100% - Yellow +] +philips_cmap = LinearSegmentedColormap.from_list("philips_ceus", philips_colors) + + +class ROIPreviewWidget(QWidget, BaseViewMixin): + """ + Widget for previewing and confirming segmentation. + + This is the final step in the segmentation loading process where users + can preview the loaded segmentation and confirm it before proceeding. + Designed to be used within the main application widget stack. + """ + + # Signals for communicating with controller + segmentation_confirmed = pyqtSignal(object) + back_requested = pyqtSignal() + close_requested = pyqtSignal() + apply_preprocs_preview = pyqtSignal(list) # List of dicts with 'name' and 'kwargs' keys + + def __init__(self, image_data: UltrasoundImage, seg_data: CeusSeg, + parent: Optional[QWidget] = None): + QWidget.__init__(self, parent) + self.__init_base_view__(parent) + self._ui = Ui_confirmRoi() + self._image_data = image_data + self._seg_data = seg_data + self._matplotlib_canvas: Optional[FigureCanvas] = None + self._frame = 0 + self._all_frames = self._image_data.pixel_data + + # Animation and performance variables + self._animation: Optional[anim.FuncAnimation] = None + self._im_artist = None # The image artist for fast updates + self._roi_mask_artist = None # The ROI artist for fast updates + self._frame_update_pending = False + + # Enhancement parameters + self._clahe_clip_limit = 1.2 + self._gamma = 1.5 + self._width_scale = 1.0 + + # Enhancement parameters + self._clahe_clip_limit = 1.2 + self._gamma = 1.5 + self._alpha = 125 + self._use_philips_ceus = False + self._enhanced_cache = None # Cache for enhanced current frame + self._enhanced_cache_idx = -1 + + self._setup_ui() + self._connect_signals() + + def _setup_ui(self) -> None: + """Setup the user interface.""" + self._ui.setupUi(self) + + # Configure layout for segmentation preview only - use the main layout + self.setLayout(self._ui.main_layout) + + # Configure stretch factors for confirmation + self._ui.full_screen_layout.setStretchFactor(self._ui.side_bar_layout, 1) + self._ui.full_screen_layout.setStretchFactor(self._ui.frame_preview_layout, 10) + + # Ensure the layout fills the entire widget + self._ui.main_layout.setContentsMargins(0, 0, 0, 0) + self._ui.main_layout.setSpacing(0) + self._ui.full_screen_layout.setContentsMargins(0, 0, 0, 0) + self._ui.full_screen_layout.setSpacing(0) + + # Update UI to reflect inputted image and frames + self._ui.scan_name_input.setText(self._image_data.scan_name) + self._ui.frame_slider.setRange(0, self._all_frames.shape[0] - 1) + self._ui.frame_slider.setValue(self._frame) + self._ui.cur_frame_label.setText(str(np.round(self._frame*self._image_data.frame_rate, decimals=2))) + self._ui.total_frames_label.setText(str(np.round(self._all_frames.shape[0]*self._image_data.frame_rate, decimals=2))) + + # Setup matplotlib canvas for frame preview + self._setup_matplotlib_canvas() + self._setup_enhancement_controls() + + # Display frame preview + self._initialize_frame_preview() + + def _setup_matplotlib_canvas(self) -> None: + """Setup matplotlib canvas for high-performance frame display.""" + # Create matplotlib figure and canvas with optimized settings + fig = plt.figure(figsize=(8, 6)) + self._matplotlib_canvas = FigureCanvas(fig) + self._matplotlib_canvas.figure.patch.set_facecolor((0, 0, 0, 0)) + self._matplotlib_canvas.draw() + + # Add canvas to the preview frame widget + layout = QHBoxLayout(self._ui.im_display_frame) + layout.addWidget(self._matplotlib_canvas) + self._ui.im_display_frame.setLayout(layout) + + def _connect_signals(self) -> None: + """Connect UI signals to internal handlers.""" + self._ui.frame_slider.valueChanged.connect(self._on_frame_changed) + self._ui.back_button.clicked.connect(self._on_back_clicked) + self._ui.confirm_roi_button.clicked.connect(self.segmentation_confirmed.emit) + + def _initialize_frame_preview(self) -> None: + """Initialize the frame preview with optimized matplotlib setup.""" + if not self._matplotlib_canvas: + return + + # Calculate aspect ratio + width = self._all_frames.shape[2] * self._image_data.pixdim[1] + height = self._all_frames.shape[1] * self._image_data.pixdim[0] + self.aspect = width / height + + try: + fig = self._matplotlib_canvas.figure + fig.clear() + self._ax = fig.add_subplot(111) + self._ax.set_position([0, 0, 1, 1]) + self._ax.axis("off") + + # Create the initial image artist - this will be reused for all frames + self._displayed_im = self._all_frames[self._frame] + self._seg_mask = np.zeros(self._all_frames.shape[1:-1] + (4,)) + self._seg_mask[..., 3] = self._seg_data.seg_mask * self._alpha + self._seg_mask[..., 0] = self._seg_data.seg_mask * 255 + self._seg_mask = self._seg_mask.astype(int) + self._im_artist = self._ax.imshow(self._displayed_im, cmap="gray", animated=True, zorder=1) + self._roi_mask_artist = self._ax.imshow(self._seg_mask, zorder=10) + + # Set proper aspect ratio + extent = self._im_artist.get_extent() + self._ax.set_aspect(abs((extent[1]-extent[0])/(extent[3]-extent[2]))/self.aspect) + + # Setup the animation for smooth frame updates + self._setup_frame_animation() + + # Initial draw + self._matplotlib_canvas.draw() + + except Exception as e: + self.show_error(f"Error displaying image: {e}") + + def _setup_frame_animation(self) -> None: + """Setup FuncAnimation for high-performance frame updates.""" + if self._animation: + self._animation.event_source.stop() + + def init(): + # Return all artists that will be animated + return [self._im_artist, self._roi_mask_artist] + + self._animation = anim.FuncAnimation( + self._matplotlib_canvas.figure, + self._update_frame_animated, + init_func=init, + interval=16, # ~60 FPS + blit=True, + repeat=False, + cache_frame_data=False + ) + + def _update_frame_animated(self, frame_num) -> list: + """Animation update function for smooth frame transitions.""" + if not self._frame_update_pending: + return [self._im_artist, self._roi_mask_artist] + + self._update_frame_display(self._frame) + self._frame_update_pending = False + return [self._im_artist, self._roi_mask_artist] + + def _on_width_changed(self, value: int) -> None: + """Handle width scale change.""" + self._width_scale = value / 10.0 + if hasattr(self, 'width_val_lbl'): + self.width_val_lbl.setText(f"{self._width_scale:.1f}") + self._update_aspect_ratio() + + def _update_aspect_ratio(self) -> None: + """Update the aspect ratio of the main axes based on width scale.""" + if not hasattr(self, '_ax') or self._ax is None: + return + + # Calculate base physical aspect ratio + width_phys = self._all_frames.shape[2] * self._image_data.pixdim[1] * self._width_scale + height_phys = self._all_frames.shape[1] * self._image_data.pixdim[0] + + if height_phys != 0: + new_aspect = width_phys / height_phys + extent = self._im_artist.get_extent() + self._ax.set_aspect(abs((extent[1]-extent[0])/(extent[3]-extent[2]))/new_aspect) + self._matplotlib_canvas.draw_idle() + + def _setup_enhancement_controls(self) -> None: + """Add enhancement sliders beside the frame slider in a single horizontal line.""" + # Container frame for enhancement controls + enh_group = QFrame() + enh_group.setStyleSheet("background-color: rgba(255, 255, 255, 0); border: none;") + + # Main horizontal layout for the enhancement section + container_layout = QHBoxLayout(enh_group) + container_layout.setContentsMargins(0, 0, 15, 0) + container_layout.setSpacing(15) + + def create_compact_control(label_text, min_val, max_val, current_val, callback): + # Widget to hold label, slider, and value in ONE line + ctrl_widget = QWidget() + ctrl_layout = QHBoxLayout(ctrl_widget) + ctrl_layout.setContentsMargins(0, 0, 0, 0) + ctrl_layout.setSpacing(5) + + lbl = QLabel(label_text) + lbl.setStyleSheet("font-size: 10px; color: white; font-weight: bold;") + ctrl_layout.addWidget(lbl) + + # Slider + slider = QSlider(Qt.Orientation.Horizontal) + slider.setRange(min_val, max_val) + slider.setValue(current_val) + slider.setStyleSheet(self._ui.frame_slider.styleSheet()) + slider.setFixedWidth(70) + slider.setFixedHeight(12) + slider.valueChanged.connect(callback) + ctrl_layout.addWidget(slider) + + val_lbl = QLabel(f"{current_val/10.0:.1f}") + val_lbl.setStyleSheet("color: #3498db; font-weight: bold; font-size: 10px;") + val_lbl.setMinimumWidth(22) + val_lbl.setAlignment(Qt.AlignmentFlag.AlignLeft) + ctrl_layout.addWidget(val_lbl) + + return ctrl_widget, slider, val_lbl + + # Create controls + clahe_w, self.clahe_slider, self.clahe_val_lbl = create_compact_control( + "CLAHE", 1, 100, int(self._clahe_clip_limit * 10), self._on_clahe_changed + ) + gamma_w, self.gamma_slider, self.gamma_val_lbl = create_compact_control( + "GAMMA", 1, 40, int(self._gamma * 10), self._on_gamma_changed + ) + width_w, self.width_slider, self.width_val_lbl = create_compact_control( + "WIDTH", 1, 50, int(self._width_scale * 10), self._on_width_changed + ) + alpha, self.alpha_slider, self.alpha_val_lbl = create_compact_control( + "ALPHA", 0, 255, 0, self._on_alpha_changed + ) + + self.alpha_slider.setRange(0, 255) + self.alpha_slider.setValue(self._alpha) + + # Add to horizontal layout + container_layout.addWidget(clahe_w) + container_layout.addWidget(gamma_w) + container_layout.addWidget(width_w) + container_layout.addWidget(alpha) + + # Pseudo coloring toggle nicely aligned + if not (self._image_data.pixel_data.ndim == 4 and self._image_data.pixel_data.shape[3] > 1): + # For RGB images, disable the Philips colormap option since it doesn't apply + self.philips_check = QCheckBox("Pseudo coloring") + self.philips_check.setStyleSheet("color: white; font-weight: bold; font-size: 11px;") + self.philips_check.stateChanged.connect(self._on_philips_toggled) + container_layout.addWidget(self.philips_check) + + # Add to the layout beside the frame slider (below the image) + self._ui.frameControlsLayout.insertWidget(0, enh_group) + + def _on_clahe_changed(self, value: int) -> None: + """Handle CLAHE clip limit change.""" + self._clahe_clip_limit = value / 10.0 + if hasattr(self, 'clahe_val_lbl'): + self.clahe_val_lbl.setText(f"{self._clahe_clip_limit:.1f}") + self._invalidate_enhancement_cache() + + def _on_gamma_changed(self, value: int) -> None: + """Handle gamma change.""" + self._gamma = value / 10.0 + if hasattr(self, 'gamma_val_lbl'): + self.gamma_val_lbl.setText(f"{self._gamma:.1f}") + self._invalidate_enhancement_cache() + + def _on_alpha_changed(self, value: int) -> None: + self._alpha = int(value) + if hasattr(self, 'alpha_val_lbl'): + self.alpha_val_lbl.setText(f"{self._alpha}") + if not hasattr(self, '_seg_mask'): + return + self._seg_mask[..., 3] = (self._seg_data.seg_mask * self._alpha).astype(int) + self._frame_update_pending = True + + def _invalidate_enhancement_cache(self) -> None: + """Invalidate the enhancement cache (e.g. when parameters change).""" + self._enhanced_cache = None + self._enhanced_cache_idx = -1 + self._frame_update_pending = True # Trigger update to request new enhanced frame + + def _on_philips_toggled(self, state: int) -> None: + self._use_philips_ceus = state == Qt.CheckState.Checked.value + if self._im_artist: + new_cmap = philips_cmap if self._use_philips_ceus else 'gray' + self._im_artist.set_cmap(new_cmap) + + # # Force a call to set_array() to dirty the artist for the blitter + # self._update_frame_display(self._frame) + + # Flag the animation loop to blit the newly dirtied image on its next tick + self._frame_update_pending = True + + def _request_enhanced_frame(self, frame_2d: np.ndarray) -> np.ndarray: + """Enhance a 2D image frame using backend engine functions.""" + # Create a temporary UltrasoundImage for the current frame + temp_im = UltrasoundImage(self._image_data.scan_path) + temp_im.pixel_data = frame_2d.T[None].T.copy() # Add back time dimension for processing + temp_im.pixdim = self._image_data.pixdim + temp_im.frame_rate = self._image_data.frame_rate + + clahe_preproc_dict = { + 'name': 'enhance_clahe', + 'image_data': temp_im, + 'frame_ix': self._frame, + 'kwargs': { + 'clip_limit': self._clahe_clip_limit, + 'tile_grid_size': (8, 8), + } + } + + gamma_preproc_dict = { + 'name': 'enhance_gamma', + 'image_data': None, # signal to reuse the already CLAHE-enhanced image (all preprocs in the same batch share the same image input) + 'frame_ix': self._frame, + 'kwargs': { + 'gamma': self._gamma, + } + } + + preproc_dicts = [clahe_preproc_dict, gamma_preproc_dict] + self.apply_preprocs_preview.emit(preproc_dicts) # synchronous call to apply the enhancements and update the cache via the connected slot + + def _on_frame_changed(self, value: int) -> None: + """Handle frame slider change with optimized performance.""" + self._frame = value + self._frame_update_pending = True + + def update_enhancement_cache(self, enhanced_frame: np.ndarray, frame: int) -> None: + """Receives enhanced frame from controller and stores it for display.""" + self._enhanced_cache = enhanced_frame.T[0].T # shape is (1, H, W) from the temp_im — take the single frame + self._enhanced_cache_idx = frame + self._frame_update_pending = True # Flag to update display on next animation tick + + def _update_frame_display(self, frame_index: int) -> None: + if self._im_artist: + if self._enhanced_cache is None or self._enhanced_cache_idx != frame_index: + # synchronously update self._enhanced_cache with the new enhanced frame + # for the current index + self._request_enhanced_frame(self._all_frames[frame_index]) + self._im_artist.set_array(self._enhanced_cache) + self._roi_mask_artist.set_array(self._seg_mask) + + self._ui.cur_frame_label.setText( + str(np.round(frame_index * self._image_data.frame_rate, decimals=2)) + ) + + def _cleanup_animation(self): + """Stop and clean up animation safely.""" + if self._animation: + try: + self._animation.event_source.stop() + self._animation = None + except: + # Ignore errors if already destroyed + self._animation = None + + def closeEvent(self, event) -> None: + """Clean up animation when widget is closed.""" + self._cleanup_animation() + super().closeEvent(event) + + def hideEvent(self, event): + """Clean up animation when widget is hidden.""" + self._cleanup_animation() + + def showEvent(self, event): + """Restart animation when widget is shown.""" + if self._im_artist and not self._animation: + self._setup_frame_animation() + + def __del__(self): + """Ensure animation is cleaned up when object is destroyed.""" + try: + self._cleanup_animation() + except: + pass # Ignore errors during cleanup + + def _on_back_clicked(self) -> None: + """Handle back button click.""" + self.back_requested.emit() diff --git a/src/ceus/seg_loading/views/spline.py b/src/ceus/seg_loading/views/spline.py index 65bf06e..e65263a 100644 --- a/src/ceus/seg_loading/views/spline.py +++ b/src/ceus/seg_loading/views/spline.py @@ -17,10 +17,25 @@ def calculateSpline(xpts, ypts, zpts=None): # 2D spline interpolation cv.append([xpts[i], ypts[i], zpts[i]]) else: cv.append([xpts[i], ypts[i]]) + + # Remove duplicate points which cause "ValueError: Invalid inputs" in splprep cv = np.array(cv) - if len(xpts) == 2: + if len(cv) > 1: + # Calculate distances between consecutive points + diffs = np.diff(cv, axis=0) + dists = np.sqrt(np.sum(diffs**2, axis=1)) + # Keep first point and points that are sufficiently far from the previous one + mask = np.concatenate(([True], dists > 1e-5)) + cv = cv[mask] + + if len(cv) < 2: + if zpts is not None: + return np.array([cv[0][0]]), np.array([cv[0][1]]), np.array([cv[0][2]]) + return np.array([cv[0][0]]), np.array([cv[0][1]]) + + if len(cv) == 2: tck, _ = interpolate.splprep(cv.T, s=0.0, k=1) - elif len(xpts) == 3: + elif len(cv) == 3: tck, _ = interpolate.splprep(cv.T, s=0.0, k=2) else: tck, _ = interpolate.splprep(cv.T, s=0.0, k=3) @@ -54,6 +69,11 @@ def ellipsoidFitLS(pos): def calculateSpline3D(points): + # If the points have a 4th dimension (time), we strip it as pyvista expects (x, y, z) + points = np.array(points) + if points.shape[1] == 4: + points = points[:, :3] + cloud = pv.PolyData(points, force_float=False) volume = cloud.delaunay_3d(alpha=100) shell = volume.extract_geometry() # type: ignore diff --git a/src/ceus/seg_loading/views/voi_preview_widget.py b/src/ceus/seg_loading/views/voi_preview_widget.py new file mode 100644 index 0000000..32443d3 --- /dev/null +++ b/src/ceus/seg_loading/views/voi_preview_widget.py @@ -0,0 +1,771 @@ +""" +Segmentation Preview Widget for CEUS +""" + +from typing import Optional, Tuple, List, Dict, Any +from pathlib import Path +from matplotlib.backends.backend_qtagg import FigureCanvas +from matplotlib.colors import LinearSegmentedColormap +from PyQt6.QtWidgets import QWidget, QLabel, QHBoxLayout, QSizePolicy, QSlider, QVBoxLayout, QFrame, QCheckBox, QPushButton, QFileDialog +from PyQt6.QtCore import QEvent, pyqtSignal, Qt + +import numpy as np +import nibabel as nib +import matplotlib.pyplot as plt +import matplotlib.animation as anim + +from ...mvc.base_view import BaseViewMixin +from ..ui.voi_preview_ui import Ui_confirmVoi +from engines.ceus.src.data_objs import UltrasoundImage, CeusSeg + +# Philips CEUS Colormap: Grayscale -> Red -> Yellow +philips_colors = [ + (0.0, 0.0, 0.0), # 0% - Black + (0.4, 0.4, 0.4), # 40% - Gray + (0.8, 0.0, 0.0), # 80% - Red + (1.0, 1.0, 0.0) # 100% - Yellow +] +philips_cmap = LinearSegmentedColormap.from_list("philips_ceus", philips_colors) + +class VOIPreviewWidget(QWidget, BaseViewMixin): + """ + Widget for previewing and confirming segmentation for CEUS. + Reuses UI components from VOI drawer but in a read-only preview mode. + """ + + # Signals for communicating with controller + segmentation_confirmed = pyqtSignal() + back_requested = pyqtSignal() + close_requested = pyqtSignal() + apply_preprocs_preview = pyqtSignal(list) # List of dicts with 'name' and 'kwargs' keys + + def __init__(self, image_data: UltrasoundImage, seg_data: CeusSeg, parent: Optional[QWidget] = None): + QWidget.__init__(self, parent) + self.__init_base_view__(parent) + self._ui = Ui_confirmVoi() + self._image_data = image_data + self._seg_data = seg_data + self._pix_data = image_data.pixel_data + + # Enhancement parameters (Inherited from seg_data) + self._clahe_clip_limit = getattr(seg_data, 'clahe_clip_limit', 1.2) + self._gamma = getattr(seg_data, 'gamma', 1.5) + self._width_scale_axial = getattr(seg_data, 'width_scale_axial', 1.0) + self._width_scale_sagittal = getattr(seg_data, 'width_scale_sagittal', 1.0) + self._width_scale_coronal = getattr(seg_data, 'width_scale_coronal', 1.0) + self._use_philips_ceus = getattr(seg_data, 'use_philips_ceus', False) + self._mask_alpha = 125 # Default alpha for mask overlay (0-255) + + # Cache for enhanced volume + self._enhanced_cache = None + self._enhanced_cache_frame = -1 + + # Crosshair / navigation state + self._crosshair_active = False + self._crosshair_visible = True + + # Dimensions: x, y, z, t + if self._pix_data.ndim == 4: + self._x_len, self._y_len, self._z_len, self._num_slices = self._pix_data.shape + else: + # Fallback if 3D + self._x_len, self._y_len, self._z_len = self._pix_data.shape + self._num_slices = 1 + self._pix_data = self._pix_data.reshape((self._x_len, self._y_len, self._z_len, 1)) + + self._crosshair_xyzt = [self._x_len // 2, self._y_len // 2, self._z_len // 2, 0] + + # Segmentation mask overlay + self._roi_masks_overlap = np.zeros((self._x_len, self._y_len, self._z_len, 4), dtype=np.uint8) + self._seg_mask_indices = None # Store binary mask indices for alpha updates + if hasattr(seg_data, 'seg_mask') and seg_data.seg_mask is not None: + # seg_mask should be same spatial shape (x, y, z) + mask = seg_data.seg_mask + + # Ensure spatial alignment with image data if dimensions are flipped or permuted + # If the mask was saved from DrawVOIWidget, it should match the pixel_data shape + if mask.shape == (self._x_len, self._y_len, self._z_len): + self._seg_mask_indices = np.where(mask > 0) + elif mask.ndim == 2 and (mask.shape == (self._y_len, self._z_len) or mask.shape == (self._x_len, self._y_len)): + # Handle 2D mask for 2D+time sequence (broadcasting across Time) + # If image is (759, 1472, 1962), x=759 (Time), y=1472 (H), z=1962 (W) + # Mask is (1472, 1962), which matches (y, z) + if mask.shape == (self._y_len, self._z_len): + temp_mask = np.repeat(mask[np.newaxis, :, :], self._x_len, axis=0) + self._seg_mask_indices = np.where(temp_mask > 0) + else: + # Original logic for (x, y) spatial dimensions + temp_mask = np.repeat(mask[:, :, np.newaxis], self._z_len, axis=2) + self._seg_mask_indices = np.where(temp_mask > 0) + elif mask.shape == (self._y_len, self._x_len, self._z_len): + # Handle common XY transpose if detected + self._seg_mask_indices = np.where(mask.transpose(1, 0, 2) > 0) + else: + # Log or handle shape mismatch more gracefully if needed + print(f"Warning: Mask shape {mask.shape} does not match image shape {(self._x_len, self._y_len, self._z_len)}") + # Try to fit the mask as much as possible if shapes match in 3D volume + try: + self._seg_mask_indices = np.where(mask[:self._x_len, :self._y_len, :self._z_len] > 0) + except Exception: + pass + + if self._seg_mask_indices is not None and len(self._seg_mask_indices[0]) > 0: + self._roi_masks_overlap[self._seg_mask_indices[0], self._seg_mask_indices[1], self._seg_mask_indices[2]] = [255, 0, 0, 125] + + # Jump crosshair to a point within the mask to show it immediately + mask_indices = np.where(self._roi_masks_overlap[..., 3] > 0) + if len(mask_indices[0]) > 0: + mid_idx = len(mask_indices[0]) // 2 + self._crosshair_xyzt = [ + mask_indices[0][mid_idx], + mask_indices[1][mid_idx], + mask_indices[2][mid_idx], + 0 + ] + else: + self._crosshair_xyzt = [self._x_len // 2, self._y_len // 2, self._z_len // 2, 0] + + # Per-plane resources (axial, sagittal, coronal) + self._ax_sag_cor_matplotlib_canvases = [None, None, None] + self._ax_sag_cor_planes = (None, None, None) + self._ax_sag_cor_index_maps = ((0, 1), (2, 1), (0, 2)) # (horiz_dim, vert_dim) + self._ax_sag_cor_animations = [None, None, None] + self._ax_sag_cor_plane_artists = [None, None, None] + self._ax_sag_cor_crosshair_lines = [(None, None), (None, None), (None, None)] + self._ax_sag_cor_pending = [False, False, False] + self._ax_sag_cor_seg_masks = [None, None, None] + + # UI & visualization setup + self._setup_ui() + self._setup_matplotlib_canvases() + self._initialize_plane_displays() + self._setup_all_plane_animations() + self._connect_signals() + self._connect_matplotlib_events() + self.setFocusPolicy(Qt.FocusPolicy.StrongFocus) + self._update_scan_display() + self._refresh_frames() + + def _setup_ui(self) -> None: + """Setup the user interface to match the segmentation menu style.""" + self._ui.setupUi(self) + + # Store QLabels as tags for layout mapping + self._ax_sag_cor_planes = (self._ui.ax_plane, self._ui.sag_plane, self._ui.cor_plane) + + # Configure layout + self.setLayout(self._ui.full_screen_layout) + self._ui.full_screen_layout.setStretchFactor(self._ui.side_bar_layout, 1) + self._ui.full_screen_layout.setStretchFactor(self._ui.voi_layout, 10) + + # Initial visibility + self._ui.scan_name_input.setText(self._image_data.scan_name) + self._ui.segSidebarLabel_2.setText("Segmentation Selection") + self._ui.toggle_crosshair_visibility_button.setText('Hide Crosshair') + self._ui.cur_slice_label.setText("Current Frame:") + + self._ui.navigating_label.hide() + self._ui.observing_label.show() + + # Setup enhancement controls + self._setup_enhancement_controls() + + # Update slider for frames + self._ui.cur_slice_slider.setMinimum(0) + self._ui.cur_slice_slider.setMaximum(self._num_slices - 1) + self._ui.cur_slice_slider.setValue(0) + self._ui.cur_slice_total.setText(str(self._num_slices)) + self._ui.cur_slice_spin_box.setRange(1, self._num_slices) + self._ui.cur_slice_spin_box.setValue(1) + + self._ui.ax_total_frames.setText(str(self._z_len)) + self._ui.sag_total_frames.setText(str(self._x_len)) + self._ui.cor_total_frames.setText(str(self._y_len)) + + # Install event filters + for label in self._ax_sag_cor_planes: + if label: + label.installEventFilter(self) + + def _cleanup_animations(self): + """Internal helper to stop animations safely.""" + for i in range(3): + if i < len(self._ax_sag_cor_animations) and self._ax_sag_cor_animations[i]: + try: + self._ax_sag_cor_animations[i].event_source.stop() + except Exception: + pass + self._ax_sag_cor_animations[i] = None + + # ============================================================================ + # UI SETUP & HELPERS + # ============================================================================ + + def _show_widget_lists(self, widget_lists: List[List[QWidget]]) -> None: + """Helper to show groups of widgets.""" + for widget_list in widget_lists: + for widget in widget_list: + widget.show() + + def _hide_widget_lists(self, widget_lists: List[List[QWidget]]) -> None: + """Helper to hide groups of widgets.""" + for widget_list in widget_lists: + for widget in widget_list: + widget.hide() + + def _setup_enhancement_controls(self) -> None: + """Add enhancement sliders to the sidebar, mirroring DrawVOIWidget style.""" + enh_group = QFrame() + enh_group.setStyleSheet("background-color: rgba(255, 255, 255, 0); border: none;") + + container_layout = QVBoxLayout(enh_group) + container_layout.setContentsMargins(0, 10, 0, 10) + container_layout.setSpacing(15) + + row1_layout = QHBoxLayout() + row2_layout = QHBoxLayout() + row1_layout.setSpacing(20) + row2_layout.setSpacing(20) + + def create_enh_column(label_text, min_val, max_val, current_val, callback): + col_widget = QWidget() + col_layout = QVBoxLayout(col_widget) + col_layout.setContentsMargins(0, 0, 0, 0) + col_layout.setSpacing(5) + + lbl = QLabel(label_text) + lbl.setStyleSheet("font-size: 14px; color: white; font-weight: bold;") + lbl.setAlignment(Qt.AlignmentFlag.AlignCenter) + col_layout.addWidget(lbl) + + row_layout = QHBoxLayout() + slider = QSlider(Qt.Orientation.Horizontal) + slider.setRange(min_val, max_val) + slider.setValue(current_val) + slider.setStyleSheet(self._ui.cur_slice_slider.styleSheet()) + slider.setMinimumWidth(80) + slider.setMaximumWidth(120) + slider.valueChanged.connect(callback) + + val_lbl = QLabel(f"{current_val/10.0:.1f}") + val_lbl.setMinimumWidth(40) + val_lbl.setStyleSheet("color: #3498db; font-weight: bold; font-size: 14px;") + + row_layout.addWidget(slider) + row_layout.addWidget(val_lbl) + col_layout.addLayout(row_layout) + return col_widget, slider, val_lbl + + # Sliders + clahe_col, self.clahe_slider, self.clahe_val_lbl = create_enh_column( + "CLAHE", 1, 100, int(self._clahe_clip_limit * 10), self._on_clahe_changed + ) + gamma_col, self.gamma_slider, self.gamma_val_lbl = create_enh_column( + "GAMMA", 1, 40, int(self._gamma * 10), self._on_gamma_changed + ) + width_ax_col, self.width_ax_slider, self.width_ax_val_lbl = create_enh_column( + "WIDTH (AX)", 1, 50, int(self._width_scale_axial * 10), self._on_width_axial_changed + ) + width_sag_col, self.width_sag_slider, self.width_sag_val_lbl = create_enh_column( + "WIDTH (SAG)", 1, 50, int(self._width_scale_sagittal * 10), self._on_width_sagittal_changed + ) + width_cor_col, self.width_cor_slider, self.width_cor_val_lbl = create_enh_column( + "WIDTH (COR)", 1, 50, int(self._width_scale_coronal * 10), self._on_width_coronal_changed + ) + + # Add VOI Alpha slider + alpha_col, self.alpha_slider, self.alpha_val_lbl = create_enh_column( + "VOI ALPHA", 0, 2550, int(self._mask_alpha * 10), lambda v: self._on_alpha_changed(v // 10) + ) + # Fix label to show integer for alpha + self.alpha_val_lbl.setText(str(self._mask_alpha)) + self.alpha_slider.valueChanged.disconnect() + self.alpha_slider.valueChanged.connect(lambda v: (self._on_alpha_changed(v // 10), self.alpha_val_lbl.setText(str(v // 10)))) + + row1_layout.addWidget(clahe_col) + row1_layout.addWidget(gamma_col) + + self.philips_check = QCheckBox("Pseudocoloring") + self.philips_check.setChecked(self._use_philips_ceus) + self.philips_check.setStyleSheet("color: white; font-weight: bold; font-size: 14px;") + self.philips_check.stateChanged.connect(self._on_philips_toggled) + row1_layout.addWidget(self.philips_check) + row1_layout.addWidget(alpha_col) + + row2_layout.addWidget(width_ax_col) + row2_layout.addWidget(width_sag_col) + row2_layout.addWidget(width_cor_col) + + container_layout.addLayout(row1_layout) + container_layout.addLayout(row2_layout) + self._ui.verticalLayout_2.addWidget(enh_group) + + def _invalidate_enhancement_cache(self) -> None: + """Clear cache when processing parameters change.""" + self._enhanced_cache = None + self._enhanced_cache_frame = -1 + self._refresh_frames() + + def _on_clahe_changed(self, value: int) -> None: + """Handle CLAHE change.""" + self._clahe_clip_limit = value / 10.0 + if hasattr(self, 'clahe_val_lbl'): + self.clahe_val_lbl.setText(f"{self._clahe_clip_limit:.1f}") + self._invalidate_enhancement_cache() + + def _on_gamma_changed(self, value: int) -> None: + """Handle gamma change.""" + self._gamma = value / 10.0 + if hasattr(self, 'gamma_val_lbl'): + self.gamma_val_lbl.setText(f"{self._gamma:.1f}") + self._invalidate_enhancement_cache() + + def _on_width_axial_changed(self, value: int) -> None: + """Handle axial aspect ratio.""" + self._width_scale_axial = value / 10.0 + if hasattr(self, 'width_ax_val_lbl'): + self.width_ax_val_lbl.setText(f"{self._width_scale_axial:.1f}") + self._update_aspect_ratios() + + def _on_width_sagittal_changed(self, value: int) -> None: + """Handle sagittal aspect ratio.""" + self._width_scale_sagittal = value / 10.0 + if hasattr(self, 'width_sag_val_lbl'): + self.width_sag_val_lbl.setText(f"{self._width_scale_sagittal:.1f}") + self._update_aspect_ratios() + + def _on_width_coronal_changed(self, value: int) -> None: + """Handle coronal aspect ratio.""" + self._width_scale_coronal = value / 10.0 + if hasattr(self, 'width_cor_val_lbl'): + self.width_cor_val_lbl.setText(f"{self._width_scale_coronal:.1f}") + self._update_aspect_ratios() + + def _on_alpha_changed(self, value: int) -> None: + """Handle alpha transparency change for the VOI mask.""" + self._mask_alpha = value + # Update the rgba mask transparency + # Use stored mask indices to ensure we can recover from alpha=0 + if self._seg_mask_indices is not None and len(self._seg_mask_indices[0]) > 0: + # Re-apply color (Red) and new alpha to the relevant indices + self._roi_masks_overlap[self._seg_mask_indices[0], self._seg_mask_indices[1], self._seg_mask_indices[2]] = [255, 0, 0, self._mask_alpha] + self._refresh_frames() + + def _update_aspect_ratios(self) -> None: + """Update artist aspect ratios based on physics (pixdim) and sliders.""" + if not hasattr(self, '_image_data') or self._image_data is None: + return + + try: + # Safely get pixdim or default to 1.0s + pix = getattr(self._image_data, 'pixdim', [1.0, 1.0, 1.0]) + if pix is None or len(pix) < 2: + pix = [1.0, 1.0, 1.0] + + # Pad with 1.0 if we have 2D data but need 3D for sagittal/coronal canvases + while len(pix) < 3: + pix.append(1.0) + + # Plane 0: Axial (XY) -> show (Y, X) -> Rows=Y, Cols=X -> dy / dx + if self._ax_sag_cor_matplotlib_canvases[0]: + dx, dy = pix[0], pix[1] + aspect_ax = (dy / dx if dx != 0 else 1.0) * self._width_scale_axial + self._ax_sag_cor_matplotlib_canvases[0].figure.gca().set_aspect(aspect_ax) + + # Plane 1: Sagittal (YZ) -> 90 CW Rotation -> show (Y, Z) -> Rows=Y, Cols=Z -> dy / dz + if self._ax_sag_cor_matplotlib_canvases[1]: + dy, dz = pix[1], pix[2] + aspect_sag = (dy / dz if dz != 0 else 1.0) * self._width_scale_sagittal + self._ax_sag_cor_matplotlib_canvases[1].figure.gca().set_aspect(aspect_sag) + + # Plane 2: Coronal (XZ) -> show (Z, X) -> Rows=Z, Cols=X -> dz / dx + if self._ax_sag_cor_matplotlib_canvases[2]: + dx, dz = pix[0], pix[2] + aspect_cor = (dz / dx if dx != 0 else 1.0) * self._width_scale_coronal + self._ax_sag_cor_matplotlib_canvases[2].figure.gca().set_aspect(aspect_cor) + + for canvas in self._ax_sag_cor_matplotlib_canvases: + if canvas: + canvas.draw_idle() + + self._refresh_frames() + except Exception as e: + print(f"Error updating aspect ratios: {e}") + import traceback + traceback.print_exc() + + def _on_philips_toggled(self, state: int) -> None: + """Handle Philips CEUS pseudocolor toggle.""" + self._use_philips_ceus = state == Qt.CheckState.Checked.value + new_cmap = philips_cmap if self._use_philips_ceus else 'gray' + for artist in self._ax_sag_cor_plane_artists: + if artist: + artist.set_cmap(new_cmap) + self._refresh_frames() + + def _setup_matplotlib_canvases(self) -> None: + """Initialize and embed matplotlib canvases into each plane's placeholder layout.""" + # Use the axial, sagittal, coronal labels themselves as the parent for the canvases + # This ensures they are inside their respective QFrame boxes and aligned correctly + for i, parent_label in enumerate(self._ax_sag_cor_planes): + fig, ax = plt.subplots(facecolor='black') + fig.subplots_adjust(left=0, right=1, top=1, bottom=0) + ax.axis('off') + canvas = FigureCanvas(fig) + canvas.setParent(parent_label) + + # Use Expanding policy + canvas.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + + # Hide the label's text but keep the background/frame + parent_label.setText("") + + self._ax_sag_cor_matplotlib_canvases[i] = canvas + + def _initialize_plane_displays(self) -> None: + """Initial render of each orthogonal plane with fixed intensity scaling.""" + for i, canvas in enumerate(self._ax_sag_cor_matplotlib_canvases): + if not canvas: continue + ax = canvas.figure.axes[0] + + # Initial data slice + slice_data = self._get_plane_slice(i) + # Use fixed vmin/vmax to prevent auto-scaling contrast per slice + artist = ax.imshow(slice_data, cmap='gray', interpolation='nearest', + zorder=1, vmin=0, vmax=255) + self._ax_sag_cor_plane_artists[i] = artist + + # Mask overlay + mask_slice = self._get_mask_slice(i) + # Match interpolation and zorder to DrawVOIWidget + mask_artist = ax.imshow(mask_slice, interpolation='nearest', + zorder=8) + self._ax_sag_cor_seg_masks[i] = mask_artist + + # Crosshair lines + # Get actual coordinate indices for current plane from maps + idx_x, idx_y = self._ax_sag_cor_index_maps[i] + v_line = ax.axvline(x=self._crosshair_xyzt[idx_x], color='yellow', lw=0.8, animated=True, zorder=11) + h_line = ax.axhline(y=self._crosshair_xyzt[idx_y], color='yellow', lw=0.8, animated=True, zorder=11) + self._ax_sag_cor_crosshair_lines[i] = (v_line, h_line) + + self._update_aspect_ratios() + + def _get_plane_slice(self, plane_ix: int) -> np.ndarray: + """Extract a 2D image slice for the specified plane at current crosshair indices.""" + x, y, z, t = self._crosshair_xyzt + vol = self._get_enhanced_volume(t) + + if plane_ix == 0: # Axial (XY) at Z -> show (Y, X) + return vol[:, :, z].T + elif plane_ix == 1: # Sagittal (YZ) at X -> show (Z, Y) then rotate 90 CW -> (Y, Z) + # Match DrawVOIWidget approach: arr.T then rot90(k=-1) + arr = vol[x, :, :] + return arr + elif plane_ix == 2: # Coronal (XZ) at Y -> show (Y, X) + # Mirror Axial for Coronal to match DrawVOI behavior + return vol[:, y, :].T + return np.zeros((10, 10)) + + def _get_mask_slice(self, plane_ix: int) -> np.ndarray: + """Extract a 2D mask slice for overlay.""" + x, y, z, _ = self._crosshair_xyzt + if plane_ix == 0: # Axial (XY) at Z -> show (Y, X) + # Match DrawVOIWidget approach: index mask then transpose (1, 0, 2) + arr = self._roi_masks_overlap[:, :, z, :] + return np.transpose(arr, (1, 0, 2)) + elif plane_ix == 1: # Sagittal (YZ) at X -> show (Z, Y) then rotate 90 CW -> (Y, Z) + arr = self._roi_masks_overlap[x, :, :, :] + return arr + elif plane_ix == 2: # Coronal (XZ) at Y -> show (Y, X) + # Mirror Axial for Coronal to match DrawVOI behavior + arr = self._roi_masks_overlap[:, y, :, :] + return np.transpose(arr, (1, 0, 2)) + return np.zeros((10, 10, 4), dtype=np.uint8) + + def _get_enhanced_volume(self, t: int) -> np.ndarray: + """Return the enhanced 3D volume at frame t, requesting enhancement via the controller if needed.""" + if self._enhanced_cache is not None and self._enhanced_cache_frame == t: + return self._enhanced_cache + + # Extract the 3D volume for current frame and wrap it in a temporary + # UltrasoundImage so the controller-side preprocessors receive the + # same data structure they expect. + vol_3d = self._pix_data[:, :, :, t] + temp_im = UltrasoundImage(self._image_data.scan_path) + temp_im.pixel_data = vol_3d + temp_im.pixdim = self._image_data.pixdim + temp_im.frame_rate = self._image_data.frame_rate + + clahe_preproc_dict = { + 'name': 'enhance_clahe', + 'image_data': temp_im, + 'frame_ix': t, + 'kwargs': { + 'clip_limit': self._clahe_clip_limit, + 'tile_grid_size': (8, 8), + } + } + + gamma_preproc_dict = { + 'name': 'enhance_gamma', + 'image_data': None, # signal to reuse the already CLAHE-enhanced image + 'frame_ix': t, + 'kwargs': { + 'gamma': self._gamma, + } + } + + # Delegate to the controller via signal (synchronous slot expected on the + # other end, matching the pattern in DrawROIWidget). + self.apply_preprocs_preview.emit([clahe_preproc_dict, gamma_preproc_dict]) + + # The cache will have been populated synchronously by update_enhancement_cache + # before this returns + if self._enhanced_cache_frame >= 0: assert self._enhanced_cache_frame == t + if self._enhanced_cache is not None: + return self._enhanced_cache + return vol_3d + + def update_enhancement_cache(self, enhanced_vol: np.ndarray, frame: int) -> None: + """Receive the processed 3D volume from the controller and store it. + + Args: + enhanced_vol: The enhanced spatial volume (x, y, z) for the given frame. + frame: The time index (t) this volume corresponds to. + """ + self._enhanced_cache = enhanced_vol + self._enhanced_cache_frame = frame + self._refresh_frames() + + def _setup_all_plane_animations(self) -> None: + """Setup refresh animations for each matplotlib canvas.""" + for i in range(3): + canvas = self._ax_sag_cor_matplotlib_canvases[i] + self._ax_sag_cor_animations[i] = anim.FuncAnimation( + canvas.figure, + lambda frame, p_ix=i: self._update_plane(p_ix), + interval=33, + blit=True, + cache_frame_data=False + ) + + def _update_plane(self, plane_ix: int): + """Update artist data for a single plane.""" + # Always return the list of artists for blitting + v_line, h_line = self._ax_sag_cor_crosshair_lines[plane_ix] + artist = self._ax_sag_cor_plane_artists[plane_ix] + mask_artist = self._ax_sag_cor_seg_masks[plane_ix] + + artists = [] + if artist: artists.append(artist) + if mask_artist: artists.append(mask_artist) + if v_line: + v_line.set_visible(self._crosshair_visible) + artists.append(v_line) + if h_line: + h_line.set_visible(self._crosshair_visible) + artists.append(h_line) + + if not self._ax_sag_cor_pending[plane_ix]: + return artists + + if artist: + artist.set_data(self._get_plane_slice(plane_ix)) + if mask_artist: + mask_artist.set_data(self._get_mask_slice(plane_ix)) + + if v_line and h_line: + idx_x, idx_y = self._ax_sag_cor_index_maps[plane_ix] + # When refreshing (e.g. slice changed), snap back to stored indices + v_line.set_xdata([self._crosshair_xyzt[idx_x]]) + h_line.set_ydata([self._crosshair_xyzt[idx_y]]) + v_line.set_visible(self._crosshair_visible) + h_line.set_visible(self._crosshair_visible) + + self._ax_sag_cor_pending[plane_ix] = False + return artists + + def _connect_signals(self) -> None: + """Connect UI signals to internal handlers, matching DrawVOIWidget patterns.""" + # Frame/Time Navigation + self._ui.cur_slice_slider.valueChanged.connect(self._on_slice_slider_changed) + self._ui.cur_slice_spin_box.valueChanged.connect(lambda v: self._ui.cur_slice_slider.setValue(int(v)-1)) + self._ui.toggle_crosshair_visibility_button.clicked.connect(self._on_toggle_crosshair) + self._ui.back_button.clicked.connect(self._on_back_requested) + self._ui.confirm_voi_button.clicked.connect(self.segmentation_confirmed.emit) + + def keyPressEvent(self, event): # type: ignore + """Handle key presses for quick actions (e.g., 'd' to toggle draw ROI).""" + if event.key() == Qt.Key.Key_H: + self._on_toggle_crosshair() + return + super().keyPressEvent(event) + + def _on_back_requested(self): + """Handle back request with cleanup.""" + self._cleanup_animations() + self.back_requested.emit() + + def closeEvent(self, event): + """Clean up animations and canvases before the widget is destroyed.""" + self._cleanup_animations() + + for i in range(3): + canvas = self._ax_sag_cor_matplotlib_canvases[i] + if canvas: + try: + plt.close(canvas.figure) + except Exception: + pass + self._ax_sag_cor_matplotlib_canvases[i] = None + + super().closeEvent(event) + + def _on_slice_slider_changed(self, value: int) -> None: + """Handle time-series frame change.""" + self._crosshair_xyzt[3] = value + self._ui.cur_slice_spin_box.blockSignals(True) + self._ui.cur_slice_spin_box.setValue(value + 1) + self._ui.cur_slice_spin_box.blockSignals(False) + self._refresh_frames() + + def _on_toggle_crosshair(self) -> None: + """Toggle crosshair visibility.""" + self._crosshair_visible = not self._crosshair_visible + self._ui.toggle_crosshair_visibility_button.setText( + 'Show Crosshair' if not self._crosshair_visible else 'Hide Crosshair' + ) + self._refresh_frames() + + def _refresh_frames(self) -> None: + """Mark all planes for refresh.""" + self._ax_sag_cor_pending = [True, True, True] + + def _update_scan_display(self) -> None: + """Sync UI labels with current crosshair indices. Using 1-based indexing for display.""" + self._ui.ax_frame_num.setText(str(self._crosshair_xyzt[2] + 1)) + self._ui.sag_frame_num.setText(str(self._crosshair_xyzt[0] + 1)) + self._ui.cor_frame_num.setText(str(self._crosshair_xyzt[1] + 1)) + + # Update spinbox for t + self._ui.cur_slice_spin_box.blockSignals(True) + self._ui.cur_slice_spin_box.setValue(self._crosshair_xyzt[3] + 1) + self._ui.cur_slice_spin_box.blockSignals(False) + + def set_crosshair(self, x=None, y=None, z=None, t=None): + """Update crosshair position and trigger refresh.""" + changed = False + if x is not None and 0 <= x < self._x_len: + self._crosshair_xyzt[0] = x; changed = True + if y is not None and 0 <= y < self._y_len: + self._crosshair_xyzt[1] = y; changed = True + if z is not None and 0 <= z < self._z_len: + self._crosshair_xyzt[2] = z; changed = True + if t is not None and 0 <= t < self._num_slices: + self._crosshair_xyzt[3] = t; changed = True + + if changed: + self._update_scan_display() + self._refresh_frames() + + # ======================= Resize Handling ================================= + def eventFilter(self, obj, event): # type: ignore + if event.type() == QEvent.Type.Resize and obj in self._ax_sag_cor_planes: + self._resize_canvas_for(obj) + return super().eventFilter(obj, event) + + def _resize_canvas_for(self, label_widget: QLabel): + try: + idx = self._ax_sag_cor_planes.index(label_widget) + except ValueError: + return + canvas = self._ax_sag_cor_matplotlib_canvases[idx] + if not canvas: + return + + # Match canvas size to the QLabel/placeholder size + # We ensure it fills the parent label precisely + canvas_width = label_widget.width() + canvas_height = label_widget.height() + canvas.setFixedSize(canvas_width, canvas_height) + canvas.move(0, 0) + canvas.draw_idle() + + def _resize_all_canvases(self): + """Force a resize of all embedded matplotlib canvases.""" + for label in self._ax_sag_cor_planes: + if label: + self._resize_canvas_for(label) + + def showEvent(self, event): + # Ensure canvases sized properly when shown + self._resize_all_canvases() + return super().showEvent(event) + + def _connect_matplotlib_events(self): + """Connect motion and click events on each plane's matplotlib canvas.""" + for plane_ix, canvas in enumerate(self._ax_sag_cor_matplotlib_canvases): + if not canvas: continue + canvas.mpl_connect('motion_notify_event', lambda e, p=plane_ix: self._on_canvas_motion(e, p)) + canvas.mpl_connect('button_press_event', lambda e, p=plane_ix: self._on_canvas_click(e, p)) + + def _on_canvas_click(self, event, plane_ix: int): + if event.inaxes is None: return + self._crosshair_active = not self._crosshair_active + if self._crosshair_active: + self._ui.navigating_label.show() + self._ui.observing_label.hide() + else: + self._ui.navigating_label.hide() + self._ui.observing_label.show() + self._on_canvas_motion(event, plane_ix) + + def _on_canvas_motion(self, event, plane_ix: int): # type: ignore + """Handle mouse movement over a plane and update crosshair indices. + + event.xdata maps to the first varying dimension of that plane slice, + event.ydata to the second. We clamp to valid ranges and call set_crosshair + only if the index meaningfully changed. + """ + if not self._crosshair_active: + return + if event.inaxes is None or event.xdata is None or event.ydata is None: + return + + vary_dims = self._ax_sag_cor_index_maps[plane_ix] + dim_x, dim_y = vary_dims[0], vary_dims[1] + + # Dimension lengths mapping + dim_lengths = [self._x_len, self._y_len, self._z_len, self._num_slices] + + # Proposed new indices (int rounding & clamp) + new_xval = int(round(event.xdata)) + new_yval = int(round(event.ydata)) + if new_xval < 0 or new_yval < 0: + return + if new_xval >= dim_lengths[dim_x] or new_yval >= dim_lengths[dim_y]: + return + + # Build kwargs for set_crosshair only for dims that change + params = {} + if self._crosshair_xyzt[dim_x] != new_xval: + key = ['x','y','z','t'][dim_x] + params[key] = new_xval + if self._crosshair_xyzt[dim_y] != new_yval: + key = ['x','y','z','t'][dim_y] + params[key] = new_yval + + if params: + self.set_crosshair(**params) + + def _update_hover_crosshair(self, x, y, plane_ix): + """Update crosshair lines to follow mouse hover.""" + v_line, h_line = self._ax_sag_cor_crosshair_lines[plane_ix] + if v_line and h_line: + v_line.set_xdata([x, x]) + h_line.set_ydata([y, y]) + v_line.set_visible(self._crosshair_visible) + h_line.set_visible(self._crosshair_visible) + # Only update the background for this ONE canvas + self._ax_sag_cor_matplotlib_canvases[plane_ix].draw_idle()