Skip to content

Commit 1028ad3

Browse files
authored
Merge pull request #161 from sandialabs/160-remove-fixed-tolerance-in-hiddenmarkovmodelto_dict
Fixing #160
2 parents 0779bf0 + 9a617b8 commit 1028ad3

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

conin/hmm/hmm.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def __str__(self):
245245
"""
246246
Nice printing
247247
"""
248-
return pprint.pformat(self.to_dict(), indent=4, sort_dicts=True)
248+
return pprint.pformat(self.to_dict(tolerance=1e-3), indent=4, sort_dicts=True)
249249

250250
@property
251251
def repn(self):
@@ -464,31 +464,31 @@ def get_emission_probs(self):
464464
for o in range(self.num_observed_states)
465465
}
466466

467-
def to_dict(self):
467+
def to_dict(self, tolerance=0.0):
468468
"""
469469
Generate a dict representation of the model data.
470470
471471
Returns:
472472
dict: A dictionary representaiton of this statistical model.
473473
"""
474474

475-
start_probs = {
476-
self.hidden_to_external[i]: v
475+
start_probs = [
476+
(self.hidden_to_external[i], v)
477477
for i, v in enumerate(self.start_vec)
478-
if v > 1e-3
479-
}
480-
transition_probs = {
481-
(self.hidden_to_external[i], self.hidden_to_external[j]): v
478+
if v > tolerance
479+
]
480+
transition_probs = [
481+
((self.hidden_to_external[i], self.hidden_to_external[j]), v)
482482
for i, row in enumerate(self.transition_mat)
483483
for j, v in enumerate(row)
484-
if v > 1e-3
485-
}
486-
emission_probs = {
487-
(self.hidden_to_external[i], self.observed_to_external[o]): v
484+
if v > tolerance
485+
]
486+
emission_probs = [
487+
((self.hidden_to_external[i], self.observed_to_external[o]), v)
488488
for i, row in enumerate(self.emission_mat)
489489
for o, v in enumerate(row)
490-
if v > 1e-3
491-
}
490+
if v > tolerance
491+
]
492492

493493
return dict(
494494
start_probs=start_probs,

0 commit comments

Comments
 (0)