diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/.gitignore b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/.gitignore new file mode 100644 index 0000000000..f4daeb0de8 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/.gitignore @@ -0,0 +1,4 @@ +node_modules/ +package-lock.json +dist/ +.vite/ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/README.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/README.md new file mode 100644 index 0000000000..cb3c16e12f --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/README.md @@ -0,0 +1,20 @@ +# Evo2 SAE Feature Explorer (front-end) + +Interactive dashboard for Evo2 SAE features — feature atlas, sequence inspector, and +generative steering. + +This directory is the **front-end only**. Its backend is the standalone +[`evo2_sae_infer`](../evo2_sae_infer) engine — the viz is just a UI over its +`serve` mode, so there is no model code here. + +```bash +# 1. Backend: loads Evo2 + the SAE and serves the HTTP API on :8001 +../scripts/launch_inference.sh serve # or: python -m evo2_sae_infer serve + +# 2. Front-end (this directory) +npm install && npm run dev # Vite dev server +``` + +The Vite dev server proxies `/api` → `http://localhost:8001` (see `vite.config.js`); +point it elsewhere with the `VITE_BACKEND` env var. Configure the backend (checkpoint, +SAE, layer, feature annotations) via the env vars documented in `launch_inference.sh`. diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/index.html b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/index.html new file mode 100644 index 0000000000..c1eacdeb66 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/index.html @@ -0,0 +1,104 @@ + + + + + + Evo 2 SAE Feature Explorer + + + +
+ + + diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/package.json b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/package.json new file mode 100644 index 0000000000..53674056a3 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/package.json @@ -0,0 +1,25 @@ +{ + "name": "evo2-dashboard-mockup", + "version": "0.1.0", + "private": true, + "type": "module", + "scripts": { + "dev": "vite", + "build": "vite build", + "preview": "vite preview" + }, + "dependencies": { + "@uwdata/mosaic-core": "^0.21.1", + "@uwdata/mosaic-sql": "^0.21.1", + "@uwdata/vgplot": "^0.21.1", + "embedding-atlas": "^0.16.1", + "lucide-react": "^0.577.0", + "react": "^18.2.0", + "react-dom": "^18.2.0", + "umap-js": "^1.4.0" + }, + "devDependencies": { + "@vitejs/plugin-react": "^4.2.0", + "vite": "^5.0.0" + } +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/public/feature_examples.parquet b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/public/feature_examples.parquet new file mode 100644 index 0000000000..0c46e8ca8b Binary files /dev/null and b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/public/feature_examples.parquet differ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/public/feature_metadata.parquet b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/public/feature_metadata.parquet new file mode 100644 index 0000000000..b4ae716182 Binary files /dev/null and b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/public/feature_metadata.parquet differ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/public/features_atlas.parquet b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/public/features_atlas.parquet new file mode 100644 index 0000000000..825f08963c Binary files /dev/null and b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/public/features_atlas.parquet differ diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/App.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/App.jsx new file mode 100644 index 0000000000..524e803a15 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/App.jsx @@ -0,0 +1,1352 @@ +import React, { useState, useEffect, useRef, useMemo, useCallback } from 'react' +import * as vg from '@uwdata/vgplot' +import { wasmConnector, MosaicClient } from '@uwdata/mosaic-core' +import { Query, sql, literal } from '@uwdata/mosaic-sql' +import FeatureCard from './FeatureCard' +import FeatureList from './FeatureList' +import EmbeddingView from './EmbeddingView' +import Histogram from './Histogram' +import InfoButton from './InfoButton' +import { Sun, Moon } from 'lucide-react' +import { styles } from './styles' + +export default function App({ title = "Evo 2 SAE Feature Explorer", subtitle = "Real SAE features — Evo 2 1B, layer 19" }) { + const [darkMode, setDarkMode] = useState(true) + + // Toggle dark class on document root + useEffect(() => { + document.documentElement.classList.toggle('dark', darkMode) + }, [darkMode]) + + const [features, setFeatures] = useState([]) + const [loading, setLoading] = useState(true) + const [loadingProgress, setLoadingProgress] = useState({ step: 0, total: 4, message: 'Starting up...' }) + const [error, setError] = useState(null) + const [sortBy, setSortBy] = useState('frequency') + const [selectedFeatureIds, setSelectedFeatureIds] = useState(null) // null = all selected + const [mosaicReady, setMosaicReady] = useState(false) + const [categoryColumns, setCategoryColumns] = useState([]) + const [selectedCategory, setSelectedCategory] = useState('cluster_id') + const [hiddenCategories, setHiddenCategories] = useState(new Set()) + const [clickedFeatureId, setClickedFeatureId] = useState(null) + const [clusterLabels, setClusterLabels] = useState(null) + const [vocabLogits, setVocabLogits] = useState(null) + const [featureAnalysis, setFeatureAnalysis] = useState(null) + + const brushRef = useRef(null) + const [showGuideModal, setShowGuideModal] = useState(false) + const [showMetricsModal, setShowMetricsModal] = useState(false) + const [searchTerm, setSearchTerm] = useState('') + const [cardResetKey, setCardResetKey] = useState(0) + const [plotResetKey, setPlotResetKey] = useState(0) + const [viewportState, setViewportState] = useState(null) // null = let embedding-atlas auto-fit on first load + const [displayedCardCount, setDisplayedCardCount] = useState(20) // Pagination: start with 20 cards + const [showEditedOnly, setShowEditedOnly] = useState(false) // Filter for edited features only + const [histMetric1, setHistMetric1] = useState('log_frequency') + const [histMetric2, setHistMetric2] = useState('max_activation') + const [histMetric3, setHistMetric3] = useState('cluster_id') // tracks color-by selection + const featureRefs = useRef({}) + const featureListRef = useRef(null) + const endOfListRef = useRef(null) + const searchSource = useRef({ source: 'search' }) + const editedSource = useRef({ source: 'edited' }) + const legendSource = useRef({ source: 'legend' }) + const loadingMoreRef = useRef(false) + + // Lazy-load examples for a single feature from DuckDB (feature_examples VIEW) + const loadExamplesForFeature = useCallback(async (featureId) => { + const result = await vg.coordinator().query( + `SELECT * FROM feature_examples WHERE feature_id = ${featureId} ORDER BY example_rank` + ) + return result.toArray().map(row => ({ + sequence_id: row.sequence_id, + start: row.start, + end: row.end, + sequence: row.sequence, + activations: Array.from(row.activations), + max_activation: row.max_activation, + best_annotation: row.best_annotation, + })) + }, []) + + // Intersection Observer for infinite scroll pagination + useEffect(() => { + const sentinel = endOfListRef.current + const scrollContainer = featureListRef.current + if (!sentinel || !scrollContainer) return + + const observer = new IntersectionObserver( + entries => { + console.log('[scroll] sentinel intersecting:', entries[0].isIntersecting, 'loadingMore:', loadingMoreRef.current) + if (entries[0].isIntersecting && !loadingMoreRef.current) { + loadingMoreRef.current = true + setDisplayedCardCount(prev => prev + 20) + // Reset flag after a delay to allow next batch + setTimeout(() => { + loadingMoreRef.current = false + }, 300) + } + }, + { root: scrollContainer, threshold: 0.1, rootMargin: '200px' } + ) + + observer.observe(sentinel) + + return () => { + observer.disconnect() + } + }, [mosaicReady]) + + // Handle click on a feature in the UMAP (or null for empty canvas click) + const animationRef = useRef(null) + const currentViewportRef = useRef(null) + const initialViewportRef = useRef(null) + + // Handle viewport changes from the UMAP component + const handleViewportChange = useCallback((vp) => { + // Capture initial viewport on first report, slightly zoomed out so all points fit + if (!initialViewportRef.current && vp) { + initialViewportRef.current = { ...vp, scale: vp.scale * 0.5 } + setViewportState(initialViewportRef.current) + currentViewportRef.current = { ...initialViewportRef.current } + } + // Clamp zoom to max scale of 5 + if (vp && vp.scale > 5) { + const clamped = { ...vp, scale: 5 } + setViewportState(clamped) + currentViewportRef.current = clamped + return + } + // Always track current viewport (but not during our own animations) + if (!animationRef.current) { + currentViewportRef.current = vp + } + }, []) + + // Handle click on a feature in the UMAP (with coordinates for zooming) + const handleFeatureClick = useCallback((featureId, x, y) => { + + setClickedFeatureId(featureId) + + if (featureId == null) return + + // Scroll to the feature card + setTimeout(() => { + const ref = featureRefs.current[featureId] + if (ref) { + ref.scrollIntoView({ behavior: 'smooth', block: 'center' }) + } + }, 50) + }, []) + + // Handle click on a feature card (highlights point in UMAP, no zoom) + const handleCardClick = useCallback(async (featureId, isExpanding) => { + + if (!isExpanding) { + setClickedFeatureId(null) + return + } + + setClickedFeatureId(featureId) + }, []) + + // Initialize Mosaic and load data + useEffect(() => { + async function init() { + try { + // Step 1: Initialize DuckDB-WASM + setLoadingProgress({ step: 1, total: 4, message: 'Initializing database engine...' }) + const wasm = wasmConnector() + vg.coordinator().databaseConnector(wasm) + + // Step 2: Load parquet data + setLoadingProgress({ step: 2, total: 4, message: 'Loading embedding data...' }) + const urlParams = new URLSearchParams(window.location.search) + const dataPath = urlParams.get('data') || '/features_atlas.parquet' + const parquetUrl = dataPath.startsWith('http') + ? dataPath + : new URL(dataPath, window.location.origin).href + + + await vg.coordinator().exec(` + CREATE TABLE features AS + SELECT * FROM read_parquet('${parquetUrl}') + `) + + // HDBSCAN assigns -1 to noise points; embedding-atlas casts category + // columns to UTINYINT which can't hold negatives. Remap to NULL. + try { + await vg.coordinator().exec(` + UPDATE features SET cluster_id = NULL WHERE cluster_id < 0 + `) + } catch (e) { + // cluster_id column may not exist — that's fine + } + + // Step 3: Process columns and categories + setLoadingProgress({ step: 3, total: 4, message: 'Processing columns...' }) + const schemaResult = await vg.coordinator().query(` + SELECT column_name, column_type + FROM (DESCRIBE features) + `) + + const columns = schemaResult.toArray().map(row => ({ + name: row.column_name, + type: row.column_type + })) + + const detectedCategories = [] + const sequentialColumns = [] + + for (const col of columns) { + if (['x', 'y', 'feature_id', 'top_example_idx', 'logo_path'].includes(col.name)) continue + + if (col.type === 'VARCHAR') { + const isGsea = col.name.startsWith('gsea_') + const maxUnique = isGsea ? Infinity : 50 + const cardinalityResult = await vg.coordinator().query(` + SELECT COUNT(DISTINCT "${col.name}") as n_unique FROM features WHERE "${col.name}" IS NOT NULL AND "${col.name}" != 'unlabeled' + `) + const nUnique = cardinalityResult.toArray()[0]?.n_unique ?? 0 + if (nUnique > 0 && nUnique <= maxUnique) { + // For high-cardinality GSEA columns, collapse to top 20 + "other" + if (isGsea && nUnique > 20) { + await vg.coordinator().exec(` + CREATE OR REPLACE TABLE features AS + SELECT * REPLACE ( + CASE + WHEN "${col.name}" IS NULL OR "${col.name}" = 'unlabeled' THEN 'unlabeled' + WHEN "${col.name}" IN ( + SELECT "${col.name}" FROM features + WHERE "${col.name}" IS NOT NULL AND "${col.name}" != 'unlabeled' + GROUP BY "${col.name}" ORDER BY COUNT(*) DESC LIMIT 20 + ) THEN "${col.name}" + ELSE 'other' + END AS "${col.name}" + ) FROM features + `) + detectedCategories.push({ name: col.name, type: 'string', nUnique: 22 }) + } else { + detectedCategories.push({ name: col.name, type: 'string', nUnique }) + } + } + } else if (col.type === 'BIGINT' || col.type === 'INTEGER') { + if (col.name.includes('cluster') || col.name.includes('category') || col.name.includes('group')) { + const cardinalityResult = await vg.coordinator().query(` + SELECT COUNT(DISTINCT "${col.name}") as n_unique FROM features WHERE "${col.name}" IS NOT NULL + `) + const nUnique = cardinalityResult.toArray()[0]?.n_unique ?? 0 + if (nUnique > 0 && nUnique <= 50) { + detectedCategories.push({ name: col.name, type: 'integer', nUnique }) + } + } + } else if (col.type === 'DOUBLE' || col.type === 'FLOAT') { + // Numeric columns for sequential coloring + if (['log_frequency', 'max_activation', 'activation_freq', 'frequency', + 'mean_variant_1bcdwt', + 'high_score_fraction', 'clinvar_fraction', + 'mean_phylop', 'mean_variant_delta', 'mean_site_delta', 'mean_local_delta', + 'high_score_delta', 'low_score_delta', + 'gc_mean', 'gc_std', + 'trinuc_entropy', 'trinuc_dominant_frac', + 'pli_mean_pli', 'pli_frac_constrained', 'pli_max_pli', + 'codon_cai', 'codon_tai', 'codon_rscu', + 'gene_entropy', 'gene_n_unique', 'gene_dominant_frac', + ].includes(col.name)) { + sequentialColumns.push({ name: col.name, type: 'sequential' }) + } + } + } + + // Create integer-encoded versions of string category columns + for (const col of detectedCategories) { + if (col.type === 'string') { + await vg.coordinator().exec(` + CREATE OR REPLACE TABLE features AS + SELECT *, + CASE WHEN "${col.name}" IS NULL THEN NULL + ELSE DENSE_RANK() OVER (ORDER BY "${col.name}") - 1 + END AS "${col.name}_cat" + FROM features + `) + } + } + + // Create binned versions of sequential columns (10 bins) + const NUM_BINS = 10 + for (const col of sequentialColumns) { + await vg.coordinator().exec(` + CREATE OR REPLACE TABLE features AS + SELECT *, + CASE WHEN "${col.name}" IS NULL THEN NULL + ELSE LEAST(${NUM_BINS - 1}, CAST( + (("${col.name}" - (SELECT MIN("${col.name}") FROM features)) / + NULLIF((SELECT MAX("${col.name}") - MIN("${col.name}") FROM features), 0)) * ${NUM_BINS} + AS INTEGER)) + END AS "${col.name}_bin" + FROM features + `) + detectedCategories.push({ name: col.name, type: 'sequential', nUnique: NUM_BINS }) + } + + setCategoryColumns(detectedCategories) + + // Create crossfilter selection + brushRef.current = vg.Selection.crossfilter() + + + // Step 4: Load feature metadata from parquet via DuckDB + setLoadingProgress({ step: 4, total: 4, message: 'Loading feature metadata...' }) + const metaUrl = new URL('/feature_metadata.parquet', window.location.origin).href + const examplesUrl = new URL('/feature_examples.parquet', window.location.origin).href + + await vg.coordinator().exec(` + CREATE TABLE IF NOT EXISTS feature_metadata AS + SELECT * FROM read_parquet('${metaUrl}') + `) + await vg.coordinator().exec(` + CREATE VIEW IF NOT EXISTS feature_examples AS + SELECT * FROM read_parquet('${examplesUrl}') + `) + + // Load features from the features table (which has labels + category columns) + const categorySelectCols = detectedCategories + .filter(c => c.type === 'string' || c.type === 'integer') + .map(c => `"${c.name}"`) + .join(', ') + const extraSelect = categorySelectCols ? `, ${categorySelectCols}` : '' + // logo_path is optional — older parquets won't have it, so detect and + // include it only if the column exists. + const hasLogoPath = columns.some(c => c.name === 'logo_path') + const logoSelect = hasLogoPath ? ', logo_path' : '' + const featuresResult = await vg.coordinator().query(` + SELECT + feature_id, + label, + activation_freq, + max_activation, + x, + y + ${logoSelect} + ${extraSelect} + FROM features + ORDER BY feature_id + `) + const loadedFeatures = featuresResult.toArray().map(row => { + const f = { + feature_id: row.feature_id, + label: row.label, + description: row.label, + activation_freq: row.activation_freq, + max_activation: row.max_activation, + x: row.x, + y: row.y, + logo_path: row.logo_path, + } + for (const col of detectedCategories) { + if (col.type === 'string' || col.type === 'integer') { + f[col.name] = row[col.name] + } + } + return f + }) + setFeatures(loadedFeatures) + + // Generate cluster labels from DuckDB (non-fatal if cluster_id doesn't exist) + try { + const clusterResult = await vg.coordinator().query(` + SELECT + cluster_id, + AVG(x) as cx, + AVG(y) as cy, + MODE(label) as top_label, + COUNT(*) as n + FROM features + WHERE cluster_id IS NOT NULL + GROUP BY cluster_id + ORDER BY n DESC + `) + const labels = clusterResult.toArray() + .filter(row => row.top_label && !row.top_label.startsWith('Feature ')) + .map((row, i) => ({ + x: Number(row.cx), + y: Number(row.cy), + text: row.top_label.length > 40 ? row.top_label.slice(0, 40) + '...' : row.top_label, + priority: row.n, + level: 0, + })) + console.log('[cluster labels] generated:', labels.length, labels.slice(0, 5)) + if (labels.length > 0) { + setClusterLabels(labels) + } + } catch (e) { + console.log('[cluster labels] query failed:', e.message) + } + + // Load cluster labels from file (overrides computed ones if present) + try { + const labelsRes = await fetch('./cluster_labels.json') + if (labelsRes.ok) { + const labelsData = await labelsRes.json() + setClusterLabels(labelsData) + } + } catch (labelErr) { + } + + // Load vocab logits (non-fatal if missing) + try { + const logitsRes = await fetch('./vocab_logits.json') + if (logitsRes.ok) { + const logitsData = await logitsRes.json() + setVocabLogits(logitsData) + } + } catch (e) { + } + + // Load feature analysis (non-fatal if missing) + try { + const analysisRes = await fetch('./feature_analysis.json') + if (analysisRes.ok) { + const analysisData = await analysisRes.json() + setFeatureAnalysis(analysisData) + } + } catch (e) { + } + + setMosaicReady(true) + setLoading(false) + + } catch (err) { + console.error('Init error:', err) + setError(err.message) + setLoading(false) + } + } + + init() + }, []) + + // Create a Mosaic client that receives filtered feature IDs + useEffect(() => { + if (!mosaicReady || !brushRef.current) return + + const coordinator = vg.coordinator() + const selection = brushRef.current + const totalFeatures = features.length + + // Create a class that extends MosaicClient + class FeatureFilterClient extends MosaicClient { + constructor(filterBy) { + super(filterBy) + this._isConnected = true + } + + query(filter = []) { + // Use Mosaic's Query builder + const q = Query + .select({ feature_id: 'feature_id' }) + .distinct() + .from('features') + + // Apply filter if present + if (filter.length > 0) { + q.where(filter) + } + + return q + } + + queryResult(data) { + if (!this._isConnected) return + + try { + let ids = new Set() + if (data && typeof data.getChild === 'function') { + const col = data.getChild('feature_id') + if (col) { + for (let i = 0; i < col.length; i++) { + ids.add(col.get(i)) + } + } + } else if (data && data.toArray) { + ids = new Set(data.toArray().map(r => r.feature_id)) + } + setSelectedFeatureIds(ids.size > 0 && ids.size < totalFeatures ? ids : null) + } catch (err) { + console.error('Error processing result:', err) + } + } + + // Required by Mosaic for selection updates + update() { + return this + } + + queryError(err) { + if (this._isConnected) { + console.error('FeatureFilterClient error:', err) + } + } + + disconnect() { + this._isConnected = false + } + } + + const client = new FeatureFilterClient(selection) + + // Delay connection slightly to ensure Mosaic is fully ready + const timeoutId = setTimeout(() => { + try { + coordinator.connect(client) + } catch (err) { + console.warn('Error connecting FeatureFilterClient:', err) + } + }, 0) + + return () => { + clearTimeout(timeoutId) + try { + client.disconnect() + coordinator.disconnect(client) + } catch (err) { + // Ignore disconnect errors + } + } + }, [mosaicReady, features.length]) + + // Clear ALL selections (search, histograms, UMAP, clicked feature) + const handleClearSelection = useCallback(() => { + if (brushRef.current) { + const selection = brushRef.current + // Clear each clause by updating with null predicate for each source + const clauses = selection.clauses || [] + for (const clause of clauses) { + if (clause.source) { + try { + selection.update({ source: clause.source, predicate: null, value: null }) + } catch (e) { + // Ignore errors from clearing + } + } + } + // Also clear the search clause specifically + if (searchSource.current) { + try { + selection.update({ source: searchSource.current, predicate: null, value: null }) + } catch (e) { + // Ignore + } + } + } + setSelectedFeatureIds(null) + setSearchTerm('') + setClickedFeatureId(null) + setHiddenCategories(new Set()) + // Reset viewport to the auto-fit view captured on first load + if (initialViewportRef.current) { + setViewportState({ ...initialViewportRef.current }) + currentViewportRef.current = { ...initialViewportRef.current } + } else { + setViewportState(null) + currentViewportRef.current = null + } + // Reset all cards to collapsed state + setCardResetKey(k => k + 1) + // Reset histograms and UMAP to clear brush visuals + setPlotResetKey(k => k + 1) + }, []) + + // Export all edited features to CSV with full data + const handleExportEdited = useCallback(() => { + // Get all edited features + const editedFeatures = features.filter(f => localStorage.getItem(`featureTitle_${f.feature_id}`) !== null) + + if (editedFeatures.length === 0) { + alert('No edited features to export') + return + } + + const lines = [] + const escapeCsv = (str) => `"${(str || '').toString().replace(/"/g, '""')}"` + + // Codon mapping for amino acids + const CODON_AA = { + 'TTT':'F','TTC':'F','TTA':'L','TTG':'L','TCT':'S','TCC':'S','TCA':'S','TCG':'S', + 'TAT':'Y','TAC':'Y','TAA':'*','TAG':'*','TGT':'C','TGC':'C','TGA':'*','TGG':'W', + 'CTT':'L','CTC':'L','CTA':'L','CTG':'L','CCT':'P','CCC':'P','CCA':'P','CCG':'P', + 'CAT':'H','CAC':'H','CAA':'Q','CAG':'Q','CGT':'R','CGC':'R','CGA':'R','CGG':'R', + 'ATT':'I','ATC':'I','ATA':'I','ATG':'M','ACT':'T','ACC':'T','ACA':'T','ACG':'T', + 'AAT':'N','AAC':'N','AAA':'K','AAG':'K','AGT':'S','AGC':'S','AGA':'R','AGG':'R', + 'GTT':'V','GTC':'V','GTA':'V','GTG':'V','GCT':'A','GCC':'A','GCA':'A','GCG':'A', + 'GAT':'D','GAC':'D','GAA':'E','GAG':'E','GGT':'G','GGC':'G','GGA':'G','GGG':'G', + } + + editedFeatures.forEach((f, idx) => { + const userTitle = localStorage.getItem(`featureTitle_${f.feature_id}`) + const label = f.label || `Feature ${f.feature_id}` + + // Add separator for readability + if (idx > 0) lines.push('') + + // Feature metadata + lines.push(`=== FEATURE ${f.feature_id} ===`) + lines.push(`Feature ID,${f.feature_id}`) + lines.push(`Original Label,${escapeCsv(label)}`) + lines.push(`Your Title,${escapeCsv(userTitle)}`) + lines.push(`Activation Frequency,${(f.activation_freq || 0).toFixed(6)}`) + lines.push(`Max Activation,${(f.max_activation || 0).toFixed(4)}`) + lines.push('') + + // Vocab logits + const logits = vocabLogits?.[String(f.feature_id)] + if (logits) { + lines.push('TOP PROMOTED CODONS') + lines.push('Codon,Amino Acid,Logit Value') + ;(logits.top_positive || []).forEach(([codon, val]) => { + lines.push(`${codon},${CODON_AA[codon] || '?'},${val.toFixed(4)}`) + }) + lines.push('') + + lines.push('TOP SUPPRESSED CODONS') + lines.push('Codon,Amino Acid,Logit Value') + ;(logits.top_negative || []).forEach(([codon, val]) => { + lines.push(`${codon},${CODON_AA[codon] || '?'},${val.toFixed(4)}`) + }) + lines.push('') + } + + // Feature analysis + const analysis = featureAnalysis?.[String(f.feature_id)] + if (analysis?.codon_annotations) { + lines.push('CODON ANNOTATIONS') + const ann = analysis.codon_annotations + if (ann.amino_acid) { + lines.push(`Amino Acid,${ann.amino_acid.aa}`) + lines.push(`AA Frequency,${(ann.amino_acid.fraction * 100).toFixed(1)}%`) + } + if (ann.codon_usage) { + lines.push(`Codon Usage,${ann.codon_usage.bias}`) + } + if (ann.wobble) { + lines.push(`Wobble Position,${ann.wobble.preference}`) + } + if (ann.cpg) { + lines.push(`CpG Context,${ann.cpg.fraction}`) + } + lines.push('') + } + }) + + // Create and download file + const csv = lines.join('\n') + const blob = new Blob([csv], { type: 'text/csv' }) + const url = URL.createObjectURL(blob) + const a = document.createElement('a') + a.href = url + a.download = `edited_features_${new Date().toISOString().split('T')[0]}.csv` + document.body.appendChild(a) + a.click() + document.body.removeChild(a) + URL.revokeObjectURL(url) + }, [features, vocabLogits, featureAnalysis]) + + // Update Mosaic crossfilter when "Edited Only" toggle changes + useEffect(() => { + if (!brushRef.current || !mosaicReady) return + + const selection = brushRef.current + + if (showEditedOnly) { + // Get all edited feature IDs from localStorage + const editedIds = features + .filter(f => localStorage.getItem(`featureTitle_${f.feature_id}`) !== null) + .map(f => f.feature_id) + + if (editedIds.length > 0) { + // Create predicate: feature_id IN (id1, id2, id3, ...) + const idsStr = editedIds.join(',') + // Use raw SQL string, not literal() which would quote it as a string + const predicateSql = `feature_id IN (${idsStr})` + + try { + selection.update({ + source: editedSource.current, + predicate: predicateSql, + value: 'edited' + }) + } catch (err) { + console.warn('Error updating edited filter:', err) + } + } + } else { + // Clear the edited filter + try { + selection.update({ + source: editedSource.current, + predicate: null, + value: null + }) + } catch (err) { + console.warn('Error clearing edited filter:', err) + } + } + }, [showEditedOnly, mosaicReady, features]) + + // Update Mosaic crossfilter when legend selection changes + useEffect(() => { + if (!brushRef.current || !mosaicReady) return + + const selection = brushRef.current + + if (hiddenCategories.size > 0 && selectedCategory && selectedCategory !== 'none') { + const colInfo = categoryColumns.find(c => c.name === selectedCategory) + if (colInfo && (colInfo.type === 'string' || colInfo.type === 'integer')) { + const values = Array.from(hiddenCategories).map(v => `'${v.replace(/'/g, "''")}'`).join(',') + const predicateSql = `"${selectedCategory}" IN (${values})` + + try { + selection.update({ + source: legendSource.current, + predicate: predicateSql, + value: Array.from(hiddenCategories).join(',') + }) + } catch (err) { + console.warn('Legend filter update failed:', err) + } + } + } else { + try { + selection.update({ + source: legendSource.current, + predicate: null, + value: null + }) + } catch (err) { + // Ignore + } + } + }, [hiddenCategories, selectedCategory, mosaicReady, categoryColumns]) + + // Handle search - updates both Mosaic crossfilter (for UMAP/histograms) and local state (for cards) + const handleSearchChange = useCallback((e) => { + const term = e.target.value + setSearchTerm(term) + + // Also update Mosaic crossfilter so UMAP and histograms filter + if (brushRef.current) { + const selection = brushRef.current + + try { + if (term.trim()) { + // Build predicate using sql template - ILIKE for case-insensitive search + const pattern = literal('%' + term.trim() + '%') + const predicate = sql`label ILIKE ${pattern}` + + selection.update({ + source: searchSource.current, + predicate: predicate, + value: term.trim() + }) + } else { + // Clear search by removing the clause + selection.update({ + source: searchSource.current, + predicate: null, + value: null + }) + } + } catch (err) { + console.warn('Search update error:', err) + } + } + }, []) + + // Filter and sort features + const filteredFeatures = useMemo(() => { + let result = features + + // Filter by Mosaic selection (includes UMAP brush) + if (selectedFeatureIds !== null) { + result = result.filter(f => selectedFeatureIds.has(f.feature_id)) + } + + // Also filter by search term client-side (searches metadata fields) + if (searchTerm.trim()) { + const q = searchTerm.toLowerCase() + result = result.filter(f => + f.description?.toLowerCase().includes(q) || + f.feature_id.toString().includes(q) || + f.best_annotation?.toLowerCase().includes(q) + ) + } + + // Filter by edited features only + if (showEditedOnly) { + result = result.filter(f => localStorage.getItem(`featureTitle_${f.feature_id}`) !== null) + } + + // Helper: unlabeled features sort last + const isUnlabeled = (f) => { + const lbl = (f.label || f.description || '').toLowerCase() + return !lbl || lbl.startsWith('feature ') || lbl.includes('common codons') + } + + // Sort (labeled features first, then by chosen metric) + if (sortBy === 'frequency') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || (b.activation_freq || 0) - (a.activation_freq || 0)) + } else if (sortBy === 'max_activation') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || (b.max_activation || 0) - (a.max_activation || 0)) + } else if (sortBy === 'feature_id') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || a.feature_id - b.feature_id) + } else if (sortBy === 'high_score_fraction') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || (b.high_score_fraction || 0) - (a.high_score_fraction || 0)) + } else if (sortBy === 'mean_variant_delta') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || Math.abs(b.mean_variant_delta || 0) - Math.abs(a.mean_variant_delta || 0)) + } else if (sortBy === 'mean_site_delta') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || Math.abs(b.mean_site_delta || 0) - Math.abs(a.mean_site_delta || 0)) + } else if (sortBy === 'mean_local_delta') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || Math.abs(b.mean_local_delta || 0) - Math.abs(a.mean_local_delta || 0)) + } else if (sortBy === 'clinvar_fraction') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || (b.clinvar_fraction || 0) - (a.clinvar_fraction || 0)) + } else if (sortBy === 'mean_phylop') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || (b.mean_phylop || 0) - (a.mean_phylop || 0)) + } else if (sortBy === 'gc_mean') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || Math.abs((b.gc_mean || 0.5) - 0.5) - Math.abs((a.gc_mean || 0.5) - 0.5)) + } else if (sortBy === 'trinuc_entropy') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || (a.trinuc_entropy ?? 99) - (b.trinuc_entropy ?? 99)) + } else if (sortBy === 'gene_entropy') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || (a.gene_entropy ?? 99) - (b.gene_entropy ?? 99)) + } else if (sortBy === 'gene_n_unique') { + result = [...result].sort((a, b) => isUnlabeled(a) - isUnlabeled(b) || (a.gene_n_unique || 999) - (b.gene_n_unique || 999)) + } + + return result + }, [features, sortBy, selectedFeatureIds, searchTerm, showEditedOnly]) + + // Reset pagination when filters change + useEffect(() => { + setDisplayedCardCount(20) + loadingMoreRef.current = false + }, [searchTerm, sortBy, selectedFeatureIds, showEditedOnly]) + + if (loading) { + const pct = Math.round(((loadingProgress.step - 1) / loadingProgress.total) * 100) + return ( +
+
Loading dashboard...
+
+
+
+
{loadingProgress.message}
+
+ ) + } + + if (error) { + return ( +
+

Error: {error}

+

+ Make sure features_atlas.parquet, feature_metadata.parquet, and feature_examples.parquet exist in the public/ folder. +

+
+ ) + } + + return ( +
+
+
+

{subtitle}

+
+
+ + +
+
+ +
+
+
+
+ + Decoder UMAP + +
+ + + setShowMetricsModal(true)} + style={{ + display: 'inline-flex', alignItems: 'center', justifyContent: 'center', + width: '15px', height: '15px', borderRadius: '50%', border: '1px solid var(--border-input)', + fontSize: '10px', fontWeight: '600', color: 'var(--text-tertiary)', cursor: 'pointer', + userSelect: 'none', lineHeight: 1, flexShrink: 0, + }} + >i + +
+
+
+ {mosaicReady && ( + + )} + {selectedCategory && selectedCategory !== 'none' && (() => { + const colInfo = categoryColumns.find(c => c.name === selectedCategory) + if (!colInfo) return null + + if (colInfo.type === 'sequential') { + const colors = [ + "#c359ef", "#9525C6", "#0046a4", "#0074DF", "#3f8500", + "#76B900", "#ef9100", "#F9C500", "#ff8181", "#EF2020" + ] + const vals = features + .map(f => f[selectedCategory]) + .filter(v => v != null && !isNaN(v)) + const minVal = vals.length > 0 ? Math.min(...vals) : 0 + const maxVal = vals.length > 0 ? Math.max(...vals) : 1 + const fmt = (v) => Math.abs(v) >= 100 ? v.toFixed(0) : Math.abs(v) >= 1 ? v.toFixed(1) : v.toFixed(3) + return ( +
+ {fmt(maxVal)} +
+ {fmt(minVal)} + + {selectedCategory.replace(/_/g, ' ')} + +
+ ) + } + + if (colInfo.type === 'string' || colInfo.type === 'integer') { + const catColors = [ + "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", + "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf", + "#aec7e8", "#ffbb78", "#98df8a", "#ff9896", "#c5b0d5", + "#c49c94", "#f7b6d2", "#c7c7c7", "#dbdb8d", "#9edae5" + ] + // Count occurrences of each category value, sorted alphabetically + // (matching DENSE_RANK ORDER BY which is alphabetical) + const counts = {} + for (const f of features) { + const val = f[selectedCategory] + if (val != null && val !== '') { + counts[val] = (counts[val] || 0) + 1 + } + } + // Sort alphabetically to match dense_rank ordering + const sortedCategories = Object.keys(counts).sort() + return ( +
+
+ {selectedCategory.replace(/_/g, ' ').replace('gsea ', '')} +
+ {sortedCategories.map((cat, i) => { + const hasFilter = hiddenCategories.size > 0 + const isHidden = hasFilter && !hiddenCategories.has(cat) + return ( +
{ + if (e.metaKey || e.ctrlKey) { + // Cmd/Ctrl+click: toggle this category in the selection + setHiddenCategories(prev => { + const next = new Set(prev) + if (next.has(cat)) { + next.delete(cat) + // If nothing left selected, clear filter + return next.size === 0 ? new Set() : next + } else { + next.add(cat) + return next + } + }) + } else { + // Regular click: solo this category (or clear if already solo'd) + setHiddenCategories(prev => { + if (prev.size === 1 && prev.has(cat)) return new Set() + return new Set([cat]) + }) + } + }} + style={{ + display: 'flex', alignItems: 'center', gap: '5px', padding: '2px 0', + cursor: 'pointer', opacity: isHidden ? 0.15 : 1, + userSelect: 'none', + }} + > + + + {cat} + + + {counts[cat]} + +
+ ) + })} +
+ ) + } + + return null + })()} +
+
+ +
+ {[ + { value: histMetric1, setter: setHistMetric1 }, + { value: histMetric2, setter: setHistMetric2 }, + { value: histMetric3, setter: setHistMetric3 }, + ].map(({ value, setter }, i) => ( +
+
+ +
+ {mosaicReady && value && value !== 'none' && ( + + )} +
+ ))} +
+
+ +
+
+ + + +
+ +
+ + Showing {filteredFeatures.length} of {features.length} features + {selectedFeatureIds !== null && ` (${selectedFeatureIds.size} selected in UMAP)`} + + setShowGuideModal(true)} + style={{ + display: 'inline-flex', alignItems: 'center', justifyContent: 'center', + width: '15px', height: '15px', borderRadius: '50%', border: '1px solid #bbb', + fontSize: '10px', fontWeight: '600', color: '#888', cursor: 'pointer', + userSelect: 'none', lineHeight: 1, flexShrink: 0, + }} + >i +
+ + +
+
+ + {showGuideModal && ( +
setShowGuideModal(false)} + style={{ + position: 'fixed', inset: 0, background: 'rgba(0,0,0,0.45)', + display: 'flex', alignItems: 'center', justifyContent: 'center', zIndex: 1000, + }} + > +
e.stopPropagation()} + style={{ + background: 'var(--bg-card)', borderRadius: '10px', maxWidth: '560px', width: '90%', + maxHeight: '80vh', overflowY: 'auto', padding: '28px 32px', + boxShadow: '0 8px 30px rgba(0,0,0,0.2)', + }} + > +
+

