Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#!/bin/bash
# Launch the Evo2 SAE inference engine. One engine, three modes:
# Launch the Evo2 SAE inference engine. One engine, four modes:
#
# ./launch_inference.sh serve # live HTTP server on :8001 (viz backend)
# ./launch_inference.sh encode --sequence ATGC... # annotate ONE sequence -> top features
# ./launch_inference.sh batch --fasta in.fa --out out.parquet # MANY sequences -> parquet
# ./launch_inference.sh generate --prompt ATGC... --clamp 29244:300 # steer + generate DNA
#
# Steering loop: `encode` a sequence to find an active feature id, then
# `generate --clamp ID:STRENGTH` (strength ~2-3x the feature's max_activation; repeat --clamp).
#
# Config via env. Required: EVO2_CKPT_DIR, SAE_CKPT_PATH. Optional (have defaults):
# FEATURE_ANNOTATIONS, EMBEDDING_LAYER (26), DEVICE, PORT, CUDA_VISIBLE_DEVICES.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evo2 SAE inference CLI — one engine, three modes.
"""Evo2 SAE inference CLI — one engine, four modes.

serve : start the FastAPI server (one sequence at a time, interactive)
encode : annotate ONE sequence -> top features (stdout JSON)
batch : run a FASTA of MANY sequences -> parquet of per-sequence top features
generate: generate DNA, optionally steering SAE features (stdout JSON)

All three build the same `Evo2SAE` engine; config comes from flags or env
They all build the same `Evo2SAE` engine; config comes from flags or env
(EVO2_CKPT_DIR / SAE_CKPT_PATH / FEATURE_ANNOTATIONS / EMBEDDING_LAYER).
"""

Expand Down Expand Up @@ -73,9 +74,21 @@ def _engine(args):
)


def _parse_clamps(clamps: list[str]) -> list[dict]:
"""Parse repeated ``--clamp FEATURE_ID[:STRENGTH]`` args into [{feature_id, strength}].

Strength defaults to 1.0 if omitted (e.g. ``--clamp 29244:300`` or ``--clamp 29244``).
"""
specs = []
for c in clamps:
fid, sep, strength = c.partition(":")
specs.append({"feature_id": int(fid), "strength": float(strength) if (sep and strength) else 1.0})
return specs


def main():
"""Parse args and dispatch to the serve / encode / batch subcommand."""
ap = argparse.ArgumentParser(description="Evo2 SAE inference (serve | encode | batch)")
ap = argparse.ArgumentParser(description="Evo2 SAE inference (serve | encode | batch | generate)")
sub = ap.add_subparsers(dest="cmd", required=True)

ps = sub.add_parser("serve", help="start the FastAPI inference server")
Expand All @@ -96,6 +109,23 @@ def main():
pb.add_argument("--top-k", type=int, default=16)
pb.add_argument("--batch-size", type=int, default=8)

pg = sub.add_parser("generate", help="generate DNA, optionally steering SAE features")
_add_common(pg)
pg.add_argument("--prompt", default="", help="DNA to seed; steering applies to the continuation")
pg.add_argument("--organism", default="None (raw DNA)")
pg.add_argument(
"--clamp",
action="append",
default=[],
metavar="FEATURE_ID[:STRENGTH]",
help="clamp a feature on the continuation; repeatable (e.g. --clamp 29244:300). "
"Find feature ids with `encode`.",
)
pg.add_argument("--n-tokens", type=int, default=120)
pg.add_argument("--temperature", type=float, default=1.0)
pg.add_argument("--top-k", type=int, default=0)
pg.add_argument("--compare-baseline", action="store_true", help="also generate unsteered, for comparison")

args = ap.parse_args()

if args.cmd == "serve":
Expand Down Expand Up @@ -141,6 +171,27 @@ def main():
df.to_parquet(args.out, index=False)
print(f"[batch] wrote {len(df)} rows for {len(seqs)} sequences -> {args.out}")

elif args.cmd == "generate":
out = eng.generate(
prompt=args.prompt,
organism=args.organism,
features=_parse_clamps(args.clamp),
n_tokens=args.n_tokens,
temperature=args.temperature,
top_k=args.top_k,
compare_baseline=args.compare_baseline,
)
result = {
"prompt": out["prompt"],
"organism": out["organism"],
"steered": out["steered"],
"features": out["features"],
"sequence": out["generation"]["sequence"],
}
if out.get("baseline"):
result["baseline_sequence"] = out["baseline"]["sequence"]
print(json.dumps(result, indent=2))


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""CPU test for the generate CLI's --clamp parsing (no model)."""

from evo2_sae.cli import _parse_clamps


def test_parse_clamps_id_and_strength():
assert _parse_clamps(["29244:300", "88:1.5"]) == [
{"feature_id": 29244, "strength": 300.0},
{"feature_id": 88, "strength": 1.5},
]


def test_parse_clamps_default_strength():
assert _parse_clamps(["29244"]) == [{"feature_id": 29244, "strength": 1.0}]


def test_parse_clamps_empty():
assert _parse_clamps([]) == []
Loading