Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#!/bin/bash
# Launch the Evo2 SAE inference engine. One engine, three modes:
# Launch the Evo2 SAE inference engine. One engine, four modes:
#
# ./launch_inference.sh serve # live HTTP server on :8001 (viz backend)
# ./launch_inference.sh encode --sequence ATGC... # annotate ONE sequence -> top features
# ./launch_inference.sh batch --fasta in.fa --out out.parquet # MANY sequences -> parquet
# ./launch_inference.sh generate --prompt ATGC... --clamp 29244:300 # steer + generate DNA
#
# Steering loop: `encode` a sequence to find an active feature id, then
# `generate --clamp ID:STRENGTH` (strength ~2-3x the feature's max_activation; repeat --clamp).
#
# Config via env. Required: EVO2_CKPT_DIR, SAE_CKPT_PATH. Optional (have defaults):
# FEATURE_ANNOTATIONS, EMBEDDING_LAYER (26), DEVICE, PORT, CUDA_VISIBLE_DEVICES.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Evo2 SAE steering harness — clamp features and measure the causal effect on generation.

Uses ``sae.steering.clamp_hook`` (the shared delta-clamp) registered on the Evo2 decoder layer
the SAE was trained on. Workflow: encode a sequence to find its active features, then for a
**target** feature sweep the clamp strength (dose-response) and for **control** features apply
the same clamp (selectivity), each time comparing the steered continuation to the baseline.

GPU harness — run on an H100 with the inference engine available; this is not a CPU unit test.

python steer.py --evo2-ckpt-dir <mbridge> --sae-checkpoint <sae.pt> --layer 26 \
--sequence ATGGCC... --feature 29244 --controls 12345,54321 --strengths 0,50,100,200

Note: ``sae.steering.clamp_hook`` clamps on *every* forward (prefill + decode), so it steers
the prompt as well as the continuation. The decode-only ("continuation-only") variant lives in
``evo2_sae.core.Evo2SAE._clamp_hook``; unifying the two onto ``sae.steering`` (with a
``decode_only`` flag) is a planned follow-up.
"""

from __future__ import annotations

import argparse
import sys
from contextlib import nullcontext
from pathlib import Path


_HERE = Path(__file__).resolve().parent
sys.path.insert(0, str(_HERE))
sys.path.insert(0, str(_HERE.parent / "src")) # recipes/evo2/src -> evo2_sae package
sys.path.insert(0, str(_HERE.parents[2] / "sae" / "src"))

from sae.steering import steer # noqa: E402


def _divergence(a: str, b: str):
"""Return (first differing index, fraction of differing chars) over the shared prefix length."""
n = min(len(a), len(b))
first = next((i for i in range(n) if a[i] != b[i]), n)
diff = sum(1 for i in range(n) if a[i] != b[i]) / max(1, n)
return first, diff


def main():
"""Encode a sequence, then steer a target feature (dose-response) + control features (selectivity)."""
p = argparse.ArgumentParser(description="Evo2 SAE steering harness (clamp -> continuation effect).")
p.add_argument("--evo2-ckpt-dir", required=True)
p.add_argument("--sae-checkpoint", required=True)
p.add_argument("--layer", type=int, required=True)
p.add_argument("--sequence", required=True)
p.add_argument("--organism", default="None (raw DNA)")
p.add_argument("--feature", type=int, default=None, help="Target feature id (default: top labeled feature).")
p.add_argument("--controls", default="", help="Comma-separated control feature ids (selectivity).")
p.add_argument("--strengths", default="0,50,100,200", help="Comma-separated clamp strengths to sweep.")
p.add_argument("--n-tokens", type=int, default=60)
p.add_argument("--device", default="cuda")
a = p.parse_args()

from bionemo.evo2.run import infer as INF # noqa: E402, I001, RUF100
from evo2_sae.core import Evo2SAE, clean_dna # noqa: E402, RUF100
from megatron.core.utils import unwrap_model # noqa: E402, RUF100

eng = Evo2SAE(a.evo2_ckpt_dir, a.sae_checkpoint, a.layer, device=a.device).load()

# 1. Encode -> the sequence's most-active features (pick a target if not given).
codes = eng.encode(a.sequence)
vals, ids = codes.max(0).values.topk(10)
print(f"top features on {a.sequence[:24]}...:")
target = a.feature
for v, i in zip(vals.tolist(), ids.tolist()):
lab = eng.labels.get(int(i))
print(f" feat {int(i):6d} {str(lab):18s} max_act {v:7.2f}")
if target is None and lab:
target = int(i)
controls = [int(c) for c in a.controls.split(",") if c.strip()]
strengths = [float(s) for s in a.strengths.split(",")]

# 2. The Evo2 decoder layer the SAE hooks + a clean (tag + DNA) prompt.
comp = eng._ensure_engine()
prompt = (eng.resolve_tag(a.organism, None) or "") + clean_dna(a.sequence)
layer_mod = unwrap_model(comp.model).decoder.layers[a.layer]

def gen(clamps):
ctx = steer(layer_mod, eng.sae, clamps) if clamps else nullcontext()
with ctx:
out = INF.generate(comp, [prompt], max_new_tokens=a.n_tokens, temperature=0.0, top_k=1)
return clean_dna(INF._unwrap_result(out[0]).generated_text)

base = gen({})
print(f"\nbaseline: {base[:60]}")
print(f"\n=== dose-response: feature {target} ({eng.labels.get(target)}) ===")
for s in strengths:
steered = gen({target: s})
first, diff = _divergence(base, steered)
print(f" strength {s:7.1f}: diverges@{first:3d} {diff:6.1%} changed {steered[:44]}")

if controls:
s = strengths[-1]
print(f"\n=== selectivity: control features clamped to {s} ===")
for c in controls:
steered = gen({c: s})
first, diff = _divergence(base, steered)
print(f" control {c:6d} ({str(eng.labels.get(c)):16s}): diverges@{first:3d} {diff:6.1%} changed")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evo2 SAE inference CLI — one engine, three modes.
"""Evo2 SAE inference CLI — one engine, four modes.

