Skip to content

Fix WebSocket buffered read and add support for fragmented messages #5255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ import Foundation
import Dispatch

internal class _WebSocketURLProtocol: _HTTPURLProtocol {

private var messageData = Data()

public required init(task: URLSessionTask, cachedResponse: CachedURLResponse?, client: URLProtocolClient?) {
super.init(task: task, cachedResponse: nil, client: client)
}
Expand Down Expand Up @@ -118,14 +121,14 @@ internal class _WebSocketURLProtocol: _HTTPURLProtocol {
lastRedirectBody = redirectBody
}

let flags = easyHandle.getWebSocketFlags()
let (offset, bytesLeft, flags) = easyHandle.getWebSocketMeta()

notifyTask(aboutReceivedData: data, flags: flags)
notifyTask(aboutReceivedData: data, offset: offset, bytesLeft: bytesLeft, flags: flags)
internalState = .transferInProgress(ts)
return .proceed
}

fileprivate func notifyTask(aboutReceivedData data: Data, flags: _EasyHandle.WebSocketFlags) {
fileprivate func notifyTask(aboutReceivedData data: Data, offset: Int64, bytesLeft: Int64, flags: _EasyHandle.WebSocketFlags) {
guard let t = self.task else {
fatalError("Cannot notify")
}
Expand Down Expand Up @@ -159,10 +162,21 @@ internal class _WebSocketURLProtocol: _HTTPURLProtocol {
} else if flags.contains(.pong) {
task.noteReceivedPong()
} else if flags.contains(.binary) {
let message = URLSessionWebSocketTask.Message.data(data)
if bytesLeft > 0 || flags.contains(.cont) {
messageData.append(data)
return
}
messageData.append(data)
let message = URLSessionWebSocketTask.Message.data(messageData)
task.appendReceivedMessage(message)
messageData = Data() // Reset for the next message
} else if flags.contains(.text) {
guard let utf8 = String(data: data, encoding: .utf8) else {
if bytesLeft > 0 || flags.contains(.cont) {
messageData.append(data)
return
}
messageData.append(data)
guard let utf8 = String(data: messageData, encoding: .utf8) else {
NSLog("Invalid utf8 message received from server \(data)")
let error = NSError(domain: NSURLErrorDomain, code: NSURLErrorBadServerResponse,
userInfo: [
Expand All @@ -175,6 +189,7 @@ internal class _WebSocketURLProtocol: _HTTPURLProtocol {
}
let message = URLSessionWebSocketTask.Message.string(utf8)
task.appendReceivedMessage(message)
messageData = Data() // Reset for the next message
} else {
NSLog("Unexpected message received from server \(data) \(flags)")
let error = NSError(domain: NSURLErrorDomain, code: NSURLErrorBadServerResponse,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,10 @@ extension _EasyHandle {
}

// Only valid to call within a didReceive(data:size:nmemb:) call
func getWebSocketFlags() -> WebSocketFlags {
func getWebSocketMeta() -> (Int64, Int64, WebSocketFlags) {
let metadataPointer = CFURLSessionEasyHandleWebSocketsMetadata(rawHandle)
let flags = WebSocketFlags(rawValue: metadataPointer.pointee.flags)
return flags
return (metadataPointer.pointee.offset, metadataPointer.pointee.bytesLeft, flags)
}

func receiveWebSocketsData() throws -> (Data, WebSocketFlags) {
Expand Down
99 changes: 91 additions & 8 deletions Tests/Foundation/HTTPServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -914,21 +914,41 @@ public class TestURLSessionServer: CustomStringConvertible {
"Connection: Upgrade"]

let expectFullRequestResponseTests: Bool
let bufferedSendingTests: Bool
let fragmentedTests: Bool
let sendClosePacket: Bool
let completeUpgrade: Bool

let uri = request.uri
switch uri {
case "/web-socket":
expectFullRequestResponseTests = true
bufferedSendingTests = false
fragmentedTests = false
completeUpgrade = true
sendClosePacket = true
case "/web-socket/buffered-sending":
expectFullRequestResponseTests = true
bufferedSendingTests = true
fragmentedTests = false
completeUpgrade = true
sendClosePacket = true
case "/web-socket/fragmented":
expectFullRequestResponseTests = true
bufferedSendingTests = false
fragmentedTests = true
completeUpgrade = true
sendClosePacket = true
case "/web-socket/semi-abrupt-close":
expectFullRequestResponseTests = false
bufferedSendingTests = false
fragmentedTests = false
completeUpgrade = true
sendClosePacket = false
case "/web-socket/abrupt-close":
expectFullRequestResponseTests = false
bufferedSendingTests = false
fragmentedTests = false
completeUpgrade = false
sendClosePacket = false
default:
Expand All @@ -944,6 +964,8 @@ public class TestURLSessionServer: CustomStringConvertible {
}
responseHeaders.append("Sec-WebSocket-Protocol: \(expectedProtocol)")
expectFullRequestResponseTests = false
bufferedSendingTests = false
fragmentedTests = false
completeUpgrade = true
sendClosePacket = true
}
Expand Down Expand Up @@ -978,10 +1000,41 @@ public class TestURLSessionServer: CustomStringConvertible {
NSLog("Invalid string frame")
throw InternalServerError.badBody
}

// Send a string message
let sendStringFrame = Data([0x81, UInt8(stringPayload.count)]) + stringPayload
try httpServer.tcpSocket.writeRawData(sendStringFrame)

if bufferedSendingTests {
// Send a string message in chunks of 2 bytes
let sendStringFrame = Data([0x81, UInt8(stringPayload.count)]) + stringPayload
let bufferSize = 2 // Let's assume the server has a buffer size of 2 bytes
for i in stride(from: 0, to: sendStringFrame.count, by: bufferSize) {
let end = min(i + bufferSize, sendStringFrame.count)
let chunk = sendStringFrame.subdata(in: i..<end)
try httpServer.tcpSocket.writeRawData(chunk)
Thread.sleep(forTimeInterval: 0.1) // Sleep to simulate buffered sending
}
}
else if fragmentedTests {
// Send a string message fragmented by 1 byte
for (i, byte) in stringPayload.enumerated() {
var frame = Data()
let isFirst = i == 0
let isLast = i == stringPayload.count - 1

let finBit: UInt8 = isLast ? 0x80 : 0x00
let opcode: UInt8 = isFirst ? 0x1 : 0x0 // 0x1 = text, 0x0 = continuation
let header: UInt8 = finBit | opcode

frame.append(header)
frame.append(0x01) // payload length 1, unmasked
frame.append(byte)

try httpServer.tcpSocket.writeRawData(frame)
}
}
else {
// Send a string message
let sendStringFrame = Data([0x81, UInt8(stringPayload.count)]) + stringPayload
try httpServer.tcpSocket.writeRawData(sendStringFrame)
}

// Receive a data message
guard let dataFrame = try httpServer.tcpSocket.readData(),
Expand All @@ -991,10 +1044,40 @@ public class TestURLSessionServer: CustomStringConvertible {
NSLog("Invalid data frame")
throw InternalServerError.badBody
}

// Send a data message
let sendDataFrame = Data([0x82, UInt8(dataPayload.count)]) + dataPayload
try httpServer.tcpSocket.writeRawData(sendDataFrame)

if bufferedSendingTests {
let sendDataFrame = Data([0x82, UInt8(dataPayload.count)]) + dataPayload
let bufferSize = 2 // Let's assume the server has a buffer size of 2 bytes
for i in stride(from: 0, to: sendDataFrame.count, by: bufferSize) {
let end = min(i + bufferSize, sendDataFrame.count)
let chunk = sendDataFrame.subdata(in: i..<end)
try httpServer.tcpSocket.writeRawData(chunk)
Thread.sleep(forTimeInterval: 0.1) // Sleep to simulate buffered sending
}
}
else if fragmentedTests {
// Send a data message fragmented by 1 byte
for (i, byte) in dataPayload.enumerated() {
var frame = Data()
let isFirst = i == 0
let isLast = i == dataPayload.count - 1

let finBit: UInt8 = isLast ? 0x80 : 0x00
let opcode: UInt8 = isFirst ? 0x2 : 0x0 // 0x2 = text, 0x0 = continuation
let header: UInt8 = finBit | opcode

frame.append(header)
frame.append(0x01) // payload length 1, unmasked
frame.append(byte)

try httpServer.tcpSocket.writeRawData(frame)
}
}
else {
// Send a data message
let sendDataFrame = Data([0x82, UInt8(dataPayload.count)]) + dataPayload
try httpServer.tcpSocket.writeRawData(sendDataFrame)
}

// Receive a ping
guard let pingFrame = try httpServer.tcpSocket.readData(),
Expand Down
107 changes: 56 additions & 51 deletions Tests/Foundation/TestURLSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2101,59 +2101,64 @@ final class TestURLSession: LoopbackServerTest, @unchecked Sendable {
print("libcurl lacks WebSockets support, skipping \(#function)")
return
}

let urlString = "ws://127.0.0.1:\(TestURLSession.serverPort)/web-socket"
let url = try XCTUnwrap(URL(string: urlString))
let request = URLRequest(url: url)

let delegate = SessionDelegate(with: expectation(description: "\(urlString): Connect"))
let task = delegate.runWebSocketTask(with: request, timeoutInterval: 4)

// We interleave sending and receiving, as the test HTTPServer implementation is barebones, and can't handle receiving more than one frame at a time. So, this back-and-forth acts as a gating mechanism
try await task.send(.string("Hello"))

let stringMessage = try await task.receive()
switch stringMessage {
case .string(let str):
XCTAssert(str == "Hello")
default:
XCTFail("Unexpected String Message")
}

try await task.send(.data(Data([0x20, 0x22, 0x10, 0x03])))

let dataMessage = try await task.receive()
switch dataMessage {
case .data(let data):
XCTAssert(data == Data([0x20, 0x22, 0x10, 0x03]))
default:
XCTFail("Unexpected Data Message")
}

do {
try await task.sendPing()
// Server hasn't closed the connection yet
} catch {
// Server closed the connection before we could process the pong
let urlError = try XCTUnwrap(error as? URLError)
XCTAssertEqual(urlError._nsError.code, NSURLErrorNetworkConnectionLost)
}

await fulfillment(of: [delegate.expectation], timeout: 50)

do {
_ = try await task.receive()
XCTFail("Expected to throw when receiving on closed task")
} catch {
let urlError = try XCTUnwrap(error as? URLError)
XCTAssertEqual(urlError._nsError.code, NSURLErrorNetworkConnectionLost)
func testWebSocket(withURL urlString: String) async throws -> Void {
let url = try XCTUnwrap(URL(string: urlString))
let request = URLRequest(url: url)

let delegate = SessionDelegate(with: expectation(description: "\(urlString): Connect"))
let task = delegate.runWebSocketTask(with: request, timeoutInterval: 4)

// We interleave sending and receiving, as the test HTTPServer implementation is barebones, and can't handle receiving more than one frame at a time. So, this back-and-forth acts as a gating mechanism
try await task.send(.string("Hello"))

let stringMessage = try await task.receive()
switch stringMessage {
case .string(let str):
XCTAssert(str == "Hello")
default:
XCTFail("Unexpected String Message")
}

try await task.send(.data(Data([0x20, 0x22, 0x10, 0x03])))

let dataMessage = try await task.receive()
switch dataMessage {
case .data(let data):
XCTAssert(data == Data([0x20, 0x22, 0x10, 0x03]))
default:
XCTFail("Unexpected Data Message")
}

do {
try await task.sendPing()
// Server hasn't closed the connection yet
} catch {
// Server closed the connection before we could process the pong
let urlError = try XCTUnwrap(error as? URLError)
XCTAssertEqual(urlError._nsError.code, NSURLErrorNetworkConnectionLost)
}

await fulfillment(of: [delegate.expectation], timeout: 50)

do {
_ = try await task.receive()
XCTFail("Expected to throw when receiving on closed task")
} catch {
let urlError = try XCTUnwrap(error as? URLError)
XCTAssertEqual(urlError._nsError.code, NSURLErrorNetworkConnectionLost)
}

let callbacks = [ "urlSession(_:webSocketTask:didOpenWithProtocol:)",
"urlSession(_:webSocketTask:didCloseWith:reason:)",
"urlSession(_:task:didCompleteWithError:)" ]
XCTAssertEqual(delegate.callbacks.count, callbacks.count)
XCTAssertEqual(delegate.callbacks, callbacks, "Callbacks for \(#function)")
}

let callbacks = [ "urlSession(_:webSocketTask:didOpenWithProtocol:)",
"urlSession(_:webSocketTask:didCloseWith:reason:)",
"urlSession(_:task:didCompleteWithError:)" ]
XCTAssertEqual(delegate.callbacks.count, callbacks.count)
XCTAssertEqual(delegate.callbacks, callbacks, "Callbacks for \(#function)")

try await testWebSocket(withURL: "ws://127.0.0.1:\(TestURLSession.serverPort)/web-socket")
try await testWebSocket(withURL: "ws://127.0.0.1:\(TestURLSession.serverPort)/web-socket/buffered-sending")
try await testWebSocket(withURL: "ws://127.0.0.1:\(TestURLSession.serverPort)/web-socket/fragmented")
}

func test_webSocketShared() async throws {
Expand Down