Skip to content

Commit 3835256

Browse files
committed
- Remove batch_size=1 override from all accuracy test functions
- Update teacher forcing to replicate ground truth across batch dimension - Add --accuracy-testing pytest option (usage: --accuracy-testing true) - Remove default values from accuracy_testing parameters to allow fixture injection
1 parent a5c1547 commit 3835256

File tree

3 files changed

+40
-107
lines changed

3 files changed

+40
-107
lines changed

benchmark/tt-xla/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,13 @@ def pytest_addoption(parser):
162162
type=make_validator_boolean("--experimental-compile"),
163163
help="Enable experimental compile flag (true/false). Overrides config value.",
164164
)
165+
parser.addoption(
166+
"--accuracy-testing",
167+
action="store",
168+
default=None,
169+
type=make_validator_boolean("--accuracy-testing"),
170+
help="Enable accuracy testing mode (true/false). Uses reference data for TOP1/TOP5 accuracy.",
171+
)
165172

166173

167174
@pytest.fixture
@@ -217,3 +224,9 @@ def task(request):
217224
@pytest.fixture
218225
def experimental_compile(request):
219226
return request.config.getoption("--experimental-compile")
227+
228+
229+
@pytest.fixture
230+
def accuracy_testing(request):
231+
value = request.config.getoption("--accuracy-testing")
232+
return value if value is not None else False

benchmark/tt-xla/llm_benchmark.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def generate_and_benchmark(
226226
logits = read_logits_fn(output).to("cpu")
227227
output_logits.append(logits)
228228
next_token_ids = logits[:, -1].argmax(dim=-1)
229-
predicted_token = next_token_ids[0].item() # Assuming batch_size=1
229+
predicted_token = next_token_ids[0].item() # Extract from batch[0] (all items identical in accuracy mode)
230230
predicted_tokens.append(predicted_token)
231231

232232
output_text = [tokenizer.decode(token_id) for token_id in next_token_ids]
@@ -245,8 +245,12 @@ def generate_and_benchmark(
245245
# Update inputs for next iteration
246246
if ground_truth_tokens is not None:
247247
# Teacher forcing: use ground truth token as next input
248+
# Replicate ground truth token for all batch items (they're all identical)
249+
batch_size = input_args["input_ids"].shape[0]
248250
gt_token = ground_truth_tokens[step]
249-
input_args["input_ids"] = gt_token.unsqueeze(0).unsqueeze(0).to(device) # Shape: [1, 1]
251+
input_args["input_ids"] = (
252+
gt_token.unsqueeze(0).unsqueeze(0).expand(batch_size, 1).to(device)
253+
) # Shape: [batch_size, 1]
250254
else:
251255
# Standard generation: use predicted token as next input
252256
input_args["input_ids"] = next_token_ids.unsqueeze(-1).to(device)

0 commit comments

Comments
 (0)