Skip to content

Commit 0f93228

Browse files
committed
Updating application tests
1 parent a29b19b commit 0f93228

File tree

2 files changed

+73
-58
lines changed

2 files changed

+73
-58
lines changed

conin/hmm/hmm_application.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def __init__(self, name="unknown"):
1414
self.name = name
1515

1616
self._hidden_markov_model = None
17+
self._simulations = None
1718

1819
# Applicaton data used to initialize the HMM from simulations
1920
self._transition_prior = (None,) # Nonzero values
@@ -29,6 +30,14 @@ def hidden_markov_model(self):
2930
def hidden_markov_model(self, hidden_markov_model):
3031
self._hidden_markov_model = hidden_markov_model
3132

33+
@property
34+
def simulations(self):
35+
return self._simulations
36+
37+
@simulations.setter
38+
def simulations(self, simulations):
39+
self._simulations = simulations
40+
3241
def create_chmm(self, constraint_type=None):
3342
chmm = ConstrainedHiddenMarkovModel(hmm=self.hidden_markov_model)
3443
if constraint_type == "oracle":
@@ -45,6 +54,7 @@ def initialize(self, *args, **kwargs):
4554
"""
4655
pass
4756

57+
# TODO - return an error if these methods are not defined
4858
def run_simulations(
4959
self, *, num=1, debug=False, with_observations=False, seed=None
5060
):
@@ -54,40 +64,31 @@ def run_simulations(
5464
5565
This method is defined by the application developer, and it provides a
5666
strategy for expressing domain knowledge regarding feasible hidden states.
67+
68+
This method returns the simulations generated
5769
"""
58-
return None
70+
pass
5971

6072
def initialize_hmm_from_simulations(
6173
self,
6274
*,
63-
num=100,
64-
debug=False,
65-
seed=None,
6675
start_tolerance=None,
6776
transition_tolerance=None,
6877
emission_tolerance=None,
69-
simulation_args=None,
78+
simulations=None,
7079
):
7180
assert (
7281
self._hidden_states is not None
7382
), "HMMApplication.create_hmm_from_simulations must be run after the initialize() method is executed"
74-
if simulation_args is None:
75-
simulation_args = {}
76-
simulation_args["num"] = num
77-
simulation_args["debug"] = debug
78-
simulation_args["seed"] = seed
79-
simulation_args["with_observations"] = True
80-
simulations = self.run_simulations(**simulation_args)
81-
if debug:
82-
for sim in simulations:
83-
print("TSIM", sim.observations, sim.hidden)
8483

84+
if simulations is not None:
85+
self.simulations = simulations
8586
assert (
86-
simulations is not None
87-
), f"HMMApplication.create_hmm_from_simulations - Method run_simulations() has not been defined for the {self.name} application"
87+
self.simulations is not None
88+
), f"HMMApplication.create_hmm_from_simulations - No simulations specified"
8889

