Skip to content

Commit a6fcd6e

Browse files
fix of signs for full
1 parent 151c21d commit a6fcd6e

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

selectiveInference/R/funs.fixed.R

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,20 @@ fixedLassoInf <- function(x, y, beta,
8282

8383
tol.coef = tol.beta * sqrt(n / colSums(x^2))
8484
vars = which(abs(beta) > tol.coef)
85+
sign_vars = sign(beta[vars])
8586

8687
if(sum(vars)==0){
8788
cat("Empty model",fill=T)
8889
return()
8990
}
90-
if (any(sign(g[vars]) != sign(beta[vars])))
91+
92+
if (any(sign(g[vars]) != sign_vars)) {
9193
warning(paste("Solution beta does not satisfy the KKT conditions",
9294
"(to within specified tolerances). You might try rerunning",
9395
"glmnet with a lower setting of the",
9496
"'thresh' parameter, for a more accurate convergence."))
95-
97+
}
98+
9699
# Get lasso polyhedral region, of form Gy >= u
97100

98101
logical.vars=rep(FALSE,p)
@@ -132,13 +135,19 @@ fixedLassoInf <- function(x, y, beta,
132135
}
133136

134137
# add additional targets for inference if provided
135-
if (!is.null(add.targets)) vars = sort(unique(c(vars,add.targets,recursive=T)))
136-
137-
k = length(vars)
138+
if (!is.null(add.targets)) {
139+
# vars is boolean...
140+
old_vars = vars & TRUE
141+
vars[add.targets] = TRUE
142+
sign_vars = sign(beta[vars])
143+
sign_vars[!old_vars] = NA
144+
stop("`add.targets` not fully implemented yet")
145+
}
146+
147+
k = length(vars)
138148
pv = vlo = vup = numeric(k)
139149
vmat = matrix(0,k,n)
140150
ci = tailarea = matrix(0,k,2)
141-
sign = numeric(k)
142151

143152
if (type=="full" & p > n) {
144153
if (intercept == T) {
@@ -202,28 +211,36 @@ fixedLassoInf <- function(x, y, beta,
202211
vj = M[j,]
203212
mj = sqrt(sum(vj^2))
204213
vj = vj / mj # Standardize (divide by norm of vj)
205-
sign[j] = sign(sum(vj*y))
206-
vj = sign[j] * vj
214+
215+
if (!is.na(sign_vars[j])) {
216+
vj = sign_vars[j] * vj
217+
}
207218

208219
limits.info = TG.limits(y, A, b, vj, Sigma=diag(rep(sigma^2, n)))
209220
a = TG.pvalue.base(limits.info, null_value=null_value[j], bits=bits)
210221
pv[j] = a$pv
222+
if (is.na(sign_vars[j])) { # for variables not in the active set, report 2-sided pvalue
223+
pv[j] = 2 * min(pv[j], 1 - pv[j])
224+
}
211225
vlo[j] = a$vlo * mj # Unstandardize (mult by norm of vj)
212226
vup[j] = a$vup * mj # Unstandardize (mult by norm of vj)
213-
vmat[j,] = vj * mj * sign[j] # Unstandardize (mult by norm of vj)
214-
227+
if (!is.na(sign_vars[j])) {
228+
vmat[j,] = vj * mj * sign_vars[j] # Unstandardize (mult by norm of vj) and fix sign
229+
} else {
230+
vmat[j,] = vj * mj # Unstandardize (mult by norm of vj)
231+
}
215232
a = TG.interval.base(limits.info,
216233
alpha=alpha,
217234
gridrange=gridrange,
218-
flip=(sign[j]==-1),
235+
flip=(sign_vars[j]==-1),
219236
bits=bits)
220237
ci[j,] = (a$int-null_value[j]) * mj # Unstandardize (mult by norm of vj)
221238
tailarea[j,] = a$tailarea
222239
}
223240

224241
out = list(type=type,lambda=lambda,pv=pv,ci=ci,
225242
tailarea=tailarea,vlo=vlo,vup=vup,vmat=vmat,y=y,
226-
vars=vars,sign=sign,sigma=sigma,alpha=alpha,
243+
vars=vars,sign=sign_vars,sigma=sigma,alpha=alpha,
227244
sd=sigma*sqrt(rowSums(vmat^2)),
228245
coef0=vmat%*%y,
229246
call=this.call)

0 commit comments

Comments
 (0)