Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion _doc/examples/plot_export_tiny_phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@


with torch_export_patches(patch_transformers=True):

# Two unnecessary steps but useful in case of an error
# We check the cache is registered.
assert is_cache_dynamic_registered()
Expand Down
1 change: 0 additions & 1 deletion _unittests/ut_export/test_shape_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


class TestShapeHelper(ExtTestCase):

@requires_transformers("4.52")
@requires_torch("2.7.99")
def test_all_dynamic_shape_from_cache(self):
Expand Down
1 change: 0 additions & 1 deletion _unittests/ut_helpers/test_onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@


class TestOnnxHelper(ExtTestCase):

def _get_model(self):
model = oh.make_model(
oh.make_graph(
Expand Down
1 change: 0 additions & 1 deletion _unittests/ut_helpers/test_ort_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@


class TestOrtSession(ExtTestCase):

@classmethod
def _range(cls, *shape, bias: Optional[float] = None):
n = np.prod(shape)
Expand Down
1 change: 0 additions & 1 deletion _unittests/ut_helpers/test_ort_session_tinyllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@


class TestOrtSessionTinyLLM(ExtTestCase):

def test_ort_value(self):
val = np.array([30, 31, 32], dtype=np.int64)
ort = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(val, onnx.TensorProto.INT64)
Expand Down
1 change: 0 additions & 1 deletion _unittests/ut_helpers/test_torch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@


class TestTorchTestHelper(ExtTestCase):

def test_is_torchdynamo_exporting(self):
self.assertFalse(is_torchdynamo_exporting())

Expand Down
1 change: 0 additions & 1 deletion _unittests/ut_reference/test_onnxruntime_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

class TestOnnxruntimeEvaluator(ExtTestCase):
def test_ort_eval_scan_cdist_add(self):

def dist(unused: torch.Tensor, x: torch.Tensor, samex: torch.Tensor):
sub = samex - x.reshape((1, -1))
sq = sub * sub
Expand Down
1 change: 0 additions & 1 deletion _unittests/ut_torch_export_patches/test_dynamic_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def unflatten_my_cache_78(values, context: TreeContext, output_type=None) -> MyC

@ignore_warnings(UserWarning)
def test_export_dynamic_cache_cat(self):

class ModelDynamicCache(torch.nn.Module):
def forward(self, x, dc):
y = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


class TestOnnxExportErrors(ExtTestCase):

@classmethod
def setUp(cls):
register_patched_expressions()
Expand Down
3 changes: 0 additions & 3 deletions _unittests/ut_torch_export_patches/test_patch_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@


class TestOnnxExportErrors(ExtTestCase):

def test_patched_expressions(self):
res = list(_iterate_patched_expressions())
names = {_[0] for _ in res}
self.assertIn("float_arange", names)

@requires_torch("2.8")
def test_filter_position_ids(self):

def filter_position_ids(
patch_attention_mask: torch.Tensor,
position_ids: torch.Tensor,
Expand Down Expand Up @@ -57,7 +55,6 @@ def scan_filter_position_ids(
boundaries: torch.Tensor,
num_patches_per_side: int,
):

def body(p_attn_mask, position_ids_row):
h_len = torch.tensor(1) / p_attn_mask[:, 0].sum()
w_len = torch.tensor(1) / p_attn_mask[0].sum()
Expand Down
16 changes: 0 additions & 16 deletions _unittests/ut_torch_export_patches/test_patch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def forward(self, x, y):
), f"Missing parent in {ast.dump(tree, indent=2)}"

def test_rewrite_test_in_forward_return1(self):

class Model(torch.nn.Module):
def forward(self, x, y):
if x.sum() > 0:
Expand All @@ -75,7 +74,6 @@ def forward(self, x, y):

@hide_stdout()
def test_rewrite_test_in_forward_return2(self):

class Model(torch.nn.Module):
def forward(self, x, y):
if x.sum() > 0:
Expand All @@ -101,7 +99,6 @@ def forward(self, x, y):
self.assertEqualAny(expected_, ep.module()(-x, y))

def test_rewrite_test_in_forward_assign1(self):

class Model(torch.nn.Module):
def forward(self, x, y):
if x.sum() > 0:
Expand All @@ -128,7 +125,6 @@ def forward(self, x, y):
self.assertEqualArray(expected_, ep.module()(-x, y))

def test_rewrite_test_in_forward_assign2(self):

class Model(torch.nn.Module):
def forward(self, x, y):
if x.sum() > 0:
Expand All @@ -155,10 +151,8 @@ def forward(self, x, y):
self.assertEqualAny(expected_, ep.module()(-x, y))

def test_check_syntax_assign_noelse(self):

class Model(torch.nn.Module):
def forward(self, x, y):

def branch_cond_then_1(x):
x = torch.abs(x) + 1
return x
Expand All @@ -179,7 +173,6 @@ def branch_cond_else_1(x):
self.assertEqualAny(expected_, ep.module()(-x, y))

def test_rewrite_test_in_forward_assign_noelse(self):

class Model(torch.nn.Module):
def forward(self, x, y):
if x.sum() > 0:
Expand All @@ -204,7 +197,6 @@ def forward(self, x, y):
self.assertEqualAny(expected_, ep.module()(-x, y))

def test_rewrite_test_in_forward_return_noelse(self):

class Model(torch.nn.Module):
def forward(self, x, y):
if x.sum() > 0:
Expand All @@ -216,7 +208,6 @@ def forward(self, x, y):
)

def test_rewrite_test_in_forward_assign2_in_2(self):

class Model(torch.nn.Module):
def forward(self, x, y):
if x.sum() > 0:
Expand Down Expand Up @@ -245,7 +236,6 @@ def forward(self, x, y):
self.assertEqualAny(expected_, ep.module()(-x, y))

def test_rewrite_test_in_forward_assign2_in_3(self):

class Model(torch.nn.Module):
def forward(self, x, y):
if x.sum() > 0:
Expand Down Expand Up @@ -277,13 +267,11 @@ def forward(self, x, y):
self.assertEqualAny(expected_, ep.module()(-x, y))

def test_assign_nested_check(self):

torch_cond = torch.cond

class Model(torch.nn.Module):
def forward(self, x, y):
def torch_cond_then_3(y, x):

def torch_cond_then_1(y, x):
w = x + y
z = x - y
Expand All @@ -301,7 +289,6 @@ def torch_cond_else_1(y, x):
return (w, z)

def torch_cond_else_3(y, x):

def torch_cond_then_2(y):
u = y + 1
return u
Expand All @@ -322,7 +309,6 @@ def torch_cond_else_2(y):
Model()(x, y)

def test_rewrite_test_in_forward_assign_nested(self):

class Model(torch.nn.Module):
def forward(self, x, y):
if x.sum() > 0:
Expand Down Expand Up @@ -371,7 +357,6 @@ def forward(self, x, y):
self.assertEqualAny(expected_1, ep.module()(-x, -y))

def test_rewrite_test_in_forward_none(self):

class Model(torch.nn.Module):
def forward(self, x, y):
if x is None:
Expand Down Expand Up @@ -513,7 +498,6 @@ def test__find_loop_vars(self):

@requires_torch("2.8")
def test_rewrite_loop(self):

class Model(torch.nn.Module):
def forward(self, x, y):
z = torch.empty((x.shape[0], y.shape[0]))
Expand Down
1 change: 0 additions & 1 deletion _unittests/ut_torch_models/test_hghub_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@


class TestHuggingFaceHubApi(ExtTestCase):

@requires_transformers("4.50") # we limit to some versions of the CI
@requires_torch("2.7")
@ignore_errors(OSError) # connectivity issues
Expand Down
1 change: 0 additions & 1 deletion _unittests/ut_torch_models/test_hghub_mode_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@


class TestHuggingFaceHubModelRewrite(ExtTestCase):

def test_code_needing_rewriting(self):
self.assertEqual(2, len(code_needing_rewriting("BartForConditionalGeneration")))

Expand Down
1 change: 0 additions & 1 deletion _unittests/ut_torch_models/test_tiny_llms_bypassed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def test_export_tiny_llm_2_bypassed(self):
with torch_export_patches(
patch_torch=False, patch_transformers=True, catch_constraints=False, verbose=10
) as modificator:

for k in patched_DynamicCache._PATCHES_:
self.assertEqual(getattr(patched_DynamicCache, k), getattr(DynamicCache, k))

Expand Down
1 change: 0 additions & 1 deletion _unittests/ut_torch_onnx/test_sbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@


class TestSideBySide(ExtTestCase):

@hide_stdout()
@unittest.skipIf(to_onnx is None, "to_onnx not installed")
@ignore_errors(OSError) # connectivity issues
Expand Down
1 change: 0 additions & 1 deletion _unittests/ut_xrun_doc/test_doc_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class TestDocDoc(ExtTestCase):

def test_reset(self):
reset_torch_transformers(None, None)

Expand Down
2 changes: 1 addition & 1 deletion onnx_diagnostic/export/dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _generic_walker_step(
if any(v is not None for v in value)
else None
)
assert type(inputs) is dict, f"Unexpected type for inputs {type(inputs)}"
assert isinstance(inputs, dict), f"Unexpected type for inputs {type(inputs)}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should keep type(inputs) is dict, isinstance(inputs, dict) is True for dict and output classes and I need to distinguish between the two.

Copy link
Collaborator Author

@titaiwangms titaiwangms Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of them? Because diffusers use FrozenDict, and it works with isinstance(x, dict), but not is dict. What is the output class, I can rule it out.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem of checking FrozenDict is that we need to import diffusers, if that's fine with you? I suggest we rule out output class instead?

assert set(inputs) == set(ds), (
f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, "
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
Expand Down
6 changes: 3 additions & 3 deletions onnx_diagnostic/helpers/config_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def update_config(config: Any, mkwargs: Dict[str, Any]):
setattr(config, k, v)
continue
existing = getattr(config, k)
if type(existing) is dict:
if isinstance(existing, dict):
existing.update(v)
else:
update_config(getattr(config, k), v)
continue
if type(config) is dict:
if isinstance(config, dict):
config[k] = v
else:
setattr(config, k, v)
Expand Down Expand Up @@ -76,7 +76,7 @@ def pick(config, name: str, default_value: Any) -> Any:
"""
if not config:
return default_value
if type(config) is dict:
if isinstance(config, dict):
return config.get(name, default_value)
return getattr(config, name, default_value)

Expand Down
2 changes: 1 addition & 1 deletion onnx_diagnostic/helpers/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def string_type(
print(f"[string_type] L:{type(obj)}")
return f"{{...}}#{len(obj)}" if with_shape else "{...}"
# dict
if isinstance(obj, dict) and type(obj) is dict:
if isinstance(obj, dict):
if len(obj) == 0:
if verbose:
print(f"[string_type] M:{type(obj)}")
Expand Down
2 changes: 1 addition & 1 deletion onnx_diagnostic/helpers/mini_onnx_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def _make(ty: type, res: Any) -> Any:
for k, v in res:
setattr(r, k, v)
return r
if ty is dict:
if isinstance(res, dict):
d = {}
for k, v in res:
if k.startswith("((") and k.endswith("))"):
Expand Down
1 change: 0 additions & 1 deletion onnx_diagnostic/helpers/ort_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@


class _InferenceSession:

@classmethod
def has_onnxruntime_training(cls):
"""Tells if onnxruntime_training is installed."""
Expand Down
1 change: 0 additions & 1 deletion onnx_diagnostic/helpers/rt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def make_feeds(
if len(names) < len(flat) and (
isinstance(proto, onnx.ModelProto) or hasattr(proto, "get_inputs")
):

typed_names = (
[(i.name, i.type.tensor_type.elem_type) for i in proto.graph.input]
if isinstance(proto, onnx.ModelProto)
Expand Down
7 changes: 1 addition & 6 deletions onnx_diagnostic/helpers/torch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,6 @@ def forward(self, x):
return word_emb + word_pe

class AttentionBlock(torch.nn.Module):

def __init__(self, embedding_dim: int = 16, context_size: int = 256):
super().__init__()
self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
Expand Down Expand Up @@ -555,7 +554,6 @@ def forward(self, x):
return out

class MultiAttentionBlock(torch.nn.Module):

def __init__(
self, embedding_dim: int = 16, num_heads: int = 2, context_size: int = 256
):
Expand All @@ -573,7 +571,6 @@ def forward(self, x):
return x

class FeedForward(torch.nn.Module):

def __init__(self, embedding_dim: int = 16, ff_dim: int = 128):
super().__init__()
self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
Expand All @@ -587,7 +584,6 @@ def forward(self, x):
return x

class DecoderLayer(torch.nn.Module):

def __init__(
self,
embedding_dim: int = 16,
Expand All @@ -613,7 +609,6 @@ def forward(self, x):
return ff

class LLM(torch.nn.Module):

def __init__(
self,
vocab_size: int = 1024,
Expand Down Expand Up @@ -717,7 +712,7 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
return tuple(to_any(t, to_value) for t in value)
if isinstance(value, set):
return {to_any(t, to_value) for t in value}
if type(value) is dict:
if isinstance(value, dict):
return {k: to_any(t, to_value) for k, t in value.items()}
if value.__class__.__name__ == "DynamicCache":
return make_dynamic_cache(
Expand Down
1 change: 0 additions & 1 deletion onnx_diagnostic/reference/ops/op_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@


class Scan(_Scan):

def need_context(self) -> bool:
"""Tells the runtime if this node needs the context
(all the results produced so far) as it may silently access
Expand Down
2 changes: 0 additions & 2 deletions onnx_diagnostic/torch_export_patches/eval/model_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ class ControlFlowCondIdentity_153832(torch.nn.Module):
"""

def forward(self, x, y):

def branch_cond_then_1(x):
x = torch.abs(x) + 1
return x
Expand Down Expand Up @@ -789,7 +788,6 @@ def forward(self, x):


class CropLastDimensionWithTensorShape(torch.nn.Module):

def forward(self, x, y):
return x[..., : y.shape[0]]

Expand Down
Loading
Loading