Skip to content

Commit f9cc7fb

Browse files
ulysses-yousrowen
authored andcommitted
[SPARK-36992][SQL] Improve byte array sort perf by unify getPrefix function of UTF8String and ByteArray
### What changes were proposed in this pull request? Unify the getPrefix function of `UTF8String` and `ByteArray`. ### Why are the changes needed? When execute sort operator, we first compare the prefix. However the getPrefix function of byte array is slow. We use first 8 bytes as the prefix, so at most we will call 8 times with `Platform.getByte` which is slower than call once with `Platform.getInt` or `Platform.getLong`. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? pass `org.apache.spark.util.collection.unsafe.sort.PrefixComparatorsSuite` Closes apache#34267 from ulysses-you/binary-prefix. Authored-by: ulysses-you <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent ee2647e commit f9cc7fb

File tree

3 files changed

+81
-44
lines changed

3 files changed

+81
-44
lines changed

common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.unsafe.types;
1919

20+
import java.nio.ByteOrder;
2021
import java.util.Arrays;
2122

2223
import com.google.common.primitives.Ints;
@@ -26,6 +27,8 @@
2627
public final class ByteArray {
2728

2829
public static final byte[] EMPTY_BYTE = new byte[0];
30+
private static final boolean IS_LITTLE_ENDIAN =
31+
ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN;
2932

3033
/**
3134
* Writes the content of a byte array into a memory address, identified by an object and an
@@ -42,15 +45,34 @@ public static void writeToMemory(byte[] src, Object target, long targetOffset) {
4245
public static long getPrefix(byte[] bytes) {
4346
if (bytes == null) {
4447
return 0L;
48+
}
49+
return getPrefix(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length);
50+
}
51+
52+
static long getPrefix(Object base, long offset, int numBytes) {
53+
// Since JVMs are either 4-byte aligned or 8-byte aligned, we check the size of the bytes.
54+
// If size is 0, just return 0.
55+
// If size is between 1 and 4 (inclusive), assume data is 4-byte aligned under the hood and
56+
// use a getInt to fetch the prefix.
57+
// If size is greater than 4, assume we have at least 8 bytes of data to fetch.
58+
// After getting the data, we use a mask to mask out data that is not part of the bytes.
59+
final long p;
60+
final long mask;
61+
if (numBytes >= 8) {
62+
p = Platform.getLong(base, offset);
63+
mask = 0;
64+
} else if (numBytes > 4) {
65+
p = Platform.getLong(base, offset);
66+
mask = (1L << (8 - numBytes) * 8) - 1;
67+
} else if (numBytes > 0) {
68+
long pRaw = Platform.getInt(base, offset);
69+
p = IS_LITTLE_ENDIAN ? pRaw : (pRaw << 32);
70+
mask = (1L << (8 - numBytes) * 8) - 1;
4571
} else {
46-
final int minLen = Math.min(bytes.length, 8);
47-
long p = 0;
48-
for (int i = 0; i < minLen; ++i) {
49-
p |= ((long) Platform.getByte(bytes, Platform.BYTE_ARRAY_OFFSET + i) & 0xff)
50-
<< (56 - 8 * i);
51-
}
52-
return p;
72+
p = 0;
73+
mask = 0;
5374
}
75+
return (IS_LITTLE_ENDIAN ? java.lang.Long.reverseBytes(p) : p) & ~mask;
5476
}
5577

5678
public static byte[] subStringSQL(byte[] bytes, int pos, int len) {

common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -246,43 +246,7 @@ public int numChars() {
246246
* Returns a 64-bit integer that can be used as the prefix used in sorting.
247247
*/
248248
public long getPrefix() {
249-
// Since JVMs are either 4-byte aligned or 8-byte aligned, we check the size of the string.
250-
// If size is 0, just return 0.
251-
// If size is between 0 and 4 (inclusive), assume data is 4-byte aligned under the hood and
252-
// use a getInt to fetch the prefix.
253-
// If size is greater than 4, assume we have at least 8 bytes of data to fetch.
254-
// After getting the data, we use a mask to mask out data that is not part of the string.
255-
long p;
256-
long mask = 0;
257-
if (IS_LITTLE_ENDIAN) {
258-
if (numBytes >= 8) {
259-
p = Platform.getLong(base, offset);
260-
} else if (numBytes > 4) {
261-
p = Platform.getLong(base, offset);
262-
mask = (1L << (8 - numBytes) * 8) - 1;
263-
} else if (numBytes > 0) {
264-
p = (long) Platform.getInt(base, offset);
265-
mask = (1L << (8 - numBytes) * 8) - 1;
266-
} else {
267-
p = 0;
268-
}
269-
p = java.lang.Long.reverseBytes(p);
270-
} else {
271-
// byteOrder == ByteOrder.BIG_ENDIAN
272-
if (numBytes >= 8) {
273-
p = Platform.getLong(base, offset);
274-
} else if (numBytes > 4) {
275-
p = Platform.getLong(base, offset);
276-
mask = (1L << (8 - numBytes) * 8) - 1;
277-
} else if (numBytes > 0) {
278-
p = ((long) Platform.getInt(base, offset)) << 32;
279-
mask = (1L << (8 - numBytes) * 8) - 1;
280-
} else {
281-
p = 0;
282-
}
283-
}
284-
p &= ~mask;
285-
return p;
249+
return ByteArray.getPrefix(base, offset, numBytes);
286250
}
287251

288252
/**
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.unsafe.array;
19+
20+
import org.apache.spark.unsafe.Platform;
21+
import org.apache.spark.unsafe.types.ByteArray;
22+
import org.junit.Assert;
23+
import org.junit.Test;
24+
25+
public class ByteArraySuite {
26+
private long getPrefixByByte(byte[] bytes) {
27+
final int minLen = Math.min(bytes.length, 8);
28+
long p = 0;
29+
for (int i = 0; i < minLen; ++i) {
30+
p |= ((long) Platform.getByte(bytes, Platform.BYTE_ARRAY_OFFSET + i) & 0xff)
31+
<< (56 - 8 * i);
32+
}
33+
return p;
34+
}
35+
36+
@Test
37+
public void testGetPrefix() {
38+
for (int i = 0; i <= 9; i++) {
39+
byte[] bytes = new byte[i];
40+
int prefix = i - 1;
41+
while (prefix >= 0) {
42+
bytes[prefix] = (byte) prefix;
43+
prefix -= 1;
44+
}
45+
46+
long result = ByteArray.getPrefix(bytes);
47+
long expected = getPrefixByByte(bytes);
48+
Assert.assertEquals(result, expected);
49+
}
50+
}
51+
}

0 commit comments

Comments
 (0)