Skip to content

Commit 43c4f3d

Browse files
[Misc] Begin deprecation of get_tensor_model_*_group (#22494)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 1712543 commit 43c4f3d

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

tests/distributed/test_custom_all_reduce.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010

1111
from vllm.distributed.communication_op import ( # noqa
1212
tensor_model_parallel_all_reduce)
13-
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
14-
get_tp_group, graph_capture)
13+
from vllm.distributed.parallel_state import get_tp_group, graph_capture
1514

1615
from ..utils import (ensure_model_parallel_initialized,
1716
init_test_distributed_environment, multi_process_parallel)
@@ -37,7 +36,7 @@ def graph_allreduce(
3736
init_test_distributed_environment(tp_size, pp_size, rank,
3837
distributed_init_port)
3938
ensure_model_parallel_initialized(tp_size, pp_size)
40-
group = get_tensor_model_parallel_group().device_group
39+
group = get_tp_group().device_group
4140

4241
# A small all_reduce for warmup.
4342
# this is needed because device communicators might be created lazily

tests/distributed/test_quick_all_reduce.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010

1111
from vllm.distributed.communication_op import ( # noqa
1212
tensor_model_parallel_all_reduce)
13-
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
14-
get_tp_group, graph_capture)
13+
from vllm.distributed.parallel_state import get_tp_group, graph_capture
1514
from vllm.platforms import current_platform
1615

1716
from ..utils import (ensure_model_parallel_initialized,
@@ -42,7 +41,7 @@ def graph_quickreduce(
4241
init_test_distributed_environment(tp_size, pp_size, rank,
4342
distributed_init_port)
4443
ensure_model_parallel_initialized(tp_size, pp_size)
45-
group = get_tensor_model_parallel_group().device_group
44+
group = get_tp_group().device_group
4645

4746
# A small all_reduce for warmup.
4847
# this is needed because device communicators might be created lazily

vllm/distributed/parallel_state.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import torch
3737
import torch.distributed
3838
from torch.distributed import Backend, ProcessGroup
39+
from typing_extensions import deprecated
3940

4041
import vllm.envs as envs
4142
from vllm.distributed.device_communicators.base_device_communicator import (
@@ -894,8 +895,12 @@ def get_tp_group() -> GroupCoordinator:
894895
return _TP
895896

896897

897-
# kept for backward compatibility
898-
get_tensor_model_parallel_group = get_tp_group
898+
@deprecated("`get_tensor_model_parallel_group` has been replaced with "
899+
"`get_tp_group` and may be removed after v0.12. Please use "
900+
"`get_tp_group` instead.")
901+
def get_tensor_model_parallel_group():
902+
return get_tp_group()
903+
899904

900905
_PP: Optional[GroupCoordinator] = None
901906

@@ -921,8 +926,11 @@ def get_pp_group() -> GroupCoordinator:
921926
return _PP
922927

923928

924-
# kept for backward compatibility
925-
get_pipeline_model_parallel_group = get_pp_group
929+
@deprecated("`get_pipeline_model_parallel_group` has been replaced with "
930+
"`get_pp_group` and may be removed in v0.12. Please use "
931+
"`get_pp_group` instead.")
932+
def get_pipeline_model_parallel_group():
933+
return get_pp_group()
926934

927935

928936
@contextmanager

0 commit comments

Comments
 (0)