Skip to content

Commit c8fa853

Browse files
committed
Test script: Allow --eval_rows in wiki2 ppl test
1 parent 318435d commit c8fa853

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

test_inference.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
# (!!!) NOTE: These go on top of the engine arguments that can be found in `model_init.py` (!!!)
4545
parser = argparse.ArgumentParser(description = "Test inference on ExLlamaV2 model")
4646
parser.add_argument("-ed", "--eval_dataset", type = str, help = "Perplexity evaluation dataset (.parquet file)")
47-
parser.add_argument("-er", "--eval_rows", type = int, default = 128, help = "Number of rows to apply from dataset")
47+
parser.add_argument("-er", "--eval_rows", type = int, default = None, help = "Number of rows to apply from dataset (default 128)")
4848
parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample")
4949
parser.add_argument("-et", "--eval_token", action = "store_true", help = "Evaluate perplexity on token-by-token inference using cache")
5050
parser.add_argument("-e8", "--eval_token_8bit", action = "store_true", help = "Evaluate perplexity on token-by-token inference using 8-bit (FP8) cache")
@@ -267,13 +267,15 @@
267267
seqs.append(eval_tokens[:, a:b])
268268
eval_len.append(b if a == 0 else stride)
269269
a += stride
270+
if args.eval_rows and len(seqs) >= args.eval_rows:
271+
break
270272

271273
eval_tokens = torch.cat(seqs, dim = 0)
272274

273275
else:
274276

275277
eval_dataset = args.eval_dataset
276-
eval_rows = args.eval_rows
278+
eval_rows = args.eval_rows or 128
277279
eval_length = args.eval_length
278280

279281
print(f" -- Dataset: {eval_dataset}")

0 commit comments

Comments
 (0)