Skip to content

[Feat][Diffusion]: Implement Component-Level VRAM Quota and Resource Domain Isolation#1582

Open
Flink-ddd wants to merge 3 commits intovllm-project:mainfrom
Flink-ddd:feat/omni-diffusion-memory-calculation
Open

[Feat][Diffusion]: Implement Component-Level VRAM Quota and Resource Domain Isolation#1582
Flink-ddd wants to merge 3 commits intovllm-project:mainfrom
Flink-ddd:feat/omni-diffusion-memory-calculation

Conversation

@Flink-ddd
Copy link

@Flink-ddd Flink-ddd commented Mar 1, 2026

Purpose

This PR resolves the Stage-0 initialization deadlock identified in Issue #1574.

Root Cause Resolved: Previously, the Diffusion Worker (Stage-1) allocated ~27GB of VRAM silently via global torch calls without reporting its budget to the Orchestrator. This caused the LLM (Stage-0) to hit a "Memory Blind Spot" during its profiling phase, resulting in a ValueError: No available memory (0.0 GiB reported).

Key Innovations:

Heuristic Budget Pre-audit: Introduced a staticmethod predict_resource_usage in DiffusionWorker to calculate VRAM footprint based on model metadata (Parameters/Dtype/Resolution) before worker spawning.

Resource Domain Isolation: Implemented an Orchestrator-level coordinator that adjusts Stage-0's gpu_memory_utilization using a Dynamic Utilization Boost algorithm to compensate for concurrent modal loads.

RFC #1316 Alignment: This serves as a critical prerequisite for the Sleep Mode ACK mechanism. It provides the "Logical Baseline" required to audit physical VRAM reclamation during diffusion Level 1/2 sleep transitions.

Test Plan

Validated using the Bagel (7B-MoT) multi-modal pipeline on an NVIDIA A100-80GB PCIe GPU.

Test Command:

pytest tests/e2e/offline_inference/test_bagel_text2img.py::test_bagel_text2img_shared_memory_connector -s 2>&1 | tee test_results.log

Test Result

VRAM Resource Model Comparison

Stage Resource Model (Main) Resource Model (This PR)
Diffusion Global Grab (~27GB) Fixed Quota (27.65GB)
LLM KV Cache Blind Claim (up to 72GB) Residual Scaling (~35GB)
Outcome OOM / Profiling Failure (0.0 GiB) Deterministic Success (8.89 GiB)

