Skip to content

Commit 6d71a5b

Browse files
committed
Work in progress - Initial support for SIMD in the java module.
1 parent c5af1b6 commit 6d71a5b

File tree

5 files changed

+260
-26
lines changed

5 files changed

+260
-26
lines changed

Rakefile

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ if defined?(RUBY_ENGINE) and RUBY_ENGINE == 'jruby'
6868
classpath = (Dir['java/lib/*.jar'] << 'java/src' << JRUBY_JAR) * ':'
6969
obj = src.sub(/\.java\Z/, '.class')
7070
file obj => src do
71-
sh 'javac', '-classpath', classpath, '-source', '1.8', '-target', '1.8', src
71+
sh 'javac', '--enable-preview', '--add-modules', 'jdk.incubator.vector', '-classpath', classpath, '-source', '21', '-target', '21', src
7272
end
7373
JAVA_CLASSES << obj
7474
end
@@ -117,11 +117,14 @@ if defined?(RUBY_ENGINE) and RUBY_ENGINE == 'jruby'
117117
generator_classes = FileList[
118118
"json/ext/ByteList*.class",
119119
"json/ext/OptionsReader*.class",
120+
"json/ext/EscapeScanner*.class",
120121
"json/ext/Generator*.class",
121122
"json/ext/RuntimeInfo*.class",
122123
"json/ext/StringEncoder*.class",
123-
"json/ext/Utils*.class"
124+
"json/ext/Utils*.class",
125+
"json/ext/VectorizedEscapeScanner*.class"
124126
]
127+
puts "Creating generator jar with classes: #{generator_classes.join(', ')}"
125128
sh 'jar', 'cf', File.basename(JRUBY_GENERATOR_JAR), *generator_classes
126129
mv File.basename(JRUBY_GENERATOR_JAR), File.dirname(JRUBY_GENERATOR_JAR)
127130
end

java/src/json/ext/EscapeScanner.java

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package json.ext;
2+
3+
import java.lang.reflect.Constructor;
4+
import java.lang.reflect.InvocationTargetException;
5+
import java.util.Optional;
6+
7+
interface EscapeScanner {
8+
static class State {
9+
byte[] ptrBytes;
10+
int ptr;
11+
int len;
12+
int pos;
13+
int beg;
14+
int ch;
15+
}
16+
17+
static class VectorSupport {
18+
static Constructor<?> vectorizedEscapeScannerConstructor = null;
19+
20+
static {
21+
Optional<Module> vectorModule = ModuleLayer.boot().findModule("jdk.incubator.vector");
22+
if (vectorModule.isPresent()) {
23+
try {
24+
Class<?> vectorEscapeScannerClass = EscapeScanner.class.getClassLoader().loadClass("json.ext.VectorizedEscapeScanner");
25+
vectorizedEscapeScannerConstructor = vectorEscapeScannerClass.getDeclaredConstructor();
26+
} catch (ClassNotFoundException | NoSuchMethodException e) {
27+
// Fallback to the ScalarEscapeScanner if we cannot load the VectorizedEscapeScanner.
28+
System.err.println("Failed to load VectorizedEscapeScanner, falling back to ScalarEscapeScanner: " + e.getMessage());
29+
}
30+
}
31+
}
32+
}
33+
34+
boolean scan(EscapeScanner.State state) throws java.io.IOException;
35+
36+
public static EscapeScanner basicScanner() {
37+
if (VectorSupport.vectorizedEscapeScannerConstructor != null) {
38+
try {
39+
// Attempt to instantiate the vectorized escape scanner if available.
40+
return (EscapeScanner) VectorSupport.vectorizedEscapeScannerConstructor.newInstance();
41+
} catch (InstantiationException | IllegalAccessException | InvocationTargetException e) {
42+
System.err.println("Failed to instantiate VectorizedEscapeScanner, falling back to ScalarEscapeScanner: " + e.getMessage());
43+
}
44+
45+
}
46+
47+
return new ScalarEscapeScanner(StringEncoder.ESCAPE_TABLE);
48+
}
49+
50+
public static EscapeScanner create(byte[] escapeTable) {
51+
return new ScalarEscapeScanner(escapeTable);
52+
}
53+
54+
public static class ScalarEscapeScanner implements EscapeScanner {
55+
private final byte[] escapeTable;
56+
57+
public ScalarEscapeScanner(byte[] escapeTable) {
58+
this.escapeTable = escapeTable;
59+
}
60+
61+
@Override
62+
public boolean scan(EscapeScanner.State state) throws java.io.IOException {
63+
while (state.pos < state.len) {
64+
state.ch = Byte.toUnsignedInt(state.ptrBytes[state.ptr + state.pos]);
65+
int ch_len = escapeTable[state.ch];
66+
if (ch_len > 0) {
67+
return true;
68+
}
69+
state.pos++;
70+
}
71+
return false;
72+
}
73+
74+
}
75+
}

