Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 48 additions & 38 deletions tree/dataframe/inc/ROOT/RDF/RTreeColumnReader.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "RColumnReaderBase.hxx"
#include <ROOT/RVec.hxx>
#include "ROOT/RDF/Utils.hxx"
#include <Rtypes.h> // Long64_t, R__CLING_PTRCHECK
#include <TTreeReader.h>
#include <TTreeReaderValue.h>
Expand All @@ -22,6 +23,7 @@
#include <array>
#include <memory>
#include <string>
#include <cstddef>

namespace ROOT {
namespace Internal {
Expand All @@ -30,13 +32,14 @@ namespace RDF {
/// RTreeColumnReader specialization for TTree values read via TTreeReaderValues
template <typename T>
class R__CLING_PTRCHECK(off) RTreeColumnReader final : public ROOT::Detail::RDF::RColumnReaderBase {
std::unique_ptr<TTreeReaderValue<T>> fTreeValue;
std::unique_ptr<TTreeReaderUntypedValue> fTreeValue;

void *GetImpl(Long64_t) final { return fTreeValue->Get(); }
public:
/// Construct the RTreeColumnReader. Actual initialization is performed lazily by the Init method.
RTreeColumnReader(TTreeReader &r, const std::string &colName)
: fTreeValue(std::make_unique<TTreeReaderValue<T>>(r, colName.c_str()))
: fTreeValue(std::make_unique<TTreeReaderUntypedValue>(r, colName.c_str(),
ROOT::Internal::RDF::TypeID2TypeName(typeid(T))))
{
}
};
Expand All @@ -59,16 +62,21 @@ public:
/// TTreeReaderArrays are used whenever the RDF column type is RVec<T>.
template <typename T>
class R__CLING_PTRCHECK(off) RTreeColumnReader<RVec<T>> final : public ROOT::Detail::RDF::RColumnReaderBase {
std::unique_ptr<TTreeReaderArray<T>> fTreeArray;
std::unique_ptr<TTreeReaderUntypedArray> fTreeArray;

using Byte_t = std::byte;

/// We return a reference to this RVec to clients, to guarantee a stable address and contiguous memory layout.
RVec<T> fRVec;
RVec<Byte_t> fRVec;

Long64_t fLastEntry = -1;

/// Whether we already printed a warning about performing a copy of the TTreeReaderArray contents
bool fCopyWarningPrinted = false;

/// The size of the collection value type.
std::size_t fValueSize{};

void *GetImpl(Long64_t entry) final
{
if (entry == fLastEntry)
Expand All @@ -86,11 +94,10 @@ class R__CLING_PTRCHECK(off) RTreeColumnReader<RVec<T>> final : public ROOT::Det
// trigger loading of the contents of the TTreeReaderArray
// the address of the first element in the reader array is not necessarily equal to
// the address returned by the GetAddress method
auto readerArrayAddr = &readerArray.At(0);
RVec<T> rvec(readerArrayAddr, readerArraySize);
RVec<Byte_t> rvec(readerArray.At(0), readerArraySize);
swap(fRVec, rvec);
} else {
RVec<T> emptyVec{};
RVec<Byte_t> emptyVec{};
swap(fRVec, emptyVec);
}
} else {
Expand All @@ -107,10 +114,20 @@ class R__CLING_PTRCHECK(off) RTreeColumnReader<RVec<T>> final : public ROOT::Det
(void)fCopyWarningPrinted;
#endif
if (readerArraySize > 0) {
RVec<T> rvec(readerArray.begin(), readerArray.end());
swap(fRVec, rvec);
// Caching the value type size since GetValueSize might be expensive.
if (fValueSize == 0)
fValueSize = readerArray.GetValueSize();
assert(fValueSize > 0 && "Could not retrieve size of collection value type.");
// Array is not contiguous, make a full copy of it.
fRVec = RVec<Byte_t>();
fRVec.reserve(readerArraySize * fValueSize);
for (std::size_t i{0}; i < readerArraySize; i++) {
auto val = readerArray.At(i);
std::copy(val, val + fValueSize, std::back_inserter(fRVec));
}
fRVec.resize(readerArraySize);
} else {
RVec<T> emptyVec{};
RVec<Byte_t> emptyVec{};
swap(fRVec, emptyVec);
}
}
Expand All @@ -120,7 +137,8 @@ class R__CLING_PTRCHECK(off) RTreeColumnReader<RVec<T>> final : public ROOT::Det

public:
RTreeColumnReader(TTreeReader &r, const std::string &colName)
: fTreeArray(std::make_unique<TTreeReaderArray<T>>(r, colName.c_str()))
: fTreeArray(
std::make_unique<TTreeReaderUntypedArray>(r, colName, ROOT::Internal::RDF::TypeID2TypeName(typeid(T))))
{
}
};
Expand All @@ -131,10 +149,12 @@ public:
template <>
class R__CLING_PTRCHECK(off) RTreeColumnReader<RVec<bool>> final : public ROOT::Detail::RDF::RColumnReaderBase {

std::unique_ptr<TTreeReaderArray<bool>> fTreeArray;
using Byte_t = std::byte;

std::unique_ptr<TTreeReaderUntypedArray> fTreeArray;

/// We return a reference to this RVec to clients, to guarantee a stable address and contiguous memory layout
RVec<bool> fRVec;
RVec<Byte_t> fRVec;

// We always copy the contents of TTreeReaderArray<bool> into an RVec<bool> (never take a view into the memory
// buffer) because the underlying memory buffer might be the one of a std::vector<bool>, which is not a contiguous
Expand All @@ -146,19 +166,25 @@ class R__CLING_PTRCHECK(off) RTreeColumnReader<RVec<bool>> final : public ROOT::
auto &readerArray = *fTreeArray;
const auto readerArraySize = readerArray.GetSize();
if (readerArraySize > 0) {
// always perform a copy
RVec<bool> rvec(readerArray.begin(), readerArray.end());
swap(fRVec, rvec);
// Always perform a copy
fRVec = RVec<Byte_t>();
fRVec.reserve(readerArraySize * sizeof(bool));
for (std::size_t i{0}; i < readerArraySize; i++) {
auto val = readerArray.At(i);
std::copy(val, val + sizeof(bool), std::back_inserter(fRVec));
}
fRVec.resize(readerArraySize);
} else {
RVec<bool> emptyVec{};
RVec<Byte_t> emptyVec{};
swap(fRVec, emptyVec);
}
return &fRVec;
}

public:
RTreeColumnReader(TTreeReader &r, const std::string &colName)
: fTreeArray(std::make_unique<TTreeReaderArray<bool>>(r, colName.c_str()))
: fTreeArray(std::make_unique<TTreeReaderUntypedArray>(r, colName.c_str(),
ROOT::Internal::RDF::TypeID2TypeName(typeid(bool))))
{
}
};
Expand All @@ -168,32 +194,16 @@ public:
/// This specialization is used when the requested type for reading is std::array
template <typename T, std::size_t N>
class R__CLING_PTRCHECK(off) RTreeColumnReader<std::array<T, N>> final : public ROOT::Detail::RDF::RColumnReaderBase {
std::unique_ptr<TTreeReaderArray<T>> fTreeArray;

/// We return a reference to this RVec to clients, to guarantee a stable address and contiguous memory layout
RVec<T> fArray;
std::unique_ptr<TTreeReaderUntypedArray> fTreeArray;

Long64_t fLastEntry = -1;

void *GetImpl(Long64_t entry) final
{
if (entry == fLastEntry)
return fArray.data();

// This is a non-owning view on the contents of the TTreeReaderArray
RVec<T> view{&fTreeArray->At(0), fTreeArray->GetSize()};
swap(fArray, view);

fLastEntry = entry;
// The data member of this class is an RVec, to avoid an extra copy
// but we need to return the array buffer as the reader expects
// a std::array
return fArray.data();
}
void *GetImpl(Long64_t) final { return fTreeArray->At(0); }

public:
RTreeColumnReader(TTreeReader &r, const std::string &colName)
: fTreeArray(std::make_unique<TTreeReaderArray<T>>(r, colName.c_str()))
: fTreeArray(std::make_unique<TTreeReaderUntypedArray>(r, colName.c_str(),
ROOT::Internal::RDF::TypeID2TypeName(typeid(T))))
{
}
};
Expand Down
2 changes: 2 additions & 0 deletions tree/treeplayer/inc/TBranchProxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,8 @@ namespace Detail {
Int_t GetOffset() { return fOffset; }

bool GetSuppressErrorsForMissingBranch() const { return fSuppressMissingBranchError; }

Int_t GetStreamerElementSize() const;
};
} // namespace Detail

