Skip to content

Commit f2e6a09

Browse files
ws refactor
1 parent 1dc277a commit f2e6a09

File tree

3 files changed

+43
-47
lines changed

3 files changed

+43
-47
lines changed

src/main/kotlin/com/softeno/template/app/kafka/KafkaSampleHandler.kt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ import com.softeno.template.app.kafka.config.KafkaApplicationProperties
88
import com.softeno.template.app.user.notification.CoroutineUserUpdateEmitter
99
import com.softeno.template.app.user.notification.ReactiveUserUpdateEmitter
1010
import com.softeno.template.sample.websocket.Message
11-
import com.softeno.template.sample.websocket.ReactiveMessageService
12-
import com.softeno.template.sample.websocket.WsMessageService
11+
import com.softeno.template.sample.websocket.WebSocketNotificationSender
1312
import io.micrometer.tracing.Span
1413
import io.micrometer.tracing.Tracer
1514
import kotlinx.coroutines.DelicateCoroutinesApi
@@ -35,7 +34,7 @@ class ReactiveKafkaSampleController(
3534
@Qualifier(value = "kafkaSampleConsumerTemplate") private val reactiveKafkaConsumerTemplate: ReactiveKafkaConsumerTemplate<String, JsonNode>,
3635
private val objectMapper: ObjectMapper,
3736
private val tracer: Tracer,
38-
private val wsMessageService: WsMessageService,
37+
private val ws: WebSocketNotificationSender,
3938
private val reactiveUserUpdateEmitter: ReactiveUserUpdateEmitter,
4039
private val userUpdateEmitter: CoroutineUserUpdateEmitter,
4140
) : CommandLineRunner {
@@ -66,7 +65,7 @@ class ReactiveKafkaSampleController(
6665
val span = tracer.nextSpan().name("kafka-consumer")
6766
tracer.withSpan(span.start()).use {
6867
log.info("[kafka] rx sample: $kafkaMessage")
69-
wsMessageService.broadcast(kafkaMessage.toMessage())
68+
ws.broadcast(kafkaMessage.toMessage())
7069
reactiveUserUpdateEmitter.broadcast(kafkaMessage.toMessage())
7170
userUpdateEmitter.broadcast(kafkaMessage.toMessage())
7271
}

src/main/kotlin/com/softeno/template/sample/websocket/CoroutineWebSocket.kt

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@ class CoroutineWebSocketConfig(
8383
}
8484

8585
// Register session
86-
coroutineMessageService.registerSession(session.id)
86+
coroutineMessageService.registerSession(session)
8787

8888
// Send initial messages
89-
coroutineMessageService.send(Message("SYSTEM", session.id, "HANDSHAKE"), session.id)
90-
coroutineMessageService.send(Message("SYSTEM", session.id, userId), session.id)
89+
coroutineMessageService.send(Message("SYSTEM", session.id, "HANDSHAKE"), session)
90+
coroutineMessageService.send(Message("SYSTEM", session.id, userId), session)
9191

9292
// Start concurrent coroutines for sending and receiving
9393
coroutineScope {
@@ -136,7 +136,7 @@ class CoroutineWebSocketConfig(
136136

137137
// Cleanup
138138
log.info("ws: [chat] disconnect chat session: ${session.id}")
139-
coroutineMessageService.unregisterSession(session.id)
139+
coroutineMessageService.unregisterSession(session)
140140
}
141141

142142
private suspend fun handleWebSocketSessionError(context: CoroutineContext, throwable: Throwable, session: WebSocketSession) =
@@ -149,15 +149,15 @@ class CoroutineWebSocketConfig(
149149
}
150150

151151
private suspend fun closeSession(session: WebSocketSession) {
152-
coroutineMessageService.unregisterSession(session.id)
152+
coroutineMessageService.unregisterSession(session)
153153
session.close().awaitSingleOrNull()
154154
}
155155

156156
private suspend fun handleOutgoingMessages(
157157
session: WebSocketSession,
158158
objectMapper: ObjectMapper
159159
) = withContext(Dispatchers.IO + MDCContext()) {
160-
coroutineMessageService.getMessageFlow(session.id).collect { message ->
160+
coroutineMessageService.getMessageFlow(session).collect { message ->
161161
val json = message.toJson(objectMapper)
162162
log.info("ws: [chat] tx: $json")
163163
session.send(Mono.just(session.textMessage(json))).awaitSingleOrNull()
@@ -183,7 +183,7 @@ class CoroutineWebSocketConfig(
183183
when {
184184
message.content == "pong" && message.to == "SYSTEM" -> {
185185
log.debug("ws: [chat] received pong from session: ${session.id}")
186-
coroutineMessageService.updateLastPong(session.id)
186+
coroutineMessageService.updateLastPong(session)
187187
}
188188
else -> {
189189
coroutineMessageService.routeMessage(message)
@@ -205,10 +205,7 @@ class CoroutineWebSocketConfig(
205205
while (currentCoroutineContext().isActive) {
206206
delay(config.heartbeatIntervalSeconds.toLong().seconds)
207207
log.debug("ws: [chat] sending heartbeat to session: ${session.id}")
208-
coroutineMessageService.send(
209-
Message("SYSTEM", session.id, "ping"),
210-
session.id
211-
)
208+
coroutineMessageService.send(Message("SYSTEM", session.id, "ping"), session)
212209
}
213210
}
214211
}
@@ -221,11 +218,11 @@ class CoroutineWebSocketConfig(
221218
@Service
222219
class CoroutineMessageService(
223220
private val config: ChatConfigProperties
224-
) : WsMessageService {
221+
) : WebSocketNotificationSender {
225222
private val log = LogFactory.getLog(javaClass)
226223

227-
private val messageChannels = ConcurrentHashMap<String, Channel<Message>>()
228-
private val lastPongTimes = ConcurrentHashMap<String, Instant>()
224+
private val messageChannels = ConcurrentHashMap<WebSocketSession, Channel<Message>>()
225+
private val lastPongTimes = ConcurrentHashMap<WebSocketSession, Instant>()
229226

230227
// Lazy heartbeat monitoring
231228
@OptIn(DelicateCoroutinesApi::class)
@@ -238,28 +235,28 @@ class CoroutineMessageService(
238235
}
239236
}
240237

241-
suspend fun registerSession(sessionId: String) = withContext(Dispatchers.IO + MDCContext()) {
242-
messageChannels[sessionId] = Channel(capacity = Channel.UNLIMITED)
243-
lastPongTimes[sessionId] = Instant.now()
238+
suspend fun registerSession(session: WebSocketSession) = withContext(Dispatchers.IO + MDCContext()) {
239+
messageChannels[session] = Channel(capacity = Channel.UNLIMITED)
240+
lastPongTimes[session] = Instant.now()
244241

245242
// Start heartbeat monitoring when first session connects
246243
if (config.staleCheck) {
247244
heartbeatMonitoring
248245
}
249246

250-
log.info("ws: [chat] registered session: $sessionId")
247+
log.info("ws: [chat] registered session: ${session.id}")
251248
}
252249

253-
fun unregisterSession(sessionId: String) {
254-
messageChannels.remove(sessionId)?.close()
255-
lastPongTimes.remove(sessionId)
256-
log.info("ws: [chat] unregistered session: $sessionId")
250+
fun unregisterSession(session: WebSocketSession) {
251+
messageChannels.remove(session)?.close()
252+
lastPongTimes.remove(session)
253+
log.info("ws: [chat] unregistered session: ${session.id}")
257254
}
258255

259-
fun send(message: Message, sessionId: String): Message {
260-
messageChannels[sessionId]?.trySend(message)?.let { result ->
256+
fun send(message: Message, session: WebSocketSession): Message {
257+
messageChannels[session]?.trySend(message)?.let { result ->
261258
if (result.isFailure) {
262-
log.warn("ws: [chat] failed to send message to session: $sessionId")
259+
log.warn("ws: [chat] failed to send message to session: ${session.id}")
263260
}
264261
}
265262
return message
@@ -276,39 +273,39 @@ class CoroutineMessageService(
276273
when (message.to) {
277274
"ALL" -> broadcast(message)
278275
else -> {
279-
// Try to send to specific session first, then to user
280-
if (messageChannels.containsKey(message.to)) {
281-
send(message, message.to)
282-
} else {
283-
log.error("ws: [chat] failed to send message: $message to session: ${message.to} - unknown session")
284-
}
276+
val session = getSession(message.to)
277+
?: throw RuntimeException("ws: [chat] unknown session: ${message.to}")
278+
send(message, session)
285279
}
286280
}
287281
}
282+
fun getSession(sessionId: String): WebSocketSession? {
283+
return messageChannels.keys.firstOrNull { it.id == sessionId }
284+
}
288285

289-
fun getMessageFlow(sessionId: String): Flow<Message> {
290-
return messageChannels[sessionId]?.receiveAsFlow() ?: emptyFlow()
286+
fun getMessageFlow(session: WebSocketSession): Flow<Message> {
287+
return messageChannels[session]?.receiveAsFlow() ?: emptyFlow()
291288
}
292289

293-
fun updateLastPong(sessionId: String) {
294-
lastPongTimes[sessionId] = Instant.now()
290+
fun updateLastPong(session: WebSocketSession) {
291+
lastPongTimes[session] = Instant.now()
295292
}
296293

297294
private fun checkStaleConnections() {
298295
val staleThreshold = Instant.now().minusSeconds(config.staleCheckThresholdSeconds.toLong())
299-
val staleSessions = mutableListOf<String>()
296+
val staleSessions = mutableListOf<WebSocketSession>()
300297

301298
// First, identify stale sessions without modifying the collection
302-
lastPongTimes.forEach { (sessionId, lastPong) ->
299+
lastPongTimes.forEach { (session, lastPong) ->
303300
if (lastPong.isBefore(staleThreshold)) {
304-
staleSessions.add(sessionId)
301+
staleSessions.add(session)
305302
}
306303
}
307304

308305
// Then, clean up stale sessions properly in coroutine context
309-
staleSessions.forEach { sessionId ->
310-
log.warn("ws: [chat] removing stale connection: $sessionId (last pong: ${lastPongTimes[sessionId]})")
311-
unregisterSession(sessionId)
306+
staleSessions.forEach { session ->
307+
log.warn("ws: [chat] removing stale connection: ${session.id} (last pong: ${lastPongTimes[session]})")
308+
unregisterSession(session)
312309
}
313310
}
314311
}

src/main/kotlin/com/softeno/template/sample/websocket/ReactiveWebSocket.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class WebSocketConfig(
117117
}
118118
}
119119

120-
interface WsMessageService {
120+
interface WebSocketNotificationSender {
121121
fun broadcast(message: Message): Message
122122
}
123123

@@ -130,7 +130,7 @@ interface WsMessageService {
130130
class ReactiveMessageService(
131131
private val objectMapper: ObjectMapper,
132132
private val config: ChatConfigProperties
133-
) : WsMessageService {
133+
) : WebSocketNotificationSender {
134134
private val log = LogFactory.getLog(javaClass)
135135

136136
private val sinks: MutableMap<WebSocketSession, Many<String>> = ConcurrentHashMap()

0 commit comments

Comments
 (0)