Skip to content

Commit 0bbaf4d

Browse files
committed
Support unet model
1 parent 65815ac commit 0bbaf4d

File tree

8 files changed

+46
-28
lines changed

8 files changed

+46
-28
lines changed

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def _generic_walker_step(
341341
if any(v is not None for v in value)
342342
else None
343343
)
344-
assert type(inputs) is dict, f"Unexpected type for inputs {type(inputs)}"
344+
assert isinstance(inputs, dict), f"Unexpected type for inputs {type(inputs)}"
345345
assert set(inputs) == set(ds), (
346346
f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, "
347347
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"

onnx_diagnostic/helpers/config_helper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ def update_config(config: Any, mkwargs: Dict[str, Any]):
3838
setattr(config, k, v)
3939
continue
4040
existing = getattr(config, k)
41-
if type(existing) is dict:
41+
if isinstance(existing, dict):
4242
existing.update(v)
4343
else:
4444
update_config(getattr(config, k), v)
4545
continue
46-
if type(config) is dict:
46+
if isinstance(config, dict):
4747
config[k] = v
4848
else:
4949
setattr(config, k, v)
@@ -76,7 +76,7 @@ def pick(config, name: str, default_value: Any) -> Any:
7676
"""
7777
if not config:
7878
return default_value
79-
if type(config) is dict:
79+
if isinstance(config, dict):
8080
return config.get(name, default_value)
8181
return getattr(config, name, default_value)
8282

onnx_diagnostic/helpers/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def string_type(
280280
print(f"[string_type] L:{type(obj)}")
281281
return f"{{...}}#{len(obj)}" if with_shape else "{...}"
282282
# dict
283-
if isinstance(obj, dict) and type(obj) is dict:
283+
if isinstance(obj, dict):
284284
if len(obj) == 0:
285285
if verbose:
286286
print(f"[string_type] M:{type(obj)}")

onnx_diagnostic/helpers/mini_onnx_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def _make(ty: type, res: Any) -> Any:
498498
for k, v in res:
499499
setattr(r, k, v)
500500
return r
501-
if ty is dict:
501+
if isinstance(res, dict):
502502
d = {}
503503
for k, v in res:
504504
if k.startswith("((") and k.endswith("))"):

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
717717
return tuple(to_any(t, to_value) for t in value)
718718
if isinstance(value, set):
719719
return {to_any(t, to_value) for t in value}
720-
if type(value) is dict:
720+
if isinstance(value, dict):
721721
return {k: to_any(t, to_value) for k, t in value.items()}
722722
if value.__class__.__name__ == "DynamicCache":
723723
return make_dynamic_cache(

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ...helpers.config_helper import update_config
1111
from . import hub_data_cached_configs
1212
from .hub_data import __date__, __data_tasks__, load_architecture_task, __data_arch_values__
13+
import diffusers
1314

1415

1516
@functools.cache
@@ -119,9 +120,21 @@ def get_pretrained_config(
119120
# Diffusers uses a dictionayr.
120121
with open(config, "r") as f:
121122
return json.load(f)
122-
return transformers.AutoConfig.from_pretrained(
123-
model_id, trust_remote_code=trust_remote_code, **kwargs
124-
)
123+
try:
124+
config = transformers.AutoConfig.from_pretrained(
125+
model_id, trust_remote_code=trust_remote_code, **kwargs
126+
)
127+
except ValueError:
128+
# The model might be from diffusers, not transformers.
129+
try:
130+
pipe = diffusers.DiffusionPipeline.from_pretrained(
131+
model_id, trust_remote_code=trust_remote_code, **kwargs
132+
)
133+
config = pipe.unet.config
134+
except Exception as exc:
135+
raise ValueError(f"Unable to retrieve the configuration for {model_id!r}") from exc
136+
137+
return config
125138

126139

127140
def get_model_info(model_id) -> Any:
@@ -211,7 +224,7 @@ def task_from_id(
211224
data = load_architecture_task()
212225
if model_id in data:
213226
return data[model_id]
214-
if type(config) is dict and "_class_name" in config:
227+
if isinstance(config, dict) and "_class_name" in config:
215228
return task_from_arch(config["_class_name"], default_value=default_value)
216229
if not config.architectures or not config.architectures:
217230
# Some hardcoded values until a better solution is found.
@@ -362,7 +375,7 @@ def download_code_modelid(
362375
paths = set()
363376
for i, name in enumerate(pyfiles):
364377
if verbose:
365-
print(f"[download_code_modelid] download file {i+1}/{len(pyfiles)}: {name!r}")
378+
print(f"[download_code_modelid] download file {i + 1}/{len(pyfiles)}: {name!r}")
366379
r = hf_hub_download(repo_id=model_id, filename=name)
367380
p = os.path.split(r)[0]
368381
paths.add(p)

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ...helpers.config_helper import update_config
88
from ...tasks import reduce_model_config, random_input_kwargs
99
from .hub_api import task_from_arch, task_from_id, get_pretrained_config, download_code_modelid
10+
import diffusers
1011

1112

1213
def _code_needing_rewriting(model: Any) -> Any:
@@ -18,7 +19,7 @@ def _code_needing_rewriting(model: Any) -> Any:
1819
def get_untrained_model_with_inputs(
1920
model_id: str,
2021
config: Optional[Any] = None,
21-
task: Optional[str] = "",
22+
task: Optional[str] = None,
2223
inputs_kwargs: Optional[Dict[str, Any]] = None,
2324
model_kwargs: Optional[Dict[str, Any]] = None,
2425
verbose: int = 0,
@@ -88,14 +89,20 @@ def get_untrained_model_with_inputs(
8889
**(model_kwargs or {}),
8990
)
9091

91-
if hasattr(config, "architecture") and config.architecture:
92-
archs = [config.architecture]
93-
if type(config) is dict:
94-
assert "_class_name" in config, f"Unable to get the architecture from config={config}"
95-
archs = [config["_class_name"]]
92+
# Extract architecture information from config
93+
archs = None
94+
if isinstance(config, dict):
95+
if "_class_name" in config:
96+
archs = [config["_class_name"]]
97+
else:
98+
raise ValueError(f"Unable to get the architecture from config={config}")
9699
else:
97-
archs = config.architectures # type: ignore
98-
task = None
100+
# Config is an object (e.g., transformers config)
101+
if hasattr(config, "architecture") and config.architecture:
102+
archs = [config.architecture]
103+
elif hasattr(config, "architectures") and config.architectures:
104+
archs = config.architectures
105+
99106
if archs is None:
100107
task = task_from_id(model_id)
101108
assert task is not None or (archs is not None and len(archs) == 1), (
@@ -112,9 +119,9 @@ def get_untrained_model_with_inputs(
112119

113120
# model kwagrs
114121
if dynamic_rope is not None:
115-
assert (
116-
type(config) is not dict
117-
), f"Unable to set dynamic_rope if the configuration is a dictionary\n{config}"
122+
assert type(config) is not dict, (
123+
f"Unable to set dynamic_rope if the configuration is a dictionary\n{config}"
124+
)
118125
assert hasattr(config, "rope_scaling"), f"Missing 'rope_scaling' in\n{config}"
119126
config.rope_scaling = (
120127
{"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None
@@ -150,9 +157,7 @@ def get_untrained_model_with_inputs(
150157
f"{getattr(config, '_attn_implementation', '?')!r}" # type: ignore[union-attr]
151158
)
152159

153-
if type(config) is dict and "_diffusers_version" in config:
154-
import diffusers
155-
160+
if isinstance(config, dict) and "_diffusers_version" in config:
156161
package_source = diffusers
157162
else:
158163
package_source = transformers
@@ -206,7 +211,7 @@ def get_untrained_model_with_inputs(
206211
)
207212

208213
try:
209-
if type(config) is dict:
214+
if isinstance(config, dict):
210215
model = cls_model(**config)
211216
else:
212217
model = cls_model(config)

onnx_diagnostic/torch_models/validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def validate_model(
576576
summary["model_config"] = str(
577577
shrink_config(
578578
data["configuration"]
579-
if type(data["configuration"]) is dict
579+
if isinstance(data["configuration"], dict)
580580
else data["configuration"].to_dict()
581581
)
582582
).replace(" ", "")

0 commit comments

Comments
 (0)