Expand Down
21 changes: 21 additions & 0 deletions tree/treeplayer/inc/TTreeReaderArray.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// @(#)root/tree:$Id$
// Author: Axel Naumann, 2010-08-02
// Author: Vincenzo Eduardo Padulano CERN 02/2025

/*************************************************************************
* Copyright (C) 1995-2013, Rene Brun and Fons Rademakers. *
Expand All @@ -15,6 +16,7 @@
#include "TTreeReaderValue.h"
#include "TTreeReaderUtils.h"
#include <type_traits>
#include <cstddef>

namespace ROOT {
namespace Internal {
Expand All @@ -37,6 +39,9 @@ class TTreeReaderArrayBase : public TTreeReaderValueBase {

bool IsContiguous() const { return fImpl->IsContiguous(GetProxy()); }

/// Returns the `sizeof` of the collection value type. Returns 0 in case the value size could not be retrieved.
std::size_t GetValueSize() const { return fImpl ? fImpl->GetValueSize(GetProxy()) : 0; }

protected:
void *UntypedAt(std::size_t idx) const { return fImpl->At(GetProxy(), idx); }
void CreateProxy() override;
Expand All @@ -51,6 +56,22 @@ class TTreeReaderArrayBase : public TTreeReaderValueBase {
// ClassDefOverride(TTreeReaderArrayBase, 0);//Accessor to member of an object stored in a collection
};

class R__CLING_PTRCHECK(off) TTreeReaderUntypedArray final : public TTreeReaderArrayBase {
std::string fArrayElementTypeName;

public:
TTreeReaderUntypedArray(TTreeReader &tr, std::string_view branchName, std::string_view innerTypeName)
: TTreeReaderArrayBase(&tr, branchName.data(), TDictionary::GetDictionary(innerTypeName.data())),
fArrayElementTypeName(innerTypeName)
{
}

std::byte *At(std::size_t idx) const { return reinterpret_cast<std::byte *>(UntypedAt(idx)); }

protected:
const char *GetDerivedTypeName() const final { return fArrayElementTypeName.c_str(); }
};

} // namespace Internal
} // namespace ROOT

Expand Down
1 change: 1 addition & 0 deletions tree/treeplayer/inc/TTreeReaderUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ namespace Internal {
virtual size_t GetSize(Detail::TBranchProxy*) = 0;
virtual void* At(Detail::TBranchProxy*, size_t /*idx*/) = 0;
virtual bool IsContiguous(Detail::TBranchProxy *) = 0;
virtual std::size_t GetValueSize(Detail::TBranchProxy *) = 0;
};

}
Expand Down
24 changes: 24 additions & 0 deletions tree/treeplayer/inc/TTreeReaderValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,30 @@ class R__CLING_PTRCHECK(off) TTreeReaderOpaqueValue final : public ROOT::Interna
const char *GetDerivedTypeName() const { return ""; }
};

