Skip to content

Commit 20b2013

Browse files
committed
include patch_diffusers=True when using --patch
1 parent 4dd2258 commit 20b2013

File tree

4 files changed

+26
-9
lines changed

4 files changed

+26
-9
lines changed

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,18 @@ def _guess_task_from_config(config: Any) -> Optional[str]:
140140

141141
@functools.cache
142142
def task_from_arch(
143-
arch: str, default_value: Optional[str] = None, model_id: Optional[str] = None
143+
arch: str,
144+
default_value: Optional[str] = None,
145+
model_id: Optional[str] = None,
146+
subfolder: Optional[str] = None,
144147
) -> str:
145148
"""
146149
This function relies on stored information. That information needs to be refresh.
147150
148151
:param arch: architecture name
149152
:param default_value: default value in case the task cannot be determined
150153
:param model_id: unused unless the architecture does not help.
154+
:param subfolder: subfolder
151155
:return: task
152156
153157
.. runpython::
@@ -162,7 +166,7 @@ def task_from_arch(
162166
data = load_architecture_task()
163167
if arch not in data and model_id:
164168
# Let's try with the model id.
165-
return task_from_id(model_id)
169+
return task_from_id(model_id, subfolder=subfolder)
166170
if default_value is not None:
167171
return data.get(arch, default_value)
168172
assert arch in data, (
@@ -178,6 +182,7 @@ def task_from_id(
178182
default_value: Optional[str] = None,
179183
pretrained: bool = False,
180184
fall_back_to_pretrained: bool = True,
185+
subfolder: Optional[str] = None,
181186
) -> str:
182187
"""
183188
Returns the task attached to a model id.
@@ -187,7 +192,7 @@ def task_from_id(
187192
if the task cannot be determined
188193
:param pretrained: uses the config
189194
:param fall_back_to_pretrained: falls back to pretrained config
190-
:param exc: raises an exception if True
195+
:param subfolder: subfolder
191196
:return: task
192197
"""
193198
if not pretrained:
@@ -196,7 +201,7 @@ def task_from_id(
196201
except RuntimeError:
197202
if not fall_back_to_pretrained:
198203
raise
199-
config = get_pretrained_config(model_id)
204+
config = get_pretrained_config(model_id, subfolder=subfolder)
200205
try:
201206
return config.pipeline_tag
202207
except AttributeError:
@@ -206,6 +211,8 @@ def task_from_id(
206211
data = load_architecture_task()
207212
if model_id in data:
208213
return data[model_id]
214+
if type(config) is dict and "_class_name" in config:
215+
return task_from_arch(config["_class_name"], default_value=default_value)
209216
if not config.architectures or not config.architectures:
210217
# Some hardcoded values until a better solution is found.
211218
if model_id.startswith("google/bert_"):

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
BlenderbotModel,feature-extraction
2323
BloomModel,feature-extraction
2424
CLIPModel,zero-shot-image-classification
25+
CLIPTextModel,feature-extraction
2526
CLIPVisionModel,feature-extraction
2627
CamembertModel,feature-extraction
2728
CodeGenModel,feature-extraction

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def get_untrained_model_with_inputs(
106106
print(f"[get_untrained_model_with_inputs] architectures={archs!r}")
107107
print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
108108
if task is None:
109-
task = task_from_arch(archs[0], model_id=model_id)
109+
task = task_from_arch(archs[0], model_id=model_id, subfolder=subfolder)
110110
if verbose:
111111
print(f"[get_untrained_model_with_inputs] task={task!r}")
112112

onnx_diagnostic/torch_models/validate.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def validate_model(
263263
use_pretrained: bool = False,
264264
optimization: Optional[str] = None,
265265
quiet: bool = False,
266-
patch: bool = False,
266+
patch: Union[bool, str, Dict[str, bool]] = False,
267267
rewrite: bool = False,
268268
stop_if_static: int = 1,
269269
dump_folder: Optional[str] = None,
@@ -301,8 +301,10 @@ def validate_model(
301301
:param optimization: optimization to apply to the exported model,
302302
depend on the the exporter
303303
:param quiet: if quiet, catches exception if any issue
304-
:param patch: applies patches (``patch_transformers=True``) before exporting,
305-
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
304+
:param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
305+
if True before exporting
306+
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`,
307+
a string can be used to specify only one of them
306308
:param rewrite: applies known rewriting (``patch_transformers=True``) before exporting,
307309
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
308310
:param stop_if_static: stops if a dynamic dimension becomes static,
@@ -346,6 +348,13 @@ def validate_model(
346348
exported model returns the same outputs as the original one, otherwise,
347349
:class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
348350
"""
351+
if isinstance(patch, bool):
352+
patch = dict(patch_transformers=True, patch_diffusers=True) if patch else {}
353+
elif isinstance(patch, str):
354+
patch = {p: True for p in patch.split(",")} # noqa: C420
355+
else:
356+
assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}"
357+
349358
assert not rewrite or patch, (
350359
f"rewrite={rewrite}, patch={patch}, patch must be True to enable rewriting, "
351360
f"if --no-patch was specified on the command line, --no-rewrite must be added."
@@ -580,11 +589,11 @@ def validate_model(
580589
f"stop_if_static={stop_if_static}"
581590
)
582591
with torch_export_patches( # type: ignore
583-
patch_transformers=True,
584592
stop_if_static=stop_if_static,
585593
verbose=max(0, verbose - 1),
586594
rewrite=data.get("rewrite", None),
587595
dump_rewriting=(os.path.join(dump_folder, "rewrite") if dump_folder else None),
596+
**patch,
588597
) as modificator:
589598
data["inputs_export"] = modificator(data["inputs"]) # type: ignore
590599

0 commit comments

Comments
 (0)