Skip to content
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,11 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# macOS
.DS_Store
._*

# AI assistant context (local only)
/CLAUDE.md
/AGENTS.md
3 changes: 2 additions & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from .tiled_vae_decode import LTXVTiledVAEDecode
from .tricks import NODE_CLASS_MAPPINGS as TRICKS_NODE_CLASS_MAPPINGS
from .tricks import NODE_DISPLAY_NAME_MAPPINGS as TRICKS_NODE_DISPLAY_NAME_MAPPINGS
from .utiltily_nodes import FloatToInt, ImageToCPU
from .utiltily_nodes import FloatToInt, ImageToCPU, LTXVLoopingReferenceSchedule
from .vae_patcher import LTXVPatcherVAE
from .vanish_nodes import LTXVDilateVideoMask, LTXVInpaintPreprocess

Expand Down Expand Up @@ -96,6 +96,7 @@
"LTXVMultiPromptProvider": MultiPromptProvider,
"ImageToCPU": ImageToCPU,
"LTXFloatToInt": FloatToInt,
"LTXVLoopingReferenceSchedule": LTXVLoopingReferenceSchedule,
"LTXVStatNormLatent": LTXVStatNormLatent,
"LTXVPerStepStatNormPatcher": LTXVPerStepStatNormPatcher,
"LTXVGemmaCLIPModelLoader": LTXVGemmaCLIPModelLoader,
Expand Down
155 changes: 147 additions & 8 deletions easy_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,50 @@
from .nodes_registry import comfy_node


def _make_av_latent_dict(video_latent_dict, audio_tensor, audio_noise_mask=None):
"""Wrap video latent dict + audio tensor into AV latent dict with NestedTensor.

If audio_tensor is None, returns video_latent_dict unchanged.
Creates matching noise masks for both modalities when either is present.
"""
if audio_tensor is None:
return video_latent_dict
result = video_latent_dict.copy()
result["samples"] = NestedTensor([result["samples"], audio_tensor])
video_mask = result.get("noise_mask")
if video_mask is not None or audio_noise_mask is not None:
if video_mask is None:
vs = result["samples"].tensors[0]
video_mask = torch.ones(
vs.shape[0], 1, vs.shape[2], vs.shape[3], vs.shape[4],
device=vs.device, dtype=vs.dtype,
)
if audio_noise_mask is None:
audio_noise_mask = torch.ones(
audio_tensor.shape[0], 1, audio_tensor.shape[2], audio_tensor.shape[3],
device=audio_tensor.device, dtype=audio_tensor.dtype,
)
result["noise_mask"] = NestedTensor([video_mask, audio_noise_mask])
return result


def _split_av_latent_dict(latent_dict):
"""Split AV latent dict into (video_latent_dict, audio_tensor).

If the latent is not an AV NestedTensor, returns (latent_dict, None).
"""
samples = latent_dict["samples"]
if not isinstance(samples, NestedTensor) or len(samples.tensors) < 2:
return latent_dict, None
result = latent_dict.copy()
result["samples"] = samples.tensors[0]
audio = samples.tensors[1]
nm = result.get("noise_mask")
if nm is not None and isinstance(nm, NestedTensor):
result["noise_mask"] = nm.tensors[0]
return result, audio