Feature Card Guide

+ setShowGuideModal(false)} + style={{ cursor: 'pointer', fontSize: '20px', color: '#999', lineHeight: 1 }} + >× +
+ +
+

Decoder Logits

+

+ The decoder logits histogram shows the projection of each feature's learned decoder weight vector through the language model's prediction head, with the mean logit vector subtracted across all features. This mean-centering removes the model's shared baseline bias toward common codons (e.g. GCC), so values reflect what each feature specifically promotes or suppresses relative to the average feature. Each bar represents a codon. Green bars indicate codons the feature promotes above baseline; red bars indicate codons it suppresses below baseline. Gray bars have no feature-specific effect. This tells you what the feature pushes the model to output — not what activates it. Stop codons (TAA, TAG, TGA) are excluded because the model was trained on coding sequences where internal stops almost never appear, so all features uniformly suppress them. +

+ +

Top Activating Sequences

+

+ These are the protein-coding sequences where this feature fires most strongly. Each codon is colored by its activation value — brighter highlights mean the feature responds more strongly at that position. This shows what inputs trigger the feature, which is conceptually distinct from decoder logits. A feature can activate strongly on a particular codon (e.g., lysine codons) without promoting that same codon in the output — it may instead influence downstream or contextual predictions. +

+ +
+
+
+ )} + + {showMetricsModal && ( +
setShowMetricsModal(false)} + style={{ + position: 'fixed', inset: 0, background: 'rgba(0,0,0,0.45)', + display: 'flex', alignItems: 'center', justifyContent: 'center', zIndex: 1000, + }} + > +
e.stopPropagation()} + style={{ + background: 'var(--bg-card)', borderRadius: '10px', maxWidth: '620px', width: '90%', + maxHeight: '80vh', overflowY: 'auto', padding: '28px 32px', + boxShadow: '0 8px 30px rgba(0,0,0,0.2)', + }} + > +
+

