Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions tests/diffusion/distributed/test_hsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,169 @@ def test_from_dict_with_hsdp(self):
assert config.hsdp_shard_size == 2 # auto: 4 // 2


class TestStandaloneHSDPDetection:
"""Tests for standalone HSDP detection and dit_parallel_size calculation.

These tests verify the logic used in initialize_model_parallel() to detect
standalone HSDP mode and calculate effective parallel sizes.

Standalone HSDP is when all non-HSDP parallelism dimensions are 1.
"""

@staticmethod
def compute_standalone_hsdp_params(
data_parallel_size: int = 1,
cfg_parallel_size: int = 1,
sequence_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
tensor_parallel_size: int = 1,
fully_shard_degree: int = 1,
hsdp_replicate_size: int = 1,
) -> dict:
"""Compute standalone HSDP detection parameters.

This mirrors the logic in initialize_model_parallel().
"""
dit_parallel_size = (
data_parallel_size
* cfg_parallel_size
* sequence_parallel_size
* pipeline_parallel_size
* tensor_parallel_size
)

# Check for standalone HSDP: all non-HSDP parallelism dimensions are 1
is_standalone_hsdp = dit_parallel_size == 1 and fully_shard_degree > 1

# For standalone HSDP: use (fully_shard_degree * hsdp_replicate_size)
if is_standalone_hsdp:
effective_dit_parallel_size = fully_shard_degree * hsdp_replicate_size
else:
effective_dit_parallel_size = dit_parallel_size

effective_dp_size = (fully_shard_degree * hsdp_replicate_size) if is_standalone_hsdp else data_parallel_size

return {
"original_dit_parallel_size": dit_parallel_size,
"is_standalone_hsdp": is_standalone_hsdp,
"effective_dit_parallel_size": effective_dit_parallel_size,
"effective_dp_size": effective_dp_size,
}

def test_standalone_hsdp_basic(self):
"""Test basic standalone HSDP detection (shard_size=4, replicate=1)."""
result = self.compute_standalone_hsdp_params(
fully_shard_degree=4,
hsdp_replicate_size=1,
)
assert result["original_dit_parallel_size"] == 1
assert result["is_standalone_hsdp"] is True
assert result["effective_dit_parallel_size"] == 4
assert result["effective_dp_size"] == 4

def test_standalone_hsdp_with_replicate(self):
"""Test standalone HSDP with replication (shard_size=4, replicate=2)."""
result = self.compute_standalone_hsdp_params(
fully_shard_degree=4,
hsdp_replicate_size=2,
)
assert result["original_dit_parallel_size"] == 1
assert result["is_standalone_hsdp"] is True
assert result["effective_dit_parallel_size"] == 8 # 4 * 2
assert result["effective_dp_size"] == 8

def test_combined_hsdp_sp_not_standalone(self):
"""Test HSDP combined with SP is NOT detected as standalone.

This is a regression test for the bug where the condition
`dit_parallel_size == fully_shard_degree` incorrectly matched
combined modes like SP=4 + HSDP=4.
"""
result = self.compute_standalone_hsdp_params(
sequence_parallel_size=4,
fully_shard_degree=4,
hsdp_replicate_size=1,
)
assert result["original_dit_parallel_size"] == 4
assert result["is_standalone_hsdp"] is False
# Should NOT override dp_size for combined mode
assert result["effective_dp_size"] == 1 # original data_parallel_size

def test_combined_hsdp_cfg_not_standalone(self):
"""Test HSDP combined with CFG is NOT detected as standalone."""
result = self.compute_standalone_hsdp_params(
cfg_parallel_size=2,
fully_shard_degree=4,
hsdp_replicate_size=1,
)
assert result["original_dit_parallel_size"] == 2
assert result["is_standalone_hsdp"] is False
assert result["effective_dp_size"] == 1

def test_combined_hsdp_dp_not_standalone(self):
"""Test HSDP combined with DP is NOT detected as standalone."""
result = self.compute_standalone_hsdp_params(
data_parallel_size=2,
fully_shard_degree=4,
hsdp_replicate_size=1,
)
assert result["original_dit_parallel_size"] == 2
assert result["is_standalone_hsdp"] is False
assert result["effective_dp_size"] == 2 # uses original dp_size

