Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cosmos_framework/inference/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,7 @@ class TransferDataArgs(ArgsBase, _TransferDataBase):
show_input: bool | None = None
num_first_chunk_conditional_frames: pydantic.NonNegativeInt | None = None
share_vision_temporal_positions: bool | None = None
emphasize_control_in_prompt: bool | None = None


class TransferDataOverrides(OverridesBase, _TransferDataBase):
Expand Down Expand Up @@ -726,6 +727,10 @@ class TransferDataOverrides(OverridesBase, _TransferDataBase):
"""Number of conditioning frames for the first chunk (defaults to ``num_conditional_frames``)."""
share_vision_temporal_positions: bool | None = None
"""Share vision temporal position ids across autoregressive chunks."""
emphasize_control_in_prompt: bool | None = None
"""If True (default), auto-append a one-sentence directive to the user prompt that
names the active control modality (e.g. "Follow the edge control video precisely.
..."). Set False for clean baselines / ablations. The system prompt is unchanged."""

@pydantic.model_validator(mode="after")
def _validate_transfer_hints(self) -> Self:
Expand Down Expand Up @@ -765,6 +770,7 @@ def download(self, output_dir: Path):
"show_input": False,
"num_first_chunk_conditional_frames": 0,
"share_vision_temporal_positions": True,
"emphasize_control_in_prompt": True,
}
_TRANSFER_HINT_DEFAULTS: ClassVar[dict[TransferHintKey, dict[str, Any]]] = {
TransferHintKey.EDGE: {"preset_edge_threshold": PresetEdgeThreshold.MEDIUM},
Expand Down
173 changes: 171 additions & 2 deletions cosmos_framework/inference/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Optional

import torch

Expand All @@ -27,6 +28,7 @@
from cosmos_framework.utils import log
from cosmos_framework.data.vfm.sequence_packing import SequencePlan
from cosmos_framework.model.vfm.omni_mot_model import OmniMoTModel
from cosmos_framework.model.vfm.utils.data_and_condition import GenerationDataClean
from cosmos_framework.model.vfm.vlm.qwen3_vl.utils import _SYSTEM_PROMPT_TRANSFER


Expand Down Expand Up @@ -193,6 +195,159 @@ def build_transfer_batch(
return batch


def _build_no_control_inference_state(
sequence_plans: list[SequencePlan],
gen_data_clean: GenerationDataClean,
) -> tuple[list[SequencePlan], GenerationDataClean, list[int]] | None:
"""Build a target-only counterpart of ``(sequence_plans, gen_data_clean)`` for
control-CFG. Drops all but the last vision item per sample (the target).

Returns ``None`` when no sample has multiple vision items (nothing to drop).

Also returns ``ctrl_dims_per_sample`` — the flattened control-token dimension
per sample, used to slice ``noise_x`` and mix velocities.
"""
num_items_per_sample = gen_data_clean.num_vision_items_per_sample
if num_items_per_sample is None or all(n <= 1 for n in num_items_per_sample):
return None

assert gen_data_clean.x0_tokens_vision is not None

new_x0_tokens_vision: list[torch.Tensor] = []
new_raw_state_vision: list[torch.Tensor] | None = [] if gen_data_clean.raw_state_vision is not None else None
ctrl_dims_per_sample: list[int] = []
vis_offset = 0
for n_vis in num_items_per_sample:
ctrl_dim_i = 0
for j in range(n_vis - 1):
sh = gen_data_clean.x0_tokens_vision[vis_offset + j].shape
ctrl_dim_i += math.prod(sh)
ctrl_dims_per_sample.append(ctrl_dim_i)
tgt_idx = vis_offset + n_vis - 1
new_x0_tokens_vision.append(gen_data_clean.x0_tokens_vision[tgt_idx])
if new_raw_state_vision is not None:
new_raw_state_vision.append(gen_data_clean.raw_state_vision[tgt_idx]) # type: ignore[index]
vis_offset += n_vis

gdc_nc = GenerationDataClean(
batch_size=gen_data_clean.batch_size,
is_image_batch=gen_data_clean.is_image_batch,
raw_state_vision=new_raw_state_vision,
x0_tokens_vision=new_x0_tokens_vision,
fps_vision=gen_data_clean.fps_vision,
num_vision_items_per_sample=None,
raw_state_action=gen_data_clean.raw_state_action,
x0_tokens_action=gen_data_clean.x0_tokens_action,
action_domain_id=gen_data_clean.action_domain_id,
fps_action=gen_data_clean.fps_action,
raw_action_dim=gen_data_clean.raw_action_dim,
raw_state_sound=gen_data_clean.raw_state_sound,
x0_tokens_sound=gen_data_clean.x0_tokens_sound,
fps_sound=gen_data_clean.fps_sound,
)

sp_nc = [
SequencePlan(
has_text=sp.has_text,
has_vision=sp.has_vision,
condition_frame_indexes_vision=sp.condition_frame_indexes_vision,
share_vision_temporal_positions=False,
has_action=sp.has_action,
condition_frame_indexes_action=sp.condition_frame_indexes_action,
action_start_frame_offset=sp.action_start_frame_offset,
has_sound=sp.has_sound,
condition_frame_indexes_sound=sp.condition_frame_indexes_sound,
)
for sp in sequence_plans
]

return sp_nc, gdc_nc, ctrl_dims_per_sample


def build_control_cfg_postprocess(
*,
control_guidance: float,
control_guidance_interval: Optional[list[float]] = None,
) -> Optional[
Callable[..., Optional[Callable[[list[torch.Tensor], list[torch.Tensor], torch.Tensor], list[torch.Tensor]]]]
]:
"""Return a ``velocity_postprocess_builder`` that injects control-CFG.

Pass the returned builder to ``OmniMoTModel.generate_samples_from_batch``.
The builder is invoked once at the start of sampling with the prepared
inference state; it builds the alternate (target-only) state and returns a
per-step closure that mixes the conditional velocity with an extra forward
pass that has all control items dropped.

Returns ``None`` when control-CFG is a no-op (``control_guidance == 1.0``),
so the model takes its fast single-forward path.
"""
if control_guidance == 1.0:
return None

def builder(
*,
model: OmniMoTModel,
net: torch.nn.Module | None = None,
cond_tokens: list[list[int]],
sequence_plans: list[SequencePlan],
gen_data_clean: GenerationDataClean,
) -> Optional[Callable[[list[torch.Tensor], list[torch.Tensor], torch.Tensor], list[torch.Tensor]]]:
nc_state = _build_no_control_inference_state(sequence_plans, gen_data_clean)
if nc_state is None:
log.warning(
"control_guidance != 1.0 but no multi-vision sample found; falling back to single-branch inference."
)
return None

if any(sp.has_action or sp.has_sound for sp in sequence_plans):
raise ValueError("control_guidance currently supports video transfer only, not action/sound generation.")

sp_nc, gdc_nc, ctrl_dims = nc_state
control_guidance_bounds: tuple[float, float] | None = None
if control_guidance_interval is not None:
if len(control_guidance_interval) != 2:
raise ValueError(f"control_guidance_interval must be [lo, hi], got {control_guidance_interval}")
control_guidance_bounds = (control_guidance_interval[0], control_guidance_interval[1])

def postprocess(
cond_v_full: list[torch.Tensor],
noise_x: list[torch.Tensor],
timestep: torch.Tensor,
) -> list[torch.Tensor]:
if control_guidance_bounds is not None:
if not (control_guidance_bounds[0] < timestep[0].item() < control_guidance_bounds[1]):
return cond_v_full

noise_x_nc = [nx[c:] for nx, c in zip(noise_x, ctrl_dims, strict=True)] # [[N_target],...]
cond_v_nc = model._get_velocity(
net=net,
noise_x=noise_x_nc,
timestep=timestep,
text_tokens=cond_tokens,
sequence_plans=sp_nc,
gen_data_clean=gdc_nc,
skip_text_tokens=False,
)

# Mix only the suffix (target vision). The control-token portion
# of cond_v_full is already zeroed by the model's velocity mask
# (control items are fully conditioned), so leave it untouched.
mixed: list[torch.Tensor] = []
for v_full_i, v_nc_i, c in zip(cond_v_full, cond_v_nc, ctrl_dims, strict=True):
suffix_full = v_full_i[c:] # [N_target]
assert suffix_full.shape == v_nc_i.shape, (
f"shape mismatch in control-CFG mix: full suffix {suffix_full.shape} vs no-control {v_nc_i.shape}"
)
mixed_suffix = v_nc_i + control_guidance * (suffix_full - v_nc_i) # [N_target]
mixed.append(torch.cat([v_full_i[:c], mixed_suffix], dim=0)) # [N_full]
return mixed

return postprocess

return builder


def generate_transfer_sample(
sample_args: OmniSampleArgs,
model: OmniMoTModel,
Expand Down Expand Up @@ -282,6 +437,18 @@ def generate_transfer_sample(
prompt = chunk_prompt_data[model.input_caption_key][0]
negative_prompt = chunk_prompt_data.get("neg_" + model.input_caption_key, [None])[0]

# Optionally append a one-sentence control-adherence directive to the user prompt.
# Names the active hint modality (e.g. "edge", "depth, seg") so the VLM gets the
# exact control type. System prompt is untouched (training-distribution safe).
if sample_args.emphasize_control_in_prompt:
hint_names = ", ".join(k.value for k in hints.keys())
prompt = (
prompt.rstrip() + f" Follow the {hint_names} control video precisely: shape, contour, silhouette,"
f" position, and motion of every visible structure must align with the {hint_names}"
f" signal at every frame."
)
log.info(f"[transfer] final user prompt: {prompt}")

model.eval()
seed = sample_args.seed if sample_args.seed is not None else random.randint(0, 10000)
for chunk_id in range(num_chunks):
Expand Down Expand Up @@ -351,8 +518,10 @@ def generate_transfer_sample(
sampler=sampler,
guidance=guidance,
guidance_interval=sample_args.guidance_interval,
control_guidance=sample_args.control_guidance,
control_guidance_interval=sample_args.control_guidance_interval,
velocity_postprocess_builder=build_control_cfg_postprocess(
control_guidance=sample_args.control_guidance,
control_guidance_interval=sample_args.control_guidance_interval,
),
seed=[seed + chunk_id],
n_sample=1,
has_negative_prompt=negative_prompt is not None,
Expand Down
Loading