5151
5252namespace facebook ::velox::cudf_velox {
5353
54+ namespace {
55+
56+ // / Creates extended table view by appending precomputed columns
57+ cudf::table_view createExtendedTableView (
58+ cudf::table_view originalView,
59+ std::vector<ColumnOrView>& precomputedColumns) {
60+ if (precomputedColumns.empty ()) {
61+ return originalView;
62+ }
63+
64+ std::vector<cudf::column_view> allViews;
65+ allViews.reserve (originalView.num_columns () + precomputedColumns.size ());
66+
67+ for (cudf::size_type i = 0 ; i < originalView.num_columns (); ++i) {
68+ allViews.push_back (originalView.column (i));
69+ }
70+ for (auto & col : precomputedColumns) {
71+ allViews.push_back (asView (col));
72+ }
73+
74+ return cudf::table_view (allViews);
75+ }
76+
77+ } // namespace
78+
5479void CudfHashJoinProbe::close () {
5580 Operator::close ();
5681 filterEvaluator_.reset ();
@@ -221,10 +246,10 @@ void CudfHashJoinBuild::noMoreInput() {
221246 }
222247 }
223248
224- auto buildType = joinNode_->sources ()[1 ]->outputType ();
225249 auto rightKeys = joinNode_->rightKeys ();
226250
227251 auto buildKeyIndices = std::vector<cudf::size_type>(rightKeys.size ());
252+ auto buildType = joinNode_->sources ()[1 ]->outputType ();
228253 for (size_t i = 0 ; i < buildKeyIndices.size (); i++) {
229254 buildKeyIndices[i] = static_cast <cudf::size_type>(
230255 buildType->getChildIdx (rightKeys[i]->name ()));
@@ -301,22 +326,22 @@ CudfHashJoinProbe::CudfHashJoinProbe(
301326 operatorId,
302327 fmt::format (" [{}]" , joinNode->id ())),
303328 joinNode_(joinNode),
329+ probeType_(joinNode_->sources ()[0]->outputType()),
330+ buildType_(joinNode_->sources ()[1]->outputType()),
304331 cudaEvent_(std::make_unique<CudaEvent>(cudaEventDisableTiming)) {
305332 if (CudfConfig::getInstance ().debugEnabled ) {
306333 VLOG (2 ) << " CudfHashJoinProbe constructor" ;
307334 }
308- auto probeType = joinNode_->sources ()[0 ]->outputType ();
309- auto buildType = joinNode_->sources ()[1 ]->outputType ();
310335 auto const & leftKeys = joinNode_->leftKeys (); // probe keys
311336 auto const & rightKeys = joinNode_->rightKeys (); // build keys
312337
313338 if (CudfConfig::getInstance ().debugEnabled ) {
314- for (int i = 0 ; i < probeType ->names ().size (); i++) {
315- VLOG (1 ) << " Left column " << i << " : " << probeType ->names ()[i];
339+ for (int i = 0 ; i < probeType_ ->names ().size (); i++) {
340+ VLOG (1 ) << " Left column " << i << " : " << probeType_ ->names ()[i];
316341 }
317342
318- for (int i = 0 ; i < buildType ->names ().size (); i++) {
319- VLOG (1 ) << " Right column " << i << " : " << buildType ->names ()[i];
343+ for (int i = 0 ; i < buildType_ ->names ().size (); i++) {
344+ VLOG (1 ) << " Right column " << i << " : " << buildType_ ->names ()[i];
320345 }
321346
322347 for (int i = 0 ; i < leftKeys.size (); i++) {
@@ -330,18 +355,18 @@ CudfHashJoinProbe::CudfHashJoinProbe(
330355 }
331356 }
332357
333- auto const probeTableNumColumns = probeType ->size ();
358+ auto const probeTableNumColumns = probeType_ ->size ();
334359 leftKeyIndices_ = std::vector<cudf::size_type>(leftKeys.size ());
335360 for (size_t i = 0 ; i < leftKeyIndices_.size (); i++) {
336361 leftKeyIndices_[i] = static_cast <cudf::size_type>(
337- probeType ->getChildIdx (leftKeys[i]->name ()));
362+ probeType_ ->getChildIdx (leftKeys[i]->name ()));
338363 VELOX_CHECK_LT (leftKeyIndices_[i], probeTableNumColumns);
339364 }
340- auto const buildTableNumColumns = buildType ->size ();
365+ auto const buildTableNumColumns = buildType_ ->size ();
341366 rightKeyIndices_ = std::vector<cudf::size_type>(rightKeys.size ());
342367 for (size_t i = 0 ; i < rightKeyIndices_.size (); i++) {
343368 rightKeyIndices_[i] = static_cast <cudf::size_type>(
344- buildType ->getChildIdx (rightKeys[i]->name ()));
369+ buildType_ ->getChildIdx (rightKeys[i]->name ()));
345370 VELOX_CHECK_LT (rightKeyIndices_[i], buildTableNumColumns);
346371 }
347372
@@ -355,14 +380,14 @@ CudfHashJoinProbe::CudfHashJoinProbe(
355380 if (CudfConfig::getInstance ().debugEnabled ) {
356381 VLOG (1 ) << " Output column " << i << " : " << outputName;
357382 }
358- auto channel = probeType ->getChildIdxIfExists (outputName);
383+ auto channel = probeType_ ->getChildIdxIfExists (outputName);
359384 if (channel.has_value ()) {
360385 leftColumnIndicesToGather_.push_back (
361386 static_cast <cudf::size_type>(channel.value ()));
362387 leftColumnOutputIndices_.push_back (i);
363388 continue ;
364389 }
365- channel = buildType ->getChildIdxIfExists (outputName);
390+ channel = buildType_ ->getChildIdxIfExists (outputName);
366391 if (channel.has_value ()) {
367392 rightColumnIndicesToGather_.push_back (
368393 static_cast <cudf::size_type>(channel.value ()));
@@ -394,7 +419,7 @@ CudfHashJoinProbe::CudfHashJoinProbe(
394419 // Create a reusable evaluator for the filter column. This is expensive to
395420 // build, and the expression + input schema are stable for the lifetime of
396421 // the operator instance.
397- std::vector<velox::RowTypePtr> filterRowTypes{probeType, buildType };
422+ std::vector<velox::RowTypePtr> filterRowTypes{probeType_, buildType_ };
398423 filterEvaluator_ = createCudfExpression (
399424 exprs.exprs ()[0 ],
400425 facebook::velox::type::concatRowTypes (filterRowTypes));
@@ -406,33 +431,24 @@ CudfHashJoinProbe::CudfHashJoinProbe(
406431 // in whole tables
407432
408433 // create ast tree
409- std::vector<PrecomputeInstruction> rightPrecomputeInstructions;
410- std::vector<PrecomputeInstruction> leftPrecomputeInstructions;
411- static constexpr bool kAllowPureAstOnly = true ;
412434 if (joinNode_->isRightJoin () || joinNode_->isRightSemiFilterJoin ()) {
413435 createAstTree (
414436 exprs.exprs ()[0 ],
415437 tree_,
416438 scalars_,
417- buildType,
418- probeType,
419- rightPrecomputeInstructions,
420- leftPrecomputeInstructions,
421- kAllowPureAstOnly );
439+ buildType_,
440+ probeType_,
441+ rightPrecomputeInstructions_,
442+ leftPrecomputeInstructions_);
422443 } else {
423444 createAstTree (
424445 exprs.exprs ()[0 ],
425446 tree_,
426447 scalars_,
427- probeType,
428- buildType,
429- leftPrecomputeInstructions,
430- rightPrecomputeInstructions,
431- kAllowPureAstOnly );
432- }
433- if (leftPrecomputeInstructions.size () > 0 ||
434- rightPrecomputeInstructions.size () > 0 ) {
435- VELOX_NYI (" Filters that require precomputation are not yet supported" );
448+ probeType_,
449+ buildType_,
450+ leftPrecomputeInstructions_,
451+ rightPrecomputeInstructions_);
436452 }
437453 }
438454}
@@ -632,8 +648,13 @@ std::unique_ptr<cudf::table> CudfHashJoinProbe::filteredOutput(
632648 VELOX_CHECK_NOT_NULL (
633649 filterEvaluator_,
634650 " Join filter evaluator must be initialized before filteredOutput()" );
651+ std::vector<cudf::column_view> joinedColViews;
652+ joinedColViews.reserve (joinedCols.size ());
653+ for (const auto & col : joinedCols) {
654+ joinedColViews.push_back (col->view ());
655+ }
635656 auto filterColumns = filterEvaluator_->eval (
636- joinedCols , stream, cudf::get_current_device_resource_ref ());
657+ joinedColViews , stream, cudf::get_current_device_resource_ref ());
637658 auto filterColumn = asView (filterColumns);
638659
639660 joinedCols = func (std::move (joinedCols), filterColumn);
@@ -662,12 +683,15 @@ std::unique_ptr<cudf::table> CudfHashJoinProbe::filteredOutputIndices(
662683 cudf::column_view leftIndicesCol,
663684 cudf::table_view rightTableView,
664685 cudf::column_view rightIndicesCol,
686+ cudf::table_view extendedLeftView,
687+ cudf::table_view extendedRightView,
665688 cudf::join_kind joinKind,
666689 rmm::cuda_stream_view stream) {
690+ // Use extended views (with precomputed columns) for filter evaluation
667691 auto [filteredLeftJoinIndices, filteredRightJoinIndices] =
668692 cudf::filter_join_indices (
669- leftTableView ,
670- rightTableView ,
693+ extendedLeftView ,
694+ extendedRightView ,
671695 leftIndicesCol,
672696 rightIndicesCol,
673697 tree_.back (),
@@ -680,6 +704,7 @@ std::unique_ptr<cudf::table> CudfHashJoinProbe::filteredOutputIndices(
680704 cudf::device_span<cudf::size_type const >{*filteredRightJoinIndices};
681705 auto filteredLeftIndicesCol = cudf::column_view{filteredLeftIndicesSpan};
682706 auto filteredRightIndicesCol = cudf::column_view{filteredRightIndicesSpan};
707+ // Use original views (without precomputed columns) for gathering output
683708 return unfilteredOutput (
684709 leftTableView,
685710 filteredLeftIndicesCol,
@@ -695,10 +720,31 @@ std::vector<std::unique_ptr<cudf::table>> CudfHashJoinProbe::innerJoin(
695720
696721 auto & rightTables = hashObject_.value ().first ;
697722 auto & hbs = hashObject_.value ().second ;
723+
724+ // Precompute left (probe) table columns if needed (once, outside loop)
725+ std::vector<ColumnOrView> leftPrecomputed;
726+ cudf::table_view extendedLeftView = leftTableView;
727+ if (joinNode_->filter () && !leftPrecomputeInstructions_.empty ()) {
728+ auto leftColumnViews = tableViewToColumnViews (leftTableView);
729+ leftPrecomputed = precomputeSubexpressions (
730+ leftColumnViews,
731+ leftPrecomputeInstructions_,
732+ scalars_,
733+ probeType_,
734+ stream);
735+ extendedLeftView = createExtendedTableView (leftTableView, leftPrecomputed);
736+ }
737+
698738 for (auto i = 0 ; i < rightTables.size (); i++) {
699739 auto rightTableView = rightTables[i]->view ();
700740 auto & hb = hbs[i];
701741
742+ // Use cached precomputed columns for right (build) table
743+ cudf::table_view extendedRightView =
744+ (joinNode_->filter () && !rightPrecomputeInstructions_.empty ())
745+ ? cachedExtendedRightViews_[i]
746+ : rightTableView;
747+
702748 // left = probe, right = build
703749 VELOX_CHECK_NOT_NULL (hb);
704750 if (buildStream_.has_value ()) {
@@ -728,6 +774,8 @@ std::vector<std::unique_ptr<cudf::table>> CudfHashJoinProbe::innerJoin(
728774 leftIndicesCol,
729775 rightTableView,
730776 rightIndicesCol,
777+ extendedLeftView,
778+ extendedRightView,
731779 cudf::join_kind::INNER_JOIN,
732780 stream));
733781 } else {
@@ -749,10 +797,31 @@ std::vector<std::unique_ptr<cudf::table>> CudfHashJoinProbe::leftJoin(
749797
750798 auto & rightTables = hashObject_.value ().first ;
751799 auto & hbs = hashObject_.value ().second ;
800+
801+ // Precompute left (probe) table columns if needed (once, outside loop)
802+ std::vector<ColumnOrView> leftPrecomputed;
803+ cudf::table_view extendedLeftView = leftTableView;
804+ if (joinNode_->filter () && !leftPrecomputeInstructions_.empty ()) {
805+ auto leftColumnViews = tableViewToColumnViews (leftTableView);
806+ leftPrecomputed = precomputeSubexpressions (
807+ leftColumnViews,
808+ leftPrecomputeInstructions_,
809+ scalars_,
810+ probeType_,
811+ stream);
812+ extendedLeftView = createExtendedTableView (leftTableView, leftPrecomputed);
813+ }
814+
752815 for (auto i = 0 ; i < rightTables.size (); i++) {
753816 auto rightTableView = rightTables[i]->view ();
754817 auto & hb = hbs[i];
755818
819+ // Use cached precomputed columns for right (build) table
820+ cudf::table_view extendedRightView =
821+ (joinNode_->filter () && !rightPrecomputeInstructions_.empty ())
822+ ? cachedExtendedRightViews_[i]
823+ : rightTableView;
824+
756825 VELOX_CHECK_NOT_NULL (hb);
757826 if (buildStream_.has_value ()) {
758827 cudaEvent_->recordFrom (stream).waitOn (buildStream_.value ());
@@ -779,6 +848,8 @@ std::vector<std::unique_ptr<cudf::table>> CudfHashJoinProbe::leftJoin(
779848 leftIndicesCol,
780849 rightTableView,
781850 rightIndicesCol,
851+ extendedLeftView,
852+ extendedRightView,
782853 cudf::join_kind::LEFT_JOIN,
783854 stream));
784855 } else {
@@ -1271,12 +1342,11 @@ RowVectorPtr CudfHashJoinProbe::getOutput() {
12711342 std::vector<std::unique_ptr<cudf::column>> outCols (outputType_->size ());
12721343 // Left side nulls (types derive from probe schema at the matching
12731344 // channel indices)
1274- auto probeType = joinNode_->sources ()[0 ]->outputType ();
12751345 for (size_t li = 0 ; li < leftColumnOutputIndices_.size (); ++li) {
12761346 auto outIdx = leftColumnOutputIndices_[li];
12771347 auto probeChannel = leftColumnIndicesToGather_[li];
12781348 auto leftCudfType =
1279- veloxToCudfTypeId (probeType ->childAt (probeChannel));
1349+ veloxToCudfTypeId (probeType_ ->childAt (probeChannel));
12801350 auto nullScalar = cudf::make_default_constructed_scalar (
12811351 cudf::data_type{leftCudfType});
12821352 outCols[outIdx] = cudf::make_column_from_scalar (
@@ -1445,6 +1515,33 @@ exec::BlockingReason CudfHashJoinProbe::isBlocked(ContinueFuture* future) {
14451515 }
14461516 initStream.synchronize ();
14471517 }
1518+
1519+ // Precompute right table columns if filter exists (once when build is done)
1520+ if (joinNode_->filter () && !rightPrecomputeInstructions_.empty ()) {
1521+ auto & rightTablesInit = hashObject_.value ().first ;
1522+ cachedRightPrecomputed_.clear ();
1523+ cachedExtendedRightViews_.clear ();
1524+ cachedRightPrecomputed_.reserve (rightTablesInit.size ());
1525+ cachedExtendedRightViews_.reserve (rightTablesInit.size ());
1526+
1527+ auto initStream = cudfGlobalStreamPool ().get_stream ();
1528+ for (auto & rt : rightTablesInit) {
1529+ auto rightTableView = rt->view ();
1530+ auto rightColumnViews = tableViewToColumnViews (rightTableView);
1531+ auto rightPrecomputed = precomputeSubexpressions (
1532+ rightColumnViews,
1533+ rightPrecomputeInstructions_,
1534+ scalars_,
1535+ buildType_,
1536+ initStream);
1537+ auto extendedView =
1538+ createExtendedTableView (rightTableView, rightPrecomputed);
1539+ cachedRightPrecomputed_.push_back (std::move (rightPrecomputed));
1540+ cachedExtendedRightViews_.push_back (extendedView);
1541+ }
1542+ initStream.synchronize ();
1543+ }
1544+
14481545 auto & rightTables = hashObject_.value ().first ;
14491546 // should be rightTable->numDistinct() but it needs compute,
14501547 // so we use num_rows()
0 commit comments