From b2549c3fbcca4cc68d249239eb5072505fc2ee71 Mon Sep 17 00:00:00 2001 From: wang55 Date: Fri, 22 Aug 2025 01:53:40 +0200 Subject: [PATCH] allow expert_parallel wrapper to handel kwargs --- torchtitan/distributed/expert_parallel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 384d9e33f..25f9db066 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -322,6 +322,7 @@ def wrapper( w3: torch.Tensor, x: torch.Tensor, num_tokens_per_expert: torch.Tensor, + **kwargs, ) -> torch.Tensor: global TOKEN_GROUP_ALIGN_SIZE_M if isinstance(w1, DTensor): @@ -351,7 +352,7 @@ def wrapper( input_shape = x.shape x = x[permuted_indices, :] - out = func(w1, w2, w3, x, num_tokens_per_expert) + out = func(w1, w2, w3, x, num_tokens_per_expert, **kwargs) out_unpermuted = out.new_empty(input_shape) out_unpermuted[permuted_indices, :] = out