Skip to content

Commit 8aacd6c

Browse files
VVBondarenkorozza
andcommitted
Added MongoTable support for deletes
- added support for delete - minor refactoring for filter transformations - minor fix for catalog to load table with resolved schema instead of empty one SPARK-414 Original PR: mongodb#124 --------- Co-authored-by: Ross Lawley <[email protected]>
1 parent 29d321d commit 8aacd6c

File tree

6 files changed

+326
-194
lines changed

6 files changed

+326
-194
lines changed

src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
package com.mongodb.spark.sql.connector;
1919

20+
import static com.mongodb.spark.sql.connector.mongodb.MongoSparkConnectorHelper.CATALOG;
2021
import static java.util.Arrays.asList;
2122
import static java.util.Collections.singletonList;
23+
import static org.junit.jupiter.api.Assertions.assertEquals;
2224
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
2325

2426
import com.mongodb.spark.sql.connector.beans.BoxedBean;
@@ -41,6 +43,7 @@
4143
import org.apache.spark.sql.Dataset;
4244
import org.apache.spark.sql.Encoder;
4345
import org.apache.spark.sql.Encoders;
46+
import org.apache.spark.sql.Row;
4447
import org.apache.spark.sql.SparkSession;
4548
import org.junit.jupiter.api.Test;
4649
import org.junit.jupiter.params.ParameterizedTest;
@@ -172,4 +175,32 @@ void testComplexBean() {
172175
.collectAsList();
173176
assertIterableEquals(dataSetOriginal, dataSetMongo);
174177
}
178+
179+
@Test
180+
void testCatalogAccessAndDelete() {
181+
List<BoxedBean> dataSetOriginal = asList(
182+
new BoxedBean((byte) 1, (short) 2, 0, 4L, 5.0f, 6.0, true),
183+
new BoxedBean((byte) 1, (short) 2, 1, 4L, 5.0f, 6.0, true),
184+
new BoxedBean((byte) 1, (short) 2, 2, 4L, 5.0f, 6.0, true),
185+
new BoxedBean((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0, false),
186+
new BoxedBean((byte) 1, (short) 2, 4, 4L, 5.0f, 6.0, false),
187+
new BoxedBean((byte) 1, (short) 2, 5, 4L, 5.0f, 6.0, false));
188+
189+
SparkSession spark = getOrCreateSparkSession();
190+
Encoder<BoxedBean> encoder = Encoders.bean(BoxedBean.class);
191+
spark
192+
.createDataset(dataSetOriginal, encoder)
193+
.write()
194+
.format("mongodb")
195+
.mode("Overwrite")
196+
.save();
197+
198+
String tableName = CATALOG + "." + HELPER.getDatabaseName() + "." + HELPER.getCollectionName();
199+
List<Row> rows = spark.sql("select * from " + tableName).collectAsList();
200+
assertEquals(6, rows.size());
201+
202+
spark.sql("delete from " + tableName + " where booleanField = false and intField > 3");
203+
rows = spark.sql("select * from " + tableName).collectAsList();
204+
assertEquals(4, rows.size());
205+
}
175206
}

src/integrationTest/java/com/mongodb/spark/sql/connector/mongodb/MongoSparkConnectorHelper.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import com.mongodb.client.model.UpdateOptions;
2929
import com.mongodb.client.model.Updates;
3030
import com.mongodb.connection.ClusterType;
31+
import com.mongodb.spark.sql.connector.MongoCatalog;
3132
import com.mongodb.spark.sql.connector.config.MongoConfig;
3233
import java.io.File;
3334
import java.io.IOException;
@@ -62,6 +63,7 @@ public class MongoSparkConnectorHelper
6263
"{_id: '%s', pk: '%s', dups: '%s', i: %d, s: '%s'}";
6364
private static final String COMPLEX_SAMPLE_DATA_TEMPLATE =
6465
"{_id: '%s', nested: {pk: '%s', dups: '%s', i: %d}, s: '%s'}";
66+
public static final String CATALOG = "mongo_catalog";
6567

6668
private static final Logger LOGGER = LoggerFactory.getLogger(MongoSparkConnectorHelper.class);
6769

