@@ -15,17 +15,22 @@ limitations under the License.
1515
1616#include " stablehlo/integrations/python/StablehloApi.h"
1717
18+ #include < stdexcept>
1819#include < string>
1920#include < string_view>
2021
2122#include " llvm/Support/raw_ostream.h"
2223#include " mlir-c/BuiltinAttributes.h"
2324#include " mlir-c/IR.h"
2425#include " mlir-c/Support.h"
25- #include " mlir/Bindings/Python/PybindAdaptors.h"
26+ #include " mlir/Bindings/Python/NanobindAdaptors.h"
27+ #include " nanobind/nanobind.h"
28+ #include " nanobind/stl/string.h"
29+ #include " nanobind/stl/string_view.h"
30+ #include " nanobind/stl/vector.h"
2631#include " stablehlo/integrations/c/StablehloApi.h"
2732
28- namespace py = pybind11 ;
33+ namespace nb = nanobind ;
2934
3035namespace mlir {
3136namespace stablehlo {
@@ -63,14 +68,18 @@ static MlirStringRef toMlirStringRef(std::string_view s) {
6368 return mlirStringRefCreate (s.data (), s.size ());
6469}
6570
66- void AddStablehloApi (py::module &m) {
71+ static MlirStringRef toMlirStringRef (const nb::bytes &s) {
72+ return mlirStringRefCreate (static_cast <const char *>(s.data ()), s.size ());
73+ }
74+
75+ void AddStablehloApi (nb::module_ &m) {
6776 // Portable API is a subset of StableHLO API
6877 AddPortableApi (m);
6978
7079 //
7180 // Utility APIs.
7281 //
73- py ::enum_<MlirStablehloCompatibilityRequirement>(
82+ nb ::enum_<MlirStablehloCompatibilityRequirement>(
7483 m, " StablehloCompatibilityRequirement" )
7584 .value (" NONE" , MlirStablehloCompatibilityRequirement::NONE)
7685 .value (" WEEK_4" , MlirStablehloCompatibilityRequirement::WEEK_4)
@@ -79,48 +88,57 @@ void AddStablehloApi(py::module &m) {
7988
8089 m.def (
8190 " get_version_from_compatibility_requirement" ,
82- [](MlirStablehloCompatibilityRequirement requirement) -> py::str {
91+ [](MlirStablehloCompatibilityRequirement requirement) -> std::string {
8392 StringWriterHelper accumulator;
8493 stablehloVersionFromCompatibilityRequirement (
8594 requirement, accumulator.getMlirStringCallback (),
8695 accumulator.getUserData ());
8796 return accumulator.toString ();
8897 },
89- py ::arg (" requirement" ));
98+ nb ::arg (" requirement" ));
9099
91100 //
92101 // Serialization APIs.
93102 //
94103 m.def (
95104 " serialize_portable_artifact" ,
96- [](MlirModule module , std::string_view target) -> py ::bytes {
105+ [](MlirModule module , std::string_view target) -> nb ::bytes {
97106 StringWriterHelper accumulator;
98107 if (mlirLogicalResultIsFailure (
99108 stablehloSerializePortableArtifactFromModule (
100109 module , toMlirStringRef (target),
101110 accumulator.getMlirStringCallback (),
102111 accumulator.getUserData ()))) {
103- PyErr_SetString (PyExc_ValueError, " failed to serialize module" );
104- return " " ;
112+ throw nb::value_error (" failed to serialize module" );
105113 }
106114
107- return py::bytes (accumulator.toString ());
115+ std::string serialized = accumulator.toString ();
116+ return nb::bytes (serialized.data (), serialized.size ());
108117 },
109- py ::arg (" module" ), py ::arg (" target" ));
118+ nb ::arg (" module" ), nb ::arg (" target" ));
110119
111120 m.def (
112121 " deserialize_portable_artifact" ,
113122 [](MlirContext context, std::string_view artifact) -> MlirModule {
114123 auto module = stablehloDeserializePortableArtifactNoError (
115124 toMlirStringRef (artifact), context);
116125 if (mlirModuleIsNull (module )) {
117- PyErr_SetString (PyExc_ValueError, " failed to deserialize module" );
118- return {};
126+ throw nb::value_error (" failed to deserialize module" );
119127 }
120128 return module ;
121129 },
122- py::arg (" context" ), py::arg (" artifact" ));
123-
130+ nb::arg (" context" ), nb::arg (" artifact" ));
131+ m.def (
132+ " deserialize_portable_artifact" ,
133+ [](MlirContext context, nb::bytes artifact) -> MlirModule {
134+ auto module = stablehloDeserializePortableArtifactNoError (
135+ toMlirStringRef (artifact), context);
136+ if (mlirModuleIsNull (module )) {
137+ throw nb::value_error (" failed to deserialize module" );
138+ }
139+ return module ;
140+ },
141+ nb::arg (" context" ), nb::arg (" artifact" ));
124142 //
125143 // Reference APIs
126144 //
@@ -130,9 +148,7 @@ void AddStablehloApi(py::module &m) {
130148 std::vector<MlirAttribute> &args) -> std::vector<MlirAttribute> {
131149 for (auto arg : args) {
132150 if (!mlirAttributeIsADenseElements (arg)) {
133- PyErr_SetString (PyExc_ValueError,
134- " input args must be DenseElementsAttr" );
135- return {};
151+ throw nb::value_error (" input args must be DenseElementsAttr" );
136152 }
137153 }
138154
@@ -141,8 +157,7 @@ void AddStablehloApi(py::module &m) {
141157 stablehloEvalModule (module , args.size (), args.data (), &errorCode);
142158
143159 if (errorCode != 0 ) {
144- PyErr_SetString (PyExc_ValueError, " interpreter failed" );
145- return {};
160+ throw nb::value_error (" interpreter failed" );
146161 }
147162
148163 std::vector<MlirAttribute> pyResults;
@@ -151,39 +166,39 @@ void AddStablehloApi(py::module &m) {
151166 }
152167 return pyResults;
153168 },
154- py ::arg (" module" ), py ::arg (" args" ));
169+ nb ::arg (" module" ), nb ::arg (" args" ));
155170}
156171
157- void AddPortableApi (py:: module &m) {
172+ void AddPortableApi (nb::module_ &m) {
158173 //
159174 // Utility APIs.
160175 //
161176 m.def (" get_api_version" , []() { return stablehloGetApiVersion (); });
162177
163178 m.def (
164179 " get_smaller_version" ,
165- [](const std::string &version1, const std::string &version2) -> py::str {
180+ [](const std::string &version1,
181+ const std::string &version2) -> std::string {
166182 StringWriterHelper accumulator;
167183 if (mlirLogicalResultIsFailure (stablehloGetSmallerVersion (
168184 toMlirStringRef (version1), toMlirStringRef (version2),
169185 accumulator.getMlirStringCallback (),
170186 accumulator.getUserData ()))) {
171- PyErr_SetString (PyExc_ValueError,
172- " failed to convert version to stablehlo version" );
173- return " " ;
187+ throw nb::value_error (
188+ " failed to convert version to stablehlo version" );
174189 }
175190 return accumulator.toString ();
176191 },
177- py ::arg (" version1" ), py ::arg (" version2" ));
192+ nb ::arg (" version1" ), nb ::arg (" version2" ));
178193
179- m.def (" get_current_version" , []() -> py::str {
194+ m.def (" get_current_version" , []() -> std::string {
180195 StringWriterHelper accumulator;
181196 stablehloGetCurrentVersion (accumulator.getMlirStringCallback (),
182197 accumulator.getUserData ());
183198 return accumulator.toString ();
184199 });
185200
186- m.def (" get_minimum_version" , []() -> py::str {
201+ m.def (" get_minimum_version" , []() -> std::string {
187202 StringWriterHelper accumulator;
188203 stablehloGetMinimumVersion (accumulator.getMlirStringCallback (),
189204 accumulator.getUserData ());
@@ -196,34 +211,64 @@ void AddPortableApi(py::module &m) {
196211 m.def (
197212 " serialize_portable_artifact_str" ,
198213 [](std::string_view moduleStrOrBytecode,
199- std::string_view targetVersion) -> py::bytes {
214+ std::string_view targetVersion) -> nb::bytes {
215+ StringWriterHelper accumulator;
216+ if (mlirLogicalResultIsFailure (
217+ stablehloSerializePortableArtifactFromStringRef (
218+ toMlirStringRef (moduleStrOrBytecode),
219+ toMlirStringRef (targetVersion),
220+ accumulator.getMlirStringCallback (),
221+ accumulator.getUserData ()))) {
222+ throw nb::value_error (" failed to serialize module" );
223+ }
224+ std::string serialized = accumulator.toString ();
225+ return nb::bytes (serialized.data (), serialized.size ());
226+ },
227+ nb::arg (" module_str" ), nb::arg (" target_version" ));
228+ m.def (
229+ " serialize_portable_artifact_str" ,
230+ [](nb::bytes moduleStrOrBytecode,
231+ std::string_view targetVersion) -> nb::bytes {
200232 StringWriterHelper accumulator;
201233 if (mlirLogicalResultIsFailure (
202234 stablehloSerializePortableArtifactFromStringRef (
203235 toMlirStringRef (moduleStrOrBytecode),
204236 toMlirStringRef (targetVersion),
205237 accumulator.getMlirStringCallback (),
206238 accumulator.getUserData ()))) {
207- PyErr_SetString (PyExc_ValueError, " failed to serialize module" );
208- return " " ;
239+ throw nb::value_error (" failed to serialize module" );
209240 }
210- return py::bytes (accumulator.toString ());
241+ std::string serialized = accumulator.toString ();
242+ return nb::bytes (serialized.data (), serialized.size ());
211243 },
212- py ::arg (" module_str" ), py ::arg (" target_version" ));
244+ nb ::arg (" module_str" ), nb ::arg (" target_version" ));
213245
214246 m.def (
215247 " deserialize_portable_artifact_str" ,
216- [](std::string_view artifact) -> py::bytes {
248+ [](std::string_view artifact) -> nb::bytes {
249+ StringWriterHelper accumulator;
250+ if (mlirLogicalResultIsFailure (stablehloDeserializePortableArtifact (
251+ toMlirStringRef (artifact), accumulator.getMlirStringCallback (),
252+ accumulator.getUserData ()))) {
253+ throw nb::value_error (" failed to deserialize module" );
254+ }
255+ std::string serialized = accumulator.toString ();
256+ return nb::bytes (serialized.data (), serialized.size ());
257+ },
258+ nb::arg (" artifact_str" ));
259+ m.def (
260+ " deserialize_portable_artifact_str" ,
261+ [](const nb::bytes& artifact) -> nb::bytes {
217262 StringWriterHelper accumulator;
218263 if (mlirLogicalResultIsFailure (stablehloDeserializePortableArtifact (
219264 toMlirStringRef (artifact), accumulator.getMlirStringCallback (),
220265 accumulator.getUserData ()))) {
221- PyErr_SetString (PyExc_ValueError, " failed to deserialize module" );
222- return " " ;
266+ throw nb::value_error (" failed to deserialize module" );
223267 }
224- return py::bytes (accumulator.toString ());
268+ std::string serialized = accumulator.toString ();
269+ return nb::bytes (serialized.data (), serialized.size ());
225270 },
226- py ::arg (" artifact_str" ));
271+ nb ::arg (" artifact_str" ));
227272}
228273
229274} // namespace stablehlo
0 commit comments