From a94cc85adfa414c53fbf81a796f4cad8e1e8f2bd Mon Sep 17 00:00:00 2001 From: Utkarsh Sharma Date: Mon, 6 Oct 2025 06:14:42 +0000 Subject: [PATCH] Rename tpu_commons to tpu_inference Signed-off-by: Utkarsh Sharma --- .../device_communicators/tpu_communicator.py | 14 +++++++------- vllm/model_executor/model_loader/default_loader.py | 4 ++-- vllm/platforms/__init__.py | 2 +- vllm/platforms/tpu.py | 10 +++++----- vllm/v1/attention/backends/pallas.py | 2 +- vllm/v1/worker/tpu_worker.py | 12 ++++++------ 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index e0ac9df9a6af..b2faea512791 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -10,7 +10,7 @@ from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.platforms.tpu import USE_TPU_COMMONS +from vllm.platforms.tpu import USE_TPU_INFERENCE from .base_device_communicator import DeviceCommunicatorBase @@ -20,8 +20,8 @@ logger = init_logger(__name__) -if not USE_TPU_COMMONS: - logger.info("tpu_commons not found, using vLLM's TpuCommunicator") +if not USE_TPU_INFERENCE: + logger.info("tpu_inference not found, using vLLM's TpuCommunicator") if current_platform.is_tpu(): import torch_xla import torch_xla.core.xla_model as xm @@ -100,9 +100,9 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: return xm.all_gather(input_, dim=dim) -if USE_TPU_COMMONS: - from tpu_commons.distributed.device_communicators import ( - TpuCommunicator as TpuCommonsCommunicator, +if USE_TPU_INFERENCE: + from tpu_inference.distributed.device_communicators import ( + TpuCommunicator as TpuInferenceCommunicator, ) - TpuCommunicator = TpuCommonsCommunicator # type: ignore + TpuCommunicator = TpuInferenceCommunicator # type: ignore diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 206b8244569f..4f2bfd89348e 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -222,9 +222,9 @@ def _get_weights_iterator( ) if current_platform.is_tpu(): - from vllm.platforms.tpu import USE_TPU_COMMONS + from vllm.platforms.tpu import USE_TPU_INFERENCE - if not USE_TPU_COMMONS: + if not USE_TPU_INFERENCE: # In PyTorch XLA, we should call `torch_xla.sync` # frequently so that not too many ops are accumulated # in the XLA program. diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 5154b1cea782..962e1323b721 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -37,7 +37,7 @@ def tpu_platform_plugin() -> Optional[str]: # Check for Pathways TPU proxy if envs.VLLM_TPU_USING_PATHWAYS: logger.debug("Confirmed TPU platform is available via Pathways proxy.") - return "tpu_commons.platforms.tpu_jax.TpuPlatform" + return "tpu_inference.platforms.tpu_jax.TpuPlatform" # Check for libtpu installation try: diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 6be9ca1298a9..c0888247f593 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -26,7 +26,7 @@ logger = init_logger(__name__) -USE_TPU_COMMONS = False +USE_TPU_INFERENCE = False class TpuPlatform(Platform): @@ -254,10 +254,10 @@ def use_sync_weight_loader(cls) -> bool: try: - from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform + from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform - TpuPlatform = TpuCommonsPlatform # type: ignore - USE_TPU_COMMONS = True + TpuPlatform = TpuInferencePlatform # type: ignore + USE_TPU_INFERENCE = True except ImportError: - logger.info("tpu_commons not found, using vLLM's TpuPlatform") + logger.info("tpu_inference not found, using vLLM's TpuPlatform") pass diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 7e83e7a681f4..1622f852a952 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -35,7 +35,7 @@ } try: - import tpu_commons # noqa: F401 + import tpu_inference # noqa: F401 except ImportError: # Lazy import torch_xla import torch_xla.core.xla_builder as xb diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 66515c7e5786..861d7ae737ee 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -23,7 +23,7 @@ from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.platforms import current_platform -from vllm.platforms.tpu import USE_TPU_COMMONS +from vllm.platforms.tpu import USE_TPU_INFERENCE from vllm.tasks import SupportedTask from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.core.sched.output import SchedulerOutput @@ -36,8 +36,8 @@ _R = TypeVar("_R") -if not USE_TPU_COMMONS: - logger.info("tpu_commons not found, using vLLM's TPUWorker.") +if not USE_TPU_INFERENCE: + logger.info("tpu_inference not found, using vLLM's TPUWorker.") import torch_xla.core.xla_model as xm import torch_xla.debug.profiler as xp import torch_xla.runtime as xr @@ -346,7 +346,7 @@ def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R: return fn(self.get_model()) -if USE_TPU_COMMONS: - from tpu_commons.worker import TPUWorker as TPUCommonsWorker +if USE_TPU_INFERENCE: + from tpu_inference.worker import TPUWorker as TpuInferenceWorker - TPUWorker = TPUCommonsWorker # type: ignore + TPUWorker = TpuInferenceWorker # type: ignore