1+ #=============================================================================
2+ #
3+ # Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved.
4+ #
5+ # SPDX-License-Identifier: BSD-3-Clause
6+ #
7+ #=============================================================================
8+
9+ # ---------------------------------------------------------------------
10+ # Download Stable Diffusion 1.5 ONNX models (with optional external data)
11+ #
12+ # Requirements:
13+ # pip install -U huggingface_hub onnx
14+ # ---------------------------------------------------------------------
15+
16+ import argparse
17+ from pathlib import Path
18+
19+ from huggingface_hub import snapshot_download
20+
21+
22+ REPO_ID = "modularai/stable-diffusion-1.5-onnx"
23+
24+
25+ def _model_requires_external_data (onnx_path : Path ) -> tuple [bool , set [str ]]:
26+ """
27+ Return (requires_external_data, locations)
28+ locations: set of external data filenames referenced by the model.
29+ """
30+ import onnx
31+ from onnx import TensorProto
32+
33+ # Load without external data (we only need the metadata that says whether it uses external data)
34+ m = onnx .load (str (onnx_path ), load_external_data = False )
35+
36+ referenced = set ()
37+
38+ def collect_from_tensor (t ):
39+ # TensorProto has data_location + external_data entries when external
40+ if isinstance (t , TensorProto ) and t .data_location == TensorProto .EXTERNAL :
41+ # external_data is a repeated key/value pair list
42+ loc = None
43+ for kv in t .external_data :
44+ if kv .key == "location" :
45+ loc = kv .value
46+ break
47+ if loc :
48+ referenced .add (loc )
49+
50+ # Check initializers
51+ for init in m .graph .initializer :
52+ collect_from_tensor (init )
53+
54+ # Check constant nodes (rare but possible)
55+ for node in m .graph .node :
56+ if node .op_type == "Constant" :
57+ for attr in node .attribute :
58+ if attr .type == onnx .AttributeProto .TENSOR :
59+ collect_from_tensor (attr .t )
60+
61+ return (len (referenced ) > 0 ), referenced
62+
63+
64+ def download_onnx_models (out_root : Path ):
65+ print (f"[HF] Downloading ONNX models from: { REPO_ID } " )
66+ print (f"[HF] Output directory: { out_root } " )
67+
68+ # Download ONNX + external data files (if they exist)
69+ allow_patterns = [
70+ "**/*.onnx" ,
71+ "**/*.onnx_data" ,
72+ "**/*.json" ,
73+ "**/*.txt" ,
74+ "**/*.md" ,
75+ ]
76+
77+ snapshot_download (
78+ repo_id = REPO_ID ,
79+ local_dir = str (out_root ),
80+ local_dir_use_symlinks = False ,
81+ allow_patterns = allow_patterns ,
82+ )
83+ print ("[HF] Download completed" )
84+
85+ # ------------------------------------------------------------
86+ # Post-check: only require external data when ONNX references it
87+ # ------------------------------------------------------------
88+ print ("[CHECK] Verifying ONNX external data references..." )
89+
90+ missing = []
91+ checked = 0
92+
93+ for onnx_path in sorted (out_root .rglob ("*.onnx" )):
94+ checked += 1
95+ requires_ext , locations = _model_requires_external_data (onnx_path )
96+
97+ if not requires_ext :
98+ print (f"[OK] { onnx_path .relative_to (out_root )} (no external data)" )
99+ continue
100+
101+ # If the model references external data, ensure all referenced files exist
102+ all_ok = True
103+ for loc in locations :
104+ data_path = (onnx_path .parent / loc ).resolve ()
105+ if not data_path .exists ():
106+ all_ok = False
107+ missing .append (str (data_path ))
108+ if all_ok :
109+ locs = ", " .join (sorted (locations ))
110+ print (f"[OK] { onnx_path .relative_to (out_root )} (external data: { locs } )" )
111+ else :
112+ print (f"[ERROR] { onnx_path .relative_to (out_root )} references missing external data: { locations } " )
113+
114+ print (f"[CHECK] Scanned { checked } ONNX files." )
115+ if missing :
116+ raise RuntimeError (
117+ "Some ONNX models reference external data files that are missing:\n "
118+ + "\n " .join (missing )
119+ + "\n \n This is required by ONNX external data format. "
120+ "Ensure the referenced files exist in the same folder as model.onnx. "
121+ "See ONNX external data spec for details."
122+ )
123+
124+ print ("[CHECK] All required external data files are present ✅" )
125+
126+
127+ def main ():
128+ parser = argparse .ArgumentParser (
129+ description = "Download Stable Diffusion 1.5 ONNX models (with optional external data)"
130+ )
131+ parser .add_argument (
132+ "--out_dir" ,
133+ type = str ,
134+ default = "models-onnx/modularai_stable-diffusion-1-5-onnx" ,
135+ help = "Output directory" ,
136+ )
137+ args = parser .parse_args ()
138+
139+ out_root = Path (args .out_dir ).resolve ()
140+ out_root .mkdir (parents = True , exist_ok = True )
141+
142+ download_onnx_models (out_root )
143+
144+ print ("\n [OK] ONNX model preparation finished." )
145+ print (f"Model root: { out_root } " )
146+
147+
148+ if __name__ == "__main__" :
149+ main ()
0 commit comments