11/* The MIT License
22
3- Copyright (c) 2014-2017 Genome Research Ltd.
3+ Copyright (c) 2014-2025 Genome Research Ltd.
44
55 Author: Petr Danecek <[email protected] > 66
2424
2525 */
2626
27- #include <stdio.h>
2827#include <stdlib.h>
2928#include <string.h>
3029#include <assert.h>
31- #include <htslib/hts.h>
3230#include "HMM.h"
3331
3432typedef struct
@@ -63,13 +61,42 @@ struct _hmm_t
6361 snapshot_t init , state ; // Initial and current state probs. Set state from snapshot if prev_snap_pos!=0 or from init otherwise
6462 snapshot_t * snapshot ; // snapshot->snap_at_pos .. request a snapshot at this position
6563 // hmm->state.snap_at_pos .. the current state comes from snapshot made at this position
64+ FILE * debug_fh ;
6665};
6766
6867uint8_t * hmm_get_viterbi_path (hmm_t * hmm ) { return hmm -> vpath ; }
6968double * hmm_get_tprob (hmm_t * hmm ) { return hmm -> tprob_arr ; }
7069int hmm_get_nstates (hmm_t * hmm ) { return hmm -> nstates ; }
7170double * hmm_get_fwd_bwd_prob (hmm_t * hmm ) { return hmm -> fwd ; }
7271
72+ int hmm_set (hmm_t * hmm , hmm_opt_t key , ...)
73+ {
74+ va_list args ;
75+ switch (key )
76+ {
77+ case DEBUG :
78+ va_start (args , key );
79+ hmm -> debug_fh = va_arg (args ,FILE * );
80+ va_end (args );
81+ return 0 ;
82+ default :
83+ fprintf (stderr ,"Todo: hmm_set key=%d" ,(int )key );
84+ return -1 ;
85+ break ;
86+ }
87+ return 0 ;
88+ }
89+ void * hmm_get (hmm_t * hmm , hmm_opt_t key , ...)
90+ {
91+ switch (key )
92+ {
93+ case DEBUG : return & hmm -> debug_fh ; break ;
94+ default : fprintf (stderr ,"Todo: hmm_get key=%d" ,(int )key ); return NULL ; break ;
95+ }
96+ return NULL ;
97+ }
98+
99+
73100static inline void multiply_matrix (int n , double * a , double * b , double * dst , double * tmp )
74101{
75102 double * out = dst ;
@@ -107,7 +134,7 @@ void hmm_init_states(hmm_t *hmm, double *probs)
107134 hmm -> state .fwd_prob = (double * ) malloc (sizeof (double )* hmm -> nstates );
108135 if ( !hmm -> state .bwd_prob )
109136 hmm -> state .bwd_prob = (double * ) malloc (sizeof (double )* hmm -> nstates );
110-
137+
111138 int i ;
112139 if ( probs )
113140 {
@@ -119,8 +146,8 @@ void hmm_init_states(hmm_t *hmm, double *probs)
119146 else
120147 for (i = 0 ; i < hmm -> nstates ; i ++ ) hmm -> init .vit_prob [i ] = 1. /hmm -> nstates ;
121148
149+ for (i = 0 ; i < hmm -> nstates ; i ++ ) hmm -> init .bwd_prob [i ] = 1 ;
122150 memcpy (hmm -> init .fwd_prob ,hmm -> init .vit_prob ,sizeof (double )* hmm -> nstates ); // these remain unchanged
123- memcpy (hmm -> init .bwd_prob ,hmm -> init .vit_prob ,sizeof (double )* hmm -> nstates );
124151 memcpy (hmm -> state .vit_prob ,hmm -> init .vit_prob ,sizeof (double )* hmm -> nstates ); // can be changed by snapshotting
125152 memcpy (hmm -> state .fwd_prob ,hmm -> init .fwd_prob ,sizeof (double )* hmm -> nstates );
126153 memcpy (hmm -> state .bwd_prob ,hmm -> init .bwd_prob ,sizeof (double )* hmm -> nstates );
@@ -164,7 +191,7 @@ void *hmm_snapshot(hmm_t *hmm, void *_snapshot, uint32_t pos)
164191void hmm_restore (hmm_t * hmm , void * _snapshot )
165192{
166193 snapshot_t * snapshot = (snapshot_t * ) _snapshot ;
167- if ( !snapshot || !snapshot -> snap_at_pos )
194+ if ( !snapshot || !snapshot -> snap_at_pos )
168195 {
169196 hmm -> state .snap_at_pos = 0 ;
170197 memcpy (hmm -> state .vit_prob ,hmm -> init .vit_prob ,sizeof (double )* hmm -> nstates );
@@ -238,7 +265,7 @@ void hmm_run_viterbi(hmm_t *hmm, int n, double *eprobs, uint32_t *sites)
238265 hmm -> vprob_tmp = (double * ) malloc (sizeof (double )* hmm -> nstates );
239266 }
240267
241- // Init all states with equal likelihood
268+ // Init states
242269 int i ,j , nstates = hmm -> nstates ;
243270 memcpy (hmm -> vprob , hmm -> state .vit_prob , sizeof (* hmm -> state .vit_prob )* nstates );
244271 uint32_t prev_pos = hmm -> state .snap_at_pos ? hmm -> state .snap_at_pos : sites [0 ];
@@ -268,24 +295,33 @@ void hmm_run_viterbi(hmm_t *hmm, int n, double *eprobs, uint32_t *sites)
268295 hmm -> vprob_tmp [j ] = vmax * eprob [j ];
269296 vnorm += hmm -> vprob_tmp [j ];
270297 }
298+
271299 for (j = 0 ; j < nstates ; j ++ ) hmm -> vprob_tmp [j ] /= vnorm ;
272300 double * tmp = hmm -> vprob ; hmm -> vprob = hmm -> vprob_tmp ; hmm -> vprob_tmp = tmp ;
273301
302+ if ( hmm -> debug_fh )
303+ {
304+ fprintf (hmm -> debug_fh ,"viterbi\t%d" ,i );
305+ for (j = 0 ; j < nstates ; j ++ ) fprintf (hmm -> debug_fh ,"\t%f" ,hmm -> vprob [j ]);
306+ fprintf (hmm -> debug_fh ,"\n" );
307+ }
308+
274309 if ( hmm -> snapshot && sites [i ]== hmm -> snapshot -> snap_at_pos )
275310 memcpy (hmm -> snapshot -> vit_prob , hmm -> vprob , sizeof (* hmm -> vprob )* nstates );
276311 }
277312
278313 // Find the most likely state
279314 int iptr = 0 ;
280- for (i = 1 ; i < nstates ; i ++ )
315+ for (i = 1 ; i < nstates ; i ++ )
281316 if ( hmm -> vprob [iptr ] < hmm -> vprob [i ] ) iptr = i ;
282317
283318 // Trace back the Viterbi path, we are reusing vpath for storing the states (vpath[i*nstates])
284319 for (i = n - 1 ; i >=0 ; i -- )
285320 {
286321 assert ( iptr < nstates && hmm -> vpath [i * nstates + iptr ]< nstates );
287- iptr = hmm -> vpath [i * nstates + iptr ];
322+ int iptr_prev = hmm -> vpath [i * nstates + iptr ];
288323 hmm -> vpath [i * nstates ] = iptr ; // reusing the array for different purpose here
324+ iptr = iptr_prev ;
289325 }
290326}
291327
@@ -309,7 +345,7 @@ void hmm_run_fwd_bwd(hmm_t *hmm, int n, double *eprobs, uint32_t *sites)
309345 memcpy (hmm -> bwd , hmm -> state .bwd_prob , sizeof (* hmm -> state .bwd_prob )* nstates );
310346 uint32_t prev_pos = hmm -> state .snap_at_pos ? hmm -> state .snap_at_pos : sites [0 ];
311347
312- // Run fwd
348+ // Run fwd
313349 for (i = 0 ; i < n ; i ++ )
314350 {
315351 double * fwd_prev = & hmm -> fwd [i * nstates ];
@@ -333,6 +369,13 @@ void hmm_run_fwd_bwd(hmm_t *hmm, int n, double *eprobs, uint32_t *sites)
333369 }
334370 for (j = 0 ; j < nstates ; j ++ ) fwd [j ] /= norm ;
335371
372+ if ( hmm -> debug_fh )
373+ {
374+ fprintf (hmm -> debug_fh ,"fwd\t%d" ,i );
375+ for (j = 0 ; j < nstates ; j ++ ) fprintf (hmm -> debug_fh ,"\t%f" ,fwd [j ]);
376+ fprintf (hmm -> debug_fh ,"\n" );
377+ }
378+
336379 if ( hmm -> snapshot && sites [i ]== hmm -> snapshot -> snap_at_pos )
337380 memcpy (hmm -> snapshot -> fwd_prob , fwd , sizeof (* fwd )* nstates );
338381 }
@@ -342,9 +385,9 @@ void hmm_run_fwd_bwd(hmm_t *hmm, int n, double *eprobs, uint32_t *sites)
342385 prev_pos = sites [n - 1 ];
343386 for (i = 0 ; i < n ; i ++ )
344387 {
345- double * fwd = & hmm -> fwd [(n - i )* nstates ];
388+ double * fwd = & hmm -> fwd [(n - i )* nstates ]; // the size of the fwd array is n+1
346389 double * eprob = & eprobs [(n - i - 1 )* nstates ];
347-
390+
348391 int pos_diff = sites [n - i - 1 ] == prev_pos ? 0 : prev_pos - sites [n - i - 1 ] - 1 ;
349392
350393 _set_tprob (hmm , pos_diff );
@@ -364,10 +407,21 @@ void hmm_run_fwd_bwd(hmm_t *hmm, int n, double *eprobs, uint32_t *sites)
364407 for (j = 0 ; j < nstates ; j ++ )
365408 {
366409 bwd_tmp [j ] /= bwd_norm ;
367- fwd [j ] *= bwd_tmp [j ]; // fwd now stores fwd*bwd
410+ fwd [j ] *= bwd [j ]; // fwd now stores fwd*bwd
368411 norm += fwd [j ];
369412 }
370413 for (j = 0 ; j < nstates ; j ++ ) fwd [j ] /= norm ;
414+
415+ if ( hmm -> debug_fh )
416+ {
417+ fprintf (hmm -> debug_fh ,"bwd\t%d" ,n - i - 1 );
418+ for (j = 0 ; j < nstates ; j ++ ) fprintf (hmm -> debug_fh ,"\t%f" ,bwd [j ]);
419+ fprintf (hmm -> debug_fh ,"\n" );
420+
421+ fprintf (hmm -> debug_fh ,"fwd_bwd\t%d" ,i );
422+ for (j = 0 ; j < nstates ; j ++ ) fprintf (hmm -> debug_fh ,"\t%f" ,fwd [j ]);
423+ fprintf (hmm -> debug_fh ,"\n" );
424+ }
371425 double * tmp = bwd_tmp ; bwd_tmp = bwd ; bwd = tmp ;
372426 }
373427}
@@ -397,7 +451,7 @@ double *hmm_run_baum_welch(hmm_t *hmm, int n, double *eprobs, uint32_t *sites)
397451 double * tmp_gamma = (double * ) calloc (nstates ,sizeof (double ));
398452 double * fwd_bwd = (double * ) malloc (sizeof (double )* nstates );
399453
400- // Run fwd
454+ // Run fwd
401455 for (i = 0 ; i < n ; i ++ )
402456 {
403457 double * fwd_prev = & hmm -> fwd [i * nstates ];
@@ -429,7 +483,7 @@ double *hmm_run_baum_welch(hmm_t *hmm, int n, double *eprobs, uint32_t *sites)
429483 {
430484 double * fwd = & hmm -> fwd [(n - i )* nstates ];
431485 double * eprob = & eprobs [(n - i - 1 )* nstates ];
432-
486+
433487 int pos_diff = sites [n - i - 1 ] == prev_pos ? 0 : prev_pos - sites [n - i - 1 ] - 1 ;
434488
435489 _set_tprob (hmm , pos_diff );
@@ -452,7 +506,7 @@ double *hmm_run_baum_welch(hmm_t *hmm, int n, double *eprobs, uint32_t *sites)
452506 fwd_bwd [j ] = fwd [j ]* bwd_tmp [j ];
453507 norm += fwd_bwd [j ];
454508 }
455- for (j = 0 ; j < nstates ; j ++ )
509+ for (j = 0 ; j < nstates ; j ++ )
456510 {
457511 fwd_bwd [j ] /= norm ;
458512 tmp_gamma [j ] += fwd_bwd [j ];
0 commit comments