Skip to content

Commit 6829118

Browse files
committed
Changes to memory ownership of jitted nodes
1 parent 0c07bea commit 6829118

File tree

10 files changed

+240
-270
lines changed

10 files changed

+240
-270
lines changed

tree/dataframe/inc/ROOT/RDF/InterfaceUtils.hxx

Lines changed: 45 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -367,23 +367,20 @@ void CheckForNoVariations(const std::string &where, std::string_view definedColV
367367

368368
std::string PrettyPrintAddr(const void *const addr);
369369

370-
std::shared_ptr<RJittedFilter> BookFilterJit(std::shared_ptr<RNodeBase> *prevNodeOnHeap, std::string_view name,
370+
std::shared_ptr<RJittedFilter> BookFilterJit(std::shared_ptr<RNodeBase> prevNode, std::string_view name,
371371
std::string_view expression, const RColumnRegister &colRegister,
372372
TTree *tree, RDataSource *ds);
373373

374374
std::shared_ptr<RJittedDefine> BookDefineJit(std::string_view name, std::string_view expression, RLoopManager &lm,
375-
RDataSource *ds, const RColumnRegister &colRegister,
376-
std::shared_ptr<RNodeBase> *prevNodeOnHeap);
375+
RDataSource *ds, const RColumnRegister &colRegister);
377376

378377
std::shared_ptr<RJittedDefine> BookDefinePerSampleJit(std::string_view name, std::string_view expression,
379-
RLoopManager &lm, const RColumnRegister &colRegister,
380-
std::shared_ptr<RNodeBase> *upcastNodeOnHeap);
378+
RLoopManager &lm, const RColumnRegister &colRegister);
381379

382380
std::shared_ptr<RJittedVariation>
383381
BookVariationJit(const std::vector<std::string> &colNames, std::string_view variationName,
384382
const std::vector<std::string> &variationTags, std::string_view expression, RLoopManager &lm,
385-
RDataSource *ds, const RColumnRegister &colRegister, std::shared_ptr<RNodeBase> *upcastNodeOnHeap,
386-
bool isSingleColumn);
383+
RDataSource *ds, const RColumnRegister &colRegister, bool isSingleColumn);
387384

