diff --git a/reachml/scoring.py b/reachml/scoring.py index 596b0ed..010925c 100644 --- a/reachml/scoring.py +++ b/reachml/scoring.py @@ -2,15 +2,22 @@ from abc import ABC, abstractmethod from collections import defaultdict +from typing import List +import matplotlib.pyplot as plt import numpy as np -from typing import List +import pandas as pd from scipy.sparse import csr_matrix from tqdm import tqdm from .action_set import ActionSet from .mip import EnumeratorMIP +RESP_BAR_COLOR = "#FFC000" + +# matplotlib font params +plt.rcParams["font.size"] = 15 + class ResponsivenessScorer(ABC): """Base scorer that computes per-feature responsiveness for a point.""" @@ -76,6 +83,30 @@ def _get_inter_key(self, x): """Return a hashable key for caching interventions for `x`.""" pass + def plot(self, score_lst=None, x_idx=None): + """Plot responsiveness scores given a point.""" + if score_lst is None and x_idx is None: + raise ValueError("Either scores or x_idx must be provided") + + if score_lst is None: + score_lst = self.scores[x_idx] + + # Sort the scores and names together + names = self.action_set.names + sorted_data = sorted( + zip(score_lst, names, strict=True), reverse=False + ) # Sort by score descending + sorted_scores, sorted_names = zip(*sorted_data, strict=True) + + fig, ax = plt.subplots(figsize=(8, 6)) + + ax.barh(sorted_names, sorted_scores, color=RESP_BAR_COLOR) + ax.set_xlabel("Responsiveness Score") + ax.set_yticks(range(len(sorted_scores))) + ax.set_xlim(0, 1) + + return fig + def __call__(self, X, clf, save=True, **kwargs): """Compute per-feature responsiveness scores for each row in `X`. @@ -88,6 +119,9 @@ def __call__(self, X, clf, save=True, **kwargs): Returns: Numpy array of shape (n_samples, n_features) with responsiveness scores. """ + if isinstance(X, pd.DataFrame): + X = X.to_numpy() + out = np.zeros((len(X), len(self.action_set))) for i, x in tqdm(enumerate(X), total=len(X)):