Skip to content

Commit 29c9c3b

Browse files
committed
Set batch-size to mitigate OOM issues due to larger input sequence than tested in perf benchmark
1 parent cd3a018 commit 29c9c3b

File tree

3 files changed

+57
-27
lines changed

3 files changed

+57
-27
lines changed

.github/workflows/call-perf-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ jobs:
197197
python benchmark/benchmark.py -p ${{ matrix.build.project}} -m ${{ matrix.build.name }} -bs ${{ matrix.build.bs }} -df ${{ matrix.build.df }} -lp ${{ matrix.build.lp }} ${{ matrix.build.input_sequence_length && format('-isl {0}', matrix.build.input_sequence_length) }} -ts ${{ matrix.build.ts }} -o ${{ steps.strings.outputs.perf_report_json_file }} ${{ inputs.run_id_source && format('-r {0}', inputs.run_id_source) }}
198198
else
199199
# Run with pytest
200-
pytest -svv "${{ matrix.build.pytest }}" ${{ matrix.build.accuracy-testing && '--accuracy-testing true' || '' }} --output-file=${{ steps.strings.outputs.perf_report_json_file }}
200+
pytest -svv "${{ matrix.build.pytest }}" ${{ matrix.build.accuracy-testing && '--accuracy-testing true' || '' }} ${{ matrix.build['batch-size'] && format('--batch-size {0}', matrix.build['batch-size']) || '' }} --output-file=${{ steps.strings.outputs.perf_report_json_file }}
201201
fi
202202
203203
- name: Dump stablehlo to report

.github/workflows/perf-bench-matrix.json

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -322,13 +322,15 @@
322322
"name": "llama_3_1_8b_instruct_accuracy",
323323
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
324324
"pytest": "benchmark/tt-xla/test_llms.py::test_llama_3_1_8b",
325-
"accuracy-testing": true
325+
"accuracy-testing": true,
326+
"batch-size": 16
326327
},
327328
{
328329
"name": "mistral_7b_accuracy",
329330
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1 protobuf sentencepiece",
330331
"pytest": "benchmark/tt-xla/test_llms.py::test_mistral_7b",
331-
"accuracy-testing": true
332+
"accuracy-testing": true,
333+
"batch-size": 8
332334
},
333335
{
334336
"name": "qwen_2_5_7b_instruct_accuracy",
@@ -382,7 +384,8 @@
382384
"name": "tiiuae_falcon3-7b-base_accuracy",
383385
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
384386
"pytest": "benchmark/tt-xla/test_llms.py::test_falcon3_7b",
385-
"accuracy-testing": true
387+
"accuracy-testing": true,
388+
"batch-size": 4
386389
},
387390
{
388391
"name": "qwen_2_5_0_5b_instruct_accuracy",
@@ -400,7 +403,8 @@
400403
"name": "qwen_2_5_3b_instruct_accuracy",
401404
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
402405
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_2_5_3b",
403-
"accuracy-testing": true
406+
"accuracy-testing": true,
407+
"batch-size": 16
404408
},
405409
{
406410
"name": "qwen_3_0_6b_accuracy",
@@ -430,7 +434,8 @@
430434
"name": "ministral_8b_accuracy",
431435
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
432436
"pytest": "benchmark/tt-xla/test_llms.py::test_ministral_8b",
433-
"accuracy-testing": true
437+
"accuracy-testing": true,
438+
"batch-size": 16
434439
}
435440
]
436441
}

