@@ -55,11 +55,11 @@ class DefaultNeo4jClient implements Neo4jClient {
5555
5656 private final Driver driver ;
5757 private final TypeSystem typeSystem ;
58- private final DatabaseSelectionProvider databaseSelectionProvider ;
58+ private @ Nullable final DatabaseSelectionProvider databaseSelectionProvider ;
5959 private final ConversionService conversionService ;
6060 private final Neo4jPersistenceExceptionTranslator persistenceExceptionTranslator = new Neo4jPersistenceExceptionTranslator ();
6161
62- DefaultNeo4jClient (Driver driver , DatabaseSelectionProvider databaseSelectionProvider ) {
62+ DefaultNeo4jClient (Driver driver , @ Nullable DatabaseSelectionProvider databaseSelectionProvider ) {
6363
6464 this .driver = driver ;
6565 this .typeSystem = driver .defaultTypeSystem ();
@@ -69,8 +69,9 @@ class DefaultNeo4jClient implements Neo4jClient {
6969 new Neo4jConversions ().registerConvertersIn ((ConverterRegistry ) conversionService );
7070 }
7171
72- DelegatingQueryRunner getQueryRunner (@ Nullable final String targetDatabase ) {
72+ DelegatingQueryRunner getQueryRunner (DatabaseSelection databaseSelection ) {
7373
74+ String targetDatabase = databaseSelection .getValue ();
7475 QueryRunner queryRunner = Neo4jTransactionManager .retrieveTransaction (driver , targetDatabase );
7576 if (queryRunner == null ) {
7677 queryRunner = driver .session (Neo4jTransactionUtils .defaultSessionConfig (targetDatabase ));
@@ -193,21 +194,33 @@ private static RuntimeException potentiallyConvertRuntimeException(RuntimeExcept
193194 return resolved == null ? ex : resolved ;
194195 }
195196
197+ private DatabaseSelection resolveTargetDatabaseName (@ Nullable String parameterTargetDatabase ) {
198+
199+ String value = Neo4jClient .verifyDatabaseName (parameterTargetDatabase );
200+ if (value != null ) {
201+ return DatabaseSelection .byName (value );
202+ }
203+ if (databaseSelectionProvider != null ) {
204+ return databaseSelectionProvider .getDatabaseSelection ();
205+ }
206+ return DatabaseSelectionProvider .getDefaultSelectionProvider ().getDatabaseSelection ();
207+ }
208+
196209 class DefaultRunnableSpec implements RunnableSpec {
197210
198211 private RunnableStatement runnableStatement ;
199212
200- private String targetDatabase ;
213+ private DatabaseSelection databaseSelection ;
201214
202215 DefaultRunnableSpec (Supplier <String > cypherSupplier ) {
203- this .targetDatabase = Neo4jClient . verifyDatabaseName ( resolveTargetDatabaseName (targetDatabase ) );
216+ this .databaseSelection = resolveTargetDatabaseName (null );
204217 this .runnableStatement = new RunnableStatement (cypherSupplier );
205218 }
206219
207220 @ Override
208- public RunnableSpecTightToDatabase in (@ SuppressWarnings ( "HiddenField" ) String targetDatabase ) {
221+ public RunnableSpecTightToDatabase in (String targetDatabase ) {
209222
210- this .targetDatabase = Neo4jClient . verifyDatabaseName (targetDatabase );
223+ this .databaseSelection = resolveTargetDatabaseName (targetDatabase );
211224 return this ;
212225 }
213226
@@ -249,53 +262,40 @@ public RunnableSpecTightToDatabase bindAll(Map<String, Object> newParameters) {
249262 @ Override
250263 public <T > MappingSpec <T > fetchAs (Class <T > targetClass ) {
251264
252- return new DefaultRecordFetchSpec (this .targetDatabase , this .runnableStatement ,
265+ return new DefaultRecordFetchSpec <> (this .databaseSelection , this .runnableStatement ,
253266 new SingleValueMappingFunction (conversionService , targetClass ));
254267 }
255268
256269 @ Override
257270 public RecordFetchSpec <Map <String , Object >> fetch () {
258271
259- return new DefaultRecordFetchSpec <>(this .targetDatabase , this .runnableStatement , (t , r ) -> r .asMap ());
272+ return new DefaultRecordFetchSpec <>(this .databaseSelection , this .runnableStatement , (t , r ) -> r .asMap ());
260273 }
261274
262275 @ Override
263276 public ResultSummary run () {
264277
265- try (DelegatingQueryRunner statementRunner = getQueryRunner (this .targetDatabase )) {
278+ try (DelegatingQueryRunner statementRunner = getQueryRunner (this .databaseSelection )) {
266279 Result result = runnableStatement .runWith (statementRunner );
267280 return ResultSummaries .process (result .consume ());
268281 } catch (RuntimeException e ) {
269282 throw potentiallyConvertRuntimeException (e , persistenceExceptionTranslator );
270283 }
271284 }
272-
273- private String resolveTargetDatabaseName (@ Nullable String parameterTargetDatabase ) {
274- if (parameterTargetDatabase != null ) {
275- return parameterTargetDatabase ;
276- }
277- if (databaseSelectionProvider != null ) {
278- String databaseSelectionProviderValue = databaseSelectionProvider .getDatabaseSelection ().getValue ();
279- if (databaseSelectionProviderValue != null ) {
280- return databaseSelectionProviderValue ;
281- }
282- }
283- return DatabaseSelectionProvider .getDefaultSelectionProvider ().getDatabaseSelection ().getValue ();
284- }
285285 }
286286
287287 class DefaultRecordFetchSpec <T > implements RecordFetchSpec <T >, MappingSpec <T > {
288288
289- private final String targetDatabase ;
289+ private final DatabaseSelection databaseSelection ;
290290
291291 private final RunnableStatement runnableStatement ;
292292
293293 private BiFunction <TypeSystem , Record , T > mappingFunction ;
294294
295- DefaultRecordFetchSpec (String parameterTargetDatabase , RunnableStatement runnableStatement ,
295+ DefaultRecordFetchSpec (DatabaseSelection databaseSelection , RunnableStatement runnableStatement ,
296296 BiFunction <TypeSystem , Record , T > mappingFunction ) {
297297
298- this .targetDatabase = parameterTargetDatabase ;
298+ this .databaseSelection = databaseSelection ;
299299 this .runnableStatement = runnableStatement ;
300300 this .mappingFunction = mappingFunction ;
301301 }
@@ -311,7 +311,7 @@ public RecordFetchSpec<T> mappedBy(
311311 @ Override
312312 public Optional <T > one () {
313313
314- try (DelegatingQueryRunner statementRunner = getQueryRunner (this .targetDatabase )) {
314+ try (DelegatingQueryRunner statementRunner = getQueryRunner (this .databaseSelection )) {
315315 Result result = runnableStatement .runWith (statementRunner );
316316 Optional <T > optionalValue = result .hasNext () ?
317317 Optional .ofNullable (mappingFunction .apply (typeSystem , result .single ())) :
@@ -326,7 +326,7 @@ public Optional<T> one() {
326326 @ Override
327327 public Optional <T > first () {
328328
329- try (DelegatingQueryRunner statementRunner = getQueryRunner (this .targetDatabase )) {
329+ try (DelegatingQueryRunner statementRunner = getQueryRunner (this .databaseSelection )) {
330330 Result result = runnableStatement .runWith (statementRunner );
331331 Optional <T > optionalValue = result .stream ().map (partialMappingFunction (typeSystem )).findFirst ();
332332 ResultSummaries .process (result .consume ());
@@ -339,7 +339,7 @@ public Optional<T> first() {
339339 @ Override
340340 public Collection <T > all () {
341341
342- try (DelegatingQueryRunner statementRunner = getQueryRunner (this .targetDatabase )) {
342+ try (DelegatingQueryRunner statementRunner = getQueryRunner (this .databaseSelection )) {
343343 Result result = runnableStatement .runWith (statementRunner );
344344 Collection <T > values = result .stream ().map (partialMappingFunction (typeSystem )).collect (Collectors .toList ());
345345 ResultSummaries .process (result .consume ());
@@ -360,34 +360,29 @@ private Function<Record, T> partialMappingFunction(TypeSystem typeSystem) {
360360
361361 class DefaultRunnableDelegation <T > implements RunnableDelegation <T >, OngoingDelegation <T > {
362362
363- private final Function < QueryRunner , Optional < T >> callback ;
363+ private DatabaseSelection databaseSelection ;
364364
365- @ Nullable private String targetDatabase ;
365+ private final Function < QueryRunner , Optional < T >> callback ;
366366
367367 DefaultRunnableDelegation (Function <QueryRunner , Optional <T >> callback ) {
368- this (callback , null );
369- }
370-
371- DefaultRunnableDelegation (Function <QueryRunner , Optional <T >> callback , @ Nullable String targetDatabase ) {
372368 this .callback = callback ;
373- this .targetDatabase = Neo4jClient . verifyDatabaseName ( targetDatabase );
369+ this .databaseSelection = resolveTargetDatabaseName ( null );
374370 }
375371
376372 @ Override
377- public RunnableDelegation in (@ Nullable @ SuppressWarnings ( "HiddenField" ) String targetDatabase ) {
373+ public RunnableDelegation in (@ Nullable String targetDatabase ) {
378374
379- this .targetDatabase = Neo4jClient . verifyDatabaseName (targetDatabase );
375+ this .databaseSelection = resolveTargetDatabaseName (targetDatabase );
380376 return this ;
381377 }
382378
383379 @ Override
384380 public Optional <T > run () {
385- try (DelegatingQueryRunner queryRunner = getQueryRunner (targetDatabase )) {
381+ try (DelegatingQueryRunner queryRunner = getQueryRunner (databaseSelection )) {
386382 return callback .apply (queryRunner );
387383 } catch (RuntimeException e ) {
388384 throw potentiallyConvertRuntimeException (e , persistenceExceptionTranslator );
389385 }
390386 }
391387 }
392-
393388}
0 commit comments