Skip to content

Commit 726b917

Browse files
authored
Merge pull request #31 from jonathan-taylor/wide_matrix_QP
Wide matrix qp
2 parents 48547ea + 76f3e7e commit 726b917

File tree

9 files changed

+908
-146
lines changed

9 files changed

+908
-146
lines changed

selectiveInference/R/RcppExports.R

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1-
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
1+
# This file was generated by Rcpp::compileAttributes
22
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
33

44
solve_QP <- function(Sigma, bound, maxiter, theta, linear_func, gradient, ever_active, nactive, kkt_tol, objective_tol, max_active) {
5-
.Call('_selectiveInference_solve_QP', PACKAGE = 'selectiveInference', Sigma, bound, maxiter, theta, linear_func, gradient, ever_active, nactive, kkt_tol, objective_tol, max_active)
5+
.Call('selectiveInference_solve_QP', PACKAGE = 'selectiveInference', Sigma, bound, maxiter, theta, linear_func, gradient, ever_active, nactive, kkt_tol, objective_tol, max_active)
6+
}
7+
8+
solve_QP_wide <- function(X, bound, maxiter, theta, linear_func, gradient, X_theta, ever_active, nactive, kkt_tol, objective_tol, max_active) {
9+
.Call('selectiveInference_solve_QP_wide', PACKAGE = 'selectiveInference', X, bound, maxiter, theta, linear_func, gradient, X_theta, ever_active, nactive, kkt_tol, objective_tol, max_active)
610
}
711

812
update1_ <- function(Q2, w, m, k) {
9-
.Call('_selectiveInference_update1_', PACKAGE = 'selectiveInference', Q2, w, m, k)
13+
.Call('selectiveInference_update1_', PACKAGE = 'selectiveInference', Q2, w, m, k)
1014
}
1115

1216
downdate1_ <- function(Q1, R, j0, m, n) {
13-
.Call('_selectiveInference_downdate1_', PACKAGE = 'selectiveInference', Q1, R, j0, m, n)
17+
.Call('selectiveInference_downdate1_', PACKAGE = 'selectiveInference', Q1, R, j0, m, n)
1418
}
1519

selectiveInference/R/funs.fixed.R