def _get_raw_conds_from_guider(guider):
if not hasattr(guider, "raw_conds"):
if "negative" not in guider.original_conds:
Expand Down Expand Up @@ -148,6 +192,7 @@ def sample(
optional_initialization_latents=None,
guiding_start_step=0,
guiding_end_step=1000,
_audio_tile=None,
):
guider = copy.copy(guider)
guider.original_conds = copy.deepcopy(guider.original_conds)
Expand Down Expand Up @@ -262,13 +307,15 @@ def sample(

# Denoise the latent video
print("Denoising with conditioning on sigmas: ", middle_sigmas)
_av = _make_av_latent_dict(latents, _audio_tile)
(output_latents, denoised_output_latents) = SamplerCustomAdvanced().sample(
noise=noise,
guider=guider,
sampler=sampler,
sigmas=middle_sigmas,
latent_image=latents,
latent_image=_av,
)
denoised_output_latents, _audio_tile = _split_av_latent_dict(denoised_output_latents)

# Clean up guides if image conditioning was used
positive, negative, denoised_output_latents = LTXVCropGuides.execute(
Expand All @@ -284,13 +331,18 @@ def sample(
"Denoising with no conditioning but with classical i2v noise mask on sigmas: ",
low_sigmas,
)
_av = _make_av_latent_dict(denoised_output_latents, _audio_tile)
(_, denoised_output_latents) = SamplerCustomAdvanced().sample(
noise=noise,
guider=guider,
sampler=sampler,
sigmas=low_sigmas,
latent_image=denoised_output_latents,
latent_image=_av,
)
denoised_output_latents, _audio_tile = _split_av_latent_dict(denoised_output_latents)

if _audio_tile is not None:
denoised_output_latents["_audio"] = _audio_tile

return (denoised_output_latents, positive, negative)

Expand Down Expand Up @@ -399,6 +451,8 @@ def sample(
guiding_start_step=0,
guiding_end_step=1000,
normalize_per_frame=False,
_audio_tile=None,
_audio_new_init=None,
):
guider = copy.copy(guider)
guider.original_conds = copy.deepcopy(guider.original_conds)
Expand All @@ -412,7 +466,20 @@ def sample(

positive, negative = _get_raw_conds_from_guider(guider)

# Handle AV latents (standalone mode)
_standalone_av = False
_accumulated_audio = _audio_tile
samples = latents["samples"]
if isinstance(samples, NestedTensor) and len(samples.tensors) == 2:
if _accumulated_audio is None:
_accumulated_audio = samples.tensors[1]
_standalone_av = True
latents = latents.copy()
latents["samples"] = samples.tensors[0]
if "noise_mask" in latents and isinstance(latents["noise_mask"], NestedTensor):
latents["noise_mask"] = latents["noise_mask"].tensors[0]
samples = latents["samples"]

batch, channels, frames, height, width = samples.shape
time_scale_factor, width_scale_factor, height_scale_factor = (
vae.downscale_index_formula
Expand All @@ -428,6 +495,52 @@ def sample(
latents, -overlap, -1
)

# Set up audio extend tile if audio is available
_audio_extend_tile = None
_audio_noise_mask = None
_audio_overlap = 0
if _accumulated_audio is not None:
audio_T = _accumulated_audio.shape[2]
video_T = frames
audio_ratio = audio_T / max(video_T, 1)
_audio_overlap = max(1, round(overlap * audio_ratio))
video_new_latent_frames = num_new_frames // time_scale_factor
audio_new_frames = max(1, round(video_new_latent_frames * audio_ratio))

# Build audio tile: overlap (already denoised) + new frames.
# If _audio_new_init is provided (stage-2 refinement), use it
# as initialization for the new frames instead of zeros.
audio_overlap_data = _accumulated_audio[:, :, -_audio_overlap:]
if _audio_new_init is not None:
available = min(audio_new_frames, _audio_new_init.shape[2])
audio_new_data = _audio_new_init[:, :, :available].clone()
if available < audio_new_frames:
pad = torch.zeros(
_accumulated_audio.shape[0], _accumulated_audio.shape[1],
audio_new_frames - available, _accumulated_audio.shape[3],
device=_accumulated_audio.device, dtype=_accumulated_audio.dtype,
)
audio_new_data = torch.cat([audio_new_data, pad], dim=2)
else:
audio_new_data = torch.zeros(
_accumulated_audio.shape[0], _accumulated_audio.shape[1],
audio_new_frames, _accumulated_audio.shape[3],
device=_accumulated_audio.device, dtype=_accumulated_audio.dtype,
)
_audio_extend_tile = torch.cat([audio_overlap_data, audio_new_data], dim=2)

# Audio noise mask: preserve overlap, denoise new
_audio_noise_mask = torch.ones(
_audio_extend_tile.shape[0], 1,
_audio_extend_tile.shape[2], _audio_extend_tile.shape[3],
device=_audio_extend_tile.device, dtype=_audio_extend_tile.dtype,
)
_audio_noise_mask[:, :, :_audio_overlap] = 1.0 - strength
print(
f"[ExtendSampler] Audio extend tile: overlap={_audio_overlap}, "
f"new={audio_new_frames}, total={_audio_extend_tile.shape[2]}"
)

if optional_initialization_latents is None:
new_latents = EmptyLTXVLatentVideo.execute(
width=width * width_scale_factor,
Expand Down Expand Up @@ -488,13 +601,15 @@ def sample(
if len(high_sigmas) > 1:
guider.set_conds(positive, negative)
print("Denoising with overlap conditioning only on sigmas: ", high_sigmas)
_av = _make_av_latent_dict(new_latents, _audio_extend_tile, _audio_noise_mask)
(_, new_latents) = SamplerCustomAdvanced().sample(
noise=noise,
guider=guider,
sampler=sampler,
sigmas=high_sigmas,
latent_image=new_latents,
latent_image=_av,
)
new_latents, _audio_extend_tile = _split_av_latent_dict(new_latents)

if optional_guiding_latents is not None:
optional_guiding_latents = LTXVSelectLatents().select_latents(
Expand Down Expand Up @@ -533,13 +648,15 @@ def sample(

# Denoise the latent video
print("Denoising with full conditioning on sigmas: ", middle_sigmas)
_av = _make_av_latent_dict(new_latents, _audio_extend_tile, _audio_noise_mask)
(output_latents, denoised_output_latents) = SamplerCustomAdvanced().sample(
noise=noise,
guider=guider,
sampler=sampler,
sigmas=middle_sigmas,
latent_image=new_latents,
latent_image=_av,
)
denoised_output_latents, _audio_extend_tile = _split_av_latent_dict(denoised_output_latents)

positive, negative, denoised_output_latents = LTXVCropGuides.execute(
positive=positive,
Expand Down Expand Up @@ -591,13 +708,15 @@ def sample(
"Denoising with overlap + keyframes conditioning only on sigmas: ",
low_sigmas,
)
_av = _make_av_latent_dict(denoised_output_latents, _audio_extend_tile, _audio_noise_mask)
(_, denoised_output_latents) = SamplerCustomAdvanced().sample(
noise=noise,
guider=guider,
sampler=sampler,
sigmas=low_sigmas,
latent_image=denoised_output_latents,
latent_image=_av,
)
denoised_output_latents, _audio_extend_tile = _split_av_latent_dict(denoised_output_latents)
positive, negative, denoised_output_latents = LTXVCropGuides.execute(
positive=positive,
negative=negative,
Expand All @@ -621,6 +740,16 @@ def sample(
(latents,) = LinearOverlapLatentTransition().process(
latents, truncated_denoised_output_latents, overlap - 1, axis=2
)

# Accumulate audio: append new (non-overlap) audio frames
if _accumulated_audio is not None and _audio_extend_tile is not None:
new_audio = _audio_extend_tile[:, :, _audio_overlap:]
accumulated_audio_out = torch.cat([_accumulated_audio, new_audio], dim=2)
if _standalone_av:
latents["samples"] = NestedTensor([latents["samples"], accumulated_audio_out])
else:
latents["_audio"] = accumulated_audio_out

return (latents, positive, negative)


Expand Down Expand Up @@ -692,6 +821,7 @@ def sample(
guiding_strength=1.0,
guiding_start_step=0,
guiding_end_step=1000,
_audio_tile=None,
):
guider = copy.copy(guider)
guider.original_conds = copy.deepcopy(guider.original_conds)
Expand Down Expand Up @@ -735,13 +865,15 @@ def sample(
"Denoising with keyframes only [if available] on sigmas: ",
high_sigmas,
)
_av = _make_av_latent_dict(new_latents, _audio_tile)
(_, new_latents) = SamplerCustomAdvanced().sample(
noise=noise,
guider=guider,
sampler=sampler,
sigmas=high_sigmas,
latent_image=new_latents,
latent_image=_av,
)
new_latents, _audio_tile = _split_av_latent_dict(new_latents)

if optional_cond_indices is not None and 0 in optional_cond_indices:
guiding_latents = LTXVSelectLatents().select_latents(
Expand Down Expand Up @@ -806,13 +938,15 @@ def sample(

# Denoise the latent video
print("Denoising with full conditioning on sigmas: ", middle_sigmas)
_av = _make_av_latent_dict(new_latents, _audio_tile)
(_, denoised_output_latents) = SamplerCustomAdvanced().sample(
noise=noise,
guider=guider,
sampler=sampler,
sigmas=middle_sigmas,
latent_image=new_latents,
latent_image=_av,
)
denoised_output_latents, _audio_tile = _split_av_latent_dict(denoised_output_latents)

# Clean up guides if image conditioning was used
positive, negative, denoised_output_latents = LTXVCropGuides.execute(
Expand All @@ -827,19 +961,24 @@ def sample(
"Denoising with keyframes only [if available] conditioning on sigmas: ",
low_sigmas,
)
_av = _make_av_latent_dict(denoised_output_latents, _audio_tile)
(_, denoised_output_latents) = SamplerCustomAdvanced().sample(
noise=noise,
guider=guider,
sampler=sampler,
sigmas=low_sigmas,
latent_image=denoised_output_latents,
latent_image=_av,
)
denoised_output_latents, _audio_tile = _split_av_latent_dict(denoised_output_latents)
positive, negative, denoised_output_latents = LTXVCropGuides.execute(
positive=positive,
negative=negative,
latent=denoised_output_latents,
)

if _audio_tile is not None:
denoised_output_latents["_audio"] = _audio_tile

return (denoised_output_latents, positive, negative)


Expand Down
Loading