Skip to content

Commit 1dc277a

Browse files
add coroutine web socket chat and handle stale connections
1 parent 479cefa commit 1dc277a

File tree

6 files changed

+557
-157
lines changed

6 files changed

+557
-157
lines changed

src/main/kotlin/com/softeno/template/SoftenoReactiveMongoApp.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import reactor.core.publisher.Hooks
1515

1616
@SpringBootApplication
1717
@EnableConfigurationProperties
18-
@ConfigurationPropertiesScan("com.softeno")
18+
@ConfigurationPropertiesScan
1919
class SoftenoReactiveMongoApp
2020

2121
fun main(args: Array<String>) {

src/main/kotlin/com/softeno/template/app/kafka/KafkaSampleHandler.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import com.softeno.template.app.user.notification.CoroutineUserUpdateEmitter
99
import com.softeno.template.app.user.notification.ReactiveUserUpdateEmitter
1010
import com.softeno.template.sample.websocket.Message
1111
import com.softeno.template.sample.websocket.ReactiveMessageService
12+
import com.softeno.template.sample.websocket.WsMessageService
1213
import io.micrometer.tracing.Span
1314
import io.micrometer.tracing.Tracer
1415
import kotlinx.coroutines.DelicateCoroutinesApi
@@ -34,7 +35,7 @@ class ReactiveKafkaSampleController(
3435
@Qualifier(value = "kafkaSampleConsumerTemplate") private val reactiveKafkaConsumerTemplate: ReactiveKafkaConsumerTemplate<String, JsonNode>,
3536
private val objectMapper: ObjectMapper,
3637
private val tracer: Tracer,
37-
private val reactiveMessageService: ReactiveMessageService,
38+
private val wsMessageService: WsMessageService,
3839
private val reactiveUserUpdateEmitter: ReactiveUserUpdateEmitter,
3940
private val userUpdateEmitter: CoroutineUserUpdateEmitter,
4041
) : CommandLineRunner {
@@ -65,7 +66,7 @@ class ReactiveKafkaSampleController(
6566
val span = tracer.nextSpan().name("kafka-consumer")
6667
tracer.withSpan(span.start()).use {
6768
log.info("[kafka] rx sample: $kafkaMessage")
68-
reactiveMessageService.broadcast(kafkaMessage.toMessage())
69+
wsMessageService.broadcast(kafkaMessage.toMessage())
6970
reactiveUserUpdateEmitter.broadcast(kafkaMessage.toMessage())
7071
userUpdateEmitter.broadcast(kafkaMessage.toMessage())
7172
}
Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
package com.softeno.template.sample.websocket
2+
3+
import com.fasterxml.jackson.databind.ObjectMapper
4+
import kotlinx.coroutines.*
5+
import kotlinx.coroutines.channels.Channel
6+
import kotlinx.coroutines.flow.Flow
7+
import kotlinx.coroutines.flow.emptyFlow
8+
import kotlinx.coroutines.flow.receiveAsFlow
9+
import kotlinx.coroutines.reactor.awaitSingle
10+
import kotlinx.coroutines.reactor.awaitSingleOrNull
11+
import kotlinx.coroutines.reactor.mono
12+
import kotlinx.coroutines.selects.select
13+
import kotlinx.coroutines.slf4j.MDCContext
14+
import org.apache.commons.logging.LogFactory
15+
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty
16+
import org.springframework.context.annotation.Bean
17+
import org.springframework.context.annotation.Configuration
18+
import org.springframework.security.core.context.ReactiveSecurityContextHolder
19+
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken
20+
import org.springframework.stereotype.Service
21+
import org.springframework.web.reactive.HandlerMapping
22+
import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping
23+
import org.springframework.web.reactive.socket.WebSocketHandler
24+
import org.springframework.web.reactive.socket.WebSocketSession
25+
import org.springframework.web.reactive.socket.server.support.WebSocketHandlerAdapter
26+
import reactor.core.publisher.Mono
27+
import java.time.Instant
28+
import java.util.concurrent.ConcurrentHashMap
29+
import kotlin.coroutines.CoroutineContext
30+
import kotlin.time.Duration.Companion.seconds
31+
32+
@ConditionalOnProperty(
33+
name = ["com.softeno.ws.type"],
34+
havingValue = "coroutine",
35+
matchIfMissing = false
36+
)
37+
@Configuration
38+
class CoroutineWebSocketConfig(
39+
private val coroutineMessageService: CoroutineMessageService,
40+
private val config: ChatConfigProperties
41+
) {
42+
private val log = LogFactory.getLog(javaClass)
43+
44+
@Bean
45+
fun webSocketHandlerAdapter() = WebSocketHandlerAdapter()
46+
47+
@Bean
48+
fun handlerMapping(objectMapper: ObjectMapper): HandlerMapping {
49+
val simpleMapping = SimpleUrlHandlerMapping()
50+
simpleMapping.order = 10
51+
simpleMapping.urlMap = mapOf(
52+
"/ws" to webSocketHandler(objectMapper)
53+
)
54+
return simpleMapping
55+
}
56+
57+
@Bean
58+
fun webSocketHandler(objectMapper: ObjectMapper): WebSocketHandler {
59+
return WebSocketHandler { session ->
60+
log.info("ws: [chat] new session: ${session.id}")
61+
// Convert reactive to coroutine context
62+
mono( Dispatchers.IO + MDCContext()) {
63+
try {
64+
handleWebSocketSession(session, objectMapper)
65+
} catch (e: Exception) {
66+
log.error("ws: [chat] error in session: ${session.id}", e)
67+
}
68+
}.then()
69+
}
70+
}
71+
72+
private suspend fun handleWebSocketSession(
73+
session: WebSocketSession,
74+
objectMapper: ObjectMapper
75+
) = withContext(Dispatchers.IO + MDCContext()) {
76+
// Get user authentication
77+
78+
val userId = withContext(Dispatchers.IO + MDCContext()) {
79+
val authentication = ReactiveSecurityContextHolder.getContext().awaitSingle().authentication
80+
val token = (authentication as JwtAuthenticationToken).token
81+
val userId = token.claims["sub"] as String
82+
return@withContext userId
83+
}
84+
85+
// Register session
86+
coroutineMessageService.registerSession(session.id)
87+
88+
// Send initial messages
89+
coroutineMessageService.send(Message("SYSTEM", session.id, "HANDSHAKE"), session.id)
90+
coroutineMessageService.send(Message("SYSTEM", session.id, userId), session.id)
91+
92+
// Start concurrent coroutines for sending and receiving
93+
coroutineScope {
94+
// Coroutine for sending messages (including heartbeat)
95+
val sendingJob = launch(Dispatchers.IO + MDCContext() +
96+
SupervisorJob() + CoroutineExceptionHandler { context, throwable -> runBlocking(Dispatchers.IO + MDCContext()) {
97+
log.error("ws: [chat] failed to send message in session: $session.id", throwable)
98+
handleWebSocketSessionError(context, throwable, session)
99+
}
100+
}) {
101+
handleOutgoingMessages(session, objectMapper)
102+
}
103+
104+
// Coroutine for receiving messages
105+
val receivingJob = launch(Dispatchers.IO + MDCContext() +
106+
SupervisorJob() + CoroutineExceptionHandler { context, throwable -> runBlocking(Dispatchers.IO + MDCContext()) {
107+
log.error("ws: [chat] failed to receive message in session: ${session.id}", throwable)
108+
handleWebSocketSessionError(context, throwable, session)
109+
}
110+
}) {
111+
handleIncomingMessages(session, objectMapper)
112+
}
113+
114+
// Coroutine for heartbeat
115+
val heartbeatJob = launch(Dispatchers.IO + MDCContext() +
116+
SupervisorJob() + CoroutineExceptionHandler { context, throwable -> runBlocking(Dispatchers.IO + MDCContext()) {
117+
log.error("ws: [chat] failed to send heartbeat in session: ${session.id}", throwable)
118+
handleWebSocketSessionError(context, throwable, session)
119+
}
120+
}) {
121+
handleHeartbeat(session)
122+
}
123+
124+
// Wait for any job to complete (usually means disconnection)
125+
select {
126+
sendingJob.onJoin { }
127+
receivingJob.onJoin { }
128+
heartbeatJob.onJoin { }
129+
}
130+
131+
// Cancel remaining jobs
132+
sendingJob.cancelAndJoin()
133+
receivingJob.cancelAndJoin()
134+
heartbeatJob.cancelAndJoin()
135+
}
136+
137+
// Cleanup
138+
log.info("ws: [chat] disconnect chat session: ${session.id}")
139+
coroutineMessageService.unregisterSession(session.id)
140+
}
141+
142+
private suspend fun handleWebSocketSessionError(context: CoroutineContext, throwable: Throwable, session: WebSocketSession) =
143+
withContext(Dispatchers.IO + MDCContext()) {
144+
log.error("ws: [chat] error: ${throwable.message}", throwable)
145+
log.info("ws: [chat] closing session: ${session.id}")
146+
147+
closeSession(session)
148+
context.cancel()
149+
}
150+
151+
private suspend fun closeSession(session: WebSocketSession) {
152+
coroutineMessageService.unregisterSession(session.id)
153+
session.close().awaitSingleOrNull()
154+
}
155+
156+
private suspend fun handleOutgoingMessages(
157+
session: WebSocketSession,
158+
objectMapper: ObjectMapper
159+
) = withContext(Dispatchers.IO + MDCContext()) {
160+
coroutineMessageService.getMessageFlow(session.id).collect { message ->
161+
val json = message.toJson(objectMapper)
162+
log.info("ws: [chat] tx: $json")
163+
session.send(Mono.just(session.textMessage(json))).awaitSingleOrNull()
164+
}
165+
}
166+
167+
@OptIn(DelicateCoroutinesApi::class)
168+
private suspend fun handleIncomingMessages(
169+
session: WebSocketSession,
170+
objectMapper: ObjectMapper
171+
) = withContext(Dispatchers.IO + MDCContext()) {
172+
session.receive()
173+
.doOnNext { wsMessage ->
174+
// Process the message immediately within the reactive context
175+
try {
176+
val payloadText = wsMessage.payloadAsText
177+
val message = objectMapper.readValue(payloadText, Message::class.java)
178+
log.info("ws: [chat] rx: $message")
179+
180+
// Launch a coroutine to handle the message asynchronously
181+
GlobalScope.launch(Dispatchers.IO + MDCContext()) {
182+
try {
183+
when {
184+
message.content == "pong" && message.to == "SYSTEM" -> {
185+
log.debug("ws: [chat] received pong from session: ${session.id}")
186+
coroutineMessageService.updateLastPong(session.id)
187+
}
188+
else -> {
189+
coroutineMessageService.routeMessage(message)
190+
}
191+
}
192+
} catch (e: Exception) {
193+
log.error("ws: [chat] failed to route message in session: ${session.id}", e)
194+
}
195+
}
196+
} catch (e: Exception) {
197+
log.error("ws: [chat] failed to parse message in session: ${session.id}", e)
198+
}
199+
}
200+
.then()
201+
.awaitSingleOrNull()
202+
}
203+
204+
private suspend fun handleHeartbeat(session: WebSocketSession) = withContext(Dispatchers.IO + MDCContext()) {
205+
while (currentCoroutineContext().isActive) {
206+
delay(config.heartbeatIntervalSeconds.toLong().seconds)
207+
log.debug("ws: [chat] sending heartbeat to session: ${session.id}")
208+
coroutineMessageService.send(
209+
Message("SYSTEM", session.id, "ping"),
210+
session.id
211+
)
212+
}
213+
}
214+
}
215+
216+
@ConditionalOnProperty(
217+
name = ["com.softeno.ws.type"],
218+
havingValue = "coroutine",
219+
matchIfMissing = false
220+
)
221+
@Service
222+
class CoroutineMessageService(
223+
private val config: ChatConfigProperties
224+
) : WsMessageService {
225+
private val log = LogFactory.getLog(javaClass)
226+
227+
private val messageChannels = ConcurrentHashMap<String, Channel<Message>>()
228+
private val lastPongTimes = ConcurrentHashMap<String, Instant>()
229+
230+
// Lazy heartbeat monitoring
231+
@OptIn(DelicateCoroutinesApi::class)
232+
private val heartbeatMonitoring: Job by lazy {
233+
GlobalScope.launch {
234+
while (isActive) {
235+
delay(config.heartbeatIntervalSeconds.toLong().seconds)
236+
checkStaleConnections()
237+
}
238+
}
239+
}
240+
241+
suspend fun registerSession(sessionId: String) = withContext(Dispatchers.IO + MDCContext()) {
242+
messageChannels[sessionId] = Channel(capacity = Channel.UNLIMITED)
243+
lastPongTimes[sessionId] = Instant.now()
244+
245+
// Start heartbeat monitoring when first session connects
246+
if (config.staleCheck) {
247+
heartbeatMonitoring
248+
}
249+
250+
log.info("ws: [chat] registered session: $sessionId")
251+
}
252+
253+
fun unregisterSession(sessionId: String) {
254+
messageChannels.remove(sessionId)?.close()
255+
lastPongTimes.remove(sessionId)
256+
log.info("ws: [chat] unregistered session: $sessionId")
257+
}
258+
259+
fun send(message: Message, sessionId: String): Message {
260+
messageChannels[sessionId]?.trySend(message)?.let { result ->
261+
if (result.isFailure) {
262+
log.warn("ws: [chat] failed to send message to session: $sessionId")
263+
}
264+
}
265+
return message
266+
}
267+
268+
override fun broadcast(message: Message): Message {
269+
messageChannels.values.forEach { channel ->
270+
channel.trySend(message)
271+
}
272+
return message
273+
}
274+
275+
fun routeMessage(message: Message) {
276+
when (message.to) {
277+
"ALL" -> broadcast(message)
278+
else -> {
279+
// Try to send to specific session first, then to user
280+
if (messageChannels.containsKey(message.to)) {
281+
send(message, message.to)
282+
} else {
283+
log.error("ws: [chat] failed to send message: $message to session: ${message.to} - unknown session")
284+
}
285+
}
286+
}
287+
}
288+
289+
fun getMessageFlow(sessionId: String): Flow<Message> {
290+
return messageChannels[sessionId]?.receiveAsFlow() ?: emptyFlow()
291+
}
292+
293+
fun updateLastPong(sessionId: String) {
294+
lastPongTimes[sessionId] = Instant.now()
295+
}
296+
297+
private fun checkStaleConnections() {
298+
val staleThreshold = Instant.now().minusSeconds(config.staleCheckThresholdSeconds.toLong())
299+
val staleSessions = mutableListOf<String>()
300+
301+
// First, identify stale sessions without modifying the collection
302+
lastPongTimes.forEach { (sessionId, lastPong) ->
303+
if (lastPong.isBefore(staleThreshold)) {
304+
staleSessions.add(sessionId)
305+
}
306+
}
307+
308+
// Then, clean up stale sessions properly in coroutine context
309+
staleSessions.forEach { sessionId ->
310+
log.warn("ws: [chat] removing stale connection: $sessionId (last pong: ${lastPongTimes[sessionId]})")
311+
unregisterSession(sessionId)
312+
}
313+
}
314+
}

0 commit comments

Comments
 (0)