diff --git a/glue/sample/src/sinter/_collection/_collection_test.py b/glue/sample/src/sinter/_collection/_collection_test.py index 58ae833e5..19b68bfce 100644 --- a/glue/sample/src/sinter/_collection/_collection_test.py +++ b/glue/sample/src/sinter/_collection/_collection_test.py @@ -1,4 +1,5 @@ import collections +from importlib.metadata import metadata import multiprocessing import pathlib import tempfile @@ -79,6 +80,54 @@ def test_collect(): assert d[0.03].errors <= 70 assert 1 <= d[0.04].errors <= 100 +def test_collect_postselection(): + def postselect_all_detectors_predicate(index: int, metadata: any, coords: tuple) -> bool: + return True + + tasks = [] + for p in [0.01, 0.02, 0.03, 0.04]: + circuit = stim.Circuit.generated( + 'repetition_code:memory', + rounds=3, + distance=3, + after_clifford_depolarization=p, + ) + mask = sinter._collection.post_selection_mask_from_predicate( + circuit_or_dem=circuit, + metadata={}, + postselected_detectors_predicate=postselect_all_detectors_predicate, + ) + tasks.append(sinter.Task( + circuit=circuit, + decoder='pymatching', + postselection_mask=mask, + json_metadata={'p': p}, + collection_options=sinter.CollectionOptions( + max_shots=1000, + max_errors=100, + start_batch_size=100, + max_batch_size=1000, + ), + )) + + results = sinter.collect( + num_workers=2, + tasks=tasks + ) + probabilities = [e.json_metadata['p'] for e in results] + assert len(probabilities) == len(set(probabilities)) + d = {e.json_metadata['p']: e for e in results} + print(d) + assert len(d) == 4 + for k, v in d.items(): + assert v.shots >= 1000 + assert v.errors <= 1 # there is some small probability for undetected logical error + assert d[0.01].discards <= 200 + assert d[0.02].discards <= 300 + assert d[0.03].discards <= 500 + assert 100 <= d[0.04].discards <= 1000 + + def test_collect_from_paths(): with tempfile.TemporaryDirectory() as d: diff --git a/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py b/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py index ea244b849..91e78fe96 100755 --- a/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py +++ b/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py @@ -197,8 +197,8 @@ def sample(self, max_shots: int) -> AnonTaskStats: raise ValueError("predictions.dtype != np.uint8") if len(predictions.shape) != 2: raise ValueError("len(predictions.shape) != 2") - if predictions.shape[0] != num_shots: - raise ValueError("predictions.shape[0] != num_shots") + if predictions.shape[0] != num_shots - num_discards_1: + raise ValueError("predictions.shape[0] != num_shots - num_discards_1") if predictions.shape[1] < actual_obs.shape[1]: raise ValueError("predictions.shape[1] < actual_obs.shape[1]") if predictions.shape[1] > actual_obs.shape[1] + 1: