|
9 | 9 |
|
10 | 10 | import os
|
11 | 11 | from datetime import datetime
|
12 |
| -from typing import TYPE_CHECKING, List, Optional, Union |
| 12 | +from typing import TYPE_CHECKING, Dict, List, Optional, Union |
13 | 13 |
|
14 | 14 | from loguru import logger
|
15 | 15 | from torch.utils.data import DataLoader
|
@@ -230,6 +230,7 @@ def oneshot(
|
230 | 230 | dataset: Optional[Union[str, "Dataset", "DatasetDict"]] = None,
|
231 | 231 | dataset_config_name: Optional[str] = None,
|
232 | 232 | dataset_path: Optional[str] = None,
|
| 233 | + splits: Optional[Union[str, List, Dict]] = None, |
233 | 234 | num_calibration_samples: int = 512,
|
234 | 235 | shuffle_calibration_samples: bool = True,
|
235 | 236 | max_seq_length: int = 384,
|
@@ -288,6 +289,7 @@ def oneshot(
|
288 | 289 | :param dataset_config_name: The configuration name of the dataset
|
289 | 290 | to use.
|
290 | 291 | :param dataset_path: Path to a custom dataset. Supports json, csv, dvc.
|
| 292 | + :param splits: Optional percentages of each split to download. |
291 | 293 | :param num_calibration_samples: Number of samples to use for one-shot
|
292 | 294 | calibration.
|
293 | 295 | :param shuffle_calibration_samples: Whether to shuffle the dataset before
|
|
0 commit comments