Skip to content

Stable Audio 3 as a ModelAdapter family#231

Draft
ryanontheinside wants to merge 7 commits into
mainfrom
ryanontheinside/feat/models/sa3-on-backend
Draft

Stable Audio 3 as a ModelAdapter family#231
ryanontheinside wants to merge 7 commits into
mainfrom
ryanontheinside/feat/models/sa3-on-backend

Conversation

@ryanontheinside

@ryanontheinside ryanontheinside commented Jun 6, 2026

Copy link
Copy Markdown
Collaborator

DRAFT: Stable Audio 3 as a ModelAdapter family

Merge order

Draft, stacked: opened against ryanontheinside/feat/models/backend. Merge #224 -> #220 -> models/backend before this.
Stack: main <- #224 <- #220 <- models/backend <- this (draft). Sibling leaf: mrt2 (draft, same parent).
Heads-up for whoever merges the second leaf: families.py will conflict (both leaves register their family in the same dicts); streaming/session.py resolves clean apart from one comment line.

Stability AI's Stable Audio 3 (small-music / medium-music) as a diffusion family behind the ModelAdapter seam, sharing StreamPipeline and the runner with ACE. This is the rectified-flow proof of the seam: nothing forks, the adapter seats SA3's schedule/noise/forward inside the existing pipeline.

What's here (seven commits)

7e2275e SA3 family. sa3_adapter.py (ModelAdapter), sa3_context.py (one loaded model per model_id per process, shared across sessions; includes SA3SAMECodec: full SAME decode per fresh latent at ~11 ms, then one 44.1->48k resample, windows slice the cache), sa3_backend.py (capabilities: refines_audio only; geometry declares delivered 48 kHz), sa3_helpers.py (import bridge to the scripts/sa3 helpers, no duplication).

74b494d Serving layer. create_sa3_session in families.SESSION_CREATORS plus acestep/streaming/sa3_session.py: process-cached context, source SAME-encoded once as the audio-to-audio anchor (continuity mechanism: every emit is a partial-denoise cover of that latent), conditioning prepared once per create, prompt re-captures via handle_set_prompt. Checkpoint aliases (sa3-small -> small-music) via families.resolve_checkpoint. /sa3 web page on the @demon/client SDK, capability-gated.

a633f66 Seeded per-slot SDE renoise (opt-in). SlotRequest.sde_noise_seeded: the bare-SDE renoise draws from a per-slot generator seeded by the request seed, so the SDE trajectory replays exactly per (seed, denoise, schedule) and advancing playback windows splice into one coherent song. Default off. Note for reviewers: the unseeded path must stay randn_like, NOT randn(shape). xt is a transposed non-contiguous view, and the two arrange identical global-RNG draws differently, which silently changes every ACE SDE trajectory. This was caught by the parity rail during the restack and fixed in this commit; the parity suite now passes bit-identical.

