diff --git a/onnx_diagnostic/helpers/model_builder_helper.py b/onnx_diagnostic/helpers/model_builder_helper.py index 29839998..c87c28e1 100644 --- a/onnx_diagnostic/helpers/model_builder_helper.py +++ b/onnx_diagnostic/helpers/model_builder_helper.py @@ -220,6 +220,9 @@ def create_model_builder( """ assert cache_dir, "create_model_builder does not work without cache_dir." assert os.path.exists(cache_dir), f"cache_dir={cache_dir!r} does not exists" + precision = {"float32": "fp32", "float16": "fp16", "bfloat16": "bfp16"}.get( + precision, precision + ) download_model_builder_to_cache() builder = import_model_builder() io_dtype = builder.set_io_dtype(precision, execution_provider, extra_options)