Skip to content

Commit 2d8d61c

Browse files
add stable diffusion sample code for onnxwrapper
1 parent b17534d commit 2d8d61c

File tree

11 files changed

+1682
-2
lines changed

11 files changed

+1682
-2
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ We provide [QAI AppBuilder Launcher](tools/launcher/), enabling you to experienc
120120
### 4. ONNX2BIN
121121
[ONNX2BIN](./tools/convert/onnx2bin/) is a guide to help you convert the ONNX model format into the BIN format optimized for a specific platform.
122122

123-
### 5. ONNXRT2QNNRT
124-
[ONNXRT2QNNRT](./tools/onnxrt2qnnrt/) is a wrapper to run onnx inference code with qnn model, which will switch to qnn runtime automatically.
123+
### 5. ONNXWRAPPER
124+
[ONNXWRAPPER](./tools/onnxwrapper/) is a wrapper to run onnx inference code with qnn model, which will switch to qnn runtime automatically.
125125

126126
## Models
127127
### Model Hub

script/qai_appbuilder/onnxwrapper.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,26 @@ def Inference(self, input_tensors: List[np.ndarray], input_names: Optional[List[
726726
return final_outs
727727

728728

729+
# -------------------- ORT enum shims --------------------
730+
class GraphOptimizationLevel:
731+
"""Compatibility shim for onnxruntime.GraphOptimizationLevel.
732+
onnxruntime uses integer levels 0..3; our SessionOptions.graph_optimization_level
733+
already stores an int, so exposing these constants makes existing scripts work.
734+
"""
735+
ORT_DISABLE_ALL = 0
736+
ORT_ENABLE_BASIC = 1
737+
ORT_ENABLE_EXTENDED = 2
738+
ORT_ENABLE_ALL = 3
739+
740+
class ExecutionMode:
741+
"""Compatibility shim for onnxruntime.ExecutionMode."""
742+
ORT_SEQUENTIAL = 0
743+
ORT_PARALLEL = 1
744+
745+
class SessionOptionsMode:
746+
"""Placeholder for any future mode enums; kept for forward compatibility."""
747+
pass
748+
729749
# -------------------- SessionOptions --------------------
730750
class SessionOptions:
731751
def __init__(self):
@@ -734,6 +754,8 @@ def __init__(self):
734754
self.enable_profiling = False
735755
self.optimized_model_filepath = ""
736756
self.graph_optimization_level = 1
757+
self.enable_mem_pattern = True
758+
self.execution_mode = ExecutionMode.ORT_SEQUENTIAL
737759
self.qnn_runtime = Runtime.HTP
738760
self.qnn_libs_dir = ""
739761
self.qnn_profiling_level = ProfilingLevel.OFF
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# stable_diffusion_v1_5 Sample Code
2+
3+
## Introduction
4+
This is sample code for using AppBuilder to load Stable Diffusion 1.5 QNN models to HTP and execute inference to generate image.
5+
6+
## Run the sample code
7+
```
8+
If you want to run the sample code with onnx models.
9+
python prepeare_stable_diffusion_onnx_models.py
10+
python stable_diffusion_1_5_onnx_infer.py --model_root models-onnx\modularai_stable-diffusion-1-5-onnx --provider cpu --out sd15_out.png
11+
12+
If you want to run the sample code with qnn models.
13+
python prepeare_stable_diffusion_qnn_models.py
14+
python onnxexec.py stable_diffusion_1_5_onnx_infer.py --model_root models-qnn --vae_scale 1.0
15+
16+
You also can add the following code at beginning of stable_diffusion_1_5_onnx_infer.py.
17+
from qai_appbuilder import onnxwrapper
18+
Then run the following command.
19+
python stable_diffusion_1_5_onnx_infer.py --model_root models-qnn --vae_scale 1.0
20+
21+
```
22+
## Output
23+
The output image will be saved to sd15_out.png
24+
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
#=============================================================================
2+
#
3+
# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved.
4+
#
5+
# SPDX-License-Identifier: BSD-3-Clause
6+
#
7+
#=============================================================================
8+
9+
# ---------------------------------------------------------------------
10+
# Download Stable Diffusion 1.5 ONNX models (with optional external data)
11+
#
12+
# Requirements:
13+
# pip install -U huggingface_hub onnx
14+
# ---------------------------------------------------------------------
15+
16+
import argparse
17+
from pathlib import Path
18+
19+
from huggingface_hub import snapshot_download
20+
21+
22+
REPO_ID = "modularai/stable-diffusion-1.5-onnx"
23+
24+
25+
def _model_requires_external_data(onnx_path: Path) -> tuple[bool, set[str]]:
26+
"""
27+
Return (requires_external_data, locations)
28+
locations: set of external data filenames referenced by the model.
29+
"""
30+
import onnx
31+
from onnx import TensorProto
32+
33+
# Load without external data (we only need the metadata that says whether it uses external data)
34+
m = onnx.load(str(onnx_path), load_external_data=False)
35+
36+
referenced = set()
37+
38+
def collect_from_tensor(t):
39+
# TensorProto has data_location + external_data entries when external
40+
if isinstance(t, TensorProto) and t.data_location == TensorProto.EXTERNAL:
41+
# external_data is a repeated key/value pair list
42+
loc = None
43+
for kv in t.external_data:
44+
if kv.key == "location":
45+
loc = kv.value
46+
break
47+
if loc:
48+
referenced.add(loc)
49+
50+
# Check initializers
51+
for init in m.graph.initializer:
52+
collect_from_tensor(init)
53+
54+
# Check constant nodes (rare but possible)
55+
for node in m.graph.node:
56+
if node.op_type == "Constant":
57+
for attr in node.attribute:
58+
if attr.type == onnx.AttributeProto.TENSOR:
59+
collect_from_tensor(attr.t)
60+
61+
return (len(referenced) > 0), referenced
62+
63+
64+
def download_onnx_models(out_root: Path):
65+
print(f"[HF] Downloading ONNX models from: {REPO_ID}")
66+
print(f"[HF] Output directory: {out_root}")
67+
68+
# Download ONNX + external data files (if they exist)
69+
allow_patterns = [
70+
"**/*.onnx",
71+
"**/*.onnx_data",
72+
"**/*.json",
73+
"**/*.txt",
74+
"**/*.md",
75+
]
76+
77+
snapshot_download(
78+
repo_id=REPO_ID,
79+
local_dir=str(out_root),
80+
local_dir_use_symlinks=False,
81+
allow_patterns=allow_patterns,
82+
)
83+
print("[HF] Download completed")
84+
85+
# ------------------------------------------------------------
86+
# Post-check: only require external data when ONNX references it
87+
# ------------------------------------------------------------
88+
print("[CHECK] Verifying ONNX external data references...")
89+
90+
missing = []
91+
checked = 0
92+
93+
for onnx_path in sorted(out_root.rglob("*.onnx")):
94+
checked += 1
95+
requires_ext, locations = _model_requires_external_data(onnx_path)
96+
97+
if not requires_ext:
98+
print(f"[OK] {onnx_path.relative_to(out_root)} (no external data)")
99+
continue
100+
101+
# If the model references external data, ensure all referenced files exist
102+
all_ok = True
103+
for loc in locations:
104+
data_path = (onnx_path.parent / loc).resolve()
105+
if not data_path.exists():
106+
all_ok = False
107+
missing.append(str(data_path))
108+
if all_ok:
109+
locs = ", ".join(sorted(locations))
110+
print(f"[OK] {onnx_path.relative_to(out_root)} (external data: {locs})")
111+
else:
112+
print(f"[ERROR] {onnx_path.relative_to(out_root)} references missing external data: {locations}")
113+
114+
print(f"[CHECK] Scanned {checked} ONNX files.")
115+
if missing:
116+
raise RuntimeError(
117+
"Some ONNX models reference external data files that are missing:\n"
118+
+ "\n".join(missing)
119+
+ "\n\nThis is required by ONNX external data format. "
120+
"Ensure the referenced files exist in the same folder as model.onnx. "
121+
"See ONNX external data spec for details."
122+
)
123+
124+
print("[CHECK] All required external data files are present ✅")
125+
126+
127+
def main():
128+
parser = argparse.ArgumentParser(
129+
description="Download Stable Diffusion 1.5 ONNX models (with optional external data)"
130+
)
131+
parser.add_argument(
132+
"--out_dir",
133+
type=str,
134+
default="models-onnx/modularai_stable-diffusion-1-5-onnx",
135+
help="Output directory",
136+
)
137+
args = parser.parse_args()
138+
139+
out_root = Path(args.out_dir).resolve()
140+
out_root.mkdir(parents=True, exist_ok=True)
141+
142+
download_onnx_models(out_root)
143+
144+
print("\n[OK] ONNX model preparation finished.")
145+
print(f"Model root: {out_root}")
146+
147+
148+
if __name__ == "__main__":
149+
main()

0 commit comments

Comments
 (0)