Skip to content

Commit 7364528

Browse files
removing gglasso dependence, bugfix in C code, randomized code
1 parent c35706b commit 7364528

File tree

5 files changed

+11
-58
lines changed

5 files changed

+11
-58
lines changed

C-software

selectiveInference/DESCRIPTION

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ Depends:
1010
glmnet,
1111
intervals,
1212
survival,
13-
adaptMCMC,
14-
gglasso
13+
adaptMCMC
1514
Suggests:
1615
Rmpfr
1716
Description: New tools for post-selection inference, for use with forward

selectiveInference/NAMESPACE

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,3 @@ importFrom("stats", "glm", "residuals", "vcov")
6363
importFrom("stats", "rbinom", "rexp")
6464
importFrom("Rcpp", "sourceCpp")
6565
importFrom("adaptMCMC", "MCMC")
66-
importFrom("gglasso", "gglasso")
67-
importFrom("gglasso", "cv.gglasso")

selectiveInference/R/funs.ROSI.R

Lines changed: 8 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -24,57 +24,14 @@ solve_problem_glmnet = function(X, y, lambda_glmnet, penalty_factor, family){
2424
return(beta_hat[-1])
2525
}
2626

27-
# solves full group lasso problem via gglasso
28-
solve_problem_gglasso = function(X, y, groups, lambda_glmnet, penalty_factor, family){
29-
if (is.null(lambda_glmnet)){
30-
cv <- cv.gglasso(x=X,
31-
y=y,
32-
group=groups,
33-
loss=loss_label(family),
34-
pf=penalty_factor,
35-
intercept=FALSE,
36-
eps=1e-12)
37-
beta_hat = coef(cv, s="lambda.min")
38-
}
39-
else {
40-
# gglasso for logit loss needs the response to be in {-1,1}
41-
if (family == 'binomial') {
42-
y_pm1 = rep(y)
43-
y_pm1[which(y==0)]=-1
44-
} else if (family == 'gaussian'){
45-
y_pm1 = rep(y)
46-
}
47-
m = gglasso(x=X,
48-
y=y_pm1,
49-
group=groups,
50-
loss=loss_label(family),
51-
pf=penalty_factor,
52-
intercept=FALSE,
53-
eps=1e-20)
54-
beta_hat = coef(m, s=lambda_glmnet)
55-
}
56-
return(beta_hat[-1])
57-
}
58-
5927
# solves the restricted problem
60-
solve_restricted_problem = function(X, y, var, lambda_glmnet, penalty_factor, loss, solver){
61-
if (solver=="glmnet"){
62-
restricted_soln=rep(0, ncol(X))
63-
restricted_soln[-var] = solve_problem_glmnet(X[,-var],
64-
y,
65-
lambda_glmnet,
66-
penalty_factor[-var],
67-
family=family_label(loss))
68-
} else if (solver=="gglasso"){
69-
penalty_factor_rest = rep(penalty_factor)
70-
penalty_factor_rest[var] = 10^10
71-
restricted_soln = solve_problem_gglasso(X,
72-
y,
73-
1:ncol(X),
74-
lambda_glmnet,
75-
penalty_factor=penalty_factor_rest,
76-
family=family_label(loss))
77-
}
28+
solve_restricted_problem = function(X, y, var, lambda_glmnet, penalty_factor, loss){
29+
restricted_soln=rep(0, ncol(X))
30+
restricted_soln[-var] = solve_problem_glmnet(X[,-var],
31+
y,
32+
lambda_glmnet,
33+
penalty_factor[-var],
34+
family=family_label(loss))
7835
return(restricted_soln)
7936
}
8037

@@ -161,8 +118,7 @@ truncation_set = function(X,
161118
var,
162119
lambda_glmnet,
163120
penalty_factor=penalty_factor,
164-
loss=loss,
165-
solver=solver)
121+
loss=loss)
166122
}
167123

168124
n = nrow(X)

selectiveInference/R/funs.randomized.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ randomizedLassoInf = function(rand_lasso_soln,
492492
}
493493

494494
if (is.null(targets)){
495-
targets = compute_target(rand_lasso_soln, type="partial")
495+
targets = compute_target(rand_lasso_soln, type="selected")
496496
}
497497

498498
alternatives = targets$alternatives

0 commit comments

Comments
 (0)