diff --git a/src/predictor/array_tree_layout.h b/src/predictor/array_tree_layout.h new file mode 100644 index 000000000000..6b24c3c64556 --- /dev/null +++ b/src/predictor/array_tree_layout.h @@ -0,0 +1,226 @@ +/** + * Copyright 2021-2025, XGBoost Contributors + * \file array_tree_layout.cc + * \brief Implementation of array tree layout -- a powerfull inference optimization method. + */ +#ifndef XGBOOST_PREDICTOR_ARRAY_TREE_LAYOUT_H_ +#define XGBOOST_PREDICTOR_ARRAY_TREE_LAYOUT_H_ + +#include +#include +#include // for conditional_t + +#include "../common/categorical.h" // for IsCat +#include "xgboost/tree_model.h" // for RegTree + +namespace xgboost::predictor { + +/** + * @brief The class holds the array-based representation of the top levels of a single tree. + * + * @tparam has_categorical if the tree has categorical features + * + * @tparam any_missing if the class is able to process missing values + * + * @tparam kNumDeepLevels number of tree leveles being unrolled into array-based structure + */ +template +class ArrayTreeLayout { + private: + /* Number of nodes in the array based representation of the top levels of the tree + */ + constexpr static size_t kNodesCount = (1u << kNumDeepLevels) - 1; + + struct Empty {}; + using DefaultLeftType = + typename std::conditional_t, Empty>; + using IsCatType = + typename std::conditional_t, Empty>; + using CatSegmentType = + typename std::conditional_t, kNodesCount>, Empty>; + + DefaultLeftType default_left_; + IsCatType is_cat_; + CatSegmentType cat_segment_; + + std::array split_index_; + std::array split_cond_; + /* The nodes at tree levels 0, 1, ..., kNumDeepLevels - 1 are unrolled into an array-based structure. + * If the tree has additional levels, this array stores the node indices of the sub-trees at level kNumDeepLevels. + * This is necessary to continue processing nodes that are not eligible for array-based unrolling. + * The number of sub-trees packed into this array is equal to the number of nodes at tree level kNumDeepLevels, + * which is calculated as (1u << kNumDeepLevels) == kNodesCount + 1. + */ + // Mapping from array node index to the RegTree node index. + std::array nidx_in_tree_; + + /** + * @brief Traverse the top levels of original tree and fill internal arrays + * + * @tparam depth the tree level being processing + * + * @param tree the original tree + * @param cats matrix of categorical splits + * @param nidx_array node idx in the array layout + * @param nidx node idx in the original tree + */ + template + void Populate(const RegTree& tree, RegTree::CategoricalSplitMatrix const& cats, + bst_node_t nidx_array = 0, bst_node_t nidx = 0) { + if constexpr (depth == kNumDeepLevels + 1) { + return; + } else if constexpr (depth == kNumDeepLevels) { + /* We store the node index in the original tree to ensure continued processing + * for nodes that are not eligible for array layout optimization. + */ + nidx_in_tree_[nidx_array - kNodesCount] = nidx; + } else { + if (tree.IsLeaf(nidx)) { + split_index_[nidx_array] = 0; + + /* + * If the tree is not fully populated, we can reduce transfer costs. + * The values for the unpopulated parts of the tree are set to ensure + * that any move will always proceed in the "right" direction. + * This is achieved by exploiting the fact that comparisons with NaN always result in false. + */ + if constexpr (any_missing) default_left_[nidx_array] = 0; + if constexpr (has_categorical) is_cat_[nidx_array] = 0; + split_cond_[nidx_array] = std::numeric_limits::quiet_NaN(); + + Populate(tree, cats, 2 * nidx_array + 2, nidx); + } else { + if constexpr (any_missing) default_left_[nidx_array] = tree.DefaultLeft(nidx); + if constexpr (has_categorical) { + is_cat_[nidx_array] = common::IsCat(cats.split_type, nidx); + if (is_cat_[nidx_array]) { + cat_segment_[nidx_array] = cats.categories.subspan(cats.node_ptr[nidx].beg, + cats.node_ptr[nidx].size); + } + } + + split_index_[nidx_array] = tree.SplitIndex(nidx); + split_cond_[nidx_array] = tree.SplitCond(nidx); + + /* + * LeftChild is used to determine if a node is a leaf, so it is always a valid value. + * However, RightChild can be invalid in some exotic cases. + * A tree with an invalid RightChild can still be correctly processed using classical methods + * if the split conditions are correct. + * However, in an array layout, an invalid RightChild, even if unreachable, can lead to memory corruption. + * A check should be added to prevent this. + */ + Populate(tree, cats, 2 * nidx_array + 1, tree.LeftChild(nidx)); + bst_node_t right_child = tree.RightChild(nidx); + if (right_child != RegTree::kInvalidNodeId) { + Populate(tree, cats, 2 * nidx_array + 2, right_child); + } + } + } + } + + bool GetDecision(float fvalue, bst_node_t nidx) const { + if constexpr (has_categorical) { + if (is_cat_[nidx]) { + return common::Decision(cat_segment_[nidx], fvalue); + } else { + return fvalue < split_cond_[nidx]; + } + } else { + return fvalue < split_cond_[nidx]; + } + } + + public: + /* Ad-hoc value. + * Increasing doesn't lead to perf gain, since bottleneck is now at gather instructions. + */ + constexpr static int kMaxNumDeepLevels = 6; + static_assert(kNumDeepLevels <= kMaxNumDeepLevels); + + ArrayTreeLayout(const RegTree& tree, RegTree::CategoricalSplitMatrix const &cats) { + Populate(tree, cats); + } + + const auto& SplitIndex() const { + return split_index_; + } + + const auto& SplitCond() const { + return split_cond_; + } + + const auto& DefaultLeft() const { + return default_left_; + } + + const auto& NidxInTree() const { + return nidx_in_tree_; + } + + /** + * @brief Traverse the top levels of the tree for the entire block_size. + * + * In the array layout, it is organized to guarantee that if a node at the current level + * has index nidx, then the node index for the left child at the next level is always + * 2*nidx, and the node index for the right child at the next level is always 2*nidx+1. + * This greatly improves data locality. + * + * @param fvec_tloc buffer holding the feature values + * @param block_size size of the current block (1 < block_size <= 64) + * @param p_nidx Pointer to the vector of node indexes in the original tree with size + * equals to the block size. (One node per sample). The value corresponds + * to the level next after kNumDeepLevels + */ + void Process(common::Span fvec_tloc, std::size_t const block_size, + bst_node_t* p_nidx) { + for (int depth = 0; depth < kNumDeepLevels; ++depth) { + std::size_t first_node = (1u << depth) - 1; + + for (std::size_t i = 0; i < block_size; ++i) { + bst_node_t idx = p_nidx[i]; + + const auto& feat = fvec_tloc[i]; + bst_feature_t split = split_index_[first_node + idx]; + auto fvalue = feat.GetFvalue(split); + if constexpr (any_missing) { + bool go_left = feat.IsMissing(split) ? default_left_[first_node + idx] + : GetDecision(fvalue, first_node + idx); + p_nidx[i] = 2 * idx + !go_left; + } else { + p_nidx[i] = 2 * idx + !GetDecision(fvalue, first_node + idx); + } + } + } + // Remap to the original index. + for (std::size_t i = 0; i < block_size; ++i) { + p_nidx[i] = nidx_in_tree_[p_nidx[i]]; + } + } +}; + +template +void ProcessArrayTree(const RegTree& tree, RegTree::CategoricalSplitMatrix const& cats, + common::Span fvec_tloc, std::size_t const block_size, + bst_node_t* p_nidx, int tree_depth) { + constexpr int kMaxNumDeepLevels = + ArrayTreeLayout::kMaxNumDeepLevels; + + // Fill the array tree, then output predicted node idx. + if constexpr (num_deep_levels == kMaxNumDeepLevels) { + ArrayTreeLayout buffer(tree, cats); + buffer.Process(fvec_tloc, block_size, p_nidx); + } else { + if (tree_depth <= num_deep_levels) { + ArrayTreeLayout buffer(tree, cats); + buffer.Process(fvec_tloc, block_size, p_nidx); + } else { + ProcessArrayTree + (tree, cats, fvec_tloc, block_size, p_nidx, tree_depth); + } + } +} + +} // namespace xgboost::predictor +#endif // XGBOOST_PREDICTOR_ARRAY_TREE_LAYOUT_H_ diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 93c9fffea031..b99beb88c7b4 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -23,6 +23,7 @@ #include "../gbm/gbtree_model.h" // for GBTreeModel, GBTreeModelParam #include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG #include "predict_fn.h" // for GetNextNode, GetNextNodeMulti +#include "array_tree_layout.h" // for ProcessArrayTree #include "treeshap.h" // for CalculateContributions #include "utils.h" // for CheckProxyDMatrix #include "xgboost/base.h" // for bst_float, bst_node_t, bst_omp_uint, bst_fe... @@ -44,8 +45,7 @@ DMLC_REGISTRY_FILE_TAG(cpu_predictor); namespace scalar { template bst_node_t GetLeafIndex(RegTree const &tree, const RegTree::FVec &feat, - RegTree::CategoricalSplitMatrix const &cats) { - bst_node_t nidx{0}; + RegTree::CategoricalSplitMatrix const &cats, bst_node_t nidx) { while (!tree[nidx].IsLeaf()) { bst_feature_t split_index = tree[nidx].SplitIndex(); auto fvalue = feat.GetFvalue(split_index); @@ -57,19 +57,47 @@ bst_node_t GetLeafIndex(RegTree const &tree, const RegTree::FVec &feat, template [[nodiscard]] float PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree, - RegTree::CategoricalSplitMatrix const &cats) noexcept(true) { + RegTree::CategoricalSplitMatrix const &cats, + bst_node_t nidx) noexcept(true) { const bst_node_t leaf = p_feats.HasMissing() - ? GetLeafIndex(tree, p_feats, cats) - : GetLeafIndex(tree, p_feats, cats); + ? GetLeafIndex(tree, p_feats, cats, nidx) + : GetLeafIndex(tree, p_feats, cats, nidx); return tree[leaf].LeafValue(); } + +template +void PredValueByOneTree(const RegTree& tree, + std::size_t const predict_offset, + common::Span fvec_tloc, + std::size_t const block_size, + linalg::MatrixView out_predt, + bst_node_t* p_nidx, int depth, int gid) { + auto const &cats = tree.GetCategoriesMatrix(); + if constexpr (use_array_tree_layout) { + ProcessArrayTree(tree, cats, fvec_tloc, block_size, p_nidx, + depth); + } + for (std::size_t i = 0; i < block_size; ++i) { + bst_node_t nidx = 0; + /* + * If array_tree_layout was used, we start processing from the nidx calculated using + * the array tree. + */ + if constexpr (use_array_tree_layout) { + nidx = p_nidx[i]; + p_nidx[i] = 0; + } + out_predt(predict_offset + i, gid) += + PredValueByOneTree(fvec_tloc[i], tree, cats, nidx); + } +} } // namespace scalar namespace multi { template bst_node_t GetLeafIndex(MultiTargetTree const &tree, const RegTree::FVec &feat, - RegTree::CategoricalSplitMatrix const &cats) { - bst_node_t nidx{0}; + RegTree::CategoricalSplitMatrix const &cats, + bst_node_t nidx) { while (!tree.IsLeaf(nidx)) { bst_feature_t split_index = tree.SplitIndex(nidx); auto fvalue = feat.GetFvalue(split_index); @@ -82,61 +110,114 @@ bst_node_t GetLeafIndex(MultiTargetTree const &tree, const RegTree::FVec &feat, template void PredValueByOneTree(RegTree::FVec const &p_feats, MultiTargetTree const &tree, RegTree::CategoricalSplitMatrix const &cats, - linalg::VectorView out_predt) { + linalg::VectorView out_predt, bst_node_t nidx) { bst_node_t const leaf = p_feats.HasMissing() - ? GetLeafIndex(tree, p_feats, cats) - : GetLeafIndex(tree, p_feats, cats); + ? GetLeafIndex(tree, p_feats, cats, nidx) + : GetLeafIndex(tree, p_feats, cats, nidx); auto leaf_value = tree.LeafValue(leaf); assert(out_predt.Shape(0) == leaf_value.Shape(0) && "shape mismatch."); for (size_t i = 0; i < leaf_value.Size(); ++i) { out_predt(i) += leaf_value(i); } } + +template +void PredValueByOneTree(const RegTree &tree, std::size_t const predict_offset, + common::Span fvec_tloc, std::size_t const block_size, + linalg::MatrixView out_predt, bst_node_t *p_nidx, bst_node_t depth) { + const auto &mt_tree = *(tree.GetMultiTargetTree()); + auto const &cats = tree.GetCategoriesMatrix(); + if constexpr (use_array_tree_layout) { + ProcessArrayTree(tree, cats, fvec_tloc, block_size, p_nidx, + depth); + } + for (std::size_t i = 0; i < block_size; ++i) { + bst_node_t nidx = 0; + if constexpr (use_array_tree_layout) { + nidx = p_nidx[i]; + p_nidx[i] = 0; + } + auto t_predts = out_predt.Slice(predict_offset + i, linalg::All()); + PredValueByOneTree(fvec_tloc[i], mt_tree, cats, t_predts, nidx); + } +} } // namespace multi namespace { +template void PredictBlockByAllTrees(gbm::GBTreeModel const &model, bst_tree_t const tree_begin, bst_tree_t const tree_end, std::size_t const predict_offset, common::Span fvec_tloc, std::size_t const block_size, - linalg::MatrixView out_predt) { + linalg::MatrixView out_predt, + const std::vector& tree_depth) { + std::vector nidx; + if constexpr (use_array_tree_layout) { + nidx.resize(block_size, 0); + } for (bst_tree_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) { auto const &tree = *model.trees.at(tree_id); - auto const &cats = tree.GetCategoriesMatrix(); bool has_categorical = tree.HasCategoricalSplit(); + int depth = use_array_tree_layout ? tree_depth[tree_id - tree_begin] : 0; if (tree.IsMultiTarget()) { if (has_categorical) { - for (std::size_t i = 0; i < block_size; ++i) { - auto t_predts = out_predt.Slice(predict_offset + i, linalg::All()); - multi::PredValueByOneTree(fvec_tloc[i], *tree.GetMultiTargetTree(), cats, t_predts); - } + multi::PredValueByOneTree + (tree, predict_offset, fvec_tloc, block_size, out_predt, nidx.data(), depth); } else { - for (std::size_t i = 0; i < block_size; ++i) { - auto t_predts = out_predt.Slice(predict_offset + i, linalg::All()); - multi::PredValueByOneTree(fvec_tloc[i], *tree.GetMultiTargetTree(), cats, - t_predts); - } + multi::PredValueByOneTree + (tree, predict_offset, fvec_tloc, block_size, out_predt, nidx.data(), depth); } } else { auto const gid = model.tree_info[tree_id]; if (has_categorical) { - for (std::size_t i = 0; i < block_size; ++i) { - out_predt(predict_offset + i, gid) += - scalar::PredValueByOneTree(fvec_tloc[i], tree, cats); - } + scalar::PredValueByOneTree + (tree, predict_offset, fvec_tloc, block_size, out_predt, nidx.data(), depth, gid); } else { - for (std::size_t i = 0; i < block_size; ++i) { - out_predt(predict_offset + i, gid) += - scalar::PredValueByOneTree(fvec_tloc[i], tree, cats); + scalar::PredValueByOneTree + (tree, predict_offset, fvec_tloc, block_size, out_predt, nidx.data(), depth, gid); + } + } + } +} + +// Dispatch between template implementations +void DispatchArrayLayout(gbm::GBTreeModel const &model, bst_tree_t const tree_begin, + bst_tree_t const tree_end, std::size_t const predict_offset, + common::Span fvec_tloc, std::size_t const block_size, + linalg::MatrixView out_predt, const std::vector &tree_depth, + bool any_missing) { + /* + * We transform trees to array layout for each block of data to avoid memory overheads. + * It makes the array layout inefficient for block_size == 1 + */ + const bool use_array_tree_layout = block_size > 1; + if (use_array_tree_layout) { + // Recheck if the current block has missing values. + if (any_missing) { + any_missing = false; + for (std::size_t i = 0; i < block_size; ++i) { + any_missing |= fvec_tloc[i].HasMissing(); + if (any_missing) { + break; } } } + if (any_missing) { + PredictBlockByAllTrees(model, tree_begin, tree_end, predict_offset, fvec_tloc, + block_size, out_predt, tree_depth); + } else { + PredictBlockByAllTrees(model, tree_begin, tree_end, predict_offset, fvec_tloc, + block_size, out_predt, tree_depth); + } + } else { + PredictBlockByAllTrees(model, tree_begin, tree_end, predict_offset, fvec_tloc, + block_size, out_predt, tree_depth); } } bool ShouldUseBlock(DMatrix *p_fmat) { // Threshold to use block-based prediction. - constexpr double kDensityThresh = .5; + constexpr double kDensityThresh = .125; bst_idx_t n_samples = p_fmat->Info().num_row_; bst_idx_t total = std::max(n_samples * p_fmat->Info().num_col_, static_cast(1)); double density = static_cast(p_fmat->Info().num_nonzero_) / static_cast(total); @@ -379,6 +460,7 @@ struct LaunchConfig : public Args... { } } else { for (auto const &page : p_fmat->GetBatches()) { + // bool any_missing = !page.IsDense(); fn(SparsePageView{page.GetView(), page.base_rowid, acc}); } } @@ -458,18 +540,32 @@ template void PredictBatchByBlockKernel(DataView const &batch, gbm::GBTreeModel const &model, bst_tree_t tree_begin, bst_tree_t tree_end, ThreadTmp *p_fvec, std::int32_t n_threads, + bool any_missing, linalg::TensorView out_predt) { auto &fvec = *p_fvec; // Parallel over local batches auto const n_samples = batch.Size(); auto const n_features = model.learner_model_param->num_feature; + /* Precalculate depth for each tree. + * These values are required only for the ArrayLayout optimization, + * so we don't need them if kBlockOfRowsSize == 1 + */ + std::vector tree_depth; + if constexpr (kBlockOfRowsSize > 1) { + tree_depth.resize(tree_end - tree_begin); + common::ParallelFor(tree_end - tree_begin, n_threads, [&](auto i) { + bst_tree_t tree_id = tree_begin + i; + tree_depth[i] = model.trees.at(tree_id)->MaxDepth(); + }); + } + common::ParallelFor1d(n_samples, n_threads, [&](auto &&block) { auto fvec_tloc = fvec.ThreadBuffer(block.Size()); batch.FVecFill(block, n_features, fvec_tloc); - PredictBlockByAllTrees(model, tree_begin, tree_end, block.begin() + batch.base_rowid, fvec_tloc, - block.Size(), out_predt); + DispatchArrayLayout(model, tree_begin, tree_end, block.begin() + batch.base_rowid, fvec_tloc, + block.Size(), out_predt, tree_depth, any_missing); batch.FVecDrop(fvec_tloc); }); } @@ -802,13 +898,15 @@ class CPUPredictor : public Predictor { bst_idx_t n_samples = p_fmat->Info().num_row_; CHECK_EQ(out_preds->size(), n_samples * n_groups); auto out_predt = linalg::MakeTensorView(ctx_, *out_preds, n_samples, n_groups); + bool any_missing = !(p_fmat->IsDense()); LaunchPredict(this->ctx_, p_fmat, model, [&](auto &&policy) { using Policy = common::GetValueT; ThreadTmp feat_vecs{n_threads}; policy.ForEachBatch([&](auto &&batch) { PredictBatchByBlockKernel(batch, model, tree_begin, tree_end, - &feat_vecs, n_threads, out_predt); + &feat_vecs, n_threads, any_missing, + out_predt); }); }); } @@ -895,6 +993,7 @@ class CPUPredictor : public Predictor { this->InitOutPredictions(p_m->Info(), &(out_preds->predictions), model); auto &predictions = out_preds->predictions.HostVector(); + bool any_missing = true; auto const n_threads = this->ctx_->Threads(); // Always use block as we don't know the nnz. @@ -904,7 +1003,8 @@ class CPUPredictor : public Predictor { auto kernel = [&](auto &&view) { auto out_predt = linalg::MakeTensorView(ctx_, predictions, view.Size(), n_groups); PredictBatchByBlockKernel(view, model, tree_begin, tree_end, - &feat_vecs, n_threads, out_predt); + &feat_vecs, n_threads, any_missing, + out_predt); }; auto dispatch = [&](auto x) { using AdapterT = typename decltype(x)::element_type; @@ -963,12 +1063,12 @@ class CPUPredictor : public Predictor { for (bst_tree_t j = 0; j < ntree_limit; ++j) { auto const &tree = *model.trees[j]; auto const &cats = tree.GetCategoriesMatrix(); - bst_node_t nidx; + bst_node_t nidx = 0; if (tree.IsMultiTarget()) { nidx = multi::GetLeafIndex(*tree.GetMultiTargetTree(), fvec_tloc.front(), - cats); + cats, nidx); } else { - nidx = scalar::GetLeafIndex(tree, fvec_tloc.front(), cats); + nidx = scalar::GetLeafIndex(tree, fvec_tloc.front(), cats, nidx); } preds[ridx * ntree_limit + j] = static_cast(nidx); } diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index de9309c358af..bc9df71c7641 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -7,6 +7,7 @@ #include "../../../src/collective/communicator-inl.h" #include "../../../src/data/adapter.h" #include "../../../src/data/proxy_dmatrix.h" +#include "../../../src/predictor/array_tree_layout.h" #include "../../../src/gbm/gbtree.h" #include "../../../src/gbm/gbtree_model.h" #include "../collective/test_worker.h" // for TestDistributedGlobal @@ -22,6 +23,80 @@ TEST(CpuPredictor, Basic) { TestBasic(dmat.get(), &ctx); } + +template +void CheckArrayLayout(const RegTree& tree, ArrayLayoutT buffer, int max_depth, int depth, size_t nid, size_t nid_array) { + const auto& split_idx = buffer.SplitIndex(); + const auto& split_cond = buffer.SplitCond(); + const auto& default_left = buffer.DefaultLeft(); + const auto& nidx_in_tree = buffer.NidxInTree(); + const auto& nodes = tree.GetNodes(); + + if (depth == max_depth) { + ASSERT_EQ(nidx_in_tree[nid_array - (1u << max_depth) + 1], nid); + return; + } + + if (nodes[nid].IsLeaf()) { + ASSERT_EQ(default_left[nid_array], 0); + ASSERT_TRUE(std::isnan(split_cond[nid_array])); + + CheckArrayLayout(tree, buffer, max_depth, depth + 1, nid, 2 * nid_array + 2); + } else { + ASSERT_EQ(nodes[nid].SplitIndex(), split_idx[nid_array]); + ASSERT_EQ(nodes[nid].SplitCond(), split_cond[nid_array]); + ASSERT_EQ(nodes[nid].DefaultLeft(), default_left[nid_array]); + + if (nodes[nid].LeftChild() != RegTree::kInvalidNodeId) { + CheckArrayLayout(tree, buffer, max_depth, depth + 1, nodes[nid].LeftChild(), 2 * nid_array + 1); + } + if (nodes[nid].RightChild() != RegTree::kInvalidNodeId) { + CheckArrayLayout(tree, buffer, max_depth, depth + 1, nodes[nid].RightChild(), 2 * nid_array + 2); + } + } +} + +TEST(CpuPredictor, ArrayTreeLayout) { + Context ctx; + + RegTree tree; + size_t n_nodes = 15; // 2^4 - 1 + for (size_t nid = 0; nid < n_nodes; ++nid) { + // Some place-holders + size_t split_index = nid + 1; + bst_float split_cond = nid + 2; + bool default_left = nid % 2 == 0; + + tree.ExpandNode(nid, split_index, split_cond, default_left, 0, 0, 0, 0, 0, 0, 0); + } + + { + constexpr int kDepth = 1; + predictor::ArrayTreeLayout buffer(tree, tree.GetCategoriesMatrix()); + CheckArrayLayout(tree, buffer, kDepth, 0, 0, 0); + } + { + constexpr int kDepth = 2; + predictor::ArrayTreeLayout buffer(tree, tree.GetCategoriesMatrix()); + CheckArrayLayout(tree, buffer, kDepth, 0, 0, 0); + } + { + constexpr int kDepth = 3; + predictor::ArrayTreeLayout buffer(tree, tree.GetCategoriesMatrix()); + CheckArrayLayout(tree, buffer, kDepth, 0, 0, 0); + } + { + constexpr int kDepth = 4; + predictor::ArrayTreeLayout buffer(tree, tree.GetCategoriesMatrix()); + CheckArrayLayout(tree, buffer, kDepth, 0, 0, 0); + } + { + constexpr int kDepth = 5; + predictor::ArrayTreeLayout buffer(tree, tree.GetCategoriesMatrix()); + CheckArrayLayout(tree, buffer, kDepth, 0, 0, 0); + } +} + namespace { void TestColumnSplit() { Context ctx; diff --git a/tests/python-sycl/test_sycl_training_continuation.py b/tests/python-sycl/test_sycl_training_continuation.py index e2a11c987bb4..71d5965600e7 100644 --- a/tests/python-sycl/test_sycl_training_continuation.py +++ b/tests/python-sycl/test_sycl_training_continuation.py @@ -9,8 +9,8 @@ class TestSYCLTrainingContinuation: def run_training_continuation(self, use_json): kRows = 64 kCols = 32 - X = np.random.randn(kRows, kCols) - y = np.random.randn(kRows) + X = rng.randn(kRows, kCols) + y = rng.randn(kRows) dtrain = xgb.DMatrix(X, y) params = { "device": "sycl",