Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ jobs:
grep ERROR doc.txt | grep -v 'l-plot-tiny-llm-export'
exit 1
fi
if [[ $(grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string') ]]; then
if [[ $(grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string' | grep -v 'MambaCache') ]]; then
echo "Documentation produces warnings."
grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string'
grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string' | grep -v 'MambaCache'
exit 1
fi

Expand Down
131 changes: 131 additions & 0 deletions _scripts/compare_model_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""
Compares two ONNX models.
"""

print("-- import onnx")
import onnx

print("-- import onnx.helper")
from onnx.helper import tensor_dtype_to_np_dtype

print("-- import onnxruntime")
import onnxruntime

print("-- import torch")
import torch

print("-- import transformers")
import transformers

print("-- import huggingface_hub")
import huggingface_hub

print("-- import onnx-diagnostic.helper")
from onnx_diagnostic.helpers.helper import flatten_object, string_type, max_diff, string_diff

print("-- import onnx-diagnostic.torch_models.hghub")
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs

print("-- done")

model_id = "arnir0/Tiny-LLM"
onnx1 = (
"dump_test/arnir0_Tiny-LLM-custom-default-f16-cuda-op20/"
"arnir0_Tiny-LLM-custom-default-f16-cuda-op20.onnx"
)
onnx2 = (
"dump_test/arnir0_Tiny-LLM-custom-default-f16-cuda-op21/"
"arnir0_Tiny-LLM-custom-default-f16-cuda-op21.onnx"
)
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]

print(f"-- load {onnx1!r}")
onx1 = onnx.load(onnx1)
print(f"-- load {onnx2!r}")
onx2 = onnx.load(onnx2)

print(f"-- getting inputs for model_id {model_id!r}")
data = get_untrained_model_with_inputs(model_id)
inputs = data["inputs"]
print(f"-- inputs: {string_type(inputs, with_shape=True)}")
flatten_inputs = flatten_object(inputs, drop_keys=True)
print(f"-- flat inputs: {string_type(flatten_inputs, with_shape=True)}")

names = [i.name for i in onx1.graph.input]
itypes = [i.type.tensor_type.elem_type for i in onx1.graph.input]
assert names == [
i.name for i in onx2.graph.input
], f"Not the same names for both models {names} != {[i.name for i in onx2.graph.input]}"
feeds = {
n: t.numpy().astype(tensor_dtype_to_np_dtype(itype))
for n, itype, t in zip(names, itypes, flatten_inputs)
}
print(f"-- feeds: {string_type(feeds, with_shape=True)}")

print(f"-- creating session 1 from {onnx1!r}")
opts = onnxruntime.SessionOptions()
opts.optimized_model_filepath = "debug1_full.onnx"
opts.log_severity_level = 0
opts.log_verbosity_level = 0
sess1 = onnxruntime.InferenceSession(onnx1, opts, providers=providers)
print(f"-- creating session 2 from {onnx2!r}")
opts.optimized_model_filepath = "debug2_full.onnx"
opts.log_severity_level = 0
opts.log_verbosity_level = 0
sess2 = onnxruntime.InferenceSession(onnx2, opts, providers=providers)

print("-- run session1")
expected1 = sess1.run(None, feeds)
print(f"-- got {string_type(expected1, with_shape=True)}")
print("-- run session2")
expected2 = sess2.run(None, feeds)
print(f"-- got {string_type(expected2, with_shape=True)}")

print("-- compute differences")
diff = max_diff(expected1, expected2)
print(f"-- diff={string_diff(diff)}")


def get_names(onx: onnx.ModelProto) -> list[str]:
names = []
for node in onx.graph.node:
for o in node.output:
names.append((o, node.op_type, node.name))
return names


if diff["abs"] > 0.1:
print("--")
print("-- import select_model_inputs_outputs")
from onnx_extended.tools.onnx_nodes import select_model_inputs_outputs

print("-- looking into intermediate results")
names1 = get_names(onx1)
names2 = get_names(onx1)
common = [n for n in names1 if n in (set(names1) & set(names2))]
print(f"-- {len(common)} names / {len(names1)}-{len(names2)}")
print(f"-- first names {common[:5]}")
for name, op_type, op_name in common:
x1 = select_model_inputs_outputs(onx1, [name])
x2 = select_model_inputs_outputs(onx2, [name])
s1 = onnxruntime.InferenceSession(x1.SerializeToString(), providers=providers)
s2 = onnxruntime.InferenceSession(x2.SerializeToString(), providers=providers)
e1 = s1.run(None, feeds)
e2 = s2.run(None, feeds)
diff = max_diff(e1, e2)
print(
f"-- name={name!r}: diff={string_diff(diff)} "
f"- op_type={op_type!r}, op_name={op_name!r}"
)
if diff["abs"] > 0.1:
opts = onnxruntime.SessionOptions()
opts.optimized_model_filepath = "debug1.onnx"
onnxruntime.InferenceSession(x1.SerializeToString(), opts, providers=providers)
opts.optimized_model_filepath = "debug2.onnx"
onnxruntime.InferenceSession(x2.SerializeToString(), opts, providers=providers)
print("--")
print("-- break here")
print(f"-- feeds {string_type(feeds, with_shape=True)}")
print(f"-- e1={string_type(e1, with_shape=True, with_min_max=True)}")
print(f"-- e2={string_type(e2, with_shape=True, with_min_max=True)}")
break
2 changes: 1 addition & 1 deletion onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def get_parser_validate() -> ArgumentParser:
)
parser.add_argument(
"--runtime",
choices=["onnxruntime", "torch", "ref"],
choices=["onnxruntime", "torch", "ref", "orteval", "orteval10"],
default="onnxruntime",
help="onnx runtime to use, `onnxruntime` by default",
)
Expand Down
14 changes: 8 additions & 6 deletions onnx_diagnostic/helpers/cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@
import transformers
import transformers.cache_utils

try:
from transformers.models.mamba.modeling_mamba import MambaCache
except ImportError:
from transformers.cache_utils import MambaCache


class CacheKeyValue:
"""
Expand Down Expand Up @@ -354,8 +349,15 @@ def make_encoder_decoder_cache(
)


def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache:
def make_mamba_cache(
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
) -> "MambaCache": # noqa: F821
"Creates a ``MambaCache``."
# import is moved here because this part is slow.
try:
from transformers.models.mamba.modeling_mamba import MambaCache
except ImportError:
from transformers.cache_utils import MambaCache
dtype = key_value_pairs[0][0].dtype

class _config:
Expand Down
4 changes: 4 additions & 0 deletions onnx_diagnostic/torch_models/hghub/model_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def get_untrained_model_with_inputs(
f"and use_pretrained=True."
)

seed = int(os.environ.get("SEED", "17"))
torch.manual_seed(seed)
try:
if type(config) is dict:
model = cls_model(**config)
Expand All @@ -239,6 +241,8 @@ def get_untrained_model_with_inputs(
) from e

# input kwargs
seed = int(os.environ.get("SEED", "17")) + 1
torch.manual_seed(seed)
kwargs, fct = random_input_kwargs(config, task) # type: ignore[arg-type]
if verbose:
print(f"[get_untrained_model_with_inputs] use fct={fct}")
Expand Down
93 changes: 78 additions & 15 deletions onnx_diagnostic/torch_models/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import time
import numpy as np
import onnx
import onnxscript
import onnxscript.rewriter.ort_fusions as ort_fusions
import torch
from ..export import CoupleInputsDynamicShapes
from ..helpers import max_diff, string_type, string_diff
Expand Down Expand Up @@ -249,6 +247,7 @@ def _quiet_or_not_quiet(
summary[f"time_{suffix}_latency_std"] = a.std()
summary[f"time_{suffix}_latency_min"] = a.min()
summary[f"time_{suffix}_latency_min"] = a.max()
summary[f"time_{suffix}_n"] = len(a)
return res


Expand Down Expand Up @@ -337,7 +336,8 @@ def validate_model(
:param subfolder: version or subfolders to uses when retrieving a model id
:param opset: onnx opset to use for the conversion
:param runtime: onnx runtime to use to check about discrepancies,
only if `do_run` is true
possible values ``onnxruntime``, ``torch``, ``orteval``,
``orteval10``, ``ref`` only if `do_run` is true
:param repeat: number of time to measure the model
:param warmup: warmup the model first
:param inputs2: checks that the second set of inputs is reunning as well,
Expand All @@ -364,7 +364,13 @@ def validate_model(

The default runtime, :epkg:`onnxruntime` is used to validate a model and check the
exported model returns the same outputs as the original one, otherwise,
:class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
:class:`onnx_diagnostic.reference.TorchOnnxEvaluator`
if ``runtime == 'torch'`` or
:class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`
if ``runtime == 'orteval'`` or
:class:`onnx_diagnostic.reference.ExtendedReferenceEvaluator`
if ``runtime == 'ref'``,
``orteval10`` increases the verbosity.
"""
if isinstance(patch, bool):
patch_kwargs = (
Expand Down Expand Up @@ -846,15 +852,24 @@ def node_iter(proto):
raise NotImplementedError(f"Unexpected type={type(proto)}")

counts: Dict[str, Union[float, int]] = {}
n_nodes = 0
n_nodes_nocst = 0
for proto in node_iter(onx):
if isinstance(proto, onnx.NodeProto):
key = f"n_node_{proto.op_type}"
n_nodes += 1
if proto.op_type != "Constant":
n_nodes_nocst += 1
else:
key = f"n_node_initializer_{proto.data_type}"

if key not in counts:
counts[key] = 0
counts[key] += 1

counts["n_node_nodes"] = n_nodes
counts["n_node_nodes_nocst"] = n_nodes_nocst
counts["n_node_functions"] = len(onx.functions)
return counts


Expand Down Expand Up @@ -1155,7 +1170,7 @@ def validate_onnx_model(
:param quiet: catch exception or not
:param verbose: verbosity
:param flavour: use a different version of the inputs
:param runtime: onnx runtime to use, onnxruntime or torch
:param runtime: onnx runtime to use, onnxruntime, torch, orteval, ref
:param repeat: run that number of times the model
:param warmup: warmup the model
:param inputs2: to validate the model on the second input set
Expand Down Expand Up @@ -1202,23 +1217,66 @@ def _mk(key, flavour=flavour):
f"{providers}..., flavour={flavour!r}"
)

if runtime != "onnxruntime":
if runtime == "onnxruntime":
if os.environ.get("DUMPORTOPT", "") in ("1", "true", "True"):
opts = onnxruntime.SessionOptions()
opts.optimized_model_filepath = f"{data['onnx_filename']}.rtopt.onnx"
if verbose:
print(
f"[validate_onnx_model] saved optimized onnxruntime "
f"in {opts.optimized_model_filepath!r}"
)
onnxruntime.InferenceSession(data["onnx_filename"], opts, providers=providers)
if verbose:
print("[validate_onnx_model] -- done")

if verbose:
print("[validate_onnx_model] runtime is onnxruntime")
cls_runtime = lambda model, providers: onnxruntime.InferenceSession(
(model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
providers=providers,
)
elif runtime == "torch":
from ..reference import TorchOnnxEvaluator

cls_runtime = (
(
lambda model, providers: onnxruntime.InferenceSession(
(model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
providers=providers,
if verbose:
print("[validate_onnx_model] runtime is TorchOnnxEvaluator")
cls_runtime = (
lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_( # type: ignore[misc]
model, providers=providers, verbose=max(verbose - 1, 0)
)
)
if runtime == "onnxruntime"
else (
lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_( # type: ignore[misc]
elif runtime == "orteval":
from ..reference import OnnxruntimeEvaluator

if verbose:
print("[validate_onnx_model] runtime is OnnxruntimeEvaluator")
cls_runtime = (
lambda model, providers, _cls_=OnnxruntimeEvaluator: _cls_( # type: ignore[misc]
model, providers=providers, verbose=max(verbose - 1, 0)
)
)
)
elif runtime == "orteval10":
from ..reference import OnnxruntimeEvaluator

if verbose:
print("[validate_onnx_model] runtime is OnnxruntimeEvaluator(verbose=10)")
cls_runtime = (
lambda model, providers, _cls_=OnnxruntimeEvaluator: _cls_( # type: ignore[misc]
model, providers=providers, verbose=10
)
)
elif runtime == "ref":
from ..reference import ExtendedReferenceEvaluator

if verbose:
print("[validate_onnx_model] runtime is ExtendedReferenceEvaluator")
cls_runtime = lambda model, providers, _cls_=ExtendedReferenceEvaluator: _cls_( # type: ignore[misc]
model, verbose=max(verbose - 1, 0)
)
else:
raise ValueError(f"Unexpecteed runtime={runtime!r}")

sess = _quiet_or_not_quiet(
quiet,
_mk("create_onnx_ort"),
Expand Down Expand Up @@ -1399,6 +1457,8 @@ def call_torch_export_onnx(
if optimization == "ir":
label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize())
else:
import onnxscript
import onnxscript.rewriter.ort_fusions as ort_fusions

def _os_ort_optim(epo):
onnxscript.optimizer.optimize_ir(epo.model)
Expand Down Expand Up @@ -1683,6 +1743,9 @@ def call_torch_export_custom(
print("[call_torch_export_custom] done (export)")

if os_ort:
import onnxscript
import onnxscript.rewriter.ort_fusions as ort_fusions

if verbose:
print("[call_torch_export_custom] conversion to IR...")
begin = time.perf_counter()
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ disable_error_code = ["arg-type", "assignment", "import-untyped", "misc", "name-
module = ["onnx_diagnostic.helpers.args_helper"]
disable_error_code = ["arg-type", "call-overload", "index"]

[[tool.mypy.overrides]]
module = ["onnx_diagnostic.helpers.cache_helper"]
disable_error_code = ["name-defined"]

[[tool.mypy.overrides]]
module = ["onnx_diagnostic.helpers.helper"]
disable_error_code = ["arg-type", "assignment", "attr-defined", "call-overload", "misc", "name-defined", "union-attr"]
Expand Down Expand Up @@ -123,6 +127,7 @@ select = [
"_doc/examples/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
"_doc/notebooks/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
"_doc/recipes/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
"_scripts/compare_model_execution.py" = ["E402", "F401"]
"_doc/technical/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
"_unittests/*/test*.py" = ["B008", "B904", "PIE808", "SIM117", "SIM105", "UP008"]
"onnx_diagnostic/export/__init__.py" = ["F401"]
Expand All @@ -131,6 +136,7 @@ select = [
"onnx_diagnostic/reference/torch_ops/__init__.py" = ["F401"]
"onnx_diagnostic/torch_models/hghub/__init__.py" = ["F401"]
"onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py" = ["PIE804"]
"onnx_diagnostic/torch_models/validate.py" = ["E731"]
"onnx_diagnostic/torch_export_patches/__init__.py" = ["F401"]
"onnx_diagnostic/torch_export_patches/patches/__init__.py" = ["F401"]
"onnx_diagnostic/torch_models/llms.py" = ["F401"]
Loading