diff --git a/audplot/core/api.py b/audplot/core/api.py index 9a28a12..e2cf0ac 100644 --- a/audplot/core/api.py +++ b/audplot/core/api.py @@ -98,6 +98,7 @@ def confusion_matrix( *, labels: Sequence = None, label_aliases: dict = None, + metric: Callable = audmetric.confusion_matrix, percentage: bool = False, show_both: bool = False, ax: matplotlib.axes.Axes = None, @@ -112,6 +113,11 @@ def confusion_matrix( labels: labels to be included in confusion matrix label_aliases: mapping to alias names for labels to be presented in the plot + metric: calculator of confusion matrix. + The callable is expected + to have the two arguments ``truth`` and ``prediction``, + and the keyword arguments ``labels`` and ``normalize``. + Defaults to :func:`audmetric.confusion_matrix` percentage: if ``True`` present the confusion matrix with percentage values instead of absolute numbers show_both: if ``True`` and percentage is ``True`` @@ -169,7 +175,7 @@ def confusion_matrix( if labels is None: labels = audmetric.utils.infer_labels(truth, prediction) - cm = audmetric.confusion_matrix( + cm = metric( truth, prediction, labels=labels, @@ -190,7 +196,7 @@ def confusion_matrix( # Add a second row of annotations if requested if show_both: - cm2 = audmetric.confusion_matrix( + cm2 = metric( truth, prediction, labels=labels,