@@ -113,9 +113,10 @@ void MergeJoin::initialize() {
113113 isSemiFilterJoin (joinType_)) {
114114 joinTracker_ = JoinTracker (outputBatchSize_, pool ());
115115 }
116- } else if (joinNode_->isAntiJoin ()) {
116+ } else if (joinNode_->isAntiJoin () || joinNode_-> isFullJoin () ) {
117117 // Anti join needs to track the left side rows that have no match on the
118- // right.
118+ // right. Full outer join needs to track the right side rows that have no
119+ // match on the left.
119120 joinTracker_ = JoinTracker (outputBatchSize_, pool ());
120121 }
121122
@@ -392,7 +393,8 @@ bool MergeJoin::tryAddOutputRow(
392393 const RowVectorPtr& leftBatch,
393394 vector_size_t leftRow,
394395 const RowVectorPtr& rightBatch,
395- vector_size_t rightRow) {
396+ vector_size_t rightRow,
397+ bool isRightJoinForFullOuter) {
396398 if (outputSize_ == outputBatchSize_) {
397399 return false ;
398400 }
@@ -426,12 +428,15 @@ bool MergeJoin::tryAddOutputRow(
426428 filterRightInputProjections_);
427429
428430 if (joinTracker_) {
429- if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_)) {
431+ if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_) ||
432+ (isFullJoin (joinType_) && isRightJoinForFullOuter)) {
430433 // Record right-side row with a match on the left-side.
431- joinTracker_->addMatch (rightBatch, rightRow, outputSize_);
434+ joinTracker_->addMatch (
435+ rightBatch, rightRow, outputSize_, isRightJoinForFullOuter);
432436 } else {
433437 // Record left-side row with a match on the right-side.
434- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
438+ joinTracker_->addMatch (
439+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
435440 }
436441 }
437442 }
@@ -441,7 +446,8 @@ bool MergeJoin::tryAddOutputRow(
441446 if (isAntiJoin (joinType_)) {
442447 VELOX_CHECK (joinTracker_.has_value ());
443448 // Record left-side row with a match on the right-side.
444- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
449+ joinTracker_->addMatch (
450+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
445451 }
446452
447453 ++outputSize_;
@@ -460,14 +466,15 @@ bool MergeJoin::prepareOutput(
460466 return true ;
461467 }
462468
463- if (isRightJoin (joinType_) && right != currentRight_) {
464- return true ;
465- }
466-
467469 // If there is a new right, we need to flatten the dictionary.
468470 if (!isRightFlattened_ && right && currentRight_ != right) {
469471 flattenRightProjections ();
470472 }
473+
474+ if (right != currentRight_) {
475+ return true ;
476+ }
477+
471478 return false ;
472479 }
473480
@@ -490,11 +497,10 @@ bool MergeJoin::prepareOutput(
490497 }
491498 } else {
492499 for (const auto & projection : leftProjections_) {
500+ auto column = left->childAt (projection.inputChannel );
501+ column->clearContainingLazyAndWrapped ();
493502 localColumns[projection.outputChannel ] = BaseVector::wrapInDictionary (
494- {},
495- leftOutputIndices_,
496- outputBatchSize_,
497- left->childAt (projection.inputChannel ));
503+ {}, leftOutputIndices_, outputBatchSize_, column);
498504 }
499505 }
500506 currentLeft_ = left;
@@ -510,11 +516,10 @@ bool MergeJoin::prepareOutput(
510516 isRightFlattened_ = true ;
511517 } else {
512518 for (const auto & projection : rightProjections_) {
519+ auto column = right->childAt (projection.inputChannel );
520+ column->clearContainingLazyAndWrapped ();
513521 localColumns[projection.outputChannel ] = BaseVector::wrapInDictionary (
514- {},
515- rightOutputIndices_,
516- outputBatchSize_,
517- right->childAt (projection.inputChannel ));
522+ {}, rightOutputIndices_, outputBatchSize_, column);
518523 }
519524 isRightFlattened_ = false ;
520525 }
@@ -579,6 +584,39 @@ bool MergeJoin::prepareOutput(
579584bool MergeJoin::addToOutput () {
580585 if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_)) {
581586 return addToOutputForRightJoin ();
587+ } else if (isFullJoin (joinType_) && filter_) {
588+ if (!leftForRightJoinMatch_) {
589+ leftForRightJoinMatch_ = leftMatch_;
590+ rightForRightJoinMatch_ = rightMatch_;
591+ }
592+
593+ if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) {
594+ auto left = addToOutputForLeftJoin ();
595+ if (!leftMatch_) {
596+ leftJoinForFullFinished_ = true ;
597+ }
598+ if (left) {
599+ if (!leftMatch_) {
600+ leftMatch_ = leftForRightJoinMatch_;
601+ rightMatch_ = rightForRightJoinMatch_;
602+ }
603+
604+ return true ;
605+ }
606+ }
607+
608+ if (!leftMatch_ && !rightJoinForFullFinished_) {
609+ leftMatch_ = leftForRightJoinMatch_;
610+ rightMatch_ = rightForRightJoinMatch_;
611+ rightJoinForFullFinished_ = true ;
612+ }
613+
614+ auto right = addToOutputForRightJoin ();
615+
616+ leftForRightJoinMatch_ = leftMatch_;
617+ rightForRightJoinMatch_ = rightMatch_;
618+
619+ return right;
582620 } else {
583621 return addToOutputForLeftJoin ();
584622 }
@@ -727,7 +765,9 @@ bool MergeJoin::addToOutputForRightJoin() {
727765 }
728766
729767 for (auto j = leftStartRow; j < leftEndRow; ++j) {
730- if (!tryAddOutputRow (leftBatch, j, rightBatch, i)) {
768+ const auto isRightJoinForFullOuter = isFullJoin (joinType_);
769+ if (!tryAddOutputRow (
770+ leftBatch, j, rightBatch, i, isRightJoinForFullOuter)) {
731771 // If we run out of space in the current output_, we will need to
732772 // produce a buffer and continue processing left later. In this
733773 // case, we cannot leave left as a lazy vector, since we cannot have
@@ -1141,7 +1181,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
11411181 isFullJoin (joinType_)) {
11421182 // If output_ is currently wrapping a different buffer, return it
11431183 // first.
1144- if (prepareOutput (input_, nullptr )) {
1184+ if (prepareOutput (input_, rightInput_ )) {
11451185 output_->resize (outputSize_);
11461186 return std::move (output_);
11471187 }
@@ -1166,7 +1206,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
11661206 if (isRightJoin (joinType_) || isFullJoin (joinType_)) {
11671207 // If output_ is currently wrapping a different buffer, return it
11681208 // first.
1169- if (prepareOutput (nullptr , rightInput_)) {
1209+ if (prepareOutput (input_ , rightInput_)) {
11701210 output_->resize (outputSize_);
11711211 return std::move (output_);
11721212 }
@@ -1218,6 +1258,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
12181258 endRightRow < rightInput_->size (),
12191259 std::nullopt };
12201260
1261+ leftJoinForFullFinished_ = false ;
1262+ rightJoinForFullFinished_ = false ;
12211263 if (!leftMatch_->complete || !rightMatch_->complete ) {
12221264 if (!leftMatch_->complete ) {
12231265 // Need to continue looking for the end of match.
@@ -1262,8 +1304,6 @@ RowVectorPtr MergeJoin::doGetOutput() {
12621304RowVectorPtr MergeJoin::applyFilter (const RowVectorPtr& output) {
12631305 const auto numRows = output->size ();
12641306
1265- RowVectorPtr fullOuterOutput = nullptr ;
1266-
12671307 BufferPtr indices = allocateIndices (numRows, pool ());
12681308 auto * rawIndices = indices->asMutable <vector_size_t >();
12691309 vector_size_t numPassed = 0 ;
@@ -1280,84 +1320,41 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12801320
12811321 // If all matches for a given left-side row fail the filter, add a row to
12821322 // the output with nulls for the right-side columns.
1283- const auto onMiss = [&](auto row) {
1323+ const auto onMiss = [&](auto row, bool isRightJoinForFullOuter ) {
12841324 if (isSemiFilterJoin (joinType_)) {
12851325 return ;
12861326 }
12871327 rawIndices[numPassed++] = row;
12881328
1289- if (isFullJoin (joinType_)) {
1290- // For filtered rows, it is necessary to insert additional data
1291- // to ensure the result set is complete. Specifically, we
1292- // need to generate two records: one record containing the
1293- // columns from the left table along with nulls for the
1294- // right table, and another record containing the columns
1295- // from the right table along with nulls for the left table.
1296- // For instance, the current output is filtered based on the condition
1297- // t > 1.
1298-
1299- // 1, 1
1300- // 2, 2
1301- // 3, 3
1302-
1303- // In this scenario, we need to additionally insert a record 1, 1.
1304- // Subsequently, we will set the values of the columns on the left to
1305- // null and the values of the columns on the right to null as well. By
1306- // doing so, we will obtain the final result set.
1307-
1308- // 1, null
1309- // null, 1
1310- // 2, 2
1311- // 3, 3
1312- fullOuterOutput = BaseVector::create<RowVector>(
1313- output->type (), output->size () + 1 , pool ());
1314-
1315- for (auto i = 0 ; i < row + 1 ; ++i) {
1316- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1317- fullOuterOutput->childAt (j)->copy (
1318- output->childAt (j).get (), i, i, 1 );
1329+ if (!isRightJoin (joinType_)) {
1330+ if (isFullJoin (joinType_) && isRightJoinForFullOuter) {
1331+ for (auto & projection : leftProjections_) {
1332+ auto target = output->childAt (projection.outputChannel );
1333+ target->setNull (row, true );
13191334 }
1320- }
1321-
1322- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1323- fullOuterOutput->childAt (j)->copy (
1324- output->childAt (j).get (), row + 1 , row, 1 );
1325- }
1326-
1327- for (auto i = row + 1 ; i < output->size (); ++i) {
1328- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1329- fullOuterOutput->childAt (j)->copy (
1330- output->childAt (j).get (), i + 1 , i, 1 );
1335+ } else {
1336+ for (auto & projection : rightProjections_) {
1337+ auto target = output->childAt (projection.outputChannel );
1338+ target->setNull (row, true );
13311339 }
13321340 }
1333-
1334- for (auto & projection : leftProjections_) {
1335- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1336- target->setNull (row, true );
1337- }
1338-
1339- for (auto & projection : rightProjections_) {
1340- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1341- target->setNull (row + 1 , true );
1342- }
1343- } else if (!isRightJoin (joinType_)) {
1344- for (auto & projection : rightProjections_) {
1345- auto & target = output->childAt (projection.outputChannel );
1346- target->setNull (row, true );
1347- }
13481341 } else {
13491342 for (auto & projection : leftProjections_) {
1350- auto & target = output->childAt (projection.outputChannel );
1343+ auto target = output->childAt (projection.outputChannel );
13511344 target->setNull (row, true );
13521345 }
13531346 }
13541347 };
13551348
13561349 auto onMatch = [&](auto row, bool firstMatch) {
1357- const bool isNonSemiAntiJoin =
1358- !isSemiFilterJoin (joinType_) && !isAntiJoin (joinType_);
1350+ const bool isFullLeftJoin =
1351+ isFullJoin (joinType_) && !joinTracker_->isRightJoinForFullOuter (row);
1352+
1353+ const bool isNonSemiAntiFullJoin = !isSemiFilterJoin (joinType_) &&
1354+ !isAntiJoin (joinType_) && !isFullJoin (joinType_);
13591355
1360- if ((isSemiFilterJoin (joinType_) && firstMatch) || isNonSemiAntiJoin) {
1356+ if ((isSemiFilterJoin (joinType_) && firstMatch) ||
1357+ isNonSemiAntiFullJoin || isFullLeftJoin) {
13611358 rawIndices[numPassed++] = row;
13621359 }
13631360 };
@@ -1418,17 +1415,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
14181415
14191416 if (numPassed == numRows) {
14201417 // All rows passed.
1421- if (fullOuterOutput) {
1422- return fullOuterOutput;
1423- }
14241418 return output;
14251419 }
14261420
14271421 // Some, but not all rows passed.
1428- if (fullOuterOutput) {
1429- return wrap (numPassed, indices, fullOuterOutput);
1430- }
1431-
14321422 return wrap (numPassed, indices, output);
14331423}
14341424
0 commit comments