@@ -41,7 +41,7 @@ def make_feeds(
4141 """
4242 # NOTE: position_ids is a special case because ModelBuilder does not usually use it,
4343 # because it's fued into rotary embedding in GQA.
44- if isinstance (inputs , dict ):
44+ if is_modelbuilder and isinstance (inputs , dict ):
4545 inputs .pop ("position_ids" , None ) # Ensure 'position_ids' absent before removing.
4646
4747 flat = flatten_object (inputs , drop_keys = True )
@@ -112,19 +112,23 @@ def reorder_modelbuilder_cache_to_torch(past_kv: List[Any]) -> List[Any]:
112112 Reorders the past_kvs for ModelBuilder to match the expected order
113113 by PyTorch exported models.
114114
115- NOTE: This function can take either the names or the actual tensors
116- as long as they are in a list.
115+ .. note::
116+ This function can take either the names or the actual tensors
117+ as long as they are in a list.
117118
118119 Conceptually,
119120
120- From:
121- [past_key_values.0.key, past_key_values.0.value,
122- past_key_values.1.key, past_key_values.1.value, ...]
123- To:
124- [past_key_values.0.key, past_key_values.1.key,
125- ..., past_key_values.0.value, past_key_values.1.value, ...]
121+ From::
126122
127- :param flat: list of flattened inputs
123+ [past_key_values.0.key, past_key_values.0.value,
124+ past_key_values.1.key, past_key_values.1.value, ...]
125+
126+ To::
127+
128+ [past_key_values.0.key, past_key_values.1.key,
129+ ..., past_key_values.0.value, past_key_values.1.value, ...]
130+
131+ :param past_kv: list of flattened inputs
128132 :return: reordered list of flattened inputs
129133 """
130134 total_len = len (past_kv )
0 commit comments