|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 |
|
4 | 4 | import base64
|
| 5 | +import math |
5 | 6 | import mimetypes
|
6 | 7 | import os
|
7 | 8 | from tempfile import NamedTemporaryFile, TemporaryDirectory
|
|
20 | 21 | from vllm.multimodal.image import convert_image_mode
|
21 | 22 | from vllm.multimodal.inputs import PlaceholderRange
|
22 | 23 | from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions,
|
| 24 | + get_load_balance_assignment, |
| 25 | + run_dp_sharded_mrope_vision_model, |
23 | 26 | run_dp_sharded_vision_model)
|
24 | 27 | from vllm.platforms import current_platform
|
25 | 28 | from vllm.utils import get_open_port, update_environment_variables
|
@@ -425,8 +428,8 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
|
425 | 428 | # Set random seed for reproducibility
|
426 | 429 | current_platform.seed_everything(0)
|
427 | 430 |
|
428 |
| - device = torch.device(f"cuda:{local_rank}") |
429 |
| - torch.cuda.set_device(device) |
| 431 | + device = f"{current_platform.device_name}:{local_rank}" |
| 432 | + current_platform.set_device(device) |
430 | 433 | torch.set_default_device(device)
|
431 | 434 |
|
432 | 435 | update_environment_variables({
|
@@ -463,3 +466,322 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
|
463 | 466 |
|
464 | 467 | # Check that the outputs are close (they should be identical)
|
465 | 468 | assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)
|
| 469 | + |
| 470 | + |
| 471 | +@pytest.mark.parametrize( |
| 472 | + "sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts," |
| 473 | + "expected_grouped_sizes_per_gpu,test_description", |
| 474 | + [ |
| 475 | + # Empty input |
| 476 | + ([], 2, [], [0, 0], [0, 0], "empty input"), |
| 477 | +
|
| 478 | + # Fewer samples than GPUs |
| 479 | + ([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0 |
| 480 | + ], "fewer samples than GPUs"), |
| 481 | +
|
| 482 | + # Single GPU |
| 483 | + ([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"), |
| 484 | +
|
| 485 | + # Balanced assignment |
| 486 | + ([100, 100, 100, 100 |
| 487 | + ], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"), |
| 488 | +
|
| 489 | + # Unbalanced sizes - this one is trickier since the algorithm is greedy |
| 490 | + ([1000, 100, 200, 50], 2, [0, 2, 1, 3 |
| 491 | + ], [1, 3], [1000, 350], "unbalanced sizes"), |
| 492 | + ], |
| 493 | +) |
| 494 | +def test_get_load_balance_assignment_cases(sizes, num_gpus, |
| 495 | + expected_shuffle_indices, |
| 496 | + expected_gpu_sample_counts, |
| 497 | + expected_grouped_sizes_per_gpu, |
| 498 | + test_description): |
| 499 | + """Test get_load_balance_assignment with various input cases.""" |
| 500 | + result = get_load_balance_assignment(sizes, num_gpus=num_gpus) |
| 501 | + (shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result |
| 502 | + |
| 503 | + # Common assertions for all cases |
| 504 | + assert len(shuffle_indices) == len(sizes) |
| 505 | + assert len(gpu_sample_counts) == num_gpus |
| 506 | + assert len(grouped_sizes_per_gpu) == num_gpus |
| 507 | + assert sum(gpu_sample_counts) == len(sizes) |
| 508 | + |
| 509 | + assert shuffle_indices == expected_shuffle_indices |
| 510 | + |
| 511 | + assert gpu_sample_counts == expected_gpu_sample_counts |
| 512 | + assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu |
| 513 | + |
| 514 | + |
| 515 | +class SimpleMRopeVisionModel(torch.nn.Module): |
| 516 | + """A simple vision model for testing mrope functionality.""" |
| 517 | + |
| 518 | + def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64): |
| 519 | + super().__init__() |
| 520 | + self.spatial_merge_size = spatial_merge_size |
| 521 | + self.out_hidden_size = out_hidden_size |
| 522 | + self.linear = torch.nn.Linear(768, out_hidden_size) |
| 523 | + |
| 524 | + def forward(self, pixel_values: torch.Tensor, |
| 525 | + grid_thw_list: list[list[int]]): |
| 526 | + """Simple forward pass that simulates spatial merging.""" |
| 527 | + # Apply linear transformation |
| 528 | + embeddings = self.linear(pixel_values) |
| 529 | + |
| 530 | + # Simulate spatial merging by reducing the number of patches |
| 531 | + merge_factor = self.spatial_merge_size * self.spatial_merge_size |
| 532 | + |
| 533 | + # Group patches and merge spatially |
| 534 | + merged_embeddings = [] |
| 535 | + start_idx = 0 |
| 536 | + |
| 537 | + for grid_thw in grid_thw_list: |
| 538 | + num_patches = math.prod(grid_thw) |
| 539 | + end_idx = start_idx + num_patches |
| 540 | + |
| 541 | + # Get patches for this image |
| 542 | + image_patches = embeddings[start_idx:end_idx] |
| 543 | + |
| 544 | + # Simulate spatial merging by averaging groups of patches |
| 545 | + merged_patches = num_patches // merge_factor |
| 546 | + if merged_patches > 0: |
| 547 | + # Reshape and average to simulate merging |
| 548 | + reshaped = image_patches[:merged_patches * merge_factor].view( |
| 549 | + merged_patches, merge_factor, -1) |
| 550 | + merged = reshaped.mean(dim=1) |
| 551 | + merged_embeddings.append(merged) |
| 552 | + |
| 553 | + start_idx = end_idx |
| 554 | + |
| 555 | + if merged_embeddings: |
| 556 | + return torch.cat(merged_embeddings, dim=0) |
| 557 | + else: |
| 558 | + return torch.empty((0, self.out_hidden_size), |
| 559 | + device=pixel_values.device, |
| 560 | + dtype=pixel_values.dtype) |
| 561 | + |
| 562 | + |
| 563 | +@multi_gpu_test(num_gpus=2) |
| 564 | +@pytest.mark.parametrize( |
| 565 | + "batch_size", |
| 566 | + [ |
| 567 | + 1, # Single image |
| 568 | + 3, # Small batch |
| 569 | + 5, # Odd batch size (for testing padding) |
| 570 | + ], |
| 571 | +) |
| 572 | +def test_run_dp_sharded_mrope_vision_model(batch_size: int): |
| 573 | + world_size = 2 |
| 574 | + # Launch processes |
| 575 | + mp.spawn( |
| 576 | + run_dp_sharded_mrope_vision_model_vs_direct, |
| 577 | + args=( |
| 578 | + world_size, |
| 579 | + batch_size, |
| 580 | + get_open_port(), |
| 581 | + ), |
| 582 | + nprocs=world_size, |
| 583 | + ) |
| 584 | + |
| 585 | + |
| 586 | +def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int, |
| 587 | + world_size: int, |
| 588 | + batch_size: int, |
| 589 | + master_port: int): |
| 590 | + """ |
| 591 | + Test that run_dp_sharded_mrope_vision_model produces the same results as |
| 592 | + calling the model directly. |
| 593 | + """ |
| 594 | + # Set random seed for reproducibility |
| 595 | + current_platform.seed_everything(0) |
| 596 | + device = f"{current_platform.device_name}:{local_rank}" |
| 597 | + current_platform.set_device(device) |
| 598 | + torch.set_default_device(device) |
| 599 | + |
| 600 | + update_environment_variables({ |
| 601 | + 'RANK': str(local_rank), |
| 602 | + 'LOCAL_RANK': str(local_rank), |
| 603 | + 'WORLD_SIZE': str(world_size), |
| 604 | + 'MASTER_ADDR': 'localhost', |
| 605 | + 'MASTER_PORT': str(master_port), |
| 606 | + }) |
| 607 | + |
| 608 | + # initialize distributed |
| 609 | + init_distributed_environment() |
| 610 | + initialize_model_parallel(tensor_model_parallel_size=world_size) |
| 611 | + |
| 612 | + # Create test data |
| 613 | + grid_thw_list = [] |
| 614 | + pixel_values_list = [] |
| 615 | + |
| 616 | + for i in range(batch_size): |
| 617 | + # Varying image sizes for better testing |
| 618 | + t, h, w = 1, 4 + i, 4 + i |
| 619 | + grid_thw_list.append([t, h, w]) |
| 620 | + |
| 621 | + num_patches = t * h * w |
| 622 | + # Create random pixel values for this image |
| 623 | + image_pixels = torch.randn(num_patches, 768) |
| 624 | + pixel_values_list.append(image_pixels) |
| 625 | + |
| 626 | + # Concatenate all pixel values |
| 627 | + pixel_values = torch.cat(pixel_values_list, dim=0) |
| 628 | + |
| 629 | + # Create a simple mrope vision model |
| 630 | + vision_model = SimpleMRopeVisionModel() |
| 631 | + |
| 632 | + # Run the model directly on the full input (only on rank 0) |
| 633 | + if local_rank == 0: |
| 634 | + with torch.inference_mode(): |
| 635 | + direct_output = vision_model(pixel_values, grid_thw_list) |
| 636 | + |
| 637 | + # Run the model through the sharded function |
| 638 | + with torch.inference_mode(): |
| 639 | + sharded_output = run_dp_sharded_mrope_vision_model( |
| 640 | + vision_model, pixel_values, grid_thw_list) |
| 641 | + sharded_output = torch.cat(sharded_output, dim=0) |
| 642 | + |
| 643 | + # Check that the world size is setup correctly |
| 644 | + assert get_tensor_model_parallel_world_size() == world_size |
| 645 | + |
| 646 | + # Compare outputs (only on rank 0) |
| 647 | + if local_rank == 0: |
| 648 | + # Check that the outputs have the same shape |
| 649 | + assert direct_output.shape == sharded_output.shape |
| 650 | + # Check that the outputs are close (they should be identical) |
| 651 | + assert torch.allclose(direct_output, |
| 652 | + sharded_output, |
| 653 | + rtol=1e-5, |
| 654 | + atol=1e-5) |
| 655 | + |
| 656 | + |
| 657 | +@multi_gpu_test(num_gpus=2) |
| 658 | +def test_run_dp_sharded_mrope_vision_model_empty_input(): |
| 659 | + world_size = 2 |
| 660 | + mp.spawn( |
| 661 | + run_dp_sharded_mrope_vision_model_empty_input_worker, |
| 662 | + args=(world_size, get_open_port()), |
| 663 | + nprocs=world_size, |
| 664 | + ) |
| 665 | + |
| 666 | + |
| 667 | +def run_dp_sharded_mrope_vision_model_empty_input_worker( |
| 668 | + local_rank: int, world_size: int, master_port: int): |
| 669 | + """Test run_dp_sharded_mrope_vision_model with empty input.""" |
| 670 | + # Set up distributed environment |
| 671 | + device = f"{current_platform.device_name}:{local_rank}" |
| 672 | + current_platform.set_device(device) |
| 673 | + torch.set_default_device(device) |
| 674 | + |
| 675 | + update_environment_variables({ |
| 676 | + 'RANK': str(local_rank), |
| 677 | + 'LOCAL_RANK': str(local_rank), |
| 678 | + 'WORLD_SIZE': str(world_size), |
| 679 | + 'MASTER_ADDR': 'localhost', |
| 680 | + 'MASTER_PORT': str(master_port), |
| 681 | + }) |
| 682 | + |
| 683 | + init_distributed_environment() |
| 684 | + initialize_model_parallel(tensor_model_parallel_size=world_size) |
| 685 | + |
| 686 | + # Create empty inputs |
| 687 | + pixel_values = torch.empty((0, 768)) |
| 688 | + grid_thw_list: list[list[int]] = [] |
| 689 | + |
| 690 | + vision_model = SimpleMRopeVisionModel() |
| 691 | + |
| 692 | + # Should handle empty input gracefully |
| 693 | + with torch.inference_mode(): |
| 694 | + output = run_dp_sharded_mrope_vision_model(vision_model, pixel_values, |
| 695 | + grid_thw_list) |
| 696 | + |
| 697 | + assert len(output) == 0 |
| 698 | + |
| 699 | + |
| 700 | +@multi_gpu_test(num_gpus=4) |
| 701 | +def test_run_dp_sharded_mrope_vision_model_uneven_load(): |
| 702 | + world_size = 4 |
| 703 | + mp.spawn( |
| 704 | + run_dp_sharded_mrope_vision_model_uneven_load_worker, |
| 705 | + args=(world_size, get_open_port()), |
| 706 | + nprocs=world_size, |
| 707 | + ) |
| 708 | + |
| 709 | + |
| 710 | +def run_dp_sharded_mrope_vision_model_uneven_load_worker( |
| 711 | + local_rank: int, world_size: int, master_port: int): |
| 712 | + """Test run_dp_sharded_mrope_vision_model with uneven load distribution.""" |
| 713 | + # Set up distributed environment |
| 714 | + current_platform.seed_everything(123) |
| 715 | + device = f"{current_platform.device_name}:{local_rank}" |
| 716 | + current_platform.set_device(device) |
| 717 | + torch.set_default_device(device) |
| 718 | + |
| 719 | + update_environment_variables({ |
| 720 | + 'RANK': str(local_rank), |
| 721 | + 'LOCAL_RANK': str(local_rank), |
| 722 | + 'WORLD_SIZE': str(world_size), |
| 723 | + 'MASTER_ADDR': 'localhost', |
| 724 | + 'MASTER_PORT': str(master_port), |
| 725 | + }) |
| 726 | + |
| 727 | + init_distributed_environment() |
| 728 | + initialize_model_parallel(tensor_model_parallel_size=world_size) |
| 729 | + |
| 730 | + # Create images with very different sizes |
| 731 | + grid_thw_list = [ |
| 732 | + [1, 2, 2], # Small: 4 patches |
| 733 | + [1, 8, 8], # Large: 64 patches |
| 734 | + [1, 3, 3], # Medium: 9 patches |
| 735 | + ] |
| 736 | + |
| 737 | + pixel_values_list = [] |
| 738 | + for grid_thw in grid_thw_list: |
| 739 | + num_patches = math.prod(grid_thw) |
| 740 | + image_pixels = torch.randn(num_patches, 768) |
| 741 | + pixel_values_list.append(image_pixels) |
| 742 | + |
| 743 | + pixel_values = torch.cat(pixel_values_list, dim=0) |
| 744 | + vision_model = SimpleMRopeVisionModel() |
| 745 | + |
| 746 | + # Should handle uneven distribution without errors |
| 747 | + with torch.inference_mode(): |
| 748 | + output_tuple = run_dp_sharded_mrope_vision_model( |
| 749 | + vision_model, pixel_values, grid_thw_list) |
| 750 | + |
| 751 | + # Verify output shape is reasonable |
| 752 | + merge_factor = vision_model.spatial_merge_size**2 |
| 753 | + expected_output_patches = list( |
| 754 | + math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list) |
| 755 | + |
| 756 | + for i, output in enumerate(output_tuple): |
| 757 | + assert output.shape[0] == expected_output_patches[i] |
| 758 | + assert output.shape[1] == vision_model.out_hidden_size |
| 759 | + |
| 760 | + |
| 761 | +@pytest.mark.parametrize("spatial_merge_size", [2, 4]) |
| 762 | +def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int): |
| 763 | + """Test SimpleMRopeVisionModel with different spatial merge sizes.""" |
| 764 | + device = current_platform.device_type |
| 765 | + |
| 766 | + grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images |
| 767 | + pixel_values_list = [] |
| 768 | + |
| 769 | + for grid_thw in grid_thw_list: |
| 770 | + num_patches = math.prod(grid_thw) |
| 771 | + image_pixels = torch.randn(num_patches, 768, device=device) |
| 772 | + pixel_values_list.append(image_pixels) |
| 773 | + |
| 774 | + pixel_values = torch.cat(pixel_values_list, dim=0) |
| 775 | + vision_model = SimpleMRopeVisionModel( |
| 776 | + spatial_merge_size=spatial_merge_size).to(device) |
| 777 | + |
| 778 | + with torch.inference_mode(): |
| 779 | + output = vision_model(pixel_values, grid_thw_list) |
| 780 | + |
| 781 | + # Verify output dimensions based on spatial merging |
| 782 | + total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list) |
| 783 | + merge_factor = spatial_merge_size**2 |
| 784 | + expected_output_patches = total_patches // merge_factor |
| 785 | + |
| 786 | + assert output.shape[0] == expected_output_patches |
| 787 | + assert output.shape[1] == vision_model.out_hidden_size |
0 commit comments