@@ -21,6 +21,31 @@ TKeyTypes KeyTypesFromColumns(const std::vector<TType*>& types, const std::vecto
21
21
return kt;
22
22
}
23
23
24
+ bool SemiOrOnlyJoin (EJoinKind kind) {
25
+ switch (kind) {
26
+ using enum EJoinKind;
27
+ case RightOnly:
28
+ case RightSemi:
29
+ case LeftOnly:
30
+ case LeftSemi:
31
+ return true ;
32
+ default :
33
+ return false ;
34
+ }
35
+ }
36
+
37
+ bool IsInner (EJoinKind kind) {
38
+ switch (kind) {
39
+ using enum EJoinKind;
40
+ case Inner:
41
+ case Full:
42
+ case Left:
43
+ case Right:
44
+ return true ;
45
+ default :
46
+ return false ;
47
+ }
48
+ }
24
49
25
50
class TScalarHashJoinState : public TComputationValue <TScalarHashJoinState> {
26
51
using TBase = TComputationValue<TScalarHashJoinState>;
@@ -31,38 +56,19 @@ class TScalarHashJoinState : public TComputationValue<TScalarHashJoinState> {
31
56
IComputationWideFlowNode* ProbeSide () const {
32
57
return LeftFinished_ ? nullptr : LeftFlow_;
33
58
}
34
- void AppendTuple (NJoinTable::TTuple probe, NJoinTable::TTuple build, std::vector<NUdf::TUnboxedValue>& output) {
35
- MKQL_ENSURE (probe || build," appending invalid tuple" );
36
- if (probe) {
37
- for (int index = 0 ; index < std::ssize (LeftKeyColumns_); ++index) {
38
- output.push_back (probe[LeftKeyColumns_[index]]);
39
- }
59
+ void AppendTuple (NJoinTable::TTuple left, NJoinTable::TTuple right, std::vector<NUdf::TUnboxedValue>& output) {
60
+ MKQL_ENSURE (left || right," appending invalid tuple" );
61
+ auto outIt = std::back_inserter (output);
62
+ if (left) {
63
+ std::copy_n (left,std::ssize (LeftColumnTypes_), outIt);
40
64
} else {
41
- for (int index = 0 ; index < std::ssize (RightKeyColumns_); ++index) {
42
- output.push_back (build[RightKeyColumns_[index]]);
43
- }
65
+ std::copy_n (NullTuples.data (),std::ssize (LeftColumnTypes_), outIt);
44
66
}
45
-
46
- for (int index = 0 ; index < std::ssize (LeftColumnTypes_); ++index) {
47
- if (std::ranges::find (LeftKeyColumns_, index) == LeftKeyColumns_.end ()) {
48
- if (probe) {
49
- output.push_back (probe[index]);
50
- } else {
51
- output.push_back (NYql::NUdf::TUnboxedValuePod{});
52
- }
53
- }
54
- }
55
-
56
- for (int index = 0 ; index < std::ssize (RightColumnTypes_); ++index) {
57
- if (std::ranges::find (RightKeyColumns_, index) == RightKeyColumns_.end ()) {
58
- if (build) {
59
- output.push_back (build[index]);
60
- } else {
61
- output.push_back (NYql::NUdf::TUnboxedValuePod{});
62
- }
63
- }
67
+ if (right) {
68
+ std::copy_n (right,std::ssize (RightColumnTypes_), outIt);
69
+ } else {
70
+ std::copy_n (NullTuples.data (),std::ssize (RightColumnTypes_), outIt);
64
71
}
65
-
66
72
}
67
73
68
74
public:
@@ -93,7 +99,7 @@ class TScalarHashJoinState : public TComputationValue<TScalarHashJoinState> {
93
99
, Output_()
94
100
{
95
101
MKQL_ENSURE (RightColumnTypes_.size () == LeftColumnTypes_.size (), " unimplemented" );
96
- MKQL_ENSURE (joinKind == EJoinKind::Inner || joinKind == EJoinKind::Left || joinKind == EJoinKind::Right || joinKind == EJoinKind::Full , " Unsupported join kind" );
102
+ MKQL_ENSURE (joinKind != EJoinKind::Cross , " Unsupported join kind" );
97
103
Pointers_.resize (LeftColumnTypes_.size ());
98
104
for (int index = 0 ; index < std::ssize (LeftKeyColumns_); ++index) {
99
105
Pointers_[LeftKeyColumns_[index]] = &Values_[index];
@@ -105,13 +111,19 @@ class TScalarHashJoinState : public TComputationValue<TScalarHashJoinState> {
105
111
valuesIndex++;
106
112
}
107
113
}
108
- MKQL_ENSURE (std::ranges::is_permutation (Values_ | std::views::transform ([](auto & value){return &value;}), Pointers_), " Pointers_ should be a permutation of Values_ addresses" );
114
+ MKQL_ENSURE (std::ranges::is_permutation (Values_ | std::views::transform ([](auto & value) {return &value;}), Pointers_), " Pointers_ should be a permutation of Values_ addresses" );
109
115
110
116
UDF_LOG (Logger_, LogComponent_, NUdf::ELogLevel::Debug, " TScalarHashJoinState created" );
111
117
}
112
118
113
119
EFetchResult FetchValues (TComputationContext& ctx, NUdf::TUnboxedValue* const * output) {
114
- const int outputTupleSize = std::ssize (RightColumnTypes_) + std::ssize (LeftColumnTypes_) - std::ssize (LeftKeyColumns_);
120
+ const int outputTupleSize = [&] {
121
+ if (SemiOrOnlyJoin (JoinKind_)) {
122
+ return std::ssize (RightColumnTypes_);
123
+ } else {
124
+ return std::ssize (RightColumnTypes_) * 2 ;
125
+ }
126
+ }();
115
127
if (auto * buildSide = BuildSide ()) {
116
128
auto res = buildSide->FetchValues (ctx, Pointers_.data ());
117
129
switch (res) {
@@ -144,15 +156,28 @@ class TScalarHashJoinState : public TComputationValue<TScalarHashJoinState> {
144
156
return EFetchResult::One;
145
157
}
146
158
if (auto * probeSide = ProbeSide ()) {
147
- auto result = LeftFlow_ ->FetchValues (ctx, Pointers_.data ());
159
+ auto result = probeSide ->FetchValues (ctx, Pointers_.data ());
148
160
switch (result) {
149
161
case EFetchResult::Finish: {
150
162
LeftFinished_ = true ;
151
163
if (Table_.UnusedTrackingOn ()) {
164
+ for (auto & v : Table_.MapView ()) {
165
+ if (v.second .Used && JoinKind_ == EJoinKind::RightSemi ) {
166
+ for ( NJoinTable::TTuple used: v.second .Tuples ) {
167
+ std::copy_n (used, std::ssize (RightColumnTypes_), std::back_inserter (Output_));
168
+ }
169
+ }
170
+ }
152
171
Table_.ForEachUnused ([this ](NJoinTable::TTuple unused) {
153
- AppendTuple (nullptr , unused, Output_);
172
+ if (JoinKind_ == EJoinKind::RightOnly) {
173
+ std::copy_n (unused, std::ssize (RightColumnTypes_), std::back_inserter (Output_));
174
+ }
175
+ if (JoinKind_ == EJoinKind::Exclusion || JoinKind_ == EJoinKind::Right || JoinKind_ == EJoinKind::Full) {
176
+ AppendTuple (nullptr , unused, Output_);
177
+ }
154
178
});
155
179
}
180
+
156
181
return EFetchResult::Yield;
157
182
}
158
183
case EFetchResult::Yield: {
@@ -161,10 +186,15 @@ class TScalarHashJoinState : public TComputationValue<TScalarHashJoinState> {
161
186
case EFetchResult::One: {
162
187
bool found = false ;
163
188
Table_.Lookup (Values_.data (), [this , &found](NJoinTable::TTuple matched) {
164
- AppendTuple (Values_.data (),matched,Output_);
189
+ if (IsInner (JoinKind_)) {
190
+ AppendTuple (Values_.data (),matched,Output_);
191
+ }
165
192
found = true ;
166
193
});
167
- if (!found && NJoinTable::NeedToTrackUnusedLeftTuples (JoinKind_)) {
194
+ if (!found && JoinKind_ == EJoinKind::LeftOnly || found && JoinKind_ == EJoinKind::LeftSemi) {
195
+ std::copy (Values_.data (), Values_.data () + std::ssize (LeftColumnTypes_), std::back_inserter (Output_));
196
+ }
197
+ if (!found && (JoinKind_ == EJoinKind::Exclusion || JoinKind_ == EJoinKind::Left || JoinKind_ == EJoinKind::Full)) {
168
198
AppendTuple (Values_.data (), nullptr , Output_);
169
199
}
170
200
return EFetchResult::Yield;
@@ -189,6 +219,8 @@ class TScalarHashJoinState : public TComputationValue<TScalarHashJoinState> {
189
219
const NUdf::TLogComponentId LogComponent_;
190
220
const TKeyTypes KeyTypes_;
191
221
const EJoinKind JoinKind_;
222
+ const std::vector<NYql::NUdf::TUnboxedValuePod> NullTuples{std::max (std::size (LeftColumnTypes_), std::size (RightColumnTypes_)), NYql::NUdf::TUnboxedValuePod{}};
223
+
192
224
bool LeftFinished_ = false ;
193
225
bool RightFinished_ = false ;
194
226
NJoinTable::TStdJoinTable Table_;
0 commit comments