Skip to content

Commit be257f1

Browse files
small cleanup of fixedLassoPoly
1 parent 9d54c53 commit be257f1

File tree

1 file changed

+29
-20
lines changed

1 file changed

+29
-20
lines changed

selectiveInference/R/funs.fixed.R

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,9 @@ fixedLassoInf <- function(x, y, beta,
8080
warning(paste("Solution beta does not satisfy the KKT conditions",
8181
"(to within specified tolerances)"))
8282

83-
tol.coef = tol.beta * sqrt(n^2 / colSums(x^2))
84-
# print(tol.coef)
85-
vars = which(abs(beta) > tol.coef)
86-
# vars = abs(beta) > tol.coef
87-
# print(beta)
88-
# print(vars)
83+
tol.coef = tol.beta * sqrt(n / colSums(x^2))
84+
vars = which(abs(beta) > tol.coef)
85+
8986
if(sum(vars)==0){
9087
cat("Empty model",fill=T)
9188
return()
@@ -97,10 +94,17 @@ fixedLassoInf <- function(x, y, beta,
9794
"'thresh' parameter, for a more accurate convergence."))
9895

9996
# Get lasso polyhedral region, of form Gy >= u
100-
logical.vars=rep(FALSE,p)
101-
logical.vars[vars]=TRUE
102-
if (type == 'full') out = fixedLassoPoly(x,y,lambda,beta,logical.vars,inactive=TRUE)
103-
else out = fixedLassoPoly(x,y,lambda,beta,logical.vars)
97+
98+
logical.vars=rep(FALSE,p)
99+
logical.vars[vars]=TRUE
100+
101+
if (type == 'full') {
102+
out = fixedLassoPoly(x, y, lambda, beta, logical.vars, inactive=TRUE)
103+
}
104+
else {
105+
out = fixedLassoPoly(x, y, lambda, beta, logical.vars)
106+
}
107+
104108
A = out$A
105109
b = out$b
106110

@@ -233,34 +237,39 @@ logical.vars[vars]=TRUE
233237

234238
fixedLassoPoly =
235239
function(X, y, lambda, beta, active, inactive = FALSE) {
236-
Xa = X[,active,drop=F]
237-
Xac = X[,!active,drop=F]
240+
Xa = X[, active, drop=FALSE]
241+
Xac = X[, !active, drop=FALSE]
238242
Xai = pinv(crossprod(Xa))
239243
Xap = Xai %*% t(Xa)
240244

241245
za = sign(beta[active])
242246
if (length(za)>1) dz = diag(za)
243247
if (length(za)==1) dz = matrix(za,1,1)
244248

249+
if(length(lambda)>1) {
250+
lambdaA= lambda[active]
251+
lambdaI = lambda[!active]
252+
} else {
253+
lambdaA = rep(lambda, sum(active))
254+
lambdaI = rep(lambda, sum(!active))
255+
}
245256
if (inactive) { # should we include the inactive constraints?
246-
R = diag(1,nrow(Xa)) - Xa %*% Xap # R is residual forming matrix of selected model
257+
R = diag(rep(1, nrow(Xa))) - Xa %*% Xap # R is residual forming matrix of selected model
247258

248259
A = rbind(
249-
1/lambda * t(Xac) %*% R,
250-
-1/lambda * t(Xac) %*% R,
260+
1/lambdaI * t(Xac) %*% R,
261+
-1/lambdaI * t(Xac) %*% R,
251262
-dz %*% Xap
252263
)
253264
lambda2=lambda
254-
if(length(lambda)>1) lambda2=lambda[active]
265+
255266
b = c(
256267
1 - t(Xac) %*% t(Xap) %*% za,
257268
1 + t(Xac) %*% t(Xap) %*% za,
258-
-lambda2 * dz %*% Xai %*% za)
269+
-lambdaA * dz %*% Xai %*% za)
259270
} else {
260271
A = -dz %*% Xap
261-
lambda2=lambda
262-
if(length(lambda)>1) lambda2=lambda[active]
263-
b = -lambda2 * dz %*% Xai %*% za
272+
b = -lambdaA * dz %*% Xai %*% za
264273
}
265274

266275
return(list(A=A, b=b))

0 commit comments

Comments
 (0)