diff --git a/.gitignore b/.gitignore index 5bf1196..264beac 100644 --- a/.gitignore +++ b/.gitignore @@ -7,5 +7,5 @@ fixed_wheels/ .vscode/ .DS_Store *.so -/**/*_ui.py Visualization_Results/ +/**/*_ui.py \ No newline at end of file diff --git a/engines/ceus b/engines/ceus index 97d3d3a..e5ae0e4 160000 --- a/engines/ceus +++ b/engines/ceus @@ -1 +1 @@ -Subproject commit 97d3d3a8b03ee02bc5bda8dca2f1f316151d8210 +Subproject commit e5ae0e467bc4e5608e6f1e086a1fa8905a5417c1 diff --git a/src/ceus/analysis_loading/analysis_loading_controller.py b/src/ceus/analysis_loading/analysis_loading_controller.py index 8ddbc27..cbcac1d 100644 --- a/src/ceus/analysis_loading/analysis_loading_controller.py +++ b/src/ceus/analysis_loading/analysis_loading_controller.py @@ -10,7 +10,8 @@ from ..mvc.base_controller import BaseController from .analysis_loading_view_coordinator import AnalysisLoadingViewCoordinator -from engines.ceus.src.data_objs import UltrasoundImage, CeusSeg +from engines.ceus.src.data_objs.image import UltrasoundImage +from engines.ceus.src.data_objs.seg import CeusSeg from engines.ceus.src.time_series_analysis.curves.framework import CurvesAnalysis @@ -64,20 +65,30 @@ def _connect_signals(self) -> None: def _setup_analysis_options(self) -> None: """Setup available analysis types and functions in the view.""" analysis_types, analysis_functions = self._model.get_analysis_types() + print(f"DEBUG: Available analysis types: {list(analysis_types.keys())}") - # Automatically select "Paramap" as the analysis type - paramap_type = "paramap" - if paramap_type in analysis_types: - self._selected_analysis_type = paramap_type - if self._model.set_analysis_type(paramap_type): - # Get available functions for Paramap analysis - available_functions = self._model.get_analysis_functions(paramap_type) + # Automatically select the best available analysis type + # Prefer curves_paramap, then curves, or just the first available one + selected_type = None + for preferred in ["curves_paramap", "curves", "paramap"]: + if preferred in analysis_types: + selected_type = preferred + break + + if not selected_type and analysis_types: + selected_type = list(analysis_types.keys())[0] + + if selected_type: + self._selected_analysis_type = selected_type + if self._model.set_analysis_type(selected_type): + # Get available functions for selected analysis type + available_functions = self._model.get_analysis_functions(selected_type) # Skip analysis type selection and go directly to function selection self._view_coordinator.show_function_selection(available_functions) else: - self._view_coordinator.show_error("Failed to set Paramap analysis type") + self._view_coordinator.show_error(f"Failed to set {selected_type} analysis type") else: - self._view_coordinator.show_error("Paramap analysis type not available") + self._view_coordinator.show_error("No analysis types available") def _on_user_action(self, action_name: str, action_data: Any) -> None: """ @@ -98,7 +109,7 @@ def _on_user_action(self, action_name: str, action_data: Any) -> None: print(f"DEBUG: Controller received analysis_execution_started action") print(f"DEBUG: action_data = {action_data}") self._handle_analysis_execution(action_data) - elif action_name == "analysis_completed": + elif action_name == "analysis_loading_completed": self._handle_analysis_completion(action_data) else: # Forward unknown actions to application controller diff --git a/src/ceus/analysis_loading/analysis_loading_view_coordinator.py b/src/ceus/analysis_loading/analysis_loading_view_coordinator.py index a87ade9..a754541 100644 --- a/src/ceus/analysis_loading/analysis_loading_view_coordinator.py +++ b/src/ceus/analysis_loading/analysis_loading_view_coordinator.py @@ -10,12 +10,13 @@ from PyQt6.QtWidgets import QWidget, QStackedWidget from PyQt6.QtCore import pyqtSignal -from quantus.gui.mvc.base_view import BaseViewMixin +from ..mvc.base_view import BaseViewMixin from .views.analysis_function_selection_widget import AnalysisFunctionSelectionWidget -from quantus.gui.config_loading.views.analysis_params_widget import AnalysisParamsWidget +from .views.analysis_params_widget import AnalysisParamsWidget from .views.analysis_execution_widget import AnalysisExecutionWidget -from quantus.data_objs import UltrasoundRfImage, BmodeSeg, RfAnalysisConfig -from quantus.analysis.paramap.framework import ParamapAnalysis +from engines.ceus.src.data_objs.image import UltrasoundImage +from engines.ceus.src.data_objs.seg import CeusSeg +from engines.ceus.src.time_series_analysis.curves.framework import CurvesAnalysis class AnalysisLoadingViewCoordinator(QStackedWidget, BaseViewMixin): @@ -40,7 +41,7 @@ class AnalysisLoadingViewCoordinator(QStackedWidget, BaseViewMixin): # ============================================================================ - def __init__(self, image_data: UltrasoundRfImage, seg_data: BmodeSeg, config_data: RfAnalysisConfig, parent: Optional[QWidget] = None): + def __init__(self, image_data: UltrasoundImage, seg_data: CeusSeg, config_data, parent: Optional[QWidget] = None): super().__init__(parent) self.__init_base_view__(parent) self._image_data = image_data @@ -48,11 +49,6 @@ def __init__(self, image_data: UltrasoundRfImage, seg_data: BmodeSeg, config_dat self._config_data = config_data print(f"DEBUG: AnalysisLoadingViewCoordinator - image_data = {image_data is not None}") - if image_data is not None: - print(f"DEBUG: AnalysisLoadingViewCoordinator - scan_name = {image_data.scan_name}") - print(f"DEBUG: AnalysisLoadingViewCoordinator - phantom_name = {image_data.phantom_name}") - else: - print(f"DEBUG: AnalysisLoadingViewCoordinator - image_data is None!") # Widget instances self._function_selection_widget: Optional[AnalysisFunctionSelectionWidget] = None @@ -63,7 +59,7 @@ def __init__(self, image_data: UltrasoundRfImage, seg_data: BmodeSeg, config_dat self._selected_analysis_type: Optional[str] = None self._selected_functions: List[str] = [] self._analysis_params: dict = {} - self._analysis_data: Optional[ParamapAnalysis] = None + self._analysis_data: Optional[CurvesAnalysis] = None # Note: Analysis type selection is now skipped - Paramap is automatically selected # The controller will call show_function_selection directly @@ -106,12 +102,16 @@ def show_error(self, error_message: str) -> None: error_message: Error message to display """ current_widget: BaseViewMixin = self.currentWidget() - current_widget.show_error(error_message) + if current_widget: + current_widget.show_error(error_message) + else: + print(f"ERROR (no active widget): {error_message}") def clear_error(self) -> None: """Clear error message in the current widget.""" current_widget: BaseViewMixin = self.currentWidget() - current_widget.clear_error() + if current_widget: + current_widget.clear_error() # ============================================================================ # NAVIGATION METHODS - Methods to show different widgets @@ -208,7 +208,7 @@ def show_analysis_execution(self, execution_summary: Dict) -> None: self._execution_widget.clear_error() print(f"DEBUG: show_analysis_execution completed - execution screen should be visible") - def show_analysis_results(self, analysis_data: ParamapAnalysis) -> None: + def show_analysis_results(self, analysis_data: CurvesAnalysis) -> None: """ Show analysis results in the execution widget. @@ -260,7 +260,7 @@ def _on_execution_started(self, execution_data: dict) -> None: self._emit_user_action("analysis_execution_started", execution_data) print(f"DEBUG: user_action signal emitted") - def _on_analysis_confirmed(self, analysis_data: ParamapAnalysis) -> None: + def _on_analysis_confirmed(self, analysis_data: CurvesAnalysis) -> None: """ Handle analysis completion confirmation. diff --git a/src/ceus/analysis_loading/ui/analysis_params.ui b/src/ceus/analysis_loading/ui/analysis_params.ui new file mode 100644 index 0000000..c373b58 --- /dev/null +++ b/src/ceus/analysis_loading/ui/analysis_params.ui @@ -0,0 +1,688 @@ + + + analysisParams + + + + 0 + 0 + 1284 + 803 + + + + + 0 + 0 + + + + Analysis Parameters Configuration + + + QWidget { + background: rgb(42, 42, 42); +} + + + + + 60 + 20 + 951 + 731 + + + + + + + 0 + + + QLayout::SizeConstraint::SetMaximumSize + + + + + + 341 + 601 + + + + + 241 + 601 + + + + <html><head/><body><p><br/></p></body></html> + + + QWidget { + background-color: rgb(28, 0, 101); +} + + + + + 0 + 0 + 341 + 121 + + + + + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(99, 0, 174); + border: 1px solid black; +} + + + QFrame::Shape::StyledPanel + + + QFrame::Shadow::Raised + + + + + 70 + 0 + 191 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Image Selection: + + + Qt::AlignmentFlag::AlignCenter + + + + + + -60 + 40 + 191 + 51 + + + + QLabel { + font-size: 16px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Image: + + + Qt::AlignmentFlag::AlignCenter + + + + + + -50 + 70 + 191 + 51 + + + + QLabel { + font-size: 16px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold +} + + + Phantom: + + + Qt::AlignmentFlag::AlignCenter + + + + + + 100 + 40 + 241 + 51 + + + + QLabel { + font-size: 14px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; +} + + + Sample filename + + + Qt::AlignmentFlag::AlignLeading|Qt::AlignmentFlag::AlignLeft|Qt::AlignmentFlag::AlignVCenter + + + + + + 100 + 70 + 241 + 51 + + + + QLabel { + font-size: 14px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; +} + + + Sample filename + + + Qt::AlignmentFlag::AlignLeading|Qt::AlignmentFlag::AlignLeft|Qt::AlignmentFlag::AlignVCenter + + + + + + + 0 + 120 + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(99, 0, 174); + border: 1px solid black; +} + + + QFrame::Shape::StyledPanel + + + QFrame::Shadow::Raised + + + + + 0 + 40 + 341 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Segmentation Selection + + + Qt::AlignmentFlag::AlignCenter + + + + + + + 0 + 240 + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(99, 0, 174); + border: 1px solid black; +} + + + QFrame::Shape::StyledPanel + + + QFrame::Shadow::Raised + + + + + 0 + 30 + 341 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight:bold; +} + + + Analysis Parameter Selection + + + Qt::AlignmentFlag::AlignCenter + + + + + + + 0 + 360 + 341 + 121 + + + + + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(99, 0, 174); + border: 1px solid black; +} + + + QFrame::Shape::StyledPanel + + + QFrame::Shadow::Raised + + + + + 0 + 30 + 341 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + CEUS Analysis + + + Qt::AlignmentFlag::AlignCenter + + + + + + + 0 + 480 + 341 + 121 + + + + + 341 + 121 + + + + + 341 + 121 + + + + QFrame { + background-color: rgb(49, 0, 124); + border: 1px solid black; +} + + + QFrame::Shape::StyledPanel + + + QFrame::Shadow::Raised + + + + + 20 + 30 + 301 + 51 + + + + QLabel { + font-size: 21px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); + border: 0px; + font-weight: bold; +} + + + Visualization / Export + + + Qt::AlignmentFlag::AlignCenter + + + + + + + + + + 341 + 16777215 + + + + QFrame { + background-color: rgb(28, 0, 101); +} + + + + QLayout::SizeConstraint::SetMinAndMaxSize + + + 10 + + + 10 + + + 10 + + + 10 + + + + + Qt::Orientation::Vertical + + + + 20 + 40 + + + + + + + + Qt::Orientation::Horizontal + + + + 40 + 20 + + + + + + + + + 131 + 41 + + + + + 131 + 41 + + + + QPushButton { + color: white; + font-size: 16px; + background: rgb(90, 37, 255); + border-radius: 15px; +} + + + Back + + + + + + + + + + + + 50 + + + 30 + + + 10 + + + 30 + + + 10 + + + + + QLabel { + font-size: 29px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + Analysis in Progress... + + + Qt::TextFormat::AutoText + + + false + + + Qt::AlignmentFlag::AlignCenter + + + true + + + + + + + QLabel { + font-size: 29px; + color: rgb(255, 255, 255); + background-color: rgba(255, 255, 255, 0); +} + + + Configure Analysis Parameters: + + + Qt::TextFormat::AutoText + + + false + + + Qt::AlignmentFlag::AlignCenter + + + true + + + + + + + true + + + + + 0 + 0 + 409 + 284 + + + + + + + + + + QLabel { + color: rgb(0, 255, 0); + font-size: 20px; + background-color: rgba(255, 255, 255, 0); +} + + + Running Analysis.... + + + Qt::AlignmentFlag::AlignCenter + + + + + + + + 131 + 41 + + + + + 131 + 41 + + + + QPushButton { + color: white; + font-size: 16px; + background: rgb(90, 37, 255); + border-radius: 15px; +} + + + Run Analysis + + + + + + + Qt::Orientation::Vertical + + + + 20 + 40 + + + + + + + + + + + + diff --git a/src/ceus/analysis_loading/views/analysis_execution_widget.py b/src/ceus/analysis_loading/views/analysis_execution_widget.py index 08d5b76..a302e51 100644 --- a/src/ceus/analysis_loading/views/analysis_execution_widget.py +++ b/src/ceus/analysis_loading/views/analysis_execution_widget.py @@ -10,10 +10,11 @@ from PyQt6.QtCore import pyqtSignal, Qt, QTimer from PyQt6.QtGui import QFont -from quantus.gui.mvc.base_view import BaseViewMixin -from quantus.gui.analysis_loading.ui.analysis_execution_ui import Ui_analysisExecution -from quantus.data_objs import UltrasoundRfImage, BmodeSeg, RfAnalysisConfig -from quantus.analysis.paramap.framework import ParamapAnalysis +from ...mvc.base_view import BaseViewMixin +from ..ui.analysis_execution_ui import Ui_analysisExecution +from engines.ceus.src.data_objs.image import UltrasoundImage +from engines.ceus.src.data_objs.seg import CeusSeg +from engines.ceus.src.time_series_analysis.curves.framework import CurvesAnalysis class AnalysisExecutionWidget(QWidget, BaseViewMixin): @@ -26,11 +27,11 @@ class AnalysisExecutionWidget(QWidget, BaseViewMixin): # Signals for communicating with controller execution_started = pyqtSignal(dict) # execution_data - analysis_confirmed = pyqtSignal(object) # analysis_data (ParamapAnalysis) + analysis_confirmed = pyqtSignal(object) # analysis_data (CurvesAnalysis) close_requested = pyqtSignal() back_requested = pyqtSignal() - def __init__(self, image_data: UltrasoundRfImage, seg_data: BmodeSeg, config_data: RfAnalysisConfig, parent: Optional[QWidget] = None): + def __init__(self, image_data: UltrasoundImage, seg_data: CeusSeg, config_data, parent: Optional[QWidget] = None): QWidget.__init__(self, parent) self.__init_base_view__(parent) self._ui = Ui_analysisExecution() @@ -40,7 +41,7 @@ def __init__(self, image_data: UltrasoundRfImage, seg_data: BmodeSeg, config_dat # Current state self._execution_summary: Dict = {} - self._analysis_data: Optional[ParamapAnalysis] = None + self._analysis_data: Optional[CurvesAnalysis] = None self._is_executing = False self._results_shown = False # Track if results have been shown @@ -78,8 +79,8 @@ def setup_ui(self) -> None: # Update labels to reflect inputted image and phantom if self._image_data is not None: - self._ui.image_path_input.setText(self._image_data.scan_name or "No image loaded") - self._ui.phantom_path_input.setText(self._image_data.phantom_name or "No phantom loaded") + self._ui.image_path_input.setText(getattr(self._image_data, 'scan_name', "No image loaded")) + self._ui.phantom_path_input.setText(getattr(self._image_data, 'phantom_name', "No phantom loaded")) else: self._ui.image_path_input.setText("No image loaded") self._ui.phantom_path_input.setText("No phantom loaded") @@ -219,7 +220,7 @@ def _clear_summary_layout(self) -> None: if child.widget(): child.widget().deleteLater() - def show_results(self, analysis_data: ParamapAnalysis) -> None: + def show_results(self, analysis_data: CurvesAnalysis) -> None: """ Show analysis results. diff --git a/src/ceus/analysis_loading/views/analysis_function_selection_widget.py b/src/ceus/analysis_loading/views/analysis_function_selection_widget.py index f36643d..318544a 100644 --- a/src/ceus/analysis_loading/views/analysis_function_selection_widget.py +++ b/src/ceus/analysis_loading/views/analysis_function_selection_widget.py @@ -9,9 +9,10 @@ from PyQt6.QtWidgets import QWidget, QComboBox, QVBoxLayout, QHBoxLayout, QLabel, QSizePolicy from PyQt6.QtCore import pyqtSignal, Qt -from quantus.gui.mvc.base_view import BaseViewMixin -from quantus.gui.analysis_loading.ui.analysis_function_selection_ui import Ui_analysisFunctionSelection -from quantus.data_objs import UltrasoundRfImage, BmodeSeg, RfAnalysisConfig +from ...mvc.base_view import BaseViewMixin +from ..ui.analysis_function_selection_ui import Ui_analysisFunctionSelection +from engines.ceus.src.data_objs.image import UltrasoundImage +from engines.ceus.src.data_objs.seg import CeusSeg class AnalysisFunctionSelectionWidget(QWidget, BaseViewMixin): @@ -27,7 +28,7 @@ class AnalysisFunctionSelectionWidget(QWidget, BaseViewMixin): close_requested = pyqtSignal() back_requested = pyqtSignal() - def __init__(self, image_data: UltrasoundRfImage, seg_data: BmodeSeg, config_data: RfAnalysisConfig, parent: Optional[QWidget] = None): + def __init__(self, image_data: UltrasoundImage, seg_data: CeusSeg, config_data, parent: Optional[QWidget] = None): QWidget.__init__(self, parent) self.__init_base_view__(parent) self._ui = Ui_analysisFunctionSelection() @@ -56,8 +57,8 @@ def setup_ui(self) -> None: # Update labels to reflect inputted image and phantom if self._image_data is not None: - self._ui.image_path_input.setText(self._image_data.scan_name or "No image loaded") - self._ui.phantom_path_input.setText(self._image_data.phantom_name or "No phantom loaded") + self._ui.image_path_input.setText(getattr(self._image_data, 'scan_name', "No image loaded")) + self._ui.phantom_path_input.setText(getattr(self._image_data, 'phantom_name', "No phantom loaded")) else: self._ui.image_path_input.setText("No image loaded") self._ui.phantom_path_input.setText("No phantom loaded") diff --git a/src/ceus/analysis_loading/views/analysis_params_widget.py b/src/ceus/analysis_loading/views/analysis_params_widget.py new file mode 100644 index 0000000..aea4c51 --- /dev/null +++ b/src/ceus/analysis_loading/views/analysis_params_widget.py @@ -0,0 +1,92 @@ +""" +Analysis Parameters Widget for Analysis Loading + +This widget allows users to configure parameters required for the selected analysis functions. +It dynamically creates input fields based on the required parameters. +""" + +from typing import List, Optional, Dict, Any +from PyQt6.QtWidgets import (QWidget, QLabel, QLineEdit, QDoubleSpinBox, QSpinBox, + QCheckBox, QComboBox, QFormLayout, + QGroupBox, QTextEdit) +from PyQt6.QtCore import pyqtSignal, Qt, QTimer + +from ...mvc.base_view import BaseViewMixin +from ..ui.analysis_params_ui import Ui_analysisParams +from engines.ceus.src.data_objs.image import UltrasoundImage +from engines.ceus.src.data_objs.seg import CeusSeg + + +class AnalysisParamsWidget(QWidget, BaseViewMixin): + """ + Widget for configuring analysis parameters. + + This widget dynamically creates input fields based on the required parameters + for the selected analysis functions. + """ + + # Signals for communicating with controller + params_configured = pyqtSignal(dict) # analysis_params + close_requested = pyqtSignal() + back_requested = pyqtSignal() + + def __init__(self, image_data: UltrasoundImage, seg_data: CeusSeg, config_data, parent: Optional[QWidget] = None): + QWidget.__init__(self, parent) + self.__init_base_view__(parent) + self._ui = Ui_analysisParams() + self._image_data = image_data + self._seg_data = seg_data + self._config_data = config_data + + # Track parameter inputs + self._param_inputs: Dict[str, QWidget] = {} + self._required_params: List[str] = [] + self._selected_functions: List[str] = [] + + def setup_ui(self) -> None: + """Setup the user interface.""" + self._ui.setupUi(self) + + # Configure layout for parameters configuration (assuming similar structure to QUS) + if hasattr(self._ui, 'full_screen_layout'): + self.setLayout(self._ui.full_screen_layout) + + # Update labels to reflect inputted image + if hasattr(self._ui, 'image_path_input') and self._image_data: + scan_name = getattr(self._image_data, 'scan_name', 'Unknown') + self._ui.image_path_input.setText(scan_name) + + def connect_signals(self) -> None: + """Connect UI signals to internal handlers.""" + if hasattr(self._ui, 'run_analysis_button'): + self._ui.run_analysis_button.clicked.connect(self._on_run_analysis_clicked) + if hasattr(self._ui, 'back_button'): + self._ui.back_button.clicked.connect(self._on_back_clicked) + + def set_required_params(self, required_params: List[str], selected_functions: List[str]) -> None: + """ + Set required parameters and create input fields. + + Args: + required_params: List of required parameter names + selected_functions: List of selected function names + """ + self._required_params = required_params + self._selected_functions = selected_functions + self._create_parameter_inputs() + + def _create_parameter_inputs(self) -> None: + """Create input fields for each required parameter.""" + # This implementation is simplified compared to QUS for now + # Ideally would dynamically create inputs based on CEUS requirements + pass + + def _on_run_analysis_clicked(self) -> None: + """Handle run analysis button click.""" + # Collect parameters (simplified) + params = {} + self.params_configured.emit(params) + + def _on_back_clicked(self) -> None: + """Handle back button click.""" + self.back_requested.emit() diff --git a/src/ceus/application_controller.py b/src/ceus/application_controller.py index 86f6cbf..3456bf1 100644 --- a/src/ceus/application_controller.py +++ b/src/ceus/application_controller.py @@ -12,7 +12,9 @@ from .image_loading.image_loading_view_coordinator import ImageLoadingViewCoordinator from .image_loading.image_loading_controller import ImageLoadingController from .seg_loading.seg_loading_controller import SegmentationLoadingController -from engines.ceus.src.data_objs import UltrasoundImage, CeusSeg +from .analysis_loading.analysis_loading_controller import AnalysisLoadingController +from engines.ceus.src.data_objs.image import UltrasoundImage +from engines.ceus.src.data_objs.seg import CeusSeg class ApplicationController(QObject): @@ -37,9 +39,14 @@ def __init__(self, app: QApplication): # Unified application model self._model = ApplicationModel() + # Current data + self._image_data: Optional[UltrasoundImage] = None + self._seg_data: Optional[CeusSeg] = None + # Controllers for different screens (using the same model) self._image_loading_controller: Optional[ImageLoadingController] = None self._segmentation_controller: Optional[SegmentationLoadingController] = None + self._analysis_loading_controller: Optional[AnalysisLoadingController] = None # Setup main widget self._setup_main_widget() @@ -82,6 +89,8 @@ def _initialize_segmentation_loading(self, image_data: UltrasoundImage) -> None: Args: image_data: Loaded image data from previous screen """ + self._image_data = image_data + if self._segmentation_controller: self._cleanup_segmentation_loading() @@ -128,10 +137,58 @@ def _on_segmentation_action(self, action_name: str, action_data) -> None: """ if action_name == 'segmentation_confirmed': self._seg_data = self._segmentation_controller.get_loaded_segmentation() - # TODO: Navigate to analysis screen when implemented - print("Analysis screen coming soon...") - self._app.quit() + + # Use model data as source of truth + image_data = self._model.image_data if self._model.image_data else self._image_data + + self._initialize_analysis_loading(image_data, self._seg_data) + def _initialize_analysis_loading(self, image_data: UltrasoundImage, seg_data: CeusSeg) -> None: + """ + Initialize the analysis loading screen. + + Args: + image_data: Loaded image data + seg_data: Loaded segmentation data + """ + if self._analysis_loading_controller: + self._cleanup_analysis_loading() + + # Create controller with unified model + # Note: CEUS might need a config object, passing None for now if not available + self._analysis_loading_controller = AnalysisLoadingController(self._model, image_data, seg_data, None) + + # Connect signals + self._analysis_loading_controller.view.user_action.connect(self._on_analysis_action) + self._analysis_loading_controller.view.back_requested.connect(self._navigate_to_segmentation_loading) + + # Add to stack and show + self._widget_stack.addWidget(self._analysis_loading_controller.view) + self._widget_stack.setCurrentWidget(self._analysis_loading_controller.view) + + def _on_analysis_action(self, action_name: str, action_data) -> None: + """ + Handle actions from the analysis loading screen. + + Args: + action_name: Name of the action + action_data: Data associated with the action + """ + if action_name == 'analysis_loading_completed': + print("Analysis completed successfully!") + # Future: Navigate to visualization screen + self._app.quit() + + def _navigate_to_segmentation_loading(self) -> None: + """Navigate back to segmentation loading.""" + if self._analysis_loading_controller: + self._cleanup_analysis_loading() + + if self._segmentation_controller: + self._widget_stack.setCurrentWidget(self._segmentation_controller.view) + else: + self._initialize_segmentation_loading(self._image_data) + def _navigate_to_image_loading(self) -> None: """Navigate to image loading screen.""" # Reset image loading controller to initial state @@ -192,6 +249,15 @@ def _cleanup(self) -> None: """Clean up all resources before application exit.""" self._cleanup_image_loading() self._cleanup_segmentation_loading() + self._cleanup_analysis_loading() + + def _cleanup_analysis_loading(self) -> None: + """Clean up analysis loading controller resources.""" + if self._analysis_loading_controller: + self._widget_stack.removeWidget(self._analysis_loading_controller.view) + self._analysis_loading_controller.cleanup() + self._analysis_loading_controller.view.deleteLater() + self._analysis_loading_controller = None @property def image_data(self) -> Optional[UltrasoundImage]: diff --git a/src/ceus/application_model.py b/src/ceus/application_model.py index 2ae47fc..25e62c0 100644 --- a/src/ceus/application_model.py +++ b/src/ceus/application_model.py @@ -12,8 +12,11 @@ from .mvc.base_model import BaseModel from engines.ceus.src.image_loading.options import get_scan_loaders from engines.ceus.src.seg_loading.options import get_seg_loaders +from engines.ceus.src.time_series_analysis.options import get_analysis_types from engines.ceus.src.entrypoints import scan_loading_step, seg_loading_step -from engines.ceus.src.data_objs import UltrasoundImage, CeusSeg +from engines.ceus.src.data_objs.image import UltrasoundImage +from engines.ceus.src.data_objs.seg import CeusSeg +from engines.ceus.src.time_series_analysis.curves.framework import CurvesAnalysis class ScanLoadingWorker(QThread): @@ -75,6 +78,56 @@ def run(self): self.error_msg.emit(f"Error loading segmentation: {e}") +class AnalysisWorker(QThread): + """Worker thread for time-consuming analysis operations.""" + finished = pyqtSignal(object) + error_msg = pyqtSignal(str) + + def __init__(self, analysis_type: str, image_data: UltrasoundImage, + config_data: Any, seg_data: CeusSeg, + selected_functions: List[str], analysis_kwargs: Dict[str, Any]): + super().__init__() + self.analysis_type = analysis_type + self.image_data = image_data + self.config_data = config_data + self.seg_data = seg_data + self.selected_functions = selected_functions + self.analysis_kwargs = analysis_kwargs + + def run(self): + """Execute the analysis in background thread.""" + try: + from engines.ceus.src.time_series_analysis.options import get_analysis_types + all_types, _ = get_analysis_types() + + if self.analysis_type not in all_types: + self.error_msg.emit(f"Invalid analysis type: {self.analysis_type}") + return + + analysis_cls = all_types[self.analysis_type] + + # Initialize analysis + analysis_obj = analysis_cls( + self.image_data, + self.seg_data, + self.selected_functions, + **self.analysis_kwargs + ) + + # Execute analysis + # Note: For CEUS, execution might happen during init or via a specific method + # In time_series_analysis/curves/framework.py, init does some setup but maybe not full execution + if hasattr(analysis_obj, 'run'): + analysis_obj.run() + + self.finished.emit(analysis_obj) + + except Exception as e: + import traceback + traceback.print_exc() + self.error_msg.emit(f"Error during analysis: {e}") + + class ApplicationModel(BaseModel): """ Unified application model that manages all data and business logic for the QuantUS GUI. @@ -88,7 +141,9 @@ class ApplicationModel(BaseModel): # Additional signals for application-specific events image_loaded = pyqtSignal(UltrasoundImage) + preprocessing_complete = pyqtSignal(UltrasoundImage) segmentation_loaded = pyqtSignal(CeusSeg) + analysis_completed = pyqtSignal(object) # Emits CurvesAnalysis def __init__(self): super().__init__() @@ -105,9 +160,17 @@ def __init__(self): self._seg_data: Optional[CeusSeg] = None self._seg_worker: Optional[SegLoadingWorker] = None + # Analysis state + self._analysis_data: Optional[CurvesAnalysis] = None + self._analysis_types: Dict[str, Any] = {} + self._analysis_functions: Dict[str, Any] = {} + self._selected_analysis_type: Optional[str] = None + self._analysis_worker: Optional[AnalysisWorker] = None + # Initialize loaders self._load_scan_loaders() self._load_seg_loaders() + self._load_analysis_types() def _load_scan_loaders(self) -> None: """Load available scan loaders from backend.""" @@ -122,6 +185,15 @@ def _load_seg_loaders(self) -> None: self._seg_loaders = get_seg_loaders() except Exception as e: self._emit_error(f"Failed to load seg loaders: {e}") + + def _load_analysis_types(self) -> None: + """Load available analysis types from backend.""" + try: + self._analysis_types, self._analysis_functions = get_analysis_types() + except Exception as e: + print(f"Error loading analysis types: {e}") + self._analysis_types = {} + self._analysis_functions = {} # Image Loading Properties and Methods @property @@ -279,23 +351,21 @@ def get_preprocessing_kwargs_requirements(self, func_names: list) -> list: from engines.ceus.src.image_preprocessing.options import get_required_im_preproc_kwargs return get_required_im_preproc_kwargs(func_names) - def apply_preprocessing_preview(self, func_configs: List[Dict[str, Any]], image_data: Optional[UltrasoundImage] = None) -> UltrasoundImage: + def apply_preprocessing(self, func_configs: List[Dict[str, Any]]) -> None: """ - Apply preprocessing to the given UltrasoundImage. - This does not modify the image data in the model. + Apply preprocessing to the model's current image. + This modifies the image data in the model. Args: func_configs: List of dicts with 'name' and 'kwargs' for each function - image_data: Optional UltrasoundImage to preprocess (if None, uses current image) """ - if not image_data and not self._image_data: + if not self._image_data: self._emit_error("No image loaded to preprocess") return - - processed_image = image_data if image_data else self._image_data try: funcs = self.get_preprocessing_options() + processed_image = self._image_data for config in func_configs: name = config['name'] @@ -304,12 +374,12 @@ def apply_preprocessing_preview(self, func_configs: List[Dict[str, Any]], image_ processed_image = funcs[name](processed_image, **kwargs) else: print(f"WARNING: Preprocessing function {name} not found") - + + self._image_data = processed_image + self.preprocessing_complete.emit(self._image_data) except Exception as e: self._emit_error(f"Error during preprocessing: {e}") - return processed_image - def enhance_image(self, image: UltrasoundImage, func_configs: List[Dict[str, Any]]) -> UltrasoundImage: """ Enhance a given UltrasoundImage and return the result. @@ -544,6 +614,19 @@ def _on_segmentation_loading_complete(self, seg_data: CeusSeg) -> None: else: print(f"DEBUG: Segmentation loading failed - invalid seg data") self._emit_error("Failed to load segmentation data") + + def set_manual_segmentation(self, seg_data: CeusSeg) -> None: + """ + Set manually drawn segmentation data. + + Args: + seg_data: Manually drawn segmentation data + """ + if seg_data and hasattr(seg_data, 'seg_mask') and seg_data.seg_mask is not None: + self._seg_data = seg_data + self.segmentation_loaded.emit(seg_data) + else: + self._emit_error("Invalid manual segmentation data") def cleanup(self) -> None: """Clean up resources.""" @@ -556,3 +639,83 @@ def cleanup(self) -> None: self._seg_worker.quit() self._seg_worker.wait() self._seg_worker = None + + # ============================================================================ + # ANALYSIS METHODS + # ============================================================================ + + def get_analysis_types(self) -> tuple: + """Get available analysis types and functions.""" + return self._analysis_types, self._analysis_functions + + def set_analysis_type(self, analysis_type: str) -> bool: + """ + Set the selected analysis type. + + Args: + analysis_type: Analysis type to select + + Returns: + bool: True if successful + """ + if analysis_type in self._analysis_types: + self._selected_analysis_type = analysis_type + return True + else: + print(f"DEBUG: Invalid analysis type: {analysis_type}") + return False + + def get_analysis_functions(self, analysis_type: str) -> dict: + """ + Get available functions for an analysis type. + + Args: + analysis_type: Analysis type + + Returns: + dict: Available functions for the analysis type + """ + # In CEUS engine, analysis_functions is a flat dict of all available curve functions + # that are applicable to both 'curves' and 'curves_paramap' analysis types. + if analysis_type in self._analysis_functions and isinstance(self._analysis_functions[analysis_type], dict): + return self._analysis_functions[analysis_type] + + return self._analysis_functions + + def set_analysis_data(self, analysis_data: CurvesAnalysis) -> None: + """ + Store completed analysis data. + + Args: + analysis_data: Completed analysis data + """ + self._analysis_data = analysis_data + # Signal that analysis is complete + self.analysis_completed.emit(analysis_data) + + def run_analysis(self, analysis_type: str, image_data: UltrasoundImage, + config_data: Any, seg_data: CeusSeg, + selected_functions: List[str], **kwargs) -> None: + """ + Run the analysis in a background thread. + """ + # Stop existing worker if running + if self._analysis_worker and self._analysis_worker.isRunning(): + self._analysis_worker.quit() + self._analysis_worker.wait() + + self._analysis_worker = AnalysisWorker( + analysis_type, image_data, config_data, seg_data, selected_functions, kwargs + ) + + self._analysis_worker.finished.connect(self._on_analysis_worker_finished) + self._analysis_worker.error_msg.connect(self._emit_error) + + self._set_loading(True) + self._analysis_worker.start() + + def _on_analysis_worker_finished(self, analysis_obj: Any) -> None: + """Handle analysis completion.""" + self._set_loading(False) + self._analysis_data = analysis_obj + self.analysis_completed.emit(analysis_obj) diff --git a/src/ceus/image_loading/views/file_selection_widget.py b/src/ceus/image_loading/views/file_selection_widget.py index 5b095c6..23f8d09 100644 --- a/src/ceus/image_loading/views/file_selection_widget.py +++ b/src/ceus/image_loading/views/file_selection_widget.py @@ -116,7 +116,8 @@ def _show_loading_message(self) -> None: def _on_choose_image_path(self) -> None: """Handle image file selection.""" - if self._file_extensions == ["FOLDER"]: + is_folder = any(ext.upper() == "FOLDER" for ext in self._file_extensions) + if is_folder: dir_name = QFileDialog.getExistingDirectory(self, "Select Directory") if dir_name: self._ui.image_path_input.setText(dir_name) @@ -133,12 +134,16 @@ def _on_generate_image(self) -> None: if not os.path.exists(image_path): self.show_error(f"Image file does not exist: {os.path.basename(image_path)}") return - if not image_path.endswith(tuple(self._file_extensions)) and self._file_extensions != ['FOLDER']: - self.show_error(f"Image file must have one of the following extensions: {', '.join(self._file_extensions)}") - return - if self._file_extensions == ["FOLDER"] and not os.path.isdir(image_path): - self.show_error("Input path must be a folder!") - return + + is_folder = any(ext.upper() == "FOLDER" for ext in self._file_extensions) + if not is_folder: + if not image_path.endswith(tuple(self._file_extensions)): + self.show_error(f"Image file must have one of the following extensions: {', '.join(self._file_extensions)}") + return + else: + if not os.path.isdir(image_path): + self.show_error("Input path must be a folder!") + return self.clear_error() diff --git a/src/ceus/seg_loading/__init__.py b/src/ceus/seg_loading/__init__.py index 9047816..120438d 100644 --- a/src/ceus/seg_loading/__init__.py +++ b/src/ceus/seg_loading/__init__.py @@ -8,6 +8,8 @@ # Individual widget components from .views.seg_type_selection_widget import SegTypeSelectionWidget from .views.seg_file_selection_widget import SegFileSelectionWidget +from .views.seg_preview_widget import SegPreviewWidget +from .views.draw_roi_widget import DrawROIWidget __all__ = [ 'SegmentationLoadingModel', diff --git a/src/ceus/seg_loading/seg_loading_controller.py b/src/ceus/seg_loading/seg_loading_controller.py index eab66f6..513dd3b 100644 --- a/src/ceus/seg_loading/seg_loading_controller.py +++ b/src/ceus/seg_loading/seg_loading_controller.py @@ -35,15 +35,15 @@ def __init__(self, model: Optional[ApplicationModel] = None, custom_view=None): super().__init__(model, view) - # # Connect to model signals for automatic view updates - # self._connect_model_signals() + # Connect to model signals for automatic view updates + self._connect_model_signals() # Initialize view with segmentation loaders self._initialize_view() - # def _connect_model_signals(self) -> None: - # """Connect to model signals for automatic view updates.""" - # self.model.segmentation_loaded.connect(self.view.show_segmentation_preview) + def _connect_model_signals(self) -> None: + """Connect to model signals for automatic view updates.""" + self.model.segmentation_loaded.connect(self.view.show_segmentation_preview) def _initialize_view(self) -> None: """Initialize the view with data from the model.""" @@ -67,7 +67,10 @@ def handle_user_action(self, action_name: str, action_data: Any) -> None: elif action_name == 'apply_preprocs_preview': self._handle_preprocs_preview(action_data) elif action_name == 'segmentation_confirmed': - pass # Handle confirmation action in the application controller + # Ensure the model has the confirmed segmentation data + # This is especially important for manually drawn segmentations + if action_data: + self.model.set_manual_segmentation(action_data) else: raise ValueError(f"Unknown action: {action_name}") diff --git a/src/ceus/seg_loading/seg_loading_view_coordinator.py b/src/ceus/seg_loading/seg_loading_view_coordinator.py index 4f6d9d5..0656042 100644 --- a/src/ceus/seg_loading/seg_loading_view_coordinator.py +++ b/src/ceus/seg_loading/seg_loading_view_coordinator.py @@ -15,6 +15,7 @@ from .views.seg_file_selection_widget import SegFileSelectionWidget from .views.draw_roi_widget import DrawROIWidget from .views.draw_voi_widget import DrawVOIWidget +from .views.seg_preview_widget import SegPreviewWidget from engines.ceus.src.data_objs import UltrasoundImage, CeusSeg @@ -48,6 +49,7 @@ def __init__(self, image_data: UltrasoundImage, parent: Optional[QWidget] = None self._seg_type_widget: Optional[SegTypeSelectionWidget] = None self._seg_file_widget: Optional[SegFileSelectionWidget] = None self._voi_drawing_widget: Optional[DrawVOIWidget] = None + self._seg_preview_widget: Optional[SegPreviewWidget] = None # Current state self._selected_seg_type: Optional[str] = None @@ -109,6 +111,7 @@ def reset_to_seg_type_selection(self) -> None: widgets_to_remove = [ self._seg_file_widget, self._voi_drawing_widget, + self._seg_preview_widget, ] for widget in widgets_to_remove: @@ -181,6 +184,7 @@ def show_voi_drawing(self) -> None: self._voi_drawing_widget = DrawVOIWidget(self._image_data) # Connect signals to handle user actions + self._voi_drawing_widget.segmentation_completed.connect(self.show_segmentation_preview) self._voi_drawing_widget.back_requested.connect(self.reset_to_seg_type_selection) self._voi_drawing_widget.close_requested.connect(self.close_requested.emit) self._voi_drawing_widget.apply_preprocs_preview.connect(self._on_preprocs_preview_requested) @@ -201,6 +205,7 @@ def show_roi_drawing(self) -> None: self._roi_drawing_widget = DrawROIWidget(self._image_data) # Connect signals to handle user actions + self._roi_drawing_widget.segmentation_completed.connect(self.show_segmentation_preview) self._roi_drawing_widget.back_requested.connect(self.reset_to_seg_type_selection) self._roi_drawing_widget.close_requested.connect(self.close_requested.emit) @@ -208,6 +213,34 @@ def show_roi_drawing(self) -> None: self.addWidget(self._roi_drawing_widget) self.setCurrentWidget(self._roi_drawing_widget) + def show_segmentation_preview(self, seg_data: CeusSeg) -> None: + """ + Show the segmentation preview widget. + + Args: + seg_data: Loaded segmentation data + """ + # Avoid redundant preview if already showing this data + if self._seg_preview_widget and self._seg_data is seg_data: + self.setCurrentWidget(self._seg_preview_widget) + return + + self._seg_data = seg_data + + # Create and setup segmentation preview widget + self._seg_preview_widget = SegPreviewWidget(self._image_data, seg_data) + + # Connect signals to handle user actions + self._seg_preview_widget.segmentation_confirmed.connect( + lambda: self.user_action.emit('segmentation_confirmed', seg_data) + ) + self._seg_preview_widget.back_requested.connect(self.reset_to_seg_type_selection) + self._seg_preview_widget.close_requested.connect(self.close_requested.emit) + + # Add to stack and show + self.addWidget(self._seg_preview_widget) + self.setCurrentWidget(self._seg_preview_widget) + # ============================================================================ # USER ACTION HANDLING - Process user interactions and communicate with controller # ============================================================================ diff --git a/src/ceus/seg_loading/views/draw_roi_widget.py b/src/ceus/seg_loading/views/draw_roi_widget.py index f72738f..5c2d891 100644 --- a/src/ceus/seg_loading/views/draw_roi_widget.py +++ b/src/ceus/seg_loading/views/draw_roi_widget.py @@ -11,13 +11,25 @@ from scipy import interpolate from PIL import Image, ImageDraw from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.colors import LinearSegmentedColormap -from PyQt6.QtWidgets import QWidget, QHBoxLayout, QFileDialog +from PyQt6.QtWidgets import QWidget, QHBoxLayout, QFileDialog, QSlider, QVBoxLayout, QFrame, QCheckBox, QLabel from PyQt6.QtCore import pyqtSignal, Qt from ...mvc.base_view import BaseViewMixin from ..ui.draw_roi_ui import Ui_constructRoi from engines.ceus.src.data_objs import UltrasoundImage +from engines.ceus.src.image_preprocessing.functions import enhance_clahe, enhance_gamma +from engines.ceus.src.data_objs.seg import CeusSeg + +# Philips CEUS Colormap: Grayscale -> Red -> Yellow +philips_colors = [ + (0.0, 0.0, 0.0), # 0% - Black + (0.4, 0.4, 0.4), # 40% - Gray + (0.8, 0.0, 0.0), # 80% - Red + (1.0, 1.0, 0.0) # 100% - Yellow +] +philips_cmap = LinearSegmentedColormap.from_list("philips_ceus", philips_colors) class DrawROIWidget(QWidget, BaseViewMixin): @@ -31,6 +43,7 @@ class DrawROIWidget(QWidget, BaseViewMixin): # Signals for communicating with controller segmentation_saved = pyqtSignal(str) # emit with saved file path + segmentation_completed = pyqtSignal(object) # CeusSeg object back_requested = pyqtSignal() close_requested = pyqtSignal() @@ -59,8 +72,18 @@ def __init__(self, image_data: UltrasoundImage, parent: Optional[QWidget] = None self._target_frame = 0 # Target frame for smooth transitions self._frame_update_pending = False + # Enhancement parameters + self._clahe_clip_limit = 1.2 + self._gamma = 1.5 self._width_scale = 1.0 + # Enhancement parameters + self._clahe_clip_limit = 1.2 + self._gamma = 1.5 + self._use_philips_ceus = False + self._enhanced_cache = None # Cache for enhanced current frame + self._enhanced_cache_idx = -1 + self._setup_ui() self._connect_signals() self._show_draw_type_selection() @@ -96,6 +119,8 @@ def _setup_ui(self) -> None: 'roi_name_label', 'save_name_input', 'save_roi_button', 'back_from_save_button', ] + self._save_objects = self._save_seg_menu_objects + self._draw_types_objects = [ 'draw_rect_drag_type_button', 'draw_freehand_drag_type_button', 'draw_pts_type_button', ] @@ -112,6 +137,17 @@ def _setup_ui(self) -> None: # Setup matplotlib canvas for frame preview self._setup_matplotlib_canvas() + + # Add a "Confirm & Review" button programmatically + from PyQt6.QtWidgets import QPushButton + self.confirm_review_button = QPushButton("Confirm && Review", parent=self) + self.confirm_review_button.setMinimumSize(self._ui.save_roi_button.minimumSize()) + self.confirm_review_button.setMaximumSize(self._ui.save_roi_button.maximumSize()) + self.confirm_review_button.setStyleSheet(self._ui.save_roi_button.styleSheet()) + # Position it next to the save button in the layout + self._ui.chooseImageButtonsLayout_4.addWidget(self.confirm_review_button) + self.confirm_review_button.hide() + self._setup_enhancement_controls() # Display frame preview @@ -145,6 +181,7 @@ def _connect_signals(self) -> None: self._ui.clear_save_folder_button.clicked.connect(self._ui.save_folder_input.clear) self._ui.back_from_save_button.clicked.connect(self._show_draw_type_selection) self._ui.save_roi_button.clicked.connect(self._on_save_roi) + self.confirm_review_button.clicked.connect(self._on_confirm_review_clicked) def _initialize_frame_preview(self) -> None: """Initialize the frame preview with optimized matplotlib setup.""" @@ -256,52 +293,112 @@ def _update_aspect_ratio(self) -> None: self._matplotlib_canvas.draw_idle() def _setup_enhancement_controls(self) -> None: - """Add enhancement sliders to the sidebar.""" - from PyQt6.QtWidgets import QVBoxLayout, QLabel, QSlider, QFrame - + """Add enhancement sliders beside the frame slider in a single horizontal line.""" + # Container frame for enhancement controls enh_group = QFrame() enh_group.setStyleSheet("background-color: rgba(255, 255, 255, 0); border: none;") - container_layout = QVBoxLayout(enh_group) - container_layout.setContentsMargins(0, 10, 0, 10) + + # Main horizontal layout for the enhancement section + container_layout = QHBoxLayout(enh_group) + container_layout.setContentsMargins(0, 0, 15, 0) container_layout.setSpacing(15) - def create_enh_column(label_text, min_val, max_val, current_val, callback): - col_widget = QWidget() - col_layout = QVBoxLayout(col_widget) - col_layout.setContentsMargins(0, 0, 0, 0) - col_layout.setSpacing(5) + def create_compact_control(label_text, min_val, max_val, current_val, callback): + # Widget to hold label, slider, and value in ONE line + ctrl_widget = QWidget() + ctrl_layout = QHBoxLayout(ctrl_widget) + ctrl_layout.setContentsMargins(0, 0, 0, 0) + ctrl_layout.setSpacing(5) lbl = QLabel(label_text) - lbl.setStyleSheet("font-size: 14px; color: white; font-weight: bold;") - lbl.setAlignment(Qt.AlignmentFlag.AlignCenter) - col_layout.addWidget(lbl) + lbl.setStyleSheet("font-size: 10px; color: white; font-weight: bold;") + ctrl_layout.addWidget(lbl) - row_layout = QHBoxLayout() + # Slider slider = QSlider(Qt.Orientation.Horizontal) slider.setRange(min_val, max_val) slider.setValue(current_val) - slider.setMinimumWidth(100) - slider.setMaximumWidth(120) + slider.setStyleSheet(self._ui.frame_slider.styleSheet()) + slider.setFixedWidth(70) + slider.setFixedHeight(12) slider.valueChanged.connect(callback) - + ctrl_layout.addWidget(slider) + val_lbl = QLabel(f"{current_val/10.0:.1f}") - val_lbl.setMinimumWidth(40) - val_lbl.setStyleSheet("color: #3498db; font-weight: bold; font-size: 14px;") - - row_layout.addWidget(slider) - row_layout.addWidget(val_lbl) - col_layout.addLayout(row_layout) + val_lbl.setStyleSheet("color: #3498db; font-weight: bold; font-size: 10px;") + val_lbl.setMinimumWidth(22) + val_lbl.setAlignment(Qt.AlignmentFlag.AlignLeft) + ctrl_layout.addWidget(val_lbl) - return col_widget, slider, val_lbl + return ctrl_widget, slider, val_lbl - width_col, self.width_slider, self.width_val_lbl = create_enh_column( + # Create controls + clahe_w, self.clahe_slider, self.clahe_val_lbl = create_compact_control( + "CLAHE", 1, 100, int(self._clahe_clip_limit * 10), self._on_clahe_changed + ) + gamma_w, self.gamma_slider, self.gamma_val_lbl = create_compact_control( + "GAMMA", 1, 40, int(self._gamma * 10), self._on_gamma_changed + ) + width_w, self.width_slider, self.width_val_lbl = create_compact_control( "WIDTH", 1, 50, int(self._width_scale * 10), self._on_width_changed ) - container_layout.addWidget(width_col) + # Pseudo colouring toggle nicely aligned + self.philips_check = QCheckBox("Pseudo colouring") + self.philips_check.setStyleSheet("color: white; font-weight: bold; font-size: 11px;") + self.philips_check.stateChanged.connect(self._on_philips_toggled) + + # Add to horizontal layout + container_layout.addWidget(clahe_w) + container_layout.addWidget(gamma_w) + container_layout.addWidget(width_w) + container_layout.addWidget(self.philips_check) + + # Add to the layout beside the frame slider (below the image) + self._ui.frameControlsLayout.insertWidget(0, enh_group) + + def _on_clahe_changed(self, value: int) -> None: + """Handle CLAHE clip limit change.""" + self._clahe_clip_limit = value / 10.0 + if hasattr(self, 'clahe_val_lbl'): + self.clahe_val_lbl.setText(f"{self._clahe_clip_limit:.1f}") + self._invalidate_enhancement_cache() + + def _on_gamma_changed(self, value: int) -> None: + """Handle gamma change.""" + self._gamma = value / 10.0 + if hasattr(self, 'gamma_val_lbl'): + self.gamma_val_lbl.setText(f"{self._gamma:.1f}") + self._invalidate_enhancement_cache() + + def _on_philips_toggled(self, state: int) -> None: + """Handle Philips CEUS pseudocolor toggle.""" + self._use_philips_ceus = state == Qt.CheckState.Checked.value + # Update colormap on artist + if self._im_artist: + new_cmap = philips_cmap if self._use_philips_ceus else 'gray' + self._im_artist.set_cmap(new_cmap) + self._matplotlib_canvas.draw_idle() - # Add to the layout below the frame slider - self._ui.side_bar_layout.addWidget(enh_group) + def _invalidate_enhancement_cache(self) -> None: + """Invalidate the enhancement cache and trigger display update.""" + self._enhanced_cache = None + self._enhanced_cache_idx = -1 + self._force_frame_update() + + def _enhance_frame(self, frame_2d: np.ndarray) -> np.ndarray: + """Enhance a 2D image frame using backend engine functions.""" + # Create a temporary UltrasoundImage for processing + temp_im = UltrasoundImage(self._image_data.scan_path) + temp_im.pixel_data = frame_2d + temp_im.pixdim = self._image_data.pixdim + temp_im.frame_rate = self._image_data.frame_rate + + # Apply enhancements + temp_im = enhance_clahe(temp_im, clip_limit=self._clahe_clip_limit) + temp_im = enhance_gamma(temp_im, gamma=self._gamma) + + return temp_im.pixel_data def _on_frame_changed(self, value: int) -> None: """Handle frame slider change with optimized performance.""" @@ -312,8 +409,18 @@ def _on_frame_changed(self, value: int) -> None: def _update_frame_display(self, frame_index: int) -> None: """Update the frame display with consistent parameters.""" if self._im_artist: - self._displayed_im = self._all_frames[frame_index] + # Update cache if needed + if self._enhanced_cache is None or self._enhanced_cache_idx != frame_index: + self._enhanced_cache = self._enhance_frame(self._all_frames[frame_index]) + self._enhanced_cache_idx = frame_index + + self._displayed_im = self._enhanced_cache self._im_artist.set_array(self._displayed_im) + + # Ensure correct colormap is applied (e.g. after initialization) + new_cmap = philips_cmap if self._use_philips_ceus else 'gray' + self._im_artist.set_cmap(new_cmap) + self._ui.cur_frame_label.setText(str(np.round(frame_index*self._image_data.frame_rate, decimals=2))) def _force_frame_update(self) -> None: @@ -649,6 +756,7 @@ def _select_dest_folder(self) -> None: def _hide_save_menu(self) -> None: """Hide the save menu.""" + self.confirm_review_button.hide() for obj_name in self._save_seg_menu_objects: widget = getattr(self._ui, obj_name, None) if widget: @@ -689,6 +797,7 @@ def _show_save_menu(self) -> None: widget.show() else: print(f"Warning: Widget '{obj_name}' not found in UI") + self.confirm_review_button.show() def _hide_draw_type_selection(self) -> None: """Hide the draw type selection layout.""" @@ -847,3 +956,48 @@ def _on_save_roi(self) -> None: self.segmentation_saved.emit(nii_path) print(f"Segmentation saved to: {nii_path}") + + def _on_confirm_review_clicked(self) -> None: + """Handle confirmation and transition to formal review screen.""" + + # Ensure there is a drawn ROI to confirm + if len(self._roi_plot_coords[0]) < 3: + self.show_error("Please draw a valid region of interest before confirming.") + return + + # Create binary mask from drawn ROI + spline = [(self._roi_plot_coords[0][i], self._roi_plot_coords[1][i]) for i in range(len(self._roi_plot_coords[0]))] + + # Note: self._all_frames shape is [t, y, x] (or similar) + # Based on _on_save_roi, it seems to be [t, y, x] + mask_2d = Image.new("L", (self._all_frames[self._frame].shape[1], self._all_frames[self._frame].shape[0]), 0) + ImageDraw.Draw(mask_2d).polygon(spline, outline=1, fill=1) + mask_2d = np.array(mask_2d, dtype=np.uint8) + + # Create CeusSeg object + seg_data = CeusSeg() + seg_data.seg_name = f"Manual_{self._image_data.scan_name}" + + # CEUS expects 3D mask (x, y, z) + # We need to create a 3D mask where this 2D ROI is on one slice or repeated. + # However, for consistency with DrawVOIWidget, we probably want a 3D volume. + # If DrawROIWidget is only for a single frame, z_len should match what is expected. + x_len, y_len, z_len = self._image_data.pixel_data.shape[:3] + seg_mask = np.zeros((x_len, y_len, z_len), dtype=np.uint8) + + # Translate 2D mask [y, x] to [x, y, z] slice + # Assuming the ROI was drawn on a specific slice? + # Actually DrawROIWidget seems to be for 2D images or a specific frame. + # If the image is 4D [x, y, z, t], maybe it was drawn on the central slice? + # Let's assume it was for 2D or we put it on the middle slice of 3D. + mid_z = z_len // 2 + + # Handle shape mismatch if any + if mask_2d.shape[1] == x_len and mask_2d.shape[0] == y_len: + seg_mask[:, :, mid_z] = mask_2d.T + + seg_data.seg_mask = seg_mask + seg_data.pixdim = self._image_data.pixdim[:3] + + # Emit signal to coordinator + self.segmentation_completed.emit(seg_data) diff --git a/src/ceus/seg_loading/views/draw_voi_widget.py b/src/ceus/seg_loading/views/draw_voi_widget.py index 858a869..0f77831 100644 --- a/src/ceus/seg_loading/views/draw_voi_widget.py +++ b/src/ceus/seg_loading/views/draw_voi_widget.py @@ -4,23 +4,26 @@ from pathlib import Path from typing import Optional, Tuple, List +from scipy.ndimage import binary_fill_holes, binary_erosion +from matplotlib.backends.backend_qtagg import FigureCanvas +from matplotlib.path import Path as Mpl_Path +from matplotlib.colors import LinearSegmentedColormap +from PyQt6.QtWidgets import QWidget, QLabel, QHBoxLayout, QSizePolicy, QFileDialog, QSlider, QVBoxLayout, QFrame, QCheckBox, QPushButton +from PyQt6.QtCore import QEvent, pyqtSignal, Qt, QThread + import numpy as np import nibabel as nib -from scipy.ndimage import binary_fill_holes import matplotlib.pyplot as plt import matplotlib.animation as anim -from matplotlib.backends.backend_qtagg import FigureCanvas -from matplotlib.path import Path as Mpl_Path -from matplotlib.colors import LinearSegmentedColormap import scipy.interpolate as interpolate from scipy.spatial import ConvexHull -from PyQt6.QtWidgets import QWidget, QLabel, QHBoxLayout, QSizePolicy, QFileDialog, QSlider, QVBoxLayout, QFrame, QCheckBox -from PyQt6.QtCore import QEvent, pyqtSignal, Qt, QThread +import traceback from ...mvc.base_view import BaseViewMixin from ..ui.draw_voi_ui import Ui_voi_drawer -from engines.ceus.src.data_objs import UltrasoundImage +from engines.ceus.src.data_objs import UltrasoundImage, CeusSeg from .spline import calculateSpline3D, calculateSpline +from engines.ceus.src.image_preprocessing.functions import enhance_clahe, enhance_gamma # Philips CEUS Colormap: Grayscale -> Red -> Yellow philips_colors = [ @@ -98,7 +101,6 @@ def run(self): self.finished.emit(voi_mask) except Exception as e: - import traceback traceback.print_exc() self.error_msg.emit(f"Error interpolating VOI: {e}") @@ -129,6 +131,7 @@ class DrawVOIWidget(QWidget, BaseViewMixin): # Signals for communicating with controller file_selected = pyqtSignal(dict) # {'seg_path': str, 'seg_type': str} + segmentation_completed = pyqtSignal(object) # CeusSeg object back_requested = pyqtSignal() close_requested = pyqtSignal() apply_preprocs_preview = pyqtSignal(list) # List of dicts with 'name' and 'kwargs' keys @@ -177,7 +180,7 @@ def __init__(self, image_data: UltrasoundImage, parent: Optional[QWidget] = None # Per-plane resources (axial, sagittal, coronal) self._ax_sag_cor_matplotlib_canvases = [None, None, None] self._ax_sag_cor_planes = (None, None, None) - self._ax_sag_cor_index_maps = ((0, 1), (2, 1), (2, 0)) # dims that vary per plane + self._ax_sag_cor_index_maps = ((0, 1), (2, 1), (0, 2)) # (horiz_dim, vert_dim) self._ax_sag_cor_animations = [None, None, None] self._ax_sag_cor_plane_artists = [None, None, None] self._ax_sag_cor_crosshair_lines = [(None, None), (None, None), (None, None)] @@ -196,8 +199,8 @@ def __init__(self, image_data: UltrasoundImage, parent: Optional[QWidget] = None self._connect_signals() self._connect_matplotlib_events() self.setFocusPolicy(Qt.FocusPolicy.StrongFocus) - self._update_scan_display() # Initial UI update - self._refresh_frames() # Mark all planes for first update + self._update_scan_display() # Initial UI update + self._refresh_frames() # Mark all planes for first update def update_enhancement_cache(self, enhanced_frame: np.ndarray, frame: int) -> None: """Update the displayed image data, e.g. after preprocessing.""" @@ -314,6 +317,17 @@ def _setup_ui(self) -> None: self._ui.restart_voi_button, self._ui.save_voi_button, ] + + # Add a "Confirm & Review" button programmatically + self.confirm_review_button = QPushButton("Confirm && Review", parent=self._ui.horizontalLayoutWidget_4) + self.confirm_review_button.setMinimumSize(self._ui.save_voi_button.minimumSize()) + self.confirm_review_button.setMaximumSize(self._ui.save_voi_button.maximumSize()) + self.confirm_review_button.setStyleSheet(self._ui.save_voi_button.styleSheet()) + # Move it to a reasonable position - maybe next to save button + self.confirm_review_button.setGeometry(self._ui.save_voi_button.geometry().translated(0, 50)) + self.confirm_review_button.hide() + self._voi_decision_widgets.append(self.confirm_review_button) + self._save_voi_widgets = [ self._ui.back_from_save_button, self._ui.dest_folder_label, @@ -492,29 +506,23 @@ def _update_aspect_ratios(self) -> None: try: pix = self._image_data.pixdim - # Index 0: Axial (Plane 0) + # Index 0: Axial (Plane 0) -> (Y, X) -> Rows=Y, Cols=X -> dy / dx if self._ax_sag_cor_matplotlib_canvases[0]: dx, dy = pix[0], pix[1] aspect = (dy / dx if dx != 0 else 1.0) * self._width_scale_axial - fig0 = self._ax_sag_cor_matplotlib_canvases[0].figure - if fig0.axes: - fig0.axes[0].set_aspect(aspect) + self._ax_sag_cor_matplotlib_canvases[0].figure.gca().set_aspect(aspect) - # Index 1: Sagittal (Plane 1) + # Index 1: Sagittal (Plane 1) -> 90 CW Rotation -> (Y, Z) -> Rows=Y, Cols=Z -> dy / dz if self._ax_sag_cor_matplotlib_canvases[1]: dy, dz = pix[1], pix[2] aspect = (dy / dz if dz != 0 else 1.0) * self._width_scale_sagittal - fig1 = self._ax_sag_cor_matplotlib_canvases[1].figure - if fig1.axes: - fig1.axes[0].set_aspect(aspect) + self._ax_sag_cor_matplotlib_canvases[1].figure.gca().set_aspect(aspect) - # Index 2: Coronal (Plane 2) + # Index 2: Coronal (Plane 2) -> (Z, X) -> Rows=Z, Cols=X -> dz / dx if self._ax_sag_cor_matplotlib_canvases[2]: dx, dz = pix[0], pix[2] - aspect = (dx / dz if dz != 0 else 1.0) * self._width_scale_coronal - fig2 = self._ax_sag_cor_matplotlib_canvases[2].figure - if fig2.axes: - fig2.axes[0].set_aspect(aspect) + aspect = (dz / dx if dx != 0 else 1.0) * self._width_scale_coronal + self._ax_sag_cor_matplotlib_canvases[2].figure.gca().set_aspect(aspect) for canvas in self._ax_sag_cor_matplotlib_canvases: if canvas: @@ -583,7 +591,7 @@ def _initialize_plane_displays(self) -> None: ax = fig.add_subplot(111) ax.axis('off') # Get initial slice - slice_arr = self._get_plane_slice(plane_ix, initializing=True) + slice_arr = self._get_plane_slice(plane_ix) mask_arr = self._get_mask_slice(plane_ix) current_cmap = philips_cmap if self._use_philips_ceus else 'gray' @@ -605,21 +613,19 @@ def _initialize_plane_displays(self) -> None: except Exception as e: self.show_error(f"Error initializing plane display {plane_ix}: {e}") - def _get_plane_slice(self, plane_ix: int, initializing=False): + def _get_plane_slice(self, plane_ix: int): """Return 2D numpy slice for given plane index based on current crosshair.""" idx = self._get_plane_indices(plane_ix) current_t = self._crosshair_xyzt[3] # Check if we need to enhance a new frame - if not initializing and (self._enhanced_cache is None or self._enhanced_cache_frame != current_t): + if self._enhanced_cache is None or self._enhanced_cache_frame != current_t: # Get the 3D volume for current time frame current_frame_3d = self._pix_data[:, :, :, current_t] # Enhance the entire 3D volume ONCE per frame - self._enhance_volume(current_frame_3d) # performs enhancement SYNCHRONOUSLY + self._enhanced_cache = self._enhance_volume(current_frame_3d) self._enhanced_cache_frame = current_t - elif initializing: - self._enhanced_cache = self._image_data.pixel_data[:, :, :, current_t] # Cache the initial frame without enhancement for faster startup # Extract the 2D slice from cached enhanced volume slice_idx = list(idx[:3]) # Remove time dimension @@ -627,49 +633,41 @@ def _get_plane_slice(self, plane_ix: int, initializing=False): if arr.ndim != 2: arr = arr.squeeze() - # Axial plane (index 0) needs transpose for correct orientation - if plane_ix == 0: - arr = arr.T - return arr + # All planes need transpose to match (Vertical, Horizontal) orientation. + # Sagittal (plane 1) specifically needs a 90 deg clockwise rotation. + arr_t = arr.T + if plane_ix == 1: + return np.rot90(arr_t, k=-1) + return arr_t - def _enhance_volume(self, volume_3d: np.ndarray) -> None: + def _enhance_volume(self, volume_3d: np.ndarray) -> np.ndarray: """Enhance a 3D image volume using predefined enhancement methods in the backend engine.""" # Create a temporary UltrasoundImage for the current frame temp_im = UltrasoundImage(self._image_data.scan_path) - temp_im.pixel_data = volume_3d.T[None].T.copy() # Add back time dimension for processing + temp_im.pixel_data = volume_3d temp_im.pixdim = self._image_data.pixdim temp_im.frame_rate = self._image_data.frame_rate - - clahe_preproc_dict = { - 'name': 'enhance_clahe', - 'image_data': temp_im, - 'frame_ix': self._crosshair_xyzt[3], - 'kwargs': { - 'clip_limit': self._clahe_clip_limit, - 'tile_grid_size': (8, 8), - } - } - - gamma_preproc_dict = { - 'name': 'enhance_gamma', - 'image_data': None, # signal to reuse the already CLAHE-enhanced image (all preprocs in the same batch share the same image input) - 'frame_ix': self._crosshair_xyzt[3], - 'kwargs': { - 'gamma': self._gamma, - } - } - - preproc_dicts = [clahe_preproc_dict, gamma_preproc_dict] - self.apply_preprocs_preview.emit(preproc_dicts) # synchronous call to apply the enhancements and update the cache via the connected slot + + # Apply backend engine functions directly on the UltrasoundImage object + temp_im = enhance_clahe(temp_im, clip_limit=self._clahe_clip_limit) + temp_im = enhance_gamma(temp_im, gamma=self._gamma) + + return temp_im.pixel_data def _get_mask_slice(self, plane_ix: int): """Return RGBA numpy slice for the mask of the given plane index.""" idx = self._get_plane_indices(plane_ix)[:-1] # no time dimension arr = self._roi_masks_overlap[idx] - # Mask needs transpose for correct orientation to match the image slice - if plane_ix == 0: - arr = np.transpose(arr, (1, 0, 2)) # Transpose for axial plane - return arr + + # Consistent mapping for all planes as areas: + # Axial: (X, Y) slice -> want (Y, X) for imshow -> transpose (1, 0, 2) + # Sagittal: (Y, Z) slice -> want (Z, Y) then rot -> transpose (1, 0, 2) + rot90 + # Coronal: (X, Z) slice -> want (Z, X) for imshow -> transpose (1, 0, 2) + + arr_reg = np.transpose(arr, (1, 0, 2)) + if plane_ix == 1: + return np.rot90(arr_reg, k=-1) + return arr_reg def _get_plane_indices(self, plane_ix: int) -> Tuple[int]: """Return a list of indices for the given plane.""" @@ -837,6 +835,7 @@ def _connect_signals(self) -> None: self._ui.back_from_save_button.clicked.connect(self._on_back_from_save) self._ui.toggle_crosshair_visibility_button.clicked.connect(self._on_toggle_crosshair_visibility) self._ui.save_voi_button.clicked.connect(self._on_save_voi_clicked) + self.confirm_review_button.clicked.connect(self._on_confirm_review_clicked) # Configure slice/time controls self._ui.cur_slice_slider.setMinimum(0) @@ -948,9 +947,16 @@ def _on_roi_close(self): if plane_ix == 0: # Axial target_slice_mask[:, :, fixed_val] = mask_2d.T elif plane_ix == 1: # Sagittal - target_slice_mask[fixed_val, :, :] = mask_2d + # mask_2d from meshgrid (Y_len, X_len) where X_idx=z, Y_idx=y + # In sagittal plane_ix=1: vary_x=z(2), vary_y=y(1). fixed=x(0) + # mask_2d shape is (y_len, z_len). + # Meshgrid with 'xy' returns (rows=Y, cols=X), so mask_2d is (Y, Z). + # The display uses rot90(arr.T, k=-1) which is (Z, Y). + # To match the display, we must rot90 back: rot90(mask_2d, k=1).T + target_slice_mask[fixed_val, :, :] = np.rot90(mask_2d, k=1).T elif plane_ix == 2: # Coronal - target_slice_mask[:, fixed_val, :] = mask_2d + # mask_2d from meshgrid (Y_len, X_len) where X_idx=x, Y_idx=z + target_slice_mask[:, fixed_val, :] = mask_2d.T # Apply colors to the RGBA mask where the 3D mask is true current_roi_mask_rgba[target_slice_mask, 0] = 255 # Red @@ -1030,6 +1036,27 @@ def _on_save_voi_clicked(self): self._show_widget_lists([self._save_voi_widgets, self._voi_alpha_widgets]) self._refresh_frames() + def _on_confirm_review_clicked(self): + """Handle confirmation and transition to formal review screen.""" + + # Create CeusSeg object from current mask + seg_data = CeusSeg() + seg_data.seg_name = f"Manual_{self._image_data.scan_name}" + # Extract the binary mask from the overlap RGBA buffer (red channel > 0) + seg_data.seg_mask = (self._roi_masks_overlap[:, :, :, 0] > 0).astype(np.uint8) + seg_data.pixdim = self._image_data.pixdim[:3] + + # Preserve current visualization parameters for the preview step + seg_data.clahe_clip_limit = self._clahe_clip_limit + seg_data.gamma = self._gamma + seg_data.width_scale_axial = self._width_scale_axial + seg_data.width_scale_sagittal = self._width_scale_sagittal + seg_data.width_scale_coronal = self._width_scale_coronal + seg_data.use_philips_ceus = self._use_philips_ceus + + # Emit signal to coordinator + self.segmentation_completed.emit(seg_data) + def _on_export_voi_clicked(self): # Show saving label, hide save widgets self._ui.saving_voi_label.show() @@ -1201,28 +1228,25 @@ def _remove_duplicates(self, points: List[List[float]]) -> List[List[float]]: def _on_interpolate_voi(self): """Handle VOI interpolation from the drawn 2D ROIs.""" - if len(self._drawn_rois) == 2 or not len(self._drawn_rois): - print("At least 3 ROIs on different planes or 1 ROI is required for 3D interpolation.") + if len(self._drawn_rois) < 1: + print("At least 1 ROI is required for 3D interpolation.") return # Combine all points from all drawn ROIs all_points = [] - for _, pts, _ in self._drawn_rois: - xyz_pts = np.array(pts)[:, :3].T - x_interp, y_interp, z_interp = calculateSpline(*xyz_pts) - all_points.extend(zip(x_interp, y_interp, z_interp)) + for plane_num, pts, _ in self._drawn_rois: + all_points.extend(pts) # Ensure no duplicate points are used for interpolation unique_points = self._remove_duplicates(all_points) - if len(unique_points) < 4: + if len(unique_points) < 3: self.show_error("Interpolation Error", "Not enough unique points for 3D spline interpolation.") return # Perform 3D spline interpolation - x_coords, y_coords, z_coords = zip(*unique_points) - coords = np.transpose([x_coords, y_coords, z_coords]) + coords = np.array(unique_points) - if len(self._drawn_rois) > 2: + if len(self._drawn_rois) > 1: # Stop any existing worker if self._voi_interpolation_worker and self._voi_interpolation_worker.isRunning(): self._voi_interpolation_worker.quit() @@ -1241,22 +1265,14 @@ def _on_interpolate_voi(self): self._set_interp_loading(True) self._voi_interpolation_worker.start() else: + # Single ROI Case - just fill the area of that single mesh voi_mask = np.zeros((self._x_len, self._y_len, self._z_len), dtype=bool) - # For simplicity, we'll mark the voxels the spline passes through. - # A more robust solution would involve filling the volume enclosed by the spline surface. - interp_points = np.round(np.array(list(coords))).astype(int) - - # Clamp points to be within bounds - interp_points[:, 0] = np.clip(interp_points[:, 0], 0, self._x_len - 1) - interp_points[:, 1] = np.clip(interp_points[:, 1], 0, self._y_len - 1) - interp_points[:, 2] = np.clip(interp_points[:, 2], 0, self._z_len - 1) - - voi_mask[interp_points[:, 0], interp_points[:, 1], interp_points[:, 2]] = True + # For a single ROI, we extract the boolean mask from the stored RGBA mask + # The red channel (index 0) is set to 255 for the drawn mask. + _, _, roi_mask_rgba = self._drawn_rois[0] + voi_mask = roi_mask_rgba[:, :, :, 0] > 0 - # Fill holes in the resulting mask to create a solid volume - voi_mask = _smooth_3d_mask(voi_mask) - self._hide_widget_lists([self._drawing_widgets]) self._on_interpolation_finished(voi_mask) def _save_voi(self): @@ -1273,13 +1289,18 @@ def _save_voi(self): out_path = Path(self._ui.save_folder_input.text()) / out_name + # Construct affine matching standard NIfTI orientation affine = np.eye(4) for i, res in enumerate(self._image_data.pixdim[:3]): affine[i, i] = res - voi_mask = np.array(self._roi_masks_overlap[:, :, :, 0] / 255.0).astype(np.uint8) - niiarray = nib.Nifti1Image(voi_mask, affine) - niiarray.header["descrip"] = self._image_data.scan_name - nib.save(niiarray, out_path) + + # Ensure binary mask is correctly extracted from the overlay buffer + voi_mask = (self._roi_masks_overlap[:, :, :, 0] > 0).astype(np.uint8) + + # Save as NIfTI image + nii_img = nib.Nifti1Image(voi_mask, affine) + nii_img.header["descrip"] = self._image_data.scan_name + nib.save(nii_img, out_path) def _set_interp_loading(self, loading_state: bool) -> None: """Set the interpolation loading state.""" diff --git a/src/ceus/seg_loading/views/seg_preview_widget.py b/src/ceus/seg_loading/views/seg_preview_widget.py new file mode 100644 index 0000000..509f966 --- /dev/null +++ b/src/ceus/seg_loading/views/seg_preview_widget.py @@ -0,0 +1,840 @@ +""" +Segmentation Preview Widget for CEUS +""" + +from typing import Optional, Tuple, List, Dict, Any +from pathlib import Path +from matplotlib.backends.backend_qtagg import FigureCanvas +from matplotlib.colors import LinearSegmentedColormap +from PyQt6.QtWidgets import QWidget, QLabel, QHBoxLayout, QSizePolicy, QSlider, QVBoxLayout, QFrame, QCheckBox, QPushButton, QFileDialog +from PyQt6.QtCore import QEvent, pyqtSignal, Qt + +import numpy as np +import nibabel as nib +import matplotlib.pyplot as plt +import matplotlib.animation as anim + +from ...mvc.base_view import BaseViewMixin +from ..ui.draw_voi_ui import Ui_voi_drawer +from engines.ceus.src.data_objs import UltrasoundImage, CeusSeg +from engines.ceus.src.image_preprocessing.functions import enhance_clahe, enhance_gamma + +# Philips CEUS Colormap: Grayscale -> Red -> Yellow +philips_colors = [ + (0.0, 0.0, 0.0), # 0% - Black + (0.4, 0.4, 0.4), # 40% - Gray + (0.8, 0.0, 0.0), # 80% - Red + (1.0, 1.0, 0.0) # 100% - Yellow +] +philips_cmap = LinearSegmentedColormap.from_list("philips_ceus", philips_colors) + +class SegPreviewWidget(QWidget, BaseViewMixin): + """ + Widget for previewing and confirming segmentation for CEUS. + Reuses UI components from VOI drawer but in a read-only preview mode. + """ + + # Signals for communicating with controller + segmentation_confirmed = pyqtSignal() + back_requested = pyqtSignal() + close_requested = pyqtSignal() + + def __init__(self, image_data: UltrasoundImage, seg_data: CeusSeg, parent: Optional[QWidget] = None): + QWidget.__init__(self, parent) + self.__init_base_view__(parent) + self._ui = Ui_voi_drawer() + self._image_data = image_data + self._seg_data = seg_data + self._pix_data = image_data.pixel_data + + # Enhancement parameters (Inherited from seg_data) + self._clahe_clip_limit = getattr(seg_data, 'clahe_clip_limit', 1.2) + self._gamma = getattr(seg_data, 'gamma', 1.5) + self._width_scale_axial = getattr(seg_data, 'width_scale_axial', 1.0) + self._width_scale_sagittal = getattr(seg_data, 'width_scale_sagittal', 1.0) + self._width_scale_coronal = getattr(seg_data, 'width_scale_coronal', 1.0) + self._use_philips_ceus = getattr(seg_data, 'use_philips_ceus', False) + self._mask_alpha = 125 # Default alpha for mask overlay (0-255) + + # Cache for enhanced volume + self._enhanced_cache = None + self._enhanced_cache_frame = -1 + + # Crosshair / navigation state + self._crosshair_active = False + self._crosshair_visible = True + + # Dimensions: x, y, z, t + if self._pix_data.ndim == 4: + self._x_len, self._y_len, self._z_len, self._num_slices = self._pix_data.shape + else: + # Fallback if 3D + self._x_len, self._y_len, self._z_len = self._pix_data.shape + self._num_slices = 1 + self._pix_data = self._pix_data.reshape((self._x_len, self._y_len, self._z_len, 1)) + + self._crosshair_xyzt = [self._x_len // 2, self._y_len // 2, self._z_len // 2, 0] + + # Segmentation mask overlay + self._roi_masks_overlap = np.zeros((self._x_len, self._y_len, self._z_len, 4), dtype=np.uint8) + self._seg_mask_indices = None # Store binary mask indices for alpha updates + if hasattr(seg_data, 'seg_mask') and seg_data.seg_mask is not None: + # seg_mask should be same spatial shape (x, y, z) + mask = seg_data.seg_mask + + # Ensure spatial alignment with image data if dimensions are flipped or permuted + # If the mask was saved from DrawVOIWidget, it should match the pixel_data shape + if mask.shape == (self._x_len, self._y_len, self._z_len): + self._seg_mask_indices = np.where(mask > 0) + elif mask.shape == (self._y_len, self._x_len, self._z_len): + # Handle common XY transpose if detected + self._seg_mask_indices = np.where(mask.transpose(1, 0, 2) > 0) + else: + # Log or handle shape mismatch more gracefully if needed + print(f"Warning: Mask shape {mask.shape} does not match image shape {(self._x_len, self._y_len, self._z_len)}") + # Try to fit the mask as much as possible if shapes match in 3D volume + try: + self._seg_mask_indices = np.where(mask[:self._x_len, :self._y_len, :self._z_len] > 0) + except Exception: + pass + + if self._seg_mask_indices is not None and len(self._seg_mask_indices[0]) > 0: + self._roi_masks_overlap[self._seg_mask_indices[0], self._seg_mask_indices[1], self._seg_mask_indices[2]] = [255, 0, 0, 125] + + # Jump crosshair to a point within the mask to show it immediately + mask_indices = np.where(self._roi_masks_overlap[..., 3] > 0) + if len(mask_indices[0]) > 0: + mid_idx = len(mask_indices[0]) // 2 + self._crosshair_xyzt = [ + mask_indices[0][mid_idx], + mask_indices[1][mid_idx], + mask_indices[2][mid_idx], + 0 + ] + else: + self._crosshair_xyzt = [self._x_len // 2, self._y_len // 2, self._z_len // 2, 0] + + # Per-plane resources (axial, sagittal, coronal) + self._ax_sag_cor_matplotlib_canvases = [None, None, None] + self._ax_sag_cor_planes = (None, None, None) + self._ax_sag_cor_index_maps = ((0, 1), (2, 1), (0, 2)) # (horiz_dim, vert_dim) + self._ax_sag_cor_animations = [None, None, None] + self._ax_sag_cor_plane_artists = [None, None, None] + self._ax_sag_cor_crosshair_lines = [(None, None), (None, None), (None, None)] + self._ax_sag_cor_pending = [False, False, False] + self._ax_sag_cor_seg_masks = [None, None, None] + + # UI & visualization setup + self._setup_ui() + self._setup_matplotlib_canvases() + self._initialize_plane_displays() + self._setup_all_plane_animations() + self._connect_signals() + self._connect_matplotlib_events() + self.setFocusPolicy(Qt.FocusPolicy.StrongFocus) + self._update_scan_display() + self._refresh_frames() + + def _setup_ui(self) -> None: + """Setup the user interface to match the segmentation menu style.""" + self._ui.setupUi(self) + + # Store QLabels as tags for layout mapping + self._ax_sag_cor_planes = (self._ui.ax_plane, self._ui.sag_plane, self._ui.cor_plane) + + # Configure layout + self.setLayout(self._ui.full_screen_layout) + self._ui.full_screen_layout.setStretchFactor(self._ui.side_bar_layout, 1) + self._ui.full_screen_layout.setStretchFactor(self._ui.voi_layout, 10) + + # Widget groups matching DrawVOIWidget for consistency + self._drawing_widgets = [ + self._ui.draw_roi_button, + self._ui.interpolate_voi_button, + self._ui.undo_last_pt_button, + self._ui.close_roi_button, + self._ui.undo_last_roi_button, + self._ui.construct_voi_label, + ] + + # Add a "Confirm & Analysis" button programmatically + self.confirm_review_button = QPushButton("Confirm && Analysis") + self.confirm_review_button.setMinimumSize(self._ui.save_voi_button.minimumSize()) + self.confirm_review_button.setMaximumSize(self._ui.save_voi_button.maximumSize()) + self.confirm_review_button.setStyleSheet(self._ui.save_voi_button.styleSheet()) + + # Insert it into the layout that has Restart and Save + self._ui.horizontalLayout_2.addWidget(self.confirm_review_button) + + # Update existing button texts for clarity in preview mode + self._ui.restart_voi_button.setText("Review / Redraw") + self._ui.save_voi_button.setText("Save Setup") + + self._voi_decision_widgets = [ + self._ui.restart_voi_button, + self._ui.save_voi_button, + self.confirm_review_button + ] + + self._save_voi_widgets = [ + self._ui.back_from_save_button, + self._ui.dest_folder_label, + self._ui.voi_name_label, + self._ui.save_folder_input, + self._ui.save_name_input, + self._ui.choose_save_folder_button, + self._ui.clear_save_folder_button, + self._ui.export_voi_button, + ] + + self._voi_alpha_widgets = [ + self._ui.alpha_label, + self._ui.alpha_of_label, + self._ui.alpha_spin_box, + self._ui.alpha_status, + self._ui.alpha_total + ] + + # Initial visibility + self._ui.scan_name_input.setText(self._image_data.scan_name) + self._ui.segSidebarLabel_2.setText("Segmentation Selection") + self._ui.toggle_crosshair_visibility_button.setText('Hide Crosshair') + self._ui.cur_slice_label.setText("Current Frame:") + + self._hide_widget_lists([self._drawing_widgets, self._save_voi_widgets, self._voi_alpha_widgets]) + self._show_widget_lists([self._voi_decision_widgets]) + + # Hide original plane labels (replaced by canvases) + # Note: We do NOT hide ax_plane, sag_plane, cor_plane here because they are needed as containers + for widget in [self._ui.interp_loading_label, self._ui.saving_voi_label]: + widget.hide() + + self._ui.navigating_label.hide() + self._ui.observing_label.show() + + # Setup enhancement controls + self._setup_enhancement_controls() + + # Update slider for frames + self._ui.cur_slice_slider.setMinimum(0) + self._ui.cur_slice_slider.setMaximum(self._num_slices - 1) + self._ui.cur_slice_slider.setValue(0) + self._ui.cur_slice_total.setText(str(self._num_slices)) + self._ui.cur_slice_spin_box.setRange(1, self._num_slices) + self._ui.cur_slice_spin_box.setValue(1) + + self._ui.ax_total_frames.setText(str(self._z_len)) + self._ui.sag_total_frames.setText(str(self._x_len)) + self._ui.cor_total_frames.setText(str(self._y_len)) + + # Install event filters + for label in self._ax_sag_cor_planes: + if label: + label.installEventFilter(self) + + def _cleanup_animations(self): + """Internal helper to stop animations safely.""" + for i in range(3): + if i < len(self._ax_sag_cor_animations) and self._ax_sag_cor_animations[i]: + try: + self._ax_sag_cor_animations[i].event_source.stop() + except Exception: + pass + self._ax_sag_cor_animations[i] = None + + # ============================================================================ + # UI SETUP & HELPERS + # ============================================================================ + + def _show_widget_lists(self, widget_lists: List[List[QWidget]]) -> None: + """Helper to show groups of widgets.""" + for widget_list in widget_lists: + for widget in widget_list: + widget.show() + + def _hide_widget_lists(self, widget_lists: List[List[QWidget]]) -> None: + """Helper to hide groups of widgets.""" + for widget_list in widget_lists: + for widget in widget_list: + widget.hide() + + def _setup_enhancement_controls(self) -> None: + """Add enhancement sliders to the sidebar, mirroring DrawVOIWidget style.""" + enh_group = QFrame() + enh_group.setStyleSheet("background-color: rgba(255, 255, 255, 0); border: none;") + + container_layout = QVBoxLayout(enh_group) + container_layout.setContentsMargins(0, 10, 0, 10) + container_layout.setSpacing(15) + + row1_layout = QHBoxLayout() + row2_layout = QHBoxLayout() + row1_layout.setSpacing(20) + row2_layout.setSpacing(20) + + def create_enh_column(label_text, min_val, max_val, current_val, callback): + col_widget = QWidget() + col_layout = QVBoxLayout(col_widget) + col_layout.setContentsMargins(0, 0, 0, 0) + col_layout.setSpacing(5) + + lbl = QLabel(label_text) + lbl.setStyleSheet("font-size: 14px; color: white; font-weight: bold;") + lbl.setAlignment(Qt.AlignmentFlag.AlignCenter) + col_layout.addWidget(lbl) + + row_layout = QHBoxLayout() + slider = QSlider(Qt.Orientation.Horizontal) + slider.setRange(min_val, max_val) + slider.setValue(current_val) + slider.setStyleSheet(self._ui.cur_slice_slider.styleSheet()) + slider.setMinimumWidth(80) + slider.setMaximumWidth(120) + slider.valueChanged.connect(callback) + + val_lbl = QLabel(f"{current_val/10.0:.1f}") + val_lbl.setMinimumWidth(40) + val_lbl.setStyleSheet("color: #3498db; font-weight: bold; font-size: 14px;") + + row_layout.addWidget(slider) + row_layout.addWidget(val_lbl) + col_layout.addLayout(row_layout) + return col_widget, slider, val_lbl + + # Sliders + clahe_col, self.clahe_slider, self.clahe_val_lbl = create_enh_column( + "CLAHE", 1, 100, int(self._clahe_clip_limit * 10), self._on_clahe_changed + ) + gamma_col, self.gamma_slider, self.gamma_val_lbl = create_enh_column( + "GAMMA", 1, 40, int(self._gamma * 10), self._on_gamma_changed + ) + width_ax_col, self.width_ax_slider, self.width_ax_val_lbl = create_enh_column( + "WIDTH (AX)", 1, 50, int(self._width_scale_axial * 10), self._on_width_axial_changed + ) + width_sag_col, self.width_sag_slider, self.width_sag_val_lbl = create_enh_column( + "WIDTH (SAG)", 1, 50, int(self._width_scale_sagittal * 10), self._on_width_sagittal_changed + ) + width_cor_col, self.width_cor_slider, self.width_cor_val_lbl = create_enh_column( + "WIDTH (COR)", 1, 50, int(self._width_scale_coronal * 10), self._on_width_coronal_changed + ) + + # Add VOI Alpha slider + alpha_col, self.alpha_slider, self.alpha_val_lbl = create_enh_column( + "VOI ALPHA", 0, 2550, int(self._mask_alpha * 10), lambda v: self._on_alpha_changed(v // 10) + ) + # Fix label to show integer for alpha + self.alpha_val_lbl.setText(str(self._mask_alpha)) + self.alpha_slider.valueChanged.disconnect() + self.alpha_slider.valueChanged.connect(lambda v: (self._on_alpha_changed(v // 10), self.alpha_val_lbl.setText(str(v // 10)))) + + row1_layout.addWidget(clahe_col) + row1_layout.addWidget(gamma_col) + + self.philips_check = QCheckBox("Pseudocoloring") + self.philips_check.setChecked(self._use_philips_ceus) + self.philips_check.setStyleSheet("color: white; font-weight: bold; font-size: 14px;") + self.philips_check.stateChanged.connect(self._on_philips_toggled) + row1_layout.addWidget(self.philips_check) + row1_layout.addWidget(alpha_col) + + row2_layout.addWidget(width_ax_col) + row2_layout.addWidget(width_sag_col) + row2_layout.addWidget(width_cor_col) + + container_layout.addLayout(row1_layout) + container_layout.addLayout(row2_layout) + self._ui.verticalLayout_2.addWidget(enh_group) + + def _invalidate_enhancement_cache(self) -> None: + """Clear cache when processing parameters change.""" + self._enhanced_cache = None + self._enhanced_cache_frame = -1 + self._refresh_frames() + + def _on_clahe_changed(self, value: int) -> None: + """Handle CLAHE change.""" + self._clahe_clip_limit = value / 10.0 + if hasattr(self, 'clahe_val_lbl'): + self.clahe_val_lbl.setText(f"{self._clahe_clip_limit:.1f}") + self._invalidate_enhancement_cache() + + def _on_gamma_changed(self, value: int) -> None: + """Handle gamma change.""" + self._gamma = value / 10.0 + if hasattr(self, 'gamma_val_lbl'): + self.gamma_val_lbl.setText(f"{self._gamma:.1f}") + self._invalidate_enhancement_cache() + + def _on_width_axial_changed(self, value: int) -> None: + """Handle axial aspect ratio.""" + self._width_scale_axial = value / 10.0 + if hasattr(self, 'width_ax_val_lbl'): + self.width_ax_val_lbl.setText(f"{self._width_scale_axial:.1f}") + self._update_aspect_ratios() + + def _on_width_sagittal_changed(self, value: int) -> None: + """Handle sagittal aspect ratio.""" + self._width_scale_sagittal = value / 10.0 + if hasattr(self, 'width_sag_val_lbl'): + self.width_sag_val_lbl.setText(f"{self._width_scale_sagittal:.1f}") + self._update_aspect_ratios() + + def _on_width_coronal_changed(self, value: int) -> None: + """Handle coronal aspect ratio.""" + self._width_scale_coronal = value / 10.0 + if hasattr(self, 'width_cor_val_lbl'): + self.width_cor_val_lbl.setText(f"{self._width_scale_coronal:.1f}") + self._update_aspect_ratios() + + def _on_alpha_changed(self, value: int) -> None: + """Handle alpha transparency change for the VOI mask.""" + self._mask_alpha = value + # Update the rgba mask transparency + # Use stored mask indices to ensure we can recover from alpha=0 + if self._seg_mask_indices is not None and len(self._seg_mask_indices[0]) > 0: + # Re-apply color (Red) and new alpha to the relevant indices + self._roi_masks_overlap[self._seg_mask_indices[0], self._seg_mask_indices[1], self._seg_mask_indices[2]] = [255, 0, 0, self._mask_alpha] + self._refresh_frames() + + def _update_aspect_ratios(self) -> None: + """Update artist aspect ratios based on physics (pixdim) and sliders.""" + if not hasattr(self, '_image_data') or self._image_data is None: + return + + try: + pix = self._image_data.pixdim + + # Plane 0: Axial (XY) -> show (Y, X) -> Rows=Y, Cols=X -> dy / dx + if self._ax_sag_cor_matplotlib_canvases[0]: + dx, dy = pix[0], pix[1] + aspect_ax = (dy / dx if dx != 0 else 1.0) * self._width_scale_axial + self._ax_sag_cor_matplotlib_canvases[0].figure.gca().set_aspect(aspect_ax) + + # Plane 1: Sagittal (YZ) -> 90 CW Rotation -> show (Y, Z) -> Rows=Y, Cols=Z -> dy / dz + if self._ax_sag_cor_matplotlib_canvases[1]: + dy, dz = pix[1], pix[2] + aspect_sag = (dy / dz if dz != 0 else 1.0) * self._width_scale_sagittal + self._ax_sag_cor_matplotlib_canvases[1].figure.gca().set_aspect(aspect_sag) + + # Plane 2: Coronal (XZ) -> show (Z, X) -> Rows=Z, Cols=X -> dz / dx + if self._ax_sag_cor_matplotlib_canvases[2]: + dx, dz = pix[0], pix[2] + aspect_cor = (dz / dx if dx != 0 else 1.0) * self._width_scale_coronal + self._ax_sag_cor_matplotlib_canvases[2].figure.gca().set_aspect(aspect_cor) + + for canvas in self._ax_sag_cor_matplotlib_canvases: + if canvas: + canvas.draw_idle() + + self._refresh_frames() + except Exception as e: + print(f"Error updating aspect ratios: {e}") + + def _on_philips_toggled(self, state: int) -> None: + """Handle Philips CEUS pseudocolor toggle.""" + self._use_philips_ceus = state == Qt.CheckState.Checked.value + new_cmap = philips_cmap if self._use_philips_ceus else 'gray' + for artist in self._ax_sag_cor_plane_artists: + if artist: + artist.set_cmap(new_cmap) + self._refresh_frames() + + def _setup_matplotlib_canvases(self) -> None: + """Initialize and embed matplotlib canvases into each plane's placeholder layout.""" + # Use the axial, sagittal, coronal labels themselves as the parent for the canvases + # This ensures they are inside their respective QFrame boxes and aligned correctly + for i, parent_label in enumerate(self._ax_sag_cor_planes): + fig, ax = plt.subplots(facecolor='black') + fig.subplots_adjust(left=0, right=1, top=1, bottom=0) + ax.axis('off') + canvas = FigureCanvas(fig) + canvas.setParent(parent_label) + + # Use Expanding policy + canvas.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + + # Hide the label's text but keep the background/frame + parent_label.setText("") + + self._ax_sag_cor_matplotlib_canvases[i] = canvas + + def _initialize_plane_displays(self) -> None: + """Initial render of each orthogonal plane with fixed intensity scaling.""" + for i, canvas in enumerate(self._ax_sag_cor_matplotlib_canvases): + if not canvas: continue + ax = canvas.figure.axes[0] + + # Initial data slice + slice_data = self._get_plane_slice(i) + # Use fixed vmin/vmax to prevent auto-scaling contrast per slice + artist = ax.imshow(slice_data, cmap='gray', interpolation='nearest', + zorder=1, vmin=0, vmax=255) + self._ax_sag_cor_plane_artists[i] = artist + + # Mask overlay + mask_slice = self._get_mask_slice(i) + # Match interpolation and zorder to DrawVOIWidget + mask_artist = ax.imshow(mask_slice, interpolation='nearest', + zorder=8) + self._ax_sag_cor_seg_masks[i] = mask_artist + + # Crosshair lines + # Get actual coordinate indices for current plane from maps + idx_x, idx_y = self._ax_sag_cor_index_maps[i] + v_line = ax.axvline(x=self._crosshair_xyzt[idx_x], color='yellow', lw=0.8, animated=True, zorder=11) + h_line = ax.axhline(y=self._crosshair_xyzt[idx_y], color='yellow', lw=0.8, animated=True, zorder=11) + self._ax_sag_cor_crosshair_lines[i] = (v_line, h_line) + + self._update_aspect_ratios() + + def _get_plane_slice(self, plane_ix: int) -> np.ndarray: + """Extract a 2D image slice for the specified plane at current crosshair indices.""" + x, y, z, t = self._crosshair_xyzt + vol = self._get_enhanced_volume(t) + + if plane_ix == 0: # Axial (XY) at Z -> show (Y, X) + return vol[:, :, z].T + elif plane_ix == 1: # Sagittal (YZ) at X -> show (Z, Y) then rotate 90 CW -> (Y, Z) + # Match DrawVOIWidget approach: arr.T then rot90(k=-1) + arr = vol[x, :, :] + arr_t = arr.T + return np.rot90(arr_t, k=-1) + elif plane_ix == 2: # Coronal (XZ) at Y -> show (Y, X) + # Mirror Axial for Coronal to match DrawVOI behavior + return vol[:, y, :].T + return np.zeros((10, 10)) + + def _get_mask_slice(self, plane_ix: int) -> np.ndarray: + """Extract a 2D mask slice for overlay.""" + x, y, z, _ = self._crosshair_xyzt + if plane_ix == 0: # Axial (XY) at Z -> show (Y, X) + # Match DrawVOIWidget approach: index mask then transpose (1, 0, 2) + arr = self._roi_masks_overlap[:, :, z, :] + return np.transpose(arr, (1, 0, 2)) + elif plane_ix == 1: # Sagittal (YZ) at X -> show (Z, Y) then rotate 90 CW -> (Y, Z) + arr = self._roi_masks_overlap[x, :, :, :] + # Consistent with DrawVOIWidget: (1, 0, 2) then rot90(k=-1) + arr_t = np.transpose(arr, (1, 0, 2)) + return np.rot90(arr_t, k=-1) + elif plane_ix == 2: # Coronal (XZ) at Y -> show (Y, X) + # Mirror Axial for Coronal to match DrawVOI behavior + arr = self._roi_masks_overlap[:, y, :, :] + return np.transpose(arr, (1, 0, 2)) + return np.zeros((10, 10, 4), dtype=np.uint8) + + def _get_enhanced_volume(self, t: int) -> np.ndarray: + """Apply image processing and return the 3D volume at frame t.""" + if self._enhanced_cache is not None and self._enhanced_cache_frame == t: + return self._enhanced_cache + + # Extract the 3D volume for current frame + vol_3d = self._pix_data[:, :, :, t] + + # Create a temporary UltrasoundImage for the engine preprocessors + temp_im = UltrasoundImage(self._image_data.scan_path) + temp_im.pixel_data = vol_3d + temp_im.pixdim = self._image_data.pixdim + temp_im.frame_rate = self._image_data.frame_rate + + # Apply backend engine functions + temp_im = enhance_clahe(temp_im, clip_limit=self._clahe_clip_limit) + temp_im = enhance_gamma(temp_im, gamma=self._gamma) + + self._enhanced_cache = temp_im.pixel_data + self._enhanced_cache_frame = t + return self._enhanced_cache + + def _setup_all_plane_animations(self) -> None: + """Setup refresh animations for each matplotlib canvas.""" + for i in range(3): + canvas = self._ax_sag_cor_matplotlib_canvases[i] + self._ax_sag_cor_animations[i] = anim.FuncAnimation( + canvas.figure, + lambda frame, p_ix=i: self._update_plane(p_ix), + interval=33, + blit=True, + cache_frame_data=False + ) + + def _update_plane(self, plane_ix: int): + """Update artist data for a single plane.""" + # Always return the list of artists for blitting + v_line, h_line = self._ax_sag_cor_crosshair_lines[plane_ix] + artist = self._ax_sag_cor_plane_artists[plane_ix] + mask_artist = self._ax_sag_cor_seg_masks[plane_ix] + + artists = [] + if artist: artists.append(artist) + if mask_artist: artists.append(mask_artist) + if v_line: + v_line.set_visible(self._crosshair_visible) + artists.append(v_line) + if h_line: + h_line.set_visible(self._crosshair_visible) + artists.append(h_line) + + if not self._ax_sag_cor_pending[plane_ix]: + return artists + + if artist: + artist.set_data(self._get_plane_slice(plane_ix)) + if mask_artist: + mask_artist.set_data(self._get_mask_slice(plane_ix)) + + if v_line and h_line: + idx_x, idx_y = self._ax_sag_cor_index_maps[plane_ix] + # When refreshing (e.g. slice changed), snap back to stored indices + v_line.set_xdata([self._crosshair_xyzt[idx_x]]) + h_line.set_ydata([self._crosshair_xyzt[idx_y]]) + v_line.set_visible(self._crosshair_visible) + h_line.set_visible(self._crosshair_visible) + + self._ax_sag_cor_pending[plane_ix] = False + return artists + + def _connect_signals(self) -> None: + """Connect UI signals to internal handlers, matching DrawVOIWidget patterns.""" + # Frame/Time Navigation + self._ui.cur_slice_slider.valueChanged.connect(self._on_slice_slider_changed) + self._ui.cur_slice_spin_box.valueChanged.connect(lambda v: self._ui.cur_slice_slider.setValue(int(v)-1)) + self._ui.toggle_crosshair_visibility_button.clicked.connect(self._on_toggle_crosshair) + self._ui.back_button.clicked.connect(self._on_back_requested) + + # Decision Buttons + self.confirm_review_button.clicked.connect(self._on_confirm_review) + self._ui.restart_voi_button.clicked.connect(self._on_back_requested) + self._ui.save_voi_button.clicked.connect(self._on_save_voi_clicked) + + # Save Form Actions + self._ui.back_from_save_button.clicked.connect(self._on_back_from_save_clicked) + self._ui.choose_save_folder_button.clicked.connect(self._on_choose_save_folder) + self._ui.clear_save_folder_button.clicked.connect(lambda: self._ui.save_folder_input.clear()) + self._ui.export_voi_button.clicked.connect(self._on_export_voi_clicked) + + def _on_back_requested(self): + """Handle back request with cleanup.""" + self._cleanup_animations() + self.back_requested.emit() + + def _on_confirm_review(self): + """Handle confirmation with cleanup.""" + self._cleanup_animations() + self.segmentation_confirmed.emit() + + def _on_save_voi_clicked(self) -> None: + """Switch to the save file configuration menu.""" + self._hide_widget_lists([self._voi_decision_widgets]) + self._show_widget_lists([self._save_voi_widgets, self._voi_alpha_widgets]) + # Default save name + self._ui.save_name_input.setText(f"{self._image_data.scan_name}_mask") + + def _on_back_from_save_clicked(self) -> None: + """Switch back from save menu to decision menu.""" + self._hide_widget_lists([self._save_voi_widgets, self._voi_alpha_widgets]) + self._show_widget_lists([self._voi_decision_widgets]) + + def closeEvent(self, event): + """Clean up animations and canvases before the widget is destroyed.""" + self._cleanup_animations() + + for i in range(3): + canvas = self._ax_sag_cor_matplotlib_canvases[i] + if canvas: + try: + plt.close(canvas.figure) + except Exception: + pass + self._ax_sag_cor_matplotlib_canvases[i] = None + + super().closeEvent(event) + + def _on_choose_save_folder(self) -> None: + """Open directory dialog for saving.""" + folder = QFileDialog.getExistingDirectory(self, "Select Save Directory") + if folder: + self._ui.save_folder_input.setText(folder) + + def _on_export_voi_clicked(self) -> None: + """Export the current 3D mask to NIfTI.""" + folder_path = self._ui.save_folder_input.text() + file_name = self._ui.save_name_input.text() + + if not folder_path or not Path(folder_path).is_dir(): + self.show_error("Please select a valid folder.") + return + if not file_name: + self.show_error("Please enter a file name.") + return + + if not file_name.endswith('.nii.gz'): + file_name += '.nii.gz' + + out_path = Path(folder_path) / file_name + + try: + self.show_loading() + # Construct affine + affine = np.eye(4) + for i, res in enumerate(self._image_data.pixdim[:3]): + affine[i, i] = res + + # The mask is stored in self._seg_data.seg_mask + mask = self._seg_data.seg_mask + nii_img = nib.Nifti1Image(mask, affine) + nii_img.header["descrip"] = self._image_data.scan_name + nib.save(nii_img, out_path) + + self.hide_loading() + # After export, show decision again + self._on_back_from_save_clicked() + except Exception as e: + self.hide_loading() + self.show_error(f"Export failed: {str(e)}") + + def _on_slice_slider_changed(self, value: int) -> None: + """Handle time-series frame change.""" + self._crosshair_xyzt[3] = value + self._ui.cur_slice_spin_box.blockSignals(True) + self._ui.cur_slice_spin_box.setValue(value + 1) + self._ui.cur_slice_spin_box.blockSignals(False) + self._refresh_frames() + + def _on_toggle_crosshair(self) -> None: + """Toggle crosshair visibility.""" + self._crosshair_visible = not self._crosshair_visible + self._ui.toggle_crosshair_visibility_button.setText( + 'Show Crosshair' if not self._crosshair_visible else 'Hide Crosshair' + ) + self._refresh_frames() + + def _refresh_frames(self) -> None: + """Mark all planes for refresh.""" + self._ax_sag_cor_pending = [True, True, True] + + def _update_scan_display(self) -> None: + """Sync UI labels with current crosshair indices. Using 1-based indexing for display.""" + self._ui.ax_frame_num.setText(str(self._crosshair_xyzt[2] + 1)) + self._ui.sag_frame_num.setText(str(self._crosshair_xyzt[0] + 1)) + self._ui.cor_frame_num.setText(str(self._crosshair_xyzt[1] + 1)) + + # Update spinbox for t + self._ui.cur_slice_spin_box.blockSignals(True) + self._ui.cur_slice_spin_box.setValue(self._crosshair_xyzt[3] + 1) + self._ui.cur_slice_spin_box.blockSignals(False) + + def set_crosshair(self, x=None, y=None, z=None, t=None): + """Update crosshair position and trigger refresh.""" + changed = False + if x is not None and 0 <= x < self._x_len: + self._crosshair_xyzt[0] = x; changed = True + if y is not None and 0 <= y < self._y_len: + self._crosshair_xyzt[1] = y; changed = True + if z is not None and 0 <= z < self._z_len: + self._crosshair_xyzt[2] = z; changed = True + if t is not None and 0 <= t < self._num_slices: + self._crosshair_xyzt[3] = t; changed = True + + if changed: + self._update_scan_display() + self._refresh_frames() + + # ======================= Resize Handling ================================= + def eventFilter(self, obj, event): # type: ignore + if event.type() == QEvent.Type.Resize and obj in self._ax_sag_cor_planes: + self._resize_canvas_for(obj) + return super().eventFilter(obj, event) + + def _resize_canvas_for(self, label_widget: QLabel): + try: + idx = self._ax_sag_cor_planes.index(label_widget) + except ValueError: + return + canvas = self._ax_sag_cor_matplotlib_canvases[idx] + if not canvas: + return + + # Match canvas size to the QLabel/placeholder size + # We ensure it fills the parent label precisely + canvas_width = label_widget.width() + canvas_height = label_widget.height() + canvas.setFixedSize(canvas_width, canvas_height) + canvas.move(0, 0) + canvas.draw_idle() + + def _resize_all_canvases(self): + """Force a resize of all embedded matplotlib canvases.""" + for label in self._ax_sag_cor_planes: + if label: + self._resize_canvas_for(label) + + def showEvent(self, event): + # Ensure canvases sized properly when shown + self._resize_all_canvases() + return super().showEvent(event) + + def _connect_matplotlib_events(self): + """Connect motion and click events on each plane's matplotlib canvas.""" + for plane_ix, canvas in enumerate(self._ax_sag_cor_matplotlib_canvases): + if not canvas: continue + canvas.mpl_connect('motion_notify_event', lambda e, p=plane_ix: self._on_canvas_motion(e, p)) + canvas.mpl_connect('button_press_event', lambda e, p=plane_ix: self._on_canvas_click(e, p)) + + def _on_canvas_click(self, event, plane_ix: int): + if event.inaxes is None: return + self._crosshair_active = not self._crosshair_active + if self._crosshair_active: + self._ui.navigating_label.show() + self._ui.observing_label.hide() + else: + self._ui.navigating_label.hide() + self._ui.observing_label.show() + self._on_canvas_motion(event, plane_ix) + + def _on_canvas_motion(self, event, plane_ix: int): # type: ignore + """Handle mouse movement over a plane and update crosshair indices. + + event.xdata maps to the first varying dimension of that plane slice, + event.ydata to the second. We clamp to valid ranges and call set_crosshair + only if the index meaningfully changed. + """ + if not self._crosshair_active: + return + if event.inaxes is None or event.xdata is None or event.ydata is None: + return + + vary_dims = self._ax_sag_cor_index_maps[plane_ix] + dim_x, dim_y = vary_dims[0], vary_dims[1] + + # Dimension lengths mapping + dim_lengths = [self._x_len, self._y_len, self._z_len, self._num_slices] + + # Proposed new indices (int rounding & clamp) + new_xval = int(round(event.xdata)) + new_yval = int(round(event.ydata)) + if new_xval < 0 or new_yval < 0: + return + if new_xval >= dim_lengths[dim_x] or new_yval >= dim_lengths[dim_y]: + return + + # Build kwargs for set_crosshair only for dims that change + params = {} + if self._crosshair_xyzt[dim_x] != new_xval: + key = ['x','y','z','t'][dim_x] + params[key] = new_xval + if self._crosshair_xyzt[dim_y] != new_yval: + key = ['x','y','z','t'][dim_y] + params[key] = new_yval + + if params: + self.set_crosshair(**params) + + def _update_hover_crosshair(self, x, y, plane_ix): + """Update crosshair lines to follow mouse hover.""" + v_line, h_line = self._ax_sag_cor_crosshair_lines[plane_ix] + if v_line and h_line: + v_line.set_xdata([x, x]) + h_line.set_ydata([y, y]) + v_line.set_visible(self._crosshair_visible) + h_line.set_visible(self._crosshair_visible) + # Only update the background for this ONE canvas + self._ax_sag_cor_matplotlib_canvases[plane_ix].draw_idle() diff --git a/src/ceus/seg_loading/views/spline.py b/src/ceus/seg_loading/views/spline.py index 65bf06e..e65263a 100644 --- a/src/ceus/seg_loading/views/spline.py +++ b/src/ceus/seg_loading/views/spline.py @@ -17,10 +17,25 @@ def calculateSpline(xpts, ypts, zpts=None): # 2D spline interpolation cv.append([xpts[i], ypts[i], zpts[i]]) else: cv.append([xpts[i], ypts[i]]) + + # Remove duplicate points which cause "ValueError: Invalid inputs" in splprep cv = np.array(cv) - if len(xpts) == 2: + if len(cv) > 1: + # Calculate distances between consecutive points + diffs = np.diff(cv, axis=0) + dists = np.sqrt(np.sum(diffs**2, axis=1)) + # Keep first point and points that are sufficiently far from the previous one + mask = np.concatenate(([True], dists > 1e-5)) + cv = cv[mask] + + if len(cv) < 2: + if zpts is not None: + return np.array([cv[0][0]]), np.array([cv[0][1]]), np.array([cv[0][2]]) + return np.array([cv[0][0]]), np.array([cv[0][1]]) + + if len(cv) == 2: tck, _ = interpolate.splprep(cv.T, s=0.0, k=1) - elif len(xpts) == 3: + elif len(cv) == 3: tck, _ = interpolate.splprep(cv.T, s=0.0, k=2) else: tck, _ = interpolate.splprep(cv.T, s=0.0, k=3) @@ -54,6 +69,11 @@ def ellipsoidFitLS(pos): def calculateSpline3D(points): + # If the points have a 4th dimension (time), we strip it as pyvista expects (x, y, z) + points = np.array(points) + if points.shape[1] == 4: + points = points[:, :3] + cloud = pv.PolyData(points, force_float=False) volume = cloud.delaunay_3d(alpha=100) shell = volume.extract_geometry() # type: ignore