Skip to content

Commit 86a92fd

Browse files
WIP: writing a solver for wide X matrices
1 parent 41dc63a commit 86a92fd

File tree

5 files changed

+460
-0
lines changed

5 files changed

+460
-0
lines changed

selectiveInference/R/RcppExports.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ solve_QP <- function(Sigma, bound, maxiter, theta, linear_func, gradient, ever_a
55
.Call('_selectiveInference_solve_QP', PACKAGE = 'selectiveInference', Sigma, bound, maxiter, theta, linear_func, gradient, ever_active, nactive, kkt_tol, objective_tol, max_active)
66
}
77

8+
solve_QP_wide <- function(X, bound, maxiter, theta, linear_func, gradient, ever_active, nactive, kkt_tol, objective_tol, max_active) {
9+
.Call('_selectiveInference_solve_QP_wide', PACKAGE = 'selectiveInference', X, bound, maxiter, theta, linear_func, gradient, ever_active, nactive, kkt_tol, objective_tol, max_active)
10+
}
11+
812
update1_ <- function(Q2, w, m, k) {
913
.Call('_selectiveInference_update1_', PACKAGE = 'selectiveInference', Q2, w, m, k)
1014
}

selectiveInference/src/Rcpp-debias.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,75 @@ Rcpp::List solve_QP(Rcpp::NumericMatrix Sigma,
6868
Rcpp::Named("max_active_check") = max_active_check));
6969

7070
}
71+
72+
73+
// [[Rcpp::export]]
74+
Rcpp::List solve_QP_wide(Rcpp::NumericMatrix X,
75+
double bound,
76+
int maxiter,
77+
Rcpp::NumericVector theta,
78+
Rcpp::NumericVector linear_func,
79+
Rcpp::NumericVector gradient,
80+
Rcpp::IntegerVector ever_active,
81+
Rcpp::IntegerVector nactive,
82+
double kkt_tol,
83+
double objective_tol,
84+
int max_active
85+
) {
86+
87+
int nrow = X.nrow(); // number of cases
88+
int ncol = X.ncol(); // number of features
89+
90+
// Active set
91+
92+
int irow, icol;
93+
94+
// Extract the diagonal
95+
Rcpp::NumericVector X_diag(ncol);
96+
double *X_diag_p = X_diag.begin();
97+
98+
for (icol=0; icol<ncol; icol++) {
99+
X_diag_p[irow] = 0;
100+
for (irow=0; irow<nrow; irow++) {
101+
X_diag_p[irow] += X(irow, icol) * X(irow, icol);
102+
}
103+
X_diag_p[irow] = X_diag_p[irow] / nrow;
104+
}
105+
106+
// Now call our C function
107+
108+
int iter = solve_wide((double *) X.begin(),
109+
(double *) linear_func.begin(),
110+
(double *) X_diag.begin(),
111+
(double *) gradient.begin(),
112+
(int *) ever_active.begin(),
113+
(int *) nactive.begin(),
114+
nrow,
115+
ncol,
116+
bound,
117+
(double *) theta.begin(),
118+
maxiter,
119+
kkt_tol,
120+
objective_tol,
121+
max_active);
122+
123+
// Check whether feasible
124+
125+
int kkt_check = check_KKT_wide(theta.begin(), // This is the same function as check_KKT_qp at the moment!!
126+
gradient.begin(),
127+
nrow,
128+
bound,
129+
kkt_tol);
130+
131+
int max_active_check = (*(nactive.begin()) >= max_active);
132+
133+
return(Rcpp::List::create(Rcpp::Named("soln") = theta,
134+
Rcpp::Named("gradient") = gradient,
135+
Rcpp::Named("linear_func") = linear_func,
136+
Rcpp::Named("iter") = iter,
137+
Rcpp::Named("kkt_check") = kkt_check,
138+
Rcpp::Named("ever_active") = ever_active,
139+
Rcpp::Named("nactive") = nactive,
140+
Rcpp::Named("max_active_check") = max_active_check));
141+
142+
}

