diff --git a/_doc/examples/plot_export_tiny_phi2.py b/_doc/examples/plot_export_tiny_phi2.py index cc2d1df5..918b9481 100644 --- a/_doc/examples/plot_export_tiny_phi2.py +++ b/_doc/examples/plot_export_tiny_phi2.py @@ -89,7 +89,6 @@ with torch_export_patches(patch_transformers=True): - # Two unnecessary steps but useful in case of an error # We check the cache is registered. assert is_cache_dynamic_registered() diff --git a/_unittests/ut_export/test_shape_helper.py b/_unittests/ut_export/test_shape_helper.py index 2917fa7b..798867a1 100644 --- a/_unittests/ut_export/test_shape_helper.py +++ b/_unittests/ut_export/test_shape_helper.py @@ -16,7 +16,6 @@ class TestShapeHelper(ExtTestCase): - @requires_transformers("4.52") @requires_torch("2.7.99") def test_all_dynamic_shape_from_cache(self): diff --git a/_unittests/ut_helpers/test_onnx_helper.py b/_unittests/ut_helpers/test_onnx_helper.py index 5c22cae0..fc0704ca 100644 --- a/_unittests/ut_helpers/test_onnx_helper.py +++ b/_unittests/ut_helpers/test_onnx_helper.py @@ -26,7 +26,6 @@ class TestOnnxHelper(ExtTestCase): - def _get_model(self): model = oh.make_model( oh.make_graph( diff --git a/_unittests/ut_helpers/test_ort_session.py b/_unittests/ut_helpers/test_ort_session.py index dc899f2e..b41c7721 100644 --- a/_unittests/ut_helpers/test_ort_session.py +++ b/_unittests/ut_helpers/test_ort_session.py @@ -24,7 +24,6 @@ class TestOrtSession(ExtTestCase): - @classmethod def _range(cls, *shape, bias: Optional[float] = None): n = np.prod(shape) diff --git a/_unittests/ut_helpers/test_ort_session_tinyllm.py b/_unittests/ut_helpers/test_ort_session_tinyllm.py index 82a52ce5..0048bf1f 100644 --- a/_unittests/ut_helpers/test_ort_session_tinyllm.py +++ b/_unittests/ut_helpers/test_ort_session_tinyllm.py @@ -19,7 +19,6 @@ class TestOrtSessionTinyLLM(ExtTestCase): - def test_ort_value(self): val = np.array([30, 31, 32], dtype=np.int64) ort = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(val, onnx.TensorProto.INT64) diff --git a/_unittests/ut_helpers/test_torch_helper.py b/_unittests/ut_helpers/test_torch_helper.py index c71b6ea9..2ca4c7a3 100644 --- a/_unittests/ut_helpers/test_torch_helper.py +++ b/_unittests/ut_helpers/test_torch_helper.py @@ -32,7 +32,6 @@ class TestTorchTestHelper(ExtTestCase): - def test_is_torchdynamo_exporting(self): self.assertFalse(is_torchdynamo_exporting()) diff --git a/_unittests/ut_reference/test_onnxruntime_evaluator.py b/_unittests/ut_reference/test_onnxruntime_evaluator.py index 738b0ea3..27f1c89f 100644 --- a/_unittests/ut_reference/test_onnxruntime_evaluator.py +++ b/_unittests/ut_reference/test_onnxruntime_evaluator.py @@ -23,7 +23,6 @@ class TestOnnxruntimeEvaluator(ExtTestCase): def test_ort_eval_scan_cdist_add(self): - def dist(unused: torch.Tensor, x: torch.Tensor, samex: torch.Tensor): sub = samex - x.reshape((1, -1)) sq = sub * sub diff --git a/_unittests/ut_torch_export_patches/test_dynamic_class.py b/_unittests/ut_torch_export_patches/test_dynamic_class.py index 9396ac41..459124be 100644 --- a/_unittests/ut_torch_export_patches/test_dynamic_class.py +++ b/_unittests/ut_torch_export_patches/test_dynamic_class.py @@ -227,7 +227,6 @@ def unflatten_my_cache_78(values, context: TreeContext, output_type=None) -> MyC @ignore_warnings(UserWarning) def test_export_dynamic_cache_cat(self): - class ModelDynamicCache(torch.nn.Module): def forward(self, x, dc): y = ( diff --git a/_unittests/ut_torch_export_patches/test_patch_expressions.py b/_unittests/ut_torch_export_patches/test_patch_expressions.py index b65908c0..2b1a98c0 100644 --- a/_unittests/ut_torch_export_patches/test_patch_expressions.py +++ b/_unittests/ut_torch_export_patches/test_patch_expressions.py @@ -11,7 +11,6 @@ class TestOnnxExportErrors(ExtTestCase): - @classmethod def setUp(cls): register_patched_expressions() diff --git a/_unittests/ut_torch_export_patches/test_patch_loops.py b/_unittests/ut_torch_export_patches/test_patch_loops.py index 4ec01c4a..cfb521e6 100644 --- a/_unittests/ut_torch_export_patches/test_patch_loops.py +++ b/_unittests/ut_torch_export_patches/test_patch_loops.py @@ -14,7 +14,6 @@ class TestOnnxExportErrors(ExtTestCase): - def test_patched_expressions(self): res = list(_iterate_patched_expressions()) names = {_[0] for _ in res} @@ -22,7 +21,6 @@ def test_patched_expressions(self): @requires_torch("2.8") def test_filter_position_ids(self): - def filter_position_ids( patch_attention_mask: torch.Tensor, position_ids: torch.Tensor, @@ -57,7 +55,6 @@ def scan_filter_position_ids( boundaries: torch.Tensor, num_patches_per_side: int, ): - def body(p_attn_mask, position_ids_row): h_len = torch.tensor(1) / p_attn_mask[:, 0].sum() w_len = torch.tensor(1) / p_attn_mask[0].sum() diff --git a/_unittests/ut_torch_export_patches/test_patch_module.py b/_unittests/ut_torch_export_patches/test_patch_module.py index 457d471f..0fc4e108 100644 --- a/_unittests/ut_torch_export_patches/test_patch_module.py +++ b/_unittests/ut_torch_export_patches/test_patch_module.py @@ -48,7 +48,6 @@ def forward(self, x, y): ), f"Missing parent in {ast.dump(tree, indent=2)}" def test_rewrite_test_in_forward_return1(self): - class Model(torch.nn.Module): def forward(self, x, y): if x.sum() > 0: @@ -75,7 +74,6 @@ def forward(self, x, y): @hide_stdout() def test_rewrite_test_in_forward_return2(self): - class Model(torch.nn.Module): def forward(self, x, y): if x.sum() > 0: @@ -101,7 +99,6 @@ def forward(self, x, y): self.assertEqualAny(expected_, ep.module()(-x, y)) def test_rewrite_test_in_forward_assign1(self): - class Model(torch.nn.Module): def forward(self, x, y): if x.sum() > 0: @@ -128,7 +125,6 @@ def forward(self, x, y): self.assertEqualArray(expected_, ep.module()(-x, y)) def test_rewrite_test_in_forward_assign2(self): - class Model(torch.nn.Module): def forward(self, x, y): if x.sum() > 0: @@ -155,10 +151,8 @@ def forward(self, x, y): self.assertEqualAny(expected_, ep.module()(-x, y)) def test_check_syntax_assign_noelse(self): - class Model(torch.nn.Module): def forward(self, x, y): - def branch_cond_then_1(x): x = torch.abs(x) + 1 return x @@ -179,7 +173,6 @@ def branch_cond_else_1(x): self.assertEqualAny(expected_, ep.module()(-x, y)) def test_rewrite_test_in_forward_assign_noelse(self): - class Model(torch.nn.Module): def forward(self, x, y): if x.sum() > 0: @@ -204,7 +197,6 @@ def forward(self, x, y): self.assertEqualAny(expected_, ep.module()(-x, y)) def test_rewrite_test_in_forward_return_noelse(self): - class Model(torch.nn.Module): def forward(self, x, y): if x.sum() > 0: @@ -216,7 +208,6 @@ def forward(self, x, y): ) def test_rewrite_test_in_forward_assign2_in_2(self): - class Model(torch.nn.Module): def forward(self, x, y): if x.sum() > 0: @@ -245,7 +236,6 @@ def forward(self, x, y): self.assertEqualAny(expected_, ep.module()(-x, y)) def test_rewrite_test_in_forward_assign2_in_3(self): - class Model(torch.nn.Module): def forward(self, x, y): if x.sum() > 0: @@ -277,13 +267,11 @@ def forward(self, x, y): self.assertEqualAny(expected_, ep.module()(-x, y)) def test_assign_nested_check(self): - torch_cond = torch.cond class Model(torch.nn.Module): def forward(self, x, y): def torch_cond_then_3(y, x): - def torch_cond_then_1(y, x): w = x + y z = x - y @@ -301,7 +289,6 @@ def torch_cond_else_1(y, x): return (w, z) def torch_cond_else_3(y, x): - def torch_cond_then_2(y): u = y + 1 return u @@ -322,7 +309,6 @@ def torch_cond_else_2(y): Model()(x, y) def test_rewrite_test_in_forward_assign_nested(self): - class Model(torch.nn.Module): def forward(self, x, y): if x.sum() > 0: @@ -371,7 +357,6 @@ def forward(self, x, y): self.assertEqualAny(expected_1, ep.module()(-x, -y)) def test_rewrite_test_in_forward_none(self): - class Model(torch.nn.Module): def forward(self, x, y): if x is None: @@ -513,7 +498,6 @@ def test__find_loop_vars(self): @requires_torch("2.8") def test_rewrite_loop(self): - class Model(torch.nn.Module): def forward(self, x, y): z = torch.empty((x.shape[0], y.shape[0])) diff --git a/_unittests/ut_torch_models/test_hghub_api.py b/_unittests/ut_torch_models/test_hghub_api.py index 10a9689e..31c299e9 100644 --- a/_unittests/ut_torch_models/test_hghub_api.py +++ b/_unittests/ut_torch_models/test_hghub_api.py @@ -27,7 +27,6 @@ class TestHuggingFaceHubApi(ExtTestCase): - @requires_transformers("4.50") # we limit to some versions of the CI @requires_torch("2.7") @ignore_errors(OSError) # connectivity issues diff --git a/_unittests/ut_torch_models/test_hghub_mode_rewrite.py b/_unittests/ut_torch_models/test_hghub_mode_rewrite.py index 1dcdca82..04d978b4 100644 --- a/_unittests/ut_torch_models/test_hghub_mode_rewrite.py +++ b/_unittests/ut_torch_models/test_hghub_mode_rewrite.py @@ -13,7 +13,6 @@ class TestHuggingFaceHubModelRewrite(ExtTestCase): - def test_code_needing_rewriting(self): self.assertEqual(2, len(code_needing_rewriting("BartForConditionalGeneration"))) diff --git a/_unittests/ut_torch_models/test_tiny_llms_bypassed.py b/_unittests/ut_torch_models/test_tiny_llms_bypassed.py index 4d0d7b0c..1db73d76 100644 --- a/_unittests/ut_torch_models/test_tiny_llms_bypassed.py +++ b/_unittests/ut_torch_models/test_tiny_llms_bypassed.py @@ -28,7 +28,6 @@ def test_export_tiny_llm_2_bypassed(self): with torch_export_patches( patch_torch=False, patch_transformers=True, catch_constraints=False, verbose=10 ) as modificator: - for k in patched_DynamicCache._PATCHES_: self.assertEqual(getattr(patched_DynamicCache, k), getattr(DynamicCache, k)) diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index 485a97b6..a1f5c56a 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -15,7 +15,6 @@ class TestSideBySide(ExtTestCase): - @hide_stdout() @unittest.skipIf(to_onnx is None, "to_onnx not installed") @ignore_errors(OSError) # connectivity issues diff --git a/_unittests/ut_xrun_doc/test_doc_doc.py b/_unittests/ut_xrun_doc/test_doc_doc.py index 986d9696..c55469e6 100644 --- a/_unittests/ut_xrun_doc/test_doc_doc.py +++ b/_unittests/ut_xrun_doc/test_doc_doc.py @@ -4,7 +4,6 @@ class TestDocDoc(ExtTestCase): - def test_reset(self): reset_torch_transformers(None, None) diff --git a/onnx_diagnostic/export/dynamic_shapes.py b/onnx_diagnostic/export/dynamic_shapes.py index c5d0695d..c2b6bd6f 100644 --- a/onnx_diagnostic/export/dynamic_shapes.py +++ b/onnx_diagnostic/export/dynamic_shapes.py @@ -341,7 +341,7 @@ def _generic_walker_step( if any(v is not None for v in value) else None ) - assert type(inputs) is dict, f"Unexpected type for inputs {type(inputs)}" + assert isinstance(inputs, dict), f"Unexpected type for inputs {type(inputs)}" assert set(inputs) == set(ds), ( f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, " f"inputs={string_type(inputs, with_shape=True)}, ds={ds}" diff --git a/onnx_diagnostic/helpers/config_helper.py b/onnx_diagnostic/helpers/config_helper.py index bd4f987d..450b33fe 100644 --- a/onnx_diagnostic/helpers/config_helper.py +++ b/onnx_diagnostic/helpers/config_helper.py @@ -38,12 +38,12 @@ def update_config(config: Any, mkwargs: Dict[str, Any]): setattr(config, k, v) continue existing = getattr(config, k) - if type(existing) is dict: + if isinstance(existing, dict): existing.update(v) else: update_config(getattr(config, k), v) continue - if type(config) is dict: + if isinstance(config, dict): config[k] = v else: setattr(config, k, v) @@ -76,7 +76,7 @@ def pick(config, name: str, default_value: Any) -> Any: """ if not config: return default_value - if type(config) is dict: + if isinstance(config, dict): return config.get(name, default_value) return getattr(config, name, default_value) diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index 7242806a..0e8601f3 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -280,7 +280,7 @@ def string_type( print(f"[string_type] L:{type(obj)}") return f"{{...}}#{len(obj)}" if with_shape else "{...}" # dict - if isinstance(obj, dict) and type(obj) is dict: + if isinstance(obj, dict): if len(obj) == 0: if verbose: print(f"[string_type] M:{type(obj)}") diff --git a/onnx_diagnostic/helpers/mini_onnx_builder.py b/onnx_diagnostic/helpers/mini_onnx_builder.py index 8fe1ef7d..d8b61f74 100644 --- a/onnx_diagnostic/helpers/mini_onnx_builder.py +++ b/onnx_diagnostic/helpers/mini_onnx_builder.py @@ -498,7 +498,7 @@ def _make(ty: type, res: Any) -> Any: for k, v in res: setattr(r, k, v) return r - if ty is dict: + if isinstance(res, dict): d = {} for k, v in res: if k.startswith("((") and k.endswith("))"): diff --git a/onnx_diagnostic/helpers/ort_session.py b/onnx_diagnostic/helpers/ort_session.py index 8138f896..e78b3f1d 100644 --- a/onnx_diagnostic/helpers/ort_session.py +++ b/onnx_diagnostic/helpers/ort_session.py @@ -19,7 +19,6 @@ class _InferenceSession: - @classmethod def has_onnxruntime_training(cls): """Tells if onnxruntime_training is installed.""" diff --git a/onnx_diagnostic/helpers/rt_helper.py b/onnx_diagnostic/helpers/rt_helper.py index ebd6e157..77ba608b 100644 --- a/onnx_diagnostic/helpers/rt_helper.py +++ b/onnx_diagnostic/helpers/rt_helper.py @@ -79,7 +79,6 @@ def make_feeds( if len(names) < len(flat) and ( isinstance(proto, onnx.ModelProto) or hasattr(proto, "get_inputs") ): - typed_names = ( [(i.name, i.type.tensor_type.elem_type) for i in proto.graph.input] if isinstance(proto, onnx.ModelProto) diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index 27fe4081..5b19a309 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -526,7 +526,6 @@ def forward(self, x): return word_emb + word_pe class AttentionBlock(torch.nn.Module): - def __init__(self, embedding_dim: int = 16, context_size: int = 256): super().__init__() self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) @@ -555,7 +554,6 @@ def forward(self, x): return out class MultiAttentionBlock(torch.nn.Module): - def __init__( self, embedding_dim: int = 16, num_heads: int = 2, context_size: int = 256 ): @@ -573,7 +571,6 @@ def forward(self, x): return x class FeedForward(torch.nn.Module): - def __init__(self, embedding_dim: int = 16, ff_dim: int = 128): super().__init__() self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim) @@ -587,7 +584,6 @@ def forward(self, x): return x class DecoderLayer(torch.nn.Module): - def __init__( self, embedding_dim: int = 16, @@ -613,7 +609,6 @@ def forward(self, x): return ff class LLM(torch.nn.Module): - def __init__( self, vocab_size: int = 1024, @@ -717,7 +712,7 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any: return tuple(to_any(t, to_value) for t in value) if isinstance(value, set): return {to_any(t, to_value) for t in value} - if type(value) is dict: + if isinstance(value, dict): return {k: to_any(t, to_value) for k, t in value.items()} if value.__class__.__name__ == "DynamicCache": return make_dynamic_cache( diff --git a/onnx_diagnostic/reference/ops/op_scan.py b/onnx_diagnostic/reference/ops/op_scan.py index bcf80966..5762ed73 100644 --- a/onnx_diagnostic/reference/ops/op_scan.py +++ b/onnx_diagnostic/reference/ops/op_scan.py @@ -3,7 +3,6 @@ class Scan(_Scan): - def need_context(self) -> bool: """Tells the runtime if this node needs the context (all the results produced so far) as it may silently access diff --git a/onnx_diagnostic/torch_export_patches/eval/model_cases.py b/onnx_diagnostic/torch_export_patches/eval/model_cases.py index caf4c405..229ca2dc 100644 --- a/onnx_diagnostic/torch_export_patches/eval/model_cases.py +++ b/onnx_diagnostic/torch_export_patches/eval/model_cases.py @@ -358,7 +358,6 @@ class ControlFlowCondIdentity_153832(torch.nn.Module): """ def forward(self, x, y): - def branch_cond_then_1(x): x = torch.abs(x) + 1 return x @@ -789,7 +788,6 @@ def forward(self, x): class CropLastDimensionWithTensorShape(torch.nn.Module): - def forward(self, x, y): return x[..., : y.shape[0]] diff --git a/onnx_diagnostic/torch_export_patches/patch_module_helper.py b/onnx_diagnostic/torch_export_patches/patch_module_helper.py index fe8c8b14..729e6cb1 100644 --- a/onnx_diagnostic/torch_export_patches/patch_module_helper.py +++ b/onnx_diagnostic/torch_export_patches/patch_module_helper.py @@ -22,7 +22,6 @@ def ast_or_into_bitor(node: "ast.Node") -> "ast.Node": @functools.lru_cache def _rewrite_forward_clamp_float16() -> Dict[str, List[type]]: - import transformers _known = { diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index b053aaef..5a1c005f 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -166,7 +166,6 @@ def patched__broadcast_shapes(*_shapes): class patched_ShapeEnv: - def _check_frozen( self, expr: "sympy.Basic", concrete_val: "sympy.Basic" # noqa: F821 ) -> None: diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 0a574909..18bea52b 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -1103,9 +1103,11 @@ def forward( .transpose(1, 2) ) else: - _, kv_len, _ = ( - key_value_states.size() - ) # Note that, in this case, `kv_len` == `kv_seq_len` + ( + _, + kv_len, + _, + ) = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len` key_states = ( self.k_proj(key_value_states) .view(bsz, kv_len, self.num_heads, self.head_dim) @@ -1127,10 +1129,11 @@ def forward( torch.tensor(q_len, dtype=torch.int64), ) cos, sin = self.rotary_emb(value_states, seq_len=rotary_length) - query_states, key_states = ( - transformers.models.idefics.modeling_idefics.apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) + ( + query_states, + key_states, + ) = transformers.models.idefics.modeling_idefics.apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids ) # [bsz, nh, t, hd] diff --git a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py index 3b2dc899..d8d064f2 100644 --- a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +++ b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py @@ -67,7 +67,9 @@ def __init__(self): return cache -def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[ +def flatten_with_keys_mamba_cache( + cache: MambaCache, +) -> Tuple[ List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context, ]: @@ -224,7 +226,9 @@ def flatten_encoder_decoder_cache( return torch.utils._pytree._dict_flatten(dictionary) -def flatten_with_keys_encoder_decoder_cache(ec_cache: EncoderDecoderCache) -> Tuple[ +def flatten_with_keys_encoder_decoder_cache( + ec_cache: EncoderDecoderCache, +) -> Tuple[ List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context, ]: diff --git a/onnx_diagnostic/torch_models/hghub/hub_api.py b/onnx_diagnostic/torch_models/hghub/hub_api.py index da05f7c0..b020f795 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_api.py +++ b/onnx_diagnostic/torch_models/hghub/hub_api.py @@ -119,9 +119,23 @@ def get_pretrained_config( # Diffusers uses a dictionayr. with open(config, "r") as f: return json.load(f) - return transformers.AutoConfig.from_pretrained( - model_id, trust_remote_code=trust_remote_code, **kwargs - ) + try: + config = transformers.AutoConfig.from_pretrained( + model_id, trust_remote_code=trust_remote_code, **kwargs + ) + except ValueError: + # The model might be from diffusers, not transformers. + try: + import diffusers + + pipe = diffusers.DiffusionPipeline.from_pretrained( + model_id, trust_remote_code=trust_remote_code, **kwargs + ) + config = pipe.unet.config + except Exception as exc: + raise ValueError(f"Unable to retrieve the configuration for {model_id!r}") from exc + + return config def get_model_info(model_id) -> Any: @@ -211,7 +225,7 @@ def task_from_id( data = load_architecture_task() if model_id in data: return data[model_id] - if type(config) is dict and "_class_name" in config: + if isinstance(config, dict) and "_class_name" in config: return task_from_arch(config["_class_name"], default_value=default_value) if not config.architectures or not config.architectures: # Some hardcoded values until a better solution is found. @@ -362,7 +376,7 @@ def download_code_modelid( paths = set() for i, name in enumerate(pyfiles): if verbose: - print(f"[download_code_modelid] download file {i+1}/{len(pyfiles)}: {name!r}") + print(f"[download_code_modelid] download file {i + 1}/{len(pyfiles)}: {name!r}") r = hf_hub_download(repo_id=model_id, filename=name) p = os.path.split(r)[0] paths.add(p) diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index c9d97a99..b51a0d5e 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -7,6 +7,7 @@ from ...helpers.config_helper import update_config from ...tasks import reduce_model_config, random_input_kwargs from .hub_api import task_from_arch, task_from_id, get_pretrained_config, download_code_modelid +import diffusers def _code_needing_rewriting(model: Any) -> Any: @@ -18,7 +19,7 @@ def _code_needing_rewriting(model: Any) -> Any: def get_untrained_model_with_inputs( model_id: str, config: Optional[Any] = None, - task: Optional[str] = "", + task: Optional[str] = None, inputs_kwargs: Optional[Dict[str, Any]] = None, model_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, @@ -88,14 +89,20 @@ def get_untrained_model_with_inputs( **(model_kwargs or {}), ) - if hasattr(config, "architecture") and config.architecture: - archs = [config.architecture] - if type(config) is dict: - assert "_class_name" in config, f"Unable to get the architecture from config={config}" - archs = [config["_class_name"]] + # Extract architecture information from config + archs = None + if isinstance(config, dict): + if "_class_name" in config: + archs = [config["_class_name"]] + else: + raise ValueError(f"Unable to get the architecture from config={config}") else: - archs = config.architectures # type: ignore - task = None + # Config is an object (e.g., transformers config) + if hasattr(config, "architecture") and config.architecture: + archs = [config.architecture] + elif hasattr(config, "architectures") and config.architectures: + archs = config.architectures + if archs is None: task = task_from_id(model_id) assert task is not None or (archs is not None and len(archs) == 1), ( @@ -106,6 +113,7 @@ def get_untrained_model_with_inputs( print(f"[get_untrained_model_with_inputs] architectures={archs!r}") print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}") if task is None: + assert archs is not None task = task_from_arch(archs[0], model_id=model_id, subfolder=subfolder) if verbose: print(f"[get_untrained_model_with_inputs] task={task!r}") @@ -150,9 +158,7 @@ def get_untrained_model_with_inputs( f"{getattr(config, '_attn_implementation', '?')!r}" # type: ignore[union-attr] ) - if type(config) is dict and "_diffusers_version" in config: - import diffusers - + if isinstance(config, dict) and "_diffusers_version" in config: package_source = diffusers else: package_source = transformers @@ -206,7 +212,7 @@ def get_untrained_model_with_inputs( ) try: - if type(config) is dict: + if isinstance(config, dict): model = cls_model(**config) else: model = cls_model(config) diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index 8362ae9e..475803e0 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -252,7 +252,6 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]: """Shrinks the configuration before it gets added to the information to log.""" new_cfg = {} for k, v in cfg.items(): - new_cfg[k] = ( v if (not isinstance(v, (list, tuple, set, dict)) or len(v) < 50) @@ -576,7 +575,7 @@ def validate_model( summary["model_config"] = str( shrink_config( data["configuration"] - if type(data["configuration"]) is dict + if isinstance(data["configuration"], dict) else data["configuration"].to_dict() ) ).replace(" ", "") @@ -827,7 +826,6 @@ def _validate_do_run_model( def _validate_do_run_exported_program(data, summary, verbose, quiet): - # We run a second time the model to check the patch did not # introduce any discrepancies if verbose: