-
Notifications
You must be signed in to change notification settings - Fork 244
Description
Motivation.
Motivation
Pain points before the refactor:
- Denoising logic was split across large, model-specific stage classes.
- Adding new pipelines required more conditionals in shared stages.
- Denoising code was difficult to test in isolation.
- Features like CFG-parallel and SP sharding had no consistent integration
point.
The new engine + strategy design centralizes the denoising loop while allowing
per-model behavior in strategies. Hooks provide a clean extension path for
performance and distributed features.
Proposed Change.
Summary
This RFC formalizes the denoising refactor and captures the current
implementation state. The goal is a single shared denoising engine with small,
per-model strategy hooks, plus an engine-level hook system for cross-cutting
concerns (perf logging, CFG-parallel, SP sharding, cache-dit, scheduler client
patterns). Legacy denoising stage wrappers are removed in favor of
DenoisingStage(strategy_cls=...), with MatrixGame keeping its streaming
wrapper.
Goals
- One shared denoising loop for Standard, Cosmos, LongCat, DMD, and others.
- Block strategies for causal and MatrixGame (KV caches, streaming).
- Make per-model differences explicit and testable via strategies.
- Provide engine hooks for perf logging and future distributed features.
- Preserve behavior via parity tests and SSIM tests.
Non-goals
- Rewriting training-time denoising or loss computation.
- Changing model architecture internals.
- Full adoption of the scheduler adapter across all strategies in this RFC
(adapter is introduced and optional).
Design Overview
Engine and Strategy Protocols
The engine owns the loop; strategies implement model-specific details.
class DenoisingStrategy(Protocol):
def prepare(self, batch, args) -> StrategyState: ...
def make_model_inputs(self, state, t, step_idx) -> ModelInputs: ...
def forward(self, state, model_inputs) -> torch.Tensor: ...
def cfg_combine(self, state, noise_pred) -> torch.Tensor: ...
def scheduler_step(self, state, noise_pred, t) -> torch.Tensor: ...
def postprocess(self, state) -> ForwardBatch: ...
class BlockDenoisingStrategy(DenoisingStrategy):
def block_plan(self, state) -> BlockPlan: ...
def init_block_context(self, state, block_item, block_idx) -> BlockContext: ...
def process_block(self, state, block_ctx, block_item) -> None: ...
def update_context(self, state, block_ctx, block_item) -> None: ...Stage Wrapper
Pipelines use a thin wrapper that selects the strategy:
stage = DenoisingStage(
transformer=...,
scheduler=...,
strategy_cls=StandardStrategy,
)MatrixGame keeps a streaming wrapper but uses the same engine for blocks.
Engine Hooks
Hooks are for cross-cutting concerns. A minimal perf hook is implemented:
from fastvideo.pipelines.stages.denoising_engine_hooks import PerfLoggingHook
engine = DenoisingEngine(
StandardStrategy(stage),
hooks=[PerfLoggingHook()],
)Hooks fire in both per-step loops and block loops, including MatrixGame
streaming via streaming_reset/streaming_clear.
Scheduler Adapter
fastvideo/models/schedulers/adapter.py defines a thin adapter interface to
normalize scheduler usage. The engine can attach an adapter into the
StrategyState.extra dict, and strategies can opt in to it. This allows a
progressive migration without breaking current behavior.
Examples
1) Standard Pipeline Wiring
from fastvideo.pipelines.stages import DenoisingStage
from fastvideo.pipelines.stages.denoising_standard_strategy import StandardStrategy
self.add_stage(
stage_name="denoising_stage",
stage=DenoisingStage(
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler"),
strategy_cls=StandardStrategy,
),
)2) Custom Strategy Skeleton
class MyStrategy:
def __init__(self, stage):
self.stage = stage
def prepare(self, batch, args):
return StrategyState(
latents=batch.latents,
timesteps=batch.timesteps,
num_inference_steps=batch.num_inference_steps,
prompt_embeds=batch.prompt_embeds,
negative_prompt_embeds=batch.negative_prompt_embeds,
prompt_attention_mask=batch.prompt_attention_mask,
negative_attention_mask=batch.negative_attention_mask,
image_embeds=batch.image_embeds,
guidance_scale=batch.guidance_scale,
guidance_scale_2=batch.guidance_scale_2,
guidance_rescale=batch.guidance_rescale,
do_cfg=batch.do_classifier_free_guidance,
extra={"batch": batch},
)
def make_model_inputs(self, state, t, step_idx):
return ModelInputs(
latent_model_input=state.latents,
timestep=t,
prompt_embeds=state.prompt_embeds,
prompt_attention_mask=state.prompt_attention_mask,
)
def forward(self, state, model_inputs):
return self.stage.transformer(
model_inputs.latent_model_input,
model_inputs.prompt_embeds,
model_inputs.timestep,
)
def cfg_combine(self, state, noise_pred):
return noise_pred
def scheduler_step(self, state, noise_pred, t):
return self.stage.scheduler.step(
noise_pred, t, state.latents, return_dict=False
)[0]
def postprocess(self, state):
state.extra["batch"].latents = state.latents
return state.extra["batch"]3) Perf Logging Hook (Opt-in)
export FASTVIDEO_DENOISING_PERF_LOGGING=1Per-step and total timings are stored in ForwardBatch.logging_info, and
a summary log is emitted on rank 0.
4) MatrixGame Streaming
MatrixGame retains a streaming wrapper but uses the engine under the hood:
denoiser = self._stage_name_mapping["denoising_stage"]
denoiser.streaming_reset(batch, args) # initializes engine + hooks
denoiser.streaming_step(keyboard_action, mouse_action)
denoiser.streaming_clear() # finalizes hook metricsCurrent State
Implemented:
DenoisingEngine+ strategy protocols in
fastvideo/pipelines/stages/denoising_engine.pyand
fastvideo/pipelines/stages/denoising_strategies.py.- Strategies for Standard, Cosmos, LongCat, DMD, Causal, and MatrixGame.
- Engine hooks with per-step perf logging, gated by
FASTVIDEO_DENOISING_PERF_LOGGING. - Block hooks fire for
run_blocks, including MatrixGame streaming. - Per-pipeline wrappers removed (except MatrixGame streaming wrapper).
- Parity tests expanded for I2V/V2V/boundary/DMD-I2V and block strategies.
Open Questions
- Should the engine own scheduler state and CFG policy fully, or keep it in
strategies? - How to expose deterministic RNG across strategies without coupling?
- What is the canonical behavior for CFG batching vs sequential CFG?
Future Work
- CFG-parallel hook (cross-rank split/merge).
- SP sharding hook (latent split/gather).
- Cache-dit hook and scheduler client patterns.
- Attention context builder to centralize STA/VSA/VMOBA metadata.
Feedback Period.
No response
CC List.
No response
Any Other Things.
No response
Before submitting a new issue...
- Make sure you already searched for relevant issues.