33import onnx
44import torch
55from .helper import string_type , flatten_object
6- from .onnx_helper import dtype_to_tensor_dtype
76from .cache_helper import is_cache_dynamic_registered
87
98
@@ -23,6 +22,7 @@ def make_feeds(
2322 use_numpy : bool = False ,
2423 copy : bool = False ,
2524 check_flatten : bool = True ,
25+ is_modelbuilder : bool = False ,
2626) -> Dict [str , Union [torch .Tensor , np .ndarray ]]:
2727 """
2828 Serializes the inputs to produce feeds expected
@@ -35,10 +35,15 @@ def make_feeds(
3535 by ``OrtValue``
3636 :param check_flatten: if True, checks the ``torch.utils._pytree.tree_flatten``
3737 returns the same number of outputs
38+ :param is_modelbuilder: if True, the exporter is ModelBuilder, and we need to reorder
39+ the past_key_values inputs to match the expected order, and get rid of position_ids.
3840 :return: feeds dictionary
3941 """
40- # position_ids is a special case because ModelBuilder does not usually use it.
41- # We use types to detect the best inputs.
42+ # NOTE: position_ids is a special case because ModelBuilder does not usually use it,
43+ # because it's fued into rotary embedding in GQA.
44+ if isinstance (inputs , dict ):
45+ inputs .pop ("position_ids" , None ) # Ensure 'position_ids' absent before removing.
46+
4247 flat = flatten_object (inputs , drop_keys = True )
4348 assert (
4449 not check_flatten
@@ -76,39 +81,6 @@ def make_feeds(
7681 f"\n -- inputs={ string_type (inputs , with_shape = True )} "
7782 f"\n -- names={ names } "
7883 )
79- if len (names ) < len (flat ) and (
80- isinstance (proto , onnx .ModelProto ) or hasattr (proto , "get_inputs" )
81- ):
82-
83- typed_names = (
84- [(i .name , i .type .tensor_type .elem_type ) for i in proto .graph .input ]
85- if isinstance (proto , onnx .ModelProto )
86- else [(i .name , name_type_to_onnx_dtype (i .type )) for i in proto .get_inputs ()]
87- )
88-
89- new_flat = []
90- pos = 0
91- for _name , dtype in typed_names :
92- assert isinstance (
93- dtype , int
94- ), f"Unexpected value for dtype={ dtype !r} , type(proto)={ type (proto )} "
95- itype = dtype_to_tensor_dtype (flat [pos ].dtype )
96- while dtype != itype :
97- pos += 1
98- if pos >= len (flat ):
99- break
100- itype = dtype_to_tensor_dtype (flat [pos ].dtype )
101- if pos >= len (flat ):
102- break
103- new_flat .append (flat [pos ])
104- pos += 1
105- assert len (new_flat ) == len (names ), (
106- f"Unable to align expected input { names } with the given input, "
107- f"type(proto)={ type (proto )} "
108- f"\n -- inputs: { string_type (inputs , with_shape = True )} "
109- f"\n -- typed_names: { typed_names } "
110- )
111- flat = new_flat
11284
11385 if copy :
11486 flat = [t .copy () if hasattr (t , "copy" ) else t .clone () for t in flat ]
@@ -122,4 +94,45 @@ def make_feeds(
12294 elif isinstance (i , float ):
12395 i = np .array (i , dtype = np .float32 )
12496 new_flat .append (i )
97+
98+ # NOTE: model builder has a different order for past_key_values
99+ # we need to reorder them to match the expected order
100+ if is_modelbuilder :
101+ # We assume that if "past_key_values" is in the names when it's
102+ # modelbuilder
103+ non_past_kv_input_names = [n for n in names if "past_key_values" not in n ]
104+ past_kv_names = [n for n in names if "past_key_values" in n ]
105+ reorder_past_kv_names = reorder_modelbuilder_cache_to_torch (past_kv_names )
106+ names = non_past_kv_input_names + reorder_past_kv_names
125107 return dict (zip (names , new_flat ))
108+
109+
110+ def reorder_modelbuilder_cache_to_torch (past_kv : List [Any ]) -> List [Any ]:
111+ """
112+ Reorders the past_kvs for ModelBuilder to match the expected order
113+ by PyTorch exported models.
114+
115+ NOTE: This function can take either the names or the actual tensors
116+ as long as they are in a list.
117+
118+ Conceptually,
119+
120+ From:
121+ [past_key_values.0.key, past_key_values.0.value,
122+ past_key_values.1.key, past_key_values.1.value, ...]
123+ To:
124+ [past_key_values.0.key, past_key_values.1.key,
125+ ..., past_key_values.0.value, past_key_values.1.value, ...]
126+
127+ :param flat: list of flattened inputs
128+ :return: reordered list of flattened inputs
129+ """
130+ total_len = len (past_kv )
131+ if total_len % 2 != 0 :
132+ raise ValueError ("The length of past_key_values should be even." )
133+ keys = []
134+ values = []
135+ for i in range (0 , total_len , 2 ):
136+ keys .append (past_kv [i ])
137+ values .append (past_kv [i + 1 ])
138+ return keys + values
0 commit comments