Skip to content

Commit ff598d0

Browse files
authored
Merge pull request #36 from jonathan-taylor/master
fix of debiasingMatrix path, beginning of randomized LASSO
2 parents 4e42c75 + 63e60df commit ff598d0

23 files changed

+2320
-290
lines changed

selectiveInference/DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Maintainer: Rob Tibshirani <[email protected]>
99
Depends:
1010
glmnet,
1111
intervals,
12-
survival
12+
survival,
1313
Suggests:
1414
Rmpfr
1515
Description: New tools for post-selection inference, for use with forward

selectiveInference/NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,6 @@ importFrom("stats", dnorm, lsfit, pexp, pnorm, predict,
4343
qnorm, rnorm, sd, uniroot, dchisq, model.matrix, pchisq)
4444
importFrom("stats", "coef", "df", "lm", "pf")
4545
importFrom("stats", "glm", "residuals", "vcov")
46+
importFrom("stats", "rbinom", "rexp")
4647
importFrom("Rcpp", "sourceCpp")
4748

selectiveInference/R/RcppExports.R

Lines changed: 0 additions & 19 deletions
This file was deleted.

selectiveInference/R/funs.fixed.R

Lines changed: 67 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ fixedLassoInf <- function(x, y, beta,
66
lambda, family=c("gaussian","binomial","cox"),
77
intercept=TRUE, add.targets=NULL, status=NULL,
88
sigma=NULL, alpha=0.1,
9-
type=c("partial","full"), tol.beta=1e-5, tol.kkt=0.1,
9+
type=c("partial", "full"), tol.beta=1e-5, tol.kkt=0.1,
1010
gridrange=c(-100,100), bits=NULL, verbose=FALSE,
1111
linesearch.try=10) {
1212

@@ -150,7 +150,7 @@ fixedLassoInf <- function(x, y, beta,
150150
ci = tailarea = matrix(0,k,2)
151151

152152
if (type=="full" & p > n) {
153-
if (intercept == T) {
153+
if (intercept == TRUE) {
154154
pp=p+1
155155
Xint <- cbind(rep(1,n),x)
156156
# indices of selected predictors
@@ -189,8 +189,10 @@ fixedLassoInf <- function(x, y, beta,
189189
}
190190

191191
M <- (((htheta%*%t(Xordered))+ithetasigma%*%FS%*%hsigmaSinv%*%t(XS))/n)
192+
192193
# vector which is offset for testing debiased beta's
193194
null_value <- (((ithetasigma%*%FS%*%hsigmaSinv)%*%sign(hbetaS))*lambda/n)
195+
194196
if (intercept == T) {
195197
M = M[-1,] # remove intercept row
196198
null_value = null_value[-1] # remove intercept element
@@ -238,12 +240,23 @@ fixedLassoInf <- function(x, y, beta,
238240
tailarea[j,] = a$tailarea
239241
}
240242

241-
out = list(type=type,lambda=lambda,pv=pv,ci=ci,
242-
tailarea=tailarea,vlo=vlo,vup=vup,vmat=vmat,y=y,
243-
vars=vars,sign=sign_vars,sigma=sigma,alpha=alpha,
244-
sd=sigma*sqrt(rowSums(vmat^2)),
245-
coef0=vmat%*%y,
246-
call=this.call)
243+
out = list(type=type,
244+
lambda=lambda,
245+
pv=pv,
246+
ci=ci,
247+
tailarea=tailarea,
248+
vlo=vlo,
249+
vup=vup,
250+
vmat=vmat,
251+
y=y,
252+
vars=vars,
253+
sign=sign_vars,
254+
sigma=sigma,
255+
alpha=alpha,
256+
sd=sigma*sqrt(rowSums(vmat^2)),
257+
coef0=vmat%*%y,
258+
call=this.call)
259+
247260
class(out) = "fixedLassoInf"
248261
return(out)
249262
}
@@ -306,15 +319,19 @@ debiasingMatrix = function(Xinfo, # could be X or t(X) %*% X / n d
306319
nsample,
307320
rows,
308321
verbose=FALSE,
309-
mu=NULL, # starting value of mu
322+
bound=NULL, # starting value of bound
310323
linesearch=TRUE, # do a linesearch?
311324
scaling_factor=1.5, # multiplicative factor for linesearch
312325
max_active=NULL, # how big can active set get?
313326
max_try=10, # how many steps in linesearch?
314327
warn_kkt=FALSE, # warn if KKT does not seem to be satisfied?
315-
max_iter=100, # how many iterations for each optimization problem
328+
max_iter=50, # how many iterations for each optimization problem
329+
kkt_stop=TRUE, # stop based on KKT conditions?
330+
parameter_stop=TRUE, # stop based on relative convergence of parameter?
331+
objective_stop=TRUE, # stop based on relative decrease in objective?
316332
kkt_tol=1.e-4, # tolerance for the KKT conditions
317-
objective_tol=1.e-8 # tolerance for relative decrease in objective
333+
parameter_tol=1.e-4, # tolerance for relative convergence of parameter
334+
objective_tol=1.e-4 # tolerance for relative decrease in objective
318335
) {
319336

320337

@@ -325,8 +342,8 @@ debiasingMatrix = function(Xinfo, # could be X or t(X) %*% X / n d
325342
p = ncol(Xinfo);
326343
M = matrix(0, length(rows), p);
327344

328-
if (is.null(mu)) {
329-
mu = (1/sqrt(nsample)) * qnorm(1-(0.1/(p^2)))
345+
if (is.null(bound)) {
346+
bound = (1/sqrt(nsample)) * qnorm(1-(0.1/(p^2)))
330347
}
331348

332349
xperc = 0;
@@ -342,14 +359,18 @@ debiasingMatrix = function(Xinfo, # could be X or t(X) %*% X / n d
342359
output = debiasingRow(Xinfo, # could be X or t(X) %*% X / n depending on is_wide
343360
is_wide,
344361
row,
345-
mu,
362+
bound,
346363
linesearch=linesearch,
347364
scaling_factor=scaling_factor,
348365
max_active=max_active,
349366
max_try=max_try,
350367
warn_kkt=FALSE,
351368
max_iter=max_iter,
369+
kkt_stop=kkt_stop,
370+
parameter_stop=parameter_stop,
371+
objective_stop=objective_stop,
352372
kkt_tol=kkt_tol,
373+
parameter_tol=parameter_tol,
353374
objective_tol=objective_tol)
354375

355376
if (warn_kkt && (!output$kkt_check)) {
@@ -372,15 +393,19 @@ debiasingMatrix = function(Xinfo, # could be X or t(X) %*% X / n d
372393
debiasingRow = function (Xinfo, # could be X or t(X) %*% X / n depending on is_wide
373394
is_wide,
374395
row,
375-
mu,
396+
bound,
376397
linesearch=TRUE, # do a linesearch?
377-
scaling_factor=1.2, # multiplicative factor for linesearch
398+
scaling_factor=1.5, # multiplicative factor for linesearch
378399
max_active=NULL, # how big can active set get?
379400
max_try=10, # how many steps in linesearch?
380401
warn_kkt=FALSE, # warn if KKT does not seem to be satisfied?
381-
max_iter=100, # how many iterations for each optimization problem
402+
max_iter=50, # how many iterations for each optimization problem
403+
kkt_stop=TRUE, # stop based on KKT conditions?
404+
parameter_stop=TRUE, # stop based on relative convergence of parameter?
405+
objective_stop=TRUE, # stop based on relative decrease in objective?
382406
kkt_tol=1.e-4, # tolerance for the KKT conditions
383-
objective_tol=1.e-8 # tolerance for relative decrease in objective
407+
parameter_tol=1.e-4, # tolerance for relative convergence of parameter
408+
objective_tol=1.e-4 # tolerance for relative decrease in objective
384409
) {
385410

386411
p = ncol(Xinfo)
@@ -389,9 +414,11 @@ debiasingRow = function (Xinfo, # could be X or t(X) %*% X / n dep
389414
max_active = min(nrow(Xinfo), ncol(Xinfo))
390415
}
391416

417+
392418
# Initialize variables
393419

394420
soln = rep(0, p)
421+
soln = as.numeric(soln)
395422
ever_active = rep(0, p)
396423
ever_active[1] = row # 1-based
397424
ever_active = as.integer(ever_active)
@@ -407,11 +434,15 @@ debiasingRow = function (Xinfo, # could be X or t(X) %*% X / n dep
407434

408435
last_output = NULL
409436

437+
if (is_wide) {
438+
Xsoln = as.numeric(rep(0, nrow(Xinfo)))
439+
}
440+
410441
while (counter_idx < max_try) {
411442

412443
if (!is_wide) {
413444
result = solve_QP(Xinfo, # this is non-neg-def matrix
414-
mu,
445+
bound,
415446
max_iter,
416447
soln,
417448
linear_func,
@@ -420,11 +451,15 @@ debiasingRow = function (Xinfo, # could be X or t(X) %*% X / n dep
420451
nactive,
421452
kkt_tol,
422453
objective_tol,
423-
max_active)
454+
parameter_tol,
455+
max_active,
456+
kkt_stop,
457+
objective_stop,
458+
parameter_stop)
424459
} else {
425-
Xsoln = rep(0, nrow(Xinfo))
426-
result = solve_QP_wide(Xinfo, # this is a design matrix
427-
mu,
460+
result = solve_QP_wide(Xinfo, # this is a design matrix
461+
as.numeric(rep(bound, p)), # vector of Lagrange multipliers
462+
0, # ridge_term
428463
max_iter,
429464
soln,
430465
linear_func,
@@ -434,7 +469,11 @@ debiasingRow = function (Xinfo, # could be X or t(X) %*% X / n dep
434469
nactive,
435470
kkt_tol,
436471
objective_tol,
437-
max_active)
472+
parameter_tol,
473+
max_active,
474+
kkt_stop,
475+
objective_stop,
476+
parameter_stop)
438477

439478
}
440479

@@ -458,13 +497,13 @@ debiasingRow = function (Xinfo, # could be X or t(X) %*% X / n dep
458497
if ((iter < (max_iter+1)) && (counter_idx > 1)) {
459498
break; # we've found a feasible point and solved the problem
460499
}
461-
mu = mu * scaling_factor;
500+
bound = bound * scaling_factor;
462501
} else { # trying to drop the bound parameter further
463502
if ((iter == (max_iter + 1)) && (counter_idx > 1)) {
464503
result = last_output; # problem seems infeasible because we didn't solve it
465504
break; # so we revert to previously found solution
466505
}
467-
mu = mu / scaling_factor;
506+
bound = bound / scaling_factor;
468507
}
469508

470509
# If the active set has grown to a certain size
@@ -490,7 +529,8 @@ debiasingRow = function (Xinfo, # could be X or t(X) %*% X / n dep
490529
}
491530

492531
return(list(soln=result$soln,
493-
kkt_check=result$kkt_check))
532+
kkt_check=result$kkt_check,
533+
gradient=result$gradient))
494534

495535
}
496536

selectiveInference/R/funs.fixedCox.R

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ if( sum(status==0)+sum(status==1)!=length(y)) stop("status vector must have valu
2929
vars=which(m)
3030
if(sum(m)>0){
3131
bhat=beta[beta!=0] #penalized coefs just for active variables
32-
s2=sign(bhat)
32+
sign_bhat=sign(bhat)
3333

3434
#check KKT
3535

@@ -40,7 +40,7 @@ if(sum(m)>0){
4040
res=residuals(aaa,type="score")
4141
if(!is.matrix(res)) res=matrix(res,ncol=1)
4242
scor=colSums(res)
43-
g=(scor+lambda*s2)/(2*lambda)
43+
g=(scor+lambda*sign_bhat)/(2*lambda)
4444
# cat(c(g,lambda,tol.kkt),fill=T)
4545
if (any(abs(g) > 1+tol.kkt) )
4646
warning(paste("Solution beta does not satisfy the KKT conditions",
@@ -49,9 +49,9 @@ scor=colSums(res)
4949
# Hessian of partial likelihood at the LASSO solution
5050
MM=vcov(aaa)
5151

52-
bbar=(bhat+lambda*MM%*%s2)
53-
A1=-(mydiag(s2))
54-
b1= -(mydiag(s2)%*%MM)%*%s2*lambda
52+
bbar=(bhat+lambda*MM%*%sign_bhat)
53+
A1=-(mydiag(sign_bhat))
54+
b1= -(mydiag(sign_bhat)%*%MM)%*%sign_bhat*lambda
5555

5656
temp=max(A1%*%bbar-b1)
5757

@@ -63,7 +63,7 @@ b1= -(mydiag(s2)%*%MM)%*%s2*lambda
6363
# the one sided p-values are a bit off
6464

6565
for(jj in 1:length(bbar)){
66-
vj=rep(0,length(bbar));vj[jj]=s2[jj]
66+
vj=rep(0,length(bbar));vj[jj]=sign_bhat[jj]
6767

6868

6969
junk=TG.pvalue(bbar, A1, b1, vj,MM)
@@ -73,7 +73,7 @@ b1= -(mydiag(s2)%*%MM)%*%s2*lambda
7373
vup[jj]=junk$vup
7474
sd[jj]=junk$sd
7575

76-
junk2=TG.interval(bbar, A1, b1, vj, MM, alpha, flip=(s2[jj]==-1))
76+
junk2=TG.interval(bbar, A1, b1, vj, MM, alpha, flip=(sign_bhat[jj]==-1))
7777
ci[jj,]=junk2$int
7878
tailarea[jj,] = junk2$tailarea
7979

selectiveInference/R/funs.fixedLogit.R

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ fixedLogitLassoInf=function(x,y,beta,lambda,alpha=.1, type=c("partial"), tol.bet
3232
m=beta[-1]!=0 #active set
3333

3434
bhat=c(beta[1],beta[-1][beta[-1]!=0]) # intcpt plus active vars
35-
s2=sign(bhat)
35+
sign_bhat=sign(bhat)
3636
lam2m=diag(c(0,rep(lambda,sum(m))))
3737

3838

@@ -66,14 +66,14 @@ fixedLogitLassoInf=function(x,y,beta,lambda,alpha=.1, type=c("partial"), tol.bet
6666
# MM=solve(t(xxm)%*%w%*%xxm)
6767
MM=solve(scale(t(xxm),F,1/ww)%*%xxm)
6868
gm = c(0,-g[vars]*lambda) # gradient at LASSO solution, first entry is 0 because intercept is unpenalized
69-
# at exact LASSO solution it should be s2[-1]
69+
# at exact LASSO solution it should be sign_bhat[-1]
7070
dbeta = MM %*% gm
7171

72-
# bbar=(bhat+lam2m%*%MM%*%s2) # JT: this is wrong, shouldn't use sign of intercept anywhere...
72+
# bbar=(bhat+lam2m%*%MM%*%sign_bhat) # JT: this is wrong, shouldn't use sign of intercept anywhere...
7373
bbar = bhat - dbeta
7474

75-
A1=-(mydiag(s2))[-1,]
76-
b1= (s2 * dbeta)[-1]
75+
A1=-(mydiag(sign_bhat))[-1,]
76+
b1= (sign_bhat * dbeta)[-1]
7777

7878
tol.poly = 0.01
7979
if (max((A1 %*% bbar) - b1) > tol.poly)
@@ -87,7 +87,7 @@ fixedLogitLassoInf=function(x,y,beta,lambda,alpha=.1, type=c("partial"), tol.bet
8787

8888

8989
for(jj in 1:sum(m)){
90-
vj=c(rep(0,sum(m)+1));vj[jj+1]=s2[jj+1]
90+
vj=c(rep(0,sum(m)+1));vj[jj+1]=sign_bhat[jj+1]
9191
# compute p-values
9292
junk=TG.pvalue(bbar, A1, b1, vj, MM)
9393
pv[jj] = junk$pv
@@ -96,7 +96,7 @@ fixedLogitLassoInf=function(x,y,beta,lambda,alpha=.1, type=c("partial"), tol.bet
9696
vup[jj]=junk$vup
9797
sd[jj]=junk$sd
9898

99-
junk2=TG.interval(bbar, A1, b1, vj, MM,alpha=alpha, flip=(s2[jj+1]==-1))
99+
junk2=TG.interval(bbar, A1, b1, vj, MM,alpha=alpha, flip=(sign_bhat[jj+1]==-1))
100100

101101
ci[jj,]=junk2$int
102102
tailarea[jj,] = junk2$tailarea

0 commit comments

Comments
 (0)