Skip to content

Commit da42256

Browse files
committed
added input checking in GaussianMixture constructor
1 parent bfff99e commit da42256

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

pyest/gm/gm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,10 @@ def __init__(self, w, m, cov, cov_type='full', Seig=None):
604604
self._Seig = Seig # directly write Seig here as to not overwrite P
605605
self._set_cov(cov, cov_type)
606606

607+
# check that equal numbers of weights, means, and covariances are provided
608+
if len(self.w) != len(self.m) or len(self.w) != len(self._cov):
609+
raise ValueError("Number of weights, means, and covariances must match.")
610+
607611
self.set_msize(self._cov[0].covariance.shape[-1])
608612

609613
def __getitem__(self, ind):

tests/test_gm.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,5 +525,36 @@ def test_pickle_gm():
525525
assert(p == ptest)
526526

527527

528+
def test_get_comp():
529+
p = gm.defaults.default_gm()
530+
# test with index
531+
comp = p.get_comp(0)
532+
npt.assert_array_equal(comp[0], p.w[0])
533+
npt.assert_array_equal(comp[1], p.m[0])
534+
npt.assert_array_equal(comp[2], p.P[0])
535+
536+
# test with index
537+
comp = p.get_comp(1)
538+
npt.assert_array_equal(comp[0], p.w[1])
539+
npt.assert_array_equal(comp[1], p.m[1])
540+
npt.assert_array_equal(comp[2], p.P[1])
541+
542+
543+
def test_init_mismatch_fail():
544+
545+
# try initializing a GaussianMixture with mismatched weights, means, and covariances
546+
w = np.array([0.4, 0.6])
547+
m = np.array([[30., 0.], [10., np.pi/2]])
548+
P = np.array([[1., 0.5], [0.5, 3.4]])
549+
550+
fail(gm.GaussianMixture, ValueError, w, m, P)
551+
552+
w = np.array([0.4])
553+
m = np.array([[30., 0.], [10., np.pi/2]])
554+
P = np.array([[1., 0.5], [0.5, 3.4]])
555+
556+
fail(gm.GaussianMixture, ValueError, w, m, P)
557+
558+
528559
if __name__ == '__main__':
529560
pytest.main([__file__])

0 commit comments

Comments
 (0)