Skip to content

Commit c3d738d

Browse files
Added truncateMode write configuration
To allow the overwrite save mode to keep collection options SPARK-384 Original PR: mongodb#123 - removed recreate mode due to fragility --------- Co-authored-by: Ross Lawley <[email protected]>
1 parent 55dea7d commit c3d738d

File tree

5 files changed

+254
-5
lines changed

5 files changed

+254
-5
lines changed

src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import com.mongodb.spark.sql.connector.beans.ComplexBean;
2626
import com.mongodb.spark.sql.connector.beans.DateTimeBean;
2727
import com.mongodb.spark.sql.connector.beans.PrimitiveBean;
28+
import com.mongodb.spark.sql.connector.config.WriteConfig;
29+
import com.mongodb.spark.sql.connector.config.WriteConfig.TruncateMode;
2830
import com.mongodb.spark.sql.connector.mongodb.MongoSparkConnectorTestCase;
2931
import java.sql.Date;
3032
import java.sql.Timestamp;
@@ -41,6 +43,8 @@
4143
import org.apache.spark.sql.Encoders;
4244
import org.apache.spark.sql.SparkSession;
4345
import org.junit.jupiter.api.Test;
46+
import org.junit.jupiter.params.ParameterizedTest;
47+
import org.junit.jupiter.params.provider.EnumSource;
4448

