diff --git a/chemap/fingerprint_computation.py b/chemap/fingerprint_computation.py index 866f978..95478c9 100644 --- a/chemap/fingerprint_computation.py +++ b/chemap/fingerprint_computation.py @@ -339,7 +339,12 @@ def _rdkit_unfolded( if cfg.count: out: UnfoldedCount = [] - for s, mol in zip(smiles, mols): + for s, mol in tqdm( + zip(smiles, mols), + disable=not show_progress, + desc="Compute fingerprints", + total=len(mols) + ): if mol is None: _handle_invalid(cfg.invalid_policy, s) if cfg.invalid_policy == "keep": @@ -356,7 +361,12 @@ def _rdkit_unfolded( return out out: UnfoldedBinary = [] - for s, mol in zip(smiles, mols): + for s, mol in tqdm( + zip(smiles, mols), + disable=not show_progress, + desc="Compute fingerprints", + total=len(mols) + ): if mol is None: _handle_invalid(cfg.invalid_policy, s) if cfg.invalid_policy == "keep": @@ -386,7 +396,12 @@ def _rdkit_folded_dense( n_features: Optional[int] = None pending_invalid: List[int] = [] # indices in `rows` that need backfill after we learn D - for s, mol in zip(smiles, mols): + for s, mol in tqdm( + zip(smiles, mols), + disable=not show_progress, + desc="Compute fingerprints", + total=len(mols) + ): if mol is None: _handle_invalid(cfg.invalid_policy, s) if cfg.invalid_policy == "keep": @@ -450,7 +465,12 @@ def _rdkit_folded_csr( if cfg.folded_weights is not None: w = np.asarray(cfg.folded_weights, dtype=np.float32).ravel() - for s, mol in zip(smiles, mols): + for s, mol in tqdm( + zip(smiles, mols), + disable=not show_progress, + desc="Compute fingerprints", + total=len(mols) + ): if mol is None: _handle_invalid(cfg.invalid_policy, s) diff --git a/chemap/metrics.py b/chemap/metrics.py index fd43f8d..35b17b9 100644 --- a/chemap/metrics.py +++ b/chemap/metrics.py @@ -219,55 +219,55 @@ def tanimoto_similarity_matrix_dense(references: np.ndarray, queries: np.ndarray # This is O(R*Q*avg_nnz_merge) and can be expensive for large R,Q. # For huge datasets prefer ANN (PyNNDescent/UMAP) with `tanimoto_distance_sparse`. @numba.njit(parallel=True, fastmath=True, cache=True) -def tanimoto_similarity_matrix_sparse_binary(references, queries) -> np.ndarray: +def tanimoto_similarity_matrix_sparse_binary(fingerprints_1, fingerprints_2) -> np.ndarray: """ Pairwise Tanimoto similarity between two sets of unfolded or sparse binary fingerprints. Parameters ---------- - references + fingerprints_1 List of 1D numpy arrays of sorted bit indices (unique). - queries + fingerprints_2 List of 1D numpy arrays of sorted bit indices (unique). """ - R = len(references) - Q = len(queries) + R = len(fingerprints_1) + Q = len(fingerprints_2) out = np.empty((R, Q), dtype=np.float32) for i in numba.prange(R): for j in range(Q): - out[i, j] = tanimoto_similarity_sparse_binary(references[i], queries[j]) + out[i, j] = tanimoto_similarity_sparse_binary(fingerprints_1[i], fingerprints_2[j]) return out @numba.njit(parallel=True, fastmath=True, cache=True) def tanimoto_similarity_matrix_sparse( - references_bits, - references_vals, - queries_bits, - queries_vals + fingerprints_1_bits, + fingerprints_1_vals, + fingerprints_2_bits, + fingerprints_2_vals ) -> np.ndarray: """ Pairwise generalized Tanimoto similarity between two sets of unfolded count/weight fingerprints. Parameters ---------- - references_bits + fingerprints_1_bits List of 1D numpy arrays of sorted bit indices (unique) for reference fingerprints. - references_vals + fingerprints_1_vals List of 1D numpy arrays of counts/weights for reference fingerprints. - queries_bits + fingerprints_2_bits List of 1D numpy arrays of sorted bit indices (unique) for query fingerprints. - queries_vals + fingerprints_2_vals List of 1D numpy arrays of counts/weights for query fingerprints. """ - R = len(references_bits) - Q = len(queries_bits) + R = len(fingerprints_1_bits) + Q = len(fingerprints_2_bits) out = np.empty((R, Q), dtype=np.float32) for i in numba.prange(R): for j in range(Q): out[i, j] = tanimoto_similarity_sparse( - references_bits[i], references_vals[i], - queries_bits[j], queries_vals[j], + fingerprints_1_bits[i], fingerprints_1_vals[i], + fingerprints_2_bits[j], fingerprints_2_vals[j], ) return out diff --git a/chemap/plotting/__init__.py b/chemap/plotting/__init__.py index e6e3c0d..8f26126 100644 --- a/chemap/plotting/__init__.py +++ b/chemap/plotting/__init__.py @@ -1,4 +1,5 @@ from .chem_space_umap import create_chem_space_umap, create_chem_space_umap_gpu +from .cleveland import ClevelandStyle, cleveland_dotplot from .colormap_handling import ( LabelMapConfig, PaletteConfig, @@ -19,12 +20,14 @@ __all__ = [ + "ClevelandStyle", "LabelMapConfig", "PaletteConfig", "PresentPairsConfig", "build_hier_label_map", "build_selected_label_column", "build_selected_palette", + "cleveland_dotplot", "create_chem_space_umap", "create_chem_space_umap_gpu", "make_hier_palette", diff --git a/chemap/plotting/chem_space_umap.py b/chemap/plotting/chem_space_umap.py index d42c343..138867d 100644 --- a/chemap/plotting/chem_space_umap.py +++ b/chemap/plotting/chem_space_umap.py @@ -163,9 +163,9 @@ def create_chem_space_umap_gpu( fpgen: Optional[Any] = None, fingerprint_config: Optional[FingerprintConfig] = None, show_progress: bool = True, - log_count: bool = True, + log_count: bool = False, # UMAP (GPU / cuML) - n_neighbors: int = 15, + n_neighbors: int = 100, min_dist: float = 0.25, ) -> pd.DataFrame: """Compute fingerprints and create 2D UMAP coordinates using cuML (GPU). @@ -221,9 +221,6 @@ def create_chem_space_umap_gpu( show_progress=show_progress, ) - # Convert to sparse array - # fps_csr = fingerprints_to_csr(fingerprints).X - # Reduce memory footprint (works well for count fingerprints) if not log_count: # stays integer-like diff --git a/chemap/plotting/cleveland.py b/chemap/plotting/cleveland.py new file mode 100644 index 0000000..2e9cffa --- /dev/null +++ b/chemap/plotting/cleveland.py @@ -0,0 +1,275 @@ +from dataclasses import dataclass +from typing import Dict, Mapping, Optional, Sequence, Tuple +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from matplotlib.lines import Line2D + + +@dataclass(frozen=True) +class ClevelandStyle: + """Styling defaults for a Cleveland-ish dot plot.""" + figsize: Tuple[float, float] = (9.0, 6.0) + dpi: int = 600 + markersize: float = 7.0 + markeredgecolor: str = "white" + connect_linestyle: str = ":" + connect_linewidth: float = 2.0 + connect_alpha: float = 0.9 + grid_x_linestyle: str = ":" + grid_y_linestyle: str = "-" + grid_alpha_x: float = 0.6 + grid_alpha_y: float = 0.6 + + +def cleveland_dotplot( + *, + # Data in "tidy" arrays + row: Sequence[str], + x: Sequence[float], + color_group: Optional[Sequence[str]] = None, + marker_group: Optional[Sequence[str]] = None, + connect_group: Optional[Sequence[str]] = None, + marker_zorder: Optional[Mapping[str, float]] = None, + + # Ordering / labels + row_order: Optional[Sequence[str]] = None, + row_label_fn=None, + + # Mappings + color_map: Optional[Dict[str, str]] = None, + marker_map: Optional[Dict[str, str]] = None, + + # Figure/axes + title: str = "", + xlabel: str = "", + ax: Optional[Axes] = None, + + # Behavior + connect: bool = True, + show_zero_line_if_needed: bool = True, + + # Range line + row_range: bool = True, + row_range_color: str = "black", + row_range_linewidth: float = 1.5, + row_range_alpha: float = 0.75, + + # Legends + show_legends: bool = True, + color_legend_title: str = "Setting", + marker_legend_title: str = "Variant", + + style: ClevelandStyle = ClevelandStyle(), +) -> Tuple[Figure, Axes]: + """ + Generic Cleveland-ish dot plot. + + Parameters + ---------- + row, x + One entry per point. `row` defines which horizontal row the point belongs to, + `x` is the numeric x-position. + color_group + Category that controls dot color (e.g., binary/count/logcount). + marker_group + Category that controls marker shape (e.g., dense/sparse/fixed). + connect_group + Group id used to connect points on the same row (e.g., same setting). + Connection happens within (row, connect_group). + connect + If True, draw a line between min(x) and max(x) within each (row, connect_group). + """ + + # --- normalize inputs --- + row = list(map(str, row)) + x = np.asarray(x, dtype=float) + + n = len(row) + if x.shape[0] != n: + raise ValueError("row and x must have the same length") + + if color_group is None: + color_group = ["_"] * n + else: + color_group = list(map(str, color_group)) + if len(color_group) != n: + raise ValueError("color_group must match length of row/x") + + if marker_group is None: + marker_group = ["_"] * n + else: + marker_group = list(map(str, marker_group)) + if len(marker_group) != n: + raise ValueError("marker_group must match length of row/x") + + if connect_group is None: + connect_group = ["_"] * n + else: + connect_group = list(map(str, connect_group)) + if len(connect_group) != n: + raise ValueError("connect_group must match length of row/x") + + # Default mappings + if color_map is None: + # Let matplotlib handle if not provided: we’ll still create a legend if asked. + color_map = {} + + if marker_map is None: + marker_map = {"_": "o"} # default to circle + + if marker_zorder is None: + marker_zorder = {} + + # Row order + if row_order is None: + # stable order by appearance + seen = [] + seen_set = set() + for r in row: + if r not in seen_set: + seen.append(r) + seen_set.add(r) + row_order = seen + else: + row_order = list(map(str, row_order)) + + ypos = {r: i for i, r in enumerate(row_order)} + y = np.array([ypos[r] for r in row], dtype=float) + + # Labels + if row_label_fn is None: + row_label_fn = lambda s: s # noqa: E731 + + # --- axes setup --- + if ax is None: + fig_h = max(2.5, len(row_order) * 0.28) + fig, ax = plt.subplots(figsize=(style.figsize[0], fig_h), dpi=style.dpi) + else: + fig = ax.figure + + # Row-range indicator (min->max across ALL points in the row) + if row_range: + from collections import defaultdict + xs_by_row = defaultdict(list) + for r, xv in zip(row, x): + xs_by_row[r].append(float(xv)) + + for r in row_order: + xs = xs_by_row.get(r, []) + if len(xs) < 2: + continue + y0 = ypos[r] + ax.plot( + [min(xs), max(xs)], + [y0, y0], + color=row_range_color, + linewidth=row_range_linewidth, + alpha=row_range_alpha, + zorder=0, # behind connectors + points + ) + + # --- optional connectors --- + if connect: + # For each (row, connect_group), connect min->max x + key_arr = list(zip(row, connect_group)) + # group indices by key + from collections import defaultdict + idx_by_key = defaultdict(list) + for i, k in enumerate(key_arr): + idx_by_key[k].append(i) + + for (r, cg), idxs in idx_by_key.items(): + if len(idxs) < 2: + continue + xs = x[idxs] + y0 = ypos[r] + col_key = color_group[idxs[0]] + col = color_map.get(col_key, "gray") + ax.plot( + [float(xs.min()), float(xs.max())], + [y0, y0], + linestyle=style.connect_linestyle, + linewidth=style.connect_linewidth, + alpha=style.connect_alpha, + color=col, + zorder=1, + ) + + # --- plot dots --- + # Plot by (marker_group, color_group) so style is consistent & legend-friendly + uniq_marker = list(dict.fromkeys(marker_group)) + uniq_color = list(dict.fromkeys(color_group)) + + for mg in uniq_marker: + m = marker_map.get(mg, "o") + z = float(marker_zorder.get(mg, 3.0)) + for cg in uniq_color: + mask = [(marker_group[i] == mg and color_group[i] == cg) for i in range(n)] + if not any(mask): + continue + xs = x[mask] + ys = y[mask] + col = color_map.get(cg, None) + ax.plot( + xs, ys, + linestyle="None", + marker=m, + color=col, + markersize=style.markersize, + markeredgecolor=style.markeredgecolor, + zorder=z, + label=None, # legends are handled manually + ) + + # zero line if needed + if show_zero_line_if_needed and n > 0 and float(np.min(x)) <= 0.0: + ax.axvline(0, linestyle="--", linewidth=1.5, color="black") + + # axes / labels + ax.set_yticks(range(len(row_order))) + ax.set_yticklabels([row_label_fn(r) for r in row_order]) + ax.set_xlabel(xlabel) + ax.set_title(title) + ax.grid(True, axis="x", linestyle=style.grid_x_linestyle, alpha=style.grid_alpha_x) + ax.grid(True, axis="y", linestyle=style.grid_y_linestyle, alpha=style.grid_alpha_y) + ax.set_axisbelow(True) + + # --- legends --- + if show_legends: + # marker legend + marker_handles = [] + for mg in uniq_marker: + if mg == "_" and len(set(uniq_marker)) == 1: + continue + marker_handles.append( + Line2D([0], [0], marker=marker_map.get(mg, "o"), + color="black", linestyle="None", label=str(mg)) + ) + + # color legend + color_handles = [] + for cg in uniq_color: + if cg == "_" and len(set(uniq_color)) == 1: + continue + col = color_map.get(cg, "gray") + color_handles.append( + Line2D([0], [0], marker="o", color=col, linestyle="None", label=str(cg)) + ) + + # Place legends similarly to your original if both exist + if marker_handles and color_handles: + leg1 = ax.legend(handles=marker_handles, loc="lower right", + title=marker_legend_title, frameon=True) + ax.add_artist(leg1) + ax.legend(handles=color_handles, loc="lower left", + title=color_legend_title, frameon=True) + elif marker_handles: + ax.legend(handles=marker_handles, loc="lower right", + title=marker_legend_title, frameon=True) + elif color_handles: + ax.legend(handles=color_handles, loc="lower left", + title=color_legend_title, frameon=True) + + return fig, ax diff --git a/tests/test_cleveland_plot.py b/tests/test_cleveland_plot.py new file mode 100644 index 0000000..6107fec --- /dev/null +++ b/tests/test_cleveland_plot.py @@ -0,0 +1,252 @@ +import matplotlib + + +matplotlib.use("Agg") # must be set before importing pyplot + +import matplotlib.pyplot as plt +import numpy as np +import pytest +from matplotlib.legend import Legend +from matplotlib.lines import Line2D +from chemap.plotting import cleveland_dotplot + + +@pytest.fixture(autouse=True) +def _close_figures(): + """Ensure matplotlib figures are closed after each test.""" + yield + plt.close("all") + + +def _find_lines(ax, *, zorder=None, linestyle=None, marker=None): + """Helper: filter Line2D from axes by a few common properties.""" + lines = [ln for ln in ax.lines if isinstance(ln, Line2D)] + if zorder is not None: + lines = [ln for ln in lines if ln.get_zorder() == zorder] + if linestyle is not None: + lines = [ln for ln in lines if ln.get_linestyle() == linestyle] + if marker is not None: + lines = [ln for ln in lines if ln.get_marker() == marker] + return lines + + +def _has_zero_vline(ax) -> bool: + """Detect the zero reference line produced by ax.axvline(0, linestyle='--').""" + for ln in ax.lines: + if not isinstance(ln, Line2D): + continue + if ln.get_linestyle() != "--": + continue + xdata = np.asarray(ln.get_xdata(), dtype=float) + if xdata.size == 2 and np.allclose(xdata, [0.0, 0.0]): + return True + return False + + +def test_length_mismatch_raises(): + with pytest.raises(ValueError, match="row and x must have the same length"): + cleveland_dotplot( + row=["A", "B"], + x=[1.0], + show_legends=False, + ) + + +def test_group_length_mismatch_raises(): + with pytest.raises(ValueError, match="color_group must match length"): + cleveland_dotplot( + row=["A", "B"], + x=[1.0, 2.0], + color_group=["binary"], # mismatch + show_legends=False, + ) + + with pytest.raises(ValueError, match="marker_group must match length"): + cleveland_dotplot( + row=["A", "B"], + x=[1.0, 2.0], + marker_group=["dense"], # mismatch + show_legends=False, + ) + + with pytest.raises(ValueError, match="connect_group must match length"): + cleveland_dotplot( + row=["A", "B"], + x=[1.0, 2.0], + connect_group=["g1"], # mismatch + show_legends=False, + ) + + +def test_default_row_order_is_stable_by_appearance(): + fig, ax = cleveland_dotplot( + row=["B", "A", "B", "C"], + x=[0.2, 0.1, 0.3, 0.4], + show_legends=False, + connect=False, + row_range=False, + show_zero_line_if_needed=False, + ) + + labels = [t.get_text() for t in ax.get_yticklabels()] + assert labels == ["B", "A", "C"] + + +def test_row_range_indicator_drawn_per_row_with_2plus_points(): + # Row A has 3 points -> should get a range line. + # Row B has 1 point -> no range line. + fig, ax = cleveland_dotplot( + row=["A", "A", "A", "B"], + x=[10, 20, 5, 7], + row_range=True, + connect=False, + show_legends=False, + show_zero_line_if_needed=False, + ) + + # range lines have zorder=0 and default solid linestyle '-' + range_lines = _find_lines(ax, zorder=0) + assert len(range_lines) == 1 + + ln = range_lines[0] + xdata = np.asarray(ln.get_xdata(), dtype=float) + ydata = np.asarray(ln.get_ydata(), dtype=float) + + assert np.allclose(xdata, [5.0, 20.0]) + # Row A is first (stable appearance order), so y should be 0 + assert np.allclose(ydata, [0.0, 0.0]) + + +def test_connectors_drawn_within_row_and_connect_group_and_use_color_map(): + # Two points in (row=A, group=g1) => connector drawn. + # One point in (row=A, group=g2) => no connector for that. + color_map = {"binary": "crimson", "count": "teal"} + + fig, ax = cleveland_dotplot( + row=["A", "A", "A"], + x=[1.0, 3.0, 2.0], + color_group=["binary", "binary", "count"], + connect_group=["g1", "g1", "g2"], + connect=True, + row_range=False, + show_legends=False, + show_zero_line_if_needed=False, + color_map=color_map, + ) + + connector_lines = _find_lines(ax, zorder=1, linestyle=":") + assert len(connector_lines) == 1 + + ln = connector_lines[0] + xdata = np.asarray(ln.get_xdata(), dtype=float) + ydata = np.asarray(ln.get_ydata(), dtype=float) + + assert np.allclose(xdata, [1.0, 3.0]) + assert np.allclose(ydata, [0.0, 0.0]) # row A at y=0 + + # Connector color is taken from first element in that (row, connect_group) + # Here that's 'binary' -> crimson. + assert ln.get_color() == "crimson" + + +def test_marker_zorder_applied_per_marker_group(): + marker_map = {"dense": "o", "sparse": "^"} + marker_zorder = {"dense": 3.0, "sparse": 5.0} + + fig, ax = cleveland_dotplot( + row=["A", "A"], + x=[1.0, 1.0], + marker_group=["dense", "sparse"], + color_group=["binary", "binary"], + marker_map=marker_map, + marker_zorder=marker_zorder, + connect=False, + row_range=False, + show_legends=False, + show_zero_line_if_needed=False, + ) + + # Marker lines have linestyle "None" + dense_lines = _find_lines(ax, linestyle="None", marker="o") + sparse_lines = _find_lines(ax, linestyle="None", marker="^") + + assert len(dense_lines) == 1 + assert len(sparse_lines) == 1 + + assert dense_lines[0].get_zorder() == 3.0 + assert sparse_lines[0].get_zorder() == 5.0 + + +def test_zero_line_added_only_when_min_x_leq_zero(): + # Case 1: includes negative -> should add zero vline + fig, ax = cleveland_dotplot( + row=["A", "B"], + x=[-0.1, 0.2], + connect=False, + row_range=False, + show_legends=False, + show_zero_line_if_needed=True, + ) + assert _has_zero_vline(ax) is True + + # Case 2: all positive -> should not add zero vline + fig, ax = cleveland_dotplot( + row=["A", "B"], + x=[0.1, 0.2], + connect=False, + row_range=False, + show_legends=False, + show_zero_line_if_needed=True, + ) + assert _has_zero_vline(ax) is False + + # Case 3: negative but disabled -> should not add + fig, ax = cleveland_dotplot( + row=["A", "B"], + x=[-0.1, 0.2], + connect=False, + row_range=False, + show_legends=False, + show_zero_line_if_needed=False, + ) + assert _has_zero_vline(ax) is False + + +def test_legends_created_when_enabled_and_groups_present(): + marker_map = {"dense": "o", "sparse": "^"} + color_map = {"binary": "crimson", "count": "teal"} + + fig, ax = cleveland_dotplot( + row=["A", "A", "B", "B"], + x=[1.0, 2.0, 3.0, 4.0], + marker_group=["dense", "sparse", "dense", "sparse"], + color_group=["binary", "binary", "count", "count"], + marker_map=marker_map, + color_map=color_map, + connect=False, + row_range=False, + show_legends=True, + show_zero_line_if_needed=False, + ) + + # When both marker_handles and color_handles exist: + # - one legend is added as an artist (ax.add_artist), + # - the other is the "current" legend returned by ax.legend(...) + legends = [ch for ch in ax.get_children() if isinstance(ch, Legend)] + assert len(legends) >= 1 + assert ax.get_legend() is not None # at least the last legend exists + + +def test_no_legends_when_disabled(): + fig, ax = cleveland_dotplot( + row=["A", "B"], + x=[1.0, 2.0], + show_legends=False, + connect=False, + row_range=False, + show_zero_line_if_needed=False, + ) + + legends = [ch for ch in ax.get_children() if isinstance(ch, Legend)] + assert len(legends) == 0 + assert ax.get_legend() is None