Skip to content

Commit 8cef6e0

Browse files
authored
[Misc] add w8a8 asym models (#11075)
1 parent b866cdb commit 8cef6e0

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

tests/quantization/test_compressed_tensors.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,23 @@ def zp_valid(zp: Optional[torch.Tensor]):
7979
assert output
8080

8181

82-
@pytest.mark.parametrize(
83-
"model_path",
84-
[
85-
"neuralmagic/Llama-3.2-1B-quantized.w8a8"
86-
# TODO static & asymmetric
87-
])
82+
@pytest.mark.parametrize("model_path", [
83+
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
84+
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym",
85+
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
86+
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
87+
])
8888
@pytest.mark.parametrize("max_tokens", [32])
8989
@pytest.mark.parametrize("num_logprobs", [10])
9090
def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner,
9191
example_prompts, model_path,
9292
max_tokens, num_logprobs):
9393
dtype = "bfloat16"
9494

95+
# skip language translation prompt for the static per tensor asym model
96+
if model_path == "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym": # noqa: E501
97+
example_prompts = example_prompts[0:-1]
98+
9599
with hf_runner(model_path, dtype=dtype) as hf_model:
96100
hf_outputs = hf_model.generate_greedy_logprobs_limit(
97101
example_prompts, max_tokens, num_logprobs)

0 commit comments

Comments
 (0)