Skip to content

Commit 0b8f52b

Browse files
GH-2375 - Make LiteralReplacement-Cache threadsafe.
Fixes #2375.
1 parent 50e3d66 commit 0b8f52b

File tree

2 files changed

+109
-7
lines changed

2 files changed

+109
-7
lines changed

src/main/java/org/springframework/data/neo4j/repository/query/Neo4jSpelSupport.java

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.LinkedHashMap;
1919
import java.util.Locale;
2020
import java.util.Map;
21+
import java.util.concurrent.locks.StampedLock;
2122
import java.util.regex.Matcher;
2223
import java.util.regex.Pattern;
2324
import java.util.stream.Collectors;
@@ -123,12 +124,39 @@ protected boolean removeEldestEntry(Map.Entry<String, LiteralReplacement> eldest
123124
}
124125
};
125126

127+
private static final StampedLock LOCK = new StampedLock();
128+
126129
static LiteralReplacement withTargetAndValue(LiteralReplacement.Target target, @Nullable String value) {
127130

128131
String valueUsed = value == null ? "" : value;
129-
StringBuilder key = new StringBuilder(target.name()).append("_").append(valueUsed);
130-
131-
return INSTANCES.computeIfAbsent(key.toString(), k -> new StringBasedLiteralReplacement(target, valueUsed));
132+
String key = new StringBuilder(target.name()).append("_").append(valueUsed).toString();
133+
134+
long stamp = LOCK.tryOptimisticRead();
135+
if (LOCK.validate(stamp) && INSTANCES.containsKey(key)) {
136+
return INSTANCES.get(key);
137+
}
138+
try {
139+
stamp = LOCK.readLock();
140+
LiteralReplacement replacement = null;
141+
while (replacement == null) {
142+
if (INSTANCES.containsKey(key)) {
143+
replacement = INSTANCES.get(key);
144+
} else {
145+
long writeStamp = LOCK.tryConvertToWriteLock(stamp);
146+
if (LOCK.validate(writeStamp)) {
147+
replacement = new StringBasedLiteralReplacement(target, valueUsed);
148+
stamp = writeStamp;
149+
INSTANCES.put(key, replacement);
150+
} else {
151+
LOCK.unlockRead(stamp);
152+
stamp = LOCK.writeLock();
153+
}
154+
}
155+
}
156+
return replacement;
157+
} finally {
158+
LOCK.unlock(stamp);
159+
}
132160
}
133161

134162
private final Target target;

src/test/java/org/springframework/data/neo4j/repository/query/Neo4jSpelSupportTest.java

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,21 @@
1717

1818
import static org.assertj.core.api.Assertions.assertThat;
1919
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
20+
import static org.assertj.core.api.Assumptions.assumeThat;
2021

22+
import java.lang.reflect.Field;
23+
import java.util.ArrayList;
24+
import java.util.Collection;
2125
import java.util.Collections;
26+
import java.util.IdentityHashMap;
27+
import java.util.Map;
28+
import java.util.concurrent.Callable;
29+
import java.util.concurrent.ExecutionException;
30+
import java.util.concurrent.ExecutorService;
31+
import java.util.concurrent.Executors;
32+
import java.util.concurrent.Future;
33+
import java.util.concurrent.atomic.AtomicBoolean;
34+
import java.util.concurrent.atomic.AtomicInteger;
2235

2336
import org.junit.jupiter.api.Test;
2437
import org.junit.jupiter.params.ParameterizedTest;
@@ -31,6 +44,7 @@
3144
import org.springframework.data.neo4j.repository.query.Neo4jSpelSupport.LiteralReplacement;
3245
import org.springframework.data.repository.core.EntityMetadata;
3346
import org.springframework.expression.spel.standard.SpelExpressionParser;
47+
import org.springframework.util.ReflectionUtils;
3448

3549
/**
3650
* @author Michael J. Simons
@@ -74,19 +88,79 @@ void orderByShouldWork() {
7488
.withMessageMatching(".+is not a valid order criteria.");
7589
}
7690

91+
private Map<?, ?> getCacheInstance() throws ClassNotFoundException, IllegalAccessException {
92+
Class<?> type = Class.forName(
93+
"org.springframework.data.neo4j.repository.query.Neo4jSpelSupport$StringBasedLiteralReplacement");
94+
Field cacheField = ReflectionUtils.findField(type, "INSTANCES");
95+
cacheField.setAccessible(true);
96+
return (Map<?, ?>) cacheField.get(null);
97+
}
98+
99+
private void flushLiteralCache() {
100+
try {
101+
Map<?, ?> cache = getCacheInstance();
102+
cache.clear();
103+
} catch (Exception e) {
104+
throw new RuntimeException(e);
105+
}
106+
}
107+
108+
private int getCacheSize() {
109+
try {
110+
Map<?, ?> cache = getCacheInstance();
111+
return cache.size();
112+
} catch (Exception e) {
113+
throw new RuntimeException(e);
114+
}
115+
}
116+
77117
@Test // DATAGRAPH-1454
78118
void cacheShouldWork() {
79119

80-
// Make sure we flush this before...
81-
for (int i = 0; i < 16; ++i) {
82-
LiteralReplacement literalReplacement = Neo4jSpelSupport.literal("y" + i);
83-
}
120+
flushLiteralCache();
84121

85122
LiteralReplacement literalReplacement1 = Neo4jSpelSupport.literal("x");
86123
LiteralReplacement literalReplacement2 = Neo4jSpelSupport.literal("x");
87124
assertThat(literalReplacement1).isSameAs(literalReplacement2);
88125
}
89126

127+
@Test // GH-2375
128+
void cacheShouldBeThreadSafe() throws ExecutionException, InterruptedException {
129+
130+
flushLiteralCache();
131+
132+
int numThreads = Runtime.getRuntime().availableProcessors();
133+
ExecutorService executor = Executors.newWorkStealingPool();
134+
135+
AtomicBoolean running = new AtomicBoolean();
136+
AtomicInteger overlaps = new AtomicInteger();
137+
138+
Collection<Callable<LiteralReplacement>> getReplacementCalls = new ArrayList<>();
139+
for (int t = 0; t < numThreads; ++t) {
140+
getReplacementCalls.add(() -> {
141+
if (!running.compareAndSet(false, true)) {
142+
overlaps.incrementAndGet();
143+
}
144+
Thread.sleep(100); // Make the chances of overlapping a bit bigger
145+
LiteralReplacement d = Neo4jSpelSupport.literal("x");
146+
running.compareAndSet(true, false);
147+
return d;
148+
});
149+
}
150+
151+
Map<LiteralReplacement, Integer> replacements = new IdentityHashMap<>();
152+
for (Future<LiteralReplacement> getDriverFuture : executor.invokeAll(getReplacementCalls)) {
153+
replacements.put(getDriverFuture.get(), 1);
154+
}
155+
executor.shutdown();
156+
157+
// Assume things actually had been concurrent
158+
assumeThat(overlaps.get()).isGreaterThan(0);
159+
160+
assertThat(getCacheSize()).isEqualTo(1);
161+
assertThat(replacements).hasSize(1);
162+
}
163+
90164
@ParameterizedTest // GH-2279
91165
@CsvSource({
92166
"MATCH (n:Something) WHERE n.name = ?#{#name}, MATCH (n:Something) WHERE n.name = ?__HASH__{#name}",

0 commit comments

Comments
 (0)