selectiveInference/src/RcppExports.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,27 @@ BEGIN_RCPP
2626
return rcpp_result_gen;
2727
END_RCPP
2828
}
29+
// solve_QP_wide
30+
Rcpp::List solve_QP_wide(Rcpp::NumericMatrix X, double bound, int maxiter, Rcpp::NumericVector theta, Rcpp::NumericVector linear_func, Rcpp::NumericVector gradient, Rcpp::IntegerVector ever_active, Rcpp::IntegerVector nactive, double kkt_tol, double objective_tol, int max_active);
31+
RcppExport SEXP _selectiveInference_solve_QP_wide(SEXP XSEXP, SEXP boundSEXP, SEXP maxiterSEXP, SEXP thetaSEXP, SEXP linear_funcSEXP, SEXP gradientSEXP, SEXP ever_activeSEXP, SEXP nactiveSEXP, SEXP kkt_tolSEXP, SEXP objective_tolSEXP, SEXP max_activeSEXP) {
32+
BEGIN_RCPP
33+
Rcpp::RObject rcpp_result_gen;
34+
Rcpp::RNGScope rcpp_rngScope_gen;
35+
Rcpp::traits::input_parameter< Rcpp::NumericMatrix >::type X(XSEXP);
36+
Rcpp::traits::input_parameter< double >::type bound(boundSEXP);
37+
Rcpp::traits::input_parameter< int >::type maxiter(maxiterSEXP);
38+
Rcpp::traits::input_parameter< Rcpp::NumericVector >::type theta(thetaSEXP);
39+
Rcpp::traits::input_parameter< Rcpp::NumericVector >::type linear_func(linear_funcSEXP);
40+
Rcpp::traits::input_parameter< Rcpp::NumericVector >::type gradient(gradientSEXP);
41+
Rcpp::traits::input_parameter< Rcpp::IntegerVector >::type ever_active(ever_activeSEXP);
42+
Rcpp::traits::input_parameter< Rcpp::IntegerVector >::type nactive(nactiveSEXP);
43+
Rcpp::traits::input_parameter< double >::type kkt_tol(kkt_tolSEXP);
44+
Rcpp::traits::input_parameter< double >::type objective_tol(objective_tolSEXP);
45+
Rcpp::traits::input_parameter< int >::type max_active(max_activeSEXP);
46+
rcpp_result_gen = Rcpp::wrap(solve_QP_wide(X, bound, maxiter, theta, linear_func, gradient, ever_active, nactive, kkt_tol, objective_tol, max_active));
47+
return rcpp_result_gen;
48+
END_RCPP
49+
}
2950
// update1_
3051
Rcpp::List update1_(Rcpp::NumericMatrix Q2, Rcpp::NumericVector w, int m, int k);
3152
RcppExport SEXP _selectiveInference_update1_(SEXP Q2SEXP, SEXP wSEXP, SEXP mSEXP, SEXP kSEXP) {
@@ -58,6 +79,7 @@ END_RCPP
5879

5980
static const R_CallMethodDef CallEntries[] = {
6081
{"_selectiveInference_solve_QP", (DL_FUNC) &_selectiveInference_solve_QP, 11},
82+
{"_selectiveInference_solve_QP_wide", (DL_FUNC) &_selectiveInference_solve_QP_wide, 11},
6183
{"_selectiveInference_update1_", (DL_FUNC) &_selectiveInference_update1_, 4},
6284
{"_selectiveInference_downdate1_", (DL_FUNC) &_selectiveInference_downdate1_, 5},
6385
{NULL, NULL, 0}

selectiveInference/src/debias.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,27 @@ int check_KKT_qp(double *theta, /* current theta */
2323
double bound, /* Lagrange multipler for \ell_1 */
2424
double tol); /* precision for checking KKT conditions */
2525

26+
int solve_wide(double *X_ptr, /* A design matrix */
27+
double *linear_func_ptr, /* Linear term in objective */
28+
double *X_diag_ptr, /* Diagonal entry of covariance matrix */
29+
double *gradient_ptr, /* Current gradient of quadratic loss */
30+
int *ever_active_ptr, /* Ever active set: 0-based */
31+
int *nactive_ptr, /* Size of ever active set */
32+
int nrow, /* How many rows in X */
33+
int ncol, /* How many rows in X */
34+
double bound, /* feasibility parameter */
35+
double *theta, /* current value */
36+
int maxiter, /* how many iterations */
37+
double kkt_tol, /* precision for checking KKT conditions */
38+
double objective_tol, /* precision for checking relative decrease in objective value */
39+
int max_active); /* Upper limit for size of active set -- otherwise break */
40+
41+
int check_KKT_wide(double *theta, /* current theta */
42+
double *gradient_ptr, /* Current gradient of quadratic loss */
43+
int nrow, /* how many rows in Sigma */
44+
double bound, /* Lagrange multipler for \ell_1 */
45+
double tol); /* precision for checking KKT conditions */
46+
2647
#ifdef __cplusplus
2748
} /* extern "C" */
2849
#endif /* __cplusplus */

0 commit comments

Comments
 (0)