@@ -146,6 +148,7 @@ public SparkConf getSparkConf() {
146148
.set("spark.sql.streaming.checkpointLocation", getTempDirectory())
147149
.set("spark.sql.streaming.forceDeleteTempCheckpointLocation", "true")
148150
.set("spark.app.id", "MongoSparkConnector")
151+
.set("spark.sql.catalog." + CATALOG, MongoCatalog.class.getCanonicalName())
149152
.set(
150153
MongoConfig.PREFIX + MongoConfig.CONNECTION_STRING_CONFIG,
151154
getConnectionString().getConnectionString())
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
*/
17+
18+
package com.mongodb.spark.sql.connector;
19+
20+
import static com.mongodb.spark.sql.connector.schema.RowToBsonDocumentConverter.createObjectToBsonValue;
21+
import static java.lang.String.format;
22+
23+
import com.mongodb.client.model.Filters;
24+
import com.mongodb.spark.sql.connector.assertions.Assertions;
25+
import com.mongodb.spark.sql.connector.config.WriteConfig;
26+
import com.mongodb.spark.sql.connector.schema.RowToBsonDocumentConverter;
27+
import java.util.Arrays;
28+
import java.util.List;
29+
import java.util.Optional;
30+
import java.util.stream.Collectors;
31+
import org.apache.spark.sql.Column;
32+
import org.apache.spark.sql.sources.And;
33+
import org.apache.spark.sql.sources.EqualNullSafe;
34+
import org.apache.spark.sql.sources.EqualTo;
35+
import org.apache.spark.sql.sources.Filter;
36+
import org.apache.spark.sql.sources.GreaterThan;
37+
import org.apache.spark.sql.sources.GreaterThanOrEqual;
38+
import org.apache.spark.sql.sources.In;
39+
import org.apache.spark.sql.sources.IsNotNull;
40+
import org.apache.spark.sql.sources.IsNull;
41+
import org.apache.spark.sql.sources.LessThan;
42+
import org.apache.spark.sql.sources.LessThanOrEqual;
43+
import org.apache.spark.sql.sources.Not;
44+
import org.apache.spark.sql.sources.Or;
45+
import org.apache.spark.sql.sources.StringContains;
46+
import org.apache.spark.sql.sources.StringEndsWith;
47+
import org.apache.spark.sql.sources.StringStartsWith;
48+
import org.apache.spark.sql.types.DataType;
49+
import org.apache.spark.sql.types.StructField;
50+
import org.apache.spark.sql.types.StructType;
51+
import org.bson.BsonValue;
52+
import org.bson.conversions.Bson;
53+
import org.jetbrains.annotations.Nullable;
54+
import org.jetbrains.annotations.VisibleForTesting;
55+
56+
/**
57+
* Utility class to convert {@link Filter} expressions into MongoDB aggregation pipelines
58+
*
59+
* @since 10.6
60+
*/
61+
public final class ExpressionConverter {
62+
private final StructType schema;
63+
64+
/**
65+
* Construct a new instance
66+
* @param schema the schema for the data
67+
*/
68+
public ExpressionConverter(final StructType schema) {
69+
this.schema = schema;
70+
}
71+
72+
/**
73+
* Processes {@link Filter} into aggregation pipelines if possible
74+
* @param filter the filter to translate
75+
* @return the {@link FilterAndPipelineStage} representing the Filter and pipeline stage if conversion is possible
76+
*/
77+
public FilterAndPipelineStage processFilter(final Filter filter) {
78+
Assertions.ensureArgument(() -> filter != null, () -> "Invalid argument filter cannot be null");
79+
if (filter instanceof And) {
80+
And andFilter = (And) filter;
81+
FilterAndPipelineStage eitherLeft = processFilter(andFilter.left());
82+
FilterAndPipelineStage eitherRight = processFilter(andFilter.right());
83+
if (eitherLeft.hasPipelineStage() && eitherRight.hasPipelineStage()) {
84+
return new FilterAndPipelineStage(
85+
filter, Filters.and(eitherLeft.getPipelineStage(), eitherRight.getPipelineStage()));
86+
}
87+
} else if (filter instanceof EqualNullSafe) {
88+
EqualNullSafe equalNullSafe = (EqualNullSafe) filter;
89+
String fieldName = unquoteFieldName(equalNullSafe.attribute());
90+
return new FilterAndPipelineStage(
91+
filter,
92+
getBsonValue(fieldName, equalNullSafe.value())
93+
.map(bsonValue -> Filters.eq(fieldName, bsonValue))
94+
.orElse(null));
95+
} else if (filter instanceof EqualTo) {
96+
EqualTo equalTo = (EqualTo) filter;
97+
String fieldName = unquoteFieldName(equalTo.attribute());
98+
return new FilterAndPipelineStage(
99+
filter,
100+
getBsonValue(fieldName, equalTo.value())
101+
.map(bsonValue -> Filters.eq(fieldName, bsonValue))
102+
.orElse(null));
103+
} else if (filter instanceof GreaterThan) {
104+
GreaterThan greaterThan = (GreaterThan) filter;
105+
String fieldName = unquoteFieldName(greaterThan.attribute());
106+
return new FilterAndPipelineStage(
107+
filter,
108+
getBsonValue(fieldName, greaterThan.value())
109+
.map(bsonValue -> Filters.gt(fieldName, bsonValue))
110+
.orElse(null));
111+
} else if (filter instanceof GreaterThanOrEqual) {
112+
GreaterThanOrEqual greaterThanOrEqual = (GreaterThanOrEqual) filter;
113+
String fieldName = unquoteFieldName(greaterThanOrEqual.attribute());
114+
return new FilterAndPipelineStage(
115+
filter,
116+
getBsonValue(fieldName, greaterThanOrEqual.value())
117+
.map(bsonValue -> Filters.gte(fieldName, bsonValue))
118+
.orElse(null));
119+
} else if (filter instanceof In) {
120+
In inFilter = (In) filter;
121+
String fieldName = unquoteFieldName(inFilter.attribute());
122+
List<BsonValue> values = Arrays.stream(inFilter.values())
123+
.map(v -> getBsonValue(fieldName, v))
124+
.filter(Optional::isPresent)
125+
.map(Optional::get)
126+
.collect(Collectors.toList());
127+
128+
// Ensure all values were matched otherwise leave to Spark to filter.
129+
Bson pipelineStage = null;
130+
if (values.size() == inFilter.values().length) {
131+
pipelineStage = Filters.in(fieldName, values);
132+
}
133+
return new FilterAndPipelineStage(filter, pipelineStage);
134+
} else if (filter instanceof IsNull) {
135+
IsNull isNullFilter = (IsNull) filter;
136+
String fieldName = unquoteFieldName(isNullFilter.attribute());
137+
return new FilterAndPipelineStage(filter, Filters.eq(fieldName, null));
138+
} else if (filter instanceof IsNotNull) {
139+
IsNotNull isNotNullFilter = (IsNotNull) filter;
140+
String fieldName = unquoteFieldName(isNotNullFilter.attribute());
141+
return new FilterAndPipelineStage(filter, Filters.ne(fieldName, null));
142+
} else if (filter instanceof LessThan) {
143+
LessThan lessThan = (LessThan) filter;
144+
String fieldName = unquoteFieldName(lessThan.attribute());
145+
return new FilterAndPipelineStage(
146+
filter,
147+
getBsonValue(fieldName, lessThan.value())
148+
.map(bsonValue -> Filters.lt(fieldName, bsonValue))
149+
.orElse(null));
150+
} else if (filter instanceof LessThanOrEqual) {
151+
LessThanOrEqual lessThanOrEqual = (LessThanOrEqual) filter;
152+
String fieldName = unquoteFieldName(lessThanOrEqual.attribute());
153+
return new FilterAndPipelineStage(
154+
filter,
155+
getBsonValue(fieldName, lessThanOrEqual.value())
156+
.map(bsonValue -> Filters.lte(fieldName, bsonValue))
157+
.orElse(null));
158+
} else if (filter instanceof Not) {
159+
Not notFilter = (Not) filter;
160+
FilterAndPipelineStage notChild = processFilter(notFilter.child());
161+
if (notChild.hasPipelineStage()) {
162+
return new FilterAndPipelineStage(filter, Filters.not(notChild.pipelineStage));
163+
}
164+
} else if (filter instanceof Or) {
165+
Or or = (Or) filter;
166+
FilterAndPipelineStage eitherLeft = processFilter(or.left());
167+
FilterAndPipelineStage eitherRight = processFilter(or.right());
168+
if (eitherLeft.hasPipelineStage() && eitherRight.hasPipelineStage()) {
169+
return new FilterAndPipelineStage(
170+
filter, Filters.or(eitherLeft.getPipelineStage(), eitherRight.getPipelineStage()));
171+
}
172+
} else if (filter instanceof StringContains) {
173+
StringContains stringContains = (StringContains) filter;
174+
String fieldName = unquoteFieldName(stringContains.attribute());
175+
return new FilterAndPipelineStage(
176+
filter, Filters.regex(fieldName, format(".*%s.*", stringContains.value())));
177+
} else if (filter instanceof StringEndsWith) {
178+
StringEndsWith stringEndsWith = (StringEndsWith) filter;
179+
String fieldName = unquoteFieldName(stringEndsWith.attribute());
180+
return new FilterAndPipelineStage(
181+
filter, Filters.regex(fieldName, format(".*%s$", stringEndsWith.value())));
182+
} else if (filter instanceof StringStartsWith) {
183+
StringStartsWith stringStartsWith = (StringStartsWith) filter;
184+
String fieldName = unquoteFieldName(stringStartsWith.attribute());
185+
return new FilterAndPipelineStage(
186+
filter, Filters.regex(fieldName, format("^%s.*", stringStartsWith.value())));
187+
}
188+
return new FilterAndPipelineStage(filter, null);
189+
}
190+
191+
@VisibleForTesting
192+
static String unquoteFieldName(final String fieldName) {
193+
// Spark automatically escapes hyphenated names using backticks
194+
if (fieldName.contains("`")) {
195+
return new Column(fieldName).toString();
196+
}
197+
return fieldName;
198+
}
199+
200+
private Optional<BsonValue> getBsonValue(final String fieldName, final Object value) {
201+
try {
202+
StructType localSchema = schema;
203+
DataType localDataType = localSchema;
204+
205+
for (String localFieldName : fieldName.split("\\.")) {
206+
StructField localField = localSchema.apply(localFieldName);
207+
localDataType = localField.dataType();
208+
if (localField.dataType() instanceof StructType) {
209+
localSchema = (StructType) localField.dataType();
210+
}
211+
}
212+
RowToBsonDocumentConverter.ObjectToBsonValue objectToBsonValue =
213+
createObjectToBsonValue(localDataType, WriteConfig.ConvertJson.FALSE, false);
214+
return Optional.of(objectToBsonValue.apply(value));
215+
} catch (Exception e) {
216+
// ignore
217+
return Optional.empty();
218+
}
219+
}
220+
221+
/** FilterAndPipelineStage - contains an optional pipeline stage for the filter. */
222+
public static final class FilterAndPipelineStage {
223+
224+
private final Filter filter;
225+
private final Bson pipelineStage;
226+
227+
private FilterAndPipelineStage(final Filter filter, @Nullable final Bson pipelineStage) {
228+
this.filter = filter;
229+
this.pipelineStage = pipelineStage;
230+
}
231+
232+
/**
233+
* @return the filter
234+
*/
235+
public Filter getFilter() {
236+
return filter;
237+
}
238+
239+
/**
240+
* @return the equivalent pipeline for the filter or {@code null} if translation for the filter wasn't possible
241+
*/
242+
public Bson getPipelineStage() {
243+
return pipelineStage;
244+
}
245+
246+
/**
247+
* @return true if the {@link Filter} could be converted into a pipeline stage
248+
*/
249+
public boolean hasPipelineStage() {
250+
return pipelineStage != null;
251+
}
252+
}
253+
}

src/main/java/com/mongodb/spark/sql/connector/MongoCatalog.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import com.mongodb.spark.sql.connector.config.ReadConfig;
2929
import com.mongodb.spark.sql.connector.config.WriteConfig;
3030
import com.mongodb.spark.sql.connector.exceptions.MongoSparkException;
31+
import com.mongodb.spark.sql.connector.schema.InferSchema;
3132
import java.util.ArrayList;
3233
import java.util.Collections;
3334
import java.util.HashMap;
@@ -239,7 +240,9 @@ public Table loadTable(final Identifier identifier) throws NoSuchTableException
239240
properties.put(
240241
MongoConfig.READ_PREFIX + MongoConfig.DATABASE_NAME_CONFIG, identifier.namespace()[0]);
241242
properties.put(MongoConfig.READ_PREFIX + MongoConfig.COLLECTION_NAME_CONFIG, identifier.name());
242-
return new MongoTable(MongoConfig.readConfig(properties));
243+
return new MongoTable(
244+
InferSchema.inferSchema(new CaseInsensitiveStringMap(properties)),
245+
MongoConfig.readConfig(properties));
243246
}
244247

245248
/**

0 commit comments

Comments
 (0)