java/src/json/ext/Generator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ public StringEncoder getStringEncoder(ThreadContext context) {
232232
GeneratorState state = getState(context);
233233
stringEncoder = state.asciiOnly() ?
234234
new StringEncoderAsciiOnly(state.scriptSafe()) :
235-
new StringEncoder(state.scriptSafe());
235+
state.scriptSafe() ? StringEncoder.scriptSafeEncoder() : StringEncoder.basicEncoder();
236236
}
237237
return stringEncoder;
238238
}

java/src/json/ext/StringEncoder.java

Lines changed: 122 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
*/
66
package json.ext;
77

8+
import java.io.IOException;
9+
import java.io.OutputStream;
10+
import java.nio.charset.StandardCharsets;
11+
812
import org.jcodings.Encoding;
913
import org.jcodings.specific.ASCIIEncoding;
1014
import org.jcodings.specific.USASCIIEncoding;
@@ -17,9 +21,9 @@
1721
import org.jruby.util.ByteList;
1822
import org.jruby.util.StringSupport;
1923

20-
import java.io.IOException;
21-
import java.io.OutputStream;
22-
import java.nio.charset.StandardCharsets;
24+
import jdk.incubator.vector.ByteVector;
25+
import jdk.incubator.vector.VectorSpecies;
26+
import json.ext.VectorizedEscapeScanner;
2327

2428
/**
2529
* An encoder that reads from the given source and outputs its representation
@@ -130,14 +134,22 @@ class StringEncoder extends ByteListTranscoder {
130134
new byte[] {'0', '1', '2', '3', '4', '5', '6', '7',
131135
'8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
132136

133-
StringEncoder(boolean scriptSafe) {
137+
private StringEncoder(boolean scriptSafe) {
134138
this(scriptSafe ? SCRIPT_SAFE_ESCAPE_TABLE : ESCAPE_TABLE);
135139
}
136140

137141
StringEncoder(byte[] escapeTable) {
138142
this.escapeTable = escapeTable;
139143
}
140144

145+
public static StringEncoder scriptSafeEncoder() {
146+
return new StringEncoder(SCRIPT_SAFE_ESCAPE_TABLE);
147+
}
148+
149+
public static StringEncoder basicEncoder() {
150+
return new StringEncoder(ESCAPE_TABLE);
151+
}
152+
141153
// C: generate_json_string
142154
void generate(ThreadContext context, RubyString object, OutputStream buffer) throws IOException {
143155
object = ensureValidEncoding(context, object);
@@ -198,41 +210,89 @@ private static RubyString tryWeirdEncodings(ThreadContext context, RubyString st
198210
return str;
199211
}
200212

213+
boolean searchEscape(EscapeScanner.State state) throws IOException {
214+
byte[] escapeTable = StringEncoder.this.escapeTable;
215+
216+
while (state.pos < state.len) {
217+
state.ch = Byte.toUnsignedInt(state.ptrBytes[state.ptr + state.pos]);
218+
int ch_len = escapeTable[state.ch];
219+
220+
if (ch_len > 0) {
221+
return true;
222+
}
223+
224+
state.pos++;
225+
}
226+
227+
return false;
228+
}
229+
230+
void encodeBasic(ByteList src) throws IOException {
231+
EscapeScanner.State state = new EscapeScanner.State();
232+
state.ptrBytes = src.unsafeBytes();
233+
state.ptr = src.begin();
234+
state.len = src.realSize();
235+
state.beg = 0;
236+
state.pos = 0;
237+
238+
byte[] hexdig = HEX;
239+
byte[] scratch = aux;
240+
241+
EscapeScanner scanner = EscapeScanner.basicScanner();
242+
243+
while(scanner.scan(state)) {
244+
int ch = Byte.toUnsignedInt(state.ptrBytes[state.ptr + state.pos]);
245+
state.beg = state.pos = flushPos(state.pos, state.beg, state.ptrBytes, state.ptr, 1);
246+
escapeAscii(ch, scratch, hexdig);
247+
}
248+
249+
if (state.beg < state.len) {
250+
append(state.ptrBytes, state.ptr + state.beg, state.len - state.beg);
251+
}
252+
}
253+
201254
// C: convert_UTF8_to_JSON
202255
void encode(ByteList src) throws IOException {
256+
if (this.escapeTable == StringEncoder.ESCAPE_TABLE) {
257+
encodeBasic(src);
258+
return;
259+
}
260+
203261
byte[] hexdig = HEX;
204262
byte[] scratch = aux;
205263
byte[] escapeTable = this.escapeTable;
206264

207-
byte[] ptrBytes = src.unsafeBytes();
208-
int ptr = src.begin();
209-
int len = src.realSize();
210-
211-
int beg = 0;
212-
int pos = 0;
213-
214-
while (pos < len) {
215-
int ch = Byte.toUnsignedInt(ptrBytes[ptr + pos]);
265+
EscapeScanner.State state = new EscapeScanner.State();
266+
state.ptrBytes = src.unsafeBytes();
267+
state.ptr = src.begin();
268+
state.len = src.realSize();
269+
state.beg = 0;
270+
state.pos = 0;
271+
272+
while(searchEscape(state)) {
273+
// We found an escape character, so we need to flush up to this point
274+
// and then handle the escape character.
275+
state.beg = flushPos(state.pos, state.beg, state.ptrBytes, state.ptr, 0);
276+
int ch = Byte.toUnsignedInt(state.ptrBytes[state.ptr + state.pos]);
216277
int ch_len = escapeTable[ch];
217-
/* JSON encoding */
218278

