Skip to content

Commit 1dedad3

Browse files
add merge_coint
1 parent 7c4bc5b commit 1dedad3

File tree

3 files changed

+51
-2
lines changed

3 files changed

+51
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
- Add `jax_interface`
88

9+
- Add `merge_count` in `results` module
10+
911
## 1.1.0
1012

1113
### Added

tensorcircuit/results/counts.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
Tensor = Any
1111
ct = Dict[str, int]
1212

13-
# TODO(@refraction-ray): merge_count
14-
1513

1614
def reverse_count(count: ct) -> ct:
1715
ncount = {}
@@ -107,6 +105,22 @@ def expectation(
107105
return r / shots
108106

109107

108+
def merge_count(*counts: ct) -> ct:
109+
"""
110+
Merge multiple count dictionaries by summing up their counts
111+
112+
:param counts: Variable number of count dictionaries
113+
:type counts: ct
114+
:return: Merged count dictionary
115+
:rtype: ct
116+
"""
117+
merged: ct = {}
118+
for count in counts:
119+
for k, v in count.items():
120+
merged[k] = merged.get(k, 0) + v
121+
return merged
122+
123+
110124
def plot_histogram(data: Any, **kws: Any) -> Any:
111125
"""
112126
See ``qiskit.visualization.plot_histogram``:

tests/test_results.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,39 @@ def test_marginal_count():
1515
assert counts.marginal_count(d, [2, 1, 0])["001"] == 4
1616

1717

18+
def test_merge_count():
19+
20+
c1 = {"00": 10, "01": 20, "11": 30}
21+
c2 = {"00": 5, "10": 15, "11": 25}
22+
c3 = {"01": 10, "10": 20}
23+
24+
# Test merging two count dicts
25+
merged = counts.merge_count(c1, c2)
26+
assert merged["00"] == 15
27+
assert merged["01"] == 20
28+
assert merged["10"] == 15
29+
assert merged["11"] == 55
30+
31+
# Test merging three count dicts
32+
merged = counts.merge_count(c1, c2, c3)
33+
assert merged["00"] == 15
34+
assert merged["01"] == 30
35+
assert merged["10"] == 35
36+
assert merged["11"] == 55
37+
38+
# Test merging single count dict
39+
merged = counts.merge_count(c1)
40+
assert merged == c1
41+
42+
# Test merging empty dicts
43+
merged = counts.merge_count({}, {})
44+
assert merged == {}
45+
46+
# Test merging empty with non-empty
47+
merged = counts.merge_count({}, c1)
48+
assert merged == c1
49+
50+
1851
def test_count2vec():
1952
assert counts.vec2count(counts.count2vec(d, normalization=False), prune=True) == d
2053

0 commit comments

Comments
 (0)