Variant Analysis Metrics

+ setShowMetricsModal(false)} + style={{ cursor: 'pointer', fontSize: '20px', color: '#999', lineHeight: 1 }} + >× +
+ +
+

Mean Variant Score (per model)

+

+ For each feature, the average model effect score across variant sequences where the feature fires. Computed for the 1b_cdwt model score column. A high value means the feature preferentially activates on variants that model predicts to be functionally impactful. +

+ +

High Score Fraction

+

+ Variants are split at the median model score. Among variants where a feature fires, what fraction are high-scoring? A value of 0.5 means no preference. Above 0.5 means the feature disproportionately fires on high-impact variants. Robust to outliers — measures distributional preference rather than average. +

+ +

ClinVar Fraction

+

+ Among variant sequences where the feature fires, the fraction from ClinVar vs COSMIC. ClinVar variants are germline (inherited, Mendelian disease). COSMIC variants are somatic (cancer mutations). High ClinVar fraction means the feature responds to germline disease patterns; low means it prefers somatic cancer mutation patterns. +

+ +

Mean PhyloP

+

+ Average evolutionary conservation score (PhyloP) across sequences where the feature fires. High values indicate conserved positions (functionally important). Negative values indicate rapidly evolving regions. Features with high mean PhyloP capture evolutionarily constrained patterns. +

