@@ -79,6 +79,10 @@ def test_llm(
7979 required_pcc: Required PCC threshold
8080 accuracy_testing: Enable token accuracy testing with reference data
8181 """
82+ # Set default batch size if None
83+ if batch_size is None :
84+ batch_size = DEFAULT_BATCH_SIZE
85+
8286 model_loader = create_model_loader (ModelLoaderModule , num_layers = num_layers , variant = variant )
8387 if num_layers is not None and model_loader is None :
8488 pytest .fail ("num_layers override requested but ModelLoader does not support it." )
@@ -196,7 +200,7 @@ def test_llm_tp(ModelLoaderModule, variant, output_file, num_layers=None, reques
196200 )
197201
198202
199- def test_llama_3_2_1b (output_file , num_layers , request , accuracy_testing ):
203+ def test_llama_3_2_1b (output_file , num_layers , request , accuracy_testing , batch_size ):
200204 from third_party .tt_forge_models .llama .causal_lm .pytorch .loader import ModelLoader , ModelVariant
201205
202206 variant = ModelVariant .LLAMA_3_2_1B_INSTRUCT
@@ -207,10 +211,11 @@ def test_llama_3_2_1b(output_file, num_layers, request, accuracy_testing):
207211 num_layers = num_layers ,
208212 request = request ,
209213 accuracy_testing = accuracy_testing ,
214+ batch_size = batch_size ,
210215 )
211216
212217
213- def test_llama_3_2_3b (output_file , num_layers , request , accuracy_testing ):
218+ def test_llama_3_2_3b (output_file , num_layers , request , accuracy_testing , batch_size ):
214219 from third_party .tt_forge_models .llama .causal_lm .pytorch .loader import ModelLoader , ModelVariant
215220
216221 variant = ModelVariant .LLAMA_3_2_3B_INSTRUCT
@@ -221,10 +226,11 @@ def test_llama_3_2_3b(output_file, num_layers, request, accuracy_testing):
221226 num_layers = num_layers ,
222227 request = request ,
223228 accuracy_testing = accuracy_testing ,
229+ batch_size = batch_size ,
224230 )
225231
226232
227- def test_gemma_1_1_2b (output_file , num_layers , request , accuracy_testing ):
233+ def test_gemma_1_1_2b (output_file , num_layers , request , accuracy_testing , batch_size ):
228234 from third_party .tt_forge_models .gemma .pytorch .loader import ModelLoader , ModelVariant
229235
230236 variant = ModelVariant .GEMMA_1_1_2B_IT
@@ -237,10 +243,11 @@ def test_gemma_1_1_2b(output_file, num_layers, request, accuracy_testing):
237243 num_layers = num_layers ,
238244 request = request ,
239245 accuracy_testing = accuracy_testing ,
246+ batch_size = batch_size ,
240247 )
241248
242249
243- def test_gemma_2_2b (output_file , num_layers , request , accuracy_testing ):
250+ def test_gemma_2_2b (output_file , num_layers , request , accuracy_testing , batch_size ):
244251 from third_party .tt_forge_models .gemma .pytorch .loader import ModelLoader , ModelVariant
245252
246253 variant = ModelVariant .GEMMA_2_2B_IT
@@ -253,10 +260,11 @@ def test_gemma_2_2b(output_file, num_layers, request, accuracy_testing):
253260 num_layers = num_layers ,
254261 request = request ,
255262 accuracy_testing = accuracy_testing ,
263+ batch_size = batch_size ,
256264 )
257265
258266
259- def test_phi1 (output_file , num_layers , request , accuracy_testing ):
267+ def test_phi1 (output_file , num_layers , request , accuracy_testing , batch_size ):
260268 from third_party .tt_forge_models .phi1 .causal_lm .pytorch .loader import ModelLoader , ModelVariant
261269
262270 variant = ModelVariant .PHI1
@@ -267,10 +275,11 @@ def test_phi1(output_file, num_layers, request, accuracy_testing):
267275 num_layers = num_layers ,
268276 request = request ,
269277 accuracy_testing = accuracy_testing ,
278+ batch_size = batch_size ,
270279 )
271280
272281
273- def test_phi1_5 (output_file , num_layers , request , accuracy_testing ):
282+ def test_phi1_5 (output_file , num_layers , request , accuracy_testing , batch_size ):
274283 from third_party .tt_forge_models .phi1_5 .causal_lm .pytorch .loader import ModelLoader , ModelVariant
275284
276285 variant = ModelVariant .PHI1_5
@@ -281,10 +290,11 @@ def test_phi1_5(output_file, num_layers, request, accuracy_testing):
281290 num_layers = num_layers ,
282291 request = request ,
283292 accuracy_testing = accuracy_testing ,
293+ batch_size = batch_size ,
284294 )
285295
286296
287- def test_phi2 (output_file , num_layers , request , accuracy_testing ):
297+ def test_phi2 (output_file , num_layers , request , accuracy_testing , batch_size ):
288298 from third_party .tt_forge_models .phi2 .causal_lm .pytorch .loader import ModelLoader , ModelVariant
289299
290300 variant = ModelVariant .PHI2
@@ -295,10 +305,11 @@ def test_phi2(output_file, num_layers, request, accuracy_testing):
295305 num_layers = num_layers ,
296306 request = request ,
297307 accuracy_testing = accuracy_testing ,
308+ batch_size = batch_size ,
298309 )
299310
300311
301- def test_falcon3_1b (output_file , num_layers , request , accuracy_testing ):
312+ def test_falcon3_1b (output_file , num_layers , request , accuracy_testing , batch_size ):
302313 from third_party .tt_forge_models .falcon .pytorch .loader import ModelLoader , ModelVariant
303314
304315 variant = ModelVariant .FALCON_1B
@@ -312,10 +323,11 @@ def test_falcon3_1b(output_file, num_layers, request, accuracy_testing):
312323 num_layers = num_layers ,
313324 request = request ,
314325 accuracy_testing = accuracy_testing ,
326+ batch_size = batch_size ,
315327 )
316328
317329
318- def test_falcon3_3b (output_file , num_layers , request , accuracy_testing ):
330+ def test_falcon3_3b (output_file , num_layers , request , accuracy_testing , batch_size ):
319331 from third_party .tt_forge_models .falcon .pytorch .loader import ModelLoader , ModelVariant
320332
321333 variant = ModelVariant .FALCON_3B
@@ -329,10 +341,11 @@ def test_falcon3_3b(output_file, num_layers, request, accuracy_testing):
329341 num_layers = num_layers ,
330342 request = request ,
331343 accuracy_testing = accuracy_testing ,
344+ batch_size = batch_size ,
332345 )
333346
334347
335- def test_qwen_2_5_0_5b (output_file , num_layers , request , accuracy_testing ):
348+ def test_qwen_2_5_0_5b (output_file , num_layers , request , accuracy_testing , batch_size ):
336349 from third_party .tt_forge_models .qwen_2_5 .causal_lm .pytorch .loader import ModelLoader , ModelVariant
337350
338351 variant = ModelVariant .QWEN_2_5_0_5B_INSTRUCT
@@ -344,10 +357,11 @@ def test_qwen_2_5_0_5b(output_file, num_layers, request, accuracy_testing):
344357 num_layers = num_layers ,
345358 request = request ,
346359 accuracy_testing = accuracy_testing ,
360+ batch_size = batch_size ,
347361 )
348362
349363
350- def test_qwen_3_0_6b (output_file , num_layers , request , accuracy_testing ):
364+ def test_qwen_3_0_6b (output_file , num_layers , request , accuracy_testing , batch_size ):
351365 from third_party .tt_forge_models .qwen_3 .causal_lm .pytorch .loader import ModelLoader , ModelVariant
352366
353367 variant = ModelVariant .QWEN_3_0_6B
@@ -358,10 +372,11 @@ def test_qwen_3_0_6b(output_file, num_layers, request, accuracy_testing):
358372 num_layers = num_layers ,
359373 request = request ,
360374 accuracy_testing = accuracy_testing ,
375+ batch_size = batch_size ,
361376 )
362377
363378
364- def test_qwen_3_1_7b (output_file , num_layers , request , accuracy_testing ):
379+ def test_qwen_3_1_7b (output_file , num_layers , request , accuracy_testing , batch_size ):
365380 from third_party .tt_forge_models .qwen_3 .causal_lm .pytorch .loader import ModelLoader , ModelVariant
366381
367382 variant = ModelVariant .QWEN_3_1_7B
@@ -372,10 +387,11 @@ def test_qwen_3_1_7b(output_file, num_layers, request, accuracy_testing):
372387 num_layers = num_layers ,
373388 request = request ,
374389 accuracy_testing = accuracy_testing ,
390+ batch_size = batch_size ,
375391 )
376392
377393
378- def test_qwen_3_4b (output_file , num_layers , request , accuracy_testing ):
394+ def test_qwen_3_4b (output_file , num_layers , request , accuracy_testing , batch_size ):
379395 from third_party .tt_forge_models .qwen_3 .causal_lm .pytorch .loader import ModelLoader , ModelVariant
380396
381397 variant = ModelVariant .QWEN_3_4B
@@ -386,10 +402,11 @@ def test_qwen_3_4b(output_file, num_layers, request, accuracy_testing):
386402 num_layers = num_layers ,
387403 request = request ,
388404 accuracy_testing = accuracy_testing ,
405+ batch_size = batch_size ,
389406 )
390407
391408
392- def test_qwen_2_5_1_5b (output_file , num_layers , request , accuracy_testing ):
409+ def test_qwen_2_5_1_5b (output_file , num_layers , request , accuracy_testing , batch_size ):
393410 from third_party .tt_forge_models .qwen_2_5 .causal_lm .pytorch .loader import ModelLoader , ModelVariant
394411
395412 variant = ModelVariant .QWEN_2_5_1_5B_INSTRUCT
@@ -400,10 +417,11 @@ def test_qwen_2_5_1_5b(output_file, num_layers, request, accuracy_testing):
400417 num_layers = num_layers ,
401418 request = request ,
402419 accuracy_testing = accuracy_testing ,
420+ batch_size = batch_size ,
403421 )
404422
405423
406- def test_qwen_2_5_3b (output_file , num_layers , request , accuracy_testing ):
424+ def test_qwen_2_5_3b (output_file , num_layers , request , accuracy_testing , batch_size ):
407425 from third_party .tt_forge_models .qwen_2_5 .causal_lm .pytorch .loader import ModelLoader , ModelVariant
408426
409427 variant = ModelVariant .QWEN_2_5_3B_INSTRUCT
@@ -414,10 +432,11 @@ def test_qwen_2_5_3b(output_file, num_layers, request, accuracy_testing):
414432 num_layers = num_layers ,
415433 request = request ,
416434 accuracy_testing = accuracy_testing ,
435+ batch_size = batch_size ,
417436 )
418437
419438
420- def test_qwen_3_8b (output_file , num_layers , request , accuracy_testing ):
439+ def test_qwen_3_8b (output_file , num_layers , request , accuracy_testing , batch_size ):
421440 from third_party .tt_forge_models .qwen_3 .causal_lm .pytorch .loader import ModelLoader , ModelVariant
422441
423442 variant = ModelVariant .QWEN_3_8B
@@ -428,10 +447,11 @@ def test_qwen_3_8b(output_file, num_layers, request, accuracy_testing):
428447 num_layers = num_layers ,
429448 request = request ,
430449 accuracy_testing = accuracy_testing ,
450+ batch_size = batch_size ,
431451 )
432452
433453
434- def test_qwen_2_5_7b (output_file , num_layers , request , accuracy_testing ):
454+ def test_qwen_2_5_7b (output_file , num_layers , request , accuracy_testing , batch_size ):
435455 from third_party .tt_forge_models .qwen_2_5 .causal_lm .pytorch .loader import ModelLoader , ModelVariant
436456
437457 variant = ModelVariant .QWEN_2_5_7B_INSTRUCT
@@ -442,6 +462,7 @@ def test_qwen_2_5_7b(output_file, num_layers, request, accuracy_testing):
442462 num_layers = num_layers ,
443463 request = request ,
444464 accuracy_testing = accuracy_testing ,
465+ batch_size = batch_size ,
445466 )
446467
447468
@@ -485,7 +506,7 @@ def test_mamba_2_8b(output_file, num_layers, request):
485506 )
486507
487508
488- def test_falcon3_7b (output_file , num_layers , request , accuracy_testing ):
509+ def test_falcon3_7b (output_file , num_layers , request , accuracy_testing , batch_size ):
489510 from third_party .tt_forge_models .falcon .pytorch .loader import ModelLoader , ModelVariant
490511
491512 variant = ModelVariant .FALCON_7B
@@ -499,10 +520,11 @@ def test_falcon3_7b(output_file, num_layers, request, accuracy_testing):
499520 num_layers = num_layers ,
500521 request = request ,
501522 accuracy_testing = accuracy_testing ,
523+ batch_size = batch_size ,
502524 )
503525
504526
505- def test_mistral_7b (output_file , num_layers , request , accuracy_testing ):
527+ def test_mistral_7b (output_file , num_layers , request , accuracy_testing , batch_size ):
506528 from third_party .tt_forge_models .mistral .pytorch .loader import ModelLoader , ModelVariant
507529
508530 variant = ModelVariant .MISTRAL_7B_INSTRUCT_V03
@@ -513,10 +535,11 @@ def test_mistral_7b(output_file, num_layers, request, accuracy_testing):
513535 num_layers = num_layers ,
514536 request = request ,
515537 accuracy_testing = accuracy_testing ,
538+ batch_size = batch_size ,
516539 )
517540
518541
519- def test_ministral_8b (output_file , num_layers , request , accuracy_testing ):
542+ def test_ministral_8b (output_file , num_layers , request , accuracy_testing , batch_size ):
520543 from third_party .tt_forge_models .mistral .pytorch .loader import ModelLoader , ModelVariant
521544
522545 variant = ModelVariant .MINISTRAL_8B
@@ -528,10 +551,11 @@ def test_ministral_8b(output_file, num_layers, request, accuracy_testing):
528551 request = request ,
529552 fp32_dest_acc_en = False ,
530553 accuracy_testing = accuracy_testing ,
554+ batch_size = batch_size ,
531555 )
532556
533557
534- def test_llama_3_1_8b (output_file , num_layers , request , accuracy_testing ):
558+ def test_llama_3_1_8b (output_file , num_layers , request , accuracy_testing , batch_size ):
535559 from third_party .tt_forge_models .llama .causal_lm .pytorch .loader import ModelLoader , ModelVariant
536560
537561 variant = ModelVariant .LLAMA_3_1_8B_INSTRUCT
@@ -543,6 +567,7 @@ def test_llama_3_1_8b(output_file, num_layers, request, accuracy_testing):
543567 request = request ,
544568 fp32_dest_acc_en = False ,
545569 accuracy_testing = accuracy_testing ,
570+ batch_size = batch_size ,
546571 )
547572
548573
0 commit comments