@@ -50,11 +50,10 @@ def initialize(
5050 transition_probs = transition_probs ,
5151 emission_probs = emission_probs ,
5252 )
53- self .hmm = hmm
53+ self .hidden_markov_model = hmm
5454 else :
55- self .hmm = hmm
55+ self .hidden_markov_model = hmm
5656 self .num_zeros = num_zeros
57- self .generate_oracle_constraints ()
5857 self ._hidden_states = {0 , 1 }
5958 self ._observable_states = {0 , 1 }
6059 self .time = time
@@ -65,32 +64,36 @@ def run_simulations(
6564 if seed is not None :
6665 random .seed (seed )
6766 output = []
67+ oracle_chmm = self .create_chmm ("oracle" )
68+
6869 for n in range (num ):
6970 res = munch .Munch ()
70- hidden = self . oracle .generate_hidden (self .time )
71+ hidden = oracle_chmm .generate_hidden (self .time )
7172 if with_observations :
72- observed = self . oracle .generate_observed_from_hidden (hidden )
73+ observed = oracle_chmm .generate_observed_from_hidden (hidden )
7374 res = munch .Munch (hidden = hidden , index = n )
7475 if with_observations :
7576 res .observed = observed
7677 output .append (res )
7778 return output
7879
79- def generate_oracle_constraints (self ):
80+ def get_oracle_constraints (self ):
8081 constraint = has_exact_number_of_occurences_constraint (
8182 val = 0 , count = self .num_zeros
8283 )
83- self . oracle . set_constraints ( [constraint ])
84+ return [constraint ]
8485
85- def generate_pyomo_constraints (self , * , M ):
86- index_sets = list (M .hmm .x .index_set ().subsets ())
87- T = list (index_sets [0 ])
86+ def get_pyomo_constraints (self ):
87+ @pyomo_constraint_fn ()
88+ def constraint (M , data ):
89+ index_sets = list (M .hmm .x .index_set ().subsets ())
90+ T = list (index_sets [0 ])
8891
89- M .num_zeros = pyo .Constraint (
90- expr = sum (M .hmm .x [t , 0 ] for t in T ) == self .num_zeros
91- )
92+ M .num_zeros = pyo .Constraint (
93+ expr = sum (M .hmm .x [t , 0 ] for t in T ) == data .num_zeros
94+ )
9295
93- return M
96+ return [ constraint ]
9497
9598
9699@pytest .fixture
@@ -111,24 +114,29 @@ def app():
111114 return app
112115
113116
114- class XTest_Application_CHMM :
117+ @pytest .fixture
118+ def oracle_chmm (app ):
119+ return app .create_chmm ("oracle" )
120+
121+
122+ class Test_Application_CHMM :
115123 def test_hmm (self , app ):
116- assert app .hmm .transition_mat == [[0.6 , 0.4 ], [0.4 , 0.6 ]]
117- assert app .hmm .emission_mat == [[0.7 , 0.3 ], [0.3 , 0.7 ]]
118- assert app .hmm .start_vec == [0.5 , 0.5 ]
124+ assert app .hidden_markov_model .transition_mat == [[0.6 , 0.4 ], [0.4 , 0.6 ]]
125+ assert app .hidden_markov_model .emission_mat == [[0.7 , 0.3 ], [0.3 , 0.7 ]]
126+ assert app .hidden_markov_model .start_vec == [0.5 , 0.5 ]
119127 assert app ._hidden_states == {0 , 1 }
120128 assert app ._observable_states == {0 , 1 }
121129 assert app .time == 20
122130
123131 def test_oracle_type (self , app ):
124- assert isinstance (app .oracle , type (Oracle_CHMM ()))
132+ chmm = app .create_chmm ("oracle" )
133+ assert isinstance (chmm .chmm , Oracle_CHMM )
134+ assert app .hidden_markov_model == chmm .hidden_markov_model
125135
126136 def test_algebraic_type (self , app ):
127- assert isinstance (app .algebraic , type (PyomoAlgebraic_CHMM ()))
128-
129- def test_hmm_equality (self , app ):
130- assert app .hmm == app .oracle .hmm
131- assert app .hmm == app .algebraic .hmm
137+ chmm = app .create_chmm ("pyomo" )
138+ assert isinstance (chmm .chmm , PyomoAlgebraic_CHMM )
139+ assert app .hidden_markov_model == chmm .hidden_markov_model
132140
133141 def test_hmm_equality_setter (self , app ):
134142 hmm = HiddenMarkovModel ()
@@ -147,45 +155,49 @@ def test_hmm_equality_setter(self, app):
147155 (1 , 1 ): 0.6 ,
148156 },
149157 )
150- app .hmm = hmm
151- assert app .hmm == app .oracle .hmm
152- assert app .hmm == app .algebraic .hmm
158+ app .hidden_markov_model = hmm
159+ chmm = app .create_chmm ("oracle" )
160+ assert app .hidden_markov_model == chmm .hidden_markov_model
161+ chmm = app .create_chmm ("pyomo" )
162+ assert app .hidden_markov_model == chmm .hidden_markov_model
153163
154164 def test_get_internal_hmm (self , app ):
155- assert app .get_internal_hmm ().transition_mat == [
165+ repn = app .hidden_markov_model .repn
166+ assert repn .transition_mat == [
156167 [0.6 , 0.4 ],
157168 [0.4 , 0.6 ],
158169 ]
159- assert app . get_internal_hmm () .emission_mat == [[0.7 , 0.3 ], [0.3 , 0.7 ]]
160- assert app . get_internal_hmm () .start_vec == [0.5 , 0.5 ]
170+ assert repn .emission_mat == [[0.7 , 0.3 ], [0.3 , 0.7 ]]
171+ assert repn .start_vec == [0.5 , 0.5 ]
161172
162- # CLM: This is a random test. Is that okay? -- I could also set a seed
163173 def test_run_simulations (self , app ):
164174 seed = 1
165175 num_simulations = 5
166176 simulations = app .run_simulations (
167177 num = num_simulations , with_observations = True , seed = seed
168178 )
179+
180+ oracle_chmm = app .create_chmm ("oracle" )
169181 assert len (simulations ) == num_simulations
170182 for i in range (num_simulations ):
171183 assert len (simulations [i ].hidden ) == app .time
172184 assert len (simulations [i ].observed ) == app .time
173- assert app . oracle .is_feasible (simulations [i ].hidden )
185+ assert oracle_chmm .is_feasible (simulations [i ].hidden )
174186
175- def test_oracle_is_feasible (self , app ):
187+ def test_oracle_is_feasible (self , oracle_chmm ):
176188 seq1 = [0 ] * 9
177189 seq2 = [0 ] * 10
178190 seq3 = [0 ] * 11
179- assert not app . oracle .is_feasible (seq1 )
180- assert app . oracle .is_feasible (seq2 )
181- assert not app . oracle .is_feasible (seq3 )
191+ assert not oracle_chmm .is_feasible (seq1 )
192+ assert oracle_chmm .is_feasible (seq2 )
193+ assert not oracle_chmm .is_feasible (seq3 )
182194
183195 # This assumes that the internal logic is correct, and is really just
184196 def test_initalize_hmm_from_simulations (self , app ):
185- app .initialize_hmm_from_simulations ( num = 7 )
186- assert app .hmm == app . oracle . hmm
187- assert app . hmm == app .algebraic . hmm
188- assert app . hmm .transition_mat != [
197+ simulations = app .run_simulations ( with_observations = True )
198+ app .initialize_hmm_from_simulations ( simulations = simulations )
199+ repn = app .hidden_markov_model . repn
200+ assert repn .transition_mat != [
189201 [0.6 , 0.4 ],
190202 [0.4 , 0.6 ],
191203 ] # Just checks that it's updated
0 commit comments