@@ -83,11 +83,11 @@ class CoroutineWebSocketConfig(
8383 }
8484
8585 // Register session
86- coroutineMessageService.registerSession(session.id )
86+ coroutineMessageService.registerSession(session)
8787
8888 // Send initial messages
89- coroutineMessageService.send(Message (" SYSTEM" , session.id, " HANDSHAKE" ), session.id )
90- coroutineMessageService.send(Message (" SYSTEM" , session.id, userId), session.id )
89+ coroutineMessageService.send(Message (" SYSTEM" , session.id, " HANDSHAKE" ), session)
90+ coroutineMessageService.send(Message (" SYSTEM" , session.id, userId), session)
9191
9292 // Start concurrent coroutines for sending and receiving
9393 coroutineScope {
@@ -136,7 +136,7 @@ class CoroutineWebSocketConfig(
136136
137137 // Cleanup
138138 log.info(" ws: [chat] disconnect chat session: ${session.id} " )
139- coroutineMessageService.unregisterSession(session.id )
139+ coroutineMessageService.unregisterSession(session)
140140 }
141141
142142 private suspend fun handleWebSocketSessionError (context : CoroutineContext , throwable : Throwable , session : WebSocketSession ) =
@@ -149,15 +149,15 @@ class CoroutineWebSocketConfig(
149149 }
150150
151151 private suspend fun closeSession (session : WebSocketSession ) {
152- coroutineMessageService.unregisterSession(session.id )
152+ coroutineMessageService.unregisterSession(session)
153153 session.close().awaitSingleOrNull()
154154 }
155155
156156 private suspend fun handleOutgoingMessages (
157157 session : WebSocketSession ,
158158 objectMapper : ObjectMapper
159159 ) = withContext(Dispatchers .IO + MDCContext ()) {
160- coroutineMessageService.getMessageFlow(session.id ).collect { message ->
160+ coroutineMessageService.getMessageFlow(session).collect { message ->
161161 val json = message.toJson(objectMapper)
162162 log.info(" ws: [chat] tx: $json " )
163163 session.send(Mono .just(session.textMessage(json))).awaitSingleOrNull()
@@ -183,7 +183,7 @@ class CoroutineWebSocketConfig(
183183 when {
184184 message.content == " pong" && message.to == " SYSTEM" -> {
185185 log.debug(" ws: [chat] received pong from session: ${session.id} " )
186- coroutineMessageService.updateLastPong(session.id )
186+ coroutineMessageService.updateLastPong(session)
187187 }
188188 else -> {
189189 coroutineMessageService.routeMessage(message)
@@ -205,10 +205,7 @@ class CoroutineWebSocketConfig(
205205 while (currentCoroutineContext().isActive) {
206206 delay(config.heartbeatIntervalSeconds.toLong().seconds)
207207 log.debug(" ws: [chat] sending heartbeat to session: ${session.id} " )
208- coroutineMessageService.send(
209- Message (" SYSTEM" , session.id, " ping" ),
210- session.id
211- )
208+ coroutineMessageService.send(Message (" SYSTEM" , session.id, " ping" ), session)
212209 }
213210 }
214211}
@@ -221,11 +218,11 @@ class CoroutineWebSocketConfig(
221218@Service
222219class CoroutineMessageService (
223220 private val config : ChatConfigProperties
224- ) : WsMessageService {
221+ ) : WebSocketNotificationSender {
225222 private val log = LogFactory .getLog(javaClass)
226223
227- private val messageChannels = ConcurrentHashMap <String , Channel <Message >>()
228- private val lastPongTimes = ConcurrentHashMap <String , Instant >()
224+ private val messageChannels = ConcurrentHashMap <WebSocketSession , Channel <Message >>()
225+ private val lastPongTimes = ConcurrentHashMap <WebSocketSession , Instant >()
229226
230227 // Lazy heartbeat monitoring
231228 @OptIn(DelicateCoroutinesApi ::class )
@@ -238,28 +235,28 @@ class CoroutineMessageService(
238235 }
239236 }
240237
241- suspend fun registerSession (sessionId : String ) = withContext(Dispatchers .IO + MDCContext ()) {
242- messageChannels[sessionId ] = Channel (capacity = Channel .UNLIMITED )
243- lastPongTimes[sessionId ] = Instant .now()
238+ suspend fun registerSession (session : WebSocketSession ) = withContext(Dispatchers .IO + MDCContext ()) {
239+ messageChannels[session ] = Channel (capacity = Channel .UNLIMITED )
240+ lastPongTimes[session ] = Instant .now()
244241
245242 // Start heartbeat monitoring when first session connects
246243 if (config.staleCheck) {
247244 heartbeatMonitoring
248245 }
249246
250- log.info(" ws: [chat] registered session: $sessionId " )
247+ log.info(" ws: [chat] registered session: ${session.id} " )
251248 }
252249
253- fun unregisterSession (sessionId : String ) {
254- messageChannels.remove(sessionId )?.close()
255- lastPongTimes.remove(sessionId )
256- log.info(" ws: [chat] unregistered session: $sessionId " )
250+ fun unregisterSession (session : WebSocketSession ) {
251+ messageChannels.remove(session )?.close()
252+ lastPongTimes.remove(session )
253+ log.info(" ws: [chat] unregistered session: ${session.id} " )
257254 }
258255
259- fun send (message : Message , sessionId : String ): Message {
260- messageChannels[sessionId ]?.trySend(message)?.let { result ->
256+ fun send (message : Message , session : WebSocketSession ): Message {
257+ messageChannels[session ]?.trySend(message)?.let { result ->
261258 if (result.isFailure) {
262- log.warn(" ws: [chat] failed to send message to session: $sessionId " )
259+ log.warn(" ws: [chat] failed to send message to session: ${session.id} " )
263260 }
264261 }
265262 return message
@@ -276,39 +273,39 @@ class CoroutineMessageService(
276273 when (message.to) {
277274 " ALL" -> broadcast(message)
278275 else -> {
279- // Try to send to specific session first, then to user
280- if (messageChannels.containsKey(message.to)) {
281- send(message, message.to)
282- } else {
283- log.error(" ws: [chat] failed to send message: $message to session: ${message.to} - unknown session" )
284- }
276+ val session = getSession(message.to)
277+ ? : throw RuntimeException (" ws: [chat] unknown session: ${message.to} " )
278+ send(message, session)
285279 }
286280 }
287281 }
282+ fun getSession (sessionId : String ): WebSocketSession ? {
283+ return messageChannels.keys.firstOrNull { it.id == sessionId }
284+ }
288285
289- fun getMessageFlow (sessionId : String ): Flow <Message > {
290- return messageChannels[sessionId ]?.receiveAsFlow() ? : emptyFlow()
286+ fun getMessageFlow (session : WebSocketSession ): Flow <Message > {
287+ return messageChannels[session ]?.receiveAsFlow() ? : emptyFlow()
291288 }
292289
293- fun updateLastPong (sessionId : String ) {
294- lastPongTimes[sessionId ] = Instant .now()
290+ fun updateLastPong (session : WebSocketSession ) {
291+ lastPongTimes[session ] = Instant .now()
295292 }
296293
297294 private fun checkStaleConnections () {
298295 val staleThreshold = Instant .now().minusSeconds(config.staleCheckThresholdSeconds.toLong())
299- val staleSessions = mutableListOf<String >()
296+ val staleSessions = mutableListOf<WebSocketSession >()
300297
301298 // First, identify stale sessions without modifying the collection
302- lastPongTimes.forEach { (sessionId , lastPong) ->
299+ lastPongTimes.forEach { (session , lastPong) ->
303300 if (lastPong.isBefore(staleThreshold)) {
304- staleSessions.add(sessionId )
301+ staleSessions.add(session )
305302 }
306303 }
307304
308305 // Then, clean up stale sessions properly in coroutine context
309- staleSessions.forEach { sessionId ->
310- log.warn(" ws: [chat] removing stale connection: $sessionId (last pong: ${lastPongTimes[sessionId ]} )" )
311- unregisterSession(sessionId )
306+ staleSessions.forEach { session ->
307+ log.warn(" ws: [chat] removing stale connection: ${session.id} (last pong: ${lastPongTimes[session ]} )" )
308+ unregisterSession(session )
312309 }
313310 }
314311}
0 commit comments