388385
std::string JitBuildAction(const ColumnNames_t &bl, const std::type_info &art, const std::type_info &at, TTree *tree,
389386
const unsigned int nSlots, const RColumnRegister &colRegister, RDataSource *ds,
@@ -471,42 +468,32 @@ void AddDSColumns(const std::vector<std::string> &requiredCols, ROOT::Detail::RD
471468
ROOT::Internal::RDF::RColumnRegister &colRegister);
472469

473470
// this function is meant to be called by the jitted code generated by BookFilterJit
474-
template <typename F, typename PrevNode>
475-
void JitFilterHelper(F &&f, const ColumnNames_t &cols, std::string_view name,
476-
std::weak_ptr<RJittedFilter> *wkJittedFilter, std::shared_ptr<PrevNode> *prevNodeOnHeap,
477-
RColumnRegister *colRegister) noexcept
471+
template <typename F>
472+
void JitFilterHelper(F &&f, const ColumnNames_t &cols, std::string_view name, RColumnRegister &colRegister,
473+
ROOT::Detail::RDF::RLoopManager &lm, ROOT::Detail::RDF::RJittedFilter *jittedFilter) noexcept
478474
{
479-
if (wkJittedFilter->expired()) {
475+
if (!jittedFilter) {
480476
// The branch of the computation graph that needed this jitted code went out of scope between the type
481477
// jitting was booked and the time jitting actually happened. Nothing to do other than cleaning up.
482-
delete wkJittedFilter;
483-
delete colRegister;
484-
delete prevNodeOnHeap;
485478
return;
486479
}
487480

488-
const auto jittedFilter = wkJittedFilter->lock();
489-
490481
// mock Filter logic -- validity checks and Define-ition of RDataSource columns
491482
using Callable_t = std::decay_t<F>;
492-
using F_t = RFilter<Callable_t, PrevNode>;
483+
auto prevNode = jittedFilter->MoveOutPrevNode();
484+
using PrevNode_t = typename decltype(prevNode)::element_type;
485+
using F_t = RFilter<Callable_t, PrevNode_t>;
493486
using ColTypes_t = typename TTraits::CallableTraits<Callable_t>::arg_types;
494487
constexpr auto nColumns = ColTypes_t::list_size;
495488
CheckFilter(f);
496489

497-
auto &lm = *jittedFilter->GetLoopManagerUnchecked(); // RLoopManager must exist at this time
498490
auto ds = lm.GetDataSource();
499491

500-
if (ds != nullptr)
501-
AddDSColumns(cols, lm, *ds, ColTypes_t(), *colRegister);
492+
if (ds != nullptr && !cols.empty())
493+
AddDSColumns(cols, lm, *ds, ColTypes_t(), colRegister);
502494

503495
jittedFilter->SetFilter(
504-
std::unique_ptr<RFilterBase>(new F_t(std::forward<F>(f), cols, *prevNodeOnHeap, *colRegister, name)));
505-
// colRegister points to the columns structure in the heap, created before the jitted call so that the jitter can
506-
// share data after it has lazily compiled the code. Here the data has been used and the memory can be freed.
507-
delete colRegister;
508-
delete prevNodeOnHeap;
509-
delete wkJittedFilter;
496+
std::unique_ptr<RFilterBase>(new F_t(std::forward<F>(f), cols, prevNode, colRegister, name)));
510497
}
511498

512499
namespace DefineTypes {
@@ -534,124 +521,80 @@ auto MakeDefineNode(DefineTypes::RDefinePerSampleTag, std::string_view name, std
534521
// This function is meant to be called by jitted code right before starting the event loop.
535522
// If colsPtr is null, build a RDefinePerSample (it has no input columns), otherwise a RDefine.
536523
template <typename RDefineTypeTag, typename F>
537-
void JitDefineHelper(F &&f, const ColumnNames_t &cols, std::string_view name, RLoopManager *lm,
538-
std::weak_ptr<RJittedDefine> *wkJittedDefine, RColumnRegister *colRegister,
539-
std::shared_ptr<RNodeBase> *prevNodeOnHeap) noexcept
524+
void JitDefineHelper(F &&f, const ColumnNames_t &cols, std::string_view name, RColumnRegister &colRegister,
525+
ROOT::Detail::RDF::RLoopManager &lm, ROOT::Detail::RDF::RJittedDefine *jittedDefine) noexcept
540526
{
541-
// a helper to delete objects allocated before jitting, so that the jitter can share data with lazily jitted code
542-
auto doDeletes = [&] {
543-
delete wkJittedDefine;
544-
delete colRegister;
545-
delete prevNodeOnHeap;
546-
};
547-
548-
if (wkJittedDefine->expired()) {
527+
528+
if (!jittedDefine) {
549529
// The branch of the computation graph that needed this jitted code went out of scope between the type
550530
// jitting was booked and the time jitting actually happened. Nothing to do other than cleaning up.
551-
doDeletes();
552531
return;
553532
}
554533

555-
auto jittedDefine = wkJittedDefine->lock();
556-
557534
using Callable_t = std::decay_t<F>;
558535
using ColTypes_t = typename TTraits::CallableTraits<Callable_t>::arg_types;
559536

560-
auto ds = lm->GetDataSource();
561-
if (ds != nullptr)
562-
AddDSColumns(cols, *lm, *ds, ColTypes_t(), *colRegister);
537+
auto ds = lm.GetDataSource();
538+
if (ds != nullptr && !cols.empty())
539+
AddDSColumns(cols, lm, *ds, ColTypes_t(), colRegister);
563540

564541
// will never actually be used (trumped by jittedDefine->GetTypeName()), but we set it to something meaningful
565542
// to help devs debugging
566543
const auto dummyType = "jittedCol_t";
567544
// use unique_ptr<RDefineBase> instead of make_unique<NewCol_t> to reduce jit/compile-times
568545
std::unique_ptr<RDefineBase> newCol{
569-
MakeDefineNode(RDefineTypeTag{}, name, dummyType, std::forward<F>(f), cols, *colRegister, *lm)};
546+
MakeDefineNode(RDefineTypeTag{}, name, dummyType, std::forward<F>(f), cols, colRegister, lm)};
570547
jittedDefine->SetDefine(std::move(newCol));
571-
572-
doDeletes();
573548
}
574549

575550
template <bool IsSingleColumn, typename F>
576-
void JitVariationHelper(F &&f, const ColumnNames_t &inputColNames, const ColumnNames_t &variedColNames,
577-
const char **variationTags, std::size_t variationTagsSize, std::string_view variationName,
578-
RLoopManager *lm, std::weak_ptr<RJittedVariation> *wkJittedVariation,
579-
RColumnRegister *colRegister, std::shared_ptr<RNodeBase> *prevNodeOnHeap) noexcept
551+
void JitVariationHelper(F &&f, const ColumnNames_t &inputColNames, std::string_view variationName,
552+
RColumnRegister &colRegister, ROOT::Detail::RDF::RLoopManager &lm,
553+
RJittedVariation *jittedVariation, const ColumnNames_t &variedColNames,
554+
const ColumnNames_t &variationTags) noexcept
580555
{
581-
// a helper to delete objects allocated before jitting, so that the jitter can share data with lazily jitted code
582-
auto doDeletes = [&] {
583-
delete[] variationTags;
584-
delete wkJittedVariation;
585-
delete colRegister;
586-
delete prevNodeOnHeap;
587-
};
588-
589-
if (wkJittedVariation->expired()) {
556+
557+
if (!jittedVariation) {
590558
// The branch of the computation graph that needed this jitted variation went out of scope between the type
591559
// jitting was booked and the time jitting actually happened. Nothing to do other than cleaning up.
592-
doDeletes();
593560
return;
594561
}
595562

596-
std::vector<std::string> tags(variationTags, variationTags + variationTagsSize);
597-
598-
auto jittedVariation = wkJittedVariation->lock();
599-
600563
using Callable_t = std::decay_t<F>;
601564
using ColTypes_t = typename TTraits::CallableTraits<Callable_t>::arg_types;
602565

603-
auto ds = lm->GetDataSource();
604-
if (ds != nullptr)
605-
AddDSColumns(inputColNames, *lm, *ds, ColTypes_t(), *colRegister);
566+
auto ds = lm.GetDataSource();
567+
if (ds != nullptr && !inputColNames.empty())
568+
AddDSColumns(inputColNames, lm, *ds, ColTypes_t(), colRegister);
606569

607570
// use unique_ptr<RDefineBase> instead of make_unique<NewCol_t> to reduce jit/compile-times
608-
std::unique_ptr<RVariationBase> newVariation{new RVariation<std::decay_t<F>, IsSingleColumn>(
609-
std::move(variedColNames), variationName, std::forward<F>(f), std::move(tags), jittedVariation->GetTypeName(),
610-
*colRegister, *lm, inputColNames)};
571+
std::unique_ptr<RVariationBase> newVariation{
572+
new RVariation<std::decay_t<F>, IsSingleColumn>(variedColNames, variationName, std::forward<F>(f), variationTags,
573+
jittedVariation->GetTypeName(), colRegister, lm, inputColNames)};
611574
jittedVariation->SetVariation(std::move(newVariation));
612-
613-
doDeletes();
614575
}
615576

616577
/// Convenience function invoked by jitted code to build action nodes at runtime
617-
template <typename ActionTag, typename... ColTypes, typename PrevNodeType, typename HelperArgType>
618-
void CallBuildAction(std::shared_ptr<PrevNodeType> *prevNodeOnHeap, const ColumnNames_t &cols,
619-
const unsigned int nSlots, std::shared_ptr<HelperArgType> *helperArgOnHeap,
620-
std::weak_ptr<RJittedAction> *wkJittedActionOnHeap, RColumnRegister *colRegister) noexcept
578+
template <typename ActionTag, typename... ColTypes, typename HelperArgType>
579+
void CallBuildAction(const ColumnNames_t &cols, RColumnRegister &colRegister, ROOT::Detail::RDF::RLoopManager &lm,
580+
RJittedAction *jittedAction, unsigned int nSlots,
581+
std::shared_ptr<HelperArgType> *helperArg) noexcept
621582
{
622-
// a helper to delete objects allocated before jitting, so that the jitter can share data with lazily jitted code
623-
auto doDeletes = [&] {
624-
delete helperArgOnHeap;
625-
delete wkJittedActionOnHeap;
626-
// colRegister must be deleted before prevNodeOnHeap because their dtor needs the RLoopManager to be alive
627-
// and prevNodeOnHeap is what keeps it alive if the rest of the computation graph is already out of scope
628-
delete colRegister;
629-
delete prevNodeOnHeap;
630-
};
631-
632-
if (wkJittedActionOnHeap->expired()) {
583+
if (!jittedAction) {
633584
// The branch of the computation graph that needed this jitted variation went out of scope between the type
634585
// jitting was booked and the time jitting actually happened. Nothing to do other than cleaning up.
635-
doDeletes();
636586
return;
637587
}
638588

639-
auto jittedActionOnHeap = wkJittedActionOnHeap->lock();
640-
641-
// if we are here it means we are jitting, if we are jitting the loop manager must be alive
642-
auto &prevNodePtr = *prevNodeOnHeap;
643-
auto &loopManager = *prevNodePtr->GetLoopManagerUnchecked();
644589
using ColTypes_t = TypeList<ColTypes...>;
645590
constexpr auto nColumns = ColTypes_t::list_size;
646-
auto ds = loopManager.GetDataSource();
647-
if (ds != nullptr)
648-
AddDSColumns(cols, loopManager, *ds, ColTypes_t(), *colRegister);
649-
650-
auto actionPtr = BuildAction<ColTypes...>(cols, std::move(*helperArgOnHeap), nSlots, std::move(prevNodePtr),
651-
ActionTag{}, *colRegister);
652-
jittedActionOnHeap->SetAction(std::move(actionPtr));
591+
auto ds = lm.GetDataSource();
592+
if (ds != nullptr && !cols.empty())
593+
AddDSColumns(cols, lm, *ds, ColTypes_t(), colRegister);
653594

654-
doDeletes();
595+
auto actionPtr =
596+
BuildAction<ColTypes...>(cols, *helperArg, nSlots, jittedAction->MoveOutPrevNode(), ActionTag{}, colRegister);
597+
jittedAction->SetAction(std::move(actionPtr));
655598
}
656599

657600
/// The contained `type` alias is `double` if `T == RInferredType`, `U` if `T == std::container<U>`, `T` otherwise.

tree/dataframe/inc/ROOT/RDF/RInterface.hxx

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -293,12 +293,8 @@ public:
293293
/// ~~~
294294
RInterface<RDFDetail::RJittedFilter, DS_t> Filter(std::string_view expression, std::string_view name = "")
295295
{
296-
// deleted by the jitted call to JitFilterHelper
297-
auto upcastNodeOnHeap = RDFInternal::MakeSharedOnHeap(RDFInternal::UpcastNode(fProxiedPtr));
298-
using BaseNodeType_t = typename std::remove_pointer_t<decltype(upcastNodeOnHeap)>::element_type;
299-
RInterface<BaseNodeType_t> upcastInterface(*upcastNodeOnHeap, *fLoopManager, fColRegister);
300-
const auto jittedFilter =
301-
RDFInternal::BookFilterJit(upcastNodeOnHeap, name, expression, fColRegister, nullptr, GetDataSource());
296+
const auto jittedFilter = RDFInternal::BookFilterJit(RDFInternal::UpcastNode(fProxiedPtr), name, expression,
297+
fColRegister, nullptr, GetDataSource());
302298

303299
return RInterface<RDFDetail::RJittedFilter, DS_t>(std::move(jittedFilter), *fLoopManager, fColRegister);
304300
}
@@ -540,9 +536,7 @@ public:
540536
RDFInternal::CheckForRedefinition(where, name, fColRegister,
541537
GetDataSource() ? GetDataSource()->GetColumnNames() : ColumnNames_t{});
542538

543-
auto upcastNodeOnHeap = RDFInternal::MakeSharedOnHeap(RDFInternal::UpcastNode(fProxiedPtr));
544-
auto jittedDefine =
545-
RDFInternal::BookDefineJit(name, expression, *fLoopManager, GetDataSource(), fColRegister, upcastNodeOnHeap);
539+
auto jittedDefine = RDFInternal::BookDefineJit(name, expression, *fLoopManager, GetDataSource(), fColRegister);
546540

547541
RDFInternal::RColumnRegister newCols(fColRegister);
548542
newCols.AddDefine(std::move(jittedDefine));
@@ -630,9 +624,7 @@ public:
630624
GetDataSource() ? GetDataSource()->GetColumnNames() : ColumnNames_t{});
631625
RDFInternal::CheckForNoVariations(where, name, fColRegister);
632626

633-
auto upcastNodeOnHeap = RDFInternal::MakeSharedOnHeap(RDFInternal::UpcastNode(fProxiedPtr));
634-
auto jittedDefine =
635-
RDFInternal::BookDefineJit(name, expression, *fLoopManager, GetDataSource(), fColRegister, upcastNodeOnHeap);
627+
auto jittedDefine = RDFInternal::BookDefineJit(name, expression, *fLoopManager, GetDataSource(), fColRegister);
636628

637629
RDFInternal::RColumnRegister newCols(fColRegister);
638630
newCols.AddDefine(std::move(jittedDefine));
@@ -807,9 +799,7 @@ public:
807799
RDFInternal::CheckForRedefinition("DefinePerSample", name, fColRegister,
808800
GetDataSource() ? GetDataSource()->GetColumnNames() : ColumnNames_t{});
809801

810-
auto upcastNodeOnHeap = RDFInternal::MakeSharedOnHeap(RDFInternal::UpcastNode(fProxiedPtr));
811-
auto jittedDefine =
812-
RDFInternal::BookDefinePerSampleJit(name, expression, *fLoopManager, fColRegister, upcastNodeOnHeap);
802+
auto jittedDefine = RDFInternal::BookDefinePerSampleJit(name, expression, *fLoopManager, fColRegister);
813803

814804
RDFInternal::RColumnRegister newCols(fColRegister);
815805
newCols.AddDefine(std::move(jittedDefine));
@@ -3420,10 +3410,9 @@ private:
34203410
throw std::logic_error("A column name was passed to the same Vary invocation multiple times.");
34213411
}
34223412

3423-
auto upcastNodeOnHeap = RDFInternal::MakeSharedOnHeap(RDFInternal::UpcastNode(fProxiedPtr));
34243413
auto jittedVariation =
34253414
RDFInternal::BookVariationJit(colNames, variationName, variationTags, expression, *fLoopManager,
3426-
GetDataSource(), fColRegister, upcastNodeOnHeap, isSingleColumn);
3415+
GetDataSource(), fColRegister, isSingleColumn);
34273416

34283417
RDFInternal::RColumnRegister newColRegister(fColRegister);
34293418
newColRegister.AddVariation(std::move(jittedVariation));

tree/dataframe/inc/ROOT/RDF/RInterfaceBase.hxx

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -196,19 +196,14 @@ protected:
196196
const auto validColumnNames = GetValidatedColumnNames(realNColumns, columns);
197197
const unsigned int nSlots = fLoopManager->GetNSlots();
198198

199-
auto *helperArgOnHeap = RDFInternal::MakeSharedOnHeap(helperArg);
199+
const auto jittedAction = std::make_shared<RDFInternal::RJittedAction>(
200+
*fLoopManager, validColumnNames, fColRegister, proxiedPtr->GetVariations(), proxiedPtr);
200201

201-
auto upcastNodeOnHeap = RDFInternal::MakeSharedOnHeap(RDFInternal::UpcastNode(proxiedPtr));
202-
203-
const auto jittedAction = std::make_shared<RDFInternal::RJittedAction>(*fLoopManager, validColumnNames,
204-
fColRegister, proxiedPtr->GetVariations());
205-
auto jittedActionOnHeap = RDFInternal::MakeWeakOnHeap(jittedAction);
206-
207-
auto definesCopy = new RDFInternal::RColumnRegister(fColRegister); // deleted in jitted call
208202
auto funcBody = RDFInternal::JitBuildAction(validColumnNames, typeid(HelperArgType), typeid(ActionTag), nullptr,
209203
nSlots, fColRegister, GetDataSource(), vector2RVec);
210-
fLoopManager->RegisterJitHelperCall(funcBody, upcastNodeOnHeap, definesCopy, validColumnNames, jittedActionOnHeap,
211-
helperArgOnHeap);
204+
fLoopManager->RegisterJitHelperCall(funcBody,
205+
std::make_unique<ROOT::Internal::RDF::RColumnRegister>(fColRegister),
206+
validColumnNames, jittedAction, helperArg);
212207
return MakeResultPtr(r, *fLoopManager, std::move(jittedAction));
213208
}
214209

tree/dataframe/inc/ROOT/RDF/RJittedAction.hxx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace ROOT {
2323
namespace Detail {
2424
namespace RDF {
2525
class RMergeableValueBase;
26+
class RNodeBase;
2627
} // namespace RDF
2728
} // namespace Detail
2829
} // namespace ROOT
@@ -39,10 +40,12 @@ class GraphNode;
3940
class RJittedAction : public RActionBase {
4041
private:
4142
std::unique_ptr<RActionBase> fConcreteAction;
43+
std::shared_ptr<ROOT::Detail::RDF::RNodeBase> fPrevNode;
4244

4345
public:
4446
RJittedAction(RLoopManager &lm, const ROOT::RDF::ColumnNames_t &columns, const RColumnRegister &colRegister,
45-
const std::vector<std::string> &prevVariations);
47+
const std::vector<std::string> &prevVariations,
48+
std::shared_ptr<ROOT::Detail::RDF::RNodeBase> prevNode = nullptr);
4649
~RJittedAction();
4750

4851
void SetAction(std::unique_ptr<RActionBase> a) { fConcreteAction = std::move(a); }
@@ -67,6 +70,7 @@ public:
6770

6871
std::unique_ptr<RActionBase> MakeVariedAction(std::vector<void *> &&results) final;
6972
std::unique_ptr<ROOT::Internal::RDF::RActionBase> CloneAction(void *newResult) final;
73+
std::shared_ptr<ROOT::Detail::RDF::RNodeBase> MoveOutPrevNode();
7074
};
7175

