Skip to content

Commit 35f0d3b

Browse files
Utkarsh Sharmautkarshsharma1
authored andcommitted
Rename tpu_commons to tpu_inference
Signed-off-by: Utkarsh Sharma <[email protected]>
1 parent e1098ce commit 35f0d3b

File tree

6 files changed

+22
-22
lines changed

6 files changed

+22
-22
lines changed

vllm/distributed/device_communicators/tpu_communicator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from vllm.config import get_current_vllm_config
1111
from vllm.logger import init_logger
1212
from vllm.platforms import current_platform
13-
from vllm.platforms.tpu import USE_TPU_COMMONS
13+
from vllm.platforms.tpu import USE_TPU_INFERENCE
1414

1515
from .base_device_communicator import DeviceCommunicatorBase
1616

@@ -20,8 +20,8 @@
2020

2121
logger = init_logger(__name__)
2222

23-
if not USE_TPU_COMMONS:
24-
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
23+
if not USE_TPU_INFERENCE:
24+
logger.info("tpu_inference not found, using vLLM's TpuCommunicator")
2525
if current_platform.is_tpu():
2626
import torch_xla
2727
import torch_xla.core.xla_model as xm
@@ -100,9 +100,9 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
100100
return xm.all_gather(input_, dim=dim)
101101

102102

103-
if USE_TPU_COMMONS:
104-
from tpu_commons.distributed.device_communicators import (
105-
TpuCommunicator as TpuCommonsCommunicator,
103+
if USE_TPU_INFERENCE:
104+
from tpu_inference.distributed.device_communicators import (
105+
TpuCommunicator as TpuInferenceCommunicator,
106106
)
107107

108-
TpuCommunicator = TpuCommonsCommunicator # type: ignore
108+
TpuCommunicator = TpuInferenceCommunicator # type: ignore

vllm/model_executor/model_loader/default_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,9 @@ def _get_weights_iterator(
222222
)
223223

224224
if current_platform.is_tpu():
225-
from vllm.platforms.tpu import USE_TPU_COMMONS
225+
from vllm.platforms.tpu import USE_TPU_INFERENCE
226226

227-
if not USE_TPU_COMMONS:
227+
if not USE_TPU_INFERENCE:
228228
# In PyTorch XLA, we should call `torch_xla.sync`
229229
# frequently so that not too many ops are accumulated
230230
# in the XLA program.

vllm/platforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def tpu_platform_plugin() -> Optional[str]:
3737
# Check for Pathways TPU proxy
3838
if envs.VLLM_TPU_USING_PATHWAYS:
3939
logger.debug("Confirmed TPU platform is available via Pathways proxy.")
40-
return "tpu_commons.platforms.tpu_jax.TpuPlatform"
40+
return "tpu_inference.platforms.tpu_jax.TpuPlatform"
4141

4242
# Check for libtpu installation
4343
try:

vllm/platforms/tpu.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
logger = init_logger(__name__)
2828

29-
USE_TPU_COMMONS = False
29+
USE_TPU_INFERENCE = False
3030

3131

3232
class TpuPlatform(Platform):
@@ -254,10 +254,10 @@ def use_sync_weight_loader(cls) -> bool:
254254

255255

256256
try:
257-
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
257+
from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
258258

259-
TpuPlatform = TpuCommonsPlatform # type: ignore
260-
USE_TPU_COMMONS = True
259+
TpuPlatform = TpuInferencePlatform # type: ignore
260+
USE_TPU_INFERENCE = True
261261
except ImportError:
262-
logger.info("tpu_commons not found, using vLLM's TpuPlatform")
262+
logger.info("tpu_inference not found, using vLLM's TpuPlatform")
263263
pass

vllm/v1/attention/backends/pallas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
}
3636

3737
try:
38-
import tpu_commons # noqa: F401
38+
import tpu_inference # noqa: F401
3939
except ImportError:
4040
# Lazy import torch_xla
4141
import torch_xla.core.xla_builder as xb

vllm/v1/worker/tpu_worker.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from vllm.lora.request import LoRARequest
2424
from vllm.model_executor import set_random_seed
2525
from vllm.platforms import current_platform
26-
from vllm.platforms.tpu import USE_TPU_COMMONS
26+
from vllm.platforms.tpu import USE_TPU_INFERENCE
2727
from vllm.tasks import SupportedTask
2828
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
2929
from vllm.v1.core.sched.output import SchedulerOutput
@@ -36,8 +36,8 @@
3636

3737
_R = TypeVar("_R")
3838

39-
if not USE_TPU_COMMONS:
40-
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
39+
if not USE_TPU_INFERENCE:
40+
logger.info("tpu_inference not found, using vLLM's TPUWorker.")
4141
import torch_xla.core.xla_model as xm
4242
import torch_xla.debug.profiler as xp
4343
import torch_xla.runtime as xr
@@ -346,7 +346,7 @@ def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
346346
return fn(self.get_model())
347347

348348

349-
if USE_TPU_COMMONS:
350-
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
349+
if USE_TPU_INFERENCE:
350+
from tpu_inference.worker import TPUWorker as TpuInferenceWorker
351351

352-
TPUWorker = TPUCommonsWorker # type: ignore
352+
TPUWorker = TpuInferenceWorker # type: ignore

0 commit comments

Comments
 (0)