Skip to content

Commit 150cc59

Browse files
committed
Fix #928:
The `synthesize_dual_basis()` method had a potential `NameError` because it calculated the number of rows after a for loop. If the `dual_basis` was empty, the loop would not run, and a variable used later would be undefined. The fix was to move the line that calculates the number of rows to before the loop, ensuring it is always defined.
1 parent fb7bc39 commit 150cc59

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

src/openfermion/contrib/representability/_multitensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,15 @@ def synthesize_dual_basis(self):
102102
bias_data_values = []
103103
# this forms the c-vector of ax + b = c
104104
inner_prod_data_values = []
105+
n_rows = len(self.dual_basis.elements)
105106
for index, dual_element in enumerate(self.dual_basis):
106107
dcol, dval = self.synthesize_element(dual_element)
107108
dual_row_indices.extend([index] * len(dcol))
108109
dual_col_indices.extend(dcol)
109110
dual_data_values.extend(dval)
110111
inner_prod_data_values.append(float(dual_element.dual_scalar))
111112
bias_data_values.append(dual_element.constant_bias)
112-
n_rows = len(self.dual_basis.elements)
113+
113114
sparse_dual_operator = csr_matrix(
114115
(dual_data_values, (dual_row_indices, dual_col_indices)), [n_rows, self.vec_dim]
115116
)

src/openfermion/contrib/representability/_multitensor_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,20 @@ def test_cover_make_offset_dict():
185185
c = np.random.random((3, 3))
186186
with pytest.raises(TypeError):
187187
_ = MultiTensor.make_offset_dict([a, b, c])
188+
189+
def test_synthesize_dual_basis_empty():
190+
a = np.random.random((5, 5))
191+
b = np.random.random((4, 4))
192+
c = np.random.random((3, 3))
193+
at = Tensor(tensor=a, name='a')
194+
bt = Tensor(tensor=b, name='b')
195+
ct = Tensor(tensor=c, name='c')
196+
mt = MultiTensor([at, bt, ct], DualBasis(elements=[]))
197+
198+
A, c, b = mt.synthesize_dual_basis()
199+
assert isinstance(A, sp.sparse.csr_matrix)
200+
assert isinstance(c, sp.sparse.csr_matrix)
201+
assert isinstance(b, sp.sparse.csr_matrix)
202+
assert A.shape == (0, 50)
203+
assert b.shape == (0, 1)
204+
assert c.shape == (0, 1)

0 commit comments

Comments
 (0)