benchmark/tt-xla/test_llms.py

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ def test_llm(
7979
required_pcc: Required PCC threshold
8080
accuracy_testing: Enable token accuracy testing with reference data
8181
"""
82+
# Set default batch size if None
83+
if batch_size is None:
84+
batch_size = DEFAULT_BATCH_SIZE
85+
8286
model_loader = create_model_loader(ModelLoaderModule, num_layers=num_layers, variant=variant)
8387
if num_layers is not None and model_loader is None:
8488
pytest.fail("num_layers override requested but ModelLoader does not support it.")
@@ -196,7 +200,7 @@ def test_llm_tp(ModelLoaderModule, variant, output_file, num_layers=None, reques
196200
)
197201

198202

199-
def test_llama_3_2_1b(output_file, num_layers, request, accuracy_testing):
203+
def test_llama_3_2_1b(output_file, num_layers, request, accuracy_testing, batch_size):
200204
from third_party.tt_forge_models.llama.causal_lm.pytorch.loader import ModelLoader, ModelVariant
201205

202206
variant = ModelVariant.LLAMA_3_2_1B_INSTRUCT
@@ -207,10 +211,11 @@ def test_llama_3_2_1b(output_file, num_layers, request, accuracy_testing):
207211
num_layers=num_layers,
208212
request=request,
209213
accuracy_testing=accuracy_testing,
214+
batch_size=batch_size,
210215
)
211216

212217

213-
def test_llama_3_2_3b(output_file, num_layers, request, accuracy_testing):
218+
def test_llama_3_2_3b(output_file, num_layers, request, accuracy_testing, batch_size):
214219
from third_party.tt_forge_models.llama.causal_lm.pytorch.loader import ModelLoader, ModelVariant
215220

216221
variant = ModelVariant.LLAMA_3_2_3B_INSTRUCT
@@ -221,10 +226,11 @@ def test_llama_3_2_3b(output_file, num_layers, request, accuracy_testing):
221226
num_layers=num_layers,
222227
request=request,
223228
accuracy_testing=accuracy_testing,
229+
batch_size=batch_size,
224230
)
225231

226232

227-
def test_gemma_1_1_2b(output_file, num_layers, request, accuracy_testing):
233+
def test_gemma_1_1_2b(output_file, num_layers, request, accuracy_testing, batch_size):
228234
from third_party.tt_forge_models.gemma.pytorch.loader import ModelLoader, ModelVariant
229235

230236
variant = ModelVariant.GEMMA_1_1_2B_IT
@@ -237,10 +243,11 @@ def test_gemma_1_1_2b(output_file, num_layers, request, accuracy_testing):
237243
num_layers=num_layers,
238244
request=request,
239245
accuracy_testing=accuracy_testing,
246+
batch_size=batch_size,
240247
)
241248

242249

243-
def test_gemma_2_2b(output_file, num_layers, request, accuracy_testing):
250+
def test_gemma_2_2b(output_file, num_layers, request, accuracy_testing, batch_size):
244251
from third_party.tt_forge_models.gemma.pytorch.loader import ModelLoader, ModelVariant
245252

246253
variant = ModelVariant.GEMMA_2_2B_IT
@@ -253,10 +260,11 @@ def test_gemma_2_2b(output_file, num_layers, request, accuracy_testing):
253260
num_layers=num_layers,
254261
request=request,
255262
accuracy_testing=accuracy_testing,
263+
batch_size=batch_size,
256264
)
257265

258266

259-
def test_phi1(output_file, num_layers, request, accuracy_testing):
267+
def test_phi1(output_file, num_layers, request, accuracy_testing, batch_size):
260268
from third_party.tt_forge_models.phi1.causal_lm.pytorch.loader import ModelLoader, ModelVariant
261269

262270
variant = ModelVariant.PHI1
@@ -267,10 +275,11 @@ def test_phi1(output_file, num_layers, request, accuracy_testing):
267275
num_layers=num_layers,
268276
request=request,
269277
accuracy_testing=accuracy_testing,
278+
batch_size=batch_size,
270279
)
271280

272281

273-
def test_phi1_5(output_file, num_layers, request, accuracy_testing):
282+
def test_phi1_5(output_file, num_layers, request, accuracy_testing, batch_size):
274283
from third_party.tt_forge_models.phi1_5.causal_lm.pytorch.loader import ModelLoader, ModelVariant
275284

276285
variant = ModelVariant.PHI1_5
@@ -281,10 +290,11 @@ def test_phi1_5(output_file, num_layers, request, accuracy_testing):
281290
num_layers=num_layers,
282291
request=request,
283292
accuracy_testing=accuracy_testing,
293+
batch_size=batch_size,
284294
)
285295

286296

287-
def test_phi2(output_file, num_layers, request, accuracy_testing):
297+
def test_phi2(output_file, num_layers, request, accuracy_testing, batch_size):
288298
from third_party.tt_forge_models.phi2.causal_lm.pytorch.loader import ModelLoader, ModelVariant
289299

290300
variant = ModelVariant.PHI2
@@ -295,10 +305,11 @@ def test_phi2(output_file, num_layers, request, accuracy_testing):
295305
num_layers=num_layers,
296306
request=request,
297307
accuracy_testing=accuracy_testing,
308+
batch_size=batch_size,
298309
)
299310

300311

301-
def test_falcon3_1b(output_file, num_layers, request, accuracy_testing):
312+
def test_falcon3_1b(output_file, num_layers, request, accuracy_testing, batch_size):
302313
from third_party.tt_forge_models.falcon.pytorch.loader import ModelLoader, ModelVariant
303314

304315
variant = ModelVariant.FALCON_1B
@@ -312,10 +323,11 @@ def test_falcon3_1b(output_file, num_layers, request, accuracy_testing):
312323
num_layers=num_layers,
313324
request=request,
314325
accuracy_testing=accuracy_testing,
326+
batch_size=batch_size,
315327
)
316328

317329

318-
def test_falcon3_3b(output_file, num_layers, request, accuracy_testing):
330+
def test_falcon3_3b(output_file, num_layers, request, accuracy_testing, batch_size):
319331
from third_party.tt_forge_models.falcon.pytorch.loader import ModelLoader, ModelVariant
320332

321333
variant = ModelVariant.FALCON_3B
@@ -329,10 +341,11 @@ def test_falcon3_3b(output_file, num_layers, request, accuracy_testing):
329341
num_layers=num_layers,
330342
request=request,
331343
accuracy_testing=accuracy_testing,
344+
batch_size=batch_size,
332345
)
333346

334347

335-
def test_qwen_2_5_0_5b(output_file, num_layers, request, accuracy_testing):
348+
def test_qwen_2_5_0_5b(output_file, num_layers, request, accuracy_testing, batch_size):
336349
from third_party.tt_forge_models.qwen_2_5.causal_lm.pytorch.loader import ModelLoader, ModelVariant
337350

338351
variant = ModelVariant.QWEN_2_5_0_5B_INSTRUCT
@@ -344,10 +357,11 @@ def test_qwen_2_5_0_5b(output_file, num_layers, request, accuracy_testing):
344357
num_layers=num_layers,
345358
request=request,
346359
accuracy_testing=accuracy_testing,
360+
batch_size=batch_size,
347361
)
348362

349363

350-
def test_qwen_3_0_6b(output_file, num_layers, request, accuracy_testing):
364+
def test_qwen_3_0_6b(output_file, num_layers, request, accuracy_testing, batch_size):
351365
from third_party.tt_forge_models.qwen_3.causal_lm.pytorch.loader import ModelLoader, ModelVariant
352366

353367
variant = ModelVariant.QWEN_3_0_6B
@@ -358,10 +372,11 @@ def test_qwen_3_0_6b(output_file, num_layers, request, accuracy_testing):
358372
num_layers=num_layers,
359373
request=request,
360374
accuracy_testing=accuracy_testing,
375+
batch_size=batch_size,
361376
)
362377

363378

364-
def test_qwen_3_1_7b(output_file, num_layers, request, accuracy_testing):
379+
def test_qwen_3_1_7b(output_file, num_layers, request, accuracy_testing, batch_size):
365380
from third_party.tt_forge_models.qwen_3.causal_lm.pytorch.loader import ModelLoader, ModelVariant
366381

367382
variant = ModelVariant.QWEN_3_1_7B
@@ -372,10 +387,11 @@ def test_qwen_3_1_7b(output_file, num_layers, request, accuracy_testing):
372387
num_layers=num_layers,
373388
request=request,
374389
accuracy_testing=accuracy_testing,
390+
batch_size=batch_size,
375391
)
376392

377393

378-
def test_qwen_3_4b(output_file, num_layers, request, accuracy_testing):
394+
def test_qwen_3_4b(output_file, num_layers, request, accuracy_testing, batch_size):
379395
from third_party.tt_forge_models.qwen_3.causal_lm.pytorch.loader import ModelLoader, ModelVariant
380396

381397
variant = ModelVariant.QWEN_3_4B
@@ -386,10 +402,11 @@ def test_qwen_3_4b(output_file, num_layers, request, accuracy_testing):
386402
num_layers=num_layers,
387403
request=request,
388404
accuracy_testing=accuracy_testing,
405+
batch_size=batch_size,
389406
)
390407

391408

392-
def test_qwen_2_5_1_5b(output_file, num_layers, request, accuracy_testing):
409+
def test_qwen_2_5_1_5b(output_file, num_layers, request, accuracy_testing, batch_size):
393410
from third_party.tt_forge_models.qwen_2_5.causal_lm.pytorch.loader import ModelLoader, ModelVariant
394411

395412
variant = ModelVariant.QWEN_2_5_1_5B_INSTRUCT
@@ -400,10 +417,11 @@ def test_qwen_2_5_1_5b(output_file, num_layers, request, accuracy_testing):
400417
num_layers=num_layers,
401418
request=request,
402419
accuracy_testing=accuracy_testing,
420+
batch_size=batch_size,
403421
)
404422

405423

406-
def test_qwen_2_5_3b(output_file, num_layers, request, accuracy_testing):
424+
def test_qwen_2_5_3b(output_file, num_layers, request, accuracy_testing, batch_size):
407425
from third_party.tt_forge_models.qwen_2_5.causal_lm.pytorch.loader import ModelLoader, ModelVariant
408426

409427
variant = ModelVariant.QWEN_2_5_3B_INSTRUCT
@@ -414,10 +432,11 @@ def test_qwen_2_5_3b(output_file, num_layers, request, accuracy_testing):
414432
num_layers=num_layers,
415433
request=request,
416434
accuracy_testing=accuracy_testing,
435+
batch_size=batch_size,
417436
)
418437

419438

420-
def test_qwen_3_8b(output_file, num_layers, request, accuracy_testing):
439+
def test_qwen_3_8b(output_file, num_layers, request, accuracy_testing, batch_size):
421440
from third_party.tt_forge_models.qwen_3.causal_lm.pytorch.loader import ModelLoader, ModelVariant
422441

423442
variant = ModelVariant.QWEN_3_8B
@@ -428,10 +447,11 @@ def test_qwen_3_8b(output_file, num_layers, request, accuracy_testing):
428447
num_layers=num_layers,
429448
request=request,
430449
accuracy_testing=accuracy_testing,
450+
batch_size=batch_size,
431451
)
432452

433453

434-
def test_qwen_2_5_7b(output_file, num_layers, request, accuracy_testing):
454+
def test_qwen_2_5_7b(output_file, num_layers, request, accuracy_testing, batch_size):
435455
from third_party.tt_forge_models.qwen_2_5.causal_lm.pytorch.loader import ModelLoader, ModelVariant
436456

437457
variant = ModelVariant.QWEN_2_5_7B_INSTRUCT
@@ -442,6 +462,7 @@ def test_qwen_2_5_7b(output_file, num_layers, request, accuracy_testing):
442462
num_layers=num_layers,
443463
request=request,
444464
accuracy_testing=accuracy_testing,
465+
batch_size=batch_size,
445466
)
446467

447468

@@ -485,7 +506,7 @@ def test_mamba_2_8b(output_file, num_layers, request):
485506
)
486507

487508

488-
def test_falcon3_7b(output_file, num_layers, request, accuracy_testing):
509+
def test_falcon3_7b(output_file, num_layers, request, accuracy_testing, batch_size):
489510
from third_party.tt_forge_models.falcon.pytorch.loader import ModelLoader, ModelVariant
490511

491512
variant = ModelVariant.FALCON_7B
@@ -499,10 +520,11 @@ def test_falcon3_7b(output_file, num_layers, request, accuracy_testing):
499520
num_layers=num_layers,
500521
request=request,
501522
accuracy_testing=accuracy_testing,
523+
batch_size=batch_size,
502524
)
503525

504526

505-
def test_mistral_7b(output_file, num_layers, request, accuracy_testing):
527+
def test_mistral_7b(output_file, num_layers, request, accuracy_testing, batch_size):
506528
from third_party.tt_forge_models.mistral.pytorch.loader import ModelLoader, ModelVariant
507529

508530
variant = ModelVariant.MISTRAL_7B_INSTRUCT_V03
@@ -513,10 +535,11 @@ def test_mistral_7b(output_file, num_layers, request, accuracy_testing):
513535
num_layers=num_layers,
514536
request=request,
515537
accuracy_testing=accuracy_testing,
538+
batch_size=batch_size,
516539
)
517540

518541

519-
def test_ministral_8b(output_file, num_layers, request, accuracy_testing):
542+
def test_ministral_8b(output_file, num_layers, request, accuracy_testing, batch_size):
520543
from third_party.tt_forge_models.mistral.pytorch.loader import ModelLoader, ModelVariant
521544

522545
variant = ModelVariant.MINISTRAL_8B
@@ -528,10 +551,11 @@ def test_ministral_8b(output_file, num_layers, request, accuracy_testing):
528551
request=request,
529552
fp32_dest_acc_en=False,
530553
accuracy_testing=accuracy_testing,
554+
batch_size=batch_size,
531555
)
532556

533557

534-
def test_llama_3_1_8b(output_file, num_layers, request, accuracy_testing):
558+
def test_llama_3_1_8b(output_file, num_layers, request, accuracy_testing, batch_size):
535559
from third_party.tt_forge_models.llama.causal_lm.pytorch.loader import ModelLoader, ModelVariant
536560

537561
variant = ModelVariant.LLAMA_3_1_8B_INSTRUCT
@@ -543,6 +567,7 @@ def test_llama_3_1_8b(output_file, num_layers, request, accuracy_testing):
543567
request=request,
544568
fp32_dest_acc_en=False,
545569
accuracy_testing=accuracy_testing,
570+
batch_size=batch_size,
546571
)
547572

548573

0 commit comments

Comments
 (0)