Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions Sources/AsyncHTTPClient/HTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
case uncleanShutdown
case traceRequestWithBody
case invalidHeaderFieldNames([String])
case bodyLengthMismatch
}

private var code: Code
Expand Down Expand Up @@ -969,10 +970,12 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
public static let redirectLimitReached = HTTPClientError(code: .redirectLimitReached)
/// Redirect Cycle detected.
public static let redirectCycleDetected = HTTPClientError(code: .redirectCycleDetected)
/// Unclean shutdown
/// Unclean shutdown.
public static let uncleanShutdown = HTTPClientError(code: .uncleanShutdown)
/// A body was sent in a request with method TRACE
/// A body was sent in a request with method TRACE.
public static let traceRequestWithBody = HTTPClientError(code: .traceRequestWithBody)
/// Header field names contain invalid characters
/// Header field names contain invalid characters.
public static func invalidHeaderFieldNames(_ names: [String]) -> HTTPClientError { return HTTPClientError(code: .invalidHeaderFieldNames(names)) }
/// Body length is not equal to `Content-Length`.
public static let bodyLengthMismatch = HTTPClientError(code: .bodyLengthMismatch)
}
44 changes: 30 additions & 14 deletions Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ internal class TaskHandler<Delegate: HTTPClientResponseDelegate>: RemovableChann
case head
case redirected(HTTPResponseHead, URL)
case body
case end
case endOrError
}

let task: HTTPClient.Task<Delegate.Response>
Expand All @@ -651,6 +651,8 @@ internal class TaskHandler<Delegate: HTTPClientResponseDelegate>: RemovableChann
let logger: Logger // We are okay to store the logger here because a TaskHandler is just for one request.

