1- library(loo )
21set.seed(123 )
32
43LLarr <- example_loglik_array()
@@ -12,45 +11,63 @@ test_that("loo_compare throws appropriate errors", {
1211 w4 <- suppressWarnings(waic(LLarr [,, - (1 : 2 )]))
1312
1413 expect_error(loo_compare(2 , 3 ), " must be a list if not a 'loo' object" )
15- expect_error(loo_compare(w1 , w2 , x = list (w1 , w2 )),
16- " If 'x' is a list then '...' should not be specified" )
17- expect_error(loo_compare(w1 , list (1 ,2 ,3 )), " class 'loo'" )
14+ expect_error(
15+ loo_compare(w1 , w2 , x = list (w1 , w2 )),
16+ " If 'x' is a list then '...' should not be specified"
17+ )
18+ expect_error(loo_compare(w1 , list (1 , 2 , 3 )), " class 'loo'" )
1819 expect_error(loo_compare(w1 ), " requires at least two models" )
1920 expect_error(loo_compare(x = list (w1 )), " requires at least two models" )
2021 expect_error(loo_compare(w1 , w3 ), " same number of data points" )
2122 expect_error(loo_compare(w1 , w2 , w3 ), " same number of data points" )
2223})
2324
2425test_that(" loo_compare throws appropriate warnings" , {
25- w3 <- w1 ; w4 <- w2
26+ w3 <- w1
27+ w4 <- w2
2628 class(w3 ) <- class(w4 ) <- c(" kfold" , " loo" )
2729 attr(w3 , " K" ) <- 2
2830 attr(w4 , " K" ) <- 3
29- expect_warning(loo_compare(w3 , w4 ), " Not all kfold objects have the same K value" )
31+ expect_warning(
32+ loo_compare(w3 , w4 ),
33+ " Not all kfold objects have the same K value"
34+ )
3035
3136 class(w4 ) <- c(" psis_loo" , " loo" )
3237 attr(w4 , " K" ) <- NULL
3338 expect_warning(loo_compare(w3 , w4 ), " Comparing LOO-CV to K-fold-CV" )
3439
35- w3 <- w1 ; w4 <- w2
40+ w3 <- w1
41+ w4 <- w2
3642 attr(w3 , " yhash" ) <- " a"
3743 attr(w4 , " yhash" ) <- " b"
3844 expect_warning(loo_compare(w3 , w4 ), " Not all models have the same y variable" )
3945
4046 set.seed(123 )
41- w_list <- lapply(1 : 25 , function (x ) suppressWarnings(waic(LLarr + rnorm(1 , 0 , 0.1 ))))
42- expect_warning(loo_compare(w_list ),
43- " Difference in performance potentially due to chance" )
44-
45- w_list_short <- lapply(1 : 4 , function (x ) suppressWarnings(waic(LLarr + rnorm(1 , 0 , 0.1 ))))
47+ w_list <- lapply(1 : 25 , function (x ) {
48+ suppressWarnings(waic(LLarr + rnorm(1 , 0 , 0.1 )))
49+ })
50+ expect_warning(
51+ loo_compare(w_list ),
52+ " Difference in performance potentially due to chance"
53+ )
54+
55+ w_list_short <- lapply(1 : 4 , function (x ) {
56+ suppressWarnings(waic(LLarr + rnorm(1 , 0 , 0.1 )))
57+ })
4658 expect_no_warning(loo_compare(w_list_short ))
4759})
4860
4961
50-
5162comp_colnames <- c(
52- " elpd_diff" , " se_diff" , " elpd_waic" , " se_elpd_waic" ,
53- " p_waic" , " se_p_waic" , " waic" , " se_waic"
63+ " elpd_diff" ,
64+ " se_diff" ,
65+ " elpd_waic" ,
66+ " se_elpd_waic" ,
67+ " p_waic" ,
68+ " se_p_waic" ,
69+ " waic" ,
70+ " se_waic"
5471)
5572
5673test_that(" loo_compare returns expected results (2 models)" , {
@@ -59,15 +76,15 @@ test_that("loo_compare returns expected results (2 models)", {
5976 expect_equal(colnames(comp1 ), comp_colnames )
6077 expect_equal(rownames(comp1 ), c(" model1" , " model2" ))
6178 expect_output(print(comp1 ), " elpd_diff" )
62- expect_equal(comp1 [1 : 2 ,1 ], c(0 , 0 ), ignore_attr = TRUE )
63- expect_equal(comp1 [1 : 2 ,2 ], c(0 , 0 ), ignore_attr = TRUE )
79+ expect_equal(comp1 [1 : 2 , 1 ], c(0 , 0 ), ignore_attr = TRUE )
80+ expect_equal(comp1 [1 : 2 , 2 ], c(0 , 0 ), ignore_attr = TRUE )
6481
6582 comp2 <- loo_compare(w1 , w2 )
6683 expect_s3_class(comp2 , " compare.loo" )
6784 expect_equal(colnames(comp2 ), comp_colnames )
68-
85+
6986 expect_snapshot_value(comp2 , style = " serialize" )
70-
87+
7188 # specifying objects via ... and via arg x gives equal results
7289 expect_equal(comp2 , loo_compare(x = list (w1 , w2 )))
7390})
@@ -79,7 +96,7 @@ test_that("loo_compare returns expected result (3 models)", {
7996
8097 expect_equal(colnames(comp1 ), comp_colnames )
8198 expect_equal(rownames(comp1 ), c(" model1" , " model2" , " model3" ))
82- expect_equal(comp1 [1 ,1 ], 0 )
99+ expect_equal(comp1 [1 , 1 ], 0 )
83100 expect_s3_class(comp1 , " compare.loo" )
84101 expect_s3_class(comp1 , " matrix" )
85102
@@ -119,34 +136,53 @@ test_that("compare returns expected result (3 models)", {
119136 expect_equal(
120137 colnames(comp1 ),
121138 c(
122- " elpd_diff" , " se_diff" , " elpd_waic" , " se_elpd_waic" ,
123- " p_waic" , " se_p_waic" , " waic" , " se_waic"
124- ))
139+ " elpd_diff" ,
140+ " se_diff" ,
141+ " elpd_waic" ,
142+ " se_elpd_waic" ,
143+ " p_waic" ,
144+ " se_p_waic" ,
145+ " waic" ,
146+ " se_waic"
147+ )
148+ )
125149 expect_equal(rownames(comp1 ), c(" w1" , " w2" , " w3" ))
126- expect_equal(comp1 [1 ,1 ], 0 )
150+ expect_equal(comp1 [1 , 1 ], 0 )
127151 expect_s3_class(comp1 , " compare.loo" )
128152 expect_s3_class(comp1 , " matrix" )
129153 expect_snapshot_value(comp1 , style = " serialize" )
130154
131155 # specifying objects via '...' gives equivalent results (equal
132156 # except rownames) to using 'x' argument
133- expect_warning(comp_via_list <- loo :: compare(x = list (w1 , w2 , w3 )), " Deprecated" )
157+ expect_warning(
158+ comp_via_list <- loo :: compare(x = list (w1 , w2 , w3 )),
159+ " Deprecated"
160+ )
134161 expect_equal(comp1 , comp_via_list , ignore_attr = TRUE )
135162})
136163
137164test_that(" compare throws appropriate errors" , {
138- expect_error(suppressWarnings(loo :: compare(w1 , w2 , x = list (w1 , w2 ))),
139- " should not be specified" )
140- expect_error(suppressWarnings(loo :: compare(x = 2 )),
141- " must be a list" )
142- expect_error(suppressWarnings(loo :: compare(x = list (2 ))),
143- " should have class 'loo'" )
144- expect_error(suppressWarnings(loo :: compare(x = list (w1 ))),
145- " requires at least two models" )
146-
147- w3 <- suppressWarnings(waic(LLarr2 [,,- 1 ]))
148- expect_error(suppressWarnings(loo :: compare(x = list (w1 , w3 ))),
149- " same number of data points" )
150- expect_error(suppressWarnings(loo :: compare(x = list (w1 , w2 , w3 ))),
151- " same number of data points" )
165+ expect_error(
166+ suppressWarnings(loo :: compare(w1 , w2 , x = list (w1 , w2 ))),
167+ " should not be specified"
168+ )
169+ expect_error(suppressWarnings(loo :: compare(x = 2 )), " must be a list" )
170+ expect_error(
171+ suppressWarnings(loo :: compare(x = list (2 ))),
172+ " should have class 'loo'"
173+ )
174+ expect_error(
175+ suppressWarnings(loo :: compare(x = list (w1 ))),
176+ " requires at least two models"
177+ )
178+
179+ w3 <- suppressWarnings(waic(LLarr2 [,, - 1 ]))
180+ expect_error(
181+ suppressWarnings(loo :: compare(x = list (w1 , w3 ))),
182+ " same number of data points"
183+ )
184+ expect_error(
185+ suppressWarnings(loo :: compare(x = list (w1 , w2 , w3 ))),
186+ " same number of data points"
187+ )
152188})
0 commit comments