|
| 1 | +package com.softeno.template.sample.websocket |
| 2 | + |
| 3 | +import com.fasterxml.jackson.databind.ObjectMapper |
| 4 | +import kotlinx.coroutines.* |
| 5 | +import kotlinx.coroutines.channels.Channel |
| 6 | +import kotlinx.coroutines.flow.Flow |
| 7 | +import kotlinx.coroutines.flow.emptyFlow |
| 8 | +import kotlinx.coroutines.flow.receiveAsFlow |
| 9 | +import kotlinx.coroutines.reactor.awaitSingle |
| 10 | +import kotlinx.coroutines.reactor.awaitSingleOrNull |
| 11 | +import kotlinx.coroutines.reactor.mono |
| 12 | +import kotlinx.coroutines.selects.select |
| 13 | +import kotlinx.coroutines.slf4j.MDCContext |
| 14 | +import org.apache.commons.logging.LogFactory |
| 15 | +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty |
| 16 | +import org.springframework.context.annotation.Bean |
| 17 | +import org.springframework.context.annotation.Configuration |
| 18 | +import org.springframework.security.core.context.ReactiveSecurityContextHolder |
| 19 | +import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken |
| 20 | +import org.springframework.stereotype.Service |
| 21 | +import org.springframework.web.reactive.HandlerMapping |
| 22 | +import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping |
| 23 | +import org.springframework.web.reactive.socket.WebSocketHandler |
| 24 | +import org.springframework.web.reactive.socket.WebSocketSession |
| 25 | +import org.springframework.web.reactive.socket.server.support.WebSocketHandlerAdapter |
| 26 | +import reactor.core.publisher.Mono |
| 27 | +import java.time.Instant |
| 28 | +import java.util.concurrent.ConcurrentHashMap |
| 29 | +import kotlin.coroutines.CoroutineContext |
| 30 | +import kotlin.time.Duration.Companion.seconds |
| 31 | + |
| 32 | +@ConditionalOnProperty( |
| 33 | + name = ["com.softeno.ws.type"], |
| 34 | + havingValue = "coroutine", |
| 35 | + matchIfMissing = false |
| 36 | +) |
| 37 | +@Configuration |
| 38 | +class CoroutineWebSocketConfig( |
| 39 | + private val coroutineMessageService: CoroutineMessageService, |
| 40 | + private val config: ChatConfigProperties |
| 41 | +) { |
| 42 | + private val log = LogFactory.getLog(javaClass) |
| 43 | + |
| 44 | + @Bean |
| 45 | + fun webSocketHandlerAdapter() = WebSocketHandlerAdapter() |
| 46 | + |
| 47 | + @Bean |
| 48 | + fun handlerMapping(objectMapper: ObjectMapper): HandlerMapping { |
| 49 | + val simpleMapping = SimpleUrlHandlerMapping() |
| 50 | + simpleMapping.order = 10 |
| 51 | + simpleMapping.urlMap = mapOf( |
| 52 | + "/ws" to webSocketHandler(objectMapper) |
| 53 | + ) |
| 54 | + return simpleMapping |
| 55 | + } |
| 56 | + |
| 57 | + @Bean |
| 58 | + fun webSocketHandler(objectMapper: ObjectMapper): WebSocketHandler { |
| 59 | + return WebSocketHandler { session -> |
| 60 | + log.info("ws: [chat] new session: ${session.id}") |
| 61 | + // Convert reactive to coroutine context |
| 62 | + mono( Dispatchers.IO + MDCContext()) { |
| 63 | + try { |
| 64 | + handleWebSocketSession(session, objectMapper) |
| 65 | + } catch (e: Exception) { |
| 66 | + log.error("ws: [chat] error in session: ${session.id}", e) |
| 67 | + } |
| 68 | + }.then() |
| 69 | + } |
| 70 | + } |
| 71 | + |
| 72 | + private suspend fun handleWebSocketSession( |
| 73 | + session: WebSocketSession, |
| 74 | + objectMapper: ObjectMapper |
| 75 | + ) = withContext(Dispatchers.IO + MDCContext()) { |
| 76 | + // Get user authentication |
| 77 | + |
| 78 | + val userId = withContext(Dispatchers.IO + MDCContext()) { |
| 79 | + val authentication = ReactiveSecurityContextHolder.getContext().awaitSingle().authentication |
| 80 | + val token = (authentication as JwtAuthenticationToken).token |
| 81 | + val userId = token.claims["sub"] as String |
| 82 | + return@withContext userId |
| 83 | + } |
| 84 | + |
| 85 | + // Register session |
| 86 | + coroutineMessageService.registerSession(session.id) |
| 87 | + |
| 88 | + // Send initial messages |
| 89 | + coroutineMessageService.send(Message("SYSTEM", session.id, "HANDSHAKE"), session.id) |
| 90 | + coroutineMessageService.send(Message("SYSTEM", session.id, userId), session.id) |
| 91 | + |
| 92 | + // Start concurrent coroutines for sending and receiving |
| 93 | + coroutineScope { |
| 94 | + // Coroutine for sending messages (including heartbeat) |
| 95 | + val sendingJob = launch(Dispatchers.IO + MDCContext() + |
| 96 | + SupervisorJob() + CoroutineExceptionHandler { context, throwable -> runBlocking(Dispatchers.IO + MDCContext()) { |
| 97 | + log.error("ws: [chat] failed to send message in session: $session.id", throwable) |
| 98 | + handleWebSocketSessionError(context, throwable, session) |
| 99 | + } |
| 100 | + }) { |
| 101 | + handleOutgoingMessages(session, objectMapper) |
| 102 | + } |
| 103 | + |
| 104 | + // Coroutine for receiving messages |
| 105 | + val receivingJob = launch(Dispatchers.IO + MDCContext() + |
| 106 | + SupervisorJob() + CoroutineExceptionHandler { context, throwable -> runBlocking(Dispatchers.IO + MDCContext()) { |
| 107 | + log.error("ws: [chat] failed to receive message in session: ${session.id}", throwable) |
| 108 | + handleWebSocketSessionError(context, throwable, session) |
| 109 | + } |
| 110 | + }) { |
| 111 | + handleIncomingMessages(session, objectMapper) |
| 112 | + } |
| 113 | + |
| 114 | + // Coroutine for heartbeat |
| 115 | + val heartbeatJob = launch(Dispatchers.IO + MDCContext() + |
| 116 | + SupervisorJob() + CoroutineExceptionHandler { context, throwable -> runBlocking(Dispatchers.IO + MDCContext()) { |
| 117 | + log.error("ws: [chat] failed to send heartbeat in session: ${session.id}", throwable) |
| 118 | + handleWebSocketSessionError(context, throwable, session) |
| 119 | + } |
| 120 | + }) { |
| 121 | + handleHeartbeat(session) |
| 122 | + } |
| 123 | + |
| 124 | + // Wait for any job to complete (usually means disconnection) |
| 125 | + select { |
| 126 | + sendingJob.onJoin { } |
| 127 | + receivingJob.onJoin { } |
| 128 | + heartbeatJob.onJoin { } |
| 129 | + } |
| 130 | + |
| 131 | + // Cancel remaining jobs |
| 132 | + sendingJob.cancelAndJoin() |
| 133 | + receivingJob.cancelAndJoin() |
| 134 | + heartbeatJob.cancelAndJoin() |
| 135 | + } |
| 136 | + |
| 137 | + // Cleanup |
| 138 | + log.info("ws: [chat] disconnect chat session: ${session.id}") |
| 139 | + coroutineMessageService.unregisterSession(session.id) |
| 140 | + } |
| 141 | + |
| 142 | + private suspend fun handleWebSocketSessionError(context: CoroutineContext, throwable: Throwable, session: WebSocketSession) = |
| 143 | + withContext(Dispatchers.IO + MDCContext()) { |
| 144 | + log.error("ws: [chat] error: ${throwable.message}", throwable) |
| 145 | + log.info("ws: [chat] closing session: ${session.id}") |
| 146 | + |
| 147 | + closeSession(session) |
| 148 | + context.cancel() |
| 149 | + } |
| 150 | + |
| 151 | + private suspend fun closeSession(session: WebSocketSession) { |
| 152 | + coroutineMessageService.unregisterSession(session.id) |
| 153 | + session.close().awaitSingleOrNull() |
| 154 | + } |
| 155 | + |
| 156 | + private suspend fun handleOutgoingMessages( |
| 157 | + session: WebSocketSession, |
| 158 | + objectMapper: ObjectMapper |
| 159 | + ) = withContext(Dispatchers.IO + MDCContext()) { |
| 160 | + coroutineMessageService.getMessageFlow(session.id).collect { message -> |
| 161 | + val json = message.toJson(objectMapper) |
| 162 | + log.info("ws: [chat] tx: $json") |
| 163 | + session.send(Mono.just(session.textMessage(json))).awaitSingleOrNull() |
| 164 | + } |
| 165 | + } |
| 166 | + |
| 167 | + @OptIn(DelicateCoroutinesApi::class) |
| 168 | + private suspend fun handleIncomingMessages( |
| 169 | + session: WebSocketSession, |
| 170 | + objectMapper: ObjectMapper |
| 171 | + ) = withContext(Dispatchers.IO + MDCContext()) { |
| 172 | + session.receive() |
| 173 | + .doOnNext { wsMessage -> |
| 174 | + // Process the message immediately within the reactive context |
| 175 | + try { |
| 176 | + val payloadText = wsMessage.payloadAsText |
| 177 | + val message = objectMapper.readValue(payloadText, Message::class.java) |
| 178 | + log.info("ws: [chat] rx: $message") |
| 179 | + |
| 180 | + // Launch a coroutine to handle the message asynchronously |
| 181 | + GlobalScope.launch(Dispatchers.IO + MDCContext()) { |
| 182 | + try { |
| 183 | + when { |
| 184 | + message.content == "pong" && message.to == "SYSTEM" -> { |
| 185 | + log.debug("ws: [chat] received pong from session: ${session.id}") |
| 186 | + coroutineMessageService.updateLastPong(session.id) |
| 187 | + } |
| 188 | + else -> { |
| 189 | + coroutineMessageService.routeMessage(message) |
| 190 | + } |
| 191 | + } |
| 192 | + } catch (e: Exception) { |
| 193 | + log.error("ws: [chat] failed to route message in session: ${session.id}", e) |
| 194 | + } |
| 195 | + } |
| 196 | + } catch (e: Exception) { |
| 197 | + log.error("ws: [chat] failed to parse message in session: ${session.id}", e) |
| 198 | + } |
| 199 | + } |
| 200 | + .then() |
| 201 | + .awaitSingleOrNull() |
| 202 | + } |
| 203 | + |
| 204 | + private suspend fun handleHeartbeat(session: WebSocketSession) = withContext(Dispatchers.IO + MDCContext()) { |
| 205 | + while (currentCoroutineContext().isActive) { |
| 206 | + delay(config.heartbeatIntervalSeconds.toLong().seconds) |
| 207 | + log.debug("ws: [chat] sending heartbeat to session: ${session.id}") |
| 208 | + coroutineMessageService.send( |
| 209 | + Message("SYSTEM", session.id, "ping"), |
| 210 | + session.id |
| 211 | + ) |
| 212 | + } |
| 213 | + } |
| 214 | +} |
| 215 | + |
| 216 | +@ConditionalOnProperty( |
| 217 | + name = ["com.softeno.ws.type"], |
| 218 | + havingValue = "coroutine", |
| 219 | + matchIfMissing = false |
| 220 | +) |
| 221 | +@Service |
| 222 | +class CoroutineMessageService( |
| 223 | + private val config: ChatConfigProperties |
| 224 | +) : WsMessageService { |
| 225 | + private val log = LogFactory.getLog(javaClass) |
| 226 | + |
| 227 | + private val messageChannels = ConcurrentHashMap<String, Channel<Message>>() |
| 228 | + private val lastPongTimes = ConcurrentHashMap<String, Instant>() |
| 229 | + |
| 230 | + // Lazy heartbeat monitoring |
| 231 | + @OptIn(DelicateCoroutinesApi::class) |
| 232 | + private val heartbeatMonitoring: Job by lazy { |
| 233 | + GlobalScope.launch { |
| 234 | + while (isActive) { |
| 235 | + delay(config.heartbeatIntervalSeconds.toLong().seconds) |
| 236 | + checkStaleConnections() |
| 237 | + } |
| 238 | + } |
| 239 | + } |
| 240 | + |
| 241 | + suspend fun registerSession(sessionId: String) = withContext(Dispatchers.IO + MDCContext()) { |
| 242 | + messageChannels[sessionId] = Channel(capacity = Channel.UNLIMITED) |
| 243 | + lastPongTimes[sessionId] = Instant.now() |
| 244 | + |
| 245 | + // Start heartbeat monitoring when first session connects |
| 246 | + if (config.staleCheck) { |
| 247 | + heartbeatMonitoring |
| 248 | + } |
| 249 | + |
| 250 | + log.info("ws: [chat] registered session: $sessionId") |
| 251 | + } |
| 252 | + |
| 253 | + fun unregisterSession(sessionId: String) { |
| 254 | + messageChannels.remove(sessionId)?.close() |
| 255 | + lastPongTimes.remove(sessionId) |
| 256 | + log.info("ws: [chat] unregistered session: $sessionId") |
| 257 | + } |
| 258 | + |
| 259 | + fun send(message: Message, sessionId: String): Message { |
| 260 | + messageChannels[sessionId]?.trySend(message)?.let { result -> |
| 261 | + if (result.isFailure) { |
| 262 | + log.warn("ws: [chat] failed to send message to session: $sessionId") |
| 263 | + } |
| 264 | + } |
| 265 | + return message |
| 266 | + } |
| 267 | + |
| 268 | + override fun broadcast(message: Message): Message { |
| 269 | + messageChannels.values.forEach { channel -> |
| 270 | + channel.trySend(message) |
| 271 | + } |
| 272 | + return message |
| 273 | + } |
| 274 | + |
| 275 | + fun routeMessage(message: Message) { |
| 276 | + when (message.to) { |
| 277 | + "ALL" -> broadcast(message) |
| 278 | + else -> { |
| 279 | + // Try to send to specific session first, then to user |
| 280 | + if (messageChannels.containsKey(message.to)) { |
| 281 | + send(message, message.to) |
| 282 | + } else { |
| 283 | + log.error("ws: [chat] failed to send message: $message to session: ${message.to} - unknown session") |
| 284 | + } |
| 285 | + } |
| 286 | + } |
| 287 | + } |
| 288 | + |
| 289 | + fun getMessageFlow(sessionId: String): Flow<Message> { |
| 290 | + return messageChannels[sessionId]?.receiveAsFlow() ?: emptyFlow() |
| 291 | + } |
| 292 | + |
| 293 | + fun updateLastPong(sessionId: String) { |
| 294 | + lastPongTimes[sessionId] = Instant.now() |
| 295 | + } |
| 296 | + |
| 297 | + private fun checkStaleConnections() { |
| 298 | + val staleThreshold = Instant.now().minusSeconds(config.staleCheckThresholdSeconds.toLong()) |
| 299 | + val staleSessions = mutableListOf<String>() |
| 300 | + |
| 301 | + // First, identify stale sessions without modifying the collection |
| 302 | + lastPongTimes.forEach { (sessionId, lastPong) -> |
| 303 | + if (lastPong.isBefore(staleThreshold)) { |
| 304 | + staleSessions.add(sessionId) |
| 305 | + } |
| 306 | + } |
| 307 | + |
| 308 | + // Then, clean up stale sessions properly in coroutine context |
| 309 | + staleSessions.forEach { sessionId -> |
| 310 | + log.warn("ws: [chat] removing stale connection: $sessionId (last pong: ${lastPongTimes[sessionId]})") |
| 311 | + unregisterSession(sessionId) |
| 312 | + } |
| 313 | + } |
| 314 | +} |
0 commit comments