Skip to content

Commit 492ecc5

Browse files
link tesseract with stim and add tests for the python wrapper of common.h (#32)
some classes in stim are not exposed directly to python (e.g. DemTarget) instead they are exposed through a child class (e.g. ExposedDemTarget) ... this forces us to do the conversion to be able to cross the python/C++ boundary. --- note that This PR doesn't actually link with stim instead the rule "@stim_py//:stim" rebuilds the python wrapper for stim in a way that allows us to use it. this is in order to unblock this project until I figure out how to link to the `stim.so` file generated by https://github.com/quantumlib/Stim/blob/f566b83c5da89ab94d7280178e2bd642350180c3/BUILD#L90 TODO: find the correct way to link to the `stim.so` file. --- part of #17
1 parent 0d24008 commit 492ecc5

File tree

9 files changed

+216
-31
lines changed

9 files changed

+216
-31
lines changed

WORKSPACE

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,11 @@ http_archive(
6565
urls = ["https://github.com/bazelbuild/platforms/archive/refs/tags/0.0.6.zip"],
6666
strip_prefix = "platforms-0.0.6",
6767
)
68+
69+
http_archive(
70+
name = "stim_py",
71+
build_file = "//external:stim_py.BUILD",
72+
sha256 = "95236006859d6754be99629d4fb44788e742e962ac8c59caad421ca088f7350e",
73+
strip_prefix = "stim-1.15.0",
74+
urls = ["https://github.com/quantumlib/Stim/releases/download/v1.15.0/stim-1.15.0.tar.gz"],
75+
)

external/stim_py.BUILD

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
load("@pybind11_bazel//:build_defs.bzl", "pybind_library")
2+
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
3+
4+
SOURCE_FILES_NO_MAIN = glob(
5+
[
6+
"src/**/*.cc",
7+
"src/**/*.h",
8+
"src/**/*.inl",
9+
],
10+
exclude = glob([
11+
"src/**/*.test.cc",
12+
"src/**/*.test.h",
13+
"src/**/*.perf.cc",
14+
"src/**/*.perf.h",
15+
"src/**/*.pybind.cc",
16+
"src/**/*.pybind.h",
17+
"src/**/main.cc",
18+
]),
19+
)
20+
21+
PYBIND_MODULES = [
22+
"src/stim/py/march.pybind.cc",
23+
"src/stim/py/stim.pybind.cc",
24+
]
25+
26+
PYBIND_FILES_WITHOUT_MODULES = glob(
27+
[
28+
"src/**/*.pybind.cc",
29+
"src/**/*.pybind.h",
30+
],
31+
exclude=PYBIND_MODULES,
32+
)
33+
34+
35+
36+
pybind_library(
37+
name = "stim_pybind_lib",
38+
srcs = SOURCE_FILES_NO_MAIN + PYBIND_FILES_WITHOUT_MODULES,
39+
copts = [
40+
"-O3",
41+
"-std=c++20",
42+
"-fvisibility=hidden",
43+
"-march=native",
44+
"-DVERSION_INFO=0.0.dev0",
45+
],
46+
includes = ["src/"],
47+
visibility = ["//visibility:public"],
48+
)
49+
50+
pybind_extension(
51+
name = "stim",
52+
srcs = PYBIND_MODULES,
53+
copts = [
54+
"-O3",
55+
"-std=c++20",
56+
"-fvisibility=hidden",
57+
"-march=native",
58+
"-DSTIM_PYBIND11_MODULE_NAME=stim",
59+
"-DVERSION_INFO=0.0.dev0",
60+
],
61+
deps=[":stim_pybind_lib"],
62+
includes = ["src/"],
63+
visibility = ["//visibility:public"],
64+
)

src/BUILD

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# load("@benchmark//:benchmark.bzl", "cc_benchmark")
16-
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
16+
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library")
1717
load("@rules_python//python:defs.bzl", "py_library")
1818

1919
package(default_visibility = ["//visibility:public"])
@@ -64,22 +64,35 @@ cc_library(
6464
)
6565

6666

67+
pybind_library(
68+
name = "tesseract_decoder_pybind",
69+
srcs = [
70+
"common.pybind.h",
71+
],
72+
deps = [
73+
":libcommon",
74+
"@stim_py//:stim_pybind_lib",
75+
],
76+
)
77+
6778
pybind_extension(
6879
name = "tesseract_decoder",
6980
srcs = [
70-
"common.pybind.h",
7181
"tesseract.pybind.cc",
7282
],
7383
deps = [
74-
":libcommon",
84+
":tesseract_decoder_pybind",
85+
"@stim_py//:stim",
7586
],
7687
)
7788

7889

7990
py_library(
8091
name="lib_tesseract_decoder",
81-
data=[":tesseract_decoder"],
8292
imports=["src"],
93+
deps=[
94+
":tesseract_decoder",
95+
],
8396
)
8497

8598

src/common.pybind.h

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,54 @@
11
#ifndef TESSERACT_COMMON_PY_H
22
#define TESSERACT_COMMON_PY_H
33

4+
#include <vector>
5+
46
#include <pybind11/operators.h>
57
#include <pybind11/pybind11.h>
68
#include <pybind11/stl.h>
79

8-
#include <vector>
10+
#include "src/stim/dem/dem_instruction.pybind.h"
11+
#include "stim/dem/detector_error_model_target.pybind.h"
912

1013
#include "common.h"
1114

1215
namespace py = pybind11;
1316

14-
void add_common_module(py::module &root) {
15-
auto m = root.def_submodule("common", "classes commonly used by the decoder");
16-
17-
// TODO: add as_dem_instruction_targets
18-
py::class_<common::Symptom>(m, "Symptom")
19-
.def(py::init<std::vector<int>, common::ObservablesMask>(),
20-
py::arg("detectors") = std::vector<int>(),
21-
py::arg("observables") = 0)
22-
.def_readwrite("detectors", &common::Symptom::detectors)
23-
.def_readwrite("observables", &common::Symptom::observables)
24-
.def("__str__", &common::Symptom::str)
25-
.def(py::self == py::self)
26-
.def(py::self != py::self);
27-
28-
// TODO: add constructor with stim::DemInstruction.
29-
py::class_<common::Error>(m, "Error")
30-
.def_readwrite("likelihood_cost", &common::Error::likelihood_cost)
31-
.def_readwrite("probability", &common::Error::probability)
32-
.def_readwrite("symptom", &common::Error::symptom)
33-
.def("__str__", &common::Error::str)
34-
.def(py::init<>())
35-
.def(py::init<double, std::vector<int> &, common::ObservablesMask,
36-
std::vector<bool> &>())
37-
.def(py::init<double, double, std::vector<int> &, common::ObservablesMask,
38-
std::vector<bool> &>());
17+
void add_common_module(py::module &root)
18+
{
19+
auto m = root.def_submodule("common", "classes commonly used by the decoder");
20+
21+
py::class_<common::Symptom>(m, "Symptom")
22+
.def(py::init<std::vector<int>, common::ObservablesMask>(),
23+
py::arg("detectors") = std::vector<int>(),
24+
py::arg("observables") = 0)
25+
.def_readwrite("detectors", &common::Symptom::detectors)
26+
.def_readwrite("observables", &common::Symptom::observables)
27+
.def("__str__", &common::Symptom::str)
28+
.def(py::self == py::self)
29+
.def(py::self != py::self)
30+
.def("as_dem_instruction_targets", [](common::Symptom s)
31+
{
32+
std::vector<stim_pybind::ExposedDemTarget> ret;
33+
for(auto & t : s.as_dem_instruction_targets()) ret.emplace_back(t);
34+
return ret; });
35+
36+
py::class_<common::Error>(m, "Error")
37+
.def_readwrite("likelihood_cost", &common::Error::likelihood_cost)
38+
.def_readwrite("probability", &common::Error::probability)
39+
.def_readwrite("symptom", &common::Error::symptom)
40+
.def("__str__", &common::Error::str)
41+
.def(py::init<>())
42+
.def(py::init<double, std::vector<int> &, common::ObservablesMask,
43+
std::vector<bool> &>())
44+
.def(py::init<double, double, std::vector<int> &, common::ObservablesMask,
45+
std::vector<bool> &>())
46+
.def(py::init([](stim_pybind::ExposedDemInstruction edi)
47+
{ return new common::Error(edi.as_dem_instruction()); }));
48+
49+
m.def("merge_identical_errors", &common::merge_identical_errors);
50+
m.def("remove_zero_probability_errors", &common::remove_zero_probability_errors);
51+
m.def("dem_from_counts", &common::dem_from_counts);
3952
}
4053

4154
#endif

src/py/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
1+
load("@rules_python//python:py_test.bzl", "py_test")
12
load("@rules_python//python:pip.bzl", "compile_pip_requirements")
23

4+
py_test(
5+
name = "common_test",
6+
srcs = ["common_test.py"],
7+
visibility = ["//:__subpackages__"],
8+
deps = [
9+
"@pypi//pytest",
10+
"//src:lib_tesseract_decoder",
11+
],
12+
)
13+
314
compile_pip_requirements(
415
name = "requirements",
516
src = "requirements.in",

src/py/common_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import pytest
2+
import stim
3+
4+
from src import tesseract_decoder
5+
6+
7+
def test_as_dem_instruction_targets():
8+
s = tesseract_decoder.common.Symptom([1, 2], 4324)
9+
dits = s.as_dem_instruction_targets()
10+
assert dits == [
11+
stim.DemTarget("D1"),
12+
stim.DemTarget("D2"),
13+
stim.DemTarget("L2"),
14+
stim.DemTarget("L5"),
15+
stim.DemTarget("L6"),
16+
stim.DemTarget("L7"),
17+
stim.DemTarget("L12"),
18+
]
19+
20+
21+
def test_error_from_dem_instruction():
22+
di = stim.DemInstruction("error", [0.125], [stim.target_logical_observable_id(3)])
23+
error = tesseract_decoder.common.Error(di)
24+
25+
assert str(error) == "Error{cost=1.945910, symptom=Symptom{}}"
26+
27+
28+
def test_merge_identical_errors():
29+
dem = stim.DetectorErrorModel()
30+
assert isinstance(
31+
tesseract_decoder.common.merge_identical_errors(dem), stim.DetectorErrorModel
32+
)
33+
34+
35+
def test_remove_zero_probability_errors():
36+
dem = stim.DetectorErrorModel()
37+
assert isinstance(
38+
tesseract_decoder.common.remove_zero_probability_errors(dem),
39+
stim.DetectorErrorModel,
40+
)
41+
42+
43+
def test_dem_from_counts():
44+
dem = stim.DetectorErrorModel()
45+
assert isinstance(
46+
tesseract_decoder.common.dem_from_counts(dem, [], 3), stim.DetectorErrorModel
47+
)
48+
49+
50+
if __name__ == "__main__":
51+
raise SystemExit(pytest.main([__file__]))

src/py/requirements.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
stim
2+
pytest

src/py/requirements_lock.txt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
#
55
# bazel run //src/py:requirements.update
66
#
7+
iniconfig==2.1.0 \
8+
--hash=sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7 \
9+
--hash=sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760
10+
# via pytest
711
numpy==2.2.6 \
812
--hash=sha256:038613e9fb8c72b0a41f025a7e4c3f0b7a1b5d768ece4796b674c8f3fe13efff \
913
--hash=sha256:0678000bb9ac1475cd454c6b8c799206af8107e310843532b04d49649c717a47 \
@@ -61,6 +65,22 @@ numpy==2.2.6 \
6165
--hash=sha256:fe27749d33bb772c80dcd84ae7e8df2adc920ae8297400dabec45f0dedb3f6de \
6266
--hash=sha256:fee4236c876c4e8369388054d02d0e9bb84821feb1a64dd59e137e6511a551f8
6367
# via stim
68+
packaging==25.0 \
69+
--hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \
70+
--hash=sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f
71+
# via pytest
72+
pluggy==1.6.0 \
73+
--hash=sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3 \
74+
--hash=sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746
75+
# via pytest
76+
pygments==2.19.1 \
77+
--hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \
78+
--hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c
79+
# via pytest
80+
pytest==8.4.0 \
81+
--hash=sha256:14d920b48472ea0dbf68e45b96cd1ffda4705f33307dcc86c676c1b5104838a6 \
82+
--hash=sha256:f40f825768ad76c0977cbacdf1fd37c6f7a468e460ea6a0636078f8972d4517e
83+
# via -r src/py/requirements.in
6484
stim==1.15.0 \
6585
--hash=sha256:0bb3757c69c9b16fd24ff7400b5cddb22017c4cae84fc4b7b73f84373cb03c00 \
6686
--hash=sha256:190c5a3c9cecdfae3302d02057d1ed6d9ce7910d2bcc2ff375807d8f8ec5494d \

src/tesseract.pybind.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,8 @@
33
#include "common.pybind.h"
44
#include "pybind11/detail/common.h"
55

6-
PYBIND11_MODULE(tesseract_py, m) { add_common_module(m); }
6+
PYBIND11_MODULE(tesseract_decoder, m)
7+
{
8+
py::module::import("stim");
9+
add_common_module(m);
10+
}

0 commit comments

Comments
 (0)