Skip to content

Commit 7440cb4

Browse files
author
Varun Sundar Rabindranath
committed
fix get_expert_weights
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
1 parent 26b9aaa commit 7440cb4

File tree

1 file changed

+21
-15
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+21
-15
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,29 +1225,35 @@ def load_weights(
12251225
yield param_name
12261226

12271227
def get_expert_weights(self) -> Iterable[torch.Tensor]:
1228-
def maybe_make_contiguous(name: str, p: torch.nn.Parameter) -> torch.nn.Parameter:
1228+
def _maybe_make_contiguous(
1229+
name: str, p: torch.nn.Parameter
1230+
) -> torch.nn.Parameter:
12291231
"""
1230-
Expert weight-scales are transposed and are represented
1231-
in column-major. This function transposes the tensor back
1232-
so the tensor is contiguous().
1232+
In some cases, the last 2 dimensions (the non-expert dimensions)
1233+
of the weight scale tensor are transposed. This function transposes
1234+
the tensor back so the tensor is contiguous().
1235+
Example: A scale tensor,
1236+
`x` of shape (E, 32, 16) and stride (512, 1, 32) is transposed to
1237+
`xt` of shape (E, 16, 32) and stride (512, 32, 1).
1238+
Note that we specifically use torch.transpose() so `xt` refers
1239+
to the same underlying memory. The tensors `x` and `xt`, pointing
1240+
to the same underlying memory make this transformation safe in the
1241+
context of EPLB. i.e. It is the same memory and just the view
1242+
is different.
1243+
Note: This function handles the "weight_scale" tensors specifically.
1244+
This could however be generalized to handle similar tensors.
12331245
"""
1234-
if p.is_contiguous():
1235-
return p
1236-
if "weight_scale" not in name:
1246+
# Check if the last 2 dimensions are trasposed
1247+
is_transposed = p.stride(1) == 1 and p.stride(2) != 1
1248+
if p.is_contiguous() or not is_transposed or "weight_scale" not in name:
12371249
# do nothing.
12381250
return p
12391251
assert p.ndim == 3
1240-
# Check if the tensor is tranposed
1241-
is_colmajor = p.size(1) == 1 and p.size(2) != 1
1242-
p = torch.transpose(p, 1, 2)
1243-
assert p.is_contiguous()
1252+
p.data = torch.transpose(p.data, 1, 2)
12441253
return p
12451254

12461255
weights = list(self.named_parameters())
1247-
weights = [ (name, maybe_make_contiguous(name, p)) for name, p in weights]
1248-
1249-
#for name, weight in weights:
1250-
# print (f"{name} is_contiguous() ? {weight.is_contiguous()}")
1256+
weights = [(name, _maybe_make_contiguous(name, p)) for name, p in weights]
12511257

12521258
assert all(
12531259
weight.is_contiguous()

0 commit comments

Comments
 (0)