Skip to content

Commit b232b18

Browse files
authored
Normalizes state after every W gate in simulator
1 parent c2e90b3 commit b232b18

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

cirq/google/sim/xmon_stepper.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,13 @@ def simulate_w(self,
313313
# W gate is within a shard.
314314
self._pool.map(_w_within_shard, args)
315315

316+
# Normalize after every w.
317+
norm = np.sum(self._pool.map(_norm, args))
318+
args = self._shard_num_args({
319+
'norm': norm
320+
})
321+
self._pool.map(_renorm, args)
322+
316323
def simulate_measurement(self, index: int) -> bool:
317324
"""Simulates a single qubit measurement in the computational basis.
318325
@@ -507,6 +514,19 @@ def _one_prob_per_shard(args: Dict[str, Any]) -> float:
507514
return norm * norm
508515

509516

517+
def _norm(args: Dict[str, Any]) -> float:
518+
"""Returns the norm for each state shard."""
519+
state = _state_shard(args)
520+
return np.dot(state, np.conjugate(state))
521+
522+
523+
def _renorm(args: Dict[str, Any]):
524+
"""Renormalizes the state using the norm arg."""
525+
state = _state_shard(args)
526+
# If our gate is so bad that we have norm of zero, we have bigger problems.
527+
state /= args['norm']
528+
529+
510530
def _collapse_state(args: Dict[str, Any]):
511531
"""Projects state shards onto the appropriate post measurement state.
512532

cirq/google/sim/xmon_stepper_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,3 +594,33 @@ def test_shard_for_more_prefix_qubits_than_qubits():
594594
expected = np.zeros(2 ** 2, dtype=np.complex64)
595595
expected[0] = 1.0
596596
np.testing.assert_almost_equal(expected, s.current_state)
597+
598+
599+
@pytest.mark.parametrize('num_prefix_qubits', (0, 2))
600+
def test_precision(num_prefix_qubits):
601+
# 16 random W's followed by their inverses on five qubits.
602+
# The floating point epsilon for np.float32 is about 1e-7.
603+
# Floating point epsilon is about 1e-7.
604+
# Each qubits error will add across gates on that qubit, but it is like a
605+
# random walk, so error should be about sqrt(16) per qubit.
606+
# The total error should then be about 2e-6.
607+
with xmon_stepper.Stepper(num_qubits=5,
608+
num_prefix_qubits=num_prefix_qubits,
609+
min_qubits_before_shard=0) as s:
610+
half_turns_list = [np.random.rand() for _ in range(16)]
611+
axis_half_turns_list = [np.random.rand() for _ in range(16)]
612+
613+
for half_turns, axis_half_turns in zip(half_turns_list,
614+
axis_half_turns_list):
615+
for index in range(5):
616+
s.simulate_w(index=index, axis_half_turns=axis_half_turns,
617+
half_turns=half_turns)
618+
for half_turns, axis_half_turns in zip(half_turns_list[::-1],
619+
axis_half_turns_list[::-1]):
620+
for index in range(5):
621+
s.simulate_w(index=index, axis_half_turns=axis_half_turns,
622+
half_turns=-half_turns)
623+
expected = np.zeros(2 ** 5, dtype=np.complex64)
624+
expected[0] = 1.0
625+
# asserts that abs value of arrays is < 1.5 * 10^(-decimal)
626+
np.testing.assert_almost_equal(expected, s.current_state, decimal=6)

0 commit comments

Comments
 (0)