Skip to content

[RFC]: Unified Denoising Engine + Strategy Model #1012

@SolitaryThinker

Description

@SolitaryThinker

Motivation.

Motivation

Pain points before the refactor:

  • Denoising logic was split across large, model-specific stage classes.
  • Adding new pipelines required more conditionals in shared stages.
  • Denoising code was difficult to test in isolation.
  • Features like CFG-parallel and SP sharding had no consistent integration
    point.

The new engine + strategy design centralizes the denoising loop while allowing
per-model behavior in strategies. Hooks provide a clean extension path for
performance and distributed features.

Proposed Change.

Summary

This RFC formalizes the denoising refactor and captures the current
implementation state. The goal is a single shared denoising engine with small,
per-model strategy hooks, plus an engine-level hook system for cross-cutting
concerns (perf logging, CFG-parallel, SP sharding, cache-dit, scheduler client
patterns). Legacy denoising stage wrappers are removed in favor of
DenoisingStage(strategy_cls=...), with MatrixGame keeping its streaming
wrapper.

Goals

  • One shared denoising loop for Standard, Cosmos, LongCat, DMD, and others.
  • Block strategies for causal and MatrixGame (KV caches, streaming).
  • Make per-model differences explicit and testable via strategies.
  • Provide engine hooks for perf logging and future distributed features.
  • Preserve behavior via parity tests and SSIM tests.

Non-goals

  • Rewriting training-time denoising or loss computation.
  • Changing model architecture internals.
  • Full adoption of the scheduler adapter across all strategies in this RFC
    (adapter is introduced and optional).

Design Overview

Engine and Strategy Protocols

The engine owns the loop; strategies implement model-specific details.

class DenoisingStrategy(Protocol):
    def prepare(self, batch, args) -> StrategyState: ...
    def make_model_inputs(self, state, t, step_idx) -> ModelInputs: ...
    def forward(self, state, model_inputs) -> torch.Tensor: ...
    def cfg_combine(self, state, noise_pred) -> torch.Tensor: ...
    def scheduler_step(self, state, noise_pred, t) -> torch.Tensor: ...
    def postprocess(self, state) -> ForwardBatch: ...

class BlockDenoisingStrategy(DenoisingStrategy):
    def block_plan(self, state) -> BlockPlan: ...
    def init_block_context(self, state, block_item, block_idx) -> BlockContext: ...
    def process_block(self, state, block_ctx, block_item) -> None: ...
    def update_context(self, state, block_ctx, block_item) -> None: ...

Stage Wrapper

Pipelines use a thin wrapper that selects the strategy:

stage = DenoisingStage(
    transformer=...,
    scheduler=...,
    strategy_cls=StandardStrategy,
)

MatrixGame keeps a streaming wrapper but uses the same engine for blocks.

Engine Hooks

Hooks are for cross-cutting concerns. A minimal perf hook is implemented:

from fastvideo.pipelines.stages.denoising_engine_hooks import PerfLoggingHook

engine = DenoisingEngine(
    StandardStrategy(stage),
    hooks=[PerfLoggingHook()],
)

Hooks fire in both per-step loops and block loops, including MatrixGame
streaming via streaming_reset/streaming_clear.

Scheduler Adapter

fastvideo/models/schedulers/adapter.py defines a thin adapter interface to
normalize scheduler usage. The engine can attach an adapter into the
StrategyState.extra dict, and strategies can opt in to it. This allows a
progressive migration without breaking current behavior.

Examples

1) Standard Pipeline Wiring

from fastvideo.pipelines.stages import DenoisingStage
from fastvideo.pipelines.stages.denoising_standard_strategy import StandardStrategy

self.add_stage(
    stage_name="denoising_stage",
    stage=DenoisingStage(
        transformer=self.get_module("transformer"),
        scheduler=self.get_module("scheduler"),
        strategy_cls=StandardStrategy,
    ),
)

2) Custom Strategy Skeleton

class MyStrategy:
    def __init__(self, stage):
        self.stage = stage

    def prepare(self, batch, args):
        return StrategyState(
            latents=batch.latents,
            timesteps=batch.timesteps,
            num_inference_steps=batch.num_inference_steps,
            prompt_embeds=batch.prompt_embeds,
            negative_prompt_embeds=batch.negative_prompt_embeds,
            prompt_attention_mask=batch.prompt_attention_mask,
            negative_attention_mask=batch.negative_attention_mask,
            image_embeds=batch.image_embeds,
            guidance_scale=batch.guidance_scale,
            guidance_scale_2=batch.guidance_scale_2,
            guidance_rescale=batch.guidance_rescale,
            do_cfg=batch.do_classifier_free_guidance,
            extra={"batch": batch},
        )

    def make_model_inputs(self, state, t, step_idx):
        return ModelInputs(
            latent_model_input=state.latents,
            timestep=t,
            prompt_embeds=state.prompt_embeds,
            prompt_attention_mask=state.prompt_attention_mask,
        )

    def forward(self, state, model_inputs):
        return self.stage.transformer(
            model_inputs.latent_model_input,
            model_inputs.prompt_embeds,
            model_inputs.timestep,
        )

    def cfg_combine(self, state, noise_pred):
        return noise_pred

    def scheduler_step(self, state, noise_pred, t):
        return self.stage.scheduler.step(
            noise_pred, t, state.latents, return_dict=False
        )[0]

    def postprocess(self, state):
        state.extra["batch"].latents = state.latents
        return state.extra["batch"]

3) Perf Logging Hook (Opt-in)

export FASTVIDEO_DENOISING_PERF_LOGGING=1

Per-step and total timings are stored in ForwardBatch.logging_info, and
a summary log is emitted on rank 0.

4) MatrixGame Streaming

MatrixGame retains a streaming wrapper but uses the engine under the hood:

denoiser = self._stage_name_mapping["denoising_stage"]
denoiser.streaming_reset(batch, args)   # initializes engine + hooks
denoiser.streaming_step(keyboard_action, mouse_action)
denoiser.streaming_clear()              # finalizes hook metrics

Current State

Implemented:

  • DenoisingEngine + strategy protocols in
    fastvideo/pipelines/stages/denoising_engine.py and
    fastvideo/pipelines/stages/denoising_strategies.py.
  • Strategies for Standard, Cosmos, LongCat, DMD, Causal, and MatrixGame.
  • Engine hooks with per-step perf logging, gated by
    FASTVIDEO_DENOISING_PERF_LOGGING.
  • Block hooks fire for run_blocks, including MatrixGame streaming.
  • Per-pipeline wrappers removed (except MatrixGame streaming wrapper).
  • Parity tests expanded for I2V/V2V/boundary/DMD-I2V and block strategies.

Open Questions

  • Should the engine own scheduler state and CFG policy fully, or keep it in
    strategies?
  • How to expose deterministic RNG across strategies without coupling?
  • What is the canonical behavior for CFG batching vs sequential CFG?

Future Work

  • CFG-parallel hook (cross-rank split/merge).
  • SP sharding hook (latent split/gather).
  • Cache-dit hook and scheduler client patterns.
  • Attention context builder to centralize STA/VSA/VMOBA metadata.

Feedback Period.

No response

CC List.

No response

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions