Skip to content

Commit 29d321d

Browse files
Added truncateMode write configuration
To allow the overwrite save mode to keep collection options SPARK-384 Original PR: mongodb#123 - removed recreate mode due to fragility --------- Co-authored-by: Ross Lawley <[email protected]>
1 parent 55dea7d commit 29d321d

File tree

5 files changed

+260
-5
lines changed

5 files changed

+260
-5
lines changed

src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import com.mongodb.spark.sql.connector.beans.ComplexBean;
2626
import com.mongodb.spark.sql.connector.beans.DateTimeBean;
2727
import com.mongodb.spark.sql.connector.beans.PrimitiveBean;
28+
import com.mongodb.spark.sql.connector.config.WriteConfig;
29+
import com.mongodb.spark.sql.connector.config.WriteConfig.TruncateMode;
2830
import com.mongodb.spark.sql.connector.mongodb.MongoSparkConnectorTestCase;
2931
import java.sql.Date;
3032
import java.sql.Timestamp;
@@ -41,6 +43,8 @@
4143
import org.apache.spark.sql.Encoders;
4244
import org.apache.spark.sql.SparkSession;
4345
import org.junit.jupiter.api.Test;
46+
import org.junit.jupiter.params.ParameterizedTest;
47+
import org.junit.jupiter.params.provider.EnumSource;
4448

