From 7771d1d337e247d986168052abd7cdb07253a79a Mon Sep 17 00:00:00 2001 From: "liang.feng" Date: Mon, 29 Jun 2026 03:25:38 -0700 Subject: [PATCH] Add transfer control-CFG (cfg for control inputs) + smoke-test coverage Port the control-input CFG feature from i4 commit f11349b into the transfer inference path, reconciling with logic already synced into this repo: - omni_mot_model.py already carries the velocity_postprocess_builder hook, so no model change was needed. - transfer.py: add _build_no_control_inference_state and build_control_cfg_postprocess, and wire them through generate_samples_from_batch via velocity_postprocess_builder. Previously transfer.py passed control_guidance/control_guidance_interval directly, which were silently dropped by **kwargs (control-CFG was a no-op). - args.py: add emphasize_control_in_prompt (TransferDataArgs/Overrides + _TRANSFER_SAMPLE_DEFAULTS) to match the ported prompt-emphasis logic. Test: extend tests/nano_inference_smoke_test.py to cover transfer inference. The existing throughput run keeps t2vs + policy + forward_dynamics; transfer is run as a SEPARATE latency-preset call. Control-CFG runs an extra control-dropped forward each step, which under throughput (data-parallel over samples, FSDP- sharded) executes on only the transfer rank and deadlocks the cross-rank allgather -- so transfer must use the latency preset (context/CFG parallel, all ranks on one sample), matching the cookbook multi-GPU transfer recipe. The spec is built inline (_TRANSFER_SPEC, written to a temp file) and pulls the control video from the public NVIDIA/cosmos GitHub raw URL (same file the cookbook edge.json uses), downscaled for a fast smoke run. Validates transfer-specific attributes (edge control_path, control_guidance>1, guidance>1) and a non-degenerate output clip via the new _assert_video_has_content helper. Verified on a GB200 node: the README Nano edge transfer and the inline smoke spec both generate valid, non-degenerate video; the latency-preset transfer (4 ranks: cfgp=2, cp=2) completes and passes the test assertions, while the throughput+mixed path reproduces the deadlock. Co-Authored-By: Claude Opus 4.8 (1M context) --- cosmos_framework/inference/args.py | 6 + cosmos_framework/inference/transfer.py | 173 ++++++++++++++++++++++++- tests/nano_inference_smoke_test.py | 143 +++++++++++++++++++- 3 files changed, 313 insertions(+), 9 deletions(-) diff --git a/cosmos_framework/inference/args.py b/cosmos_framework/inference/args.py index 15240273..da8dea22 100644 --- a/cosmos_framework/inference/args.py +++ b/cosmos_framework/inference/args.py @@ -688,6 +688,7 @@ class TransferDataArgs(ArgsBase, _TransferDataBase): show_input: bool | None = None num_first_chunk_conditional_frames: pydantic.NonNegativeInt | None = None share_vision_temporal_positions: bool | None = None + emphasize_control_in_prompt: bool | None = None class TransferDataOverrides(OverridesBase, _TransferDataBase): @@ -726,6 +727,10 @@ class TransferDataOverrides(OverridesBase, _TransferDataBase): """Number of conditioning frames for the first chunk (defaults to ``num_conditional_frames``).""" share_vision_temporal_positions: bool | None = None """Share vision temporal position ids across autoregressive chunks.""" + emphasize_control_in_prompt: bool | None = None + """If True (default), auto-append a one-sentence directive to the user prompt that + names the active control modality (e.g. "Follow the edge control video precisely. + ..."). Set False for clean baselines / ablations. The system prompt is unchanged.""" @pydantic.model_validator(mode="after") def _validate_transfer_hints(self) -> Self: @@ -765,6 +770,7 @@ def download(self, output_dir: Path): "show_input": False, "num_first_chunk_conditional_frames": 0, "share_vision_temporal_positions": True, + "emphasize_control_in_prompt": True, } _TRANSFER_HINT_DEFAULTS: ClassVar[dict[TransferHintKey, dict[str, Any]]] = { TransferHintKey.EDGE: {"preset_edge_threshold": PresetEdgeThreshold.MEDIUM}, diff --git a/cosmos_framework/inference/transfer.py b/cosmos_framework/inference/transfer.py index d2bc182e..57760e7c 100644 --- a/cosmos_framework/inference/transfer.py +++ b/cosmos_framework/inference/transfer.py @@ -7,6 +7,7 @@ import random from dataclasses import dataclass from pathlib import Path +from typing import Callable, Optional import torch @@ -27,6 +28,7 @@ from cosmos_framework.utils import log from cosmos_framework.data.vfm.sequence_packing import SequencePlan from cosmos_framework.model.vfm.omni_mot_model import OmniMoTModel +from cosmos_framework.model.vfm.utils.data_and_condition import GenerationDataClean from cosmos_framework.model.vfm.vlm.qwen3_vl.utils import _SYSTEM_PROMPT_TRANSFER @@ -193,6 +195,159 @@ def build_transfer_batch( return batch +def _build_no_control_inference_state( + sequence_plans: list[SequencePlan], + gen_data_clean: GenerationDataClean, +) -> tuple[list[SequencePlan], GenerationDataClean, list[int]] | None: + """Build a target-only counterpart of ``(sequence_plans, gen_data_clean)`` for + control-CFG. Drops all but the last vision item per sample (the target). + + Returns ``None`` when no sample has multiple vision items (nothing to drop). + + Also returns ``ctrl_dims_per_sample`` — the flattened control-token dimension + per sample, used to slice ``noise_x`` and mix velocities. + """ + num_items_per_sample = gen_data_clean.num_vision_items_per_sample + if num_items_per_sample is None or all(n <= 1 for n in num_items_per_sample): + return None + + assert gen_data_clean.x0_tokens_vision is not None + + new_x0_tokens_vision: list[torch.Tensor] = [] + new_raw_state_vision: list[torch.Tensor] | None = [] if gen_data_clean.raw_state_vision is not None else None + ctrl_dims_per_sample: list[int] = [] + vis_offset = 0 + for n_vis in num_items_per_sample: + ctrl_dim_i = 0 + for j in range(n_vis - 1): + sh = gen_data_clean.x0_tokens_vision[vis_offset + j].shape + ctrl_dim_i += math.prod(sh) + ctrl_dims_per_sample.append(ctrl_dim_i) + tgt_idx = vis_offset + n_vis - 1 + new_x0_tokens_vision.append(gen_data_clean.x0_tokens_vision[tgt_idx]) + if new_raw_state_vision is not None: + new_raw_state_vision.append(gen_data_clean.raw_state_vision[tgt_idx]) # type: ignore[index] + vis_offset += n_vis + + gdc_nc = GenerationDataClean( + batch_size=gen_data_clean.batch_size, + is_image_batch=gen_data_clean.is_image_batch, + raw_state_vision=new_raw_state_vision, + x0_tokens_vision=new_x0_tokens_vision, + fps_vision=gen_data_clean.fps_vision, + num_vision_items_per_sample=None, + raw_state_action=gen_data_clean.raw_state_action, + x0_tokens_action=gen_data_clean.x0_tokens_action, + action_domain_id=gen_data_clean.action_domain_id, + fps_action=gen_data_clean.fps_action, + raw_action_dim=gen_data_clean.raw_action_dim, + raw_state_sound=gen_data_clean.raw_state_sound, + x0_tokens_sound=gen_data_clean.x0_tokens_sound, + fps_sound=gen_data_clean.fps_sound, + ) + + sp_nc = [ + SequencePlan( + has_text=sp.has_text, + has_vision=sp.has_vision, + condition_frame_indexes_vision=sp.condition_frame_indexes_vision, + share_vision_temporal_positions=False, + has_action=sp.has_action, + condition_frame_indexes_action=sp.condition_frame_indexes_action, + action_start_frame_offset=sp.action_start_frame_offset, + has_sound=sp.has_sound, + condition_frame_indexes_sound=sp.condition_frame_indexes_sound, + ) + for sp in sequence_plans + ] + + return sp_nc, gdc_nc, ctrl_dims_per_sample + + +def build_control_cfg_postprocess( + *, + control_guidance: float, + control_guidance_interval: Optional[list[float]] = None, +) -> Optional[ + Callable[..., Optional[Callable[[list[torch.Tensor], list[torch.Tensor], torch.Tensor], list[torch.Tensor]]]] +]: + """Return a ``velocity_postprocess_builder`` that injects control-CFG. + + Pass the returned builder to ``OmniMoTModel.generate_samples_from_batch``. + The builder is invoked once at the start of sampling with the prepared + inference state; it builds the alternate (target-only) state and returns a + per-step closure that mixes the conditional velocity with an extra forward + pass that has all control items dropped. + + Returns ``None`` when control-CFG is a no-op (``control_guidance == 1.0``), + so the model takes its fast single-forward path. + """ + if control_guidance == 1.0: + return None + + def builder( + *, + model: OmniMoTModel, + net: torch.nn.Module | None = None, + cond_tokens: list[list[int]], + sequence_plans: list[SequencePlan], + gen_data_clean: GenerationDataClean, + ) -> Optional[Callable[[list[torch.Tensor], list[torch.Tensor], torch.Tensor], list[torch.Tensor]]]: + nc_state = _build_no_control_inference_state(sequence_plans, gen_data_clean) + if nc_state is None: + log.warning( + "control_guidance != 1.0 but no multi-vision sample found; falling back to single-branch inference." + ) + return None + + if any(sp.has_action or sp.has_sound for sp in sequence_plans): + raise ValueError("control_guidance currently supports video transfer only, not action/sound generation.") + + sp_nc, gdc_nc, ctrl_dims = nc_state + control_guidance_bounds: tuple[float, float] | None = None + if control_guidance_interval is not None: + if len(control_guidance_interval) != 2: + raise ValueError(f"control_guidance_interval must be [lo, hi], got {control_guidance_interval}") + control_guidance_bounds = (control_guidance_interval[0], control_guidance_interval[1]) + + def postprocess( + cond_v_full: list[torch.Tensor], + noise_x: list[torch.Tensor], + timestep: torch.Tensor, + ) -> list[torch.Tensor]: + if control_guidance_bounds is not None: + if not (control_guidance_bounds[0] < timestep[0].item() < control_guidance_bounds[1]): + return cond_v_full + + noise_x_nc = [nx[c:] for nx, c in zip(noise_x, ctrl_dims, strict=True)] # [[N_target],...] + cond_v_nc = model._get_velocity( + net=net, + noise_x=noise_x_nc, + timestep=timestep, + text_tokens=cond_tokens, + sequence_plans=sp_nc, + gen_data_clean=gdc_nc, + skip_text_tokens=False, + ) + + # Mix only the suffix (target vision). The control-token portion + # of cond_v_full is already zeroed by the model's velocity mask + # (control items are fully conditioned), so leave it untouched. + mixed: list[torch.Tensor] = [] + for v_full_i, v_nc_i, c in zip(cond_v_full, cond_v_nc, ctrl_dims, strict=True): + suffix_full = v_full_i[c:] # [N_target] + assert suffix_full.shape == v_nc_i.shape, ( + f"shape mismatch in control-CFG mix: full suffix {suffix_full.shape} vs no-control {v_nc_i.shape}" + ) + mixed_suffix = v_nc_i + control_guidance * (suffix_full - v_nc_i) # [N_target] + mixed.append(torch.cat([v_full_i[:c], mixed_suffix], dim=0)) # [N_full] + return mixed + + return postprocess + + return builder + + def generate_transfer_sample( sample_args: OmniSampleArgs, model: OmniMoTModel, @@ -282,6 +437,18 @@ def generate_transfer_sample( prompt = chunk_prompt_data[model.input_caption_key][0] negative_prompt = chunk_prompt_data.get("neg_" + model.input_caption_key, [None])[0] + # Optionally append a one-sentence control-adherence directive to the user prompt. + # Names the active hint modality (e.g. "edge", "depth, seg") so the VLM gets the + # exact control type. System prompt is untouched (training-distribution safe). + if sample_args.emphasize_control_in_prompt: + hint_names = ", ".join(k.value for k in hints.keys()) + prompt = ( + prompt.rstrip() + f" Follow the {hint_names} control video precisely: shape, contour, silhouette," + f" position, and motion of every visible structure must align with the {hint_names}" + f" signal at every frame." + ) + log.info(f"[transfer] final user prompt: {prompt}") + model.eval() seed = sample_args.seed if sample_args.seed is not None else random.randint(0, 10000) for chunk_id in range(num_chunks): @@ -351,8 +518,10 @@ def generate_transfer_sample( sampler=sampler, guidance=guidance, guidance_interval=sample_args.guidance_interval, - control_guidance=sample_args.control_guidance, - control_guidance_interval=sample_args.control_guidance_interval, + velocity_postprocess_builder=build_control_cfg_postprocess( + control_guidance=sample_args.control_guidance, + control_guidance_interval=sample_args.control_guidance_interval, + ), seed=[seed + chunk_id], n_sample=1, has_negative_prompt=negative_prompt is not None, diff --git a/tests/nano_inference_smoke_test.py b/tests/nano_inference_smoke_test.py index a2f0e1b0..14defa15 100644 --- a/tests/nano_inference_smoke_test.py +++ b/tests/nano_inference_smoke_test.py @@ -3,9 +3,10 @@ """8-GPU multi-modality inference smoke test for Cosmos3-Nano. -Runs ONE ``cosmos_framework.scripts.inference`` call over three input samples of -different modalities (the ``-i`` flag takes a list of files) and validates each -sample's output: +Runs two ``cosmos_framework.scripts.inference`` calls and validates each output: + +1. A ``throughput`` call over three input samples of different modalities (the + ``-i`` flag takes a list of files): * ``inputs/omni/t2vs.json`` (text2video + sound) -> a ``vision.mp4`` whose muxed audio is real sound (finite, non-empty, non-silent, non-constant). @@ -15,8 +16,19 @@ * ``inputs/omni/action_policy_robot.json`` (policy) -> BOTH a ``vision.mp4`` and a finite, non-empty predicted ``action`` array in ``sample_outputs.json``. -All three samples produce a video; the policy sample additionally produces an -action and the t2vs sample an audio track. +2. A separate ``latency`` call for a video2video transfer spec (``_TRANSFER_SPEC``, + an edge control hint with ``control_guidance`` > 1.0, written to a temp file at + run time rather than committed under ``inputs/``) -> a non-degenerate + ``vision.mp4``. Exercises the transfer control-CFG path (the extra control-input + forward driven by ``control_guidance``). Transfer needs the ``latency`` preset: + under ``throughput`` (data-parallel over samples, FSDP-sharded) the extra + control forward runs on only the transfer rank and deadlocks the cross-rank + allgather, so it cannot share the call above — matching the cookbook's + multi-GPU transfer recipe, which is also ``latency``. + +All four samples produce a video; the policy sample additionally produces an +action, the t2vs sample an audio track, and the transfer sample exercises the +control-guidance branch. Smoke-level only (output validity, not numeric goldens). The checkpoint + its tokenizers download from the HF Hub on first run and are reused afterward. @@ -50,6 +62,45 @@ "inputs/omni/action_forward_dynamics_camera.json", ] +# Transfer (video2video, edge control) input, written to a temp file at run time +# rather than committed under inputs/. Mirrors the cookbook +# ``cookbooks/cosmos3/generator/transfer/specs/edge.json`` behavior — the edge +# control hint with guidance=3.0 + control_guidance=1.5, which selects the +# control-CFG path — but downscaled (480p / 10 steps / single 29-frame chunk) +# for a fast smoke run. The control video is the exact same file the cookbook +# uses, pulled from the public NVIDIA/cosmos GitHub raw URL; the prompt is a +# compact caption of that clip (the dense cookbook caption is not needed to +# exercise the path). +_TRANSFER_CONTROL_URL = ( + "https://github.com/NVIDIA/cosmos/raw/main/" + "cookbooks/cosmos3/generator/transfer/assets/edge/control_edge.mp4" +) +_TRANSFER_SPEC = { + "name": "transfer_edge", + "model_mode": "video2video", + "resolution": "480", + "aspect_ratio": "16,9", + "num_frames": 29, + "fps": 30, + "shift": 10.0, + "num_steps": 10, + "seed": 2026, + "num_video_frames_per_chunk": 29, + "max_frames": 29, + "num_conditional_frames": 1, + "num_first_chunk_conditional_frames": 0, + "share_vision_temporal_positions": True, + "guidance": 3.0, + "control_guidance": 1.5, + "prompt": ( + "A woman with blonde hair in a low ponytail, wearing a black sleeveless top and black " + "leggings, practices a dance routine in a brightly lit rehearsal studio with light wood " + "floors, a large red-framed window, and a black curtain." + ), + "negative_prompt": "blurry, distorted, deformed, low quality, flickering, artifacts", + "edge": {"control_path": _TRANSFER_CONTROL_URL, "preset_edge_threshold": "medium"}, +} + # Audio sanity thresholds for the muxed sound track. _RMS_SILENCE_FLOOR = 1e-4 # below this the track is effectively silence _PEAK_SANITY_CEIL = 1.5 # decoded float audio should sit within ~[-1, 1] @@ -145,6 +196,29 @@ def _assert_valid_video(mp4_path: Path) -> None: assert frames >= 1 and width > 0 and height > 0, f"no decodable video frame in {mp4_path}" +def _assert_video_has_content(mp4_path: Path, *, min_frames: int = 16) -> None: + """Assert ``mp4_path`` decodes to enough non-degenerate frames. + + Stronger than ``_assert_valid_video`` (which only inspects the first frame): + decodes the whole clip and checks the frame count plus real pixel variation, + so a run that produced a well-formed container but collapsed to a constant / + blank video (e.g. a broken control-CFG path) fails instead of passing. + """ + import av + import numpy as np + + with av.open(str(mp4_path)) as container: + vstreams = container.streams.video + assert vstreams, f"no video stream in {mp4_path}" + frames = [frame.to_ndarray(format="rgb24") for frame in container.decode(vstreams[0])] + assert len(frames) >= min_frames, f"{mp4_path}: expected >= {min_frames} frames, got {len(frames)}" + arr = np.stack(frames).astype(np.float64) + assert np.all(np.isfinite(arr)), f"{mp4_path}: decoded video has non-finite pixels" + # Both spatial and temporal flatness collapse global std toward 0; a real + # generated clip sits well above this floor (typically tens on a 0-255 scale). + assert arr.std() > 3.0, f"{mp4_path}: degenerate/near-constant video (pixel std={arr.std():.3f})" + + def _assert_valid_action(content: dict, where: str) -> None: """Assert a policy sample's predicted ``action`` is a non-empty, all-finite array.""" import numpy as np @@ -177,7 +251,8 @@ def _require_8_gpus() -> None: @pytest.mark.level(2) @pytest.mark.gpus(8) def test_nano_inference_omni(tmp_path: Path) -> None: - """One Cosmos3-Nano inference call over t2vs + policy + forward_dynamics; check each output.""" + """Throughput run over t2vs + policy + forward_dynamics, plus a separate latency transfer run.""" + # --- 1) Throughput run: t2vs + policy + forward_dynamics ---------------- out_dir = tmp_path / "out" cmd = [ "torchrun", @@ -222,7 +297,61 @@ def test_nano_inference_omni(tmp_path: Path) -> None: n_action += 1 # Every sample produces a valid video (t2vs, forward_dynamics, policy); - # the policy sample additionally yields an action, t2vs an audio track. + # the policy sample additionally yields an action and t2vs an audio track. assert n_video == len(_INPUTS), f"expected every sample to produce a valid video, got {n_video}/{len(_INPUTS)}" assert n_sound >= 1, f"expected the t2vs sample's audio to be checked, got {n_sound}" assert n_action >= 1, f"expected the policy sample's action to be checked, got {n_action}" + + # --- 2) Transfer run (separate, latency preset) ------------------------- + # Control-CFG (control_guidance > 1.0) runs an extra control-dropped forward + # each step. Under the throughput preset (data-parallel over samples, FSDP- + # sharded) that extra forward executes on only the transfer rank and + # deadlocks the cross-rank allgather, so transfer cannot share the call + # above; it needs the latency preset (context/CFG parallel -- every rank + # runs the same sample together), matching the cookbook multi-GPU transfer + # recipe. The spec is generated here (not committed under inputs/) and the + # control video is pulled from the public NVIDIA/cosmos GitHub raw URL. + # 4 ranks -> cfgp=2, cp=2 (the cookbook Cosmos3-Super transfer layout). + transfer_spec = tmp_path / "transfer_edge.json" + transfer_spec.write_text(json.dumps(_TRANSFER_SPEC)) + transfer_out = tmp_path / "out_transfer" + transfer_cmd = [ + "torchrun", + "--nproc_per_node=4", + f"--master_port={_free_port()}", + "-m", + "cosmos_framework.scripts.inference", + "--parallelism-preset=latency", + "-i", + str(transfer_spec), + "-o", + str(transfer_out), + "--checkpoint-path", + "Cosmos3-Nano", + "--seed=0", + ] + _run(transfer_cmd, tmp_path / "inference_transfer.log") + + transfer_results = sorted(transfer_out.rglob("sample_outputs.json")) + assert len(transfer_results) == 1, ( + f"expected 1 transfer sample_outputs.json, found {[str(p) for p in transfer_results]}" + ) + so = transfer_results[0] + args = json.loads(so.read_text()).get("args", {}) + # Transfer-specific input attributes: the edge control hint + the CFG knobs + # that select the control-CFG path. + edge = args.get("edge") or {} + assert edge.get("control_path"), f"transfer sample missing edge control_path ({so}); args keys={list(args)}" + assert args.get("control_guidance", 1.0) > 1.0, ( + f"expected control-CFG (control_guidance > 1.0), got {args.get('control_guidance')} ({so})" + ) + assert (args.get("guidance") or 1.0) > 1.0, ( + f"expected text-CFG (guidance > 1.0), got {args.get('guidance')} ({so})" + ) + # A valid, non-degenerate clip produced under control_guidance > 1.0 means the + # control-CFG branch ran to completion: a broken postprocess would raise + # mid-sampling, and a numerically broken one would collapse the output (caught + # by _assert_video_has_content). + transfer_video = so.parent / "vision.mp4" + assert transfer_video.is_file(), f"transfer run produced no vision.mp4 ({so})" + _assert_video_has_content(transfer_video)