Skip to content

Commit 6cbdbd2

Browse files
committed
Update RooFit ATLAS benchmarks to run with AD by default
1 parent 2221b6f commit 6cbdbd2

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

root/roofit/atlas-benchmarks/roofitAtlasHiggsBenchmark.cxx

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ BenchmarkData &bmdata()
5252

5353
static void benchCreateNLL(benchmark::State &state)
5454
{
55-
const std::string batchMode = state.range(0) == 2 ? "codegen" : (state.range(0) == 1 ? "cpu" : "off");
55+
const std::string evalBackend = state.range(0) == 2 ? "codegen" : (state.range(0) == 1 ? "cpu" : "legacy");
5656
auto &nllPtr =
5757
state.range(0) == 2 ? bmdata().codegenNll : (state.range(0) == 1 ? bmdata().batchedNll : bmdata().nll);
5858

@@ -61,8 +61,8 @@ static void benchCreateNLL(benchmark::State &state)
6161
bmdata().data = bmdata().ws->data("toyData");
6262

6363
for (auto _ : state) {
64-
nllPtr = std::unique_ptr<RooAbsReal>{bmdata().pdf->createNLL(*bmdata().data, GlobalObservables(*globObs),
65-
Offset(true), Optimize(2), BatchMode(batchMode))};
64+
nllPtr = std::unique_ptr<RooAbsReal>{bmdata().pdf->createNLL(
65+
*bmdata().data, GlobalObservables(*globObs), Offset(true), Optimize(2), EvalBackend(evalBackend))};
6666
}
6767

6868
double val = nllPtr->getVal();
@@ -97,7 +97,7 @@ static void benchEvaluateNLL(benchmark::State &state)
9797
static void benchMinimizeNLL(benchmark::State &state)
9898
{
9999
auto &nllPtr =
100-
state.range(1) == 2 ? bmdata().codegenNll : (state.range(1) == 1 ? bmdata().batchedNll : bmdata().nll);
100+
state.range(0) == 2 ? bmdata().codegenNll : (state.range(0) == 1 ? bmdata().batchedNll : bmdata().nll);
101101

102102
RooArgSet parameters;
103103
nllPtr->getParameters(nullptr, parameters);
@@ -115,23 +115,23 @@ static void benchMinimizeNLL(benchmark::State &state)
115115
parameters.assign(initialParams);
116116
}
117117

118-
BENCHMARK(benchCreateNLL)->Name("createNLL")->Args({0})->Unit(kSecond)->Iterations(1);
119-
BENCHMARK(benchCreateNLL)->Name("createNLL_BatchMode")->Args({1})->Unit(kSecond)->Iterations(1);
118+
// BENCHMARK(benchCreateNLL)->Name("createNLL")->Args({0})->Unit(kSecond)->Iterations(1);
119+
// BENCHMARK(benchCreateNLL)->Name("createNLL_CPU")->Args({1})->Unit(kSecond)->Iterations(1);
120120
BENCHMARK(benchCreateNLL)->Name("createNLL_CodeGenAD")->Args({2})->Unit(kSecond)->Iterations(1);
121-
BENCHMARK(benchEvaluateNLL)->Name("evaluateNLL")->Args({1, 0})->Unit(kMillisecond);
122-
BENCHMARK(benchEvaluateNLL)->Name("evaluateNLL_BatchMode")->Args({1, 1})->Unit(kMillisecond);
123-
BENCHMARK(benchEvaluateNLL)->Name("evaluateNLL_CodeGenAD")->Args({1, 2})->Unit(kMillisecond);
124-
BENCHMARK(benchEvaluateNLL)->Name("evaluateNLL_SingleKick")->Args({0, 0})->Unit(kMillisecond);
125-
BENCHMARK(benchEvaluateNLL)->Name("evaluateNLL_BatchMode_SingleKick")->Args({0, 1})->Unit(kMillisecond);
126-
BENCHMARK(benchEvaluateNLL)->Name("evaluateNLL_CodeGenAD_SingleKick")->Args({0, 2})->Unit(kMillisecond);
121+
// BENCHMARK(benchEvaluateNLL)->Name("evaluateNLL")->Args({1, 0})->Unit(kMillisecond);
122+
// BENCHMARK(benchEvaluateNLL)->Name("evaluateNLL_CPU")->Args({1, 1})->Unit(kMillisecond);
123+
// BENCHMARK(benchEvaluateNLL)->Name("evaluateNLL_CodeGenAD")->Args({1, 2})->Unit(kMillisecond);
124+
// BENCHMARK(benchEvaluateNLL)->Name("evaluateNLL_SingleKick")->Args({0, 0})->Unit(kMillisecond);
125+
// BENCHMARK(benchEvaluateNLL)->Name("evaluateNLL_CPU_SingleKick")->Args({0, 1})->Unit(kMillisecond);
126+
// BENCHMARK(benchEvaluateNLL)->Name("evaluateNLL_CodeGenAD_SingleKick")->Args({0, 2})->Unit(kMillisecond);
127127
// BENCHMARK(benchMinimizeNLL)->Name("minimizeNLL")->Args({0})->Unit(kSecond)->Iterations(1);
128-
// BENCHMARK(benchMinimizeNLL)->Name("minimizeNLL_BatchMode")->Args({1})->Unit(kSecond)->Iterations(1);
128+
// BENCHMARK(benchMinimizeNLL)->Name("minimizeNLL_CPU")->Args({1})->Unit(kSecond)->Iterations(1);
129129
BENCHMARK(benchMinimizeNLL)->Name("minimizeNLL_CodeGenAD")->Args({2})->Unit(kSecond)->Iterations(1);
130130

131131
// The channels 221 to 231 inclusive of the full combination workspace are
132132
// unfortunately corrupt. They contain RooAddPdfs that are affected by the
133133
// notorious server-proxy-desyncing that can happen if the RooFit frameworks do
134-
// the wrong thing. Th new the BatchMode uses the client-server links to build
134+
// the wrong thing. Th new the EvalBackend uses the client-server links to build
135135
// the computation graph for evaluation, and the old RooFit uses the proxies
136136
// (the client-server links are only used to the dirty flag propagation in the
137137
// old RooFit). As the servers and proxies are out of sync in some channels, we
@@ -213,10 +213,10 @@ int main(int argc, char **argv)
213213
bmdata().ws = bmdata().tfile->Get<RooWorkspace>(workspaceNames[iWorkspace].c_str());
214214
auto mc = static_cast<RooStats::ModelConfig *>(bmdata().ws->obj("ModelConfig"));
215215

216-
// bmdata().pdf = mc->GetPdf();
216+
bmdata().pdf = mc->GetPdf();
217217
// Use this instead to create a new simultaneous pdf that only includes a
218218
// subset of the channels:
219-
bmdata().pdf = createSimPdfSubset(*bmdata().ws, "simPdfSubset", 0, 1);
219+
// bmdata().pdf = createSimPdfSubset(*bmdata().ws, "simPdfSubset", 0, 5);
220220

221221
// Mask broken channels of the full Higgs combination workspace.
222222
if (iWorkspace == 2) {

0 commit comments

Comments
 (0)