2626import io .zonky .test .db .provider .EmbeddedDatabase ;
2727import io .zonky .test .db .provider .ProviderException ;
2828import org .apache .commons .lang3 .RandomStringUtils ;
29+ import org .apache .commons .lang3 .tuple .Pair ;
2930import org .slf4j .Logger ;
3031import org .slf4j .LoggerFactory ;
3132import org .springframework .core .env .Environment ;
3233import org .springframework .scheduling .concurrent .ThreadPoolTaskExecutor ;
3334import org .springframework .util .concurrent .ListenableFutureCallback ;
3435import org .springframework .util .concurrent .ListenableFutureTask ;
3536
37+ import java .util .Comparator ;
3638import java .util .List ;
3739import java .util .Objects ;
3840import java .util .Optional ;
@@ -64,6 +66,7 @@ public class PrefetchingDatabaseProvider implements DatabaseProvider {
6466
6567 protected static final ThreadPoolTaskExecutor taskExecutor = new PriorityThreadPoolTaskExecutor ();
6668 protected static final ConcurrentMap <PipelineKey , DatabasePipeline > pipelines = new ConcurrentHashMap <>();
69+ protected static final AtomicLong databaseCount = new AtomicLong ();
6770
6871 protected final int pipelineMaxCacheSize ;
6972
@@ -92,6 +95,7 @@ public PrefetchingDatabaseProvider(DatabaseProvider provider, Environment enviro
9295 public EmbeddedDatabase createDatabase (DatabasePreparer preparer ) throws ProviderException {
9396 Stopwatch stopwatch = Stopwatch .createStarted ();
9497 logger .trace ("Prefetching pipelines: {}" , pipelines .values ());
98+ databaseCount .decrementAndGet ();
9599
96100 PipelineKey key = new PipelineKey (provider , preparer );
97101 DatabasePipeline pipeline = pipelines .computeIfAbsent (key , k -> new DatabasePipeline ());
@@ -108,12 +112,13 @@ public EmbeddedDatabase createDatabase(DatabasePreparer preparer) throws Provide
108112 }
109113
110114 long invocationCount = pipeline .requests .incrementAndGet ();
111- if ( invocationCount > 1 ) {
112- if (invocationCount - 1 <= pipelineMaxCacheSize ) {
113- prepareDatabase ( key , - 1 );
114- }
115- reschedulePipeline (key );
115+ long databasesCount = pipeline . tasks . size () + pipeline . results . size ();
116+ if (result == null ) databasesCount --;
117+
118+ if ( databasesCount < invocationCount - 1 && databasesCount < pipelineMaxCacheSize ) {
119+ prepareDatabase (key , - 1 );
116120 }
121+ reschedulePipeline (key );
117122
118123 if (result == null ) {
119124 try {
@@ -141,6 +146,21 @@ protected PrefetchingTask prepareDatabase(PipelineKey key, int priority) {
141146 }
142147
143148 protected PrefetchingTask prepareNewDatabase (PipelineKey key , int priority ) {
149+ databaseCount .incrementAndGet ();
150+
151+ Pair <PipelineKey , EmbeddedDatabase > databaseToRemove = findDatabaseToRemove ().orElse (null );
152+ if (databaseToRemove != null ) {
153+ databaseCount .decrementAndGet ();
154+
155+ if (databaseToRemove .getKey ().equals (key )) {
156+ return executeTask (key , PrefetchingTask .withDatabase (databaseToRemove .getValue (), priority ));
157+ } else {
158+ databaseToRemove .getValue ().close ();
159+ DatabasePipeline pipeline = pipelines .get (databaseToRemove .getKey ());
160+ logger .trace ("Prepared database has been cleaned: {}" , pipeline .key );
161+ }
162+ }
163+
144164 return executeTask (key , PrefetchingTask .forPreparer (key .provider , key .preparer , priority ));
145165 }
146166
@@ -221,6 +241,35 @@ public void onFailure(Throwable error) {
221241 return task ;
222242 }
223243
244+ protected Optional <Pair <PipelineKey , EmbeddedDatabase >> findDatabaseToRemove () {
245+ while (databaseCount .get () > 35 ) {
246+ long timestampThreshold = System .currentTimeMillis () - 10_000 ;
247+
248+ PipelineKey key = pipelines .entrySet ().stream ()
249+ .map (e -> Pair .of (e .getKey (), e .getValue ().results .peek ()))
250+ .filter (e -> e .getValue () != null && e .getValue ().getTimestamp () < timestampThreshold )
251+ .min (Comparator .comparing (e -> e .getValue ().getTimestamp ()))
252+ .map (Pair ::getKey ).orElse (null );
253+
254+ if (key == null ) {
255+ return Optional .empty ();
256+ }
257+
258+ DatabasePipeline pipeline = pipelines .get (key );
259+ if (pipeline != null ) {
260+ PreparedResult result = pipeline .results .poll ();
261+ if (result != null ) {
262+ if (result .hasResult ()) {
263+ return Optional .of (Pair .of (key , result .get ()));
264+ } else {
265+ databaseCount .decrementAndGet ();
266+ }
267+ }
268+ }
269+ }
270+ return Optional .empty ();
271+ }
272+
224273 @ Override
225274 public boolean equals (Object o ) {
226275 if (this == o ) return true ;
@@ -288,6 +337,7 @@ protected enum State {
288337
289338 protected static class PreparedResult {
290339
340+ private final long timestamp = System .currentTimeMillis ();
291341 private final EmbeddedDatabase result ;
292342 private final Throwable error ;
293343
@@ -304,6 +354,14 @@ protected PreparedResult(EmbeddedDatabase result, Throwable error) {
304354 this .error = error ;
305355 }
306356
357+ public long getTimestamp () {
358+ return timestamp ;
359+ }
360+
361+ public boolean hasResult () {
362+ return result != null ;
363+ }
364+
307365 public EmbeddedDatabase get () throws ProviderException {
308366 if (result != null ) {
309367 return result ;
@@ -340,6 +398,10 @@ public static PrefetchingTask withDatabase(EmbeddedDatabase database, DatabasePr
340398 });
341399 }
342400
401+ public static PrefetchingTask withDatabase (EmbeddedDatabase database , int priority ) {
402+ return new PrefetchingTask (priority , EXISTING_DATABASE , () -> database );
403+ }
404+
343405 public static PrefetchingTask fromTask (PrefetchingTask task , int priority ) {
344406 return new PrefetchingTask (priority , task .type , task .action );
345407 }
0 commit comments