Skip to content

Commit 431a4f6

Browse files
GH-2369 - Propagate Default database name when using delegateTo.
Also unify both imperative and reactive clients in their flows (always passing a `DatabaseSelection` around). Closes #2369.
1 parent 2301ad8 commit 431a4f6

File tree

10 files changed

+169
-114
lines changed

10 files changed

+169
-114
lines changed

src/main/java/org/springframework/data/neo4j/core/DefaultNeo4jClient.java

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ class DefaultNeo4jClient implements Neo4jClient {
5555

5656
private final Driver driver;
5757
private final TypeSystem typeSystem;
58-
private final DatabaseSelectionProvider databaseSelectionProvider;
58+
private @Nullable final DatabaseSelectionProvider databaseSelectionProvider;
5959
private final ConversionService conversionService;
6060
private final Neo4jPersistenceExceptionTranslator persistenceExceptionTranslator = new Neo4jPersistenceExceptionTranslator();
6161

62-
DefaultNeo4jClient(Driver driver, DatabaseSelectionProvider databaseSelectionProvider) {
62+
DefaultNeo4jClient(Driver driver, @Nullable DatabaseSelectionProvider databaseSelectionProvider) {
6363

6464
this.driver = driver;
6565
this.typeSystem = driver.defaultTypeSystem();
@@ -69,8 +69,9 @@ class DefaultNeo4jClient implements Neo4jClient {
6969
new Neo4jConversions().registerConvertersIn((ConverterRegistry) conversionService);
7070
}
7171

72-
DelegatingQueryRunner getQueryRunner(@Nullable final String targetDatabase) {
72+
DelegatingQueryRunner getQueryRunner(DatabaseSelection databaseSelection) {
7373

74+
String targetDatabase = databaseSelection.getValue();
7475
QueryRunner queryRunner = Neo4jTransactionManager.retrieveTransaction(driver, targetDatabase);
7576
if (queryRunner == null) {
7677
queryRunner = driver.session(Neo4jTransactionUtils.defaultSessionConfig(targetDatabase));
@@ -193,21 +194,33 @@ private static RuntimeException potentiallyConvertRuntimeException(RuntimeExcept
193194
return resolved == null ? ex : resolved;
194195
}
195196

197+
private DatabaseSelection resolveTargetDatabaseName(@Nullable String parameterTargetDatabase) {
198+
199+
String value = Neo4jClient.verifyDatabaseName(parameterTargetDatabase);
200+
if (value != null) {
201+
return DatabaseSelection.byName(value);
202+
}
203+
if (databaseSelectionProvider != null) {
204+
return databaseSelectionProvider.getDatabaseSelection();
205+
}
206+
return DatabaseSelectionProvider.getDefaultSelectionProvider().getDatabaseSelection();
207+
}
208+
196209
class DefaultRunnableSpec implements RunnableSpec {
197210

198211
private RunnableStatement runnableStatement;
199212

200-
private String targetDatabase;
213+
private DatabaseSelection databaseSelection;
201214

202215
DefaultRunnableSpec(Supplier<String> cypherSupplier) {
203-
this.targetDatabase = Neo4jClient.verifyDatabaseName(resolveTargetDatabaseName(targetDatabase));
216+
this.databaseSelection = resolveTargetDatabaseName(null);
204217
this.runnableStatement = new RunnableStatement(cypherSupplier);
205218
}
206219

207220
@Override
208-
public RunnableSpecTightToDatabase in(@SuppressWarnings("HiddenField") String targetDatabase) {
221+
public RunnableSpecTightToDatabase in(String targetDatabase) {
209222

210-
this.targetDatabase = Neo4jClient.verifyDatabaseName(targetDatabase);
223+
this.databaseSelection = resolveTargetDatabaseName(targetDatabase);
211224
return this;
212225
}
213226

@@ -249,53 +262,40 @@ public RunnableSpecTightToDatabase bindAll(Map<String, Object> newParameters) {
249262
@Override
250263
public <T> MappingSpec<T> fetchAs(Class<T> targetClass) {
251264

252-
return new DefaultRecordFetchSpec(this.targetDatabase, this.runnableStatement,
265+
return new DefaultRecordFetchSpec<>(this.databaseSelection, this.runnableStatement,
253266
new SingleValueMappingFunction(conversionService, targetClass));
254267
}
255268

256269
@Override
257270
public RecordFetchSpec<Map<String, Object>> fetch() {
258271

259-
return new DefaultRecordFetchSpec<>(this.targetDatabase, this.runnableStatement, (t, r) -> r.asMap());
272+
return new DefaultRecordFetchSpec<>(this.databaseSelection, this.runnableStatement, (t, r) -> r.asMap());
260273
}
261274

262275
@Override
263276
public ResultSummary run() {
264277

265-
try (DelegatingQueryRunner statementRunner = getQueryRunner(this.targetDatabase)) {
278+
try (DelegatingQueryRunner statementRunner = getQueryRunner(this.databaseSelection)) {
266279
Result result = runnableStatement.runWith(statementRunner);
267280
return ResultSummaries.process(result.consume());
268281
} catch (RuntimeException e) {
269282
throw potentiallyConvertRuntimeException(e, persistenceExceptionTranslator);
270283
}
271284
}
272-
273-
private String resolveTargetDatabaseName(@Nullable String parameterTargetDatabase) {
274-
if (parameterTargetDatabase != null) {
275-
return parameterTargetDatabase;
276-
}
277-
if (databaseSelectionProvider != null) {
278-
String databaseSelectionProviderValue = databaseSelectionProvider.getDatabaseSelection().getValue();
279-
if (databaseSelectionProviderValue != null) {
280-
return databaseSelectionProviderValue;
281-
}
282-
}
283-
return DatabaseSelectionProvider.getDefaultSelectionProvider().getDatabaseSelection().getValue();
284-
}
285285
}
286286

287287
class DefaultRecordFetchSpec<T> implements RecordFetchSpec<T>, MappingSpec<T> {
288288

289-
private final String targetDatabase;
289+
private final DatabaseSelection databaseSelection;
290290

291291
private final RunnableStatement runnableStatement;
292292

293293
private BiFunction<TypeSystem, Record, T> mappingFunction;
294294

295-
DefaultRecordFetchSpec(String parameterTargetDatabase, RunnableStatement runnableStatement,
295+
DefaultRecordFetchSpec(DatabaseSelection databaseSelection, RunnableStatement runnableStatement,
296296
BiFunction<TypeSystem, Record, T> mappingFunction) {
297297

298-
this.targetDatabase = parameterTargetDatabase;
298+
this.databaseSelection = databaseSelection;
299299
this.runnableStatement = runnableStatement;
300300
this.mappingFunction = mappingFunction;
301301
}
@@ -311,7 +311,7 @@ public RecordFetchSpec<T> mappedBy(
311311
@Override
312312
public Optional<T> one() {
313313

314-
try (DelegatingQueryRunner statementRunner = getQueryRunner(this.targetDatabase)) {
314+
try (DelegatingQueryRunner statementRunner = getQueryRunner(this.databaseSelection)) {
315315
Result result = runnableStatement.runWith(statementRunner);
316316
Optional<T> optionalValue = result.hasNext() ?
317317
Optional.ofNullable(mappingFunction.apply(typeSystem, result.single())) :
@@ -326,7 +326,7 @@ public Optional<T> one() {
326326
@Override
327327
public Optional<T> first() {
328328

329-
try (DelegatingQueryRunner statementRunner = getQueryRunner(this.targetDatabase)) {
329+
try (DelegatingQueryRunner statementRunner = getQueryRunner(this.databaseSelection)) {
330330
Result result = runnableStatement.runWith(statementRunner);
331331
Optional<T> optionalValue = result.stream().map(partialMappingFunction(typeSystem)).findFirst();
332332
ResultSummaries.process(result.consume());
@@ -339,7 +339,7 @@ public Optional<T> first() {
339339
@Override
340340
public Collection<T> all() {
341341

342-
try (DelegatingQueryRunner statementRunner = getQueryRunner(this.targetDatabase)) {
342+
try (DelegatingQueryRunner statementRunner = getQueryRunner(this.databaseSelection)) {
343343
Result result = runnableStatement.runWith(statementRunner);
344344
Collection<T> values = result.stream().map(partialMappingFunction(typeSystem)).collect(Collectors.toList());
345345
ResultSummaries.process(result.consume());
@@ -360,34 +360,29 @@ private Function<Record, T> partialMappingFunction(TypeSystem typeSystem) {
360360

361361
class DefaultRunnableDelegation<T> implements RunnableDelegation<T>, OngoingDelegation<T> {
362362

363-
private final Function<QueryRunner, Optional<T>> callback;
363+
private DatabaseSelection databaseSelection;
364364

365-
@Nullable private String targetDatabase;
365+
private final Function<QueryRunner, Optional<T>> callback;
366366

367367
DefaultRunnableDelegation(Function<QueryRunner, Optional<T>> callback) {
368-
this(callback, null);
369-
}
370-
371-
DefaultRunnableDelegation(Function<QueryRunner, Optional<T>> callback, @Nullable String targetDatabase) {
372368
this.callback = callback;
373-
this.targetDatabase = Neo4jClient.verifyDatabaseName(targetDatabase);
369+
this.databaseSelection = resolveTargetDatabaseName(null);
374370
}
375371

376372
@Override
377-
public RunnableDelegation in(@Nullable @SuppressWarnings("HiddenField") String targetDatabase) {
373+
public RunnableDelegation in(@Nullable String targetDatabase) {
378374

379-
this.targetDatabase = Neo4jClient.verifyDatabaseName(targetDatabase);
375+
this.databaseSelection = resolveTargetDatabaseName(targetDatabase);
380376
return this;
381377
}
382378

383379
@Override
384380
public Optional<T> run() {
385-
try (DelegatingQueryRunner queryRunner = getQueryRunner(targetDatabase)) {
381+
try (DelegatingQueryRunner queryRunner = getQueryRunner(databaseSelection)) {
386382
return callback.apply(queryRunner);
387383
} catch (RuntimeException e) {
388384
throw potentiallyConvertRuntimeException(e, persistenceExceptionTranslator);
389385
}
390386
}
391387
}
392-
393388
}

0 commit comments

Comments
 (0)