diff --git a/acestep/engine/diffusion.py b/acestep/engine/diffusion.py index b3e32b49..02c8d0f8 100644 --- a/acestep/engine/diffusion.py +++ b/acestep/engine/diffusion.py @@ -198,19 +198,49 @@ def load_trt_engine(self, engine_path): if hasattr(trt, "bfloat16"): _trt_dtype_map[trt.bfloat16] = torch.bfloat16 - input_names = ("hidden_states", "timestep", "encoder_hidden_states", "context_latents") + engine_input_names = { + self._trt_engine.get_tensor_name(i) + for i in range(self._trt_engine.num_io_tensors) + } + # ``steering`` is present on spectral-prefixed engines only; + # older builds get _steering_num_layers == 0 and the streaming + # pipeline skips the buffer entirely. + has_steering = "steering" in engine_input_names + base_inputs = ( + "hidden_states", "timestep", "encoder_hidden_states", + "context_latents", + ) + input_names = base_inputs + (("steering",) if has_steering else ()) self._trt_input_dtypes = { name: _trt_dtype_map.get(self._trt_engine.get_tensor_dtype(name), torch.float32) for name in input_names } + if has_steering: + steer_shape = tuple(self._trt_engine.get_tensor_shape("steering")) + if len(steer_shape) != 3: + raise RuntimeError(f"Unexpected steering input rank: {steer_shape}") + self._steering_num_layers = int(steer_shape[1]) + self._steering_hidden_size = int(steer_shape[2]) + else: + self._steering_num_layers = 0 + self._steering_hidden_size = 0 + logger.warning( + "TRT decoder engine has no 'steering' input — activation-" + "steering knobs will no-op on this session. Rebuild with " + "`python -m acestep.engine.trt.build --all --force-rebuild` " + "to enable." + ) self._trt_output_dtype = _trt_dtype_map.get( self._trt_engine.get_tensor_dtype("velocity"), torch.float32 ) self._trt_io_dtype = self._trt_input_dtypes["hidden_states"] logger.info( - "TRT decoder engine ready (input_dtypes={}, output_dtype={})", + "TRT decoder engine ready (input_dtypes={}, output_dtype={}, " + "steering L={}, D={})", self._trt_input_dtypes, self._trt_output_dtype, + self._steering_num_layers, + self._steering_hidden_size, ) # Try to initialize LoRA refit manager (requires REFIT-enabled engine) @@ -526,11 +556,15 @@ def _trt_decoder_step( timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, context_latents: torch.Tensor, + steering: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Run one decoder step through TRT with pre-allocated buffers. Handles odd-T padding. Caches buffers by shape for reuse across steps. Calls execute_async_v3 on the shared polygraphy stream. + + ``steering`` is optional ``[B, num_layers, hidden_size]``; + when omitted the cached zero buffer makes the per-layer adds a no-op. """ orig_T = hidden_states.shape[1] pad = orig_T % 2 == 1 @@ -548,6 +582,7 @@ def _trt_decoder_step( dev = hidden_states.device hs_shape, ts_shape, enc_shape, cl_shape = key in_dtypes = self._trt_input_dtypes + B = hs_shape[0] bufs = { "hidden_states": torch.empty(hs_shape, dtype=in_dtypes["hidden_states"], device=dev), @@ -555,6 +590,11 @@ def _trt_decoder_step( "encoder_hidden_states": torch.empty(enc_shape, dtype=in_dtypes["encoder_hidden_states"], device=dev), "context_latents": torch.empty(cl_shape, dtype=in_dtypes["context_latents"], device=dev), } + if self._steering_num_layers > 0: + bufs["steering"] = torch.zeros( + B, self._steering_num_layers, self._steering_hidden_size, + dtype=in_dtypes["steering"], device=dev, + ) for name, buf in bufs.items(): if not ctx.set_input_shape(name, tuple(buf.shape)): raise RuntimeError(f"TRT decoder rejected input shape for {name}: {tuple(buf.shape)}") @@ -589,6 +629,11 @@ def _trt_decoder_step( bufs["context_latents"].copy_(context_latents) bufs["timestep"].copy_(timestep) bufs["encoder_hidden_states"].copy_(encoder_hidden_states) + if "steering" in bufs: + if steering is not None: + bufs["steering"].copy_(steering) + else: + bufs["steering"].zero_() ctx = self._trt_ctx for name, buf in bufs.items(): diff --git a/acestep/engine/stream.py b/acestep/engine/stream.py index a4672494..839b45c2 100644 --- a/acestep/engine/stream.py +++ b/acestep/engine/stream.py @@ -15,7 +15,7 @@ import time from collections import OrderedDict from dataclasses import dataclass, field -from typing import Callable, Optional, List, Tuple, TYPE_CHECKING +from typing import Callable, Dict, NamedTuple, Optional, List, Tuple, TYPE_CHECKING from loguru import logger import torch @@ -29,6 +29,13 @@ from .masking import LatentNoiseMask +class _SteeringApply(NamedTuple): + """One pre-resolved activation-steering shift bound to a layer.""" + vector: torch.Tensor # 1-D [hidden_dim] + scale: float # alpha * magnitude + step: int # gate: only rows at this denoise step receive it + + @dataclass class SlotCondition: """One conditioning entry for multi-condition per-frame blending. @@ -285,6 +292,10 @@ def __init__( self._trt_io_dtype = getattr(engine, '_trt_io_dtype', torch.float32) self._trt_input_dtypes = getattr(engine, "_trt_input_dtypes", {}) or {} self._trt_output_dtype = getattr(engine, "_trt_output_dtype", self._trt_io_dtype) + # Steering shape (constants per engine); snapshotted so the + # per-tick buffer fill doesn't re-query TRT. + self._steering_num_layers = getattr(engine, "_steering_num_layers", 0) + self._steering_hidden_size = getattr(engine, "_steering_hidden_size", 0) # Currently-bound TRT I/O buffers (set by _ensure_trt_bufs to one # entry of _trt_bufs_cache). _trt_forward reads these directly. @@ -325,6 +336,15 @@ def __init__( # ``.to(...)`` is a no-op. self._channel_gain: Optional[torch.Tensor] = None + # Activation steering. Per-DiT-layer additive shift on the + # post-block residual, gated per-row by denoise step. + # ``_current_step_per_row`` is populated by _tick_complex_pt + # around each forward; empty means "skip injection" so a + # forward issued outside the rendezvous can't fire steering. + self._steering_by_layer: Dict[int, List[_SteeringApply]] = {} + self._steering_hooks_installed: bool = False + self._current_step_per_row: List[int] = [] + # Sentinel tensors for the "always-on multiply" idiom in the step # helpers. Built lazily once the first slot's device/dtype is known. # ``_ones_3d`` stands in for absent ``velocity_scale`` (vt * 1 = vt). @@ -402,9 +422,14 @@ def _on_engine_swapped(self) -> None: self._trt_output_dtype = getattr( engine, "_trt_output_dtype", self._trt_io_dtype ) + self._steering_num_layers = getattr(engine, "_steering_num_layers", 0) + self._steering_hidden_size = getattr(engine, "_steering_hidden_size", 0) self._trt_bufs = None self._trt_out_buf = None self._trt_bufs_cache.clear() + # Hooks live on the old decoder.layers; the new decoder needs a + # fresh install on the next non-empty set_steering call. + self._steering_hooks_installed = False def submit(self, request: SlotRequest) -> None: """Enqueue a generation request. @@ -712,6 +737,10 @@ def _trt_forward( else: bufs["context_latents"].copy_(ctx_io) + # Steering: absent on non-spectral engines. + if "steering" in bufs: + self._fill_trt_steering_buffer(bufs["steering"], B) + # Rebind and execute. ctx = self._trt_ctx for name, buf in bufs.items(): @@ -731,6 +760,39 @@ def _trt_forward( return out[:, :T, :].to(self._dtype) return out.to(self._dtype) + def _steering_row_mask(self, target_step: int, B: int) -> Optional[List[int]]: + """Rows whose slot is at ``target_step``, or None to skip. + + Returns None when no per-row step mapping exists or it + disagrees with ``B`` — the eager hook and TRT buffer fill both + skip injection in that case rather than firing blindly. + """ + row_steps = self._current_step_per_row + if not row_steps or len(row_steps) != B: + return None + mask = [i for i, s in enumerate(row_steps) if s == target_step] + return mask or None + + def _fill_trt_steering_buffer(self, buf: torch.Tensor, B: int) -> None: + """Populate the TRT steering buffer for one forward. + + ``buf`` is ``[B, num_layers, hidden_size]``; zeroed first so + previous-tick content doesn't leak. Rows with no matching shift + stay zero, which the engine adds as a no-op per layer. + """ + buf.zero_() + if not self._steering_by_layer: + return + for layer_idx, applies in self._steering_by_layer.items(): + if layer_idx < 0 or layer_idx >= self._steering_num_layers: + continue + for apply in applies: + mask_rows = self._steering_row_mask(apply.step, B) + if mask_rows is None: + continue + v = apply.vector.to(device=buf.device, dtype=buf.dtype) + buf[mask_rows, layer_idx, :] += apply.scale * v + # ------------------------------------------------------------------ # TRT buffer management # ------------------------------------------------------------------ @@ -787,6 +849,15 @@ def _ensure_trt_bufs(self, B: int, T: int, max_L: int): device=device, ), } + if self._steering_num_layers > 0: + # Zeroed so a tick with no active configs is a true no-op + # (engine still adds zeros per layer); repopulated each + # forward by _trt_forward. + bufs["steering"] = torch.zeros( + B, self._steering_num_layers, self._steering_hidden_size, + dtype=in_dtypes.get("steering", io_dtype), + device=device, + ) for name, buf in bufs.items(): if not ctx.set_input_shape(name, tuple(buf.shape)): @@ -1033,7 +1104,13 @@ def _forward_pairs( for c in pos_conds_per_slot[si]: pos_pair_si.append(si) pos_pair_cond.append(c) - vt_pos_all = _forward_pairs(pos_pair_si, pos_pair_cond) + # Per-row step mapping for the steering hook; try/finally so a + # stale list never leaks into a forward outside this rendezvous. + self._current_step_per_row = [slots[si].step_idx for si in pos_pair_si] + try: + vt_pos_all = _forward_pairs(pos_pair_si, pos_pair_cond) + finally: + self._current_step_per_row = [] # --- Negative pass (CFG only): skipped when no slot has CFG. --- neg_pair_si: List[int] = [] @@ -1042,9 +1119,14 @@ def _forward_pairs( for c in neg_conds_per_slot[si]: neg_pair_si.append(si) neg_pair_cond.append(c) - vt_neg_all = ( - _forward_pairs(neg_pair_si, neg_pair_cond) if neg_pair_si else None - ) + if neg_pair_si: + self._current_step_per_row = [slots[si].step_idx for si in neg_pair_si] + try: + vt_neg_all = _forward_pairs(neg_pair_si, neg_pair_cond) + finally: + self._current_step_per_row = [] + else: + vt_neg_all = None # --- Per-slot: blend pos, blend neg (if CFG), APG-combine --- vt_per_slot: List[torch.Tensor] = [None] * len(slots) # type: ignore[list-item] @@ -1425,6 +1507,79 @@ def set_channel_gain_tensor(self, gain: Optional[torch.Tensor]) -> None: dt = self._dtype or torch.float16 self._channel_gain = gain.to(device=dev, dtype=dt) + def set_steering(self, configs: List[Dict[str, object]]) -> None: + """Set activation-steering configs. + + Each config dict has: + - ``layer``: int, index into ``decoder.layers`` + - ``step``: int, the denoise step at which to fire + - ``vector``: 1-D unit-norm tensor [hidden_dim] + - ``magnitude``: float, paired mean-diff scale + - ``alpha``: float, knob value; effective shift is + ``alpha * magnitude * vector`` + + Multiple configs may share a layer (additions sum, modulo the + per-row step gate). Zero-alpha entries drop. Pass ``[]`` to + clear. Eager: forward hooks on ``decoder.layers`` (installed + lazily). TRT: per-tick buffer fill in ``_trt_forward``. + """ + by_layer: Dict[int, List[_SteeringApply]] = {} + for c in configs: + alpha = float(c.get("alpha", 0.0)) + if alpha == 0.0: + continue + mag = float(c.get("magnitude", 1.0)) + li = int(c["layer"]) + by_layer.setdefault(li, []).append(_SteeringApply( + vector=c["vector"], # type: ignore[arg-type] + scale=alpha * mag, + step=int(c["step"]), + )) + self._steering_by_layer = by_layer + # Eager-path hooks only matter when the PyTorch decoder runs. + # With TRT active the decoder may be a stub. + if ( + by_layer + and not self._steering_hooks_installed + and self._trt_engine is None + ): + self._install_steering_hooks() + self._steering_hooks_installed = True + + def _install_steering_hooks(self) -> None: + """Attach one forward hook per DiT layer. + + Reads ``_steering_by_layer`` + ``_current_step_per_row`` at call + time. Idempotent via ``_steering_hooks_installed`` (cleared on + engine swap so a new eager decoder gets fresh hooks). + """ + layers = self.decoder.layers + + def make_hook(layer_idx: int): + def _hook(_module, _inputs, output): + applies = self._steering_by_layer.get(layer_idx) + if not applies: + return output + hs = output[0] if isinstance(output, tuple) else output + B = hs.shape[0] + changed = False + for apply in applies: + mask_rows = self._steering_row_mask(apply.step, B) + if mask_rows is None: + continue + v = apply.vector.to(device=hs.device, dtype=hs.dtype) + hs[mask_rows] = hs[mask_rows] + apply.scale * v.view(1, 1, -1) + changed = True + if not changed: + return output + if isinstance(output, tuple): + return (hs,) + output[1:] + return hs + return _hook + + for li in range(len(layers)): + layers[li].register_forward_hook(make_hook(li)) + def stats(self) -> dict: return { "ticks": self.ticks, diff --git a/acestep/engine/trt/_engine_metadata.py b/acestep/engine/trt/_engine_metadata.py index c49097be..e3571c90 100644 --- a/acestep/engine/trt/_engine_metadata.py +++ b/acestep/engine/trt/_engine_metadata.py @@ -37,7 +37,10 @@ from loguru import logger -_ENGINE_METADATA_SCHEMA = 1 +# Schema 2: decoder ONNX/engines carry the ``steering`` input (the +# ``spectral_decoder_*`` family); pre-steering engine metadata is +# invalidated so the next build run refreshes it. +_ENGINE_METADATA_SCHEMA = 2 def _sha256_file(path: str | os.PathLike[str]) -> str: diff --git a/acestep/engine/trt/build.py b/acestep/engine/trt/build.py index 57159265..969c8369 100644 --- a/acestep/engine/trt/build.py +++ b/acestep/engine/trt/build.py @@ -899,12 +899,25 @@ def _print_matrix(durations, build_vae, build_decoder, output_dir, batch_max, checkpoint="acestep-v15-turbo", build_dreamvae=False): """Print the build matrix for --all mode, showing existing vs new.""" variant = _checkpoint_to_variant(checkpoint) - vtag = f"_{variant}" if variant != "turbo" else "" from acestep.paths import ( WINDOWED_VAE_DECODE_NAME, WINDOWED_DREAMVAE_DECODE_NAME, ) + # Defer to TRTBuildConfig.engine_filename so this preview can't + # drift from what _build_decoder_engine writes. + from .export import TRTBuildConfig + + def _decoder_dir_name(dur: int) -> str: + cfg = TRTBuildConfig( + fp16=True, + strongly_typed=True, + refit=True, + batch_max=batch_max, + seq_max=dur * 25, + variant=variant, + ) + return cfg.engine_filename().replace(".engine", "") # (label, engine_dir_name) pairs jobs = [] @@ -913,7 +926,7 @@ def _print_matrix(durations, build_vae, build_decoder, output_dir, batch_max, jobs.append((f"VAE decode {dur}s", f"vae_decode_fp16_{dur}s")) jobs.append((f"VAE encode {dur}s", f"vae_encode_fp16_{dur}s")) if build_decoder: - jobs.append((f"Decoder {variant} {dur}s, refit", f"decoder{vtag}_mixed_refit_b{batch_max}_{dur}s")) + jobs.append((f"Decoder {variant} {dur}s, refit", _decoder_dir_name(dur))) if build_dreamvae: jobs.append((f"DreamVAE decode {dur}s", f"dreamvae_decode_fp16_{dur}s")) diff --git a/acestep/engine/trt/export.py b/acestep/engine/trt/export.py index 873f48ac..7a8fda60 100644 --- a/acestep/engine/trt/export.py +++ b/acestep/engine/trt/export.py @@ -334,6 +334,7 @@ def _export_forward( encoder_hidden_states, encoder_attention_mask, context_latents, + steering, use_cache=None, past_key_values=None, cache_position=None, @@ -380,7 +381,10 @@ def _export_forward( ) sw_mask = sw_mask.unsqueeze(0).unsqueeze(0) # [1, 1, S, S] - # Layer loop: static branching on layer_types (config, not runtime) + # Layer loop: static branching on layer_types. The added + # ``steering[:, i, :]`` shift is the in-engine equivalent of + # StreamPipeline._install_steering_hooks; the host zeros rows + # that aren't at an active step so they get a no-op add. for i, layer_module in enumerate(self_dec.layers): attn_mask = sw_mask if layer_types[i] == "sliding_attention" else None layer_outputs = layer_module( @@ -396,7 +400,7 @@ def _export_forward( encoder_hidden_states, None, # encoder_attention_mask ) - hidden_states = layer_outputs[0] + hidden_states = layer_outputs[0] + steering[:, i, :].unsqueeze(1).type_as(layer_outputs[0]) # Output AdaLN + proj_out (ConvTranspose1d stride=2 doubles seq_len) shift, scale = (self_dec.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) @@ -415,6 +419,7 @@ def forward( timestep: torch.Tensor, # [B] encoder_hidden_states: torch.Tensor, # [B, L_enc, 2048] context_latents: torch.Tensor, # [B, T, 128] + steering: torch.Tensor, # [B, num_layers, hidden_size] ) -> torch.Tensor: outputs = self.decoder( hidden_states=hidden_states, @@ -424,6 +429,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=None, context_latents=context_latents, + steering=steering, use_cache=False, past_key_values=None, output_attentions=False, @@ -519,12 +525,17 @@ def export_decoder_onnx( B = config.batch_size T = config.seq_len L = config.enc_len + # Steering: one additive vector per (batch row, DiT layer). Sized + # off the live decoder so XL reuses this path without a config knob. + num_layers = decoder.config.num_hidden_layers + hidden_size = decoder.config.hidden_size example_inputs = ( torch.randn(B, T, 64, device=device, dtype=trace_dtype), torch.full((B,), 0.5, device=device, dtype=ts_dtype), torch.randn(B, L, 2048, device=device, dtype=trace_dtype), torch.randn(B, T, 128, device=device, dtype=trace_dtype), + torch.zeros(B, num_layers, hidden_size, device=device, dtype=trace_dtype), ) input_names = [ @@ -532,6 +543,7 @@ def export_decoder_onnx( "timestep", "encoder_hidden_states", "context_latents", + "steering", ] output_names = ["velocity"] @@ -540,6 +552,8 @@ def export_decoder_onnx( "timestep": {0: "batch"}, "encoder_hidden_states": {0: "batch", 1: "enc_len"}, "context_latents": {0: "batch", 1: "seq_len"}, + # num_layers / hidden_size are static within a checkpoint family. + "steering": {0: "batch"}, "velocity": {0: "batch", 1: "seq_len"}, } @@ -590,6 +604,7 @@ def export_decoder_onnx( "timestep": {0: batch}, "encoder_hidden_states": {0: batch, 1: enc}, "context_latents": {0: batch, 1: seq}, + "steering": {0: batch}, } torch.onnx.export( wrapper, @@ -1064,11 +1079,12 @@ def max_duration_s(self) -> int: return self.seq_max // 25 def engine_filename(self) -> str: - """Generate a standardized engine filename from build config. + """Standardized engine filename. - Format: decoder_{variant}_{precision}[_refit]_b{batch_max}_{duration}s.engine - The variant tag is omitted for "turbo" (backward compat). - Uses seconds so naming is stable across frame rates. + Format: ``spectral_decoder_{variant}_{prec}[_refit]_b{batch_max}_{dur}s.engine``; + variant tag omitted for "turbo". The ``spectral_`` prefix marks + engines that carry the steering input; pre-steering engines used + ``decoder_*`` and the runtime rejects them. """ if self.strongly_typed: # fp8_mixed gets its own tag so FP8 engines never collide @@ -1089,7 +1105,7 @@ def engine_filename(self) -> str: dur = self.max_duration_s # Include variant in name for non-turbo models variant_tag = f"_{self.variant}" if self.variant != "turbo" else "" - return f"decoder{variant_tag}_{prec}{refit_tag}_b{self.batch_max}_{dur}s.engine" + return f"spectral_decoder{variant_tag}_{prec}{refit_tag}_b{self.batch_max}_{dur}s.engine" def build_trt_engine( @@ -1192,6 +1208,28 @@ def build_trt_engine( min=(Bmin, Smin, 128), opt=(Bopt, Sopt, 128), max=(Bmax, Smax, 128), ) + # Read L, D off the parsed network so this stays variant-agnostic. + steering_idx = None + for ti in range(network.num_inputs): + if network.get_input(ti).name == "steering": + steering_idx = ti + break + if steering_idx is None: + raise RuntimeError( + "Decoder ONNX is missing the 'steering' input; re-export with " + "the current export.py." + ) + steering_shape = tuple(network.get_input(steering_idx).shape) + if len(steering_shape) != 3: + raise RuntimeError(f"Unexpected steering input rank: {steering_shape}") + _, steer_L, steer_D = steering_shape + profile.set_shape( + "steering", + min=(Bmin, steer_L, steer_D), + opt=(Bopt, steer_L, steer_D), + max=(Bmax, steer_L, steer_D), + ) + profile_idx = build_config.add_optimization_profile(profile) if profile_idx < 0: raise RuntimeError("Failed to add TensorRT optimization profile") diff --git a/acestep/engine/trt/runtime.py b/acestep/engine/trt/runtime.py index 392fa8b9..beb9fc19 100644 --- a/acestep/engine/trt/runtime.py +++ b/acestep/engine/trt/runtime.py @@ -51,7 +51,10 @@ class TRTDecoder: ) """ - INPUT_NAMES = ("hidden_states", "timestep", "encoder_hidden_states", "context_latents") + INPUT_NAMES = ( + "hidden_states", "timestep", "encoder_hidden_states", + "context_latents", "steering", + ) OUTPUT_NAME = "velocity" def __init__( @@ -93,6 +96,13 @@ def __init__( out_trt_dtype = self.engine.get_tensor_dtype(self.OUTPUT_NAME) self._output_dtype = dtype_map.get(out_trt_dtype, torch.float32) + # Static L, D; B is dynamic per call. + steer_shape = tuple(self.engine.get_tensor_shape("steering")) + if len(steer_shape) != 3: + raise RuntimeError(f"Unexpected steering input rank: {steer_shape}") + self._steer_L = int(steer_shape[1]) + self._steer_D = int(steer_shape[2]) + # Per-shape buffer cache: shape_key -> {bufs, output} self._buf_cache: dict[tuple, dict] = {} @@ -115,11 +125,16 @@ def _get_bufs(self, hs_shape, ts_shape, enc_shape, cl_shape): dev = self.device in_dt = self._input_dtypes + B = hs_shape[0] bufs = { "hidden_states": torch.empty(hs_shape, dtype=in_dt["hidden_states"], device=dev), "timestep": torch.empty(ts_shape, dtype=in_dt["timestep"], device=dev), "encoder_hidden_states": torch.empty(enc_shape, dtype=in_dt["encoder_hidden_states"], device=dev), "context_latents": torch.empty(cl_shape, dtype=in_dt["context_latents"], device=dev), + "steering": torch.zeros( + B, self._steer_L, self._steer_D, + dtype=in_dt["steering"], device=dev, + ), } for name, buf in bufs.items(): @@ -148,12 +163,16 @@ def __call__( timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, context_latents: torch.Tensor, + steering: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Run one decoder step through TensorRT. Accepts any dtype; inputs are copied into pre-allocated fp32 buffers. Returns a view of the internal output buffer (caller must not hold references across calls with different shapes). + + ``steering`` is optional ``[B, num_layers, hidden_size]``; when + omitted the per-layer adds are no-ops. """ orig_T = hidden_states.shape[1] pad = orig_T % 2 == 1 @@ -178,6 +197,10 @@ def __call__( bufs["context_latents"].copy_(context_latents) bufs["timestep"].copy_(timestep) bufs["encoder_hidden_states"].copy_(encoder_hidden_states) + if steering is not None: + bufs["steering"].copy_(steering) + else: + bufs["steering"].zero_() # Bind addresses ctx = self.context diff --git a/acestep/paths.py b/acestep/paths.py index 3d2806ca..bae3b842 100644 --- a/acestep/paths.py +++ b/acestep/paths.py @@ -164,6 +164,43 @@ def user_uploads_dir() -> Path: return models_dir() / "user_uploads" +# Per-checkpoint probe-bundle subpath; shared between HF fetch +# (``steering_vectors/``) and the local cache. XL is absent +# because the 2B vectors don't transfer to a different hidden_size / +# layer count. +_STEERING_VECTORS_BY_CHECKPOINT: dict[str, str] = { + "acestep-v15-turbo": "v15-turbo/shift3.5_n8_seed1528", +} + + +def steering_bundle_subpath( + checkpoint: str | Path | None = None, +) -> str | None: + """Bundle subpath registered for ``checkpoint`` (pure lookup).""" + name = _checkpoint_name(checkpoint) + return _STEERING_VECTORS_BY_CHECKPOINT.get(name) + + +def steering_vectors_dir() -> Path: + """Root directory for cached steering-vector bundles.""" + return models_dir() / "steering_vectors" + + +def steering_vector_dir( + checkpoint: str | Path | None = None, +) -> Path | None: + """Local cache directory for a checkpoint's steering bundle. + + Pure: returns where vectors WOULD live; ``None`` when no bundle + is registered. Callers needing the dir populated should go through + ``acestep.steering.hub.ensure_steering_vectors`` first. + """ + subpath = steering_bundle_subpath(checkpoint) + if subpath is None: + return None + return steering_vectors_dir() / subpath + + def discover_loras(directory: Path | None = None) -> list[Path]: """List ``*.safetensors`` files recursively under ``directory`` (default: ``loras_dir()``). @@ -260,17 +297,17 @@ def trt_engine_path(engine_name: str) -> Path: # that fits the audio (see `select_trt_engines` and `available_trt_engines`). _TRT_ENGINE_PROFILES: dict[float, dict[str, str]] = { 60.0: { - "decoder": "decoder_mixed_refit_b8_60s", + "decoder": "spectral_decoder_mixed_refit_b8_60s", "vae_encode": "vae_encode_fp16_60s", "vae_decode": "vae_decode_fp16_60s", }, 120.0: { - "decoder": "decoder_mixed_refit_b8_120s", + "decoder": "spectral_decoder_mixed_refit_b8_120s", "vae_encode": "vae_encode_fp16_120s", "vae_decode": "vae_decode_fp16_120s", }, 240.0: { - "decoder": "decoder_mixed_refit_b8_240s", + "decoder": "spectral_decoder_mixed_refit_b8_240s", "vae_encode": "vae_encode_fp16_240s", "vae_decode": "vae_decode_fp16_240s", }, @@ -356,7 +393,7 @@ def trt_engine_profiles( def default_trt_engines( - decoder: str = "decoder_mixed_refit_b8_60s", + decoder: str = "spectral_decoder_mixed_refit_b8_60s", vae_encode: str = "vae_encode_fp16_60s", vae_decode: str = "vae_decode_fp16_60s", ) -> dict[str, str]: diff --git a/acestep/steering/__init__.py b/acestep/steering/__init__.py new file mode 100644 index 00000000..f817d1a1 --- /dev/null +++ b/acestep/steering/__init__.py @@ -0,0 +1,57 @@ +"""Activation-steering engine surface. + +Wire knob names: + - ``steer_bright`` / ``steer_warm`` / ``steer_rough`` / ``steer_density``: + the four verified auto-path axes. + - ``man_src_`` / ``man_layer_`` / ``man_step_`` / ``man_alpha_``: + LIFO-numbered manual slot quadruples. + +Auto-axis alpha is sign-corrected at config-build time; the on-disk +cache stays raw so the manual path (direction-agnostic by design) sees +the literal probe direction. +""" + +from .catalog import enumerate_catalog, load_auto_vectors, load_vector +from .controller import SteeringController +from .hub import ensure_steering_vectors, upload_steering_vectors +from .policy import ( + AUTO_AXES, + MANUAL_CATALOG_AXES, + MANUAL_MAX_LAYER, + MANUAL_MAX_STEP, + MANUAL_SLOT_CAP, + MANUAL_SLOT_DEFAULT_COUNT, + PROBE_N, + fractional_inject_step, +) +from .types import ( + AutoAxis, + CapacityError, + CatalogEntry, + EmptyError, + KnobNames, + SteeringError, +) + +__all__ = [ + "AUTO_AXES", + "AutoAxis", + "CapacityError", + "CatalogEntry", + "EmptyError", + "KnobNames", + "MANUAL_CATALOG_AXES", + "MANUAL_MAX_LAYER", + "MANUAL_MAX_STEP", + "MANUAL_SLOT_CAP", + "MANUAL_SLOT_DEFAULT_COUNT", + "PROBE_N", + "SteeringController", + "SteeringError", + "ensure_steering_vectors", + "enumerate_catalog", + "fractional_inject_step", + "load_auto_vectors", + "load_vector", + "upload_steering_vectors", +] diff --git a/acestep/steering/catalog.py b/acestep/steering/catalog.py new file mode 100644 index 00000000..dde23eeb --- /dev/null +++ b/acestep/steering/catalog.py @@ -0,0 +1,109 @@ +"""Vector catalog: walk the on-disk probe dir into stable indexed entries. + +Split into a torch-free enumerator (filename metadata) and a +torch-paying loader. Stems look like ``brightness_l09_t3.pt`` — the +zero-padded layer lets alphabetical sort yield the intended +(l03..l18) x (t0, t3, ...) order naturally. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +from .policy import MANUAL_CATALOG_AXES +from .types import AutoAxis, CatalogEntry + +if TYPE_CHECKING: + import torch + + +def _parse_stem(stem: str, axis: str) -> tuple[int, int] | None: + """Extract (layer, step) from ``brightness_l09_t3`` style stems. + + Returns ``None`` if the stem doesn't parse against ``axis``. + """ + prefix = f"{axis}_l" + if not stem.startswith(prefix): + return None + tail = stem[len(axis) + 1:] # "l09_t3" + try: + layer = int(tail.split("_")[0][1:]) + step = int(tail.split("_")[1][1:]) + except (IndexError, ValueError): + return None + return layer, step + + +def enumerate_catalog(vector_dir: Path) -> tuple[CatalogEntry, ...]: + """List every (axis, build_layer, build_step) cell on disk. + + Torch-free; reads filenames only. Empty tuple when the dir is + missing. + """ + if not vector_dir.exists(): + return () + per_axis: dict[str, list[Path]] = {a: [] for a in MANUAL_CATALOG_AXES} + for path in sorted(vector_dir.glob("*.pt")): + for axis in MANUAL_CATALOG_AXES: + if path.stem.startswith(f"{axis}_l"): + per_axis[axis].append(path) + break + out: list[CatalogEntry] = [] + for axis in MANUAL_CATALOG_AXES: + for path in per_axis[axis]: + parsed = _parse_stem(path.stem, axis) + if parsed is None: + continue + layer, step = parsed + out.append(CatalogEntry( + index=len(out), + axis=axis, + build_layer=layer, + build_step=step, + filename=path.name, + )) + return tuple(out) + + +def load_vector( + vector_dir: Path, + entry: CatalogEntry, +) -> tuple["torch.Tensor", float]: + """Load one catalog entry's (vector, magnitude) as CPU float32.""" + import torch as _t + + path = vector_dir / entry.filename + blob = _t.load(path, map_location="cpu", weights_only=False) + return blob["vector"].to(_t.float32), float(blob["magnitude"]) + + +def load_auto_vectors( + vector_dir: Path, + auto_axes: tuple[AutoAxis, ...], +) -> dict[str, dict]: + """Load the per-axis vector for each verified auto-path axis. + + Returns ``{axis.name: {"layer": int, "vector": Tensor, + "magnitude": float}}``. Missing files are silently skipped — the + knob for that axis just never produces a config. + """ + import torch as _t + + out: dict[str, dict] = {} + if not vector_dir.exists(): + return out + for ax in auto_axes: + path = vector_dir / f"{ax.axis}_l{ax.probe_layer:02d}_t{ax.probe_step}.pt" + if not path.exists(): + continue + try: + blob = _t.load(path, map_location="cpu", weights_only=False) + except Exception: + continue + out[ax.name] = { + "layer": ax.probe_layer, + "vector": blob["vector"].to(_t.float32), + "magnitude": float(blob["magnitude"]), + } + return out diff --git a/acestep/steering/controller.py b/acestep/steering/controller.py new file mode 100644 index 00000000..86e537c3 --- /dev/null +++ b/acestep/steering/controller.py @@ -0,0 +1,238 @@ +"""SteeringController: per-session entry point for activation steering. + +Owns the auto-axis vector cache, the manual catalog, the LIFO slot +registry, and the ``raw`` knob dict → ``set_steering`` config-list +translation. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Mapping + +from .catalog import enumerate_catalog, load_auto_vectors, load_vector +from .policy import ( + AUTO_AXES, + MANUAL_MAX_LAYER, + MANUAL_MAX_STEP, + MANUAL_SLOT_CAP, + MANUAL_SLOT_DEFAULT_COUNT, + PROBE_N, + fractional_inject_step, + parse_probe_n, +) +from .types import AutoAxis, CapacityError, CatalogEntry, EmptyError, KnobNames + + +# Wire knob name prefixes — agreed on by pipeline raw dict, demo +# KnobDef registry, and the MCP catalog. +_MAN_SRC = "man_src_" +_MAN_LAYER = "man_layer_" +_MAN_STEP = "man_step_" +_MAN_ALPHA = "man_alpha_" + + +class SteeringController: + """Engine-side steering state for one streaming session. + + Construction loads the catalog + auto-axis tensors. Degrades to + ``is_loaded == False`` (empty catalog, ``build_configs`` returns + ``[]``) when ``vector_dir`` is None or missing on disk. + + Slots are LIFO: ``add_slot`` allocates ``slot_count + 1`` so the + active set is always ``{1..count}`` with no holes. Interior + deletion is not supported, which keeps catalog indices stable + across edits. + """ + + auto_axes: tuple[AutoAxis, ...] + catalog: tuple[CatalogEntry, ...] + slot_cap: int = MANUAL_SLOT_CAP + MANUAL_MAX_LAYER: int = MANUAL_MAX_LAYER + MANUAL_MAX_STEP: int = MANUAL_MAX_STEP + + def __init__( + self, + vector_dir: Path | None, + *, + default_slot_count: int = MANUAL_SLOT_DEFAULT_COUNT, + ) -> None: + self._vector_dir = vector_dir + self.auto_axes = AUTO_AXES + # Probe schedule N comes from the bundle dir name + # (``shift3.5_n8_seed1528`` → 8); falls back to PROBE_N when the + # token is absent or no dir was passed. + self._probe_n = parse_probe_n(vector_dir.name) if vector_dir else PROBE_N + if vector_dir is None: + self.catalog = () + self._auto_vectors = {} + else: + self.catalog = enumerate_catalog(vector_dir) + self._auto_vectors = load_auto_vectors(vector_dir, AUTO_AXES) + # Lazy manual-vector cache keyed by catalog index. ``None`` is + # the cached load-failure sentinel so a corrupted .pt logs once. + self._manual_vectors: dict[int, dict | None] = {} + if not self.is_loaded: + self._slot_count = 0 + else: + self._slot_count = max(0, min(int(default_slot_count), self.slot_cap)) + + @property + def is_loaded(self) -> bool: + """True if at least one vector (auto or manual) is reachable.""" + return bool(self._auto_vectors) or bool(self.catalog) + + @property + def slot_count(self) -> int: + return self._slot_count + + def active_slots(self) -> tuple[int, ...]: + """Sorted tuple of currently allocated slot ids (1..slot_count).""" + return tuple(range(1, self._slot_count + 1)) + + def add_slot(self) -> int: + """Allocate the next LIFO slot id. Raises CapacityError at cap. + + Refused when no vectors are loaded: a slot with no catalog + behind it is just a dead knob quadruple. + """ + if not self.is_loaded: + raise CapacityError("manual steering unavailable (no vectors loaded)") + if self._slot_count >= self.slot_cap: + raise CapacityError( + f"manual steering at cap ({self.slot_cap})", + ) + self._slot_count += 1 + return self._slot_count + + def pop_slot(self) -> int: + """Pop the highest-numbered slot. Raises EmptyError when empty.""" + if self._slot_count <= 0: + raise EmptyError("no manual steering slots to remove") + popped = self._slot_count + self._slot_count -= 1 + return popped + + @staticmethod + def knob_names(slot_id: int) -> KnobNames: + """The four wire-protocol knob names for slot ``slot_id``.""" + return KnobNames( + src=f"{_MAN_SRC}{slot_id}", + layer=f"{_MAN_LAYER}{slot_id}", + step=f"{_MAN_STEP}{slot_id}", + alpha=f"{_MAN_ALPHA}{slot_id}", + ) + + def snapshot_key( + self, + raw: Mapping[str, float], + n: int, + ) -> tuple: + """Build the change-detection key for ``raw`` at schedule ``n``. + + Returns a tuple containing every input that ``build_configs`` + consults: the auto-axis alphas in declaration order, then per + active slot the (src, layer, step, alpha) quadruple, then the + schedule step count. Demo callers cache this and skip + ``set_steering`` when it's unchanged. + + Length is dynamic in ``slot_count``: an add or pop changes the + tuple length so equality fails and the next ``build_configs`` + actually fires. + """ + out: list[float] = [float(raw.get(ax.name, 0.0)) for ax in self.auto_axes] + for slot in self.active_slots(): + names = self.knob_names(slot) + out.append(float(raw.get(names.src, 0.0))) + out.append(float(raw.get(names.layer, 0.0))) + out.append(float(raw.get(names.step, 0.0))) + out.append(float(raw.get(names.alpha, 0.0))) + out.append(float(max(1, int(n)))) + return tuple(out) + + def _get_manual_vector(self, src_idx: int) -> dict | None: + """Lazy-load one catalog entry's ``{vector, magnitude}``. + + Caches successes and failures (failure as None) so a corrupted + .pt logs once instead of every tick. Returns None for out-of- + range indices or load failures. + """ + if src_idx < 0 or src_idx >= len(self.catalog): + return None + cached = self._manual_vectors.get(src_idx, _MISSING) + if cached is not _MISSING: + return cached + entry = self.catalog[src_idx] + try: + vec, mag = load_vector(self._vector_dir, entry) # type: ignore[arg-type] + blob: dict | None = {"vector": vec, "magnitude": mag} + except Exception as exc: + from loguru import logger + logger.warning( + "steering: failed to load manual vector {} (idx {}): {}", + entry.filename, src_idx, exc, + ) + blob = None + self._manual_vectors[src_idx] = blob + return blob + + def build_configs( + self, + raw: Mapping[str, float], + n: int, + ) -> list[dict]: + """Translate live ``raw`` knob state into ``set_steering`` configs. + + Auto axes get the fractional step mapping, per-axis layer + offset, and sign correction. Manual slots are verbatim: + (layer, step) lands as picked, alpha is consumed sign-as-given. + Zero-alpha entries are dropped. + """ + if not self.is_loaded: + return [] + n = max(1, int(n)) + configs: list[dict] = [] + for ax in self.auto_axes: + alpha = float(raw.get(ax.name, 0.0)) + if alpha == 0.0: + continue + blob = self._auto_vectors.get(ax.name) + if blob is None: + continue + inject_step = fractional_inject_step( + ax.probe_step, n, probe_n=self._probe_n, + ) + inject_layer = max( + 0, min(MANUAL_MAX_LAYER, ax.probe_layer + ax.layer_offset), + ) + configs.append({ + "layer": inject_layer, + "step": inject_step, + "vector": blob["vector"], + "magnitude": blob["magnitude"], + "alpha": alpha * ax.sign, + }) + for slot in self.active_slots(): + names = self.knob_names(slot) + alpha = float(raw.get(names.alpha, 0.0)) + if alpha == 0.0: + continue + src_idx = int(round(float(raw.get(names.src, 0.0)))) + blob = self._get_manual_vector(src_idx) + if blob is None: + continue + inject_layer = max( + 0, min(MANUAL_MAX_LAYER, int(round(float(raw.get(names.layer, 0.0))))), + ) + inject_step = int(round(float(raw.get(names.step, 0.0)))) + configs.append({ + "layer": inject_layer, + "step": inject_step, + "vector": blob["vector"], + "magnitude": blob["magnitude"], + "alpha": alpha, + }) + return configs + + +_MISSING = object() diff --git a/acestep/steering/hub.py b/acestep/steering/hub.py new file mode 100644 index 00000000..3ed3a783 --- /dev/null +++ b/acestep/steering/hub.py @@ -0,0 +1,155 @@ +"""HF source for activation-steering vector bundles. + +Mirrors :mod:`acestep.engine.trt.onnx_hub`. Local layout matches +:func:`acestep.paths.steering_vector_dir` so a fresh fetch and a +warm cache resolve to the same path. +""" + +from __future__ import annotations + +import shutil +from pathlib import Path +from typing import Optional + +from loguru import logger + +from acestep.engine.trt.onnx_hub import DEMON_ONNX_REPO +from acestep.paths import steering_bundle_subpath, steering_vectors_dir + + +# Steering bundles share the demon-onnx repo today; alias kept so a +# future split to ``daydreamlive/demon-steering`` is a one-line change. +DEMON_STEERING_REPO = DEMON_ONNX_REPO +_HF_PREFIX = "steering_vectors" + + +def hf_bundle_dir(subpath: str) -> str: + """The HF in-repo path for a bundle subpath (no trailing slash).""" + return f"{_HF_PREFIX}/{subpath}" + + +def ensure_steering_vectors( + checkpoint: Optional[str] = None, + *, + force_download: bool = False, +) -> Optional[Path]: + """Ensure the checkpoint's bundle is on disk; return its cache dir. + + Returns ``None`` when no bundle is registered for the checkpoint + (XL today). HF errors are best-effort — they log a warning and + return the (possibly empty) cache dir so the session keeps booting. + """ + subpath = steering_bundle_subpath(checkpoint) + if subpath is None: + return None + + target_dir = steering_vectors_dir() / subpath + # Cache-hit heuristic: directory has at least one .pt file. We + # don't compare against an HF manifest because the bundle is + # small and the operator can force_download=True to refresh. + if not force_download and target_dir.is_dir(): + if any(target_dir.glob("*.pt")): + return target_dir + + target_dir.mkdir(parents=True, exist_ok=True) + hf_dir = hf_bundle_dir(subpath) + + try: + from huggingface_hub import snapshot_download + except ImportError as exc: + logger.warning( + "steering: huggingface_hub not available ({}); " + "steering knobs will no-op until vectors are installed at {}.", + exc, target_dir, + ) + return target_dir + + logger.info( + "Fetching steering bundle {!r} from HF: {}/{}", + subpath, DEMON_STEERING_REPO, hf_dir, + ) + try: + snap_dir = Path(snapshot_download( + repo_id=DEMON_STEERING_REPO, + allow_patterns=[f"{hf_dir}/*"], + force_download=force_download, + )) + except Exception as exc: + logger.warning( + "steering: HF fetch failed ({}); steering knobs will no-op. " + "Cache dir: {}", + exc, target_dir, + ) + return target_dir + + src_dir = snap_dir / hf_dir + if not src_dir.is_dir(): + logger.warning( + "steering: HF bundle dir {} not present in snapshot at {}; " + "steering knobs will no-op.", + hf_dir, snap_dir, + ) + return target_dir + + copied = 0 + for f in src_dir.iterdir(): + if f.is_file(): + shutil.copy2(f, target_dir / f.name) + copied += 1 + logger.info( + "Steering bundle ready at {} ({} files)", + target_dir, copied, + ) + return target_dir + + +def upload_steering_vectors( + checkpoint: str, + *, + local_dir: Optional[Path] = None, + repo: str = DEMON_STEERING_REPO, + commit_message: Optional[str] = None, + dry_run: bool = False, +) -> None: + """Upload a local steering bundle to HF (operator-facing).""" + subpath = steering_bundle_subpath(checkpoint) + if subpath is None: + raise ValueError( + f"No steering bundle registered for checkpoint {checkpoint!r}. " + f"Add it to _STEERING_VECTORS_BY_CHECKPOINT in acestep/paths.py." + ) + + src = Path(local_dir) if local_dir is not None else steering_vectors_dir() / subpath + if not src.is_dir(): + raise FileNotFoundError(f"Local steering dir not found: {src}") + + files = sorted(p for p in src.iterdir() if p.is_file()) + if not files: + raise FileNotFoundError(f"Local steering dir is empty: {src}") + + total_mb = sum(p.stat().st_size for p in files) / (1 << 20) + hf_dir = hf_bundle_dir(subpath) + logger.info( + "Upload plan: {} -> {}/{} ({} files, {:.1f} MB)", + subpath, repo, hf_dir, len(files), total_mb, + ) + for f in files[:10]: + logger.info(" {}", f.name) + if len(files) > 10: + logger.info(" ... {} more", len(files) - 10) + + if dry_run: + logger.info("--dry-run: not uploading") + return + + from huggingface_hub import HfApi + msg = commit_message or ( + f"Upload steering bundle {subpath} ({len(files)} files, {total_mb:.0f} MB)" + ) + HfApi().upload_folder( + repo_id=repo, + folder_path=str(src), + path_in_repo=hf_dir, + commit_message=msg, + ) + logger.info("Uploaded steering bundle {} to {}/{}", subpath, repo, hf_dir) diff --git a/acestep/steering/policy.py b/acestep/steering/policy.py new file mode 100644 index 00000000..70be58a5 --- /dev/null +++ b/acestep/steering/policy.py @@ -0,0 +1,145 @@ +"""Activation-steering policy: which axes are exposed, where they +inject, and how schedules translate. + +Pure tables + one pure function. No torch, no disk I/O. The +SteeringController consumes this module; consumers that only need to +*describe* the surface (MCP, UI catalog builders) can import here +without torch. +""" + +from __future__ import annotations + +import re + +from .types import AutoAxis + + +# Default probe schedule. Used when the bundle subpath doesn't encode +# one (older bundles) or for callers that don't have a subpath handy. +PROBE_N: int = 8 + + +_PROBE_N_RE = re.compile(r"_n(\d+)(?:_|$)") + + +def parse_probe_n(subpath: str) -> int: + """Pull the probe schedule N out of a bundle subpath. + + Subpaths look like ``v15-turbo/shift3.5_n8_seed1528``; the ``_n_`` + token names the schedule. Falls back to ``PROBE_N`` when absent. + """ + m = _PROBE_N_RE.search(subpath) + return int(m.group(1)) if m else PROBE_N + + +# v1.5 turbo decoder has 24 DiT blocks → legal layers 0..23. +MANUAL_MAX_LAYER: int = 23 + +# Matches the demo's ``steps_override`` MIDI cap of 16 (0..15). +# Values past the live ``steps_override - 1`` silently no-op. +MANUAL_MAX_STEP: int = 15 + +MANUAL_SLOT_DEFAULT_COUNT: int = 1 + +# Soft cap on registered slots; sized so the UI doesn't collapse if +# someone stress-tests the surface (engine cost per slot is negligible). +MANUAL_SLOT_CAP: int = 16 + + +# Manual catalog index order: axis-major per this list, then +# build_layer asc, then build_step asc. Pins each (axis, layer, step) +# cell to a stable index across sessions. +MANUAL_CATALOG_AXES: tuple[str, ...] = ( + "brightness", + "warmth", + "roughness", + "density", + "attack", + "tonality", + "punch", + "bass_emphasis", +) + + +# Auto-path axes — only the four whose prompt-to-metric premise +# verified in PROMPT_BASELINE.md. The broken-premise four (attack, +# tonality, punch, bass_emphasis) stay reachable via the manual +# catalog only. +# +# warmth: sign=-1 flips the raw vector so positive alpha tilts warmer +# (raw probe direction is reversed for this axis). +# density: layer_offset=-3 from PHASE3_ANALYSIS.md (28/30 transfer +# pairs preferred 3 layers shallower than the probe). +AUTO_AXES: tuple[AutoAxis, ...] = ( + AutoAxis( + name="steer_bright", + axis="brightness", + probe_layer=9, + probe_step=3, + sign=1.0, + layer_offset=0, + blurb=( + "positive alpha shifts spectral centroid up " + "(brighter, more highs)" + ), + ), + AutoAxis( + name="steer_warm", + axis="warmth", + probe_layer=15, + probe_step=0, + sign=-1.0, + layer_offset=0, + blurb=( + "positive alpha tilts the spectrum toward bass (warmer); " + "vector is sign-corrected from the raw probe direction" + ), + ), + AutoAxis( + name="steer_rough", + axis="roughness", + probe_layer=9, + probe_step=3, + sign=1.0, + layer_offset=0, + blurb=( + "positive alpha increases spectral flatness " + "(grittier, noisier); magnitude is small at this probe cell" + ), + ), + AutoAxis( + name="steer_density", + axis="density", + probe_layer=18, + probe_step=3, + sign=1.0, + layer_offset=-3, + blurb=( + "positive alpha thins the texture toward sparse/minimal" + ), + ), +) + + +def fractional_inject_step( + probe_step: int, + inject_n: int, + *, + probe_n: int = PROBE_N, +) -> int: + """Translate a probe-schedule step into the inject schedule. + + Returns ``round(probe_step / probe_n * inject_n)`` clamped to + ``[0, inject_n - 1]``. Identity at same schedule + (``inject_n == probe_n`` returns ``probe_step`` exactly). + """ + if inject_n < 1: + inject_n = 1 + if probe_n < 1: + probe_n = 1 + s = int(round(probe_step / probe_n * inject_n)) + if s < 0: + return 0 + if s >= inject_n: + return inject_n - 1 + return s diff --git a/acestep/steering/types.py b/acestep/steering/types.py new file mode 100644 index 00000000..13cd63aa --- /dev/null +++ b/acestep/steering/types.py @@ -0,0 +1,61 @@ +"""Frozen dataclasses + exceptions for the activation-steering surface. + +Pure data; no torch / disk I/O so torch-free consumers (MCP catalog, +UI metadata) can import without paying for the controller's torch +dependency. +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class AutoAxis: + """One verified auto-path steering axis. + + ``sign`` is folded into alpha at config-build time so the on-disk + cache stays raw (the manual path reads the same cache and is + sign-agnostic by design). + """ + + name: str # knob name, e.g. "steer_bright" + axis: str # disk-side axis tag, e.g. "brightness" + probe_layer: int + probe_step: int + sign: float # -1.0 flips the raw vector direction; 1.0 leaves it + layer_offset: int # added to probe_layer at inject time + blurb: str # operator-facing effect description + + +@dataclass(frozen=True) +class CatalogEntry: + """One pre-built (axis, build_layer, build_step) cell from disk.""" + + index: int + axis: str + build_layer: int + build_step: int + filename: str + + +@dataclass(frozen=True) +class KnobNames: + """Wire knob names for one manual steering slot.""" + + src: str # man_src_ + layer: str # man_layer_ + step: str # man_step_ + alpha: str # man_alpha_ + + +class SteeringError(Exception): + """Base class for slot-registry errors raised by SteeringController.""" + + +class CapacityError(SteeringError): + """``add_slot`` called when the registry is already at ``slot_cap``.""" + + +class EmptyError(SteeringError): + """``pop_slot`` called when the registry has no allocated slots.""" diff --git a/acestep/streaming/ace_backend.py b/acestep/streaming/ace_backend.py index 1f765f6a..ff454297 100644 --- a/acestep/streaming/ace_backend.py +++ b/acestep/streaming/ace_backend.py @@ -38,10 +38,13 @@ Capabilities, TickContext, ) +from acestep.steering import SteeringController from acestep.streaming.knobs import ( CHANNEL_GROUPS, KEYSTONE_CHANNELS, knob_specs as registry_knob_specs, + manual_slot_specs, + steering_axis_spec, ) # Audio sample rate the ACE-Step v1.5 family is trained on, and the @@ -121,6 +124,41 @@ def _curve_from_spec(spec, T): return None +def steering_knob_specs(steering: "SteeringController") -> list: + """Project a SteeringController's live surface into registry specs. + + Empty when no vector bundle is reachable (the knobs would be dead). + The spec SHAPES come from the registry factories in + ``acestep.streaming.knobs``; only the axis/catalog metadata is + filled in here, where the steering policy lives. + """ + if not steering.is_loaded: + return [] + specs: list = [] + for ax in steering.auto_axes: + inject_layer = max( + 0, min(steering.MANUAL_MAX_LAYER, ax.probe_layer + ax.layer_offset), + ) + specs.append(steering_axis_spec( + ax.name, + axis=ax.axis, + inject_layer=inject_layer, + probe_step=ax.probe_step, + probe_n=steering._probe_n, + blurb=ax.blurb, + )) + src_max = max(0, len(steering.catalog) - 1) + for slot in steering.active_slots(): + specs.extend(manual_slot_specs( + slot, + src_max=src_max, + catalog_len=len(steering.catalog), + layer_max=steering.MANUAL_MAX_LAYER, + step_max=steering.MANUAL_MAX_STEP, + )) + return specs + + class ACEStepBackend(DiffusionBackend): """ACE-Step v1.5 diffusion generation behind the GeneratorBackend seam. @@ -141,6 +179,7 @@ def __init__( walk_window=False, walk_window_s=60.0, neg_conditioning=None, + steering: SteeringController | None = None, ): # The family codec is the engine Session: its windowed VAE # decode is what render_window()/render_full() drive. The @@ -205,6 +244,19 @@ def __init__( # ``None`` on the first tick just seeds the baseline. self._last_rebuild_keys = None + # Activation steering. The controller is the source of truth for + # the slot count and vector catalog; the session mirrors its + # slot ops into KnobState / the knob manifest. ``None`` (e.g. a + # bare-construction test fixture) degrades to an unloaded + # controller so every consumer can read it unconditionally. + self.steering = ( + steering if steering is not None else SteeringController(None) + ) + # (pipeline, snapshot) change-detection key for _sync_steering; + # None forces a push on the first tick and after a + # steps_override-driven pipeline rebuild. + self._last_steering = None + # ----- per-tick translation state (the old run() locals) ----- self._last_latent = None # Previous fresh latent for the full-buffer MSE skip. Tracked @@ -275,6 +327,7 @@ def capabilities(self) -> Capabilities: depth=True, curves=True, notes_conditioning=False, + steering=self.steering.is_loaded, ) def geometry(self) -> AudioGeometry: @@ -291,11 +344,13 @@ def geometry(self) -> AudioGeometry: def knob_specs(self, lora_ids=()) -> list: """The ACE-family manifest: the shared registry's spec list, parameterized by this session's SDE mode and the enabled-LoRA - set the session passes in (see the protocol docstring).""" + set the session passes in (see the protocol docstring), plus + the activation-steering surface (auto axes + the live manual + slots) when this session's checkpoint has a vector bundle.""" return registry_knob_specs( self.use_sde, loras=list(lora_ids) if self.use_lora else [], - ) + ) + steering_knob_specs(self.steering) # ---- public hooks reachable from session ops --------------------------- @@ -377,6 +432,28 @@ def _sync_channel_guidance(self, raw: dict, last: list) -> list: self.stream.model.handler._channel_guidance = configs return ch_gains[:] + def _sync_steering(self, raw: dict, last): + """Push activation-steering configs when the snapshot changes. + + ``last`` is ``(pipeline, snapshot_tuple)`` or ``None``. Pipeline + identity is part of the key because ``steps_override`` rebuilds + the StreamPipeline (fresh, empty steering state) without + changing ``raw`` — without the identity check the new pipeline + would never receive ``set_steering``. + """ + if not self.steering.is_loaded: + return last + pipe = self.stream.pipeline + if pipe is None: + return last + n = max(1, int(raw.get("steps_override", 8))) + snapshot = self.steering.snapshot_key(raw, n) + last_pipe, last_snapshot = last if last is not None else (None, None) + if pipe is last_pipe and snapshot == last_snapshot: + return last + pipe.set_steering(self.steering.build_configs(raw, n)) + return (pipe, snapshot) + # ---- GeneratorBackend hot loop ----------------------------------------- def sync_source(self, ctx: TickContext) -> None: @@ -712,6 +789,9 @@ def _prepare_tick(self, knobs: dict, ctx: TickContext) -> dict: self._last_channel_gains = self._sync_channel_guidance( raw, self._last_channel_gains, ) + self._last_steering = self._sync_steering( + raw, self._last_steering, + ) # Route every curve-capable parameter through the shared # mutable curve system so knob changes take effect on ALL diff --git a/acestep/streaming/events.py b/acestep/streaming/events.py index 9e86949b..a5a090d6 100644 --- a/acestep/streaming/events.py +++ b/acestep/streaming/events.py @@ -166,6 +166,15 @@ class DepthApplied: value: int +@dataclass(frozen=True) +class ManualSlotCount: + """Manual steering slot count after a manual_slot_add / manual_slot_pop + (published on success AND refusal so the client's +/- UI resyncs + either way). ``count`` is the controller's live slot count.""" + + count: int + + @dataclass(frozen=True) class SwapReady: """Source swap completed. Carries enough state for the transport to diff --git a/acestep/streaming/families.py b/acestep/streaming/families.py index 701d4d92..4adda768 100644 --- a/acestep/streaming/families.py +++ b/acestep/streaming/families.py @@ -20,8 +20,16 @@ def _make_acestep(ss): + from acestep.steering import SteeringController, ensure_steering_vectors from acestep.streaming.ace_backend import ACEStepBackend + # SteeringController is the source of truth for slot_count and the + # vector catalog; ensure_steering_vectors fetches/caches the + # checkpoint's probe bundle (None for checkpoints without one — XL, + # fetch failures — which degrades the controller to is_loaded=False + # and drops the steering capability/knobs for the session). + steering = SteeringController(ensure_steering_vectors(ss.checkpoint)) + return ACEStepBackend( ss.session, ss.stream, state=ss.state, @@ -34,6 +42,7 @@ def _make_acestep(ss): walk_window=ss.walk_window, walk_window_s=ss.walk_window_s, neg_conditioning=ss.cond_negative, + steering=steering, ) @@ -43,14 +52,48 @@ def _make_acestep(ss): def _acestep_knob_universe(): - from acestep.streaming.knobs import knob_specs + from acestep.steering.policy import ( + AUTO_AXES, + MANUAL_MAX_LAYER, + MANUAL_MAX_STEP, + PROBE_N, + ) + from acestep.streaming.knobs import ( + knob_specs, + manual_slot_specs, + steering_axis_spec, + ) # Every spec the family can ever expose: both SDE-mode variants plus # a representative LoRA-strength knob (the per-id specs all come from - # lora_strength_spec, so one placeholder id covers the pattern). + # lora_strength_spec, so one placeholder id covers the pattern), plus + # the steering surface — the four auto axes and one representative + # manual slot (per-slot specs all come from manual_slot_specs). + # Catalog geometry uses the canonical v15-turbo bundle's 144 cells; + # no network fetch happens here (policy tables only). + steering = [ + steering_axis_spec( + ax.name, + axis=ax.axis, + inject_layer=max( + 0, min(MANUAL_MAX_LAYER, ax.probe_layer + ax.layer_offset), + ), + probe_step=ax.probe_step, + probe_n=PROBE_N, + blurb=ax.blurb, + ) + for ax in AUTO_AXES + ] + manual_slot_specs( + 1, + src_max=143, + catalog_len=144, + layer_max=MANUAL_MAX_LAYER, + step_max=MANUAL_MAX_STEP, + ) return ( knob_specs(False, loras=[""]) + knob_specs(True, loras=[""]) + + steering ) diff --git a/acestep/streaming/generator_backend.py b/acestep/streaming/generator_backend.py index 0a33ac05..e26b3ae6 100644 --- a/acestep/streaming/generator_backend.py +++ b/acestep/streaming/generator_backend.py @@ -152,6 +152,11 @@ class Capabilities: depth: bool = False curves: bool = False notes_conditioning: bool = False + # Activation steering (per-layer residual shifts driven by the + # steer_* / man_*_ knobs and the manual_slot_add/pop commands). + # True only when the backend has a steering controller with a + # reachable vector bundle for its checkpoint. + steering: bool = False @dataclass(frozen=True) diff --git a/acestep/streaming/knobs.py b/acestep/streaming/knobs.py index f3c14bb9..df8c0eff 100644 --- a/acestep/streaming/knobs.py +++ b/acestep/streaming/knobs.py @@ -218,6 +218,101 @@ def lora_strength_spec(lora_id: str) -> KnobSpec: ) +# Activation-steering alpha range. Bipolar so the operator can invert an +# axis without leaving the surface; useful magnitude is roughly 2..15 by +# ear, breakage above that. +STEERING_ALPHA_MAX = 30.0 + + +def steering_axis_spec( + name: str, + *, + axis: str = "", + inject_layer: int = 0, + probe_step: int = 0, + probe_n: int = 8, + blurb: str = "", +) -> KnobSpec: + """The registry spec for one auto-path activation-steering knob. + + Shaped here (range, group, bank) so every transport projects the + same contract; the axis metadata (where the vector injects, what it + does) arrives as plain values from the backend that owns the + steering policy — this module stays torch-free / acestep-free. + """ + return KnobSpec( + name, default=0.0, + min_val=-STEERING_ALPHA_MAX, max_val=STEERING_ALPHA_MAX, + group="steering", + description=( + f"Activation-steering ({axis}) injected at DiT layer " + f"{inject_layer}, step round({probe_step}/{probe_n} * inject_n) " + f"of the current schedule. 0 = off, negative inverts the axis " + f"direction. {blurb}." + " Useful magnitude roughly 2..15 by ear; breakage above that." + ), + ) + + +def manual_slot_specs( + slot_id: int, + *, + src_max: int, + catalog_len: int, + layer_max: int, + step_max: int, +) -> list: + """The four registry specs for one manual steering slot. + + Like :func:`lora_strength_spec`, factored so the runtime slot + add path and the session-start manifest both shape the knobs from + the registry. Manual slots bypass the auto path's fractional step + mapping, layer offset, and sign correction — the vector lands at + the operator's chosen cell with the operator's chosen sign. + """ + return [ + KnobSpec( + f"man_src_{slot_id}", default=0.0, min_val=0.0, + max_val=float(src_max), type="int", group="manual", + description=( + f"Manual slot {slot_id}: vector catalog index. Resolves to " + f"a (axis, build_layer, build_step) cell on disk; call " + f"list_manual_steering_vectors for the table. Index " + f"0..{src_max} ({catalog_len} cells)." + ), + ), + KnobSpec( + f"man_layer_{slot_id}", default=9.0, min_val=0.0, + max_val=float(layer_max), type="int", group="manual", + description=( + f"Manual slot {slot_id}: DiT inject layer (0..{layer_max}). " + "Passed verbatim to the engine; no automatic offset." + ), + ), + KnobSpec( + f"man_step_{slot_id}", default=0.0, min_val=0.0, + max_val=float(step_max), type="int", group="manual", + description=( + f"Manual slot {slot_id}: diffusion inject step " + f"(0..{step_max}). No fractional mapping. Values past the " + "current steps_override - 1 silently no-op (the engine only " + "fires when step equals the active diffusion step)." + ), + ), + KnobSpec( + f"man_alpha_{slot_id}", default=0.0, + min_val=-STEERING_ALPHA_MAX, max_val=STEERING_ALPHA_MAX, + group="manual", + description=( + f"Manual slot {slot_id}: injection strength. 0 = slot off. " + "Bipolar: negative alpha inverts the chosen vector's " + "direction at injection (no sign correction is applied). " + "Useful magnitude roughly 2..15 by ear; breakage above that." + ), + ), + ] + + def knob_catalog(sde: bool, loras=None) -> dict: """Project the full registry into a transport-agnostic catalog: ``name -> {type, default, min?, max, group, options?, description?, diff --git a/acestep/streaming/session.py b/acestep/streaming/session.py index a3b87df7..b3acaa06 100644 --- a/acestep/streaming/session.py +++ b/acestep/streaming/session.py @@ -73,12 +73,14 @@ from acestep.streaming.commands import CommandOrigin from acestep.streaming.config import SessionConfig from acestep.streaming.encode import blend_for_strength, encode_cond_pair +from acestep.steering import CapacityError, EmptyError from acestep.streaming.events import ( AudioReady, CommandFailed, DepthApplied, EventBus, LoraCatalogUpdate, + ManualSlotCount, ParamsEcho, PromptApplied, PromptBlendEcho, @@ -462,11 +464,20 @@ def __init__( # Cached {name: KnobSpec} map for hot-path validation in set_knobs. # knob_specs() rebuilds 34 dataclasses, so we never call it per tick; - # rebuilt only when the LoRA set changes (see _apply_lora_pending). + # rebuilt only when the LoRA set changes (see _apply_lora_pending) + # or a manual steering slot is added/popped. # Reassigned wholesale (atomic ref swap), so set_knobs can read it # without a lock from the dispatch thread. self._rebuild_knob_specs(self.initial_enable_ids) + # The backend's manifest can extend the shared registry set the + # KnobState was seeded from (today: the activation-steering + # knobs, present only when the checkpoint has a vector bundle). + # Seed any such knob so snapshot ``knob_values`` stays complete + # from t=0 — add_knob is a no-op for already-seeded names. + for spec in self._knob_specs_by_name.values(): + self.virtual_knobs.add_knob(spec) + def _rebuild_knob_specs(self, lora_ids: list) -> None: # Backend-owned manifest: which specs the LoRA set expands to is # family knowledge behind the seam; the session only tracks the @@ -524,6 +535,24 @@ def knob_manifest_payload(self) -> dict: "knobs": catalog_from_specs(self._knob_specs_by_name.values()), } + def steering_payload(self) -> dict: + """Wire-shaped activation-steering block, shared by the ``ready`` + frame and the snapshot. ``steering_available`` mirrors the + backend's ``steering`` capability bit; the count/cap fields drive + the client's manual-slot row rendering and +/- enablement.""" + ctl = getattr(self.backend, "steering", None) + if ctl is None: + return { + "manual_slot_count": 0, + "manual_slot_cap": 0, + "steering_available": False, + } + return { + "manual_slot_count": ctl.slot_count, + "manual_slot_cap": ctl.slot_cap, + "steering_available": bool(ctl.is_loaded), + } + def lora_catalog_payload(self) -> list: """Wire-shaped LoRA catalog for the active engine. Empty list when LoRA isn't available on this backend.""" @@ -589,6 +618,9 @@ def snapshot(self) -> dict: "geometry": self.geometry_payload(), "capabilities": self.capabilities_payload(), "knob_manifest": self.knob_manifest_payload(), + # Activation-steering surface (count drives manual-slot row + # rendering; available=False hides the steering tiles). + **self.steering_payload(), } # ---- Runner lifecycle ---------------------------------------------- @@ -1498,6 +1530,67 @@ def disable_lora( origin.value, lora_id, ) + @requires_capability("steering", "manual_slot_add") + def manual_slot_add( + self, + *, + origin: CommandOrigin = CommandOrigin.PRIMARY, + ) -> None: + """Allocate the next manual steering slot (LIFO). + + The controller is the primary write; KnobState and the cached + spec map mirror it so the four ``man_*_`` knobs validate and + snapshot from the moment the slot exists. No GPU work, so it + applies inline (no pending queue). Publishes + :class:`ManualSlotCount` on success AND refusal — the client's + +/- UI resyncs from the echo either way. + """ + self.state.last_activity_ts = time.monotonic() + ctl = self.backend.steering + try: + new_slot = ctl.add_slot() + except CapacityError: + logger.info( + "manual_slot_add_refused origin={} cap={}", + origin.value, ctl.slot_cap, + ) + else: + self._rebuild_knob_specs(self._enabled_lora_ids()) + names = ctl.knob_names(new_slot) + for name in (names.src, names.layer, names.step, names.alpha): + spec = self._knob_specs_by_name.get(name) + if spec is not None: + self.virtual_knobs.add_knob(spec) + logger.info( + "manual_slot_added origin={} slot={}", origin.value, new_slot, + ) + self.bus.publish(ManualSlotCount(count=ctl.slot_count)) + + @requires_capability("steering", "manual_slot_pop") + def manual_slot_pop( + self, + *, + origin: CommandOrigin = CommandOrigin.PRIMARY, + ) -> None: + """Remove the highest-numbered manual steering slot (LIFO; + interior deletion is not supported). Refusal on an empty + registry still publishes :class:`ManualSlotCount`.""" + self.state.last_activity_ts = time.monotonic() + ctl = self.backend.steering + try: + popped = ctl.pop_slot() + except EmptyError: + logger.info("manual_slot_pop_refused origin={}", origin.value) + else: + names = ctl.knob_names(popped) + for name in (names.src, names.layer, names.step, names.alpha): + self.virtual_knobs.remove_knob(name) + self._rebuild_knob_specs(self._enabled_lora_ids()) + logger.info( + "manual_slot_popped origin={} slot={}", origin.value, popped, + ) + self.bus.publish(ManualSlotCount(count=ctl.slot_count)) + @requires_capability("timbre", "set_timbre_strength") def set_timbre_strength( self, diff --git a/demos/realtime_motion_graph_web/mcp_server.py b/demos/realtime_motion_graph_web/mcp_server.py index 96726793..56c2bf1f 100644 --- a/demos/realtime_motion_graph_web/mcp_server.py +++ b/demos/realtime_motion_graph_web/mcp_server.py @@ -47,14 +47,25 @@ from loguru import logger from mcp.server.fastmcp import FastMCP +from acestep.steering import ( + MANUAL_SLOT_CAP, + ensure_steering_vectors, + enumerate_catalog, +) from acestep.streaming.knobs import ( KNOB_SCHEMA_VERSION, + KnobSpec, coerce_knob_values, knob_catalog, knob_specs, ) from .protocol import coerce_command_payload, wire_contract +# MCP runs as a single global process, so we pre-fetch the canonical +# 2B turbo bundle at module init. Fetch failures leave the cache empty; +# the next streaming session retries. +_MANUAL_VECTOR_DIR = ensure_steering_vectors("acestep-v15-turbo") + # MCP wire protocol owns stdout — every log MUST go to stderr. Lazy # configure so this module stays importable without a hard dependency on @@ -239,6 +250,27 @@ def _waveform_to_audio_bytes(waveform: np.ndarray) -> bytes: return struct.pack(" list[dict]: + """Manual steering catalog flattened to wire-stable dicts.""" + if _MANUAL_VECTOR_DIR is None: + return [] + return [ + { + "index": entry.index, + "axis": entry.axis, + "build_layer": entry.build_layer, + "build_step": entry.build_step, + "filename": entry.filename, + } + for entry in enumerate_catalog(_MANUAL_VECTOR_DIR) + ] + + # --------------------------------------------------------------------------- # Tools — discovery # --------------------------------------------------------------------------- @@ -339,6 +371,19 @@ async def list_knobs(session_id: Optional[str] = None) -> dict: which LoRAs are currently enabled — pulled from the live snapshot. """ snap = await session_state(session_id) + # Prefer the snapshot's backend-owned manifest (Phase 2): it is the + # session's LIVE knob universe, including the backend-specific knobs + # the static registry projection can't reproduce (the steering + # steer_* axes and the per-slot man_*_ quadruples). + manifest = snap.get("knob_manifest") or {} + if isinstance(manifest, dict) and manifest.get("knobs"): + return { + "version": manifest.get("version", KNOB_SCHEMA_VERSION), + "knobs": manifest["knobs"], + "current": snap.get("knob_values") or {}, + } + # Older server snapshot without a manifest: re-derive from the + # shared registry (no steering surface on those servers anyway). sde, enabled = _session_knob_shape(snap) return { "version": KNOB_SCHEMA_VERSION, @@ -362,6 +407,33 @@ def _session_knob_shape(snap: dict) -> tuple[bool, list]: return sde, enabled +def _specs_from_snapshot(snap: dict) -> dict: + """``{name: KnobSpec}`` for a session snapshot. + + Reconstructed from the snapshot's backend-owned ``knob_manifest`` + when present (so validation covers backend-specific knobs like the + steering surface); falls back to the shared registry for older + servers. The manifest is itself a registry projection + (``catalog_from_specs``), so this stays single-source.""" + manifest = (snap.get("knob_manifest") or {}).get("knobs") + if isinstance(manifest, dict) and manifest: + return { + name: KnobSpec( + name=name, + default=e.get("default", 0.0), + min_val=e.get("min"), + max_val=e.get("max", 1.0), + type=e.get("type", "float"), + options=tuple(e.get("options") or ()), + group=e.get("group", "core"), + bank=bool(e.get("bank", True)), + ) + for name, e in manifest.items() + } + sde, enabled = _session_knob_shape(snap) + return {s.name: s for s in knob_specs(sde=sde, loras=enabled)} + + async def _validate_against_session( raw: dict, session_id: Optional[str] ) -> dict: @@ -370,14 +442,53 @@ async def _validate_against_session( any value is out of range or not an allowed enum/bool option. Reuses the same coerce_knob_values the server enforces, so MCP can't drift.""" snap = await session_state(session_id) - sde, enabled = _session_knob_shape(snap) - specs = {s.name: s for s in knob_specs(sde=sde, loras=enabled)} - clean, errors = coerce_knob_values(raw, specs) + clean, errors = coerce_knob_values(raw, _specs_from_snapshot(snap)) if errors: raise ValueError("; ".join(errors)) return clean +@mcp.tool() +async def add_manual_slot(session_id: Optional[str] = None) -> dict: + """Spawn the next manual steering slot (LIFO; alpha defaults to 0). + + Refused (no-op echo) at MANUAL_SLOT_CAP. + """ + _send_cmd(session_id, {"type": "manual_slot_add"}) + snap = await session_state(session_id) + return { + "count": int(snap.get("manual_slot_count") or 0), + "cap": MANUAL_SLOT_CAP, + } + + +@mcp.tool() +async def pop_manual_slot(session_id: Optional[str] = None) -> dict: + """Remove the highest-numbered manual steering slot. + + LIFO; interior deletion is not supported. Refused (no-op echo) on + an empty registry. + """ + _send_cmd(session_id, {"type": "manual_slot_pop"}) + snap = await session_state(session_id) + return { + "count": int(snap.get("manual_slot_count") or 0), + "cap": MANUAL_SLOT_CAP, + } + + +@mcp.tool() +async def list_manual_steering_vectors() -> dict: + """Catalog of pre-built steering vectors for the manual slots. + + Returns ``{"count": N, "vectors": [...]}``. Each entry's ``index`` + is the value to set on ``man_src_``. Order is stable across + sessions (axis-major, then build_layer asc, then build_step asc). + """ + cat = _enumerate_manual_catalog() + return {"count": len(cat), "vectors": cat} + + # --------------------------------------------------------------------------- # Tools — prompt # --------------------------------------------------------------------------- diff --git a/demos/realtime_motion_graph_web/protocol.py b/demos/realtime_motion_graph_web/protocol.py index 57c48c29..ec85cfd9 100644 --- a/demos/realtime_motion_graph_web/protocol.py +++ b/demos/realtime_motion_graph_web/protocol.py @@ -292,6 +292,22 @@ class EventSpec: requires="lora", description="Disable a LoRA and drop its lora_str_ knob.", ), + CommandSpec( + "manual_slot_add", + requires="steering", + description="Allocate the next manual steering slot (LIFO); " + "allocates its four man_*_ knobs. Echoed back as " + "manual_slot_count on success AND refusal (at " + "manual_slot_cap).", + ), + CommandSpec( + "manual_slot_pop", + requires="steering", + description="Remove the highest-numbered manual steering slot and " + "drop its man_*_ knobs (LIFO; interior deletion is " + "not supported). Echoed back as manual_slot_count on " + "success AND refusal (empty registry).", + ), CommandSpec( "set_timbre_strength", fields=(FieldSpec("value", "float", required=True, default=1.0, @@ -439,6 +455,25 @@ class EventSpec: "resolved (SDE mode, enabled lora_str_ " "knobs). /api/knobs remains the static " "pre-session probe."), + # Activation-steering surface. Wire-optional like the + # Phase-2 fields above: absent on backends without the + # steering capability, and the client hides the steering + # tiles when steering_available isn't explicitly true. + FieldSpec("manual_slot_count", "int", + description="Active manual steering slot count; " + "drives the client's man_*_ row " + "rendering. Updated live via the " + "manual_slot_count event."), + FieldSpec("manual_slot_cap", "int", + description="Server-imposed ceiling on manual " + "steering slots; gates the client's " + "+ button."), + FieldSpec("steering_available", "bool", + description="True when the session's checkpoint has " + "a reachable steering-vector bundle; " + "false hides the steering surface (the " + "steer_*/man_* knobs are absent from " + "the manifest too)."), ), binary_follow=True, description="First JSON after the upload handshake, followed by the " @@ -539,6 +574,15 @@ class EventSpec: description="The clamped applied depth."),), description="Ack for set_depth.", ), + EventSpec( + "manual_slot_count", + fields=(FieldSpec("count", "int", required=True, + description="The live manual steering slot " + "count after the command."),), + description="Ack for manual_slot_add / manual_slot_pop — emitted on " + "success and refusal alike so the client's +/- UI " + "resyncs either way.", + ), EventSpec( "timbre_set", fields=(FieldSpec("name", "str", required=True), diff --git a/demos/realtime_motion_graph_web/web/app/globals.css b/demos/realtime_motion_graph_web/web/app/globals.css index 466a2913..99445f6b 100644 --- a/demos/realtime_motion_graph_web/web/app/globals.css +++ b/demos/realtime_motion_graph_web/web/app/globals.css @@ -4193,15 +4193,15 @@ body.curve-open #install-video-area #graph { margin: 0; max-width: 70ch; } -/* Side-by-side columns for the two voice banks. Internal voices on the - left (8 wide), tuned morph on the right (6 wide), thin vertical - divider between them — the inShaper "shaper-1 / shaper-2" two-pane - layout in miniature. */ +/* Two-column grid shared by every row inside the voice tile (channels + on row 1, steering on row 2 when present). Grid (not flex) so the + columns auto-size to the widest content across rows — that's what + makes "manual steering" line up with "channel groups" above. */ .voice-sections-row { - display: flex; - flex-direction: row; - flex-wrap: wrap; - gap: 14px; + display: grid; + grid-template-columns: auto 1px auto; + row-gap: 14px; + column-gap: 14px; align-items: stretch; } .voice-section { @@ -4214,8 +4214,7 @@ body.curve-open #install-video-area #graph { width: 1px; align-self: stretch; background: var(--frame-line); - margin: 4px 4px; - flex: 0 0 auto; + margin: 4px 0; } .voice-section-label { font-family: var(--font-mono); @@ -4225,6 +4224,7 @@ body.curve-open #install-video-area #graph { color: var(--text-dim); } + /* A tile is a labeled card that groups related controls. Tiles are sized to their content (no fixed widths beyond what their inner channels and per-tile rules dictate), which lets them wrap densely. */ diff --git a/demos/realtime_motion_graph_web/web/components/Performance/AdvancedDrawer.tsx b/demos/realtime_motion_graph_web/web/components/Performance/AdvancedDrawer.tsx index 7fc0d068..4ffcdbde 100644 --- a/demos/realtime_motion_graph_web/web/components/Performance/AdvancedDrawer.tsx +++ b/demos/realtime_motion_graph_web/web/components/Performance/AdvancedDrawer.tsx @@ -265,7 +265,7 @@ const SPREAD_SECTIONS: Array<{ id: DrawerTab; label: string }> = [ { id: "core", label: "Core" }, { id: "styles", label: "Styles" }, { id: "mod", label: "Mod" }, - { id: "voice", label: "Channels" }, + { id: "voice", label: "Experimental" }, { id: "config", label: "Config" }, ]; diff --git a/demos/realtime_motion_graph_web/web/components/Performance/DrawerTabs.tsx b/demos/realtime_motion_graph_web/web/components/Performance/DrawerTabs.tsx index fe6a99bc..2bebeedd 100644 --- a/demos/realtime_motion_graph_web/web/components/Performance/DrawerTabs.tsx +++ b/demos/realtime_motion_graph_web/web/components/Performance/DrawerTabs.tsx @@ -27,7 +27,7 @@ export type DrawerTab = (typeof DRAWER_TABS)[number]; const TAB_LABELS: Record = { core: "Core", mod: "Mod", - voice: "Channels", + voice: "Experimental", styles: "Styles", // Auto-generated control surface, rendered straight from the backend // /api/knobs manifest. Reference template for a re-skinned UI. diff --git a/demos/realtime_motion_graph_web/web/components/Performance/MobileFullSheet.tsx b/demos/realtime_motion_graph_web/web/components/Performance/MobileFullSheet.tsx index 084fc5ba..acc05387 100644 --- a/demos/realtime_motion_graph_web/web/components/Performance/MobileFullSheet.tsx +++ b/demos/realtime_motion_graph_web/web/components/Performance/MobileFullSheet.tsx @@ -29,7 +29,7 @@ const TABS: { id: Tab; label: string }[] = [ { id: "core", label: "Core" }, { id: "styles", label: "Styles" }, { id: "mod", label: "Mod" }, - { id: "voice", label: "Channels" }, + { id: "voice", label: "Experimental" }, { id: "saved", label: "Saved" }, { id: "config", label: "Config" }, ]; diff --git a/demos/realtime_motion_graph_web/web/components/Performance/ModTile.tsx b/demos/realtime_motion_graph_web/web/components/Performance/ModTile.tsx index 7f64d10f..ae3d4e97 100644 --- a/demos/realtime_motion_graph_web/web/components/Performance/ModTile.tsx +++ b/demos/realtime_motion_graph_web/web/components/Performance/ModTile.tsx @@ -2,7 +2,12 @@ import { usePerformanceStore } from "@/store/usePerformanceStore"; import { useSessionStore } from "@/store/useSessionStore"; -import { DCW_MODES, DCW_WAVELETS, RCFG_MODES, type RcfgMode } from "@/types/engine"; +import { + DCW_MODES, + DCW_WAVELETS, + RCFG_MODES, + type RcfgMode, +} from "@/types/engine"; import { Knob } from "./Knob"; import { defaultLabelFor, kbdHintFor } from "./SliderTile"; diff --git a/demos/realtime_motion_graph_web/web/components/Performance/SliderTile.tsx b/demos/realtime_motion_graph_web/web/components/Performance/SliderTile.tsx index 6d5c38e4..7aa2a7a1 100644 --- a/demos/realtime_motion_graph_web/web/components/Performance/SliderTile.tsx +++ b/demos/realtime_motion_graph_web/web/components/Performance/SliderTile.tsx @@ -63,6 +63,18 @@ const PARAM_TOOLTIPS: Record = { cfg_rescale: "After CFG, mix the guided velocity's magnitude back toward what the positive forward produced. 0 keeps raw CFG; 1 fully snaps the magnitude. Pair with high guidance_scale to keep the prompt-push without the harshness that high CFG causes on its own.", + // ── Activation steering (auto path) ── + // Each tooltip names the underlying probe cell so the operator can + // recreate the effect on a manual slot. + steer_bright: + "Activation-steering: positive alpha shifts spectral centroid up (brighter, more highs). 0 = off; useful range 5-15 by ear. Recreate as a manual slot: vector brightness_l09_t3 at layer = 9, step = round(3/8 x steps_count).", + steer_warm: + "Activation-steering: positive alpha tilts the spectrum toward bass (warmer). The raw vector points the wrong way for this axis, so this knob folds in a -1 sign. 0 = off; useful range 5-15 by ear. Recreate as a manual slot: vector warmth_l15_t0 at layer = 15, step = 0, then INVERT alpha sign (manual mode is sign-agnostic).", + steer_rough: + "Activation-steering: positive alpha increases spectral flatness (grittier, noisier). Vector magnitude at this probe cell is small, so effect builds slowly. 0 = off; useful range 5-15 by ear. Recreate as a manual slot: vector roughness_l09_t3 at layer = 9, step = round(3/8 x steps_count).", + steer_density: + "Activation-steering: positive alpha thins the texture toward sparse/minimal. Inject layer is shifted 3 shallower than the probe layer (Phase-3 transfer finding). 0 = off; useful range 5-15 by ear. Recreate as a manual slot: vector density_l18_t3 at layer = 15 (probe 18 minus 3), step = round(3/8 x steps_count).", + // ── DCW ── dcw_scaler: "Experimental — adjusts the low-band strength of an internal correction the model applies to itself during generation (DCW). This scaler is active in the early part of the run. The exact audio mapping is still being explored — sweep it to discover what it does for your source. Extreme values can be unpredictable but cool.", @@ -97,6 +109,19 @@ export function tooltipFor(param: string): string | undefined { if (param === "lora_blend") { return "Crossfade between LoRA A and LoRA B. 0 = A only, 1 = B only, 0.5 = both at half strength. Use this to morph between two styles smoothly."; } + // Manual steering tooltips share copy across all slots. + if (param.startsWith("man_src_")) { + return "Catalog index of the steering vector this slot fires. The catalog enumerates every pre-built (axis, build_layer, build_step) cell on disk in stable axis-major order. Double-click the readout to type an exact index; query the MCP list_manual_steering_vectors tool for the full table. Has no effect until α is non-zero."; + } + if (param.startsWith("man_layer_")) { + return "DiT inject layer (0-23). The vector is added to this layer's post-block residual. Bypasses the auto path's density layer offset — the value lands exactly where you point it."; + } + if (param.startsWith("man_step_")) { + return "Diffusion inject step (0-15). Bypasses the auto path's fractional step mapping; the engine fires the injection only on the step that matches this value. If you pick a step past the current steps count - 1, the slot stays silent until you raise the step count."; + } + if (param.startsWith("man_alpha_")) { + return "Strength of this manual slot's injection. 0 disables the slot. Negative α inverts the vector's direction at injection time (no sign correction is applied; what you set is what the engine receives). Sweep range and breakage point mirror the perceptual steering knobs."; + } return PARAM_TOOLTIPS[param]; } diff --git a/demos/realtime_motion_graph_web/web/components/Performance/VoiceTile.tsx b/demos/realtime_motion_graph_web/web/components/Performance/VoiceTile.tsx index 7beeb478..17bf11ea 100644 --- a/demos/realtime_motion_graph_web/web/components/Performance/VoiceTile.tsx +++ b/demos/realtime_motion_graph_web/web/components/Performance/VoiceTile.tsx @@ -3,7 +3,9 @@ import { useEffect, useState } from "react"; import { useConfig } from "@/lib/config"; +import { useSessionStore } from "@/store/useSessionStore"; +import { Knob } from "./Knob"; import { SliderGroup } from "./SliderGroup"; import { defaultLabelFor, kbdHintFor } from "./SliderTile"; @@ -31,6 +33,15 @@ const MORPH = ["ch13", "ch14", "ch19", "ch23", "ch29", "ch56"]; export function VoiceTile() { const ranges = useConfig().channel_ranges; + const manualSlotCount = useSessionStore((s) => s.manualSlotCount); + const manualSlotCap = useSessionStore((s) => s.manualSlotCap); + const steeringAvailable = useSessionStore((s) => s.steeringAvailable); + const remote = useSessionStore((s) => s.remote); + const slotCount = manualSlotCount ?? 0; + const slotCap = manualSlotCap ?? 0; + const canAddSlot = remote !== null && slotCap > 0 && slotCount < slotCap; + const canPopSlot = remote !== null && slotCount > 0; + const showSteering = steeringAvailable === true; // Experimental-feature notice — dismissable, and the dismissal sticks // across reloads. Read after mount (not in the useState initializer) // so a localStorage read can't break SSR hydration. @@ -69,9 +80,12 @@ export function VoiceTile() {

)} + {/* Two-column grid shared by the channel rows and the steering + rows so column 1 (highlights / steering) and column 2 (groups + / manual steering) line up across rows. */}
-
Highlights
+
channel highlights
{MORPH.map((p) => { const r = ranges[p]; @@ -92,7 +106,7 @@ export function VoiceTile() {
); diff --git a/demos/realtime_motion_graph_web/web/hooks/useStartSession.ts b/demos/realtime_motion_graph_web/web/hooks/useStartSession.ts index 7afb0eb9..56fef257 100644 --- a/demos/realtime_motion_graph_web/web/hooks/useStartSession.ts +++ b/demos/realtime_motion_graph_web/web/hooks/useStartSession.ts @@ -193,12 +193,22 @@ function wireRemoteListeners( // ungated). The hand-coded swap/timbre/structure/LoRA panels read // it via useCapability. s.setCapabilities(remote.capabilities); + // Activation-steering surface (null on servers without it = tiles + // hidden). VoiceTile reads these to render the steering racks. + s.setManualSlotCount(remote.manualSlotCount); + s.setManualSlotCap(remote.manualSlotCap); + s.setSteeringAvailable(remote.steeringAvailable); }); remote.addEventListener("depth_applied", (e) => { useSessionStore .getState() .setPipelineDepth((e as CustomEvent).detail); }); + remote.addEventListener("manual_slot_count", (e) => { + useSessionStore + .getState() + .setManualSlotCount((e as CustomEvent).detail); + }); // WS startup telemetry. The SDK only observes its own connection // (wsTrace + the optional init_ack echo); persisting the latest trace // for the debug surface is the app's job. Attached before connect() diff --git a/demos/realtime_motion_graph_web/web/sdk/protocol.ts b/demos/realtime_motion_graph_web/web/sdk/protocol.ts index 86f6f5a2..ea2eb25a 100644 --- a/demos/realtime_motion_graph_web/web/sdk/protocol.ts +++ b/demos/realtime_motion_graph_web/web/sdk/protocol.ts @@ -38,6 +38,8 @@ import type { DisableLoraCommand, EnableLoraCommand, LoopBandCommand, + ManualSlotAddCommand, + ManualSlotPopCommand, ParamsCommand, PromptCommand, SetDepthCommand, @@ -221,6 +223,15 @@ export class RemoteBackend extends EventTarget { * on older servers / replays; `/api/knobs` remains the static * pre-session probe. */ knobManifest: KnobManifestResponse | null = null; + /** Active manual steering slot count, mirrored from the server + * (`ready` + `manual_slot_count` echoes). Null until ready / on + * servers without the steering surface. */ + manualSlotCount: number | null = null; + /** Server-imposed cap on manual steering slots. Null until ready. */ + manualSlotCap: number | null = null; + /** Whether the session's checkpoint has steering vectors. The host + * hides the steering tiles when false. Null until ready. */ + steeringAvailable: boolean | null = null; /** Browser-observed WS lifecycle for this concrete connection attempt. */ wsTrace: WsTrace; /** Pod-side session id from the optional init_ack telemetry message. */ @@ -448,6 +459,20 @@ export class RemoteBackend extends EventTarget { (msg.capabilities as CapabilityMask | undefined) ?? null; this.knobManifest = (msg.knob_manifest as KnobManifestResponse | undefined) ?? null; + // Activation-steering surface. Wire-optional like the + // Phase-2 fields: null hides the steering tiles host-side. + this.manualSlotCount = + typeof msg.manual_slot_count === "number" + ? msg.manual_slot_count + : null; + this.manualSlotCap = + typeof msg.manual_slot_cap === "number" + ? msg.manual_slot_cap + : null; + this.steeringAvailable = + typeof msg.steering_available === "boolean" + ? msg.steering_available + : null; // Scale + depth bounds are exposed as instance fields; the host // app mirrors them into its own state from the "ready" event // listener (the SDK never writes app stores). @@ -651,6 +676,16 @@ export class RemoteBackend extends EventTarget { new CustomEvent("command_failed", { detail: msg }), ); break; + case "manual_slot_count": { + // Echoed after manual_slot_add / manual_slot_pop (success + // or refusal). The host mirrors it into its own state. + const v = typeof msg.count === "number" ? msg.count : null; + this.manualSlotCount = v; + this.dispatchEvent( + new CustomEvent("manual_slot_count", { detail: v }), + ); + break; + } default: this.dispatchEvent(new CustomEvent("json", { detail: msg })); } @@ -930,6 +965,25 @@ export class RemoteBackend extends EventTarget { } catch {} } + /** Add the next manual steering slot (LIFO). Server echoes + * ``manual_slot_count`` on success or refusal. */ + sendManualSlotAdd(): void { + if (this.ws?.readyState !== WebSocket.OPEN) return; + try { + const msg: ManualSlotAddCommand = { type: "manual_slot_add" }; + this.ws.send(JSON.stringify(msg)); + } catch {} + } + + /** Pop the highest-numbered manual steering slot. */ + sendManualSlotPop(): void { + if (this.ws?.readyState !== WebSocket.OPEN) return; + try { + const msg: ManualSlotPopCommand = { type: "manual_slot_pop" }; + this.ws.send(JSON.stringify(msg)); + } catch {} + } + /** * Live timbre-strength knob. Backend keeps a cached * (cond_silence, cond_full) pair and lerp-blends their encoder hidden diff --git a/demos/realtime_motion_graph_web/web/sdk/types/wireContract.gen.ts b/demos/realtime_motion_graph_web/web/sdk/types/wireContract.gen.ts index ef1ba721..cf1c9dc6 100644 --- a/demos/realtime_motion_graph_web/web/sdk/types/wireContract.gen.ts +++ b/demos/realtime_motion_graph_web/web/sdk/types/wireContract.gen.ts @@ -29,6 +29,8 @@ export type CommandName = | "set_depth" | "enable_lora" | "disable_lora" + | "manual_slot_add" + | "manual_slot_pop" | "set_timbre_strength" | "set_timbre_source" | "set_timbre_fixture" @@ -47,6 +49,8 @@ export const COMMAND_NAMES: readonly CommandName[] = [ "set_depth", "enable_lora", "disable_lora", + "manual_slot_add", + "manual_slot_pop", "set_timbre_strength", "set_timbre_source", "set_timbre_fixture", @@ -71,6 +75,7 @@ export type EventName = | "stem_assets" | "stem_failed" | "depth_applied" + | "manual_slot_count" | "timbre_set" | "timbre_cleared" | "timbre_failed" @@ -93,6 +98,7 @@ export const EVENT_NAMES: readonly EventName[] = [ "stem_assets", "stem_failed", "depth_applied", + "manual_slot_count", "timbre_set", "timbre_cleared", "timbre_failed", @@ -172,6 +178,14 @@ export interface DisableLoraCommand { id: string; } +export interface ManualSlotAddCommand { + type: "manual_slot_add"; +} + +export interface ManualSlotPopCommand { + type: "manual_slot_pop"; +} + export interface SetTimbreStrengthCommand { type: "set_timbre_strength"; /** 1.0 = full reference, 0.0 = silence baseline. Clamped to [0,1]. */ @@ -258,6 +272,12 @@ export interface ReadyEvent { capabilities?: Record; /** Per-session knob manifest: the same {version, knobs} envelope GET /api/knobs serves, but backend-owned and session-resolved (SDE mode, enabled lora_str_ knobs). /api/knobs remains the static pre-session probe. */ knob_manifest?: Record; + /** Active manual steering slot count; drives the client's man_*_ row rendering. Updated live via the manual_slot_count event. */ + manual_slot_count?: number; + /** Server-imposed ceiling on manual steering slots; gates the client's + button. */ + manual_slot_cap?: number; + /** True when the session's checkpoint has a reachable steering-vector bundle; false hides the steering surface (the steer_*\/man_* knobs are absent from the manifest too). */ + steering_available?: boolean; } export interface ErrorEvent { @@ -336,6 +356,12 @@ export interface DepthAppliedEvent { value: number; } +export interface ManualSlotCountEvent { + type: "manual_slot_count"; + /** The live manual steering slot count after the command. */ + count: number; +} + export interface TimbreSetEvent { type: "timbre_set"; name: string; @@ -444,6 +470,8 @@ export type WireCommand = | SetDepthCommand | EnableLoraCommand | DisableLoraCommand + | ManualSlotAddCommand + | ManualSlotPopCommand | SetTimbreStrengthCommand | SetTimbreSourceCommand | SetTimbreFixtureCommand @@ -467,6 +495,7 @@ export type WireEvent = | StemAssetsEvent | StemFailedEvent | DepthAppliedEvent + | ManualSlotCountEvent | TimbreSetEvent | TimbreClearedEvent | TimbreFailedEvent diff --git a/demos/realtime_motion_graph_web/web/store/usePerformanceStore.ts b/demos/realtime_motion_graph_web/web/store/usePerformanceStore.ts index da967b59..7eb01c50 100644 --- a/demos/realtime_motion_graph_web/web/store/usePerformanceStore.ts +++ b/demos/realtime_motion_graph_web/web/store/usePerformanceStore.ts @@ -17,6 +17,11 @@ import { type TimeSignature, } from "@/types/engine"; +// Pre-seed defaults for every potential manual steering slot so reset / +// snapback / curve hooks never miss a key when the server allocates a +// new slot at runtime. Must match the prereg ceiling in engine.ts. +const MANUAL_SLOT_PREREG = 16; + // Top-level performance state. Mirrors app.js's module-level vars // (sliderValues, seedValue, blendValue, activeKey, prompts, fixture, mode, // kiosk). LoRA strength sliders live in useLoraStore. @@ -259,6 +264,11 @@ const DEFAULT_SLIDER_VALUES: Record = { ch23: 1.0, ch29: 1.0, ch56: 1.0, + steer_bright: 0.0, + steer_warm: 0.0, + steer_rough: 0.0, + steer_density: 0.0, + // Manual steering slot defaults are filled in below the object literal. dcw_scaler: 0.05, dcw_high_scaler: 0.02, dcw_mult_blend: 0.0, @@ -272,6 +282,13 @@ const DEFAULT_SLIDER_VALUES: Record = { steps_override: 8, }; +for (let slot = 1; slot <= MANUAL_SLOT_PREREG; slot++) { + DEFAULT_SLIDER_VALUES[`man_src_${slot}`] = 0; + DEFAULT_SLIDER_VALUES[`man_layer_${slot}`] = 9; + DEFAULT_SLIDER_VALUES[`man_step_${slot}`] = 0; + DEFAULT_SLIDER_VALUES[`man_alpha_${slot}`] = 0.0; +} + /** A re-applyable record of an active timbre / structure reference. * `timbreName` / `structName` are the server-acked DISPLAY name * (cleared whenever the session leaves "ready"); this is the client's diff --git a/demos/realtime_motion_graph_web/web/store/useSessionStore.ts b/demos/realtime_motion_graph_web/web/store/useSessionStore.ts index 42a125c7..3567def3 100644 --- a/demos/realtime_motion_graph_web/web/store/useSessionStore.ts +++ b/demos/realtime_motion_graph_web/web/store/useSessionStore.ts @@ -63,6 +63,15 @@ interface SessionState { lastBackendSessionId: string | null; /** Client id echoed by init_ack. */ lastBackendClientId: string | null; + /** Active manual steering slot count, mirrored from the server. + * Null until ready. */ + manualSlotCount: number | null; + /** Server-imposed cap on manual steering slots. Drives the +/- enable + * state in ModTile. Null until ready. */ + manualSlotCap: number | null; + /** Whether the session's checkpoint has steering vectors. When false, + * ModTile hides both steering tiles. Null until ready. */ + steeringAvailable: boolean | null; setStatus: (status: SessionStatus, message?: string) => void; setSession: (remote: RemoteBackend | null, player: AudioPlayer | null) => void; @@ -76,6 +85,9 @@ interface SessionState { setLastWsTrace: (trace: WsTrace | null) => void; setLastBackendSessionId: (id: string | null) => void; setLastBackendClientId: (id: string | null) => void; + setManualSlotCount: (count: number | null) => void; + setManualSlotCap: (cap: number | null) => void; + setSteeringAvailable: (available: boolean | null) => void; reset: () => void; } @@ -94,6 +106,9 @@ export const useSessionStore = create((set, get) => ({ lastWsTrace: null, lastBackendSessionId: null, lastBackendClientId: null, + manualSlotCount: null, + manualSlotCap: null, + steeringAvailable: null, setStatus: (status, message = "") => set({ status, message }), setSession: (remote, player) => set({ remote, player }), @@ -107,6 +122,9 @@ export const useSessionStore = create((set, get) => ({ setLastWsTrace: (trace) => set({ lastWsTrace: trace }), setLastBackendSessionId: (id) => set({ lastBackendSessionId: id }), setLastBackendClientId: (id) => set({ lastBackendClientId: id }), + setManualSlotCount: (count) => set({ manualSlotCount: count }), + setManualSlotCap: (cap) => set({ manualSlotCap: cap }), + setSteeringAvailable: (available) => set({ steeringAvailable: available }), reset: () => { try { get().monitor?.stop(); @@ -129,6 +147,9 @@ export const useSessionStore = create((set, get) => ({ lastWsTrace: null, lastBackendSessionId: null, lastBackendClientId: null, + manualSlotCount: null, + manualSlotCap: null, + steeringAvailable: null, // checkpointScale survives reset on purpose: the server's // checkpoint doesn't change across sessions, and pre-fetching // it from /api/loras lets the library filter render correctly diff --git a/demos/realtime_motion_graph_web/web/types/engine.ts b/demos/realtime_motion_graph_web/web/types/engine.ts index a4ad2bb1..6dc7469f 100644 --- a/demos/realtime_motion_graph_web/web/types/engine.ts +++ b/demos/realtime_motion_graph_web/web/types/engine.ts @@ -74,6 +74,16 @@ export const SLIDER_META: Record = { ch29: { max: 3.0, step: 0.15, pro: true }, ch56: { max: 3.0, step: 0.15, pro: true }, + // Activation-steering. ±30 range so the operator can invert the axis + // without leaving the surface. Useful magnitude ~5..15 by ear. + steer_bright: { min: -30.0, max: 30.0, step: 0.5, pro: true }, + steer_warm: { min: -30.0, max: 30.0, step: 0.5, pro: true }, + steer_rough: { min: -30.0, max: 30.0, step: 0.5, pro: true }, + steer_density: { min: -30.0, max: 30.0, step: 0.5, pro: true }, + + // Manual steering slots (man_*_) are pre-registered in the loop + // below so a runtime slot add doesn't have to mutate SLIDER_META. + // DCW (wavelet-domain post-step correction). Numeric knobs only; the // boolean ON/OFF + mode + wavelet choices live in their own panel state. // @@ -92,6 +102,19 @@ export const SLIDER_META: Record = { dcw_soft_thresh: { max: 0.3, step: 0.01, pro: true }, }; +// SLIDER_META pre-reg ceiling. Authoritative cap is the server's +// ``manual_slot_cap``; raising the server cap above this means bumping +// this number too. +const MANUAL_SLOT_PREREG = 16; +for (let slot = 1; slot <= MANUAL_SLOT_PREREG; slot++) { + SLIDER_META[`man_src_${slot}`] = { min: 0, max: 143, step: 1, pro: true }; + SLIDER_META[`man_layer_${slot}`] = { min: 0, max: 23, step: 1, pro: true }; + SLIDER_META[`man_step_${slot}`] = { min: 0, max: 15, step: 1, pro: true }; + SLIDER_META[`man_alpha_${slot}`] = { + min: -30.0, max: 30.0, step: 0.5, pro: true, + }; +} + export const DCW_MODES = ["low", "high", "double", "pix"] as const; export const DCW_WAVELETS = ["haar", "db4", "sym8", "db8"] as const; export type DcwMode = (typeof DCW_MODES)[number]; diff --git a/demos/realtime_motion_graph_web/ws_adapter.py b/demos/realtime_motion_graph_web/ws_adapter.py index aa6c1b98..f40f7169 100644 --- a/demos/realtime_motion_graph_web/ws_adapter.py +++ b/demos/realtime_motion_graph_web/ws_adapter.py @@ -60,6 +60,7 @@ CommandFailed, DepthApplied, LoraCatalogUpdate, + ManualSlotCount, ParamsEcho, PromptApplied, PromptBlendEcho, @@ -574,6 +575,8 @@ def on_event(event) -> None: _send_json({"type": "lora_catalog", "catalog": event.catalog}) elif isinstance(event, DepthApplied): _send_json({"type": "depth_applied", "value": event.value}) + elif isinstance(event, ManualSlotCount): + _send_json({"type": "manual_slot_count", "count": event.count}) elif isinstance(event, CommandFailed): _send_json({ "type": "command_failed", @@ -638,6 +641,9 @@ def on_event(event) -> None: "geometry": streaming.geometry_payload(), "capabilities": streaming.capabilities_payload(), "knob_manifest": streaming.knob_manifest_payload(), + # Activation-steering surface (manual_slot_count / + # manual_slot_cap / steering_available). + **streaming.steering_payload(), })) ws.send(src_np.astype(np.float16).tobytes()) if streaming.initial_upload_stems is not None: @@ -788,6 +794,10 @@ def _dispatch_message( lid = data.get("id") if lid: streaming.disable_lora(str(lid), origin=origin) + elif mtype == "manual_slot_add": + streaming.manual_slot_add(origin=origin) + elif mtype == "manual_slot_pop": + streaming.manual_slot_pop(origin=origin) elif mtype == "set_timbre_strength": try: v = float(data.get("value", 1.0))