Skip to content

Commit e887312

Browse files
Merge pull request #21 from jonathan-taylor/general_pivot
General pivot -- exposes functions for TG and eliminates poly.* and mypoly.*
2 parents 61dd190 + 4a9f1cf commit e887312

File tree

10 files changed

+358
-95
lines changed

10 files changed

+358
-95
lines changed

selectiveInference/NAMESPACE

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ export(lar,fs,
1212
estimateSigma,
1313
manyMeans,print.manyMeans,
1414
groupfs,groupfsInf,
15-
scaleGroups,factorDesign
15+
scaleGroups,factorDesign,
16+
TG.pvalue,
17+
TG.limits,
18+
TG.interval
1619
)
1720

1821
S3method("coef", "lar")

selectiveInference/R/funs.fixed.R

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,19 @@ else{
122122
vj = vj / mj # Standardize (divide by norm of vj)
123123
sign[j] = sign(sum(vj*y))
124124
vj = sign[j] * vj
125-
a = poly.pval(y,G,u,vj,sigma,bits)
125+
126+
limits.info = TG.limits(y, -G, -u, vj, Sigma=diag(rep(sigma^2, n)))
127+
a = TG.pvalue.base(limits.info, bits=bits)
126128
pv[j] = a$pv
127129
vlo[j] = a$vlo * mj # Unstandardize (mult by norm of vj)
128130
vup[j] = a$vup * mj # Unstandardize (mult by norm of vj)
129131
vmat[j,] = vj * mj * sign[j] # Unstandardize (mult by norm of vj)
130132

131-
a = poly.int(y,G,u,vj,sigma,alpha,gridrange=gridrange,
132-
flip=(sign[j]==-1),bits=bits)
133+
a = TG.interval.base(limits.info,
134+
alpha=alpha,
135+
gridrange=gridrange,
136+
flip=(sign[j]==-1),
137+
bits=bits)
133138
ci[j,] = a$int * mj # Unstandardize (mult by norm of vj)
134139
tailarea[j,] = a$tailarea
135140
}

selectiveInference/R/funs.fixedCox.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ b1= -(mydiag(s2)%*%MM)%*%s2*lambda
6262
vj=rep(0,length(bbar));vj[jj]=s2[jj]
6363

6464

65-
junk=mypoly.pval.lee(bbar,A1,b1,vj,MM)
65+
junk=TG.pvalue(bbar, A1, b1, vj,MM)
6666

6767
pv[jj] = junk$pv
6868
vlo[jj]=junk$vlo
6969
vup[jj]=junk$vup
7070
sd[jj]=junk$sd
7171

72-
junk2=mypoly.int.lee(bbar,vj,vlo[jj],vup[jj],sd[jj],alpha)
72+
junk2=TG.interval(bbar, A1, b1, vj, MM, alpha)
7373
ci[jj,]=junk2$int
7474
tailarea[jj,] = junk2$tailarea
7575

