10
10
import torch
11
11
import torch .nn as nn
12
12
# TPU XLA related
13
+ import torch_xla
13
14
import torch_xla .core .xla_model as xm
14
15
import torch_xla .distributed .spmd as xs
15
16
import torch_xla .runtime as xr
@@ -846,10 +847,10 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
846
847
# 2. A list or tuple (length: num_items) of tensors, each of shape
847
848
# (feature_size, hidden_size) in case the feature size is dynamic
848
849
# depending on the input multimodal items.
849
- xm . mark_step ( )
850
+ torch_xla . sync ( wait = False )
850
851
curr_group_outputs = self .model .get_multimodal_embeddings (
851
852
** mm_kwargs_group )
852
- xm . mark_step ( )
853
+ torch_xla . sync ( wait = False )
853
854
854
855
sanity_check_mm_encoder_outputs (
855
856
curr_group_outputs ,
@@ -952,7 +953,7 @@ def execute_model(
952
953
mm_embeds = self ._gather_mm_embeddings (scheduler_output )
953
954
else :
954
955
mm_embeds = []
955
- xm . mark_step ( )
956
+ torch_xla . sync ( wait = False )
956
957
# Prepare inputs, the requests might be split into multiple
957
958
# executions, combine the result of each execution.
958
959
start_index = 0
@@ -969,7 +970,7 @@ def execute_model(
969
970
end_index = self ._prepare_inputs (scheduler_output , start_index )
970
971
input_ids , inputs_embeds = self ._get_model_inputs (
971
972
self .input_ids , mm_embeds )
972
- xm . mark_step ( )
973
+ torch_xla . sync ( wait = False )
973
974
# Run the decoder
974
975
with set_forward_context (
975
976
attn_metadata ,
@@ -1183,7 +1184,7 @@ def load_model(self) -> None:
1183
1184
1184
1185
# Sync all pending XLA execution during model initialization and weight
1185
1186
# loading.
1186
- xm . mark_step ( )
1187
+ torch_xla . sync ( wait = False )
1187
1188
xm .wait_device_ops ()
1188
1189
if not hasattr (self , "model" ):
1189
1190
self .model = model
@@ -1267,10 +1268,10 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
1267
1268
1268
1269
def _set_active_loras (self , prompt_lora_mapping , token_lora_mapping ,
1269
1270
lora_requests ) -> None :
1270
- xm . mark_step ( ) # Captures input updates
1271
+ torch_xla . sync ( wait = False ) # Captures input updates
1271
1272
super ()._set_active_loras (prompt_lora_mapping , token_lora_mapping ,
1272
1273
lora_requests )
1273
- xm . mark_step ( ) # Captures metadata updates
1274
+ torch_xla . sync ( wait = False ) # Captures metadata updates
1274
1275
1275
1276
def _precompile_mm_encoder (self ) -> None :
1276
1277
if not self .supports_mm_inputs :
@@ -1297,10 +1298,10 @@ def _precompile_mm_encoder(self) -> None:
1297
1298
num_items ,
1298
1299
)
1299
1300
# Run multimodal encoder.
1300
- xm . mark_step ( )
1301
+ torch_xla . sync ( wait = False )
1301
1302
mm_embeds = self .model .get_multimodal_embeddings (
1302
1303
** batched_dummy_mm_inputs )
1303
- xm . mark_step ( )
1304
+ torch_xla . sync ( wait = False )
1304
1305
num_patches = mm_embeds [0 ].shape [0 ]
1305
1306
items_size = num_patches * num_items
1306
1307
@@ -1325,7 +1326,7 @@ def _precompile_mm_encoder(self) -> None:
1325
1326
a , b = self ._get_model_inputs (placeholders_ids ,
1326
1327
[mm_embeds ])
1327
1328
assert a is None
1328
- xm . mark_step ( )
1329
+ torch_xla . sync ( wait = False )
1329
1330
1330
1331
# Pre-compile `get_input_embeddings` when mm_embeddings are not
1331
1332
# present. Chunk is only made of text, no mm_placeholders.
@@ -1336,7 +1337,7 @@ def _precompile_mm_encoder(self) -> None:
1336
1337
placeholders_ids = placeholders_ids .to (self .device )
1337
1338
a , b = self ._get_model_inputs (placeholders_ids , [])
1338
1339
assert a is None
1339
- xm . mark_step ( )
1340
+ torch_xla . sync ( wait = False )
1340
1341
1341
1342
xm .wait_device_ops ()
1342
1343
end = time .perf_counter ()
@@ -1532,11 +1533,11 @@ def profile_run(
1532
1533
# Isolate encoder graph from post-processing to minimize
1533
1534
# impact of recompilation until it's fixed.
1534
1535
start = time .perf_counter ()
1535
- xm . mark_step ( )
1536
+ torch_xla . sync ( wait = False )
1536
1537
dummy_encoder_outputs = \
1537
1538
self .model .get_multimodal_embeddings (
1538
1539
** batched_dummy_mm_inputs )
1539
- xm . mark_step ( )
1540
+ torch_xla . sync ( wait = False )
1540
1541
xm .wait_device_ops ()
1541
1542
end = time .perf_counter ()
1542
1543
logger .info (
@@ -1559,7 +1560,7 @@ def profile_run(
1559
1560
self ._dummy_run (num_tokens , self .num_reqs_most_model_len ,
1560
1561
self .num_blocks_per_most_len_req )
1561
1562
1562
- xm . mark_step ( )
1563
+ torch_xla . sync ( wait = False )
1563
1564
xm .wait_device_ops ()
1564
1565
self .encoder_cache .clear ()
1565
1566
gc .collect ()
@@ -1927,11 +1928,11 @@ def _tpu_set_lora(
1927
1928
# to a tensor doesn't seem to work anymore. This might be fixed with a
1928
1929
# later release of torch_xla.
1929
1930
self ._original_set_lora (index , lora_a , lora_b , embeddings_tensor , bias )
1930
- xm . mark_step ( )
1931
+ torch_xla . sync ( wait = False )
1931
1932
1932
1933
def _tpu_reset_lora (self , index : int ):
1933
1934
self ._original_reset_lora (index )
1934
- xm . mark_step ( )
1935
+ torch_xla . sync ( wait = False )
1935
1936
1936
1937
for _ , module in model .named_modules ():
1937
1938
if isinstance (module , BaseLayerWithLoRA ):
0 commit comments