Skip to content

Commit 44ddc8b

Browse files
Trevörartemredkin
authored andcommitted
Transfer scheme and host to new request when redirected (#20)
* Transfer scheme and host to new request when redirected This solves a bug that could cause infinite redirections (eg: requesting example.com, being redirected to www.example.com, but requesting again example.com as host hadn't been correctly specified) * Check that redirect correctly modifies Host to what's specified by the server response * Move test to new test case, use Decodable inline struct instead of JSONSerialization * Run swiftformat tool * Use NIOFoundationCompat helper functions * Stop using httpbin.org and use the local HttpBin server for the testHttpHostRedirect test * Reformat files * Check that redirectURL contains a host and scheme * Run swiftformat * Replace assertionFailure with preconditionFailure
1 parent b9c5535 commit 44ddc8b

File tree

8 files changed

+80
-29
lines changed

8 files changed

+80
-29
lines changed

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ let package = Package(
3131
),
3232
.testTarget(
3333
name: "NIOHTTPClientTests",
34-
dependencies: ["NIOHTTPClient"]
34+
dependencies: ["NIOHTTPClient", "NIOFoundationCompat"]
3535
),
3636
]
3737
)

Sources/NIOHTTPClient/HTTPClientProxyHandler.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChan
7474
switch res {
7575
case .head(let head):
7676
switch head.status.code {
77-
case 200..<300:
77+
case 200 ..< 300:
7878
// Any 2xx (Successful) response indicates that the sender (and all
7979
// inbound proxies) will switch to tunnel mode immediately after the
8080
// blank line that concludes the successful response's header section
@@ -116,7 +116,7 @@ internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChan
116116
private func handleConnect(context: ChannelHandlerContext) -> EventLoopFuture<Void> {
117117
return self.onConnect(context.channel).flatMap {
118118
self.readState = .connected
119-
119+
120120
// forward any buffered reads
121121
while !self.readBuffer.isEmpty {
122122
context.fireChannelRead(self.readBuffer.removeFirst())

Sources/NIOHTTPClient/HTTPCookie.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ public struct HTTPCookie {
6868
formatter.locale = Locale(identifier: "en_US")
6969
formatter.timeZone = TimeZone(identifier: "GMT")
7070
formatter.dateFormat = "EEE, dd MMM yyyy HH:mm:ss z"
71-
self.expires = parseComponentValue(component).flatMap { formatter.date(from: $0) }
71+
self.expires = self.parseComponentValue(component).flatMap { formatter.date(from: $0) }
7272
continue
7373
}
7474

Sources/NIOHTTPClient/HTTPHandler.swift

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ internal class ResponseAccumulator: HTTPClientResponseDelegate {
140140
case .body(let head, var body):
141141
var part = part
142142
body.writeBuffer(&part)
143-
state = .body(head, body)
143+
self.state = .body(head, body)
144144
case .end:
145145
preconditionFailure("request already processed")
146146
case .error:
@@ -171,7 +171,7 @@ internal class ResponseAccumulator: HTTPClientResponseDelegate {
171171
/// This delegate is strongly held by the HTTPTaskHandler
172172
/// for the duration of the HTTPRequest processing and will be
173173
/// released together with the HTTPTaskHandler when channel is closed
174-
public protocol HTTPClientResponseDelegate: class {
174+
public protocol HTTPClientResponseDelegate: AnyObject {
175175
associatedtype Response
176176

177177
func didTransmitRequestBody(task: HTTPClient.Task<Response>)
@@ -361,13 +361,13 @@ internal class TaskHandler<T: HTTPClientResponseDelegate>: ChannelInboundHandler
361361
if (event as? IdleStateHandler.IdleStateEvent) == .read {
362362
self.state = .end
363363
let error = HTTPClientError.readTimeout
364-
delegate.didReceiveError(task: self.task, error)
365-
promise.fail(error)
364+
self.delegate.didReceiveError(task: self.task, error)
365+
self.promise.fail(error)
366366
} else if (event as? TaskCancelEvent) != nil {
367367
self.state = .end
368368
let error = HTTPClientError.cancelled
369-
delegate.didReceiveError(task: self.task, error)
370-
promise.fail(error)
369+
self.delegate.didReceiveError(task: self.task, error)
370+
self.promise.fail(error)
371371
} else {
372372
context.fireUserInboundEventTriggered(event)
373373
}
@@ -380,8 +380,8 @@ internal class TaskHandler<T: HTTPClientResponseDelegate>: ChannelInboundHandler
380380
default:
381381
self.state = .end
382382
let error = HTTPClientError.remoteConnectionClosed
383-
delegate.didReceiveError(task: self.task, error)
384-
promise.fail(error)
383+
self.delegate.didReceiveError(task: self.task, error)
384+
self.promise.fail(error)
385385
}
386386
}
387387

@@ -408,7 +408,7 @@ internal class TaskHandler<T: HTTPClientResponseDelegate>: ChannelInboundHandler
408408

409409
internal struct RedirectHandler<T> {
410410
let request: HTTPClient.Request
411-
let execute: ((HTTPClient.Request) -> HTTPClient.Task<T>)
411+
let execute: (HTTPClient.Request) -> HTTPClient.Task<T>
412412

413413
func redirectTarget(status: HTTPResponseStatus, headers: HTTPHeaders) -> URL? {
414414
switch status {
@@ -443,6 +443,18 @@ internal struct RedirectHandler<T> {
443443
var request = self.request
444444
request.url = redirectURL
445445

446+
if let redirectHost = redirectURL.host {
447+
request.host = redirectHost
448+
} else {
449+
preconditionFailure("redirectURL doesn't contain a host")
450+
}
451+
452+
if let redirectScheme = redirectURL.scheme {
453+
request.scheme = redirectScheme
454+
} else {
455+
preconditionFailure("redirectURL doesn't contain a scheme")
456+
}
457+
446458
var convertToGet = false
447459
if status == .seeOther, request.method != .HEAD {
448460
convertToGet = true

Sources/NIOHTTPClient/SwiftNIOHTTP.swift

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ public class HTTPClient {
115115

116116
public func execute<T: HTTPClientResponseDelegate>(request: Request, delegate: T, timeout: Timeout? = nil) -> Task<T.Response> {
117117
let timeout = timeout ?? configuration.timeout
118-
119118
let promise: EventLoopPromise<T.Response> = self.eventLoopGroup.next().makePromise()
120119

121120
let redirectHandler: RedirectHandler<T.Response>?
@@ -151,12 +150,12 @@ public class HTTPClient {
151150
let taskHandler = TaskHandler(task: task, delegate: delegate, promise: promise, redirectHandler: redirectHandler)
152151
return channel.pipeline.addHandler(taskHandler)
153152
}
154-
}
153+
}
155154

156155
if let connectTimeout = timeout.connect {
157156
bootstrap = bootstrap.connectTimeout(connectTimeout)
158157
}
159-
158+
160159
let address = self.resolveAddress(request: request, proxy: self.configuration.proxy)
161160
bootstrap.connect(host: address.host, port: address.port)
162161
.map { channel in
@@ -172,7 +171,7 @@ public class HTTPClient {
172171
return task
173172
}
174173

175-
private func resolveAddress(request: Request, proxy: Proxy?) -> (host: String, port: Int) {
174+
private func resolveAddress(request: Request, proxy: Proxy?) -> (host: String, port: Int) {
176175
switch self.configuration.proxy {
177176
case .none:
178177
return (request.host, request.port)
@@ -216,7 +215,7 @@ public class HTTPClient {
216215
private extension ChannelPipeline {
217216
func addProxyHandler(for request: HTTPClient.Request, decoder: ByteToMessageHandler<HTTPResponseDecoder>, encoder: HTTPRequestEncoder, tlsConfiguration: TLSConfiguration?) -> EventLoopFuture<Void> {
218217
let handler = HTTPClientProxyHandler(host: request.host, port: request.port, onConnect: { channel in
219-
return channel.pipeline.removeHandler(decoder).flatMap {
218+
channel.pipeline.removeHandler(decoder).flatMap {
220219
return channel.pipeline.addHandler(
221220
ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)),
222221
position: .after(encoder)

Tests/NIOHTTPClientTests/HTTPClientTestUtils.swift

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class TestHTTPDelegate: HTTPClientResponseDelegate {
3434
case .body(let head, var body):
3535
var buffer = buffer
3636
body.writeBuffer(&buffer)
37-
state = .body(head, body)
37+
self.state = .body(head, body)
3838
default:
3939
preconditionFailure("expecting head or body")
4040
}
@@ -94,11 +94,11 @@ internal class HttpBin {
9494
}
9595

9696
init(ssl: Bool = false, simulateProxy: HTTPProxySimulator.Option? = nil) {
97-
self.serverChannel = try! ServerBootstrap(group: group)
97+
self.serverChannel = try! ServerBootstrap(group: self.group)
9898
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
9999
.childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1)
100100
.childChannelInitializer { channel in
101-
return channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap {
101+
channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap {
102102
if let simulateProxy = simulateProxy {
103103
return channel.pipeline.addHandler(HTTPProxySimulator(option: simulateProxy), position: .first)
104104
} else {
@@ -216,13 +216,27 @@ internal final class HttpBinHandler: ChannelInboundHandler {
216216
case "/redirect/302":
217217
var headers = HTTPHeaders()
218218
headers.add(name: "Location", value: "/ok")
219-
resps.append(HTTPResponseBuilder(status: .found, headers: headers))
219+
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
220220
return
221221
case "/redirect/https":
222-
let port = value(for: "port", from: url.query!)
222+
let port = self.value(for: "port", from: url.query!)
223223
var headers = HTTPHeaders()
224224
headers.add(name: "Location", value: "https://localhost:\(port)/ok")
225-
resps.append(HTTPResponseBuilder(status: .found, headers: headers))
225+
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
226+
return
227+
case "/redirect/loopback":
228+
let port = self.value(for: "port", from: url.query!)
229+
var headers = HTTPHeaders()
230+
headers.add(name: "Location", value: "http://127.0.0.1:\(port)/echohostheader")
231+
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
232+
return
233+
case "/echohostheader":
234+
var builder = HTTPResponseBuilder(status: .ok)
235+
let hostValue = req.headers["Host"].first ?? ""
236+
var buff = context.channel.allocator.buffer(capacity: hostValue.utf8.count)
237+
buff.writeString(hostValue)
238+
builder.add(buff)
239+
self.resps.append(builder)
226240
return
227241
case "/wait":
228242
return
@@ -246,14 +260,14 @@ internal final class HttpBinHandler: ChannelInboundHandler {
246260
return
247261
}
248262
case .body(let body):
249-
var response = resps.removeFirst()
263+
var response = self.resps.removeFirst()
250264
response.add(body)
251-
resps.prepend(response)
265+
self.resps.prepend(response)
252266
case .end:
253267
if self.resps.isEmpty {
254268
return
255269
}
256-
let response = resps.removeFirst()
270+
let response = self.resps.removeFirst()
257271
context.write(wrapOutboundOut(.head(response.head)), promise: nil)
258272
if let body = response.body {
259273
let data = body.withUnsafeReadableBytes {
@@ -288,7 +302,7 @@ internal final class HttpBinHandler: ChannelInboundHandler {
288302
}
289303
}
290304

291-
fileprivate let cert = """
305+
private let cert = """
292306
-----BEGIN CERTIFICATE-----
293307
MIICmDCCAYACCQCPC8JDqMh1zzANBgkqhkiG9w0BAQsFADANMQswCQYDVQQGEwJ1
294308
czAgFw0xODEwMzExNTU1MjJaGA8yMTE4MTAwNzE1NTUyMlowDTELMAkGA1UEBhMC
@@ -307,7 +321,7 @@ Au4LoEYwT730QKC/VQxxEVZobjn9/sTrq9CZlbPYHxX4fz6e00sX7H9i49vk9zQ5
307321
-----END CERTIFICATE-----
308322
"""
309323

310-
fileprivate let key = """
324+
private let key = """
311325
-----BEGIN PRIVATE KEY-----
312326
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDiC+TGmbSP/nWW
313327
N1tjyNfnWCU5ATjtIOfdtP6ycx8JSeqkvyNXG21kNUn14jTTU8BglGL2hfVpCbMi

Tests/NIOHTTPClientTests/SwiftNIOHTTPTests+XCTest.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ extension SwiftHTTPTests {
3333
("testGetHttps", testGetHttps),
3434
("testPostHttps", testPostHttps),
3535
("testHttpRedirect", testHttpRedirect),
36+
("testHttpHostRedirect", testHttpHostRedirect),
3637
("testMultipleContentLengthHeaders", testMultipleContentLengthHeaders),
3738
("testStreaming", testStreaming),
3839
("testRemoteClose", testRemoteClose),

Tests/NIOHTTPClientTests/SwiftNIOHTTPTests.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import Foundation
1616
import NIO
17+
import NIOFoundationCompat
1718
@testable import NIOHTTP1
1819
@testable import NIOHTTPClient
1920
import NIOSSL
@@ -171,6 +172,30 @@ class SwiftHTTPTests: XCTestCase {
171172
XCTAssertEqual(response.status, .ok)
172173
}
173174

175+
func testHttpHostRedirect() throws {
176+
let httpBin = HttpBin(ssl: false)
177+
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew,
178+
configuration: HTTPClient.Configuration(certificateVerification: .none, followRedirects: true))
179+
180+
defer {
181+
try! httpClient.syncShutdown()
182+
httpBin.shutdown()
183+
}
184+
185+
let response = try httpClient.get(url: "http://localhost:\(httpBin.port)/redirect/loopback?port=\(httpBin.port)").wait()
186+
guard var body = response.body else {
187+
XCTFail("The target page should have a body containing the value of the Host header")
188+
return
189+
}
190+
guard let responseData = body.readData(length: body.readableBytes) else {
191+
XCTFail("Read data shouldn't be nil since we passed body.readableBytes to body.readData")
192+
return
193+
}
194+
let decoder = JSONDecoder()
195+
let hostName = try decoder.decode([String: String].self, from: responseData)["data"]
196+
XCTAssert(hostName == "127.0.0.1")
197+
}
198+
174199
func testMultipleContentLengthHeaders() throws {
175200
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew)
176201
defer {

0 commit comments

Comments
 (0)