Skip to content

Commit 6218db9

Browse files
committed
Backport of InMemoryWebSession changes
- hooks to check expired sessions in both create and retrieve. - maxSessions limit on the total number of sessions. - getSessions method for management purposes - removeExpiredSessions public API Issue: SPR-17020, SPR-16713
1 parent 7ea8ecb commit 6218db9

File tree

2 files changed

+149
-80
lines changed

2 files changed

+149
-80
lines changed

spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java

Lines changed: 110 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
import java.time.Duration;
2121
import java.time.Instant;
2222
import java.time.ZoneId;
23+
import java.time.temporal.ChronoUnit;
24+
import java.util.Collections;
2325
import java.util.Iterator;
2426
import java.util.Map;
2527
import java.util.concurrent.ConcurrentHashMap;
26-
import java.util.concurrent.ConcurrentMap;
2728
import java.util.concurrent.atomic.AtomicReference;
2829
import java.util.concurrent.locks.ReentrantLock;
2930

@@ -43,20 +44,37 @@
4344
*/
4445
public class InMemoryWebSessionStore implements WebSessionStore {
4546

46-
/** Minimum period between expiration checks */
47-
private static final Duration EXPIRATION_CHECK_PERIOD = Duration.ofSeconds(60);
48-
4947
private static final IdGenerator idGenerator = new JdkIdGenerator();
5048

5149

50+
private int maxSessions = 10000;
51+
5252
private Clock clock = Clock.system(ZoneId.of("GMT"));
5353

54-
private final ConcurrentMap<String, InMemoryWebSession> sessions = new ConcurrentHashMap<>();
54+
private final Map<String, InMemoryWebSession> sessions = new ConcurrentHashMap<>();
55+
56+
private final ExpiredSessionChecker expiredSessionChecker = new ExpiredSessionChecker();
5557

56-
private volatile Instant nextExpirationCheckTime = Instant.now(this.clock).plus(EXPIRATION_CHECK_PERIOD);
5758

58-
private final ReentrantLock expirationCheckLock = new ReentrantLock();
59+
/**
60+
* Set the maximum number of sessions that can be stored. Once the limit is
61+
* reached, any attempt to store an additional session will result in an
62+
* {@link IllegalStateException}.
63+
* <p>By default set to 10000.
64+
* @param maxSessions the maximum number of sessions
65+
* @since 5.1
66+
*/
67+
public void setMaxSessions(int maxSessions) {
68+
this.maxSessions = maxSessions;
69+
}
5970

71+
/**
72+
* Return the maximum number of sessions that can be stored.
73+
* @since 5.1
74+
*/
75+
public int getMaxSessions() {
76+
return this.maxSessions;
77+
}
6078

6179
/**
6280
* Configure the {@link Clock} to use to set lastAccessTime on every created
@@ -70,8 +88,7 @@ public class InMemoryWebSessionStore implements WebSessionStore {
7088
public void setClock(Clock clock) {
7189
Assert.notNull(clock, "Clock is required");
7290
this.clock = clock;
73-
// Force a check when clock changes..
74-
this.nextExpirationCheckTime = Instant.now(this.clock);
91+
removeExpiredSessions();
7592
}
7693

7794
/**
@@ -81,67 +98,67 @@ public Clock getClock() {
8198
return this.clock;
8299
}
83100

101+
/**
102+
* Return the map of sessions with an {@link Collections#unmodifiableMap
103+
* unmodifiable} wrapper. This could be used for management purposes, to
104+
* list active sessions, invalidate expired ones, etc.
105+
* @since 5.1
106+
*/
107+
public Map<String, InMemoryWebSession> getSessions() {
108+
return Collections.unmodifiableMap(this.sessions);
109+
}
110+
84111

85112
@Override
86113
public Mono<WebSession> createWebSession() {
87-
return Mono.fromSupplier(InMemoryWebSession::new);
114+
Instant now = this.clock.instant();
115+
this.expiredSessionChecker.checkIfNecessary(now);
116+
return Mono.fromSupplier(() -> new InMemoryWebSession(now));
88117
}
89118

90119
@Override
91120
public Mono<WebSession> retrieveSession(String id) {
92-
Instant currentTime = Instant.now(this.clock);
93-
if (!this.sessions.isEmpty() && !currentTime.isBefore(this.nextExpirationCheckTime)) {
94-
checkExpiredSessions(currentTime);
95-
}
96-
121+
Instant now = this.clock.instant();
122+
this.expiredSessionChecker.checkIfNecessary(now);
97123
InMemoryWebSession session = this.sessions.get(id);
98124
if (session == null) {
99125
return Mono.empty();
100126
}
101-
else if (session.isExpired(currentTime)) {
127+
else if (session.isExpired(now)) {
102128
this.sessions.remove(id);
103129
return Mono.empty();
104130
}
105131
else {
106-
session.updateLastAccessTime(currentTime);
132+
session.updateLastAccessTime(now);
107133
return Mono.just(session);
108134
}
109135
}
110136

111-
private void checkExpiredSessions(Instant currentTime) {
112-
if (this.expirationCheckLock.tryLock()) {
113-
try {
114-
Iterator<InMemoryWebSession> iterator = this.sessions.values().iterator();
115-
while (iterator.hasNext()) {
116-
InMemoryWebSession session = iterator.next();
117-
if (session.isExpired(currentTime)) {
118-
iterator.remove();
119-
session.invalidate();
120-
}
121-
}
122-
}
123-
finally {
124-
this.nextExpirationCheckTime = currentTime.plus(EXPIRATION_CHECK_PERIOD);
125-
this.expirationCheckLock.unlock();
126-
}
127-
}
128-
}
129-
130137
@Override
131138
public Mono<Void> removeSession(String id) {
132139
this.sessions.remove(id);
133140
return Mono.empty();
134141
}
135142

136-
public Mono<WebSession> updateLastAccessTime(WebSession webSession) {
143+
public Mono<WebSession> updateLastAccessTime(WebSession session) {
137144
return Mono.fromSupplier(() -> {
138-
Assert.isInstanceOf(InMemoryWebSession.class, webSession);
139-
InMemoryWebSession session = (InMemoryWebSession) webSession;
140-
session.updateLastAccessTime(Instant.now(getClock()));
145+
Assert.isInstanceOf(InMemoryWebSession.class, session);
146+
((InMemoryWebSession) session).updateLastAccessTime(this.clock.instant());
141147
return session;
142148
});
143149
}
144150

151+
/**
152+
* Check for expired sessions and remove them. Typically such checks are
153+
* kicked off lazily during calls to {@link #createWebSession() create} or
154+
* {@link #retrieveSession retrieve}, no less than 60 seconds apart.
155+
* This method can be called to force a check at a specific time.
156+
* @since 5.1
157+
*/
158+
public void removeExpiredSessions() {
159+
this.expiredSessionChecker.removeExpiredSessions(this.clock.instant());
160+
}
161+
145162

146163
private class InMemoryWebSession implements WebSession {
147164

@@ -157,8 +174,9 @@ private class InMemoryWebSession implements WebSession {
157174

158175
private final AtomicReference<State> state = new AtomicReference<>(State.NEW);
159176

160-
public InMemoryWebSession() {
161-
this.creationTime = Instant.now(getClock());
177+
178+
public InMemoryWebSession(Instant creationTime) {
179+
this.creationTime = creationTime;
162180
this.lastAccessTime = this.creationTime;
163181
}
164182

@@ -222,6 +240,12 @@ public Mono<Void> invalidate() {
222240

223241
@Override
224242
public Mono<Void> save() {
243+
if (sessions.size() >= maxSessions) {
244+
expiredSessionChecker.removeExpiredSessions(clock.instant());
245+
if (sessions.size() >= maxSessions) {
246+
return Mono.error(new IllegalStateException("Max sessions limit reached: " + sessions.size()));
247+
}
248+
}
225249
if (!getAttributes().isEmpty()) {
226250
this.state.compareAndSet(State.NEW, State.STARTED);
227251
}
@@ -231,14 +255,14 @@ public Mono<Void> save() {
231255

232256
@Override
233257
public boolean isExpired() {
234-
return isExpired(Instant.now(getClock()));
258+
return isExpired(clock.instant());
235259
}
236260

237-
private boolean isExpired(Instant currentTime) {
261+
private boolean isExpired(Instant now) {
238262
if (this.state.get().equals(State.EXPIRED)) {
239263
return true;
240264
}
241-
if (checkExpired(currentTime)) {
265+
if (checkExpired(now)) {
242266
this.state.set(State.EXPIRED);
243267
return true;
244268
}
@@ -256,6 +280,47 @@ private void updateLastAccessTime(Instant currentTime) {
256280
}
257281

258282

283+
private class ExpiredSessionChecker {
284+
285+
/** Max time between expiration checks. */
286+
private static final int CHECK_PERIOD = 60 * 1000;
287+
288+
289+
private final ReentrantLock lock = new ReentrantLock();
290+
291+
private Instant checkTime = clock.instant().plus(CHECK_PERIOD, ChronoUnit.MILLIS);
292+
293+
294+
public void checkIfNecessary(Instant now) {
295+
if (this.checkTime.isBefore(now)) {
296+
removeExpiredSessions(now);
297+
}
298+
}
299+
300+
public void removeExpiredSessions(Instant now) {
301+
if (sessions.isEmpty()) {
302+
return;
303+
}
304+
if (this.lock.tryLock()) {
305+
try {
306+
Iterator<InMemoryWebSession> iterator = sessions.values().iterator();
307+
while (iterator.hasNext()) {
308+
InMemoryWebSession session = iterator.next();
309+
if (session.isExpired(now)) {
310+
iterator.remove();
311+
session.invalidate();
312+
}
313+
}
314+
}
315+
finally {
316+
this.checkTime = now.plus(CHECK_PERIOD, ChronoUnit.MILLIS);
317+
this.lock.unlock();
318+
}
319+
}
320+
}
321+
}
322+
323+
259324
private enum State { NEW, STARTED, EXPIRED }
260325

261326
}

spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,20 @@
1818
import java.time.Clock;
1919
import java.time.Duration;
2020
import java.time.Instant;
21+
import java.util.Map;
22+
import java.util.stream.IntStream;
2123

2224
import org.junit.Test;
2325

26+
import org.springframework.beans.DirectFieldAccessor;
2427
import org.springframework.web.server.WebSession;
2528

2629
import static junit.framework.TestCase.assertSame;
30+
import static org.junit.Assert.assertEquals;
2731
import static org.junit.Assert.assertNotNull;
2832
import static org.junit.Assert.assertNull;
2933
import static org.junit.Assert.assertTrue;
34+
import static org.junit.Assert.fail;
3035

3136
/**
3237
* Unit tests for {@link InMemoryWebSessionStore}.
@@ -55,7 +60,7 @@ public void startsSessionImplicitly() {
5560
}
5661

5762
@Test
58-
public void retrieveExpiredSession() throws Exception {
63+
public void retrieveExpiredSession() {
5964
WebSession session = this.store.createWebSession().block();
6065
assertNotNull(session);
6166
session.getAttributes().put("foo", "bar");
@@ -73,7 +78,7 @@ public void retrieveExpiredSession() throws Exception {
7378
}
7479

7580
@Test
76-
public void lastAccessTimeIsUpdatedOnRetrieve() throws Exception {
81+
public void lastAccessTimeIsUpdatedOnRetrieve() {
7782
WebSession session1 = this.store.createWebSession().block();
7883
assertNotNull(session1);
7984
String id = session1.getId();
@@ -91,46 +96,45 @@ public void lastAccessTimeIsUpdatedOnRetrieve() throws Exception {
9196
}
9297

9398
@Test
94-
public void expirationChecks() throws Exception {
95-
// Create 3 sessions
96-
WebSession session1 = this.store.createWebSession().block();
97-
assertNotNull(session1);
98-
session1.start();
99-
session1.save().block();
99+
public void expirationCheckPeriod() {
100100

101-
WebSession session2 = this.store.createWebSession().block();
102-
assertNotNull(session2);
103-
session2.start();
104-
session2.save().block();
101+
DirectFieldAccessor accessor = new DirectFieldAccessor(this.store);
102+
Map<?,?> sessions = (Map<?, ?>) accessor.getPropertyValue("sessions");
103+
assertNotNull(sessions);
105104

106-
WebSession session3 = this.store.createWebSession().block();
107-
assertNotNull(session3);
108-
session3.start();
109-
session3.save().block();
105+
// Create 100 sessions
106+
IntStream.range(0, 100).forEach(i -> insertSession());
107+
assertEquals(100, sessions.size());
110108

111-
// Fast-forward 31 minutes
112-
this.store.setClock(Clock.offset(this.store.getClock(), Duration.ofMinutes(31)));
109+
// Force a new clock (31 min later), don't use setter which would clean expired sessions
110+
accessor.setPropertyValue("clock", Clock.offset(this.store.getClock(), Duration.ofMinutes(31)));
111+
assertEquals(100, sessions.size());
113112

114-
// Create 2 more sessions
115-
WebSession session4 = this.store.createWebSession().block();
116-
assertNotNull(session4);
117-
session4.start();
118-
session4.save().block();
113+
// Create 1 more which forces a time-based check (clock moved forward)
114+
insertSession();
115+
assertEquals(1, sessions.size());
116+
}
119117

120-
WebSession session5 = this.store.createWebSession().block();
121-
assertNotNull(session5);
122-
session5.start();
123-
session5.save().block();
118+
@Test
119+
public void maxSessions() {
124120

125-
// Retrieve, forcing cleanup of all expired..
126-
assertNull(this.store.retrieveSession(session1.getId()).block());
127-
assertNull(this.store.retrieveSession(session2.getId()).block());
128-
assertNull(this.store.retrieveSession(session3.getId()).block());
121+
IntStream.range(0, 10000).forEach(i -> insertSession());
129122

130-
assertNotNull(this.store.retrieveSession(session4.getId()).block());
131-
assertNotNull(this.store.retrieveSession(session5.getId()).block());
123+
try {
124+
insertSession();
125+
fail();
126+
}
127+
catch (IllegalStateException ex) {
128+
assertEquals("Max sessions limit reached: 10000", ex.getMessage());
129+
}
132130
}
133131

132+
private WebSession insertSession() {
133+
WebSession session = this.store.createWebSession().block();
134+
assertNotNull(session);
135+
session.start();
136+
session.save().block();
137+
return session;
138+
}
134139

135-
136-
}
140+
}

0 commit comments

Comments
 (0)