8990
self.hidden_markov_model = learning.supervised_learning(
90-
simulations=simulations,
91+
simulations=self.simulations,
9192
hidden_states=self._hidden_states,
9293
observable_states=self._observable_states,
9394
start_tolerance=start_tolerance,
@@ -97,8 +98,10 @@ def initialize_hmm_from_simulations(
9798
emission_prior=self._emission_prior,
9899
)
99100

101+
# TODO - return an error if these methods are not defined
100102
def get_oracle_constraints(self):
101103
return []
102104

105+
# TODO - return an error if these methods are not defined
103106
def get_pyomo_constraints(self):
104107
return []

conin/hmm/tests/test_application_chmm.py

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,10 @@ def initialize(
5050
transition_probs=transition_probs,
5151
emission_probs=emission_probs,
5252
)
53-
self.hmm = hmm
53+
self.hidden_markov_model = hmm
5454
else:
55-
self.hmm = hmm
55+
self.hidden_markov_model = hmm
5656
self.num_zeros = num_zeros
57-
self.generate_oracle_constraints()
5857
self._hidden_states = {0, 1}
5958
self._observable_states = {0, 1}
6059
self.time = time
@@ -65,32 +64,36 @@ def run_simulations(
6564
if seed is not None:
6665
random.seed(seed)
6766
output = []
67+
oracle_chmm = self.create_chmm("oracle")
68+
6869
for n in range(num):
6970
res = munch.Munch()
70-
hidden = self.oracle.generate_hidden(self.time)
71+
hidden = oracle_chmm.generate_hidden(self.time)
7172
if with_observations:
72-
observed = self.oracle.generate_observed_from_hidden(hidden)
73+
observed = oracle_chmm.generate_observed_from_hidden(hidden)
7374
res = munch.Munch(hidden=hidden, index=n)
7475
if with_observations:
7576
res.observed = observed
7677
output.append(res)
7778
return output
7879

79-
def generate_oracle_constraints(self):
80+
def get_oracle_constraints(self):
8081
constraint = has_exact_number_of_occurences_constraint(
8182
val=0, count=self.num_zeros
8283
)
83-
self.oracle.set_constraints([constraint])
84+
return [constraint]
8485

85-
def generate_pyomo_constraints(self, *, M):
86-
index_sets = list(M.hmm.x.index_set().subsets())
87-
T = list(index_sets[0])
86+
def get_pyomo_constraints(self):
87+
@pyomo_constraint_fn()
88+
def constraint(M, data):
89+
index_sets = list(M.hmm.x.index_set().subsets())
90+
T = list(index_sets[0])
8891

89-
M.num_zeros = pyo.Constraint(
90-
expr=sum(M.hmm.x[t, 0] for t in T) == self.num_zeros
91-
)
92+
M.num_zeros = pyo.Constraint(
93+
expr=sum(M.hmm.x[t, 0] for t in T) == data.num_zeros
94+
)
9295

93-
return M
96+
return [constraint]
9497

9598

9699
@pytest.fixture
@@ -111,24 +114,29 @@ def app():
111114
return app
112115

113116

114-
class XTest_Application_CHMM:
117+
@pytest.fixture
118+
def oracle_chmm(app):
119+
return app.create_chmm("oracle")
120+
121+
122+
class Test_Application_CHMM:
115123
def test_hmm(self, app):
116-
assert app.hmm.transition_mat == [[0.6, 0.4], [0.4, 0.6]]
117-
assert app.hmm.emission_mat == [[0.7, 0.3], [0.3, 0.7]]
118-
assert app.hmm.start_vec == [0.5, 0.5]
124+
assert app.hidden_markov_model.transition_mat == [[0.6, 0.4], [0.4, 0.6]]
125+
assert app.hidden_markov_model.emission_mat == [[0.7, 0.3], [0.3, 0.7]]
126+
assert app.hidden_markov_model.start_vec == [0.5, 0.5]
119127
assert app._hidden_states == {0, 1}
120128
assert app._observable_states == {0, 1}
121129
assert app.time == 20
122130

123131
def test_oracle_type(self, app):
124-
assert isinstance(app.oracle, type(Oracle_CHMM()))
132+
chmm = app.create_chmm("oracle")
133+
assert isinstance(chmm.chmm, Oracle_CHMM)
134+
assert app.hidden_markov_model == chmm.hidden_markov_model
125135

126136
def test_algebraic_type(self, app):
127-
assert isinstance(app.algebraic, type(PyomoAlgebraic_CHMM()))
128-
129-
def test_hmm_equality(self, app):
130-
assert app.hmm == app.oracle.hmm
131-
assert app.hmm == app.algebraic.hmm
137+
chmm = app.create_chmm("pyomo")
138+
assert isinstance(chmm.chmm, PyomoAlgebraic_CHMM)
139+
assert app.hidden_markov_model == chmm.hidden_markov_model
132140

133141
def test_hmm_equality_setter(self, app):
134142
hmm = HiddenMarkovModel()
@@ -147,45 +155,49 @@ def test_hmm_equality_setter(self, app):
147155
(1, 1): 0.6,
148156
},
149157
)
150-
app.hmm = hmm
151-
assert app.hmm == app.oracle.hmm
152-
assert app.hmm == app.algebraic.hmm
158+
app.hidden_markov_model = hmm
159+
chmm = app.create_chmm("oracle")
160+
assert app.hidden_markov_model == chmm.hidden_markov_model
161+
chmm = app.create_chmm("pyomo")
162+
assert app.hidden_markov_model == chmm.hidden_markov_model
153163

154164
def test_get_internal_hmm(self, app):
155-
assert app.get_internal_hmm().transition_mat == [
165+
repn = app.hidden_markov_model.repn
166+
assert repn.transition_mat == [
156167
[0.6, 0.4],
157168
[0.4, 0.6],
158169
]
159-
assert app.get_internal_hmm().emission_mat == [[0.7, 0.3], [0.3, 0.7]]
160-
assert app.get_internal_hmm().start_vec == [0.5, 0.5]
170+
assert repn.emission_mat == [[0.7, 0.3], [0.3, 0.7]]
171+
assert repn.start_vec == [0.5, 0.5]
161172

162-
# CLM: This is a random test. Is that okay? -- I could also set a seed
163173
def test_run_simulations(self, app):
164174
seed = 1
165175
num_simulations = 5
166176
simulations = app.run_simulations(
167177
num=num_simulations, with_observations=True, seed=seed
168178
)
179+
180+
oracle_chmm = app.create_chmm("oracle")
169181
assert len(simulations) == num_simulations
170182
for i in range(num_simulations):
171183
assert len(simulations[i].hidden) == app.time
172184
assert len(simulations[i].observed) == app.time
173-
assert app.oracle.is_feasible(simulations[i].hidden)
185+
assert oracle_chmm.is_feasible(simulations[i].hidden)
174186

175-
def test_oracle_is_feasible(self, app):
187+
def test_oracle_is_feasible(self, oracle_chmm):
176188
seq1 = [0] * 9
177189
seq2 = [0] * 10
178190
seq3 = [0] * 11
179-
assert not app.oracle.is_feasible(seq1)
180-
assert app.oracle.is_feasible(seq2)
181-
assert not app.oracle.is_feasible(seq3)
191+
assert not oracle_chmm.is_feasible(seq1)
192+
assert oracle_chmm.is_feasible(seq2)
193+
assert not oracle_chmm.is_feasible(seq3)
182194

183195
# This assumes that the internal logic is correct, and is really just
184196
def test_initalize_hmm_from_simulations(self, app):
185-
app.initialize_hmm_from_simulations(num=7)
186-
assert app.hmm == app.oracle.hmm
187-
assert app.hmm == app.algebraic.hmm
188-
assert app.hmm.transition_mat != [
197+
simulations = app.run_simulations(with_observations=True)
198+
app.initialize_hmm_from_simulations(simulations=simulations)
199+
repn = app.hidden_markov_model.repn
200+
assert repn.transition_mat != [
189201
[0.6, 0.4],
190202
[0.4, 0.6],
191203
] # Just checks that it's updated

0 commit comments

Comments
 (0)