Skip to content

Commit c32aee1

Browse files
dillitzhvanhovell
authored andcommitted
[SPARK-55243][CONNECT] Allow setting binary headers via the -bin suffix in the Scala Connect client
### What changes were proposed in this pull request? Automatically use the `Metadata.BINARY_BYTE_MARSHALLER` for `-bin` suffixed header keys, assuming base64-encoded header value strings set through the Scala Spark Connect client builder. ### Why are the changes needed? The Scala Spark Connect client currently only allows setting `Metadata.ASCII_STRING_MARSHALLER` headers and fails if one tries to put a (binary) header with `-bin` key suffix: ``` [info] java.lang.IllegalArgumentException: ASCII header is named test-bin. Only binary headers may end with -bin [info] at com.google.common.base.Preconditions.checkArgument(Preconditions.java:445) [info] at io.grpc.Metadata$AsciiKey.<init>(Metadata.java:972) [info] at io.grpc.Metadata$AsciiKey.<init>(Metadata.java:966) [info] at io.grpc.Metadata$Key.of(Metadata.java:708) [info] at io.grpc.Metadata$Key.of(Metadata.java:704) [info] at org.apache.spark.sql.connect.client.SparkConnectClient$MetadataHeaderClientInterceptor$$anon$2.$anonfun$start$1(SparkConnectClient.scala:1112) [info] at org.apache.spark.sql.connect.client.SparkConnectClient$MetadataHeaderClientInterceptor$$anon$2.$anonfun$start$1$adapted(SparkConnectClient.scala:1106) [info] at scala.collection.immutable.Map$Map1.foreach(Map.scala:278) [info] at org.apache.spark.sql.connect.client.SparkConnectClient$MetadataHeaderClientInterceptor$$anon$2.start(SparkConnectClient.scala:1106) [info] at io.grpc.stub.ClientCalls.startCall(ClientCalls.java:435) ``` ### Does this PR introduce _any_ user-facing change? Current behaviour: Fails for all header key-value pairs if the key has the `-bin` suffix with an `IllegalArgumentException`. New behaviour: Adds a `Metadata.BINARY_BYTE_MARSHALLER` header if the key has a `-bin` suffix and the value string is base64-encoded. ### How was this patch tested? Added a test to `SparkConnectClientSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#54016 from dillitz/fix-bin-header. Authored-by: Robert Dillitz <r.dillitz@gmail.com> Signed-off-by: Herman van Hövell <herman@databricks.com>
1 parent 5c320f4 commit c32aee1

File tree

2 files changed

+84
-7
lines changed

2 files changed

+84
-7
lines changed

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
*/
1717
package org.apache.spark.sql.connect.client
1818

19-
import java.util.UUID
19+
import java.nio.charset.StandardCharsets.UTF_8
20+
import java.util.{Base64, UUID}
2021
import java.util.concurrent.TimeUnit
2122

2223
import scala.collection.mutable
2324
import scala.jdk.CollectionConverters._
2425

