Skip to content

Commit a4055ce

Browse files
BF: forgot to rename in Rcpp file, passing active and ever_active from R
1 parent 7f1e4a6 commit a4055ce

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

selectiveInference/R/funs.fixed.R

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,18 +341,24 @@ InverseLinftyOneRow <- function (Sigma, i, mu, maxiter=50, soln_result=NULL) {
341341

342342
# If soln_result is not Null, it is used as a warm start.
343343
# It should be a list
344-
# with entries "soln" and "Sigma_soln"
344+
# with entries "soln", "gradient", "ever_active", "nactive"
345345

346346
if (is.null(soln_result)) {
347347
soln = rep(0, nrow(Sigma))
348-
Sigma_soln = rep(0, nrow(Sigma))
348+
gradient = rep(0, nrow(Sigma))
349+
ever_active = rep(0, nrow(Sigma))
350+
ever_active[1] = i-1 # 0-based
351+
ever_active = as.integer(ever_active)
352+
nactive = as.integer(1)
349353
}
350354
else {
351355
soln = soln_result$soln
352-
Sigma_soln = soln_result$Sigma_soln
356+
gradient = soln_result$gradient
357+
ever_active = as.integer(soln_result$ever_active)
358+
nactive = as.integer(soln_result$nactive)
353359
}
354360

355-
result = find_one_row_debiasingM(Sigma, i-1, mu, maxiter, soln, Sigma_soln) # C function uses 0-based indexing
361+
result = find_one_row_debiasingM(Sigma, i-1, mu, maxiter, soln, gradient, ever_active, nactive) # C function uses 0-based indexing
356362

357363
# Check feasibility
358364

selectiveInference/src/Rcpp-debias.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,30 @@ Rcpp::List find_one_row_debiasingM(Rcpp::NumericMatrix Sigma,
77
double bound,
88
int maxiter,
99
Rcpp::NumericVector theta,
10-
Rcpp::NumericVector Sigma_theta) {
10+
Rcpp::NumericVector gradient,
11+
Rcpp::IntegerVector ever_active,
12+
Rcpp::IntegerVector nactive
13+
) {
1114

1215
int nrow = Sigma.nrow(); // number of features
1316

1417
// Active set
1518

1619
int irow;
17-
Rcpp::IntegerVector nactive(1); // An array so we can easily modify it
18-
Rcpp::IntegerVector ever_active(1);
19-
int *ever_active_p = ever_active.begin();
20-
*ever_active_p = row;
2120

2221
// Extract the diagonal
2322
Rcpp::NumericVector Sigma_diag(nrow);
24-
double *sigma_p = Sigma_diag.begin();
23+
double *sigma_diag_p = Sigma_diag.begin();
2524

2625
for (irow=0; irow<nrow; irow++) {
27-
sigma_p[irow] = Sigma(irow, irow);
26+
sigma_diag_p[irow] = Sigma(irow, irow);
2827
}
2928

3029
// Now call our C function
3130

3231
int iter = find_one_row_((double *) Sigma.begin(),
3332
(double *) Sigma_diag.begin(),
34-
(double *) Sigma_theta.begin(),
33+
(double *) gradient.begin(),
3534
(int *) ever_active.begin(),
3635
(int *) nactive.begin(),
3736
nrow,
@@ -43,14 +42,16 @@ Rcpp::List find_one_row_debiasingM(Rcpp::NumericMatrix Sigma,
4342
// Check whether feasible
4443

4544
int kkt_check = check_KKT(theta.begin(),
46-
Sigma_theta.begin(),
45+
gradient.begin(),
4746
nrow,
4847
row,
4948
bound);
5049

5150
return(Rcpp::List::create(Rcpp::Named("soln") = theta,
52-
Rcpp::Named("Sigma_soln") = Sigma_theta,
51+
Rcpp::Named("gradient") = gradient,
5352
Rcpp::Named("iter") = iter,
54-
Rcpp::Named("kkt_check") = kkt_check));
53+
Rcpp::Named("kkt_check") = kkt_check,
54+
Rcpp::Named("ever_active") = ever_active,
55+
Rcpp::Named("nactive") = nactive));
5556

5657
}

0 commit comments

Comments
 (0)