Skip to content

Commit 6b1736c

Browse files
committed
Add tests for new parsing functions
1 parent 5c792aa commit 6b1736c

File tree

1 file changed

+363
-0
lines changed

1 file changed

+363
-0
lines changed

test/test_stancsv.py

Lines changed: 363 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""testing stancsv parsing"""
22

3+
import io
34
import os
45
from pathlib import Path
56
from test import without_import
@@ -352,6 +353,17 @@ def test_config_parsing():
352353
assert config == expected
353354

354355

356+
def test_config_parsing_data_transforms():
357+
comments = [
358+
b"# bool_t = true\n",
359+
b"# bool_f = false\n",
360+
b"# float = 1.5\n",
361+
b"# int = 1\n",
362+
]
363+
expected = {"bool_t": 1, "bool_f": 0, "float": 1.5, "int": 1}
364+
assert stancsv.parse_config(comments) == expected
365+
366+
355367
def test_extract_header_line():
356368
assert stancsv.extract_header_line([b"a,b\n", b"1,2\n"]) == "a,b"
357369
with pytest.raises(ValueError):
@@ -389,3 +401,354 @@ def test_column_filter_non_consecutive_indexes():
389401
b"7,9\n",
390402
b"3,5\n",
391403
]
404+
405+
406+
def test_parse_header():
407+
header = (
408+
"lp__,accept_stat__,stepsize__,treedepth__"
409+
",n_leapfrog__,divergent__,energy__,theta.1"
410+
)
411+
parsed = stancsv.parse_header(header)
412+
expected = (
413+
"lp__",
414+
"accept_stat__",
415+
"stepsize__",
416+
"treedepth__",
417+
"n_leapfrog__",
418+
"divergent__",
419+
"energy__",
420+
"theta[1]",
421+
)
422+
assert parsed == expected
423+
424+
425+
def test_extract_config_and_header_info():
426+
comments = [b"# stan_version_major = 2\n"]
427+
draws = [b"lp__,theta.1\n"]
428+
out = stancsv.extract_config_and_header_info(comments, draws)
429+
assert out["stan_version_major"] == 2
430+
assert out["raw_header"] == "lp__,theta.1"
431+
assert out["column_names"] == ("lp__", "theta[1]")
432+
433+
434+
def test_parse_variational_eta():
435+
csv_path = os.path.join(DATAFILES_PATH, "variational", "eta_big_output.csv")
436+
comments, _ = stancsv.parse_stan_csv_comments_and_draws(csv_path)
437+
eta = stancsv.parse_variational_eta(comments)
438+
assert eta == 100.0
439+
440+
441+
def test_parse_variational_eta_no_block():
442+
comments = [
443+
b"# stanc_version = stanc3 v2.28.0\n",
444+
b"# stancflags = \n",
445+
b"lp__,log_p__,log_g__,mu.1,mu.2\n",
446+
b"0,0,0,311.545,532.801\n",
447+
b"0,-186118,-4.74553,311.545,353.503\n",
448+
b"0,-184982,-2.75303,311.545,587.377\n",
449+
]
450+
451+
with pytest.raises(ValueError):
452+
stancsv.parse_variational_eta(comments)
453+
454+
455+
def test_max_treedepth_and_divergence_counts():
456+
draws = [
457+
(
458+
b"lp__,accept_stat__,stepsize__,treedepth__,"
459+
b"n_leapfrog__,divergent__,energy__,theta\n"
460+
),
461+
b"-4.78686,0.986298,1.09169,1,3,0,5.29492,0.550024\n",
462+
b"-5.07942,0.676947,1.09169,10,3,0,6.44279,0.709113\n",
463+
b"-5.04922,1,1.09169,1,1,0,5.14176,0.702445\n",
464+
b"-5.09338,0.996111,1.09169,10,3,1,5.16083,0.712059\n",
465+
b"-4.78903,0.989798,1.09169,1,3,0,5.08116,0.546685\n",
466+
b"-5.36502,0.854345,1.09169,1,3,0,5.39311,0.369686\n",
467+
b"-5.13605,0.937837,1.09169,1,3,0,5.95811,0.720607\n",
468+
b"-4.80646,1,1.09169,2,3,0,5.0962,0.528418\n",
469+
]
470+
out = stancsv.extract_max_treedepth_and_divergence_counts(draws, 10, 0)
471+
assert out == (2, 1)
472+
473+
474+
def test_max_treedepth_and_divergence_counts_warmup_draws():
475+
draws = [
476+
(
477+
b"lp__,accept_stat__,stepsize__,treedepth__,"
478+
b"n_leapfrog__,divergent__,energy__,theta\n"
479+
),
480+
b"-4.78686,0.986298,1.09169,1,3,0,5.29492,0.550024\n",
481+
b"-5.07942,0.676947,1.09169,10,3,0,6.44279,0.709113\n",
482+
b"-5.04922,1,1.09169,1,1,0,5.14176,0.702445\n",
483+
b"-5.09338,0.996111,1.09169,10,3,1,5.16083,0.712059\n",
484+
b"-4.78903,0.989798,1.09169,1,3,0,5.08116,0.546685\n",
485+
b"-5.36502,0.854345,1.09169,1,3,0,5.39311,0.369686\n",
486+
b"-5.13605,0.937837,1.09169,1,3,0,5.95811,0.720607\n",
487+
b"-4.80646,1,1.09169,2,3,0,5.0962,0.528418\n",
488+
]
489+
out = stancsv.extract_max_treedepth_and_divergence_counts(draws, 10, 2)
490+
assert out == (1, 1)
491+
492+
493+
def test_max_treedepth_and_divergence_counts_no_draws():
494+
draws = [
495+
(
496+
b"lp__,accept_stat__,stepsize__,treedepth__,"
497+
b"n_leapfrog__,divergent__,energy__,theta\n"
498+
),
499+
]
500+
out = stancsv.extract_max_treedepth_and_divergence_counts(draws, 10, 0)
501+
assert out == (0, 0)
502+
503+
504+
def test_max_treedepth_and_divergence_invalid():
505+
draws = [
506+
b"lp__,accept_stat__,stepsize__,n_leapfrog__,energy__,theta\n",
507+
b"-4.78686,0.986298,1.09169,3,5.29492,0.550024\n",
508+
]
509+
assert stancsv.extract_max_treedepth_and_divergence_counts(
510+
draws, 10, 0
511+
) == (0, 0)
512+
513+
514+
def test_sneaky_fixed_param_check():
515+
sneaky_header = b"lp__,accept_stat__,N,y_sim.1"
516+
normal_header = (
517+
b"lp__,accept_stat__,stepsize__,treedepth__,"
518+
b"n_leapfrog__,divergent__,energy__,theta"
519+
)
520+
521+
assert stancsv.is_sneaky_fixed_param(sneaky_header)
522+
assert not stancsv.is_sneaky_fixed_param(normal_header)
523+
524+
525+
def test_warmup_sampling_draw_counts():
526+
csv_path = os.path.join(DATAFILES_PATH, "bernoulli_output_1.csv")
527+
assert stancsv.count_warmup_and_sampling_draws(csv_path) == (0, 10)
528+
529+
530+
def test_warmup_sampling_draw_counts_with_warmup():
531+
lines = [
532+
b"# algorithm = hmc (Default)\n",
533+
(
534+
b"lp__,accept_stat__,stepsize__,treedepth__,"
535+
b"n_leapfrog__,divergent__,energy__,theta\n"
536+
),
537+
b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n",
538+
b"# Adaptation terminated\n",
539+
b"# Step size = 0.787025\n",
540+
b"# Diagonal elements of inverse mass matrix:\n",
541+
b"# 1\n",
542+
b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n",
543+
b"# \n",
544+
b"# Elapsed Time: 0.001332 seconds (Warm-up)\n",
545+
]
546+
fio = io.BytesIO(b"".join(lines))
547+
assert stancsv.count_warmup_and_sampling_draws(fio) == (1, 1)
548+
549+
550+
def test_warmup_sampling_draw_counts_fixed_param():
551+
lines = [
552+
b"# algorithm = fixed_param\n",
553+
(
554+
b"lp__,accept_stat__,stepsize__,treedepth__,"
555+
b"n_leapfrog__,divergent__,energy__,theta\n"
556+
),
557+
b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n",
558+
b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n",
559+
b"# \n",
560+
b"# Elapsed Time: 0.001332 seconds (Warm-up)\n",
561+
]
562+
fio = io.BytesIO(b"".join(lines))
563+
assert stancsv.count_warmup_and_sampling_draws(fio) == (0, 2)
564+
565+
566+
def test_warmup_sampling_draw_counts_no_draws():
567+
lines = [
568+
b"# algorithm = fixed_param\n",
569+
(
570+
b"lp__,accept_stat__,stepsize__,treedepth__,"
571+
b"n_leapfrog__,divergent__,energy__,theta\n"
572+
),
573+
b"# Elapsed Time: 0.001332 seconds (Warm-up)\n",
574+
b"# 0.001332 seconds (Sampling)\n",
575+
]
576+
fio = io.BytesIO(b"".join(lines))
577+
assert stancsv.count_warmup_and_sampling_draws(fio) == (0, 0)
578+
579+
580+
def test_warmup_sampling_draw_counts_invalid():
581+
lines = [
582+
b"# algorithm = fixed_param\n",
583+
]
584+
fio = io.BytesIO(b"".join(lines))
585+
with pytest.raises(ValueError):
586+
stancsv.count_warmup_and_sampling_draws(fio)
587+
588+
589+
def test_inconsistent_draws_shape():
590+
draws = [b"a,b\n", b"0,1,2\n"]
591+
with pytest.raises(ValueError):
592+
stancsv.raise_on_inconsistent_draws_shape(draws)
593+
594+
595+
def test_inconsistent_draws_shape_empty():
596+
draws = []
597+
stancsv.raise_on_inconsistent_draws_shape(draws)
598+
599+
600+
def test_invalid_adaptation_block_good():
601+
csv_path = os.path.join(DATAFILES_PATH, "bernoulli_output_1.csv")
602+
comments, _ = stancsv.parse_stan_csv_comments_and_draws(csv_path)
603+
stancsv.raise_on_invalid_adaptation_block(comments)
604+
605+
606+
def test_invalid_adaptation_block_missing():
607+
lines = [
608+
b"# metric = diag_e (Default)\n",
609+
(
610+
b"lp__,accept_stat__,stepsize__,treedepth__,"
611+
b"n_leapfrog__,divergent__,energy__,theta\n"
612+
),
613+
b"-6.76206,1,0.787025,1,1,0,6.81411,0.229458\n",
614+
b"# \n",
615+
b"# Elapsed Time: 0.001332 seconds (Warm-up)\n",
616+
]
617+
with pytest.raises(ValueError, match="expecting metric"):
618+
stancsv.raise_on_invalid_adaptation_block(lines)
619+
620+
621+
def test_invalid_adaptation_block_no_metric():
622+
lines = [
623+
(
624+
b"lp__,accept_stat__,stepsize__,treedepth__,"
625+
b"n_leapfrog__,divergent__,energy__,theta\n"
626+
),
627+
b"# Adaptation terminated\n",
628+
b"# Step size = 0.787025\n",
629+
b"# Diagonal elements of inverse mass matrix:\n",
630+
b"# 1\n",
631+
]
632+
with pytest.raises(ValueError, match="No reported metric"):
633+
stancsv.raise_on_invalid_adaptation_block(lines)
634+
635+
636+
def test_invalid_adaptation_block_invalid_step_size():
637+
lines = [
638+
b"# metric = diag_e (Default)\n",
639+
(
640+
b"lp__,accept_stat__,stepsize__,treedepth__,"
641+
b"n_leapfrog__,divergent__,energy__,theta\n"
642+
),
643+
b"# Adaptation terminated\n",
644+
b"# Step size = bad\n",
645+
b"# Diagonal elements of inverse mass matrix:\n",
646+
b"# 1\n",
647+
]
648+
with pytest.raises(ValueError, match="invalid step size"):
649+
stancsv.raise_on_invalid_adaptation_block(lines)
650+
651+
652+
def test_invalid_adaptation_block_mismatched_structure():
653+
lines = [
654+
b"# metric = diag_e (Default)\n",
655+
(
656+
b"lp__,accept_stat__,stepsize__,treedepth__,"
657+
b"n_leapfrog__,divergent__,energy__,theta\n"
658+
),
659+
b"# Adaptation terminated\n",
660+
b"# Step size = 0.787025\n",
661+
b"# Elements of inverse mass matrix:\n",
662+
b"# 1\n",
663+
]
664+
with pytest.raises(ValueError, match="invalid or missing"):
665+
stancsv.raise_on_invalid_adaptation_block(lines)
666+
667+
668+
def test_invalid_adaptation_block_missing_step_size():
669+
lines = [
670+
b"# metric = diag_e (Default)\n",
671+
(
672+
b"lp__,accept_stat__,stepsize__,treedepth__,"
673+
b"n_leapfrog__,divergent__,energy__,theta\n"
674+
),
675+
b"# Adaptation terminated\n",
676+
b"# Diagonal elements of inverse mass matrix:\n",
677+
b"# 1\n",
678+
]
679+
with pytest.raises(ValueError, match="expecting step size"):
680+
stancsv.raise_on_invalid_adaptation_block(lines)
681+
682+
683+
def test_invalid_adaptation_block_unit_e():
684+
lines = [
685+
b"# metric = unit_e\n",
686+
(
687+
b"lp__,accept_stat__,stepsize__,treedepth__,"
688+
b"n_leapfrog__,divergent__,energy__,theta\n"
689+
),
690+
b"# Adaptation terminated\n",
691+
b"# Step size = 1.77497\n",
692+
b"# No free parameters for unit metric\n",
693+
]
694+
stancsv.raise_on_invalid_adaptation_block(lines)
695+
696+
697+
def test_invalid_adaptation_block_dense_e_valid():
698+
lines = [
699+
b"# metric = dense_e\n",
700+
(
701+
b"lp__,accept_stat__,stepsize__,treedepth__,"
702+
b"n_leapfrog__,divergent__,energy__,theta.1,theta.2,theta.3\n"
703+
),
704+
b"# Adaptation terminated\n",
705+
b"# Step size = 0.775147\n",
706+
b"# Elements of inverse mass matrix:\n",
707+
b"# 2.84091, 0.230843, 0.0509365\n",
708+
b"# 0.230843, 3.92459, 0.126989\n",
709+
b"# 0.0509365, 0.126989, 3.82718\n",
710+
]
711+
stancsv.raise_on_invalid_adaptation_block(lines)
712+
713+
714+
def test_invalid_adaptation_block_dense_e_invalid():
715+
lines = [
716+
b"# metric = dense_e\n",
717+
(
718+
b"lp__,accept_stat__,stepsize__,treedepth__,"
719+
b"n_leapfrog__,divergent__,energy__,theta.1,theta.2,theta.3\n"
720+
),
721+
b"# Adaptation terminated\n",
722+
b"# Step size = 0.775147\n",
723+
b"# Elements of inverse mass matrix:\n",
724+
b"# 2.84091, 0.230843, 0.0509365\n",
725+
b"# 2.84091, 0.230843\n",
726+
b"# 0.230843, 3.92459\n",
727+
]
728+
with pytest.raises(ValueError, match="invalid or missing"):
729+
stancsv.raise_on_invalid_adaptation_block(lines)
730+
731+
732+
def test_parsing_timing_lines():
733+
lines = [
734+
b"# \n",
735+
b"# Elapsed Time: 0.001332 seconds (Warm-up)\n",
736+
b"# 0.000249 seconds (Sampling)\n",
737+
b"# 0.001581 seconds (Total)\n",
738+
b"# \n",
739+
]
740+
out = stancsv.parse_timing_lines(lines)
741+
742+
assert len(out) == 3
743+
assert out['Warm-up'] == 0.001332
744+
assert out['Sampling'] == 0.000249
745+
assert out['Total'] == 0.001581
746+
747+
748+
def test_munge_varname():
749+
name1 = "a"
750+
name2 = "a:1"
751+
name3 = "a:1.2"
752+
assert stancsv.munge_varname(name1) == "a"
753+
assert stancsv.munge_varname(name2) == "a.1"
754+
assert stancsv.munge_varname(name3) == "a.1[2]"

0 commit comments

Comments
 (0)