From a8296fbac9fbed5dc8a7d53d9f0ac932659b5043 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Thu, 28 May 2026 11:13:16 +0000 Subject: [PATCH 01/15] Add Cosmos3 action generation support --- examples/cosmos3/inference_cosmos3.py | 106 +++- .../transformers/transformer_cosmos3.py | 132 +++- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 584 ++++++++++++++++-- 3 files changed, 755 insertions(+), 67 deletions(-) diff --git a/examples/cosmos3/inference_cosmos3.py b/examples/cosmos3/inference_cosmos3.py index fd0d0537cb0e..675ead892c2f 100644 --- a/examples/cosmos3/inference_cosmos3.py +++ b/examples/cosmos3/inference_cosmos3.py @@ -23,13 +23,15 @@ """ import argparse +import json import pathlib +import urllib.request import torch from huggingface_hub import snapshot_download from diffusers import Cosmos3OmniPipeline -from diffusers.utils import encode_video, export_to_video, load_image +from diffusers.utils import encode_video, export_to_video, load_image, load_video HF_REPOS = { @@ -38,6 +40,22 @@ } +def _load_action(path: str | None): + if path is None: + raise ValueError("--action-path is required for forward_dynamics mode.") + if path.startswith(("http://", "https://")): + with urllib.request.urlopen(path) as response: + action = json.loads(response.read().decode("utf-8")) + else: + action = json.loads(pathlib.Path(path).read_text()) + tensor = torch.as_tensor(action, dtype=torch.float32) + if tensor.ndim == 3 and tensor.shape[0] == 1: + tensor = tensor.squeeze(0) + if tensor.ndim != 2: + raise ValueError(f"Cosmos3 action must have shape [T, D], got {tuple(tensor.shape)}.") + return tensor + + def main(): parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument("--prompt", required=True, help="Text prompt.") @@ -50,7 +68,7 @@ def main(): parser.add_argument( "--vision-path", default=None, - help="Optional URL or local path for an image-conditioning frame (enables image-to-video).", + help="Optional URL or local path for an image-conditioning frame, or an action conditioning video.", ) parser.add_argument("--output", default=".", help="Directory to save generated video/image/audio files.") parser.add_argument("--height", type=int, default=720) @@ -62,12 +80,26 @@ def main(): help="Number of frames to generate. Use 1 for text-to-image; defaults to 189 for video (≈ 7.9s @ 24 FPS).", ) parser.add_argument("--fps", type=float, default=24.0) + parser.add_argument("--guidance-scale", type=float, default=6.0, help="Classifier-free guidance scale.") + parser.add_argument("--num-inference-steps", type=int, default=35, help="Number of denoising steps.") + parser.add_argument("--flow-shift", type=float, default=None, help="Scheduler flow shift.") + parser.add_argument("--seed", type=int, default=None, help="Random seed for latent initialization.") parser.add_argument( "--enable-sound", action="store_true", default=False, help="Generate sound alongside video (requires a sound-capable checkpoint).", ) + parser.add_argument( + "--action-mode", + choices=["forward_dynamics", "inverse_dynamics", "policy"], + default=None, + help="Enable Cosmos3 action generation with a loaded conditioning video.", + ) + parser.add_argument("--action-path", default=None, help="JSON action path for forward_dynamics mode.") + parser.add_argument("--action-chunk-size", type=int, default=None, help="Number of action tokens to generate/use.") + parser.add_argument("--domain-name", default=None, help="Cosmos3 action embodiment domain name.") + parser.add_argument("--raw-action-dim", type=int, default=None, help="Slice predicted action output to this size.") parser.add_argument( "--no-duration-template", dest="add_duration_template", @@ -110,21 +142,54 @@ def main(): output_dir = pathlib.Path(args.output) output_dir.mkdir(parents=True, exist_ok=True) - - image = load_image(args.vision_path) if args.vision_path is not None else None - - result = pipeline( - prompt=args.prompt, - image=image, - num_frames=args.num_frames, - height=args.height, - width=args.width, - fps=args.fps, - enable_sound=args.enable_sound, - add_resolution_template=args.add_resolution_template, - add_duration_template=args.add_duration_template, - enable_safety_check=not args.no_safety_check, - ) + generator = torch.Generator().manual_seed(args.seed) if args.seed is not None else None + + if args.action_mode is not None: + if args.vision_path is None: + raise ValueError("--vision-path must point to a video for action modes.") + if args.action_chunk_size is None: + raise ValueError("--action-chunk-size is required for action modes.") + video = load_video(args.vision_path) + action = _load_action(args.action_path) if args.action_mode == "forward_dynamics" else None + result = pipeline( + prompt=args.prompt, + video=video, + num_frames=args.action_chunk_size + 1, + height=args.height, + width=args.width, + fps=args.fps, + num_inference_steps=args.num_inference_steps, + flow_shift=args.flow_shift, + action_mode=args.action_mode, + action=action, + action_chunk_size=args.action_chunk_size, + domain_name=args.domain_name, + raw_action_dim=args.raw_action_dim, + guidance_scale=args.guidance_scale, + generator=generator, + use_system_prompt=False, + add_resolution_template=args.add_resolution_template, + add_duration_template=args.add_duration_template, + enable_safety_check=not args.no_safety_check, + ) + else: + image = load_image(args.vision_path) if args.vision_path is not None else None + result = pipeline( + prompt=args.prompt, + image=image, + num_frames=args.num_frames, + height=args.height, + width=args.width, + fps=args.fps, + num_inference_steps=args.num_inference_steps, + flow_shift=args.flow_shift, + enable_sound=args.enable_sound, + guidance_scale=args.guidance_scale, + generator=generator, + add_resolution_template=args.add_resolution_template, + add_duration_template=args.add_duration_template, + enable_safety_check=not args.no_safety_check, + ) if args.num_frames == 1: save_path = output_dir / "sample.jpg" @@ -145,6 +210,13 @@ def main(): export_to_video(result.video, str(save_path), fps=int(args.fps), quality=10, macro_block_size=1) print(f"Saved: {save_path}") + if result.action is not None: + for action in result.action: + action_path = output_dir / "sample_action.json" + with open(action_path, "w") as f: + json.dump(action.tolist(), f) + print(f"Saved: {action_path}") + if __name__ == "__main__": main() diff --git a/src/diffusers/models/transformers/transformer_cosmos3.py b/src/diffusers/models/transformers/transformer_cosmos3.py index 822d4f279e28..54fbe066ac33 100644 --- a/src/diffusers/models/transformers/transformer_cosmos3.py +++ b/src/diffusers/models/transformers/transformer_cosmos3.py @@ -146,6 +146,39 @@ def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) +class DomainAwareLinear(nn.Module): + """Linear projection with one weight/bias pair per embodiment domain.""" + + def __init__(self, input_size: int, output_size: int, num_domains: int) -> None: + super().__init__() + self.input_size = int(input_size) + self.output_size = int(output_size) + self.num_domains = int(num_domains) + self.fc = nn.Embedding(self.num_domains, self.output_size * self.input_size) + self.bias = nn.Embedding(self.num_domains, self.output_size) + nn.init.xavier_uniform_(self.fc.weight) + nn.init.zeros_(self.bias.weight) + + def forward(self, x: torch.Tensor, domain_id: torch.Tensor) -> torch.Tensor: + if domain_id.ndim == 0: + domain_id = domain_id.unsqueeze(0) + domain_id = domain_id.to(device=x.device, dtype=torch.long).reshape(-1) + if x.shape[0] != domain_id.shape[0]: + raise ValueError( + "Cosmos3 action domain_id batch size must match action tokens: " + f"tokens={x.shape[0]}, domain_id={domain_id.shape[0]}." + ) + if torch.any((domain_id < 0) | (domain_id >= self.num_domains)): + raise ValueError(f"Cosmos3 action domain_id must be in [0, {self.num_domains}), got {domain_id.tolist()}.") + weight = self.fc(domain_id).view(domain_id.shape[0], self.input_size, self.output_size) + bias = self.bias(domain_id).view(domain_id.shape[0], self.output_size) + if x.ndim == 2: + return torch.bmm(x.unsqueeze(1), weight).squeeze(1) + bias + if x.ndim == 3: + return torch.bmm(x, weight) + bias.unsqueeze(1) + raise ValueError(f"Cosmos3 DomainAwareLinear expected rank-2 or rank-3 input, got {tuple(x.shape)}.") + + class Cosmos3PackedMoTAttention(nn.Module, AttentionModuleMixin): """Dual-pathway packed attention for Qwen3VL MoT — separate projections for understanding (causal) and generation (full) token streams.""" @@ -291,6 +324,9 @@ def __init__( rms_norm_eps: float = 1e-6, rope_scaling: dict | None = None, rope_theta: float = 5000000.0, + action_dim: int | None = None, + action_gen: bool = False, + num_embodiment_domains: int = 32, sound_dim: int | None = None, sound_gen: bool = False, sound_latent_fps: float = 25.0, @@ -333,6 +369,13 @@ def __init__( self.proj_out = nn.Linear(hidden_size, patch_latent_dim, bias=True) self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.time_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size) + self.action_gen = action_gen + self.action_dim = int(32 if action_dim is None else action_dim) + self.num_embodiment_domains = int(num_embodiment_domains) + if action_gen: + self.action_proj_in = DomainAwareLinear(self.action_dim, hidden_size, self.num_embodiment_domains) + self.action_proj_out = DomainAwareLinear(hidden_size, self.action_dim, self.num_embodiment_domains) + self.action_modality_embed = nn.Parameter(torch.zeros(hidden_size)) if sound_gen: if sound_dim is None: raise ValueError("`sound_dim` must be provided when `sound_gen=True`.") @@ -464,9 +507,43 @@ def _unpack_sound_latents( unpacked.append(output) return unpacked + def _pack_action_latents( + self, + tokens_action: list[torch.Tensor], + token_shapes_action: list[tuple[int, int, int]], + domain_ids_action: list[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """List of ``[T, D]`` tensors → packed ``[total_T, D]`` plus per-token domain ids.""" + packed: list[torch.Tensor] = [] + domain_ids: list[torch.Tensor] = [] + for action, shape, domain_id in zip(tokens_action, token_shapes_action, domain_ids_action): + token_count = shape[0] + packed.append(action[:token_count]) + domain_ids.append(domain_id.reshape(1).expand(token_count)) + return torch.cat(packed, dim=0), torch.cat(domain_ids, dim=0) + + def _unpack_action_latents( + self, + packed_preds: torch.Tensor, + token_shapes_action: list[tuple[int, int, int]], + noisy_frame_indexes_action: list[torch.Tensor], + ) -> list[torch.Tensor]: + """Packed ``[total_noisy_T, D]`` predictions → list of ``[T, D]`` tensors.""" + unpacked: list[torch.Tensor] = [] + start_idx = 0 + for shape, noisy_idxs in zip(token_shapes_action, noisy_frame_indexes_action): + T = shape[0] + output = torch.zeros((T, self.action_dim), device=packed_preds.device, dtype=packed_preds.dtype) + t_n = len(noisy_idxs) + if t_n > 0: + output[noisy_idxs] = packed_preds[start_idx : start_idx + t_n] + start_idx += t_n + unpacked.append(output) + return unpacked + # ------------------------------------------------------------------------- - # forward: full per-step pass — encode text/vision/sound → run layers → - # decode vision/sound. Pipeline calls this once per CFG pass. + # forward: full per-step pass — encode text/vision/sound/action → run layers → + # decode vision/sound/action. Pipeline calls this once per CFG pass. # ------------------------------------------------------------------------- def forward( @@ -488,7 +565,14 @@ def forward( sound_mse_loss_indexes: torch.Tensor | None = None, sound_timesteps: torch.Tensor | None = None, sound_noisy_frame_indexes: list[torch.Tensor] | None = None, - ) -> tuple[list[torch.Tensor], list[torch.Tensor] | None]: + action_tokens: list[torch.Tensor] | None = None, + action_token_shapes: list[tuple[int, int, int]] | None = None, + action_sequence_indexes: torch.Tensor | None = None, + action_mse_loss_indexes: torch.Tensor | None = None, + action_timesteps: torch.Tensor | None = None, + action_noisy_frame_indexes: list[torch.Tensor] | None = None, + action_domain_ids: list[torch.Tensor] | None = None, + ) -> tuple[list[torch.Tensor], list[torch.Tensor] | None, list[torch.Tensor] | None]: """Run a full denoising-step forward pass. Args: @@ -511,10 +595,11 @@ def forward( sound_noisy_frame_indexes: Optional noisy frame indices per sound item. Returns: - ``(preds_vision, preds_sound)`` — list of per-modality latents (``preds_sound`` is ``None`` when the model - has no sound branch or sound inputs are omitted). + ``(preds_vision, preds_sound, preds_action)`` — lists of per-modality predictions. Optional modalities + return ``None`` when their inputs are omitted. """ has_sound = sound_tokens is not None and sound_sequence_indexes is not None + has_action = action_tokens is not None and action_sequence_indexes is not None # Embed text tokens into the joint hidden_states buffer at their sequence positions. packed_text_embedding = self.embed_tokens(input_ids) @@ -551,6 +636,27 @@ def forward( ) hidden_states[sound_sequence_indexes] = packed_tokens_sound + # Pack + project action latents (when present). Domain ids select the action head weights. + if has_action: + packed_tokens_action, per_token_domain_ids = self._pack_action_latents( + action_tokens, action_token_shapes, action_domain_ids + ) + packed_tokens_action = packed_tokens_action.to(target_dtype) + per_token_domain_ids = per_token_domain_ids.to(device=packed_tokens_action.device) + packed_tokens_action = self.action_proj_in(packed_tokens_action, per_token_domain_ids) + packed_tokens_action = packed_tokens_action + self.action_modality_embed + if action_mse_loss_indexes.numel() > 0: + timesteps_action = action_timesteps * self.config.timestep_scale + packed_timestep_embeds_action = self.time_embedder(self.time_proj(timesteps_action)) + packed_timestep_embeds_action = packed_timestep_embeds_action.to(target_dtype) + packed_tokens_action = self._apply_timestep_embeds_to_noisy_tokens( + packed_tokens=packed_tokens_action, + packed_timestep_embeds=packed_timestep_embeds_action, + noisy_frame_indexes=action_noisy_frame_indexes, + token_shapes=action_token_shapes, + ) + hidden_states[action_sequence_indexes] = packed_tokens_action + # Compute rotary embeddings once for the joint sequence, then slice into und/gen halves. _meta_tensor = torch.tensor([], dtype=hidden_states.dtype, device=hidden_states.device) cos, sin = self.rotary_emb( @@ -590,4 +696,18 @@ def forward( preds_sound_packed = self.audio_proj_out(last_hidden_state[sound_mse_loss_indexes]) preds_sound = self._unpack_sound_latents(preds_sound_packed, sound_token_shapes, sound_noisy_frame_indexes) - return preds_vision, preds_sound + preds_action: list[torch.Tensor] | None = None + if has_action: + per_noisy_domain_ids = [ + domain_id.reshape(1).expand(len(noisy_idxs)) + for domain_id, noisy_idxs in zip(action_domain_ids, action_noisy_frame_indexes) + ] + per_noisy_domain_ids = torch.cat(per_noisy_domain_ids, dim=0).to(device=last_hidden_state.device) + preds_action_packed = self.action_proj_out( + last_hidden_state[action_mse_loss_indexes], per_noisy_domain_ids + ) + preds_action = self._unpack_action_latents( + preds_action_packed, action_token_shapes, action_noisy_frame_indexes + ) + + return preds_vision, preds_sound, preds_action diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 7225cce6ac9b..2bf831c7dd6d 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -19,6 +19,7 @@ import numpy as np import torch +import torch.nn.functional as F from PIL import Image from transformers import AutoTokenizer, BatchEncoding @@ -130,6 +131,62 @@ def get_3d_mrope_ids_vae_tokens( _SYSTEM_PROMPT_IMAGE = "You are a helpful assistant who will generate images from a give prompt." _SYSTEM_PROMPT_VIDEO = "You are a helpful assistant who will generate videos from a give prompt." +_ACTION_MODE_FORWARD_DYNAMICS = "forward_dynamics" +_ACTION_MODE_INVERSE_DYNAMICS = "inverse_dynamics" +_ACTION_MODE_POLICY = "policy" +_ACTION_MODES = {_ACTION_MODE_FORWARD_DYNAMICS, _ACTION_MODE_INVERSE_DYNAMICS, _ACTION_MODE_POLICY} + +_ACTION_RESOLUTION_BINS = { + "256": { + "1.0": (256, 256), + "0.8": (256, 320), + "1.25": (320, 256), + "0.6": (192, 320), + "1.6666666666666667": (320, 192), + }, + "480": { + "1.0": (640, 640), + "0.7391304347826086": (544, 736), + "1.3529411764705883": (736, 544), + "0.5769230769230769": (480, 832), + "1.7333333333333334": (832, 480), + }, + "704": { + "1.0": (960, 960), + "0.7647058823529411": (832, 1088), + "1.3076923076923077": (1088, 832), + "0.55": (704, 1280), + "1.8181818181818181": (1280, 704), + }, + "720": { + "1.0": (960, 960), + "0.7536231884057971": (832, 1104), + "1.3269230769230769": (1104, 832), + "0.5625": (720, 1280), + "1.7777777777777777": (1280, 720), + }, +} + +_EMBODIMENT_TO_DOMAIN_ID = { + "no_action": 0, + "av": 1, + "camera_pose": 2, + "hand_pose": 3, + "pusht": 4, + "libero": 5, + "umi": 6, + "bridge_orig_lerobot": 7, + "droid_lerobot": 8, + "robomind-franka": 8, + "galbot": 9, + "robomind-franka-dual": 12, + "robomind-ur": 13, + "agibotworld": 15, + "agibot_gear_gripper": 15, + "agibot_gear_gripper_ext": 15, + "fractal": 20, +} + @dataclass class Cosmos3OmniPipelineOutput(BaseOutput): @@ -142,10 +199,12 @@ class Cosmos3OmniPipelineOutput(BaseOutput): when ``output_type="latent"``. sound: Decoded audio waveform of shape ``[C, N]``. ``None`` when ``enable_sound=False``. + action: Predicted action tokens. ``None`` unless an action mode predicts actions. """ video: Any sound: torch.Tensor | None = None + action: list[torch.Tensor] | None = None # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents @@ -308,6 +367,7 @@ def _prepare_vision_segment( vision_fps: float | None, curr: int, device: torch.device | str, + condition_frame_indexes: list[int] | None = None, ) -> dict[str, Any]: """Build the static portion of the vision segment of the joint sequence. @@ -322,12 +382,16 @@ def _prepare_vision_segment( patch_w = math.ceil(latent_w / latent_patch_size) num_vision_tokens = latent_t * patch_h * patch_w - noisy_start = 1 if has_image_condition else 0 - noisy_frame_indexes = torch.arange(noisy_start, latent_t, device=device, dtype=torch.long) + if condition_frame_indexes is None: + condition_frame_indexes = [0] if has_image_condition else [] + cond_frames = {idx for idx in condition_frame_indexes if 0 <= idx < latent_t} + noisy_frame_indexes = torch.tensor( + [idx for idx in range(latent_t) if idx not in cond_frames], device=device, dtype=torch.long + ) frame_token_stride = patch_h * patch_w mse_loss_indexes: list[int] = [] - for frame_idx in range(noisy_start, latent_t): + for frame_idx in noisy_frame_indexes.tolist(): frame_start = curr + frame_idx * frame_token_stride mse_loss_indexes.extend(range(frame_start, frame_start + frame_token_stride)) @@ -352,7 +416,7 @@ def _prepare_vision_segment( # Assembly helpers (consumed inline before the transformer call). "vision_mrope_ids": vision_mrope_ids.to(device), "num_vision_tokens": num_vision_tokens, - "num_noisy_vision_tokens": (latent_t - noisy_start) * frame_token_stride, + "num_noisy_vision_tokens": len(noisy_frame_indexes) * frame_token_stride, } def _prepare_sound_segment( @@ -396,37 +460,163 @@ def _prepare_sound_segment( "sound_len": sound_len, } + def _pack_action_tokens( + self, + input_action_tokens: torch.Tensor, + condition_frame_indexes: list[int], + mrope_offset: int | float, + action_fps: float | None, + curr: int, + device: torch.device | str, + ) -> dict[str, Any]: + """Build the static action segment; per-step tokens/timesteps are spliced in the denoising loop.""" + config = self.transformer.config + action_len = input_action_tokens.shape[0] + cond_frames = {idx for idx in condition_frame_indexes if 0 <= idx < action_len} + noisy_frame_indexes = torch.tensor( + [idx for idx in range(action_len) if idx not in cond_frames], device=device, dtype=torch.long + ) + + effective_fps = action_fps if config.enable_fps_modulation else None + action_mrope_ids, _ = get_3d_mrope_ids_vae_tokens( + grid_t=action_len, + grid_h=1, + grid_w=1, + temporal_offset=mrope_offset, + reset_spatial_indices=config.unified_3d_mrope_reset_spatial_ids, + fps=effective_fps, + base_fps=float(config.base_fps), + temporal_compression_factor=1, + start_frame_offset=1, + ) + + sequence_indexes = torch.arange(curr, curr + action_len, dtype=torch.long, device=device) + return { + "action_token_shapes": [(action_len, 1, 1)], + "action_sequence_indexes": sequence_indexes, + "action_mse_loss_indexes": sequence_indexes[noisy_frame_indexes], + "action_noisy_frame_indexes": [noisy_frame_indexes], + "action_mrope_ids": action_mrope_ids.to(device), + "action_len": action_len, + "num_noisy_action_tokens": len(noisy_frame_indexes), + } + + def _get_action_target_size( + self, + source_height: int, + source_width: int, + requested_height: int, + requested_width: int, + ) -> tuple[int, int]: + resolution_key = str(min(requested_height, requested_width)) + if resolution_key not in _ACTION_RESOLUTION_BINS: + raise ValueError( + f"Cosmos3 action resolution binning only supports {sorted(_ACTION_RESOLUTION_BINS)}, " + f"got height={requested_height}, width={requested_width}." + ) + return self.video_processor.classify_height_width_bin( + source_height, + source_width, + ratios=_ACTION_RESOLUTION_BINS[resolution_key], + ) + + def _prepare_action_video_conditioning( + self, + video: Any, + height: int, + width: int, + num_frames: int, + device: torch.device | str, + dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor, int, int]: + frames = self.video_processor.preprocess_video(video).to(device=device, dtype=dtype) + source_h, source_w = frames.shape[-2:] + target_h, target_w = self._get_action_target_size(source_h, source_w, height, width) + + if frames.shape[2] < num_frames: + frames = torch.cat([frames, frames[:, :, -1:].expand(-1, -1, num_frames - frames.shape[2], -1, -1)], dim=2) + else: + frames = frames[:, :, :num_frames] + + _, _, _, frame_h, frame_w = frames.shape + scale = min(target_w / frame_w, target_h / frame_h, 1.0) + content_h = max(1, int(scale * frame_h + 0.5)) + content_w = max(1, int(scale * frame_w + 0.5)) + + frames_t = frames.permute(0, 2, 1, 3, 4).reshape(-1, frames.shape[1], frame_h, frame_w) + if content_h != frame_h or content_w != frame_w: + frames_t = F.interpolate( + frames_t, + size=(content_h, content_w), + mode="bicubic", + align_corners=False, + antialias=True, + ) + pad_right = target_w - content_w + pad_bottom = target_h - content_h + if pad_right or pad_bottom: + pad_mode = "replicate" if pad_right >= content_w or pad_bottom >= content_h else "reflect" + frames_t = F.pad(frames_t, (0, pad_right, 0, pad_bottom), mode=pad_mode) + frames = frames_t.reshape(frames.shape[0], num_frames, frames.shape[1], target_h, target_w).permute( + 0, 2, 1, 3, 4 + ) + image_size = torch.tensor([target_h, target_w, content_h, content_w], device=device, dtype=torch.float32) + return frames.to(dtype=dtype), image_size, target_h, target_w + + def _remove_action_video_padding_from_latent( + self, latents: torch.Tensor, image_size: torch.Tensor + ) -> torch.Tensor: + content_h = int(image_size[2].item()) + content_w = int(image_size[3].item()) + content_h_latent = max(content_h // self.vae_scale_factor_spatial, 1) + content_w_latent = max(content_w // self.vae_scale_factor_spatial, 1) + return latents[:, :, :, :content_h_latent, :content_w_latent].contiguous() + + def _remove_action_video_padding_from_video(self, video: torch.Tensor, image_size: torch.Tensor) -> torch.Tensor: + content_h = int(image_size[2].item()) + content_w = int(image_size[3].item()) + return video[:, :, :, :content_h, :content_w].contiguous() + def prepare_latents( self, image: torch.Tensor | None = None, + video: Any | None = None, num_frames: int = 189, height: int = 720, width: int = 1280, fps: float = 24.0, latents: torch.Tensor | None = None, sound_latents: torch.Tensor | None = None, + action_latents: torch.Tensor | None = None, generator: torch.Generator | None = None, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, enable_sound: bool = False, + action_mode: str | None = None, + action: torch.Tensor | None = None, + action_chunk_size: int | None = None, + domain_name: str | None = None, + raw_action_dim: int | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, + torch.Tensor | None, float, float | None, torch.Tensor, torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + int | None, ]: """Build conditioning + initial noise for a single sample. Returns: - ``(vision_latents, sound_latents, fps_vision, fps_sound)``. ``vision_latents`` is the noisy vision tensor; - ``sound_latents`` is the noisy sound tensor (``None`` unless ``enable_sound`` was set). The FPS scalars - feed the per-step :meth:`_prepare_vision_segment` / :meth:`_prepare_sound_segment` calls in the denoising - loop. + Initial noisy tensors plus condition masks/metadata for vision, sound, and optional action modalities. """ is_image = num_frames == 1 - has_image_condition = image is not None and not is_image + has_image_condition = (image is not None and not is_image) or action_mode is not None # video_processor.preprocess handles PIL/np/tensor → [1, 3, H, W] in [-1, 1], resized to (height, width). conditioning_frame_2d: torch.Tensor | None = None @@ -435,8 +625,41 @@ def prepare_latents( device=device, dtype=dtype ) + action_domain_id: torch.Tensor | None = None + action_condition_mask: torch.Tensor | None = None + raw_action_dim_resolved: int | None = int(raw_action_dim) if raw_action_dim is not None else None + action_condition_frames: list[int] = [] + action_condition_frame_indexes: list[int] = [] + action_image_size: torch.Tensor | None = None + vision_condition_frames: list[int] | None = None + # Build the vision conditioning tensor (always [1, 3, T, H, W], in [-1, 1], on device). - if is_image: + if action_mode is not None: + if action_chunk_size is None: + raise ValueError("action_mode requires action_chunk_size.") + if video is None: + raise ValueError(f"action_mode={action_mode!r} requires loaded video conditioning.") + target_frames = action_chunk_size + 1 + if num_frames != target_frames: + raise ValueError( + "Action runs require num_frames to equal action_chunk_size + 1; " + f"got num_frames={num_frames}, action_chunk_size={action_chunk_size}." + ) + vision_tensor, action_image_size, height, width = self._prepare_action_video_conditioning( + video, height, width, target_frames, device=device, dtype=dtype + ) + if action_mode == _ACTION_MODE_FORWARD_DYNAMICS: + vision_condition_frames = [0] + action_condition_frames = list(range(action_chunk_size)) + elif action_mode == _ACTION_MODE_POLICY: + vision_condition_frames = [0] + elif action_mode == _ACTION_MODE_INVERSE_DYNAMICS: + latent_frames = (target_frames - 1) // self.vae.config.scale_factor_temporal + 1 + vision_condition_frames = list(range(latent_frames)) + else: + raise ValueError(f"Unsupported action_mode={action_mode!r}; expected one of {sorted(_ACTION_MODES)}.") + action_condition_frame_indexes = action_condition_frames + elif is_image: vision_tensor = ( conditioning_frame_2d.unsqueeze(2) # [1, 3, 1, H, W] if conditioning_frame_2d is not None @@ -451,6 +674,8 @@ def prepare_latents( vision_tensor[:, :, 1:] = conditioning_frame_2d.unsqueeze(2).expand(-1, -1, num_frames - 1, -1, -1) x0_tokens_vision = self._encode_video(vision_tensor).contiguous().float() + if action_image_size is not None: + x0_tokens_vision = self._remove_action_video_padding_from_latent(x0_tokens_vision, action_image_size) vision_shape = tuple(x0_tokens_vision.shape) x0_tokens_sound: torch.Tensor | None = None @@ -463,9 +688,60 @@ def prepare_latents( T_sound = (n_audio_samples + hop_size - 1) // hop_size x0_tokens_sound = torch.zeros(sound_dim, T_sound, device=device, dtype=dtype) + x0_tokens_action: torch.Tensor | None = None + if action_mode is not None: + assert action_chunk_size is not None + action_dim = self.transformer.action_dim + if action_mode == "forward_dynamics": + if action is None: + raise ValueError("action_mode='forward_dynamics' requires an action tensor.") + action = action.to(device=device, dtype=dtype) + if action.shape[0] == 0: + raise ValueError("action_mode='forward_dynamics' requires at least one action token.") + + # Action chunks describe transitions, so action length must match action_chunk_size + # while the paired video has action_chunk_size + 1 frames. Short inputs repeat the last action. + if action.shape[0] < action_chunk_size: + action = torch.cat( + [action, action[-1:].expand(action_chunk_size - action.shape[0], -1)], + dim=0, + ) + action = action[:action_chunk_size] + + # The model action head has a fixed action_dim; pad raw domain actions with zeros on the channel axis. + if action.shape[-1] > action_dim: + raise ValueError( + f"Cosmos3 action dimension {action.shape[-1]} exceeds model action_dim={action_dim}." + ) + if action.shape[-1] < action_dim: + action_padding = torch.zeros( + action.shape[0], + action_dim - action.shape[-1], + dtype=action.dtype, + device=action.device, + ) + action = torch.cat([action, action_padding], dim=-1) + x0_tokens_action = action + else: + x0_tokens_action = torch.zeros(action_chunk_size, action_dim, device=device, dtype=dtype) + if domain_name not in _EMBODIMENT_TO_DOMAIN_ID: + raise ValueError( + f"Unknown Cosmos3 action domain_name={domain_name!r}; " + f"expected one of {sorted(_EMBODIMENT_TO_DOMAIN_ID)}." + ) + action_domain_id = torch.tensor( + [_EMBODIMENT_TO_DOMAIN_ID[domain_name]], + dtype=torch.long, + device=device, + ) + # Vision conditioning mask [latent_t, 1, 1]: frame 0 anchored when image-conditioning, rest noisy. vision_condition_mask = torch.zeros((x0_tokens_vision.shape[2], 1, 1), device=device, dtype=dtype) - if has_image_condition: + if vision_condition_frames is not None: + for frame_idx in vision_condition_frames: + if 0 <= frame_idx < vision_condition_mask.shape[0]: + vision_condition_mask[frame_idx, 0, 0] = 1.0 + elif has_image_condition: vision_condition_mask[0, 0, 0] = 1.0 if latents is None: @@ -491,17 +767,55 @@ def prepare_latents( else: sound_latents = sound_latents.to(device=device, dtype=dtype) - return latents, sound_latents, fps, fps_sound, vision_condition_mask, sound_condition_mask + if action_mode is not None and x0_tokens_action is not None: + action_condition_mask = torch.zeros((x0_tokens_action.shape[0], 1), device=device, dtype=dtype) + for frame_idx in action_condition_frames: + if 0 <= frame_idx < action_condition_mask.shape[0]: + action_condition_mask[frame_idx, 0] = 1.0 + if action_latents is None: + pure_noise_action = randn_tensor( + tuple(x0_tokens_action.shape), generator=generator, device=device, dtype=dtype + ) + action_latents = ( + action_condition_mask * x0_tokens_action + (1.0 - action_condition_mask) * pure_noise_action + ) + if raw_action_dim_resolved is not None: + action_latents[:, raw_action_dim_resolved:] = 0 + else: + action_latents = action_latents.to(device=device, dtype=dtype) + + return ( + latents, + sound_latents, + action_latents, + fps, + fps_sound, + vision_condition_mask, + sound_condition_mask, + action_condition_mask, + action_domain_id, + action_image_size, + raw_action_dim_resolved, + action_condition_frame_indexes, + ) def check_inputs( self, prompt, negative_prompt, + image, + video, height: int, width: int, num_frames: int, + guidance_scale: float, enable_sound: bool, callback_on_step_end_tensor_inputs: list[str], + action_mode: str | None, + action: torch.Tensor | None, + action_chunk_size: int | None, + domain_name: str | None, + raw_action_dim: int | None, ) -> None: if not isinstance(prompt, (str, list)) or ( isinstance(prompt, list) and not all(isinstance(p, str) for p in prompt) @@ -526,6 +840,31 @@ def check_inputs( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) + if action_mode is not None: + if not getattr(self.transformer.config, "action_gen", False): + raise ValueError("action_mode requires a transformer trained with action_gen=True.") + if image is not None: + raise ValueError("Use `video`, not `image`, for Cosmos3 action conditioning.") + if video is None: + raise ValueError(f"action_mode={action_mode!r} requires a loaded conditioning video.") + if action_chunk_size is None: + raise ValueError("action_mode requires action_chunk_size.") + if num_frames != action_chunk_size + 1: + raise ValueError( + "Action runs require num_frames to equal action_chunk_size + 1; " + f"got num_frames={num_frames}, action_chunk_size={action_chunk_size}." + ) + if domain_name is None: + raise ValueError("action_mode requires domain_name.") + if domain_name not in _EMBODIMENT_TO_DOMAIN_ID: + raise ValueError( + f"Unknown Cosmos3 action domain_name={domain_name!r}; " + f"expected one of {sorted(_EMBODIMENT_TO_DOMAIN_ID)}." + ) + if action_mode == "forward_dynamics" and action is None: + raise ValueError("action_mode='forward_dynamics' requires an action tensor.") + if action_mode in {"inverse_dynamics", "policy"} and raw_action_dim is None: + raise ValueError(f"action_mode={action_mode!r} requires raw_action_dim for output slicing.") def tokenize_prompt( self, @@ -538,6 +877,7 @@ def tokenize_prompt( use_system_prompt: bool = True, add_resolution_template: bool = True, add_duration_template: bool = True, + action_mode: str | None = None, ) -> tuple[list[int], list[int]]: """Apply prompt-augmentation templates and tokenize cond/uncond prompts via the Qwen2 chat template. @@ -606,7 +946,10 @@ def _mask_velocity_predictions( preds_sound: list[torch.Tensor] | None, vision_condition_mask: list[torch.Tensor], sound_condition_mask: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + preds_action: list[torch.Tensor] | None = None, + action_condition_mask: list[torch.Tensor] | None = None, + raw_action_dim: int | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: """Zero out conditioning positions in the transformer's velocity predictions. ``preds_vision`` / ``preds_sound`` are returned per-sample by the transformer; the pipeline runs batch=1, so we @@ -625,7 +968,16 @@ def _mask_velocity_predictions( noisy_mask_s = (1.0 - cond_mask_s).T.to(dtype=pred_s.dtype, device=pred_s.device) velocity_sound = pred_s * noisy_mask_s if noisy_mask_s.sum() > 0 else torch.zeros_like(pred_s) - return velocity_vision, velocity_sound + velocity_action: torch.Tensor | None = None + if preds_action is not None and action_condition_mask is not None: + pred_a = preds_action[0] + cond_mask_a = action_condition_mask[0] + noisy_mask_a = (1.0 - cond_mask_a).to(dtype=pred_a.dtype, device=pred_a.device) + velocity_action = pred_a * noisy_mask_a if noisy_mask_a.sum() > 0 else torch.zeros_like(pred_a) + if raw_action_dim is not None: + velocity_action[:, raw_action_dim:] = 0 + + return velocity_vision, velocity_sound, velocity_action def _apply_video_safety_check(self, video: Any, output_type: str, device: torch.device) -> Any: """Run the Cosmos video guardrail on a postprocessed video and return it in the same format. @@ -676,16 +1028,24 @@ def __call__( prompt: str | list[str], negative_prompt: str | list[str] | None = None, image: torch.Tensor | None = None, + video: Any | None = None, num_frames: int = 189, height: int = 720, width: int = 1280, fps: float = 24.0, num_inference_steps: int = 35, guidance_scale: float = 6.0, + flow_shift: float | None = None, enable_sound: bool = False, generator: torch.Generator | None = None, latents: torch.Tensor | None = None, sound_latents: torch.Tensor | None = None, + action_latents: torch.Tensor | None = None, + action_mode: str | None = None, + action: torch.Tensor | None = None, + action_chunk_size: int | None = None, + domain_name: str | None = None, + raw_action_dim: int | None = None, output_type: str = "pil", return_dict: bool = True, use_system_prompt: bool = True, @@ -770,9 +1130,28 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if action_mode is not None and action_mode not in _ACTION_MODES: + raise ValueError(f"Unsupported action_mode={action_mode!r}; expected one of {sorted(_ACTION_MODES)}.") + if action_mode is not None and action_chunk_size is not None: + num_frames = action_chunk_size + 1 + # 1. Check inputs self.check_inputs( - prompt, negative_prompt, height, width, num_frames, enable_sound, callback_on_step_end_tensor_inputs + prompt, + negative_prompt, + image, + video, + height, + width, + num_frames, + guidance_scale, + enable_sound, + callback_on_step_end_tensor_inputs, + action_mode, + action, + action_chunk_size, + domain_name, + raw_action_dim, ) self._current_timestep = None @@ -809,6 +1188,7 @@ def __call__( use_system_prompt=use_system_prompt, add_resolution_template=add_resolution_template, add_duration_template=add_duration_template, + action_mode=action_mode, ) # 3. Pre-pack the text segment for each prompt — text packing is invariant @@ -817,22 +1197,42 @@ def __call__( uncond_text_segment = self._prepare_text_segment(uncond_input_ids, device=device) # 4. Prepare latents (initial noise per modality + pack metadata) - has_image_condition = image is not None and num_frames > 1 - latents, sound_latents, fps_vision, fps_sound, vision_condition_mask, sound_condition_mask = ( - self.prepare_latents( - image=image, - num_frames=num_frames, - height=height, - width=width, - fps=fps, - latents=latents, - sound_latents=sound_latents, - generator=generator, - device=device, - dtype=dtype, - enable_sound=enable_sound, - ) + ( + latents, + sound_latents, + action_latents, + fps_vision, + fps_sound, + vision_condition_mask, + sound_condition_mask, + action_condition_mask, + action_domain_id, + action_image_size, + raw_action_dim_resolved, + action_condition_frame_indexes, + ) = self.prepare_latents( + image=image, + video=video, + num_frames=num_frames, + height=height, + width=width, + fps=fps, + latents=latents, + sound_latents=sound_latents, + action_latents=action_latents, + generator=generator, + device=device, + dtype=dtype, + enable_sound=enable_sound, + action_mode=action_mode, + action=action, + action_chunk_size=action_chunk_size, + domain_name=domain_name, + raw_action_dim=raw_action_dim, ) + vision_condition_indexes_for_pack = torch.nonzero(vision_condition_mask[:, 0, 0] > 0, as_tuple=False).flatten() + vision_condition_indexes_for_pack = [int(idx.item()) for idx in vision_condition_indexes_for_pack] + has_image_condition = bool(vision_condition_indexes_for_pack) # 5. Pre-pack the static per-prompt vision / sound sequence segments. The only # fields that vary across denoising steps are the modality token tensors and the @@ -846,6 +1246,7 @@ def __call__( vision_fps=fps_vision, curr=cond_text_segment["und_len"], device=device, + condition_frame_indexes=vision_condition_indexes_for_pack, ) cond_sound_segment: dict[str, Any] = {} if sound_latents is not None: @@ -856,17 +1257,33 @@ def __call__( curr=cond_text_segment["und_len"] + cond_vision_segment["num_vision_tokens"], device=device, ) + cond_action_segment: dict[str, Any] = {} + if action_latents is not None: + cond_action_segment = self._pack_action_tokens( + input_action_tokens=action_latents, + condition_frame_indexes=action_condition_frame_indexes, + mrope_offset=cond_text_segment["vision_start_temporal_offset"], + action_fps=fps_vision, + curr=cond_text_segment["und_len"] + + cond_vision_segment["num_vision_tokens"] + + cond_sound_segment.get("sound_len", 0), + device=device, + ) cond_mrope_segments = [cond_text_segment["text_mrope_ids"], cond_vision_segment["vision_mrope_ids"]] if cond_sound_segment: cond_mrope_segments.append(cond_sound_segment["sound_mrope_ids"]) + if cond_action_segment: + cond_mrope_segments.append(cond_action_segment["action_mrope_ids"]) cond_packed_static = { **cond_text_segment, **cond_vision_segment, **cond_sound_segment, + **cond_action_segment, "position_ids": torch.cat(cond_mrope_segments, dim=1), "sequence_length": cond_text_segment["und_len"] + cond_vision_segment["num_vision_tokens"] - + cond_sound_segment.get("sound_len", 0), + + cond_sound_segment.get("sound_len", 0) + + cond_action_segment.get("action_len", 0), } uncond_vision_segment = self._prepare_vision_segment( @@ -876,6 +1293,7 @@ def __call__( vision_fps=fps_vision, curr=uncond_text_segment["und_len"], device=device, + condition_frame_indexes=vision_condition_indexes_for_pack, ) uncond_sound_segment: dict[str, Any] = {} if sound_latents is not None: @@ -886,29 +1304,58 @@ def __call__( curr=uncond_text_segment["und_len"] + uncond_vision_segment["num_vision_tokens"], device=device, ) + uncond_action_segment: dict[str, Any] = {} + if action_latents is not None: + uncond_action_segment = self._pack_action_tokens( + input_action_tokens=action_latents, + condition_frame_indexes=action_condition_frame_indexes, + mrope_offset=uncond_text_segment["vision_start_temporal_offset"], + action_fps=fps_vision, + curr=uncond_text_segment["und_len"] + + uncond_vision_segment["num_vision_tokens"] + + uncond_sound_segment.get("sound_len", 0), + device=device, + ) uncond_mrope_segments = [uncond_text_segment["text_mrope_ids"], uncond_vision_segment["vision_mrope_ids"]] if uncond_sound_segment: uncond_mrope_segments.append(uncond_sound_segment["sound_mrope_ids"]) + if uncond_action_segment: + uncond_mrope_segments.append(uncond_action_segment["action_mrope_ids"]) uncond_packed_static = { **uncond_text_segment, **uncond_vision_segment, **uncond_sound_segment, + **uncond_action_segment, "position_ids": torch.cat(uncond_mrope_segments, dim=1), "sequence_length": uncond_text_segment["und_len"] + uncond_vision_segment["num_vision_tokens"] - + uncond_sound_segment.get("sound_len", 0), + + uncond_sound_segment.get("sound_len", 0) + + uncond_action_segment.get("action_len", 0), } num_noisy_vision_tokens = cond_vision_segment["num_noisy_vision_tokens"] sound_len = cond_sound_segment.get("sound_len") + action_noisy_len = cond_action_segment.get("num_noisy_action_tokens") # 6. Set timesteps. UniPCMultistepScheduler keeps per-step state (_step_index, - # model_outputs history) on the instance, so audio gets its own copy. - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - sound_scheduler = copy.deepcopy(self.scheduler) if sound_latents is not None else None + # model_outputs history) on the instance, so audio/action gets its own copy. + inference_scheduler = copy.deepcopy(self.scheduler) + if flow_shift is not None: + inference_scheduler.register_to_config( + use_flow_sigmas=True, + use_karras_sigmas=False, + use_exponential_sigmas=False, + use_beta_sigmas=False, + flow_shift=flow_shift, + shift_terminal=None, + final_sigmas_type="zero", + ) + inference_scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = inference_scheduler.timesteps + sound_scheduler = copy.deepcopy(inference_scheduler) if sound_latents is not None else None + action_scheduler = copy.deepcopy(inference_scheduler) if action_latents is not None else None # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + num_warmup_steps = len(timesteps) - num_inference_steps * inference_scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -922,15 +1369,19 @@ def __call__( # noisy tokens before packing so the modality tokens enter the model in the right dtype. vision_tokens = latents.to(device=device, dtype=dtype) sound_tokens = sound_latents.to(device=device, dtype=dtype) if sound_latents is not None else None + action_tokens = action_latents.to(device=device, dtype=dtype) if action_latents is not None else None # The static packs both report the same num_noisy_vision_tokens / sound_len, so a # single per-step timestep tensor per modality is shared by the cond / uncond passes. vision_timesteps = torch.full((num_noisy_vision_tokens,), timestep, device=device) sound_timesteps = ( torch.full((sound_len,), timestep, device=device) if sound_tokens is not None else None ) + action_timesteps = ( + torch.full((action_noisy_len,), timestep, device=device) if action_tokens is not None else None + ) # --- Conditional pass --- - preds_vision, preds_sound = self.transformer( + preds_vision, preds_sound, preds_action = self.transformer( input_ids=cond_packed_static["input_ids"], text_indexes=cond_packed_static["text_indexes"], position_ids=cond_packed_static["position_ids"], @@ -948,17 +1399,28 @@ def __call__( sound_mse_loss_indexes=cond_packed_static.get("sound_mse_loss_indexes"), sound_timesteps=sound_timesteps, sound_noisy_frame_indexes=cond_packed_static.get("sound_noisy_frame_indexes"), + action_tokens=[action_tokens] if action_tokens is not None else None, + action_token_shapes=cond_packed_static.get("action_token_shapes"), + action_sequence_indexes=cond_packed_static.get("action_sequence_indexes"), + action_mse_loss_indexes=cond_packed_static.get("action_mse_loss_indexes"), + action_timesteps=action_timesteps, + action_noisy_frame_indexes=cond_packed_static.get("action_noisy_frame_indexes"), + action_domain_ids=[action_domain_id] if action_domain_id is not None else None, ) - cond_v_vision, cond_v_sound = self._mask_velocity_predictions( + cond_v_vision, cond_v_sound, cond_v_action = self._mask_velocity_predictions( preds_vision, preds_sound, vision_condition_mask=[vision_condition_mask], sound_condition_mask=[sound_condition_mask] if sound_condition_mask is not None else None, + preds_action=preds_action, + action_condition_mask=[action_condition_mask] if action_condition_mask is not None else None, + raw_action_dim=raw_action_dim_resolved, ) # --- Unconditional pass (Skip if not using CFG) --- + uncond_v_vision = uncond_v_sound = uncond_v_action = None if guidance_scale != 1.0: - preds_vision, preds_sound = self.transformer( + preds_vision, preds_sound, preds_action = self.transformer( input_ids=uncond_packed_static["input_ids"], text_indexes=uncond_packed_static["text_indexes"], position_ids=uncond_packed_static["position_ids"], @@ -976,12 +1438,22 @@ def __call__( sound_mse_loss_indexes=uncond_packed_static.get("sound_mse_loss_indexes"), sound_timesteps=sound_timesteps, sound_noisy_frame_indexes=uncond_packed_static.get("sound_noisy_frame_indexes"), + action_tokens=[action_tokens] if action_tokens is not None else None, + action_token_shapes=uncond_packed_static.get("action_token_shapes"), + action_sequence_indexes=uncond_packed_static.get("action_sequence_indexes"), + action_mse_loss_indexes=uncond_packed_static.get("action_mse_loss_indexes"), + action_timesteps=action_timesteps, + action_noisy_frame_indexes=uncond_packed_static.get("action_noisy_frame_indexes"), + action_domain_ids=[action_domain_id] if action_domain_id is not None else None, ) - uncond_v_vision, uncond_v_sound = self._mask_velocity_predictions( + uncond_v_vision, uncond_v_sound, uncond_v_action = self._mask_velocity_predictions( preds_vision, preds_sound, vision_condition_mask=[vision_condition_mask], sound_condition_mask=[sound_condition_mask] if sound_condition_mask is not None else None, + preds_action=preds_action, + action_condition_mask=[action_condition_mask] if action_condition_mask is not None else None, + raw_action_dim=raw_action_dim_resolved, ) # --- CFG combine + per-modality scheduler step --- @@ -994,7 +1466,7 @@ def __call__( else: velocity_vision = cond_v_vision - latents = self.scheduler.step( + latents = inference_scheduler.step( velocity_vision.unsqueeze(0), t, latents.unsqueeze(0), return_dict=False )[0].squeeze(0) @@ -1008,18 +1480,40 @@ def __call__( velocity_sound.unsqueeze(0), t, sound_latents.unsqueeze(0), return_dict=False )[0].squeeze(0) + has_noisy_action = ( + action_condition_mask is not None and action_condition_mask.sum() < action_condition_mask.numel() + ) + if action_scheduler is not None and has_noisy_action and cond_v_action is not None: + if guidance_scale != 1.0: + velocity_action = uncond_v_action + guidance_scale * (cond_v_action - uncond_v_action) + else: + velocity_action = cond_v_action + action_latents = action_scheduler.step( + velocity_action.unsqueeze(0), t, action_latents.unsqueeze(0), return_dict=False + )[0].squeeze(0) + if raw_action_dim_resolved is not None: + action_latents[:, raw_action_dim_resolved:] = 0 + if callback_on_step_end is not None: callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs} callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0 + ): progress_bar.update() self._current_timestep = None # 8. Postprocess + decode sound = self.decode_sound(sound_latents) if sound_latents is not None else None + action_output = None + if action_mode in {"inverse_dynamics", "policy"} and action_latents is not None: + action_output = action_latents + if raw_action_dim_resolved is not None: + action_output = action_output[:, :raw_action_dim_resolved] + action_output = [action_output.detach().cpu()] if output_type == "latent": video = latents else: @@ -1037,5 +1531,7 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: + if action_mode is not None: + return (video, sound, action_output) return (video, sound) - return Cosmos3OmniPipelineOutput(video=video, sound=sound) + return Cosmos3OmniPipelineOutput(video=video, sound=sound, action=action_output) From 2fcef5b47c01ef7569691f037604dd8b1b3966b9 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Thu, 28 May 2026 11:40:25 +0000 Subject: [PATCH 02/15] Add README action examples --- examples/cosmos3/README.md | 103 +++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/examples/cosmos3/README.md b/examples/cosmos3/README.md index 7a4cb277aa07..98cf30eac6d9 100644 --- a/examples/cosmos3/README.md +++ b/examples/cosmos3/README.md @@ -48,6 +48,104 @@ python examples/cosmos3/inference_cosmos3.py \ --enable-sound ``` +Action forward dynamics, robot domain (predict video from an observation video and a provided action chunk): + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "Put the pot to the left of the purple item. This video is captured from a first-person perspective looking at the scene." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" \ + --action-mode forward_dynamics \ + --action-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.json" \ + --action-chunk-size 16 \ + --domain-name bridge_orig_lerobot \ + --height 480 --width 832 --fps 5 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ + --output results/cosmos3_forward_dynamics_robot +``` + +Action forward dynamics, autonomous-vehicle domain: + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "You are an autonomous vehicle planning system. This video is captured from a first-person perspective looking at the scene." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_vision_25_73d01c91-51f0-46cf-9b76-5682a76fb349.mp4" \ + --action-mode forward_dynamics \ + --action-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_action_25.json" \ + --action-chunk-size 60 \ + --domain-name av \ + --height 480 --width 832 --fps 10 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ + --output results/cosmos3_forward_dynamics_av +``` + +Action inverse dynamics, robot domain (predict actions from an observed video): + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "Put the pot to the left of the purple item. This video is captured from a first-person perspective looking at the scene." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" \ + --action-mode inverse_dynamics \ + --action-chunk-size 16 \ + --raw-action-dim 10 \ + --domain-name bridge_orig_lerobot \ + --height 480 --width 832 --fps 5 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ + --output results/cosmos3_inverse_dynamics_robot +``` + +Action inverse dynamics, autonomous-vehicle domain: + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "You are an autonomous vehicle planning system. This video is captured from a first-person perspective looking at the scene." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_vision_25_73d01c91-51f0-46cf-9b76-5682a76fb349.mp4" \ + --action-mode inverse_dynamics \ + --action-chunk-size 60 \ + --raw-action-dim 9 \ + --domain-name av \ + --height 480 --width 832 --fps 10 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ + --output results/cosmos3_inverse_dynamics_av +``` + +Action policy, robot domain (predict both future video and actions from the first observation frame): + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "Put the pot to the left of the purple item. This video is captured from a first-person perspective looking at the scene." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" \ + --action-mode policy \ + --action-chunk-size 16 \ + --raw-action-dim 10 \ + --domain-name bridge_orig_lerobot \ + --height 480 --width 832 --fps 5 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ + --output results/cosmos3_policy_robot +``` + +Action policy, autonomous-vehicle domain: + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "You are an autonomous vehicle planning system. Please go backward. This video is captured from a first-person perspective looking at the scene." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_vision_25_73d01c91-51f0-46cf-9b76-5682a76fb349.mp4" \ + --action-mode policy \ + --action-chunk-size 60 \ + --raw-action-dim 9 \ + --domain-name av \ + --height 480 --width 832 --fps 10 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ + --output results/cosmos3_policy_av +``` + +Action modes use `action_chunk_size + 1` video frames. `forward_dynamics` consumes `--action-path`; `inverse_dynamics` and `policy` write predicted actions to `sample-*_action.json` in model-normalized action space. The upstream camera-pose forward-dynamics sample uses a still image (`mountain_720.png`), while this wrapper currently expects `--vision-path` to load as video for action modes. + ### Useful flags | Flag | Default | Description | @@ -58,6 +156,11 @@ python examples/cosmos3/inference_cosmos3.py \ | `--height` / `--width` | `720` / `1280` | Output resolution (must be a multiple of the VAE spatial scale factor). | | `--fps` | `24.0` | Frame rate of the generated video. | | `--enable-sound` | off | Generate a synchronized audio track. | +| `--action-mode` | `None` | Enable action conditioning/generation. One of `forward_dynamics`, `inverse_dynamics`, or `policy`. | +| `--action-path` | `None` | URL or local JSON action path for `forward_dynamics`. | +| `--action-chunk-size` | `None` | Number of action tokens. Action runs generate/use `action_chunk_size + 1` video frames. | +| `--domain-name` | `None` | Action embodiment domain, for example `bridge_orig_lerobot` or `av`. | +| `--raw-action-dim` | `None` | Slice predicted action output to the unpadded action dimension. Required for `inverse_dynamics` and `policy`. | | `--no-duration-template` | off | Skip the duration metadata sentence appended to the prompt and negative prompt. Ignored for `--num-frames 1`. | | `--no-resolution-template` | off | Skip the resolution metadata sentence appended to the prompt and negative prompt. | | `--output` | `.` | Directory to write `sample.jpg` or `sample.mp4`. | From 40ea9732ce41b52f27b9a0f49ea04131c9f1d682 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Thu, 28 May 2026 11:59:01 +0000 Subject: [PATCH 03/15] Use do_classifier_free_guidance property --- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 2bf831c7dd6d..460b0786e4a1 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -1022,6 +1022,10 @@ def current_timestep(self): def interrupt(self): return self._interrupt + @property + def do_classifier_free_guidance(self): + return self._guidance_scale != 1.0 + @torch.no_grad() def __call__( self, @@ -1156,6 +1160,7 @@ def __call__( self._current_timestep = None self._interrupt = False + self._guidance_scale = guidance_scale # Pipeline supports a single sample at a time; collapse list-style inputs to a single string. if isinstance(prompt, list): @@ -1419,7 +1424,7 @@ def __call__( # --- Unconditional pass (Skip if not using CFG) --- uncond_v_vision = uncond_v_sound = uncond_v_action = None - if guidance_scale != 1.0: + if self.do_classifier_free_guidance: preds_vision, preds_sound, preds_action = self.transformer( input_ids=uncond_packed_static["input_ids"], text_indexes=uncond_packed_static["text_indexes"], @@ -1461,7 +1466,7 @@ def __call__( # to carry a batch dim; per-modality latents have no batch axis, so wrap for the step. # Skip CFG for 1.0 guidance scale - if guidance_scale != 1.0: + if self.do_classifier_free_guidance: velocity_vision = uncond_v_vision + guidance_scale * (cond_v_vision - uncond_v_vision) else: velocity_vision = cond_v_vision @@ -1472,7 +1477,7 @@ def __call__( if sound_scheduler is not None and cond_v_sound is not None: # Skip CFG for 1.0 guidance scale - if guidance_scale != 1.0: + if self.do_classifier_free_guidance: velocity_sound = uncond_v_sound + guidance_scale * (cond_v_sound - uncond_v_sound) else: velocity_sound = cond_v_sound @@ -1484,7 +1489,7 @@ def __call__( action_condition_mask is not None and action_condition_mask.sum() < action_condition_mask.numel() ) if action_scheduler is not None and has_noisy_action and cond_v_action is not None: - if guidance_scale != 1.0: + if self.do_classifier_free_guidance: velocity_action = uncond_v_action + guidance_scale * (cond_v_action - uncond_v_action) else: velocity_action = cond_v_action From 591cd4d062597331b982686ec446029a20f81c20 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Thu, 28 May 2026 12:03:12 +0000 Subject: [PATCH 04/15] Remove unused method --- src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 460b0786e4a1..bf9edba8a265 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -572,11 +572,6 @@ def _remove_action_video_padding_from_latent( content_w_latent = max(content_w // self.vae_scale_factor_spatial, 1) return latents[:, :, :, :content_h_latent, :content_w_latent].contiguous() - def _remove_action_video_padding_from_video(self, video: torch.Tensor, image_size: torch.Tensor) -> torch.Tensor: - content_h = int(image_size[2].item()) - content_w = int(image_size[3].item()) - return video[:, :, :, :content_h, :content_w].contiguous() - def prepare_latents( self, image: torch.Tensor | None = None, From 04efd90e0f9b82ba92291907a9f8dc7ddd4d62ea Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Thu, 28 May 2026 16:24:14 +0000 Subject: [PATCH 05/15] Add action policy example to pipelines doc --- docs/source/en/api/pipelines/cosmos3.md | 48 +++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/docs/source/en/api/pipelines/cosmos3.md b/docs/source/en/api/pipelines/cosmos3.md index ce26ee0c36ef..291c1dab75e4 100644 --- a/docs/source/en/api/pipelines/cosmos3.md +++ b/docs/source/en/api/pipelines/cosmos3.md @@ -459,6 +459,54 @@ encode_video( +## Action policy + +Action policy generation predicts future video and action tokens from the first observation frame, text prompt, and action domain metadata. The example below uses the Bridge robot domain and writes the predicted action chunk to JSON in model-normalized action space. + +```python +import json + +import torch +from diffusers import Cosmos3OmniPipeline +from diffusers.utils import export_to_video, load_video + +pipe = Cosmos3OmniPipeline.from_pretrained( + "nvidia/Cosmos3-Nano", torch_dtype=torch.bfloat16, device_map="cuda" +) + +prompt = ( + "Put the pot to the left of the purple item. This video is captured from a first-person perspective looking " + "at the scene." +) +video = load_video( + "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" +) + +result = pipe( + prompt=prompt, + video=video, + num_frames=17, + height=480, + width=832, + fps=5, + num_inference_steps=30, + guidance_scale=1.0, + flow_shift=5.0, + action_mode="policy", + action_chunk_size=16, + raw_action_dim=10, + domain_name="bridge_orig_lerobot", + use_system_prompt=False, +) + +# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). +export_to_video(result.video, "sample.mp4", fps=5, macro_block_size=1) + +if result.action is not None: + with open("sample_action.json", "w") as f: + json.dump(result.action[0].tolist(), f) +``` + ## Metadata templates `tokenize_prompt` appends short metadata sentences inside the user message so the LLM sees the conditioning the model was trained with. The positive prompt gets sentences like *"The video is 7.9 seconds long and is of 24 FPS."* and *"This video is of 720x1280 resolution."*; the negative prompt gets the inverse (*"… is not …"*). From 362b6ebcb5251e53826339474fd37f1156bcd320 Mon Sep 17 00:00:00 2001 From: Atharva Joshi Date: Thu, 28 May 2026 11:45:46 -0700 Subject: [PATCH 06/15] Adding model selection for action example doc. --- docs/source/en/api/pipelines/cosmos3.md | 53 +++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/docs/source/en/api/pipelines/cosmos3.md b/docs/source/en/api/pipelines/cosmos3.md index 291c1dab75e4..e58c7a95f796 100644 --- a/docs/source/en/api/pipelines/cosmos3.md +++ b/docs/source/en/api/pipelines/cosmos3.md @@ -463,6 +463,9 @@ encode_video( Action policy generation predicts future video and action tokens from the first observation frame, text prompt, and action domain metadata. The example below uses the Bridge robot domain and writes the predicted action chunk to JSON in model-normalized action space. + + + ```python import json @@ -507,6 +510,56 @@ if result.action is not None: json.dump(result.action[0].tolist(), f) ``` + + + +```python +import json + +import torch +from diffusers import Cosmos3OmniPipeline +from diffusers.utils import export_to_video, load_video + +pipe = Cosmos3OmniPipeline.from_pretrained( + "nvidia/Cosmos3-Super", torch_dtype=torch.bfloat16, device_map="cuda" +) + +prompt = ( + "Put the pot to the left of the purple item. This video is captured from a first-person perspective looking " + "at the scene." +) +video = load_video( + "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" +) + +result = pipe( + prompt=prompt, + video=video, + num_frames=17, + height=480, + width=832, + fps=5, + num_inference_steps=30, + guidance_scale=1.0, + flow_shift=5.0, + action_mode="policy", + action_chunk_size=16, + raw_action_dim=10, + domain_name="bridge_orig_lerobot", + use_system_prompt=False, +) + +# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). +export_to_video(result.video, "sample.mp4", fps=5, macro_block_size=1) + +if result.action is not None: + with open("sample_action.json", "w") as f: + json.dump(result.action[0].tolist(), f) +``` + + + + ## Metadata templates `tokenize_prompt` appends short metadata sentences inside the user message so the LLM sees the conditioning the model was trained with. The positive prompt gets sentences like *"The video is 7.9 seconds long and is of 24 FPS."* and *"This video is of 720x1280 resolution."*; the negative prompt gets the inverse (*"… is not …"*). From 5ff2ea92b495c3320b97516134bf4394d18aff22 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 12:29:00 +0000 Subject: [PATCH 07/15] Remove redundant casts --- .../models/transformers/transformer_cosmos3.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cosmos3.py b/src/diffusers/models/transformers/transformer_cosmos3.py index 54fbe066ac33..29cfc127d253 100644 --- a/src/diffusers/models/transformers/transformer_cosmos3.py +++ b/src/diffusers/models/transformers/transformer_cosmos3.py @@ -151,9 +151,9 @@ class DomainAwareLinear(nn.Module): def __init__(self, input_size: int, output_size: int, num_domains: int) -> None: super().__init__() - self.input_size = int(input_size) - self.output_size = int(output_size) - self.num_domains = int(num_domains) + self.input_size = input_size + self.output_size = output_size + self.num_domains = num_domains self.fc = nn.Embedding(self.num_domains, self.output_size * self.input_size) self.bias = nn.Embedding(self.num_domains, self.output_size) nn.init.xavier_uniform_(self.fc.weight) @@ -370,8 +370,8 @@ def __init__( self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.time_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size) self.action_gen = action_gen - self.action_dim = int(32 if action_dim is None else action_dim) - self.num_embodiment_domains = int(num_embodiment_domains) + self.action_dim = 32 if action_dim is None else action_dim + self.num_embodiment_domains = num_embodiment_domains if action_gen: self.action_proj_in = DomainAwareLinear(self.action_dim, hidden_size, self.num_embodiment_domains) self.action_proj_out = DomainAwareLinear(hidden_size, self.action_dim, self.num_embodiment_domains) From a01c1c908fe9fa9d2094cbd7747aef6670550c92 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 12:30:53 +0000 Subject: [PATCH 08/15] Rename _pack_action_tokens to _prepare_action_segment --- src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index bf9edba8a265..6a8771a2c197 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -460,7 +460,7 @@ def _prepare_sound_segment( "sound_len": sound_len, } - def _pack_action_tokens( + def _prepare_action_segment( self, input_action_tokens: torch.Tensor, condition_frame_indexes: list[int], @@ -1259,7 +1259,7 @@ def __call__( ) cond_action_segment: dict[str, Any] = {} if action_latents is not None: - cond_action_segment = self._pack_action_tokens( + cond_action_segment = self._prepare_action_segment( input_action_tokens=action_latents, condition_frame_indexes=action_condition_frame_indexes, mrope_offset=cond_text_segment["vision_start_temporal_offset"], @@ -1306,7 +1306,7 @@ def __call__( ) uncond_action_segment: dict[str, Any] = {} if action_latents is not None: - uncond_action_segment = self._pack_action_tokens( + uncond_action_segment = self._prepare_action_segment( input_action_tokens=action_latents, condition_frame_indexes=action_condition_frame_indexes, mrope_offset=uncond_text_segment["vision_start_temporal_offset"], From c12e6b162b8542d7a4ef3f84822798a8f3c22ef0 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 12:35:39 +0000 Subject: [PATCH 09/15] Move validation checks to check_inputs --- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 6a8771a2c197..f556a6810598 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -630,16 +630,8 @@ def prepare_latents( # Build the vision conditioning tensor (always [1, 3, T, H, W], in [-1, 1], on device). if action_mode is not None: - if action_chunk_size is None: - raise ValueError("action_mode requires action_chunk_size.") - if video is None: - raise ValueError(f"action_mode={action_mode!r} requires loaded video conditioning.") + assert action_chunk_size is not None target_frames = action_chunk_size + 1 - if num_frames != target_frames: - raise ValueError( - "Action runs require num_frames to equal action_chunk_size + 1; " - f"got num_frames={num_frames}, action_chunk_size={action_chunk_size}." - ) vision_tensor, action_image_size, height, width = self._prepare_action_video_conditioning( video, height, width, target_frames, device=device, dtype=dtype ) @@ -691,8 +683,6 @@ def prepare_latents( if action is None: raise ValueError("action_mode='forward_dynamics' requires an action tensor.") action = action.to(device=device, dtype=dtype) - if action.shape[0] == 0: - raise ValueError("action_mode='forward_dynamics' requires at least one action token.") # Action chunks describe transitions, so action length must match action_chunk_size # while the paired video has action_chunk_size + 1 frames. Short inputs repeat the last action. @@ -704,10 +694,6 @@ def prepare_latents( action = action[:action_chunk_size] # The model action head has a fixed action_dim; pad raw domain actions with zeros on the channel axis. - if action.shape[-1] > action_dim: - raise ValueError( - f"Cosmos3 action dimension {action.shape[-1]} exceeds model action_dim={action_dim}." - ) if action.shape[-1] < action_dim: action_padding = torch.zeros( action.shape[0], @@ -856,8 +842,16 @@ def check_inputs( f"Unknown Cosmos3 action domain_name={domain_name!r}; " f"expected one of {sorted(_EMBODIMENT_TO_DOMAIN_ID)}." ) - if action_mode == "forward_dynamics" and action is None: - raise ValueError("action_mode='forward_dynamics' requires an action tensor.") + if action_mode == "forward_dynamics": + if action is None: + raise ValueError("action_mode='forward_dynamics' requires an action tensor.") + if action.shape[0] == 0: + raise ValueError("action_mode='forward_dynamics' requires at least one action token.") + action_dim = self.transformer.action_dim + if action.shape[-1] > action_dim: + raise ValueError( + f"Cosmos3 action dimension {action.shape[-1]} exceeds model action_dim={action_dim}." + ) if action_mode in {"inverse_dynamics", "policy"} and raw_action_dim is None: raise ValueError(f"action_mode={action_mode!r} requires raw_action_dim for output slicing.") From fbcd0777f4118c8fec4f2be306c936334ab862f1 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 12:37:52 +0000 Subject: [PATCH 10/15] Add action arguments in the __call__ docstring --- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index f556a6810598..b0b6ca8485c1 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -1088,6 +1088,28 @@ def __call__( sound_latents (`torch.Tensor`, *optional*): Pre-generated sound latents to start denoising from. Only consulted when `enable_sound=True`; when `None`, fresh Gaussian noise is sampled. + action_latents (`torch.Tensor`, *optional*): + Pre-generated action latents to start the action stream's denoising from. Only consulted when an action + run is configured via `action_mode`; when `None`, fresh Gaussian noise is sampled for the action tokens. + action_mode (`str`, *optional*): + Selects the action-conditioned generation task and requires a transformer trained with + `action_gen=True`. One of `"forward_dynamics"` (predict the future video from an initial frame and a + given `action` sequence), `"inverse_dynamics"` (infer the actions connecting the conditioning frames), + or `"policy"` (jointly roll out future video and actions from the first frame). When set, conditioning + must be supplied via `video` (not `image`) and `num_frames` is forced to `action_chunk_size + 1`. + action (`torch.Tensor`, *optional*): + Raw action tokens of shape `[T, action_dim]` driving `action_mode="forward_dynamics"`. Sequences shorter + than `action_chunk_size` repeat the last action; longer ones are truncated. Channels beyond the model's + `action_dim` are rejected, and narrower inputs are zero-padded up to `action_dim`. + action_chunk_size (`int`, *optional*): + Number of action transition steps in the chunk. Required for every `action_mode`; the paired video has + `action_chunk_size + 1` frames and `num_frames` is overwritten accordingly. + domain_name (`str`, *optional*): + Embodiment domain that selects the domain-aware action projection weights. Required for action runs and + must be one of the registered Cosmos 3 embodiment domains. + raw_action_dim (`int`, *optional*): + Number of meaningful (unpadded) action channels to keep when slicing predicted actions. Required for + `action_mode="inverse_dynamics"` and `action_mode="policy"`. output_type (`str`, *optional*, defaults to `"pil"`): Output format for the video. One of `"pil"` (list of `PIL.Image.Image`), `"np"` (`np.ndarray`, `[T, H, W, C]`), `"pt"` (`torch.Tensor`, `[T, C, H, W]`), or `"latent"` (raw vision latents). From 7c4e2f488dac6de9a0aa1df98c160042c8193c13 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 12:40:18 +0000 Subject: [PATCH 11/15] Move action mode check to check_inputs --- src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index b0b6ca8485c1..09b5c52fee44 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -822,6 +822,8 @@ def check_inputs( f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if action_mode is not None: + if action_mode not in _ACTION_MODES: + raise ValueError(f"Unsupported action_mode={action_mode!r}; expected one of {sorted(_ACTION_MODES)}.") if not getattr(self.transformer.config, "action_gen", False): raise ValueError("action_mode requires a transformer trained with action_gen=True.") if image is not None: @@ -1145,8 +1147,6 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - if action_mode is not None and action_mode not in _ACTION_MODES: - raise ValueError(f"Unsupported action_mode={action_mode!r}; expected one of {sorted(_ACTION_MODES)}.") if action_mode is not None and action_chunk_size is not None: num_frames = action_chunk_size + 1 From 57d3d07d43d13ffbdea926db6ac598ddc0cb9bb5 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 12:42:45 +0000 Subject: [PATCH 12/15] Rename action to action_tokens --- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 53 ++++++++++--------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 09b5c52fee44..1f31b378e5f8 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -588,7 +588,7 @@ def prepare_latents( dtype: torch.dtype = torch.bfloat16, enable_sound: bool = False, action_mode: str | None = None, - action: torch.Tensor | None = None, + action_tokens: torch.Tensor | None = None, action_chunk_size: int | None = None, domain_name: str | None = None, raw_action_dim: int | None = None, @@ -680,29 +680,29 @@ def prepare_latents( assert action_chunk_size is not None action_dim = self.transformer.action_dim if action_mode == "forward_dynamics": - if action is None: + if action_tokens is None: raise ValueError("action_mode='forward_dynamics' requires an action tensor.") - action = action.to(device=device, dtype=dtype) + action_tokens = action_tokens.to(device=device, dtype=dtype) # Action chunks describe transitions, so action length must match action_chunk_size # while the paired video has action_chunk_size + 1 frames. Short inputs repeat the last action. - if action.shape[0] < action_chunk_size: - action = torch.cat( - [action, action[-1:].expand(action_chunk_size - action.shape[0], -1)], + if action_tokens.shape[0] < action_chunk_size: + action_tokens = torch.cat( + [action_tokens, action_tokens[-1:].expand(action_chunk_size - action_tokens.shape[0], -1)], dim=0, ) - action = action[:action_chunk_size] + action_tokens = action_tokens[:action_chunk_size] # The model action head has a fixed action_dim; pad raw domain actions with zeros on the channel axis. - if action.shape[-1] < action_dim: + if action_tokens.shape[-1] < action_dim: action_padding = torch.zeros( - action.shape[0], - action_dim - action.shape[-1], - dtype=action.dtype, - device=action.device, + action_tokens.shape[0], + action_dim - action_tokens.shape[-1], + dtype=action_tokens.dtype, + device=action_tokens.device, ) - action = torch.cat([action, action_padding], dim=-1) - x0_tokens_action = action + action_tokens = torch.cat([action_tokens, action_padding], dim=-1) + x0_tokens_action = action_tokens else: x0_tokens_action = torch.zeros(action_chunk_size, action_dim, device=device, dtype=dtype) if domain_name not in _EMBODIMENT_TO_DOMAIN_ID: @@ -793,7 +793,7 @@ def check_inputs( enable_sound: bool, callback_on_step_end_tensor_inputs: list[str], action_mode: str | None, - action: torch.Tensor | None, + action_tokens: torch.Tensor | None, action_chunk_size: int | None, domain_name: str | None, raw_action_dim: int | None, @@ -845,14 +845,14 @@ def check_inputs( f"expected one of {sorted(_EMBODIMENT_TO_DOMAIN_ID)}." ) if action_mode == "forward_dynamics": - if action is None: + if action_tokens is None: raise ValueError("action_mode='forward_dynamics' requires an action tensor.") - if action.shape[0] == 0: + if action_tokens.shape[0] == 0: raise ValueError("action_mode='forward_dynamics' requires at least one action token.") action_dim = self.transformer.action_dim - if action.shape[-1] > action_dim: + if action_tokens.shape[-1] > action_dim: raise ValueError( - f"Cosmos3 action dimension {action.shape[-1]} exceeds model action_dim={action_dim}." + f"Cosmos3 action dimension {action_tokens.shape[-1]} exceeds model action_dim={action_dim}." ) if action_mode in {"inverse_dynamics", "policy"} and raw_action_dim is None: raise ValueError(f"action_mode={action_mode!r} requires raw_action_dim for output slicing.") @@ -1037,7 +1037,7 @@ def __call__( sound_latents: torch.Tensor | None = None, action_latents: torch.Tensor | None = None, action_mode: str | None = None, - action: torch.Tensor | None = None, + action_tokens: torch.Tensor | None = None, action_chunk_size: int | None = None, domain_name: str | None = None, raw_action_dim: int | None = None, @@ -1096,10 +1096,11 @@ def __call__( action_mode (`str`, *optional*): Selects the action-conditioned generation task and requires a transformer trained with `action_gen=True`. One of `"forward_dynamics"` (predict the future video from an initial frame and a - given `action` sequence), `"inverse_dynamics"` (infer the actions connecting the conditioning frames), - or `"policy"` (jointly roll out future video and actions from the first frame). When set, conditioning - must be supplied via `video` (not `image`) and `num_frames` is forced to `action_chunk_size + 1`. - action (`torch.Tensor`, *optional*): + given `action_tokens` sequence), `"inverse_dynamics"` (infer the actions connecting the conditioning + frames), or `"policy"` (jointly roll out future video and actions from the first frame). When set, + conditioning must be supplied via `video` (not `image`) and `num_frames` is forced to + `action_chunk_size + 1`. + action_tokens (`torch.Tensor`, *optional*): Raw action tokens of shape `[T, action_dim]` driving `action_mode="forward_dynamics"`. Sequences shorter than `action_chunk_size` repeat the last action; longer ones are truncated. Channels beyond the model's `action_dim` are rejected, and narrower inputs are zero-padded up to `action_dim`. @@ -1163,7 +1164,7 @@ def __call__( enable_sound, callback_on_step_end_tensor_inputs, action_mode, - action, + action_tokens, action_chunk_size, domain_name, raw_action_dim, @@ -1241,7 +1242,7 @@ def __call__( dtype=dtype, enable_sound=enable_sound, action_mode=action_mode, - action=action, + action_tokens=action_tokens, action_chunk_size=action_chunk_size, domain_name=domain_name, raw_action_dim=raw_action_dim, From a6e204011fa1cd5e13871244611d7ccdf9407c6b Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 12:55:34 +0000 Subject: [PATCH 13/15] Add warning for num_frames ovewrite attempt --- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 1f31b378e5f8..158e68642272 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -30,12 +30,15 @@ Cosmos3OmniTransformer, ) from ...schedulers import UniPCMultistepScheduler -from ...utils import BaseOutput, is_cosmos_guardrail_available +from ...utils import BaseOutput, is_cosmos_guardrail_available, logging from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + if is_cosmos_guardrail_available(): from cosmos_guardrail import CosmosSafetyChecker else: @@ -1149,7 +1152,13 @@ def __call__( callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs if action_mode is not None and action_chunk_size is not None: - num_frames = action_chunk_size + 1 + target_num_frames = action_chunk_size + 1 + if num_frames != target_num_frames: + logger.warning( + f"`num_frames={num_frames}` is ignored for action runs and overwritten to " + f"`action_chunk_size + 1 = {target_num_frames}`." + ) + num_frames = target_num_frames # 1. Check inputs self.check_inputs( From 913c24f35d2f0af80d46405822856c3ff0f4c1f2 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 13:07:30 +0000 Subject: [PATCH 14/15] Rename action_tokens to raw_actions --- examples/cosmos3/inference_cosmos3.py | 2 +- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 54 +++++++++---------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/examples/cosmos3/inference_cosmos3.py b/examples/cosmos3/inference_cosmos3.py index 675ead892c2f..737b57d30df5 100644 --- a/examples/cosmos3/inference_cosmos3.py +++ b/examples/cosmos3/inference_cosmos3.py @@ -161,7 +161,7 @@ def main(): num_inference_steps=args.num_inference_steps, flow_shift=args.flow_shift, action_mode=args.action_mode, - action=action, + raw_actions=action, action_chunk_size=args.action_chunk_size, domain_name=args.domain_name, raw_action_dim=args.raw_action_dim, diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 158e68642272..b7d6e00226b5 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -591,7 +591,7 @@ def prepare_latents( dtype: torch.dtype = torch.bfloat16, enable_sound: bool = False, action_mode: str | None = None, - action_tokens: torch.Tensor | None = None, + raw_actions: torch.Tensor | None = None, action_chunk_size: int | None = None, domain_name: str | None = None, raw_action_dim: int | None = None, @@ -683,29 +683,29 @@ def prepare_latents( assert action_chunk_size is not None action_dim = self.transformer.action_dim if action_mode == "forward_dynamics": - if action_tokens is None: + if raw_actions is None: raise ValueError("action_mode='forward_dynamics' requires an action tensor.") - action_tokens = action_tokens.to(device=device, dtype=dtype) + raw_actions = raw_actions.to(device=device, dtype=dtype) # Action chunks describe transitions, so action length must match action_chunk_size # while the paired video has action_chunk_size + 1 frames. Short inputs repeat the last action. - if action_tokens.shape[0] < action_chunk_size: - action_tokens = torch.cat( - [action_tokens, action_tokens[-1:].expand(action_chunk_size - action_tokens.shape[0], -1)], + if raw_actions.shape[0] < action_chunk_size: + raw_actions = torch.cat( + [raw_actions, raw_actions[-1:].expand(action_chunk_size - raw_actions.shape[0], -1)], dim=0, ) - action_tokens = action_tokens[:action_chunk_size] + raw_actions = raw_actions[:action_chunk_size] # The model action head has a fixed action_dim; pad raw domain actions with zeros on the channel axis. - if action_tokens.shape[-1] < action_dim: + if raw_actions.shape[-1] < action_dim: action_padding = torch.zeros( - action_tokens.shape[0], - action_dim - action_tokens.shape[-1], - dtype=action_tokens.dtype, - device=action_tokens.device, + raw_actions.shape[0], + action_dim - raw_actions.shape[-1], + dtype=raw_actions.dtype, + device=raw_actions.device, ) - action_tokens = torch.cat([action_tokens, action_padding], dim=-1) - x0_tokens_action = action_tokens + raw_actions = torch.cat([raw_actions, action_padding], dim=-1) + x0_tokens_action = raw_actions else: x0_tokens_action = torch.zeros(action_chunk_size, action_dim, device=device, dtype=dtype) if domain_name not in _EMBODIMENT_TO_DOMAIN_ID: @@ -796,7 +796,7 @@ def check_inputs( enable_sound: bool, callback_on_step_end_tensor_inputs: list[str], action_mode: str | None, - action_tokens: torch.Tensor | None, + raw_actions: torch.Tensor | None, action_chunk_size: int | None, domain_name: str | None, raw_action_dim: int | None, @@ -848,14 +848,14 @@ def check_inputs( f"expected one of {sorted(_EMBODIMENT_TO_DOMAIN_ID)}." ) if action_mode == "forward_dynamics": - if action_tokens is None: + if raw_actions is None: raise ValueError("action_mode='forward_dynamics' requires an action tensor.") - if action_tokens.shape[0] == 0: + if raw_actions.shape[0] == 0: raise ValueError("action_mode='forward_dynamics' requires at least one action token.") action_dim = self.transformer.action_dim - if action_tokens.shape[-1] > action_dim: + if raw_actions.shape[-1] > action_dim: raise ValueError( - f"Cosmos3 action dimension {action_tokens.shape[-1]} exceeds model action_dim={action_dim}." + f"Cosmos3 action dimension {raw_actions.shape[-1]} exceeds model action_dim={action_dim}." ) if action_mode in {"inverse_dynamics", "policy"} and raw_action_dim is None: raise ValueError(f"action_mode={action_mode!r} requires raw_action_dim for output slicing.") @@ -1040,7 +1040,7 @@ def __call__( sound_latents: torch.Tensor | None = None, action_latents: torch.Tensor | None = None, action_mode: str | None = None, - action_tokens: torch.Tensor | None = None, + raw_actions: torch.Tensor | None = None, action_chunk_size: int | None = None, domain_name: str | None = None, raw_action_dim: int | None = None, @@ -1099,14 +1099,14 @@ def __call__( action_mode (`str`, *optional*): Selects the action-conditioned generation task and requires a transformer trained with `action_gen=True`. One of `"forward_dynamics"` (predict the future video from an initial frame and a - given `action_tokens` sequence), `"inverse_dynamics"` (infer the actions connecting the conditioning + given `raw_actions` sequence), `"inverse_dynamics"` (infer the actions connecting the conditioning frames), or `"policy"` (jointly roll out future video and actions from the first frame). When set, conditioning must be supplied via `video` (not `image`) and `num_frames` is forced to `action_chunk_size + 1`. - action_tokens (`torch.Tensor`, *optional*): - Raw action tokens of shape `[T, action_dim]` driving `action_mode="forward_dynamics"`. Sequences shorter - than `action_chunk_size` repeat the last action; longer ones are truncated. Channels beyond the model's - `action_dim` are rejected, and narrower inputs are zero-padded up to `action_dim`. + raw_actions (`torch.Tensor`, *optional*): + Raw domain action vectors of shape `[T, raw_action_dim]` driving `action_mode="forward_dynamics"`. + Sequences shorter than `action_chunk_size` repeat the last action; longer ones are truncated. Channels + beyond the model's `action_dim` are rejected, and narrower inputs are zero-padded up to `action_dim`. action_chunk_size (`int`, *optional*): Number of action transition steps in the chunk. Required for every `action_mode`; the paired video has `action_chunk_size + 1` frames and `num_frames` is overwritten accordingly. @@ -1173,7 +1173,7 @@ def __call__( enable_sound, callback_on_step_end_tensor_inputs, action_mode, - action_tokens, + raw_actions, action_chunk_size, domain_name, raw_action_dim, @@ -1251,7 +1251,7 @@ def __call__( dtype=dtype, enable_sound=enable_sound, action_mode=action_mode, - action_tokens=action_tokens, + raw_actions=raw_actions, action_chunk_size=action_chunk_size, domain_name=domain_name, raw_action_dim=raw_action_dim, From 3bce946d550952798c13874a5ae5c91eec2616de Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 13:33:36 +0000 Subject: [PATCH 15/15] Remove scheduler config override --- docs/source/en/api/pipelines/cosmos3.md | 2 -- examples/cosmos3/inference_cosmos3.py | 3 -- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 30 +++++-------------- 3 files changed, 8 insertions(+), 27 deletions(-) diff --git a/docs/source/en/api/pipelines/cosmos3.md b/docs/source/en/api/pipelines/cosmos3.md index e58c7a95f796..301c78f58e62 100644 --- a/docs/source/en/api/pipelines/cosmos3.md +++ b/docs/source/en/api/pipelines/cosmos3.md @@ -494,7 +494,6 @@ result = pipe( fps=5, num_inference_steps=30, guidance_scale=1.0, - flow_shift=5.0, action_mode="policy", action_chunk_size=16, raw_action_dim=10, @@ -541,7 +540,6 @@ result = pipe( fps=5, num_inference_steps=30, guidance_scale=1.0, - flow_shift=5.0, action_mode="policy", action_chunk_size=16, raw_action_dim=10, diff --git a/examples/cosmos3/inference_cosmos3.py b/examples/cosmos3/inference_cosmos3.py index 737b57d30df5..18297927e09d 100644 --- a/examples/cosmos3/inference_cosmos3.py +++ b/examples/cosmos3/inference_cosmos3.py @@ -82,7 +82,6 @@ def main(): parser.add_argument("--fps", type=float, default=24.0) parser.add_argument("--guidance-scale", type=float, default=6.0, help="Classifier-free guidance scale.") parser.add_argument("--num-inference-steps", type=int, default=35, help="Number of denoising steps.") - parser.add_argument("--flow-shift", type=float, default=None, help="Scheduler flow shift.") parser.add_argument("--seed", type=int, default=None, help="Random seed for latent initialization.") parser.add_argument( "--enable-sound", @@ -159,7 +158,6 @@ def main(): width=args.width, fps=args.fps, num_inference_steps=args.num_inference_steps, - flow_shift=args.flow_shift, action_mode=args.action_mode, raw_actions=action, action_chunk_size=args.action_chunk_size, @@ -182,7 +180,6 @@ def main(): width=args.width, fps=args.fps, num_inference_steps=args.num_inference_steps, - flow_shift=args.flow_shift, enable_sound=args.enable_sound, guidance_scale=args.guidance_scale, generator=generator, diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index b7d6e00226b5..e47062e4a0f4 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -1033,7 +1033,6 @@ def __call__( fps: float = 24.0, num_inference_steps: int = 35, guidance_scale: float = 6.0, - flow_shift: float | None = None, enable_sound: bool = False, generator: torch.Generator | None = None, latents: torch.Tensor | None = None, @@ -1363,25 +1362,14 @@ def __call__( action_noisy_len = cond_action_segment.get("num_noisy_action_tokens") # 6. Set timesteps. UniPCMultistepScheduler keeps per-step state (_step_index, - # model_outputs history) on the instance, so audio/action gets its own copy. - inference_scheduler = copy.deepcopy(self.scheduler) - if flow_shift is not None: - inference_scheduler.register_to_config( - use_flow_sigmas=True, - use_karras_sigmas=False, - use_exponential_sigmas=False, - use_beta_sigmas=False, - flow_shift=flow_shift, - shift_terminal=None, - final_sigmas_type="zero", - ) - inference_scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = inference_scheduler.timesteps - sound_scheduler = copy.deepcopy(inference_scheduler) if sound_latents is not None else None - action_scheduler = copy.deepcopy(inference_scheduler) if action_latents is not None else None + # model_outputs history) on the instance, so sound/action each get their own copy. + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + sound_scheduler = copy.deepcopy(self.scheduler) if sound_latents is not None else None + action_scheduler = copy.deepcopy(self.scheduler) if action_latents is not None else None # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * inference_scheduler.order + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1492,7 +1480,7 @@ def __call__( else: velocity_vision = cond_v_vision - latents = inference_scheduler.step( + latents = self.scheduler.step( velocity_vision.unsqueeze(0), t, latents.unsqueeze(0), return_dict=False )[0].squeeze(0) @@ -1525,9 +1513,7 @@ def __call__( callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() self._current_timestep = None