selectiveInference/R/funs.fixedLogit.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,14 @@ fixedLogitLassoInf=function(x,y,beta,lambda,alpha=.1, type=c("partial"), tol.bet
8989
for(jj in 1:sum(m)){
9090
vj=c(rep(0,sum(m)+1));vj[jj+1]=s2[jj+1]
9191
# compute p-values
92-
junk=mypoly.pval.lee(bbar,A1,b1,vj,MM)
92+
junk=TG.pvalue(bbar, A1, b1, vj, MM)
9393
pv[jj] = junk$pv
9494

9595
vlo[jj]=junk$vlo
9696
vup[jj]=junk$vup
9797
sd[jj]=junk$sd
98-
# junk2=mypoly.int.lee(bbar[-1], A1, b1,vj,MM[-1,-1],alpha=.1)
99-
junk2=mypoly.int.lee(bbar,vj,vlo[jj],vup[jj],sd[jj],alpha=alpha)
98+
99+
junk2=TG.interval(bbar, A1, b1, vj, MM,alpha=alpha)
100100

101101
ci[jj,]=junk2$int
102102
tailarea[jj,] = junk2$tailarea

selectiveInference/R/funs.fs.R

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -299,15 +299,21 @@ fsInf <- function(obj, sigma=NULL, alpha=0.1, k=NULL, type=c("active","all","aic
299299
vj = vreg[j,]
300300
mj = sqrt(sum(vj^2))
301301
vj = vj / mj # Standardize (divide by norm of vj)
302-
a = poly.pval(y,Gj,uj,vj,sigma,bits)
302+
303+
limits.info = TG.limits(y, -Gj, -uj, vj, Sigma=diag(rep(sigma^2, n)))
304+
a = TG.pvalue.base(limits.info, bits=bits)
305+
303306
pv[j] = a$pv
304307
sxj = sx[vars[j]]
305308
vlo[j] = a$vlo * mj / sxj # Unstandardize (mult by norm of vj / sxj)
306309
vup[j] = a$vup * mj / sxj # Unstandardize (mult by norm of vj / sxj)
307310
vmat[j,] = vj * mj / sxj # Unstandardize (mult by norm of vj / sxj)
308311

309-
a = poly.int(y,Gj,uj,vj,sigma,alpha,gridrange=gridrange,
310-
flip=(sign[j]==-1),bits=bits)
312+
a = TG.interval.base(limits.info,
313+
alpha=alpha,
314+
gridrange=gridrange,
315+
flip=(sign[j]==-1),
316+
bits=bits)
311317
ci[j,] = a$int * mj / sxj # Unstandardize (mult by norm of vj / sxj)
312318
tailarea[j,] = a$tailarea
313319
}
@@ -349,15 +355,19 @@ fsInf <- function(obj, sigma=NULL, alpha=0.1, k=NULL, type=c("active","all","aic
349355
Gj = rbind(G,vj)
350356
uj = c(u,0)
351357

352-
a = poly.pval(y,Gj,uj,vj,sigma,bits)
358+
limits.info = TG.limits(y, -Gj, -uj, vj, Sigma=diag(rep(sigma^2, n)))
359+
a = TG.pvalue.base(limits.info, bits=bits)
353360
pv[j] = a$pv
354361
sxj = sx[vars[j]]
355362
vlo[j] = a$vlo * mj / sxj # Unstandardize (mult by norm of vj / sxj)
356363
vup[j] = a$vup * mj / sxj # Unstandardize (mult by norm of vj / sxj)
357364
vmat[j,] = vj * mj / sxj # Unstandardize (mult by norm of vj / sxj)
358365

359-
a = poly.int(y,Gj,uj,vj,sigma,alpha,gridrange=gridrange,
360-
flip=(sign[j]==-1),bits=bits)
366+
a = TG.interval.base(limits.info,
367+
alpha=alpha,
368+
gridrange=gridrange,
369+
flip=(sign[j]==-1),
370+
bits=bits)
361371
ci[j,] = a$int * mj / sxj # Unstandardize (mult by norm of vj / sxj)
362372
tailarea[j,] = a$tailarea
363373
}

selectiveInference/R/funs.inf.R

Lines changed: 76 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,3 @@
1-
# Main p-value function
2-
3-
poly.pval <- function(y, G, u, v, sigma, bits=NULL) {
4-
z = sum(v*y)
5-
vv = sum(v^2)
6-
sd = sigma*sqrt(vv)
7-
8-
rho = G %*% v / vv
9-
vec = (u - G %*% y + rho*z) / rho
10-
vlo = suppressWarnings(max(vec[rho>0]))
11-
vup = suppressWarnings(min(vec[rho<0]))
12-
13-
pv = tnorm.surv(z,0,sd,vlo,vup,bits)
14-
return(list(pv=pv,vlo=vlo,vup=vup))
15-
}
16-
17-
# Main confidence interval function
18-
19-
poly.int <- function(y, G, u, v, sigma, alpha, gridrange=c(-100,100),
20-
gridpts=100, griddepth=2, flip=FALSE, bits=NULL) {
21-
22-
z = sum(v*y)
23-
vv = sum(v^2)
24-
sd = sigma*sqrt(vv)
25-
26-
rho = G %*% v / vv
27-
vec = (u - G %*% y + rho*z) / rho
28-
vlo = suppressWarnings(max(vec[rho>0]))
29-
vup = suppressWarnings(min(vec[rho<0]))
30-
31-
xg = seq(gridrange[1]*sd,gridrange[2]*sd,length=gridpts)
32-
fun = function(x) { tnorm.surv(z,x,sd,vlo,vup,bits) }
33-
34-
int = grid.search(xg,fun,alpha/2,1-alpha/2,gridpts,griddepth)
35-
tailarea = c(fun(int[1]),1-fun(int[2]))
36-
37-
if (flip) {
38-
int = -int[2:1]
39-
tailarea = tailarea[2:1]
40-
}
41-
42-
return(list(int=int,tailarea=tailarea))
43-
}
44-
45-
##############################
461

472
# Assuming that grid is in sorted order from smallest to largest,
483
# and vals are monotonically increasing function values over the
@@ -247,48 +202,95 @@ aicStop <- function(x, y, action, df, sigma, mult=2, ntimes=2) {
247202

248203
#these next two functions are used by the binomial and Cox options of fixedLassoInf
249204

250-
mypoly.pval.lee=
251-
function(y, A, b, eta, Sigma, bits=NULL) {
205+
# Compute the truncation interval and SD of the corresponding Gaussian
206+
207+
TG.limits = function(Z, A, b, eta, Sigma=NULL) {
208+
209+
target_estimate = sum(as.numeric(eta) * as.numeric(Z))
210+
211+
if (max(A %*% as.numeric(Z) - b) > 0) {
212+
warning('Contsraint not satisfied. A %*% Z should be elementwise less than or equal to b')
213+
}
214+
215+
if (is.null(Sigma)) {
216+
Sigma = diag(rep(1, n))
217+
}
218+
252219
# compute pvalues from poly lemma: full version from Lee et al for full matrix Sigma
253-
nn=length(y)
254-
eta=as.vector(eta)
255-
temp = sum(eta*y)
256-
vv=as.numeric(matrix(eta,nrow=1,ncol=nn)%*%Sigma%*%eta)
257-
cc = Sigma%*%eta/vv
258-
259-
z=(diag(nn)-matrix(cc,ncol=1)%*%eta)%*%y
260-
rho=A%*%cc
220+
221+
n = length(Z)
222+
eta = matrix(eta, ncol=1, nrow=n)
223+
b = as.vector(b)
224+
var_estimate = sum(matrix(eta, nrow=1, ncol=n) %*% (Sigma %*% matrix(eta, ncol=1, nrow=n)))
225+
cross_cov = Sigma %*% matrix(eta, ncol=1, nrow=n)
261226

262-
vec = (b- A %*% z)/rho
263-
vlo = suppressWarnings(max(vec[rho<0]))
264-
vup = suppressWarnings(min(vec[rho>0]))
265-
sd=sqrt(vv)
266-
pv = tnorm.surv(temp,0,sd,vlo,vup,bits)
267-
return(list(pv=pv,vlo=vlo,vup=vup,sd=sd))
227+
resid = (diag(n) - matrix(cross_cov / var_estimate, ncol=1, nrow=n) %*% matrix(eta, nrow=1, ncol=n)) %*% Z
228+
rho = A %*% cross_cov / var_estimate
229+
vec = (b - as.numeric(A %*% resid)) / rho
230+
231+
vlo = suppressWarnings(max(vec[rho < 0]))
232+
vup = suppressWarnings(min(vec[rho > 0]))
233+
234+
sd = sqrt(var_estimate)
235+
return(list(vlo=vlo, vup=vup, sd=sd, estimate=target_estimate))
236+
}
237+
238+
TG.pvalue = function(Z, A, b, eta, Sigma=NULL, null_value=0, bits=NULL) {
239+
240+
limits.info = TG.limits(Z, A, b, eta, Sigma)
241+
242+
return(TG.pvalue.base(limits.info, null_value=null_value, bits=bits))
268243
}
269244

245+
TG.interval = function(Z, A, b, eta, Sigma=NULL, alpha=0.1,
246+
gridrange=c(-100,100),
247+
gridpts=100,
248+
griddepth=2,
249+
flip=FALSE,
250+
bits=NULL) {
251+
252+
limits.info = TG.limits(Z, A, b, eta, Sigma)
253+
254+
return(TG.interval.base(limits.info,
255+
alpha=alpha,
256+
gridrange=gridrange,
257+
griddepth=griddepth,
258+
flip=flip,
259+
bits=bits))
260+
}
270261

262+
TG.interval.base = function(limits.info, alpha=0.1,
263+
gridrange=c(-100,100),
264+
gridpts=100,
265+
griddepth=2,
266+
flip=FALSE,
267+
bits=NULL) {
271268

272-
mypoly.int.lee=
273-
function(y,eta,vlo,vup,sd, alpha, gridrange=c(-100,100),gridpts=100, griddepth=2, flip=FALSE, bits=NULL) {
274269
# compute sel intervals from poly lemmma, full version from Lee et al for full matrix Sigma
275270

276-
temp = sum(eta*y)
277-
278-
xg = seq(gridrange[1]*sd,gridrange[2]*sd,length=gridpts)
279-
fun = function(x) { tnorm.surv(temp,x,sd,vlo,vup,bits) }
271+
param_grid = seq(gridrange[1] * limits.info$sd, gridrange[2] * limits.info$sd, length=gridpts)
280272

281-
int = grid.search(xg,fun,alpha/2,1-alpha/2,gridpts,griddepth)
282-
tailarea = c(fun(int[1]),1-fun(int[2]))
273+
pivot = function(param) {
274+
tnorm.surv(limits.info$estimate, param, limits.info$sd, limits.info$vlo, limits.info$vup, bits)
275+
}
283276

284-
if (flip) {
285-
int = -int[2:1]
286-
tailarea = tailarea[2:1]
287-
}
277+
interval = grid.search(param_grid, pivot, alpha/2, 1-alpha/2, gridpts, griddepth)
278+
tailarea = c(pivot(interval[1]), 1- pivot(interval[2]))
279+
280+
if (flip) {
281+
interval = -interval[2:1]
282+
tailarea = tailarea[2:1]
283+
}
288284

289-
return(list(int=int,tailarea=tailarea))
285+
# int is not a good variable name, synonymous with integer...
286+
return(list(int=interval,
287+
tailarea=tailarea))
290288
}
291289

290+
TG.pvalue.base = function(limits.info, null_value=0, bits=NULL) {
291+
pv = tnorm.surv(limits.info$estimate, null_value, limits.info$sd, limits.info$vlo, limits.info$vup, bits)
292+
return(list(pv=pv, vlo=limits.info$vlo, vup=limits.info$vup, sd=limits.info$sd))
293+
}
292294

293295

294296
mydiag=function(x){

selectiveInference/R/funs.lar.R

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,15 +379,20 @@ larInf <- function(obj, sigma=NULL, alpha=0.1, k=NULL, type=c("active","all","ai
379379
vj = vreg[j,]
380380
mj = sqrt(sum(vj^2))
381381
vj = vj / mj # Standardize (divide by norm of vj)
382-
a = poly.pval(y,Gj,uj,vj,sigma,bits)
382+
383+
limits.info = TG.limits(y, -Gj, -uj, vj, Sigma=diag(rep(sigma^2, n)))
384+
a = TG.pvalue.base(limits.info, bits=bits)
383385
pv[j] = a$pv
384386
sxj = sx[vars[j]]
385387
vlo[j] = a$vlo * mj / sxj # Unstandardize (mult by norm of vj / sxj)
386388
vup[j] = a$vup * mj / sxj # Unstandardize (mult by norm of vj)
387389
vmat[j,] = vj * mj / sxj # Unstandardize (mult by norm of vj / sxj)
388390

389-
a = poly.int(y,Gj,uj,vj,sigma,alpha,gridrange=gridrange,
390-
flip=(sign[j]==-1),bits=bits)
391+
a = TG.interval.base(limits.info,
392+
alpha=alpha,
393+
gridrange=gridrange,
394+
flip=(sign[j]==-1),
395+
bits=bits)
391396
ci[j,] = a$int * mj / sxj # Unstandardize (mult by norm of vj / sxj)
392397
tailarea[j,] = a$tailarea
393398

@@ -433,15 +438,20 @@ larInf <- function(obj, sigma=NULL, alpha=0.1, k=NULL, type=c("active","all","ai
433438
Gj = rbind(G,vj)
434439
uj = c(u,0)
435440

436-
a = poly.pval(y,Gj,uj,vj,sigma,bits)
441+
limits.info = TG.limits(y, -Gj, -uj, vj, Sigma=diag(rep(sigma^2, n)))
442+
a = TG.pvalue.base(limits.info, bits=bits)
443+
437444
pv[j] = a$pv
438445
sxj = sx[vars[j]]
439446
vlo[j] = a$vlo * mj / sxj # Unstandardize (mult by norm of vj / sxj)
440447
vup[j] = a$vup * mj / sxj # Unstandardize (mult by norm of vj / sxj)
441448
vmat[j,] = vj * mj / sxj # Unstandardize (mult by norm of vj / sxj)
442449

443-
a = poly.int(y,Gj,uj,vj,sigma,alpha,gridrange=gridrange,
444-
flip=(sign[j]==-1),bits=bits)
450+
a = TG.interval.base(limits.info,
451+
alpha=alpha,
452+
gridrange=gridrange,
453+
flip=(sign[j]==-1),
454+
bits=bits)
445455
ci[j,] = a$int * mj / sxj # Unstandardize (mult by norm of vj / sxj)
446456
tailarea[j,] = a$tailarea
447457
}

0 commit comments

Comments
 (0)