Skip to content

Commit 1d81f0b

Browse files
Expose all of the remaining parts of tesseract to python (#33)
with this PR all of tesseract* can be used in python through bazel run/test. I also added string representations to the config classes in the next PR I will add the ability to do `pip install` *all of tesseract except the batching method `decode_shots` because it uses `stim::SparseShot` which stim doesn't export to python --- part of #17
1 parent 492ecc5 commit 1d81f0b

15 files changed

+585
-43
lines changed

src/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,15 @@ pybind_library(
6868
name = "tesseract_decoder_pybind",
6969
srcs = [
7070
"common.pybind.h",
71+
"utils.pybind.h",
72+
"simplex.pybind.h",
73+
"tesseract.pybind.h",
7174
],
7275
deps = [
7376
":libcommon",
77+
":libutils",
78+
":libsimplex",
79+
":libtesseract",
7480
"@stim_py//:stim_pybind_lib",
7581
],
7682
)

src/common.pybind.h

Lines changed: 60 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,75 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
#ifndef TESSERACT_COMMON_PY_H
216
#define TESSERACT_COMMON_PY_H
317

4-
#include <vector>
5-
618
#include <pybind11/operators.h>
719
#include <pybind11/pybind11.h>
820
#include <pybind11/stl.h>
921

10-
#include "src/stim/dem/dem_instruction.pybind.h"
11-
#include "stim/dem/detector_error_model_target.pybind.h"
22+
#include <vector>
1223

1324
#include "common.h"
25+
#include "src/stim/dem/dem_instruction.pybind.h"
26+
#include "stim/dem/detector_error_model_target.pybind.h"
1427

1528
namespace py = pybind11;
1629

17-
void add_common_module(py::module &root)
18-
{
19-
auto m = root.def_submodule("common", "classes commonly used by the decoder");
20-
21-
py::class_<common::Symptom>(m, "Symptom")
22-
.def(py::init<std::vector<int>, common::ObservablesMask>(),
23-
py::arg("detectors") = std::vector<int>(),
24-
py::arg("observables") = 0)
25-
.def_readwrite("detectors", &common::Symptom::detectors)
26-
.def_readwrite("observables", &common::Symptom::observables)
27-
.def("__str__", &common::Symptom::str)
28-
.def(py::self == py::self)
29-
.def(py::self != py::self)
30-
.def("as_dem_instruction_targets", [](common::Symptom s)
31-
{
32-
std::vector<stim_pybind::ExposedDemTarget> ret;
33-
for(auto & t : s.as_dem_instruction_targets()) ret.emplace_back(t);
34-
return ret; });
35-
36-
py::class_<common::Error>(m, "Error")
37-
.def_readwrite("likelihood_cost", &common::Error::likelihood_cost)
38-
.def_readwrite("probability", &common::Error::probability)
39-
.def_readwrite("symptom", &common::Error::symptom)
40-
.def("__str__", &common::Error::str)
41-
.def(py::init<>())
42-
.def(py::init<double, std::vector<int> &, common::ObservablesMask,
43-
std::vector<bool> &>())
44-
.def(py::init<double, double, std::vector<int> &, common::ObservablesMask,
45-
std::vector<bool> &>())
46-
.def(py::init([](stim_pybind::ExposedDemInstruction edi)
47-
{ return new common::Error(edi.as_dem_instruction()); }));
48-
49-
m.def("merge_identical_errors", &common::merge_identical_errors);
50-
m.def("remove_zero_probability_errors", &common::remove_zero_probability_errors);
51-
m.def("dem_from_counts", &common::dem_from_counts);
30+
void add_common_module(py::module &root) {
31+
auto m = root.def_submodule("common", "classes commonly used by the decoder");
32+
33+
py::class_<common::Symptom>(m, "Symptom")
34+
.def(py::init<std::vector<int>, common::ObservablesMask>(),
35+
py::arg("detectors") = std::vector<int>(),
36+
py::arg("observables") = 0)
37+
.def_readwrite("detectors", &common::Symptom::detectors)
38+
.def_readwrite("observables", &common::Symptom::observables)
39+
.def("__str__", &common::Symptom::str)
40+
.def(py::self == py::self)
41+
.def(py::self != py::self)
42+
.def("as_dem_instruction_targets", [](common::Symptom s) {
43+
std::vector<stim_pybind::ExposedDemTarget> ret;
44+
for (auto &t : s.as_dem_instruction_targets()) ret.emplace_back(t);
45+
return ret;
46+
});
47+
48+
py::class_<common::Error>(m, "Error")
49+
.def_readwrite("likelihood_cost", &common::Error::likelihood_cost)
50+
.def_readwrite("probability", &common::Error::probability)
51+
.def_readwrite("symptom", &common::Error::symptom)
52+
.def("__str__", &common::Error::str)
53+
.def(py::init<>())
54+
.def(py::init<double, std::vector<int> &, common::ObservablesMask,
55+
std::vector<bool> &>(),
56+
py::arg("likelihood_cost"), py::arg("detectors"),
57+
py::arg("observables"), py::arg("dets_array"))
58+
.def(py::init<double, double, std::vector<int> &, common::ObservablesMask,
59+
std::vector<bool> &>(),
60+
py::arg("likelihood_cost"), py::arg("probability"),
61+
py::arg("detectors"), py::arg("observables"), py::arg("dets_array"))
62+
.def(py::init([](stim_pybind::ExposedDemInstruction edi) {
63+
return new common::Error(edi.as_dem_instruction());
64+
}),
65+
py::arg("error"));
66+
67+
m.def("merge_identical_errors", &common::merge_identical_errors,
68+
py::arg("dem"));
69+
m.def("remove_zero_probability_errors",
70+
&common::remove_zero_probability_errors, py::arg("dem"));
71+
m.def("dem_from_counts", &common::dem_from_counts, py::arg("orig_dem"),
72+
py::arg("error_counts"), py::arg("num_shots"));
5273
}
5374

5475
#endif

src/py/BUILD

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,36 @@ py_test(
1111
],
1212
)
1313

14+
py_test(
15+
name = "utils_test",
16+
srcs = ["utils_test.py"],
17+
visibility = ["//:__subpackages__"],
18+
deps = [
19+
"@pypi//pytest",
20+
"//src:lib_tesseract_decoder",
21+
],
22+
)
23+
24+
py_test(
25+
name = "simplex_test",
26+
srcs = ["simplex_test.py"],
27+
visibility = ["//:__subpackages__"],
28+
deps = [
29+
"@pypi//pytest",
30+
"//src:lib_tesseract_decoder",
31+
],
32+
)
33+
34+
py_test(
35+
name = "tesseract_test",
36+
srcs = ["tesseract_test.py"],
37+
visibility = ["//:__subpackages__"],
38+
deps = [
39+
"@pypi//pytest",
40+
"//src:lib_tesseract_decoder",
41+
],
42+
)
43+
1444
compile_pip_requirements(
1545
name = "requirements",
1646
src = "requirements.in",

src/py/common_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http:#www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
import pytest
216
import stim
317

src/py/simplex_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http:#www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
import pytest
17+
import stim
18+
19+
from src import tesseract_decoder
20+
21+
_DETECTOR_ERROR_MODEL = stim.DetectorErrorModel(
22+
"""
23+
error(0.125) D0
24+
error(0.375) D0 D1
25+
error(0.25) D1
26+
"""
27+
)
28+
29+
30+
def test_create_simplex_config():
31+
sc = tesseract_decoder.simplex.SimplexConfig(_DETECTOR_ERROR_MODEL, window_length=5)
32+
assert sc.dem == _DETECTOR_ERROR_MODEL
33+
assert sc.window_length == 5
34+
assert (
35+
str(sc)
36+
== "SimplexConfig(dem=DetectorErrorModel_Object, window_length=5, window_slide_length=0, verbose=0)"
37+
)
38+
39+
40+
def test_create_simplex_decoder():
41+
decoder = tesseract_decoder.simplex.SimplexDecoder(
42+
tesseract_decoder.simplex.SimplexConfig(_DETECTOR_ERROR_MODEL, window_length=5)
43+
)
44+
decoder.init_ilp()
45+
decoder.decode_to_errors([1])
46+
assert decoder.mask_from_errors([1]) == 0
47+
assert decoder.cost_from_errors([2]) == pytest.approx(1.0986123)
48+
assert decoder.decode([1, 2]) == 0
49+
50+
51+
if __name__ == "__main__":
52+
raise SystemExit(pytest.main([__file__]))

src/py/tesseract_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http:#www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
import pytest
17+
import stim
18+
19+
from src import tesseract_decoder
20+
21+
_DETECTOR_ERROR_MODEL = stim.DetectorErrorModel(
22+
"""
23+
error(0.125) D0
24+
error(0.375) D0 D1
25+
error(0.25) D1
26+
"""
27+
)
28+
29+
30+
def test_create_config():
31+
assert (
32+
str(tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL))
33+
== "TesseractConfig(dem=DetectorErrorModel_Object, det_beam=65535, no_revisit_dets=0, at_most_two_errors_per_detector=0, verbose=0, pqlimit=18446744073709551615, det_orders=[], det_penalty=0)"
34+
)
35+
36+
37+
def test_create_node():
38+
node = tesseract_decoder.tesseract.Node(dets=["a"])
39+
assert node.dets == ["a"]
40+
41+
42+
def test_create_qnode():
43+
qnode = tesseract_decoder.tesseract.QNode(num_dets=5, errs=[42])
44+
assert qnode.num_dets == 5
45+
assert str(qnode) == "QNode(cost=0, num_dets=5, errs=[42])"
46+
47+
48+
def test_create_decoder():
49+
config = tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL)
50+
decoder = tesseract_decoder.tesseract.TesseractDecoder(config)
51+
decoder.decode_to_errors([0])
52+
decoder.decode_to_errors([0], 0)
53+
assert decoder.mask_from_errors([1]) == 0
54+
assert decoder.cost_from_errors([1]) == pytest.approx(1.609438)
55+
assert decoder.decode([0]) == 0
56+
57+
58+
if __name__ == "__main__":
59+
raise SystemExit(pytest.main([__file__]))