7276
} // ns RDF

tree/dataframe/inc/ROOT/RDF/RJittedFilter.hxx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,11 @@ namespace RDFGraphDrawing = ROOT::Internal::RDF::GraphDrawing;
3838
/// at a later time, from jitted code.
3939
class RJittedFilter final : public RFilterBase {
4040
std::unique_ptr<RFilterBase> fConcreteFilter = nullptr;
41+
std::shared_ptr<RNodeBase> fPrevNode;
4142

4243
public:
43-
RJittedFilter(RLoopManager *lm, std::string_view name, const std::vector<std::string> &variations);
44+
RJittedFilter(RLoopManager *lm, std::string_view name, const std::vector<std::string> &variations,
45+
std::shared_ptr<ROOT::Detail::RDF::RNodeBase> prevNode = nullptr);
4446

4547
// Rule of five
4648

@@ -68,6 +70,7 @@ public:
6870
std::shared_ptr<RDFGraphDrawing::GraphNode>
6971
GetGraph(std::unordered_map<void *, std::shared_ptr<RDFGraphDrawing::GraphNode>> &visitedMap) final;
7072
std::shared_ptr<RNodeBase> GetVariedFilter(const std::string &variationName) final;
73+
std::shared_ptr<RNodeBase> MoveOutPrevNode();
7174
};
7275

7376
} // ns RDF

0 commit comments

Comments
 (0)