serve : start the FastAPI server (one sequence at a time, interactive)
encode : annotate ONE sequence -> top features (stdout JSON)
batch : run a FASTA of MANY sequences -> parquet of per-sequence top features
generate: generate DNA, optionally steering SAE features (stdout JSON)

All three build the same `Evo2SAE` engine; config comes from flags or env
They all build the same `Evo2SAE` engine; config comes from flags or env
(EVO2_CKPT_DIR / SAE_CKPT_PATH / FEATURE_ANNOTATIONS / EMBEDDING_LAYER).
"""

Expand Down Expand Up @@ -73,9 +74,21 @@ def _engine(args):
)


def _parse_clamps(clamps: list[str]) -> list[dict]:
"""Parse repeated ``--clamp FEATURE_ID[:STRENGTH]`` args into [{feature_id, strength}].

Strength defaults to 1.0 if omitted (e.g. ``--clamp 29244:300`` or ``--clamp 29244``).
"""
specs = []
for c in clamps:
fid, sep, strength = c.partition(":")
specs.append({"feature_id": int(fid), "strength": float(strength) if (sep and strength) else 1.0})
return specs


def main():
"""Parse args and dispatch to the serve / encode / batch subcommand."""
ap = argparse.ArgumentParser(description="Evo2 SAE inference (serve | encode | batch)")
ap = argparse.ArgumentParser(description="Evo2 SAE inference (serve | encode | batch | generate)")
sub = ap.add_subparsers(dest="cmd", required=True)

ps = sub.add_parser("serve", help="start the FastAPI inference server")
Expand All @@ -96,6 +109,23 @@ def main():
pb.add_argument("--top-k", type=int, default=16)
pb.add_argument("--batch-size", type=int, default=8)

pg = sub.add_parser("generate", help="generate DNA, optionally steering SAE features")
_add_common(pg)
pg.add_argument("--prompt", default="", help="DNA to seed; steering applies to the continuation")
pg.add_argument("--organism", default="None (raw DNA)")
pg.add_argument(
"--clamp",
action="append",
default=[],
metavar="FEATURE_ID[:STRENGTH]",
help="clamp a feature on the continuation; repeatable (e.g. --clamp 29244:300). "
"Find feature ids with `encode`.",
)
pg.add_argument("--n-tokens", type=int, default=120)
pg.add_argument("--temperature", type=float, default=1.0)
pg.add_argument("--top-k", type=int, default=0)
pg.add_argument("--compare-baseline", action="store_true", help="also generate unsteered, for comparison")

args = ap.parse_args()

if args.cmd == "serve":
Expand Down Expand Up @@ -141,6 +171,27 @@ def main():
df.to_parquet(args.out, index=False)
print(f"[batch] wrote {len(df)} rows for {len(seqs)} sequences -> {args.out}")

elif args.cmd == "generate":
out = eng.generate(
prompt=args.prompt,
organism=args.organism,
features=_parse_clamps(args.clamp),
n_tokens=args.n_tokens,
temperature=args.temperature,
top_k=args.top_k,
compare_baseline=args.compare_baseline,
)
result = {
"prompt": out["prompt"],
"organism": out["organism"],
"steered": out["steered"],
"features": out["features"],
"sequence": out["generation"]["sequence"],
}
if out.get("baseline"):
result["baseline_sequence"] = out["baseline"]["sequence"]
print(json.dumps(result, indent=2))


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -294,30 +294,6 @@ def top_features(self, codes: torch.Tensor, tag_len: int = 0, k: int = 8) -> lis
]

# ------------------------------------------------------------------ generate
def _clamp_hook(self, specs, pre_bias):
"""Forward hook that clamps SAE features on the residual during DECODE steps only.