Integration with Sleep Mode ACK (RFC #1316)

This PR establishes the foundation for the upcoming Sleep Mode ACK PR:
Logical Audit: The predict_resource_usage method provides the expected_freed_gb value used in ACK signals.

Dynamic Ledger: By tracking total_reserved_gb in the Orchestrator, we enable a "Logic vs. Physical" dual-audit mechanism. When an ACK confirms a physical release (Level 2 Sleep), the Orchestrator can dynamically decrease the reserved budget, enabling deterministic KV Cache expansion for the active LLM worker.

Verified Logs (Success Case)

Click to expand E2E passed logs

Success Highlights:

  1. Device-Specific Reservation: Stage-1 (Diffusion) accurately predicted 27.65 GiB on specific device [0].
  2. Context-Aware Boost: LLM Stage-0 performed dynamic boost (0.35 -> 0.699) only for its target Device 0.
  3. Memory Safety: Initialized 162,976 tokens for KV Cache with 8.7 GiB logic overhead (vs 0.0 GiB crash).
  4. Deterministic Success: End-to-end generation completed at 2.07 img/s.

Full Test Logs:

============================= test session starts ==============================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /workspace/vllm-omni
configfile: pyproject.toml
plugins: anyio-4.11.0
collected 1 item

tests/e2e/offline_inference/test_bagel_text2img.py 
=== PRE-TEST GPU CLEANUP ===
Pre-test GPU status:
[GPU Memory Monitor] Waiting for GPU 0 to free memory, Condition: Memory usage ratio ≤ 5.0%
[GPU Memory Status] Current usage:
  GPU 0: 0.7GiB/80.0GiB (0.9%)
[GPU Memory Freed] Devices 0 meet memory condition
   Condition: Memory usage ratio ≤ 5.0%
   Wait time: 0.0 seconds (0.0 minutes)
   Final status:
     GPU 0: 0.7GiB/80.0GiB (0.9%)
INFO 03-01 18:16:24 [scheduler.py:224] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 03-01 18:16:24 [vllm.py:689] Asynchronous scheduling is enabled.
--- Running test: test_bagel_text2img_shared_memory_connector
INFO 03-01 18:16:24 [weight_utils.py:50] Using model weights format ['*']
INFO 03-01 18:16:25 [omni.py:181] Initializing stages for model: ByteDance-Seed/BAGEL-7B-MoT
INFO 03-01 18:16:25 [omni.py:358] No omni_master_address provided, defaulting to localhost (127.0.0.1)
INFO 03-01 18:16:27 [omni.py:321] [Coordinator] Stage-1 (Diffusion) on devices [0] predicted budget: 27.65 GiB
INFO 03-01 18:16:27 [omni.py:341] [Coordinator] LLM Stage-0 on Device 0 dynamic boost: 0.35 -> 0.699 (Compensating 0.35 ratio for resource domain isolation)
INFO 03-01 18:16:27 [initialization.py:233] Auto-configuring SharedMemoryConnector for edge ('0', '1')
INFO 03-01 18:16:27 [initialization.py:270] Loaded OmniTransferConfig with 1 connector configurations
INFO 03-01 18:16:27 [factory.py:46] Created connector: SharedMemoryConnector
INFO 03-01 18:16:27 [initialization.py:60] Created connector for 0 -> 1: SharedMemoryConnector
INFO 03-01 18:16:27 [omni.py:394] [Orchestrator] Loaded 2 stages
INFO 03-01 18:16:27 [omni.py:505] [Orchestrator] Waiting for 2 stages to initialize (timeout: 300s)
[Stage-0] INFO 03-01 18:16:39 [omni_stage.py:679] Starting stage worker with model: ByteDance-Seed/BAGEL-7B-MoT
[Stage-0] INFO 03-01 18:16:39 [omni_stage.py:725] [Stage-0] ZMQ transport detected; disabling SHM IPC (shm_threshold_bytes set to maxsize)
[Stage-0] INFO 03-01 18:16:39 [omni_stage.py:85] Using sequential init locks (nvml_available=True, pid_host=False)
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
[Stage-0] INFO 03-01 18:16:39 [initialization.py:233] Auto-configuring SharedMemoryConnector for edge ('0', '1')
[Stage-0] INFO 03-01 18:16:39 [initialization.py:270] Loaded OmniTransferConfig with 1 connector configurations
[Stage-0] INFO 03-01 18:16:39 [factory.py:46] Created connector: SharedMemoryConnector
[Stage-0] INFO 03-01 18:16:39 [initialization.py:60] Created connector for 0 -> 1: SharedMemoryConnector
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
[Stage-0] INFO 03-01 18:16:40 [model.py:529] Resolved architecture: BagelForConditionalGeneration
[Stage-0] INFO 03-01 18:16:40 [model.py:1549] Using max model len 32768
[Stage-0] INFO 03-01 18:16:40 [scheduler.py:224] Chunked prefill is enabled with max_num_batched_tokens=32768.
[Stage-0] INFO 03-01 18:16:40 [vllm.py:689] Asynchronous scheduling is enabled.
[Stage-0] WARNING 03-01 18:16:40 [vllm.py:727] Enforce eager set, overriding optimization level to -O0
[Stage-0] INFO 03-01 18:16:40 [vllm.py:845] Cudagraph is disabled under eager mode
/usr/local/lib/python3.12/dist-packages/transformers/models/auto/image_processing_auto.py:647: FutureWarning: The image_processor_class argument is deprecated and will be removed in v4.42. Please use `slow_image_processor_class`, or `fast_image_processor_class` instead
  warnings.warn(
[Stage-1] INFO 03-01 18:16:40 [omni_stage.py:679] Starting stage worker with model: ByteDance-Seed/BAGEL-7B-MoT
[Stage-1] INFO 03-01 18:16:40 [omni_stage.py:725] [Stage-1] ZMQ transport detected; disabling SHM IPC (shm_threshold_bytes set to maxsize)
[Stage-1] INFO 03-01 18:16:40 [omni_stage.py:85] Using sequential init locks (nvml_available=True, pid_host=False)
(EngineCore_DP0 pid=9730) [Stage-0] INFO 03-01 18:16:54 [core.py:97] Initializing a V1 LLM engine (v0.16.0) with config: model='ByteDance-Seed/BAGEL-7B-MoT', speculative_config=None, tokenizer='ByteDance-Seed/BAGEL-7B-MoT', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, enable_return_routed_experts=False, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, kv_cache_metrics=False, kv_cache_metrics_sample=0.01, cudagraph_metrics=False, enable_layerwise_nvtx_tracing=False, enable_mfu_metrics=False, enable_mm_processor_stats=False, enable_logging_iteration_details=False), seed=0, served_model_name=ByteDance-Seed/BAGEL-7B-MoT, enable_prefix_caching=False, enable_chunked_prefill=True, pooler_config=None, compilation_config={'level': None, 'mode': <CompilationMode.NONE: 0>, 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'inductor', 'custom_ops': ['all'], 'splitting_ops': [], 'compile_mm_encoder': False, 'compile_sizes': [], 'compile_ranges_split_points': [32768], 'inductor_compile_config': {'enable_auto_functionalized_v2': False, 'combo_kernels': True, 'benchmark_combo_kernel': True}, 'inductor_passes': {}, 'cudagraph_mode': <CUDAGraphMode.NONE: 0>, 'cudagraph_num_of_warmups': 0, 'cudagraph_capture_sizes': [], 'cudagraph_copy_inputs': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {'fuse_norm_quant': False, 'fuse_act_quant': False, 'fuse_attn_quant': False, 'eliminate_noops': False, 'enable_sp': False, 'fuse_gemm_comms': False, 'fuse_allreduce_rms': False, 'fuse_act_padding': False}, 'max_cudagraph_capture_size': 0, 'dynamic_shapes_config': {'type': <DynamicShapesType.BACKED: 'backed'>, 'evaluate_guards': False, 'assume_32_bit_indexing': False}, 'local_cache_dir': None, 'fast_moe_cold_start': True, 'static_all_moe_layers': []}
(EngineCore_DP0 pid=9730) [Stage-0] WARNING 03-01 18:16:54 [multiproc_executor.py:921] Reducing Torch parallelism from 252 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
[Stage-1] WARNING 03-01 18:17:00 [omni_stage.py:173] Timeout waiting for device 0 initialization lock, proceeding anyway
[Stage-1] INFO 03-01 18:17:02 [multiproc_executor.py:74] Starting server...
/usr/local/lib/python3.12/dist-packages/transformers/models/auto/image_processing_auto.py:647: FutureWarning: The image_processor_class argument is deprecated and will be removed in v4.42. Please use `slow_image_processor_class`, or `fast_image_processor_class` instead
  warnings.warn(
[Stage-0] INFO 03-01 18:17:07 [parallel_state.py:1234] world_size=1 rank=0 local_rank=0 distributed_init_method=tcp://127.0.0.1:51383 backend=nccl
[Stage-0] INFO 03-01 18:17:07 [parallel_state.py:1445] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, PCP rank 0, TP rank 0, EP rank N/A
(Worker pid=10064) [Stage-0] INFO 03-01 18:17:08 [gpu_model_runner.py:4124] Starting to load model ByteDance-Seed/BAGEL-7B-MoT...
(Worker pid=10064) [Stage-0] INFO 03-01 18:17:10 [vllm.py:689] Asynchronous scheduling is enabled.
(Worker pid=10064) [Stage-0] WARNING 03-01 18:17:10 [vllm.py:734] Inductor compilation was disabled by user settings, optimizations settings that are only active during inductor compilation will be ignored.
(Worker pid=10064) [Stage-0] INFO 03-01 18:17:10 [vllm.py:845] Cudagraph is disabled under eager mode
(Worker pid=10064) [Stage-0] INFO 03-01 18:17:11 [cuda.py:367] Using FLASH_ATTN attention backend out of potential backends: ['FLASH_ATTN', 'FLASHINFER', 'TRITON_ATTN', 'FLEX_ATTENTION'].
(Worker pid=10064) [Stage-0] WARNING 03-01 18:17:11 [bagel.py:391] Overriding vit_config.num_hidden_layers from 27 to 26 to match the Bagel model checkpoint.
(Worker pid=10064) [Stage-0] WARNING 03-01 18:17:11 [bagel.py:397] Setting vit_config.vision_use_head to False as it is not present in the Bagel model checkpoint.
(Worker pid=10064) [Stage-0] INFO 03-01 18:17:11 [mm_encoder_attention.py:77] Using AttentionBackendEnum.FLASH_ATTN for MMEncoderAttention.
(Worker pid=10064) 
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
(Worker pid=10064) 
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00, 22.09it/s]
(Worker pid=10064) 
[Stage-1] INFO 03-01 18:17:16 [diffusion_worker.py:388] Worker 0 created result MessageQueue
[Stage-1] INFO 03-01 18:17:17 [scheduler.py:224] Chunked prefill is enabled with max_num_batched_tokens=2048.
[Stage-1] INFO 03-01 18:17:17 [vllm.py:689] Asynchronous scheduling is enabled.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Stage-1] INFO 03-01 18:17:17 [diffusion_worker.py:153] Worker 0: Initialized device and distributed environment.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Stage-1] INFO 03-01 18:17:17 [parallel_state.py:575] Building SP subgroups from explicit sp_group_ranks (sp_size=1, ulysses=1, ring=1, use_ulysses_low=True).
[Stage-1] INFO 03-01 18:17:17 [parallel_state.py:617] SP group details for rank 0: sp_group=[0], ulysses_group=[0], ring_group=[0]
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Stage-1] INFO 03-01 18:17:17 [weight_utils.py:50] Using model weights format ['*']
(Worker pid=10064) [Stage-0] INFO 03-01 18:17:18 [default_loader.py:293] Loading weights took 5.44 seconds
(Worker pid=10064) [Stage-0] INFO 03-01 18:17:19 [gpu_model_runner.py:4221] Model loading took 15.04 GiB memory and 7.672688 seconds
[Stage-1] INFO 03-01 18:17:19 [weight_utils.py:579] No diffusion_pytorch_model.safetensors.index.json found in remote.

Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
(Worker pid=10064) [Stage-0] INFO 03-01 18:17:21 [gpu_model_runner.py:5140] Encoder cache will be initialized with a budget of 32768 tokens, and profiled with 6 image items of the maximum feature size.
(Worker pid=10064) Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.

Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:09<00:09,  9.89s/it]
(Worker pid=10064) [Stage-0] INFO 03-01 18:17:33 [base.py:102] Available KV cache memory: 8.7 GiB (profiling fallback)
(EngineCore_DP0 pid=9730) [Stage-0] INFO 03-01 18:17:33 [kv_cache_utils.py:1307] GPU KV cache size: 162,976 tokens
(EngineCore_DP0 pid=9730) [Stage-0] INFO 03-01 18:17:33 [kv_cache_utils.py:1312] Maximum concurrency for 32,768 tokens per request: 4.97x
(Worker pid=10064) [Stage-0] INFO 03-01 18:17:33 [kernel_warmup.py:44] Skipping FlashInfer autotune because it is disabled.
(EngineCore_DP0 pid=9730) [Stage-0] INFO 03-01 18:17:33 [core.py:278] init engine (profile, create kv cache, warmup model) took 14.06 seconds

Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:13<00:00,  6.24s/it]

Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:13<00:00,  6.79s/it]

