@@ -24,9 +24,27 @@ def fetch_nodes_info(
2424 context_length : int ,
2525 file_path : str = "custom_io_config.yaml" ,
2626 full_batch_size : Optional [int ] = None ,
27- decode_only : Optional [bool ] = False ,
2827 kv_precision : Optional [str ] = "float16" ,
28+ kv_cache_batch_size : Optional [int ] = None ,
2929) -> None :
30+ """
31+ Generates network specialization config custom IO file for convertor stage in QNN compilation.
32+ Reads onnx graph and creates a custom IO configuration file according to the passed parameters and
33+ save it as a yaml file provided in file_path argument.
34+
35+ ``Mandatory`` Args:
36+ :onnx_graph_path (str): Generated ``ONNX`` Model Path.
37+ :batch_size (int): Batch size to compile the model for.
38+ :sequence_length (int): Sequence length for the model to compile.
39+ :context_length (int): Maximum context length to compile the model.
40+
41+ ``Optional`` Args:
42+ :file_path (str): File path to save the generated custom IO config. ``Defaults to custom_io_config.yaml.``
43+ :full_batch_size (int): Set full batch size to enable continuous batching mode. ``Default to None``
44+ :kv_precision (str): Sets kv precision for compilation. ``Defaults to float16.``
45+ :kv_cache_batch_size (int): kv_cache_batch_size for Prefix Caching. ``Defaults to None.``
46+ """
47+
3048 # Load the ONNX model
3149 onnx_model = onnx .load (onnx_graph_path )
3250
@@ -46,7 +64,9 @@ def fetch_nodes_info(
4664 if full_batch_size :
4765 input_info ["Shape" ] = f"(1, 1), ({ full_batch_size } , 1)"
4866 else :
49- input_info ["Shape" ] = "(1, 1)"
67+ raise AttributeError (
68+ "ERROR: Full batch size is required for populating batch_index in custom_io_config.yaml"
69+ )
5070 else :
5171 shapes = []
5272 for input_shape in node .type .tensor_type .shape .dim :
@@ -67,11 +87,14 @@ def fetch_nodes_info(
6787 for shape in shapes :
6888 if isinstance (shape , str ):
6989 if "full_batch_size" in shape :
70- if full_batch_size :
90+ if ("past_key" in node .name or "past_value" in node .name ) and kv_cache_batch_size :
91+ shapeList .append (kv_cache_batch_size )
92+ elif full_batch_size :
7193 shapeList .append (full_batch_size )
7294 else :
73- print ("ERROR: Full batch size is required to generate custom_io_config.yaml" )
74- exit ()
95+ raise AttributeError (
96+ "ERROR: Full batch size is required to generate custom_io_config.yaml"
97+ )
7598 elif "batch_size" in shape :
7699 shapeList .append (batch_size )
77100 elif shape in ["ctx_len" , "max_context_len" ]:
@@ -107,7 +130,7 @@ def fetch_nodes_info(
107130 .replace ("[" , "(" )
108131 .replace ("]" , ")" )
109132 )
110- shape = shape_2 if decode_only else shape_1 + "," + shape_2
133+ shape = shape_1 + "," + shape_2
111134 elif ("batch_size" in shapes or "full_batch_size" in shapes ) and (
112135 "ctx_len" in shapes or "max_context_len" in shapes
113136 ):
@@ -153,6 +176,21 @@ def generate_data_format_config(
153176 model_dlc_name : Optional [str ] = "model" ,
154177 file_path : str = "qnn_data_format_config.json" ,
155178) -> None :
179+ """
180+ Generates data format config for context binary generation stage in QNN compilation path.
181+ It defines the tensor format for KV nodes when precision is set to mxint8.
182+ Reads onnx graph and creates a data format configuration file and save it as a json file provided in
183+ file_path argument.
184+
185+ ``Mandatory`` Args:
186+ :onnx_graph_path (str): Generated ``ONNX`` Model Path.
187+
188+ ``Optional`` Args:
189+ :data_format (str): Tensor format for KV nodes. ``Defaults to QNN_TENSOR_DATA_FORMAT_MX.``
190+ :model_dlc_name (str): DLC Name generated by the convertor stage in QNN Compilation. ``Defaults to model.``
191+ :file_path (str): File path to save the generated data format config. ``Defaults to qnn_data_format_config.json.``
192+ """
193+
156194 # Load the ONNX model
157195 onnx_model = onnx .load (onnx_graph_path )
158196
0 commit comments