A decode step processes a single new token (sequence dim == 1); the prompt prefill
(sequence dim > 1) is left untouched, giving continuation-only steering through
`infer.generate`: h <- h + Σ_f (t_f - a_f(h)) · d_f
`specs` = list of (enc_f [H], b_f float, dec_f [H], target float).
"""

def hook(_module, _inp, output):
hs = output[0] if isinstance(output, tuple) else output # [S, B, H]
if hs.shape[0] != 1: # prefill (whole prompt) — leave untouched
return output
x = hs.float()
xc = x - pre_bias
add = torch.zeros_like(x)
for enc_f, b_f, dec_f, target in specs:
a = torch.relu(torch.matmul(xc, enc_f) + b_f)
add = add + (target - a).unsqueeze(-1) * dec_f
new = (x + add).to(hs.dtype)
return (new, *output[1:]) if isinstance(output, tuple) else new

return hook

def generate(
self,
prompt="",
Expand Down Expand Up @@ -354,23 +330,16 @@ def generate(
with self._lock:
comp = self._ensure_engine()
hook_layer = unwrap_model(comp.model).decoder.layers[self.layer]
pre_bias = self.sae.pre_bias.detach().float().to(self.device)
specs, feat_meta = [], []
for f in features:
fid = int(f["feature_id"])
specs.append(
(
self.sae.encoder.weight[fid].detach().float().to(self.device),
float(self.sae.latent_bias[fid].detach()),
self.sae.decoder.weight[:, fid].detach().float().to(self.device),
float(f.get("strength", 1.0)),
)
)
feat_meta.append({"id": fid, "label": self.labels.get(fid), "strength": float(f.get("strength", 1.0))})
from sae.steering import clamp_hook

clamps = {int(f["feature_id"]): float(f.get("strength", 1.0)) for f in features}
feat_meta = [{"id": fid, "label": self.labels.get(fid), "strength": s} for fid, s in clamps.items()]

def _run(steer: bool) -> str:
handle = (
hook_layer.register_forward_hook(self._clamp_hook(specs, pre_bias)) if (steer and specs) else None
hook_layer.register_forward_hook(clamp_hook(self.sae, clamps, decode_only=True))
if (steer and clamps)
else None
)
try:
out = INF.generate(
Expand All @@ -382,7 +351,7 @@ def _run(steer: bool) -> str:
handle.remove()

main_dna = _run(steer=True)
base_dna = _run(steer=False) if (compare_baseline and specs) else None
base_dna = _run(steer=False) if (compare_baseline and clamps) else None

resp = {
"prompt": dna,
Expand All @@ -391,7 +360,7 @@ def _run(steer: bool) -> str:
"tag_len": len(resolved_tag),
"n_tokens": n_tokens,
"features": feat_meta,
"steered": bool(specs),
"steered": bool(clamps),
"generation": {"sequence": main_dna, "activations": self.feature_tracks(main_dna, fids)},
"baseline": None,
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""CPU test for the generate CLI's --clamp parsing (no model)."""

from evo2_sae.cli import _parse_clamps


def test_parse_clamps_id_and_strength():
assert _parse_clamps(["29244:300", "88:1.5"]) == [
{"feature_id": 29244, "strength": 300.0},
{"feature_id": 88, "strength": 1.5},
]


def test_parse_clamps_default_strength():
assert _parse_clamps(["29244"]) == [{"feature_id": 29244, "strength": 1.0}]


def test_parse_clamps_empty():
assert _parse_clamps([]) == []
Loading
Loading