Skip to content

Commit 79a268c

Browse files
[BUG] fixed fp8 conflict with aqlm (#4307)
Fixes fp8 iterface which broke in AQLM merge.
1 parent eace8bf commit 79a268c

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ steps:
9696
- label: Metrics Test
9797
command: pytest -v -s metrics
9898

99+
- label: Quantization Test
100+
command: pytest -v -s quantization
101+
99102
- label: Benchmarks
100103
working_dir: "/vllm-workspace/.buildkite"
101104
commands:

vllm/model_executor/layers/linear.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,19 @@ def create_weights(self, layer: torch.nn.Module,
3434
output_partition_sizes: List[int], input_size: int,
3535
output_size: int, params_dtype: torch.dtype,
3636
**extra_weight_attrs):
37-
"""Create weights for a linear layer.
38-
39-
The weights will be set as attributes of the layer."""
37+
"""Create weights for a linear layer.
38+
The weights will be set as attributes of the layer.
39+
40+
Args:
41+
layer: The layer that is using the LinearMethodBase factory.
42+
input_size_per_partition: Size of the weight input dim on rank X.
43+
output_partition_sizes: Sizes of the output dim of each logical
44+
weight on rank X. E.g., output_partition_sizes for QKVLinear
45+
is a list contains the width of Wq, Wk, Wv on rank X.
46+
input_size: Size of the input dim of the weight across all ranks.
47+
output_size: Size of the output dim of the weight across all ranks.
48+
params_dtype: Datatype of the parameters.
49+
"""
4050
raise NotImplementedError
4151

4252
@abstractmethod

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,13 @@ def create_weights(
6464
self,
6565
layer: torch.nn.Module,
6666
input_size_per_partition: int,
67-
output_size_per_partition: int,
67+
output_partition_sizes: List[int],
6868
input_size: int,
6969
output_size: int,
7070
params_dtype: torch.dtype,
7171
**extra_weight_attrs,
7272
):
73+
output_size_per_partition = sum(output_partition_sizes)
7374
weight = Parameter(torch.empty(output_size_per_partition,
7475
input_size_per_partition,
7576
dtype=params_dtype),

0 commit comments

Comments
 (0)