diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index 8c7148ff..265e165e 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -118,9 +118,9 @@ jobs: grep ERROR doc.txt | grep -v 'l-plot-tiny-llm-export' exit 1 fi - if [[ $(grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string') ]]; then + if [[ $(grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string' | grep -v 'MambaCache') ]]; then echo "Documentation produces warnings." - grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string' + grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string' | grep -v 'MambaCache' exit 1 fi diff --git a/_scripts/compare_model_execution.py b/_scripts/compare_model_execution.py new file mode 100644 index 00000000..c5a823b5 --- /dev/null +++ b/_scripts/compare_model_execution.py @@ -0,0 +1,131 @@ +""" +Compares two ONNX models. +""" + +print("-- import onnx") +import onnx + +print("-- import onnx.helper") +from onnx.helper import tensor_dtype_to_np_dtype + +print("-- import onnxruntime") +import onnxruntime + +print("-- import torch") +import torch + +print("-- import transformers") +import transformers + +print("-- import huggingface_hub") +import huggingface_hub + +print("-- import onnx-diagnostic.helper") +from onnx_diagnostic.helpers.helper import flatten_object, string_type, max_diff, string_diff + +print("-- import onnx-diagnostic.torch_models.hghub") +from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs + +print("-- done") + +model_id = "arnir0/Tiny-LLM" +onnx1 = ( + "dump_test/arnir0_Tiny-LLM-custom-default-f16-cuda-op20/" + "arnir0_Tiny-LLM-custom-default-f16-cuda-op20.onnx" +) +onnx2 = ( + "dump_test/arnir0_Tiny-LLM-custom-default-f16-cuda-op21/" + "arnir0_Tiny-LLM-custom-default-f16-cuda-op21.onnx" +) +providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + +print(f"-- load {onnx1!r}") +onx1 = onnx.load(onnx1) +print(f"-- load {onnx2!r}") +onx2 = onnx.load(onnx2) + +print(f"-- getting inputs for model_id {model_id!r}") +data = get_untrained_model_with_inputs(model_id) +inputs = data["inputs"] +print(f"-- inputs: {string_type(inputs, with_shape=True)}") +flatten_inputs = flatten_object(inputs, drop_keys=True) +print(f"-- flat inputs: {string_type(flatten_inputs, with_shape=True)}") + +names = [i.name for i in onx1.graph.input] +itypes = [i.type.tensor_type.elem_type for i in onx1.graph.input] +assert names == [ + i.name for i in onx2.graph.input +], f"Not the same names for both models {names} != {[i.name for i in onx2.graph.input]}" +feeds = { + n: t.numpy().astype(tensor_dtype_to_np_dtype(itype)) + for n, itype, t in zip(names, itypes, flatten_inputs) +} +print(f"-- feeds: {string_type(feeds, with_shape=True)}") + +print(f"-- creating session 1 from {onnx1!r}") +opts = onnxruntime.SessionOptions() +opts.optimized_model_filepath = "debug1_full.onnx" +opts.log_severity_level = 0 +opts.log_verbosity_level = 0 +sess1 = onnxruntime.InferenceSession(onnx1, opts, providers=providers) +print(f"-- creating session 2 from {onnx2!r}") +opts.optimized_model_filepath = "debug2_full.onnx" +opts.log_severity_level = 0 +opts.log_verbosity_level = 0 +sess2 = onnxruntime.InferenceSession(onnx2, opts, providers=providers) + +print("-- run session1") +expected1 = sess1.run(None, feeds) +print(f"-- got {string_type(expected1, with_shape=True)}") +print("-- run session2") +expected2 = sess2.run(None, feeds) +print(f"-- got {string_type(expected2, with_shape=True)}") + +print("-- compute differences") +diff = max_diff(expected1, expected2) +print(f"-- diff={string_diff(diff)}") + + +def get_names(onx: onnx.ModelProto) -> list[str]: + names = [] + for node in onx.graph.node: + for o in node.output: + names.append((o, node.op_type, node.name)) + return names + + +if diff["abs"] > 0.1: + print("--") + print("-- import select_model_inputs_outputs") + from onnx_extended.tools.onnx_nodes import select_model_inputs_outputs + + print("-- looking into intermediate results") + names1 = get_names(onx1) + names2 = get_names(onx1) + common = [n for n in names1 if n in (set(names1) & set(names2))] + print(f"-- {len(common)} names / {len(names1)}-{len(names2)}") + print(f"-- first names {common[:5]}") + for name, op_type, op_name in common: + x1 = select_model_inputs_outputs(onx1, [name]) + x2 = select_model_inputs_outputs(onx2, [name]) + s1 = onnxruntime.InferenceSession(x1.SerializeToString(), providers=providers) + s2 = onnxruntime.InferenceSession(x2.SerializeToString(), providers=providers) + e1 = s1.run(None, feeds) + e2 = s2.run(None, feeds) + diff = max_diff(e1, e2) + print( + f"-- name={name!r}: diff={string_diff(diff)} " + f"- op_type={op_type!r}, op_name={op_name!r}" + ) + if diff["abs"] > 0.1: + opts = onnxruntime.SessionOptions() + opts.optimized_model_filepath = "debug1.onnx" + onnxruntime.InferenceSession(x1.SerializeToString(), opts, providers=providers) + opts.optimized_model_filepath = "debug2.onnx" + onnxruntime.InferenceSession(x2.SerializeToString(), opts, providers=providers) + print("--") + print("-- break here") + print(f"-- feeds {string_type(feeds, with_shape=True)}") + print(f"-- e1={string_type(e1, with_shape=True, with_min_max=True)}") + print(f"-- e2={string_type(e2, with_shape=True, with_min_max=True)}") + break diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 18d6b7dc..db488e46 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -474,7 +474,7 @@ def get_parser_validate() -> ArgumentParser: ) parser.add_argument( "--runtime", - choices=["onnxruntime", "torch", "ref"], + choices=["onnxruntime", "torch", "ref", "orteval", "orteval10"], default="onnxruntime", help="onnx runtime to use, `onnxruntime` by default", ) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 9d3a52e0..dcd997c5 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -4,11 +4,6 @@ import transformers import transformers.cache_utils -try: - from transformers.models.mamba.modeling_mamba import MambaCache -except ImportError: - from transformers.cache_utils import MambaCache - class CacheKeyValue: """ @@ -354,8 +349,15 @@ def make_encoder_decoder_cache( ) -def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache: +def make_mamba_cache( + key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], +) -> "MambaCache": # noqa: F821 "Creates a ``MambaCache``." + # import is moved here because this part is slow. + try: + from transformers.models.mamba.modeling_mamba import MambaCache + except ImportError: + from transformers.cache_utils import MambaCache dtype = key_value_pairs[0][0].dtype class _config: diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index 91cc052a..730b659e 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -228,6 +228,8 @@ def get_untrained_model_with_inputs( f"and use_pretrained=True." ) + seed = int(os.environ.get("SEED", "17")) + torch.manual_seed(seed) try: if type(config) is dict: model = cls_model(**config) @@ -239,6 +241,8 @@ def get_untrained_model_with_inputs( ) from e # input kwargs + seed = int(os.environ.get("SEED", "17")) + 1 + torch.manual_seed(seed) kwargs, fct = random_input_kwargs(config, task) # type: ignore[arg-type] if verbose: print(f"[get_untrained_model_with_inputs] use fct={fct}") diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index d6b3994f..b0b69e50 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -7,8 +7,6 @@ import time import numpy as np import onnx -import onnxscript -import onnxscript.rewriter.ort_fusions as ort_fusions import torch from ..export import CoupleInputsDynamicShapes from ..helpers import max_diff, string_type, string_diff @@ -249,6 +247,7 @@ def _quiet_or_not_quiet( summary[f"time_{suffix}_latency_std"] = a.std() summary[f"time_{suffix}_latency_min"] = a.min() summary[f"time_{suffix}_latency_min"] = a.max() + summary[f"time_{suffix}_n"] = len(a) return res @@ -337,7 +336,8 @@ def validate_model( :param subfolder: version or subfolders to uses when retrieving a model id :param opset: onnx opset to use for the conversion :param runtime: onnx runtime to use to check about discrepancies, - only if `do_run` is true + possible values ``onnxruntime``, ``torch``, ``orteval``, + ``orteval10``, ``ref`` only if `do_run` is true :param repeat: number of time to measure the model :param warmup: warmup the model first :param inputs2: checks that the second set of inputs is reunning as well, @@ -364,7 +364,13 @@ def validate_model( The default runtime, :epkg:`onnxruntime` is used to validate a model and check the exported model returns the same outputs as the original one, otherwise, - :class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used. + :class:`onnx_diagnostic.reference.TorchOnnxEvaluator` + if ``runtime == 'torch'`` or + :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator` + if ``runtime == 'orteval'`` or + :class:`onnx_diagnostic.reference.ExtendedReferenceEvaluator` + if ``runtime == 'ref'``, + ``orteval10`` increases the verbosity. """ if isinstance(patch, bool): patch_kwargs = ( @@ -846,15 +852,24 @@ def node_iter(proto): raise NotImplementedError(f"Unexpected type={type(proto)}") counts: Dict[str, Union[float, int]] = {} + n_nodes = 0 + n_nodes_nocst = 0 for proto in node_iter(onx): if isinstance(proto, onnx.NodeProto): key = f"n_node_{proto.op_type}" + n_nodes += 1 + if proto.op_type != "Constant": + n_nodes_nocst += 1 else: key = f"n_node_initializer_{proto.data_type}" if key not in counts: counts[key] = 0 counts[key] += 1 + + counts["n_node_nodes"] = n_nodes + counts["n_node_nodes_nocst"] = n_nodes_nocst + counts["n_node_functions"] = len(onx.functions) return counts @@ -1155,7 +1170,7 @@ def validate_onnx_model( :param quiet: catch exception or not :param verbose: verbosity :param flavour: use a different version of the inputs - :param runtime: onnx runtime to use, onnxruntime or torch + :param runtime: onnx runtime to use, onnxruntime, torch, orteval, ref :param repeat: run that number of times the model :param warmup: warmup the model :param inputs2: to validate the model on the second input set @@ -1202,23 +1217,66 @@ def _mk(key, flavour=flavour): f"{providers}..., flavour={flavour!r}" ) - if runtime != "onnxruntime": + if runtime == "onnxruntime": + if os.environ.get("DUMPORTOPT", "") in ("1", "true", "True"): + opts = onnxruntime.SessionOptions() + opts.optimized_model_filepath = f"{data['onnx_filename']}.rtopt.onnx" + if verbose: + print( + f"[validate_onnx_model] saved optimized onnxruntime " + f"in {opts.optimized_model_filepath!r}" + ) + onnxruntime.InferenceSession(data["onnx_filename"], opts, providers=providers) + if verbose: + print("[validate_onnx_model] -- done") + + if verbose: + print("[validate_onnx_model] runtime is onnxruntime") + cls_runtime = lambda model, providers: onnxruntime.InferenceSession( + (model.SerializeToString() if isinstance(model, onnx.ModelProto) else model), + providers=providers, + ) + elif runtime == "torch": from ..reference import TorchOnnxEvaluator - cls_runtime = ( - ( - lambda model, providers: onnxruntime.InferenceSession( - (model.SerializeToString() if isinstance(model, onnx.ModelProto) else model), - providers=providers, + if verbose: + print("[validate_onnx_model] runtime is TorchOnnxEvaluator") + cls_runtime = ( + lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_( # type: ignore[misc] + model, providers=providers, verbose=max(verbose - 1, 0) ) ) - if runtime == "onnxruntime" - else ( - lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_( # type: ignore[misc] + elif runtime == "orteval": + from ..reference import OnnxruntimeEvaluator + + if verbose: + print("[validate_onnx_model] runtime is OnnxruntimeEvaluator") + cls_runtime = ( + lambda model, providers, _cls_=OnnxruntimeEvaluator: _cls_( # type: ignore[misc] model, providers=providers, verbose=max(verbose - 1, 0) ) ) - ) + elif runtime == "orteval10": + from ..reference import OnnxruntimeEvaluator + + if verbose: + print("[validate_onnx_model] runtime is OnnxruntimeEvaluator(verbose=10)") + cls_runtime = ( + lambda model, providers, _cls_=OnnxruntimeEvaluator: _cls_( # type: ignore[misc] + model, providers=providers, verbose=10 + ) + ) + elif runtime == "ref": + from ..reference import ExtendedReferenceEvaluator + + if verbose: + print("[validate_onnx_model] runtime is ExtendedReferenceEvaluator") + cls_runtime = lambda model, providers, _cls_=ExtendedReferenceEvaluator: _cls_( # type: ignore[misc] + model, verbose=max(verbose - 1, 0) + ) + else: + raise ValueError(f"Unexpecteed runtime={runtime!r}") + sess = _quiet_or_not_quiet( quiet, _mk("create_onnx_ort"), @@ -1399,6 +1457,8 @@ def call_torch_export_onnx( if optimization == "ir": label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize()) else: + import onnxscript + import onnxscript.rewriter.ort_fusions as ort_fusions def _os_ort_optim(epo): onnxscript.optimizer.optimize_ir(epo.model) @@ -1683,6 +1743,9 @@ def call_torch_export_custom( print("[call_torch_export_custom] done (export)") if os_ort: + import onnxscript + import onnxscript.rewriter.ort_fusions as ort_fusions + if verbose: print("[call_torch_export_custom] conversion to IR...") begin = time.perf_counter() diff --git a/pyproject.toml b/pyproject.toml index a904f088..b1a90221 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,10 @@ disable_error_code = ["arg-type", "assignment", "import-untyped", "misc", "name- module = ["onnx_diagnostic.helpers.args_helper"] disable_error_code = ["arg-type", "call-overload", "index"] +[[tool.mypy.overrides]] +module = ["onnx_diagnostic.helpers.cache_helper"] +disable_error_code = ["name-defined"] + [[tool.mypy.overrides]] module = ["onnx_diagnostic.helpers.helper"] disable_error_code = ["arg-type", "assignment", "attr-defined", "call-overload", "misc", "name-defined", "union-attr"] @@ -123,6 +127,7 @@ select = [ "_doc/examples/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"] "_doc/notebooks/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"] "_doc/recipes/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"] +"_scripts/compare_model_execution.py" = ["E402", "F401"] "_doc/technical/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"] "_unittests/*/test*.py" = ["B008", "B904", "PIE808", "SIM117", "SIM105", "UP008"] "onnx_diagnostic/export/__init__.py" = ["F401"] @@ -131,6 +136,7 @@ select = [ "onnx_diagnostic/reference/torch_ops/__init__.py" = ["F401"] "onnx_diagnostic/torch_models/hghub/__init__.py" = ["F401"] "onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py" = ["PIE804"] +"onnx_diagnostic/torch_models/validate.py" = ["E731"] "onnx_diagnostic/torch_export_patches/__init__.py" = ["F401"] "onnx_diagnostic/torch_export_patches/patches/__init__.py" = ["F401"] "onnx_diagnostic/torch_models/llms.py" = ["F401"]