diff --git a/Sources/FoundationNetworking/URLSession/WebSocket/WebSocketURLProtocol.swift b/Sources/FoundationNetworking/URLSession/WebSocket/WebSocketURLProtocol.swift index 8216f23d58..e35612ddae 100644 --- a/Sources/FoundationNetworking/URLSession/WebSocket/WebSocketURLProtocol.swift +++ b/Sources/FoundationNetworking/URLSession/WebSocket/WebSocketURLProtocol.swift @@ -17,6 +17,9 @@ import Foundation import Dispatch internal class _WebSocketURLProtocol: _HTTPURLProtocol { + + private var messageData = Data() + public required init(task: URLSessionTask, cachedResponse: CachedURLResponse?, client: URLProtocolClient?) { super.init(task: task, cachedResponse: nil, client: client) } @@ -118,14 +121,14 @@ internal class _WebSocketURLProtocol: _HTTPURLProtocol { lastRedirectBody = redirectBody } - let flags = easyHandle.getWebSocketFlags() + let (offset, bytesLeft, flags) = easyHandle.getWebSocketMeta() - notifyTask(aboutReceivedData: data, flags: flags) + notifyTask(aboutReceivedData: data, offset: offset, bytesLeft: bytesLeft, flags: flags) internalState = .transferInProgress(ts) return .proceed } - fileprivate func notifyTask(aboutReceivedData data: Data, flags: _EasyHandle.WebSocketFlags) { + fileprivate func notifyTask(aboutReceivedData data: Data, offset: Int64, bytesLeft: Int64, flags: _EasyHandle.WebSocketFlags) { guard let t = self.task else { fatalError("Cannot notify") } @@ -159,10 +162,21 @@ internal class _WebSocketURLProtocol: _HTTPURLProtocol { } else if flags.contains(.pong) { task.noteReceivedPong() } else if flags.contains(.binary) { - let message = URLSessionWebSocketTask.Message.data(data) + if bytesLeft > 0 || flags.contains(.cont) { + messageData.append(data) + return + } + messageData.append(data) + let message = URLSessionWebSocketTask.Message.data(messageData) task.appendReceivedMessage(message) + messageData = Data() // Reset for the next message } else if flags.contains(.text) { - guard let utf8 = String(data: data, encoding: .utf8) else { + if bytesLeft > 0 || flags.contains(.cont) { + messageData.append(data) + return + } + messageData.append(data) + guard let utf8 = String(data: messageData, encoding: .utf8) else { NSLog("Invalid utf8 message received from server \(data)") let error = NSError(domain: NSURLErrorDomain, code: NSURLErrorBadServerResponse, userInfo: [ @@ -175,6 +189,7 @@ internal class _WebSocketURLProtocol: _HTTPURLProtocol { } let message = URLSessionWebSocketTask.Message.string(utf8) task.appendReceivedMessage(message) + messageData = Data() // Reset for the next message } else { NSLog("Unexpected message received from server \(data) \(flags)") let error = NSError(domain: NSURLErrorDomain, code: NSURLErrorBadServerResponse, diff --git a/Sources/FoundationNetworking/URLSession/libcurl/EasyHandle.swift b/Sources/FoundationNetworking/URLSession/libcurl/EasyHandle.swift index e9b90a23f1..a891c9818e 100644 --- a/Sources/FoundationNetworking/URLSession/libcurl/EasyHandle.swift +++ b/Sources/FoundationNetworking/URLSession/libcurl/EasyHandle.swift @@ -375,10 +375,10 @@ extension _EasyHandle { } // Only valid to call within a didReceive(data:size:nmemb:) call - func getWebSocketFlags() -> WebSocketFlags { + func getWebSocketMeta() -> (Int64, Int64, WebSocketFlags) { let metadataPointer = CFURLSessionEasyHandleWebSocketsMetadata(rawHandle) let flags = WebSocketFlags(rawValue: metadataPointer.pointee.flags) - return flags + return (metadataPointer.pointee.offset, metadataPointer.pointee.bytesLeft, flags) } func receiveWebSocketsData() throws -> (Data, WebSocketFlags) { diff --git a/Tests/Foundation/HTTPServer.swift b/Tests/Foundation/HTTPServer.swift index a6457f285f..9922d32897 100644 --- a/Tests/Foundation/HTTPServer.swift +++ b/Tests/Foundation/HTTPServer.swift @@ -914,6 +914,8 @@ public class TestURLSessionServer: CustomStringConvertible { "Connection: Upgrade"] let expectFullRequestResponseTests: Bool + let bufferedSendingTests: Bool + let fragmentedTests: Bool let sendClosePacket: Bool let completeUpgrade: Bool @@ -921,14 +923,32 @@ public class TestURLSessionServer: CustomStringConvertible { switch uri { case "/web-socket": expectFullRequestResponseTests = true + bufferedSendingTests = false + fragmentedTests = false + completeUpgrade = true + sendClosePacket = true + case "/web-socket/buffered-sending": + expectFullRequestResponseTests = true + bufferedSendingTests = true + fragmentedTests = false + completeUpgrade = true + sendClosePacket = true + case "/web-socket/fragmented": + expectFullRequestResponseTests = true + bufferedSendingTests = false + fragmentedTests = true completeUpgrade = true sendClosePacket = true case "/web-socket/semi-abrupt-close": expectFullRequestResponseTests = false + bufferedSendingTests = false + fragmentedTests = false completeUpgrade = true sendClosePacket = false case "/web-socket/abrupt-close": expectFullRequestResponseTests = false + bufferedSendingTests = false + fragmentedTests = false completeUpgrade = false sendClosePacket = false default: @@ -944,6 +964,8 @@ public class TestURLSessionServer: CustomStringConvertible { } responseHeaders.append("Sec-WebSocket-Protocol: \(expectedProtocol)") expectFullRequestResponseTests = false + bufferedSendingTests = false + fragmentedTests = false completeUpgrade = true sendClosePacket = true } @@ -978,10 +1000,41 @@ public class TestURLSessionServer: CustomStringConvertible { NSLog("Invalid string frame") throw InternalServerError.badBody } - - // Send a string message - let sendStringFrame = Data([0x81, UInt8(stringPayload.count)]) + stringPayload - try httpServer.tcpSocket.writeRawData(sendStringFrame) + + if bufferedSendingTests { + // Send a string message in chunks of 2 bytes + let sendStringFrame = Data([0x81, UInt8(stringPayload.count)]) + stringPayload + let bufferSize = 2 // Let's assume the server has a buffer size of 2 bytes + for i in stride(from: 0, to: sendStringFrame.count, by: bufferSize) { + let end = min(i + bufferSize, sendStringFrame.count) + let chunk = sendStringFrame.subdata(in: i.. Void { + let url = try XCTUnwrap(URL(string: urlString)) + let request = URLRequest(url: url) + + let delegate = SessionDelegate(with: expectation(description: "\(urlString): Connect")) + let task = delegate.runWebSocketTask(with: request, timeoutInterval: 4) + + // We interleave sending and receiving, as the test HTTPServer implementation is barebones, and can't handle receiving more than one frame at a time. So, this back-and-forth acts as a gating mechanism + try await task.send(.string("Hello")) + + let stringMessage = try await task.receive() + switch stringMessage { + case .string(let str): + XCTAssert(str == "Hello") + default: + XCTFail("Unexpected String Message") + } + + try await task.send(.data(Data([0x20, 0x22, 0x10, 0x03]))) + + let dataMessage = try await task.receive() + switch dataMessage { + case .data(let data): + XCTAssert(data == Data([0x20, 0x22, 0x10, 0x03])) + default: + XCTFail("Unexpected Data Message") + } + + do { + try await task.sendPing() + // Server hasn't closed the connection yet + } catch { + // Server closed the connection before we could process the pong + let urlError = try XCTUnwrap(error as? URLError) + XCTAssertEqual(urlError._nsError.code, NSURLErrorNetworkConnectionLost) + } + + await fulfillment(of: [delegate.expectation], timeout: 50) + + do { + _ = try await task.receive() + XCTFail("Expected to throw when receiving on closed task") + } catch { + let urlError = try XCTUnwrap(error as? URLError) + XCTAssertEqual(urlError._nsError.code, NSURLErrorNetworkConnectionLost) + } + + let callbacks = [ "urlSession(_:webSocketTask:didOpenWithProtocol:)", + "urlSession(_:webSocketTask:didCloseWith:reason:)", + "urlSession(_:task:didCompleteWithError:)" ] + XCTAssertEqual(delegate.callbacks.count, callbacks.count) + XCTAssertEqual(delegate.callbacks, callbacks, "Callbacks for \(#function)") } - - let callbacks = [ "urlSession(_:webSocketTask:didOpenWithProtocol:)", - "urlSession(_:webSocketTask:didCloseWith:reason:)", - "urlSession(_:task:didCompleteWithError:)" ] - XCTAssertEqual(delegate.callbacks.count, callbacks.count) - XCTAssertEqual(delegate.callbacks, callbacks, "Callbacks for \(#function)") + + try await testWebSocket(withURL: "ws://127.0.0.1:\(TestURLSession.serverPort)/web-socket") + try await testWebSocket(withURL: "ws://127.0.0.1:\(TestURLSession.serverPort)/web-socket/buffered-sending") + try await testWebSocket(withURL: "ws://127.0.0.1:\(TestURLSession.serverPort)/web-socket/fragmented") } func test_webSocketShared() async throws {