Skip to content

Commit bc5b3d9

Browse files
Make the pulsar client creation factory pluggable (#159)
* fix build * clean up * increase timeout * use client factory in pulsar sink
1 parent 3f2e904 commit bc5b3d9

File tree

7 files changed

+135
-11
lines changed

7 files changed

+135
-11
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
unittest:
2525
name: Run unit tests
2626
runs-on: ubuntu-latest
27-
timeout-minutes: 30
27+
timeout-minutes: 45
2828
steps:
2929
- uses: actions/checkout@v3
3030
- name: Setup Java
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package org.apache.spark.sql.pulsar
15+
16+
import java.{util => ju}
17+
18+
import org.apache.pulsar.client.impl.PulsarClientImpl
19+
20+
import org.apache.spark.SparkConf
21+
import org.apache.spark.util.Utils
22+
23+
trait PulsarClientFactory {
24+
def getOrCreate(params: ju.Map[String, Object]): PulsarClientImpl
25+
}
26+
27+
class DefaultPulsarClientFactory extends PulsarClientFactory {
28+
def getOrCreate(params: ju.Map[String, Object]): PulsarClientImpl = {
29+
CachedPulsarClient.getOrCreate(params)
30+
}
31+
}
32+
33+
object PulsarClientFactory {
34+
val PulsarClientFactoryClassOption = "org.apache.spark.sql.pulsar.PulsarClientFactoryClass"
35+
def getOrCreate(sparkConf: SparkConf, params: ju.Map[String, Object]): PulsarClientImpl = {
36+
getFactory(sparkConf).getOrCreate(params)
37+
}
38+
39+
private def getFactory(sparkConf: SparkConf): PulsarClientFactory = {
40+
sparkConf.getOption(PulsarClientFactoryClassOption) match {
41+
case Some(factoryClassName) =>
42+
Utils.classForName(factoryClassName).getConstructor()
43+
.newInstance().asInstanceOf[PulsarClientFactory]
44+
case None => new DefaultPulsarClientFactory()
45+
}
46+
}
47+
}
48+
49+

src/main/scala/org/apache/spark/sql/pulsar/PulsarHelper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.pulsar.common.naming.TopicName
3232
import org.apache.pulsar.common.schema.SchemaInfo
3333
import org.apache.pulsar.shade.com.google.common.util.concurrent.Uninterruptibles
3434

35+
import org.apache.spark.SparkContext
3536
import org.apache.spark.internal.Logging
3637
import org.apache.spark.sql.connector.read.streaming
3738
import org.apache.spark.sql.connector.read.streaming.{ReadAllAvailable, ReadLimit}
@@ -52,13 +53,15 @@ private[pulsar] case class PulsarHelper(
5253
driverGroupIdPrefix: String,
5354
caseInsensitiveParameters: Map[String, String],
5455
allowDifferentTopicSchemas: Boolean,
55-
predefinedSubscription: Option[String])
56+
predefinedSubscription: Option[String],
57+
sparkContext: SparkContext)
5658
extends Closeable
5759
with Logging {
5860

5961
import scala.collection.JavaConverters._
6062

61-
protected var client: PulsarClientImpl = CachedPulsarClient.getOrCreate(clientConf)
63+
protected var client: PulsarClientImpl =
64+
PulsarClientFactory.getOrCreate(sparkContext.conf, clientConf)
6265

6366
private var topics: Seq[String] = _
6467
private var topicPartitions: Seq[String] = _

src/main/scala/org/apache/spark/sql/pulsar/PulsarProvider.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ private[pulsar] class PulsarProvider
6969
subscriptionNamePrefix,
7070
caseInsensitiveParams,
7171
getAllowDifferentTopicSchemas(parameters),
72-
getPredefinedSubscription(parameters))) { pulsarHelper =>
72+
getPredefinedSubscription(parameters),
73+
sqlContext.sparkContext)) { pulsarHelper =>
7374
pulsarHelper.getAndCheckCompatible(schema)
7475
}
7576

@@ -102,7 +103,8 @@ private[pulsar] class PulsarProvider
102103
subscriptionNamePrefix,
103104
caseInsensitiveParams,
104105
getAllowDifferentTopicSchemas(parameters),
105-
getPredefinedSubscription(parameters))
106+
getPredefinedSubscription(parameters),
107+
sqlContext.sparkContext)
106108

