|
11 | 11 |
|
12 | 12 | import torch
|
13 | 13 | import uvloop
|
14 |
| -from benchmark_dataset import (BurstGPTDataset, ConversationDataset, |
15 |
| - InstructCoderDataset, RandomDataset, |
16 |
| - SampleRequest, ShareGPTDataset, SonnetDataset, |
17 |
| - VisionArenaDataset) |
| 14 | +from benchmark_dataset import (AIMODataset, BurstGPTDataset, |
| 15 | + ConversationDataset, InstructCoderDataset, |
| 16 | + RandomDataset, SampleRequest, ShareGPTDataset, |
| 17 | + SonnetDataset, VisionArenaDataset) |
18 | 18 | from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
19 | 19 | from tqdm import tqdm
|
20 | 20 | from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
@@ -332,7 +332,10 @@ def get_requests(args, tokenizer):
|
332 | 332 | common_kwargs['dataset_subset'] = args.hf_subset
|
333 | 333 | common_kwargs['dataset_split'] = args.hf_split
|
334 | 334 | sample_kwargs["enable_multimodal_chat"] = True
|
335 |
| - |
| 335 | + elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: |
| 336 | + dataset_cls = AIMODataset |
| 337 | + common_kwargs['dataset_subset'] = None |
| 338 | + common_kwargs['dataset_split'] = "train" |
336 | 339 | else:
|
337 | 340 | raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
338 | 341 | # Remove None values
|
@@ -467,12 +470,13 @@ def validate_args(args):
|
467 | 470 | since --dataset-name is not 'hf'.",
|
468 | 471 | stacklevel=2)
|
469 | 472 | elif args.dataset_name == "hf":
|
470 |
| - if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: |
471 |
| - assert args.backend == "vllm-chat", "VisionArenaDataset needs to use vllm-chat as the backend." #noqa: E501 |
472 |
| - elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: |
473 |
| - assert args.backend == "vllm", "InstructCoder dataset needs to use vllm as the backend." #noqa: E501 |
474 |
| - elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: |
475 |
| - assert args.backend == "vllm-chat", "ConversationDataset needs to use vllm-chat as the backend." #noqa: E501 |
| 473 | + if args.dataset_path in ( |
| 474 | + VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() |
| 475 | + | ConversationDataset.SUPPORTED_DATASET_PATHS): |
| 476 | + assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501 |
| 477 | + elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS |
| 478 | + | AIMODataset.SUPPORTED_DATASET_PATHS): |
| 479 | + assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501 |
476 | 480 | else:
|
477 | 481 | raise ValueError(
|
478 | 482 | f"{args.dataset_path} is not supported by hf dataset.")
|
|
0 commit comments