Skip to content

Commit a20ebf0

Browse files
authored
all join kind tests for dq scalar join (#26354)
1 parent cf85134 commit a20ebf0

File tree

6 files changed

+265
-67
lines changed

6 files changed

+265
-67
lines changed

ydb/library/yql/dq/comp_nodes/dq_hash_join_table.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,29 @@
22

33
namespace NKikimr::NMiniKQL::NJoinTable {
44
bool NeedToTrackUnusedRightTuples(EJoinKind kind) {
5-
return (static_cast<int>(kind)&4) == 4;
5+
switch (kind) {
6+
using enum NKikimr::NMiniKQL::EJoinKind;
7+
case Exclusion:
8+
case Full:
9+
case Right:
10+
case RightOnly:
11+
case RightSemi:
12+
return true;
13+
default:
14+
return false;
15+
}
616
}
717
bool NeedToTrackUnusedLeftTuples(EJoinKind kind) {
8-
return static_cast<int>(kind)&1 == 1;
18+
switch (kind) {
19+
using enum NKikimr::NMiniKQL::EJoinKind;
20+
case Exclusion:
21+
case Full:
22+
case Left:
23+
case LeftOnly:
24+
return true;
25+
default:
26+
return false;
27+
}
928
}
1029

1130
}

ydb/library/yql/dq/comp_nodes/dq_hash_join_table.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ class TStdJoinTable {
5353
return TrackUnusedTuples;
5454
}
5555

56+
const auto& MapView() const {
57+
return BuiltTable;
58+
}
5659
void ForEachUnused(std::function<void(TTuple)> produce) {
5760
MKQL_ENSURE(TrackUnusedTuples, "wasn't tracking tuples at all");
5861
for(auto& tuplesSameKey: BuiltTable) {

ydb/library/yql/dq/comp_nodes/dq_program_builder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ TRuntimeNode TDqProgramBuilder::DqBlockHashJoin(TRuntimeNode leftStream, TRuntim
163163
TRuntimeNode TDqProgramBuilder::DqScalarHashJoin(TRuntimeNode leftFlow, TRuntimeNode rightFlow, EJoinKind joinKind,
164164
const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& rightKeyColumns, TType* returnType) {
165165

166-
MKQL_ENSURE(joinKind == EJoinKind::Inner || joinKind == EJoinKind::Left || joinKind == EJoinKind::Right || joinKind == EJoinKind::Full, "Unsupported join kind");
166+
MKQL_ENSURE(joinKind != EJoinKind::Cross, "Unsupported join kind");
167167
MKQL_ENSURE(leftKeyColumns.size() == rightKeyColumns.size(), "Key column count mismatch");
168168
MKQL_ENSURE(!leftKeyColumns.empty(), "At least one key column must be specified");
169169

ydb/library/yql/dq/comp_nodes/dq_scalar_hash_join.cpp

Lines changed: 68 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,31 @@ TKeyTypes KeyTypesFromColumns(const std::vector<TType*>& types, const std::vecto
2121
return kt;
2222
}
2323

24+
bool SemiOrOnlyJoin(EJoinKind kind) {
25+
switch (kind) {
26+
using enum EJoinKind;
27+
case RightOnly:
28+
case RightSemi:
29+
case LeftOnly:
30+
case LeftSemi:
31+
return true;
32+
default:
33+
return false;
34+
}
35+
}
36+
37+
bool IsInner(EJoinKind kind) {
38+
switch (kind) {
39+
using enum EJoinKind;
40+
case Inner:
41+
case Full:
42+
case Left:
43+
case Right:
44+
return true;
45+
default:
46+
return false;
47+
}
48+
}
2449

2550
class TScalarHashJoinState : public TComputationValue<TScalarHashJoinState> {
2651
using TBase = TComputationValue<TScalarHashJoinState>;
@@ -31,38 +56,19 @@ class TScalarHashJoinState : public TComputationValue<TScalarHashJoinState> {
3156
IComputationWideFlowNode* ProbeSide() const {
3257
return LeftFinished_ ? nullptr : LeftFlow_;
3358
}
34-
void AppendTuple(NJoinTable::TTuple probe, NJoinTable::TTuple build, std::vector<NUdf::TUnboxedValue>& output) {
35-
MKQL_ENSURE(probe || build,"appending invalid tuple");
36-
if (probe) {
37-
for (int index = 0; index < std::ssize(LeftKeyColumns_); ++index) {
38-
output.push_back(probe[LeftKeyColumns_[index]]);
39-
}
59+
void AppendTuple(NJoinTable::TTuple left, NJoinTable::TTuple right, std::vector<NUdf::TUnboxedValue>& output) {
60+
MKQL_ENSURE(left || right,"appending invalid tuple");
61+
auto outIt = std::back_inserter(output);
62+
if (left) {
63+
std::copy_n(left,std::ssize(LeftColumnTypes_), outIt);
4064
} else {
41-
for (int index = 0; index < std::ssize(RightKeyColumns_); ++index) {
42-
output.push_back(build[RightKeyColumns_[index]]);
43-
}
65+
std::copy_n(NullTuples.data(),std::ssize(LeftColumnTypes_), outIt);
4466
}
45-
46-
for (int index = 0; index < std::ssize(LeftColumnTypes_); ++index) {
47-
if (std::ranges::find(LeftKeyColumns_, index) == LeftKeyColumns_.end()) {
48-
if (probe) {
49-
output.push_back(probe[index]);
50-
} else {
51-
output.push_back(NYql::NUdf::TUnboxedValuePod{});
52-
}
53-
}
54-
}
55-
56-
for (int index = 0; index < std::ssize(RightColumnTypes_); ++index) {
57-
if (std::ranges::find(RightKeyColumns_, index) == RightKeyColumns_.end()) {
58-
if (build) {
59-
output.push_back(build[index]);
60-
} else {
61-
output.push_back(NYql::NUdf::TUnboxedValuePod{});
62-
}
63-
}
67+
if (right) {
68+
std::copy_n(right,std::ssize(RightColumnTypes_), outIt);
69+
} else {
70+
std::copy_n(NullTuples.data(),std::ssize(RightColumnTypes_), outIt);
6471
}
65-
6672
}
6773

6874
public:
@@ -93,7 +99,7 @@ class TScalarHashJoinState : public TComputationValue<TScalarHashJoinState> {
9399
, Output_()
94100
{
95101
MKQL_ENSURE(RightColumnTypes_.size() == LeftColumnTypes_.size(), "unimplemented");
96-
MKQL_ENSURE(joinKind == EJoinKind::Inner || joinKind == EJoinKind::Left || joinKind == EJoinKind::Right || joinKind == EJoinKind::Full, "Unsupported join kind");
102+
MKQL_ENSURE(joinKind != EJoinKind::Cross, "Unsupported join kind");
97103
Pointers_.resize(LeftColumnTypes_.size());
98104
for (int index = 0; index < std::ssize(LeftKeyColumns_); ++index) {
99105
Pointers_[LeftKeyColumns_[index]] = &Values_[index];
@@ -105,13 +111,19 @@ class TScalarHashJoinState : public TComputationValue<TScalarHashJoinState> {
105111
valuesIndex++;
106112
}
107113
}
108-
MKQL_ENSURE(std::ranges::is_permutation(Values_ | std::views::transform([](auto& value){return &value;}), Pointers_), "Pointers_ should be a permutation of Values_ addresses");
114+
MKQL_ENSURE(std::ranges::is_permutation(Values_ | std::views::transform([](auto& value) {return &value;}), Pointers_), "Pointers_ should be a permutation of Values_ addresses");
109115

110116
UDF_LOG(Logger_, LogComponent_, NUdf::ELogLevel::Debug, "TScalarHashJoinState created");
111117
}
112118

113119
EFetchResult FetchValues(TComputationContext& ctx, NUdf::TUnboxedValue* const* output) {
114-
const int outputTupleSize = std::ssize(RightColumnTypes_) + std::ssize(LeftColumnTypes_) - std::ssize(LeftKeyColumns_);
120+
const int outputTupleSize = [&] {
121+
if (SemiOrOnlyJoin(JoinKind_)) {
122+
return std::ssize(RightColumnTypes_);
123+
} else {
124+
return std::ssize(RightColumnTypes_) * 2;
125+
}
126+
}();
115127
if (auto* buildSide = BuildSide()) {
116128
auto res = buildSide->FetchValues(ctx, Pointers_.data());
117129
switch (res) {
@@ -144,15 +156,28 @@ class TScalarHashJoinState : public TComputationValue<TScalarHashJoinState> {
144156
return EFetchResult::One;
145157
}
146158
if (auto* probeSide = ProbeSide()) {
147-
auto result = LeftFlow_->FetchValues(ctx, Pointers_.data());
159+
auto result = probeSide->FetchValues(ctx, Pointers_.data());
148160
switch (result) {
149161
case EFetchResult::Finish: {
150162
LeftFinished_ = true;
151163
if (Table_.UnusedTrackingOn()) {
164+
for (auto& v : Table_.MapView()) {
165+
if (v.second.Used && JoinKind_ == EJoinKind::RightSemi ) {
166+
for( NJoinTable::TTuple used: v.second.Tuples ) {
167+
std::copy_n(used, std::ssize(RightColumnTypes_), std::back_inserter(Output_));
168+
}
169+
}
170+
}
152171
Table_.ForEachUnused([this](NJoinTable::TTuple unused) {
153-
AppendTuple(nullptr, unused, Output_);
172+
if (JoinKind_ == EJoinKind::RightOnly) {
173+
std::copy_n(unused, std::ssize(RightColumnTypes_), std::back_inserter(Output_));
174+
}
175+
if (JoinKind_ == EJoinKind::Exclusion || JoinKind_ == EJoinKind::Right || JoinKind_ == EJoinKind::Full) {
176+
AppendTuple(nullptr, unused, Output_);
177+
}
154178
});
155179
}
180+
156181
return EFetchResult::Yield;
157182
}
158183
case EFetchResult::Yield: {
@@ -161,10 +186,15 @@ class TScalarHashJoinState : public TComputationValue<TScalarHashJoinState> {
161186
case EFetchResult::One: {
162187
bool found = false;
163188
Table_.Lookup(Values_.data(), [this, &found](NJoinTable::TTuple matched) {
164-
AppendTuple(Values_.data(),matched,Output_);
189+
if (IsInner(JoinKind_)) {
190+
AppendTuple(Values_.data(),matched,Output_);
191+
}
165192
found = true;
166193
});
167-
if (!found && NJoinTable::NeedToTrackUnusedLeftTuples(JoinKind_)) {
194+
if (!found && JoinKind_ == EJoinKind::LeftOnly || found && JoinKind_ == EJoinKind::LeftSemi) {
195+
std::copy(Values_.data(), Values_.data() + std::ssize(LeftColumnTypes_), std::back_inserter(Output_));
196+
}
197+
if (!found && (JoinKind_ == EJoinKind::Exclusion || JoinKind_ == EJoinKind::Left || JoinKind_ == EJoinKind::Full)) {
168198
AppendTuple(Values_.data(), nullptr, Output_);
169199
}
170200
return EFetchResult::Yield;
@@ -189,6 +219,8 @@ class TScalarHashJoinState : public TComputationValue<TScalarHashJoinState> {
189219
const NUdf::TLogComponentId LogComponent_;
190220
const TKeyTypes KeyTypes_;
191221
const EJoinKind JoinKind_;
222+
const std::vector<NYql::NUdf::TUnboxedValuePod> NullTuples{std::max(std::size(LeftColumnTypes_), std::size(RightColumnTypes_)), NYql::NUdf::TUnboxedValuePod{}};
223+
192224
bool LeftFinished_ = false;
193225
bool RightFinished_ = false;
194226
NJoinTable::TStdJoinTable Table_;

0 commit comments

Comments
 (0)