Skip to content

Commit 1922b8e

Browse files
GH-2369 - Propagate Default database name when using delegateTo.
Also unify both imperative and reactive clients in their flows (always passing a `DatabaseSelection` around). Closes #2369.
1 parent 3f0c9a1 commit 1922b8e

File tree

10 files changed

+169
-114
lines changed

10 files changed

+169
-114
lines changed

src/main/java/org/springframework/data/neo4j/core/DefaultNeo4jClient.java

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ class DefaultNeo4jClient implements Neo4jClient {
5555

5656
private final Driver driver;
5757
private final TypeSystem typeSystem;
58-
private final DatabaseSelectionProvider databaseSelectionProvider;
58+
private @Nullable final DatabaseSelectionProvider databaseSelectionProvider;
5959
private final ConversionService conversionService;
6060
private final Neo4jPersistenceExceptionTranslator persistenceExceptionTranslator = new Neo4jPersistenceExceptionTranslator();
6161

62-
DefaultNeo4jClient(Driver driver, DatabaseSelectionProvider databaseSelectionProvider) {
62+
DefaultNeo4jClient(Driver driver, @Nullable DatabaseSelectionProvider databaseSelectionProvider) {
6363

6464
this.driver = driver;
6565
this.typeSystem = driver.defaultTypeSystem();
@@ -69,8 +69,9 @@ class DefaultNeo4jClient implements Neo4jClient {
6969
new Neo4jConversions().registerConvertersIn((ConverterRegistry) conversionService);
7070
}
7171

72-
QueryRunner getQueryRunner(@Nullable final String targetDatabase) {
72+
QueryRunner getQueryRunner(DatabaseSelection databaseSelection) {
7373

74+
String targetDatabase = databaseSelection.getValue();
7475
QueryRunner queryRunner = Neo4jTransactionManager.retrieveTransaction(driver, targetDatabase);
7576
if (queryRunner == null) {
7677
queryRunner = driver.session(Neo4jTransactionUtils.defaultSessionConfig(targetDatabase));
@@ -193,21 +194,33 @@ private static RuntimeException potentiallyConvertRuntimeException(RuntimeExcept
193194
return resolved == null ? ex : resolved;
194195
}
195196

197+
private DatabaseSelection resolveTargetDatabaseName(@Nullable String parameterTargetDatabase) {
198+
199+
String value = Neo4jClient.verifyDatabaseName(parameterTargetDatabase);
200+
if (value != null) {
201+
return DatabaseSelection.byName(value);
202+
}
203+
if (databaseSelectionProvider != null) {
204+
return databaseSelectionProvider.getDatabaseSelection();
205+
}
206+
return DatabaseSelectionProvider.getDefaultSelectionProvider().getDatabaseSelection();
207+
}
208+
196209
class DefaultRunnableSpec implements RunnableSpec {
197210

198211
private final RunnableStatement runnableStatement;
199212

200-
private String targetDatabase;
213+
private DatabaseSelection databaseSelection;
201214

202215
DefaultRunnableSpec(Supplier<String> cypherSupplier) {
203-
this.targetDatabase = Neo4jClient.verifyDatabaseName(resolveTargetDatabaseName(null));
216+
this.databaseSelection = resolveTargetDatabaseName(null);
204217
this.runnableStatement = new RunnableStatement(cypherSupplier);
205218
}
206219

207220
@Override
208-
public RunnableSpecTightToDatabase in(@SuppressWarnings("HiddenField") String targetDatabase) {
221+
public RunnableSpecTightToDatabase in(String targetDatabase) {
209222

210-
this.targetDatabase = Neo4jClient.verifyDatabaseName(targetDatabase);
223+
this.databaseSelection = resolveTargetDatabaseName(targetDatabase);
211224
return this;
212225
}
213226

@@ -249,20 +262,20 @@ public RunnableSpecTightToDatabase bindAll(Map<String, Object> newParameters) {
249262
@Override
250263
public <T> MappingSpec<T> fetchAs(Class<T> targetClass) {
251264

252-
return new DefaultRecordFetchSpec<>(this.targetDatabase, this.runnableStatement,
265+
return new DefaultRecordFetchSpec<>(this.databaseSelection, this.runnableStatement,
253266
new SingleValueMappingFunction<>(conversionService, targetClass));
254267
}
255268

256269
@Override
257270
public RecordFetchSpec<Map<String, Object>> fetch() {
258271

259-
return new DefaultRecordFetchSpec<>(this.targetDatabase, this.runnableStatement, (t, r) -> r.asMap());
272+
return new DefaultRecordFetchSpec<>(this.databaseSelection, this.runnableStatement, (t, r) -> r.asMap());
260273
}
261274

262275
@Override
263276
public ResultSummary run() {
264277

265-
try (QueryRunner statementRunner = getQueryRunner(this.targetDatabase)) {
278+
try (QueryRunner statementRunner = getQueryRunner(this.databaseSelection)) {
266279
Result result = runnableStatement.runWith(statementRunner);
267280
return ResultSummaries.process(result.consume());
268281
} catch (RuntimeException e) {
@@ -271,33 +284,20 @@ public ResultSummary run() {
271284
throw new RuntimeException(e);
272285
}
273286
}
274-
275-
private String resolveTargetDatabaseName(@Nullable String parameterTargetDatabase) {
276-
if (parameterTargetDatabase != null) {
277-
return parameterTargetDatabase;
278-
}
279-
if (databaseSelectionProvider != null) {
280-
String databaseSelectionProviderValue = databaseSelectionProvider.getDatabaseSelection().getValue();
281-
if (databaseSelectionProviderValue != null) {
282-
return databaseSelectionProviderValue;
283-
}
284-
}
285-
return DatabaseSelectionProvider.getDefaultSelectionProvider().getDatabaseSelection().getValue();
286-
}
287287
}
288288

289289
class DefaultRecordFetchSpec<T> implements RecordFetchSpec<T>, MappingSpec<T> {
290290

291-
private final String targetDatabase;
291+
private final DatabaseSelection databaseSelection;
292292

293293
private final RunnableStatement runnableStatement;
294294

295295
private BiFunction<TypeSystem, Record, T> mappingFunction;
296296

297-
DefaultRecordFetchSpec(String parameterTargetDatabase, RunnableStatement runnableStatement,
297+
DefaultRecordFetchSpec(DatabaseSelection databaseSelection, RunnableStatement runnableStatement,
298298
BiFunction<TypeSystem, Record, T> mappingFunction) {
299299

300-
this.targetDatabase = parameterTargetDatabase;
300+
this.databaseSelection = databaseSelection;
301301
this.runnableStatement = runnableStatement;
302302
this.mappingFunction = mappingFunction;
303303
}
@@ -313,7 +313,7 @@ public RecordFetchSpec<T> mappedBy(
313313
@Override
314314
public Optional<T> one() {
315315

316-
try (QueryRunner statementRunner = getQueryRunner(this.targetDatabase)) {
316+
try (QueryRunner statementRunner = getQueryRunner(this.databaseSelection)) {
317317
Result result = runnableStatement.runWith(statementRunner);
318318
Optional<T> optionalValue = result.hasNext() ?
319319
Optional.ofNullable(mappingFunction.apply(typeSystem, result.single())) :
@@ -330,7 +330,7 @@ public Optional<T> one() {
330330
@Override
331331
public Optional<T> first() {
332332

333-
try (QueryRunner statementRunner = getQueryRunner(this.targetDatabase)) {
333+
try (QueryRunner statementRunner = getQueryRunner(this.databaseSelection)) {
334334
Result result = runnableStatement.runWith(statementRunner);
335335
Optional<T> optionalValue = result.stream().map(partialMappingFunction(typeSystem)).findFirst();
336336
ResultSummaries.process(result.consume());
@@ -345,7 +345,7 @@ public Optional<T> first() {
345345
@Override
346346
public Collection<T> all() {
347347

348-
try (QueryRunner statementRunner = getQueryRunner(this.targetDatabase)) {
348+
try (QueryRunner statementRunner = getQueryRunner(this.databaseSelection)) {
349349
Result result = runnableStatement.runWith(statementRunner);
350350
Collection<T> values = result.stream().map(partialMappingFunction(typeSystem)).collect(Collectors.toList());
351351
ResultSummaries.process(result.consume());
@@ -368,29 +368,25 @@ private Function<Record, T> partialMappingFunction(TypeSystem typeSystem) {
368368

369369
class DefaultRunnableDelegation<T> implements RunnableDelegation<T>, OngoingDelegation<T> {
370370

371-
private final Function<QueryRunner, Optional<T>> callback;
371+
private DatabaseSelection databaseSelection;
372372

373-
@Nullable private String targetDatabase;
373+
private final Function<QueryRunner, Optional<T>> callback;
374374

375375
DefaultRunnableDelegation(Function<QueryRunner, Optional<T>> callback) {
376-
this(callback, null);
377-
}
378-
379-
DefaultRunnableDelegation(Function<QueryRunner, Optional<T>> callback, @Nullable String targetDatabase) {
380376
this.callback = callback;
381-
this.targetDatabase = Neo4jClient.verifyDatabaseName(targetDatabase);
377+
this.databaseSelection = resolveTargetDatabaseName(null);
382378
}
383379

384380
@Override
385-
public RunnableDelegation<T> in(@Nullable @SuppressWarnings("HiddenField") String targetDatabase) {
381+
public RunnableDelegation<T> in(@Nullable String targetDatabase) {
386382

387-
this.targetDatabase = Neo4jClient.verifyDatabaseName(targetDatabase);
383+
this.databaseSelection = resolveTargetDatabaseName(targetDatabase);
388384
return this;
389385
}
390386

391387
@Override
392388
public Optional<T> run() {
393-
try (QueryRunner queryRunner = getQueryRunner(targetDatabase)) {
389+
try (QueryRunner queryRunner = getQueryRunner(databaseSelection)) {
394390
return callback.apply(queryRunner);
395391
} catch (RuntimeException e) {
396392
throw potentiallyConvertRuntimeException(e, persistenceExceptionTranslator);
@@ -399,5 +395,4 @@ public Optional<T> run() {
399395
}
400396
}
401397
}
402-
403398
}

0 commit comments

Comments
 (0)