[Stage-1] INFO 03-01 18:17:33 [pipeline_bagel.py:716] BagelPipeline weight filter kept 1466/1467 tensors (shape mismatches seen: 0)
(EngineCore_DP0 pid=9730) [Stage-0] WARNING 03-01 18:17:33 [scheduler.py:166] Using custom scheduler class vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler. This scheduler interface is not public and compatibility may not be maintained.
(EngineCore_DP0 pid=9730) /usr/local/lib/python3.12/dist-packages/transformers/models/auto/image_processing_auto.py:647: FutureWarning: The image_processor_class argument is deprecated and will be removed in v4.42. Please use `slow_image_processor_class`, or `fast_image_processor_class` instead
(EngineCore_DP0 pid=9730)   warnings.warn(
[Stage-1] INFO 03-01 18:17:34 [diffusers_loader.py:301] Loading weights took 14.66 seconds
(EngineCore_DP0 pid=9730) [Stage-0] INFO 03-01 18:17:34 [vllm.py:689] Asynchronous scheduling is enabled.
(EngineCore_DP0 pid=9730) [Stage-0] WARNING 03-01 18:17:34 [vllm.py:734] Inductor compilation was disabled by user settings, optimizations settings that are only active during inductor compilation will be ignored.
(EngineCore_DP0 pid=9730) [Stage-0] INFO 03-01 18:17:34 [vllm.py:845] Cudagraph is disabled under eager mode
[Stage-0] INFO 03-01 18:17:34 [omni_llm.py:173] Supported_tasks: ['generate']
[Stage-0] INFO 03-01 18:17:34 [initialization.py:324] [Stage-0] Initializing OmniConnectors with config keys: ['to_stage_1']
[Stage-0] INFO 03-01 18:17:34 [omni_stage.py:794] Max batch size: 1
INFO 03-01 18:17:34 [omni.py:495] [Orchestrator] Stage-0 reported ready
[Stage-1] INFO 03-01 18:17:35 [diffusion_model_runner.py:118] Model loading took 26.4738 GiB and 18.129542 seconds
[Stage-1] INFO 03-01 18:17:35 [diffusion_model_runner.py:123] Model runner: Model loaded successfully.
[Stage-1] INFO 03-01 18:17:35 [diffusion_model_runner.py:163] Model runner: Initialization complete.
[Stage-1] INFO 03-01 18:17:35 [diffusion_worker.py:181] Worker 0: Process-scoped GPU memory after model loading: 0.00 GiB.
[Stage-1] INFO 03-01 18:17:35 [manager.py:91] Initializing DiffusionLoRAManager: device=cuda:0, dtype=torch.bfloat16, max_cached_adapters=1, static_lora_path=None
[Stage-1] INFO 03-01 18:17:35 [diffusion_worker.py:123] Worker 0: Initialization complete.
[Stage-1] INFO 03-01 18:17:35 [diffusion_worker.py:522] Worker 0: Scheduler loop started.
[Stage-1] INFO 03-01 18:17:35 [diffusion_worker.py:445] Worker 0 ready to receive requests via shared memory
[Stage-1] INFO 03-01 18:17:35 [scheduler.py:41] SyncScheduler initialized result MessageQueue
[Stage-1] INFO 03-01 18:17:35 [diffusion_engine.py:341] dummy run to warm up the model
[Stage-1] INFO 03-01 18:17:35 [manager.py:566] Deactivating all adapters: 0 layers
[Stage-1] WARNING 03-01 18:17:35 [kv_transfer_manager.py:517] Request has no ID, cannot receive KV cache
[Stage-1] INFO 03-01 18:17:36 [initialization.py:324] [Stage-1] Initializing OmniConnectors with config keys: ['from_stage_0']
[Stage-1] INFO 03-01 18:17:36 [factory.py:46] Created connector: SharedMemoryConnector
[Stage-1] INFO 03-01 18:17:36 [initialization.py:60] Created connector for 0 -> 1: SharedMemoryConnector
[Stage-1] INFO 03-01 18:17:36 [omni_stage.py:794] Max batch size: 1
INFO 03-01 18:17:36 [omni.py:495] [Orchestrator] Stage-1 reported ready
INFO 03-01 18:17:36 [omni.py:524] [Orchestrator] All stages initialized successfully

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 unit/s, output: 0.00 unit/s]�[A(Worker pid=10064) [Stage-0] INFO 03-01 18:17:36 [kv_transfer_manager.py:142] Initializing OmniConnector with config: {'type': 'SharedMemoryConnector', 'shm_threshold_bytes': 65536, 'role': 'sender'}
(Worker pid=10064) [Stage-0] INFO 03-01 18:17:36 [factory.py:46] Created connector: SharedMemoryConnector
[Stage-1] INFO 03-01 18:17:36 [manager.py:566] Deactivating all adapters: 0 layers
[Stage-1] INFO 03-01 18:17:36 [kv_transfer_manager.py:142] Initializing OmniConnector with config: {'type': 'SharedMemoryConnector', 'shm_threshold_bytes': 65536, 'role': 'receiver'}
[Stage-1] INFO 03-01 18:17:36 [factory.py:46] Created connector: SharedMemoryConnector
[Stage-1] INFO 03-01 18:17:36 [kv_transfer_manager.py:437] Wait for KV cache for request 0_e8f88933-70ac-4ac9-9ebd-547a932328bf from stage 0 to 1...
(Worker pid=10064) [Stage-0] INFO 03-01 18:17:36 [kv_transfer_manager.py:361] KV transfer OK: 0_e8f88933-70ac-4ac9-9ebd-547a932328bf, 690920 bytes
[Stage-1] INFO 03-01 18:17:36 [kv_transfer_manager.py:450] Successfully received KV cache for 0_e8f88933-70ac-4ac9-9ebd-547a932328bf, 690920 bytes
[Stage-1] INFO 03-01 18:17:36 [pipeline_bagel.py:326] Using injected KV Cache (direct)
[Stage-1] WARNING 03-01 18:17:36 [pipeline_bagel.py:333] CFG is disabled when using injected KV Cache
[Stage-1] INFO 03-01 18:17:42 [diffusion_engine.py:80] Generation completed successfully.
[Stage-1] INFO 03-01 18:17:42 [diffusion_engine.py:98] Post-processing completed in 0.0000 seconds


Processed prompts: 100%|██████████| 1/1 [00:06<00:00,  6.29s/img, est. speed stage-1 img/s: 2.07, avg e2e_lat: 0.0ms]�[A
Processed prompts: 100%|██████████| 1/1 [00:06<00:00,  6.29s/img, est. speed stage-1 img/s: 2.07, avg e2e_lat: 0.0ms]

Adding requests:   0%|          | 0/1 [00:06<?, ?it/s]
[Stage-0] INFO 03-01 18:17:42 [omni_stage.py:842] Received shutdown signal
[Stage-1] INFO 03-01 18:17:42 [omni_stage.py:842] Received shutdown signal
WARNING 03-01 18:17:42 [omni_stage.py:542] Failed to send shutdown to in_q: Socket operation on non-socket
(Worker pid=10064) [Stage-0] INFO 03-01 18:17:42 [multiproc_executor.py:732] Parent process exited, terminating worker
(Worker pid=10064) [Stage-0] INFO 03-01 18:17:42 [multiproc_executor.py:785] WorkerProc shutting down.
[Stage-1] INFO 03-01 18:17:42 [diffusion_worker.py:474] Worker 0: Received shutdown message
[Stage-1] INFO 03-01 18:17:42 [diffusion_worker.py:495] event loop terminated.
[Stage-1] INFO 03-01 18:17:42 [diffusion_worker.py:530] Worker 0: Shutdown complete.
WARNING 03-01 18:17:47 [omni_stage.py:542] Failed to send shutdown to in_q: Socket operation on non-socket
.Post-test GPU status:

================================================================================
NVIDIA GPU Information (nvidia-smi)
================================================================================
Sun Mar  1 18:17:48 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:00:0C.0 Off |                    0 |
| N/A   39C    P0             45W /  300W |       4MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

================================================================================
Detailed GPU Processes (nvidia-smi pmon)
================================================================================
# gpu         pid   type     sm    mem    enc    dec    jpg    ofa    command 
# Idx           #    C/G      %      %      %      %      %      %    name 
    0          -     -      -      -      -      -      -      -    -              


================================================================================
System Processes with GPU keywords
================================================================================


=============================== warnings summary ===============================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

../../usr/local/lib/python3.12/dist-packages/_pytest/config/__init__.py:1428
  /usr/local/lib/python3.12/dist-packages/_pytest/config/__init__.py:1428: PytestConfigWarning: Unknown config option: asyncio_mode
  
    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=================== 1 passed, 3 warnings in 84.04s (0:01:24) ===================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

</details>

Signed-off-by: vensen <vensenmu@gmail.com>
Signed-off-by: vensen <vensenmu@gmail.com>
@Flink-ddd Flink-ddd marked this pull request as ready for review March 1, 2026 17:51
@Flink-ddd Flink-ddd requested a review from hsliuustc0106 as a code owner March 1, 2026 17:51
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c19643c896

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Signed-off-by: vensen <vensenmu@gmail.com>
@princepride princepride changed the title [feat][diffusion]: Implement Component-Level VRAM Quota and Resource Domain Isolation [Feat][Diffusion]: Implement Component-Level VRAM Quota and Resource Domain Isolation Mar 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant