Skip to content

Commit 8e2781a

Browse files
Merge pull request #274 from MichaelBroughton/exclude_support
Add inclusion/exclusion support on bulksetampl.
2 parents 3824fb2 + f6d1444 commit 8e2781a

File tree

7 files changed

+158
-29
lines changed

7 files changed

+158
-29
lines changed

lib/statespace_avx.h

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -239,22 +239,31 @@ class StateSpaceAVX : public StateSpace<StateSpaceAVX<For>, For, float> {
239239
state.get()[k + 8] = im;
240240
}
241241

242-
// Sets state[i] = val where (i & mask) == bits
242+
// Sets state[i] = complex(re, im) where (i & mask) == bits.
243+
// if `exclude` is true then the criteria becomes (i & mask) != bits.
243244
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits,
244-
const std::complex<fp_type>& val) const {
245-
BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val));
245+
const std::complex<fp_type>& val,
246+
bool exclude = false) const {
247+
BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val), exclude);
246248
}
247249

248-
// Sets state[i] = complex(re, im) where (i & mask) == bits
250+
// Sets state[i] = complex(re, im) where (i & mask) == bits.
251+
// if `exclude` is true then the criteria becomes (i & mask) != bits.
249252
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re,
250-
fp_type im) const {
253+
fp_type im, bool exclude = false) const {
251254
__m256 re_reg = _mm256_set1_ps(re);
252255
__m256 im_reg = _mm256_set1_ps(im);
253256

257+
__m256i exclude_reg = _mm256_setzero_si256();
258+
if (exclude) {
259+
exclude_reg = _mm256_cmpeq_epi32(exclude_reg, exclude_reg);
260+
}
261+
254262
auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv,
255-
uint64_t bitsv, __m256 re_n, __m256 im_n, fp_type* p) {
256-
__m256 ml =
257-
_mm256_castsi256_ps(detail::GetZeroMaskAVX(8 * i, maskv, bitsv));
263+
uint64_t bitsv, __m256 re_n, __m256 im_n, __m256i exclude_n,
264+
fp_type* p) {
265+
__m256 ml = _mm256_castsi256_ps(_mm256_xor_si256(
266+
detail::GetZeroMaskAVX(8 * i, maskv, bitsv), exclude_n));
258267

259268
__m256 re = _mm256_load_ps(p + 16 * i);
260269
__m256 im = _mm256_load_ps(p + 16 * i + 8);
@@ -267,7 +276,7 @@ class StateSpaceAVX : public StateSpace<StateSpaceAVX<For>, For, float> {
267276
};
268277

269278
Base::for_.Run(MinSize(state.num_qubits()) / 16, f, mask, bits, re_reg,
270-
im_reg, state.get());
279+
im_reg, exclude_reg, state.get());
271280
}
272281

273282
// Does the equivalent of dest += src elementwise.

lib/statespace_basic.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,26 +96,30 @@ class StateSpaceBasic : public StateSpace<StateSpaceBasic<For, FP>, For, FP> {
9696
state.get()[p + 1] = im;
9797
}
9898

99-
// Sets state[i] = val where (i & mask) == bits
99+
// Sets state[i] = complex(re, im) where (i & mask) == bits.
100+
// if `exclude` is true then the criteria becomes (i & mask) != bits.
100101
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits,
101-
const std::complex<fp_type>& val) const {
102-
BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val));
102+
const std::complex<fp_type>& val,
103+
bool exclude = false) const {
104+
BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val), exclude);
103105
}
104106

105-
// Sets state[i] = complex(re, im) where (i & mask) == bits
107+
// Sets state[i] = complex(re, im) where (i & mask) == bits.
108+
// if `exclude` is true then the criteria becomes (i & mask) != bits.
106109
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re,
107-
fp_type im) const {
110+
fp_type im, bool exclude = false) const {
108111
auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv,
109-
uint64_t bitsv, fp_type re_n, fp_type im_n, fp_type* p) {
112+
uint64_t bitsv, fp_type re_n, fp_type im_n, bool excludev,
113+
fp_type* p) {
110114
auto s = p + 2 * i;
111115
bool in_mask = (i & maskv) == bitsv;
112-
116+
in_mask ^= excludev;
113117
s[0] = in_mask ? re_n : s[0];
114118
s[1] = in_mask ? im_n : s[1];
115119
};
116120

117121
Base::for_.Run(MinSize(state.num_qubits()) / 2, f, mask, bits, re, im,
118-
state.get());
122+
exclude, state.get());
119123
}
120124

121125
// Does the equivalent of dest += src elementwise.

lib/statespace_sse.h

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -200,21 +200,30 @@ class StateSpaceSSE : public StateSpace<StateSpaceSSE<For>, For, float> {
200200
state.get()[p + 4] = im;
201201
}
202202

203-
// Sets state[i] = val where (i & mask) == bits
203+
// Sets state[i] = complex(re, im) where (i & mask) == bits.
204+
// if `exclude` is true then the criteria becomes (i & mask) != bits.
204205
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits,
205-
const std::complex<fp_type>& val) const {
206+
const std::complex<fp_type>& val,
207+
bool exclude = false) const {
206208
BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val));
207209
}
208210

209-
// Sets state[i] = complex(re, im) where (i & mask) == bits
211+
// Sets state[i] = complex(re, im) where (i & mask) == bits.
212+
// if `exclude` is true then the criteria becomes (i & mask) != bits.
210213
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re,
211-
fp_type im) const {
214+
fp_type im, bool exclude = false) const {
212215
__m128 re_reg = _mm_set1_ps(re);
213216
__m128 im_reg = _mm_set1_ps(im);
217+
__m128i exclude_reg = _mm_setzero_si128();
218+
if (exclude) {
219+
exclude_reg = _mm_cmpeq_epi32(exclude_reg, exclude_reg);
220+
}
214221

215222
auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv,
216-
uint64_t bitsv, __m128 re_n, __m128 im_n, fp_type* p) {
217-
__m128 ml = _mm_castsi128_ps(detail::GetZeroMaskSSE(4 * i, maskv, bitsv));
223+
uint64_t bitsv, __m128 re_n, __m128 im_n, __m128i exclude_n,
224+
fp_type* p) {
225+
__m128 ml = _mm_castsi128_ps(_mm_xor_si128(
226+
detail::GetZeroMaskSSE(4 * i, maskv, bitsv), exclude_n));
218227

219228
__m128 re = _mm_load_ps(p + 8 * i);
220229
__m128 im = _mm_load_ps(p + 8 * i + 4);
@@ -227,7 +236,7 @@ class StateSpaceSSE : public StateSpace<StateSpaceSSE<For>, For, float> {
227236
};
228237

229238
Base::for_.Run(MinSize(state.num_qubits()) / 8, f, mask, bits, re_reg,
230-
im_reg, state.get());
239+
im_reg, exclude_reg, state.get());
231240
}
232241

233242
// Does the equivalent of dest += src elementwise.

tests/statespace_avx_test.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,18 @@ TEST(StateSpaceAVXTest, InvalidStateSize) {
6262
TestInvalidStateSize<StateSpaceAVX<For>>();
6363
}
6464

65-
TEST(StateSpaceBasicTest, BulkSetAmpl) {
65+
TEST(StateSpaceAVXTest, BulkSetAmpl) {
6666
TestBulkSetAmplitude<StateSpaceAVX<For>>();
6767
}
6868

69+
TEST(StateSpaceAVXTest, BulkSetAmplExclude) {
70+
TestBulkSetAmplitudeExclusion<StateSpaceAVX<For>>();
71+
}
72+
73+
TEST(StateSpaceAVXTest, BulkSetAmplDefault) {
74+
TestBulkSetAmplitudeDefault<StateSpaceAVX<For>>();
75+
}
76+
6977
} // namespace qsim
7078

7179
int main(int argc, char** argv) {

tests/statespace_basic_test.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ TEST(StateSpaceBasicTest, BulkSetAmpl) {
6666
TestBulkSetAmplitude<StateSpaceBasic<For, float>>();
6767
}
6868

69+
TEST(StateSpaceBasicTest, BulkSetAmplExclude) {
70+
TestBulkSetAmplitudeExclusion<StateSpaceBasic<For, float>>();
71+
}
72+
73+
TEST(StateSpaceBasicTest, BulkSetAmplDefault) {
74+
TestBulkSetAmplitudeDefault<StateSpaceBasic<For, float>>();
75+
}
76+
6977
} // namespace qsim
7078

7179
int main(int argc, char** argv) {

tests/statespace_sse_test.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,18 @@ TEST(StateSpaceSSETest, InvalidStateSize) {
6262
TestInvalidStateSize<StateSpaceSSE<For>>();
6363
}
6464

65-
TEST(StateSpaceBasicTest, BulkSetAmpl) {
65+
TEST(StateSpaceSSETest, BulkSetAmpl) {
6666
TestBulkSetAmplitude<StateSpaceSSE<For>>();
6767
}
6868

69+
TEST(StateSpaceSSETest, BulkSetAmplExclude) {
70+
TestBulkSetAmplitudeExclusion<StateSpaceSSE<For>>();
71+
}
72+
73+
TEST(StateSpaceSSETest, BulkSetAmplDefault) {
74+
TestBulkSetAmplitudeDefault<StateSpaceSSE<For>>();
75+
}
76+
6977
} // namespace qsim
7078

7179
int main(int argc, char** argv) {

tests/statespace_testfixture.h

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ void TestBulkSetAmplitude() {
820820
for(int i = 0; i < 8; i++) {
821821
state_space.SetAmpl(state, i, 1, 1);
822822
}
823-
state_space.BulkSetAmpl(state, 1, 0, 0, 0);
823+
state_space.BulkSetAmpl(state, 1, 0, 0, 0, false);
824824
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(0, 0));
825825
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(1, 1));
826826
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(0, 0));
@@ -833,7 +833,7 @@ void TestBulkSetAmplitude() {
833833
for(int i = 0; i < 8; i++) {
834834
state_space.SetAmpl(state, i, 1, 1);
835835
}
836-
state_space.BulkSetAmpl(state, 2, 0, 0, 0);
836+
state_space.BulkSetAmpl(state, 2, 0, 0, 0, false);
837837
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(0, 0));
838838
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(0, 0));
839839
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(1, 1));
@@ -846,7 +846,7 @@ void TestBulkSetAmplitude() {
846846
for(int i = 0; i < 8; i++) {
847847
state_space.SetAmpl(state, i, 1, 1);
848848
}
849-
state_space.BulkSetAmpl(state, 4, 0, 0, 0);
849+
state_space.BulkSetAmpl(state, 4, 0, 0, 0, false);
850850
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(0, 0));
851851
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(0, 0));
852852
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(0, 0));
@@ -856,6 +856,89 @@ void TestBulkSetAmplitude() {
856856
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(1, 1));
857857
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(1, 1));
858858

859+
for(int i = 0; i < 8; i++) {
860+
state_space.SetAmpl(state, i, 1, 1);
861+
}
862+
state_space.BulkSetAmpl(state, 4 | 1, 4, 0, 0, false);
863+
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(1, 1));
864+
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(1, 1));
865+
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(1, 1));
866+
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(1, 1));
867+
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(0, 0));
868+
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(1, 1));
869+
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(0, 0));
870+
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(1, 1));
871+
}
872+
873+
template <typename StateSpace>
874+
void TestBulkSetAmplitudeExclusion() {
875+
using State = typename StateSpace::State;
876+
unsigned num_qubits = 3;
877+
878+
StateSpace state_space(1);
879+
880+
State state = state_space.Create(num_qubits);
881+
for(int i = 0; i < 8; i++) {
882+
state_space.SetAmpl(state, i, 1, 1);
883+
}
884+
state_space.BulkSetAmpl(state, 1, 0, 0, 0, true);
885+
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(1, 1));
886+
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(0, 0));
887+
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(1, 1));
888+
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(0, 0));
889+
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(1, 1));
890+
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(0, 0));
891+
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(1, 1));
892+
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(0, 0));
893+
894+
for(int i = 0; i < 8; i++) {
895+
state_space.SetAmpl(state, i, 1, 1);
896+
}
897+
state_space.BulkSetAmpl(state, 2, 0, 0, 0, true);
898+
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(1, 1));
899+
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(1, 1));
900+
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(0, 0));
901+
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(0, 0));
902+
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(1, 1));
903+
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(1, 1));
904+
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(0, 0));
905+
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(0, 0));
906+
907+
for(int i = 0; i < 8; i++) {
908+
state_space.SetAmpl(state, i, 1, 1);
909+
}
910+
state_space.BulkSetAmpl(state, 4, 0, 0, 0, true);
911+
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(1, 1));
912+
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(1, 1));
913+
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(1, 1));
914+
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(1, 1));
915+
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(0, 0));
916+
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(0, 0));
917+
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(0, 0));
918+
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(0, 0));
919+
920+
for(int i = 0; i < 8; i++) {
921+
state_space.SetAmpl(state, i, 1, 1);
922+
}
923+
state_space.BulkSetAmpl(state, 4 | 1, 4, 0, 0, true);
924+
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(0, 0));
925+
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(0, 0));
926+
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(0, 0));
927+
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(0, 0));
928+
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(1, 1));
929+
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(0, 0));
930+
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(1, 1));
931+
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(0, 0));
932+
}
933+
934+
template <typename StateSpace>
935+
void TestBulkSetAmplitudeDefault() {
936+
using State = typename StateSpace::State;
937+
unsigned num_qubits = 3;
938+
939+
StateSpace state_space(1);
940+
941+
State state = state_space.Create(num_qubits);
859942
for(int i = 0; i < 8; i++) {
860943
state_space.SetAmpl(state, i, 1, 1);
861944
}

0 commit comments

Comments
 (0)