Skip to content

Commit 516bb19

Browse files
committed
improve multcomp_pmt
1 parent 92c405f commit 516bb19

File tree

7 files changed

+65
-82
lines changed

7 files changed

+65
-82
lines changed

R/CDF.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ CDF <- R6Class(
6363
)
6464
},
6565

66-
.compile = function() NULL,
6766
.calculate_statistic = function() NULL,
6867
.calculate_side = function() NULL,
6968
.calculate_p = function() NULL,

R/MultipleComparison.R

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#' @aliases class.multcomp
66
#'
77
#' @importFrom R6 R6Class
8-
#' @importFrom compiler cmpfun
98
#' @importFrom graphics par layout mtext hist.default abline
109

1110

@@ -27,14 +26,8 @@ MultipleComparison <- R6Class(
2726
)
2827
},
2928

30-
.compile = function() {
31-
private$.statistic_func <- cmpfun(private$.statistic_func)
32-
},
33-
3429
.calculate_statistic = function() {
3530
private$.statistic <- multcomp_pmt(
36-
private$.group_ij$i,
37-
private$.group_ij$j,
3831
private$.data,
3932
attr(private$.data, "group"),
4033
private$.statistic_func,

R/RcppExports.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ ksample_pmt <- function(data, group, statistic_func, n_permu, progress) {
99
.Call(`_LearnNonparam_ksample_pmt`, data, group, statistic_func, n_permu, progress)
1010
}
1111

12-
multcomp_pmt <- function(group_i, group_j, data, group, statistic_func, n_permu, progress) {
13-
.Call(`_LearnNonparam_multcomp_pmt`, group_i, group_j, data, group, statistic_func, n_permu, progress)
12+
multcomp_pmt <- function(data, group, statistic_func, n_permu, progress) {
13+
.Call(`_LearnNonparam_multcomp_pmt`, data, group, statistic_func, n_permu, progress)
1414
}
1515

1616
paired_pmt <- function(x, y, statistic_func, n_permu, progress) {

R/Studentized.R

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -59,33 +59,35 @@ Studentized <- R6Class(
5959
.name = "Multiple Comparison Based on Studentized Statistic",
6060

6161
.define = function() {
62-
inv_lengths <- 1 / tabulate(attr(private$.data, "group"))
63-
sum_inv_lengths <- outer(inv_lengths, inv_lengths, `+`)
62+
private$.statistic_func <- function(data, group) {
63+
inv_lengths <- 1 / tabulate(group)
64+
sum_inv_lengths <- outer(inv_lengths, inv_lengths, `+`)
6465

65-
if (private$.scoring == "none") {
66-
N <- length(private$.data)
67-
k <- attr(private$.data, "group")[N]
66+
if (private$.scoring == "none") {
67+
N <- length(data)
68+
k <- group[N]
6869

69-
private$.statistic_func <- function(data, group) {
70-
means <- rowsum.default(data, group) * inv_lengths
71-
mse <- sum((data - means[group])^2) / (N - k)
70+
function(data, group) {
71+
means <- rowsum.default(data, group) * inv_lengths
72+
mse <- sum((data - means[group])^2) / (N - k)
7273

73-
function(i, j) {
74-
(means[i] - means[j]) / sqrt(
75-
mse * sum_inv_lengths[i, j]
76-
)
74+
function(i, j) {
75+
(means[i] - means[j]) / sqrt(
76+
mse * sum_inv_lengths[i, j]
77+
)
78+
}
7779
}
78-
}
79-
} else {
80-
var <- var(private$.data)
80+
} else {
81+
var <- var(data)
8182

82-
private$.statistic_func <- function(data, group) {
83-
means <- rowsum.default(data, group) * inv_lengths
83+
function(data, group) {
84+
means <- rowsum.default(data, group) * inv_lengths
8485

85-
function(i, j) {
86-
(means[i] - means[j]) / sqrt(
87-
var * sum_inv_lengths[i, j]
88-
)
86+
function(i, j) {
87+
(means[i] - means[j]) / sqrt(
88+
var * sum_inv_lengths[i, j]
89+
)
90+
}
8991
}
9092
}
9193
}

inst/include/pmt/impl_multcomp_pmt.hpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,35 @@
11
template <bool progress, typename T>
22
RObject impl_multcomp_pmt(
3-
const IntegerVector group_i,
4-
const IntegerVector group_j,
53
const NumericVector data,
64
IntegerVector group,
75
const T& statistic_func,
86
const double n_permu)
97
{
108
Stat<progress> statistic_container;
119

12-
R_xlen_t K = group_i.size();
10+
int k = *(group.end() - 1);
1311

14-
auto multcomp_update = [&statistic_container, &statistic_func, group_i, group_j, data, group, K]() {
15-
auto statistic_closure = statistic_func(data, group);
12+
auto multcomp_update = [&statistic_container, statistic_closure = statistic_func(data, group), data, group, k]() {
13+
auto pairwise_closure = statistic_closure(data, group);
1614

17-
for (R_xlen_t k = 0; k < K - 1; k++) {
18-
statistic_container << statistic_closure(group_i[k], group_j[k]);
15+
for (int i = 1; i < k - 1; i++) {
16+
for (int j = i + 1; j <= k; j++)
17+
statistic_container << pairwise_closure(i, j);
1918
}
2019

21-
return statistic_container << statistic_closure(group_i[K - 1], group_j[K - 1]);
20+
return statistic_container << pairwise_closure(k - 1, k);
2221
};
2322

2423
if (std::isnan(n_permu)) {
25-
statistic_container.init(multcomp_update, K);
24+
statistic_container.init(multcomp_update, C(k, 2));
2625
} else if (n_permu == 0) {
27-
statistic_container.init(multcomp_update, K, n_permutation(group));
26+
statistic_container.init(multcomp_update, C(k, 2), n_permutation(group));
2827

2928
do {
3029
multcomp_update();
3130
} while (next_permutation(group));
3231
} else {
33-
statistic_container.init(multcomp_update, K, n_permu);
32+
statistic_container.init(multcomp_update, C(k, 2), n_permu);
3433

3534
do {
3635
random_shuffle(group);

src/RcppExports.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,17 @@ BEGIN_RCPP
4141
END_RCPP
4242
}
4343
// multcomp_pmt
44-
SEXP multcomp_pmt(const SEXP group_i, const SEXP group_j, const SEXP data, const SEXP group, const SEXP statistic_func, const double n_permu, const bool progress);
45-
RcppExport SEXP _LearnNonparam_multcomp_pmt(SEXP group_iSEXP, SEXP group_jSEXP, SEXP dataSEXP, SEXP groupSEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP, SEXP progressSEXP) {
44+
SEXP multcomp_pmt(const SEXP data, const SEXP group, const SEXP statistic_func, const double n_permu, const bool progress);
45+
RcppExport SEXP _LearnNonparam_multcomp_pmt(SEXP dataSEXP, SEXP groupSEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP, SEXP progressSEXP) {
4646
BEGIN_RCPP
4747
Rcpp::RObject rcpp_result_gen;
4848
Rcpp::RNGScope rcpp_rngScope_gen;
49-
Rcpp::traits::input_parameter< const SEXP >::type group_i(group_iSEXP);
50-
Rcpp::traits::input_parameter< const SEXP >::type group_j(group_jSEXP);
5149
Rcpp::traits::input_parameter< const SEXP >::type data(dataSEXP);
5250
Rcpp::traits::input_parameter< const SEXP >::type group(groupSEXP);
5351
Rcpp::traits::input_parameter< const SEXP >::type statistic_func(statistic_funcSEXP);
5452
Rcpp::traits::input_parameter< const double >::type n_permu(n_permuSEXP);
5553
Rcpp::traits::input_parameter< const bool >::type progress(progressSEXP);
56-
rcpp_result_gen = Rcpp::wrap(multcomp_pmt(group_i, group_j, data, group, statistic_func, n_permu, progress));
54+
rcpp_result_gen = Rcpp::wrap(multcomp_pmt(data, group, statistic_func, n_permu, progress));
5755
return rcpp_result_gen;
5856
END_RCPP
5957
}
@@ -120,7 +118,7 @@ END_RCPP
120118
static const R_CallMethodDef CallEntries[] = {
121119
{"_LearnNonparam_twosample_pmt", (DL_FUNC) &_LearnNonparam_twosample_pmt, 5},
122120
{"_LearnNonparam_ksample_pmt", (DL_FUNC) &_LearnNonparam_ksample_pmt, 5},
123-
{"_LearnNonparam_multcomp_pmt", (DL_FUNC) &_LearnNonparam_multcomp_pmt, 7},
121+
{"_LearnNonparam_multcomp_pmt", (DL_FUNC) &_LearnNonparam_multcomp_pmt, 5},
124122
{"_LearnNonparam_paired_pmt", (DL_FUNC) &_LearnNonparam_paired_pmt, 5},
125123
{"_LearnNonparam_rcbd_pmt", (DL_FUNC) &_LearnNonparam_rcbd_pmt, 4},
126124
{"_LearnNonparam_association_pmt", (DL_FUNC) &_LearnNonparam_association_pmt, 5},

src/pmt_interface.cpp

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,35 +11,29 @@ constexpr auto Rf_lang<2> = Rf_lang2;
1111
template <>
1212
constexpr auto Rf_lang<3> = Rf_lang3;
1313

14-
template <bool sharing_args>
15-
class StatFunc : public Function {
14+
template <typename T>
15+
class CachedFunc : public Function {
1616
public:
1717
using Function::Function;
1818

1919
template <typename... Args>
2020
auto operator()(Args&&... args) const
2121
{
22-
return _invoke(Tag<sharing_args>(), std::forward<Args>(args)...);
23-
}
24-
25-
private:
26-
template <bool>
27-
struct Tag { };
28-
29-
template <typename... Args>
30-
auto _invoke(Tag<false>, Args&&... args) const
31-
{
32-
return [R_closure = Function(Function::operator()(std::forward<Args>(args)...))](auto&&... args) {
33-
return as<double>(R_closure(std::forward<decltype(args)>(args)...));
22+
return [R_call = Shield<SEXP>(Rf_lang<sizeof...(args) + 1>(Function::operator()(std::forward<Args>(args)...), std::forward<Args>(args)...))](auto&&...) {
23+
return as<T>(Rcpp_fast_eval(R_call, R_GlobalEnv));
3424
};
3525
}
26+
};
27+
28+
template <typename T>
29+
class TypedFunc : public Function {
30+
public:
31+
using Function::Function;
3632

3733
template <typename... Args>
38-
auto _invoke(Tag<true>, Args&&... args) const
34+
T operator()(Args&&... args) const
3935
{
40-
return [R_call = Shield<SEXP>(Rf_lang<sizeof...(args) + 1>(Function::operator()(std::forward<Args>(args)...), std::forward<Args>(args)...))](auto&&...) {
41-
return as<double>(Rcpp_fast_eval(R_call, R_GlobalEnv));
42-
};
36+
return as<T>(Function::operator()(std::forward<Args>(args)...));
4337
}
4438
};
4539

@@ -57,8 +51,8 @@ SEXP twosample_pmt(
5751
const bool progress)
5852
{
5953
return progress ?
60-
impl_twosample_pmt<true, StatFunc<true>>(clone(x), clone(y), statistic_func, n_permu) :
61-
impl_twosample_pmt<false, StatFunc<true>>(clone(x), clone(y), statistic_func, n_permu);
54+
impl_twosample_pmt<true, CachedFunc<double>>(clone(x), clone(y), statistic_func, n_permu) :
55+
impl_twosample_pmt<false, CachedFunc<double>>(clone(x), clone(y), statistic_func, n_permu);
6256
}
6357

6458
#include "pmt/impl_ksample_pmt.hpp"
@@ -72,25 +66,23 @@ SEXP ksample_pmt(
7266
const bool progress)
7367
{
7468
return progress ?
75-
impl_ksample_pmt<true, StatFunc<true>>(data, clone(group), statistic_func, n_permu) :
76-
impl_ksample_pmt<false, StatFunc<true>>(data, clone(group), statistic_func, n_permu);
69+
impl_ksample_pmt<true, CachedFunc<double>>(data, clone(group), statistic_func, n_permu) :
70+
impl_ksample_pmt<false, CachedFunc<double>>(data, clone(group), statistic_func, n_permu);
7771
}
7872

7973
#include "pmt/impl_multcomp_pmt.hpp"
8074

8175
// [[Rcpp::export]]
8276
SEXP multcomp_pmt(
83-
const SEXP group_i,
84-
const SEXP group_j,
8577
const SEXP data,
8678
const SEXP group,
8779
const SEXP statistic_func,
8880
const double n_permu,
8981
const bool progress)
9082
{
9183
return progress ?
92-
impl_multcomp_pmt<true, StatFunc<false>>(group_i, group_j, data, clone(group), statistic_func, n_permu) :
93-
impl_multcomp_pmt<false, StatFunc<false>>(group_i, group_j, data, clone(group), statistic_func, n_permu);
84+
impl_multcomp_pmt<true, CachedFunc<TypedFunc<double>>>(data, clone(group), statistic_func, n_permu) :
85+
impl_multcomp_pmt<false, CachedFunc<TypedFunc<double>>>(data, clone(group), statistic_func, n_permu);
9486
}
9587

9688
#include "pmt/impl_paired_pmt.hpp"
@@ -104,8 +96,8 @@ SEXP paired_pmt(
10496
const bool progress)
10597
{
10698
return progress ?
107-
impl_paired_pmt<true, StatFunc<true>>(clone(x), clone(y), statistic_func, n_permu) :
108-
impl_paired_pmt<false, StatFunc<true>>(clone(x), clone(y), statistic_func, n_permu);
99+
impl_paired_pmt<true, CachedFunc<double>>(clone(x), clone(y), statistic_func, n_permu) :
100+
impl_paired_pmt<false, CachedFunc<double>>(clone(x), clone(y), statistic_func, n_permu);
109101
}
110102

111103
#include "pmt/impl_rcbd_pmt.hpp"
@@ -118,8 +110,8 @@ SEXP rcbd_pmt(
118110
const bool progress)
119111
{
120112
return progress ?
121-
impl_rcbd_pmt<true, StatFunc<true>>(clone(data), statistic_func, n_permu) :
122-
impl_rcbd_pmt<false, StatFunc<true>>(clone(data), statistic_func, n_permu);
113+
impl_rcbd_pmt<true, CachedFunc<double>>(clone(data), statistic_func, n_permu) :
114+
impl_rcbd_pmt<false, CachedFunc<double>>(clone(data), statistic_func, n_permu);
123115
}
124116

125117
#include "pmt/impl_association_pmt.hpp"
@@ -133,8 +125,8 @@ SEXP association_pmt(
133125
const bool progress)
134126
{
135127
return progress ?
136-
impl_association_pmt<true, StatFunc<true>>(clone(x), clone(y), statistic_func, n_permu) :
137-
impl_association_pmt<false, StatFunc<true>>(clone(x), clone(y), statistic_func, n_permu);
128+
impl_association_pmt<true, CachedFunc<double>>(clone(x), clone(y), statistic_func, n_permu) :
129+
impl_association_pmt<false, CachedFunc<double>>(clone(x), clone(y), statistic_func, n_permu);
138130
}
139131

140132
#include "pmt/impl_table_pmt.hpp"
@@ -148,6 +140,6 @@ SEXP table_pmt(
148140
const bool progress)
149141
{
150142
return progress ?
151-
impl_table_pmt<true, StatFunc<true>>(clone(row), clone(col), statistic_func, n_permu) :
152-
impl_table_pmt<false, StatFunc<true>>(clone(row), clone(col), statistic_func, n_permu);
143+
impl_table_pmt<true, CachedFunc<double>>(clone(row), clone(col), statistic_func, n_permu) :
144+
impl_table_pmt<false, CachedFunc<double>>(clone(row), clone(col), statistic_func, n_permu);
153145
}

0 commit comments

Comments
 (0)