From 7f92e8ff1141fbe234a287589315fa8e993c0f8f Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Thu, 28 May 2026 09:36:09 -0400 Subject: [PATCH 1/3] feat: spectral control backend --- acestep/engine/diffusion.py | 49 ++++- acestep/engine/stream.py | 165 ++++++++++++++++- acestep/engine/trt/_engine_metadata.py | 5 +- acestep/engine/trt/build.py | 17 +- acestep/engine/trt/export.py | 52 +++++- acestep/engine/trt/runtime.py | 25 ++- acestep/paths.py | 45 ++++- acestep/steering/__init__.py | 57 ++++++ acestep/steering/catalog.py | 109 +++++++++++ acestep/steering/controller.py | 238 +++++++++++++++++++++++++ acestep/steering/hub.py | 155 ++++++++++++++++ acestep/steering/policy.py | 145 +++++++++++++++ acestep/steering/types.py | 61 +++++++ 13 files changed, 1101 insertions(+), 22 deletions(-) create mode 100644 acestep/steering/__init__.py create mode 100644 acestep/steering/catalog.py create mode 100644 acestep/steering/controller.py create mode 100644 acestep/steering/hub.py create mode 100644 acestep/steering/policy.py create mode 100644 acestep/steering/types.py diff --git a/acestep/engine/diffusion.py b/acestep/engine/diffusion.py index b3e32b49..02c8d0f8 100644 --- a/acestep/engine/diffusion.py +++ b/acestep/engine/diffusion.py @@ -198,19 +198,49 @@ def load_trt_engine(self, engine_path): if hasattr(trt, "bfloat16"): _trt_dtype_map[trt.bfloat16] = torch.bfloat16 - input_names = ("hidden_states", "timestep", "encoder_hidden_states", "context_latents") + engine_input_names = { + self._trt_engine.get_tensor_name(i) + for i in range(self._trt_engine.num_io_tensors) + } + # ``steering`` is present on spectral-prefixed engines only; + # older builds get _steering_num_layers == 0 and the streaming + # pipeline skips the buffer entirely. + has_steering = "steering" in engine_input_names + base_inputs = ( + "hidden_states", "timestep", "encoder_hidden_states", + "context_latents", + ) + input_names = base_inputs + (("steering",) if has_steering else ()) self._trt_input_dtypes = { name: _trt_dtype_map.get(self._trt_engine.get_tensor_dtype(name), torch.float32) for name in input_names } + if has_steering: + steer_shape = tuple(self._trt_engine.get_tensor_shape("steering")) + if len(steer_shape) != 3: + raise RuntimeError(f"Unexpected steering input rank: {steer_shape}") + self._steering_num_layers = int(steer_shape[1]) + self._steering_hidden_size = int(steer_shape[2]) + else: + self._steering_num_layers = 0 + self._steering_hidden_size = 0 + logger.warning( + "TRT decoder engine has no 'steering' input — activation-" + "steering knobs will no-op on this session. Rebuild with " + "`python -m acestep.engine.trt.build --all --force-rebuild` " + "to enable." + ) self._trt_output_dtype = _trt_dtype_map.get( self._trt_engine.get_tensor_dtype("velocity"), torch.float32 ) self._trt_io_dtype = self._trt_input_dtypes["hidden_states"] logger.info( - "TRT decoder engine ready (input_dtypes={}, output_dtype={})", + "TRT decoder engine ready (input_dtypes={}, output_dtype={}, " + "steering L={}, D={})", self._trt_input_dtypes, self._trt_output_dtype, + self._steering_num_layers, + self._steering_hidden_size, ) # Try to initialize LoRA refit manager (requires REFIT-enabled engine) @@ -526,11 +556,15 @@ def _trt_decoder_step( timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, context_latents: torch.Tensor, + steering: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Run one decoder step through TRT with pre-allocated buffers. Handles odd-T padding. Caches buffers by shape for reuse across steps. Calls execute_async_v3 on the shared polygraphy stream. + + ``steering`` is optional ``[B, num_layers, hidden_size]``; + when omitted the cached zero buffer makes the per-layer adds a no-op. """ orig_T = hidden_states.shape[1] pad = orig_T % 2 == 1 @@ -548,6 +582,7 @@ def _trt_decoder_step( dev = hidden_states.device hs_shape, ts_shape, enc_shape, cl_shape = key in_dtypes = self._trt_input_dtypes + B = hs_shape[0] bufs = { "hidden_states": torch.empty(hs_shape, dtype=in_dtypes["hidden_states"], device=dev), @@ -555,6 +590,11 @@ def _trt_decoder_step( "encoder_hidden_states": torch.empty(enc_shape, dtype=in_dtypes["encoder_hidden_states"], device=dev), "context_latents": torch.empty(cl_shape, dtype=in_dtypes["context_latents"], device=dev), } + if self._steering_num_layers > 0: + bufs["steering"] = torch.zeros( + B, self._steering_num_layers, self._steering_hidden_size, + dtype=in_dtypes["steering"], device=dev, + ) for name, buf in bufs.items(): if not ctx.set_input_shape(name, tuple(buf.shape)): raise RuntimeError(f"TRT decoder rejected input shape for {name}: {tuple(buf.shape)}") @@ -589,6 +629,11 @@ def _trt_decoder_step( bufs["context_latents"].copy_(context_latents) bufs["timestep"].copy_(timestep) bufs["encoder_hidden_states"].copy_(encoder_hidden_states) + if "steering" in bufs: + if steering is not None: + bufs["steering"].copy_(steering) + else: + bufs["steering"].zero_() ctx = self._trt_ctx for name, buf in bufs.items(): diff --git a/acestep/engine/stream.py b/acestep/engine/stream.py index a4672494..839b45c2 100644 --- a/acestep/engine/stream.py +++ b/acestep/engine/stream.py @@ -15,7 +15,7 @@ import time from collections import OrderedDict from dataclasses import dataclass, field -from typing import Callable, Optional, List, Tuple, TYPE_CHECKING +from typing import Callable, Dict, NamedTuple, Optional, List, Tuple, TYPE_CHECKING from loguru import logger import torch @@ -29,6 +29,13 @@ from .masking import LatentNoiseMask +class _SteeringApply(NamedTuple): + """One pre-resolved activation-steering shift bound to a layer.""" + vector: torch.Tensor # 1-D [hidden_dim] + scale: float # alpha * magnitude + step: int # gate: only rows at this denoise step receive it + + @dataclass class SlotCondition: """One conditioning entry for multi-condition per-frame blending. @@ -285,6 +292,10 @@ def __init__( self._trt_io_dtype = getattr(engine, '_trt_io_dtype', torch.float32) self._trt_input_dtypes = getattr(engine, "_trt_input_dtypes", {}) or {} self._trt_output_dtype = getattr(engine, "_trt_output_dtype", self._trt_io_dtype) + # Steering shape (constants per engine); snapshotted so the + # per-tick buffer fill doesn't re-query TRT. + self._steering_num_layers = getattr(engine, "_steering_num_layers", 0) + self._steering_hidden_size = getattr(engine, "_steering_hidden_size", 0) # Currently-bound TRT I/O buffers (set by _ensure_trt_bufs to one # entry of _trt_bufs_cache). _trt_forward reads these directly. @@ -325,6 +336,15 @@ def __init__( # ``.to(...)`` is a no-op. self._channel_gain: Optional[torch.Tensor] = None + # Activation steering. Per-DiT-layer additive shift on the + # post-block residual, gated per-row by denoise step. + # ``_current_step_per_row`` is populated by _tick_complex_pt + # around each forward; empty means "skip injection" so a + # forward issued outside the rendezvous can't fire steering. + self._steering_by_layer: Dict[int, List[_SteeringApply]] = {} + self._steering_hooks_installed: bool = False + self._current_step_per_row: List[int] = [] + # Sentinel tensors for the "always-on multiply" idiom in the step # helpers. Built lazily once the first slot's device/dtype is known. # ``_ones_3d`` stands in for absent ``velocity_scale`` (vt * 1 = vt). @@ -402,9 +422,14 @@ def _on_engine_swapped(self) -> None: self._trt_output_dtype = getattr( engine, "_trt_output_dtype", self._trt_io_dtype ) + self._steering_num_layers = getattr(engine, "_steering_num_layers", 0) + self._steering_hidden_size = getattr(engine, "_steering_hidden_size", 0) self._trt_bufs = None self._trt_out_buf = None self._trt_bufs_cache.clear() + # Hooks live on the old decoder.layers; the new decoder needs a + # fresh install on the next non-empty set_steering call. + self._steering_hooks_installed = False def submit(self, request: SlotRequest) -> None: """Enqueue a generation request. @@ -712,6 +737,10 @@ def _trt_forward( else: bufs["context_latents"].copy_(ctx_io) + # Steering: absent on non-spectral engines. + if "steering" in bufs: + self._fill_trt_steering_buffer(bufs["steering"], B) + # Rebind and execute. ctx = self._trt_ctx for name, buf in bufs.items(): @@ -731,6 +760,39 @@ def _trt_forward( return out[:, :T, :].to(self._dtype) return out.to(self._dtype) + def _steering_row_mask(self, target_step: int, B: int) -> Optional[List[int]]: + """Rows whose slot is at ``target_step``, or None to skip. + + Returns None when no per-row step mapping exists or it + disagrees with ``B`` — the eager hook and TRT buffer fill both + skip injection in that case rather than firing blindly. + """ + row_steps = self._current_step_per_row + if not row_steps or len(row_steps) != B: + return None + mask = [i for i, s in enumerate(row_steps) if s == target_step] + return mask or None + + def _fill_trt_steering_buffer(self, buf: torch.Tensor, B: int) -> None: + """Populate the TRT steering buffer for one forward. + + ``buf`` is ``[B, num_layers, hidden_size]``; zeroed first so + previous-tick content doesn't leak. Rows with no matching shift + stay zero, which the engine adds as a no-op per layer. + """ + buf.zero_() + if not self._steering_by_layer: + return + for layer_idx, applies in self._steering_by_layer.items(): + if layer_idx < 0 or layer_idx >= self._steering_num_layers: + continue + for apply in applies: + mask_rows = self._steering_row_mask(apply.step, B) + if mask_rows is None: + continue + v = apply.vector.to(device=buf.device, dtype=buf.dtype) + buf[mask_rows, layer_idx, :] += apply.scale * v + # ------------------------------------------------------------------ # TRT buffer management # ------------------------------------------------------------------ @@ -787,6 +849,15 @@ def _ensure_trt_bufs(self, B: int, T: int, max_L: int): device=device, ), } + if self._steering_num_layers > 0: + # Zeroed so a tick with no active configs is a true no-op + # (engine still adds zeros per layer); repopulated each + # forward by _trt_forward. + bufs["steering"] = torch.zeros( + B, self._steering_num_layers, self._steering_hidden_size, + dtype=in_dtypes.get("steering", io_dtype), + device=device, + ) for name, buf in bufs.items(): if not ctx.set_input_shape(name, tuple(buf.shape)): @@ -1033,7 +1104,13 @@ def _forward_pairs( for c in pos_conds_per_slot[si]: pos_pair_si.append(si) pos_pair_cond.append(c) - vt_pos_all = _forward_pairs(pos_pair_si, pos_pair_cond) + # Per-row step mapping for the steering hook; try/finally so a + # stale list never leaks into a forward outside this rendezvous. + self._current_step_per_row = [slots[si].step_idx for si in pos_pair_si] + try: + vt_pos_all = _forward_pairs(pos_pair_si, pos_pair_cond) + finally: + self._current_step_per_row = [] # --- Negative pass (CFG only): skipped when no slot has CFG. --- neg_pair_si: List[int] = [] @@ -1042,9 +1119,14 @@ def _forward_pairs( for c in neg_conds_per_slot[si]: neg_pair_si.append(si) neg_pair_cond.append(c) - vt_neg_all = ( - _forward_pairs(neg_pair_si, neg_pair_cond) if neg_pair_si else None - ) + if neg_pair_si: + self._current_step_per_row = [slots[si].step_idx for si in neg_pair_si] + try: + vt_neg_all = _forward_pairs(neg_pair_si, neg_pair_cond) + finally: + self._current_step_per_row = [] + else: + vt_neg_all = None # --- Per-slot: blend pos, blend neg (if CFG), APG-combine --- vt_per_slot: List[torch.Tensor] = [None] * len(slots) # type: ignore[list-item] @@ -1425,6 +1507,79 @@ def set_channel_gain_tensor(self, gain: Optional[torch.Tensor]) -> None: dt = self._dtype or torch.float16 self._channel_gain = gain.to(device=dev, dtype=dt) + def set_steering(self, configs: List[Dict[str, object]]) -> None: + """Set activation-steering configs. + + Each config dict has: + - ``layer``: int, index into ``decoder.layers`` + - ``step``: int, the denoise step at which to fire + - ``vector``: 1-D unit-norm tensor [hidden_dim] + - ``magnitude``: float, paired mean-diff scale + - ``alpha``: float, knob value; effective shift is + ``alpha * magnitude * vector`` + + Multiple configs may share a layer (additions sum, modulo the + per-row step gate). Zero-alpha entries drop. Pass ``[]`` to + clear. Eager: forward hooks on ``decoder.layers`` (installed + lazily). TRT: per-tick buffer fill in ``_trt_forward``. + """ + by_layer: Dict[int, List[_SteeringApply]] = {} + for c in configs: + alpha = float(c.get("alpha", 0.0)) + if alpha == 0.0: + continue + mag = float(c.get("magnitude", 1.0)) + li = int(c["layer"]) + by_layer.setdefault(li, []).append(_SteeringApply( + vector=c["vector"], # type: ignore[arg-type] + scale=alpha * mag, + step=int(c["step"]), + )) + self._steering_by_layer = by_layer + # Eager-path hooks only matter when the PyTorch decoder runs. + # With TRT active the decoder may be a stub. + if ( + by_layer + and not self._steering_hooks_installed + and self._trt_engine is None + ): + self._install_steering_hooks() + self._steering_hooks_installed = True + + def _install_steering_hooks(self) -> None: + """Attach one forward hook per DiT layer. + + Reads ``_steering_by_layer`` + ``_current_step_per_row`` at call + time. Idempotent via ``_steering_hooks_installed`` (cleared on + engine swap so a new eager decoder gets fresh hooks). + """ + layers = self.decoder.layers + + def make_hook(layer_idx: int): + def _hook(_module, _inputs, output): + applies = self._steering_by_layer.get(layer_idx) + if not applies: + return output + hs = output[0] if isinstance(output, tuple) else output + B = hs.shape[0] + changed = False + for apply in applies: + mask_rows = self._steering_row_mask(apply.step, B) + if mask_rows is None: + continue + v = apply.vector.to(device=hs.device, dtype=hs.dtype) + hs[mask_rows] = hs[mask_rows] + apply.scale * v.view(1, 1, -1) + changed = True + if not changed: + return output + if isinstance(output, tuple): + return (hs,) + output[1:] + return hs + return _hook + + for li in range(len(layers)): + layers[li].register_forward_hook(make_hook(li)) + def stats(self) -> dict: return { "ticks": self.ticks, diff --git a/acestep/engine/trt/_engine_metadata.py b/acestep/engine/trt/_engine_metadata.py index c49097be..e3571c90 100644 --- a/acestep/engine/trt/_engine_metadata.py +++ b/acestep/engine/trt/_engine_metadata.py @@ -37,7 +37,10 @@ from loguru import logger -_ENGINE_METADATA_SCHEMA = 1 +# Schema 2: decoder ONNX/engines carry the ``steering`` input (the +# ``spectral_decoder_*`` family); pre-steering engine metadata is +# invalidated so the next build run refreshes it. +_ENGINE_METADATA_SCHEMA = 2 def _sha256_file(path: str | os.PathLike[str]) -> str: diff --git a/acestep/engine/trt/build.py b/acestep/engine/trt/build.py index 57159265..969c8369 100644 --- a/acestep/engine/trt/build.py +++ b/acestep/engine/trt/build.py @@ -899,12 +899,25 @@ def _print_matrix(durations, build_vae, build_decoder, output_dir, batch_max, checkpoint="acestep-v15-turbo", build_dreamvae=False): """Print the build matrix for --all mode, showing existing vs new.""" variant = _checkpoint_to_variant(checkpoint) - vtag = f"_{variant}" if variant != "turbo" else "" from acestep.paths import ( WINDOWED_VAE_DECODE_NAME, WINDOWED_DREAMVAE_DECODE_NAME, ) + # Defer to TRTBuildConfig.engine_filename so this preview can't + # drift from what _build_decoder_engine writes. + from .export import TRTBuildConfig + + def _decoder_dir_name(dur: int) -> str: + cfg = TRTBuildConfig( + fp16=True, + strongly_typed=True, + refit=True, + batch_max=batch_max, + seq_max=dur * 25, + variant=variant, + ) + return cfg.engine_filename().replace(".engine", "") # (label, engine_dir_name) pairs jobs = [] @@ -913,7 +926,7 @@ def _print_matrix(durations, build_vae, build_decoder, output_dir, batch_max, jobs.append((f"VAE decode {dur}s", f"vae_decode_fp16_{dur}s")) jobs.append((f"VAE encode {dur}s", f"vae_encode_fp16_{dur}s")) if build_decoder: - jobs.append((f"Decoder {variant} {dur}s, refit", f"decoder{vtag}_mixed_refit_b{batch_max}_{dur}s")) + jobs.append((f"Decoder {variant} {dur}s, refit", _decoder_dir_name(dur))) if build_dreamvae: jobs.append((f"DreamVAE decode {dur}s", f"dreamvae_decode_fp16_{dur}s")) diff --git a/acestep/engine/trt/export.py b/acestep/engine/trt/export.py index 873f48ac..7a8fda60 100644 --- a/acestep/engine/trt/export.py +++ b/acestep/engine/trt/export.py @@ -334,6 +334,7 @@ def _export_forward( encoder_hidden_states, encoder_attention_mask, context_latents, + steering, use_cache=None, past_key_values=None, cache_position=None, @@ -380,7 +381,10 @@ def _export_forward( ) sw_mask = sw_mask.unsqueeze(0).unsqueeze(0) # [1, 1, S, S] - # Layer loop: static branching on layer_types (config, not runtime) + # Layer loop: static branching on layer_types. The added + # ``steering[:, i, :]`` shift is the in-engine equivalent of + # StreamPipeline._install_steering_hooks; the host zeros rows + # that aren't at an active step so they get a no-op add. for i, layer_module in enumerate(self_dec.layers): attn_mask = sw_mask if layer_types[i] == "sliding_attention" else None layer_outputs = layer_module( @@ -396,7 +400,7 @@ def _export_forward( encoder_hidden_states, None, # encoder_attention_mask ) - hidden_states = layer_outputs[0] + hidden_states = layer_outputs[0] + steering[:, i, :].unsqueeze(1).type_as(layer_outputs[0]) # Output AdaLN + proj_out (ConvTranspose1d stride=2 doubles seq_len) shift, scale = (self_dec.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) @@ -415,6 +419,7 @@ def forward( timestep: torch.Tensor, # [B] encoder_hidden_states: torch.Tensor, # [B, L_enc, 2048] context_latents: torch.Tensor, # [B, T, 128] + steering: torch.Tensor, # [B, num_layers, hidden_size] ) -> torch.Tensor: outputs = self.decoder( hidden_states=hidden_states, @@ -424,6 +429,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=None, context_latents=context_latents, + steering=steering, use_cache=False, past_key_values=None, output_attentions=False, @@ -519,12 +525,17 @@ def export_decoder_onnx( B = config.batch_size T = config.seq_len L = config.enc_len + # Steering: one additive vector per (batch row, DiT layer). Sized + # off the live decoder so XL reuses this path without a config knob. + num_layers = decoder.config.num_hidden_layers + hidden_size = decoder.config.hidden_size example_inputs = ( torch.randn(B, T, 64, device=device, dtype=trace_dtype), torch.full((B,), 0.5, device=device, dtype=ts_dtype), torch.randn(B, L, 2048, device=device, dtype=trace_dtype), torch.randn(B, T, 128, device=device, dtype=trace_dtype), + torch.zeros(B, num_layers, hidden_size, device=device, dtype=trace_dtype), ) input_names = [ @@ -532,6 +543,7 @@ def export_decoder_onnx( "timestep", "encoder_hidden_states", "context_latents", + "steering", ] output_names = ["velocity"] @@ -540,6 +552,8 @@ def export_decoder_onnx( "timestep": {0: "batch"}, "encoder_hidden_states": {0: "batch", 1: "enc_len"}, "context_latents": {0: "batch", 1: "seq_len"}, + # num_layers / hidden_size are static within a checkpoint family. + "steering": {0: "batch"}, "velocity": {0: "batch", 1: "seq_len"}, } @@ -590,6 +604,7 @@ def export_decoder_onnx( "timestep": {0: batch}, "encoder_hidden_states": {0: batch, 1: enc}, "context_latents": {0: batch, 1: seq}, + "steering": {0: batch}, } torch.onnx.export( wrapper, @@ -1064,11 +1079,12 @@ def max_duration_s(self) -> int: return self.seq_max // 25 def engine_filename(self) -> str: - """Generate a standardized engine filename from build config. + """Standardized engine filename. - Format: decoder_{variant}_{precision}[_refit]_b{batch_max}_{duration}s.engine - The variant tag is omitted for "turbo" (backward compat). - Uses seconds so naming is stable across frame rates. + Format: ``spectral_decoder_{variant}_{prec}[_refit]_b{batch_max}_{dur}s.engine``; + variant tag omitted for "turbo". The ``spectral_`` prefix marks + engines that carry the steering input; pre-steering engines used + ``decoder_*`` and the runtime rejects them. """ if self.strongly_typed: # fp8_mixed gets its own tag so FP8 engines never collide @@ -1089,7 +1105,7 @@ def engine_filename(self) -> str: dur = self.max_duration_s # Include variant in name for non-turbo models variant_tag = f"_{self.variant}" if self.variant != "turbo" else "" - return f"decoder{variant_tag}_{prec}{refit_tag}_b{self.batch_max}_{dur}s.engine" + return f"spectral_decoder{variant_tag}_{prec}{refit_tag}_b{self.batch_max}_{dur}s.engine" def build_trt_engine( @@ -1192,6 +1208,28 @@ def build_trt_engine( min=(Bmin, Smin, 128), opt=(Bopt, Sopt, 128), max=(Bmax, Smax, 128), ) + # Read L, D off the parsed network so this stays variant-agnostic. + steering_idx = None + for ti in range(network.num_inputs): + if network.get_input(ti).name == "steering": + steering_idx = ti + break + if steering_idx is None: + raise RuntimeError( + "Decoder ONNX is missing the 'steering' input; re-export with " + "the current export.py." + ) + steering_shape = tuple(network.get_input(steering_idx).shape) + if len(steering_shape) != 3: + raise RuntimeError(f"Unexpected steering input rank: {steering_shape}") + _, steer_L, steer_D = steering_shape + profile.set_shape( + "steering", + min=(Bmin, steer_L, steer_D), + opt=(Bopt, steer_L, steer_D), + max=(Bmax, steer_L, steer_D), + ) + profile_idx = build_config.add_optimization_profile(profile) if profile_idx < 0: raise RuntimeError("Failed to add TensorRT optimization profile") diff --git a/acestep/engine/trt/runtime.py b/acestep/engine/trt/runtime.py index 392fa8b9..beb9fc19 100644 --- a/acestep/engine/trt/runtime.py +++ b/acestep/engine/trt/runtime.py @@ -51,7 +51,10 @@ class TRTDecoder: ) """ - INPUT_NAMES = ("hidden_states", "timestep", "encoder_hidden_states", "context_latents") + INPUT_NAMES = ( + "hidden_states", "timestep", "encoder_hidden_states", + "context_latents", "steering", + ) OUTPUT_NAME = "velocity" def __init__( @@ -93,6 +96,13 @@ def __init__( out_trt_dtype = self.engine.get_tensor_dtype(self.OUTPUT_NAME) self._output_dtype = dtype_map.get(out_trt_dtype, torch.float32) + # Static L, D; B is dynamic per call. + steer_shape = tuple(self.engine.get_tensor_shape("steering")) + if len(steer_shape) != 3: + raise RuntimeError(f"Unexpected steering input rank: {steer_shape}") + self._steer_L = int(steer_shape[1]) + self._steer_D = int(steer_shape[2]) + # Per-shape buffer cache: shape_key -> {bufs, output} self._buf_cache: dict[tuple, dict] = {} @@ -115,11 +125,16 @@ def _get_bufs(self, hs_shape, ts_shape, enc_shape, cl_shape): dev = self.device in_dt = self._input_dtypes + B = hs_shape[0] bufs = { "hidden_states": torch.empty(hs_shape, dtype=in_dt["hidden_states"], device=dev), "timestep": torch.empty(ts_shape, dtype=in_dt["timestep"], device=dev), "encoder_hidden_states": torch.empty(enc_shape, dtype=in_dt["encoder_hidden_states"], device=dev), "context_latents": torch.empty(cl_shape, dtype=in_dt["context_latents"], device=dev), + "steering": torch.zeros( + B, self._steer_L, self._steer_D, + dtype=in_dt["steering"], device=dev, + ), } for name, buf in bufs.items(): @@ -148,12 +163,16 @@ def __call__( timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, context_latents: torch.Tensor, + steering: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Run one decoder step through TensorRT. Accepts any dtype; inputs are copied into pre-allocated fp32 buffers. Returns a view of the internal output buffer (caller must not hold references across calls with different shapes). + + ``steering`` is optional ``[B, num_layers, hidden_size]``; when + omitted the per-layer adds are no-ops. """ orig_T = hidden_states.shape[1] pad = orig_T % 2 == 1 @@ -178,6 +197,10 @@ def __call__( bufs["context_latents"].copy_(context_latents) bufs["timestep"].copy_(timestep) bufs["encoder_hidden_states"].copy_(encoder_hidden_states) + if steering is not None: + bufs["steering"].copy_(steering) + else: + bufs["steering"].zero_() # Bind addresses ctx = self.context diff --git a/acestep/paths.py b/acestep/paths.py index 3d2806ca..bae3b842 100644 --- a/acestep/paths.py +++ b/acestep/paths.py @@ -164,6 +164,43 @@ def user_uploads_dir() -> Path: return models_dir() / "user_uploads" +# Per-checkpoint probe-bundle subpath; shared between HF fetch +# (``steering_vectors/``) and the local cache. XL is absent +# because the 2B vectors don't transfer to a different hidden_size / +# layer count. +_STEERING_VECTORS_BY_CHECKPOINT: dict[str, str] = { + "acestep-v15-turbo": "v15-turbo/shift3.5_n8_seed1528", +} + + +def steering_bundle_subpath( + checkpoint: str | Path | None = None, +) -> str | None: + """Bundle subpath registered for ``checkpoint`` (pure lookup).""" + name = _checkpoint_name(checkpoint) + return _STEERING_VECTORS_BY_CHECKPOINT.get(name) + + +def steering_vectors_dir() -> Path: + """Root directory for cached steering-vector bundles.""" + return models_dir() / "steering_vectors" + + +def steering_vector_dir( + checkpoint: str | Path | None = None, +) -> Path | None: + """Local cache directory for a checkpoint's steering bundle. + + Pure: returns where vectors WOULD live; ``None`` when no bundle + is registered. Callers needing the dir populated should go through + ``acestep.steering.hub.ensure_steering_vectors`` first. + """ + subpath = steering_bundle_subpath(checkpoint) + if subpath is None: + return None + return steering_vectors_dir() / subpath + + def discover_loras(directory: Path | None = None) -> list[Path]: """List ``*.safetensors`` files recursively under ``directory`` (default: ``loras_dir()``). @@ -260,17 +297,17 @@ def trt_engine_path(engine_name: str) -> Path: # that fits the audio (see `select_trt_engines` and `available_trt_engines`). _TRT_ENGINE_PROFILES: dict[float, dict[str, str]] = { 60.0: { - "decoder": "decoder_mixed_refit_b8_60s", + "decoder": "spectral_decoder_mixed_refit_b8_60s", "vae_encode": "vae_encode_fp16_60s", "vae_decode": "vae_decode_fp16_60s", }, 120.0: { - "decoder": "decoder_mixed_refit_b8_120s", + "decoder": "spectral_decoder_mixed_refit_b8_120s", "vae_encode": "vae_encode_fp16_120s", "vae_decode": "vae_decode_fp16_120s", }, 240.0: { - "decoder": "decoder_mixed_refit_b8_240s", + "decoder": "spectral_decoder_mixed_refit_b8_240s", "vae_encode": "vae_encode_fp16_240s", "vae_decode": "vae_decode_fp16_240s", }, @@ -356,7 +393,7 @@ def trt_engine_profiles( def default_trt_engines( - decoder: str = "decoder_mixed_refit_b8_60s", + decoder: str = "spectral_decoder_mixed_refit_b8_60s", vae_encode: str = "vae_encode_fp16_60s", vae_decode: str = "vae_decode_fp16_60s", ) -> dict[str, str]: diff --git a/acestep/steering/__init__.py b/acestep/steering/__init__.py new file mode 100644 index 00000000..f817d1a1 --- /dev/null +++ b/acestep/steering/__init__.py @@ -0,0 +1,57 @@ +"""Activation-steering engine surface. + +Wire knob names: + - ``steer_bright`` / ``steer_warm`` / ``steer_rough`` / ``steer_density``: + the four verified auto-path axes. + - ``man_src_`` / ``man_layer_`` / ``man_step_`` / ``man_alpha_``: + LIFO-numbered manual slot quadruples. + +Auto-axis alpha is sign-corrected at config-build time; the on-disk +cache stays raw so the manual path (direction-agnostic by design) sees +the literal probe direction. +""" + +from .catalog import enumerate_catalog, load_auto_vectors, load_vector +from .controller import SteeringController +from .hub import ensure_steering_vectors, upload_steering_vectors +from .policy import ( + AUTO_AXES, + MANUAL_CATALOG_AXES, + MANUAL_MAX_LAYER, + MANUAL_MAX_STEP, + MANUAL_SLOT_CAP, + MANUAL_SLOT_DEFAULT_COUNT, + PROBE_N, + fractional_inject_step, +) +from .types import ( + AutoAxis, + CapacityError, + CatalogEntry, + EmptyError, + KnobNames, + SteeringError, +) + +__all__ = [ + "AUTO_AXES", + "AutoAxis", + "CapacityError", + "CatalogEntry", + "EmptyError", + "KnobNames", + "MANUAL_CATALOG_AXES", + "MANUAL_MAX_LAYER", + "MANUAL_MAX_STEP", + "MANUAL_SLOT_CAP", + "MANUAL_SLOT_DEFAULT_COUNT", + "PROBE_N", + "SteeringController", + "SteeringError", + "ensure_steering_vectors", + "enumerate_catalog", + "fractional_inject_step", + "load_auto_vectors", + "load_vector", + "upload_steering_vectors", +] diff --git a/acestep/steering/catalog.py b/acestep/steering/catalog.py new file mode 100644 index 00000000..dde23eeb --- /dev/null +++ b/acestep/steering/catalog.py @@ -0,0 +1,109 @@ +"""Vector catalog: walk the on-disk probe dir into stable indexed entries. + +Split into a torch-free enumerator (filename metadata) and a +torch-paying loader. Stems look like ``brightness_l09_t3.pt`` — the +zero-padded layer lets alphabetical sort yield the intended +(l03..l18) x (t0, t3, ...) order naturally. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +from .policy import MANUAL_CATALOG_AXES +from .types import AutoAxis, CatalogEntry + +if TYPE_CHECKING: + import torch + + +def _parse_stem(stem: str, axis: str) -> tuple[int, int] | None: + """Extract (layer, step) from ``brightness_l09_t3`` style stems. + + Returns ``None`` if the stem doesn't parse against ``axis``. + """ + prefix = f"{axis}_l" + if not stem.startswith(prefix): + return None + tail = stem[len(axis) + 1:] # "l09_t3" + try: + layer = int(tail.split("_")[0][1:]) + step = int(tail.split("_")[1][1:]) + except (IndexError, ValueError): + return None + return layer, step + + +def enumerate_catalog(vector_dir: Path) -> tuple[CatalogEntry, ...]: + """List every (axis, build_layer, build_step) cell on disk. + + Torch-free; reads filenames only. Empty tuple when the dir is + missing. + """ + if not vector_dir.exists(): + return () + per_axis: dict[str, list[Path]] = {a: [] for a in MANUAL_CATALOG_AXES} + for path in sorted(vector_dir.glob("*.pt")): + for axis in MANUAL_CATALOG_AXES: + if path.stem.startswith(f"{axis}_l"): + per_axis[axis].append(path) + break + out: list[CatalogEntry] = [] + for axis in MANUAL_CATALOG_AXES: + for path in per_axis[axis]: + parsed = _parse_stem(path.stem, axis) + if parsed is None: + continue + layer, step = parsed + out.append(CatalogEntry( + index=len(out), + axis=axis, + build_layer=layer, + build_step=step, + filename=path.name, + )) + return tuple(out) + + +def load_vector( + vector_dir: Path, + entry: CatalogEntry, +) -> tuple["torch.Tensor", float]: + """Load one catalog entry's (vector, magnitude) as CPU float32.""" + import torch as _t + + path = vector_dir / entry.filename + blob = _t.load(path, map_location="cpu", weights_only=False) + return blob["vector"].to(_t.float32), float(blob["magnitude"]) + + +def load_auto_vectors( + vector_dir: Path, + auto_axes: tuple[AutoAxis, ...], +) -> dict[str, dict]: + """Load the per-axis vector for each verified auto-path axis. + + Returns ``{axis.name: {"layer": int, "vector": Tensor, + "magnitude": float}}``. Missing files are silently skipped — the + knob for that axis just never produces a config. + """ + import torch as _t + + out: dict[str, dict] = {} + if not vector_dir.exists(): + return out + for ax in auto_axes: + path = vector_dir / f"{ax.axis}_l{ax.probe_layer:02d}_t{ax.probe_step}.pt" + if not path.exists(): + continue + try: + blob = _t.load(path, map_location="cpu", weights_only=False) + except Exception: + continue + out[ax.name] = { + "layer": ax.probe_layer, + "vector": blob["vector"].to(_t.float32), + "magnitude": float(blob["magnitude"]), + } + return out diff --git a/acestep/steering/controller.py b/acestep/steering/controller.py new file mode 100644 index 00000000..86e537c3 --- /dev/null +++ b/acestep/steering/controller.py @@ -0,0 +1,238 @@ +"""SteeringController: per-session entry point for activation steering. + +Owns the auto-axis vector cache, the manual catalog, the LIFO slot +registry, and the ``raw`` knob dict → ``set_steering`` config-list +translation. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Mapping + +from .catalog import enumerate_catalog, load_auto_vectors, load_vector +from .policy import ( + AUTO_AXES, + MANUAL_MAX_LAYER, + MANUAL_MAX_STEP, + MANUAL_SLOT_CAP, + MANUAL_SLOT_DEFAULT_COUNT, + PROBE_N, + fractional_inject_step, + parse_probe_n, +) +from .types import AutoAxis, CapacityError, CatalogEntry, EmptyError, KnobNames + + +# Wire knob name prefixes — agreed on by pipeline raw dict, demo +# KnobDef registry, and the MCP catalog. +_MAN_SRC = "man_src_" +_MAN_LAYER = "man_layer_" +_MAN_STEP = "man_step_" +_MAN_ALPHA = "man_alpha_" + + +class SteeringController: + """Engine-side steering state for one streaming session. + + Construction loads the catalog + auto-axis tensors. Degrades to + ``is_loaded == False`` (empty catalog, ``build_configs`` returns + ``[]``) when ``vector_dir`` is None or missing on disk. + + Slots are LIFO: ``add_slot`` allocates ``slot_count + 1`` so the + active set is always ``{1..count}`` with no holes. Interior + deletion is not supported, which keeps catalog indices stable + across edits. + """ + + auto_axes: tuple[AutoAxis, ...] + catalog: tuple[CatalogEntry, ...] + slot_cap: int = MANUAL_SLOT_CAP + MANUAL_MAX_LAYER: int = MANUAL_MAX_LAYER + MANUAL_MAX_STEP: int = MANUAL_MAX_STEP + + def __init__( + self, + vector_dir: Path | None, + *, + default_slot_count: int = MANUAL_SLOT_DEFAULT_COUNT, + ) -> None: + self._vector_dir = vector_dir + self.auto_axes = AUTO_AXES + # Probe schedule N comes from the bundle dir name + # (``shift3.5_n8_seed1528`` → 8); falls back to PROBE_N when the + # token is absent or no dir was passed. + self._probe_n = parse_probe_n(vector_dir.name) if vector_dir else PROBE_N + if vector_dir is None: + self.catalog = () + self._auto_vectors = {} + else: + self.catalog = enumerate_catalog(vector_dir) + self._auto_vectors = load_auto_vectors(vector_dir, AUTO_AXES) + # Lazy manual-vector cache keyed by catalog index. ``None`` is + # the cached load-failure sentinel so a corrupted .pt logs once. + self._manual_vectors: dict[int, dict | None] = {} + if not self.is_loaded: + self._slot_count = 0 + else: + self._slot_count = max(0, min(int(default_slot_count), self.slot_cap)) + + @property + def is_loaded(self) -> bool: + """True if at least one vector (auto or manual) is reachable.""" + return bool(self._auto_vectors) or bool(self.catalog) + + @property + def slot_count(self) -> int: + return self._slot_count + + def active_slots(self) -> tuple[int, ...]: + """Sorted tuple of currently allocated slot ids (1..slot_count).""" + return tuple(range(1, self._slot_count + 1)) + + def add_slot(self) -> int: + """Allocate the next LIFO slot id. Raises CapacityError at cap. + + Refused when no vectors are loaded: a slot with no catalog + behind it is just a dead knob quadruple. + """ + if not self.is_loaded: + raise CapacityError("manual steering unavailable (no vectors loaded)") + if self._slot_count >= self.slot_cap: + raise CapacityError( + f"manual steering at cap ({self.slot_cap})", + ) + self._slot_count += 1 + return self._slot_count + + def pop_slot(self) -> int: + """Pop the highest-numbered slot. Raises EmptyError when empty.""" + if self._slot_count <= 0: + raise EmptyError("no manual steering slots to remove") + popped = self._slot_count + self._slot_count -= 1 + return popped + + @staticmethod + def knob_names(slot_id: int) -> KnobNames: + """The four wire-protocol knob names for slot ``slot_id``.""" + return KnobNames( + src=f"{_MAN_SRC}{slot_id}", + layer=f"{_MAN_LAYER}{slot_id}", + step=f"{_MAN_STEP}{slot_id}", + alpha=f"{_MAN_ALPHA}{slot_id}", + ) + + def snapshot_key( + self, + raw: Mapping[str, float], + n: int, + ) -> tuple: + """Build the change-detection key for ``raw`` at schedule ``n``. + + Returns a tuple containing every input that ``build_configs`` + consults: the auto-axis alphas in declaration order, then per + active slot the (src, layer, step, alpha) quadruple, then the + schedule step count. Demo callers cache this and skip + ``set_steering`` when it's unchanged. + + Length is dynamic in ``slot_count``: an add or pop changes the + tuple length so equality fails and the next ``build_configs`` + actually fires. + """ + out: list[float] = [float(raw.get(ax.name, 0.0)) for ax in self.auto_axes] + for slot in self.active_slots(): + names = self.knob_names(slot) + out.append(float(raw.get(names.src, 0.0))) + out.append(float(raw.get(names.layer, 0.0))) + out.append(float(raw.get(names.step, 0.0))) + out.append(float(raw.get(names.alpha, 0.0))) + out.append(float(max(1, int(n)))) + return tuple(out) + + def _get_manual_vector(self, src_idx: int) -> dict | None: + """Lazy-load one catalog entry's ``{vector, magnitude}``. + + Caches successes and failures (failure as None) so a corrupted + .pt logs once instead of every tick. Returns None for out-of- + range indices or load failures. + """ + if src_idx < 0 or src_idx >= len(self.catalog): + return None + cached = self._manual_vectors.get(src_idx, _MISSING) + if cached is not _MISSING: + return cached + entry = self.catalog[src_idx] + try: + vec, mag = load_vector(self._vector_dir, entry) # type: ignore[arg-type] + blob: dict | None = {"vector": vec, "magnitude": mag} + except Exception as exc: + from loguru import logger + logger.warning( + "steering: failed to load manual vector {} (idx {}): {}", + entry.filename, src_idx, exc, + ) + blob = None + self._manual_vectors[src_idx] = blob + return blob + + def build_configs( + self, + raw: Mapping[str, float], + n: int, + ) -> list[dict]: + """Translate live ``raw`` knob state into ``set_steering`` configs. + + Auto axes get the fractional step mapping, per-axis layer + offset, and sign correction. Manual slots are verbatim: + (layer, step) lands as picked, alpha is consumed sign-as-given. + Zero-alpha entries are dropped. + """ + if not self.is_loaded: + return [] + n = max(1, int(n)) + configs: list[dict] = [] + for ax in self.auto_axes: + alpha = float(raw.get(ax.name, 0.0)) + if alpha == 0.0: + continue + blob = self._auto_vectors.get(ax.name) + if blob is None: + continue + inject_step = fractional_inject_step( + ax.probe_step, n, probe_n=self._probe_n, + ) + inject_layer = max( + 0, min(MANUAL_MAX_LAYER, ax.probe_layer + ax.layer_offset), + ) + configs.append({ + "layer": inject_layer, + "step": inject_step, + "vector": blob["vector"], + "magnitude": blob["magnitude"], + "alpha": alpha * ax.sign, + }) + for slot in self.active_slots(): + names = self.knob_names(slot) + alpha = float(raw.get(names.alpha, 0.0)) + if alpha == 0.0: + continue + src_idx = int(round(float(raw.get(names.src, 0.0)))) + blob = self._get_manual_vector(src_idx) + if blob is None: + continue + inject_layer = max( + 0, min(MANUAL_MAX_LAYER, int(round(float(raw.get(names.layer, 0.0))))), + ) + inject_step = int(round(float(raw.get(names.step, 0.0)))) + configs.append({ + "layer": inject_layer, + "step": inject_step, + "vector": blob["vector"], + "magnitude": blob["magnitude"], + "alpha": alpha, + }) + return configs + + +_MISSING = object() diff --git a/acestep/steering/hub.py b/acestep/steering/hub.py new file mode 100644 index 00000000..3ed3a783 --- /dev/null +++ b/acestep/steering/hub.py @@ -0,0 +1,155 @@ +"""HF source for activation-steering vector bundles. + +Mirrors :mod:`acestep.engine.trt.onnx_hub`. Local layout matches +:func:`acestep.paths.steering_vector_dir` so a fresh fetch and a +warm cache resolve to the same path. +""" + +from __future__ import annotations + +import shutil +from pathlib import Path +from typing import Optional + +from loguru import logger + +from acestep.engine.trt.onnx_hub import DEMON_ONNX_REPO +from acestep.paths import steering_bundle_subpath, steering_vectors_dir + + +# Steering bundles share the demon-onnx repo today; alias kept so a +# future split to ``daydreamlive/demon-steering`` is a one-line change. +DEMON_STEERING_REPO = DEMON_ONNX_REPO +_HF_PREFIX = "steering_vectors" + + +def hf_bundle_dir(subpath: str) -> str: + """The HF in-repo path for a bundle subpath (no trailing slash).""" + return f"{_HF_PREFIX}/{subpath}" + + +def ensure_steering_vectors( + checkpoint: Optional[str] = None, + *, + force_download: bool = False, +) -> Optional[Path]: + """Ensure the checkpoint's bundle is on disk; return its cache dir. + + Returns ``None`` when no bundle is registered for the checkpoint + (XL today). HF errors are best-effort — they log a warning and + return the (possibly empty) cache dir so the session keeps booting. + """ + subpath = steering_bundle_subpath(checkpoint) + if subpath is None: + return None + + target_dir = steering_vectors_dir() / subpath + # Cache-hit heuristic: directory has at least one .pt file. We + # don't compare against an HF manifest because the bundle is + # small and the operator can force_download=True to refresh. + if not force_download and target_dir.is_dir(): + if any(target_dir.glob("*.pt")): + return target_dir + + target_dir.mkdir(parents=True, exist_ok=True) + hf_dir = hf_bundle_dir(subpath) + + try: + from huggingface_hub import snapshot_download + except ImportError as exc: + logger.warning( + "steering: huggingface_hub not available ({}); " + "steering knobs will no-op until vectors are installed at {}.", + exc, target_dir, + ) + return target_dir + + logger.info( + "Fetching steering bundle {!r} from HF: {}/{}", + subpath, DEMON_STEERING_REPO, hf_dir, + ) + try: + snap_dir = Path(snapshot_download( + repo_id=DEMON_STEERING_REPO, + allow_patterns=[f"{hf_dir}/*"], + force_download=force_download, + )) + except Exception as exc: + logger.warning( + "steering: HF fetch failed ({}); steering knobs will no-op. " + "Cache dir: {}", + exc, target_dir, + ) + return target_dir + + src_dir = snap_dir / hf_dir + if not src_dir.is_dir(): + logger.warning( + "steering: HF bundle dir {} not present in snapshot at {}; " + "steering knobs will no-op.", + hf_dir, snap_dir, + ) + return target_dir + + copied = 0 + for f in src_dir.iterdir(): + if f.is_file(): + shutil.copy2(f, target_dir / f.name) + copied += 1 + logger.info( + "Steering bundle ready at {} ({} files)", + target_dir, copied, + ) + return target_dir + + +def upload_steering_vectors( + checkpoint: str, + *, + local_dir: Optional[Path] = None, + repo: str = DEMON_STEERING_REPO, + commit_message: Optional[str] = None, + dry_run: bool = False, +) -> None: + """Upload a local steering bundle to HF (operator-facing).""" + subpath = steering_bundle_subpath(checkpoint) + if subpath is None: + raise ValueError( + f"No steering bundle registered for checkpoint {checkpoint!r}. " + f"Add it to _STEERING_VECTORS_BY_CHECKPOINT in acestep/paths.py." + ) + + src = Path(local_dir) if local_dir is not None else steering_vectors_dir() / subpath + if not src.is_dir(): + raise FileNotFoundError(f"Local steering dir not found: {src}") + + files = sorted(p for p in src.iterdir() if p.is_file()) + if not files: + raise FileNotFoundError(f"Local steering dir is empty: {src}") + + total_mb = sum(p.stat().st_size for p in files) / (1 << 20) + hf_dir = hf_bundle_dir(subpath) + logger.info( + "Upload plan: {} -> {}/{} ({} files, {:.1f} MB)", + subpath, repo, hf_dir, len(files), total_mb, + ) + for f in files[:10]: + logger.info(" {}", f.name) + if len(files) > 10: + logger.info(" ... {} more", len(files) - 10) + + if dry_run: + logger.info("--dry-run: not uploading") + return + + from huggingface_hub import HfApi + msg = commit_message or ( + f"Upload steering bundle {subpath} ({len(files)} files, {total_mb:.0f} MB)" + ) + HfApi().upload_folder( + repo_id=repo, + folder_path=str(src), + path_in_repo=hf_dir, + commit_message=msg, + ) + logger.info("Uploaded steering bundle {} to {}/{}", subpath, repo, hf_dir) diff --git a/acestep/steering/policy.py b/acestep/steering/policy.py new file mode 100644 index 00000000..70be58a5 --- /dev/null +++ b/acestep/steering/policy.py @@ -0,0 +1,145 @@ +"""Activation-steering policy: which axes are exposed, where they +inject, and how schedules translate. + +Pure tables + one pure function. No torch, no disk I/O. The +SteeringController consumes this module; consumers that only need to +*describe* the surface (MCP, UI catalog builders) can import here +without torch. +""" + +from __future__ import annotations + +import re + +from .types import AutoAxis + + +# Default probe schedule. Used when the bundle subpath doesn't encode +# one (older bundles) or for callers that don't have a subpath handy. +PROBE_N: int = 8 + + +_PROBE_N_RE = re.compile(r"_n(\d+)(?:_|$)") + + +def parse_probe_n(subpath: str) -> int: + """Pull the probe schedule N out of a bundle subpath. + + Subpaths look like ``v15-turbo/shift3.5_n8_seed1528``; the ``_n_`` + token names the schedule. Falls back to ``PROBE_N`` when absent. + """ + m = _PROBE_N_RE.search(subpath) + return int(m.group(1)) if m else PROBE_N + + +# v1.5 turbo decoder has 24 DiT blocks → legal layers 0..23. +MANUAL_MAX_LAYER: int = 23 + +# Matches the demo's ``steps_override`` MIDI cap of 16 (0..15). +# Values past the live ``steps_override - 1`` silently no-op. +MANUAL_MAX_STEP: int = 15 + +MANUAL_SLOT_DEFAULT_COUNT: int = 1 + +# Soft cap on registered slots; sized so the UI doesn't collapse if +# someone stress-tests the surface (engine cost per slot is negligible). +MANUAL_SLOT_CAP: int = 16 + + +# Manual catalog index order: axis-major per this list, then +# build_layer asc, then build_step asc. Pins each (axis, layer, step) +# cell to a stable index across sessions. +MANUAL_CATALOG_AXES: tuple[str, ...] = ( + "brightness", + "warmth", + "roughness", + "density", + "attack", + "tonality", + "punch", + "bass_emphasis", +) + + +# Auto-path axes — only the four whose prompt-to-metric premise +# verified in PROMPT_BASELINE.md. The broken-premise four (attack, +# tonality, punch, bass_emphasis) stay reachable via the manual +# catalog only. +# +# warmth: sign=-1 flips the raw vector so positive alpha tilts warmer +# (raw probe direction is reversed for this axis). +# density: layer_offset=-3 from PHASE3_ANALYSIS.md (28/30 transfer +# pairs preferred 3 layers shallower than the probe). +AUTO_AXES: tuple[AutoAxis, ...] = ( + AutoAxis( + name="steer_bright", + axis="brightness", + probe_layer=9, + probe_step=3, + sign=1.0, + layer_offset=0, + blurb=( + "positive alpha shifts spectral centroid up " + "(brighter, more highs)" + ), + ), + AutoAxis( + name="steer_warm", + axis="warmth", + probe_layer=15, + probe_step=0, + sign=-1.0, + layer_offset=0, + blurb=( + "positive alpha tilts the spectrum toward bass (warmer); " + "vector is sign-corrected from the raw probe direction" + ), + ), + AutoAxis( + name="steer_rough", + axis="roughness", + probe_layer=9, + probe_step=3, + sign=1.0, + layer_offset=0, + blurb=( + "positive alpha increases spectral flatness " + "(grittier, noisier); magnitude is small at this probe cell" + ), + ), + AutoAxis( + name="steer_density", + axis="density", + probe_layer=18, + probe_step=3, + sign=1.0, + layer_offset=-3, + blurb=( + "positive alpha thins the texture toward sparse/minimal" + ), + ), +) + + +def fractional_inject_step( + probe_step: int, + inject_n: int, + *, + probe_n: int = PROBE_N, +) -> int: + """Translate a probe-schedule step into the inject schedule. + + Returns ``round(probe_step / probe_n * inject_n)`` clamped to + ``[0, inject_n - 1]``. Identity at same schedule + (``inject_n == probe_n`` returns ``probe_step`` exactly). + """ + if inject_n < 1: + inject_n = 1 + if probe_n < 1: + probe_n = 1 + s = int(round(probe_step / probe_n * inject_n)) + if s < 0: + return 0 + if s >= inject_n: + return inject_n - 1 + return s diff --git a/acestep/steering/types.py b/acestep/steering/types.py new file mode 100644 index 00000000..13cd63aa --- /dev/null +++ b/acestep/steering/types.py @@ -0,0 +1,61 @@ +"""Frozen dataclasses + exceptions for the activation-steering surface. + +Pure data; no torch / disk I/O so torch-free consumers (MCP catalog, +UI metadata) can import without paying for the controller's torch +dependency. +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class AutoAxis: + """One verified auto-path steering axis. + + ``sign`` is folded into alpha at config-build time so the on-disk + cache stays raw (the manual path reads the same cache and is + sign-agnostic by design). + """ + + name: str # knob name, e.g. "steer_bright" + axis: str # disk-side axis tag, e.g. "brightness" + probe_layer: int + probe_step: int + sign: float # -1.0 flips the raw vector direction; 1.0 leaves it + layer_offset: int # added to probe_layer at inject time + blurb: str # operator-facing effect description + + +@dataclass(frozen=True) +class CatalogEntry: + """One pre-built (axis, build_layer, build_step) cell from disk.""" + + index: int + axis: str + build_layer: int + build_step: int + filename: str + + +@dataclass(frozen=True) +class KnobNames: + """Wire knob names for one manual steering slot.""" + + src: str # man_src_ + layer: str # man_layer_ + step: str # man_step_ + alpha: str # man_alpha_ + + +class SteeringError(Exception): + """Base class for slot-registry errors raised by SteeringController.""" + + +class CapacityError(SteeringError): + """``add_slot`` called when the registry is already at ``slot_cap``.""" + + +class EmptyError(SteeringError): + """``pop_slot`` called when the registry has no allocated slots.""" From 094bfbb7d760a5c11ec257a319fd283370435cd2 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Thu, 28 May 2026 09:44:21 -0400 Subject: [PATCH 2/3] =?UTF-8?q?feat(rtmg):=20rename=20"Channels"=20tab=20?= =?UTF-8?q?=E2=86=92=20"Experimental"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Also relabels the two sections inside as "channel highlights" and "channel groups" so the section ↔ tab vocabulary stays honest. --- .../web/components/Performance/AdvancedDrawer.tsx | 2 +- .../web/components/Performance/DrawerTabs.tsx | 2 +- .../web/components/Performance/MobileFullSheet.tsx | 2 +- .../web/components/Performance/VoiceTile.tsx | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/demos/realtime_motion_graph_web/web/components/Performance/AdvancedDrawer.tsx b/demos/realtime_motion_graph_web/web/components/Performance/AdvancedDrawer.tsx index 7fc0d068..4ffcdbde 100644 --- a/demos/realtime_motion_graph_web/web/components/Performance/AdvancedDrawer.tsx +++ b/demos/realtime_motion_graph_web/web/components/Performance/AdvancedDrawer.tsx @@ -265,7 +265,7 @@ const SPREAD_SECTIONS: Array<{ id: DrawerTab; label: string }> = [ { id: "core", label: "Core" }, { id: "styles", label: "Styles" }, { id: "mod", label: "Mod" }, - { id: "voice", label: "Channels" }, + { id: "voice", label: "Experimental" }, { id: "config", label: "Config" }, ]; diff --git a/demos/realtime_motion_graph_web/web/components/Performance/DrawerTabs.tsx b/demos/realtime_motion_graph_web/web/components/Performance/DrawerTabs.tsx index fe6a99bc..2bebeedd 100644 --- a/demos/realtime_motion_graph_web/web/components/Performance/DrawerTabs.tsx +++ b/demos/realtime_motion_graph_web/web/components/Performance/DrawerTabs.tsx @@ -27,7 +27,7 @@ export type DrawerTab = (typeof DRAWER_TABS)[number]; const TAB_LABELS: Record = { core: "Core", mod: "Mod", - voice: "Channels", + voice: "Experimental", styles: "Styles", // Auto-generated control surface, rendered straight from the backend // /api/knobs manifest. Reference template for a re-skinned UI. diff --git a/demos/realtime_motion_graph_web/web/components/Performance/MobileFullSheet.tsx b/demos/realtime_motion_graph_web/web/components/Performance/MobileFullSheet.tsx index 084fc5ba..acc05387 100644 --- a/demos/realtime_motion_graph_web/web/components/Performance/MobileFullSheet.tsx +++ b/demos/realtime_motion_graph_web/web/components/Performance/MobileFullSheet.tsx @@ -29,7 +29,7 @@ const TABS: { id: Tab; label: string }[] = [ { id: "core", label: "Core" }, { id: "styles", label: "Styles" }, { id: "mod", label: "Mod" }, - { id: "voice", label: "Channels" }, + { id: "voice", label: "Experimental" }, { id: "saved", label: "Saved" }, { id: "config", label: "Config" }, ]; diff --git a/demos/realtime_motion_graph_web/web/components/Performance/VoiceTile.tsx b/demos/realtime_motion_graph_web/web/components/Performance/VoiceTile.tsx index 7beeb478..d7b3a8cc 100644 --- a/demos/realtime_motion_graph_web/web/components/Performance/VoiceTile.tsx +++ b/demos/realtime_motion_graph_web/web/components/Performance/VoiceTile.tsx @@ -71,7 +71,7 @@ export function VoiceTile() { )}
-
Highlights
+
channel highlights
{MORPH.map((p) => { const r = ranges[p]; @@ -92,7 +92,7 @@ export function VoiceTile() {