4549
public class RoundTripTest extends MongoSparkConnectorTestCase {
4650

@@ -68,8 +72,9 @@ void testPrimitiveBean() {
6872
assertIterableEquals(dataSetOriginal, dataSetMongo);
6973
}
7074

71-
@Test
72-
void testBoxedBean() {
75+
@ParameterizedTest
76+
@EnumSource(TruncateMode.class)
77+
void testBoxedBean(final TruncateMode mode) {
7378
// Given
7479
List<BoxedBean> dataSetOriginal =
7580
singletonList(new BoxedBean((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0, true));
@@ -79,7 +84,12 @@ void testBoxedBean() {
7984
Encoder<BoxedBean> encoder = Encoders.bean(BoxedBean.class);
8085

8186
Dataset<BoxedBean> dataset = spark.createDataset(dataSetOriginal, encoder);
82-
dataset.write().format("mongodb").mode("Overwrite").save();
87+
dataset
88+
.write()
89+
.format("mongodb")
90+
.mode("Overwrite")
91+
.option(WriteConfig.TRUNCATE_MODE_CONFIG, mode.name())
92+
.save();
8393

8494
// Then
8595
List<BoxedBean> dataSetMongo = spark
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
*/
17+
package com.mongodb.spark.sql.connector.write;
18+
19+
import static com.mongodb.spark.sql.connector.config.WriteConfig.TRUNCATE_MODE_CONFIG;
20+
import static java.util.Arrays.asList;
21+
import static java.util.Collections.singletonList;
22+
import static org.junit.jupiter.api.Assertions.assertEquals;
23+
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
24+
import static org.junit.jupiter.api.Assertions.assertTrue;
25+
26+
import com.mongodb.client.MongoCollection;
27+
import com.mongodb.client.MongoDatabase;
28+
import com.mongodb.client.model.Collation;
29+
import com.mongodb.client.model.CollationStrength;
30+
import com.mongodb.client.model.CreateCollectionOptions;
31+
import com.mongodb.client.model.IndexOptions;
32+
import com.mongodb.spark.sql.connector.beans.BoxedBean;
33+
import com.mongodb.spark.sql.connector.config.WriteConfig;
34+
import com.mongodb.spark.sql.connector.mongodb.MongoSparkConnectorTestCase;
35+
import java.util.ArrayList;
36+
import java.util.List;
37+
import org.apache.spark.sql.Dataset;
38+
import org.apache.spark.sql.Encoder;
39+
import org.apache.spark.sql.Encoders;
40+
import org.apache.spark.sql.SparkSession;
41+
import org.bson.Document;
42+
import org.jetbrains.annotations.NotNull;
43+
import org.junit.jupiter.api.BeforeEach;
44+
import org.junit.jupiter.api.Test;
45+
46+
public class TruncateModesTest extends MongoSparkConnectorTestCase {
47+
48+
public static final String INT_FIELD_INDEX = "intFieldIndex";
49+
public static final String ID_INDEX = "_id_";
50+
51+
@BeforeEach
52+
void setup() {
53+
MongoDatabase database = getDatabase();
54+
getCollection().drop();
55+
CreateCollectionOptions createCollectionOptions = new CreateCollectionOptions()
56+
.collation(Collation.builder()
57+
.locale("en")
58+
.collationStrength(CollationStrength.SECONDARY)
59+
.build());
60+
database.createCollection(getCollectionName(), createCollectionOptions);
61+
MongoCollection<Document> collection = database.getCollection(getCollectionName());
62+
collection.insertOne(new Document().append("intField", null));
63+
collection.createIndex(
64+
new Document().append("intField", 1), new IndexOptions().name(INT_FIELD_INDEX));
65+
}
66+
67+
@Test
68+
void testCollectionDroppedOnOverwrite() {
69+
// Given
70+
List<BoxedBean> dataSetOriginal = singletonList(getBoxedBean());
71+
72+
// when
73+
SparkSession spark = getOrCreateSparkSession();
74+
Encoder<BoxedBean> encoder = Encoders.bean(BoxedBean.class);
75+
Dataset<BoxedBean> dataset = spark.createDataset(dataSetOriginal, encoder);
76+
dataset
77+
.write()
78+
.format("mongodb")
79+
.mode("Overwrite")
80+
.option(TRUNCATE_MODE_CONFIG, WriteConfig.TruncateMode.DROP.toString())
81+
.save();
82+
83+
// Then
84+
List<BoxedBean> dataSetMongo = spark
85+
.read()
86+
.format("mongodb")
87+
.schema(encoder.schema())
88+
.load()
89+
.as(encoder)
90+
.collectAsList();
91+
assertIterableEquals(dataSetOriginal, dataSetMongo);
92+
93+
List<String> indexes =
94+
getCollection().listIndexes().map(it -> it.getString("name")).into(new ArrayList<>());
95+
assertEquals(indexes, singletonList(ID_INDEX));
96+
Document options = getCollectionOptions();
97+
assertTrue(options.isEmpty());
98+
}
99+
100+
@Test
101+
void testOptionKeepingOverwrites() {
102+
// Given
103+
List<BoxedBean> dataSetOriginal = singletonList(getBoxedBean());
104+
105+
// when
106+
SparkSession spark = getOrCreateSparkSession();
107+
Encoder<BoxedBean> encoder = Encoders.bean(BoxedBean.class);
108+
Dataset<BoxedBean> dataset = spark.createDataset(dataSetOriginal, encoder);
109+
dataset
110+
.write()
111+
.format("mongodb")
112+
.mode("Overwrite")
113+
.option(TRUNCATE_MODE_CONFIG, WriteConfig.TruncateMode.TRUNCATE.toString())
114+
.save();
115+
116+
// Then
117+
List<BoxedBean> dataSetMongo = spark
118+
.read()
119+
.format("mongodb")
120+
.schema(encoder.schema())
121+
.load()
122+
.as(encoder)
123+
.collectAsList();
124+
assertIterableEquals(dataSetOriginal, dataSetMongo);
125+
126+
List<String> indexes =
127+
getCollection().listIndexes().map(it -> it.getString("name")).into(new ArrayList<>());
128+
assertEquals(indexes, asList(ID_INDEX, INT_FIELD_INDEX));
129+
130+
Document options = getCollectionOptions();
131+
assertTrue(options.containsKey("collation"));
132+
assertEquals("en", options.get("collation", new Document()).get("locale", "NA"), "en");
133+
}
134+
135+
private @NotNull BoxedBean getBoxedBean() {
136+
return new BoxedBean((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0, true);
137+
}
138+
139+
private Document getCollectionOptions() {
140+
Document getCollectionMeta = new Document()
141+
.append("listCollections", 1)
142+
.append("filter", new Document().append("name", getCollectionName()));
143+
144+
Document foundMeta = getDatabase().runCommand(getCollectionMeta);
145+
Document cursor = foundMeta.get("cursor", Document.class);
146+
List<Document> firstBatch = cursor.getList("firstBatch", Document.class);
147+
if (firstBatch.isEmpty()) {
148+
return getCollectionMeta;
149+
}
150+
151+
return firstBatch.get(0).get("options", Document.class);
152+
}
153+
}

src/main/java/com/mongodb/spark/sql/connector/config/WriteConfig.java

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@
2222

2323
import com.mongodb.MongoNamespace;
2424
import com.mongodb.WriteConcern;
25+
import com.mongodb.client.MongoCollection;
2526
import com.mongodb.spark.sql.connector.exceptions.ConfigException;
2627
import java.util.HashMap;
2728
import java.util.List;
2829
import java.util.Map;
2930
import java.util.concurrent.TimeUnit;
31+
import org.bson.Document;
3032
import org.jetbrains.annotations.ApiStatus;
3133
import org.slf4j.Logger;
3234
import org.slf4j.LoggerFactory;
@@ -116,6 +118,59 @@ public String toString() {
116118
}
117119
}
118120

121+
/**
122+
* Determines how to truncate a collection when using {@link org.apache.spark.sql.SaveMode#Overwrite}
123+
*
124+
* @since 10.6
125+
*/
126+
public enum TruncateMode {
127+
/**
128+
* Drops the collection
129+
*/
130+
DROP("drop") {
131+
@Override
132+
public void truncate(final WriteConfig writeConfig) {
133+
writeConfig.doWithCollection(MongoCollection::drop);
134+
}
135+
},
136+
/**
137+
* Deletes all entries in the collection preserving indexes, collection options and any sharding configuration
138+
* <p><strong>Warning:</strong> This operation is currently much more expensive than doing a simple drop operation. </p>
139+
*/
140+
TRUNCATE("truncate") {
141+
@Override
142+
public void truncate(final WriteConfig writeConfig) {
143+
writeConfig.doWithCollection(collection -> collection.deleteMany(new Document()));
144+
}
145+
};
146+
147+
private final String value;
148+
149+
TruncateMode(final String value) {
150+
this.value = value;
151+
}
152+
153+
static TruncateMode fromString(final String truncateMode) {
154+
for (TruncateMode truncateModeType : TruncateMode.values()) {
155+
if (truncateMode.equalsIgnoreCase(truncateModeType.value)) {
156+
return truncateModeType;
157+
}
158+
}
159+
throw new ConfigException(format("'%s' is not a valid Truncate Mode", truncateMode));
160+
}
161+
162+
/**
163+
* The truncation implementation for each different truncation type
164+
* @param writeConfig the write config
165+
*/
166+
public abstract void truncate(WriteConfig writeConfig);
167+
168+
@Override
169+
public String toString() {
170+
return value;
171+
}
172+
}
173+
119174
/**
120175
* The maximum batch size for the batch in the bulk operation.
121176
*
@@ -243,6 +298,21 @@ public String toString() {
243298

244299
private static final boolean IGNORE_NULL_VALUES_DEFAULT = false;
245300

301+
/**
302+
* Truncate Mode
303+
*
304+
* <p>Configuration: {@value}
305+
*
306+
* <p>Default: {@code Drop}
307+
*
308+
* <p>Determines how to truncate a collection when using {@link org.apache.spark.sql.SaveMode#Overwrite}
309+
*
310+
* @since 10.6
311+
*/
312+
public static final String TRUNCATE_MODE_CONFIG = "truncateMode";
313+
314+
private static final String TRUNCATE_MODE_DEFAULT = TruncateMode.DROP.value;
315+
246316
private final WriteConcern writeConcern;
247317
private final OperationType operationType;
248318

@@ -319,6 +389,14 @@ public boolean ignoreNullValues() {
319389
return getBoolean(IGNORE_NULL_VALUES_CONFIG, IGNORE_NULL_VALUES_DEFAULT);
320390
}
321391

392+
/**
393+
* @return the truncate mode for use when overwriting collections
394+
* @since 10.6
395+
*/
396+
public TruncateMode truncateMode() {
397+
return TruncateMode.fromString(getOrDefault(TRUNCATE_MODE_CONFIG, TRUNCATE_MODE_DEFAULT));
398+
}
399+
322400
@Override
323401
CollectionsConfig parseAndValidateCollectionsConfig() {
324402
CollectionsConfig collectionsConfig = super.parseAndValidateCollectionsConfig();

src/main/java/com/mongodb/spark/sql/connector/write/MongoBatchWrite.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import static java.lang.String.format;
2121

22-
import com.mongodb.client.MongoCollection;
2322
import com.mongodb.spark.sql.connector.config.WriteConfig;
2423
import com.mongodb.spark.sql.connector.exceptions.DataException;
2524
import java.util.Arrays;
@@ -62,7 +61,7 @@ final class MongoBatchWrite implements BatchWrite {
6261
@Override
6362
public DataWriterFactory createBatchWriterFactory(final PhysicalWriteInfo physicalWriteInfo) {
6463
if (truncate) {
65-
writeConfig.doWithCollection(MongoCollection::drop);
64+
writeConfig.truncateMode().truncate(writeConfig);
6665
}
6766
return new MongoDataWriterFactory(info.schema(), writeConfig);
6867
}

src/test/java/com/mongodb/spark/sql/connector/config/MongoConfigTest.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,21 @@ void testWriteConfigConvertJson() {
159159
WriteConfig.ConvertJson.OBJECT_OR_ARRAY_ONLY);
160160
}
161161

162+
@Test
163+
void testWriteConfigTruncateMode() {
164+
WriteConfig writeConfig = MongoConfig.createConfig(CONFIG_MAP).toWriteConfig();
165+
assertEquals(writeConfig.truncateMode(), WriteConfig.TruncateMode.DROP);
166+
assertEquals(
167+
writeConfig.withOption("TruncateMode", "truncate").truncateMode(),
168+
WriteConfig.TruncateMode.TRUNCATE);
169+
assertEquals(
170+
writeConfig.withOption("TruncateMode", "Drop").truncateMode(),
171+
WriteConfig.TruncateMode.DROP);
172+
assertThrows(
173+
ConfigException.class,
174+
() -> writeConfig.withOption("TruncateMode", "RECREATE").truncateMode());
175+
}
176+
162177
@Test
163178
void testMongoConfigOptionsParsing() {
164179
MongoConfig mongoConfig = MongoConfig.readConfig(OPTIONS_CONFIG_MAP);

0 commit comments

Comments
 (0)