Skip to content

Commit b23031b

Browse files
committed
add pick
1 parent 4c6cebd commit b23031b

File tree

3 files changed

+26
-9
lines changed

3 files changed

+26
-9
lines changed

onnx_diagnostic/helpers/config_helper.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,18 @@ def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None):
6969
raise AssertionError(f"Unable to find any of these {atts!r} in {config}")
7070

7171

72+
def pick(config, name: str, default_value: Any) -> Any:
73+
"""
74+
Returns the vlaue of a attribute if config has it
75+
otherwise the default value.
76+
"""
77+
if not config:
78+
return default_value
79+
if type(config) is dict:
80+
return config.get(name, default_value)
81+
return getattr(config, name, default_value)
82+
83+
7284
@functools.cache
7385
def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[type]:
7486
"""

onnx_diagnostic/tasks/text_to_image.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Callable, Dict, Optional, Tuple
22
import torch
3-
from ..helpers.config_helper import update_config, check_hasattr
3+
from ..helpers.config_helper import update_config, check_hasattr, pick
44

55
__TASK__ = "text-to-image"
66

@@ -82,10 +82,10 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
8282
check_hasattr(config, "sample_size", "cross_attention_dim", "in_channels")
8383
kwargs = dict(
8484
batch_size=2,
85-
sequence_length=config["in_channels"],
85+
sequence_length=pick(config, "in_channels", 4),
8686
cache_length=77,
87-
in_channels=config["in_channels"],
88-
sample_size=config["sample_size"],
89-
cross_attention_dim=config["cross_attention_dim"],
87+
in_channels=pick(config, "in_channels", 4),
88+
sample_size=pick(config, "sample_size", 32),
89+
cross_attention_dim=pick(config, "cross_attention_dim", 64),
9090
)
9191
return kwargs, get_inputs

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,15 @@ def get_untrained_model_with_inputs(
200200
f"and use_pretrained=True."
201201
)
202202

203-
if type(config) is dict:
204-
model = cls_model(**config)
205-
else:
206-
model = cls_model(config)
203+
try:
204+
if type(config) is dict:
205+
model = cls_model(**config)
206+
else:
207+
model = cls_model(config)
208+
except RuntimeError as e:
209+
raise RuntimeError(
210+
f"Unable to instantiate class {cls_model.__name__} with\n{config}"
211+
) from e
207212

208213
# input kwargs
209214
kwargs, fct = random_input_kwargs(config, task)

0 commit comments

Comments
 (0)