Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions acestep/engine/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,19 +198,49 @@ def load_trt_engine(self, engine_path):
if hasattr(trt, "bfloat16"):
_trt_dtype_map[trt.bfloat16] = torch.bfloat16

input_names = ("hidden_states", "timestep", "encoder_hidden_states", "context_latents")
engine_input_names = {
self._trt_engine.get_tensor_name(i)
for i in range(self._trt_engine.num_io_tensors)
}
# ``steering`` is present on spectral-prefixed engines only;
# older builds get _steering_num_layers == 0 and the streaming
# pipeline skips the buffer entirely.
has_steering = "steering" in engine_input_names
base_inputs = (
"hidden_states", "timestep", "encoder_hidden_states",
"context_latents",
)
input_names = base_inputs + (("steering",) if has_steering else ())
self._trt_input_dtypes = {
name: _trt_dtype_map.get(self._trt_engine.get_tensor_dtype(name), torch.float32)
for name in input_names
}
if has_steering:
steer_shape = tuple(self._trt_engine.get_tensor_shape("steering"))
if len(steer_shape) != 3:
raise RuntimeError(f"Unexpected steering input rank: {steer_shape}")
self._steering_num_layers = int(steer_shape[1])
self._steering_hidden_size = int(steer_shape[2])
else:
self._steering_num_layers = 0
self._steering_hidden_size = 0
logger.warning(
"TRT decoder engine has no 'steering' input — activation-"
"steering knobs will no-op on this session. Rebuild with "
"`python -m acestep.engine.trt.build --all --force-rebuild` "
"to enable."
)
self._trt_output_dtype = _trt_dtype_map.get(
self._trt_engine.get_tensor_dtype("velocity"), torch.float32
)
self._trt_io_dtype = self._trt_input_dtypes["hidden_states"]
logger.info(
"TRT decoder engine ready (input_dtypes={}, output_dtype={})",
"TRT decoder engine ready (input_dtypes={}, output_dtype={}, "
"steering L={}, D={})",
self._trt_input_dtypes,
self._trt_output_dtype,
self._steering_num_layers,
self._steering_hidden_size,
)