219279
if (ch_len > 0) {
220280
switch (ch_len) {
221281
case 9: {
222-
beg = pos = flushPos(pos, beg, ptrBytes, ptr, 1);
282+
state.beg = state.pos = flushPos(state.pos, state.beg, state.ptrBytes, state.ptr, 1);
223283
escapeAscii(ch, scratch, hexdig);
224284
break;
225285
}
226286
case 11: {
227-
int b2 = Byte.toUnsignedInt(ptrBytes[ptr + pos + 1]);
287+
int b2 = Byte.toUnsignedInt(state.ptrBytes[state.ptr + state.pos + 1]);
228288
if (b2 == 0x80) {
229-
int b3 = Byte.toUnsignedInt(ptrBytes[ptr + pos + 2]);
289+
int b3 = Byte.toUnsignedInt(state.ptrBytes[state.ptr + state.pos + 2]);
230290
if (b3 == 0xA8) {
231-
beg = pos = flushPos(pos, beg, ptrBytes, ptr, 3);
291+
state.beg = state.pos = flushPos(state.pos, state.beg, state.ptrBytes, state.ptr, 3);
232292
append(BACKSLASH_U2028, 0, 6);
233293
break;
234294
} else if (b3 == 0xA9) {
235-
beg = pos = flushPos(pos, beg, ptrBytes, ptr, 3);
295+
state.beg = state.pos = flushPos(state.pos, state.beg, state.ptrBytes, state.ptr, 3);
236296
append(BACKSLASH_U2029, 0, 6);
237297
break;
238298
}
@@ -241,16 +301,55 @@ void encode(ByteList src) throws IOException {
241301
// fallthrough
242302
}
243303
default:
244-
pos += ch_len;
304+
state.pos += ch_len;
245305
break;
246306
}
247307
} else {
248-
pos++;
308+
// This should be unreachable.
309+
state.pos++;
249310
}
250311
}
251312

252-
if (beg < len) {
253-
append(ptrBytes, ptr + beg, len - beg);
313+
// while (state.pos < state.len) {
314+
// int ch = Byte.toUnsignedInt(state.ptrBytes[state.ptr + state.pos]);
315+
// int ch_len = escapeTable[ch];
316+
// /* JSON encoding */
317+
318+
// if (ch_len > 0) {
319+
// switch (ch_len) {
320+
// case 9: {
321+
// state.beg = state.pos = flushPos(state.pos, state.beg, state.ptrBytes, state.ptr, 1);
322+
// escapeAscii(ch, scratch, hexdig);
323+
// break;
324+
// }
325+
// case 11: {
326+
// int b2 = Byte.toUnsignedInt(state.ptrBytes[state.ptr + state.pos + 1]);
327+
// if (b2 == 0x80) {
328+
// int b3 = Byte.toUnsignedInt(state.ptrBytes[state.ptr + state.pos + 2]);
329+
// if (b3 == 0xA8) {
330+
// state.beg = state.pos = flushPos(state.pos, state.beg, state.ptrBytes, state.ptr, 3);
331+
// append(BACKSLASH_U2028, 0, 6);
332+
// break;
333+
// } else if (b3 == 0xA9) {
334+
// state.beg = state.pos = flushPos(state.pos, state.beg, state.ptrBytes, state.ptr, 3);
335+
// append(BACKSLASH_U2029, 0, 6);
336+
// break;
337+
// }
338+
// }
339+
// ch_len = 3;
340+
// // fallthrough
341+
// }
342+
// default:
343+
// state.pos += ch_len;
344+
// break;
345+
// }
346+
// } else {
347+
// state.pos++;
348+
// }
349+
// }
350+
351+
if (state.beg < state.len) {
352+
append(state.ptrBytes, state.ptr + state.beg, state.len - state.beg);
254353
}
255354
}
256355

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package json.ext;
2+
3+
import java.io.IOException;
4+
5+
import jdk.incubator.vector.ByteVector;
6+
import jdk.incubator.vector.VectorMask;
7+
import jdk.incubator.vector.VectorOperators;
8+
import jdk.incubator.vector.VectorSpecies;
9+
10+
public class VectorizedEscapeScanner implements EscapeScanner {
11+
public static EscapeScanner.ScalarEscapeScanner FALLBACK = new EscapeScanner.ScalarEscapeScanner(StringEncoder.ESCAPE_TABLE);
12+
13+
// private VectorMask<Byte> needsEscape = null;
14+
// private int chunkStart = 0;
15+
16+
@Override
17+
public boolean scan(State state) throws IOException {
18+
VectorSpecies<Byte> species = ByteVector.SPECIES_PREFERRED;
19+
20+
// if (needsEscape != null) {
21+
// if (needsEscape.anyTrue()) {
22+
// int firstEscapeIndex = needsEscape.firstTrue();
23+
// needsEscape = needsEscape.andNot(VectorMask.fromLong(species, 1L << firstEscapeIndex));
24+
// state.pos = chunkStart + firstEscapeIndex;
25+
// return true;
26+
// } else {
27+
// needsEscape = null;
28+
// }
29+
// }
30+
31+
while ((state.ptr + state.pos) + species.length() < state.len) {
32+
ByteVector chunk = ByteVector.fromArray(species, state.ptrBytes, state.ptr + state.pos);
33+
ByteVector zero = ByteVector.broadcast(species, 0);
34+
35+
// bytes are unsigned in java, so we need to check for negative values
36+
// to determine if we have a byte that is less than 0 (>= 128).
37+
VectorMask<Byte> negative = zero.lt(chunk);
38+
39+
VectorMask<Byte> tooLowOrDblQuote = chunk.lanewise(VectorOperators.XOR, ByteVector.broadcast(species, 2))
40+
.lt(ByteVector.broadcast(species, 33));
41+
42+
VectorMask<Byte> needsEscape = chunk.eq(ByteVector.broadcast(species, '\\')).or(tooLowOrDblQuote).and(negative);
43+
if (needsEscape.anyTrue()) {
44+
// chunkStart = state.ptr + state.pos;
45+
int firstEscapeIndex = needsEscape.firstTrue();
46+
// Clear the bit at firstEscapeIndex to avoid scanning the same byte again
47+
// needsEscape = needsEscape.andNot(VectorMask.fromLong(species, 1L << firstEscapeIndex));
48+
state.pos += firstEscapeIndex;
49+
return true;
50+
}
51+
52+
state.pos += species.length();
53+
}
54+
55+
return FALLBACK.scan(state);
56+
}
57+
}

0 commit comments

Comments
 (0)