diff --git a/pyproject.toml b/pyproject.toml index 565715f..11de846 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "raidionicsseg" -version = "1.5.0" +version = "1.5.1" description = "Raidionics segmentation and classification back-end with ONNX runtime" readme = "README.md" license = { text = "BSD-2-Clause" } diff --git a/raidionicsseg/Utils/io.py b/raidionicsseg/Utils/io.py index 83b3c39..711adb8 100644 --- a/raidionicsseg/Utils/io.py +++ b/raidionicsseg/Utils/io.py @@ -53,16 +53,22 @@ def dump_predictions( try: naming_suffix = "pred" if parameters.predictions_reconstruction_method == "probabilities" else "labels" class_names = parameters.training_class_names - + modified_header = nib_volume.header.copy() + if parameters.predictions_reconstruction_method != "probabilities": + modified_header.set_data_dtype(np.uint8) + assert predictions.dtype == np.uint8 + else: + modified_header.set_data_dtype(np.float32) + assert predictions.dtype == np.float32 if len(predictions.shape) == 4: first_class = 0 if parameters.training_activation_layer_type == "sigmoid" else 1 for c in range(first_class, predictions.shape[-1]): - img = nib.Nifti1Image(predictions[..., c], affine=nib_volume.affine, header=nib_volume.header) + img = nib.Nifti1Image(predictions[..., c], affine=nib_volume.affine, header=modified_header) predictions_output_path = os.path.join(storage_path, naming_suffix + "_" + class_names[c] + ".nii.gz") os.makedirs(os.path.dirname(predictions_output_path), exist_ok=True) nib.save(img, predictions_output_path) else: - img = nib.Nifti1Image(predictions, affine=nib_volume.affine, header=nib_volume.header) + img = nib.Nifti1Image(predictions, affine=nib_volume.affine, header=modified_header) predictions_output_path = os.path.join(storage_path, naming_suffix + "_" + "argmax" + ".nii.gz") os.makedirs(os.path.dirname(predictions_output_path), exist_ok=True) nib.save(img, predictions_output_path) diff --git a/tests/generic_tests/test_inference_segmentation_reconstruction.py b/tests/generic_tests/test_inference_segmentation_reconstruction.py index 04522e9..6798ac2 100644 --- a/tests/generic_tests/test_inference_segmentation_reconstruction.py +++ b/tests/generic_tests/test_inference_segmentation_reconstruction.py @@ -148,6 +148,7 @@ def test_inference_segmentation_reconstruction_method(test_dir, tmp_path): segmentation_gt_nib.header.get_zooms()[0:3]) * 1e-3 logging.info(f"Volume difference: {abs(pred_volume - gt_volume)}\n") assert abs(pred_volume - gt_volume) < 0.1, "Ground truth and prediction volumes are very different" + assert segmentation_pred_nib.get_data_dtype() == np.float32, "Predictions is not of type float32" except Exception as e: logging.error(f"Error during inference Python package test with: {e} \n {traceback.format_exc()}.\n") if os.path.exists(tmp_test_input_fn): diff --git a/tests/generic_tests/test_inference_segmentation_simple.py b/tests/generic_tests/test_inference_segmentation_simple.py index 82af699..48ac419 100644 --- a/tests/generic_tests/test_inference_segmentation_simple.py +++ b/tests/generic_tests/test_inference_segmentation_simple.py @@ -87,6 +87,7 @@ def test_inference_cli(test_dir, tmp_path): segmentation_gt = nib.load(segmentation_gt_filename).get_fdata()[:] assert np.array_equal(segmentation_pred, segmentation_gt), "Ground truth and prediction arrays are not identical" + assert nib.load(segmentation_pred_filename).get_data_dtype() == np.uint8, "Tresholded predictions is not of type uint8" except Exception as e: logging.error(f"Error during inference CLI test with: {e}\n {traceback.format_exc()}.\n") raise ValueError("Error during inference CLI test.\n") @@ -155,6 +156,8 @@ def test_inference_package(test_dir, tmp_path): segmentation_pred = nib.load(segmentation_pred_filename).get_fdata()[:] segmentation_gt = nib.load(segmentation_gt_filename).get_fdata()[:] assert np.array_equal(segmentation_pred, segmentation_gt), "Ground truth and prediction arrays are not identical" + assert nib.load( + segmentation_pred_filename).get_data_dtype() == np.uint8, "Tresholded predictions is not of type uint8" except Exception as e: logging.error(f"Error during inference Python package test with: {e} \n {traceback.format_exc()}.\n") if os.path.exists(tmp_test_input_fn):