@@ -55,11 +55,11 @@ class DefaultNeo4jClient implements Neo4jClient {
55
55
56
56
private final Driver driver ;
57
57
private final TypeSystem typeSystem ;
58
- private final DatabaseSelectionProvider databaseSelectionProvider ;
58
+ private @ Nullable final DatabaseSelectionProvider databaseSelectionProvider ;
59
59
private final ConversionService conversionService ;
60
60
private final Neo4jPersistenceExceptionTranslator persistenceExceptionTranslator = new Neo4jPersistenceExceptionTranslator ();
61
61
62
- DefaultNeo4jClient (Driver driver , DatabaseSelectionProvider databaseSelectionProvider ) {
62
+ DefaultNeo4jClient (Driver driver , @ Nullable DatabaseSelectionProvider databaseSelectionProvider ) {
63
63
64
64
this .driver = driver ;
65
65
this .typeSystem = driver .defaultTypeSystem ();
@@ -69,8 +69,9 @@ class DefaultNeo4jClient implements Neo4jClient {
69
69
new Neo4jConversions ().registerConvertersIn ((ConverterRegistry ) conversionService );
70
70
}
71
71
72
- QueryRunner getQueryRunner (@ Nullable final String targetDatabase ) {
72
+ QueryRunner getQueryRunner (DatabaseSelection databaseSelection ) {
73
73
74
+ String targetDatabase = databaseSelection .getValue ();
74
75
QueryRunner queryRunner = Neo4jTransactionManager .retrieveTransaction (driver , targetDatabase );
75
76
if (queryRunner == null ) {
76
77
queryRunner = driver .session (Neo4jTransactionUtils .defaultSessionConfig (targetDatabase ));
@@ -193,21 +194,33 @@ private static RuntimeException potentiallyConvertRuntimeException(RuntimeExcept
193
194
return resolved == null ? ex : resolved ;
194
195
}
195
196
197
+ private DatabaseSelection resolveTargetDatabaseName (@ Nullable String parameterTargetDatabase ) {
198
+
199
+ String value = Neo4jClient .verifyDatabaseName (parameterTargetDatabase );
200
+ if (value != null ) {
201
+ return DatabaseSelection .byName (value );
202
+ }
203
+ if (databaseSelectionProvider != null ) {
204
+ return databaseSelectionProvider .getDatabaseSelection ();
205
+ }
206
+ return DatabaseSelectionProvider .getDefaultSelectionProvider ().getDatabaseSelection ();
207
+ }
208
+
196
209
class DefaultRunnableSpec implements RunnableSpec {
197
210
198
211
private final RunnableStatement runnableStatement ;
199
212
200
- private String targetDatabase ;
213
+ private DatabaseSelection databaseSelection ;
201
214
202
215
DefaultRunnableSpec (Supplier <String > cypherSupplier ) {
203
- this .targetDatabase = Neo4jClient . verifyDatabaseName ( resolveTargetDatabaseName (null ) );
216
+ this .databaseSelection = resolveTargetDatabaseName (null );
204
217
this .runnableStatement = new RunnableStatement (cypherSupplier );
205
218
}
206
219
207
220
@ Override
208
- public RunnableSpecTightToDatabase in (@ SuppressWarnings ( "HiddenField" ) String targetDatabase ) {
221
+ public RunnableSpecTightToDatabase in (String targetDatabase ) {
209
222
210
- this .targetDatabase = Neo4jClient . verifyDatabaseName (targetDatabase );
223
+ this .databaseSelection = resolveTargetDatabaseName (targetDatabase );
211
224
return this ;
212
225
}
213
226
@@ -249,20 +262,20 @@ public RunnableSpecTightToDatabase bindAll(Map<String, Object> newParameters) {
249
262
@ Override
250
263
public <T > MappingSpec <T > fetchAs (Class <T > targetClass ) {
251
264
252
- return new DefaultRecordFetchSpec <>(this .targetDatabase , this .runnableStatement ,
265
+ return new DefaultRecordFetchSpec <>(this .databaseSelection , this .runnableStatement ,
253
266
new SingleValueMappingFunction <>(conversionService , targetClass ));
254
267
}
255
268
256
269
@ Override
257
270
public RecordFetchSpec <Map <String , Object >> fetch () {
258
271
259
- return new DefaultRecordFetchSpec <>(this .targetDatabase , this .runnableStatement , (t , r ) -> r .asMap ());
272
+ return new DefaultRecordFetchSpec <>(this .databaseSelection , this .runnableStatement , (t , r ) -> r .asMap ());
260
273
}
261
274
262
275
@ Override
263
276
public ResultSummary run () {
264
277
265
- try (QueryRunner statementRunner = getQueryRunner (this .targetDatabase )) {
278
+ try (QueryRunner statementRunner = getQueryRunner (this .databaseSelection )) {
266
279
Result result = runnableStatement .runWith (statementRunner );
267
280
return ResultSummaries .process (result .consume ());
268
281
} catch (RuntimeException e ) {
@@ -271,33 +284,20 @@ public ResultSummary run() {
271
284
throw new RuntimeException (e );
272
285
}
273
286
}
274
-
275
- private String resolveTargetDatabaseName (@ Nullable String parameterTargetDatabase ) {
276
- if (parameterTargetDatabase != null ) {
277
- return parameterTargetDatabase ;
278
- }
279
- if (databaseSelectionProvider != null ) {
280
- String databaseSelectionProviderValue = databaseSelectionProvider .getDatabaseSelection ().getValue ();
281
- if (databaseSelectionProviderValue != null ) {
282
- return databaseSelectionProviderValue ;
283
- }
284
- }
285
- return DatabaseSelectionProvider .getDefaultSelectionProvider ().getDatabaseSelection ().getValue ();
286
- }
287
287
}
288
288
289
289
class DefaultRecordFetchSpec <T > implements RecordFetchSpec <T >, MappingSpec <T > {
290
290
291
- private final String targetDatabase ;
291
+ private final DatabaseSelection databaseSelection ;
292
292
293
293
private final RunnableStatement runnableStatement ;
294
294
295
295
private BiFunction <TypeSystem , Record , T > mappingFunction ;
296
296
297
- DefaultRecordFetchSpec (String parameterTargetDatabase , RunnableStatement runnableStatement ,
297
+ DefaultRecordFetchSpec (DatabaseSelection databaseSelection , RunnableStatement runnableStatement ,
298
298
BiFunction <TypeSystem , Record , T > mappingFunction ) {
299
299
300
- this .targetDatabase = parameterTargetDatabase ;
300
+ this .databaseSelection = databaseSelection ;
301
301
this .runnableStatement = runnableStatement ;
302
302
this .mappingFunction = mappingFunction ;
303
303
}
@@ -313,7 +313,7 @@ public RecordFetchSpec<T> mappedBy(
313
313
@ Override
314
314
public Optional <T > one () {
315
315
316
- try (QueryRunner statementRunner = getQueryRunner (this .targetDatabase )) {
316
+ try (QueryRunner statementRunner = getQueryRunner (this .databaseSelection )) {
317
317
Result result = runnableStatement .runWith (statementRunner );
318
318
Optional <T > optionalValue = result .hasNext () ?
319
319
Optional .ofNullable (mappingFunction .apply (typeSystem , result .single ())) :
@@ -330,7 +330,7 @@ public Optional<T> one() {
330
330
@ Override
331
331
public Optional <T > first () {
332
332
333
- try (QueryRunner statementRunner = getQueryRunner (this .targetDatabase )) {
333
+ try (QueryRunner statementRunner = getQueryRunner (this .databaseSelection )) {
334
334
Result result = runnableStatement .runWith (statementRunner );
335
335
Optional <T > optionalValue = result .stream ().map (partialMappingFunction (typeSystem )).findFirst ();
336
336
ResultSummaries .process (result .consume ());
@@ -345,7 +345,7 @@ public Optional<T> first() {
345
345
@ Override
346
346
public Collection <T > all () {
347
347
348
- try (QueryRunner statementRunner = getQueryRunner (this .targetDatabase )) {
348
+ try (QueryRunner statementRunner = getQueryRunner (this .databaseSelection )) {
349
349
Result result = runnableStatement .runWith (statementRunner );
350
350
Collection <T > values = result .stream ().map (partialMappingFunction (typeSystem )).collect (Collectors .toList ());
351
351
ResultSummaries .process (result .consume ());
@@ -368,29 +368,25 @@ private Function<Record, T> partialMappingFunction(TypeSystem typeSystem) {
368
368
369
369
class DefaultRunnableDelegation <T > implements RunnableDelegation <T >, OngoingDelegation <T > {
370
370
371
- private final Function < QueryRunner , Optional < T >> callback ;
371
+ private DatabaseSelection databaseSelection ;
372
372
373
- @ Nullable private String targetDatabase ;
373
+ private final Function < QueryRunner , Optional < T >> callback ;
374
374
375
375
DefaultRunnableDelegation (Function <QueryRunner , Optional <T >> callback ) {
376
- this (callback , null );
377
- }
378
-
379
- DefaultRunnableDelegation (Function <QueryRunner , Optional <T >> callback , @ Nullable String targetDatabase ) {
380
376
this .callback = callback ;
381
- this .targetDatabase = Neo4jClient . verifyDatabaseName ( targetDatabase );
377
+ this .databaseSelection = resolveTargetDatabaseName ( null );
382
378
}
383
379
384
380
@ Override
385
- public RunnableDelegation <T > in (@ Nullable @ SuppressWarnings ( "HiddenField" ) String targetDatabase ) {
381
+ public RunnableDelegation <T > in (@ Nullable String targetDatabase ) {
386
382
387
- this .targetDatabase = Neo4jClient . verifyDatabaseName (targetDatabase );
383
+ this .databaseSelection = resolveTargetDatabaseName (targetDatabase );
388
384
return this ;
389
385
}
390
386
391
387
@ Override
392
388
public Optional <T > run () {
393
- try (QueryRunner queryRunner = getQueryRunner (targetDatabase )) {
389
+ try (QueryRunner queryRunner = getQueryRunner (databaseSelection )) {
394
390
return callback .apply (queryRunner );
395
391
} catch (RuntimeException e ) {
396
392
throw potentiallyConvertRuntimeException (e , persistenceExceptionTranslator );
@@ -399,5 +395,4 @@ public Optional<T> run() {
399
395
}
400
396
}
401
397
}
402
-
403
398
}
0 commit comments