6
6
7
7
8
8
import re
9
+ from threading import local
9
10
from typing import Any , Dict
10
11
11
12
import torch
@@ -276,54 +277,15 @@ def _get_local_experts_weights(
276
277
local_expert_tensors [expert_key ] = expert_dtensor
277
278
278
279
return local_expert_tensors
279
-
280
- def _chunk_local_expert_weights (
281
- self ,
282
- local_tensor : torch .Tensor ,
283
- dtensor_placements : tuple ,
284
- dtensor_shape : tuple ,
285
- device_mesh : DeviceMesh ,
286
- ):
287
- """
288
- Chunk the local individual experts weight, assemble back to GroupedExperts weights DTensor.
289
-
290
- This method is a placeholder for future implementation of expert weight concatenation.
291
-
292
- Args:
293
- local_tensor: Concatenated local individual expert weights
294
- """
295
-
296
- # Calculate the index range on dim-i to chunk
297
- for i in range (1 , len (dtensor_placements )):
298
- dim_size = dtensor_shape [i ]
299
- start_index , end_index = self ._caculate_indices_from_placements (
300
- dim = i ,
301
- dim_size = dim_size ,
302
- dtensor_placements = dtensor_placements ,
303
- device_mesh = device_mesh ,
304
- )
305
- # No need to chunk on current dimension
306
- if start_index is None or end_index is None :
307
- continue
308
-
309
- # Chunk local_tensor on dim-i
310
- local_tensor = local_tensor .narrow (i , start_index , end_index - start_index )
311
-
312
- # Assemble DTensor
313
- grouped_expert_weights = DTensor .from_local (
314
- local_tensor , device_mesh , dtensor_placements , run_check = False
315
- )
316
-
317
- return grouped_expert_weights
318
-
280
+
319
281
def _concatenate_local_expert_weights (
320
282
self ,
321
283
expert_weights_by_layer : dict [str , Any ],
322
284
abstract_key : str ,
323
285
device_mesh : DeviceMesh ,
324
286
) -> torch .Tensor :
325
287
"""
326
- Concatenate the weights of separate experts into GroupedExperts weights.
288
+ Try to concatenate the weights of separate experts into GroupedExperts weights.
327
289
"""
328
290
for layer in expert_weights_by_layer .keys ():
329
291
# If we have all the experts for this abstract_key, concatenate them
@@ -335,20 +297,15 @@ def _concatenate_local_expert_weights(
335
297
if len (experts ) == expected_n_experts :
336
298
sorted_expert_ids = sorted (experts .keys ())
337
299
sorted_experts = [experts [i ] for i in sorted_expert_ids ]
338
- local_tensor = torch .stack (sorted_experts , dim = 0 )
339
-
300
+ local_tensor = torch .stack (sorted_experts , dim = 0 ). _local_tensor
301
+
340
302
assert (
341
303
abstract_key in self .grouped_expert_weight_placements
342
304
and abstract_key in self .grouped_expert_weight_shape
343
305
), f"GroupedExperts weight metadata { self .grouped_expert_weight_placements } { self .grouped_expert_weight_shape } can not be None!"
344
306
345
- stacked_dtensor = self ._chunk_local_expert_weights (
346
- local_tensor ,
347
- dtensor_placements = self .grouped_expert_weight_placements [
348
- abstract_key
349
- ],
350
- dtensor_shape = self .grouped_expert_weight_shape [abstract_key ],
351
- device_mesh = device_mesh ,
307
+ stacked_dtensor = DTensor .from_local (
308
+ local_tensor , device_mesh , self .grouped_expert_weight_placements [abstract_key ], run_check = False
352
309
)
353
310
354
311
# Remove these experts from the tracking dict to free memory
0 commit comments