var state: State = .idle
var expectedBodyLength: Int?
var actualBodyLength: Int = 0
var pendingRead = false
var mayRead = true
var closing = false {
Expand Down Expand Up @@ -780,7 +782,7 @@ extension TaskHandler: ChannelDuplexHandler {
} catch {
promise?.fail(error)
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
self.state = .end
self.state = .endOrError
return
}

Expand All @@ -794,12 +796,24 @@ extension TaskHandler: ChannelDuplexHandler {
assert(head.version == HTTPVersion(major: 1, minor: 1),
"Sending a request in HTTP version \(head.version) which is unsupported by the above `if`")


let contentLengths = head.headers[canonicalForm: "content-length"]
assert(contentLengths.count <= 1)

self.expectedBodyLength = contentLengths.first.flatMap { Int($0) }

context.write(wrapOutboundOut(.head(head))).map {
self.callOutToDelegateFireAndForget(value: head, self.delegate.didSendRequestHead)
}.flatMap {
self.writeBody(request: request, context: context)
}.flatMap {
context.eventLoop.assertInEventLoop()
if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength {
self.state = .endOrError
let error = HTTPClientError.bodyLengthMismatch
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
return context.eventLoop.makeFailedFuture(error)
}
return context.writeAndFlush(self.wrapOutboundOut(.end(nil)))
}.map {
context.eventLoop.assertInEventLoop()
Expand All @@ -808,10 +822,10 @@ extension TaskHandler: ChannelDuplexHandler {
}.flatMapErrorThrowing { error in
context.eventLoop.assertInEventLoop()
switch self.state {
case .end:
case .endOrError:
break
default:
self.state = .end
self.state = .endOrError
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
}
throw error
Expand All @@ -829,9 +843,11 @@ extension TaskHandler: ChannelDuplexHandler {
// All writes have to be switched to the channel EL if channel and task ELs differ
if context.eventLoop.inEventLoop {
context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise)
self.actualBodyLength += part.readableBytes
} else {
context.eventLoop.execute {
context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise)
self.actualBodyLength += part.readableBytes
}
}

Expand Down Expand Up @@ -893,12 +909,12 @@ extension TaskHandler: ChannelDuplexHandler {
case .end:
switch self.state {
case .redirected(let head, let redirectURL):
self.state = .end
self.state = .endOrError
self.task.releaseAssociatedConnection(delegateType: Delegate.self, closing: self.closing).whenSuccess {
self.redirectHandler?.redirect(status: head.status, to: redirectURL, promise: self.task.promise)
}
default:
self.state = .end
self.state = .endOrError
self.callOutToDelegate(promise: self.task.promise, self.delegate.didFinishRequest)
}
}
Expand All @@ -913,14 +929,14 @@ extension TaskHandler: ChannelDuplexHandler {
context.read()
}
case .failure(let error):
self.state = .end
self.state = .endOrError
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
}
}

func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
if (event as? IdleStateHandler.IdleStateEvent) == .read {
self.state = .end
self.state = .endOrError
let error = HTTPClientError.readTimeout
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
} else {
Expand All @@ -930,7 +946,7 @@ extension TaskHandler: ChannelDuplexHandler {

func triggerUserOutboundEvent(context: ChannelHandlerContext, event: Any, promise: EventLoopPromise<Void>?) {
if (event as? TaskCancelEvent) != nil {
self.state = .end
self.state = .endOrError
let error = HTTPClientError.cancelled
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
promise?.succeed(())
Expand All @@ -941,10 +957,10 @@ extension TaskHandler: ChannelDuplexHandler {

func channelInactive(context: ChannelHandlerContext) {
switch self.state {
case .end:
case .endOrError:
break
case .body, .head, .idle, .redirected, .sent:
self.state = .end
self.state = .endOrError
let error = HTTPClientError.remoteConnectionClosed
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
}
Expand All @@ -955,7 +971,7 @@ extension TaskHandler: ChannelDuplexHandler {
switch error {
case NIOSSLError.uncleanShutdown:
switch self.state {
case .end:
case .endOrError:
/// Some HTTP Servers can 'forget' to respond with CloseNotify when client is closing connection,
/// this could lead to incomplete SSL shutdown. But since request is already processed, we can ignore this error.
break
Expand All @@ -964,11 +980,11 @@ extension TaskHandler: ChannelDuplexHandler {
/// We can also ignore this error like `.end`.
break
default:
self.state = .end
self.state = .endOrError
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
}
default:
self.state = .end
self.state = .endOrError
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
}
}
Expand Down
33 changes: 22 additions & 11 deletions Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ internal final class HTTPBin {
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
let serverChannel: Channel
let isShutdown: NIOAtomic<Bool> = .makeAtomic(value: false)
var connections: NIOAtomic<Int>
var connectionCount: NIOAtomic<Int> = .makeAtomic(value: 0)
private let activeConnCounterHandler: CountActiveConnectionsHandler
var activeConnections: Int {
Expand Down Expand Up @@ -233,6 +234,9 @@ internal final class HTTPBin {
let activeConnCounterHandler = CountActiveConnectionsHandler()
self.activeConnCounterHandler = activeConnCounterHandler

let connections = NIOAtomic.makeAtomic(value: 0)
self.connections = connections

self.serverChannel = try! ServerBootstrap(group: self.group)
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.serverChannelInitializer { channel in
Expand Down Expand Up @@ -261,10 +265,10 @@ internal final class HTTPBin {
}.flatMap {
if ssl {
return HTTPBin.configureTLS(channel: channel).flatMap {
channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge))
channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge, connectionId: connections.add(1)))
}
} else {
return channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge))
return channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge, connectionId: connections.add(1)))
}
}
}
Expand Down Expand Up @@ -357,8 +361,8 @@ internal struct HTTPResponseBuilder {
}
}

let globalRequestCounter = NIOAtomic<Int>.makeAtomic(value: 0)
let globalConnectionCounter = NIOAtomic<Int>.makeAtomic(value: 0)
//let globalRequestCounter = NIOAtomic<Int>.makeAtomic(value: 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, fixed, thanks!

//let globalConnectionCounter = NIOAtomic<Int>.makeAtomic(value: 0)

internal struct RequestInfo: Codable {
var data: String
Expand All @@ -378,13 +382,13 @@ internal final class HttpBinHandler: ChannelInboundHandler {
let maxChannelAge: TimeAmount?
var shouldClose = false
var isServingRequest = false
let myConnectionNumber: Int
var currentRequestNumber: Int = -1
let connectionId: Int
var requestId: Int = 0

init(channelPromise: EventLoopPromise<Channel>? = nil, maxChannelAge: TimeAmount? = nil) {
init(channelPromise: EventLoopPromise<Channel>? = nil, maxChannelAge: TimeAmount? = nil, connectionId: Int) {
self.channelPromise = channelPromise
self.maxChannelAge = maxChannelAge
self.myConnectionNumber = globalConnectionCounter.add(1)
self.connectionId = connectionId
}

func handlerAdded(context: ChannelHandlerContext) {
Expand Down Expand Up @@ -424,7 +428,7 @@ internal final class HttpBinHandler: ChannelInboundHandler {
switch self.unwrapInboundIn(data) {
case .head(let req):
self.responseHeaders = HTTPHeaders()
self.currentRequestNumber = globalRequestCounter.add(1)
self.requestId += 1
self.parseAndSetOptions(from: req)
let urlComponents = URLComponents(string: req.uri)!
switch urlComponents.percentEncodedPath {
Expand Down Expand Up @@ -552,8 +556,15 @@ internal final class HttpBinHandler: ChannelInboundHandler {
context.write(wrapOutboundOut(.head(response.head)), promise: nil)
if let body = response.body {
let requestInfo = RequestInfo(data: String(buffer: body),
requestNumber: self.currentRequestNumber,
connectionNumber: self.myConnectionNumber)
requestNumber: self.requestId,
connectionNumber: self.connectionId)
let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo,
allocator: context.channel.allocator)
context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil)
} else {
let requestInfo = RequestInfo(data: "",
requestNumber: self.requestId,
connectionNumber: self.connectionId)
let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo,
allocator: context.channel.allocator)
context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil)
Expand Down
2 changes: 2 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ extension HTTPClientTests {
("testAllMethodsLog", testAllMethodsLog),
("testClosingIdleConnectionsInPoolLogsInTheBackground", testClosingIdleConnectionsInPoolLogsInTheBackground),
("testDelegateCallinsTolerateRandomEL", testDelegateCallinsTolerateRandomEL),
("testContentLengthTooLongFails", testContentLengthTooLongFails),
("testContentLengthTooShortFails", testContentLengthTooShortFails),
]
}
}
51 changes: 47 additions & 4 deletions Tests/AsyncHTTPClientTests/HTTPClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1713,7 +1713,8 @@ class HTTPClientTests: XCTestCase {

// req 1 and 2 cannot share the same connection (close header)
XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber)
XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber)
XCTAssertEqual(stats1.requestNumber, 1)
XCTAssertEqual(stats2.requestNumber, 1)

// req 2 and 3 should share the same connection (keep-alive is default)
XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber)
Expand Down Expand Up @@ -1742,7 +1743,8 @@ class HTTPClientTests: XCTestCase {

// req 1 and 2 cannot share the same connection (close header)
XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber)
XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber)
XCTAssertEqual(stats1.requestNumber, 1)
XCTAssertEqual(stats2.requestNumber, 1)

// req 2 and 3 should share the same connection (keep-alive is default)
XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber)
Expand Down Expand Up @@ -1773,7 +1775,7 @@ class HTTPClientTests: XCTestCase {

// req 1 and 2 cannot share the same connection (close header)
XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber)
XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber)
XCTAssertEqual(stats2.requestNumber, 1)

