@@ -5,7 +5,7 @@ library(glmnet)
5
5
6
6
# testing Liu et al type=full in high dimensional settings -- uses debiasing matrix
7
7
8
- test_liu_full = function (seed = 1 , outfile = NULL , loss = " ls " , lambda_frac = 0.7 ,
8
+ test_liu_full = function (seed = 1 , outfile = NULL , family = " gaussian " , lambda_frac = 0.7 ,
9
9
nrep = 50 , n = 200 , p = 500 , s = 20 , rho = 0 . ){
10
10
11
11
snr = sqrt(2 * log(p )/ n )
@@ -28,12 +28,14 @@ test_liu_full = function(seed=1, outfile=NULL, loss="ls", lambda_frac=0.7,
28
28
29
29
for (i in 1 : nrep ){
30
30
31
- if (loss == " ls " ){
31
+ if (family == " gaussian " ){
32
32
sigma = 1
33
33
data = selectiveInference ::: gaussian_instance(n = n , p = p , s = s , rho = rho , sigma = sigma , snr = snr )
34
- } else if (loss == " logit" ){
34
+ loss = ' ls'
35
+ } else if (family == ' binomial' ){
35
36
sigma = 1
36
37
data = selectiveInference ::: logistic_instance(n = n , p = p , s = s , rho = rho , snr = snr )
38
+ loss = ' logit'
37
39
}
38
40
39
41
X = data $ X
@@ -50,26 +52,32 @@ test_liu_full = function(seed=1, outfile=NULL, loss="ls", lambda_frac=0.7,
50
52
lambda = lambda_frac * selectiveInference ::: theoretical.lambda(X , loss , sigma_est ) # theoretical lambda
51
53
print(c(" lambda" , lambda ))
52
54
53
- soln = selectiveInference ::: solve_problem_glmnet(X , y , lambda , penalty_factor = penalty_factor , loss = loss )
54
- PVS = selectiveInference ::: ROSI(X ,
55
- y ,
56
- soln ,
57
- lambda = lambda ,
58
- penalty_factor = penalty_factor ,
59
- sigma_est ,
60
- loss = loss ,
61
- algo = " Q " ,
62
- construct_ci = construct_ci ,
63
- debias_mat = " JM" ,
64
- verbose = TRUE )
55
+ soln = selectiveInference ::: solve_problem_glmnet(X , y , lambda , penalty_factor = penalty_factor , family = family )
56
+ PVS = ROSI(X ,
57
+ y ,
58
+ soln ,
59
+ lambda = lambda ,
60
+ penalty_factor = penalty_factor ,
61
+ dispersion = sigma_est ^ 2 ,
62
+ family = family ,
63
+ solver = " QP " ,
64
+ construct_ci = construct_ci ,
65
+ debiasing_method = " JM" ,
66
+ verbose = TRUE )
65
67
66
- active_vars = PVS $ active_vars
68
+ active_vars = PVS $ active_set
67
69
cat(" active_vars:" ,active_vars ," \n " )
68
70
pvalues = c(pvalues , PVS $ pvalues )
69
- naive_pvalues = c(naive_pvalues , PVS $ naive_pvalues )
70
- sel_intervals = cbind(sel_intervals , PVS $ sel_intervals ) # matrix with two rows
71
- naive_intervals = cbind(naive_intervals , PVS $ naive_intervals )
72
-
71
+ if (family == ' gaussian' ) {
72
+ glm_Xy = glm(y ~ X [,active_vars ] - 1 )
73
+ } else {
74
+ glm_Xy = glm(y ~ X [,active_vars ] - 1 , family = binomial )
75
+ }
76
+ naive_pvalues = c(naive_pvalues , summary(glm_Xy )$ coef [,4 ])
77
+ sel_intervals = rbind(sel_intervals , PVS $ intervals ) # matrix with two rows
78
+ naive_int = confint(glm_Xy , level = 0.9 )
79
+ naive_intervals = rbind(naive_intervals , naive_int )
80
+ print(naive_intervals )
73
81
if (length(pvalues )> 0 ){
74
82
plot(ecdf(pvalues ))
75
83
lines(ecdf(naive_pvalues ), col = " red" )
@@ -78,10 +86,11 @@ test_liu_full = function(seed=1, outfile=NULL, loss="ls", lambda_frac=0.7,
78
86
79
87
if (construct_ci && length(active_vars )> 0 ){
80
88
81
- sel_coverages = c(sel_coverages , selectiveInference ::: compute_coverage(PVS $ sel_intervals , beta [active_vars ]))
82
- naive_coverages = c(naive_coverages , selectiveInference ::: compute_coverage(PVS $ naive_intervals , beta [active_vars ]))
83
- sel_lengths = c(sel_lengths , as.vector(PVS $ sel_intervals [2 ,]- PVS $ sel_intervals [1 ,]))
84
- naive_lengths = c(naive_lengths , as.vector(PVS $ naive_intervals [2 ,]- PVS $ naive_intervals [1 ,]))
89
+
90
+ sel_coverages = c(sel_coverages , selectiveInference ::: compute_coverage(PVS $ intervals , beta [active_vars ]))
91
+ naive_coverages = c(naive_coverages , selectiveInference ::: compute_coverage(naive_int , beta [active_vars ]))
92
+ sel_lengths = c(sel_lengths , as.vector(naive_int [,2 ]- naive_int [,1 ]))
93
+ naive_lengths = c(naive_lengths , as.vector(PVS $ naive_intervals [,2 ]- PVS $ naive_intervals [,1 ]))
85
94
# cat("sel cov", sel_coverages, "\n")
86
95
print(c(" selective coverage:" , mean(sel_coverages )))
87
96
print(c(" naive coverage:" , mean(naive_coverages )))
0 commit comments