Skip to content

Commit 1298c67

Browse files
[FEAT] [Performance] Enable DP for ViT in Qwen2.5VL (#22742)
Signed-off-by: tjtanaa <[email protected]> Co-authored-by: DarkLight1337 <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent 4d9c619 commit 1298c67

File tree

5 files changed

+633
-48
lines changed

5 files changed

+633
-48
lines changed

tests/multimodal/test_utils.py

Lines changed: 324 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import base64
5+
import math
56
import mimetypes
67
import os
78
from tempfile import NamedTemporaryFile, TemporaryDirectory
@@ -20,6 +21,8 @@
2021
from vllm.multimodal.image import convert_image_mode
2122
from vllm.multimodal.inputs import PlaceholderRange
2223
from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions,
24+
get_load_balance_assignment,
25+
run_dp_sharded_mrope_vision_model,
2326
run_dp_sharded_vision_model)
2427
from vllm.platforms import current_platform
2528
from vllm.utils import get_open_port, update_environment_variables
@@ -425,8 +428,8 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
425428
# Set random seed for reproducibility
426429
current_platform.seed_everything(0)
427430

428-
device = torch.device(f"cuda:{local_rank}")
429-
torch.cuda.set_device(device)
431+
device = f"{current_platform.device_name}:{local_rank}"
432+
current_platform.set_device(device)
430433
torch.set_default_device(device)
431434

