|
44 | 44 | # (!!!) NOTE: These go on top of the engine arguments that can be found in `model_init.py` (!!!)
|
45 | 45 | parser = argparse.ArgumentParser(description = "Test inference on ExLlamaV2 model")
|
46 | 46 | parser.add_argument("-ed", "--eval_dataset", type = str, help = "Perplexity evaluation dataset (.parquet file)")
|
47 |
| -parser.add_argument("-er", "--eval_rows", type = int, default = 128, help = "Number of rows to apply from dataset") |
| 47 | +parser.add_argument("-er", "--eval_rows", type = int, default = None, help = "Number of rows to apply from dataset (default 128)") |
48 | 48 | parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample")
|
49 | 49 | parser.add_argument("-et", "--eval_token", action = "store_true", help = "Evaluate perplexity on token-by-token inference using cache")
|
50 | 50 | parser.add_argument("-e8", "--eval_token_8bit", action = "store_true", help = "Evaluate perplexity on token-by-token inference using 8-bit (FP8) cache")
|
|
267 | 267 | seqs.append(eval_tokens[:, a:b])
|
268 | 268 | eval_len.append(b if a == 0 else stride)
|
269 | 269 | a += stride
|
| 270 | + if args.eval_rows and len(seqs) >= args.eval_rows: |
| 271 | + break |
270 | 272 |
|
271 | 273 | eval_tokens = torch.cat(seqs, dim = 0)
|
272 | 274 |
|
273 | 275 | else:
|
274 | 276 |
|
275 | 277 | eval_dataset = args.eval_dataset
|
276 |
| - eval_rows = args.eval_rows |
| 278 | + eval_rows = args.eval_rows or 128 |
277 | 279 | eval_length = args.eval_length
|
278 | 280 |
|
279 | 281 | print(f" -- Dataset: {eval_dataset}")
|
|
0 commit comments