diff --git a/.github/workflows/build_cython_extensions.yml b/.github/workflows/build_cython_extensions.yml new file mode 100644 index 00000000..50c1a175 --- /dev/null +++ b/.github/workflows/build_cython_extensions.yml @@ -0,0 +1,73 @@ +name: Build Cython extensions + +permissions: + contents: write + +on: + push: + paths: + - "cellacdc/**/*.pyx" + - "setup.py" + workflow_dispatch: + +jobs: + build: + runs-on: ${{ matrix.os }} + defaults: + run: + working-directory: ${{ github.workspace }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.10", "3.11", "3.12", "3.13"] + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Build extension + run: | + python -m pip install --upgrade "setuptools>=77" cython numpy + python "${{ github.workspace }}/setup.py" build_ext --inplace --build-temp "${{ github.workspace }}/build/temp" + + - name: Move to precompiled folder + shell: bash + run: | + mkdir -p cellacdc/precompiled + mv cellacdc/regionprops_helper.*.so cellacdc/precompiled/ 2>/dev/null || true + mv cellacdc/regionprops_helper.*.pyd cellacdc/precompiled/ 2>/dev/null || true + + - uses: actions/upload-artifact@v4 + with: + name: precompiled-${{ matrix.os }}-py${{ matrix.python-version }} + path: cellacdc/precompiled/ + + commit: + needs: build + runs-on: ubuntu-latest + defaults: + run: + working-directory: ${{ github.workspace }} + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ github.ref_name }} + + - uses: actions/download-artifact@v4 + with: + pattern: precompiled-* + merge-multiple: true + path: cellacdc/precompiled/ + + - name: Commit precompiled binaries + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git add cellacdc/precompiled/ + git diff --staged --quiet || git commit -m "ci: update precompiled Cython extensions" + git push origin HEAD:${{ github.ref_name }} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 6d90763c..6ab4aedf 100755 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,11 @@ requirements_new.txt **/weights_location_path.txt **/_test +# cython generated files +*.pyd +*.so +!cellacdc/precompiled/ + # Test output plots tests/_plots/ @@ -36,7 +41,6 @@ setup.cfg # Starting from pip 21.3 setup.py is not needed anymore # and we rely only on setup.cfg for env -setup.py environment.yml # requirements.txt conda_env_list_commands.txt diff --git a/cellacdc/annotate.py b/cellacdc/annotate.py index ec09b361..feb4698a 100644 --- a/cellacdc/annotate.py +++ b/cellacdc/annotate.py @@ -5,7 +5,7 @@ import pandas as pd from . import GUI_INSTALLED -from . import cellacdc_path, printl, ignore_exception +from . import cellacdc_path, printl, ignore_exception, debugutils if GUI_INSTALLED: from PIL import Image, ImageFont, ImageDraw @@ -207,6 +207,45 @@ def __init__(self, *args, anchor=(0.5, 0.5), **kargs): self.texts = [] self.annotData = [] self._anchor = anchor + + def _rebuildSizes(self, bold=False): + if bold: + self.sizesBold = plot.get_symbol_sizes( + self.scalesBold, self.symbolsBold, self.fontSize + ) + self._maxScaleBold = max(self.scalesBold.values(), default=None) + else: + self.sizesRegular = plot.get_symbol_sizes( + self.scalesRegular, self.symbolsRegular, self.fontSize + ) + self._maxScaleRegular = max(self.scalesRegular.values(), default=None) + + def _updateSizesForTexts(self, texts, bold=False): + if not texts: + return + + if bold: + scales = self.scalesBold + sizes_attr = 'sizesBold' + max_scale_attr = '_maxScaleBold' + else: + scales = self.scalesRegular + sizes_attr = 'sizesRegular' + max_scale_attr = '_maxScaleRegular' + + current_max_scale = getattr(self, max_scale_attr, None) + if current_max_scale is None: + self._rebuildSizes(bold=bold) + return + + added_max_scale = max(scales[text] for text in texts) + if added_max_scale > current_max_scale: + self._rebuildSizes(bold=bold) + return + + sizes = getattr(self, sizes_attr) + for text in texts: + sizes[text] = int(np.round(self.fontSize*current_max_scale/scales[text])) def clearData(self): self.setData([], []) @@ -254,15 +293,28 @@ def initSymbols(self, allIDs, onlyIDs=False): self.createSymbols(annotTexts) def addSymbols(self, annotTexts, includeBold=True): - for text in annotTexts: - if includeBold: - self.symbolsBold[text] = self.getObjTextAnnotSymbol( - text, bold=True, initSizes=False + if includeBold: + missing_bold = [ + text for text in annotTexts if text not in self.symbolsBold + ] + if missing_bold: + symbolsBold, scalesBold = plot.texts_to_pg_scatter_symbols( + missing_bold, font=self.fontBold, return_scales=True ) - self.symbolsRegular[text] = self.getObjTextAnnotSymbol( - text, bold=True, initSizes=False + self.symbolsBold.update(symbolsBold) + self.scalesBold.update(scalesBold) + self._updateSizesForTexts(missing_bold, bold=True) + + missing_regular = [ + text for text in annotTexts if text not in self.symbolsRegular + ] + if missing_regular: + symbolsRegular, scalesRegular = plot.texts_to_pg_scatter_symbols( + missing_regular, font=self.fontRegular, return_scales=True ) - self.initSizes(includeBold=includeBold) + self.symbolsRegular.update(symbolsRegular) + self.scalesRegular.update(scalesRegular) + self._updateSizesForTexts(missing_regular, bold=False) def createSymbols(self, annotTexts, includeBold=True): if includeBold: @@ -281,12 +333,8 @@ def initSizes(self, includeBold=True): includeBold = False if includeBold: - self.sizesBold = plot.get_symbol_sizes( - self.scalesBold, self.symbolsBold, self.fontSize - ) - self.sizesRegular = plot.get_symbol_sizes( - self.scalesRegular, self.symbolsRegular, self.fontSize - ) + self._rebuildSizes(bold=True) + self._rebuildSizes(bold=False) def setColors(self, colors): self._colors = colors.copy() @@ -325,7 +373,7 @@ def getObjTextAnnotSymbol(self, text, bold=False, initSizes=True): symbols[text] = symbol scales[text] = scale if initSizes: - self.initSizes() + self._updateSizesForTexts([text], bold=bold) return symbol def grayOutAnnotations(self, IDsToSkip=None): @@ -346,7 +394,7 @@ def grayOutAnnotations(self, IDsToSkip=None): self.setBrush(brushes) self.setPen(pens) - def highlightObject(self, obj): + def highlightObject(self, obj, rp=None, getObjCentroidFunc=None): ID = obj.label objIdx = None for idx, objData in enumerate(self.data): @@ -357,7 +405,14 @@ def highlightObject(self, obj): objOpts = { 'text': str(ID), 'bold': True, 'color_name': 'new_object' } - yc, xc = obj.centroid[-2:] + if rp is not None: + centroid = rp.get_centroid(obj.label) + else: + centroid = obj.centroid + if getObjCentroidFunc is not None: + yc, xc = getObjCentroidFunc(centroid) + else: + yc, xc = centroid[-2:] pos = (int(xc), int(yc)) self.addObjAnnot(pos, draw=True, **objOpts) return @@ -504,13 +559,20 @@ def removeFromPlotItem(self, ax): if hasattr(self.item, 'highlighterItem'): ax.removeItem(self.item.highlighterItem) - def addObjAnnotation(self, obj, color_name, text, bold): + def addObjAnnotation(self, obj, color_name, text, bold, rp=None, getObjCentroidFunc=None): objOpts = { 'text': text, 'bold': bold, 'color_name': color_name, } - yc, xc = obj.centroid[-2:] + if rp is not None: + centroid = rp.get_centroid(obj.label) + else: + centroid = obj.centroid + if getObjCentroidFunc is not None: + yc, xc = getObjCentroidFunc(centroid) + else: + yc, xc = centroid[-2:] pos = (int(xc), int(yc)) objData = self.item.addObjAnnot(pos, draw=True, **objOpts) self.item.appendData(objData, objOpts['text']) @@ -563,7 +625,8 @@ def setAnnotations(self, **kwargs): isGenNumTreeAnnotation, posData.frame_i ) - yc, xc = getObjCentroidFunc(obj.centroid) + centroid = posData.rp.get_centroid(obj.label) + yc, xc = getObjCentroidFunc(centroid) try: rp_zslice = posData.zSlicesRp[currentZ] obj_2d = rp_zslice[obj.label] @@ -598,7 +661,8 @@ def setAnnotations(self, **kwargs): 'color_name': 'tracked_lost_object', 'bold': False, } - yc, xc = obj.centroid[-2:] + centroid = prev_rp.get_centroid(obj.label) + yc, xc = getObjCentroidFunc(centroid) pos = (int(xc), int(yc)) objData = self.item.addObjAnnot(pos, draw=False, **objOpts) self.item.appendData(objData, objOpts['text']) @@ -624,7 +688,8 @@ def setAnnotations(self, **kwargs): 'color_name': 'lost_object', 'bold': False, } - yc, xc = getObjCentroidFunc(obj.centroid) + centroid = prev_rp.get_centroid(obj.label) + yc, xc = getObjCentroidFunc(centroid) try: pos = (int(xc), int(yc)) except Exception as err: @@ -638,8 +703,8 @@ def setAnnotations(self, **kwargs): self.item.draw() - def highlightObject(self, obj): - self.item.highlightObject(obj) + def highlightObject(self, obj, rp=None, getObjCentroidFunc=None): + self.item.highlightObject(obj, rp=None, getObjCentroidFunc=None) def removeHighlightObject(self, obj): self.item.removeHighlightObject(obj) diff --git a/cellacdc/apps.py b/cellacdc/apps.py index 80cf84a8..c59c8d63 100755 --- a/cellacdc/apps.py +++ b/cellacdc/apps.py @@ -7413,7 +7413,7 @@ def setPosData(self): # self.img.setCurrentPosIndex(self.pos_i) # self.img.minMaxValuesMapper = self.mainWin.img1.minMaxValuesMapper self.origLab = self.posData.lab.copy() - self.origRp = skimage.measure.regionprops(self.origLab) + self.origRp = skimage.measure.regionprops(self.origLab) # why seperate rp here? self.origObjs = {obj.label:obj for obj in self.origRp} def valueChanged(self, value): diff --git a/cellacdc/core.py b/cellacdc/core.py index 80ae1df0..e7f0fda8 100755 --- a/cellacdc/core.py +++ b/cellacdc/core.py @@ -27,6 +27,7 @@ import pathlib from natsort import natsorted import sympy as sp +import pickle from math import sqrt from scipy.stats import norm @@ -50,6 +51,7 @@ from . import measurements from . import favourite_func_metrics_csv_path from . import default_index_cols +from . import regionprops from ._types import ( ChannelsDict @@ -1273,8 +1275,12 @@ def cca_df_to_acdc_df(cca_df, rp, acdc_df=None): IDs.append(obj.label) is_cell_dead_li.append(0) is_cell_excluded_li.append(0) - xx_centroid.append(int(obj.centroid[1])) - yy_centroid.append(int(obj.centroid[0])) + if isinstance(rp, regionprops.acdcRegionprops): + centroid = rp.get_centroid(obj.label, exact=True) + else: + centroid = obj.centroid + xx_centroid.append(int(centroid[1])) + yy_centroid.append(int(centroid[0])) acdc_df = pd.DataFrame({ 'Cell_ID': IDs, 'is_cell_dead': is_cell_dead_li, @@ -3194,53 +3200,119 @@ def insert_missing_objects( return segm_dst -def process_lab(task): - i, lab = task - # Assuming this function processes each lab independently - data_dict = {} - rp = skimage.measure.regionprops(lab) - IDs = [obj.label for obj in rp] - data_dict['IDs'] = IDs - data_dict['regionprops'] = rp - data_dict['IDs_idxs'] = {ID: idx for idx, ID in enumerate(IDs)} +### out of date +# def process_lab(task): +# i, lab = task +# # Assuming this function processes each lab independently +# data_dict = {} +# rp = skimage.measure.regionprops(lab) +# IDs = [obj.label for obj in rp] +# data_dict['IDs'] = IDs +# data_dict['regionprops'] = rp +# data_dict['IDs_idxs'] = {ID: idx for idx, ID in enumerate(IDs)} + +# return i, data_dict, IDs # Return index, data_dict, and IDs + +# def parallel_count_objects(posData, logger_func): +# benchmark = True +# #futile attempt to use multiprocessing to speed things up +# logger_func('Counting total number of segmented objects...') + +# allIDs = set() +# seg_data = posData.segm_data + +# # Initialize empty data dictionary to avoid recalculating each time +# tasks = [(i, lab) for i, lab in enumerate(seg_data)] + +# if benchmark: +# t0 = time.perf_counter() +# # Process in batches to optimize memory usage and control parallelism +# with ThreadPoolExecutor() as executor: +# futures = [executor.submit(process_lab, task) for task in tasks] + +# # Process results as they are completed +# for future in tqdm(as_completed(futures), total=len(futures), ncols=100): +# i, data_dict, IDs = future.result() +# posData.allData_li[i] = myutils.get_empty_stored_data_dict() # or directly assign if it's mutable +# posData.allData_li[i]['IDs'] = data_dict['IDs'] +# posData.allData_li[i]['regionprops'] = data_dict['regionprops'] +# posData.allData_li[i]['IDs_idxs'] = data_dict['IDs_idxs'] +# allIDs.update(IDs) - return i, data_dict, IDs # Return index, data_dict, and IDs +# if benchmark: +# t1 = time.perf_counter() +# logger_func(f'Counting objects took {(t1 - t0)*1000:.2f} ms') -def parallel_count_objects(posData, logger_func): - benchmark = True - #futile attempt to use multiprocessing to speed things up - logger_func('Counting total number of segmented objects...') +# return allIDs, posData + +def check_file_time_proximity(file1, file2, max_seconds=300, logger_func=print): + if not os.path.isfile(file1): + return False - allIDs = set() - seg_data = posData.segm_data + if not os.path.isfile(file2): + return False - # Initialize empty data dictionary to avoid recalculating each time - tasks = [(i, lab) for i, lab in enumerate(seg_data)] + mtime1 = os.path.getmtime(file1) + mtime2 = os.path.getmtime(file2) + + if abs(mtime1 - mtime2) <= max_seconds: + return True + else: + logger_func(f'Warning: The files "{file1}" and "{file2}" were not saved within {max_seconds} seconds of each other.') + return False + +def verify_acdc_df_segm(posData: load.loadData, logger_func=print): + if posData.segmMetadata is None: + return None + segm_info = posData.segmMetadata[os.path.basename(posData.segm_npz_path)] + imgs_folder = posData.images_path + csv_name = segm_info['acdc_df_segm'] if 'acdc_df_segm' in segm_info else None + if csv_name is None: + return None + csv_filepath = os.path.join(imgs_folder, csv_name) + + # verify that that both files exist and are within the allowed time proximity + success = check_file_time_proximity( + posData.segm_npz_path, csv_filepath, max_seconds=120, logger_func=logger_func + ) + if not success: + return None + + return csv_filepath - if benchmark: - t0 = time.perf_counter() - # Process in batches to optimize memory usage and control parallelism - with ThreadPoolExecutor() as executor: - futures = [executor.submit(process_lab, task) for task in tasks] +def verify_add_data_segm_proximity(posData: load.loadData, logger_func=print): + segm_path = posData.segm_npz_path + segm_filename = os.path.basename(segm_path).replace('.npz', '') + add_data_folder = os.path.join(posData.images_path, segm_filename) + + centroids_path = os.path.join(add_data_folder, 'centroids.pkl') + # IDs_path = os.path.join(add_data_folder, 'IDs.pkl') + centroids_IDs_exact_path = os.path.join(add_data_folder, 'centroids_IDs_exact.pkl') + # ID_to_idx_path = os.path.join(add_data_folder, 'ID_to_idx.pkl') + + ok = [True] * 2 + for idx, file in enumerate([centroids_path, centroids_IDs_exact_path]): + success = check_file_time_proximity( + segm_path, file, max_seconds=120, logger_func=logger_func + ) + if not success: + ok[idx] = False + + return { + 'centroids': centroids_path if ok[0] else None, + # 'IDs': IDs_path if ok[1] else None, + 'centroids_IDs_exact': centroids_IDs_exact_path if ok[1] else None, + # 'ID_to_idx': ID_to_idx_path if ok[3] else None, + } - # Process results as they are completed - for future in tqdm(as_completed(futures), total=len(futures), ncols=100): - i, data_dict, IDs = future.result() - posData.allData_li[i] = myutils.get_empty_stored_data_dict() # or directly assign if it's mutable - posData.allData_li[i]['IDs'] = data_dict['IDs'] - posData.allData_li[i]['regionprops'] = data_dict['regionprops'] - posData.allData_li[i]['IDs_idxs'] = data_dict['IDs_idxs'] - allIDs.update(IDs) - if benchmark: - t1 = time.perf_counter() - logger_func(f'Counting objects took {(t1 - t0)*1000:.2f} ms') - - return allIDs, posData - -def count_objects(posData, logger_func): - benchmark = False - +# WARNING: this function has been attempted to be optimized by +# parallelization, loading data from last session +# The main bottleneck seams to be the rp creation (not even for example getting the IDs or centorids) +# Total time spend optimising here +# >5 hrs +# please update this if you try to optimize again +def count_objects_and_init_rps(posData: load.loadData, logger_func=print): allIDs = set() segm_data = posData.segm_data @@ -3248,25 +3320,65 @@ def count_objects(posData, logger_func): allIDs = [] return allIDs, posData + # check if csv is usable + # csv_filepath = verify_acdc_df_segm(posData, logger_func) + logger_func('Counting total number of segmented objects...') + + # frames_in_csv = set() + # if csv_filepath is not None: + # logger_func(f'Loading ACDC dataframe from "{csv_filepath}"...') + # acdc_df = pd.read_csv(csv_filepath) + # frames_in_csv = set(acdc_df['frame_i']) + # pbar = tqdm(total=len(frames_in_csv), ncols=100, leave=False) + # for i in frames_in_csv: + # posData.allData_li[i] = myutils.get_empty_stored_data_dict() + # relevant_df = acdc_df[acdc_df['frame_i'] == i] + # rp = regionprops.acdcRegionprops(segm_data[i], relevant_df) + # posData.allData_li[i]['regionprops'] = rp + # allIDs.update(rp.IDs_set) + # pbar.update() + # pbar.close() + + # # missing_i = set(range(len(segm_data))) - frames_in_csv if csv_filepath is not None else set(range(len(segm_data))) + # if len(missing_i) == 0: + # return allIDs, posData + + add_data_dict = verify_add_data_segm_proximity(posData, logger_func) + centroids_loaded = ID_to_idx_loaded = centroids_IDs_exact_loaded = IDs_loaded = None + if add_data_dict['centroids'] is not None: + with open(add_data_dict['centroids'], 'rb') as f: + centroids_loaded = pickle.load(f) + # if add_data_dict['IDs'] is not None: + # with open(add_data_dict['IDs'], 'rb') as f: + # IDs_loaded = pickle.load(f) + if add_data_dict['centroids_IDs_exact'] is not None: + with open(add_data_dict['centroids_IDs_exact'], 'rb') as f: + centroids_IDs_exact_loaded = pickle.load(f) + # if add_data_dict['ID_to_idx'] is not None: + # with open(add_data_dict['ID_to_idx'], 'rb') as f: + # ID_to_idx_loaded = pickle.load(f) + # other ids: pbar = tqdm(total=len(segm_data), ncols=100) - if benchmark: - t0 = time.perf_counter() for i, lab in enumerate(segm_data): + # if i in frames_in_csv: + # continue # skip frames already processed with csv + # _centroids_loaded = centroids_loaded[i] if centroids_loaded is not None and i in centroids_loaded else None + # # _IDs_loaded = IDs_loaded[i] if IDs_loaded is not None and i in IDs_loaded else None + # _centroids_IDs_exact_loaded = centroids_IDs_exact_loaded[i] if centroids_IDs_exact_loaded is not None and i in centroids_IDs_exact_loaded else None + # _ID_to_idx_loaded = ID_to_idx_loaded[i] if ID_to_idx_loaded is not None and i in ID_to_idx_loaded else None posData.allData_li[i] = myutils.get_empty_stored_data_dict() - rp = skimage.measure.regionprops(lab) - IDs = [obj.label for obj in rp] - posData.allData_li[i]['IDs'] = IDs + rp = regionprops.acdcRegionprops(lab, + # centroids_loaded=_centroids_loaded, + # IDs_loaded=_IDs_loaded, + # centroids_IDs_exact_loaded=_centroids_IDs_exact_loaded, + # ID_to_idx_loaded=_ID_to_idx_loaded + ) + IDs = rp.IDs_set posData.allData_li[i]['regionprops'] = rp - posData.allData_li[i]['IDs_idxs'] = { # IDs_idxs[obj.label] = idx - ID: idx for idx, ID in enumerate(IDs) - } allIDs.update(IDs) pbar.update() pbar.close() - if benchmark: - t1 = time.perf_counter() - logger_func(f'Counting objects took {(t1 - t0)*1000:.2f} ms') return allIDs, posData def fix_sparse_directML(verbose=True): diff --git a/cellacdc/debugutils.py b/cellacdc/debugutils.py index b55d0eef..9f2ab330 100644 --- a/cellacdc/debugutils.py +++ b/cellacdc/debugutils.py @@ -1,9 +1,101 @@ import inspect, os, datetime, sys, traceback +import atexit +import linecache +from collections import defaultdict from . import cellacdc_path, myutils import gc import psutil +import time +import functools + +_LINE_BENCHMARK_TRACE_LIMIT = 10000 + +_LINE_BENCHMARK_STATS = defaultdict( + lambda: { + 'count': 0, + 'traced_count': 0, + 'untracked_count': 0, + 'total_time': 0.0, + 'min_time': float('inf'), + 'max_time': 0.0, + 'filename': None, + 'line_stats': defaultdict( + lambda: { + 'count': 0, + 'total_time': 0.0, + 'min_time': float('inf'), + 'max_time': 0.0, + } + ), + } +) + +def _get_benchmark_line_snippet(filename, lineno, max_chars=30): + if lineno == 'return': + return '' + if not filename: + return '' + + line = linecache.getline(filename, lineno).strip() + if not line: + return '' + + if len(line) <= max_chars: + # fill up to max_chars for better alignment + line = line.ljust(max_chars) + return line + return f'{line[:max_chars-3]}...' + +def _print_line_benchmark_session_stats(): + if not _LINE_BENCHMARK_STATS: + return + + print('\nLine benchmark session summary:') + for func_name, stats in sorted(_LINE_BENCHMARK_STATS.items()): + total_count = stats['count'] + traced_count = stats['traced_count'] + untracked_count = stats['untracked_count'] + if total_count == 0: + continue + + if traced_count: + mean_time = stats['total_time'] / traced_count + print( + f'{func_name}: n={total_count} | ' + f'traced={traced_count} | ' + f'untracked={untracked_count} | ' + f'mean={mean_time*1000:.3f} ms | ' + f'min={stats["min_time"]*1000:.3f} ms | ' + f'max={stats["max_time"]*1000:.3f} ms | ' + f'total={stats["total_time"]*1000:.3f} ms' + ) + else: + print( + f'{func_name}: n={total_count} | ' + f'traced=0 | ' + f'untracked={untracked_count}' + ) + + line_stats = stats['line_stats'] + top_lines = sorted( + line_stats.items(), + key=lambda item: item[1]['total_time'], + reverse=True + )[:10] + filename = stats['filename'] + for (start_line, end_line), line_stat in top_lines: + line_mean = line_stat['total_time'] / line_stat['count'] + line_snippet = _get_benchmark_line_snippet(filename, start_line) + print( + f' {line_snippet:<30} {start_line} -> {end_line}: ' + f'n={line_stat["count"]} | ' + f'mean={line_mean*1000:.3f} ms | ' + f'total={line_stat["total_time"]*1000:.3f} ms' + ) + +atexit.register(_print_line_benchmark_session_stats) def showRefGraph(object_str:str, debug:bool=True): """Save a reference graph of the given object type. @@ -206,3 +298,100 @@ def print_largest_classes(package_prefix="cellacdc", top_n=10, max_instances=100 # Example usage: # print_largest_classes("cellacdc", top_n=10) + +# Return a benchmark checkpoint with caller line information. +def return_timer_and_line(benchmarking=True): + if not benchmarking: + return None + timestamp = time.perf_counter() + line = inspect.currentframe().f_back.f_lineno # is super fast! + return (timestamp, line) + +def print_benchmarks(timers, benchmarking=True): + if not benchmarking: + return + checkpoints = [timer for timer in timers if timer is not None] + if len(checkpoints) < 2: + return + + print("Benchmarks:") + for (start_time, start_line), (end_time, end_line) in zip( + checkpoints, checkpoints[1:] + ): + duration = end_time - start_time + print( + f"Line {start_line} -> {end_line}: " + f"{duration:.6f} seconds" + ) + + total_duration = checkpoints[-1][0] - checkpoints[0][0] + print(f"Total: {total_duration:.6f} seconds") + +def line_benchmark(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + stats_key = f'{func.__module__}.{func.__qualname__}' + stats = _LINE_BENCHMARK_STATS[stats_key] + stats['count'] += 1 + + if stats['traced_count'] >= _LINE_BENCHMARK_TRACE_LIMIT: + stats['untracked_count'] += 1 + return func(*args, **kwargs) + + target_code = func.__code__ + filename = target_code.co_filename + checkpoints = [] + last_time = None + last_line = None + + def tracer(frame, event, arg): + nonlocal last_time, last_line + + if frame.f_code is not target_code: + return tracer + + now = time.perf_counter() + + if event == "call": + last_time = now + last_line = frame.f_lineno + return tracer + + if event == "line": + if last_time is not None and last_line is not None: + checkpoints.append((last_line, frame.f_lineno, now - last_time)) + last_time = now + last_line = frame.f_lineno + return tracer + + if event == "return": + if last_time is not None and last_line is not None: + checkpoints.append((last_line, "return", now - last_time)) + return tracer + + return tracer + + old_trace = sys.gettrace() + sys.settrace(tracer) + try: + result = func(*args, **kwargs) + finally: + sys.settrace(old_trace) + + total = sum(dt for _, _, dt in checkpoints) + stats['traced_count'] += 1 + stats['total_time'] += total + stats['min_time'] = min(stats['min_time'], total) + stats['max_time'] = max(stats['max_time'], total) + stats['filename'] = filename + + for start_line, end_line, dt in checkpoints: + line_stat = stats['line_stats'][(start_line, end_line)] + line_stat['count'] += 1 + line_stat['total_time'] += dt + line_stat['min_time'] = min(line_stat['min_time'], dt) + line_stat['max_time'] = max(line_stat['max_time'], dt) + + return result + + return wrapper \ No newline at end of file diff --git a/cellacdc/gui.py b/cellacdc/gui.py index 619cd2d0..942abb96 100755 --- a/cellacdc/gui.py +++ b/cellacdc/gui.py @@ -95,6 +95,7 @@ from .trackers.CellACDC_normal_division.CellACDC_normal_division_tracker import ( normal_division_lineage_tree)#, reorg_sister_cells_for_export) from . import debugutils +from . import regionprops from .plot import imshow from . import gui_utils @@ -114,6 +115,11 @@ GREEN_HEX = _palettes.green() +RP_OPT_NUM_CELLS_MIN = 0 # th for trying to do local updates to regionprops, rp becomes slow for high num of cells +RP_OPT_PERC_CUTOUT_MAX = 0.1 # th for trying to do local updates to regionprops, + # if region which we have to update is too large too + # many cells are probably inside and its not worth + # local updating (since we actually need to call RP twice!) custom_annot_path = os.path.join(settings_folderpath, 'custom_annotations.json') shortcut_filepath = os.path.join(settings_folderpath, 'shortcuts.ini') @@ -240,7 +246,7 @@ def __init__( self.original_df_lin_tree = None self.original_df_lin_tree_i = None - def setTooltips(self): #laoding tooltips for GUI from .\Cell_ACDC\docs\source\tooltips.rst + def setTooltips(self): #loading tooltips for GUI from .\Cell_ACDC\docs\source\tooltips.rst tooltips = load.get_tooltips_from_docs() for key, tooltip in tooltips.items(): @@ -2762,7 +2768,7 @@ def gui_createActions(self): 'Track current frame with real-time tracker...', self ) self.repeatTrackingMenuAction.setDisabled(True) - self.repeatTrackingMenuAction.setShortcut('Shift+T') + self.repeatTrackingMenuAction.setShortcut('Ctrl+T') self.repeatTrackingVideoAction = QAction( 'Select a tracker and track multiple frames...', self @@ -4683,7 +4689,7 @@ def _gui_createGraphicsItems(self): posData = self.data[self.pos_i] - allIDs, posData = core.count_objects(posData, self.logger.info) + allIDs, posData = core.count_objects_and_init_rps(posData, self.logger.info) self.highLowResAction.setChecked(True) numItems = len(allIDs) @@ -5125,7 +5131,9 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): return else: ID = sepID_prompt.EntryID - y, x = posData.rp[posData.IDs_idxs[ID]].centroid[-2:] + + centroid = posData.rp.get_centroid(ID) + y, x = self.getObjCentroid(centroid) xdata, ydata = int(x), int(y) # Store undo state before modifying stuff @@ -5180,7 +5188,9 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): self.storeManualSeparateDrawMode(manualSep.drawMode) # Update data (rp, etc) - self.update_rp() + bbox = self.update_rp_get_bbox(use_bbox=True, specific_IDs=ID) # use old ID to get bbox + specific_IDs = list(splittedIDs) + [ID] + self.update_rp(specific_IDs=specific_IDs, preloaded_bbox=bbox) # Repeat tracking self.trackSubsetIDs(splittedIDs) @@ -5223,13 +5233,15 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): if ID in posData.lab: # Store undo state before modifying stuff self.storeUndoRedoStates(False) - obj_idx = posData.IDs.index(ID) - obj = posData.rp[obj_idx] + obj = posData.rp.get_obj_from_ID(ID) objMask = self.getObjImage(obj.image, obj.bbox) localFill = scipy.ndimage.binary_fill_holes(objMask) posData.lab[self.getObjSlice(obj.slice)][localFill] = ID - self.update_rp() + # here it is impossible that hole filling overwrites an ID which + # otuches border + + self.update_rp(use_bbox=True, specific_IDs=ID) self.updateAllImages() if not self.fillHolesToolButton.findChild(QAction).isChecked(): @@ -5262,13 +5274,22 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): if ID in posData.lab: # Store undo state before modifying stuff self.storeUndoRedoStates(False) - obj_idx = posData.IDs.index(ID) - obj = posData.rp[obj_idx] - objMask = self.getObjImage(obj.image, obj.bbox) + obj = posData.rp.get_obj_from_ID(ID) + bbox = obj.bbox + objMask = self.getObjImage(obj.image, bbox) + preloaded_bbox = self.update_rp_get_bbox(custom_bbox=bbox) + localHull = skimage.morphology.convex_hull_image(objMask) - posData.lab[self.getObjSlice(obj.slice)][localHull] = ID - - self.update_rp() + hull_lab = posData.lab[self.getObjSlice(obj.slice)][localHull] + if preloaded_bbox is not False: + IDs_overwritten = np.unique(hull_lab) # dont have to filter 0, includes original ID + hull_lab = ID + self.update_rp(preloaded_bbox=preloaded_bbox, specific_IDs=IDs_overwritten, + ) + # here it is better to use the current view as overwritten IDs + # may be large and bbox could escalate quicky. We have to keep + # track of IDs_overwritten as rp could be changed of cells which + # are outside view range self.updateAllImages() if not self.hullContToolButton.findChild(QAction).isChecked(): @@ -5334,9 +5355,8 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): self.storeUndoRedoStates(False) self.firstID = ID - obj_idx = posData.IDs_idxs[ID] - obj = posData.rp[obj_idx] - yc, xc = self.getObjCentroid(obj.centroid) + centroid = posData.rp.get_centroid(ID) + yc, xc = self.getObjCentroid(centroid) self.clickObjYc, self.clickObjXc = int(yc), int(xc) # Edit ID @@ -5363,8 +5383,8 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): else: ID = editID_prompt.EntryID - obj_idx = posData.IDs_idxs[ID] - y, x = posData.rp[obj_idx].centroid[-2:] + centroid = posData.rp.get_centroid(ID, exact=True) + y, x = self.getObjCentroid(centroid) xdata, ydata = int(x), int(y) posData.disableAutoActivateViewerWindow = True @@ -5636,7 +5656,7 @@ def expandLabel(self, dilation=True): ID = self.hoverLabelID - obj = posData.rp[posData.IDs.index(ID)] + obj = posData.rp.get_obj_from_ID(ID) if reinitExpandingLab: # Store undo state before modifying stuff @@ -5670,6 +5690,7 @@ def expandLabel(self, dilation=True): # Get coords of the dilated/eroded object expandedObj = skimage.measure.regionprops(expandedLab)[0] + expandedObj_bbox = expandedObj.bbox expandedObjCoords = (expandedObj.coords[:,-2], expandedObj.coords[:,-1]) # Add the dilated/erored object @@ -5679,7 +5700,10 @@ def expandLabel(self, dilation=True): self.set_2Dlab(lab_2D) self.currentLab2D = lab_2D - self.update_rp() + preloaded_bbox = self.update_rp_get_bbox(custom_bbox=expandedObj_bbox) + self.update_rp(preloaded_bbox=preloaded_bbox, specific_IDs=ID) + # we dont draw over other IDs so this is rare case where its fine + # to just have tight bbox and specific_IDs=ID if self.labelsGrad.showLabelsImgAction.isChecked(): self.img2.setImage(img=self.currentLab2D, autoLevels=False) @@ -5702,7 +5726,7 @@ def startMovingLabel(self, xPos, yPos): self.searchedIDitemLeft.setData([], []) self.movingID = ID self.prevMovePos = (xdata, ydata) - movingObj = posData.rp[posData.IDs.index(ID)] + movingObj = posData.rp.get_obj_from_ID(ID) self.movingObjCoords = movingObj.coords.copy() yy, xx = movingObj.coords[:,-2], movingObj.coords[:,-1] self.currentLab2D[yy, xx] = 0 @@ -5974,11 +5998,9 @@ def highlightSearchedIDcheckBoxToggled(self, checked): self.highlightedID = self.getHighlightedID() if self.highlightedID == 0: return - objIdx = posData.IDs_idxs[self.highlightedID] - obj_idx = posData.IDs_idxs.get(self.highlightedID) - if obj_idx is None: + obj = posData.rp.get_obj_from_ID(self.highlightedID) + if obj is None: return - obj = posData.rp[objIdx] self.goToZsliceSearchedID(obj) def setHighlightID(self, doHighlight): @@ -5998,13 +6020,12 @@ def propsWidgetIDvalueChanged(self, ID): return propsQGBox = self.guiTabControl.propsQGBox - obj_idx = posData.IDs_idxs.get(ID) - if obj_idx is None: + obj = posData.rp.get_obj_from_ID(ID) + if obj is None: s = f'Object ID {int(ID):d} does not exist' propsQGBox.notExistingIDLabel.setText(s) return - obj = posData.rp[obj_idx] self.goToZsliceSearchedID(obj) self.updatePropsWidget(int(ID)) @@ -6029,7 +6050,7 @@ def updatePropsWidget(self, ID, fromHover=False): return if posData.rp is None: - self.update_rp() + self.update_rp() # IDK when can this happen? if not posData.IDs: # empty segmentation mask @@ -6041,8 +6062,8 @@ def updatePropsWidget(self, ID, fromHover=False): propsQGBox = self.guiTabControl.propsQGBox - obj_idx = posData.IDs_idxs.get(ID) - if obj_idx is None: + obj = posData.rp.get_obj_from_ID(ID) + if obj is None: s = f'Object ID {int(ID):d} does not exist' propsQGBox.notExistingIDLabel.setText(s) return @@ -6058,8 +6079,6 @@ def updatePropsWidget(self, ID, fromHover=False): if doHighlight: self.highlightSearchedID(ID) - obj = posData.rp[obj_idx] - if self.isSegm3D: if self.zProjComboBox.currentText() == 'single z-slice': local_z = self.z_lab() - obj.bbox[0] @@ -6345,9 +6364,8 @@ def drawTempMothBudLine(self, event, posData): if ID == 0: self.BudMothTempLine.setData([x1, x2], [y1, y2]) else: - obj_idx = posData.IDs_idxs[ID] - obj = posData.rp[obj_idx] - y2, x2 = self.getObjCentroid(obj.centroid) + centroid = posData.rp.get_centroid(ID) + y2, x2 = self.getObjCentroid(centroid) self.BudMothTempLine.setData([x1, x2], [y1, y2]) def drawTempMergeObjsLine(self, event, posData, modifiers): @@ -6360,9 +6378,8 @@ def drawTempMergeObjsLine(self, event, posData, modifiers): y1, x1 = self.clickObjYc, self.clickObjXc ID = self.get_2Dlab(posData.lab)[ydata, xdata] if ID != 0: - obj_idx = posData.IDs_idxs[ID] - obj = posData.rp[obj_idx] - y2, x2 = self.getObjCentroid(obj.centroid) + centroid = posData.rp.get_centroid(ID) + y2, x2 = self.getObjCentroid(centroid) if modifier and ID > 0: self.mergeObjsTempLine.addPoint(x2, y2) @@ -6525,7 +6542,7 @@ def gui_setCursor(self, modifiers, event): def warnAddingPointWithExistingId(self, point_id, table_endname=''): posData = self.data[self.pos_i] - if not point_id in posData.IDs_idxs: + if not point_id in posData.IDs: return True msg = widgets.myMessageBox(wrapText=False) @@ -6782,7 +6799,7 @@ def gui_mouseReleaseEventImg2(self, event): self.isMovingLabel = False # Update data (rp, etc) - self.update_rp() + self.update_rp() # IDK can I do optimization here? # Repeat tracking self.tracking(enforce=True, assign_unique_new_IDs=False) @@ -6816,31 +6833,32 @@ def gui_mouseReleaseEventImg2(self, event): return else: ID = mergeID_prompt.EntryID - obj_idx = posData.IDs_idxs[ID] - obj = posData.rp[obj_idx] - y2, x2 = self.getObjCentroid(obj.centroid) - self.mergeObjsTempLine.addPoint(x2, y2) + centroid = posData.rp.get_centroid(ID) + ydata, xdata = self.getObjCentroid(centroid) + ydata, xdata = int(ydata), int(xdata) xx, yy = self.mergeObjsTempLine.getData() IDs_to_merge = lab2D[yy.astype(int), xx.astype(int)] for ID in IDs_to_merge: if ID == 0: continue - posData.lab[posData.lab==ID] = self.firstID + obj = posData.rp.get_obj_from_ID(ID) + + posData.lab[obj.slice][obj.image] = self.firstID self.mergeObjsTempLine.setData([], []) self.clickObjYc, self.clickObjXc = None, None - - # Update data (rp, etc) - self.update_rp() - + + bbox = self.update_rp_get_bbox(specific_IDs=IDs_to_merge,use_bbox=True) # use old IDs to get bbox + specific_IDs = list(IDs_to_merge) + [self.firstID] + self.update_rp(specific_IDs=specific_IDs,preloaded_bbox=bbox) # update with new IDs ask_back_prop = True if posData.frame_i == 0: ask_back_prop = False prev_IDs = [] else: - prev_IDs = posData.allData_li[posData.frame_i-1]['IDs'] + prev_IDs = posData.allData_li[posData.frame_i-1]['regionprops'].IDs if all(ID not in prev_IDs for ID in IDs_to_merge): ask_back_prop = False @@ -6909,8 +6927,9 @@ def gui_mouseReleaseEventImg1(self, event): if self.isRightClickDragImg1 and self.curvToolButton.isChecked(): self.isRightClickDragImg1 = False try: - self.curvToolSplineToObj(isRightClick=True) - self.update_rp() + mask, returnID = self.curvToolSplineToObj(isRightClick=True) + if mask is not None: + self.update_rp() # how can I optimize this? I think not possible tbh self.trackManuallyAddedObject(posData.brushID, True) if self.isSnapshot: self.fixCcaDfAfterEdit('Add new ID with curvature tool') @@ -6931,7 +6950,7 @@ def gui_mouseReleaseEventImg1(self, event): self.clearTempBrushImage() # Update data (rp, etc) - self.update_rp() + self.update_rp(use_curr_view=True) # only visible stuff can be deleted doUpdateImages = self.checkWarnDeletedIDwithEraser() @@ -6956,7 +6975,8 @@ def gui_mouseReleaseEventImg1(self, event): posData.lab[self.flood_mask] = posData.brushID # Update data (rp, etc) - self.update_rp() + # only visible stuff can be added, plus doesnt draw over eixisting + self.update_rp(use_curr_view=True, specific_IDs=posData.brushID) # Repeat tracking self.trackManuallyAddedObject(posData.brushID, self.isNewID) @@ -7037,7 +7057,7 @@ def gui_mouseReleaseEventImg1(self, event): self.isMovingLabel = False # Update data (rp, etc) - self.update_rp() + self.update_rp(use_curr_view=True) # only visible stuff can be moved # Repeat tracking self.tracking(enforce=True, assign_unique_new_IDs=False) @@ -7072,9 +7092,9 @@ def gui_mouseReleaseEventImg1(self, event): return else: ID = mothID_prompt.EntryID - obj_idx = posData.IDs.index(ID) - y, x = posData.rp[obj_idx].centroid - xdata, ydata = int(x), int(y) + centroid = posData.rp.get_centroid(ID) + ydata, xdata = self.getObjCentroid(centroid) + ydata, xdata = int(ydata), int(xdata) if self.isSnapshot: # Store undo state before modifying stuff @@ -7105,11 +7125,9 @@ def gui_mouseReleaseEventImg1(self, event): # on a mother budID = self.get_2Dlab(posData.lab)[self.yClickBud, self.xClickBud] new_mothID = self.get_2Dlab(posData.lab)[ydata, xdata] - bud_obj_idx = posData.IDs.index(budID) - new_moth_obj_idx = posData.IDs.index(new_mothID) - rp_budID = posData.rp[bud_obj_idx] - rp_new_mothID = posData.rp[new_moth_obj_idx] - if rp_budID.area >= rp_new_mothID.area: + bug_obj = posData.rp.get_obj_from_ID(budID) + new_mother_obj = posData.rp.get_obj_from_ID(new_mothID) + if bug_obj.area >= new_mother_obj.area: self.assignBudMothButton.setChecked(False) msg = widgets.myMessageBox() txt = ( @@ -7847,7 +7865,7 @@ def gui_mousePressEventImg1(self, event: QMouseEvent): elif right_click and copyContourON: hoverLostID = self.ax1_lostObjScatterItem.hoverLostID self.copyLostObjectContour(hoverLostID) - self.update_rp() + self.update_rp(use_curr_view=True) # only visible self.updateAllImages() self.store_data() @@ -7897,7 +7915,7 @@ def gui_mousePressEventImg1(self, event: QMouseEvent): if closeSpline: self.splineHoverON = False self.curvToolSplineToObj() - self.update_rp() + self.update_rp() # dont think I can optimize this self.trackManuallyAddedObject(posData.brushID, True) if self.isSnapshot: self.fixCcaDfAfterEdit('Add new ID with curvature tool') @@ -7973,21 +7991,27 @@ def gui_mousePressEventImg1(self, event: QMouseEvent): posData = self.data[self.pos_i] currentIDs = posData.IDs.copy() if manualTrackID in currentIDs: - tempID = max(currentIDs) + 1 - posData.lab[posData.lab == clickedID] = tempID - posData.lab[posData.lab == manualTrackID] = clickedID - posData.lab[posData.lab == tempID] = manualTrackID + clicked_obj = posData.rp.get_obj_from_ID(clickedID) + manual_track_obj = posData.rp.get_obj_from_ID(manualTrackID) + posData.lab[clicked_obj.slice][clicked_obj.image] = manualTrackID + posData.lab[manual_track_obj.slice][manual_track_obj.image] = clickedID self.manualTrackingToolbar.showWarning( f'The ID {manualTrackID} already exists --> ' f'ID {manualTrackID} has been swapped with {clickedID}' ) + assignments = {clickedID: manualTrackID, + manualTrackID: clickedID} else: - posData.lab[posData.lab == clickedID] = manualTrackID + clicked_obj = posData.rp.get_obj_from_ID(clickedID) + posData.lab[clicked_obj.slice][clicked_obj.image] = manualTrackID self.manualTrackingToolbar.showInfo( f'ID {clickedID} changed to {manualTrackID}.' ) + assignments = {clickedID: manualTrackID} - self.update_rp() + # only ID change, so use assignments + # not 3D ready yet? Otherwise I must set assignments to None + self.update_rp(assignments=assignments) self.updateAllImages() elif right_click and manualBackgroundON: @@ -8051,9 +8075,9 @@ def gui_mousePressEventImg1(self, event: QMouseEvent): return else: ID = divID_prompt.EntryID - obj_idx = posData.IDs.index(ID) - y, x = posData.rp[obj_idx].centroid - xdata, ydata = int(x), int(y) + centroid = posData.rp.get_centroid(ID) + ydata, xdata = self.getObjCentroid(centroid) + ydata, xdata = int(ydata), int(xdata) if not self.isSnapshot: # Store undo state before modifying stuff @@ -8096,8 +8120,9 @@ def gui_mousePressEventImg1(self, event: QMouseEvent): ID = budID_prompt.EntryID obj_idx = posData.IDs.index(ID) - y, x = posData.rp[obj_idx].centroid - xdata, ydata = int(x), int(y) + centroid = posData.rp.get_centroid(ID) + ydata, xdata = self.getObjCentroid(centroid) + ydata, xdata = int(ydata), int(xdata) relationship = posData.cca_df.at[ID, 'relationship'] is_history_known = posData.cca_df.at[ID, 'is_history_known'] @@ -8143,9 +8168,9 @@ def gui_mousePressEventImg1(self, event: QMouseEvent): return else: ID = unknownID_prompt.EntryID - obj_idx = posData.IDs.index(ID) - y, x = posData.rp[obj_idx].centroid - xdata, ydata = int(x), int(y) + centroid = posData.rp.get_centroid(ID) + ydata, xdata = self.getObjCentroid(centroid) + ydata, xdata = int(ydata), int(xdata) self.annotateIsHistoryKnown(ID) if not self.setIsHistoryKnownButton.findChild(QAction).isChecked(): @@ -8172,9 +8197,9 @@ def gui_mousePressEventImg1(self, event: QMouseEvent): return else: ID = clickedBkgrDialog.EntryID - obj_idx = posData.IDs.index(ID) - y, x = posData.rp[obj_idx].centroid - xdata, ydata = int(x), int(y) + centroid = posData.rp.get_centroid(ID) + ydata, xdata = self.getObjCentroid(centroid) + ydata, xdata = int(ydata), int(xdata) button = self.doCustomAnnotation(ID) if button is None: @@ -8615,10 +8640,17 @@ def onSigStoreData( autosave=autosave, store_cca_df_copy=store_cca_df_copy) waitcond.wakeAll() - def onSigUpdateRP(self, waitcond, draw=True, debug=False, update_IDs=True, - wl_update=True, wl_track_og_curr=False): - self.update_rp(draw=draw, debug=debug, update_IDs=update_IDs, - wl_update=wl_update, wl_track_og_curr=wl_track_og_curr) + def onSigUpdateRP(self, waitcond, + draw=True, debug=False, # og stuff + assignments=None, deletionIDs=None, # very quick upates, rp labels are changed but rest is same + specific_IDs=None, use_curr_view=False, use_bbox=False, preloaded_bbox=None, # for local updates to PR + wl_update=True, wl_track_og_curr=False,wl_update_lab=False, # wl stuff + ): + self.update_rp(draw=True, debug=False, # og stuff + assignments=None, deletionIDs=None, # very quick upates, rp labels are changed but rest is same + specific_IDs=None, use_curr_view=False, use_bbox=False, preloaded_bbox=None, # for local updates to PR + wl_update=True, wl_track_og_curr=False,wl_update_lab=False, # wl stuff + ) waitcond.wakeAll() def onSigGetData(self, waitcond, debug=False): @@ -8627,7 +8659,7 @@ def onSigGetData(self, waitcond, debug=False): def SegForLostIDsWorkerFinished(self): self.updateAllImages() - self.update_rp() + self.update_rp() # will update when updating segoforlostIDs self.store_data(autosave=True) self.setFrameNavigationDisabled(disable=False, why='Segmentation for lost IDs') @@ -8977,14 +9009,13 @@ def searchIDworkerCallback(self, posData, searchedID): for frame_i in range(len(posData.segm_data)): if frame_i >= len(posData.allData_li): break - lab = posData.allData_li[frame_i]['labels'] - if lab is None: - rp = skimage.measure.regionprops(posData.segm_data[frame_i]) - IDs = set([obj.label for obj in rp]) - else: - IDs = posData.allData_li[frame_i]['IDs'] - if searchedID in IDs: + rp = posData.allData_li[frame_i]['regionprops'] + if rp is None: + lab = posData.segm_data[frame_i] + rp = regionprops.acdcRegionprops(lab) + posData.allData_li[frame_i]['regionprops'] = rp + if searchedID in rp.IDs: frame_i_found = frame_i break @@ -9001,8 +9032,7 @@ def warnIDnotFound(self, searchedID): def goToObjectID(self, ID): posData = self.data[self.pos_i] - objIdx = posData.IDs_idxs[ID] - obj = posData.rp[objIdx] + obj = posData.rp.get_obj_from_ID(ID) self.goToZsliceSearchedID(obj) self.highlightSearchedID(ID) @@ -9013,8 +9043,7 @@ def goToLostObjectID(self, lostID, color=(255, 165, 0, 255)): posData = self.data[self.pos_i] frame_i = posData.frame_i prev_rp = posData.allData_li[frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[frame_i-1]['IDs_idxs'] - obj = prev_rp[prev_IDs_idxs[lostID]] + obj = prev_rp.get_obj_from_ID(lostID) self.goToZsliceSearchedID(obj) imageItem = self.getLostObjImageItem(0) @@ -9037,8 +9066,7 @@ def goToAcceptedLostObjectID(self, acceptedLostID): posData = self.data[self.pos_i] frame_i = posData.frame_i prev_rp = posData.allData_li[frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[frame_i-1]['IDs_idxs'] - obj = prev_rp[prev_IDs_idxs[acceptedLostID]] + obj = prev_rp.get_obj_from_ID(acceptedLostID) self.goToZsliceSearchedID(obj) self.updateLostTrackedContoursImage(tracked_lost_IDs=[acceptedLostID]) @@ -9641,43 +9669,53 @@ def applyEditID( lab = posData.lab # Store undo state before modifying stuff + # no risk of merging IDs if we are working with rp and dont updaet in the middle... self.storeUndoRedoStates(UndoFutFrames) - maxID = max(posData.IDs, default=0) - for old_ID, new_ID in oldIDnewIDMapper: + # could this be chained??? If yes we have to "simplify" to least swops to since we keep RP stale + # oldIDnewIDMapper + assignments = {} + for old_ID, new_ID in oldIDnewIDMapper: if new_ID in currentIDs and not self.editIDmergeIDs: - tempID = maxID + 1 - lab[lab == old_ID] = maxID + 1 - lab[lab == new_ID] = old_ID - lab[lab == tempID] = new_ID - maxID += 1 - - old_ID_idx = currentIDs.index(old_ID) - new_ID_idx = currentIDs.index(new_ID) - - # Append information for replicating the edit in tracking - # List of tuples (y, x, replacing ID) - objo = posData.rp[old_ID_idx] - yo, xo = self.getObjCentroid(objo.centroid) - objn = posData.rp[new_ID_idx] - yn, xn = self.getObjCentroid(objn.centroid) - if not math.isnan(yo) and not math.isnan(yn): + objo = posData.rp.get_obj_from_ID(old_ID) + objn = posData.rp.get_obj_from_ID(new_ID) + + # Relabel old_ID to new ID, save since rp is "stale" + slc_o = objo.slice + mask_o = objo.image + lab[slc_o][mask_o] = new_ID + + # Relabel new_ID to old_ID + slc_n = objn.slice + mask_n = objn.image + lab[slc_n][mask_n] = old_ID + + + # ask Francesco what this does ¯\_(ツ)_/¯ + # we have to switch as we switched IDs and RP is stale + objn_centroid = posData.rp.get_centroid(old_ID, exact=True) # !!!This is actually the original RP still!!! + # objo_centroid = posData.rp.get_centroid(new_ID, exact=True) + # yo, xo = self.getObjCentroid(objo_centroid) + yn, xn = self.getObjCentroid(objn_centroid) + # if not (math.isnan(yo) or math.isnan(yn)): + if not math.isnan(yn): yn, xn = int(yn), int(xn) posData.editID_info.append((yn, xn, new_ID)) yo, xo = int(clicked_y), int(clicked_x) posData.editID_info.append((yo, xo, old_ID)) + assignments[new_ID] = old_ID + assignments[old_ID] = new_ID else: - lab[lab == old_ID] = new_ID - if new_ID > maxID: - maxID = new_ID - old_ID_idx = posData.IDs.index(old_ID) - - # Append information for replicating the edit in tracking - # List of tuples (y, x, replacing ID) - obj = posData.rp[old_ID_idx] - y, x = self.getObjCentroid(obj.centroid) - if not math.isnan(y) and not math.isnan(y): + # Use regionprops for old_ID + obj = posData.rp.get_obj_from_ID(old_ID) + slc = obj.slice + mask = obj.image + lab[slc][mask] = new_ID + centroid = posData.rp.get_centroid(old_ID, exact=True) + y, x = self.getObjCentroid(centroid) + if not math.isnan(y) and not math.isnan(x): y, x = int(y), int(x) posData.editID_info.append((y, x, new_ID)) + assignments[old_ID] = new_ID self.updateAssignedObjsAcdcTrackerSecondStep(new_ID) @@ -9685,7 +9723,7 @@ def applyEditID( self.set_2Dlab(lab) # Update rps - self.update_rp() + self.update_rp(assignments = assignments if (shift and self.isSegm3D) else None) # Since we manually changed an ID we don't want to repeat tracking self.setAllTextAnnotations() @@ -10240,13 +10278,13 @@ def annotateIsHistoryKnown(self, ID): # we set the cca of it to the status it had BEFORE the assignment posData.cca_df.loc[relID] = relID_cca - # Update cell cycle info LabelItems - obj_idx = posData.IDs.index(ID) - rp_ID = posData.rp[obj_idx] + # Update cell cycle info LabelItems what was the function here? + # obj_idx = posData.IDs.index(ID) + # rp_ID = posData.rp[obj_idx] - if relID in posData.IDs: - relObj_idx = posData.IDs.index(relID) - rp_relID = posData.rp[relObj_idx] + # if relID in posData.IDs: + # relObj_idx = posData.IDs.index(relID) + # rp_relID = posData.rp[relObj_idx] self.setAllTextAnnotations() self.drawAllMothBudLines() @@ -10468,10 +10506,10 @@ def undoBudMothAssignment(self, ID): posData.cca_df.at[relID, 'cell_cycle_stage'] = 'G1' posData.cca_df.at[relID, 'relationship'] = 'mother' - obj_idx = posData.IDs.index(ID) - relObj_idx = posData.IDs.index(relID) - rp_ID = posData.rp[obj_idx] - rp_relID = posData.rp[relObj_idx] + # obj_idx = posData.IDs.index(ID) what was the function of this? + # relObj_idx = posData.IDs.index(relID) + # rp_ID = posData.rp[obj_idx] + # rp_relID = posData.rp[relObj_idx] self.store_cca_df() @@ -11534,14 +11572,10 @@ def delBorderObj(self, checked): self.storeUndoRedoStates(False) posData = self.data[self.pos_i] - posData.lab = skimage.segmentation.clear_border( - posData.lab, buffer_size=1 - ) - oldIDs = posData.IDs.copy() - self.update_rp() - removedIDs = [ID for ID in oldIDs if ID not in posData.IDs] + edge_ids = myutils.clear_border(posData.lab, return_edge_ids=True) # modifies inplace + self.update_rp(deletionIDs=edge_ids) if posData.cca_df is not None: - posData.cca_df = posData.cca_df.drop(index=removedIDs) + posData.cca_df = posData.cca_df.drop(index=edge_ids) self.store_data() self.updateAllImages() @@ -11555,7 +11589,7 @@ def delNewObj(self, checked): if frame_i == 0: return - prev_IDs = posData.allData_li[frame_i-1]['IDs'] + prev_IDs = posData.allData_li[frame_i-1]['regionprops'].IDs curr_IDs = posData.IDs new_IDs = list(set(curr_IDs) - set(prev_IDs)) @@ -11564,7 +11598,7 @@ def delNewObj(self, checked): lab[del_mask] = 0 posData.lab = lab - self.update_rp() + self.update_rp(deletionIDs=new_IDs) if posData.cca_df is not None: posData.cca_df = posData.cca_df.drop(index=new_IDs) @@ -11586,13 +11620,16 @@ def brushReleased(self): self.fillHolesID(posData.brushID, sender='brush') # Update data (rp, etc) - self.update_rp(update_IDs=self.isNewID,) + + power_brush = self.isPowerBrush() + # we have to delay for a second + self.update_rp(use_curr_view=True, specific_IDs=posData.brushID if not power_brush else None) # Repeat tracking if self.autoIDcheckbox.isChecked(): self.trackManuallyAddedObject(posData.brushID, self.isNewID) - else: - self.update_rp(update_IDs=posData.brushID not in posData.IDs_idxs) + # else: I think not necessary + # self.update_rp(use_curr_view=True) # Update images if self.isNewID: @@ -11764,7 +11801,7 @@ def delROImoving(self, roi): def delROImovingFinished(self, roi: pg.ROI): roi.setPen(color='r') - self.update_rp() + self.update_rp() # get bbox of delROI old and new, run update_rp on both seperately self.updateAllImages() QTimer.singleShot( 300, partial(self.updateDelROIinFutureFrames, roi) @@ -11802,7 +11839,7 @@ def restoreAnnotDelROI(self, roi, enforce=True, draw=True): delROIs_info['delIDsROI'][idx] = delIDs - restoredIDs self.set_2Dlab(lab2D) - self.update_rp() + self.update_rp() # get bbox of delROI old and new, run update_rp on both seperately def restoreDelROIimg1(self, delMaskID, delID, ax=0): if ax == 0: @@ -13261,8 +13298,7 @@ def manualAnnotPast_cb(self, checked): ) self.editIDspinbox.setValue(hoverID) try: - obj_idx = posData.IDs_idxs[hoverID] - obj = posData.rp[obj_idx] + obj = posData.rp.get_obj_from_ID(hoverID) radius = 0.9 * obj.minor_axis_length / 2 # math.sqrt(obj.area/math.pi)*0.9 self.brushSizeSpinbox.setValue(round(radius)) except Exception as err: @@ -13373,17 +13409,16 @@ def copyAllLostObjectsWorkerCallback( posData.frame_i = frame_i self.get_data() - self.tracking(wl_update=False) - self.update_rp() + self.tracking() # we already update rp inside here + # self.update_rp() self.updateLostNewCurrentIDs() self.store_data(mainThread=False, autosave=False) # delROIsIDs = self.getDelRoisIDs() self.lostObjContoursImage[:] = 0 prev_rp = posData.allData_li[frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[frame_i-1]['IDs_idxs'] for lostID in posData.lost_IDs: - obj = prev_rp[prev_IDs_idxs[lostID]] + obj = prev_rp.get_obj_from_ID(lostID) self.addLostObjsToImage(obj, lostID, force=True) for lostObj in skimage.measure.regionprops(self.lostObjImage): @@ -13503,7 +13538,7 @@ def copyAllLostObjectsWorkerFinished(self, output): self.blinker.start() self.copyAllLostObjectsWorkerLoop.exit() - self.update_rp() + self.update_rp() # global op and obj added, no opt imo unless difference pic self.updateAllImages() self.store_data() @@ -13719,7 +13754,7 @@ def labelRoiDone(self, roiSegmData, isTimeLapse): frame_i = start_frame_i + i lab = posData.allData_li[frame_i]['labels'] store = True - if lab is None: + if lab is None: # no rp update here? if frame_i >= len(posData.segm_data): lab = np.zeros_like(posData.segm_data[0]) posData.segm_data = np.append( @@ -13735,6 +13770,7 @@ def labelRoiDone(self, roiSegmData, isTimeLapse): if store: posData.frame_i = frame_i posData.allData_li[frame_i]['labels'] = lab.copy() + # no rp update here? self.get_data() self.store_data(autosave=False) @@ -13747,7 +13783,7 @@ def labelRoiDone(self, roiSegmData, isTimeLapse): roiLab, self.labelRoiSlice, posData.lab, posData.brushID ) - self.update_rp() + self.update_rp() # get roi and set as bbox # Repeat tracking if self.autoIDcheckbox.isChecked(): @@ -13778,8 +13814,7 @@ def labelRoiDone(self, roiSegmData, isTimeLapse): def restoreHoverObjBrush(self): posData = self.data[self.pos_i] if self.ax1BrushHoverID in posData.IDs: - obj_idx = posData.IDs_idxs[self.ax1BrushHoverID] - obj = posData.rp[obj_idx] + obj = posData.rp.get_obj_from_ID(self.ax1BrushHoverID) if not self.isObjVisible(obj.bbox): return @@ -13871,15 +13906,18 @@ def setAllIDs(self, onlyVisited=False): for frame_i in range(len(posData.segm_data)): if frame_i >= len(posData.allData_li): break + lab = posData.allData_li[frame_i]['labels'] if lab is None and onlyVisited: break - if lab is None: - rp = skimage.measure.regionprops(posData.segm_data[frame_i]) - else: - rp = posData.allData_li[frame_i]['regionprops'] - posData.allIDs.update([obj.label for obj in rp]) + rp = posData.allData_li[frame_i]['regionprops'] + if rp is None: + lab = posData.segm_data[frame_i] + rp = regionprops.acdcRegionprops(lab) + posData.allData_li[frame_i]['regionprops'] = rp + + posData.allIDs.update(rp.IDs) def countObjectsTimelapse(self): if self.countObjsWindow is None: @@ -13932,7 +13970,7 @@ def countObjectsSnapshots(self): ) for pos_i, _posData in enumerate(self.data): - IDs = _posData.allData_li[0]['IDs'] + IDs = _posData.allData_li[0]['regionprops'].IDs if os.path.exists(_posData.acdc_output_csv_path): numObjectsVisitedPosPrevious += len(IDs) if IDs: @@ -14815,10 +14853,10 @@ def keyPressEvent(self, ev): if ev.key() == Qt.Key_Q and self.debug: try: from . import _q_debug - _q_debug.q_debug(self) + _q_debug.q_debug(self, ev) except Exception as err: printl(traceback.format_exc()) - printl('[ERROR]: Error with "_qdebug" module. See Traceback above.') + printl('[ERROR]: Error with "_q_debug" module. See Traceback above.') pass if not self.isDataLoaded: @@ -15160,7 +15198,7 @@ def propagateMergeObjsPast(self, IDs_to_merge): posData.frame_i = past_frame_i self.get_data() - IDs = posData.allData_li[past_frame_i]['IDs'] + IDs = posData.allData_li[past_frame_i]['regionprops'].IDs stop_loop = False for ID in IDs_to_merge: if ID not in IDs: @@ -15169,10 +15207,13 @@ def propagateMergeObjsPast(self, IDs_to_merge): if ID == 0: continue - posData.lab[posData.lab==ID] = self.firstID - self.update_rp() - - self.store_data(autosave=False) + obj = posData.rp.get_obj_from_ID(ID) + posData.lab[obj.slice][obj.image] = self.firstID + + preloaded_bbox = self.update_rp_get_bbox(specific_IDs=IDs_to_merge,use_bbox=True) # use old RP to get the correct bbox + specific_IDs = list(IDs_to_merge) + [self.firstID] + self.update_rp(preloaded_bbox=preloaded_bbox, specific_IDs=specific_IDs) + self.store_data(autosave=False) if stop_loop: break @@ -15695,8 +15736,8 @@ def warnTrackerInputNotValid(self, trackerName, warningText): def repeatTracking(self): posData = self.data[self.pos_i] - prev_lab = self.get_2Dlab(posData.lab).copy() - self.tracking(enforce=True, DoManualEdit=False) + tracked_lab, assignments = self.tracking(enforce=True, DoManualEdit=False, return_assignments=True, return_lab=True) + posData.lab = tracked_lab if posData.editID_info: editedIDsInfo = { posData.lab[y,x]:newID @@ -15727,18 +15768,25 @@ def repeatTracking(self): detailsText=editIDul ) if msg.cancel: + self.update_rp(assignments=assignments) # rp now stale as we return img return if msg.clickedButton == keepManualEditButton: - allIDs = [obj.label for obj in posData.rp] + allIDs = posData.rp.IDs lab2D = self.get_2Dlab(posData.lab) - self.manuallyEditTracking(lab2D, allIDs) - self.update_rp() + tracked_lab, assignments = self.manuallyEditTracking(lab2D, assignments) # here not use tracked lab? + self.update_rp(assignments=assignments) # rp now stale as we return img self.setAllTextAnnotations() self.highlightLostNew() # self.checkIDsMultiContour() else: + self.update_rp(assignments=assignments) # rp now stale as we return img posData.editID_info = [] - if np.any(posData.lab != prev_lab): + else: + self.update_rp(assignments=assignments) + + # filter self assignments + assignments = {k: v for k, v in assignments.items() if k != v} + if assignments: if self.isSnapshot: self.fixCcaDfAfterEdit('Repeat tracking') self.updateAllImages() @@ -15810,8 +15858,7 @@ def initManualBackgroundObject(self, ID=None): self.manualBackgroundObjItem.clear() return - ID_idx = posData.IDs_idxs[ID] - self.manualBackgroundObj = posData.rp[ID_idx] + self.manualBackgroundObj = posData.rp.get_obj_from_ID(ID) self.manualBackgroundToolbar.clearInfoText() self.manualBackgroundObj.contour = self.getObjContours( @@ -16909,12 +16956,11 @@ def doCustomAnnotation(self, ID): xx, yy = [], [] for annotID in annotIDs_frame_i: - obj_idx = posData.IDs_idxs[annotID] - obj = posData.rp[obj_idx] + obj = posData.rp.get_obj_from_ID(annotID) acdc_df.at[annotID, state['name']] = 1 if not self.isObjVisible(obj.bbox): continue - y, x = self.getObjCentroid(obj.centroid) + y, x = self.getObjCentroid(posData.rp.get_centroid(annotID, exact=True)) xx.append(x) yy.append(y) @@ -18199,9 +18245,8 @@ def warnLostObjects(self, do_warn=True): posData.accepted_lost_IDs[frame_i].extend(posData.lost_IDs) # This section is adding the lost cells to tracked_lost_centroids... TBH I dont know why this wasnt done in the first place prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[posData.frame_i-1]['IDs_idxs'] accepted_lost_centroids = { - tuple(int(val) for val in prev_rp[prev_IDs_idxs[ID]].centroid) + tuple(int(val) for val in prev_rp.get_centroid(ID, exact=True)) for ID in posData.lost_IDs } try: @@ -18281,8 +18326,7 @@ def checkIfFutureFrameManualAnnotPastFrames(self): self.statusBarLabel.setText(f'

{warn_txt}

') return False - - # @exec_time + def next_frame(self, warn=True): proceed = self.checkIfFutureFrameManualAnnotPastFrames() if not proceed: @@ -18811,6 +18855,7 @@ def loadSelectedData(self, user_ch_file_paths, user_ch_name): create_new_segm=self.isNewFile, new_endname=self.newSegmEndName, end_filename_segm=selectedSegmEndName, + load_segm_info_ini=True ) self.selectedSegmEndName = selectedSegmEndName self.labelBoolSegm = posData.labelBoolSegm @@ -20232,7 +20277,7 @@ def curvToolSplineToObj(self, xxA=None, yyA=None, isRightClick=False): xxS, yyS = self.curvPlotItem.getData() if xxS is None: self.setUncheckedAllButtons() - return + return None, None N = len(xxS) self.smoothAutoContWithSpline(n=int(N*0.05)) @@ -20250,6 +20295,7 @@ def curvToolSplineToObj(self, xxA=None, yyA=None, isRightClick=False): newIDMask[self.currentLab2D!=0] = False self.currentLab2D[newIDMask] = curvToolID self.set_2Dlab(self.currentLab2D) + return newIDMask, curvToolID def addFluoChNameContextMenuAction(self, ch_name): posData = self.data[self.pos_i] @@ -20548,37 +20594,67 @@ def getStoredSegmData(self): segm_data.append(lab) return np.array(segm_data) - def trackNewIDtoNewIDsFutureFrame(self, newID, newIDmask): + def trackNewIDtoNewIDsFutureFrame(self, newID, obj, assignments): + # here RP is stale posData = self.data[self.pos_i] try: nextLab = posData.allData_li[posData.frame_i+1]['labels'] except IndexError: # This is last frame --> there are no future frames - return + return None, assignments if nextLab is None: - return + return None, assignments + + if obj is None: + return None, assignments + - newID_lab = np.zeros_like(posData.lab) - newID_lab[newIDmask] = newID - newLab_rp = [posData.rp[posData.IDs_idxs[newID]]] - newLab_IDs = [newID] nextRp = posData.allData_li[posData.frame_i+1]['regionprops'] + nextLab = posData.allData_li[posData.frame_i+1]['labels'] + reverse_assignments = {v:k for k, v in assignments.items()} - tracked_lab = self.trackFrame( - nextLab, nextRp, newID_lab, newLab_rp, newLab_IDs, - assign_unique_new_IDs=False + rp = posData.rp + lab = posData.lab + + # make rp remporarliy not stale anymore + rp.update_regionprops_via_assignments(assignments, lab) + tracked_lab, assignments_new = self.trackFrame( + nextLab, nextRp, lab, rp, rp.IDs, + assign_unique_new_IDs=False, return_assignments=True, + specific_IDs=[newID], ) - trackedID = tracked_lab[newID_lab>0][0] + # restore rp + posData.rp.update_regionprops_via_assignments(reverse_assignments, lab) + + # clear self assignments + assignments_new = { + k:v for k, v in assignments_new.items() if k != v + } + if not assignments_new: + return None, assignments + + trackedIDs = list(assignments_new.values()) + + trackedID = trackedIDs[0] if trackedID == newID: # Object does not exist in future frame --> do not track - return + return None, assignments - if posData.IDs_idxs.get(trackedID) is not None: + if posData.rp.get_obj_from_ID(trackedID, warn=False) is not None: # Tracked ID already exists --> do not track to avoid merging - return + return None, assignments - return trackedID + + + # update assignments + assignments = { + old_ID: tracked_ID for old_ID, tracked_ID in assignments.items() + if old_ID != newID + } + assignments[newID] = trackedID + + return trackedID, assignments def store_manual_annot_data( self, posData=None, data_frame_i=None @@ -20621,15 +20697,15 @@ def store_data( # self.lin_tree_ask_changes() allData_li = posData.allData_li[posData.frame_i] - allData_li['regionprops'] = posData.rp.copy() + allData_li['regionprops'] = posData.rp allData_li['labels'] = posData.lab.copy() - allData_li['IDs'] = posData.IDs.copy() + allData_li['regionprops'].IDs = posData.IDs allData_li['manualBackgroundLab'] = ( posData.manualBackgroundLab ) - allData_li['IDs_idxs'] = ( - posData.IDs_idxs.copy() - ) + # allData_li['IDs_idxs'] = ( + # posData.IDs_idxs.copy() + # ) if self.manualAnnotPastButton.isChecked(): self.store_manual_annot_data( posData=posData, data_frame_i=allData_li @@ -20651,13 +20727,17 @@ def store_data( is_cell_dead_li[i] = obj.dead is_cell_excluded_li[i] = obj.excluded IDs[i] = obj.label - try: - xx_centroid[i] = int(self.getObjCentroid(obj.centroid)[1]) - yy_centroid[i] = int(self.getObjCentroid(obj.centroid)[0]) - except Exception as err: - printl(obj, obj.centroid, obj.label, posData.frame_i) + centroid = posData.rp.get_centroid(obj.label, exact=True) + if centroid is None: + continue + if self.isSegm3D: - zz_centroid[i] = int(obj.centroid[0]) + zz_centroid[i] = int(centroid[0]) + xx_centroid[i] = int(centroid[2]) + yy_centroid[i] = int(centroid[1]) + else: + xx_centroid[i] = int(centroid[1]) + yy_centroid[i] = int(centroid[0]) if obj.label in editedNewIDs: areManuallyEdited[i] = 1 @@ -21436,16 +21516,16 @@ def assignNewIDfromClickedID( mapper = [(clickedID, newID)] self.applyEditID(clickedID, posData.IDs.copy(), mapper, x, y) - def get_2Drp(self, lab=None): - if self.isSegm3D: - if lab is None: - # self.currentLab2D is defined at self.setImageImg2() - lab = self.currentLab2D - lab = self.get_2Dlab(lab) - rp = skimage.measure.regionprops(lab) - return rp - else: - return self.data[self.pos_i].rp + # def get_2Drp(self, lab=None): Not in use + # if self.isSegm3D: + # if lab is None: + # # self.currentLab2D is defined at self.setImageImg2() + # lab = self.currentLab2D + # lab = self.get_2Dlab(lab) + # rp = skimage.measure.regionprops(lab) + # return rp + # else: + # return self.data[self.pos_i].rp def set_2Dlab(self, lab2D): posData = self.data[self.pos_i] @@ -21534,10 +21614,12 @@ def get_labels( def addYXcentroidToDf(self, df): posData = self.data[self.pos_i] for obj in posData.rp: - y_centroid = int(self.getObjCentroid(obj.centroid)[0]) - x_centroid = int(self.getObjCentroid(obj.centroid)[1]) - df.at[obj.label, 'y_centroid'] = y_centroid - df.at[obj.label, 'x_centroid'] = x_centroid + ID = obj.label + centroid = posData.rp.get_centroid(obj, exact=True) + y_centroid = int(self.getObjCentroid(centroid)[0]) + x_centroid = int(self.getObjCentroid(centroid)[1]) + df.at[ID, 'y_centroid'] = y_centroid + df.at[ID, 'x_centroid'] = x_centroid return df def _get_editID_info(self, df): @@ -21617,7 +21699,8 @@ def _get_data_unvisited(self, posData, debug=False, lin_tree_init=True,): posData.lab = self.apply_manual_edits_to_lab_if_needed( labels ) - posData.rp = skimage.measure.regionprops(posData.lab) + posData.rp = posData.allData_li[posData.frame_i]['regionprops'] + # get stored IDs self.setManualBackgroundLab() if posData.acdc_df is not None: @@ -21660,7 +21743,8 @@ def _get_data_visited(self, posData, debug=False, lin_tree_init=True,): # Requested frame was already visited. Load from RAM. never_visited = False posData.lab = self.get_labels(from_store=True) - posData.rp = skimage.measure.regionprops(posData.lab) + # posData.rp = skimage.measure.regionprops(posData.lab) + posData.rp = posData.allData_li[posData.frame_i]['regionprops'] df = posData.allData_li[posData.frame_i]['acdc_df'] if df is None: posData.binnedIDs = set() @@ -21700,6 +21784,7 @@ def get_data(self, debug=False, lin_tree_init=True): else: self.undoAction.setDisabled(True) self.UndoCount = 0 + # If stored labels is None then it is the first time we visit this frame if posData.allData_li[posData.frame_i]['labels'] is None: proceed_cca, never_visited = self._get_data_unvisited( @@ -21713,10 +21798,7 @@ def get_data(self, debug=False, lin_tree_init=True): ) self.update_rp_metadata(draw=False) - posData.IDs = [obj.label for obj in posData.rp] - posData.IDs_idxs = { - ID:i for ID, i in zip(posData.IDs, range(len(posData.IDs))) - } + posData.IDs = posData.rp.IDs self.get_zslices_rp() self.pointsLayerDfsToData(posData) return proceed_cca, never_visited @@ -21743,7 +21825,7 @@ def addIDBaseCca_df(self, posData, ID): def getBaseCca_df(self, with_tree_cols=False): posData = self.data[self.pos_i] - IDs = [obj.label for obj in posData.rp] + IDs = posData.rp.IDs cca_df = core.getBaseCca_df(IDs, with_tree_cols=with_tree_cols) return cca_df @@ -22305,6 +22387,32 @@ def get_cca_df(self, frame_i=None, return_df=False, debug=False): return cca_df else: posData.cca_df = cca_df + + def _changeIDhelper(self, lab, oldID, newID, rp, assignments): + did_find_newID = False + if newID in rp.IDs: # should here also self.editIDmergeIDs? + # Relabel old_ID to tempID, safe as RP is safe so no merging + objo = rp.get_obj_from_ID(oldID, warn=False) + if objo is not None: + slc_o = objo.slice + mask_o = objo.image + lab[slc_o][mask_o] = newID + assignments[oldID] = newID + # Relabel new_ID to old_ID + objn = rp.get_obj_from_ID(newID) # here warn, we check in the if if it should be there + objn_slice = objn.slice + objn_mask = objn.image + lab[objn_slice][objn_mask] = oldID + assignments[newID] = oldID + did_find_newID = True + else: + obj = rp.get_obj_from_ID(oldID, warn=False) + if obj is not None: + slc = obj.slice + mask = obj.image + lab[slc][mask] = newID + assignments[oldID] = newID + return did_find_newID def changeIDfutureFrames( self, endFrame_i, oldIDnewIDMapper, includeUnvisited, @@ -22321,6 +22429,7 @@ def changeIDfutureFrames( segmSizeT = len(posData.segm_data) for i in range(posData.frame_i+1, segmSizeT): + assignments = {} lab = posData.allData_li[i]['labels'] if lab is None and not includeUnvisited: self.enqAutosave() @@ -22334,27 +22443,20 @@ def changeIDfutureFrames( lab = self.get_2Dlab(posData.lab) else: lab = posData.lab - + if self.onlyTracking: self.tracking(enforce=True) elif not posData.IDs: continue else: - maxID = max(posData.IDs, default=0) + 1 for old_ID, new_ID in oldIDnewIDMapper: - if new_ID in lab: - tempID = maxID + 1 # lab.max() + 1 - lab[lab == old_ID] = tempID - lab[lab == new_ID] = old_ID - lab[lab == tempID] = new_ID - maxID += 1 - else: - lab[lab == old_ID] = new_ID - + self._changeIDhelper(lab, old_ID, new_ID, + posData.rp, assignments) + if shift and self.isSegm3D: self.set_2Dlab(lab) - - self.update_rp(draw=False) + + self.update_rp(draw=False,assignments=assignments if not (shift and self.isSegm3D) else None) self.store_data(autosave=i==endFrame_i) elif includeUnvisited: # Unvisited frame (includeUnvisited = True) @@ -22363,18 +22465,19 @@ def changeIDfutureFrames( lab = self.get_2Dlab(lab) else: lab = lab - + + # get rp from allData_li... Its already init in core.countObjects + assignments = {} + rp = posData.allData_li[i]['regionprops'] for old_ID, new_ID in oldIDnewIDMapper: - if new_ID in lab: - tempID = lab.max() + 1 - lab[lab == old_ID] = tempID - lab[lab == new_ID] = old_ID - lab[lab == tempID] = new_ID - else: - lab[lab == old_ID] = new_ID - + self._changeIDhelper(lab, old_ID, new_ID, + rp, assignments) + if shift and self.isSegm3D: posData.segm_data[i][self.z_lab()] = lab + rp.update_regionprops(lab) + else: + rp.update_regionprops_via_assignments(assignments, lab) # Back to current frame posData.frame_i = self.current_frame_i @@ -22646,14 +22749,11 @@ def drawObjMothBudLines(self, obj, posData, ax=0): scatterItem = self.getMothBudLineScatterItem(ax, isNew) relative_ID = cca_df_ID['relative_ID'] - try: - relative_rp_idx = posData.IDs_idxs[relative_ID] - except KeyError: - return - - relative_ID_obj = posData.rp[relative_rp_idx] - y1, x1 = self.getObjCentroid(obj.centroid) - y2, x2 = self.getObjCentroid(relative_ID_obj.centroid) + relative_ID_obj = posData.rp.get_obj_from_ID(relative_ID) + obj_centroid = posData.rp.get_centroid(ID) + rel_obj_centroid = posData.rp.get_centroid(relative_ID) + y1, x1 = self.getObjCentroid(obj_centroid) + y2, x2 = self.getObjCentroid(rel_obj_centroid) xx, yy = core.get_line(y1, x1, y2, x2, dashed=True) scatterItem.addPoints(xx, yy) @@ -22695,14 +22795,14 @@ def drawAllLineageTreeLines(self): continue for ID in new_cells: - curr_obj = myutils.get_obj_by_label(rp, ID) + curr_obj = rp.get_obj_from_ID(ID) lin_tree_df_ID = lin_tree_df.loc[ID] # lin_tree_df_mother_ID = lin_tree_df_prev.loc[lin_tree_df_ID["parent_ID_tree"]] if lin_tree_df_ID["parent_ID_tree"] == -1: # make sure that new obj where the parents are not known get skipped continue - mother_obj = myutils.get_obj_by_label(prev_rp, lin_tree_df_ID["parent_ID_tree"]) + mother_obj = prev_rp.get_obj_from_ID(lin_tree_df_ID["parent_ID_tree"]) emerg_frame_i = lin_tree_df_ID["emerg_frame_i"] isNew = emerg_frame_i == frame_i @@ -22738,9 +22838,15 @@ def drawObjLin_TreeMothBudLines(self, ax, obj, mother_obj, isNew, ID=None): return scatterItem = self.getMothBudLineScatterItem(ax, isNew) - - y1, x1 = self.getObjCentroid(obj.centroid) - y2, x2 = self.getObjCentroid(mother_obj.centroid) + + posData = self.data[self.pos_i] + prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] + rp = posData.rp + if ID is None: + ID = obj.label + ID_mother = mother_obj.label + y1, x1 = self.getObjCentroid(rp.get_centroid(ID)) + y2, x2 = self.getObjCentroid(prev_rp.get_centroid(ID_mother)) xx, yy = core.get_line(y1, x1, y2, x2, dashed=True) scatterItem.addPoints(xx, yy) @@ -22787,36 +22893,42 @@ def store_zslices_rp(self, force_update=False): posData.allData_li[posData.frame_i]['z_slices_rp'] = posData.zSlicesRp - def removeObjectFromRp(self, delID): + def removeObjectFromRp(self, delIDs): posData = self.data[self.pos_i] - rp = [] - IDs = [] - IDs_idxs = {} - idx = 0 - for obj in posData.rp: - if obj.label == delID: - continue - rp.append(obj) - IDs.append(obj.label) - IDs_idxs[obj.label] = idx - idx += 1 + if not isinstance(delIDs, (list, set)): + delIDs = [delIDs] + + posData.rp.update_regionprops_via_deletions(set(delIDs)) + posData.IDs = posData.rp.IDs + + # rp = [] + # IDs = [] + # IDs_idxs = {} + # idx = 0 + # for obj in posData.rp: + # if obj.label == delID: + # continue + # rp.append(obj) + # IDs.append(obj.label) + # IDs_idxs[obj.label] = idx + # idx += 1 - posData.rp = rp - posData.IDs = IDs - posData.IDs_idxs = IDs_idxs + # posData.rp = rp + # posData.IDs = IDs + # posData.IDs_idxs = IDs_idxs - if not self.isSegm3D: - return + # if not self.isSegm3D: + # return - zSlicesRp = {} - for z, zSliceRp in posData.zSlicesRp.items(): - if delID in zSliceRp: - continue + # zSlicesRp = {} + # for z, zSliceRp in posData.zSlicesRp.items(): + # if delID in zSliceRp: + # continue - zSlicesRp[z] = zSlicesRp + # zSlicesRp[z] = zSlicesRp - posData.zSlicesRp = zSlicesRp - self.store_zslices_rp(force_update=True) + # posData.zSlicesRp = zSlicesRp + # self.store_zslices_rp(force_update=True) def get_zslices_rp(self): if not self.isSegm3D: @@ -22834,7 +22946,7 @@ def _update_zslices_rp(self): posData = self.data[self.pos_i] posData.zSlicesRp = {} for z, lab2d in enumerate(posData.lab): - lab2d_rp = skimage.measure.regionprops(lab2d) + lab2d_rp = regionprops.acdcRegionprops(lab2d, precache_centroids=False) posData.zSlicesRp[z] = {obj.label:obj for obj in lab2d_rp} def instructHowDeleteID(self): @@ -22878,7 +22990,7 @@ def checkWarnDeletedIDwithEraser(self): for ID in self.erasedIDs: if ID == 0: continue - if ID in posData.IDs_idxs: + if posData.rp.get_obj_from_ID(ID, warn=False) is not None: continue self.instructHowDeleteID() @@ -22892,34 +23004,224 @@ def checkWarnDeletedIDwithEraser(self): return True return False + + def _get_entire_depth_axis_from_2D_cutout(self, cutout): + # cutout = (xl, xr), (yt, yb), z is always on the y position if depth axis is changed + # cutout is in the current view; return grouped ranges in the order + # expected by update_rp_get_bbox before conversion to NumPy bbox order. + posData = self.data[self.pos_i] + if self.isSegm3D: + depthAxes = self.switchPlaneCombobox.depthAxes() + if depthAxes == 'z': + # cutout is (x, y) and we prepend the full z range. + z_max = posData.SizeZ + return ((0, z_max), cutout[0], cutout[1]) + if depthAxes == 'y': + # cutout is (x, z); convert to (z, x, y). + y_max = posData.SizeY + return (cutout[1], cutout[0], (0, y_max)) + elif depthAxes == 'x': + # cutout is (y, z); convert to (z, x, y). + x_max = posData.SizeX + return (cutout[1], (0, x_max), cutout[0]) + else: + return cutout + + def _cutout_to_bbox(self, cutout): + """ + Convert grouped view ranges into a flat bbox in NumPy array order. + 2D input: ((x_min, x_max), (y_min, y_max)) → (y_min, x_min, y_max, x_max) + 3D input: ((z_min, z_max), (y_min, y_max), (x_min, x_max)) → (z_min, y_min, x_min, z_max, y_max, x_max) + """ + cutout = tuple( + (min(r[0], r[1]), max(r[0], r[1])) for r in cutout + ) + if self.isSegm3D: + (z_min, z_max), (y_min, y_max), (x_min, x_max) = cutout + return (z_min, y_min, x_min, z_max, y_max, x_max) + else: + (x_min, x_max), (y_min, y_max) = cutout + return (y_min, x_min, y_max, x_max) + + def _get_perc_cutout_from_total_img(self, cutout): + posData = self.data[self.pos_i] + single_timepoint_segm_size = posData.getSingleTimepointSegmSize() + if self.isSegm3D: + size = (cutout[0][1] - cutout[0][0]) * (cutout[1][1] - cutout[1][0]) * (cutout[2][1] - cutout[2][0]) + else: + size = (cutout[0][1] - cutout[0][0]) * (cutout[1][1] - cutout[1][0]) + return size / single_timepoint_segm_size + + def update_rp_get_bbox(self, custom_bbox=None, use_bbox=False, use_curr_view=False, + specific_IDs=None, add_frac_custom_bbox=0.05): + """ + Returns an expanded bounding box (bbox) for the given IDs or custom_bbox. + Returns False if not enough cells or cutout is too large. + """ + posData = self.data[self.pos_i] + if len(posData.rp.IDs) < RP_OPT_NUM_CELLS_MIN: + return False + if not isinstance(specific_IDs, (list, set, np.ndarray)) and specific_IDs is not None: + specific_IDs = [specific_IDs] + elif specific_IDs is not None and len(specific_IDs) == 0: + specific_IDs = None + + # Helper to merge bboxes + def merge_bbox(b1, b2): + if len(b1) == 4: + return ( + min(b1[0], b2[0]), min(b1[1], b2[1]), + max(b1[2], b2[2]), max(b1[3], b2[3]) + ) + else: + return ( + min(b1[0], b2[0]), min(b1[1], b2[1]), min(b1[2], b2[2]), + max(b1[3], b2[3]), max(b1[4], b2[4]), max(b1[5], b2[5]) + ) + + bbox = None + if custom_bbox or use_bbox: + if not custom_bbox and use_bbox and specific_IDs is not None: + rp_old = posData.rp + for ID in specific_IDs: + b = rp_old.get_obj_from_ID(ID).bbox + bbox = b if bbox is None else merge_bbox(bbox, b) + else: + bbox = custom_bbox + + if bbox is None: + return False + + elif use_curr_view: + cutout = self.ax1ViewRange(integers=True) + cutout = self._get_entire_depth_axis_from_2D_cutout(cutout) + if len(cutout)==2: + (xl, xr), (yt, yb) = cutout + else: + (z1, z2), (xl, xr), (yt, yb) = cutout + z_min = min(z1, z2) + z_max = max(z1, z2) + x_min = min(xl, xr) + x_max = max(xl, xr) + y_min = min(yt, yb) + y_max = max(yt, yb) + bbox = (y_min, x_min, y_max, x_max) if len(cutout)==2 else (z_min, y_min, x_min, z_max, y_max, x_max) + # Expand bbox by a fraction + else: + raise ValueError('''Either custom_bbox or use_bbox or use_curr_view must be provided as True''') + + if len(bbox) == 4: + y_min, x_min, y_max, x_max = bbox + offset_y = int((y_max - y_min) * add_frac_custom_bbox) + offset_x = int((x_max - x_min) * add_frac_custom_bbox) + offset_y = 1 if offset_y == 0 else offset_y + offset_x = 1 if offset_x == 0 else offset_x + size_y, size_x = posData.SizeY, posData.SizeX + cutout = ( + (max(0, x_min - offset_x), min(size_x, x_max + offset_x)), + (max(0, y_min - offset_y), min(size_y, y_max + offset_y)) + ) + else: + z_min, y_min, x_min, z_max, y_max, x_max = bbox + offset_z = int((z_max - z_min) * add_frac_custom_bbox) + offset_y = int((y_max - y_min) * add_frac_custom_bbox) + offset_x = int((x_max - x_min) * add_frac_custom_bbox) + offset_z = 1 if offset_z == 0 else offset_z + offset_y = 1 if offset_y == 0 else offset_y + offset_x = 1 if offset_x == 0 else offset_x + size_z, size_y, size_x = posData.SizeZ, posData.SizeY, posData.SizeX + cutout = ( + (max(0, z_min - offset_z), min(size_z, z_max + offset_z)), + (max(0, y_min - offset_y), min(size_y, y_max + offset_y)), + (max(0, x_min - offset_x), min(size_x, x_max + offset_x)) + ) + + perc_from_global = self._get_perc_cutout_from_total_img(cutout) + if perc_from_global > RP_OPT_PERC_CUTOUT_MAX: + return False + return self._cutout_to_bbox(cutout) @exception_handler def update_rp( - self, draw=True, debug=False, update_IDs=True, - wl_update=True, wl_track_og_curr=False,wl_update_lab=False + self, draw=True, debug=False, # og stuff + assignments=None, deletionIDs=None, # very quick upates, rp labels are changed but rest is same + specific_IDs=None, use_curr_view=False, use_bbox=False, preloaded_bbox=None, # for local updates to PR + wl_update=True, wl_track_og_curr=False,wl_update_lab=False, # wl stuff ): + """Updates posData.rp + Parameters + ---------- + + """ + #updating rp is very clostly, as it deletes all the cashed + if use_curr_view and use_bbox: + raise ValueError('''use_curr_view and use_bbox cannot be True at the + same time, as they are mutually exclusive''') + local_rp_update = bool(use_curr_view or use_bbox or preloaded_bbox) posData = self.data[self.pos_i] # Update rp for current posData.lab (e.g. after any change) - if wl_update: if self.whitelistOriginalIDs is None: - old_IDs = posData.allData_li[posData.frame_i]['IDs'].copy() # for whitelist stuff + old_IDs = posData.allData_li[posData.frame_i]['regionprops'].IDs.copy() # for whitelist stuff else: old_IDs = self.whitelistOriginalIDs.copy() self.whitelistOriginalIDs = None elif self.whitelistOriginalIDs is None: - self.whitelist_old_IDs = posData.allData_li[posData.frame_i]['IDs'].copy() - - posData.rp = skimage.measure.regionprops(posData.lab) - if update_IDs: - IDs = [] - IDs_idxs = {} - for idx, obj in enumerate(posData.rp): - IDs.append(obj.label) - IDs_idxs[obj.label] = idx - posData.IDs = IDs - posData.IDs_idxs = IDs_idxs + self.whitelist_old_IDs = posData.allData_li[posData.frame_i]['regionprops'].IDs.copy() + + # check if only one of assignments, deletionIDs or only_current_view is given + if sum([assignments is not None, + deletionIDs is not None, + local_rp_update, + ]) > 1: + print(assignments is not None, deletionIDs is not None, local_rp_update) + raise ValueError('Only one of assignments, deletionIDs, ' + 'use_curr_view or use_bbox, preloaded_bbox can be used ' + 'at a time') + + if not isinstance(specific_IDs, (list, set, np.ndarray)) and specific_IDs is not None: + specific_IDs = [specific_IDs] + elif specific_IDs is not None and len(specific_IDs) == 0: + specific_IDs = None + + + # posData.rp is an acdcRegionprops instance here. + # if rp is None (can sometimes happen appearantly???) + if posData.rp is None: + printl(f'''Warning: posData.rp is None for pos {self.pos_i}, + frame {posData.frame_i}. Recomputing rp from labels.''') + + posData.rp = regionprops.acdcRegionprops(posData.lab) + + if assignments is not None: + # {old_ID: new_ID, ...} + posData.rp.update_regionprops_via_assignments(assignments, posData.lab) + elif deletionIDs is not None: + # (delID1, delID2, ...) + posData.rp.update_regionprops_via_deletions(deletionIDs) + elif local_rp_update: + # first get current view + if preloaded_bbox is None: + preloaded_bbox = self.update_rp_get_bbox(use_bbox=use_bbox, use_curr_view=use_curr_view, + specific_IDs=specific_IDs) + if preloaded_bbox is not False: + posData.rp.update_regionprops_via_cutout( + posData.lab, cutout_bbox=preloaded_bbox, specific_IDs=specific_IDs + ) + # if ID touches border but is not in specific_IDs, it will not be updated, + # so be careful! + else: + posData.rp.update_regionprops( + posData.lab + ) + else: + posData.rp.update_regionprops( + posData.lab, + specific_IDs_update_centroids=specific_IDs if preloaded_bbox is not False else None, # since sometimes I preload + ) + posData.IDs = posData.rp.IDs + self.update_rp_metadata(draw=draw) self.store_zslices_rp(force_update=True) @@ -23016,12 +23318,11 @@ def updateTempLayerKeepIDs(self): def highlightLabelID(self, ID, ax=0): posData = self.data[self.pos_i] - try: - obj = posData.rp[posData.IDs_idxs[ID]] - except KeyError: + obj = posData.rp.get_obj_from_ID(ID, warn=False) + if obj is None: return - self.textAnnot[ax].highlightObject(obj) + self.textAnnot[ax].highlightObject(obj, rp=posData.rp, getObjCentroidFunc=self.getObjCentroid) def _keepObjects(self, keepIDs=None, lab=None, rp=None): posData = self.data[self.pos_i] @@ -23051,7 +23352,7 @@ def removeHighlightLabelID(self, IDs=None, ax=0): IDs = posData.IDs for ID in IDs: - obj = posData.rp[posData.IDs_idxs[ID]] + obj = posData.rp.get_obj_from_ID(ID) self.textAnnot[ax].removeHighlightObject(obj) def updateKeepIDs(self, IDs): @@ -23093,7 +23394,7 @@ def applyKeepObjects(self): posData = self.data[self.pos_i] - self.update_rp() + self.update_rp() # why here? # Repeat tracking self.tracking(enforce=True, assign_unique_new_IDs=False) @@ -23148,6 +23449,7 @@ def applyKeepObjects(self): posData.frame_i = self.current_frame_i self.get_data() + # no rp update here? # Ask to propagate change to all future visited frames key = 'Keep ID' @@ -23287,7 +23589,8 @@ def annotate_rip_and_bin_IDs(self, updateLabel=False): continue if obj.excluded: - y, x = self.getObjCentroid(obj.centroid) + ID = obj.label + y, x = self.getObjCentroid(posData.rp.get_centroid(ID)) binnedIDs_xx.append(x) binnedIDs_yy.append(y) if updateLabel: @@ -23295,7 +23598,8 @@ def annotate_rip_and_bin_IDs(self, updateLabel=False): how = self.drawIDsContComboBox.currentText() if obj.dead: - y, x = self.getObjCentroid(obj.centroid) + ID = obj.label + y, x = self.getObjCentroid(posData.rp.get_centroid(ID)) ripIDs_xx.append(x) ripIDs_yy.append(y) if updateLabel: @@ -23936,10 +24240,8 @@ def zoomToObj(self, obj=None): posData = self.data[self.pos_i] if obj is None: ID = self.sender().value() - try: - ID_idx = posData.IDs_idxs[ID] - obj = obj = posData.rp[ID_idx] - except Exception as e: + obj = posData.rp.get_obj_from_ID(ID, warn=False) + if obj is None: self.logger.warning( f'ID {ID} does not exist (add points by clicking)' ) @@ -23991,7 +24293,7 @@ def pointsLayerAutoPilot(self, direction): return try: - ID_idx = posData.IDs_idxs[ID] + ID_idx = posData.rp.ID_to_idx[ID] if direction == 'next': nextID_idx = ID_idx + 1 else: @@ -24040,7 +24342,7 @@ def checkLoadedTableIds(self, toolbar): for posData in self.data: for tableEndName, df in posData.clickEntryPointsDfs.items(): for point_id in df['id'].values: - if point_id in posData.IDs_idxs: + if point_id in posData.rp.IDs: proceed = self.warnAddingPointWithExistingId( point_id, table_endname=tableEndName ) @@ -24260,10 +24562,10 @@ def setHoverCircleAddPoint(self, x, y): def isPointIdAlreadyNew(self, point_id, action): posData = self.data[self.pos_i] - if point_id in posData.IDs_idxs: + if point_id in posData.rp.IDs: return False - is_ID = point_id in posData.IDs_idxs + is_ID = point_id in posData.rp.IDs pointsDataPos = action.pointsData.get(self.pos_i) if pointsDataPos is None: return not is_ID @@ -24407,6 +24709,8 @@ def getCentroidsPointsData(self, action): # Centroids (either weighted or not) # NOTE: if user requested to draw from table we load that in # apps.AddPointsLayerDialog.ok_cb() + + # this does not have the updated centroid logic to avoid weird behaviours posData = self.data[self.pos_i] action.pointsData[self.pos_i] = {posData.frame_i: {}} if hasattr(action, 'weighingData'): @@ -25087,7 +25391,8 @@ def computeAllContours(self): rp = dataDict['regionprops'] if rp is None: - rp = skimage.measure.regionprops(lab) + rp = regionprops.acdcRegionprops(lab, precache_centroids=False) + dataDict['regionprops'] = rp dataDict['contours'] = {} for obj in rp: @@ -26430,7 +26735,8 @@ def updateCcaDfDeletedIDsTimelapse( else: for delID in deletedIDs: dataDict = posData.allData_li[fut_frame_i] - delIDexists = dataDict['IDs_idxs'].get(delID, False) + rp = dataDict['regionprops'] + delIDexists = delID in rp.IDs if not delIDexists: continue @@ -26467,7 +26773,8 @@ def updateCcaDfDeletedIDsTimelapse( else: for delID in deletedIDs: dataDict = posData.allData_li[past_frame_i] - delIDexists = dataDict['IDs_idxs'].get(delID, False) + rp = dataDict['regionprops'] + delIDexists = delID in rp.IDs if not delIDexists: continue @@ -26763,8 +27070,7 @@ def highlightHoverID(self, x, y, hoverID=None): return posData = self.data[self.pos_i] - objIdx = posData.IDs_idxs[hoverID] - obj = posData.rp[objIdx] + obj = posData.rp.get_obj_from_ID(hoverID) self.goToZsliceSearchedID(obj) self.highlightSearchedID(hoverID) @@ -26828,12 +27134,10 @@ def highlightHoverIDsKeptObj(self, x, y, hoverID=None): return posData = self.data[self.pos_i] - try: - objIdx = posData.IDs_idxs[hoverID] - except KeyError as err: - return + obj = posData.rp.get_obj_from_ID(hoverID, warn=False) + if obj is None: + return - obj = posData.rp[objIdx] self.goToZsliceSearchedID(obj) for ID in self.keptObjectsIDs: @@ -26898,11 +27202,10 @@ def highlightSearchedID(self, ID, force=False, greyOthers=True): self.highlightedID = ID self.highlightIDToolbar.setVisible(True) - objIdx = posData.IDs_idxs.get(ID) - if objIdx is None: + obj = posData.rp.get_obj_from_ID(ID, warn=False) + if obj is None: return - obj = posData.rp[objIdx] isObjVisible = self.isObjVisible(obj.bbox) if not isObjVisible: return @@ -27150,7 +27453,7 @@ def setManualBackgroundImage(self): def setManualBackgrounNextID(self): posData = self.data[self.pos_i] currentID = self.manualBackgroundObj.label - idx = posData.IDs_idxs[currentID] + idx = posData.rp.ID_to_idx[currentID] next_idx = idx + 1 if next_idx >= len(posData.IDs): return @@ -27226,7 +27529,9 @@ def updateContoursImage(self, ax, delROIsIDs=None, compute=True): self.contoursImage[:] = 0 contours = [] - for obj in skimage.measure.regionprops(self.currentLab2D): + lab = self.currentLab2D + rp = skimage.measure.regionprops(lab) # any chance we dont need to update here? + for obj in rp: obj_contours = self.getObjContours( obj, all_external=True, @@ -27245,13 +27550,10 @@ def setContoursImage(self, imageItem, contours, thickness, color): def getObjFromID(self, ID): posData = self.data[self.pos_i] - try: - idx = posData.IDs_idxs[ID] - except KeyError as e: + obj = posData.rp.get_obj_from_ID(ID, warn=False) + if obj is None: # Object already cleared return - - obj = posData.rp[idx] return obj def setLostObjectContour(self, obj): @@ -27291,7 +27593,6 @@ def updateLostContoursImage(self, ax, draw=True, delROIsIDs=None): posData = self.data[self.pos_i] prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[posData.frame_i-1]['IDs_idxs'] if posData.whitelist is not None and posData.whitelist.whitelistIDs is not None: whitelist = posData.whitelist.whitelistIDs[posData.frame_i-1] else: @@ -27302,7 +27603,7 @@ def updateLostContoursImage(self, ax, draw=True, delROIsIDs=None): if lostID in delROIsIDs or (whitelist is not None and lostID not in whitelist): continue - obj = prev_rp[prev_IDs_idxs[lostID]] + obj = prev_rp.get_obj_from_ID(lostID) if not self.isObjVisible(obj.bbox): continue @@ -27347,13 +27648,12 @@ def updateLostTrackedContoursImage( tracked_lost_IDs = self.getTrackedLostIDs() prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[posData.frame_i-1]['IDs_idxs'] contours = [] for tracked_lost_ID in tracked_lost_IDs: if tracked_lost_ID in delROIsIDs: continue - obj = prev_rp[prev_IDs_idxs[tracked_lost_ID]] + obj = prev_rp.get_obj_from_ID(tracked_lost_ID) if not self.isObjVisible(obj.bbox): continue @@ -27420,8 +27720,11 @@ def setCcaIssueContour(self, obj): xx = cont[:,0] + 0.5 yy = cont[:,1] + 0.5 self.ax1_lostObjScatterItem.addPoints(xx, yy) + + posData = self.data[self.pos_i] self.textAnnot[0].addObjAnnotation( - obj, 'lost_object', f'{obj.label}?', False + obj, 'lost_object', f'{obj.label}?', False, + rp=posData.rp, getObjCentroidFunc=self.getObjCentroid ) def isLastVisitedAgainCca(self, curr_df, enforceAll=False): @@ -27469,7 +27772,8 @@ def highlightNewCellNotEnoughG1cells(self, IDsCellsG1): yy = objContours[:,1] + 0.5 self.ccaFailedScatterItem.addPoints(xx, yy) self.textAnnot[0].addObjAnnotation( - obj, 'green', f'{obj.label}?', False + obj, 'green', f'{obj.label}?', False, + rp=posData.rp, getObjCentroidFunc=self.getObjCentroid ) def handleNoCellsInG1(self, numCellsG1, numNewCells): @@ -27568,7 +27872,7 @@ def setAllTextAnnotations(self, labelsToSkip=None): def setAllContoursImages(self, delROIsIDs=None, compute=True): if compute: self.computeAllContours() - self.updateContoursImage(ax=0, delROIsIDs=delROIsIDs, compute=compute) + self.updateContoursImage(ax=0, delROIsIDs=delROIsIDs, compute=compute) #almost all from here self.updateContoursImage(ax=1, delROIsIDs=delROIsIDs, compute=compute) def setAllLostObjContoursImage(self, delROIsIDs=None): @@ -27694,7 +27998,6 @@ def keyDownCallback( QAbstractSlider.SliderAction.SliderSingleStepSub ) - # @exec_time @exception_handler def updateAllImages( self, image=None, computePointsLayers=True, computeContours=True, @@ -27771,17 +28074,15 @@ def deleteIDFromLab(self, lab, delID, frame_i=None, delMask=None): if frame_i==posData.frame_i: rp = posData.rp - IDs_idxs = posData.IDs_idxs else: rp = posData.allData_li[frame_i]['regionprops'] - IDs_idxs = posData.allData_li[frame_i]['IDs_idxs'] if isinstance(delID, int): delID = [delID] is_any_id_present = False for _delID in delID: - if _delID in IDs_idxs: + if _delID in rp.IDs: is_any_id_present = True break @@ -27794,10 +28095,9 @@ def deleteIDFromLab(self, lab, delID, frame_i=None, delMask=None): delMask[:] = False for _delID in delID: - idx = IDs_idxs.get(_delID, None) - if idx is None: + if _delID not in rp.IDs: continue - obj = rp[idx] + obj = rp.get_obj_from_ID(_delID) delMask[obj.slice][obj.image] = True lab[delMask] = 0 return lab, delMask @@ -27910,7 +28210,7 @@ def setOverlayLabelsItems(self, specific=None): imageItem, contoursItem, gradItem = items contoursItem.clear() if drawMode == 'Draw contours': - for obj in skimage.measure.regionprops(ol_lab): + for obj in skimage.measure.regionprops(ol_lab): #TODO contour opt contours = self.getObjContours( obj, all_external=True ) @@ -28075,8 +28375,7 @@ def highlightHoverLostObj(self, modifiers, event): self.ax1_lostObjScatterItem.setData([], []) else: prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[posData.frame_i-1]['IDs_idxs'] - lostObj = prev_rp[prev_IDs_idxs[hoverLostID]] + lostObj = prev_rp.get_obj_from_ID(hoverLostID) obj_contours = self.getObjContours(lostObj, all_external=True) for cont in obj_contours: xx = cont[:,0] @@ -28097,10 +28396,12 @@ def getPrevFrameIDs(self, current_frame_i=None): if current_frame_i is None: return [] - prev_frame_i = current_frame_i - 1 - prevIDs = posData.allData_li[prev_frame_i]['IDs'] + if current_frame_i == 0: + return [] - if prevIDs: + prev_frame_i = current_frame_i - 1 + if posData.allData_li[prev_frame_i]['regionprops'] is not None: + prevIDs = posData.allData_li[prev_frame_i]['regionprops'].IDs return prevIDs # IDs in previous frame were not stored --> load prev lab from HDD @@ -28109,8 +28410,9 @@ def getPrevFrameIDs(self, current_frame_i=None): frame_i=prev_frame_i, return_copy=False ) - rp = skimage.measure.regionprops(prev_lab) - prevIDs = [obj.label for obj in rp] + rp = regionprops.acdcRegionprops(prev_lab) + posData.allData_li[prev_frame_i]['regionprops'] = rp + prevIDs = rp.IDs return prevIDs # @exec_time @@ -28325,105 +28627,86 @@ def trackManuallyAddedObject( added_IDs = [added_IDs] posData = self.data[self.pos_i] - tracked_lab = self.tracking( + tracked_lab, assignments = self.tracking( enforce=True, assign_unique_new_IDs=False, return_lab=True, - IDs=added_IDs + specific_IDs=added_IDs, return_assignments=True, + against_next=posData.frame_i==0 ) + + # RP not updated after tracking!!! self.clearAssignedObjsSecondStep() if tracked_lab is None: return # Track only new object - prevIDs = posData.allData_li[posData.frame_i-1]['IDs'] - - # mask = np.zeros(posData.lab.shape, dtype=bool) - update_rp = False - + prevIDs = posData.allData_li[posData.frame_i-1]['regionprops'].IDs + + # assignments_new = dict() + # self.update_rp(assignments=assignments) for added_ID in added_IDs: - # try: - # obj = posData.rp[added_ID] # ID not present - # mask[obj.slice][obj.image] = True - - # except IndexError as err: - mask = posData.lab == added_ID + + # check if added ID is already present + # here PR is "stale" so ID maps are not tracked + obj = posData.rp.get_obj_from_ID(added_ID, warn=False) + if obj is None: + continue try: - trackedID = tracked_lab[mask][0] + trackedID = tracked_lab[obj.slice][obj.image][0] except IndexError as err: # added_ID is not present continue isTrackedIDalreadyPresentAndNotNew = ( - posData.IDs_idxs.get(trackedID) is not None + posData.rp.ID_to_idx.get(trackedID) is not None and added_ID != trackedID ) if isTrackedIDalreadyPresentAndNotNew: + self.updatePointsLayerClickEntryTableEndname( + 'added obj already present', added_ID, trackedID + ) continue isTrackedIDinPrevIDs = trackedID in prevIDs if isTrackedIDinPrevIDs: - posData.lab[mask] = trackedID + posData.lab[obj.slice][obj.image] = trackedID else: # New object where we can try to track against next frame - trackedID = self.trackNewIDtoNewIDsFutureFrame(added_ID, mask) + trackedID, assignments = self.trackNewIDtoNewIDsFutureFrame(added_ID, obj, assignments) if trackedID is None: self.clearAssignedObjsSecondStep() continue - posData.lab[mask] = trackedID + posData.lab[obj.slice][obj.image] = trackedID self.keepOnlyNewIDAssignedObjsSecondStep(trackedID) - update_rp = True - if update_rp: - self.update_rp(wl_update=wl_update) - + self.update_rp(wl_update=wl_update, assignments=assignments) + def trackFrameCustomTracker( - self, prev_lab, currentLab, IDs=None, unique_ID=None + self, prev_lab, currentLab, specific_IDs=None, unique_ID=None, + return_assignments=True, dont_return_tracked_lab=False ): if unique_ID is None: unique_ID = self.setBrushID() - try: - tracked_result = self.realTimeTracker.track_frame( - prev_lab, currentLab, - unique_ID=unique_ID, - IDs=IDs, - **self.track_frame_params, - ) - except TypeError as err: - if str(err).find('an unexpected keyword argument \'unique_ID\'') != -1: - try: - tracked_result = self.realTimeTracker.track_frame( - prev_lab, currentLab, IDs=IDs, - **self.track_frame_params - ) - except TypeError as err: - if str(err).find('an unexpected keyword argument \'IDs\'') != -1: - tracked_result = self.realTimeTracker.track_frame( - prev_lab, currentLab, - **self.track_frame_params) - else: - raise err - elif str(err).find('an unexpected keyword argument \'IDs\'') != -1: - try: - tracked_result = self.realTimeTracker.track_frame( - prev_lab, currentLab, - unique_ID=unique_ID, - **self.track_frame_params - ) - except TypeError as err: - if str(err).find('an unexpected keyword argument \'unique_ID\'') != -1: - tracked_result = self.realTimeTracker.track_frame( - prev_lab, currentLab, - **self.track_frame_params - ) - else: - raise err - else: - raise err + + kwargs_total = { + 'unique_ID': unique_ID, + 'return_assignments': return_assignments, + 'dont_return_tracked_lab': dont_return_tracked_lab, + 'specific_IDs': specific_IDs + } + kwargs_total.update(self.track_frame_params) + + kwargs = {k: v for k, v in kwargs_total.items() if k in self.realTimeTracker_kwargs} + tracked_result = self.realTimeTracker.track_frame( + prev_lab, currentLab, + **kwargs, + ) return tracked_result def trackFrame( self, prev_lab, prev_rp, curr_lab, curr_rp, curr_IDs, - assign_unique_new_IDs=True, IDs=None, unique_ID=None + assign_unique_new_IDs=True, specific_IDs=None, unique_ID=None, + dont_return_tracked_lab=False, return_assignments=False, ): if self.trackWithAcdcAction.isChecked(): tracked_result = CellACDC_tracker.track_frame( @@ -28432,8 +28715,10 @@ def trackFrame( setBrushID_func=self.setBrushID, posData=self.data[self.pos_i], assign_unique_new_IDs=assign_unique_new_IDs, - IDs=IDs, - unique_ID=unique_ID + specific_IDs=specific_IDs, + unique_ID=unique_ID, + return_assignments=return_assignments, + dont_return_tracked_lab=dont_return_tracked_lab ) elif self.trackWithYeazAction.isChecked(): tracked_result = self.tracking_yeaz.correspondence( @@ -28442,17 +28727,39 @@ def trackFrame( ) else: tracked_result = self.trackFrameCustomTracker( - prev_lab, curr_lab, IDs=IDs, unique_ID=unique_ID + prev_lab, curr_lab, specific_IDs=specific_IDs, unique_ID=unique_ID, + dont_return_tracked_lab=dont_return_tracked_lab, return_assignments=return_assignments ) # Check if tracker also returns additional info + assignments = None if isinstance(tracked_result, tuple): - tracked_lab, tracked_lost_IDs = tracked_result - self.handleAdditionalInfoRealTimeTracker(prev_rp, tracked_lost_IDs) + tracked_lab, add_info = tracked_result + assignments = self.handleAdditionalInfoRealTimeTracker(prev_rp, add_info) + elif isinstance(tracked_result, dict) and dont_return_tracked_lab: + add_info = tracked_result + if 'assignments' in add_info: # if still entire add_info is returned + assignments = self.handleAdditionalInfoRealTimeTracker(prev_rp, add_info) + else: + assignments = add_info # its just assignements else: tracked_lab = tracked_result - return tracked_lab + printl(assignments) + if not return_assignments and not dont_return_tracked_lab: + return tracked_lab + + # get assignments + if assignments is None: + assignments = dict() + for obj in curr_rp: + old_lab = obj.label + new_lab = tracked_lab[obj.slice][obj.image][0] + assignments[old_lab] = new_lab + + if dont_return_tracked_lab: + return assignments + return tracked_lab, assignments def clearAssignedObjsSecondStep(self): posData = self.data[self.pos_i] @@ -28462,37 +28769,32 @@ def trackSubsetIDs(self, subsetIDs: Iterable[int]): posData = self.data[self.pos_i] if posData.frame_i == 0: return - - subsetLab = np.zeros_like(posData.lab) - for subsetID in subsetIDs: - subsetLab[posData.lab == subsetID] = subsetID prev_lab = posData.allData_li[posData.frame_i-1]['labels'] prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - tracked_lab = self.trackFrame( + assignments = self.trackFrame( prev_lab, prev_rp, posData.lab, posData.rp, posData.IDs, - assign_unique_new_IDs=True - ) - doUpdateRp = False - for subsetID in subsetIDs: - subsetIDmask = posData.lab == subsetID - trackedID = tracked_lab[subsetIDmask][0] - if trackedID == subsetID: - continue - - is_manually_edited = False - for y, x, new_ID in posData.editID_info: - if new_ID == subsetID: + assign_unique_new_IDs=True, specific_IDs=subsetIDs, + dont_return_tracked_lab=True + ) + # I think assignments already avoids merging + assignments_new = dict() + for old_ID, new_ID in assignments.items(): + # get "old" id based on assignments + if old_ID == new_ID: + continue # nothing to do + + for y, x, editID in posData.editID_info: + if editID == old_ID or editID == new_ID: # Do not track because it was manually edited - break + continue - posData.lab[subsetIDmask] = tracked_lab[subsetIDmask] - doUpdateRp = True - - if not doUpdateRp: - return + + obj = posData.rp.get_obj_from_ID(old_ID) # pr is still old, so we need to get the old ID + posData.lab[obj.slice][obj.image] = new_ID + assignments_new[old_ID] = new_ID # old ID : new tracked ID - self.update_rp() + self.update_rp(assignments=assignments_new) def doSkipTracking(self, against_next: bool, enforce: bool): if self.isSnapshot: @@ -28541,13 +28843,13 @@ def tracking( storeUndo=False, prev_lab=None, prev_rp=None, return_lab=False, assign_unique_new_IDs=True, separateByLabel=True, wl_update=True, - IDs=None, against_next=False, + against_next=False, specific_IDs=None , return_assignments=False ): posData = self.data[self.pos_i] - + return_tuple = (None, None) if return_assignments and return_lab else None if self.doSkipTracking(against_next, enforce): self.setLostNewOldPrevIDs() - return + return return_tuple """Tracking starts here""" staturBarLabelText = self.statusBarLabel.text() @@ -28581,41 +28883,54 @@ def tracking( if posData.frame_i < self.get_last_tracked_i(): unique_ID = self.setBrushID(return_val=True) - tracked_lab = self.trackFrame( + tracked_lab, assignments = self.trackFrame( prev_lab, prev_rp, posData.lab, posData.rp, posData.IDs, - assign_unique_new_IDs=assign_unique_new_IDs, IDs=IDs, - unique_ID=unique_ID + assign_unique_new_IDs=assign_unique_new_IDs, + unique_ID=unique_ID, specific_IDs=specific_IDs, + return_assignments=True ) if DoManualEdit: # Correct tracking with manually changed IDs - rp = skimage.measure.regionprops(tracked_lab) - IDs = [obj.label for obj in rp] - self.manuallyEditTracking(tracked_lab, IDs) + tracked_lab, assignments = self.manuallyEditTracking(tracked_lab, assignments) if return_lab: QTimer.singleShot(50, partial( self.statusBarLabel.setText, staturBarLabelText )) + if return_assignments: + return tracked_lab, assignments return tracked_lab # Update labels, regionprops and determine new and lost IDs posData.lab = tracked_lab - self.update_rp(wl_update=wl_update, ) + self.update_rp(wl_update=wl_update, assignments=assignments) self.setAllTextAnnotations() QTimer.singleShot(50, partial( self.statusBarLabel.setText, staturBarLabelText )) + if return_assignments and return_lab: + return tracked_lab, assignments + elif return_assignments: + return assignments + elif return_lab: + return tracked_lab - def handleAdditionalInfoRealTimeTracker(self, prev_rp, *args): + def handleAdditionalInfoRealTimeTracker(self, prev_rp, add_info): + assignments = None if self._rtTrackerName == 'CellACDC_normal_division': - tracked_lost_IDs = args[0] + tracked_lost_IDs = add_info['mothers'] self.setTrackedLostCentroids(prev_rp, tracked_lost_IDs) + assignments = add_info['assignments'] elif self._rtTrackerName == 'CellACDC_2steps': - if args[0] is None: - return - posData = self.data[self.pos_i] - posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = args[0] + assignments = add_info['assignments'] + if add_info['to_track_tracked_objs_2nd_step'] is not None: + posData = self.data[self.pos_i] + posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = add_info['to_track_tracked_objs_2nd_step'] + elif self._rtTrackerName == 'Cell-ACDC': + assignments = add_info['assignments'] + + return assignments def keepOnlyNewIDAssignedObjsSecondStep(self, trackedID): posData = self.data[self.pos_i] @@ -28711,12 +29026,24 @@ def setTrackedLostCentroids(self, prev_rp, tracked_lost_IDs): """ posData = self.data[self.pos_i] frame_i = posData.frame_i + prev_lab = posData.allData_li[frame_i-1]['labels'] for obj in prev_rp: if obj.label not in tracked_lost_IDs: continue - - int_centroid = tuple([int(val) for val in obj.centroid]) + if isinstance(prev_rp, regionprops.acdcRegionprops): + ID = obj.ID + centroid = prev_rp.get_centroid(ID, exact=True) + else: + centroid = obj.centroid + int_centroid = tuple([int(val) for val in centroid]) + # check if centroid has right ID + if prev_lab[int_centroid] != ID: + # get closest point with the right ID + coords = obj.coords + distances = np.sqrt(np.sum((coords - centroid) ** 2, axis=1)) + closest_idx = np.argmin(distances) + int_centroid = tuple([int(val) for val in coords[closest_idx]]) try: posData.tracked_lost_centroids[frame_i].add(int_centroid) except KeyError: @@ -28770,28 +29097,59 @@ def getTrackedLostIDs(self, prev_lab=None, IDs_in_frames=None, frame_i=None): posData.trackedLostIDs = trackedLostIDs return trackedLostIDs - - def manuallyEditTracking(self, tracked_lab, allIDs): + + def manuallyEditTracking(self, tracked_lab, assignments): posData = self.data[self.pos_i] infoToRemove = [] - # Correct tracking with manually changed IDs - maxID = max(allIDs, default=1) - for y, x, new_ID in posData.editID_info: - old_ID = tracked_lab[y, x] - if old_ID == 0 or old_ID == new_ID: - infoToRemove.append((y, x, new_ID)) + + if not assignments: + return tracked_lab, assignments + + # !!! RP is stale so we need to reverse search for the ID + reversed_assignments = ( + {tracked_id: stale_id for stale_id, tracked_id in assignments.items()} + if assignments else {} + ) + stale_ids = set(posData.rp.IDs) + + covered_edited_IDs = set() + for y, x, edited_ID in posData.editID_info: + new_ID = assignments.get(edited_ID, edited_ID) # ID in tracked lab + if new_ID in covered_edited_IDs: + # This ID has already been edited by sawpping for example + continue + + if new_ID == 0 or new_ID == edited_ID: # edited ID is not tracked to a different ID + infoToRemove.append((y, x, edited_ID)) continue - if new_ID in allIDs: - tempID = maxID+1 - tracked_lab[tracked_lab == old_ID] = tempID - tracked_lab[tracked_lab == new_ID] = old_ID - tracked_lab[tracked_lab == tempID] = new_ID + + old_RP_ID = reversed_assignments.get(edited_ID, edited_ID) # ID pre tracking + old_obj = posData.rp.get_obj_from_ID(old_RP_ID) # obj pre tracking + + if edited_ID in stale_ids: + # a swap has been made by the user between an old ID (old_RP_ID) and a new ID (edited_ID) + new_obj = posData.rp.get_obj_from_ID(edited_ID) + tracked_lab[old_obj.slice][old_obj.image] = edited_ID + tracked_lab[new_obj.slice][new_obj.image] = old_RP_ID + # update assignemnets + assignments[old_RP_ID] = edited_ID + assignments[edited_ID] = old_RP_ID + # add the two swapped IDs + + covered_edited_IDs.add(edited_ID) + covered_edited_IDs.add(old_RP_ID) + else: - tracked_lab[tracked_lab == old_ID] = new_ID - if new_ID > maxID: - maxID = new_ID + tracked_lab[old_obj.slice][old_obj.image] = edited_ID + + assignments[old_RP_ID] = edited_ID + + covered_edited_IDs.add(edited_ID) + for info in infoToRemove: posData.editID_info.remove(info) + + return tracked_lab, assignments def warnReinitLastSegmFrame(self): current_frame_n = self.navigateScrollBar.value() @@ -29927,8 +30285,12 @@ def initRealTimeTracker(self, force=False): rtTracker = aliases[rtTracker] if rtTracker == 'Cell-ACDC': + self._rtTrackerName = 'Cell-ACDC' + self.realTimeTracker_kwargs = None # This is hard coded return if rtTracker == 'YeaZ': + self._rtTrackerName = 'YeaZ' + self.realTimeTracker_kwargs = None # This is hard coded return if self.isRealTimeTrackerInitialized and not force: @@ -29946,6 +30308,7 @@ def initRealTimeTracker(self, force=False): self.realTimeTracker = realTimeTracker self.track_frame_params = track_frame_params + self.realTimeTracker_kwargs = inspect.signature(self.realTimeTracker.track_frame).parameters self.logger.info(f'{rtTracker} tracker successfully initialized.') if 'image_channel_name' in self.track_frame_params: # Remove the channel name since it was already loaded in init_tracker diff --git a/cellacdc/load.py b/cellacdc/load.py index 89560f18..5466a6b1 100755 --- a/cellacdc/load.py +++ b/cellacdc/load.py @@ -20,6 +20,7 @@ import zipfile from natsort import natsorted import time +import pickle import skimage import skimage.io @@ -1300,6 +1301,7 @@ def __init__( self.loadLastEntriesMetadata() self.attempFixBasenameBug() self.non_aligned_ext = '.tif' + self.segmMetadata = None if filename_ext.endswith('aligned.npz'): for file in myutils.listdir(self.images_path): if file.endswith(f'{user_ch_name}.h5'): @@ -1640,7 +1642,10 @@ def countObjectsInSegmTimelapse(self, categories: set[str] | list[str]): for frame_i in range(len(self.segm_data)): lab = self.allData_li[frame_i]['labels'] if lab is not None: - IDsFrame = self.allData_li[frame_i]['IDs'] + if hasattr(self.allData_li[frame_i]['regionprops'], 'IDs'): + IDsFrame = self.allData_li[frame_i]['regionprops'].IDs + else: + IDsFrame = [obj.label for obj in self.allData_li[frame_i]['regionprops']] if uniqueIDsVisited is not None: uniqueIDsVisited.update(IDsFrame) @@ -1860,6 +1865,7 @@ def loadOtherFiles( new_endname='', labelBoolSegm=None, load_whitelistIDs=False, + load_segm_info_ini=False ): self.segmFound = False if load_segm_data else None self.acdc_df_found = False if load_acdc_df else None @@ -2063,6 +2069,9 @@ def loadOtherFiles( if load_whitelistIDs: self.loadWhitelist() + + if load_segm_info_ini: + self.readSegmMetadataIni() def checkAndFixZsliceSegmInfo(self): if not hasattr(self, 'segmInfo_df'): @@ -2316,14 +2325,14 @@ def fromTrackerToAcdcDf( rp = skimage.measure.regionprops(lab) for obj in rp: centroid = obj.centroid - yc, xc = obj.centroid[-2:] + yc, xc = centroid[-2:] acdc_df.at[(frame_i, obj.label), 'x_centroid'] = int(xc) acdc_df.at[(frame_i, obj.label), 'y_centroid'] = int(yc) if len(centroid) == 3: if 'z_centroid' not in acdc_df.columns: acdc_df['z_centroid'] = 0 - zc = obj.centroid[0] + zc = centroid[0] acdc_df.at[(frame_i, obj.label), 'z_centroid'] = int(zc) if not save: @@ -3053,6 +3062,7 @@ def buildPaths(self): self.raw_postproc_segm_path = f'{base_path}segm_raw_postproc' self.post_proc_mot_metrics = f'{base_path}post_proc_mot_metrics' self.segm_hyperparams_ini_path = f'{base_path}segm_hyperparams.ini' + self.segm_metadata_ini_path = f'{base_path}segm_metadata_data.ini' self.custom_annot_json_path = f'{base_path}custom_annot_params.json' self.custom_combine_metrics_path = ( f'{base_path}custom_combine_metrics.ini' @@ -3076,6 +3086,7 @@ def get_tracker_export_path(self, trackerName, ext): def setBlankSegmData(self, SizeT, SizeZ, SizeY, SizeX): if not hasattr(self, 'img_data'): self.segm_data = None + self.single_timepoint_size = None return Y, X = self.img_data.shape[-2:] @@ -3088,6 +3099,16 @@ def setBlankSegmData(self, SizeT, SizeZ, SizeY, SizeX): self.segm_data = np.zeros((SizeT, Y, X), int) else: self.segm_data = np.zeros((Y, X), int) + + + def getSingleTimepointSegmSize(self): + if hasattr(self, 'single_timepoint_size'): + return self.single_timepoint_size + if self.SizeT > 1: + self.single_timepoint_size = np.prod(self.segm_data.shape[1:]) + else: # not sure if time axis is present but would be 1 anyways + self.single_timepoint_size = np.prod(self.segm_data.shape) + return self.single_timepoint_size def loadAllImgPaths(self): tif_paths = [] @@ -3415,7 +3436,14 @@ def loadWhitelist(self): self.whitelist = whitelist.Whitelist( total_frames=self.SizeT, ) - whitelist_path = self.segm_npz_path.replace('.npz', '_whitelistIDs.json') + whitelist_path_legacy = self.segm_npz_path.replace('.npz', '_whitelistIDs.json') + segm_filename = os.path.basename(self.segm_npz_path).replace('.npz', '') + segm_add_data_folder = os.path.join(self.images_path, segm_filename) + os.makedirs(segm_add_data_folder, exist_ok=True) + whitelist_path = os.path.join(segm_add_data_folder, 'whitelistIDs.json') + if os.path.exists(whitelist_path_legacy): + # move to new path + shutil.move(whitelist_path_legacy, whitelist_path) new_centroids_path = self.segm_npz_path.replace('.npz', '_new_centroids.json') success = self.whitelist.load( whitelist_path, new_centroids_path, self.segm_data, self.allData_li, @@ -3426,7 +3454,127 @@ def loadWhitelist(self): if not success: self.whitelist = None - + def readSegmMetadataIni(self): + if not os.path.exists(self.segm_metadata_ini_path): + return None + + cp = config.ConfigParser() + cp.read(self.segm_metadata_ini_path) + # one entry for each segmentation file + self.segmMetadata = {} + for segm_file in cp.sections(): + sizeX = cp.getint(segm_file, 'sizeX', fallback=None) + sizeY = cp.getint(segm_file, 'sizeY', fallback=None) + sizeT = cp.getint(segm_file, 'SizeT', fallback=None) + sizeZ = cp.getint(segm_file, 'SizeZ', fallback=None) + is_3D = sizeZ > 1 if sizeZ is not None else False + last_modified_date = cp.get(segm_file, 'last_modified_date', fallback=None) + acdc_df_segm = cp.get(segm_file, 'acdc_df_segm', fallback=None) + acdc_df_save_date = cp.get(segm_file, 'acdc_df_save_date', fallback=None) + self.segmMetadata[segm_file] = { + 'SizeT': sizeT, + 'SizeZ': sizeZ, + 'is_3D': is_3D, + 'last_modified_date': last_modified_date, + 'acdc_df_segm': acdc_df_segm, + 'acdc_df_save_date': acdc_df_save_date, + 'sizeX': sizeX, + 'sizeY': sizeY, + } + + def saveSegmMetadataIni(self): + cp = config.ConfigParser() + for segm_file, metadata in self.segmMetadata.items(): + cp[segm_file] = {} + cp[segm_file]['SizeT'] = str(metadata.get('SizeT', '')) + cp[segm_file]['SizeZ'] = str(metadata.get('SizeZ', '')) + cp[segm_file]['last_modified_date'] = str(metadata.get('last_modified_date', '')) + cp[segm_file]['acdc_df_segm'] = str(metadata.get('acdc_df_segm', '')) + cp[segm_file]['sizeX'] = str(metadata.get('sizeX', '')) + cp[segm_file]['sizeY'] = str(metadata.get('sizeY', '')) + cp[segm_file]['acdc_df_save_date'] = str(metadata.get('acdc_df_save_date', '')) + + with open(self.segm_metadata_ini_path, 'w') as configfile: + cp.write(configfile) + + def updateSegmMetadata(self, segm_file=None, SizeT=None, SizeZ=None, + acdc_df_segm=None, last_modified_date=None, + sizeY=None, sizeX=None, all=False, acdc_df_save_date=None): + if segm_file is None: + segm_file = os.path.basename(self.segm_npz_path) + + if self.segmMetadata is None: + self.segmMetadata = {} + segm_metadata = self.segmMetadata.get(segm_file, {}) + if SizeT is not None or all: + if SizeT is True or SizeT is None: + SizeT = self.SizeT + segm_metadata['SizeT'] = SizeT + if SizeZ is not None or all: + if SizeZ is True or SizeZ is None: + SizeZ = self.SizeZ if self.isSegm3D else 1 + segm_metadata['SizeZ'] = SizeZ + segm_metadata['is_3D'] = SizeZ > 1 + if acdc_df_segm is not None or all: + if acdc_df_segm is True or acdc_df_segm is None: + acdc_df_segm = os.path.basename(self.acdc_output_csv_path) # for future if we allow multpiple outputs + # clear other segm metadata entries with acdc_df info to avoid confusion + for info in self.segmMetadata.values(): + if info.get('acdc_df_segm', '') == acdc_df_segm: + info['acdc_df_segm'] = None + segm_metadata['acdc_df_segm'] = acdc_df_segm + if last_modified_date is not None or all: + if last_modified_date is True or last_modified_date is None: # explicitly in this cane set curr datetime + last_modified_date = datetime.now() + segm_metadata['last_modified_date'] = last_modified_date + if sizeY is not None or all: + if sizeY is True or sizeY is None: + sizeY = self.SizeY + segm_metadata['sizeY'] = sizeY + if sizeX is not None or all: + if sizeX is True or sizeX is None: + sizeX = self.SizeX + segm_metadata['sizeX'] = sizeX + if acdc_df_save_date is not None or all: + if acdc_df_save_date is True or acdc_df_save_date is None: + acdc_df_save_date = datetime.now() + segm_metadata['acdc_df_save_date'] = acdc_df_save_date + self.segmMetadata[segm_file] = segm_metadata + + def saveCentroidsIDs(self): + centroids_mappers = dict() + centroids_IDs_exact = dict() + # IDs = dict() + # ID_to_idx = dict() + for i, data_dict in enumerate(self.allData_li): + rp = data_dict.get('regionprops', None) + if rp is None: + continue + centroids_mappers[i] = rp._centroid_mapper + centroids_IDs_exact[i] = rp._centroid_IDs_exact + # IDs[i] = rp.IDs + # ID_to_idx[i] = rp.ID_to_idx + + segm_filename = os.path.basename(self.segm_npz_path).replace('.npz', '') + segm_add_data_folder = os.path.join(self.images_path, segm_filename) + os.makedirs(segm_add_data_folder, exist_ok=True) + centroids_path = os.path.join(segm_add_data_folder, 'centroids.pkl') + # IDs_path = os.path.join(segm_add_data_folder, 'IDs.pkl') + centroids_IDs_exact_path = os.path.join(segm_add_data_folder, 'centroids_IDs_exact.pkl') + # ID_to_idx_path = os.path.join(segm_add_data_folder, 'ID_to_idx.pkl') + + with open(centroids_path, 'wb') as f: + pickle.dump(centroids_mappers, f) + + with open(centroids_IDs_exact_path, 'wb') as f: + pickle.dump(centroids_IDs_exact, f) + + # with open(IDs_path, 'wb') as f: + # pickle.dump(IDs, f) + + # with open(ID_to_idx_path, 'wb') as f: + # pickle.dump(ID_to_idx, f) + class select_exp_folder: def __init__(self): self.exp_path = None diff --git a/cellacdc/myutils.py b/cellacdc/myutils.py index 03f50f87..1d446d96 100644 --- a/cellacdc/myutils.py +++ b/cellacdc/myutils.py @@ -55,6 +55,7 @@ from . import urls from . import qrc_resources_path from . import settings_folderpath +from . import regionprops from .models._cellpose_base import min_target_versions_cp if GUI_INSTALLED: @@ -1101,22 +1102,74 @@ def get_chname_from_basename(filename, basename, remove_ext=True): chName = chName[:aligned_idx] return chName +def _edge_ids_2d(lab): + border_labels = np.r_[ + lab[0, :], + lab[-1, :], + lab[:, 0], + lab[:, -1], + ] + return np.unique(border_labels[border_labels != 0]) + +def _edge_ids_3d(lab): + face_labels = np.r_[ + lab[ 0, :, :].ravel(), # z min + lab[-1, :, :].ravel(), # z max + lab[:, 0, :].ravel(), # y min + lab[:, -1, :].ravel(), # y max + lab[:, :, 0].ravel(), # x min + lab[:, :, -1].ravel(), # x max + ] + ids = np.unique(face_labels) + return ids[ids != 0] + +def get_edge_ids(lab): + if lab.ndim == 2: + return _edge_ids_2d(lab) + elif lab.ndim == 3: + return _edge_ids_3d(lab) + else: + raise ValueError('Label array must be either 2D or 3D.') + +def clear_border(lab, return_edge_ids=False): + # probably faster than skimage since it avoids relabeling... + # assumes continous unique IDs, which we have. Modifies inplace! + edge_ids = get_edge_ids(lab) + lab[np.isin(lab, edge_ids)] = 0 + if return_edge_ids: + return edge_ids + def getBaseAcdcDf(rp): zeros_list = [0]*len(rp) nones_list = [None]*len(rp) minus1_list = [-1]*len(rp) IDs = [] - xx_centroid = [] - yy_centroid = [] - zz_centroid = [] - for obj in rp: - xc, yc = obj.centroid[-2:] - IDs.append(obj.label) - xx_centroid.append(xc) - yy_centroid.append(yc) - if len(obj.centroid) == 3: - zc = obj.centroid[0] - zz_centroid.append(zc) + xx_centroid = [0]*len(rp) + yy_centroid = [0]*len(rp) + zz_centroid = [0]*len(rp) + + if isinstance(rp, regionprops.acdcRegionprops): + for obj in rp: + ID = obj.label + centroid = rp.get_centroid(ID, exact=True) + xc, yc = centroid[-2:] + IDs.append(ID) + xx_centroid.append(xc) + yy_centroid.append(yc) + if len(centroid) == 3: + zc = centroid[0] + zz_centroid.append(zc) + + else: + for obj in rp: + centroid = obj.centroid + xc, yc = centroid[-2:] + IDs.append(obj.label) + xx_centroid.append(xc) + yy_centroid.append(yc) + if len(centroid) == 3: + zc = centroid[0] + zz_centroid.append(zc) df = pd.DataFrame( { @@ -5032,7 +5085,6 @@ def get_empty_stored_data_dict(): 'delROIs_info': { 'rois': [], 'delMasks': [], 'delIDsROI': [], 'state': [] }, - 'IDs': [], 'manually_edited_lab': {'lab': {}, 'zoom_slice': None} } @@ -5414,7 +5466,7 @@ def find_distances_ID(rps, point=None, ID=None): if ID is not None and point is None: try: - point = [rp.centroid for rp in rps if rp.label == ID][0] + point = rps.get_centroid(ID) except IndexError: raise ValueError(f'ID {ID} not found in regionprops (list of cells).') @@ -5426,7 +5478,7 @@ def find_distances_ID(rps, point=None, ID=None): point = point[::-1] # rp are in (y, x) format (or (z, y, x) for 3D data) so I need to reverse order point = np.array([point]) - centroids = np.array([rp.centroid for rp in rps]) + centroids = np.array([rps.get_centroid(ID) for ID in rps.IDs]) diff = point[:, np.newaxis] - centroids dist_matrix = np.linalg.norm(diff, axis=2) return dist_matrix @@ -5463,7 +5515,7 @@ def sort_IDs_dist(rps, point=None, ID=None): """ if ID is not None and point is None: try: - point = [rp.centroid for rp in rps if rp.label == ID][0] + point = rps.get_centroid(ID) except IndexError: raise ValueError(f'ID {ID} not found in regionprops (list of cells).') @@ -5474,7 +5526,7 @@ def sort_IDs_dist(rps, point=None, ID=None): raise ValueError('Only one of ID or point must be provided.') - IDs = [rp.label for rp in rps] + IDs = rps.IDs if len(IDs) == 0: return [] elif len(IDs) == 1: diff --git a/cellacdc/plot.py b/cellacdc/plot.py index 3d33b3c2..652a61e9 100644 --- a/cellacdc/plot.py +++ b/cellacdc/plot.py @@ -15,6 +15,7 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable import seaborn as sns +from . import debugutils from tqdm import tqdm from . import GUI_INSTALLED diff --git a/cellacdc/precompiled/__init__.py b/cellacdc/precompiled/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cellacdc/regionprops.py b/cellacdc/regionprops.py new file mode 100644 index 00000000..89eb007c --- /dev/null +++ b/cellacdc/regionprops.py @@ -0,0 +1,770 @@ +import numpy as np +from scipy import ndimage as ndi +import skimage.measure +from . import printl, debugutils +from skimage.measure._regionprops_utils import ( + _normalize_spacing, +) +import traceback as traceback + +try: + from .precompiled.regionprops_helper import find_all_objects_2D, find_all_objects_3D + _CYTHON_FIND_OBJECTS = True +except Exception: + _CYTHON_FIND_OBJECTS = False +# WARNING: Developers have already used +# 7 hrs +# to optimize this. +# In addition, implementing these optimizations in the codebase took +# 7 hrs +# Specifically the +# centroid (huge fain for 3D data) +# stuff was targeted. +# If you decide to try and optimize it further, please update this warning :) + +_RegionProperties = skimage.measure._regionprops.RegionProperties +_cached = skimage.measure._regionprops._cached + +# @debugutils.line_benchmark +def _acdc_regionprops_factory( + label_image, + intensity_image=None, + cache=True, + *, + extra_properties=None, + spacing=None, + offset=None, + ): + if label_image.ndim not in (2, 3): + raise TypeError('Only 2-D and 3-D images supported.') + + if not np.issubdtype(label_image.dtype, np.integer): + if np.issubdtype(label_image.dtype, bool): + raise TypeError( + 'Non-integer image types are ambiguous: ' + 'use skimage.measure.label to label the connected ' + 'components of label_image, ' + 'or label_image.astype(np.uint8) to interpret ' + 'the True values as a single label.' + ) + raise TypeError('Non-integer label_image types are ambiguous') + + if offset is None: + offset_arr = np.zeros((label_image.ndim,), dtype=int) + else: + offset_arr = np.asarray(offset) + if offset_arr.ndim != 1 or offset_arr.size != label_image.ndim: + raise ValueError( + 'Offset should be an array-like of integers ' + 'of shape (label_image.ndim,); ' + f'{offset} was provided.' + ) + + regions = [] + if _CYTHON_FIND_OBJECTS: + img_uint32 = label_image.astype(np.uint32, copy=False) + if label_image.ndim == 2: + labels, bboxes = find_all_objects_2D(img_uint32) + for i in range(len(labels)): + sl = (slice(int(bboxes[i, 0]), int(bboxes[i, 1])), + slice(int(bboxes[i, 2]), int(bboxes[i, 3]))) + regions.append(acdcRegionProperties( + sl, int(labels[i]), label_image, intensity_image, cache, + spacing=spacing, extra_properties=extra_properties, + offset=offset_arr, + )) + else: + labels, bboxes = find_all_objects_3D(img_uint32) + for i in range(len(labels)): + sl = (slice(int(bboxes[i, 0]), int(bboxes[i, 1])), + slice(int(bboxes[i, 2]), int(bboxes[i, 3])), + slice(int(bboxes[i, 4]), int(bboxes[i, 5]))) + regions.append(acdcRegionProperties( + sl, int(labels[i]), label_image, intensity_image, cache, + spacing=spacing, extra_properties=extra_properties, + offset=offset_arr, + )) + else: + objects = ndi.find_objects(label_image) + for i, sl in enumerate(objects, start=1): + if sl is None: + continue + regions.append(acdcRegionProperties( + sl, i, label_image, intensity_image, cache, + spacing=spacing, extra_properties=extra_properties, + offset=offset_arr, + )) + return regions + + +# class acdcRegionProperties(_RegionProperties): +# def __init__( +# self, +# slice, +# label, +# label_image, +# intensity_image, +# cache_active, +# *, +# extra_properties=None, +# spacing=None, +# offset=None, +# ): +# if intensity_image is not None: +# ndim = label_image.ndim +# if not ( +# intensity_image.shape[:ndim] == label_image.shape +# and intensity_image.ndim in [ndim, ndim + 1] +# ): +# raise ValueError( +# 'Label and intensity image shapes must match,' +# ' except for channel (last) axis.' +# ) +# multichannel = label_image.shape < intensity_image.shape +# else: +# multichannel = False + +# self.label = label +# if offset is None: +# offset = np.zeros((label_image.ndim,), dtype=int) +# self._offset = np.array(offset) + +# self._slice = slice + +# self._label_image = label_image +# self._intensity_image = intensity_image + +# self._cache_active = cache_active +# self._cache = {} +# self._ndim = label_image.ndim +# self._multichannel = multichannel +# self._spatial_axes = tuple(range(self._ndim)) +# if spacing is None: +# spacing = np.full(self._ndim, 1.0) +# self._spacing = _normalize_spacing(spacing, self._ndim) +# self._pixel_area = np.prod(self._spacing) + +# self._extra_properties = {} +# if extra_properties is not None: +# for func in extra_properties: +# name = func.__name__ +# if hasattr(self, name): +# msg = ( +# f"Extra property '{name}' is shadowed by existing " +# f"property and will be inaccessible. Consider " +# f"renaming it." +# ) +# self._extra_properties = {func.__name__: func for func in extra_properties} +class acdcRegionProperties(_RegionProperties): + def __init__( + self, + slice, + label, + label_image, + intensity_image, + cache_active, + *, + extra_properties=None, + spacing=None, + offset=None, + ): + super().__init__( + slice, label, label_image, intensity_image, cache_active, + extra_properties=extra_properties, spacing=spacing, offset=offset + ) + # @property + # @_cached + # def slice(self): + # # scale slice with offset + # return tuple( + # slice(self._slice[i].start + self._offset[i], + # self._slice[i].stop + self._offset[i]) + # for i in range(self._ndim) + # ) + + @property + @_cached + def bbox(self): + """ + Returns + ------- + A tuple of the bounding box's start coordinates for each dimension, + followed by the end coordinates for each dimension. + """ + return tuple( + [self.slice[i].start for i in range(self._ndim)] + + [self.slice[i].stop for i in range(self._ndim)] + ) + + @property + @_cached # slow for 3D data, better cache it + def centroid(self): + return super().centroid + + @property + @_cached + def contour(self): + pass + + # @property + # def centroid_weighted(self): + # ctr = self.centroid_weighted_local + # return tuple( + # idx + slc.start * spc + # for idx, slc, spc in zip(ctr, self._slice, self._spacing) + # ) + + # @property + # @_cached + # def image_intensity(self): + # if self._intensity_image is None: + # raise AttributeError('No intensity image specified.') + # image = ( + # self.image + # if not self._multichannel + # else np.expand_dims(self.image, self._ndim) + # ) + # return self._intensity_image[self._slice] * image + + # @property + # def coords(self): + # indices = np.argwhere(self.image) + # object_offset = np.array([self._slice[i].start for i in range(self._ndim)]) + # return object_offset + indices + self._offset + + # @property + # def coords_scaled(self): + # indices = np.argwhere(self.image) + # object_offset = np.array([self._slice[i].start for i in range(self._ndim)]) + # return (object_offset + indices) * self._spacing + self._offset + + +class acdcRegionprops: + def __init__( + self, + lab, + acdc_df=None, + centroids_loaded=None, + IDs_loaded=None, + centroids_IDs_exact_loaded=None, + ID_to_idx_loaded=None, + precache_centroids=True, + **kwargs, + ): + self.lab = lab + self.acdc_df = acdc_df + self._rp = _acdc_regionprops_factory(lab, **kwargs) + self.is3D = self.lab.ndim == 3 + self._centroid_mapper = {} + self._centroid_IDs_exact = set() + if IDs_loaded is None or ID_to_idx_loaded is None: + self.set_attributes(update_centroid_mapper=False) + else: + self.ID_to_idx = ID_to_idx_loaded + self.IDs_set = set(IDs_loaded) + self.IDs = list(self.IDs_set) + + if centroids_IDs_exact_loaded is not None and centroids_loaded is not None: + self._centroid_mapper = centroids_loaded + self._centroid_IDs_exact = set(centroids_IDs_exact_loaded) + elif precache_centroids: + self.precache_centroids() + + else: + self._centroid_mapper = dict() + + def __iter__(self): + return iter(self._rp) + + def __len__(self): + return len(self._rp) + + def __getitem__(self, idx): + return self._rp[idx] + + def __setitem__(self, idx, value): + self._rp[idx] = value + + def __repr__(self): + return repr(self._rp) + + def _get_centroid_df_from_df(self): + if self.acdc_df is None or len(self.acdc_df) == 0: + return {} + + centroid_cols = ['y_centroid', 'x_centroid'] + if self.is3D and 'z_centroid' in self.acdc_df.columns: + centroid_cols = ['z_centroid', 'y_centroid', 'x_centroid'] + + if not set(centroid_cols).issubset(self.acdc_df.columns): + return {} + + if 'Cell_ID' in self.acdc_df.columns: + centroid_df = self.acdc_df.set_index('Cell_ID')[centroid_cols] + elif 'ID' in self.acdc_df.columns: + centroid_df = self.acdc_df.set_index('ID')[centroid_cols] + else: + centroid_df = self.acdc_df[centroid_cols] + + return { + int(ID): tuple(values) + for ID, values in centroid_df.iterrows() + } + + def _get_bbox_centers_mapper( + self, objs=None, IDs_to_include=None, IDs_to_exclude=None + ): + if objs is None and not self._rp: + return {} + + if objs is None: + if IDs_to_include is None: + IDs_to_include = ( + self.IDs_set.difference(IDs_to_exclude) + if IDs_to_exclude is not None else self.IDs_set + ) + ids = set(IDs_to_include) + objs = [obj for obj in self._rp if obj.label in ids] + + if not objs: + return {} + + ndim = 2 if not self.is3D else 3 + labels = np.empty(len(objs), dtype=int) + bboxes = np.empty((len(objs), ndim * 2), dtype=float) + for i, obj in enumerate(objs): + labels[i] = obj.label + bboxes[i] = obj.bbox + + centers = (bboxes[:, :ndim] + bboxes[:, ndim:]) / 2.0 + return { + int(label): tuple(center) + for label, center in zip(labels, centers) + } + + def precache_centroids(self): + centroid_df = self._get_centroid_df_from_df() + IDs_from_df = set(centroid_df) + IDs_missing_centroid = self.IDs_set.difference(IDs_from_df) + bbox_centers_mapper = self._get_bbox_centers_mapper( + IDs_to_include=IDs_missing_centroid + ) + self._centroid_mapper = {**bbox_centers_mapper, **centroid_df} + self._centroid_IDs_exact = IDs_from_df + + def set_attributes(self, deleted_IDs=None, update_centroid_mapper=True): + self.ID_to_idx = {obj.label: idx for idx, obj in enumerate(self._rp)} + # Update IDs and IDs_set separately and explicitly + self.IDs_set = set(self.ID_to_idx) + self.IDs = list(self.IDs_set) + + + if not update_centroid_mapper: + return + if deleted_IDs is not None: + for ID in deleted_IDs: + self._centroid_mapper.pop(ID, None) + self._centroid_IDs_exact.discard(ID) + else: + self._centroid_mapper = { + ID: centroid + for ID, centroid in self._centroid_mapper.items() + if ID in self.IDs_set + } + self._centroid_IDs_exact.intersection_update(self.IDs_set) + + def get_obj_from_ID(self, ID, warn=True): + idx = self.ID_to_idx.get(ID, None) + if idx is not None: + return self._rp[idx] + else: + if warn: + # get caller info + debugutils.print_call_stack() + printl(f"Warning: Object with ID {ID} not found in regionprops.") + return None + + def delete_IDs(self, IDs_to_delete: set[int], update_other_attrs=True): + if not IDs_to_delete: + return + + self._rp = [ + obj for obj in self._rp if obj.label not in IDs_to_delete + ] + + if not update_other_attrs: + return + self.set_attributes(deleted_IDs=IDs_to_delete) + + def _get_IDs_to_update_centroids( + self, lab, objs, specific_IDs_update_centroids=None + ): + if specific_IDs_update_centroids is not None: + return set(specific_IDs_update_centroids) + + obj_to_update = set() + for obj in objs: + has_to_update = False + ID = obj.label + old_centroid = self._centroid_mapper.get(ID, None) + if old_centroid is not None: + rounded_centroid = tuple(np.round(old_centroid).astype(int)) + try: + ID_lab = lab[rounded_centroid] + except Exception: + has_to_update = True + else: + if ID_lab != ID: + has_to_update = True + else: + has_to_update = True + + if has_to_update: + obj_to_update.add(ID) + + return obj_to_update + + def update_regionprops( + self, lab, specific_IDs_update_centroids=None, + update_centroids=True + ): + old_rp_by_id = {obj.label: obj for obj in self._rp} + + new_rp = _acdc_regionprops_factory(lab) + + if update_centroids: + # Verify that the cached centroid is still inside the object mask. + obj_to_update = self._get_IDs_to_update_centroids( + lab, new_rp, + specific_IDs_update_centroids=specific_IDs_update_centroids + ) + + bbox_centers_mapper = self._get_bbox_centers_mapper( + objs=[obj for obj in new_rp if obj.label in obj_to_update] + ) + + # update centroids + self._centroid_mapper.update(bbox_centers_mapper) + + # remove from exact set if we updated the centroid + self._centroid_IDs_exact.difference_update(obj_to_update) + + for obj in new_rp: + self._copy_custom_rp_attributes(obj, old_rp_by_id.get(obj.label)) + + self._rp = new_rp + self.lab = lab + self.set_attributes() + + def _copy_custom_rp_attributes(self, new_obj, old_obj): + if old_obj is None: + return + new_obj.dead = getattr(old_obj, 'dead', False) + new_obj.excluded = getattr(old_obj, 'excluded', False) + + def _get_bbox_slices(self, bbox): + ndim = self.lab.ndim + if len(bbox) != ndim * 2: + raise ValueError( + f'Expected a bounding box with {ndim*2} values, ' + f'got {len(bbox)}.' + ) + return tuple( + slice(int(bbox[dim]), int(bbox[dim+ndim])) for dim in range(ndim) + ) + + def _translate_cutout_regionprop(self, obj, offset, lab): + offset_arr = np.asarray(offset) + centroid = obj.centroid + translated_slice = tuple( + slice( + obj._slice[dim].start + offset_arr[dim], + obj._slice[dim].stop + offset_arr[dim], + ) + for dim in range(obj._ndim) + ) + translated_bbox = tuple( + [slc.start for slc in translated_slice] + + [slc.stop for slc in translated_slice] + ) + translated_centroid = tuple( + coord + offset_arr[dim] + for dim, coord in enumerate(centroid) + ) + + obj._label_image = lab + obj._slice = translated_slice + obj.slice = translated_slice + obj._offset = np.zeros_like(offset_arr) + obj._cache['slice'] = translated_slice + obj._cache['bbox'] = translated_bbox + obj._cache['centroid'] = translated_centroid + return obj + + def _get_separate_obj_regionprops(self, lab, IDs): + IDs = tuple(int(ID) for ID in IDs) + if not IDs: + return {} + + mask = np.isin(lab, IDs) + if not np.any(mask): + return {} + + isolated_lab = np.zeros_like(lab) + isolated_lab[mask] = lab[mask] + return { + obj.label: obj + for obj in _acdc_regionprops_factory(isolated_lab) + if obj.label in IDs + } + + def _is_bbox_touching_cutout_border(self, bbox, shape): + ndim = len(shape) + for dim in range(ndim): + if bbox[dim] == 0 or bbox[dim+ndim] == shape[dim]: + return True + return False + + def _obj_intersects_bbox(self, obj, bbox): + ndim = self.lab.ndim + obj_bbox = obj.bbox + for dim in range(ndim): + start = max(int(obj_bbox[dim]), int(bbox[dim])) + stop = min(int(obj_bbox[dim+ndim]), int(bbox[dim+ndim])) + if start >= stop: + return False + + return True + + def _get_old_cutout_IDs_from_rp(self, cutout_bbox): + return { + obj.label for obj in self._rp + if self._obj_intersects_bbox(obj, cutout_bbox) + } + + def _set_label_image(self, lab, objs=None, clear_cache=False): + self.lab = lab + if objs is None: + objs = self._rp + + for obj in objs: + obj._label_image = lab + if clear_cache: + obj._cache.clear() + + def update_regionprops_via_assignments( + self, assignments:dict[int, int], lab + ): + """If the lab is completely the same, but only ID changes/swaps have been made + + Parameters + ---------- + assignments : dict[int, int] + key: old ID, + value: new ID + lab : np.ndarray, optional + Updated label image. When provided, regionprops objects are rebound + to this image so properties such as ``image`` stay consistent after + the ID remap. + """ + active_assignments = { + int(old_ID): int(new_ID) + for old_ID, new_ID in assignments.items() + if old_ID in self.IDs_set and old_ID != new_ID + } + if not active_assignments: + self._set_label_image(lab) + return + + # if not active_assignments: + # if lab is not None: + # self._set_label_image(lab) + # return + + # remapped_IDs = set() + # for obj in self._rp: + # old_ID = obj.label + # new_ID = active_assignments.get(old_ID, old_ID) + # if new_ID in remapped_IDs: + # raise ValueError( + # 'Assignments would create duplicate IDs in regionprops. ' + # 'Use a full regionprops recomputation for merges.' + # ) + # remapped_IDs.add(new_ID) + + centroid_mapper = { + active_assignments.get(ID, ID): centroid + for ID, centroid in self._centroid_mapper.items() + # if active_assignments.get(ID, ID) in remapped_IDs + } + centroid_IDs_exact = { + active_assignments.get(ID, ID) + for ID in self._centroid_IDs_exact + # if active_assignments.get(ID, ID) in remapped_IDs + } + + for obj in self._rp: + old_ID = obj.label + new_ID = active_assignments.get(old_ID, old_ID) + obj.label = new_ID + # if obj.area == 0: + # # if area is 0, centroid is not defined and we should not trust the cached one + # print("area 0...") + + self._set_label_image(lab, clear_cache=True) + + self._centroid_mapper = centroid_mapper + self._centroid_IDs_exact = centroid_IDs_exact + self.set_attributes(update_centroid_mapper=False) # update the mapper + + def update_regionprops_via_deletions( + self, IDs_to_delete: set[int] + ): + """If the lab is completely the same, but only some IDs have been deleted + + Parameters + ---------- + IDs_to_delete : set[int] + IDs to delete + """ + IDs_to_delete = set(IDs_to_delete).intersection(self.IDs_set) + if not IDs_to_delete: + return + self._rp = [obj for obj in self._rp if obj.label not in IDs_to_delete] + self.set_attributes(deleted_IDs=IDs_to_delete) # for updating the IDs to indx, centroid mapper + + def update_regionprops_via_cutout( + self, lab, cutout_bbox, specific_IDs=None, debug=True + ): + """Only relabels the regionprops of a specific cutout. + Is only faster for small cutouts. I dont have a number, but I would say + less than 30% of total image size. + + Parameters + ---------- + cutout_lab : np.ndarray + The labeled cutout image. + cutout_bbox : tuple[int, int, int, int] + The bounding box of the cutout in the format (min_row, min_col, max_row, max_col). + """ + if specific_IDs is not None and not isinstance(specific_IDs, (list, set, np.ndarray, tuple)): + specific_IDs = {specific_IDs} + elif specific_IDs is not None: + specific_IDs = set(specific_IDs) + + self.lab = lab + cutout_slices = self._get_bbox_slices(cutout_bbox) + new_cutout = lab[cutout_slices] + old_cutout_IDs = self._get_old_cutout_IDs_from_rp(cutout_bbox) + rp_cutout_new = _acdc_regionprops_factory(new_cutout) + new_cutout_IDs = set(obj.label for obj in rp_cutout_new) + + if not old_cutout_IDs and not new_cutout_IDs: + return + + target_IDs = ( + old_cutout_IDs.union(new_cutout_IDs) + if specific_IDs is None + else old_cutout_IDs.union(new_cutout_IDs).intersection(specific_IDs) + ) + + deleted_target_IDs = old_cutout_IDs.difference(new_cutout_IDs).intersection( + target_IDs + ) + + refreshed_IDs = new_cutout_IDs.intersection(target_IDs) + + conflicting_IDs = refreshed_IDs.difference(old_cutout_IDs).intersection( + self.IDs_set.difference(old_cutout_IDs) + ) + if conflicting_IDs: + raise ValueError( + 'Cutout update would reuse IDs that already belong to objects ' + 'outside the cutout. Use a full regionprops recomputation.' + ) + + old_rp_by_id = {obj.label: obj for obj in self._rp} + IDs_to_replace = old_cutout_IDs.intersection(target_IDs) + unaffected_rp = [obj for obj in self._rp if obj.label not in IDs_to_replace] + + offset = tuple(s.start for s in cutout_slices) + + border_touching_IDs = { + obj.label + for obj in rp_cutout_new + if obj.label in refreshed_IDs + and self._is_bbox_touching_cutout_border(obj.bbox, new_cutout.shape) + } + separate_objs = self._get_separate_obj_regionprops(lab, border_touching_IDs) + + new_objs = [] + updated_centroid_IDs = set() + for obj in rp_cutout_new: + ID = obj.label + if ID not in refreshed_IDs: + continue + if ID in border_touching_IDs: + # edge case: ID changed is outside the cutout + new_obj = separate_objs.get(ID) + if new_obj is None: + continue + else: + new_obj = self._translate_cutout_regionprop(obj, offset, lab) + + self._copy_custom_rp_attributes(new_obj, old_rp_by_id.get(ID)) + new_objs.append(new_obj) + updated_centroid_IDs.add(ID) + + for ID in deleted_target_IDs: + self._centroid_mapper.pop(ID, None) + self._centroid_IDs_exact.discard(ID) + + if updated_centroid_IDs: + obj_to_update = self._get_IDs_to_update_centroids( + lab, new_objs, + specific_IDs_update_centroids=updated_centroid_IDs + ) + + self._centroid_mapper.update( + self._get_bbox_centers_mapper( + objs=[obj for obj in new_objs if obj.label in obj_to_update] + ) + ) + self._centroid_IDs_exact.difference_update(obj_to_update) + + self._rp = unaffected_rp + new_objs + self._set_label_image(lab) + self.set_attributes(update_centroid_mapper=False) + + def get_centroid(self, ID, exact=False): + if exact and ID not in self._centroid_IDs_exact: + obj = self.get_obj_from_ID(ID) + centroid = obj.centroid + try: + int(centroid[0]) + except (TypeError, ValueError): + print(f"Warning: Centroid for ID {ID} is not a valid coordinate: {centroid}. " + f"Object size: {obj.bbox}. Returning None.") + return None + self._centroid_mapper[ID] = centroid + self._centroid_IDs_exact.add(ID) + return centroid + + centroid = self._centroid_mapper.get(ID, None) + if centroid is None: + # add centroid to mapper if not found + objs = [self.get_obj_from_ID(ID)] + bbox_centers_mapper = self._get_bbox_centers_mapper(objs=objs) + self._centroid_mapper.update(bbox_centers_mapper) + centroid = self._centroid_mapper.get(ID, None) + return centroid + + def copy(self): + new_instance = acdcRegionprops( + self.lab, precache_centroids=False + ) + new_instance._rp = [obj for obj in self._rp] + new_instance._centroid_mapper = self._centroid_mapper.copy() + new_instance._centroid_IDs_exact = self._centroid_IDs_exact.copy() + new_instance.set_attributes(update_centroid_mapper=False) + return new_instance \ No newline at end of file diff --git a/cellacdc/regionprops_helper.pyx b/cellacdc/regionprops_helper.pyx new file mode 100644 index 00000000..48a70bab --- /dev/null +++ b/cellacdc/regionprops_helper.pyx @@ -0,0 +1,118 @@ +# regionprops_helper.pyx +# cython: boundscheck=False, wraparound=False, cdivision=True +import numpy as np +cimport numpy as np +from libc.limits cimport UINT_MAX + +def find_all_objects_2D(np.uint32_t[:, :] label_img): + cdef Py_ssize_t n_rows = label_img.shape[0] + cdef Py_ssize_t n_cols = label_img.shape[1] + cdef Py_ssize_t i, j + cdef unsigned int label, max_label = 0 + + # First pass: find max label to allocate C arrays + for i in range(n_rows): + for j in range(n_cols): + label = label_img[i, j] + if label > max_label: + max_label = label + + if max_label == 0: + return [] + + cdef np.ndarray[np.uint32_t, ndim=1] _rs = np.full(max_label + 1, UINT_MAX, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=1] _re = np.zeros(max_label + 1, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=1] _cs = np.full(max_label + 1, UINT_MAX, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=1] _ce = np.zeros(max_label + 1, dtype=np.uint32) + + cdef unsigned int[:] rs = _rs, re = _re, cs = _cs, ce = _ce + + # Second pass: compute bounding boxes without Python objects in the hot loop + for i in range(n_rows): + for j in range(n_cols): + label = label_img[i, j] + if label > 0: + if i < rs[label]: rs[label] = i + if i + 1 > re[label]: re[label] = (i + 1) + if j < cs[label]: cs[label] = j + if j + 1 > ce[label]: ce[label] = (j + 1) + + # Collect present labels into compact numpy arrays (avoids per-label tuple allocation) + cdef unsigned int n_labels = 0 + for lbl in range(1, max_label + 1): + if re[lbl] != 0: + n_labels += 1 + + cdef np.ndarray[np.uint32_t, ndim=1] out_labels = np.empty(n_labels, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=2] out_bboxes = np.empty((n_labels, 4), dtype=np.uint32) + cdef unsigned int idx = 0 + for lbl in range(1, max_label + 1): + if re[lbl] != 0: + out_labels[idx] = lbl + out_bboxes[idx, 0] = rs[lbl] + out_bboxes[idx, 1] = re[lbl] + out_bboxes[idx, 2] = cs[lbl] + out_bboxes[idx, 3] = ce[lbl] + idx += 1 + return out_labels, out_bboxes + +def find_all_objects_3D(np.uint32_t[:, :, :] label_img): + cdef Py_ssize_t n_z = label_img.shape[0] + cdef Py_ssize_t n_rows = label_img.shape[1] + cdef Py_ssize_t n_cols = label_img.shape[2] + cdef Py_ssize_t i, j, k + cdef unsigned int label, max_label = 0 + + # First pass: find max label + for i in range(n_z): + for j in range(n_rows): + for k in range(n_cols): + label = label_img[i, j, k] + if label > max_label: + max_label = label + + if max_label == 0: + return [] + + cdef np.ndarray[np.uint32_t, ndim=1] _zs = np.full(max_label + 1, UINT_MAX, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=1] _ze = np.zeros(max_label + 1, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=1] _rs = np.full(max_label + 1, UINT_MAX, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=1] _re = np.zeros(max_label + 1, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=1] _cs = np.full(max_label + 1, UINT_MAX, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=1] _ce = np.zeros(max_label + 1, dtype=np.uint32) + + cdef unsigned int[:] zs = _zs, ze = _ze, rs = _rs, re = _re, cs = _cs, ce = _ce + + # Second pass: compute bounding boxes + for i in range(n_z): + for j in range(n_rows): + for k in range(n_cols): + label = label_img[i, j, k] + if label > 0: + if i < zs[label]: zs[label] = i + if i + 1 > ze[label]: ze[label] = (i + 1) + if j < rs[label]: rs[label] = j + if j + 1 > re[label]: re[label] = (j + 1) + if k < cs[label]: cs[label] = k + if k + 1 > ce[label]: ce[label] = (k + 1) + + # Collect present labels into compact numpy arrays (avoids per-label tuple allocation) + cdef unsigned int n_labels = 0 + for lbl in range(1, max_label + 1): + if ze[lbl] != 0: + n_labels += 1 + + cdef np.ndarray[np.uint32_t, ndim=1] out_labels = np.empty(n_labels, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=2] out_bboxes = np.empty((n_labels, 6), dtype=np.uint32) + cdef unsigned int idx = 0 + for lbl in range(1, max_label + 1): + if ze[lbl] != 0: + out_labels[idx] = lbl + out_bboxes[idx, 0] = zs[lbl] + out_bboxes[idx, 1] = ze[lbl] + out_bboxes[idx, 2] = rs[lbl] + out_bboxes[idx, 3] = re[lbl] + out_bboxes[idx, 4] = cs[lbl] + out_bboxes[idx, 5] = ce[lbl] + idx += 1 + return out_labels, out_bboxes \ No newline at end of file diff --git a/cellacdc/trackers/CellACDC/CellACDC_tracker.py b/cellacdc/trackers/CellACDC/CellACDC_tracker.py index 07cfe841..2af2f33b 100755 --- a/cellacdc/trackers/CellACDC/CellACDC_tracker.py +++ b/cellacdc/trackers/CellACDC/CellACDC_tracker.py @@ -7,20 +7,57 @@ from skimage.measure import regionprops from skimage.segmentation import relabel_sequential -from cellacdc import core, printl +from cellacdc import core, printl, debugutils DEBUG = False +def _normalize_specific_IDs(specific_IDs): + if specific_IDs is None: + return None + if isinstance(specific_IDs, (list, tuple, set, np.ndarray)): + return set(specific_IDs) + return {specific_IDs} + +def _filter_subset_assignments(old_IDs, tracked_IDs, all_curr_IDs, specific_IDs): + if specific_IDs is None: + return old_IDs, tracked_IDs + + selected_curr_IDs = set(specific_IDs) + other_curr_IDs = set(all_curr_IDs).difference(selected_curr_IDs) + filtered_old_IDs = [] + filtered_tracked_IDs = [] + for old_ID, tracked_ID in zip(old_IDs, tracked_IDs): + if tracked_ID in other_curr_IDs: + continue + filtered_old_IDs.append(old_ID) + filtered_tracked_IDs.append(tracked_ID) + + return filtered_old_IDs, filtered_tracked_IDs + def calc_Io_matrix(lab, prev_lab, rp, prev_rp, IDs_curr_untracked=None, - denom:str='area_prev', IDs=None): + specific_IDs=None, + denom:str='area_prev'): # maybe its faster to calculate IoU not via mask but via area1 / (area1 + area2 - intersection) + specific_IDs = _normalize_specific_IDs(specific_IDs) IDs_prev = [] if IDs_curr_untracked is None: IDs_curr_untracked = [obj.label for obj in rp] + elif not isinstance(IDs_curr_untracked, list): + IDs_curr_untracked = list(IDs_curr_untracked) + + if specific_IDs is not None: + IDs_curr_untracked = [ + ID for ID in IDs_curr_untracked if ID in specific_IDs + ] - IoA_matrix = np.zeros((len(rp), len(prev_rp))) + if not IDs_curr_untracked: + return np.zeros((0, len(prev_rp))), IDs_curr_untracked, [ + obj.label for obj in prev_rp + ] + + IoA_matrix = np.zeros((len(IDs_curr_untracked), len(prev_rp))) rp_mapper = {obj.label: obj for obj in rp} - idx_mapper = {obj.label: i for i, obj in enumerate(rp)} + idx_mapper = {ID: i for i, ID in enumerate(IDs_curr_untracked)} # For each ID in previous frame get IoA with all current IDs # Rows: IDs in current frame, columns: IDs in previous frame @@ -42,6 +79,7 @@ def calc_Io_matrix(lab, prev_lab, rp, prev_rp, IDs_curr_untracked=None, for j, obj_prev in enumerate(prev_rp): ID_prev = obj_prev.label IDs_prev.append(ID_prev) + # if IDs is not None and ID_prev not in IDs: # continue @@ -60,6 +98,8 @@ def calc_Io_matrix(lab, prev_lab, rp, prev_rp, IDs_curr_untracked=None, continue if denom == 'union': + if intersect_ID not in rp_mapper: + continue obj_curr = rp_mapper[intersect_ID] # temp_lab[obj_prev.slice][obj_prev.image] = True # temp_lab[obj_curr.slice][obj_curr.image] = True @@ -69,7 +109,9 @@ def calc_Io_matrix(lab, prev_lab, rp, prev_rp, IDs_curr_untracked=None, if denom_val == 0: continue - idx = idx_mapper[intersect_ID] + idx = idx_mapper.get(intersect_ID) + if idx is None: + continue IoA = I/denom_val IoA_matrix[idx, j] = IoA return IoA_matrix, IDs_curr_untracked, IDs_prev @@ -77,10 +119,11 @@ def calc_Io_matrix(lab, prev_lab, rp, prev_rp, IDs_curr_untracked=None, def assign( IoA_matrix, IDs_curr_untracked, IDs_prev, IoA_thresh=0.4, aggr_track=None, IoA_thresh_aggr=0.4, daughters_list=None, - IDs=None): + specific_IDs=None): # Determine max IoA between IDs and assign tracked ID if IoA >= IoA_thresh if IoA_matrix.size == 0: return [], [] + max_IoA_col_idx = IoA_matrix.argmax(axis=1) unique_col_idx, counts = np.unique(max_IoA_col_idx, return_counts=True) counts_dict = dict(zip(unique_col_idx, counts)) @@ -187,7 +230,10 @@ def indexAssignment( remove_untracked=False, assign_unique_new_IDs=True, return_assignments=False, - IDs=None + dont_return_tracked_lab=False, + specific_IDs=None, + all_curr_IDs=None, + IDs=None, ): """Replace `old_IDs` in `lab` with `tracked_IDs` while making sure to avoid merging IDs. @@ -229,15 +275,25 @@ def indexAssignment( assignments: dict Returned only if `return_assignments` is True. """ + specific_IDs = _normalize_specific_IDs(specific_IDs) log_debugging( 'start', IDs_curr_untracked=IDs_curr_untracked, old_IDs=old_IDs ) - # Replace untracked IDs with tracked IDs and new IDs with increasing num + if all_curr_IDs is None: + all_curr_IDs = list(IDs_curr_untracked) + old_IDs, tracked_IDs = _filter_subset_assignments( + old_IDs, tracked_IDs, all_curr_IDs, specific_IDs + ) + + # Replace untracked IDs with tracked IDs and new IDs with increasing num. + # When tracking only a subset of current IDs, leave unrelated labels untouched. new_untracked_IDs = [ID for ID in IDs_curr_untracked if ID not in old_IDs] - tracked_lab = lab + + if not dont_return_tracked_lab: + tracked_lab = lab assignments = {} log_debugging( 'assign_unique', @@ -251,9 +307,10 @@ def indexAssignment( new_tracked_IDs = [ uniqueID+i for i in range(len(new_untracked_IDs)) ] - core.lab_replace_values( - tracked_lab, rp, new_untracked_IDs, new_tracked_IDs - ) + if not dont_return_tracked_lab: + core.lab_replace_values( + tracked_lab, rp, new_untracked_IDs, new_tracked_IDs + ) assignments.update(dict(zip(new_untracked_IDs, new_tracked_IDs))) log_debugging( 'new_untracked_and_assign_unique', @@ -271,9 +328,10 @@ def indexAssignment( new_tracked_IDs = [ uniqueID+i for i in range(len(new_IDs_in_trackedIDs)) ] - core.lab_replace_values( - tracked_lab, rp, new_IDs_in_trackedIDs, new_tracked_IDs - ) + if not dont_return_tracked_lab: + core.lab_replace_values( + tracked_lab, rp, new_IDs_in_trackedIDs, new_tracked_IDs + ) assignments.update(dict(zip(new_IDs_in_trackedIDs, new_tracked_IDs))) log_debugging( 'new_untracked_and_tracked', @@ -283,10 +341,15 @@ def indexAssignment( new_tracked_IDs=new_tracked_IDs ) if tracked_IDs: - core.lab_replace_values( - tracked_lab, rp, old_IDs, tracked_IDs, in_place=True - ) - assignments.update(dict(zip(old_IDs, tracked_IDs))) + if not dont_return_tracked_lab: + core.lab_replace_values( + tracked_lab, rp, old_IDs, tracked_IDs, in_place=True + ) + assignments.update({ + old_ID: tracked_ID + for old_ID, tracked_ID in zip(old_IDs, tracked_IDs) + if old_ID != tracked_ID + }) log_debugging( 'tracked', tracked_IDs=tracked_IDs, @@ -295,6 +358,8 @@ def indexAssignment( if not return_assignments: return tracked_lab + elif dont_return_tracked_lab: + return assignments else: return tracked_lab, assignments @@ -305,17 +370,24 @@ def track_frame( return_all=False, aggr_track=None, IoA_matrix=None, IoA_thresh_aggr=None, IDs_prev=None, return_prev_IDs=False, mother_daughters=None, denom_overlap_matrix = 'area_prev', - IDs=None + return_assignments=False, specific_IDs=None, dont_return_tracked_lab=False ): if not np.any(lab): # Skip empty frames return lab + all_curr_IDs = ( + list(IDs_curr_untracked) + if IDs_curr_untracked is not None else [obj.label for obj in rp] + ) + if IoA_matrix is None: - IoA_matrix, IDs_curr_untracked, IDs_prev = calc_Io_matrix( + IoA_matrix, tracked_curr_IDs, IDs_prev = calc_Io_matrix( lab, prev_lab, rp, prev_rp, IDs_curr_untracked=IDs_curr_untracked, - denom=denom_overlap_matrix, IDs=IDs + denom=denom_overlap_matrix,specific_IDs=specific_IDs, ) + else: + tracked_curr_IDs = IDs_curr_untracked daughters_list = [] if mother_daughters: @@ -323,38 +395,61 @@ def track_frame( daughters_list.extend(daughters) old_IDs, tracked_IDs = assign( - IoA_matrix, IDs_curr_untracked, IDs_prev, + IoA_matrix, tracked_curr_IDs, IDs_prev, IoA_thresh=IoA_thresh, aggr_track=aggr_track, IoA_thresh_aggr=IoA_thresh_aggr, daughters_list=daughters_list, + specific_IDs=specific_IDs, ) if posData is None and unique_ID is None: unique_ID = max( - (max(IDs_prev, default=0), max(IDs_curr_untracked, default=0)) + (max(IDs_prev, default=0), max(all_curr_IDs, default=0)) ) + 1 elif unique_ID is None: # Compute starting unique ID setBrushID_func(useCurrentLab=True) unique_ID = posData.brushID+1 - if not return_all: + if not return_all and not return_assignments: tracked_lab = indexAssignment( - old_IDs, tracked_IDs, IDs_curr_untracked, + old_IDs, tracked_IDs, tracked_curr_IDs, lab.copy(), rp, unique_ID, assign_unique_new_IDs=assign_unique_new_IDs, + specific_IDs=specific_IDs, + all_curr_IDs=all_curr_IDs, + ) + elif dont_return_tracked_lab: + assignments = indexAssignment( + old_IDs, tracked_IDs, tracked_curr_IDs, + lab.copy(), rp, unique_ID, + assign_unique_new_IDs=assign_unique_new_IDs, + return_assignments=True, specific_IDs=specific_IDs, + dont_return_tracked_lab=True, + all_curr_IDs=all_curr_IDs, ) else: tracked_lab, assignments = indexAssignment( - old_IDs, tracked_IDs, IDs_curr_untracked, + old_IDs, tracked_IDs, tracked_curr_IDs, lab.copy(), rp, unique_ID, assign_unique_new_IDs=assign_unique_new_IDs, - return_assignments=return_all, + return_assignments=True, specific_IDs=specific_IDs, + all_curr_IDs=all_curr_IDs, ) # old_new_ids = dict(zip(old_IDs, tracked_IDs)) # for now not used, but could be useful in the future - if return_all: + if return_all and dont_return_tracked_lab: + # special case where we want to only get the assignments but need the rest too! + return IoA_matrix, assignments, tracked_IDs + elif return_all: return tracked_lab, IoA_matrix, assignments, tracked_IDs # remove tracked_IDs and change code in CellACDC_tracker.py if causing problems + elif dont_return_tracked_lab: + return assignments + elif return_assignments: + add_info = { + 'assignments': assignments, + } + return tracked_lab, add_info else: return tracked_lab diff --git a/cellacdc/trackers/CellACDC_2steps/CellACDC_2steps_tracker.py b/cellacdc/trackers/CellACDC_2steps/CellACDC_2steps_tracker.py index 4bef3ec4..35e8b469 100644 --- a/cellacdc/trackers/CellACDC_2steps/CellACDC_2steps_tracker.py +++ b/cellacdc/trackers/CellACDC_2steps/CellACDC_2steps_tracker.py @@ -13,6 +13,27 @@ import cellacdc.core from ..CellACDC import CellACDC_tracker +from ..CellACDC.CellACDC_tracker import _normalize_specific_IDs + +from cellacdc._types import NotGUIParam + +def _format_tracking_result( + tracked_lab, + assignments, + to_track_tracked_objs_2nd_step, + return_assignments, + dont_return_tracked_lab, + ): + add_info = { + 'assignments': assignments, + 'to_track_tracked_objs_2nd_step': to_track_tracked_objs_2nd_step, + } + + if dont_return_tracked_lab: + return add_info + + return tracked_lab, add_info # no harm returning the assignments + class SearchRangeUnits: values = ['pixels', 'micrometre'] @@ -114,7 +135,10 @@ def track_frame( overlap_threshold=0.4, search_range_unit: SearchRangeUnits='pixels', lost_IDs_search_range=10, - unique_ID: Integer=None + unique_ID: Integer=None, + specific_IDs: NotGUIParam=None, + dont_return_tracked_lab=False, + return_assignments=False, ): """Track two consecutive frames in two steps. First step based on `overlap_threshold` and second step tracks only lost objects to new @@ -148,20 +172,30 @@ def track_frame( If not None, uses this as starting ID for all the untracked objects. If None, this will be calculated based on the two input frames. """ + specific_IDs = _normalize_specific_IDs(specific_IDs) to_track_tracked_objs_2nd_step = None prev_rp = skimage.measure.regionprops(prev_frame_lab) curr_rp = skimage.measure.regionprops(current_frame_lab) - tracked_lab_1st_step = CellACDC_tracker.track_frame( + tracked_lab_1st_step, add_info = CellACDC_tracker.track_frame( prev_frame_lab, prev_rp, current_frame_lab, curr_rp, IoA_thresh=overlap_threshold, - return_prev_IDs=False, - unique_ID=unique_ID + return_prev_IDs=False, + unique_ID=unique_ID, + specific_IDs=specific_IDs, + return_assignments=True, ) + assignments_step_1 = add_info['assignments'] + selected_tracked_IDs = None + if specific_IDs is not None: + selected_tracked_IDs = { + assignments_step_1.get(curr_ID, curr_ID) + for curr_ID in specific_IDs + } prev_rp_mapper = {obj.label: obj for obj in prev_rp} @@ -176,27 +210,43 @@ def track_frame( } if not lost_rp_mapper: - return tracked_lab_1st_step, to_track_tracked_objs_2nd_step + return _format_tracking_result( + tracked_lab_1st_step, + assignments_step_1, + to_track_tracked_objs_2nd_step, + return_assignments, + dont_return_tracked_lab, + ) new_rp_mapper = { obj.label: obj for obj in tracked_rp_1st_step + if ( + selected_tracked_IDs is None + or obj.label in selected_tracked_IDs + ) if prev_rp_mapper.get(obj.label) is None } - + if not new_rp_mapper: - return tracked_lab_1st_step, to_track_tracked_objs_2nd_step + return _format_tracking_result( + tracked_lab_1st_step, + assignments_step_1, + to_track_tracked_objs_2nd_step, + return_assignments, + dont_return_tracked_lab, + ) ndim = current_frame_lab.ndim lost_IDs_coords = np.zeros((len(lost_rp_mapper), ndim)) lost_IDs_idx_to_obj_mapper = {} for lost_idx, lost_obj in enumerate(lost_rp_mapper.values()): - lost_IDs_coords[lost_idx] = lost_obj.centroid + lost_IDs_coords[lost_idx] = lost_obj.centroid # we have overwritten RP so its always cached lost_IDs_idx_to_obj_mapper[lost_idx] = lost_obj new_IDs_coords = np.zeros((len(new_rp_mapper), ndim)) new_IDs_idx_to_obj_mapper = {} for new_idx, new_obj in enumerate(new_rp_mapper.values()): - new_IDs_coords[new_idx] = new_obj.centroid + new_IDs_coords[new_idx] = new_obj.centroid # we have overwritten RP so its always cached new_IDs_idx_to_obj_mapper[new_idx] = new_obj if search_range_unit == 'micrometre': @@ -229,21 +279,45 @@ def track_frame( tracked_objs_2nd_step.append(lost_IDs_idx_to_obj_mapper[i]) if not IDs_to_track: - return tracked_lab_1st_step, to_track_tracked_objs_2nd_step + return _format_tracking_result( + tracked_lab_1st_step, + assignments_step_1, + to_track_tracked_objs_2nd_step, + return_assignments, + dont_return_tracked_lab, + ) - tracked_lab_2nd_step = cellacdc.core.lab_replace_values( - tracked_lab_1st_step, - tracked_rp_1st_step, - IDs_to_track, - tracked_IDs_2nd_step - ) + if not dont_return_tracked_lab: + tracked_lab_2nd_step = cellacdc.core.lab_replace_values( + tracked_lab_1st_step, + tracked_rp_1st_step, + IDs_to_track, + tracked_IDs_2nd_step + ) + else: + tracked_lab_2nd_step = None if self._annot_obj_2nd_step: to_track_tracked_objs_2nd_step = ( objs_to_track, tracked_objs_2nd_step ) - return tracked_lab_2nd_step, to_track_tracked_objs_2nd_step + assignments_step_2 = dict(zip(IDs_to_track, tracked_IDs_2nd_step)) + for current_ID, tracked_ID in list(assignments_step_1.items()): + final_tracked_ID = assignments_step_2.get(tracked_ID) + if final_tracked_ID is not None: + assignments_step_1[current_ID] = final_tracked_ID + + for current_ID, tracked_ID in assignments_step_2.items(): + assignments_step_1.setdefault(current_ID, tracked_ID) + + return _format_tracking_result( + tracked_lab_2nd_step, + assignments_step_1, + to_track_tracked_objs_2nd_step, + return_assignments, + dont_return_tracked_lab, + ) def updateGuiProgressBar(self, signals): if signals is None: diff --git a/cellacdc/trackers/CellACDC_normal_division/CellACDC_normal_division_tracker.py b/cellacdc/trackers/CellACDC_normal_division/CellACDC_normal_division_tracker.py index 8e47215a..877ebbba 100644 --- a/cellacdc/trackers/CellACDC_normal_division/CellACDC_normal_division_tracker.py +++ b/cellacdc/trackers/CellACDC_normal_division/CellACDC_normal_division_tracker.py @@ -4,33 +4,13 @@ from cellacdc.core import getBaseCca_df, printl from cellacdc.myutils import checked_reset_index, checked_reset_index_Cell_ID import numpy as np -from skimage.measure import regionprops from tqdm import tqdm import pandas as pd from cellacdc.myutils import exec_time from cellacdc._types import NotGUIParam import copy import cellacdc.debugutils as debugutils - -# def filter_cols(df): -# """ -# Filters the columns of a DataFrame based on a predefined set of column names. -# 'generation_num_tree', 'root_ID_tree', 'sister_ID_tree', 'parent_ID_tree', 'parent_ID_tree', 'emerg_frame_i', 'division_frame_i' -# plus any column that starts with 'sister_ID_tree' - -# Parameters: -# - df (pandas.DataFrame): The input DataFrame. - -# Returns: -# - pandas.DataFrame: The filtered DataFrame containing only the specified columns. -# """ -# lin_tree_cols = {'generation_num_tree', 'root_ID_tree', -# 'sister_ID_tree', 'parent_ID_tree', -# 'parent_ID_tree', 'emerg_frame_i', -# 'division_frame_i', 'is_history_known'} -# sis_cols = {col for col in df.columns if col.startswith('sister_ID_tree')} -# lin_tree_cols = lin_tree_cols | sis_cols -# return df[list(lin_tree_cols)] +from cellacdc.regionprops import acdcRegionprops as acdcRegionprops def reorg_sister_cells_for_export(lineage_tree_frame): """ @@ -60,45 +40,6 @@ def reorg_sister_cells_for_export(lineage_tree_frame): return lineage_tree_frame -# def reorg_sister_cells_inner_func(row): -# """ -# Reorganizes the sister cells in a row of a DataFrame. Used as an inner function for apply. - -# Parameters: -# - row (pandas.Series): The input row of the DataFrame (alredy filtered for the sister columns). -# Returns: -# - pandas.Series: The reorganized row with the sister cells. -# """ - -# values = [int(i) for i in row if i not in {0, -1} and not np.isnan(i)] or [-1] -# values = list(set(values)) -# return values - - -# def reorg_sister_cells_for_import(df): -# """ -# Reorganizes the sister cells for import. - -# This function takes a DataFrame `df` as input and performs the following steps: -# 1. Identifies the sister columns in the DataFrame. -# 2. Removes any values that are equal to 0 or -1 from the sister columns. (Which both represent no sister cell) -# 3. Converts the remaining values in the sister columns to a set. -# 4. Converts the set of values to a list if it is not empty, otherwise assigns [-1] to the sister column. (It actually shouldn't be empty, but just in case...) -# 5. Removes the sister columns from the DataFrame. And adds the list as the new 'sister_ID_tree' column. - -# Parameters: -# - df (pandas.DataFrame): The input DataFrame. - -# Returns: -# - df (pandas.DataFrame): The modified DataFrame with reorganized sister cells. -# """ -# sister_cols = [col for col in df.columns if col.startswith('sister_ID_tree')] # handling sister columns -# df.loc[:, 'sister_ID_tree'] = df[sister_cols].apply(reorg_sister_cells_inner_func, axis=1) -# sister_cols.remove('sister_ID_tree') -# df = df.drop(columns=sister_cols) -# df = checked_reset_index_Cell_ID(df) -# return df - def mother_daughter_assign(IoA_matrix, IoA_thresh_daughter, min_daughter, max_daughter, IoA_thresh_instant=None): """ Identifies cells that have not undergone division based on the input IoA matrix. @@ -151,13 +92,8 @@ def mother_daughter_assign(IoA_matrix, IoA_thresh_daughter, min_daughter, max_da else: should_remove_idx.append(False) - # printl(f'length of mother_daughters: {len(mother_daughters), len(should_remove_idx)}') mother_daughters = [mother_daughters[i] for i, remove in enumerate(should_remove_idx) if not remove] - # daughters_li = [] - # for _, daughters in mother_daughters: - # daughters_li.extend(daughters) - return aggr_track, mother_daughters def added_lineage_tree_to_cca_df(added_lineage_tree): @@ -241,33 +177,6 @@ def IoA_index_daughter_to_ID(daughters, assignments, IDs_curr_untracked): return daughter_IDs -# def update_fam_dynamically(families, fixed_df, Cell_IDs_fixed=None): -# if Cell_IDs_fixed is None: -# Cell_IDs_fixed = fixed_df.index -# for idx, family in enumerate(families): -# # Keep only cellinfos where cell_id is in Cell_IDs_fixed -# families[idx] = [cellinfo for cellinfo in family if cellinfo[0] not in Cell_IDs_fixed] - -# families = [family for family in families if family] # Remove empty families -# handled_cells = set() -# for family in families: -# root_ID = family[0][0] # The first cell in the family is the root -# try: -# relevant_cells = fixed_df.loc[fixed_df['root_ID_tree'] == root_ID] -# except: -# printl(fixed_df['root_ID_tree']) -# for relevant_cell in relevant_cells.index: -# # Update the family with the generation number and root ID -# family.append((relevant_cell, relevant_cells.loc[relevant_cell, 'generation_num_tree'])) -# handled_cells.update(relevant_cells.index) - -# for cell_id in Cell_IDs_fixed: -# if cell_id not in handled_cells: -# # If the cell is not handled, create a new family for it -# families.append([(cell_id, fixed_df.loc[cell_id, 'generation_num_tree'])]) - -# return families - class normal_division_tracker: """ A class that tracks cell divisions in a video sequence. The tracker uses the Intersection over Area (IoA) metric to track cells and identify daughter cells. @@ -323,7 +232,8 @@ def __init__(self, self.tracked_video[0] = segm_video[0] def track_frame(self, frame_i, lab=None, prev_lab=None, rp=None, prev_rp=None, - IDs=None, unique_ID=None): + IDs=None, unique_ID=None, + return_assignments=False, specific_IDs=None, dont_return_tracked_lab=False): """ Tracks a single frame in the video sequence. @@ -342,42 +252,75 @@ def track_frame(self, frame_i, lab=None, prev_lab=None, rp=None, prev_rp=None, prev_lab = self.tracked_video[frame_i-1] if rp is None: - self.rp = regionprops(lab.copy()) + self.rp = acdcRegionprops(lab.copy(), precache_centroids=False) else: self.rp = rp if prev_rp is None: - prev_rp = regionprops(prev_lab.copy()) - - IoA_matrix, self.IDs_curr_untracked, self.IDs_prev = calc_Io_matrix(lab, - prev_lab, - self.rp, - prev_rp, - IDs=IDs, - ) - self.aggr_track, self.mother_daughters = mother_daughter_assign(IoA_matrix, - IoA_thresh_daughter=self.IoA_thresh_daughter, - min_daughter=self.min_daughter, - max_daughter=self.max_daughter, - IoA_thresh_instant=self.IoA_thresh - ) - self.tracked_lab, IoA_matrix, self.assignments, _ = track_frame_base(prev_lab, - prev_rp, - lab, - self.rp, - IoA_thresh=self.IoA_thresh, - IoA_matrix=IoA_matrix, - aggr_track=self.aggr_track, - IoA_thresh_aggr=self.IoA_thresh_aggressive, - IDs_curr_untracked=self.IDs_curr_untracked, - IDs_prev=self.IDs_prev, - return_all=True, - mother_daughters=self.mother_daughters, - unique_ID=unique_ID - ) + prev_rp = acdcRegionprops(prev_lab.copy(), precache_centroids=False) + + full_IoA_matrix, full_curr_IDs, self.IDs_prev = calc_Io_matrix( + lab, + prev_lab, + self.rp, + prev_rp, + ) + IoA_matrix, self.IDs_curr_untracked, _ = calc_Io_matrix( + lab, + prev_lab, + self.rp, + prev_rp, + specific_IDs=specific_IDs, + ) + full_aggr_track, full_mother_daughters = mother_daughter_assign( + full_IoA_matrix, + IoA_thresh_daughter=self.IoA_thresh_daughter, + min_daughter=self.min_daughter, + max_daughter=self.max_daughter, + IoA_thresh_instant=self.IoA_thresh, + ) + + subset_idx_mapper = { + curr_ID: idx for idx, curr_ID in enumerate(self.IDs_curr_untracked) + } + self.aggr_track = [ + subset_idx_mapper[full_curr_IDs[idx]] + for idx in full_aggr_track + if full_curr_IDs[idx] in subset_idx_mapper + ] + self.mother_daughters = [] + for mother_idx, daughter_idxs in full_mother_daughters: + subset_daughter_idxs = [ + subset_idx_mapper[full_curr_IDs[idx]] + for idx in daughter_idxs + if full_curr_IDs[idx] in subset_idx_mapper + ] + if subset_daughter_idxs: + self.mother_daughters.append((mother_idx, subset_daughter_idxs)) - - self.tracked_video[frame_i] = self.tracked_lab + out = track_frame_base( + prev_lab, + prev_rp, + lab, + self.rp, + IoA_thresh=self.IoA_thresh, + IoA_matrix=IoA_matrix, + aggr_track=self.aggr_track, + IoA_thresh_aggr=self.IoA_thresh_aggressive, + IDs_curr_untracked=self.IDs_curr_untracked, + IDs_prev=self.IDs_prev, + return_all=True, + mother_daughters=self.mother_daughters, + unique_ID=unique_ID, + specific_IDs=specific_IDs, + return_assignments=return_assignments, + dont_return_tracked_lab=dont_return_tracked_lab, + ) + if dont_return_tracked_lab: + IoA_matrix, self.assignments, self.tracked_IDs = out + else: + self.tracked_lab, IoA_matrix, self.assignments, self.tracked_IDs = out + self.tracked_video[frame_i] = self.tracked_lab class normal_division_lineage_tree: """ @@ -594,8 +537,8 @@ def init_lineage_tree(self, lab=None, first_df=None, frame_i=None): if lab is not None: - rp = regionprops(lab) - labels = [obj.label for obj in rp] + rp = acdcRegionprops(lab, precache_centroids=False) + labels = rp.IDs cca_df = pd.DataFrame({ 'Cell_ID': labels, }) @@ -730,10 +673,10 @@ def real_time(self, frame_i, lab, prev_lab, rp=None, prev_rp=None): None """ if rp is None: - rp = regionprops(lab) + rp = acdcRegionprops(lab, precache_centroids=False) if prev_rp is None: - prev_rp = regionprops(prev_lab) + prev_rp = acdcRegionprops(prev_lab, precache_centroids=False) IoA_matrix, self.IDs_curr_untracked, self.IDs_prev = calc_Io_matrix(lab, prev_lab, rp, prev_rp) @@ -751,7 +694,7 @@ def real_time(self, frame_i, lab, prev_lab, rp=None, prev_rp=None): self.mother_daughters = filtered_mother_daughters curr_IDs = set(self.IDs_curr_untracked) - prev_IDs = {obj.label for obj in prev_rp} + prev_IDs = set(prev_rp.IDs) new_IDs = curr_IDs - prev_IDs self.frames_for_dfs.add(frame_i) self.add_new_frame(frame_i, self.mother_daughters, self.IDs_prev, self.IDs_curr_untracked, None, curr_IDs, new_IDs) @@ -842,80 +785,6 @@ def update_df_li_locally(self, df, frame_i): df_data.loc[ID] = cell_row - # This will probably be made obsolete by the gui_mode version - # def insert_lineage_df(self, lineage_df, frame_i, update_fams=True, - # consider_children=True, raw_input=False, propagate=True, - # relevant_cells=None): - # """ - # Insert or replace a lineage DataFrame at a given frame index, optionally updating families and propagating changes. - - # Args: - # lineage_df (pd.DataFrame): The lineage DataFrame to insert. - # frame_i (int): The index of the frame. - # update_fams (bool, optional): If True, update families based on the changes. Defaults to True. - # consider_children (bool, optional): If True, update children of the inserted frame. Defaults to True. - - # Returns: - # None - # """ - # if not self.gui_mode: - # printl("here") - # if not raw_input: - # lineage_df = reorg_sister_cells_for_import(lineage_df) - # lineage_df = filter_cols(lineage_df) - - # lineage_df = checked_reset_index_Cell_ID(lineage_df) - # len_lineage_list = len(self.lineage_list) - # if frame_i == len_lineage_list: - # self.lineage_list.append(lineage_df) - # self.frames_for_dfs.add(frame_i) - - # self.update_df_li_locally(lineage_df, frame_i) - - # if propagate: - # out = update_consistency(df_li=self.lineage_list, fixed_frame_i=frame_i, - # consider_children=consider_children, Cell_IDs_fixed=relevant_cells, - # families=self.families if update_fams else None) - # if update_fams: - # self.lineage_list, self.families = out - # else: - # self.lineage_list = out - - # elif frame_i < len_lineage_list: - # self.lineage_list[frame_i] = lineage_df - # self.update_df_li_locally(lineage_df, frame_i) - # if propagate: - # out = update_consistency(df_li=self.lineage_list, fixed_frame_i=frame_i, - # consider_children=consider_children, Cell_IDs_fixed=relevant_cells, - # families=self.families if update_fams else None) - # if update_fams: - # self.lineage_list, self.families = out - # else: - # self.lineage_list = out - - - # elif frame_i > len_lineage_list: - # printl(f'WARNING: Frame_i {frame_i} was inserted. The lineage list was only {len(self.lineage_list)} frames long, so the last known lineage tree was copy pasted up to frame_i {frame_i}') - - # original_length = len(self.lineage_list) - # self.lineage_list = self.lineage_list + [self.lineage_list[-1]] * (frame_i - len(self.lineage_list)) - - # self.generate_gen_df_from_df_li(self.lineage_list, force=True) - - # self.lineage_list.append(lineage_df) - - # frame_is = set(range(len(self.lineage_list)-original_length)) - # self.frames_for_dfs = self.frames_for_dfs | frame_is - - # self.update_df_li_locally(lineage_df, frame_i) - # if propagate: - # out = update_consistency(df_li=self.lineage_list, fixed_frame_i=frame_i, - # consider_children=consider_children, Cell_IDs_fixed=relevant_cells, - # families=self.families if update_fams else None) - # if update_fams: - # self.lineage_list, self.families = out - # else: - # self.lineage_list = out def _update_consistency(self, fixed_frame_i=None, fixed_df=None, Cell_IDs_fixed=None, consider_children=True): @@ -1043,63 +912,6 @@ def propagate(self, frame_i, relevant_cells=None): self._update_consistency(fixed_frame_i=frame_i, consider_children=True, Cell_IDs_fixed=relevant_cells) - # This will probably be made obsolete by the gui_mode version - # def load_lineage_df_list(self, df_li): - # """ - # Load a list of lineage DataFrames, reconstructing the lineage tree and families. - - # Args: - # df_li (list): List of acdc_df DataFrames. - - # Returns: - # None - # """ - # df_li = copy.deepcopy(df_li) # Ensure we don't modify the original list - # # Support for first_frame was removed since it is not necessary, just make the df_li correct... - # # Also the tree needs to be init before. Also if df_li does not contain any relevant dfs, nothing happens - # print('Loading lineage data...') - # df_li_new = [] - # families = [] - # families_root_IDs = [] - # added_IDs = set() - - # for i, df in enumerate(df_li): - # if df is None: - # continue - - # if 'generation_num_tree' not in df.columns: - # continue - - # mask = (df['generation_num_tree'].isnull() | - # df["generation_num_tree"].isna()) - - # if mask.any() or df["generation_num_tree"].empty: - # continue - - # df = checked_reset_index_Cell_ID(df) - - # df = filter_cols(df) - # df = reorg_sister_cells_for_import(df) - # self.frames_for_dfs.add(i) - # df_li_new.append(df) - - # df_filter = df.index.isin(added_IDs) - # for root_ID, group in df[df_filter].groupby('root_ID_tree'): - # if root_ID not in families_root_IDs: - # family = list(zip(group.index, group['generation_num_tree'])) - # families.append(family) - # families_root_IDs.append(root_ID) - # else: - # # If the root_ID is already in families, we just update the family with the new cells - # family_index = families_root_IDs.index(root_ID) - # families[family_index].extend(zip(group.index, group['generation_num_tree'])) - - # added_IDs.update(group.index) - - # if df_li_new: - # self.lineage_list = df_li_new - - # This will probably be made obsolete by the gui_mode version def export_df(self, frame_i): """ Export the lineage DataFrame for a specific frame, cleaning up auxiliary columns. @@ -1256,8 +1068,8 @@ def track(self, IoA_thresh_daughter=IoA_thresh_daughter ) pbar.update() - rp = regionprops(segm_video[0]) - prev_IDs = {obj.label for obj in rp} + rp = acdcRegionprops(segm_video[0], precache_centroids=False) + prev_IDs = rp.IDs_set prev_rp = rp continue @@ -1270,8 +1082,8 @@ def track(self, IDs_prev = tracker.IDs_prev assignments = tracker.assignments IDs_curr_untracked = tracker.IDs_curr_untracked - rp = regionprops(tracker.tracked_lab) - curr_IDs = {obj.label for obj in rp} + rp = acdcRegionprops(tracker.tracked_lab) + curr_IDs = rp.IDs_set new_IDs = curr_IDs - prev_IDs if record_lineage or return_tracked_lost_centroids: tree.add_new_frame( @@ -1289,7 +1101,7 @@ def track(self, found = True break if not found: - labels = [obj.label for obj in rp] + labels = rp.IDs printl(mother, mother_ID, IDs_curr_untracked, labels) raise ValueError('Something went wrong with the tracked lost centroids.') @@ -1328,6 +1140,9 @@ def track_frame(self, min_daughter:int = 2, max_daughter:int = 2, unique_ID: NotGUIParam =None, + return_assignments: NotGUIParam =False, + specific_IDs: NotGUIParam =None, + dont_return_tracked_lab: NotGUIParam =False, ): """ Tracks cell division in a single frame. (This is used for real time tracking in the GUI) @@ -1352,14 +1167,32 @@ def track_frame(self, segm_video = [previous_frame_labels, current_frame_labels] tracker = normal_division_tracker(segm_video, IoA_thresh_daughter, min_daughter, max_daughter, IoA_thresh, IoA_thresh_aggressive) - tracker.track_frame(1, IDs=IDs, unique_ID=unique_ID) - tracked_video = tracker.tracked_video + tracker.track_frame( + 1, + IDs=IDs, + unique_ID=unique_ID, + return_assignments=return_assignments, + specific_IDs=specific_IDs, + dont_return_tracked_lab=dont_return_tracked_lab, + ) mother_daughters_pairs = tracker.mother_daughters IDs_prev = tracker.IDs_prev mothers = {IDs_prev[pair[0]] for pair in mother_daughters_pairs} + assignments = tracker.assignments + + if dont_return_tracked_lab: + return assignments + + tracked_lab = tracker.tracked_video[-1] + if not return_assignments: + return tracked_lab - return tracked_video[-1], mothers + add_info = { + 'mothers': mothers, + 'assignments': assignments + } + return tracked_lab, add_info def updateGuiProgressBar(self, signals): """ diff --git a/cellacdc/whitelist.py b/cellacdc/whitelist.py index 0621f451..c2211b96 100644 --- a/cellacdc/whitelist.py +++ b/cellacdc/whitelist.py @@ -1,7 +1,7 @@ import os import numpy as np import skimage.measure -from . import printl, myutils +from . import printl, myutils, regionprops import json from typing import Set, List, Tuple import time @@ -222,14 +222,14 @@ def create_new_centroids(self, new_IDs = self.originalLabsIDs[i] - self.originalLabsIDs[i-1] - rp = None if frame_i==i and curr_rp is not None: rp = curr_rp else: - rp = skimage.measure.regionprops(self.originalLabs[i]) + rp = regionprops.acdcRegionprops(self.originalLabs[i], + precache_centroids=False) self.new_centroids.append({ - tuple(map(int, obj.centroid)) for obj in rp if obj.label in new_IDs + tuple(map(int, rp.get_centroid(label))) for label in new_IDs }) @@ -411,7 +411,7 @@ def IDsAccepted(self, printl('Using curr_lab') IDs_curr = {obj.label for obj in skimage.measure.regionprops(lab)} else: - IDs_curr = allData_li[frame_i]['IDs'] + IDs_curr = allData_li[frame_i]['regionprops'].IDs_set if self._debug: printl('Using allData_li') @@ -488,7 +488,7 @@ def makeOriginalLabsAndIDs(self, segm_data: np.ndarray, IDs = set(IDs_curr) elif allData_li is not None: try: - IDs = set(allData_li[i]['IDs']) + IDs = allData_li[i]['regionprops'].IDs_set except KeyError: pass if IDs is None: @@ -746,7 +746,7 @@ def propagateIDs(self, printl('Using index_lab_combo') IDs_curr = {obj.label for obj in skimage.measure.regionprops(lab)} elif curr_rp is not None: - IDs_curr = {obj.label for obj in curr_rp} + IDs_curr = curr_rp.IDs_set if self._debug: printl('Using rp') elif curr_lab is not None: @@ -755,7 +755,7 @@ def propagateIDs(self, printl('Using curr_lab') IDs_curr = {obj.label for obj in skimage.measure.regionprops(lab)} else: - IDs_curr = allData_li[frame_i]['IDs'] + IDs_curr = allData_li[frame_i]['regionprops'].IDs_set if self._debug: printl('Using allData_li') @@ -872,7 +872,7 @@ def propagateIDs(self, if frame_i == i: IDs_curr_loc = IDs_curr else: - IDs_curr_loc = set(allData_li[i]['IDs']) + IDs_curr_loc =allData_li[i]['regionprops'].IDs_set new_whitelist = self.get(i, try_create_new_whitelists).copy() old_whitelist = new_whitelist.copy() @@ -939,12 +939,12 @@ def whitelistTrackOGagainstPreviousFrame_cb(self, signal_slot=None): if not self.whitelistCheckOriginalLabels(): return old_cell_IDs = posData.whitelist.originalLabsIDs[frame_i] - prev_cell_IDs = posData.allData_li[frame_i-1]['IDs'] + prev_cell_IDs = posData.allData_li[frame_i-1]['regionprops'].IDs_set self.whitelistTrackOGCurr(against_prev=True) new_cell_IDs = posData.whitelist.originalLabsIDs[frame_i] new_IDs = new_cell_IDs - old_cell_IDs - new_IDs = new_IDs & set(prev_cell_IDs) + new_IDs = new_IDs & prev_cell_IDs self.whitelistUpdateLab( track_og_curr=False, IDs_to_add=new_IDs, @@ -1066,7 +1066,7 @@ def whitelistViewOGIDs(self, checked:bool): self.store_data(autosave=False) if frame_i > 0: - missing_IDs = set(posData.IDs) - set(posData.allData_li[frame_i-1]['IDs']) + missing_IDs = posData.IDs_set - posData.allData_li[frame_i-1]['regionprops'].IDs_set self.trackManuallyAddedObject(missing_IDs,isNewID=True, wl_update=False) self.setAllTextAnnotations() @@ -1502,7 +1502,7 @@ def whitelistTrackOGCurr(self, frame_i:int=None, ### against what should I track? if lab is not None and not rp: - rp = skimage.measure.regionprops(lab) + rp = regionprops.acdcRegionprops(lab, precache_centroids=False) changed_frame = False if lab is None: @@ -1520,7 +1520,7 @@ def whitelistTrackOGCurr(self, frame_i:int=None, rp = posData.rp lab = posData.lab og_lab = posData.whitelist.originalLabs[frame_i] - og_rp = skimage.measure.regionprops(og_lab) + og_rp = regionprops.acdcRegionprops(og_lab, precache_centroids=False) # lab = lab.copy() denom_overlap_matrix = 'union' if not against_prev else 'area_prev' @@ -1530,7 +1530,6 @@ def whitelistTrackOGCurr(self, frame_i:int=None, denom_overlap_matrix=denom_overlap_matrix, posData = posData, setBrushID_func=self.setBrushID, - IDs=IDs, # assign_unique_new_IDs=False, ) @@ -1583,7 +1582,7 @@ def whitelistTrackCurrOG(self, frame_i:int=None, against_prev:bool=False): else: og_lab = posData.whitelist.originalLabs[frame_i] - og_rp = skimage.measure.regionprops(og_lab) + og_rp = regionprops.acdcRegionprops(og_lab, precache_centroids=False) denom_overlap_matrix = 'union' if not against_prev else 'area_prev' diff --git a/cellacdc/workers.py b/cellacdc/workers.py index 6a57ddd9..c59387ae 100755 --- a/cellacdc/workers.py +++ b/cellacdc/workers.py @@ -178,7 +178,7 @@ def run(self): for frame_i, data_dict in enumerate(self.posData.allData_li): lab = data_dict['labels'] rp = data_dict['regionprops'] - IDs = data_dict['IDs'] + IDs = data_dict['regionprops'].IDs if lab is None: lab = self.posData.segm_data[frame_i] rp = skimage.measure.regionprops(lab) @@ -365,7 +365,7 @@ def run(self): curr_img = self.guiWin.getDisplayedImg1() prev_lab = self.guiWin.get_2Dlab(posData.allData_li[frame_i-1]['labels']) - prev_IDs = set(posData.allData_li[frame_i-1]['IDs']) + prev_IDs = posData.allData_li[frame_i-1]['regionprops'].IDs_set # should probably not paly so much with posData.lab, instead handle stuff myself self.signals.initProgressBar.emit(2 * args_new['max_iterations']) @@ -5185,7 +5185,7 @@ def check(self, posData): # There are no annotations at frame_i --> stop break - IDs = data_dict['IDs'] + IDs = data_dict['regionprops'].IDs checker = core.CcaIntegrityChecker(cca_df, lab, IDs) for checkpoint in checkpoints: @@ -6155,7 +6155,7 @@ def saveAcdcDf(self, posData: load.loadData, end_i): last_cca_frame_i=self.mainWin.save_cca_until_frame_i ) - def saveSegmData(self, posData, end_i, saved_segm_data): + def saveSegmData(self, posData: load.loadData, end_i, saved_segm_data): self.progress.emit(f'Saving segmentation data for {posData.relPath}...') for frame_i, data_dict in enumerate(posData.allData_li[:end_i+1]): if self.saveWin.aborted: @@ -6181,6 +6181,14 @@ def saveSegmData(self, posData, end_i, saved_segm_data): io.savez_compressed( posData.segm_npz_path, np.squeeze(saved_segm_data) ) + + # save information about the segmention + posData.updateSegmMetadata(all=True) + posData.saveSegmMetadataIni() + + # save rp info about segm + self.progress.emit(f'Saving additional data for {posData.relPath}...') + posData.saveCentroidsIDs() posData.segm_data = saved_segm_data # Allow single 2D/3D image if posData.SizeT == 1: diff --git a/pyproject.toml b/pyproject.toml index 7b47c38a..c5e207df 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,9 @@ requires = [ "setuptools>=64", "wheel", - "setuptools_scm[toml]>=8" + "setuptools_scm[toml]>=8", + "cython", + "numpy", ] build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..bfb4ceb7 --- /dev/null +++ b/setup.py @@ -0,0 +1,28 @@ +# only needed for cython extensions, not needed to run normally +import sys +from setuptools import setup, Extension +from Cython.Build import cythonize +import numpy as np + +setup( + ext_modules=cythonize( + Extension( + "cellacdc.regionprops_helper", + sources=["cellacdc/regionprops_helper.pyx"], + include_dirs=[np.get_include()], + ), + annotate=True, + build_dir="build/cython", # .c and .html files go here + ) +) +# move compiled binary to precompiled/ +import shutil +import os + +src_dir = "cellacdc" +for filename in os.listdir(src_dir): + if filename.startswith("regionprops_helper") and (filename.endswith(".so") or filename.endswith(".pyd")): + target_path = os.path.join("cellacdc", "precompiled", filename) + shutil.move(os.path.join(src_dir, filename), target_path) + print(f"Moved {filename} to {target_path}") +