From ca9dccbe3cf8cc5f0253bbd3bc1c08dab5fb5e30 Mon Sep 17 00:00:00 2001 From: Hagen Wierstorf Date: Mon, 10 Nov 2025 08:44:55 +0100 Subject: [PATCH 1/2] Add metric argument to confusion_matrix() --- audplot/core/api.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/audplot/core/api.py b/audplot/core/api.py index 9a28a12..76d5e8e 100644 --- a/audplot/core/api.py +++ b/audplot/core/api.py @@ -98,6 +98,7 @@ def confusion_matrix( *, labels: Sequence = None, label_aliases: dict = None, + metric: Callable = audmetric.confusion_matrix, percentage: bool = False, show_both: bool = False, ax: matplotlib.axes.Axes = None, @@ -112,6 +113,10 @@ def confusion_matrix( labels: labels to be included in confusion matrix label_aliases: mapping to alias names for labels to be presented in the plot + metric: calculator of confusion matrix. + The callable is expected + to have the two arguments ``truth`` and ``prediction``, + and the keyword arguments ``labels`` and ``normalize`` percentage: if ``True`` present the confusion matrix with percentage values instead of absolute numbers show_both: if ``True`` and percentage is ``True`` @@ -169,7 +174,7 @@ def confusion_matrix( if labels is None: labels = audmetric.utils.infer_labels(truth, prediction) - cm = audmetric.confusion_matrix( + cm = metric( truth, prediction, labels=labels, @@ -190,7 +195,7 @@ def confusion_matrix( # Add a second row of annotations if requested if show_both: - cm2 = audmetric.confusion_matrix( + cm2 = metric( truth, prediction, labels=labels, From e6ed93dfaf618fb83d7b8f1634b83601b6e16704 Mon Sep 17 00:00:00 2001 From: Hagen Wierstorf Date: Mon, 10 Nov 2025 08:48:54 +0100 Subject: [PATCH 2/2] Mention default in docstring --- audplot/core/api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/audplot/core/api.py b/audplot/core/api.py index 76d5e8e..e2cf0ac 100644 --- a/audplot/core/api.py +++ b/audplot/core/api.py @@ -116,7 +116,8 @@ def confusion_matrix( metric: calculator of confusion matrix. The callable is expected to have the two arguments ``truth`` and ``prediction``, - and the keyword arguments ``labels`` and ``normalize`` + and the keyword arguments ``labels`` and ``normalize``. + Defaults to :func:`audmetric.confusion_matrix` percentage: if ``True`` present the confusion matrix with percentage values instead of absolute numbers show_both: if ``True`` and percentage is ``True``