Skip to content

Commit ae1031b

Browse files
fixed print.ROSI doc issue, made ROSI take lambda on same scale as fixedLassoInf
1 parent fe56810 commit ae1031b

File tree

3 files changed

+26
-24
lines changed

3 files changed

+26
-24
lines changed

selectiveInference/R/funs.ROSI.R

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# solves full lasso problem via glmnet
22

3-
solve_problem_glmnet = function(X, y, lambda, penalty_factor, family){
4-
if (is.null(lambda)){
3+
solve_problem_glmnet = function(X, y, lambda_glmnet, penalty_factor, family){
4+
if (is.null(lambda_glmnet)){
55
cv = cv.glmnet(x=X,
66
y=y,
77
family=family,
@@ -19,14 +19,14 @@ solve_problem_glmnet = function(X, y, lambda, penalty_factor, family){
1919
standardize=FALSE,
2020
intercept=FALSE,
2121
thresh=1e-20)
22-
beta_hat = coef(lasso, s=lambda)
22+
beta_hat = coef(lasso, s=lambda_glmnet)
2323
}
2424
return(beta_hat[-1])
2525
}
2626

2727
# solves full group lasso problem via gglasso
28-
solve_problem_gglasso = function(X, y, groups, lambda, penalty_factor, family){
29-
if (is.null(lambda)){
28+
solve_problem_gglasso = function(X, y, groups, lambda_glmnet, penalty_factor, family){
29+
if (is.null(lambda_glmnet)){
3030
cv <- cv.gglasso(x=X,
3131
y=y,
3232
group=groups,
@@ -51,18 +51,18 @@ solve_problem_gglasso = function(X, y, groups, lambda, penalty_factor, family){
5151
pf=penalty_factor,
5252
intercept=FALSE,
5353
eps=1e-20)
54-
beta_hat = coef(m, s=lambda)
54+
beta_hat = coef(m, s=lambda_glmnet)
5555
}
5656
return(beta_hat[-1])
5757
}
5858

5959
# solves the restricted problem
60-
solve_restricted_problem = function(X, y, var, lambda, penalty_factor, loss, solver){
60+
solve_restricted_problem = function(X, y, var, lambda_glmnet, penalty_factor, loss, solver){
6161
if (solver=="glmnet"){
6262
restricted_soln=rep(0, ncol(X))
6363
restricted_soln[-var] = solve_problem_glmnet(X[,-var],
6464
y,
65-
lambda,
65+
lambda_glmnet,
6666
penalty_factor[-var],
6767
family=family_label(loss))
6868
} else if (solver=="gglasso"){
@@ -71,7 +71,7 @@ solve_restricted_problem = function(X, y, var, lambda, penalty_factor, loss, sol
7171
restricted_soln = solve_problem_gglasso(X,
7272
y,
7373
1:ncol(X),
74-
lambda,
74+
lambda_glmnet,
7575
penalty_factor=penalty_factor_rest,
7676
family=family_label(loss))
7777
}
@@ -80,7 +80,7 @@ solve_restricted_problem = function(X, y, var, lambda, penalty_factor, loss, sol
8080

8181
solve_problem_Q = function(Q_sq,
8282
Qbeta_bar,
83-
lambda,
83+
lambda_glmnet,
8484
penalty_factor,
8585
max_iter=50,
8686
kkt_tol=1.e-4,
@@ -108,7 +108,7 @@ solve_problem_Q = function(Q_sq,
108108

109109
#solve_QP_wide solves n*slinear_func^T\beta+\beta^T Xinfo\beta+\sum\lambda_i|\beta_i|
110110
result = solve_QP_wide(Xinfo, # this is a design matrix
111-
as.numeric(penalty_factor*lambda), # vector of Lagrange multipliers
111+
as.numeric(penalty_factor*lambda_glmnet), # vector of Lagrange multipliers
112112
0, # ridge_term
113113
max_iter,
114114
soln,
@@ -143,7 +143,7 @@ truncation_set = function(X,
143143
target_cov,
144144
var,
145145
active_set,
146-
lambda,
146+
lambda_glmnet,
147147
penalty_factor,
148148
loss,
149149
solver){
@@ -153,13 +153,13 @@ truncation_set = function(X,
153153
penalty_factor_rest[var] = 10^10
154154
restricted_soln = solve_problem_Q(Q_sq,
155155
Qbeta_bar,
156-
lambda,
156+
lambda_glmnet,
157157
penalty_factor=penalty_factor_rest)
158158
} else {
159159
restricted_soln = solve_restricted_problem(X,
160160
y,
161161
var,
162-
lambda,
162+
lambda_glmnet,
163163
penalty_factor=penalty_factor,
164164
loss=loss,
165165
solver=solver)
@@ -170,7 +170,7 @@ truncation_set = function(X,
170170
nuisance_res = (Qbeta_bar[var] - # nuisance stat restricted to active vars
171171
solve(target_cov) %*% target_stat)/n
172172
center = nuisance_res - (QE[idx,] %*% restricted_soln/n)
173-
radius = penalty_factor[var]*lambda
173+
radius = penalty_factor[var]*lambda_glmnet
174174
return(list(center=center*n, radius=radius*n))
175175
}
176176

@@ -454,6 +454,8 @@ ROSI = function(X,
454454

455455
begin_TS = Sys.time()
456456

457+
n = nrow(X)
458+
lambda_glmnet = lambda / n
457459
TS = truncation_set(X=X,
458460
y=y,
459461
Qbeta_bar=Qbeta_bar,
@@ -463,7 +465,7 @@ ROSI = function(X,
463465
target_stat=target_stat,
464466
var=active_set[i],
465467
active_set=active_set,
466-
lambda=lambda,
468+
lambda_glmnet=lambda_glmnet,
467469
penalty_factor=penalty_factor,
468470
loss=loss_label(family),
469471
solver=solver)

selectiveInference/man/ROSI.Rd

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,11 @@ gfit = glmnet(x,y,standardize=FALSE)
138138

139139
# extract coef for a given lambda; note the 1/n factor!
140140
# (and we don't save the intercept term)
141-
lambda = 4 / sqrt(n)
141+
lambda = 4 * sqrt(n)
142+
lambda_glmnet = 4 / sqrt(n)
142143
beta = selectiveInference:::solve_problem_glmnet(x,
143144
y,
144-
lambda,
145+
lambda_glmnet,
145146
penalty_factor=rep(1,p),
146147
family="gaussian")
147148
# compute fixed lambda p-values and selection intervals

selectiveInference/man/selectiveInference-internal.Rd

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
\name{selectiveInference-internal}
22
\title{Internal PMA functions}
33
\alias{print.fixedLassoInf}
4-
\alias{print.fs}
4+
\alias{print.fs}
55
\alias{print.fsInf}
66
\alias{print.larInf}
77
\alias{print.lar}
8-
\alias{print.manyMeans}
9-
10-
8+
\alias{print.manyMeans}
9+
\alias{print.ROSI}
1110

1211
\description{Internal selectiveInference functions}
1312
\usage{
@@ -16,8 +15,8 @@
1615
\method{print}{lar}(x,...)
1716
\method{print}{larInf}(x, tailarea = TRUE, ...)
1817
\method{print}{fixedLassoInf}(x, tailarea = TRUE, ...)
19-
\method{print}{manyMeans}(x,...)
20-
18+
\method{print}{manyMeans}(x, ...)
19+
\method{print}{ROSI}(x, ...)
2120
}
2221
\author{Ryan Tibshirani, Rob Tibshirani, Jonathan Taylor, Joshua Loftus, Stephen Reid}
2322
\keyword{internal}

0 commit comments

Comments
 (0)