Skip to content

Commit a34a161

Browse files
kylesayrsdsikkaHDCharles
authored
[Examples] Correct out-of-date warning for kv cache examples (vllm-project#2209)
## Purpose ## * As of the attention refactor, CT inference with kv cache quantization is supported. Fix incorrect information ## Changes ## * Remove note about CT inference not being supported * Standardize sample generation code * Remove note about gemma in transformers==4.49.0 (out of supported versions) --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com> Co-authored-by: HDCharles <39544797+HDCharles@users.noreply.github.com>
1 parent f3f14af commit a34a161

File tree

3 files changed

+10
-32
lines changed

3 files changed

+10
-32
lines changed

examples/quantization_kv_cache/gemma2_fp8_kv_example.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,23 +78,13 @@ def process_and_tokenize(example):
7878
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
7979
)
8080

81-
print(
82-
"Note: Inference with the quantized kv_cache is not supported. ",
83-
"Please use vLLM for inference with the quantized kv_cache.",
84-
)
8581
# Confirm generations of the quantized model look sane.
86-
87-
# NOTE: transformers 4.49.0 results in a generation error with gemma2.
88-
# Consider either downgrading your transformers version to a previous version
89-
# or use vLLM for sample generation.
90-
# Note: compile is disabled: https://github.com/huggingface/transformers/issues/38333
9182
print("\n\n")
92-
dispatch_for_generation(model)
9383
print("========== SAMPLE GENERATION ==============")
94-
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
95-
model.device
96-
)
97-
output = model.generate(input_ids, max_new_tokens=100, disable_compile=True)
84+
dispatch_for_generation(model)
85+
sample = tokenizer("Hello my name is", return_tensors="pt")
86+
sample = {key: value.to(model.device) for key, value in sample.items()}
87+
output = model.generate(**sample, max_new_tokens=100)
9888
print(tokenizer.decode(output[0]))
9989
print("==========================================\n\n")
10090

examples/quantization_kv_cache/llama3_fp8_kv_example.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from datasets import load_dataset
2-
from loguru import logger
32
from transformers import AutoModelForCausalLM, AutoTokenizer
43

54
from llmcompressor import oneshot
@@ -79,19 +78,13 @@ def process_and_tokenize(example):
7978
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
8079
)
8180

82-
logger.info(
83-
"Running sample generation. ",
84-
"Note: Inference with the quantized kv_cache is not supported. ",
85-
"Please use vLLM for inference with the quantized kv_cache.",
86-
)
8781
# Confirm generations of the quantized model look sane.
8882
print("\n\n")
8983
print("========== SAMPLE GENERATION ==============")
9084
dispatch_for_generation(model)
91-
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
92-
model.device
93-
)
94-
output = model.generate(input_ids, max_new_tokens=100)
85+
sample = tokenizer("Hello my name is", return_tensors="pt")
86+
sample = {key: value.to(model.device) for key, value in sample.items()}
87+
output = model.generate(**sample, max_new_tokens=100)
9588
print(tokenizer.decode(output[0]))
9689
print("==========================================\n\n")
9790

examples/quantization_kv_cache/phi3.5_fp8_kv_example.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,18 +80,13 @@ def process_and_tokenize(example):
8080
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
8181
)
8282

83-
print(
84-
"Note: Inference with the quantized kv_cache is not supported. ",
85-
"Please use vLLM for inference with the quantized kv_cache.",
86-
)
8783
# Confirm generations of the quantized model look sane.
8884
print("\n\n")
8985
print("========== SAMPLE GENERATION ==============")
9086
dispatch_for_generation(model)
91-
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
92-
model.device
93-
)
94-
output = model.generate(input_ids, max_new_tokens=100)
87+
sample = tokenizer("Hello my name is", return_tensors="pt")
88+
sample = {key: value.to(model.device) for key, value in sample.items()}
89+
output = model.generate(**sample, max_new_tokens=100)
9590
print(tokenizer.decode(output[0]))
9691
print("==========================================\n\n")
9792

0 commit comments

Comments
 (0)