1
1
# solves full lasso problem via glmnet
2
2
3
- solve_problem_glmnet = function (X , y , lambda , penalty_factor , family ){
4
- if (is.null(lambda )){
3
+ solve_problem_glmnet = function (X , y , lambda_glmnet , penalty_factor , family ){
4
+ if (is.null(lambda_glmnet )){
5
5
cv = cv.glmnet(x = X ,
6
6
y = y ,
7
7
family = family ,
@@ -19,14 +19,14 @@ solve_problem_glmnet = function(X, y, lambda, penalty_factor, family){
19
19
standardize = FALSE ,
20
20
intercept = FALSE ,
21
21
thresh = 1e-20 )
22
- beta_hat = coef(lasso , s = lambda )
22
+ beta_hat = coef(lasso , s = lambda_glmnet )
23
23
}
24
24
return (beta_hat [- 1 ])
25
25
}
26
26
27
27
# solves full group lasso problem via gglasso
28
- solve_problem_gglasso = function (X , y , groups , lambda , penalty_factor , family ){
29
- if (is.null(lambda )){
28
+ solve_problem_gglasso = function (X , y , groups , lambda_glmnet , penalty_factor , family ){
29
+ if (is.null(lambda_glmnet )){
30
30
cv <- cv.gglasso(x = X ,
31
31
y = y ,
32
32
group = groups ,
@@ -51,18 +51,18 @@ solve_problem_gglasso = function(X, y, groups, lambda, penalty_factor, family){
51
51
pf = penalty_factor ,
52
52
intercept = FALSE ,
53
53
eps = 1e-20 )
54
- beta_hat = coef(m , s = lambda )
54
+ beta_hat = coef(m , s = lambda_glmnet )
55
55
}
56
56
return (beta_hat [- 1 ])
57
57
}
58
58
59
59
# solves the restricted problem
60
- solve_restricted_problem = function (X , y , var , lambda , penalty_factor , loss , solver ){
60
+ solve_restricted_problem = function (X , y , var , lambda_glmnet , penalty_factor , loss , solver ){
61
61
if (solver == " glmnet" ){
62
62
restricted_soln = rep(0 , ncol(X ))
63
63
restricted_soln [- var ] = solve_problem_glmnet(X [,- var ],
64
64
y ,
65
- lambda ,
65
+ lambda_glmnet ,
66
66
penalty_factor [- var ],
67
67
family = family_label(loss ))
68
68
} else if (solver == " gglasso" ){
@@ -71,7 +71,7 @@ solve_restricted_problem = function(X, y, var, lambda, penalty_factor, loss, sol
71
71
restricted_soln = solve_problem_gglasso(X ,
72
72
y ,
73
73
1 : ncol(X ),
74
- lambda ,
74
+ lambda_glmnet ,
75
75
penalty_factor = penalty_factor_rest ,
76
76
family = family_label(loss ))
77
77
}
@@ -80,7 +80,7 @@ solve_restricted_problem = function(X, y, var, lambda, penalty_factor, loss, sol
80
80
81
81
solve_problem_Q = function (Q_sq ,
82
82
Qbeta_bar ,
83
- lambda ,
83
+ lambda_glmnet ,
84
84
penalty_factor ,
85
85
max_iter = 50 ,
86
86
kkt_tol = 1.e-4 ,
@@ -108,7 +108,7 @@ solve_problem_Q = function(Q_sq,
108
108
109
109
# solve_QP_wide solves n*slinear_func^T\beta+\beta^T Xinfo\beta+\sum\lambda_i|\beta_i|
110
110
result = solve_QP_wide(Xinfo , # this is a design matrix
111
- as.numeric(penalty_factor * lambda ), # vector of Lagrange multipliers
111
+ as.numeric(penalty_factor * lambda_glmnet ), # vector of Lagrange multipliers
112
112
0 , # ridge_term
113
113
max_iter ,
114
114
soln ,
@@ -143,7 +143,7 @@ truncation_set = function(X,
143
143
target_cov ,
144
144
var ,
145
145
active_set ,
146
- lambda ,
146
+ lambda_glmnet ,
147
147
penalty_factor ,
148
148
loss ,
149
149
solver ){
@@ -153,13 +153,13 @@ truncation_set = function(X,
153
153
penalty_factor_rest [var ] = 10 ^ 10
154
154
restricted_soln = solve_problem_Q(Q_sq ,
155
155
Qbeta_bar ,
156
- lambda ,
156
+ lambda_glmnet ,
157
157
penalty_factor = penalty_factor_rest )
158
158
} else {
159
159
restricted_soln = solve_restricted_problem(X ,
160
160
y ,
161
161
var ,
162
- lambda ,
162
+ lambda_glmnet ,
163
163
penalty_factor = penalty_factor ,
164
164
loss = loss ,
165
165
solver = solver )
@@ -170,7 +170,7 @@ truncation_set = function(X,
170
170
nuisance_res = (Qbeta_bar [var ] - # nuisance stat restricted to active vars
171
171
solve(target_cov ) %*% target_stat )/ n
172
172
center = nuisance_res - (QE [idx ,] %*% restricted_soln / n )
173
- radius = penalty_factor [var ]* lambda
173
+ radius = penalty_factor [var ]* lambda_glmnet
174
174
return (list (center = center * n , radius = radius * n ))
175
175
}
176
176
@@ -454,6 +454,8 @@ ROSI = function(X,
454
454
455
455
begin_TS = Sys.time()
456
456
457
+ n = nrow(X )
458
+ lambda_glmnet = lambda / n
457
459
TS = truncation_set(X = X ,
458
460
y = y ,
459
461
Qbeta_bar = Qbeta_bar ,
@@ -463,7 +465,7 @@ ROSI = function(X,
463
465
target_stat = target_stat ,
464
466
var = active_set [i ],
465
467
active_set = active_set ,
466
- lambda = lambda ,
468
+ lambda_glmnet = lambda_glmnet ,
467
469
penalty_factor = penalty_factor ,
468
470
loss = loss_label(family ),
469
471
solver = solver )
0 commit comments