Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "raidionicsseg"
version = "1.5.0"
version = "1.5.1"
description = "Raidionics segmentation and classification back-end with ONNX runtime"
readme = "README.md"
license = { text = "BSD-2-Clause" }
Expand Down
12 changes: 9 additions & 3 deletions raidionicsseg/Utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,22 @@ def dump_predictions(
try:
naming_suffix = "pred" if parameters.predictions_reconstruction_method == "probabilities" else "labels"
class_names = parameters.training_class_names

modified_header = nib_volume.header.copy()
if parameters.predictions_reconstruction_method != "probabilities":
modified_header.set_data_dtype(np.uint8)
assert predictions.dtype == np.uint8
else:
modified_header.set_data_dtype(np.float32)
assert predictions.dtype == np.float32
if len(predictions.shape) == 4:
first_class = 0 if parameters.training_activation_layer_type == "sigmoid" else 1
for c in range(first_class, predictions.shape[-1]):
img = nib.Nifti1Image(predictions[..., c], affine=nib_volume.affine, header=nib_volume.header)
img = nib.Nifti1Image(predictions[..., c], affine=nib_volume.affine, header=modified_header)
predictions_output_path = os.path.join(storage_path, naming_suffix + "_" + class_names[c] + ".nii.gz")
os.makedirs(os.path.dirname(predictions_output_path), exist_ok=True)
nib.save(img, predictions_output_path)
else:
img = nib.Nifti1Image(predictions, affine=nib_volume.affine, header=nib_volume.header)
img = nib.Nifti1Image(predictions, affine=nib_volume.affine, header=modified_header)
predictions_output_path = os.path.join(storage_path, naming_suffix + "_" + "argmax" + ".nii.gz")
os.makedirs(os.path.dirname(predictions_output_path), exist_ok=True)
nib.save(img, predictions_output_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def test_inference_segmentation_reconstruction_method(test_dir, tmp_path):
segmentation_gt_nib.header.get_zooms()[0:3]) * 1e-3
logging.info(f"Volume difference: {abs(pred_volume - gt_volume)}\n")
assert abs(pred_volume - gt_volume) < 0.1, "Ground truth and prediction volumes are very different"
assert segmentation_pred_nib.get_data_dtype() == np.float32, "Predictions is not of type float32"
except Exception as e:
logging.error(f"Error during inference Python package test with: {e} \n {traceback.format_exc()}.\n")
if os.path.exists(tmp_test_input_fn):
Expand Down
3 changes: 3 additions & 0 deletions tests/generic_tests/test_inference_segmentation_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def test_inference_cli(test_dir, tmp_path):
segmentation_gt = nib.load(segmentation_gt_filename).get_fdata()[:]
assert np.array_equal(segmentation_pred,
segmentation_gt), "Ground truth and prediction arrays are not identical"
assert nib.load(segmentation_pred_filename).get_data_dtype() == np.uint8, "Tresholded predictions is not of type uint8"
except Exception as e:
logging.error(f"Error during inference CLI test with: {e}\n {traceback.format_exc()}.\n")
raise ValueError("Error during inference CLI test.\n")
Expand Down Expand Up @@ -155,6 +156,8 @@ def test_inference_package(test_dir, tmp_path):
segmentation_pred = nib.load(segmentation_pred_filename).get_fdata()[:]
segmentation_gt = nib.load(segmentation_gt_filename).get_fdata()[:]
assert np.array_equal(segmentation_pred, segmentation_gt), "Ground truth and prediction arrays are not identical"
assert nib.load(
segmentation_pred_filename).get_data_dtype() == np.uint8, "Tresholded predictions is not of type uint8"
except Exception as e:
logging.error(f"Error during inference Python package test with: {e} \n {traceback.format_exc()}.\n")
if os.path.exists(tmp_test_input_fn):
Expand Down