Skip to content

Commit dd30abc

Browse files
WIP: builds now, but segfaulting
1 parent 6238a92 commit dd30abc

File tree

5 files changed

+302
-154
lines changed

5 files changed

+302
-154
lines changed

selectiveInference/R/RcppExports.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1-
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
1+
# This file was generated by Rcpp::compileAttributes
22
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
33

44
solve_QP <- function(Sigma, bound, maxiter, theta, linear_func, gradient, ever_active, nactive, kkt_tol, objective_tol, max_active) {
5-
.Call('_selectiveInference_solve_QP', PACKAGE = 'selectiveInference', Sigma, bound, maxiter, theta, linear_func, gradient, ever_active, nactive, kkt_tol, objective_tol, max_active)
5+
.Call('selectiveInference_solve_QP', PACKAGE = 'selectiveInference', Sigma, bound, maxiter, theta, linear_func, gradient, ever_active, nactive, kkt_tol, objective_tol, max_active)
66
}
77

8-
solve_QP_wide <- function(X, bound, maxiter, theta, linear_func, gradient, ever_active, nactive, kkt_tol, objective_tol, max_active) {
9-
.Call('_selectiveInference_solve_QP_wide', PACKAGE = 'selectiveInference', X, bound, maxiter, theta, linear_func, gradient, ever_active, nactive, kkt_tol, objective_tol, max_active)
8+
solve_QP_wide <- function(X, bound, maxiter, theta, linear_func, gradient, X_theta, ever_active, nactive, kkt_tol, objective_tol, max_active) {
9+
.Call('selectiveInference_solve_QP_wide', PACKAGE = 'selectiveInference', X, bound, maxiter, theta, linear_func, gradient, X_theta, ever_active, nactive, kkt_tol, objective_tol, max_active)
1010
}
1111

1212
update1_ <- function(Q2, w, m, k) {
13-
.Call('_selectiveInference_update1_', PACKAGE = 'selectiveInference', Q2, w, m, k)
13+
.Call('selectiveInference_update1_', PACKAGE = 'selectiveInference', Q2, w, m, k)
1414
}
1515

1616
downdate1_ <- function(Q1, R, j0, m, n) {
17-
.Call('_selectiveInference_downdate1_', PACKAGE = 'selectiveInference', Q1, R, j0, m, n)
17+
.Call('selectiveInference_downdate1_', PACKAGE = 'selectiveInference', Q1, R, j0, m, n)
1818
}
1919

selectiveInference/src/Rcpp-debias.cpp

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -77,42 +77,49 @@ Rcpp::List solve_QP_wide(Rcpp::NumericMatrix X,
7777
Rcpp::NumericVector theta,
7878
Rcpp::NumericVector linear_func,
7979
Rcpp::NumericVector gradient,
80+
Rcpp::NumericVector X_theta,
8081
Rcpp::IntegerVector ever_active,
8182
Rcpp::IntegerVector nactive,
8283
double kkt_tol,
8384
double objective_tol,
8485
int max_active
8586
) {
8687

87-
int nrow = X.nrow(); // number of cases
88-
int ncol = X.ncol(); // number of features
88+
int ncase = X.nrow(); // number of cases
89+
int nfeature = X.ncol(); // number of features
8990

9091
// Active set
9192

92-
int irow, icol;
93+
int icase, ifeature;
94+
95+
// A vector to keep track of gradient updates
96+
97+
Rcpp::IntegerVector need_update(nfeature);
9398

9499
// Extract the diagonal
95-
Rcpp::NumericVector X_diag(ncol);
100+
Rcpp::NumericVector X_diag(nfeature);
96101
double *X_diag_p = X_diag.begin();
97102

98-
for (icol=0; icol<ncol; icol++) {
99-
X_diag_p[irow] = 0;
100-
for (irow=0; irow<nrow; irow++) {
101-
X_diag_p[irow] += X(irow, icol) * X(irow, icol);
103+
for (icase=0; icase<ncase; icase++) {
104+
X_diag_p[icase] = 0;
105+
for (ifeature=0; ifeature<nfeature; ifeature++) {
106+
X_diag_p[icase] += X(icase, ifeature) * X(icase, ifeature);
102107
}
103-
X_diag_p[irow] = X_diag_p[irow] / nrow;
108+
X_diag_p[icase] = X_diag_p[icase] / ncase;
104109
}
105110

106111
// Now call our C function
107112

108113
int iter = solve_wide((double *) X.begin(),
114+
(double *) X_theta.begin(),
109115
(double *) linear_func.begin(),
110116
(double *) X_diag.begin(),
111117
(double *) gradient.begin(),
118+
(int *) need_update.begin(),
112119
(int *) ever_active.begin(),
113120
(int *) nactive.begin(),
114-
nrow,
115-
ncol,
121+
ncase,
122+
nfeature,
116123
bound,
117124
(double *) theta.begin(),
118125
maxiter,
@@ -122,16 +129,32 @@ Rcpp::List solve_QP_wide(Rcpp::NumericMatrix X,
122129

123130
// Check whether feasible
124131

125-
int kkt_check = check_KKT_wide(theta.begin(), // This is the same function as check_KKT_qp at the moment!!
126-
gradient.begin(),
127-
nrow,
132+
int kkt_check = check_KKT_wide((double *) theta.begin(),
133+
(double *) gradient.begin(),
134+
(double *) X_theta.begin(),
135+
(double *) X.begin(),
136+
(double *) linear_func.begin(),
137+
(int *) need_update.begin(),
138+
nfeature,
139+
ncase,
128140
bound,
129141
kkt_tol);
130142

131143
int max_active_check = (*(nactive.begin()) >= max_active);
132144

145+
// Make sure gradient is updated -- essentially a matrix multiply
146+
147+
update_gradient_wide((double *) gradient.begin(),
148+
(double *) X_theta.begin(),
149+
(double *) X.begin(),
150+
(double *) linear_func.begin(),
151+
(int *) need_update.begin(),
152+
nfeature,
153+
ncase);
154+
133155
return(Rcpp::List::create(Rcpp::Named("soln") = theta,
134156
Rcpp::Named("gradient") = gradient,
157+
Rcpp::Named("X_theta") = X_theta,
135158
Rcpp::Named("linear_func") = linear_func,
136159
Rcpp::Named("iter") = iter,
137160
Rcpp::Named("kkt_check") = kkt_check,
Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Generated by using Rcpp::compileAttributes() -> do not edit by hand
1+
// This file was generated by Rcpp::compileAttributes
22
// Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
33

44
#include <Rcpp.h>
@@ -7,10 +7,10 @@ using namespace Rcpp;
77

88
// solve_QP
99
Rcpp::List solve_QP(Rcpp::NumericMatrix Sigma, double bound, int maxiter, Rcpp::NumericVector theta, Rcpp::NumericVector linear_func, Rcpp::NumericVector gradient, Rcpp::IntegerVector ever_active, Rcpp::IntegerVector nactive, double kkt_tol, double objective_tol, int max_active);
10-
RcppExport SEXP _selectiveInference_solve_QP(SEXP SigmaSEXP, SEXP boundSEXP, SEXP maxiterSEXP, SEXP thetaSEXP, SEXP linear_funcSEXP, SEXP gradientSEXP, SEXP ever_activeSEXP, SEXP nactiveSEXP, SEXP kkt_tolSEXP, SEXP objective_tolSEXP, SEXP max_activeSEXP) {
10+
RcppExport SEXP selectiveInference_solve_QP(SEXP SigmaSEXP, SEXP boundSEXP, SEXP maxiterSEXP, SEXP thetaSEXP, SEXP linear_funcSEXP, SEXP gradientSEXP, SEXP ever_activeSEXP, SEXP nactiveSEXP, SEXP kkt_tolSEXP, SEXP objective_tolSEXP, SEXP max_activeSEXP) {
1111
BEGIN_RCPP
12-
Rcpp::RObject rcpp_result_gen;
13-
Rcpp::RNGScope rcpp_rngScope_gen;
12+
Rcpp::RObject __result;
13+
Rcpp::RNGScope __rngScope;
1414
Rcpp::traits::input_parameter< Rcpp::NumericMatrix >::type Sigma(SigmaSEXP);
1515
Rcpp::traits::input_parameter< double >::type bound(boundSEXP);
1616
Rcpp::traits::input_parameter< int >::type maxiter(maxiterSEXP);
@@ -22,70 +22,58 @@ BEGIN_RCPP
2222
Rcpp::traits::input_parameter< double >::type kkt_tol(kkt_tolSEXP);
2323
Rcpp::traits::input_parameter< double >::type objective_tol(objective_tolSEXP);
2424
Rcpp::traits::input_parameter< int >::type max_active(max_activeSEXP);
25-
rcpp_result_gen = Rcpp::wrap(solve_QP(Sigma, bound, maxiter, theta, linear_func, gradient, ever_active, nactive, kkt_tol, objective_tol, max_active));
26-
return rcpp_result_gen;
25+
__result = Rcpp::wrap(solve_QP(Sigma, bound, maxiter, theta, linear_func, gradient, ever_active, nactive, kkt_tol, objective_tol, max_active));
26+
return __result;
2727
END_RCPP
2828
}
2929
// solve_QP_wide
30-
Rcpp::List solve_QP_wide(Rcpp::NumericMatrix X, double bound, int maxiter, Rcpp::NumericVector theta, Rcpp::NumericVector linear_func, Rcpp::NumericVector gradient, Rcpp::IntegerVector ever_active, Rcpp::IntegerVector nactive, double kkt_tol, double objective_tol, int max_active);
31-
RcppExport SEXP _selectiveInference_solve_QP_wide(SEXP XSEXP, SEXP boundSEXP, SEXP maxiterSEXP, SEXP thetaSEXP, SEXP linear_funcSEXP, SEXP gradientSEXP, SEXP ever_activeSEXP, SEXP nactiveSEXP, SEXP kkt_tolSEXP, SEXP objective_tolSEXP, SEXP max_activeSEXP) {
30+
Rcpp::List solve_QP_wide(Rcpp::NumericMatrix X, double bound, int maxiter, Rcpp::NumericVector theta, Rcpp::NumericVector linear_func, Rcpp::NumericVector gradient, Rcpp::NumericVector X_theta, Rcpp::IntegerVector ever_active, Rcpp::IntegerVector nactive, double kkt_tol, double objective_tol, int max_active);
31+
RcppExport SEXP selectiveInference_solve_QP_wide(SEXP XSEXP, SEXP boundSEXP, SEXP maxiterSEXP, SEXP thetaSEXP, SEXP linear_funcSEXP, SEXP gradientSEXP, SEXP X_thetaSEXP, SEXP ever_activeSEXP, SEXP nactiveSEXP, SEXP kkt_tolSEXP, SEXP objective_tolSEXP, SEXP max_activeSEXP) {
3232
BEGIN_RCPP
33-
Rcpp::RObject rcpp_result_gen;
34-
Rcpp::RNGScope rcpp_rngScope_gen;
33+
Rcpp::RObject __result;
34+
Rcpp::RNGScope __rngScope;
3535
Rcpp::traits::input_parameter< Rcpp::NumericMatrix >::type X(XSEXP);
3636
Rcpp::traits::input_parameter< double >::type bound(boundSEXP);
3737
Rcpp::traits::input_parameter< int >::type maxiter(maxiterSEXP);
3838
Rcpp::traits::input_parameter< Rcpp::NumericVector >::type theta(thetaSEXP);
3939
Rcpp::traits::input_parameter< Rcpp::NumericVector >::type linear_func(linear_funcSEXP);
4040
Rcpp::traits::input_parameter< Rcpp::NumericVector >::type gradient(gradientSEXP);
41+
Rcpp::traits::input_parameter< Rcpp::NumericVector >::type X_theta(X_thetaSEXP);
4142
Rcpp::traits::input_parameter< Rcpp::IntegerVector >::type ever_active(ever_activeSEXP);
4243
Rcpp::traits::input_parameter< Rcpp::IntegerVector >::type nactive(nactiveSEXP);
4344
Rcpp::traits::input_parameter< double >::type kkt_tol(kkt_tolSEXP);
4445
Rcpp::traits::input_parameter< double >::type objective_tol(objective_tolSEXP);
4546
Rcpp::traits::input_parameter< int >::type max_active(max_activeSEXP);
46-
rcpp_result_gen = Rcpp::wrap(solve_QP_wide(X, bound, maxiter, theta, linear_func, gradient, ever_active, nactive, kkt_tol, objective_tol, max_active));
47-
return rcpp_result_gen;
47+
__result = Rcpp::wrap(solve_QP_wide(X, bound, maxiter, theta, linear_func, gradient, X_theta, ever_active, nactive, kkt_tol, objective_tol, max_active));
48+
return __result;
4849
END_RCPP
4950
}
5051
// update1_
5152
Rcpp::List update1_(Rcpp::NumericMatrix Q2, Rcpp::NumericVector w, int m, int k);
52-
RcppExport SEXP _selectiveInference_update1_(SEXP Q2SEXP, SEXP wSEXP, SEXP mSEXP, SEXP kSEXP) {
53+
RcppExport SEXP selectiveInference_update1_(SEXP Q2SEXP, SEXP wSEXP, SEXP mSEXP, SEXP kSEXP) {
5354
BEGIN_RCPP
54-
Rcpp::RObject rcpp_result_gen;
55-
Rcpp::RNGScope rcpp_rngScope_gen;
55+
Rcpp::RObject __result;
56+
Rcpp::RNGScope __rngScope;
5657
Rcpp::traits::input_parameter< Rcpp::NumericMatrix >::type Q2(Q2SEXP);
5758
Rcpp::traits::input_parameter< Rcpp::NumericVector >::type w(wSEXP);
5859
Rcpp::traits::input_parameter< int >::type m(mSEXP);
5960
Rcpp::traits::input_parameter< int >::type k(kSEXP);
60-
rcpp_result_gen = Rcpp::wrap(update1_(Q2, w, m, k));
61-
return rcpp_result_gen;
61+
__result = Rcpp::wrap(update1_(Q2, w, m, k));
62+
return __result;
6263
END_RCPP
6364
}
6465
// downdate1_
6566
Rcpp::List downdate1_(Rcpp::NumericMatrix Q1, Rcpp::NumericMatrix R, int j0, int m, int n);
66-
RcppExport SEXP _selectiveInference_downdate1_(SEXP Q1SEXP, SEXP RSEXP, SEXP j0SEXP, SEXP mSEXP, SEXP nSEXP) {
67+
RcppExport SEXP selectiveInference_downdate1_(SEXP Q1SEXP, SEXP RSEXP, SEXP j0SEXP, SEXP mSEXP, SEXP nSEXP) {
6768
BEGIN_RCPP
68-
Rcpp::RObject rcpp_result_gen;
69-
Rcpp::RNGScope rcpp_rngScope_gen;
69+
Rcpp::RObject __result;
70+
Rcpp::RNGScope __rngScope;
7071
Rcpp::traits::input_parameter< Rcpp::NumericMatrix >::type Q1(Q1SEXP);
7172
Rcpp::traits::input_parameter< Rcpp::NumericMatrix >::type R(RSEXP);
7273
Rcpp::traits::input_parameter< int >::type j0(j0SEXP);
7374
Rcpp::traits::input_parameter< int >::type m(mSEXP);
7475
Rcpp::traits::input_parameter< int >::type n(nSEXP);
75-
rcpp_result_gen = Rcpp::wrap(downdate1_(Q1, R, j0, m, n));
76-
return rcpp_result_gen;
76+
__result = Rcpp::wrap(downdate1_(Q1, R, j0, m, n));
77+
return __result;
7778
END_RCPP
7879
}
79-
80-
static const R_CallMethodDef CallEntries[] = {
81-
{"_selectiveInference_solve_QP", (DL_FUNC) &_selectiveInference_solve_QP, 11},
82-
{"_selectiveInference_solve_QP_wide", (DL_FUNC) &_selectiveInference_solve_QP_wide, 11},
83-
{"_selectiveInference_update1_", (DL_FUNC) &_selectiveInference_update1_, 4},
84-
{"_selectiveInference_downdate1_", (DL_FUNC) &_selectiveInference_downdate1_, 5},
85-
{NULL, NULL, 0}
86-
};
87-
88-
RcppExport void R_init_selectiveInference(DllInfo *dll) {
89-
R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
90-
R_useDynamicSymbols(dll, FALSE);
91-
}

selectiveInference/src/debias.h

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,41 @@ int check_KKT_qp(double *theta, /* current theta */
2424
double tol); /* precision for checking KKT conditions */
2525

2626
int solve_wide(double *X_ptr, /* A design matrix */
27+
double *X_theta_ptr, /* Fitted values */
2728
double *linear_func_ptr, /* Linear term in objective */
2829
double *X_diag_ptr, /* Diagonal entry of covariance matrix */
29-
double *gradient_ptr, /* Current gradient of quadratic loss */
30-
int *ever_active_ptr, /* Ever active set: 0-based */
30+
double *gradient_ptr, /* X times theta */
31+
int *need_update_ptr, /* Keeps track of updated gradient coords */
32+
int *ever_active_ptr, /* Ever active set: 1-based */
3133
int *nactive_ptr, /* Size of ever active set */
32-
int nrow, /* How many rows in X */
33-
int ncol, /* How many rows in X */
34+
int ncase, /* How many rows in X */
35+
int nfeature, /* How many columns in X */
3436
double bound, /* feasibility parameter */
35-
double *theta, /* current value */
36-
int maxiter, /* how many iterations */
37+
double *theta_ptr, /* current value */
38+
int maxiter, /* max number of iterations */
3739
double kkt_tol, /* precision for checking KKT conditions */
3840
double objective_tol, /* precision for checking relative decrease in objective value */
39-
int max_active); /* Upper limit for size of active set -- otherwise break */
41+
int max_active); /* Upper limit for size of active set -- otherwise break */
4042

41-
int check_KKT_wide(double *theta, /* current theta */
42-
double *gradient_ptr, /* Current gradient of quadratic loss */
43-
int nrow, /* how many rows in Sigma */
44-
double bound, /* Lagrange multipler for \ell_1 */
45-
double tol); /* precision for checking KKT conditions */
43+
int check_KKT_wide(double *theta_ptr, /* current theta */
44+
double *gradient_ptr, /* X^TX/n times theta */
45+
double *X_theta_ptr, /* Current fitted values */
46+
double *X_ptr, /* A design matrix */
47+
double *linear_func_ptr, /* Linear term in objective */
48+
int *need_update_ptr, /* Which coordinates need to be updated? */
49+
int nfeature, /* how many columns in X */
50+
int ncase, /* how many rows in X */
51+
double bound, /* Lagrange multipler for \ell_1 */
52+
double tol); /* precision for checking KKT conditions */
4653

54+
void update_gradient_wide(double *gradient_ptr, /* X^TX/n times theta */
55+
double *X_theta_ptr, /* Current fitted values */
56+
double *X_ptr, /* A design matrix */
57+
double *linear_func_ptr, /* Linear term in objective */
58+
int *need_update_ptr, /* Which coordinates need to be updated? */
59+
int nfeature, /* how many columns in X */
60+
int ncase); /* how many rows in X */
61+
4762
#ifdef __cplusplus
4863
} /* extern "C" */
4964
#endif /* __cplusplus */

0 commit comments

Comments
 (0)