+ +

Mean Variant Delta

+

+ For each gene, the difference in max feature activation between the variant and reference sequence: max_act(variant) − max_act(ref), averaged across all variant-ref pairs. Positive means the mutation increases feature activation; negative means it suppresses it. Near zero means the feature responds to the gene background, not the specific mutation. This controls for gene identity. +

+ +

Mean Site Delta

+

+ Like mean variant delta, but measured only at the exact codon position where the mutation occurs: activation_f(variant, pos) − activation_f(ref, pos). This captures direct effects — the feature responding to the changed codon itself. Compare with mean variant delta: a large variant delta but small site delta means the feature captures indirect/distal effects of the mutation (e.g., changes to predicted protein folding context), not the local codon change. +

+ +

Mean Local Delta

+

+ Like variant delta, but using the max activation within a 3-codon window around the variant site instead of the full sequence. Captures local effects of the mutation: max(window_variant) − max(window_ref). A large local delta with a small global delta means the mutation's effect is localized. Compare with site delta (exact position only) and variant delta (full sequence). +

+ +

GC Content (mean, std)

+

+ Mean and standard deviation of GC content across all sequences where the feature fires. Features with extreme GC mean (far from ~0.5) are GC-biased. Features with low GC std activate only on sequences with similar GC content — suggesting sensitivity to nucleotide composition rather than specific codon patterns. +