# Try to initialize LoRA refit manager (requires REFIT-enabled engine)
Expand Down Expand Up @@ -526,11 +556,15 @@ def _trt_decoder_step(
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
context_latents: torch.Tensor,
steering: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Run one decoder step through TRT with pre-allocated buffers.

Handles odd-T padding. Caches buffers by shape for reuse across
steps. Calls execute_async_v3 on the shared polygraphy stream.

``steering`` is optional ``[B, num_layers, hidden_size]``;
when omitted the cached zero buffer makes the per-layer adds a no-op.
"""
orig_T = hidden_states.shape[1]
pad = orig_T % 2 == 1
Expand All @@ -548,13 +582,19 @@ def _trt_decoder_step(
dev = hidden_states.device
hs_shape, ts_shape, enc_shape, cl_shape = key
in_dtypes = self._trt_input_dtypes
B = hs_shape[0]

bufs = {
"hidden_states": torch.empty(hs_shape, dtype=in_dtypes["hidden_states"], device=dev),
"timestep": torch.empty(ts_shape, dtype=in_dtypes["timestep"], device=dev),
"encoder_hidden_states": torch.empty(enc_shape, dtype=in_dtypes["encoder_hidden_states"], device=dev),
"context_latents": torch.empty(cl_shape, dtype=in_dtypes["context_latents"], device=dev),
}
if self._steering_num_layers > 0:
bufs["steering"] = torch.zeros(
B, self._steering_num_layers, self._steering_hidden_size,
dtype=in_dtypes["steering"], device=dev,
)
for name, buf in bufs.items():
if not ctx.set_input_shape(name, tuple(buf.shape)):
raise RuntimeError(f"TRT decoder rejected input shape for {name}: {tuple(buf.shape)}")
Expand Down Expand Up @@ -589,6 +629,11 @@ def _trt_decoder_step(
bufs["context_latents"].copy_(context_latents)
bufs["timestep"].copy_(timestep)
bufs["encoder_hidden_states"].copy_(encoder_hidden_states)
if "steering" in bufs:
if steering is not None:
bufs["steering"].copy_(steering)
else:
bufs["steering"].zero_()

ctx = self._trt_ctx
for name, buf in bufs.items():
Expand Down
165 changes: 160 additions & 5 deletions acestep/engine/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import time
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Callable, Optional, List, Tuple, TYPE_CHECKING
from typing import Callable, Dict, NamedTuple, Optional, List, Tuple, TYPE_CHECKING

from loguru import logger
import torch
Expand All @@ -29,6 +29,13 @@
from .masking import LatentNoiseMask


class _SteeringApply(NamedTuple):
"""One pre-resolved activation-steering shift bound to a layer."""
vector: torch.Tensor # 1-D [hidden_dim]
scale: float # alpha * magnitude
step: int # gate: only rows at this denoise step receive it


@dataclass
class SlotCondition:
"""One conditioning entry for multi-condition per-frame blending.
Expand Down Expand Up @@ -285,6 +292,10 @@ def __init__(
self._trt_io_dtype = getattr(engine, '_trt_io_dtype', torch.float32)
self._trt_input_dtypes = getattr(engine, "_trt_input_dtypes", {}) or {}
self._trt_output_dtype = getattr(engine, "_trt_output_dtype", self._trt_io_dtype)
# Steering shape (constants per engine); snapshotted so the
# per-tick buffer fill doesn't re-query TRT.
self._steering_num_layers = getattr(engine, "_steering_num_layers", 0)
self._steering_hidden_size = getattr(engine, "_steering_hidden_size", 0)

# Currently-bound TRT I/O buffers (set by _ensure_trt_bufs to one
# entry of _trt_bufs_cache). _trt_forward reads these directly.
Expand Down Expand Up @@ -325,6 +336,15 @@ def __init__(
# ``.to(...)`` is a no-op.
self._channel_gain: Optional[torch.Tensor] = None

# Activation steering. Per-DiT-layer additive shift on the
# post-block residual, gated per-row by denoise step.
# ``_current_step_per_row`` is populated by _tick_complex_pt
# around each forward; empty means "skip injection" so a
# forward issued outside the rendezvous can't fire steering.
self._steering_by_layer: Dict[int, List[_SteeringApply]] = {}
self._steering_hooks_installed: bool = False
self._current_step_per_row: List[int] = []

# Sentinel tensors for the "always-on multiply" idiom in the step
# helpers. Built lazily once the first slot's device/dtype is known.
# ``_ones_3d`` stands in for absent ``velocity_scale`` (vt * 1 = vt).
Expand Down Expand Up @@ -402,9 +422,14 @@ def _on_engine_swapped(self) -> None:
self._trt_output_dtype = getattr(
engine, "_trt_output_dtype", self._trt_io_dtype
)
self._steering_num_layers = getattr(engine, "_steering_num_layers", 0)
self._steering_hidden_size = getattr(engine, "_steering_hidden_size", 0)
self._trt_bufs = None
self._trt_out_buf = None
self._trt_bufs_cache.clear()
# Hooks live on the old decoder.layers; the new decoder needs a
# fresh install on the next non-empty set_steering call.
self._steering_hooks_installed = False

def submit(self, request: SlotRequest) -> None:
"""Enqueue a generation request.
Expand Down Expand Up @@ -712,6 +737,10 @@ def _trt_forward(
else:
bufs["context_latents"].copy_(ctx_io)

# Steering: absent on non-spectral engines.
if "steering" in bufs:
self._fill_trt_steering_buffer(bufs["steering"], B)

# Rebind and execute.
ctx = self._trt_ctx
for name, buf in bufs.items():
Expand All @@ -731,6 +760,39 @@ def _trt_forward(
return out[:, :T, :].to(self._dtype)
return out.to(self._dtype)

def _steering_row_mask(self, target_step: int, B: int) -> Optional[List[int]]:
"""Rows whose slot is at ``target_step``, or None to skip.

Returns None when no per-row step mapping exists or it
disagrees with ``B`` — the eager hook and TRT buffer fill both
skip injection in that case rather than firing blindly.
"""
row_steps = self._current_step_per_row
if not row_steps or len(row_steps) != B:
return None
mask = [i for i, s in enumerate(row_steps) if s == target_step]
return mask or None

def _fill_trt_steering_buffer(self, buf: torch.Tensor, B: int) -> None:
"""Populate the TRT steering buffer for one forward.

``buf`` is ``[B, num_layers, hidden_size]``; zeroed first so
previous-tick content doesn't leak. Rows with no matching shift
stay zero, which the engine adds as a no-op per layer.
"""
buf.zero_()
if not self._steering_by_layer:
return
for layer_idx, applies in self._steering_by_layer.items():
if layer_idx < 0 or layer_idx >= self._steering_num_layers:
continue
for apply in applies:
mask_rows = self._steering_row_mask(apply.step, B)
if mask_rows is None:
continue
v = apply.vector.to(device=buf.device, dtype=buf.dtype)
buf[mask_rows, layer_idx, :] += apply.scale * v

# ------------------------------------------------------------------
# TRT buffer management
# ------------------------------------------------------------------
Expand Down Expand Up @@ -787,6 +849,15 @@ def _ensure_trt_bufs(self, B: int, T: int, max_L: int):
device=device,
),
}
if self._steering_num_layers > 0:
# Zeroed so a tick with no active configs is a true no-op
# (engine still adds zeros per layer); repopulated each
# forward by _trt_forward.
bufs["steering"] = torch.zeros(
B, self._steering_num_layers, self._steering_hidden_size,
dtype=in_dtypes.get("steering", io_dtype),
device=device,
)

for name, buf in bufs.items():
if not ctx.set_input_shape(name, tuple(buf.shape)):
Expand Down Expand Up @@ -1033,7 +1104,13 @@ def _forward_pairs(
for c in pos_conds_per_slot[si]:
pos_pair_si.append(si)
pos_pair_cond.append(c)
vt_pos_all = _forward_pairs(pos_pair_si, pos_pair_cond)
# Per-row step mapping for the steering hook; try/finally so a
# stale list never leaks into a forward outside this rendezvous.
self._current_step_per_row = [slots[si].step_idx for si in pos_pair_si]
try:
vt_pos_all = _forward_pairs(pos_pair_si, pos_pair_cond)
finally:
self._current_step_per_row = []

# --- Negative pass (CFG only): skipped when no slot has CFG. ---
neg_pair_si: List[int] = []
Expand All @@ -1042,9 +1119,14 @@ def _forward_pairs(
for c in neg_conds_per_slot[si]:
neg_pair_si.append(si)
neg_pair_cond.append(c)
vt_neg_all = (
_forward_pairs(neg_pair_si, neg_pair_cond) if neg_pair_si else None
)
if neg_pair_si:
self._current_step_per_row = [slots[si].step_idx for si in neg_pair_si]
try:
vt_neg_all = _forward_pairs(neg_pair_si, neg_pair_cond)
finally:
self._current_step_per_row = []
else:
vt_neg_all = None

# --- Per-slot: blend pos, blend neg (if CFG), APG-combine ---
vt_per_slot: List[torch.Tensor] = [None] * len(slots) # type: ignore[list-item]
Expand Down Expand Up @@ -1425,6 +1507,79 @@ def set_channel_gain_tensor(self, gain: Optional[torch.Tensor]) -> None:
dt = self._dtype or torch.float16
self._channel_gain = gain.to(device=dev, dtype=dt)

def set_steering(self, configs: List[Dict[str, object]]) -> None:
"""Set activation-steering configs.

Each config dict has:
- ``layer``: int, index into ``decoder.layers``
- ``step``: int, the denoise step at which to fire
- ``vector``: 1-D unit-norm tensor [hidden_dim]
- ``magnitude``: float, paired mean-diff scale
- ``alpha``: float, knob value; effective shift is
``alpha * magnitude * vector``

Multiple configs may share a layer (additions sum, modulo the
per-row step gate). Zero-alpha entries drop. Pass ``[]`` to
clear. Eager: forward hooks on ``decoder.layers`` (installed
lazily). TRT: per-tick buffer fill in ``_trt_forward``.
"""
by_layer: Dict[int, List[_SteeringApply]] = {}
for c in configs:
alpha = float(c.get("alpha", 0.0))
if alpha == 0.0:
continue
mag = float(c.get("magnitude", 1.0))
li = int(c["layer"])
by_layer.setdefault(li, []).append(_SteeringApply(
vector=c["vector"], # type: ignore[arg-type]
scale=alpha * mag,
step=int(c["step"]),
))
self._steering_by_layer = by_layer
# Eager-path hooks only matter when the PyTorch decoder runs.
# With TRT active the decoder may be a stub.
if (
by_layer
and not self._steering_hooks_installed
and self._trt_engine is None
):
self._install_steering_hooks()
self._steering_hooks_installed = True

def _install_steering_hooks(self) -> None:
"""Attach one forward hook per DiT layer.

Reads ``_steering_by_layer`` + ``_current_step_per_row`` at call
time. Idempotent via ``_steering_hooks_installed`` (cleared on
engine swap so a new eager decoder gets fresh hooks).
"""
layers = self.decoder.layers

def make_hook(layer_idx: int):
def _hook(_module, _inputs, output):
applies = self._steering_by_layer.get(layer_idx)
if not applies:
return output
hs = output[0] if isinstance(output, tuple) else output
B = hs.shape[0]
changed = False
for apply in applies:
mask_rows = self._steering_row_mask(apply.step, B)
if mask_rows is None:
continue
v = apply.vector.to(device=hs.device, dtype=hs.dtype)
hs[mask_rows] = hs[mask_rows] + apply.scale * v.view(1, 1, -1)
changed = True
if not changed:
return output
if isinstance(output, tuple):
return (hs,) + output[1:]
return hs
return _hook

for li in range(len(layers)):
layers[li].register_forward_hook(make_hook(li))

def stats(self) -> dict:
return {
"ticks": self.ticks,
Expand Down
5 changes: 4 additions & 1 deletion acestep/engine/trt/_engine_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
from loguru import logger


_ENGINE_METADATA_SCHEMA = 1
# Schema 2: decoder ONNX/engines carry the ``steering`` input (the
# ``spectral_decoder_*`` family); pre-steering engine metadata is
# invalidated so the next build run refreshes it.
_ENGINE_METADATA_SCHEMA = 2


def _sha256_file(path: str | os.PathLike[str]) -> str:
Expand Down
17 changes: 15 additions & 2 deletions acestep/engine/trt/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,12 +899,25 @@ def _print_matrix(durations, build_vae, build_decoder, output_dir, batch_max,
checkpoint="acestep-v15-turbo", build_dreamvae=False):
"""Print the build matrix for --all mode, showing existing vs new."""
variant = _checkpoint_to_variant(checkpoint)
vtag = f"_{variant}" if variant != "turbo" else ""

from acestep.paths import (
WINDOWED_VAE_DECODE_NAME,
WINDOWED_DREAMVAE_DECODE_NAME,
)
# Defer to TRTBuildConfig.engine_filename so this preview can't
# drift from what _build_decoder_engine writes.
from .export import TRTBuildConfig

def _decoder_dir_name(dur: int) -> str:
cfg = TRTBuildConfig(
fp16=True,
strongly_typed=True,
refit=True,
batch_max=batch_max,
seq_max=dur * 25,
variant=variant,
)
return cfg.engine_filename().replace(".engine", "")

# (label, engine_dir_name) pairs
jobs = []
Expand All @@ -913,7 +926,7 @@ def _print_matrix(durations, build_vae, build_decoder, output_dir, batch_max,
jobs.append((f"VAE decode {dur}s", f"vae_decode_fp16_{dur}s"))
jobs.append((f"VAE encode {dur}s", f"vae_encode_fp16_{dur}s"))
if build_decoder:
jobs.append((f"Decoder {variant} {dur}s, refit", f"decoder{vtag}_mixed_refit_b{batch_max}_{dur}s"))
jobs.append((f"Decoder {variant} {dur}s, refit", _decoder_dir_name(dur)))
if build_dreamvae:
jobs.append((f"DreamVAE decode {dur}s", f"dreamvae_decode_fp16_{dur}s"))

Expand Down
Loading