// req 2 and 3 should share the same connection (keep-alive is default)
XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber)
Expand Down Expand Up @@ -1805,7 +1807,7 @@ class HTTPClientTests: XCTestCase {

// req 1 and 2 cannot share the same connection (close header)
XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber)
XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber)
XCTAssertEqual(stats2.requestNumber, 1)

// req 2 and 3 should share the same connection (keep-alive is default)
XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber)
Expand Down Expand Up @@ -2051,4 +2053,45 @@ class HTTPClientTests: XCTestCase {

XCTAssertNoThrow(try future.wait())
}

func testContentLengthTooLongFails() throws {
let url = self.defaultHTTPBinURLPrefix + "/post"
XCTAssertThrowsError(
try self.defaultClient.execute(request:
Request(url: url,
body: .stream(length: 10) { streamWriter in
let promise = self.defaultClient.eventLoopGroup.next().makePromise(of: Void.self)
DispatchQueue(label: "content-length-test").async {
streamWriter.write(.byteBuffer(ByteBuffer(string: "1"))).cascade(to: promise)
}
return promise.futureResult
})).wait()) { error in
XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch)
}
// Quickly try another request and check that it works.
var response = try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait()
let info = try response.body!.readJSONDecodable(RequestInfo.self, length: response.body!.readableBytes)
XCTAssertEqual(info!.connectionNumber, 1)
XCTAssertEqual(info!.requestNumber, 1)
}

// currently gets stuck because of #250 the server just never replies
func testContentLengthTooShortFails() throws {
let url = self.defaultHTTPBinURLPrefix + "/post"
let tooLong = "XBAD BAD BAD NOT HTTP/1.1\r\n\r\n"
XCTAssertThrowsError(
try self.defaultClient.execute(request:
Request(url: url,
body: .stream(length: 1) { streamWriter in
streamWriter.write(.byteBuffer(ByteBuffer(string: tooLong)))
})).wait()) { error in
XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch)
}
// Quickly try another request and check that it works. If we by accident wrote some extra bytes into the
// stream (and reuse the connection) that could cause problems.
var response = try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait()
let info = try response.body!.readJSONDecodable(RequestInfo.self, length: response.body!.readableBytes)
XCTAssertEqual(info!.connectionNumber, 1)
XCTAssertEqual(info!.requestNumber, 1)
}
}
10 changes: 10 additions & 0 deletions Tests/AsyncHTTPClientTests/RequestValidationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,14 @@ class RequestValidationTests: XCTestCase {

XCTAssertNoThrow(try headers.validate(method: .GET, body: nil))
}

func testMultipleContentLengthOnNilStreamLength() {
var headers = HTTPHeaders([("Content-Length", "1"), ("Content-Length", "2")])
var buffer = ByteBufferAllocator().buffer(capacity: 10)
buffer.writeBytes([UInt8](repeating: 12, count: 10))
let body: HTTPClient.Body = .stream() { writer in
writer.write(.byteBuffer(buffer))
}
XCTAssertThrowsError(try headers.validate(method: .PUT, body: body))
}
}