Skip to content

Commit bf1fa5e

Browse files
committed
improved cvlar, but it doesn't give same results as glmnet
1 parent 7faeba8 commit bf1fa5e

File tree

1 file changed

+260
-52
lines changed
  • forLater/josh/selectiveInference/R

1 file changed

+260
-52
lines changed

forLater/josh/selectiveInference/R/cv.R

Lines changed: 260 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ cvMakeFolds <- function(x, nfolds = 5) {
1919
}
2020

2121
# To interface with glmnet
22-
foldid <- function(folds) {
22+
foldidglmnet <- function(folds) {
2323
n <- sum(sapply(folds, length))
2424
glmnetfoldid <- rep(0, n)
2525
for (ind in 1:length(folds)) {
@@ -28,6 +28,18 @@ foldid <- function(folds) {
2828
glmnetfoldid
2929
}
3030

31+
# cv.glmnet and estimateSigma mashup
32+
cvglmnetlar <- function(x, y, foldid) {
33+
cvfit <- cv.glmnet(x, y, intercept = FALSE, foldid = foldid)
34+
lamhat <- cvfit$lambda.min
35+
yhat <- predict(cvfit, x, s = lamhat)
36+
nz <- sum(coef(cvfit, s = lamhat) !=0)
37+
cvfit$sigma <- sqrt(sum((y-yhat)^2)/(length(y)-nz-1))
38+
cvfit$df <- nz
39+
return(cvfit)
40+
}
41+
42+
3143
#--------------------------------------
3244
# Functions for computing quadratic form for cv-error
3345
#--------------------------------------
@@ -59,6 +71,8 @@ cvProductHat <- function(folds, inds, finds, ginds, hat_matrices) {
5971
return(Reduce('+', terms))
6072
}
6173

74+
# This is too "clever," I can't easily understand it
75+
# simpler code is preferable for maintenance and forking etc
6276
cvRSSquad <- function(x, folds, active.sets) {
6377
hat_matrices <- cvHatMatrix(x, folds, active.sets)
6478
nfolds <- length(folds)
@@ -78,9 +92,28 @@ cvRSSquad <- function(x, folds, active.sets) {
7892
return(Q)
7993
}
8094

95+
cvopt <- function(x, y, maxsteps, folds, active.sets) {
96+
yperm <- y[order(unlist(folds))]
97+
RSSquads <- list()
98+
# Can this loop be optimized with smart updating of each model along each path?
99+
for (s in 1:maxsteps) {
100+
initial.active <- lapply(active.sets, function(a) a[1:s])
101+
RSSquads[[s]] <- cvRSSquad(x, folds, initial.active)
102+
}
103+
104+
RSSs <- lapply(RSSquads, function(Q) t(y) %*% Q %*% y)
105+
sstar <- which.max(RSSs)
106+
quadstar <- RSSquads[sstar][[1]]
107+
108+
RSSquads <- lapply(RSSquads, function(quad) quad - quadstar)
109+
RSSquads[[sstar]] <- NULL # remove the all zeroes case
110+
return(list(sstar = sstar, RSSquads = RSSquads))
111+
}
112+
81113

82114
#--------------------------------------
83115
# Functions for forward stepwise
116+
# broke this while making cvlar
84117
#--------------------------------------
85118

86119
cvfs <- function(x, y, index = 1:ncol(x), maxsteps, sigma = NULL, intercept = TRUE, center = TRUE, normalize = TRUE, nfolds = 5) {
@@ -130,18 +163,7 @@ cvfs <- function(x, y, index = 1:ncol(x), maxsteps, sigma = NULL, intercept = TR
130163
}
131164
#projections <- do.call(c, projections)
132165

133-
RSSquads <- list()
134-
for (s in 1:maxsteps) {
135-
initial.active <- lapply(active.sets, function(a) a[1:s])
136-
RSSquads[[s]] <- cvRSSquad(X, folds, initial.active)
137-
}
138-
139-
RSSs <- lapply(RSSquads, function(Q) t(Y) %*% Q %*% Y)
140-
sstar <- which.min(RSSs)
141-
quadstar <- RSSquads[sstar][[1]]
142166

143-
RSSquads <- lapply(RSSquads, function(quad) quad - quadstar)
144-
RSSquads[[sstar]] <- NULL # remove the all zeroes case
145167

146168
fit <- groupfs(X, Y, index=index, maxsteps=sstar, sigma=sigma, intercept=intercept, center=center, normalize=normalize)
147169
fit$cvobj <- cvobj
@@ -157,54 +179,240 @@ cvfs <- function(x, y, index = 1:ncol(x), maxsteps, sigma = NULL, intercept = TR
157179
# Functions for lar
158180
#--------------------------------------
159181

160-
cvlar <- function(x, y, maxsteps) { # other args
161-
folds <- cvMakeFolds(x)
162-
models <- lapply(folds, function(fold) {
163-
x.train <- x
164-
y.train <- y
165-
x.train[fold,] <- 0
166-
y.train[fold] <- 0
167-
x.test <- x[fold,]
168-
y.test <- y[fold]
169-
larpath.train <- lar(x.train, y.train, maxsteps = maxsteps, intercept = F, normalize = F)
170-
return(larpath.train)
171-
})
182+
cvlar <- function(x, y, maxsteps, folds = NULL) { # other args
183+
this.call = match.call()
184+
if (is.null(folds)) folds <- cvMakeFolds(x)
185+
models <- lapply(folds, function(fold) {
186+
x.train <- x
187+
y.train <- y
188+
x.train[fold,] <- 0
189+
y.train[fold] <- 0
190+
x.test <- x[fold,]
191+
y.test <- y[fold]
192+
larpath.train <- lar(x.train, y.train, maxsteps = maxsteps, intercept = F, normalize = F)
193+
return(larpath.train)
194+
})
195+
196+
active.sets <- lapply(models, function(model) model$action)
197+
#lambdas <- lapply(models, function(model) model$lambda)
198+
#lmin <- min(unlist(lambdas))
199+
cvmin <- cvopt(x, y, maxsteps, folds, active.sets)
200+
sstar <- cvmin$sstar
201+
fit <- lar(x, y, maxsteps=sstar, intercept = F, normalize = F)
202+
fit$ols <- lsfit(x[, fit$action, drop = F], y, intercept = F)
203+
names(fit$ols$coefficients) <- fit$action
204+
fit$sigma <- sqrt(sum((fit$ols$residuals)^2)/(length(y)-length(fit$action)-1))
205+
fit$RSSquads <- cvmin$RSSquads
206+
# tall Gamma encoding all cv-model paths
207+
fit$tallGamma <- do.call(rbind, lapply(models, function(model) return(model$Gamma)))
208+
fit$khat <- sstar
209+
fit$folds <- folds
210+
fit$call <- this.call
211+
class(fit) <- "cvlar"
212+
# more to do here?
213+
return(fit)
214+
}
215+
216+
# cvlarInf <- function(obj, ...) {
217+
# pv.unadj <- larInf(obj, type = "all", k = obj$khat, verbose = T, ...)
218+
# obj$Gamma <- rbind(obj$Gamma, obj$tallGamma)
219+
# pv.adj <- larInf(obj, type = "all", k = obj$khat, verbose = T, ...)
220+
# return(list(pv.unadj = pv.unadj, pv.adj = pv.adj))
221+
# }
222+
223+
cvlarInf <- function (obj, sigma, alpha = 0.1,
224+
k = NULL,
225+
gridrange = c(-100, 100),
226+
bits = NULL, mult = 2,
227+
ntimes = 2, verbose = FALSE) {
228+
this.call = match.call()
229+
#checkargs.misc(sigma = sigma, alpha = alpha, k = k, gridrange = gridrange, mult = mult, ntimes = ntimes)
230+
if (class(obj) != "cvlar")
231+
stop("obj must be an object of class cvlar")
232+
if (!is.null(bits) && !requireNamespace("Rmpfr", quietly = TRUE)) {
233+
warning("Package Rmpfr is not installed, reverting to standard precision")
234+
bits = NULL
235+
}
236+
x = obj$x
237+
y = obj$y
238+
p = ncol(x)
239+
n = nrow(x)
240+
G = obj$Gamma
241+
#nk = obj$nk
242+
sx = obj$sx
243+
k = obj$khat
244+
sigma = obj$sigma
245+
# may the gods of OOP have mercy on us
246+
class(obj) <- "lar"
247+
pv.unadj <- larInf(obj, type = "all", sigma = sigma, k = obj$khat)
248+
class(obj) <- "cvlar"
249+
#pv.spacing = pv.modspac = pv.covtest = khat = NULL
250+
251+
G = rbind(obj$Gamma, obj$tallGamma) #G[1:nk[k], ]
252+
u = rep(0, nrow(G))
253+
kk = k
254+
pv = vlo = vup = numeric(kk)
255+
vmat = matrix(0, kk, n)
256+
ci = tailarea = matrix(0, kk, 2)
257+
sign = numeric(kk)
258+
vars = obj$action[1:kk]
259+
xa = x[, vars]
260+
M = pinv(crossprod(xa)) %*% t(xa)
261+
for (j in 1:kk) {
262+
if (verbose)
263+
cat(sprintf("Inference for variable %i ...\n",
264+
vars[j]))
265+
vj = M[j, ]
266+
mj = sqrt(sum(vj^2))
267+
vj = vj/mj
268+
sign[j] = sign(sum(vj * y))
269+
vj = sign[j] * vj
270+
Gj = rbind(G, vj)
271+
uj = c(u, 0)
272+
a = poly.pval(y, Gj, uj, vj, sigma, bits)
273+
pv[j] = a$pv
274+
sxj = sx[vars[j]]
275+
vlo[j] = a$vlo * mj/sxj
276+
vup[j] = a$vup * mj/sxj
277+
vmat[j, ] = vj * mj/sxj
278+
279+
#a = poly.int(y, Gj, uj, vj, sigma, alpha, gridrange = gridrange, flip = (sign[j] == -1), bits = bits)
280+
#ci[j, ] = a$int * mj/sxj
281+
#tailarea[j, ] = a$tailarea
282+
}
283+
out = list(type = type, k = k, khat = khat, pv = pv,
284+
pv.unadj = pv.unadj, vlo = vlo, vup = vup, vmat = vmat,
285+
y = y, vars = vars, sign = sign, sigma = sigma,
286+
alpha = alpha, call = this.call)
287+
class(out) = "cvlarInf"
288+
return(out)
289+
}
172290

173-
active.sets <- lapply(models, function(model) model$action)
174-
lambdas <- lapply(models, function(model) model$lambda)
175-
lmin <- min(unlist(lambdas))
176291

177-
# Interpolate lambda grid or parametrize by steps?
178-
# interpolation probably requires re-writing cvRSSquads for
179-
# penalized fits in order to make sense
180292

181-
# do steps for now just to have something that works?
293+
poly.pval <- function(y, G, u, v, sigma, bits=NULL) {
294+
z = sum(v*y)
295+
vv = sum(v^2)
296+
sd = sigma*sqrt(vv)
297+
298+
rho = G %*% v / vv
299+
vec = (u - G %*% y + rho*z) / rho
300+
vlo = suppressWarnings(max(vec[rho>0]))
301+
vup = suppressWarnings(min(vec[rho<0]))
302+
303+
pv = tnorm.surv(z,0,sd,vlo,vup,bits)
304+
return(list(pv=pv,vlo=vlo,vup=vup))
305+
}
182306

183-
RSSquads <- list()
184-
for (s in 1:maxsteps) {
185-
initial.active <- lapply(active.sets, function(a) a[1:s])
186-
RSSquads[[s]] <- cvRSSquad(x, folds, initial.active)
187-
}
307+
pinv <- function(A, tol=.Machine$double.eps) {
308+
e = eigen(A)
309+
v = Re(e$vec)
310+
d = Re(e$val)
311+
d[d > tol] = 1/d[d > tol]
312+
d[d < tol] = 0
313+
if (length(d)==1) return(v*d*v)
314+
else return(v %*% diag(d) %*% t(v))
315+
}
188316

189-
RSSs <- lapply(RSSquads, function(Q) t(y) %*% Q %*% y)
190-
sstar <- which.min(RSSs)
191-
quadstar <- RSSquads[sstar][[1]]
317+
tnorm.surv <- function(z, mean, sd, a, b, bits=NULL) {
318+
z = max(min(z,b),a)
319+
320+
# Check silly boundary cases
321+
p = numeric(length(mean))
322+
p[mean==-Inf] = 0
323+
p[mean==Inf] = 1
324+
325+
# Try the multi precision floating point calculation first
326+
o = is.finite(mean)
327+
mm = mean[o]
328+
pp = mpfr.tnorm.surv(z,mm,sd,a,b,bits)
329+
330+
# If there are any NAs, then settle for an approximation
331+
oo = is.na(pp)
332+
if (any(oo)) pp[oo] = bryc.tnorm.surv(z,mm[oo],sd,a,b)
333+
334+
p[o] = pp
335+
return(p)
336+
}
192337

193-
# Need to add these later?
194-
#RSSquads <- lapply(RSSquads, function(quad) quad - quadstar)
195-
#RSSquads[[sstar]] <- NULL # remove the all zeroes case
338+
mpfr.tnorm.surv <- function(z, mean=0, sd=1, a, b, bits=NULL) {
339+
# If bits is not NULL, then we are supposed to be using Rmpf
340+
# (note that this was fail if Rmpfr is not installed; but
341+
# by the time this function is being executed, this should
342+
# have been properly checked at a higher level; and if Rmpfr
343+
# is not installed, bits would have been previously set to NULL)
344+
if (!is.null(bits)) {
345+
z = Rmpfr::mpfr((z-mean)/sd, precBits=bits)
346+
a = Rmpfr::mpfr((a-mean)/sd, precBits=bits)
347+
b = Rmpfr::mpfr((b-mean)/sd, precBits=bits)
348+
return(as.numeric((Rmpfr::pnorm(b)-Rmpfr::pnorm(z))/
349+
(Rmpfr::pnorm(b)-Rmpfr::pnorm(a))))
350+
}
351+
352+
# Else, just use standard floating point calculations
353+
z = (z-mean)/sd
354+
a = (a-mean)/sd
355+
b = (b-mean)/sd
356+
return((pnorm(b)-pnorm(z))/(pnorm(b)-pnorm(a)))
357+
}
196358

197-
fit <- lar(x, y, maxsteps=sstar, intercept = F, normalize = F)
359+
bryc.tnorm.surv <- function(z, mean=0, sd=1, a, b) {
360+
z = (z-mean)/sd
361+
a = (a-mean)/sd
362+
b = (b-mean)/sd
363+
n = length(mean)
364+
365+
term1 = exp(z*z)
366+
o = a > -Inf
367+
term1[o] = ff(a[o])*exp(-(a[o]^2-z[o]^2)/2)
368+
term2 = rep(0,n)
369+
oo = b < Inf
370+
term2[oo] = ff(b[oo])*exp(-(b[oo]^2-z[oo]^2)/2)
371+
p = (ff(z)-term2)/(term1-term2)
372+
373+
# Sometimes the approximation can give wacky p-values,
374+
# outside of [0,1] ..
375+
#p[p<0 | p>1] = NA
376+
p = pmin(1,pmax(0,p))
377+
return(p)
378+
}
198379

199-
# Very tall Gamma encoding all cv-model paths
200-
Gamma <- do.call(rbind, lapply(models, function(model) return(model$Gamma)))
201-
fit$Gamma <- rbind(fit$Gamma, Gamma)
202-
fit$khat <- sstar
203-
fit$folds <- folds
204-
# more to do here
205-
return(fit)
380+
ff <- function(z) {
381+
return((z^2+5.575192695*z+12.7743632)/
382+
(z^3*sqrt(2*pi)+14.38718147*z*z+31.53531977*z+2*12.77436324))
206383
}
207384

208-
cvlarInf <- function(obj) {
209-
larInf(obj, type = "all", k = obj$khat)
385+
386+
387+
print.cvlar <- function(x, ...) {
388+
cat("\nCall:\n")
389+
dput(x$call)
390+
391+
cat("\nSequence of LAR moves:\n")
392+
nsteps = length(x$action)
393+
tab = cbind(1:nsteps,x$action,x$sign)
394+
colnames(tab) = c("Step","Var","Sign")
395+
rownames(tab) = rep("",nrow(tab))
396+
print(tab)
397+
invisible()
210398
}
399+
400+
print.cvlarInf <- function(x, ...) {
401+
cat("\nCall:\n")
402+
dput(x$call)
403+
404+
cat(sprintf("\nStandard deviation of noise (specified or estimated) sigma = %0.3f\n",
405+
x$sigma))
406+
407+
408+
cat(sprintf("\nTesting results at step = %i, with alpha = %0.3f\n",x$k,x$alpha))
409+
cat("",fill=T)
410+
tab = cbind(x$vars,
411+
round(x$sign*x$vmat%*%x$y,3),
412+
round(x$sign*x$vmat%*%x$y/(x$sigma*sqrt(rowSums(x$vmat^2))),3),
413+
round(x$pv,3))
414+
colnames(tab) = c("Var", "Coef", "Z-score", "P-value")
415+
rownames(tab) = rep("",nrow(tab))
416+
print(tab)
417+
invisible()
418+
}

0 commit comments

Comments
 (0)