class R__CLING_PTRCHECK(off) TTreeReaderUntypedValue final : public TTreeReaderValueBase {
std::string fElementTypeName;

public:
TTreeReaderUntypedValue(TTreeReader &tr, std::string_view branchName, std::string_view typeName)
: TTreeReaderValueBase(&tr, branchName.data(), TDictionary::GetDictionary(typeName.data())),
fElementTypeName(typeName)
{
}

void *Get()
{
if (!fProxy) {
ErrorAboutMissingProxyIfNeeded();
return nullptr;
}
void *address = GetAddress(); // Needed to figure out if it's a pointer
return fProxy->IsaPointer() ? *(void **)address : (void *)address;
}

protected:
const char *GetDerivedTypeName() const final { return fElementTypeName.c_str(); }
};

} // namespace Internal
} // namespace ROOT

Expand Down
5 changes: 5 additions & 0 deletions tree/treeplayer/src/TBranchProxy.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -583,3 +583,8 @@ bool ROOT::Detail::TBranchProxy::Setup()
return false;
}
}

Int_t ROOT::Detail::TBranchProxy::GetStreamerElementSize() const
{
return fElement ? fElement->GetSize() : 0;
}
51 changes: 51 additions & 0 deletions tree/treeplayer/src/TTreeReaderArray.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#include <memory>
#include <optional>
#include <iostream>

