File tree Expand file tree Collapse file tree 13 files changed +56
-6
lines changed
Expand file tree Collapse file tree 13 files changed +56
-6
lines changed Original file line number Diff line number Diff line change @@ -132,6 +132,9 @@ def get_inputs(
132132 )
133133 res = dict (inputs = inputs , dynamic_shapes = shapes )
134134 if add_second_input :
135+ assert (
136+ add_second_input > 0
137+ ), f"Not implemented for add_second_input={ add_second_input } ."
135138 res ["inputs2" ] = get_inputs (
136139 model = model ,
137140 config = config ,
@@ -145,6 +148,7 @@ def get_inputs(
145148 head_dim = head_dim ,
146149 batch_size = batch_size + 1 ,
147150 sequence_length = sequence_length + add_second_input ,
151+ add_second_input = 0 ,
148152 ** kwargs ,
149153 )["inputs" ]
150154 return res
Original file line number Diff line number Diff line change @@ -52,12 +52,16 @@ def get_inputs(
5252 )
5353 res = dict (inputs = inputs , dynamic_shapes = shapes )
5454 if add_second_input :
55+ assert (
56+ add_second_input > 0
57+ ), f"Not implemented for add_second_input={ add_second_input } ."
5558 res ["inputs2" ] = get_inputs (
5659 model = model ,
5760 config = config ,
5861 batch_size = batch_size + 1 ,
5962 sequence_length = sequence_length + add_second_input ,
6063 dummy_max_token_id = dummy_max_token_id ,
64+ add_second_input = 0 ,
6165 ** kwargs ,
6266 )["inputs" ]
6367 return res
Original file line number Diff line number Diff line change @@ -54,12 +54,16 @@ def get_inputs(
5454 )
5555 res = dict (inputs = inputs , dynamic_shapes = shapes )
5656 if add_second_input :
57+ assert (
58+ add_second_input > 0
59+ ), f"Not implemented for add_second_input={ add_second_input } ."
5760 res ["inputs2" ] = get_inputs (
5861 model = model ,
5962 config = config ,
6063 batch_size = batch_size + 1 ,
6164 sequence_length = sequence_length + add_second_input ,
6265 dummy_max_token_id = dummy_max_token_id ,
66+ add_second_input = 0 ,
6367 ** kwargs ,
6468 )["inputs" ]
6569 return res
Original file line number Diff line number Diff line change @@ -75,6 +75,9 @@ def get_inputs(
7575 shapes ["interpolate_pos_encoding" ] = None # type: ignore[assignment]
7676 res = dict (inputs = inputs , dynamic_shapes = shapes )
7777 if add_second_input :
78+ assert (
79+ add_second_input > 0
80+ ), f"Not implemented for add_second_input={ add_second_input } ."
7881 res ["inputs2" ] = get_inputs (
7982 model = model ,
8083 config = config ,
@@ -83,6 +86,7 @@ def get_inputs(
8386 input_channels = input_channels ,
8487 batch_size = batch_size + 1 ,
8588 dynamic_rope = dynamic_rope ,
89+ add_second_input = 0 ,
8690 ** kwargs ,
8791 )["inputs" ]
8892 return res
Original file line number Diff line number Diff line change @@ -105,6 +105,9 @@ def get_inputs(
105105 )
106106 res = dict (inputs = inputs , dynamic_shapes = shapes )
107107 if add_second_input :
108+ assert (
109+ add_second_input > 0
110+ ), f"Not implemented for add_second_input={ add_second_input } ."
108111 res ["inputs2" ] = get_inputs (
109112 model = model ,
110113 config = config ,
@@ -117,9 +120,10 @@ def get_inputs(
117120 num_channels = num_channels ,
118121 batch_size = batch_size + 1 ,
119122 sequence_length = sequence_length + add_second_input ,
120- sequence_length2 = sequence_length2 + add_second_input ,
123+ sequence_length2 = sequence_length2 + 1 ,
121124 n_images = n_images + 1 ,
122125 dynamic_rope = dynamic_rope ,
126+ add_second_input = 0 ,
123127 ** kwargs ,
124128 )["inputs" ]
125129 return res
Original file line number Diff line number Diff line change @@ -65,6 +65,9 @@ def get_inputs(
6565 )
6666 res = dict (inputs = inputs , dynamic_shapes = shapes )
6767 if add_second_input :
68+ assert (
69+ add_second_input > 0
70+ ), f"Not implemented for add_second_input={ add_second_input } ."
6871 res ["inputs2" ] = get_inputs (
6972 model = model ,
7073 config = config ,
@@ -73,6 +76,7 @@ def get_inputs(
7376 input_channels = input_channels ,
7477 batch_size = batch_size + 1 ,
7578 dynamic_rope = dynamic_rope ,
79+ add_second_input = 0 ,
7680 ** kwargs ,
7781 )["inputs" ]
7882 return res
Original file line number Diff line number Diff line change @@ -54,12 +54,16 @@ def get_inputs(
5454 )
5555 res = dict (inputs = inputs , dynamic_shapes = shapes )
5656 if add_second_input :
57+ assert (
58+ add_second_input > 0
59+ ), f"Not implemented for add_second_input={ add_second_input } ."
5760 res ["inputs2" ] = get_inputs (
5861 model = model ,
5962 config = config ,
6063 batch_size = batch_size + 1 ,
6164 sequence_length = sequence_length + add_second_input ,
6265 dummy_max_token_id = dummy_max_token_id ,
66+ add_second_input = 0 ,
6367 ** kwargs ,
6468 )["inputs" ]
6569 return res
Original file line number Diff line number Diff line change @@ -144,6 +144,9 @@ def get_inputs(
144144 )
145145 res = dict (inputs = inputs , dynamic_shapes = shapes )
146146 if add_second_input :
147+ assert (
148+ add_second_input > 0
149+ ), f"Not implemented for add_second_input={ add_second_input } ."
147150 res ["inputs2" ] = get_inputs (
148151 model = model ,
149152 config = config ,
@@ -155,7 +158,8 @@ def get_inputs(
155158 head_dim_decoder = head_dim_decoder ,
156159 batch_size = batch_size + 1 ,
157160 sequence_length = sequence_length + add_second_input ,
158- sequence_length2 = sequence_length2 + add_second_input ,
161+ sequence_length2 = sequence_length2 + 1 ,
162+ add_second_input = 0 ,
159163 ** kwargs ,
160164 )["inputs" ]
161165 return res
Original file line number Diff line number Diff line change @@ -149,6 +149,9 @@ def get_inputs(
149149 )
150150 res = dict (inputs = inputs , dynamic_shapes = shapes )
151151 if add_second_input :
152+ assert (
153+ add_second_input > 0
154+ ), f"Not implemented for add_second_input={ add_second_input } ."
152155 res ["inputs2" ] = get_inputs (
153156 model = model ,
154157 config = config ,
@@ -161,7 +164,8 @@ def get_inputs(
161164 encoder_dim = encoder_dim ,
162165 batch_size = batch_size + 1 ,
163166 sequence_length = sequence_length + add_second_input ,
164- sequence_length2 = sequence_length2 + add_second_input ,
167+ sequence_length2 = sequence_length2 + 1 ,
168+ add_second_input = 0 ,
165169 ** kwargs ,
166170 )["inputs" ]
167171 return res
Original file line number Diff line number Diff line change @@ -54,12 +54,16 @@ def get_inputs(
5454 )
5555 res = dict (inputs = inputs , dynamic_shapes = shapes )
5656 if add_second_input :
57+ assert (
58+ add_second_input > 0
59+ ), f"Not implemented for add_second_input={ add_second_input } ."
5760 res ["inputs2" ] = get_inputs (
5861 model = model ,
5962 config = config ,
6063 batch_size = batch_size + 1 ,
6164 sequence_length = sequence_length + add_second_input ,
6265 dummy_max_token_id = dummy_max_token_id ,
66+ add_second_input = 0 ,
6367 ** kwargs ,
6468 )["inputs" ]
6569 return res
You can’t perform that action at this time.
0 commit comments