c5c336c medium-music. TRT runtimes for the DiT and the SAME-L windowed decoder, plus defaults from live debugging: sampler is pingpong (SA3 checkpoints are rf_denoiser; Euler at full denoise is audibly broken), SA3 wire slice is 3.0 s (never ACE's 0.36 vae_window). At this commit the TRT DiT was opt-in via DEMON_SA3_TRT_DIT=1 behind an undiagnosed real-conditioning parity gap (cos 0.80-0.97/step vs eager); the two commits below diagnose and remove it.

9df7292 Convention alignment. Adopts the mrt2 branch's integration mechanics with textually identical session.py hunks (SESSION_CREATORS dispatch, prompt hooks, cond_pair=None sentinel, optional engine_session) so the two leaves merge cleanly. The create factory uses the same ExitStack/pop_all shape as the ACE path; the process-cached SA3Context deliberately does NOT register on it (closing it on a failed create would break the cache for every later session).

f00557b fp16mixed TRT DiT + builder into the package. The parity gap was the BF16 build recipe upstream explicitly rejects (its quantization error compounds over the 8 pingpong steps); rebuilding from the pre-surgered dit_fp16mixed.onnx (FP16 trunk, FP32 islands around RMSNorm / softmax / RoPE) as a STRONGLY_TYPED network restores per-step cos >= 0.9998 vs eager on real conditioning, matching upstream's documented numbers. The bounded residual (cos 0.991-0.999/step at durations whose padded window has masked tail frames) is the engine's missing padding_mask input, structural and documented in the parity script, not numerics. Engine creation moves to acestep.engine.trt.sa3_build, holding the ACE builder's shape: shared preflight, metadata-sidecar skip/rebuild gates, canonical --all matrix, build report; the scripts/sa3 build scripts become thin forwarders.

b593d82 Standard accel params. With parity closed, the DEMON_SA3_TRT_DIT opt-in is gone: SA3 sessions consume the same decoder_backend / vae_backend values as ACE (server --accel). tensorrt selects the built DiT engine, the TRT SAME-L window codec, and the engine-cap duration clamp; compile degrades loudly to eager (SA3 has no torch.compile path); eager stays fully eager. Threaded create_sa3_session -> backend_init -> families._make_sa3 -> SA3Backend.from_context(dit_backend, codec_backend) -> SA3Context.make_dit / make_codec.

Status / why draft

Audio confirmed by ear on small and medium through the real web app; TRT DiT parity diagnosed and closed (see f00557b). Remaining before ready-for-review: golden-harness scenarios for the family and pod deployment defaults.

Validation (2026-06-06, restacked tip)

169 unit tests passing at this tip (SA3 adapter/backend/stream-pipeline suites, the ACE bit-identity parity rail, all inherited guards), web typecheck clean.

SA3 rides the shared StreamPipeline through the Tier-2 seam, built on
the validated spike-branch helpers (ryanontheinside/feat/stable-audio-3,
ported additively into scripts/sa3/ and reused via the sa3_helpers
import bridge — no duplication):

- SA3Adapter: one batched dit(x, t, **stacked_cond) per tick (the
  spike's stack_sa3_cond_bundles), [B,T,C] <-> [B,C,T] transpose at the
  adapter boundary, SA3 build_schedule wiring (sigma_max = denoise).
- SA3Context: load_local_model (bundled-t5gemma local loader) +
  per-prompt prepare_sa3_conditioning, private to the context.
  SA3SAMECodec decodes the FULL latent per fresh generation (SAME-S is
  ~11 ms flat to 60 s) so window renders slice one cached buffer.
- SA3Backend (DiffusionBackend subclass): v1 surface = prompt, fixed
  duration, seed, steps_override, sa3_denoise (init_noise_level; the
  prefix is load-bearing — ACE's denoise is a different control, the
  homonym guard enforces the rule). Capabilities: refines_audio only.
  Delivery resamples 44.1k -> 48k at the decode boundary (decision 2);
  geometry declares the delivered rate, 48000, in v1.
- families: sa3 registered in FAMILIES + FAMILY_KNOB_UNIVERSES (homonym
  test now bites across two real families); checkpoint aliases resolve
  to (backend, model_id) — "xl" unchanged, "sa3-small" -> ("sa3",
  "small-music"); server threads the resolved family into the session
  config default.
- SessionConfig.sa3_duration_s (flat prefixed family field); wire types
  regenerated. StreamingSession.create rejects non-acestep families
  before any model loading — the per-family serving path is canonical
  plan Phase 3.

Validated in-process on the 5090: small-music streams end to end
through the shared pipeline (emit every 2nd tick at depth 4 / steps 8,
~55 ms/tick eager fp16); seam mechanics covered by mock-DiT unit tests
(engine-layout emission, cross-attn padding, source-anchored partial
denoise, hand-rolled Euler equivalence, loud failures); ported spike
unit tests skip their vendor-parity cases when the vendored
stable_audio_3 tree is absent (DEMON_SA3_SRC overrides).
…3 web page

The serving path for sa3 sessions (canonical plan Phase 3), file-disjoint
from the engine seam underneath:

- sa3_session.create_sa3_session: builds a StreamingSession with NO ACE
  engine stack (no engine Session / stream / TRT profile manager / LoRA
  state) — SessionState carries neutral fields, AudioEngine seeds from
  the decoded source anchor, and backend_init stashes the per-family
  construction payload (context, precomputed cond, encoded source) for
  the registry factory. SA3Context is process-cached per model_id (one
  load per process; concurrent first sessions wait on the lock).
- StreamingSession: create() dispatches by family; close() tolerates the
  None engine/stream of non-ACE families; set_prompt dispatches to a
  backend-owned set_prompt when the backend exposes one (SA3 re-runs
  prepare_cond and swaps the bundle; ACE path unchanged).
- families: _make_sa3 registry factory + warmup_policy — the synthetic
  startup warmup is ACE-shaped, so server.py runs it only for
  policy == "ace_trt" and sa3 pays its one-time cost on first session.
- sa3_helpers.require_sa3_vendor: actionable ImportError (DEMON_SA3_SRC
  remedy) instead of a deep ModuleNotFoundError when the vendored
  stable_audio_3 tree is absent; SA3Context calls it at construction.
- web: /sa3 demo page (panel + session hook + styles) over the existing
  ready/slice protocol.
SlotRequest.sde_noise_seeded: when set (and seed is an int), the bare-SDE
renoise draws from a per-slot torch.Generator seeded by the request seed
instead of the global RNG, so the SDE trajectory replays exactly per
(seed, denoise, schedule) — identical consecutive requests emit identical
latents and advancing playback windows splice into one coherent song.
This is the SA3 spike pipeline's per-slot-generator semantics, needed
because rf_denoiser checkpoints must be sampled with pingpong (see the
sa3 commit on top).

Default off; generator=None falls through to the global RNG, so ACE
behavior is byte-identical.
… from live debugging

Medium joins the family: TRT engine wrappers for the spike-built
artifacts, the windowed codec medium needs (eager full decode is ~80 ms
per call — too slow per render tick), and the production defaults the
live web-app debugging session settled.

- sa3_trt.py: SA3TRTDit (batch-1 DiT engine, per-session fixed L +
  duration, persistent bound buffers, cond bundle staged on identity
  change) + SameLWindowTRTDecoder (samel::diff_attn_swa plugin,
  pretransform-scaled latents, int16-PCM engine flavor) + engine
  discovery over <MODELS_DIR>/sa3/trt_engines. Cond staging strips the
  trailing seconds_total token: cross_attn_cond is the
  ["prompt", "seconds_total"] concat (257 tokens) while the engine
  takes 256 raw T5Gemma tokens + the seconds scalar and rebuilds the
  token in-graph.
- SA3Adapter: batch-1 TRT engines loop ring-buffer slots through
  step_bundle (trt_batch1 marker) instead of one stacked forward;
  output buffer materialized per slot.
- SA3Context.make_dit/make_codec: per-session component selection.
  SA3SAMEWindowCodec: 2 s context each side, slice_align=1, engine
  profile floor growth near song edges, eager decode_sa3_latent_window
  fallback. clamp_duration_for_trt keeps medium sessions on the engine
  fast path. families: "sa3-medium" alias.
- Sampler default is PINGPONG with seeded per-slot renoise
  (sde_noise_seeded): SA3 checkpoints are diffusion_objective
  rf_denoiser — upstream samples them with pingpong only, and 8-step
  euler at sa3_denoise=1.0 is audibly degraded (the web-app
  "sounds horrible" root cause). ODE remains as a debug mode.
- SA3_VAE_WINDOW_S = 3.0 wire slices (the reference web demo's value);
  ACE's 0.36 SessionConfig default must not leak into sa3 sessions.
- TRT DiT is opt-in (DEMON_SA3_TRT_DIT=1) for now: it is not
  step-parity-identical to the eager DiT on real conditioning
  (cos 0.80-0.97/step, scripts/sa3/sa3_trt_dit_cond_parity.py —
  the spike only ever speed-benchmarked random tensors); eager is the
  validated-by-ear default until that is diagnosed. The SAME-L window
  decoder stays on.
- scripts/sa3: the engine build scripts + the real-cond DiT parity
  rail. The spike's one-off benches/profilers/sweeps stay untracked.
Functional no-op. The parallel mrt2 branch
(ryanontheinside/feat/models/mrt2-on-backend) grew the cleaner
session.py / families.py integration mechanics; this adopts them on the
sa3 side with hunks textually identical to that branch's, so the merge
auto-resolves instead of leaving two names for every concept:

- families.SESSION_CREATORS registry replaces the hardcoded
  'if family == "sa3"' create dispatch; create_sa3_session takes the
  creator contract signature (cls, *, audio, config, checkpoint,
  session_id, **rest).
- Backend prompt hook renamed set_prompt -> handle_set_prompt
  (+ handle_set_prompt_blend dispatch in set_prompt_blend); the
  session now calls the handler OUTSIDE state._lock — the sa3 rebuild
  is a GPU T5Gemma encode and was previously run under the lock.
- cond_pair sentinel: None (with the _refresh_conditioning guard)
  instead of (None, None) — ungated incidental callers
  (set_interp_method, prompt_blend) are now safe no-ops on sa3
  sessions instead of a latent blend_for_strength crash.
- close(): backend-owned close() hook before stream/session teardown.
- engine_session: Session | None annotation.

The only intentional divergence from the mrt2 text is one comment line
citing sa3 as the hook example instead of mrt2.
…o the package

The medium TRT DiT real-conditioning gap (cos 0.80-0.97/step) was the
BF16 build recipe upstream explicitly rejected: its quantization error
compounds over the 8 pingpong steps. Rebuilding from the pre-surgered
dit_fp16mixed.onnx (FP16 trunk, FP32 islands around RMSNorm / softmax /
RoPE) as a STRONGLY_TYPED network restores cos >= 0.9998/step on real
conditioning, matching upstream's documented numbers. The residual
0.99x at short durations is the engine's missing padding_mask input
(structural, documented in the parity script), not numerics.

Engine creation moves to acestep.engine.trt.sa3_build, holding the
ACE builder's shape: shared preflight, metadata-sidecar skip/rebuild
gates, canonical --all matrix (DiT latent profiles 324 + 646, SAME-L
window t32_56_96), build_report.csv. The scripts/sa3 build scripts
become thin forwarders into it.
With TRT-vs-eager parity closed, SA3 sessions consume the same
decoder_backend / vae_backend acceleration values as ACE (server
--accel): tensorrt selects the built DiT engine, the TRT SAME-L window
codec, and the engine-cap duration clamp; compile degrades loudly to
eager (SA3 has no torch.compile path); eager stays fully eager.

Threading mirrors the ACE path: create_sa3_session resolves the values
and stashes them in backend_init; families._make_sa3 hands them to
SA3Backend.from_context(dit_backend, codec_backend), which maps them
onto SA3Context.make_dit / make_codec. Verified live on a loaded
medium context: all three accel values resolve to the right
components.
Base automatically changed from ryanontheinside/feat/models/backend to main June 8, 2026 13:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant