diff --git a/glue/sample/src/sinter/_decoding/_decoding.py b/glue/sample/src/sinter/_decoding/_decoding.py index 1e54f87ef..666d3cd7e 100644 --- a/glue/sample/src/sinter/_decoding/_decoding.py +++ b/glue/sample/src/sinter/_decoding/_decoding.py @@ -180,6 +180,14 @@ def sample_decode(*, decoder: The name of the decoder to use. Allowed values are: "pymatching": Use pymatching min-weight-perfect-match decoder. + "correlated_pymatching": + Use two-pass correlated pymatching decoder. + "fusion_blossom": + Use fusion blossom min-weight-perfect-match decoder. + "hypergraph_union_find": + Use weighted hypergraph union-find decoder. + "mw_parity_factor": + Use mwpf min-weight-parity-factor decoder. "internal": Use internal decoder with uncorrelated decoding. "internal_correlated": diff --git a/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py b/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py index 92d8d49dd..6cd40d7d8 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py +++ b/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py @@ -1,6 +1,7 @@ from typing import Dict from typing import Union +from sinter._decoding._decoding_correlated_pymatching import CorrelatedPyMatchingDecoder from sinter._decoding._decoding_decoder_class import Decoder from sinter._decoding._decoding_fusion_blossom import FusionBlossomDecoder from sinter._decoding._decoding_pymatching import PyMatchingDecoder @@ -12,6 +13,7 @@ BUILT_IN_DECODERS: Dict[str, Decoder] = { 'vacuous': VacuousDecoder(), 'pymatching': PyMatchingDecoder(), + "correlated_pymatching": CorrelatedPyMatchingDecoder(), 'fusion_blossom': FusionBlossomDecoder(), # an implementation of (weighted) hypergraph UF decoder (https://arxiv.org/abs/2103.08049) 'hypergraph_union_find': HyperUFDecoder(), diff --git a/glue/sample/src/sinter/_decoding/_decoding_correlated_pymatching.py b/glue/sample/src/sinter/_decoding/_decoding_correlated_pymatching.py new file mode 100644 index 000000000..d3536ee99 --- /dev/null +++ b/glue/sample/src/sinter/_decoding/_decoding_correlated_pymatching.py @@ -0,0 +1,106 @@ +from packaging import version + +from sinter._decoding._decoding_decoder_class import Decoder, CompiledDecoder + + +class CorrelatedPyMatchingCompiledDecoder(CompiledDecoder): + def __init__(self, matcher: "pymatching.Matching"): + self.matcher = matcher + + def decode_shots_bit_packed( + self, + *, + bit_packed_detection_event_data: "np.ndarray", + ) -> "np.ndarray": + return self.matcher.decode_batch( + shots=bit_packed_detection_event_data, + bit_packed_shots=True, + bit_packed_predictions=True, + return_weights=False, + ) + + +class CorrelatedPyMatchingDecoder(Decoder): + """Use correlated pymatching to predict observables from detection events.""" + + def compile_decoder_for_dem( + self, *, dem: "stim.DetectorErrorModel" + ) -> CompiledDecoder: + try: + import pymatching + except ImportError as ex: + raise ImportError( + "The decoder 'pymatching' isn't installed\n" + "To fix this, install the python package 'pymatching' into your environment.\n" + "For example, if you are using pip, run `pip install pymatching`.\n" + ) from ex + + # correlated matching requires pymatching 2.3.1 or later + if version.parse(pymatching.__version__) < version.parse("2.3.1"): + raise ValueError(""" +The correlated pymatching decoder requires pymatching 2.3.1 or later. + +If you're using pip to install packages, this can be fixed by running +``` +pip install "pymatching~=2.3.1" --upgrade +``` +""") + + return CorrelatedPyMatchingCompiledDecoder( + pymatching.Matching.from_detector_error_model(dem, enable_correlations=True) + ) + + def decode_via_files( + self, + *, + num_shots: int, + num_dets: int, + num_obs: int, + dem_path: "pathlib.Path", + dets_b8_in_path: "pathlib.Path", + obs_predictions_b8_out_path: "pathlib.Path", + tmp_dir: "pathlib.Path", + ) -> None: + try: + import pymatching + except ImportError as ex: + raise ImportError( + "The decoder 'pymatching' isn't installed\n" + "To fix this, install the python package 'pymatching' into your environment.\n" + "For example, if you are using pip, run `pip install pymatching`.\n" + ) from ex + + # correlated matching requires pymatching 2.3.1 or later + if version.parse(pymatching.__version__) < version.parse("2.3.1"): + raise ValueError(""" +The correlated pymatching decoder requires pymatching 2.3.1 or later. + +If you're using pip to install packages, this can be fixed by running +``` +pip install "pymatching~=2.3.1" --upgrade +``` +""") + + if num_dets == 0: + with open(obs_predictions_b8_out_path, "wb") as f: + f.write(b"\0" * (num_obs * num_shots)) + return + + result = pymatching.cli( + command_line_args=[ + "predict", + "--dem", + str(dem_path), + "--in", + str(dets_b8_in_path), + "--in_format", + "b8", + "--out", + str(obs_predictions_b8_out_path), + "--out_format", + "b8", + "--enable_correlations", + ] + ) + if result: + raise ValueError("pymatching.cli returned a non-zero exit code") diff --git a/glue/sample/src/sinter/_decoding/_decoding_test.py b/glue/sample/src/sinter/_decoding/_decoding_test.py index cd4e28d0d..7e20083ae 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_test.py +++ b/glue/sample/src/sinter/_decoding/_decoding_test.py @@ -22,6 +22,11 @@ def get_test_decoders() -> Tuple[List[str], Dict[str, sinter.Decoder]]: custom_decoders = {} try: import pymatching + from packaging import version + + if version.parse(pymatching.__version__) < version.parse("2.3.0"): + available_decoders.remove('correlated_pymatching') + except ImportError: available_decoders.remove('pymatching') try: @@ -234,7 +239,7 @@ def test_no_detectors_with_post_mask(decoder: str, force_streaming: Optional[boo @pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES) def test_post_selection(decoder: str, force_streaming: Optional[bool]): circuit = stim.Circuit(""" - X_ERROR(0.6) 0 + X_ERROR(0.4) 0 M 0 DETECTOR(2, 0, 0, 1) rec[-1] OBSERVABLE_INCLUDE(0) rec[-1] @@ -243,7 +248,7 @@ def test_post_selection(decoder: str, force_streaming: Optional[bool]): M 1 DETECTOR(1, 0, 0) rec[-1] OBSERVABLE_INCLUDE(0) rec[-1] - + X_ERROR(0.1) 2 M 2 OBSERVABLE_INCLUDE(0) rec[-1] @@ -259,9 +264,9 @@ def test_post_selection(decoder: str, force_streaming: Optional[bool]): __private__unstable__force_decode_on_disk=force_streaming, custom_decoders=TEST_CUSTOM_DECODERS, ) - assert 1050 <= result.discards <= 1350 + assert 650 <= result.discards <= 950 if 'vacuous' not in decoder: - assert 40 <= result.errors <= 160 + assert 60 <= result.errors <= 240 @pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES)