Skip to content

Commit 4cf71cc

Browse files
[TPU] Deprecate xm.mark_step in favor of `torch_xla.sync (#25254)
Signed-off-by: NickLucche <[email protected]> Co-authored-by: Ye (Charlotte) Qi <[email protected]>
1 parent a66d131 commit 4cf71cc

File tree

5 files changed

+31
-29
lines changed

5 files changed

+31
-29
lines changed

tests/tpu/test_moe_pallas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77
import pytest
88
import torch
9+
import torch_xla
910

1011
# yapf conflicts with isort for this block
1112
# yapf: disable
@@ -77,7 +78,7 @@ def test_pallas_moe(
7778
expert_map=e_map,
7879
renormalize=False,
7980
)
80-
xm.mark_step()
81+
torch_xla.sync(wait=False)
8182

8283
# Compare outputs
8384
torch.testing.assert_close(

tests/v1/tpu/test_topk_topp_sampler.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66
import torch
7+
import torch_xla
78

89
from vllm.platforms import current_platform
910
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
@@ -63,7 +64,7 @@ def test_topp_result_sums_past_p():
6364
probs.masked_fill_(logits_masked.isinf(), 0)
6465
masked_prob_sum = probs.sum(dim=-1)
6566

66-
xm.mark_step()
67+
torch_xla.sync()
6768

6869
# Perform assertion on CPU.
6970
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
@@ -82,7 +83,7 @@ def test_topp_basic():
8283
k=torch.tensor([3, 3]),
8384
p=torch.tensor([0.79, 0.79]))
8485

85-
xm.mark_step()
86+
torch_xla.sync()
8687

8788
# Expect the smallest elements to be dropped.
8889
expected_result = logits.clone().cpu()
@@ -104,7 +105,7 @@ def test_topp_select_all():
104105
k=torch.tensor([3, 3]),
105106
p=torch.tensor([1.0, 1.0]))
106107

107-
xm.mark_step()
108+
torch_xla.sync()
108109

109110
assert torch.allclose(logits.cpu(), result.cpu())
110111

@@ -122,7 +123,7 @@ def test_topp_with_ties():
122123
k=torch.tensor([4]),
123124
p=torch.tensor([0.2]))
124125

125-
xm.mark_step()
126+
torch_xla.sync()
126127

127128
# All tie values are included in the top-p set. Tie breaking is left
128129
# to be done during final sampling (all tie tokens have equal
@@ -146,7 +147,7 @@ def test_both_topk_topp():
146147
k=torch.tensor([1, 3]),
147148
p=torch.tensor([0.79, 0.79]))
148149

149-
xm.mark_step()
150+
torch_xla.sync()
150151

151152
# Since for the first batch k=1, expect only the largest element gets
152153
# selected.

vllm/lora/punica_wrapper/punica_tpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88
import torch.nn.functional as F
9-
import torch_xla.core.xla_model as xm
9+
import torch_xla
1010

1111
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
1212
from vllm.lora.punica_wrapper.utils import convert_mapping
@@ -323,7 +323,7 @@ def _update_base_metadata(
323323
extra_vocab_size: int,
324324
):
325325
# Make sure we don't accidentally collect outside operations
326-
xm.mark_step()
326+
torch_xla.sync()
327327

328328
# Pad the prompt mapping to avoid running into recompiles on the TPU
329329
# TODO: Should this happen inside mapping internally? If so how can we

vllm/model_executor/model_loader/default_loader.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,16 +211,15 @@ def _get_weights_iterator(
211211
from vllm.platforms.tpu import USE_TPU_COMMONS
212212

213213
if not USE_TPU_COMMONS:
214-
# In PyTorch XLA, we should call `xm.mark_step`
214+
# In PyTorch XLA, we should call `torch_xla.sync`
215215
# frequently so that not too many ops are accumulated
216-
# in the XLA program. import torch_xla.core.xla_model
217-
# as xm
218-
import torch_xla.core.xla_model as xm
216+
# in the XLA program.
217+
import torch_xla
219218

220219
def _xla_weights_iterator(iterator: Generator):
221220
for weights in iterator:
222221
yield weights
223-
xm.mark_step()
222+
torch_xla.sync(wait=False)
224223

225224
weights_iterator = _xla_weights_iterator(weights_iterator)
226225

vllm/v1/worker/tpu_model_runner.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
import torch.nn as nn
1212
# TPU XLA related
13+
import torch_xla
1314
import torch_xla.core.xla_model as xm
1415
import torch_xla.distributed.spmd as xs
1516
import torch_xla.runtime as xr
@@ -846,10 +847,10 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
846847
# 2. A list or tuple (length: num_items) of tensors, each of shape
847848
# (feature_size, hidden_size) in case the feature size is dynamic
848849
# depending on the input multimodal items.
849-
xm.mark_step()
850+
torch_xla.sync(wait=False)
850851
curr_group_outputs = self.model.get_multimodal_embeddings(
851852
**mm_kwargs_group)
852-
xm.mark_step()
853+
torch_xla.sync(wait=False)
853854

854855
sanity_check_mm_encoder_outputs(
855856
curr_group_outputs,
@@ -952,7 +953,7 @@ def execute_model(
952953
mm_embeds = self._gather_mm_embeddings(scheduler_output)
953954
else:
954955
mm_embeds = []
955-
xm.mark_step()
956+
torch_xla.sync(wait=False)
956957
# Prepare inputs, the requests might be split into multiple
957958
# executions, combine the result of each execution.
958959
start_index = 0
@@ -969,7 +970,7 @@ def execute_model(
969970
end_index = self._prepare_inputs(scheduler_output, start_index)
970971
input_ids, inputs_embeds = self._get_model_inputs(
971972
self.input_ids, mm_embeds)
972-
xm.mark_step()
973+
torch_xla.sync(wait=False)
973974
# Run the decoder
974975
with set_forward_context(
975976
attn_metadata,
@@ -1183,7 +1184,7 @@ def load_model(self) -> None:
11831184

11841185
# Sync all pending XLA execution during model initialization and weight
11851186
# loading.
1186-
xm.mark_step()
1187+
torch_xla.sync(wait=False)
11871188
xm.wait_device_ops()
11881189
if not hasattr(self, "model"):
11891190
self.model = model
@@ -1267,10 +1268,10 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
12671268

12681269
def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping,
12691270
lora_requests) -> None:
1270-
xm.mark_step() # Captures input updates
1271+
torch_xla.sync(wait=False) # Captures input updates
12711272
super()._set_active_loras(prompt_lora_mapping, token_lora_mapping,
12721273
lora_requests)
1273-
xm.mark_step() # Captures metadata updates
1274+
torch_xla.sync(wait=False) # Captures metadata updates
12741275

12751276
def _precompile_mm_encoder(self) -> None:
12761277
if not self.supports_mm_inputs:
@@ -1297,10 +1298,10 @@ def _precompile_mm_encoder(self) -> None:
12971298
num_items,
12981299
)
12991300
# Run multimodal encoder.
1300-
xm.mark_step()
1301+
torch_xla.sync(wait=False)
13011302
mm_embeds = self.model.get_multimodal_embeddings(
13021303
**batched_dummy_mm_inputs)
1303-
xm.mark_step()
1304+
torch_xla.sync(wait=False)
13041305
num_patches = mm_embeds[0].shape[0]
13051306
items_size = num_patches * num_items
13061307

@@ -1325,7 +1326,7 @@ def _precompile_mm_encoder(self) -> None:
13251326
a, b = self._get_model_inputs(placeholders_ids,
13261327
[mm_embeds])
13271328
assert a is None
1328-
xm.mark_step()
1329+
torch_xla.sync(wait=False)
13291330

13301331
# Pre-compile `get_input_embeddings` when mm_embeddings are not
13311332
# present. Chunk is only made of text, no mm_placeholders.
@@ -1336,7 +1337,7 @@ def _precompile_mm_encoder(self) -> None:
13361337
placeholders_ids = placeholders_ids.to(self.device)
13371338
a, b = self._get_model_inputs(placeholders_ids, [])
13381339
assert a is None
1339-
xm.mark_step()
1340+
torch_xla.sync(wait=False)
13401341

13411342
xm.wait_device_ops()
13421343
end = time.perf_counter()
@@ -1532,11 +1533,11 @@ def profile_run(
15321533
# Isolate encoder graph from post-processing to minimize
15331534
# impact of recompilation until it's fixed.
15341535
start = time.perf_counter()
1535-
xm.mark_step()
1536+
torch_xla.sync(wait=False)
15361537
dummy_encoder_outputs = \
15371538
self.model.get_multimodal_embeddings(
15381539
**batched_dummy_mm_inputs)
1539-
xm.mark_step()
1540+
torch_xla.sync(wait=False)
15401541
xm.wait_device_ops()
15411542
end = time.perf_counter()
15421543
logger.info(
@@ -1559,7 +1560,7 @@ def profile_run(
15591560
self._dummy_run(num_tokens, self.num_reqs_most_model_len,
15601561
self.num_blocks_per_most_len_req)
15611562

1562-
xm.mark_step()
1563+
torch_xla.sync(wait=False)
15631564
xm.wait_device_ops()
15641565
self.encoder_cache.clear()
15651566
gc.collect()
@@ -1927,11 +1928,11 @@ def _tpu_set_lora(
19271928
# to a tensor doesn't seem to work anymore. This might be fixed with a
19281929
# later release of torch_xla.
19291930
self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias)
1930-
xm.mark_step()
1931+
torch_xla.sync(wait=False)
19311932

19321933
def _tpu_reset_lora(self, index: int):
19331934
self._original_reset_lora(index)
1934-
xm.mark_step()
1935+
torch_xla.sync(wait=False)
19351936

19361937
for _, module in model.named_modules():
19371938
if isinstance(module, BaseLayerWithLoRA):

0 commit comments

Comments
 (0)