Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions vllm/distributed/device_communicators/tpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS
from vllm.platforms.tpu import USE_TPU_INFERENCE

from .base_device_communicator import DeviceCommunicatorBase

Expand All @@ -20,8 +20,8 @@

logger = init_logger(__name__)

if not USE_TPU_COMMONS:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
if not USE_TPU_INFERENCE:
logger.info("tpu_inference not found, using vLLM's TpuCommunicator")
if current_platform.is_tpu():
import torch_xla
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -100,9 +100,9 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
return xm.all_gather(input_, dim=dim)


if USE_TPU_COMMONS:
from tpu_commons.distributed.device_communicators import (
TpuCommunicator as TpuCommonsCommunicator,
if USE_TPU_INFERENCE:
from tpu_inference.distributed.device_communicators import (
TpuCommunicator as TpuInferenceCommunicator,
)

TpuCommunicator = TpuCommonsCommunicator # type: ignore
TpuCommunicator = TpuInferenceCommunicator # type: ignore
4 changes: 2 additions & 2 deletions vllm/model_executor/model_loader/default_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@ def _get_weights_iterator(
)

if current_platform.is_tpu():
from vllm.platforms.tpu import USE_TPU_COMMONS
from vllm.platforms.tpu import USE_TPU_INFERENCE

if not USE_TPU_COMMONS:
if not USE_TPU_INFERENCE:
# In PyTorch XLA, we should call `torch_xla.sync`
# frequently so that not too many ops are accumulated
# in the XLA program.
Expand Down
2 changes: 1 addition & 1 deletion vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def tpu_platform_plugin() -> Optional[str]:
# Check for Pathways TPU proxy
if envs.VLLM_TPU_USING_PATHWAYS:
logger.debug("Confirmed TPU platform is available via Pathways proxy.")
return "tpu_commons.platforms.tpu_jax.TpuPlatform"
return "tpu_inference.platforms.tpu_jax.TpuPlatform"

# Check for libtpu installation
try:
Expand Down
10 changes: 5 additions & 5 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

logger = init_logger(__name__)

USE_TPU_COMMONS = False
USE_TPU_INFERENCE = False


class TpuPlatform(Platform):
Expand Down Expand Up @@ -254,10 +254,10 @@ def use_sync_weight_loader(cls) -> bool:


try:
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform

TpuPlatform = TpuCommonsPlatform # type: ignore
USE_TPU_COMMONS = True
TpuPlatform = TpuInferencePlatform # type: ignore
USE_TPU_INFERENCE = True
except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuPlatform")
logger.info("tpu_inference not found, using vLLM's TpuPlatform")
pass
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
}

try:
import tpu_commons # noqa: F401
import tpu_inference # noqa: F401
except ImportError:
# Lazy import torch_xla
import torch_xla.core.xla_builder as xb
Expand Down
12 changes: 6 additions & 6 deletions vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS
from vllm.platforms.tpu import USE_TPU_INFERENCE
from vllm.tasks import SupportedTask
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.core.sched.output import SchedulerOutput
Expand All @@ -36,8 +36,8 @@

_R = TypeVar("_R")

if not USE_TPU_COMMONS:
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
if not USE_TPU_INFERENCE:
logger.info("tpu_inference not found, using vLLM's TPUWorker.")
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
Expand Down Expand Up @@ -346,7 +346,7 @@ def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
return fn(self.get_model())


if USE_TPU_COMMONS:
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
if USE_TPU_INFERENCE:
from tpu_inference.worker import TPUWorker as TpuInferenceWorker

TPUWorker = TPUCommonsWorker # type: ignore
TPUWorker = TpuInferenceWorker # type: ignore