Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion reachml/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@

from abc import ABC, abstractmethod
from collections import defaultdict
from typing import List

import matplotlib.pyplot as plt
import numpy as np
from typing import List
import pandas as pd
from scipy.sparse import csr_matrix
from tqdm import tqdm

from .action_set import ActionSet
from .mip import EnumeratorMIP

RESP_BAR_COLOR = "#FFC000"

# matplotlib font params
plt.rcParams["font.size"] = 15


class ResponsivenessScorer(ABC):
"""Base scorer that computes per-feature responsiveness for a point."""
Expand Down Expand Up @@ -76,6 +83,30 @@ def _get_inter_key(self, x):
"""Return a hashable key for caching interventions for `x`."""
pass

def plot(self, score_lst=None, x_idx=None):
"""Plot responsiveness scores given a point."""
if score_lst is None and x_idx is None:
raise ValueError("Either scores or x_idx must be provided")

if score_lst is None:
score_lst = self.scores[x_idx]

# Sort the scores and names together
names = self.action_set.names
sorted_data = sorted(
zip(score_lst, names, strict=True), reverse=False
) # Sort by score descending
sorted_scores, sorted_names = zip(*sorted_data, strict=True)

fig, ax = plt.subplots(figsize=(8, 6))

ax.barh(sorted_names, sorted_scores, color=RESP_BAR_COLOR)
ax.set_xlabel("Responsiveness Score")
ax.set_yticks(range(len(sorted_scores)))
ax.set_xlim(0, 1)

return fig

def __call__(self, X, clf, save=True, **kwargs):
"""Compute per-feature responsiveness scores for each row in `X`.

Expand All @@ -88,6 +119,9 @@ def __call__(self, X, clf, save=True, **kwargs):
Returns:
Numpy array of shape (n_samples, n_features) with responsiveness scores.
"""
if isinstance(X, pd.DataFrame):
X = X.to_numpy()

out = np.zeros((len(X), len(self.action_set)))

for i, x in tqdm(enumerate(X), total=len(X)):
Expand Down