def test_combined_hsdp_pp_not_standalone(self):
"""Test HSDP combined with PP is NOT detected as standalone."""
result = self.compute_standalone_hsdp_params(
pipeline_parallel_size=2,
fully_shard_degree=4,
hsdp_replicate_size=1,
)
assert result["original_dit_parallel_size"] == 2
assert result["is_standalone_hsdp"] is False
assert result["effective_dp_size"] == 1

def test_no_hsdp_not_standalone(self):
"""Test that no HSDP (fully_shard_degree=1) is NOT standalone."""
result = self.compute_standalone_hsdp_params(
fully_shard_degree=1,
)
assert result["original_dit_parallel_size"] == 1
assert result["is_standalone_hsdp"] is False
assert result["effective_dp_size"] == 1

def test_combined_multiple_parallelism_not_standalone(self):
"""Test HSDP combined with multiple parallelism is NOT standalone."""
result = self.compute_standalone_hsdp_params(
sequence_parallel_size=2,
cfg_parallel_size=2,
fully_shard_degree=4,
hsdp_replicate_size=1,
)
assert result["original_dit_parallel_size"] == 4 # 2 * 2
assert result["is_standalone_hsdp"] is False
assert result["effective_dp_size"] == 1

def test_standalone_hsdp_large_shard(self):
"""Test standalone HSDP with large shard size."""
result = self.compute_standalone_hsdp_params(
fully_shard_degree=8,
hsdp_replicate_size=1,
)
assert result["is_standalone_hsdp"] is True
assert result["effective_dit_parallel_size"] == 8
assert result["effective_dp_size"] == 8

def test_standalone_hsdp_large_replicate(self):
"""Test standalone HSDP with large replicate size."""
result = self.compute_standalone_hsdp_params(
fully_shard_degree=4,
hsdp_replicate_size=4,
)
assert result["is_standalone_hsdp"] is True
assert result["effective_dit_parallel_size"] == 16 # 4 * 4
assert result["effective_dp_size"] == 16


class TestHSDPShardConditions:
"""Tests for _hsdp_shard_conditions matching logic."""

Expand Down
15 changes: 14 additions & 1 deletion vllm_omni/diffusion/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ def initialize_model_parallel(
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
fully_shard_degree: int = 1,
hsdp_replicate_size: int = 1,
backend: str | None = None,
) -> None:
if backend is None:
Expand Down Expand Up @@ -740,6 +741,14 @@ def initialize_model_parallel(
data_parallel_size * cfg_parallel_size * sequence_parallel_size * pipeline_parallel_size * tensor_parallel_size
)

# Check for standalone HSDP: all non-HSDP parallelism dimensions are 1
is_standalone_hsdp = dit_parallel_size == 1 and fully_shard_degree > 1

# For standalone HSDP: use (fully_shard_degree * hsdp_replicate_size) as dit_parallel_size
# This ensures orthogonal rank generation works correctly for all HSDP workers
if is_standalone_hsdp:
dit_parallel_size = fully_shard_degree * hsdp_replicate_size

if world_size < dit_parallel_size:
raise RuntimeError(
f"world_size ({world_size}) is less than "
Expand All @@ -751,12 +760,16 @@ def initialize_model_parallel(
f"data_parallel_size ({data_parallel_size})"
)

# For standalone HSDP, use (fully_shard_degree * hsdp_replicate_size) as data_parallel_size
# so that RankGenerator.world_size matches the actual number of workers
effective_dp_size = (fully_shard_degree * hsdp_replicate_size) if is_standalone_hsdp else data_parallel_size

rank_generator: RankGenerator = RankGenerator(
tensor_parallel_size,
sequence_parallel_size,
pipeline_parallel_size,
cfg_parallel_size,
data_parallel_size,
effective_dp_size,
fs=fully_shard_degree,
order="tp-sp-pp-cfg-dp",
)
Expand Down
1 change: 1 addition & 0 deletions vllm_omni/diffusion/worker/diffusion_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def init_device(self) -> None:
tensor_parallel_size=parallel_config.tensor_parallel_size,
pipeline_parallel_size=parallel_config.pipeline_parallel_size,
fully_shard_degree=parallel_config.hsdp_shard_size if parallel_config.use_hsdp else 1,
hsdp_replicate_size=parallel_config.hsdp_replicate_size if parallel_config.use_hsdp else 1,
)
init_workspace_manager(self.device)

Expand Down