107109
val pSchema = pulsarHelper.getAndCheckCompatible(schema)
108110
logDebug(s"Schema from Spark: $schema; Schema from Pulsar: ${pSchema}")
@@ -151,7 +153,8 @@ private[pulsar] class PulsarProvider
151153
subscriptionNamePrefix,
152154
caseInsensitiveParams,
153155
getAllowDifferentTopicSchemas(parameters),
154-
getPredefinedSubscription(parameters))) { pulsarHelper =>
156+
getPredefinedSubscription(parameters),
157+
sqlContext.sparkContext)) { pulsarHelper =>
155158
val perTopicStarts =
156159
pulsarHelper.offsetForEachTopic(caseInsensitiveParams, EarliestOffset, StartOptionKey)
157160
val startingOffset = SpecificPulsarOffset(

src/main/scala/org/apache/spark/sql/pulsar/PulsarSinks.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import scala.util.control.NonFatal
2020

2121
import org.apache.pulsar.client.api.{Producer, PulsarClientException, Schema}
2222

23+
import org.apache.spark.SparkEnv
2324
import org.apache.spark.internal.Logging
2425
import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession, SQLContext}
2526
import org.apache.spark.sql.catalyst.expressions
@@ -161,8 +162,8 @@ private[pulsar] object PulsarSinks extends Logging {
161162
schema: Schema[T]): Producer[T] = {
162163

163164
try {
164-
CachedPulsarClient
165-
.getOrCreate(clientConf)
165+
PulsarClientFactory
166+
.getOrCreate(SparkEnv.get.conf, clientConf)
166167
.newProducer(schema)
167168
.topic(topic)
168169
.loadConf(producerConf)

src/main/scala/org/apache/spark/sql/pulsar/PulsarSourceRDD.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import java.util.concurrent.TimeUnit
1919
import org.apache.pulsar.client.api.{Message, MessageId, PulsarClientException, Schema}
2020
import org.apache.pulsar.client.impl.{BatchMessageIdImpl, MessageIdImpl}
2121

22-
import org.apache.spark.{Partition, SparkContext, TaskContext}
22+
import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext}
2323
import org.apache.spark.rdd.RDD
2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.json.JSONOptionsInRead
@@ -58,8 +58,8 @@ private[pulsar] abstract class PulsarSourceRDDBase(
5858
val deserializer = new PulsarDeserializer(schemaInfo.si, jsonOptions)
5959
val schema: Schema[_] = SchemaUtils.getPSchema(schemaInfo.si)
6060

61-
lazy val reader = CachedPulsarClient
62-
.getOrCreate(clientConf)
61+
lazy val reader = PulsarClientFactory
62+
.getOrCreate(SparkEnv.get.conf, clientConf)
6363
.newReader(schema)
6464
.subscriptionRolePrefix(subscriptionNamePrefix)
6565
.topic(topic)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package org.apache.spark.sql.pulsar
2+
3+
import org.apache.pulsar.client.impl.PulsarClientImpl
4+
import org.apache.spark.sql.pulsar.PulsarOptions.{ServiceUrlOptionKey, TopicPattern}
5+
import org.apache.spark.sql.pulsar.TestPulsarClientFactory.counter
6+
7+
import java.{util => ju}
8+
9+
class TestPulsarClientFactory extends PulsarClientFactory {
10+
def getOrCreate(params: ju.Map[String, Object]): PulsarClientImpl = {
11+
counter += 1
12+
new DefaultPulsarClientFactory().getOrCreate(params)
13+
}
14+
}
15+
16+
object TestPulsarClientFactory {
17+
var counter = 0
18+
}
19+
20+
class PulsarClientFactorySuite extends PulsarSourceTest {
21+
test("Set Pulsar client factory class") {
22+
sparkContext.conf.set(PulsarClientFactory.PulsarClientFactoryClassOption,
23+
"org.apache.spark.sql.pulsar.TestPulsarClientFactory")
24+
val topic = newTopic()
25+
sendMessages(topic, (101 to 105).map { _.toString }.toArray)
26+
27+
val reader = spark.readStream
28+
.format("pulsar")
29+
.option(ServiceUrlOptionKey, serviceUrl)
30+
.option(TopicPattern, s"$topic.*")
31+
32+
val pulsar = reader
33+
.load()
34+
.selectExpr("CAST(__key AS STRING)", "CAST(value AS STRING)")
35+
36+
testStream(pulsar)(
37+
makeSureGetOffsetCalled,
38+
StopStream
39+
)
40+
// Assert that we are using the test factory.
41+
assert(TestPulsarClientFactory.counter > 0)
42+
}
43+
44+
test("Unset Pulsar client factory class") {
45+
sparkContext.conf.remove(PulsarClientFactory.PulsarClientFactoryClassOption)
46+
val oldCount = TestPulsarClientFactory.counter
47+
val topic = newTopic()
48+
sendMessages(topic, (101 to 105).map { _.toString }.toArray)
49+
50+
val reader = spark.readStream
51+
.format("pulsar")
52+
.option(ServiceUrlOptionKey, serviceUrl)
53+
.option(TopicPattern, s"$topic.*")
54+
55+
val pulsar = reader
56+
.load()
57+
.selectExpr("CAST(__key AS STRING)", "CAST(value AS STRING)")
58+
59+
testStream(pulsar)(
60+
makeSureGetOffsetCalled,
61+
StopStream
62+
)
63+
64+
val newCount = TestPulsarClientFactory.counter
65+
// The count doesn't change because we are using the default factory.
66+
assert(oldCount == newCount)
67+
}
68+
}

0 commit comments

Comments
 (0)