432435
update_environment_variables({
@@ -463,3 +466,322 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
463466

464467
# Check that the outputs are close (they should be identical)
465468
assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)
469+
470+
471+
@pytest.mark.parametrize(
472+
"sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts,"
473+
"expected_grouped_sizes_per_gpu,test_description",
474+
[
475+
# Empty input
476+
([], 2, [], [0, 0], [0, 0], "empty input"),
477+
478+
# Fewer samples than GPUs
479+
([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0
480+
], "fewer samples than GPUs"),
481+
482+
# Single GPU
483+
([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"),
484+
485+
# Balanced assignment
486+
([100, 100, 100, 100
487+
], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"),
488+
489+
# Unbalanced sizes - this one is trickier since the algorithm is greedy
490+
([1000, 100, 200, 50], 2, [0, 2, 1, 3
491+
], [1, 3], [1000, 350], "unbalanced sizes"),
492+
],
493+
)
494+
def test_get_load_balance_assignment_cases(sizes, num_gpus,
495+
expected_shuffle_indices,
496+
expected_gpu_sample_counts,
497+
expected_grouped_sizes_per_gpu,
498+
test_description):
499+
"""Test get_load_balance_assignment with various input cases."""
500+
result = get_load_balance_assignment(sizes, num_gpus=num_gpus)
501+
(shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result
502+
503+
# Common assertions for all cases
504+
assert len(shuffle_indices) == len(sizes)
505+
assert len(gpu_sample_counts) == num_gpus
506+
assert len(grouped_sizes_per_gpu) == num_gpus
507+
assert sum(gpu_sample_counts) == len(sizes)
508+
509+
assert shuffle_indices == expected_shuffle_indices
510+
511+
assert gpu_sample_counts == expected_gpu_sample_counts
512+
assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu
513+
514+
515+
class SimpleMRopeVisionModel(torch.nn.Module):
516+
"""A simple vision model for testing mrope functionality."""
517+
518+
def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64):
519+
super().__init__()
520+
self.spatial_merge_size = spatial_merge_size
521+
self.out_hidden_size = out_hidden_size
522+
self.linear = torch.nn.Linear(768, out_hidden_size)
523+
524+
def forward(self, pixel_values: torch.Tensor,
525+
grid_thw_list: list[list[int]]):
526+
"""Simple forward pass that simulates spatial merging."""
527+
# Apply linear transformation
528+
embeddings = self.linear(pixel_values)
529+
530+
# Simulate spatial merging by reducing the number of patches
531+
merge_factor = self.spatial_merge_size * self.spatial_merge_size
532+
533+
# Group patches and merge spatially
534+
merged_embeddings = []
535+
start_idx = 0
536+
537+
for grid_thw in grid_thw_list:
538+
num_patches = math.prod(grid_thw)
539+
end_idx = start_idx + num_patches
540+
541+
# Get patches for this image
542+
image_patches = embeddings[start_idx:end_idx]
543+
544+
# Simulate spatial merging by averaging groups of patches
545+
merged_patches = num_patches // merge_factor
546+
if merged_patches > 0:
547+
# Reshape and average to simulate merging
548+
reshaped = image_patches[:merged_patches * merge_factor].view(
549+
merged_patches, merge_factor, -1)
550+
merged = reshaped.mean(dim=1)
551+
merged_embeddings.append(merged)
552+
553+
start_idx = end_idx
554+
555+
if merged_embeddings:
556+
return torch.cat(merged_embeddings, dim=0)
557+
else:
558+
return torch.empty((0, self.out_hidden_size),
559+
device=pixel_values.device,
560+
dtype=pixel_values.dtype)
561+
562+
563+
@multi_gpu_test(num_gpus=2)
564+
@pytest.mark.parametrize(
565+
"batch_size",
566+
[
567+
1, # Single image
568+
3, # Small batch
569+
5, # Odd batch size (for testing padding)
570+
],
571+
)
572+
def test_run_dp_sharded_mrope_vision_model(batch_size: int):
573+
world_size = 2
574+
# Launch processes
575+
mp.spawn(
576+
run_dp_sharded_mrope_vision_model_vs_direct,
577+
args=(
578+
world_size,
579+
batch_size,
580+
get_open_port(),
581+
),
582+
nprocs=world_size,
583+
)
584+
585+
586+
def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
587+
world_size: int,
588+
batch_size: int,
589+
master_port: int):
590+
"""
591+
Test that run_dp_sharded_mrope_vision_model produces the same results as
592+
calling the model directly.
593+
"""
594+
# Set random seed for reproducibility
595+
current_platform.seed_everything(0)
596+
device = f"{current_platform.device_name}:{local_rank}"
597+
current_platform.set_device(device)
598+
torch.set_default_device(device)
599+
600+
update_environment_variables({
601+
'RANK': str(local_rank),
602+
'LOCAL_RANK': str(local_rank),
603+
'WORLD_SIZE': str(world_size),
604+
'MASTER_ADDR': 'localhost',
605+
'MASTER_PORT': str(master_port),
606+
})
607+
608+
# initialize distributed
609+
init_distributed_environment()
610+
initialize_model_parallel(tensor_model_parallel_size=world_size)
611+
612+
# Create test data
613+
grid_thw_list = []
614+
pixel_values_list = []
615+
616+
for i in range(batch_size):
617+
# Varying image sizes for better testing
618+
t, h, w = 1, 4 + i, 4 + i
619+
grid_thw_list.append([t, h, w])
620+
621+
num_patches = t * h * w
622+
# Create random pixel values for this image
623+
image_pixels = torch.randn(num_patches, 768)
624+
pixel_values_list.append(image_pixels)
625+
626+
# Concatenate all pixel values
627+
pixel_values = torch.cat(pixel_values_list, dim=0)
628+
629+
# Create a simple mrope vision model
630+
vision_model = SimpleMRopeVisionModel()
631+
632+
# Run the model directly on the full input (only on rank 0)
633+
if local_rank == 0:
634+
with torch.inference_mode():
635+
direct_output = vision_model(pixel_values, grid_thw_list)
636+
637+
# Run the model through the sharded function
638+
with torch.inference_mode():
639+
sharded_output = run_dp_sharded_mrope_vision_model(
640+
vision_model, pixel_values, grid_thw_list)
641+
sharded_output = torch.cat(sharded_output, dim=0)
642+
643+
# Check that the world size is setup correctly
644+
assert get_tensor_model_parallel_world_size() == world_size
645+
646+
# Compare outputs (only on rank 0)
647+
if local_rank == 0:
648+
# Check that the outputs have the same shape
649+
assert direct_output.shape == sharded_output.shape
650+
# Check that the outputs are close (they should be identical)
651+
assert torch.allclose(direct_output,
652+
sharded_output,
653+
rtol=1e-5,
654+
atol=1e-5)
655+
656+
657+
@multi_gpu_test(num_gpus=2)
658+
def test_run_dp_sharded_mrope_vision_model_empty_input():
659+
world_size = 2
660+
mp.spawn(
661+
run_dp_sharded_mrope_vision_model_empty_input_worker,
662+
args=(world_size, get_open_port()),
663+
nprocs=world_size,
664+
)
665+
666+
667+
def run_dp_sharded_mrope_vision_model_empty_input_worker(
668+
local_rank: int, world_size: int, master_port: int):
669+
"""Test run_dp_sharded_mrope_vision_model with empty input."""
670+
# Set up distributed environment
671+
device = f"{current_platform.device_name}:{local_rank}"
672+
current_platform.set_device(device)
673+
torch.set_default_device(device)
674+
675+
update_environment_variables({
676+
'RANK': str(local_rank),
677+
'LOCAL_RANK': str(local_rank),
678+
'WORLD_SIZE': str(world_size),
679+
'MASTER_ADDR': 'localhost',
680+
'MASTER_PORT': str(master_port),
681+
})
682+
683+
init_distributed_environment()
684+
initialize_model_parallel(tensor_model_parallel_size=world_size)
685+
686+
# Create empty inputs
687+
pixel_values = torch.empty((0, 768))
688+
grid_thw_list: list[list[int]] = []
689+
690+
vision_model = SimpleMRopeVisionModel()
691+
692+
# Should handle empty input gracefully
693+
with torch.inference_mode():
694+
output = run_dp_sharded_mrope_vision_model(vision_model, pixel_values,
695+
grid_thw_list)
696+
697+
assert len(output) == 0
698+
699+
700+
@multi_gpu_test(num_gpus=4)
701+
def test_run_dp_sharded_mrope_vision_model_uneven_load():
702+
world_size = 4
703+
mp.spawn(
704+
run_dp_sharded_mrope_vision_model_uneven_load_worker,
705+
args=(world_size, get_open_port()),
706+
nprocs=world_size,
707+
)
708+
709+
710+
def run_dp_sharded_mrope_vision_model_uneven_load_worker(
711+
local_rank: int, world_size: int, master_port: int):
712+
"""Test run_dp_sharded_mrope_vision_model with uneven load distribution."""
713+
# Set up distributed environment
714+
current_platform.seed_everything(123)
715+
device = f"{current_platform.device_name}:{local_rank}"
716+
current_platform.set_device(device)
717+
torch.set_default_device(device)
718+
719+
update_environment_variables({
720+
'RANK': str(local_rank),
721+
'LOCAL_RANK': str(local_rank),
722+
'WORLD_SIZE': str(world_size),
723+
'MASTER_ADDR': 'localhost',
724+
'MASTER_PORT': str(master_port),
725+
})
726+
727+
init_distributed_environment()
728+
initialize_model_parallel(tensor_model_parallel_size=world_size)
729+
730+
# Create images with very different sizes
731+
grid_thw_list = [
732+
[1, 2, 2], # Small: 4 patches
733+
[1, 8, 8], # Large: 64 patches
734+
[1, 3, 3], # Medium: 9 patches
735+
]
736+
737+
pixel_values_list = []
738+
for grid_thw in grid_thw_list:
739+
num_patches = math.prod(grid_thw)
740+
image_pixels = torch.randn(num_patches, 768)
741+
pixel_values_list.append(image_pixels)
742+
743+
pixel_values = torch.cat(pixel_values_list, dim=0)
744+
vision_model = SimpleMRopeVisionModel()
745+
746+
# Should handle uneven distribution without errors
747+
with torch.inference_mode():
748+
output_tuple = run_dp_sharded_mrope_vision_model(
749+
vision_model, pixel_values, grid_thw_list)
750+
751+
# Verify output shape is reasonable
752+
merge_factor = vision_model.spatial_merge_size**2
753+
expected_output_patches = list(
754+
math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list)
755+
756+
for i, output in enumerate(output_tuple):
757+
assert output.shape[0] == expected_output_patches[i]
758+
assert output.shape[1] == vision_model.out_hidden_size
759+
760+
761+
@pytest.mark.parametrize("spatial_merge_size", [2, 4])
762+
def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int):
763+
"""Test SimpleMRopeVisionModel with different spatial merge sizes."""
764+
device = current_platform.device_type
765+
766+
grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images
767+
pixel_values_list = []
768+
769+
for grid_thw in grid_thw_list:
770+
num_patches = math.prod(grid_thw)
771+
image_pixels = torch.randn(num_patches, 768, device=device)
772+
pixel_values_list.append(image_pixels)
773+
774+
pixel_values = torch.cat(pixel_values_list, dim=0)
775+
vision_model = SimpleMRopeVisionModel(
776+
spatial_merge_size=spatial_merge_size).to(device)
777+
778+
with torch.inference_mode():
779+
output = vision_model(pixel_values, grid_thw_list)
780+
781+
# Verify output dimensions based on spatial merging
782+
total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list)
783+
merge_factor = spatial_merge_size**2
784+
expected_output_patches = total_patches // merge_factor
785+
786+
assert output.shape[0] == expected_output_patches
787+
assert output.shape[1] == vision_model.out_hidden_size

vllm/model_executor/layers/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def weight_loader(self,
437437
shard_offset = sum(self.output_sizes[:loaded_shard_id])
438438
shard_size = self.output_sizes[loaded_shard_id]
439439

440-
param[shard_offset:shard_offset + shard_size] = loaded_weight
440+
param.data[shard_offset:shard_offset + shard_size] = loaded_weight
441441

442442

443443
@CustomOp.register("column_parallel_linear")

0 commit comments

Comments
 (0)