File tree Expand file tree Collapse file tree 3 files changed +40
-107
lines changed
Expand file tree Collapse file tree 3 files changed +40
-107
lines changed Original file line number Diff line number Diff line change @@ -162,6 +162,13 @@ def pytest_addoption(parser):
162162 type = make_validator_boolean ("--experimental-compile" ),
163163 help = "Enable experimental compile flag (true/false). Overrides config value." ,
164164 )
165+ parser .addoption (
166+ "--accuracy-testing" ,
167+ action = "store" ,
168+ default = None ,
169+ type = make_validator_boolean ("--accuracy-testing" ),
170+ help = "Enable accuracy testing mode (true/false). Uses reference data for TOP1/TOP5 accuracy." ,
171+ )
165172
166173
167174@pytest .fixture
@@ -217,3 +224,9 @@ def task(request):
217224@pytest .fixture
218225def experimental_compile (request ):
219226 return request .config .getoption ("--experimental-compile" )
227+
228+
229+ @pytest .fixture
230+ def accuracy_testing (request ):
231+ value = request .config .getoption ("--accuracy-testing" )
232+ return value if value is not None else False
Original file line number Diff line number Diff line change @@ -226,7 +226,7 @@ def generate_and_benchmark(
226226 logits = read_logits_fn (output ).to ("cpu" )
227227 output_logits .append (logits )
228228 next_token_ids = logits [:, - 1 ].argmax (dim = - 1 )
229- predicted_token = next_token_ids [0 ].item () # Assuming batch_size=1
229+ predicted_token = next_token_ids [0 ].item () # Extract from batch[0] (all items identical in accuracy mode)
230230 predicted_tokens .append (predicted_token )
231231
232232 output_text = [tokenizer .decode (token_id ) for token_id in next_token_ids ]
@@ -245,8 +245,12 @@ def generate_and_benchmark(
245245 # Update inputs for next iteration
246246 if ground_truth_tokens is not None :
247247 # Teacher forcing: use ground truth token as next input
248+ # Replicate ground truth token for all batch items (they're all identical)
249+ batch_size = input_args ["input_ids" ].shape [0 ]
248250 gt_token = ground_truth_tokens [step ]
249- input_args ["input_ids" ] = gt_token .unsqueeze (0 ).unsqueeze (0 ).to (device ) # Shape: [1, 1]
251+ input_args ["input_ids" ] = (
252+ gt_token .unsqueeze (0 ).unsqueeze (0 ).expand (batch_size , 1 ).to (device )
253+ ) # Shape: [batch_size, 1]
250254 else :
251255 # Standard generation: use predicted token as next input
252256 input_args ["input_ids" ] = next_token_ids .unsqueeze (- 1 ).to (device )
You can’t perform that action at this time.
0 commit comments