+ +

Trinuc Entropy

+

+ Shannon entropy (in bits) of the trinucleotide context distribution among variant sequences where the feature fires. Low entropy means the feature concentrates on specific mutation contexts (e.g., C[C>T]G for CpG transitions). High entropy means it fires across diverse mutation types. The dominant fraction shows what fraction of activations come from the most common trinuc context. +

+ +

Gene Distribution

+

+ Shannon entropy of the gene distribution among sequences where the feature fires. Low entropy means the feature is gene-specific — it concentrates on a few genes. High entropy means it fires broadly. gene_n_unique is the number of distinct genes. gene_dominant_frac is the fraction from the most common gene. A feature with low entropy and high dominant fraction has learned something specific to one gene family. +

+ +

High Score Delta

+

+ Same as mean variant delta, but averaged only over variants with model scores above the median. Shows how the feature responds specifically to high-impact mutations. Compare with low score delta: if high_score_delta >> low_score_delta, the feature selectively detects impactful mutations. +

+ +

Low Score Delta

+

+ Same as mean variant delta, but averaged only over variants with model scores below the median. Features where high score delta and low score delta differ significantly have learned to discriminate mutation severity. Features where both are similar just detect that a mutation occurred without distinguishing impact. +

+
+
+
+ )} +
+ ) +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/Dashboard.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/Dashboard.jsx new file mode 100644 index 0000000000..80e495f35b --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/Dashboard.jsx @@ -0,0 +1,68 @@ +import React, { useEffect, useState } from 'react' +import App from './App' +import GenerativeSteering from './GenerativeSteering' +import SequenceInspector from './SequenceInspector' +import { Sun, Moon } from 'lucide-react' + +// Three-tab shell. The Feature atlas is the static-parquet explorer (works +// offline); Generative steering and Sequence inspector both talk to the live +// backend (steering_server.py) through the /api proxy. +const TABS = [ + { id: 'atlas', label: 'Feature atlas' }, + { id: 'steering', label: 'Generative steering' }, + { id: 'inspector', label: 'Sequence inspector' }, +] + +export default function Dashboard() { + const [tab, setTab] = useState('atlas') + const [dark, setDark] = useState(true) + + useEffect(() => { + document.documentElement.classList.toggle('dark', dark) + }, [dark]) + + return ( +
+
+ Evo 2 SAE Feature Explorer + {TABS.map((t) => ( + + ))} + +
+ +
+ {tab === 'atlas' && } + {tab === 'steering' && } + {tab === 'inspector' && } +
+
+ ) +} + +const S = { + shell: { height: '100vh', display: 'flex', flexDirection: 'column', background: 'var(--bg)', color: 'var(--text)' }, + tabBar: { + display: 'flex', alignItems: 'center', gap: '6px', padding: '8px 16px', + background: 'var(--bg-card)', borderBottom: '1px solid var(--border)', flexShrink: 0, + }, + brand: { fontSize: '13px', fontWeight: 700, color: 'var(--text-heading)', marginRight: '14px' }, + tabOn: { + padding: '6px 14px', border: '1px solid var(--accent)', background: 'var(--bg-card-expanded)', + color: 'var(--accent)', borderRadius: '5px', cursor: 'pointer', fontSize: '12px', fontWeight: 600, + }, + tabOff: { + padding: '6px 14px', border: '1px solid var(--border)', background: 'transparent', + color: 'var(--text-secondary)', borderRadius: '5px', cursor: 'pointer', fontSize: '12px', + }, + theme: { + marginLeft: 'auto', display: 'inline-flex', alignItems: 'center', justifyContent: 'center', + width: '30px', height: '30px', border: '1px solid var(--border)', background: 'transparent', + color: 'var(--text-secondary)', borderRadius: '5px', cursor: 'pointer', + }, + content: { flex: 1, minHeight: 0 }, +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/EmbeddingView.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/EmbeddingView.jsx new file mode 100644 index 0000000000..bc14226257 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/EmbeddingView.jsx @@ -0,0 +1,334 @@ +import React, { useEffect, useRef } from 'react' +import { EmbeddingViewMosaic } from 'embedding-atlas' + +// Color palette for categories (D3 category10 + extended) +const CATEGORY_COLORS = [ + "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", + "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf", + "#aec7e8", "#ffbb78", "#98df8a", "#ff9896", "#c5b0d5", + "#c49c94", "#f7b6d2", "#c7c7c7", "#dbdb8d", "#9edae5" +] + +// Sequential color palette (NVIDIA brand) +const SEQUENTIAL_COLORS = [ + "#c359ef", "#9525C6", "#0046a4", "#0074DF", "#3f8500", + "#76B900", "#ef9100", "#F9C500", "#ff8181", "#EF2020" +] + +// Default color for uniform coloring (NVIDIA green) +const DEFAULT_COLOR = "#76b900" + +// Custom tooltip renderer +class FeatureTooltip { + constructor(node, props) { + this.node = node + this.inner = document.createElement("div") + this.inner.style.cssText = ` + background: var(--bg-card); + border: 1px solid var(--border); + border-radius: 4px; + padding: 8px 12px; + font-family: 'NVIDIA Sans', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; + font-size: 13px; + box-shadow: 0 2px 8px rgba(0,0,0,0.25); + max-width: 300px; + color: var(--text); + ` + this.node.appendChild(this.inner) + this.update(props) + } + + update(props) { + const { tooltip } = props + if (!tooltip) { + this.inner.innerHTML = "" + return + } + const featureId = tooltip.identifier ?? "" + const label = tooltip.fields?.label ?? tooltip.text ?? "" + const logFreq = tooltip.fields?.log_frequency + const maxAct = tooltip.fields?.max_activation + const colorField = tooltip.fields?.color_field + + this.inner.innerHTML = ` +
Feature #${featureId}
+
${label}
+ ${colorField ? `
Category: ${colorField}
` : ""} + ${logFreq !== undefined ? `
Log Frequency: ${logFreq.toFixed(3)}
` : ""} + ${maxAct !== undefined ? `
Max Activation: ${maxAct.toFixed(2)}
` : ""} + ` + } + + destroy() { + this.inner.remove() + } +} + +export default function EmbeddingView({ brush, categoryColumn, categoryColumns, onFeatureClick, highlightedFeatureId, viewportState, onViewportChange, labels, features, selectedCategory, darkMode, hiddenCategories }) { + const containerRef = useRef(null) + const viewRef = useRef(null) + const onFeatureClickRef = useRef(onFeatureClick) + const onViewportChangeRef = useRef(onViewportChange) + + // Keep the callback refs updated + useEffect(() => { + onFeatureClickRef.current = onFeatureClick + }, [onFeatureClick]) + + useEffect(() => { + onViewportChangeRef.current = onViewportChange + }, [onViewportChange]) + + // Update selection and tooltip when highlightedFeatureId changes + useEffect(() => { + if (viewRef.current && highlightedFeatureId != null) { + // Find the feature data + const feature = features?.find(f => f.feature_id === highlightedFeatureId) + + // Build tooltip fields + const fields = { + label: feature?.label || `Feature ${highlightedFeatureId}`, + log_frequency: feature?.log_frequency || feature?.activation_freq || 0, + max_activation: feature?.max_activation || 0, + color_field: null + } + + // Add selected category metric if available + if (selectedCategory && selectedCategory !== 'none' && feature) { + const metricName = selectedCategory.replace(/_/g, ' ') + const metricValue = feature[selectedCategory] + if (metricValue !== undefined && metricValue !== null) { + fields.color_field = `${metricName}: ${typeof metricValue === 'number' ? metricValue.toFixed(3) : metricValue}` + } + } + + // Construct tooltip object with feature data + const tooltipObj = { + identifier: highlightedFeatureId, + text: `Feature #${highlightedFeatureId}`, + x: feature?.x, + y: feature?.y, + fields: fields + } + // Clear previous selection first to avoid animated transition + viewRef.current.update({ + selection: null, + tooltip: null + }) + viewRef.current.update({ + selection: [highlightedFeatureId], + tooltip: tooltipObj + }) + } else if (viewRef.current && highlightedFeatureId == null) { + viewRef.current.update({ + selection: null, + tooltip: null + }) + } + }, [highlightedFeatureId, features, selectedCategory]) + + // Update viewport when viewportState changes (skip null to let auto-fit persist) + useEffect(() => { + if (viewRef.current && viewportState != null) { + viewRef.current.update({ + viewportState: viewportState + }) + } + }, [viewportState]) + + // Update color scheme when dark mode changes + useEffect(() => { + if (viewRef.current) { + viewRef.current.update({ + config: { colorScheme: darkMode ? "dark" : "light" } + }) + } + }, [darkMode]) + + // Update labels when they change + useEffect(() => { + if (viewRef.current && labels) { + console.log('[EmbeddingView] updating labels:', labels.length, labels.slice(0, 2)) + viewRef.current.update({ + labels: labels + }) + } + }, [labels]) + + useEffect(() => { + if (!containerRef.current || !brush) return + + // Clear previous view + if (viewRef.current) { + containerRef.current.innerHTML = '' + } + + // Determine category column and colors + let categoryColName = null + let colors = Array(50).fill(DEFAULT_COLOR) + let additionalFields = { + label: "label", + log_frequency: "log_frequency", + max_activation: "max_activation", + } + + if (categoryColumn && categoryColumn !== "none") { + const colInfo = categoryColumns?.find(c => c.name === categoryColumn) + if (colInfo) { + if (colInfo.type === 'sequential') { + // Sequential column - use binned version and sequential colors + categoryColName = `${categoryColumn}_bin` + colors = SEQUENTIAL_COLORS + } else if (colInfo.type === 'string') { + // Categorical string column + categoryColName = `${categoryColumn}_cat` + colors = CATEGORY_COLORS.slice(0, Math.max(colInfo.nUnique, 10)) + } else { + // Integer categorical column + categoryColName = categoryColumn + colors = CATEGORY_COLORS.slice(0, Math.max(colInfo.nUnique, 10)) + } + additionalFields.color_field = categoryColumn + } + } + + const width = containerRef.current.clientWidth + const height = containerRef.current.clientHeight + + try { + viewRef.current = new EmbeddingViewMosaic( + containerRef.current, + { + table: "features", + x: "x", + y: "y", + category: categoryColName, + text: "label", + identifier: "feature_id", + filter: brush, + rangeSelection: brush, + selection: highlightedFeatureId != null ? [highlightedFeatureId] : null, + viewportState: viewportState, + categoryColors: colors, + width: width, + height: height, + labels: labels || null, + config: { + mode: "points", + colorScheme: document.documentElement.classList.contains('dark') ? "dark" : "light", + autoLabelEnabled: false, + }, + theme: { + brandingLink: { + text: "NVIDIA BioNeMo", + href: "https://github.com/NVIDIA/bionemo-framework", + }, + }, + additionalFields: additionalFields, + customTooltip: FeatureTooltip, + onSelection: (selection) => { + // selection is DataPoint[] | null + if (!onFeatureClickRef.current) return + + if (selection && selection.length > 0) { + // Get the last clicked point (most recent selection) + const lastPoint = selection[selection.length - 1] + const featureId = lastPoint?.identifier ?? lastPoint + const x = lastPoint?.x + const y = lastPoint?.y + if (featureId != null) { + onFeatureClickRef.current(featureId, x, y) + } + } else { + // Clicked on empty canvas - clear selection + onFeatureClickRef.current(null) + } + }, + onViewportState: (vp) => { + if (onViewportChangeRef.current && vp) { + onViewportChangeRef.current(vp) + } + }, + } + ) + } catch (err) { + console.warn('Error creating EmbeddingViewMosaic:', err) + } + + return () => { + if (containerRef.current) { + containerRef.current.innerHTML = '' + } + } + }, [brush]) + + // Update category coloring in-place (without recreating the view) + useEffect(() => { + if (!viewRef.current) return + + let categoryColName = null + const HIDDEN_COLOR = darkMode ? "#0a0a0a" : "#fafafa" + let colors = Array(50).fill(HIDDEN_COLOR) + + if (categoryColumn && categoryColumn !== "none") { + const colInfo = categoryColumns?.find(c => c.name === categoryColumn) + if (colInfo) { + if (colInfo.type === 'sequential') { + categoryColName = `${categoryColumn}_bin` + colors = SEQUENTIAL_COLORS + } else if (colInfo.type === 'string') { + categoryColName = `${categoryColumn}_cat` + colors = CATEGORY_COLORS.slice(0, Math.max(colInfo.nUnique, 10)) + // Map colors to match DENSE_RANK order, dim non-selected when filtering + if (hiddenCategories && hiddenCategories.size > 0 && features) { + const allCatNames = [...new Set( + features.map(f => f[categoryColumn]).filter(v => v != null) + )].sort() + colors = colors.map((c, i) => { + const name = allCatNames[i] + if (!name) return c + return !hiddenCategories.has(name) ? HIDDEN_COLOR : c + }) + } + } else { + categoryColName = categoryColumn + colors = CATEGORY_COLORS.slice(0, Math.max(colInfo.nUnique, 10)) + } + } + } + + viewRef.current.update({ + category: categoryColName, + categoryColors: colors, + selection: null, + tooltip: null, + }) + }, [categoryColumn, categoryColumns, hiddenCategories]) + + // Handle resize + useEffect(() => { + const handleResize = () => { + if (viewRef.current && containerRef.current) { + const width = containerRef.current.clientWidth + const height = containerRef.current.clientHeight + viewRef.current.update({ width, height }) + } + } + + const resizeObserver = new ResizeObserver(handleResize) + if (containerRef.current) { + resizeObserver.observe(containerRef.current) + } + + return () => { + resizeObserver.disconnect() + } + }, []) + + return ( +
+ ) +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/FeatureCard.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/FeatureCard.jsx new file mode 100644 index 0000000000..d45bc121df --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/FeatureCard.jsx @@ -0,0 +1,518 @@ +import React, { useState, useEffect, useRef, forwardRef } from 'react' +import SequenceView, { computeAlignInfo } from './SequenceView' +import FeatureDetailPage from './FeatureDetailPage' +import { getRegionLabel } from './utils' + +const styles = { + card: { + background: 'var(--bg-card)', + borderRadius: '8px', + border: '1px solid var(--border)', + flexShrink: 0, + }, + cardHighlighted: { + background: 'var(--bg-card)', + borderRadius: '8px', + border: '2px solid var(--highlight-border)', + flexShrink: 0, + boxShadow: '0 2px 8px var(--highlight-shadow)', + }, + header: { + padding: '12px 14px', + borderBottom: '1px solid var(--border-light)', + cursor: 'pointer', + display: 'flex', + justifyContent: 'space-between', + alignItems: 'flex-start', + gap: '10px', + }, + headerLeft: { + flex: 1, + minWidth: 0, + }, + featureId: { + fontSize: '11px', + color: 'var(--text-tertiary)', + fontFamily: 'monospace', + marginBottom: '2px', + }, + description: { + fontSize: '13px', + fontWeight: '500', + wordBreak: 'break-word', + lineHeight: '1.4', + color: 'var(--text)', + }, + userTitle: { + fontSize: '13px', + fontWeight: '500', + wordBreak: 'break-word', + lineHeight: '1.4', + color: 'var(--accent)', + fontStyle: 'italic', + }, + stats: { + display: 'flex', + gap: '12px', + fontSize: '11px', + color: 'var(--text-secondary)', + flexShrink: 0, + }, + stat: { + display: 'flex', + flexDirection: 'column', + alignItems: 'flex-end', + }, + statLabel: { + color: 'var(--text-muted)', + fontSize: '9px', + textTransform: 'uppercase', + }, + statValue: { + fontFamily: 'monospace', + fontWeight: '500', + }, + expandIcon: { + color: 'var(--text-muted)', + fontSize: '10px', + marginLeft: '6px', + }, + expandedContent: { + padding: '10px 14px', + background: 'var(--bg-card-expanded)', + maxHeight: '900px', + overflowY: 'auto', + }, + sectionHeader: { + fontSize: '10px', + color: 'var(--text-tertiary)', + textTransform: 'uppercase', + marginBottom: '8px', + fontWeight: '500', + }, + example: { + marginBottom: '8px', + padding: '8px 10px', + background: 'var(--bg-example)', + borderRadius: '4px', + border: '1px solid var(--border-light)', + }, + exampleMeta: { + fontSize: '10px', + color: 'var(--text-muted)', + marginBottom: '4px', + fontFamily: 'monospace', + display: 'flex', + justifyContent: 'space-between', + alignItems: 'center', + }, + proteinId: { + color: 'var(--text-heading)', + fontWeight: '700', + }, + annotation: { + color: 'var(--text-secondary)', + fontStyle: 'italic', + marginLeft: '8px', + }, + uniprotLink: { + color: 'var(--link)', + textDecoration: 'none', + fontSize: '11px', + marginLeft: '4px', + opacity: 0.6, + }, + noExamples: { + color: 'var(--text-muted)', + fontSize: '12px', + fontStyle: 'italic', + }, + densityBar: { + width: '50px', + height: '3px', + background: 'var(--density-bar-bg)', + borderRadius: '2px', + overflow: 'hidden', + marginTop: '3px', + }, + densityFill: { + height: '100%', + background: '#76b900', + borderRadius: '2px', + }, + alignBar: { + display: 'flex', + alignItems: 'center', + gap: '6px', + marginBottom: '10px', + fontSize: '10px', + color: '#888', + }, + alignLabel: { + textTransform: 'uppercase', + fontWeight: '500', + }, + alignBtn: { + padding: '2px 8px', + border: '1px solid #ddd', + borderRadius: '3px', + background: '#fff', + cursor: 'pointer', + fontSize: '10px', + color: '#555', + }, + alignBtnActive: { + padding: '2px 8px', + border: '1px solid #76b900', + borderRadius: '3px', + background: '#f0f9e0', + cursor: 'pointer', + fontSize: '10px', + color: '#333', + fontWeight: '600', + }, +} + +const FeatureCard = forwardRef(function FeatureCard({ feature, isHighlighted, forceExpanded, onClick, loadExamples }, ref) { + const [expanded, setExpanded] = useState(false) + const [showDetailPage, setShowDetailPage] = useState(false) + const [examples, setExamples] = useState([]) + const [loadingExamples, setLoadingExamples] = useState(false) + const examplesCacheRef = useRef(null) + const [alignMode, setAlignMode] = useState('start') + const scrollGroupRef = useRef([]) + const [editingTitle, setEditingTitle] = useState(false) + const [userTitle, setUserTitle] = useState('') + const inputRef = useRef(null) + + // Load user-provided title from localStorage + useEffect(() => { + const stored = localStorage.getItem(`featureTitle_${feature.feature_id}`) + if (stored) { + setUserTitle(stored) + } + }, [feature.feature_id]) + + // Focus input when editing starts + useEffect(() => { + if (editingTitle && inputRef.current) { + inputRef.current.focus() + inputRef.current.select() + } + }, [editingTitle]) + + // Reset scroll group when alignment changes + useEffect(() => { scrollGroupRef.current = [] }, [alignMode]) + + // If forceExpanded changes to true, expand the card + useEffect(() => { + if (forceExpanded) { + setExpanded(true) + } + }, [forceExpanded]) + + // Lazy-load examples from DuckDB when card is expanded + useEffect(() => { + if (!expanded || !loadExamples || examplesCacheRef.current) return + let cancelled = false + setLoadingExamples(true) + loadExamples(feature.feature_id).then(result => { + if (cancelled) return + examplesCacheRef.current = result + setExamples(result) + setLoadingExamples(false) + }).catch(err => { + if (cancelled) return + console.error('Error loading examples for feature', feature.feature_id, err) + setLoadingExamples(false) + }) + return () => { cancelled = true } + }, [expanded, loadExamples, feature.feature_id]) + + const freq = feature.activation_freq || 0 + const maxAct = feature.max_activation || 0 + const rawDesc = feature.label || feature.description || `Feature ${feature.feature_id}` + const description = rawDesc.toLowerCase().includes('common codons') ? 'Unidentified Feature' : rawDesc + + + const handleClick = () => { + const willExpand = !expanded + // Update UMAP highlight immediately, defer card expansion so it doesn't block + if (onClick) { + onClick(feature.feature_id, willExpand) + } + requestAnimationFrame(() => { + setExpanded(willExpand) + }) + } + + const handleSaveTitle = () => { + if (userTitle.trim()) { + localStorage.setItem(`featureTitle_${feature.feature_id}`, userTitle.trim()) + } else { + localStorage.removeItem(`featureTitle_${feature.feature_id}`) + setUserTitle('') + } + setEditingTitle(false) + } + + const handleCancelEdit = () => { + const stored = localStorage.getItem(`featureTitle_${feature.feature_id}`) + setUserTitle(stored || '') + setEditingTitle(false) + } + + const displayTitle = userTitle || description + + const handleTitleKeyDown = (e) => { + if (e.key === 'Enter') { + handleSaveTitle() + } else if (e.key === 'Escape') { + handleCancelEdit() + } + } + + const exportToCSV = () => { + const lines = [] + + // Feature metadata section + lines.push('=== FEATURE METADATA ===') + lines.push(`Feature ID,${feature.feature_id}`) + lines.push(`Label,${displayTitle}`) + if (userTitle) { + lines.push(`User Title,${userTitle}`) + } + lines.push(`Activation Frequency,${(freq * 100).toFixed(2)}%`) + lines.push(`Max Activation,${maxAct.toFixed(4)}`) + lines.push('') + + // Examples section + if (examples && examples.length > 0) { + lines.push('=== ACTIVATION EXAMPLES ===') + lines.push('Rank,Region,Max Activation,Sequence') + examples.forEach((ex, i) => { + lines.push(`${i + 1},${getRegionLabel(ex) || ''},${ex.max_activation?.toFixed(4) || ''},${ex.sequence || ''}`) + }) + } + + // Generate CSV + const csv = lines.join('\n') + + // Create download link + const filename = `feature_${feature.feature_id}_${displayTitle.replace(/[^a-z0-9]/gi, '_').substring(0, 20)}.csv` + const blob = new Blob([csv], { type: 'text/csv;charset=utf-8;' }) + const link = document.createElement('a') + link.setAttribute('href', URL.createObjectURL(blob)) + link.setAttribute('download', filename) + link.style.visibility = 'hidden' + document.body.appendChild(link) + link.click() + document.body.removeChild(link) + } + + return ( +
+
+
+
Feature #{feature.feature_id}
+ {editingTitle ? ( +
+ setUserTitle(e.target.value)} + onKeyDown={handleTitleKeyDown} + onClick={(e) => e.stopPropagation()} + style={{ + fontSize: '13px', + fontWeight: '500', + padding: '4px 8px', + border: '1px solid #76b900', + borderRadius: '4px', + flex: 1, + }} + /> + + +
+ ) : ( +
+
{displayTitle}
+ { e.stopPropagation(); setEditingTitle(true) }} + style={{ + fontSize: '11px', + color: '#999', + cursor: 'pointer', + padding: '2px 4px', + borderRadius: '3px', + userSelect: 'none', + }} + title="Click to edit title" + > + ✎ + +
+ )} +
+
+
+ Freq + {(freq * 100).toFixed(1)}% +
+
+
+
+
+ Max + {maxAct.toFixed(1)} +
+ {/* v2 roadmap placeholders — populated when real eval pipeline lands. */} +
+ Annotation + +
+
+ Sensitivity + +
+
+ Recon Δ + +
+ {expanded ? '▼' : '▶'} +
+
+ + {/* Details and export buttons - shown when expanded */} + {expanded && ( +
+ + +
+ )} + + {expanded && ( +
+ {feature.logo_path && ( +
+
Sequence Logo
+ {`Sequence +
+ )} + {/* Sequence examples */} +
+
Top Activating Sequences
+
+ Align by: + {['start', 'first_activation', 'max_activation'].map(mode => ( + + ))} +
+
+ {loadingExamples ? ( +
+ Loading examples... +
+ ) : examples.length > 0 ? ( + <> + {(() => { + const visibleExamples = examples.slice(0, 6) + const { anchor: alignAnchor, totalLength } = computeAlignInfo(visibleExamples, alignMode) + return visibleExamples.map((ex, i) => ( +
+
+ + {getRegionLabel(ex)} + {ex.best_annotation && ( + {ex.best_annotation} + )} + + max: {ex.max_activation?.toFixed(3) || 'N/A'} +
+ +
+ )) + })()} + + + ) : ( +
No examples available
+ )} +
+ )} + + {showDetailPage && ( + setShowDetailPage(false)} + /> + )} +
+ ) +}) + +export default FeatureCard diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/FeatureDetailPage.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/FeatureDetailPage.jsx new file mode 100644 index 0000000000..b70fdcdbde --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/FeatureDetailPage.jsx @@ -0,0 +1,198 @@ +import React, { useState, useEffect, useRef } from 'react' +import SequenceView, { computeAlignInfo } from './SequenceView' +import { getRegionLabel } from './utils' + +const styles = { + overlay: { + position: 'fixed', + inset: 0, + background: 'rgba(0, 0, 0, 0.5)', + zIndex: 2000, + overflowY: 'auto', + }, + page: { + maxWidth: '960px', + margin: '20px auto', + background: 'var(--bg-card)', + borderRadius: '8px', + boxShadow: '0 4px 24px rgba(0,0,0,0.2)', + color: 'var(--text)', + }, + header: { + padding: '12px 20px', + borderBottom: '1px solid var(--border-light)', + display: 'flex', + justifyContent: 'space-between', + alignItems: 'center', + }, + title: { + fontSize: '14px', + fontWeight: '700', + color: 'var(--text-heading)', + }, + closeBtn: { + background: 'none', + border: '1px solid var(--border-input)', + borderRadius: '4px', + padding: '3px 10px', + cursor: 'pointer', + fontSize: '11px', + color: 'var(--text-secondary)', + }, + section: { + padding: '10px 20px', + borderBottom: '1px solid var(--border-light)', + }, + sectionTitle: { + fontSize: '11px', + fontWeight: '600', + marginBottom: '6px', + color: 'var(--text-heading)', + textTransform: 'uppercase', + }, + example: { + marginBottom: '6px', + padding: '6px 8px', + background: 'var(--bg-example)', + borderRadius: '4px', + border: '1px solid var(--border-light)', + }, + exampleMeta: { + fontSize: '10px', + color: 'var(--text-secondary)', + marginBottom: '4px', + fontFamily: 'monospace', + display: 'flex', + justifyContent: 'space-between', + }, + placeholder: { + border: '1px dashed var(--border)', + borderRadius: '6px', + padding: '24px', + textAlign: 'center', + color: 'var(--text-muted)', + fontSize: '12px', + fontStyle: 'italic', + }, + placeholderLabel: { + fontSize: '13px', + fontWeight: '500', + color: 'var(--text-muted)', + marginBottom: '8px', + }, +} + +export default function FeatureDetailPage({ feature, examples, onClose }) { + const [alignMode, setAlignMode] = useState('max_activation') + const scrollGroupRef = useRef(null) + + const freq = feature.activation_freq || 0 + const maxAct = feature.max_activation || 0 + const description = feature.description || feature.label || `Feature ${feature.feature_id}` + + useEffect(() => { + const handleKey = (e) => { if (e.key === 'Escape') onClose() } + document.addEventListener('keydown', handleKey) + return () => document.removeEventListener('keydown', handleKey) + }, [onClose]) + + const visibleExamples = (examples || []).slice(0, 30) + const { anchor: alignAnchor, totalLength } = computeAlignInfo(visibleExamples.slice(0, 6), alignMode) + + return ( +
{ if (e.target === e.currentTarget) onClose() }}> +
+ +
+
+
+ Feature #{feature.feature_id} + + {description} + +
+
+
+
+ freq: {(freq * 100).toFixed(1)}% + max: {maxAct.toFixed(1)} +
+ +
+
+ + {feature.logo_path && ( +
+
Sequence Logo
+ {`Sequence +
+ )} + +
+
+
Top Activating Sequences
+
+ {['start', 'first_activation', 'max_activation'].map(mode => ( + + ))} +
+
+ + {visibleExamples.length > 0 ? ( + visibleExamples.map((ex, i) => ( +
+
+ {getRegionLabel(ex)} + max: {ex.max_activation?.toFixed(3)} +
+ +
+ )) + ) : ( +
No examples loaded
+ )} +
+ + {/* v2 roadmap placeholders — populated when annotation + conservation pipelines land. */} +
+
Annotations
+
+ Annotation overlay (RefSeq, Rfam, JASPAR) — coming in v2 +
+
+ +
+
Conservation
+
+ Conservation track (phyloP) — coming in v2 +
+
+ +
+
+ ) +} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/FeatureList.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/FeatureList.jsx new file mode 100644 index 0000000000..26cd6c2457 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/FeatureList.jsx @@ -0,0 +1,83 @@ +import React, { memo } from 'react' +import FeatureCard from './FeatureCard' + +const styles = { + featureList: { + flex: 1, + overflowY: 'auto', + overflowX: 'hidden', + display: 'flex', + flexDirection: 'column', + gap: '10px', + paddingRight: '8px', + minHeight: 0, + }, +} + +function FeatureListComponent({ + filteredFeatures, + displayedCardCount, + clickedFeatureId, + features, + cardResetKey, + handleCardClick, + loadExamples, + vocabLogits, + featureAnalysis, + featureListRef, + endOfListRef, + featureRefs, +}) { + const visibleFeatures = filteredFeatures.slice(0, displayedCardCount) + const clickedIsVisible = clickedFeatureId != null && + visibleFeatures.some(f => Number(f.feature_id) === Number(clickedFeatureId)) + const clickedFeature = clickedFeatureId != null && !clickedIsVisible + ? features.find(f => Number(f.feature_id) === Number(clickedFeatureId)) + : null + + return ( +
+ {/* Only render clicked feature at top if NOT already in visible list */} + {clickedFeature && ( + { featureRefs.current[clickedFeature.feature_id] = el }} + feature={clickedFeature} + isHighlighted={true} + forceExpanded={true} + onClick={handleCardClick} + loadExamples={loadExamples} + vocabLogits={vocabLogits} + featureAnalysis={featureAnalysis} + /> + )} + {visibleFeatures.map(feature => ( + { featureRefs.current[feature.feature_id] = el }} + feature={feature} + isHighlighted={Number(clickedFeatureId) === Number(feature.feature_id)} + forceExpanded={Number(clickedFeatureId) === Number(feature.feature_id)} + onClick={handleCardClick} + loadExamples={loadExamples} + vocabLogits={vocabLogits} + featureAnalysis={featureAnalysis} + /> + ))} + {/* Sentinel element for infinite scroll detection */} +
+ {displayedCardCount < filteredFeatures.length && ( +
+ Scroll to load more... ({visibleFeatures.length} of {filteredFeatures.length}) +
+ )} + {filteredFeatures.length === 0 && clickedFeatureId == null && ( +
+ No features match your selection. +
+ )} +
+ ) +} + +export default memo(FeatureListComponent) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/GenerativeSteering.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/GenerativeSteering.jsx new file mode 100644 index 0000000000..2cfd468bf0 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/feature_explorer/src/GenerativeSteering.jsx @@ -0,0 +1,210 @@ +import React, { useEffect, useMemo, useState } from 'react' +import { useHealth, postJSON, getJSON, cleanDNA } from './backend' +import { BackendBanner, OrganismField, FeaturePicker, resolveFeatureId, Row } from './SequenceInspector' + +// Generative steering: autoregressively generate DNA from Evo2 while ADDITIVELY +// clamping one or more SAE features (picked by name) on the generated +// continuation only. Real model + real SAE via backend /generate. + +const BASES_PER_LINE = 80 + +export default function GenerativeSteering() { + const health = useHealth() + const organismTags = health.info?.organism_tags + + const [catalog, setCatalog] = useState([]) + const [organism, setOrganism] = useState('Human') + const [tag, setTag] = useState(null) + const [prompt, setPrompt] = useState('') + const [rows, setRows] = useState([{ q: '', strength: 4 }]) + const [nTokens, setNTokens] = useState(120) + const [temperature, setTemperature] = useState(1.0) + const [compareBaseline, setCompareBaseline] = useState(false) + + const [result, setResult] = useState(null) + const [busy, setBusy] = useState(false) + const [error, setError] = useState(null) + + useEffect(() => { + if (health.status !== 'ready') return + if (tag === null && organismTags) setTag(organismTags[organism] ?? '') + if (!catalog.length) getJSON('/features').then(setCatalog).catch(() => {}) + }, [health.status, organismTags]) + + const clamps = rows + .map((r) => ({ id: resolveFeatureId(catalog, r.q), strength: Number(r.strength) })) + .filter((c) => c.id != null) + + const generate = async () => { + setBusy(true) + setError(null) + try { + const body = { + prompt: cleanDNA(prompt), + organism, + tag: tag ?? (organismTags?.[organism] ?? ''), + features: clamps.map((c) => ({ feature_id: c.id, strength: c.strength })), + n_tokens: Number(nTokens), + temperature: Number(temperature), + compare_baseline: compareBaseline, + } + setResult(await postJSON('/generate', body)) + } catch (e) { + setError(String(e.message || e)) + setResult(null) + } finally { + setBusy(false) + } + } + + const canRun = health.status === 'ready' && !busy // clamps optional — [] = plain generation + + return ( +
+ + + +
+ + + +
+