Lines changed: 57 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -154,20 +154,24 @@ fixedLassoInf <- function(x, y, beta,
154154

155155
# Reorder so that active set S is first
156156
Xordered = Xint[,c(S,notS,recursive=T)]
157+
hsigmaS = 1/n*(t(XS)%*%XS) # hsigma[S,S]
158+
hsigmaSinv = solve(hsigmaS) # pinv(hsigmaS)
157159

158-
hsigma <- 1/n*(t(Xordered)%*%Xordered)
159-
hsigmaS <- 1/n*(t(XS)%*%XS) # hsigma[S,S]
160-
hsigmaSinv <- solve(hsigmaS) # pinv(hsigmaS)
160+
FS = rbind(diag(length(S)),matrix(0,pp-length(S),length(S)))
161+
GS = cbind(diag(length(S)),matrix(0,length(S),pp-length(S)))
161162

162-
# Approximate inverse covariance matrix for when (n < p) from lasso_Inference.R
163+
is_wide = n < (2 * p) # somewhat arbitrary decision -- it is really for when we don't want to form with pxp matrices
163164

164-
htheta = debiasingMatrix(hsigma, n, 1:length(S), verbose=FALSE, max_try=linesearch.try, warn_kkt=TRUE)
165+
# Approximate inverse covariance matrix for when (n < p) from lasso_Inference.R
166+
if (!is_wide) {
167+
hsigma = 1/n*(t(Xordered)%*%Xordered)
168+
htheta = debiasingMatrix(hsigma, is_wide, n, 1:length(S), verbose=FALSE, max_try=linesearch.try, warn_kkt=TRUE)
169+
ithetasigma = (GS-(htheta%*%hsigma))
170+
} else {
171+
htheta = debiasingMatrix(Xordered, is_wide, n, 1:length(S), verbose=FALSE, max_try=linesearch.try, warn_kkt=TRUE)
172+
ithetasigma = (GS-((htheta%*%t(Xordered)) %*% Xordered)/n)
173+
}
165174

166-
FS = rbind(diag(length(S)),matrix(0,pp-length(S),length(S)))
167-
GS = cbind(diag(length(S)),matrix(0,length(S),pp-length(S)))
168-
ithetasigma = (GS-(htheta%*%hsigma))
169-
# ithetasigma = (diag(pp) - (htheta%*%hsigma))
170-
171175
M <- (((htheta%*%t(Xordered))+ithetasigma%*%FS%*%hsigmaSinv%*%t(XS))/n)
172176
# vector which is offset for testing debiased beta's
173177
null_value <- (((ithetasigma%*%FS%*%hsigmaSinv)%*%sign(hbetaS))*lambda/n)
@@ -264,10 +268,11 @@ fixedLassoPoly =
264268
## Approximates inverse covariance matrix theta
265269
## using coordinate descent
266270

267-
debiasingMatrix = function(Sigma,
271+
debiasingMatrix = function(Xinfo, # could be X or t(X) %*% X / n depending on is_wide
272+
is_wide,
268273
nsample,
269274
rows,
270-
verbose=FALSE,
275+
verbose=FALSE,
271276
mu=NULL, # starting value of mu
272277
linesearch=TRUE, # do a linesearch?
273278
scaling_factor=1.5, # multiplicative factor for linesearch
@@ -284,7 +289,7 @@ debiasingMatrix = function(Sigma,
284289
max_active = max(50, 0.3 * nsample)
285290
}
286291

287-
p = nrow(Sigma);
292+
p = ncol(Xinfo);
288293
M = matrix(0, length(rows), p);
289294

290295
if (is.null(mu)) {
@@ -295,19 +300,19 @@ debiasingMatrix = function(Sigma,
295300
xp = round(p/10);
296301
idx = 1;
297302
for (row in rows) {
298-
299303
if ((idx %% xp)==0){
300304
xperc = xperc+10;
301305
if (verbose) {
302306
print(paste(xperc,"% done",sep="")); }
303307
}
304308

305-
output = debiasingRow(Sigma,
309+
output = debiasingRow(Xinfo, # could be X or t(X) %*% X / n depending on is_wide
310+
is_wide,
306311
row,
307312
mu,
308313
linesearch=linesearch,
309314
scaling_factor=scaling_factor,
310-
max_active=max_active,
315+
max_active=max_active,
311316
max_try=max_try,
312317
warn_kkt=FALSE,
313318
max_iter=max_iter,
@@ -329,31 +334,32 @@ debiasingMatrix = function(Sigma,
329334
return(M)
330335
}
331336

332-
# Find one row of the debiasing matrix
337+
# Find one row of the debiasing matrix -- assuming X^TX/n is not too large -- i.e. X is tall
333338

334-
debiasingRow = function (Sigma,
339+
debiasingRow = function (Xinfo, # could be X or t(X) %*% X / n depending on is_wide
340+
is_wide,
335341
row,
336342
mu,
337-
linesearch=TRUE, # do a linesearch?
343+
linesearch=TRUE, # do a linesearch?
338344
scaling_factor=1.2, # multiplicative factor for linesearch
339-
max_active=NULL, # how big can active set get?
345+
max_active=NULL, # how big can active set get?
340346
max_try=10, # how many steps in linesearch?
341347
warn_kkt=FALSE, # warn if KKT does not seem to be satisfied?
342348
max_iter=100, # how many iterations for each optimization problem
343349
kkt_tol=1.e-4, # tolerance for the KKT conditions
344350
objective_tol=1.e-8 # tolerance for relative decrease in objective
345351
) {
346352

347-
p = nrow(Sigma)
353+
p = ncol(Xinfo)
348354

349355
if (is.null(max_active)) {
350-
max_active = nrow(Sigma)
356+
max_active = min(nrow(Xinfo), ncol(Xinfo))
351357
}
352358

353359
# Initialize variables
354360

355361
soln = rep(0, p)
356-
362+
Xsoln = rep(0, n)
357363
ever_active = rep(0, p)
358364
ever_active[1] = row # 1-based
359365
ever_active = as.integer(ever_active)
@@ -371,17 +377,33 @@ debiasingRow = function (Sigma,
371377

372378
while (counter_idx < max_try) {
373379

374-
result = solve_QP(Sigma,
375-
mu,
376-
max_iter,
377-
soln,
378-
linear_func,
379-
gradient,
380-
ever_active,
381-
nactive,
382-
kkt_tol,
383-
objective_tol,
384-
max_active)
380+
if (!is_wide) {
381+
result = solve_QP(Xinfo, # this is non-neg-def matrix
382+
mu,
383+
max_iter,
384+
soln,
385+
linear_func,
386+
gradient,
387+
ever_active,
388+
nactive,
389+
kkt_tol,
390+
objective_tol,
391+
max_active)
392+
} else {
393+
result = solve_QP_wide(Xinfo, # this is a design matrix
394+
mu,
395+
max_iter,
396+
soln,
397+
linear_func,
398+
gradient,
399+
Xsoln,
400+
ever_active,
401+
nactive,
402+
kkt_tol,
403+
objective_tol,
404+
max_active)
405+
406+
}
385407

386408
iter = result$iter
387409

@@ -439,6 +461,7 @@ debiasingRow = function (Sigma,
439461

440462
}
441463

464+
442465
##############################
443466

444467
print.fixedLassoInf <- function(x, tailarea=TRUE, ...) {

selectiveInference/man/debiasingMatrix.Rd

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ Newton step from some consistent estimator (such as the LASSO)
1111
to find a debiased solution.
1212
}
1313
\usage{
14-
debiasingMatrix(Sigma,
14+
debiasingMatrix(Xinfo,
15+
is_wide,
1516
nsample,
1617
rows,
1718
verbose=FALSE,
@@ -26,8 +27,14 @@ debiasingMatrix(Sigma,
2627
objective_tol=1.e-8)
2728
}
2829
\arguments{
29-
\item{Sigma}{
30-
A symmetric non-negative definite matrix, often a cross-covariance matrix.
30+
\item{Xinfo}{
31+
Either a non-negative definite matrix S=t(X) %*% X / n or X itself. If
32+
is_wide is TRUE, then Xinfo should be X, otherwise it should be S.
33+
}
34+
\item{is_wide}{
35+
Are we solving for rows of the debiasing matrix assuming it is
36+
a wide matrix so that Xinfo=X and the non-negative definite
37+
matrix of interest is t(X) %*% X / nrow(X).
3138
}
3239
\item{nsample}{
3340
Number of samples used in forming the cross-covariance matrix.
@@ -101,8 +108,9 @@ set.seed(10)
101108
n = 50
102109
p = 100
103110
X = matrix(rnorm(n * p), n, p)
104-
S = t(X) \%*\% X / n
105-
M = debiasingMatrix(S, n, c(1,3,5))
106-
111+
S = t(X) %*% X / n
112+
M = debiasingMatrix(S, FALSE, n, c(1,3,5))
113+
M2 = debiasingMatrix(X, TRUE, n, c(1,3,5))
114+
max(M - M2)
107115
}
108116

selectiveInference/src/Makevars

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ PKG_CFLAGS= -I.
22
PKG_CPPFLAGS= -I.
33
PKG_LIBS=-L.
44

5-
$(SHLIB): Rcpp Rcpp-matrixcomps.o Rcpp-debias.o RcppExports.o quadratic_program.o
5+
$(SHLIB): Rcpp Rcpp-matrixcomps.o Rcpp-debias.o RcppExports.o quadratic_program.o quadratic_program_wide.o
66

77
clean:
88
rm -f *o

selectiveInference/src/Rcpp-debias.cpp

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include <Rcpp.h> // need to include the main Rcpp header file
2-
#include <debias.h> // where find_one_row_void is defined
2+
#include <debias.h> // where solve_QP, solve_QP_wide are defined
33

44
// Below, the gradient should be equal to Sigma * theta + linear_func!!
55
// No check is done on this.
@@ -68,3 +68,98 @@ Rcpp::List solve_QP(Rcpp::NumericMatrix Sigma,
6868
Rcpp::Named("max_active_check") = max_active_check));
6969

7070
}
71+
72+
73+
// [[Rcpp::export]]
74+
Rcpp::List solve_QP_wide(Rcpp::NumericMatrix X,
75+
double bound,
76+
int maxiter,
77+
Rcpp::NumericVector theta,
78+
Rcpp::NumericVector linear_func,
79+
Rcpp::NumericVector gradient,
80+
Rcpp::NumericVector X_theta,
81+
Rcpp::IntegerVector ever_active,
82+
Rcpp::IntegerVector nactive,
83+
double kkt_tol,
84+
double objective_tol,
85+
int max_active
86+
) {
87+
88+
int ncase = X.nrow(); // number of cases
89+
int nfeature = X.ncol(); // number of features
90+
91+
// Active set
92+
93+
int icase, ifeature;
94+
95+
// A vector to keep track of gradient updates
96+
97+
Rcpp::IntegerVector need_update(nfeature);
98+
99+
// Extract the diagonal
100+
Rcpp::NumericVector nndef_diag(nfeature);
101+
double *nndef_diag_p = nndef_diag.begin();
102+
103+
for (ifeature=0; ifeature<nfeature; ifeature++) {
104+
nndef_diag_p[ifeature] = 0;
105+
for (icase=0; icase<ncase; icase++) {
106+
nndef_diag_p[ifeature] += X(icase, ifeature) * X(icase, ifeature);
107+
}
108+
nndef_diag_p[ifeature] = nndef_diag_p[ifeature] / ncase;
109+
}
110+
111+
// Now call our C function
112+
113+
int iter = solve_wide((double *) X.begin(),
114+
(double *) X_theta.begin(),
115+
(double *) linear_func.begin(),
116+
(double *) nndef_diag.begin(),
117+
(double *) gradient.begin(),
118+
(int *) need_update.begin(),
119+
(int *) ever_active.begin(),
120+
(int *) nactive.begin(),
121+
ncase,
122+
nfeature,
123+
bound,
124+
(double *) theta.begin(),
125+
maxiter,
126+
kkt_tol,
127+
objective_tol,
128+
max_active);
129+
130+
// Check whether feasible
131+
132+
int kkt_check = check_KKT_wide((double *) theta.begin(),
133+
(double *) gradient.begin(),
134+
(double *) X_theta.begin(),
135+
(double *) X.begin(),
136+
(double *) linear_func.begin(),
137+
(int *) need_update.begin(),
138+
nfeature,
139+
ncase,
140+
bound,
141+
kkt_tol);
142+
143+
int max_active_check = (*(nactive.begin()) >= max_active);
144+
145+
// Make sure gradient is updated -- essentially a matrix multiply
146+
147+
update_gradient_wide((double *) gradient.begin(),
148+
(double *) X_theta.begin(),
149+
(double *) X.begin(),
150+
(double *) linear_func.begin(),
151+
(int *) need_update.begin(),
152+
nfeature,
153+
ncase);
154+
155+
return(Rcpp::List::create(Rcpp::Named("soln") = theta,
156+
Rcpp::Named("gradient") = gradient,
157+
Rcpp::Named("X_theta") = X_theta,
158+
Rcpp::Named("linear_func") = linear_func,
159+
Rcpp::Named("iter") = iter,
160+
Rcpp::Named("kkt_check") = kkt_check,
161+
Rcpp::Named("ever_active") = ever_active,
162+
Rcpp::Named("nactive") = nactive,
163+
Rcpp::Named("max_active_check") = max_active_check));
164+
165+
}

0 commit comments

Comments
 (0)