@@ -484,13 +484,18 @@ cdef class StanFit4Model:
484484
485485 # public methods
486486
487- def plot (self , pars = None ):
487+ def plot (self , pars = None , dtypes = None ):
488488 """ Visualize samples from posterior distributions
489489
490490 Parameters
491491 ---------
492492 pars : {str, sequence of str}
493493 parameter name(s); by default use all parameters of interest
494+ dtypes : dict
495+ datatype of parameter(s).
496+ If nothing is passed, np.float will be used for all parameters.
497+ If np.int is specified, the histogram will be visualized, not but
498+ kde.
494499
495500 Note
496501 ----
@@ -501,20 +506,25 @@ cdef class StanFit4Model:
501506 elif isinstance (pars, string_types):
502507 pars = [pars]
503508 pars = pystan.misc._remove_empty_pars(pars, self .sim[' pars_oi' ], self .sim[' dims_oi' ])
504- return pystan.plots.traceplot(self , pars)
509+ return pystan.plots.traceplot(self , pars, dtypes )
505510
506- def traceplot (self , pars = None ):
511+ def traceplot (self , pars = None , dtypes = None ):
507512 """ Visualize samples from posterior distributions
508513
509514 Parameters
510515 ---------
511516 pars : {str, sequence of str}, optional
512517 parameter name(s); by default use all parameters of interest
518+ dtypes : dict
519+ datatype of parameter(s).
520+ If nothing is passed, np.float will be used for all parameters.
521+ If np.int is specified, the histogram will be visualized, not but
522+ kde.
513523 """
514524 # FIXME: for now plot and traceplot do the same thing
515- return self .plot(pars)
525+ return self .plot(pars, dtypes = dtypes )
516526
517- def extract (self , pars = None , permuted = True , inc_warmup = False ):
527+ def extract (self , pars = None , permuted = True , inc_warmup = False , dtypes = None ):
518528 """ Extract samples in different forms for different parameters.
519529
520530 Parameters
@@ -528,6 +538,9 @@ cdef class StanFit4Model:
528538 inc_warmup : bool
529539 If True, warmup samples are kept; otherwise they are
530540 discarded. If `permuted` is True, `inc_warmup` is ignored.
541+ dtypes : dict
542+ datatype of parameter(s).
543+ If nothing is passed, np.float will be used for all parameters.
531544
532545 Returns
533546 -------
@@ -545,12 +558,16 @@ cdef class StanFit4Model:
545558 self ._verify_has_samples()
546559 if inc_warmup is True and permuted is True :
547560 logging.warn(" `inc_warmup` ignored when `permuted` is True." )
561+ if dtypes is None and permuted is False :
562+ logging.warn(" `dtypes` ignored when `permuted` is False." )
548563
549564 if pars is None :
550565 pars = self .sim[' pars_oi' ]
551566 elif isinstance (pars, string_types):
552567 pars = [pars]
553568 pars = pystan.misc._remove_empty_pars(pars, self .sim[' pars_oi' ], self .sim[' dims_oi' ])
569+ if dtypes is None :
570+ dtypes = {}
554571
555572 allpars = self .sim[' pars_oi' ] + self .sim[' fnames_oi' ]
556573 pystan.misc._check_pars(allpars, pars)
@@ -567,7 +584,10 @@ cdef class StanFit4Model:
567584 for par in pars:
568585 sss = [pystan.misc._get_kept_samples(p, self .sim)
569586 for p in tidx[par]]
570- s = {par: np.column_stack(sss)}
587+ ss = np.column_stack(sss)
588+ if par in dtypes.keys():
589+ ss = ss.astype(dtypes[par])
590+ s = {par: ss}
571591 extracted.update(s)
572592 par_idx = self .sim[' pars_oi' ].index(par)
573593 par_dim = self .sim[' dims_oi' ][par_idx]
0 commit comments