4549
public class RoundTripTest extends MongoSparkConnectorTestCase {
4650

@@ -68,8 +72,9 @@ void testPrimitiveBean() {
6872
assertIterableEquals(dataSetOriginal, dataSetMongo);
6973
}
7074

71-
@Test
72-
void testBoxedBean() {
75+
@ParameterizedTest
76+
@EnumSource(TruncateMode.class)
77+
void testBoxedBean(final TruncateMode mode) {
7378
// Given
7479
List<BoxedBean> dataSetOriginal =
7580
singletonList(new BoxedBean((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0, true));
@@ -79,7 +84,12 @@ void testBoxedBean() {
7984
Encoder<BoxedBean> encoder = Encoders.bean(BoxedBean.class);
8085

8186
Dataset<BoxedBean> dataset = spark.createDataset(dataSetOriginal, encoder);
82-
dataset.write().format("mongodb").mode("Overwrite").save();
87+
dataset
88+
.write()
89+
.format("mongodb")
90+
.mode("Overwrite")
91+
.option(WriteConfig.TRUNCATE_MODE_CONFIG, mode.name())
92+
.save();
8393

8494
// Then
8595
List<BoxedBean> dataSetMongo = spark
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
*/
17+
package com.mongodb.spark.sql.connector.write;
18+
19+
import static com.mongodb.spark.sql.connector.config.WriteConfig.TRUNCATE_MODE_CONFIG;
20+
import static java.util.Arrays.asList;
21+
import static java.util.Collections.singletonList;
22+
import static org.junit.jupiter.api.Assertions.assertEquals;
23+
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
24+
import static org.junit.jupiter.api.Assertions.assertTrue;
25+
26+
import com.mongodb.client.MongoCollection;
27+
import com.mongodb.client.MongoDatabase;
28+
import com.mongodb.client.model.CreateCollectionOptions;
29+
import com.mongodb.client.model.IndexOptions;
30+
import com.mongodb.spark.sql.connector.beans.BoxedBean;
31+
import com.mongodb.spark.sql.connector.config.WriteConfig;
32+
import com.mongodb.spark.sql.connector.mongodb.MongoSparkConnectorTestCase;
33+
import java.util.ArrayList;
34+
import java.util.List;
35+
import org.apache.spark.sql.Dataset;
36+
import org.apache.spark.sql.Encoder;
37+
import org.apache.spark.sql.Encoders;
38+
import org.apache.spark.sql.SparkSession;
39+
import org.bson.Document;
40+
import org.jetbrains.annotations.NotNull;
41+
import org.junit.jupiter.api.BeforeEach;
42+
import org.junit.jupiter.api.Test;
43+
44+
public class TruncateModesTest extends MongoSparkConnectorTestCase {
45+
46+
public static final String INT_FIELD_INDEX = "intFieldIndex";
47+
public static final String ID_INDEX = "_id_";
48+
49+
@BeforeEach
50+
void setup() {
51+
MongoDatabase database = getDatabase();
52+
getCollection().drop();
53+
CreateCollectionOptions createCollectionOptions =
54+
new CreateCollectionOptions().capped(true).maxDocuments(1024).sizeInBytes(1024);
55+
database.createCollection(getCollectionName(), createCollectionOptions);
56+
MongoCollection<Document> collection = database.getCollection(getCollectionName());
57+
collection.insertOne(new Document().append("intField", null));
58+
collection.createIndex(
59+
new Document().append("intField", 1), new IndexOptions().name(INT_FIELD_INDEX));
60+
}
61+
62+
@Test
63+
void testCollectionDroppedOnOverwrite() {
64+
// Given
65+
List<BoxedBean> dataSetOriginal = singletonList(getBoxedBean());
66+
67+
// when
68+
SparkSession spark = getOrCreateSparkSession();
69+
Encoder<BoxedBean> encoder = Encoders.bean(BoxedBean.class);
70+
Dataset<BoxedBean> dataset = spark.createDataset(dataSetOriginal, encoder);
71+
dataset
72+
.write()
73+
.format("mongodb")
74+
.mode("Overwrite")
75+
.option(TRUNCATE_MODE_CONFIG, WriteConfig.TruncateMode.DROP.toString())
76+
.save();
77+
78+
// Then
79+
List<BoxedBean> dataSetMongo = spark
80+
.read()
81+
.format("mongodb")
82+
.schema(encoder.schema())
83+
.load()
84+
.as(encoder)
85+
.collectAsList();
86+
assertIterableEquals(dataSetOriginal, dataSetMongo);
87+
88+
List<String> indexes =
89+
getCollection().listIndexes().map(it -> it.getString("name")).into(new ArrayList<>());
90+
assertEquals(indexes, singletonList(ID_INDEX));
91+
Document options = getCollectionOptions();
92+
assertTrue(options.isEmpty());
93+
}
94+
95+
@Test
96+
void testOptionKeepingOverwrites() {
97+
// Given
98+
List<BoxedBean> dataSetOriginal = singletonList(getBoxedBean());
99+
100+
// when
101+
SparkSession spark = getOrCreateSparkSession();
102+
Encoder<BoxedBean> encoder = Encoders.bean(BoxedBean.class);
103+
Dataset<BoxedBean> dataset = spark.createDataset(dataSetOriginal, encoder);
104+
dataset
105+
.write()
106+
.format("mongodb")
107+
.mode("Overwrite")
108+
.option(TRUNCATE_MODE_CONFIG, WriteConfig.TruncateMode.TRUNCATE.toString())
109+
.save();
110+
111+
// Then
112+
List<BoxedBean> dataSetMongo = spark
113+
.read()
114+
.format("mongodb")
115+
.schema(encoder.schema())
116+
.load()
117+
.as(encoder)
118+
.collectAsList();
119+
assertIterableEquals(dataSetOriginal, dataSetMongo);
120+
121+
List<String> indexes =
122+
getCollection().listIndexes().map(it -> it.getString("name")).into(new ArrayList<>());
123+
assertEquals(indexes, asList(ID_INDEX, INT_FIELD_INDEX));
124+
125+
Document options = getCollectionOptions();
126+
assertTrue(options.getBoolean("capped"));
127+
}
128+
129+
private @NotNull BoxedBean getBoxedBean() {
130+
return new BoxedBean((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0, true);
131+
}
132+
133+
private Document getCollectionOptions() {
134+
Document getCollectionMeta = new Document()
135+
.append("listCollections", 1)
136+
.append("filter", new Document().append("name", getCollectionName()));
137+
138+
Document foundMeta = getDatabase().runCommand(getCollectionMeta);
139+
Document cursor = foundMeta.get("cursor", Document.class);
140+
List<Document> firstBatch = cursor.getList("firstBatch", Document.class);
141+
if (firstBatch.isEmpty()) {
142+
return getCollectionMeta;
143+
}
144+
145+
return firstBatch.get(0).get("options", Document.class);
146+
}
147+
}

src/main/java/com/mongodb/spark/sql/connector/config/WriteConfig.java

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@
2222

2323
import com.mongodb.MongoNamespace;
2424
import com.mongodb.WriteConcern;
25+
import com.mongodb.client.MongoCollection;
2526
import com.mongodb.spark.sql.connector.exceptions.ConfigException;
2627
import java.util.HashMap;
2728
import java.util.List;
2829
import java.util.Map;
2930
import java.util.concurrent.TimeUnit;
31+
import org.bson.Document;
3032
import org.jetbrains.annotations.ApiStatus;
3133
import org.slf4j.Logger;
3234
import org.slf4j.LoggerFactory;
@@ -116,6 +118,59 @@ public String toString() {
116118
}
117119
}
118120

121+
/**
122+
* Determines how to truncate a collection when using {@link org.apache.spark.sql.SaveMode#Overwrite}
123+
*
124+
* @since 10.6
125+
*/
126+
public enum TruncateMode {
127+
/**
128+
* Drops the collection
129+
*/
130+
DROP("drop") {
131+
@Override
132+
public void truncate(final WriteConfig writeConfig) {
133+
writeConfig.doWithCollection(MongoCollection::drop);
134+
}
135+
},
136+
/**
137+
* Deletes all entries in the collection preserving indexes, collection options and any sharding configuration
138+
* <p><strong>Warning:</strong> This operation is currently much more expensive than doing a simple drop operation. </p>
139+
*/
140+
TRUNCATE("truncate") {
141+
@Override
142+
public void truncate(final WriteConfig writeConfig) {
143+
writeConfig.doWithCollection(collection -> collection.deleteMany(new Document()));
144+
}
145+
};
146+
147+
private final String value;
148+
149+
TruncateMode(final String value) {
150+
this.value = value;
151+
}
152+
153+
static TruncateMode fromString(final String truncateMode) {
154+
for (TruncateMode truncateModeType : TruncateMode.values()) {
155+
if (truncateMode.equalsIgnoreCase(truncateModeType.value)) {
156+
return truncateModeType;
157+
}
158+
}
159+
throw new ConfigException(format("'%s' is not a valid Truncate Mode", truncateMode));
160+
}
161+
162+
/**
163+
* The truncation implementation for each different truncation type
164+
* @param writeConfig the write config
165+
*/
166+
public abstract void truncate(WriteConfig writeConfig);
167+
168+
@Override
169+
public String toString() {
170+
return value;
171+
}
172+
}
173+
119174
/**
120175
* The maximum batch size for the batch in the bulk operation.
121176
*
@@ -243,6 +298,21 @@ public String toString() {
243298

244299
private static final boolean IGNORE_NULL_VALUES_DEFAULT = false;
245300

301+
/**
302+
* Truncate Mode
303+
*
304+
* <p>Configuration: {@value}
305+
*
306+
* <p>Default: {@code Drop}
307+
*
308+
* <p>Determines how to truncate a collection when using {@link org.apache.spark.sql.SaveMode#Overwrite}
309+
*
310+
* @since 10.6
311+
*/
312+
public static final String TRUNCATE_MODE_CONFIG = "truncateMode";
313+
314+
private static final String TRUNCATE_MODE_DEFAULT = TruncateMode.DROP.value;
315+
246316
private final WriteConcern writeConcern;
247317
private final OperationType operationType;
248318

@@ -319,6 +389,14 @@ public boolean ignoreNullValues() {
319389
return getBoolean(IGNORE_NULL_VALUES_CONFIG, IGNORE_NULL_VALUES_DEFAULT);
320390
}
321391

392+
/**
393+
* @return the truncate mode for use when overwriting collections
394+
* @since 10.6
395+
*/
396+
public TruncateMode truncateMode() {
397+
return TruncateMode.fromString(getOrDefault(TRUNCATE_MODE_CONFIG, TRUNCATE_MODE_DEFAULT));
398+
}
399+
322400
@Override
323401
CollectionsConfig parseAndValidateCollectionsConfig() {
324402
CollectionsConfig collectionsConfig = super.parseAndValidateCollectionsConfig();

src/main/java/com/mongodb/spark/sql/connector/write/MongoBatchWrite.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import static java.lang.String.format;
2121

22-
import com.mongodb.client.MongoCollection;
2322
import com.mongodb.spark.sql.connector.config.WriteConfig;
2423
import com.mongodb.spark.sql.connector.exceptions.DataException;
2524
import java.util.Arrays;
@@ -62,7 +61,7 @@ final class MongoBatchWrite implements BatchWrite {
6261
@Override
6362
public DataWriterFactory createBatchWriterFactory(final PhysicalWriteInfo physicalWriteInfo) {
6463
if (truncate) {
65-
writeConfig.doWithCollection(MongoCollection::drop);
64+
writeConfig.truncateMode().truncate(writeConfig);
6665
}
6766
return new MongoDataWriterFactory(info.schema(), writeConfig);
6867
}

src/test/java/com/mongodb/spark/sql/connector/config/MongoConfigTest.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,21 @@ void testWriteConfigConvertJson() {
159159
WriteConfig.ConvertJson.OBJECT_OR_ARRAY_ONLY);
160160
}
161161

162+
@Test
163+
void testWriteConfigTruncateMode() {
164+
WriteConfig writeConfig = MongoConfig.createConfig(CONFIG_MAP).toWriteConfig();
165+
assertEquals(writeConfig.truncateMode(), WriteConfig.TruncateMode.DROP);
166+
assertEquals(
167+
writeConfig.withOption("TruncateMode", "truncate").truncateMode(),
168+
WriteConfig.TruncateMode.TRUNCATE);
169+
assertEquals(
170+
writeConfig.withOption("TruncateMode", "Drop").truncateMode(),
171+
WriteConfig.TruncateMode.DROP);
172+
assertThrows(
173+
ConfigException.class,
174+
() -> writeConfig.withOption("TruncateMode", "RECREATE").truncateMode());
175+
}
176+
162177
@Test
163178
void testMongoConfigOptionsParsing() {
164179
MongoConfig mongoConfig = MongoConfig.readConfig(OPTIONS_CONFIG_MAP);

0 commit comments

Comments
 (0)