Skip to content

Commit c9ab4b8

Browse files
authored
Add support for circuit instructions to stim.TableauSimulator.do (#387)
- Add `GateTarget::has_qubit_value` Fixes #285
1 parent 414a8c7 commit c9ab4b8

File tree

10 files changed

+120
-57
lines changed

10 files changed

+120
-57
lines changed

doc/python_api_reference_vDev.md

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8858,27 +8858,16 @@ def cz(
88588858
# stim.TableauSimulator.do
88598859

88608860
# (in class stim.TableauSimulator)
8861-
@overload
8862-
def do(
8863-
self,
8864-
circuit_or_pauli_string: stim.Circuit,
8865-
) -> None:
8866-
pass
8867-
@overload
8868-
def do(
8869-
self,
8870-
circuit_or_pauli_string: stim.PauliString,
8871-
) -> None:
8872-
pass
88738861
def do(
88748862
self,
8875-
circuit_or_pauli_string: object,
8863+
circuit_or_pauli_string: Union[stim.Circuit, stim.PauliString, stim.CircuitInstruction, stim.CircuitRepeatBlock],
88768864
) -> None:
88778865
"""Applies a circuit or pauli string to the simulator's state.
88788866
88798867
Args:
8880-
circuit_or_pauli_string: A stim.Circuit or a stim.PauliString containing
8881-
operations to apply to the simulator's state.
8868+
circuit_or_pauli_string: A stim.Circuit, stim.PauliString,
8869+
stim.CircuitInstruction, or stim.CircuitRepeatBlock
8870+
with operations to apply to the simulator's state.
88828871
88838872
Examples:
88848873
>>> import stim

doc/stim.pyi

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6783,27 +6783,16 @@ class TableauSimulator:
67836783
Applies the gate to the first two targets, then the next two targets,
67846784
and so forth. There must be an even number of targets.
67856785
"""
6786-
@overload
6787-
def do(
6788-
self,
6789-
circuit_or_pauli_string: stim.Circuit,
6790-
) -> None:
6791-
pass
6792-
@overload
6793-
def do(
6794-
self,
6795-
circuit_or_pauli_string: stim.PauliString,
6796-
) -> None:
6797-
pass
67986786
def do(
67996787
self,
6800-
circuit_or_pauli_string: object,
6788+
circuit_or_pauli_string: Union[stim.Circuit, stim.PauliString, stim.CircuitInstruction, stim.CircuitRepeatBlock],
68016789
) -> None:
68026790
"""Applies a circuit or pauli string to the simulator's state.
68036791
68046792
Args:
6805-
circuit_or_pauli_string: A stim.Circuit or a stim.PauliString containing
6806-
operations to apply to the simulator's state.
6793+
circuit_or_pauli_string: A stim.Circuit, stim.PauliString,
6794+
stim.CircuitInstruction, or stim.CircuitRepeatBlock
6795+
with operations to apply to the simulator's state.
68076796
68086797
Examples:
68096798
>>> import stim

glue/python/src/stim/__init__.pyi

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6783,27 +6783,16 @@ class TableauSimulator:
67836783
Applies the gate to the first two targets, then the next two targets,
67846784
and so forth. There must be an even number of targets.
67856785
"""
6786-
@overload
6787-
def do(
6788-
self,
6789-
circuit_or_pauli_string: stim.Circuit,
6790-
) -> None:
6791-
pass
6792-
@overload
6793-
def do(
6794-
self,
6795-
circuit_or_pauli_string: stim.PauliString,
6796-
) -> None:
6797-
pass
67986786
def do(
67996787
self,
6800-
circuit_or_pauli_string: object,
6788+
circuit_or_pauli_string: Union[stim.Circuit, stim.PauliString, stim.CircuitInstruction, stim.CircuitRepeatBlock],
68016789
) -> None:
68026790
"""Applies a circuit or pauli string to the simulator's state.
68036791
68046792
Args:
6805-
circuit_or_pauli_string: A stim.Circuit or a stim.PauliString containing
6806-
operations to apply to the simulator's state.
6793+
circuit_or_pauli_string: A stim.Circuit, stim.PauliString,
6794+
stim.CircuitInstruction, or stim.CircuitRepeatBlock
6795+
with operations to apply to the simulator's state.
68076796
68086797
Examples:
68096798
>>> import stim

src/stim/circuit/gate_target.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ bool GateTarget::is_inverted_result_target() const {
7979
bool GateTarget::is_measurement_record_target() const {
8080
return data & TARGET_RECORD_BIT;
8181
}
82+
bool GateTarget::has_qubit_value() const {
83+
return !(data & (TARGET_RECORD_BIT | TARGET_SWEEP_BIT | TARGET_COMBINER));
84+
}
8285
bool GateTarget::is_qubit_target() const {
8386
return !(data & (TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT | TARGET_RECORD_BIT | TARGET_SWEEP_BIT | TARGET_COMBINER));
8487
}

src/stim/circuit/gate_target.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ struct GateTarget {
4343
static GateTarget sweep_bit(uint32_t index);
4444
static GateTarget combiner();
4545

46+
bool has_qubit_value() const;
4647
bool is_combiner() const;
4748
bool is_x_target() const;
4849
bool is_y_target() const;

src/stim/circuit/gate_target.test.cc

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
#include "gtest/gtest.h"
1818

19-
#include "stim/test_util.test.h"
20-
2119
using namespace stim;
2220

2321
TEST(gate_target, xyz) {
@@ -35,6 +33,8 @@ TEST(gate_target, xyz) {
3533
ASSERT_EQ(t.is_z_target(), false);
3634
ASSERT_EQ(t.str(), "stim.target_x(5)");
3735
ASSERT_EQ(t.value(), 5);
36+
ASSERT_TRUE(t.has_qubit_value());
37+
ASSERT_FALSE(t.is_sweep_bit_target());
3838

3939
t = GateTarget::x(7, true);
4040
ASSERT_EQ(t.is_combiner(), false);
@@ -46,6 +46,8 @@ TEST(gate_target, xyz) {
4646
ASSERT_EQ(t.is_z_target(), false);
4747
ASSERT_EQ(t.str(), "stim.target_x(7, invert=True)");
4848
ASSERT_EQ(t.value(), 7);
49+
ASSERT_TRUE(t.has_qubit_value());
50+
ASSERT_FALSE(t.is_sweep_bit_target());
4951

5052
t = GateTarget::y(11, false);
5153
ASSERT_EQ(t.is_combiner(), false);
@@ -57,6 +59,8 @@ TEST(gate_target, xyz) {
5759
ASSERT_EQ(t.is_z_target(), false);
5860
ASSERT_EQ(t.str(), "stim.target_y(11)");
5961
ASSERT_EQ(t.value(), 11);
62+
ASSERT_TRUE(t.has_qubit_value());
63+
ASSERT_FALSE(t.is_sweep_bit_target());
6064

6165
t = GateTarget::y(13, true);
6266
ASSERT_EQ(t.is_combiner(), false);
@@ -68,6 +72,7 @@ TEST(gate_target, xyz) {
6872
ASSERT_EQ(t.is_z_target(), false);
6973
ASSERT_EQ(t.str(), "stim.target_y(13, invert=True)");
7074
ASSERT_EQ(t.value(), 13);
75+
ASSERT_FALSE(t.is_sweep_bit_target());
7176

7277
t = GateTarget::z(17, false);
7378
ASSERT_EQ(t.is_combiner(), false);
@@ -79,6 +84,8 @@ TEST(gate_target, xyz) {
7984
ASSERT_EQ(t.is_z_target(), true);
8085
ASSERT_EQ(t.str(), "stim.target_z(17)");
8186
ASSERT_EQ(t.value(), 17);
87+
ASSERT_TRUE(t.has_qubit_value());
88+
ASSERT_FALSE(t.is_sweep_bit_target());
8289

8390
t = GateTarget::z(19, true);
8491
ASSERT_EQ(t.is_combiner(), false);
@@ -91,6 +98,8 @@ TEST(gate_target, xyz) {
9198
ASSERT_EQ(t.str(), "stim.target_z(19, invert=True)");
9299
ASSERT_EQ(t.value(), 19);
93100
ASSERT_EQ(t.qubit_value(), 19);
101+
ASSERT_TRUE(t.has_qubit_value());
102+
ASSERT_FALSE(t.is_sweep_bit_target());
94103
}
95104

96105
TEST(gate_target, qubit) {
@@ -107,6 +116,8 @@ TEST(gate_target, qubit) {
107116
ASSERT_EQ(t.str(), "5");
108117
ASSERT_EQ(t.value(), 5);
109118
ASSERT_EQ(t.qubit_value(), 5);
119+
ASSERT_TRUE(t.has_qubit_value());
120+
ASSERT_FALSE(t.is_sweep_bit_target());
110121

111122
t = GateTarget::qubit(7, true);
112123
ASSERT_EQ(t.is_combiner(), false);
@@ -118,6 +129,8 @@ TEST(gate_target, qubit) {
118129
ASSERT_EQ(t.is_z_target(), false);
119130
ASSERT_EQ(t.str(), "stim.target_inv(7)");
120131
ASSERT_EQ(t.value(), 7);
132+
ASSERT_TRUE(t.has_qubit_value());
133+
ASSERT_FALSE(t.is_sweep_bit_target());
121134
}
122135

123136
TEST(gate_target, record) {
@@ -136,6 +149,23 @@ TEST(gate_target, record) {
136149
ASSERT_EQ(t.str(), "stim.target_rec(-5)");
137150
ASSERT_EQ(t.value(), -5);
138151
ASSERT_EQ(t.qubit_value(), 5);
152+
ASSERT_FALSE(t.has_qubit_value());
153+
ASSERT_FALSE(t.is_sweep_bit_target());
154+
}
155+
156+
TEST(gate_target, sweep) {
157+
auto t = GateTarget::sweep_bit(5);
158+
ASSERT_EQ(t.is_combiner(), false);
159+
ASSERT_EQ(t.is_inverted_result_target(), false);
160+
ASSERT_EQ(t.is_measurement_record_target(), false);
161+
ASSERT_EQ(t.is_qubit_target(), false);
162+
ASSERT_EQ(t.is_x_target(), false);
163+
ASSERT_EQ(t.is_y_target(), false);
164+
ASSERT_EQ(t.is_z_target(), false);
165+
ASSERT_EQ(t.str(), "stim.target_sweep_bit(5)");
166+
ASSERT_EQ(t.value(), 5);
167+
ASSERT_FALSE(t.has_qubit_value());
168+
ASSERT_TRUE(t.is_sweep_bit_target());
139169
}
140170

141171
TEST(gate_target, combiner) {
@@ -149,6 +179,8 @@ TEST(gate_target, combiner) {
149179
ASSERT_EQ(t.is_z_target(), false);
150180
ASSERT_EQ(t.str(), "stim.GateTarget.combiner()");
151181
ASSERT_EQ(t.qubit_value(), 0);
182+
ASSERT_FALSE(t.has_qubit_value());
183+
ASSERT_FALSE(t.is_sweep_bit_target());
152184
}
153185

154186
TEST(gate_target, equality) {

src/stim/simulators/tableau_simulator.cc

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -951,11 +951,13 @@ void TableauSimulator::collapse_isolate_qubit_z(size_t target, TableauTransposed
951951
}
952952
}
953953

954-
void TableauSimulator::expand_do_circuit(const Circuit &circuit) {
954+
void TableauSimulator::expand_do_circuit(const Circuit &circuit, uint64_t reps) {
955955
ensure_large_enough_for_qubits(circuit.count_qubits());
956-
circuit.for_each_operation([&](const Operation &op) {
957-
((*this).*op.gate->tableau_simulator_function)(op.target_data);
958-
});
956+
for (uint64_t k = 0; k < reps; k++) {
957+
circuit.for_each_operation([&](const Operation &op) {
958+
((*this).*op.gate->tableau_simulator_function)(op.target_data);
959+
});
960+
}
959961
}
960962

961963
simd_bits<MAX_BITWORD_WIDTH> TableauSimulator::reference_sample_circuit(const Circuit &circuit) {
@@ -969,6 +971,17 @@ void TableauSimulator::paulis(const PauliString &paulis) {
969971
inv_state.xs.signs.word_range_ref(0, nw) ^= paulis.zs;
970972
}
971973

974+
void TableauSimulator::do_operation_ensure_size(const Operation &operation) {
975+
uint64_t n = 0;
976+
for (const auto &t : operation.target_data.targets) {
977+
if (t.has_qubit_value()) {
978+
n = std::max(n, (uint64_t)t.qubit_value() + 1);
979+
}
980+
}
981+
ensure_large_enough_for_qubits(n);
982+
((*this).*operation.gate->tableau_simulator_function)(operation.target_data);
983+
}
984+
972985
void TableauSimulator::set_num_qubits(size_t new_num_qubits) {
973986
if (new_num_qubits >= inv_state.num_qubits) {
974987
ensure_large_enough_for_qubits(new_num_qubits);

src/stim/simulators/tableau_simulator.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ struct TableauSimulator {
143143
/// Runs all of the operations in the given circuit.
144144
///
145145
/// Automatically expands the tableau simulator's state, if needed.
146-
void expand_do_circuit(const Circuit &circuit);
146+
void expand_do_circuit(const Circuit &circuit, uint64_t reps = 1);
147+
void do_operation_ensure_size(const Operation &operation);
147148

148149
void apply_tableau(const Tableau &tableau, const std::vector<size_t> &targets);
149150

src/stim/simulators/tableau_simulator.pybind.cc

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,36 @@
1919
#include "stim/simulators/tableau_simulator.h"
2020
#include "stim/stabilizers/conversions.h"
2121
#include "stim/stabilizers/pauli_string.pybind.h"
22+
#include "stim/circuit/circuit_instruction.pybind.h"
2223
#include "stim/stabilizers/tableau.h"
24+
#include "stim/circuit/circuit_repeat_block.pybind.h"
2325

2426
using namespace stim;
2527
using namespace stim_pybind;
2628

29+
void do_circuit_instruction(TableauSimulator &self, const CircuitInstruction &circuit_instruction) {
30+
self.do_operation_ensure_size(Operation{
31+
&circuit_instruction.gate,
32+
{
33+
circuit_instruction.gate_args,
34+
circuit_instruction.targets,
35+
}
36+
});
37+
}
38+
2739
void do_obj(TableauSimulator &self, const pybind11::object &obj) {
2840
if (pybind11::isinstance<Circuit>(obj)) {
2941
self.expand_do_circuit(pybind11::cast<Circuit>(obj));
42+
} else if (pybind11::isinstance<CircuitRepeatBlock>(obj)) {
43+
const CircuitRepeatBlock &block = pybind11::cast<CircuitRepeatBlock>(obj);
44+
self.expand_do_circuit(block.body, block.repeat_count);
3045
} else if (pybind11::isinstance<PyPauliString>(obj)) {
3146
const PyPauliString &pauli_string = pybind11::cast<PyPauliString>(obj);
3247
self.ensure_large_enough_for_qubits(pauli_string.value.num_qubits);
3348
self.paulis(pauli_string.value);
49+
} else if (pybind11::isinstance<CircuitInstruction>(obj)) {
50+
const CircuitInstruction &circuit_instruction = pybind11::cast<CircuitInstruction>(obj);
51+
do_circuit_instruction(self, circuit_instruction);
3452
} else {
3553
std::stringstream ss;
3654
ss << "Don't know how to handle ";
@@ -413,12 +431,12 @@ void stim_pybind::pybind_tableau_simulator_methods(pybind11::module &m, pybind11
413431
pybind11::arg("circuit_or_pauli_string"),
414432
clean_doc_string(u8R"DOC(
415433
Applies a circuit or pauli string to the simulator's state.
416-
@overload def do(self, circuit_or_pauli_string: stim.Circuit) -> None:
417-
@overload def do(self, circuit_or_pauli_string: stim.PauliString) -> None:
434+
@signature def do(self, circuit_or_pauli_string: Union[stim.Circuit, stim.PauliString, stim.CircuitInstruction, stim.CircuitRepeatBlock]) -> None:
418435
419436
Args:
420-
circuit_or_pauli_string: A stim.Circuit or a stim.PauliString containing
421-
operations to apply to the simulator's state.
437+
circuit_or_pauli_string: A stim.Circuit, stim.PauliString,
438+
stim.CircuitInstruction, or stim.CircuitRepeatBlock
439+
with operations to apply to the simulator's state.
422440
423441
Examples:
424442
>>> import stim
@@ -461,7 +479,9 @@ void stim_pybind::pybind_tableau_simulator_methods(pybind11::module &m, pybind11
461479

462480
c.def(
463481
"do_circuit",
464-
&TableauSimulator::expand_do_circuit,
482+
[](TableauSimulator &self, const Circuit &circuit) {
483+
self.expand_do_circuit(circuit);
484+
},
465485
pybind11::arg("circuit"),
466486
clean_doc_string(u8R"DOC(
467487
Applies a circuit to the simulator's state.

src/stim/simulators/tableau_simulator_pybind_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,3 +594,29 @@ def test_copy_with_explicit_copy_rng_and_seed():
594594
s = stim.TableauSimulator()
595595
with pytest.raises(ValueError, match='seed and copy_rng are incompatible'):
596596
_ = s.copy(copy_rng=True, seed=0)
597+
598+
599+
def test_do_circuit_instruction():
600+
s = stim.TableauSimulator()
601+
assert s.peek_z(0) == +1
602+
s.do(stim.Circuit("X 0")[0])
603+
assert s.peek_z(0) == -1
604+
605+
s.do(stim.Circuit("""
606+
REPEAT 100 {
607+
CNOT 0 1 1 2 2 3 3 4 4 5 5 6 6 7 7 0
608+
}
609+
""")[0])
610+
assert s.peek_z(0) == +1
611+
assert s.peek_z(1) == +1
612+
assert s.peek_z(2) == +1
613+
assert s.peek_z(3) == -1
614+
assert s.peek_z(4) == +1
615+
assert s.peek_z(5) == -1
616+
assert s.peek_z(6) == +1
617+
assert s.peek_z(7) == +1
618+
619+
s.do(stim.Circuit("X 500")[0])
620+
assert s.peek_z(499) == +1
621+
assert s.peek_z(500) == -1
622+
assert s.peek_z(501) == +1

0 commit comments

Comments
 (0)