diff --git a/_doc/status/exported_program_dynamic.rst b/_doc/status/exported_program_dynamic.rst index b3cec3a0..a75c0d4e 100644 --- a/_doc/status/exported_program_dynamic.rst +++ b/_doc/status/exported_program_dynamic.rst @@ -53,6 +53,7 @@ with different options. This steps happens before converting into ONNX. for exporter in ( "export-strict", "export-nostrict", + "export-nostrict-oblivious", "export-nostrict-decall", "export-tracing", ): diff --git a/_unittests/ut_torch_export_patches/test_eval.py b/_unittests/ut_torch_export_patches/test_eval.py index fabb8f0c..902df00a 100644 --- a/_unittests/ut_torch_export_patches/test_eval.py +++ b/_unittests/ut_torch_export_patches/test_eval.py @@ -86,6 +86,22 @@ def test_run_exporter_custom_nested_cond(self): dynamic=False, ) + def test_run_exporter_dimension0(self): + evaluation( + cases="ExportWithDimension0", + exporters="export-nostrict-oblivious", + quiet=False, + dynamic=True, + ) + + def test_run_exporter_dimension1(self): + evaluation( + cases="ExportWithDimension1", + exporters="export-nostrict-oblivious", + quiet=False, + dynamic=True, + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/torch_export_patches/eval/__init__.py b/onnx_diagnostic/torch_export_patches/eval/__init__.py index 71194f4e..ab5ed574 100644 --- a/onnx_diagnostic/torch_export_patches/eval/__init__.py +++ b/onnx_diagnostic/torch_export_patches/eval/__init__.py @@ -39,6 +39,9 @@ def evaluation( "export-strict", "export-nostrict", "export-nostrict-decall", + "export-strict-oblivious", + "export-nostrict-oblivious", + "export-nostrict-decall-oblivious", ), dynamic: Tuple[bool] = (False, True), cases: Optional[Union[str, Dict[str, type]]] = None, @@ -105,9 +108,7 @@ def _loop(): def _flatten_inputs(x: Any) -> List["torch.Tensor"]: # noqa: F821 - """ - Flatten inputs. - """ + """Flatten inputs.""" if x is None: return x import torch @@ -173,6 +174,15 @@ def _clone(x): raise TypeError(f"Unable to clone type {type(x)}, x={x} into numpy") +def _wrap_torch_export(*args, backed_size_oblivious=False, **kwargs): + import torch + + if backed_size_oblivious: + with torch.fx.experimental._config.patch(backed_size_oblivious=True): + return torch.export.export(*args, **kwargs) + return torch.export.export(*args, **kwargs) + + def _make_exporter_export( exporter: str, model: "torch.nn.Module", # noqa: F821 @@ -183,76 +193,35 @@ def _make_exporter_export( ) -> Union[Dict, Callable]: import torch - if exporter == "export-strict": - try: - if verbose >= 2: - exported = torch.export.export( - model, inputs, dynamic_shapes=dynamic_shapes, strict=True - ) - else: - with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr( - io.StringIO() - ): - exported = torch.export.export( - model, inputs, dynamic_shapes=dynamic_shapes, strict=True - ) - except Exception as e: - if not quiet: - raise - return dict(error=str(e), success=0, error_step="export") - if verbose >= 9: - print("-- graph") - print(exported.graph) - return exported.module() - if exporter in ("export-strict-dec", "export-strict-decall"): - try: - if verbose >= 2: - exported = torch.export.export( - model, inputs, dynamic_shapes=dynamic_shapes, strict=True - ) - if verbose >= 9: - print("-- graph before decomposition") - print(exported.graph) - exported = ( - exported.run_decompositions() - if "decall" in exporter - else exported.run_decompositions({}) - ) - else: - with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr( - io.StringIO() - ): - exported = torch.export.export( - model, inputs, dynamic_shapes=dynamic_shapes, strict=True - ) - if verbose >= 9: - print("-- graph before decomposition") - print(exported.graph) - exported = ( - exported.run_decompositions() - if "decall" in exporter - else exported.run_decompositions({}) - ) - except Exception as e: - if not quiet: - raise - return dict(error=str(e), success=0, error_step="export") - if verbose >= 9: - print("-- graph after decomposition") - print(exported.graph) - return exported.module() - if exporter == "export-nostrict": + backed_size_oblivious = "-oblivious" in exporter + strict = "-nostrict" not in exporter + + if exporter in ( + "export-strict", + "export-strict-oblivious", + "export-nostrict", + "export-nostrict-oblivious", + "export-oblivious", + ): try: if verbose >= 2: - exported = torch.export.export( - model, inputs, dynamic_shapes=dynamic_shapes, strict=False + exported = _wrap_torch_export( + model, + inputs, + dynamic_shapes=dynamic_shapes, + strict=strict, + backed_size_oblivious=backed_size_oblivious, ) else: with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr( io.StringIO() ): - exported = torch.export.export( - model, inputs, dynamic_shapes=dynamic_shapes, strict=False + exported = _wrap_torch_export( + model, + inputs, + dynamic_shapes=dynamic_shapes, + strict=strict, + backed_size_oblivious=backed_size_oblivious, ) except Exception as e: if not quiet: @@ -262,11 +231,25 @@ def _make_exporter_export( print("-- graph") print(exported.graph) return exported.module() - if exporter in ("export-nostrict-dec", "export-nostrict-decall"): + + if exporter in ( + "export-strict-dec", + "export-strict-decall", + "export-strict-dec-oblivious", + "export-strict-decall-oblivious", + "export-nostrict-dec", + "export-nostrict-decall", + "export-nostrict-dec-oblivious", + "export-nostrict-decall-oblivious", + ): try: if verbose >= 2: - exported = torch.export.export( - model, inputs, dynamic_shapes=dynamic_shapes, strict=False + exported = _wrap_torch_export( + model, + inputs, + dynamic_shapes=dynamic_shapes, + strict=strict, + backed_size_oblivious=backed_size_oblivious, ) if verbose >= 9: print("-- graph before decomposition") @@ -280,8 +263,12 @@ def _make_exporter_export( with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr( io.StringIO() ): - exported = torch.export.export( - model, inputs, dynamic_shapes=dynamic_shapes, strict=False + exported = _wrap_torch_export( + model, + inputs, + dynamic_shapes=dynamic_shapes, + strict=strict, + backed_size_oblivious=backed_size_oblivious, ) if verbose >= 9: print("-- graph before decomposition") @@ -299,6 +286,7 @@ def _make_exporter_export( print("-- graph after decomposition") print(exported.graph) return exported.module() + if exporter == "export-tracing": from experimental_experiment.torch_interpreter.tracing import CustomTracer @@ -446,6 +434,74 @@ def _make_exporter_onnx( raise AssertionError(f"Unexpected exporter={exporter!r}") +def _compares_on_one_example( + model: Callable, inputs: Tuple[Any, ...], mod: Callable, verbose: int, quiet: bool +) -> Tuple[Any, Any, Dict]: + from onnx_diagnostic.helpers import max_diff, string_type + + try: + expected = model(*_clone(inputs)) + except Exception as e: + if not quiet: + raise RuntimeError( + f"eager mode failed=\n{string_type(inputs, with_shape=True)} " + f"\nmodel=\n{type(model)}" + ) from e + res = dict(error=str(e), success=0, error_step="eager") + return None, None, res + try: + got = mod(*inputs) + except Exception as e: + if not quiet: + raise RuntimeError( + f"onnxruntime failed, feeds=\n{string_type(inputs, with_shape=True)}" + ) from e + res = dict(error=str(e), success=0, error_step="run.0") + return expected, None, res + + try: + disc = max_diff(expected, got) + except Exception as e: + if not quiet: + raise + res = dict(error=str(e), success=0, error_step="discrepancy") + return expected, got, res + + if verbose >= 5 and np.isinf(disc["abs"]): + print("[run_exporter] comparison issues with") + print(f"-- inputs={string_type(inputs[0], with_shape=True, limit=20)}") + print(f"-- expected={string_type(expected, with_shape=True, limit=20)}") + print(f"-- got={string_type(got, with_shape=True, limit=20)}") + elif verbose >= 9: + print("[run_exporter] inputs and outputs") + print( + f"-- inputs=" + f"{string_type(inputs[0], with_shape=True, with_min_max=True, limit=20)}" + ) + print( + f"-- expected=" + f"{string_type(expected, with_shape=True, with_min_max=True, limit=20)}" + ) + print(f"-- got={string_type(got, with_shape=True, with_min_max=True, limit=20)}") + del disc["n"] + del disc["sum"] + disc.update( + dict( + success=1 if disc["abs"] < 0.1 else 0, + model_cls=model.__class__, # type: ignore[dict-item] + exported=mod, # type: ignore[dict-item] + ) + ) + if disc["abs"] >= 0.1: + disc["error"] = "diff.0" + disc["error_step"] = "diff.0" + if verbose >= 9: + max_diff(expected, got, verbose=verbose) + else: + disc["success"] = 1 + return expected, got, disc + + def run_exporter( exporter: str, cls_model: type, @@ -473,6 +529,7 @@ def run_exporter( model = cls_model() inputs = cls_model._inputs + valid = getattr(cls_model, "_valid", None) if isinstance(inputs, tuple): inputs = [inputs] if dynamic: @@ -566,74 +623,38 @@ def run_exporter( mod = lambda *args, names=names: sess.run(None, _make_feeds(names, args)) # noqa: E731 # we need to clone for models modifying the inputs - try: - expected = model(*_clone(inputs[0])) - except Exception as e: - if not quiet: - raise RuntimeError( - f"eager mode failed=\n{string_type(inputs[0], with_shape=True)} " - f"\nmodel=\n{type(model)}" - ) from e - res = dict(error=str(e), success=0, error_step="eager") - res.update(base) - return res - try: - got = mod(*inputs[0]) - except Exception as e: - if not quiet: - raise RuntimeError( - f"onnxruntime failed, feeds=\n{string_type(inputs[0], with_shape=True)} " - f"\nmodel=\n{pretty_onnx(onx)}" - ) from e - res = dict(error=str(e), success=0, error_step="run.0") - res.update(base) - return res - - base["expected"] = expected - base["obtained"] = got - - try: - disc = max_diff(expected, got) - except Exception as e: - if not quiet: - raise - res = dict(error=str(e), success=0, error_step="discrepancy") - res.update(base) - return res + expected, got, disc = _compares_on_one_example(model, inputs[0], mod, verbose, quiet) + if expected is not None: + base["expected"] = expected + if got is not None: + base["obtained"] = got + disc.update(base) + disc["onnx"] = onx # type: ignore[dict-item] - if verbose >= 5 and np.isinf(disc["abs"]): - print("[run_exporter] comparison issues with") - print(f"-- inputs={string_type(inputs[0], with_shape=True, limit=20)}") - print(f"-- expected={string_type(expected, with_shape=True, limit=20)}") - print(f"-- got={string_type(got, with_shape=True, limit=20)}") - elif verbose >= 9: - print("[run_exporter] inputs and outputs") - print( - f"-- inputs=" - f"{string_type(inputs[0], with_shape=True, with_min_max=True, limit=20)}" - ) - print( - f"-- expected=" - f"{string_type(expected, with_shape=True, with_min_max=True, limit=20)}" - ) - print(f"-- got={string_type(got, with_shape=True, with_min_max=True, limit=20)}") - del disc["n"] - del disc["sum"] - disc.update( - dict( - success=1 if disc["abs"] < 0.1 else 0, - model_cls=model.__class__, - exported=mod, # type: ignore[dict-item] - onnx=onx, # type: ignore[dict-item] - ) - ) - if disc["abs"] >= 0.1: - disc["error"] = "diff.0" - disc["error_step"] = "diff.0" - if verbose >= 9: - max_diff(expected, got, verbose=verbose) - else: - disc["success"] = 1 + if valid is not None: + for valid_inputs in valid: + expected, got, _disc = _compares_on_one_example( + model, valid_inputs, mod, verbose, quiet + ) + if "abs" not in disc and (np.isnan(disc["abs"]) or disc["abs"] > 1e-3): + _disc["issue-abs"] = disc["abs"] + _disc["issue-rel"] = disc["rel"] + _disc["issue-inputs"] = string_type( + valid_inputs, with_shape=True, with_min_max=True + ) + _disc["issue-expected"] = string_type( + expected, with_shape=True, with_min_max=True + ) + _disc["issue-obtained"] = string_type(got, with_shape=True, with_min_max=True) + if not quiet: + raise RuntimeError( + f"validation failed," + f"\n-- inputs=\n{string_type(_disc['issue-inputs'])} " + f"\n-- exporter={exporter!r}\n-- dynamic_shapes={dynamic_shapes}, " + f"\n-- expected={_disc['issue-expected']}" + f"\n-- obtained={_disc['issue-obtained']}" + ) + break if dynamic and onx is not None: ds = [] diff --git a/onnx_diagnostic/torch_export_patches/eval/model_cases.py b/onnx_diagnostic/torch_export_patches/eval/model_cases.py index 6541d5ff..08daf56a 100644 --- a/onnx_diagnostic/torch_export_patches/eval/model_cases.py +++ b/onnx_diagnostic/torch_export_patches/eval/model_cases.py @@ -881,3 +881,21 @@ def forward(self, x, y): _inputs = [(torch.tensor([1.0, 2.0, 3.0]), torch.tensor([0.1, 0.2, 0.3]))] _dynamic = {"x": {0: DYN}, "y": {0: DYN}} + + +class ExportWithDimension0(torch.nn.Module): + def forward(self, x): + return x @ torch.arange(x.shape[1], dtype=torch.float32).reshape((-1, 1)) + + _inputs = [(torch.empty((0, 3), dtype=torch.float32),)] + _dynamic = {"x": {0: DYN, 1: DYN}} + _valid = [(torch.rand((2, 3), dtype=torch.float32),)] + + +class ExportWithDimension1(torch.nn.Module): + def forward(self, x): + return x @ torch.arange(x.shape[1], dtype=torch.float32).reshape((-1, 1)) + + _inputs = [(torch.zeros((1, 3), dtype=torch.float32),)] + _dynamic = {"x": {0: DYN, 1: DYN}} + _valid = [(torch.rand((2, 3), dtype=torch.float32),)]