src/py/utils_test.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http:#www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
import pytest
17+
import stim
18+
19+
from src import tesseract_decoder
20+
21+
22+
_DETECTOR_ERROR_MODEL = stim.DetectorErrorModel(
23+
"""
24+
error(0.125) D0
25+
error(0.375) D0 D1
26+
error(0.25) D1
27+
"""
28+
)
29+
30+
31+
def test_module_has_global_constants():
32+
assert tesseract_decoder.utils.EPSILON <= 1e-7
33+
assert not math.isfinite(tesseract_decoder.utils.INF)
34+
35+
36+
def test_get_detector_coords():
37+
assert tesseract_decoder.utils.get_detector_coords(_DETECTOR_ERROR_MODEL) == []
38+
39+
40+
def test_build_detector_graph():
41+
assert tesseract_decoder.utils.build_detector_graph(_DETECTOR_ERROR_MODEL) == [
42+
[1],
43+
[0],
44+
]
45+
46+
47+
def test_get_errors_from_dem():
48+
expected = "Error{cost=1.945910, symptom=Symptom{D0 }}, Error{cost=0.510826, symptom=Symptom{D0 D1 }}, Error{cost=1.098612, symptom=Symptom{D1 }}"
49+
assert (
50+
", ".join(
51+
map(str, tesseract_decoder.utils.get_errors_from_dem(_DETECTOR_ERROR_MODEL))
52+
)
53+
== expected
54+
)
55+
56+
57+
if __name__ == "__main__":
58+
raise SystemExit(pytest.main([__file__]))

src/simplex.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,17 @@
2121

2222
constexpr size_t T_COORD = 2;
2323

24+
std::string SimplexConfig::str() {
25+
auto & self = *this;
26+
std::stringstream ss;
27+
ss << "SimplexConfig(";
28+
ss << "dem=" << "DetectorErrorModel_Object" << ", ";
29+
ss << "window_length=" << self.window_length << ", ";
30+
ss << "window_slide_length=" << self.window_slide_length << ", ";
31+
ss << "verbose=" << self.verbose << ")";
32+
return ss.str();
33+
}
34+
2435
SimplexDecoder::SimplexDecoder(SimplexConfig _config) : config(_config) {
2536
config.dem = common::remove_zero_probability_errors(config.dem);
2637
std::vector<double> detector_t_coords(config.dem.count_detectors());

src/simplex.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ struct SimplexConfig {
3030
size_t window_slide_length = 0;
3131
bool verbose = false;
3232
bool windowing_enabled() { return (window_length != 0); }
33+
std::string str();
3334
};
3435

3536
struct SimplexDecoder {

0 commit comments

Comments
 (0)