@@ -47,6 +47,9 @@ const (
4747 defaultReadTimeout = 60 * time .Second
4848 defaultWriteTimeout = 10 * time .Second
4949 defaultMaxMessageSize = 1024 * 1024 // 1MB
50+
51+ // Error message format for wrapped errors
52+ errWrapFormat = "%w: %v"
5053)
5154
5255// WebSocketClient provides bidirectional DIDComm messaging over WebSocket.
@@ -130,7 +133,7 @@ func (c *WebSocketClient) Connect(ctx context.Context, endpoint string) error {
130133
131134 conn , resp , err := dialer .DialContext (ctx , endpoint , nil )
132135 if err != nil {
133- return fmt .Errorf ("%w: %v" , ErrConnectionFailed , err )
136+ return fmt .Errorf (errWrapFormat , ErrConnectionFailed , err )
134137 }
135138
136139 // Verify subprotocol was accepted
@@ -171,7 +174,7 @@ func (c *WebSocketClient) Send(ctx context.Context, message []byte, mediaType st
171174 // Some implementations wrap in an envelope; we send raw for simplicity.
172175 _ = conn .SetWriteDeadline (time .Now ().Add (c .writeTimeout ))
173176 if err := conn .WriteMessage (websocket .BinaryMessage , message ); err != nil {
174- return fmt .Errorf ("%w: %v" , ErrSendFailed , err )
177+ return fmt .Errorf (errWrapFormat , ErrSendFailed , err )
175178 }
176179
177180 return nil
@@ -199,7 +202,7 @@ func (c *WebSocketClient) SendWithEnvelope(ctx context.Context, message []byte,
199202
200203 _ = conn .SetWriteDeadline (time .Now ().Add (c .writeTimeout ))
201204 if err := conn .WriteJSON (envelope ); err != nil {
202- return fmt .Errorf ("%w: %v" , ErrSendFailed , err )
205+ return fmt .Errorf (errWrapFormat , ErrSendFailed , err )
203206 }
204207
205208 return nil
@@ -223,7 +226,7 @@ func (c *WebSocketClient) Receive(ctx context.Context) ([]byte, string, error) {
223226 if websocket .IsCloseError (err , websocket .CloseNormalClosure , websocket .CloseGoingAway ) {
224227 return nil , "" , ErrConnectionClosed
225228 }
226- return nil , "" , fmt .Errorf ("%w: %v" , ErrReceiveFailed , err )
229+ return nil , "" , fmt .Errorf (errWrapFormat , ErrReceiveFailed , err )
227230 }
228231
229232 // Determine media type based on message content
@@ -277,49 +280,63 @@ func (c *WebSocketClient) Close() error {
277280// readLoop reads incoming messages and processes them.
278281func (c * WebSocketClient ) readLoop () {
279282 for {
280- select {
281- case <- c .done :
283+ if c .shouldStopReading () {
282284 return
283- default :
284285 }
285286
286- c .mu .RLock ()
287- conn := c .conn
288- c .mu .RUnlock ()
289-
287+ conn := c .getConn ()
290288 if conn == nil {
291289 return
292290 }
293291
294- _ = conn .SetReadDeadline (time .Now ().Add (c .readTimeout ))
295- msgType , data , err := conn .ReadMessage ()
292+ msgType , data , err := c .readNextMessage (conn )
296293 if err != nil {
297- if ! websocket .IsCloseError (err , websocket .CloseNormalClosure , websocket .CloseGoingAway ) {
298- // Unexpected close
299- }
300294 c .Close ()
301295 return
302296 }
303297
304- // Determine media type
305- mediaType := c .detectMediaType (msgType , data )
298+ c .processIncomingMessage (msgType , data )
299+ }
300+ }
306301
307- // Process message
308- if c .processor != nil {
309- ctx := context .Background ()
310- response , responseMediaType , err := c .processor .ProcessMessage (ctx , data , mediaType )
311- if err != nil {
312- // Log error but continue processing
313- continue
314- }
302+ // shouldStopReading checks if the read loop should terminate.
303+ func (c * WebSocketClient ) shouldStopReading () bool {
304+ select {
305+ case <- c .done :
306+ return true
307+ default :
308+ return false
309+ }
310+ }
315311
316- // Send response if provided
317- if response != nil {
318- if sendErr := c .Send (ctx , response , responseMediaType ); sendErr != nil {
319- // Log error
320- }
321- }
322- }
312+ // getConn returns the current connection safely.
313+ func (c * WebSocketClient ) getConn () * websocket.Conn {
314+ c .mu .RLock ()
315+ defer c .mu .RUnlock ()
316+ return c .conn
317+ }
318+
319+ // readNextMessage reads the next message from the connection.
320+ func (c * WebSocketClient ) readNextMessage (conn * websocket.Conn ) (int , []byte , error ) {
321+ _ = conn .SetReadDeadline (time .Now ().Add (c .readTimeout ))
322+ return conn .ReadMessage ()
323+ }
324+
325+ // processIncomingMessage handles an incoming WebSocket message.
326+ func (c * WebSocketClient ) processIncomingMessage (msgType int , data []byte ) {
327+ if c .processor == nil {
328+ return
329+ }
330+
331+ mediaType := c .detectMediaType (msgType , data )
332+ ctx := context .Background ()
333+ response , responseMediaType , err := c .processor .ProcessMessage (ctx , data , mediaType )
334+ if err != nil {
335+ return
336+ }
337+
338+ if response != nil {
339+ _ = c .Send (ctx , response , responseMediaType )
323340 }
324341}
325342
@@ -418,53 +435,57 @@ func (h *WebSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
418435 }
419436 defer conn .Close ()
420437
421- // Verify subprotocol (gorilla sets this automatically if negotiated)
422- if conn .Subprotocol () != DIDCommSubprotocol {
423- // Client didn't request didcomm/v2 subprotocol - could allow plain WebSocket
424- // or reject. For now, we continue but this may indicate a non-DIDComm client.
425- }
438+ h .setupConnection (conn )
439+ h .connectionLoop (conn , r .Context ())
440+ }
426441
442+ // setupConnection configures the WebSocket connection.
443+ func (h * WebSocketHandler ) setupConnection (conn * websocket.Conn ) {
427444 conn .SetReadLimit (h .maxMessageSize )
428-
429- // Set up ping/pong handlers
430445 conn .SetPongHandler (func (string ) error {
431446 _ = conn .SetReadDeadline (time .Now ().Add (defaultReadTimeout ))
432447 return nil
433448 })
449+ }
434450
435- // Connection loop
451+ // connectionLoop handles the message read/write loop.
452+ func (h * WebSocketHandler ) connectionLoop (conn * websocket.Conn , ctx context.Context ) {
436453 for {
437454 _ = conn .SetReadDeadline (time .Now ().Add (defaultReadTimeout ))
438455 msgType , data , err := conn .ReadMessage ()
439456 if err != nil {
440- if websocket .IsUnexpectedCloseError (err , websocket .CloseGoingAway , websocket .CloseAbnormalClosure ) {
441- // Log unexpected close
442- }
443457 return
444458 }
445459
446- // Determine media type
447- mediaType := detectMediaTypeFromMessage (msgType , data )
448-
449- // Process message
450- response , responseMediaType , err := h .processor .ProcessMessage (r .Context (), data , mediaType )
451- if err != nil {
452- // Log but continue
453- continue
460+ if err := h .handleMessage (conn , ctx , msgType , data ); err != nil {
461+ return
454462 }
463+ }
464+ }
455465
456- // Send response if provided
457- if response != nil {
458- outType := websocket .BinaryMessage
459- if responseMediaType == didcomm .MediaTypePlaintext {
460- outType = websocket .TextMessage
461- }
462- _ = conn .SetWriteDeadline (time .Now ().Add (defaultWriteTimeout ))
463- if err := conn .WriteMessage (outType , response ); err != nil {
464- return
465- }
466- }
466+ // handleMessage processes a single message and sends response if needed.
467+ func (h * WebSocketHandler ) handleMessage (conn * websocket.Conn , ctx context.Context , msgType int , data []byte ) error {
468+ mediaType := detectMediaTypeFromMessage (msgType , data )
469+ response , responseMediaType , err := h .processor .ProcessMessage (ctx , data , mediaType )
470+ if err != nil {
471+ return nil // Log error but continue
472+ }
473+
474+ if response == nil {
475+ return nil
476+ }
477+
478+ return h .sendResponse (conn , response , responseMediaType )
479+ }
480+
481+ // sendResponse writes a response message to the connection.
482+ func (h * WebSocketHandler ) sendResponse (conn * websocket.Conn , response []byte , mediaType string ) error {
483+ outType := websocket .BinaryMessage
484+ if mediaType == didcomm .MediaTypePlaintext {
485+ outType = websocket .TextMessage
467486 }
487+ _ = conn .SetWriteDeadline (time .Now ().Add (defaultWriteTimeout ))
488+ return conn .WriteMessage (outType , response )
468489}
469490
470491// detectMediaTypeFromMessage determines the DIDComm media type from message content.
0 commit comments