25-
import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor, Server, Status, StatusRuntimeException}
26+
import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, Metadata, MethodDescriptor, Server, ServerCall, ServerCallHandler, ServerInterceptor, Status, StatusRuntimeException}
2627
import io.grpc.netty.NettyServerBuilder
2728
import io.grpc.stub.StreamObserver
2829
import org.scalatest.concurrent.Eventually
@@ -42,12 +43,13 @@ class SparkConnectClientSuite extends ConnectFunSuite {
4243
private var service: DummySparkConnectService = _
4344
private var server: Server = _
4445

45-
private def startDummyServer(port: Int): Unit = {
46+
private def startDummyServer(port: Int, interceptors: Seq[ServerInterceptor] = Seq()): Unit = {
4647
service = new DummySparkConnectService
47-
server = NettyServerBuilder
48+
val serverBuilder = NettyServerBuilder
4849
.forPort(port)
4950
.addService(service)
50-
.build()
51+
interceptors.foreach(serverBuilder.intercept)
52+
server = serverBuilder.build()
5153
server.start()
5254
}
5355

@@ -622,6 +624,72 @@ class SparkConnectClientSuite extends ConnectFunSuite {
622624
// The client should try to fetch the config only once.
623625
assert(service.getAndClearLatestConfigRequests().size == 1)
624626
}
627+
628+
test("SPARK-55243: Binary headers use the correct marshaller") {
629+
class HeadersInterceptor extends ServerInterceptor {
630+
var headers: Option[Metadata] = None
631+
632+
override def interceptCall[ReqT, RespT](
633+
call: ServerCall[ReqT, RespT],
634+
headers: Metadata,
635+
next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = {
636+
this.headers = Some(headers)
637+
next.startCall(call, headers)
638+
}
639+
}
640+
641+
def buildClientWithHeader(key: String, value: String): SparkConnectClient = {
642+
SparkConnectClient
643+
.builder()
644+
.connectionString(s"sc://localhost:${server.getPort}")
645+
.option(key, value)
646+
.build()
647+
}
648+
649+
val headerInterceptor = new HeadersInterceptor()
650+
startDummyServer(0, Seq(headerInterceptor))
651+
652+
val keyName = "test-bin"
653+
val key = Metadata.Key.of(keyName, Metadata.BINARY_BYTE_MARSHALLER)
654+
val binaryData = "test-binary-data"
655+
val base64EncodedValue = Base64.getEncoder.encodeToString(binaryData.getBytes(UTF_8))
656+
657+
val plan = buildPlan("select * from range(10)")
658+
659+
// Successfully set and use base64-encoded -bin key.
660+
client = buildClientWithHeader(keyName, base64EncodedValue)
661+
client.execute(plan)
662+
663+
Eventually.eventually(timeout(5.seconds)) {
664+
assert(headerInterceptor.headers.exists(_.containsKey(key)))
665+
val bytes = headerInterceptor.headers.get.get(key)
666+
assert(new String(bytes, UTF_8) == binaryData)
667+
}
668+
669+
// Non base64-encoded -bin header throws IllegalArgumentException.
670+
client = buildClientWithHeader(keyName, binaryData)
671+
672+
assertThrows[IllegalArgumentException] {
673+
client.execute(plan)
674+
}
675+
676+
// Non -bin headers keep using the ASCII marshaller.
677+
val asciiKeyName = "test"
678+
val asciiKey = Metadata.Key.of(asciiKeyName, Metadata.ASCII_STRING_MARSHALLER)
679+
680+
headerInterceptor.headers = None // Reset captured headers.
681+
682+
client = buildClientWithHeader(asciiKeyName, base64EncodedValue)
683+
client.execute(plan)
684+
685+
Eventually.eventually(timeout(5.seconds)) {
686+
assert(headerInterceptor.headers.exists(_.containsKey(asciiKey)))
687+
val value = headerInterceptor.headers.get.get(asciiKey)
688+
assert(value == base64EncodedValue)
689+
// No BINARY_BYTE_MARSHALLER header.
690+
assert(!headerInterceptor.headers.exists(_.containsKey(key)))
691+
}
692+
}
625693
}
626694

627695
class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase {

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
package org.apache.spark.sql.connect.client
1919

2020
import java.net.URI
21-
import java.util.{Locale, UUID}
21+
import java.nio.charset.StandardCharsets.UTF_8
22+
import java.util.{Base64, Locale, UUID}
2223
import java.util.concurrent.Executor
2324

2425
import scala.collection.mutable
@@ -1093,6 +1094,8 @@ object SparkConnectClient {
10931094
*/
10941095
private[client] class MetadataHeaderClientInterceptor(metadata: Map[String, String])
10951096
extends ClientInterceptor {
1097+
metadata.foreach { case (key, value) => assert(key != null && value != null) }
1098+
10961099
override def interceptCall[ReqT, RespT](
10971100
method: MethodDescriptor[ReqT, RespT],
10981101
callOptions: CallOptions,
@@ -1103,7 +1106,13 @@ object SparkConnectClient {
11031106
responseListener: ClientCall.Listener[RespT],
11041107
headers: Metadata): Unit = {
11051108
metadata.foreach { case (key, value) =>
1106-
headers.put(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER), value)
1109+
if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
1110+
// Expects a base64-encoded value string.
1111+
val valueByteArray = Base64.getDecoder.decode(value.getBytes(UTF_8))
1112+
headers.put(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER), valueByteArray)
1113+
} else {
1114+
headers.put(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER), value)
1115+
}
11071116
}
11081117
super.start(responseListener, headers)
11091118
}

0 commit comments

Comments
 (0)