Skip to content

Commit 1ae3379

Browse files
author
shubhagr-quic
authored
Support for Prefix caching Feature in QNN Compilation Path. (quic#262)
Signed-off-by: Shubham Agrawal <quic_shubhagr@quicinc.com>
1 parent a4f3249 commit 1ae3379

File tree

7 files changed

+91
-9
lines changed

7 files changed

+91
-9
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def _qnn_compile(
339339
mxfp6_matmul: bool = False,
340340
mxint8_kv_cache: bool = False,
341341
qnn_config: Optional[str] = None,
342+
kv_cache_batch_size: Optional[int] = None,
342343
) -> str:
343344
"""
344345
Interface for QNN compiler
@@ -356,6 +357,7 @@ def _qnn_compile(
356357
:mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to True``.
357358
:mxint8_kv_cache (bool, optional): Whether to use ``mxint8`` compression for KV cache. ``Defaults to False``.
358359
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
360+
:kv_cache_batch_size (int): kv_cache_batch_size for Prefix Caching. ``Defaults to None.``
359361
"""
360362
if onnx_path is None and self.onnx_path is None:
361363
self.export()
@@ -415,6 +417,7 @@ def _qnn_compile(
415417
full_batch_size=full_batch_size,
416418
qnn_config=qnn_config,
417419
qnn_binary_dir=qpc_path,
420+
kv_cache_batch_size=kv_cache_batch_size,
418421
)
419422

420423
self.qpc_path = qpc_path

QEfficient/cloud/compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
default=False,
9090
help="Enables QNN. Optionally, a configuration file can be provided with [--enable_qnn CONFIG_FILE].\
9191
If not provided, the default configuration will be used.\
92-
Sample Config: QEfficient/cloud/compile/qnn_config.json",
92+
Sample Config: QEfficient/compile/qnn_config.json",
9393
)
9494
parser.add_argument(
9595
"qnn_config",

QEfficient/cloud/infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def main(
223223
default=False,
224224
help="Enables QNN. Optionally, a configuration file can be provided with [--enable_qnn CONFIG_FILE].\
225225
If not provided, the default configuration will be used.\
226-
Sample Config: QEfficient/cloud/compile/qnn_config.json",
226+
Sample Config: QEfficient/compile/qnn_config.json",
227227
)
228228
parser.add_argument(
229229
"qnn_config",

QEfficient/compile/qnn_compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ def compile(
338338
full_batch_size=None,
339339
qnn_config: Optional[str] = None,
340340
qnn_binary_dir: Optional[str] = None,
341+
kv_cache_batch_size: Optional[int] = None,
341342
**kwargs,
342343
) -> str:
343344
"""
@@ -362,6 +363,7 @@ def compile(
362363
:mxint8 (bool): Compress Present/Past KV to ``MXINT8`` using ``CustomIO`` config. ``Defaults to False.``
363364
:qnn_config (str): Path to ``qnn_config.json`` file (formatted as a string). ``Defaults to None.``
364365
:qnn_binary_dir (str): Path for saving qnn binaries.
366+
:kv_cache_batch_size (int): kv_cache_batch_size for Prefix Caching. ``Defaults to None.``
365367
366368
Returns:
367369
:str: Path to compiled ``qpc`` package.
@@ -386,6 +388,7 @@ def compile(
386388
file_path=custom_io_file_path,
387389
full_batch_size=full_batch_size,
388390
kv_precision=kv_precision,
391+
kv_cache_batch_size=kv_cache_batch_size,
389392
)
390393

391394
if not os.path.isfile(custom_io_file_path):

QEfficient/transformers/models/modeling_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1503,6 +1503,7 @@ def compile(
15031503
mxfp6_matmul=mxfp6_matmul,
15041504
mxint8_kv_cache=mxint8_kv_cache,
15051505
qnn_config=qnn_config,
1506+
kv_cache_batch_size=kv_cache_batch_size,
15061507
)
15071508
else:
15081509
# Custom IO

QEfficient/utils/generate_qnn_network_specialization_config.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,27 @@ def fetch_nodes_info(
2424
context_length: int,
2525
file_path: str = "custom_io_config.yaml",
2626
full_batch_size: Optional[int] = None,
27-
decode_only: Optional[bool] = False,
2827
kv_precision: Optional[str] = "float16",
28+
kv_cache_batch_size: Optional[int] = None,
2929
) -> None:
30+
"""
31+
Generates network specialization config custom IO file for convertor stage in QNN compilation.
32+
Reads onnx graph and creates a custom IO configuration file according to the passed parameters and
33+
save it as a yaml file provided in file_path argument.
34+
35+
``Mandatory`` Args:
36+
:onnx_graph_path (str): Generated ``ONNX`` Model Path.
37+
:batch_size (int): Batch size to compile the model for.
38+
:sequence_length (int): Sequence length for the model to compile.
39+
:context_length (int): Maximum context length to compile the model.
40+
41+
``Optional`` Args:
42+
:file_path (str): File path to save the generated custom IO config. ``Defaults to custom_io_config.yaml.``
43+
:full_batch_size (int): Set full batch size to enable continuous batching mode. ``Default to None``
44+
:kv_precision (str): Sets kv precision for compilation. ``Defaults to float16.``
45+
:kv_cache_batch_size (int): kv_cache_batch_size for Prefix Caching. ``Defaults to None.``
46+
"""
47+
3048
# Load the ONNX model
3149
onnx_model = onnx.load(onnx_graph_path)
3250

@@ -46,7 +64,9 @@ def fetch_nodes_info(
4664
if full_batch_size:
4765
input_info["Shape"] = f"(1, 1), ({full_batch_size}, 1)"
4866
else:
49-
input_info["Shape"] = "(1, 1)"
67+
raise AttributeError(
68+
"ERROR: Full batch size is required for populating batch_index in custom_io_config.yaml"
69+
)
5070
else:
5171
shapes = []
5272
for input_shape in node.type.tensor_type.shape.dim:
@@ -67,11 +87,14 @@ def fetch_nodes_info(
6787
for shape in shapes:
6888
if isinstance(shape, str):
6989
if "full_batch_size" in shape:
70-
if full_batch_size:
90+
if ("past_key" in node.name or "past_value" in node.name) and kv_cache_batch_size:
91+
shapeList.append(kv_cache_batch_size)
92+
elif full_batch_size:
7193
shapeList.append(full_batch_size)
7294
else:
73-
print("ERROR: Full batch size is required to generate custom_io_config.yaml")
74-
exit()
95+
raise AttributeError(
96+
"ERROR: Full batch size is required to generate custom_io_config.yaml"
97+
)
7598
elif "batch_size" in shape:
7699
shapeList.append(batch_size)
77100
elif shape in ["ctx_len", "max_context_len"]:
@@ -107,7 +130,7 @@ def fetch_nodes_info(
107130
.replace("[", "(")
108131
.replace("]", ")")
109132
)
110-
shape = shape_2 if decode_only else shape_1 + "," + shape_2
133+
shape = shape_1 + "," + shape_2
111134
elif ("batch_size" in shapes or "full_batch_size" in shapes) and (
112135
"ctx_len" in shapes or "max_context_len" in shapes
113136
):
@@ -153,6 +176,21 @@ def generate_data_format_config(
153176
model_dlc_name: Optional[str] = "model",
154177
file_path: str = "qnn_data_format_config.json",
155178
) -> None:
179+
"""
180+
Generates data format config for context binary generation stage in QNN compilation path.
181+
It defines the tensor format for KV nodes when precision is set to mxint8.
182+
Reads onnx graph and creates a data format configuration file and save it as a json file provided in
183+
file_path argument.
184+
185+
``Mandatory`` Args:
186+
:onnx_graph_path (str): Generated ``ONNX`` Model Path.
187+
188+
``Optional`` Args:
189+
:data_format (str): Tensor format for KV nodes. ``Defaults to QNN_TENSOR_DATA_FORMAT_MX.``
190+
:model_dlc_name (str): DLC Name generated by the convertor stage in QNN Compilation. ``Defaults to model.``
191+
:file_path (str): File path to save the generated data format config. ``Defaults to qnn_data_format_config.json.``
192+
"""
193+
156194
# Load the ONNX model
157195
onnx_model = onnx.load(onnx_graph_path)
158196

tests/transformers/models/test_prefix_caching.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
import os
9+
810
import numpy as np
911
import pytest
1012
from transformers import AutoTokenizer
1113

1214
from QEfficient.generation.text_generation_inference import TextGeneration
1315
from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM
16+
from QEfficient.utils._utils import create_json
1417

1518
test_models = ["gpt2"]
1619

@@ -27,14 +30,48 @@ def test_simple_prefix_caching(model_name):
2730
kv_cache_batch_size=4,
2831
num_cores=14,
2932
)
33+
prefix_caching_inference(model_name=model_name, qpc_path=qeff_model.qpc_path)
34+
35+
36+
@pytest.mark.on_qaic
37+
@pytest.mark.qnn
38+
@pytest.mark.parametrize("model_name", test_models)
39+
def test_simple_prefix_caching_qnn(model_name):
40+
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_name, continuous_batching=True)
41+
qnn_config = {
42+
"convertor_args_extension": "",
43+
"context_binary_generator_args_extension": "--log_level debug",
44+
"qnn_compilation_backend": {
45+
"compiler_enable_depth_first": True,
46+
"compiler_printDDRStats": False,
47+
"compiler_printPerfMetrics": False,
48+
},
49+
"SKIP_QNN_CONVERTOR_STEP": False,
50+
}
51+
qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json")
52+
create_json(qnn_config_json_path, qnn_config)
53+
54+
qeff_model.compile(
55+
prefill_seq_len=128,
56+
ctx_len=256,
57+
full_batch_size=2,
58+
kv_cache_batch_size=4,
59+
num_cores=14,
60+
enable_qnn=True,
61+
qnn_config=qnn_config_json_path,
62+
)
63+
prefix_caching_inference(model_name=model_name, qpc_path=qeff_model.qpc_path)
64+
os.remove(qnn_config_json_path)
65+
3066

67+
def prefix_caching_inference(model_name, qpc_path):
3168
prefixes = ["Once upon a time ", "Once upon a time "]
3269
suffixes1 = ["in a land far away", "there was a small village"]
3370
suffixes2 = ["a little girl", "in a bustling city"]
3471

3572
tokenizer = AutoTokenizer.from_pretrained(model_name)
3673

37-
generator = TextGeneration(tokenizer=tokenizer, qpc_path=qeff_model.qpc_path, full_batch_size=2, ctx_len=256)
74+
generator = TextGeneration(tokenizer=tokenizer, qpc_path=qpc_path, full_batch_size=2, ctx_len=256)
3875

3976
prompts = [pref + suff for pref, suff in zip(prefixes, suffixes1)]
4077

0 commit comments

Comments
 (0)