// pin vtable
ROOT::Internal::TVirtualCollectionReader::~TVirtualCollectionReader() {}
Expand Down Expand Up @@ -72,6 +73,12 @@ class TClonesReader : public TVirtualCollectionReader {
}

bool IsContiguous(ROOT::Detail::TBranchProxy *) override { return false; }

std::size_t GetValueSize(ROOT::Detail::TBranchProxy *proxy) override
{
auto *ca = GetCA(proxy);
return ca ? ca->GetClass()->Size() : 0;
}
};

bool IsCPContiguous(const TVirtualCollectionProxy &cp)
Expand All @@ -86,6 +93,14 @@ bool IsCPContiguous(const TVirtualCollectionProxy &cp)
}
}

UInt_t GetCPValueSize(const TVirtualCollectionProxy &cp)
{
// This works only if the collection proxy value type is a fundamental type
auto &&eDataType = cp.GetType();
auto *tDataType = TDataType::GetDataType(eDataType);
return tDataType ? tDataType->Size() : 0;
}

// Reader interface for STL
class TSTLReader final : public TVirtualCollectionReader {
public:
Expand Down Expand Up @@ -131,6 +146,12 @@ class TSTLReader final : public TVirtualCollectionReader {
auto cp = GetCP(proxy);
return IsCPContiguous(*cp);
}

std::size_t GetValueSize(ROOT::Detail::TBranchProxy *proxy) override
{
auto cp = GetCP(proxy);
return GetCPValueSize(*cp);
}
};

class TCollectionLessSTLReader final : public TVirtualCollectionReader {
Expand Down Expand Up @@ -190,6 +211,12 @@ class TCollectionLessSTLReader final : public TVirtualCollectionReader {
auto cp = GetCP(proxy);
return IsCPContiguous(*cp);
}

std::size_t GetValueSize(ROOT::Detail::TBranchProxy *proxy) override
{
auto cp = GetCP(proxy);
return GetCPValueSize(*cp);
}
};

// Reader interface for leaf list
Expand Down Expand Up @@ -243,6 +270,12 @@ class TObjectArrayReader : public TVirtualCollectionReader {
void SetBasicTypeSize(Int_t size) { fBasicTypeSize = size; }

bool IsContiguous(ROOT::Detail::TBranchProxy *) override { return true; }

std::size_t GetValueSize(ROOT::Detail::TBranchProxy *proxy) override
{
auto cp = GetCP(proxy);
return GetCPValueSize(*cp);
}
};

template <class BASE>
Expand Down Expand Up @@ -387,6 +420,18 @@ class TBasicTypeArrayReader final : public TVirtualCollectionReader {
}

bool IsContiguous(ROOT::Detail::TBranchProxy *) override { return false; }

std::size_t GetValueSize(ROOT::Detail::TBranchProxy *proxy) override
{
if (!proxy->Read()) {
fReadStatus = TTreeReaderValueBase::kReadError;
if (!proxy->GetSuppressErrorsForMissingBranch())
Error("TBasicTypeArrayReader::GetValueSize()", "Read error in TBranchProxy.");
return 0;
}
fReadStatus = TTreeReaderValueBase::kReadSuccess;
return proxy->GetStreamerElementSize();
}
};

class TBasicTypeClonesReader final : public TClonesReader {
Expand Down Expand Up @@ -434,6 +479,12 @@ class TLeafReader : public TVirtualCollectionReader {

bool IsContiguous(ROOT::Detail::TBranchProxy *) override { return true; }

std::size_t GetValueSize(ROOT::Detail::TBranchProxy *) override
{
auto *leaf = fValueReader->GetLeaf();
return leaf ? leaf->GetLenType() : 0;
}

protected:
void ProxyRead() { fValueReader->ProxyRead(); }
};
Expand Down
Loading