Skip to content

Commit 660512a

Browse files
committed
v0.37.0
See https://github.com/quic/ai-hub-models/releases/v0.37.0 for changelog. Signed-off-by: QAIHM Team <[email protected] >
1 parent 8cdeb11 commit 660512a

File tree

528 files changed

+30008
-20068
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

528 files changed

+30008
-20068
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ and many more.
171171
| [ResNet101](https://aihub.qualcomm.com/models/resnet101) | [qai_hub_models.models.resnet101](qai_hub_models/models/resnet101/README.md) |
172172
| [ResNet18](https://aihub.qualcomm.com/models/resnet18) | [qai_hub_models.models.resnet18](qai_hub_models/models/resnet18/README.md) |
173173
| [ResNet50](https://aihub.qualcomm.com/models/resnet50) | [qai_hub_models.models.resnet50](qai_hub_models/models/resnet50/README.md) |
174+
| [Sequencer2D](https://aihub.qualcomm.com/models/sequencer2d) | [qai_hub_models.models.sequencer2d](qai_hub_models/models/sequencer2d/README.md) |
174175
| [Shufflenet-v2](https://aihub.qualcomm.com/models/shufflenet_v2) | [qai_hub_models.models.shufflenet_v2](qai_hub_models/models/shufflenet_v2/README.md) |
175176
| [SqueezeNet-1.1](https://aihub.qualcomm.com/models/squeezenet1_1) | [qai_hub_models.models.squeezenet1_1](qai_hub_models/models/squeezenet1_1/README.md) |
176177
| [Swin-Base](https://aihub.qualcomm.com/models/swin_base) | [qai_hub_models.models.swin_base](qai_hub_models/models/swin_base/README.md) |
@@ -310,6 +311,7 @@ and many more.
310311
| **Text Generation**
311312
| [ALLaM-7B](https://aihub.qualcomm.com/models/allam_7b) | [qai_hub_models.models.allam_7b](qai_hub_models/models/allam_7b/README.md) |
312313
| [Baichuan2-7B](https://aihub.qualcomm.com/models/baichuan2_7b) | [qai_hub_models.models.baichuan2_7b](qai_hub_models/models/baichuan2_7b/README.md) |
314+
| [Falcon3-7B-Instruct](https://aihub.qualcomm.com/models/falcon_v3_7b_instruct) | [qai_hub_models.models.falcon_v3_7b_instruct](qai_hub_models/models/falcon_v3_7b_instruct/README.md) |
313315
| [IBM-Granite-v3.1-8B-Instruct](https://aihub.qualcomm.com/models/ibm_granite_v3_1_8b_instruct) | [qai_hub_models.models.ibm_granite_v3_1_8b_instruct](qai_hub_models/models/ibm_granite_v3_1_8b_instruct/README.md) |
314316
| [IndusQ-1.1B](https://aihub.qualcomm.com/models/indus_1b) | [qai_hub_models.models.indus_1b](qai_hub_models/models/indus_1b/README.md) |
315317
| [JAIS-6p7b-Chat](https://aihub.qualcomm.com/models/jais_6p7b_chat) | [qai_hub_models.models.jais_6p7b_chat](qai_hub_models/models/jais_6p7b_chat/README.md) |

qai_hub_models/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
# SPDX-License-Identifier: BSD-3-Clause
44
# ---------------------------------------------------------------------
55

6-
__version__ = "0.36.0"
6+
__version__ = "0.37.0"

qai_hub_models/configs/_info_yaml_enums.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class MODEL_LICENSE(Enum):
115115
LLAMA2 = "llama2"
116116
LLAMA3 = "llama3"
117117
TAIDE = "taide"
118+
FALCON3 = "falcon3"
118119

119120
@property
120121
def is_copyleft(self) -> bool:
@@ -127,6 +128,7 @@ def is_copyleft(self) -> bool:
127128
MODEL_LICENSE.LLAMA2,
128129
MODEL_LICENSE.LLAMA3,
129130
MODEL_LICENSE.TAIDE,
131+
MODEL_LICENSE.FALCON3,
130132
]
131133

132134
@property
@@ -169,6 +171,8 @@ def url(self) -> str | None:
169171
return "https://github.com/facebookresearch/llama/blob/main/LICENSE"
170172
elif self == MODEL_LICENSE.TAIDE:
171173
return "https://en.taide.tw/download.html"
174+
elif self == MODEL_LICENSE.FALCON3:
175+
return "https://falconllm.tii.ae/falcon-terms-and-conditions.html"
172176
return None
173177

174178

@@ -206,6 +210,8 @@ class MODEL_USE_CASE(Enum):
206210
SUPER_RESOLUTION = "Super Resolution"
207211
SEMANTIC_SEGMENTATION = "Semantic Segmentation"
208212
DEPTH_ESTIMATION = "Depth Estimation"
213+
GAZE_ESTIMATION = "Gaze Estimation"
214+
209215
# Ex: OCR, image caption
210216
IMAGE_TO_TEXT = "Image To Text"
211217
OBJECT_DETECTION = "Object Detection"
@@ -234,6 +240,8 @@ def map_to_hf_pipeline_tag(self):
234240
return "image-segmentation"
235241
if self.name == "POSE_ESTIMATION":
236242
return "keypoint-detection"
243+
if self.name == "GAZE_ESTIMATION":
244+
return "gaze-estimation"
237245
if self.name == "AUDIO_ENHANCEMENT":
238246
return "audio-to-audio"
239247
if self.name == "VIDEO_GENERATION":

qai_hub_models/configs/code_gen_yaml.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ class QAIHMModelCodeGen(BaseQAIHMConfig):
9494
# ("AOT prepare") are enabled, both in CI and in Scorecard.
9595
requires_aot_prepare: bool = False
9696

97+
# Supported GenAI based runtimes.
98+
# If set, ONLY these runtimes will be supported. All others will be disabled.
99+
supported_genai_runtimes: list[TargetRuntime] = Field(default_factory=list)
100+
97101
# If set, disables generating `export.py`.
98102
skip_export: bool = False
99103

@@ -171,6 +175,12 @@ def failure_reason(
171175
"""
172176
Return the reason a model failed or None if the model did not fail.
173177
"""
178+
if self.supported_genai_runtimes:
179+
if runtime not in self.supported_genai_runtimes:
180+
return f"{runtime} is not supported for this GenAI model."
181+
elif runtime.is_exclusively_for_genai:
182+
return "GenAI runtimes are not supported by this model."
183+
174184
if self.is_precompiled and runtime != TargetRuntime.QNN_CONTEXT_BINARY:
175185
return "Precompiled models are only supported via the QNN path."
176186

@@ -180,7 +190,11 @@ def failure_reason(
180190
if self.requires_aot_prepare and not runtime.is_aot_compiled:
181191
return "Only runtimes that are compiled to context binary ahead of time are supported."
182192

183-
if not self.requires_aot_prepare and runtime.is_aot_compiled:
193+
if (
194+
not self.requires_aot_prepare
195+
and runtime.is_aot_compiled
196+
and not runtime.is_exclusively_for_genai
197+
):
184198
# Only the JIT path is tested if this model does not require AOT prepare.
185199
# All AOT paths will fail if QNN fails.
186200
runtime = TargetRuntime.QNN_DLC
@@ -281,6 +295,11 @@ def check_fields(self) -> QAIHMModelCodeGen:
281295
raise ValueError(
282296
"If pip_pre_build_reqs is set, global_requirements_incompatible must also be true."
283297
)
298+
for x in self.supported_genai_runtimes:
299+
if not x.is_exclusively_for_genai:
300+
raise ValueError(
301+
f"{x.value} is not a GenAI runtime, and should not be listed in supported_genai_runtimes."
302+
)
284303

285304
return self
286305

qai_hub_models/configs/tool_versions.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,7 @@
88
from qai_hub.client import JobType
99
from qai_hub.public_rest_api import get_job_results
1010

11-
from qai_hub_models.models.common import (
12-
InferenceEngine,
13-
Optional,
14-
QAIRTVersion,
15-
TargetRuntime,
16-
)
11+
from qai_hub_models.models.common import Optional, QAIRTVersion, TargetRuntime
1712
from qai_hub_models.utils.base_config import BaseQAIHMConfig
1813
from qai_hub_models.utils.qai_hub_helpers import extract_job_options
1914

@@ -51,7 +46,10 @@ def from_compiled_model(
5146
Raises:
5247
ValueError if the model was not compiled by AI Hub.
5348
"""
54-
if model.producer is None or not model.producer._job_type == JobType.COMPILE:
49+
if model.producer is None or model.producer._job_type not in [
50+
JobType.COMPILE,
51+
JobType.LINK,
52+
]:
5553
raise ValueError(
5654
"Model must be compiled with AI Hub to extract tool versions."
5755
)
@@ -117,12 +115,26 @@ def from_job(job: hub.Job, parse_version_tags: bool = False) -> "ToolVersions":
117115
ValueError if the job type is invalid.
118116
"""
119117
# Use job_type instead of isinstance to support test mocking.
120-
if job._job_type not in [JobType.COMPILE, JobType.PROFILE, JobType.INFERENCE]:
118+
if job._job_type not in [
119+
JobType.COMPILE,
120+
JobType.LINK,
121+
JobType.PROFILE,
122+
JobType.INFERENCE,
123+
]:
121124
raise ValueError(
122125
f"Cannot extract QAIRT SDK version from job of type {job.job_type}"
123126
)
124127

125128
if not job.get_status().success:
129+
if job._job_type == JobType.LINK:
130+
# Link jobs inherit their QAIRT version from input model files.
131+
models = cast(hub.LinkJob, job).models
132+
for model in models:
133+
if model.producer is not None:
134+
return ToolVersions.from_compiled_model(model)
135+
# None of the source models came from us, so we can't detect what QAIRT version to use.
136+
return ToolVersions()
137+
126138
# If the job is not successful, the only way to get the QAIRT version is to look at the job flags.
127139
job_options = extract_job_options(job)
128140
version: Optional[str] = None
@@ -135,10 +147,7 @@ def from_job(job: hub.Job, parse_version_tags: bool = False) -> "ToolVersions":
135147
# QAIRT is applicable for compile jobs only if the target runtime uses QAIRT converters.
136148
if x := job_options.get("target_runtime"):
137149
rts = [rt for rt in TargetRuntime if rt.value == x]
138-
if (
139-
len(rts) == 1
140-
and rts[0].inference_engine == InferenceEngine.QNN
141-
):
150+
if len(rts) == 1 and rts[0].qairt_version_changes_compilation:
142151
version = "default"
143152
else:
144153
version = "default"
@@ -157,7 +166,7 @@ def from_job(job: hub.Job, parse_version_tags: bool = False) -> "ToolVersions":
157166
qairt=QAIRTVersion(version, validate_exists_on_ai_hub=False)
158167
)
159168

160-
if job._job_type == JobType.COMPILE:
169+
if job._job_type == JobType.COMPILE or job._job_type == JobType.LINK:
161170
return ToolVersions.from_compiled_model(
162171
cast(hub.Model, cast(hub.CompileJob, job).get_target_model())
163172
)

qai_hub_models/datasets/__init__.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,26 @@
2020
# We don't want to require a user to install requirements for all datasets just to
2121
# import the datasets folder. Therefore we only include the datasets that can
2222
# be imported.
23-
def _try_import_dataset(module_name: str, cls: str):
23+
def _try_import_dataset(module_name: str, cls: str, name: str | None = None):
2424
"""
2525
Import the dataset and add it to the DATASET_NAME_MAP, or pass
2626
if dependencies for the dataset aren't installed.
2727
"""
28+
if name is None:
29+
assert module_name[0] == "."
30+
name = module_name[1:]
2831
try:
2932
module = importlib.import_module(module_name, package="qai_hub_models.datasets")
3033
except NotImplementedError as e:
3134
if "AIMET-ONNX" in str(e):
3235
# stable diffusion dataset requires aimet-onnx
33-
_ALL_DATASETS_IMPORT_ERRORS[module_name] = e
36+
_ALL_DATASETS_IMPORT_ERRORS[name] = e
3437
return
3538
raise e
3639
except Exception as e:
37-
if module_name.startswith("."):
38-
module_name = module_name[1:]
3940
if (
4041
isinstance(e, ModuleNotFoundError)
41-
and str(e) == f"No module named 'qai_hub_models.datasets.{module_name}"
42+
and str(e) == f"No module named 'qai_hub_models.datasets{module_name}"
4243
):
4344
# this module legitimately does not exist
4445
raise e
@@ -48,12 +49,15 @@ def _try_import_dataset(module_name: str, cls: str):
4849
# By default, the name of the dataset is the name of its module.
4950
# We add it to this import errors list to hopefully raise the error
5051
# at a later time (when the user requests this dataset).
51-
_ALL_DATASETS_IMPORT_ERRORS[module_name] = e
52+
_ALL_DATASETS_IMPORT_ERRORS[name] = e
5253
return
5354

5455
if x := getattr(module, cls, None):
5556
xds = cast(type[BaseDataset], x)
56-
DATASET_NAME_MAP[xds.dataset_name()] = xds
57+
assert (
58+
name == xds.dataset_name()
59+
), f"Name is not consistent with call to dataset_name(): {name} vs. {xds.dataset_name()}"
60+
DATASET_NAME_MAP[name] = xds
5761
else:
5862
raise ValueError(
5963
f"Could not import {cls}. {cls} was not found in {module_name}"
@@ -70,7 +74,7 @@ def _try_import_dataset(module_name: str, cls: str):
7074
_try_import_dataset(".coco91class", "Coco91ClassDataset")
7175
_try_import_dataset(".coco_face", "CocoFaceDataset")
7276
_try_import_dataset(".human_faces", "HumanFacesDataset")
73-
_try_import_dataset(".human_faces", "HumanFaces192Dataset")
77+
_try_import_dataset(".human_faces", "HumanFaces192Dataset", name="human_faces_192")
7478
_try_import_dataset(".coco_panoptic_seg", "CocoPanopticSegmentationDataset")
7579
_try_import_dataset(".foot_track_dataset", "FootTrackDataset")
7680
_try_import_dataset(".gear_guard_dataset", "GearGuardDataset")
@@ -98,16 +102,47 @@ def _try_import_dataset(module_name: str, cls: str):
98102
_try_import_dataset(".eg1800", "eg1800SegmentationDataset")
99103
_try_import_dataset(".kitti", "KittiDataset")
100104
_try_import_dataset(".semantic_kitti", "SemanticKittiDataset")
101-
_try_import_dataset(".stable_diffusion_calib", "StableDiffusionCalibDatasetTextEncoder")
102-
_try_import_dataset(".stable_diffusion_calib", "StableDiffusionCalibDatasetUnet")
103-
_try_import_dataset(".stable_diffusion_calib", "StableDiffusionCalibDatasetVae")
104-
_try_import_dataset(".stable_diffusion_calib", "StableDiffusionCalibDatasetControlNet")
105+
_try_import_dataset(
106+
".stable_diffusion_calib",
107+
"StableDiffusionCalibDatasetTextEncoder",
108+
name="stable_diffusion_calib_text_encoder",
109+
)
110+
_try_import_dataset(
111+
".stable_diffusion_calib",
112+
"StableDiffusionCalibDatasetUnet",
113+
name="stable_diffusion_calib_unet",
114+
)
115+
_try_import_dataset(
116+
".stable_diffusion_calib",
117+
"StableDiffusionCalibDatasetVae",
118+
name="stable_diffusion_calib_vae",
119+
)
120+
_try_import_dataset(
121+
".stable_diffusion_calib",
122+
"StableDiffusionCalibDatasetControlNet",
123+
name="stable_diffusion_calib_controlnet",
124+
)
105125
_try_import_dataset(".celebahq", "CelebAHQDataset")
106126
_try_import_dataset(".wikitext", "WikiText")
107127
_try_import_dataset(".wikitext_ja", "WikiText_Japanese")
108128
_try_import_dataset(".tiny_mmlu", "TinyMMLU")
109129
_try_import_dataset(".mmlu", "MMLU")
110130
_try_import_dataset(".mmmlu", "MMMLU")
131+
_try_import_dataset(".mmmlu", "MMMLU_AR", name="mmmlu_ar")
132+
_try_import_dataset(".mmmlu", "MMMLU_BN", name="mmmlu_bn")
133+
_try_import_dataset(".mmmlu", "MMMLU_DE", name="mmmlu_de")
134+
_try_import_dataset(".mmmlu", "MMMLU_ES", name="mmmlu_es")
135+
_try_import_dataset(".mmmlu", "MMMLU_FR", name="mmmlu_fr")
136+
_try_import_dataset(".mmmlu", "MMMLU_HI", name="mmmlu_hi")
137+
_try_import_dataset(".mmmlu", "MMMLU_ID", name="mmmlu_id")
138+
_try_import_dataset(".mmmlu", "MMMLU_IT", name="mmmlu_it")
139+
_try_import_dataset(".mmmlu", "MMMLU_JA", name="mmmlu_ja")
140+
_try_import_dataset(".mmmlu", "MMMLU_KO", name="mmmlu_ko")
141+
_try_import_dataset(".mmmlu", "MMMLU_PT", name="mmmlu_pt")
142+
_try_import_dataset(".mmmlu", "MMMLU_SW", name="mmmlu_sw")
143+
_try_import_dataset(".mmmlu", "MMMLU_YO", name="mmmlu_yo")
144+
_try_import_dataset(".mmmlu", "MMMLU_ZH", name="mmmlu_zh")
145+
_try_import_dataset(".mpiigaze", "MPIIGazeDataset")
111146
_try_import_dataset(".libri_speech", "LibriSpeechDataset")
112147
_try_import_dataset(
113148
".amazon_counterfactual", "AmazonCounterfactualClassificationDataset"

qai_hub_models/datasets/common.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
import shutil
1111
from abc import ABC, abstractmethod
1212
from collections.abc import Sized
13+
from copy import copy
1314
from enum import Enum, unique
1415
from functools import cached_property
1516
from pathlib import Path
16-
from typing import NamedTuple, final
17+
from typing import Any, NamedTuple, final
1718

18-
from torch.utils.data import Dataset
19+
from torch.utils.data import Dataset, default_collate
1920

2021
from qai_hub_models.utils.asset_loaders import LOCAL_STORE_DEFAULT_PATH
2122
from qai_hub_models.utils.input_spec import InputSpec
@@ -34,6 +35,30 @@ class DatasetSplit(Enum):
3435
TEST = 2
3536

3637

38+
class AugmentedLabelDataset(Dataset):
39+
"""
40+
Augment labels to a dataset (making the label a tuple, if labels are
41+
already present).
42+
"""
43+
44+
def __init__(self, base_dataset, extra_data):
45+
self.base_dataset = base_dataset
46+
self.extra_data = extra_data
47+
self.extra_len = len(extra_data)
48+
49+
def __len__(self):
50+
return len(self.base_dataset)
51+
52+
def __getitem__(self, idx):
53+
item = copy(self.base_dataset[idx])
54+
extra_item = self.extra_data[idx % self.extra_len]
55+
if "label" in item:
56+
item["label"] = (item["label"], extra_item)
57+
else:
58+
item["label"] = extra_item
59+
return item
60+
61+
3762
class DatasetMetadata(NamedTuple):
3863
"""Metadata about the dataset to publish on the website."""
3964

@@ -80,6 +105,13 @@ def __init__(
80105
self.input_spec = input_spec
81106
self.download_data()
82107

108+
@staticmethod
109+
def collate_fn(batch: Any) -> Any:
110+
"""
111+
To be passed into DataLoader(..., collate_fn=...).
112+
"""
113+
return default_collate(batch)
114+
83115
@final
84116
def download_data(self) -> None:
85117
if self._validate_data():

0 commit comments

Comments
 (0)