This repository was archived by the owner on Apr 7, 2022. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathHTTPServer.swift
More file actions
309 lines (280 loc) · 13.4 KB
/
HTTPServer.swift
File metadata and controls
309 lines (280 loc) · 13.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
/// Simple HTTP server generic on an HTTP responder
/// that will be used to generate responses to incoming requests.
///
/// let server = try HTTPServer.start(hostname: hostname, port: port, responder: EchoResponder(), on: group).wait()
/// try server.onClose.wait()
///
public final class HTTPServer {
// MARK: Static
/// Starts the server on the supplied hostname and port, using the supplied
/// responder to generate HTTP responses for incoming requests.
///
/// let server = try HTTPServer.start(hostname: hostname, port: port, responder: EchoResponder(), on: group).wait()
/// try server.onClose.wait()
///
/// - parameters:
/// - hostname: Socket hostname to bind to. Usually `localhost` or `::1`.
/// - port: Socket port to bind to. Usually `8080` for development and `80` for production.
/// - responder: Used to generate responses for incoming requests.
/// - maxBodySize: Requests with bodies larger than this maximum will be rejected.
/// Streaming bodies, like chunked bodies, ignore this maximum.
/// - backlog: OS socket backlog size.
/// - reuseAddress: When `true`, can prevent errors re-binding to a socket after successive server restarts.
/// - tcpNoDelay: When `true`, OS will attempt to minimize TCP packet delay.
/// - supportCompression: When `true`, HTTP server will support gzip and deflate compression.
/// - serverName: If set, this name will be serialized as the `Server` header in outgoing responses.
/// - upgraders: An array of `HTTPProtocolUpgrader` to check for with each request.
/// - worker: `Worker` to perform async work on.
/// - onError: Any uncaught server or responder errors will go here.
public static func start<R>(
hostname: String,
port: Int,
responder: R,
maxBodySize: Int = 1_000_000,
backlog: Int = 256,
reuseAddress: Bool = true,
tcpNoDelay: Bool = true,
supportCompression: Bool = false,
serverName: String? = nil,
upgraders: [HTTPProtocolUpgrader] = [],
on worker: Worker,
onError: @escaping (Error) -> () = { _ in }
) -> Future<HTTPServer> where R: HTTPServerResponder {
#if os(Linux)
let bootstrapType = ServerBootstrap.self
#else
let bootstrapType = NIOTSListenerBootstrap.self
#endif
let bootstrap = bootstrapType.init(group: worker as! NIOTSEventLoopGroup)
// Specify backlog and enable SO_REUSEADDR for the server itself
.serverChannelOption(ChannelOptions.backlog, value: Int32(backlog))
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: reuseAddress ? SocketOptionValue(1) : SocketOptionValue(0))
// Set the handlers that are applied to the accepted Channels
.childChannelInitializer { channel in
// create HTTPServerResponder-based handler
let handler = HTTPServerHandler(responder: responder, maxBodySize: maxBodySize, serverHeader: serverName, onError: onError)
// re-use subcontainer for an event loop here
let upgrade: HTTPUpgradeConfiguration = (upgraders: upgraders, completionHandler: { ctx in
// shouldn't need to wait for this
_ = channel.pipeline.remove(handler: handler)
})
// configure the pipeline
return channel.pipeline.configureHTTPServerPipeline(
withPipeliningAssistance: false,
withServerUpgrade: upgrade,
withErrorHandling: false
).then {
if supportCompression {
return channel.pipeline.addHandlers([HTTPResponseCompressor(), handler], first: false)
} else {
return channel.pipeline.add(handler: handler)
}
}
}
// Enable TCP_NODELAY and SO_REUSEADDR for the accepted Channels
.childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: tcpNoDelay ? SocketOptionValue(1) : SocketOptionValue(0))
.childChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: reuseAddress ? SocketOptionValue(1) : SocketOptionValue(0))
.childChannelOption(ChannelOptions.maxMessagesPerRead, value: 1)
// .childChannelOption(ChannelOptions.allowRemoteHalfClosure, value: 1)
return bootstrap.bind(host: hostname, port: port).map(to: HTTPServer.self) { channel in
return HTTPServer(channel: channel)
}
}
// MARK: Properties
/// A future that will be signaled when the server closes.
public var onClose: Future<Void> {
return channel.closeFuture
}
/// The running channel.
private var channel: Channel
/// Creates a new `HTTPServer`. Use the public static `.start` method.
private init(channel: Channel) {
self.channel = channel
}
// MARK: Methods
/// Closes the server.
public func close() -> Future<Void> {
return channel.close(mode: .all)
}
}
// MARK: Private
/// Private `ChannelInboundHandler` that converts `HTTPServerRequestPart` to `HTTPServerResponsePart`.
private final class HTTPServerHandler<R>: ChannelInboundHandler where R: HTTPServerResponder {
/// See `ChannelInboundHandler`.
public typealias InboundIn = HTTPServerRequestPart
/// See `ChannelInboundHandler`.
public typealias OutboundOut = HTTPServerResponsePart
/// The responder generating `HTTPResponse`s for incoming `HTTPRequest`s.
public let responder: R
/// Maximum body size allowed per request.
private let maxBodySize: Int
/// Handles any errors that may occur.
private let errorHandler: (Error) -> ()
/// Optional server header.
private let serverHeader: String?
/// Current HTTP state.
var state: HTTPServerState
/// Create a new `HTTPServerHandler`.
init(responder: R, maxBodySize: Int = 1_000_000, serverHeader: String?, onError: @escaping (Error) -> ()) {
self.responder = responder
self.maxBodySize = maxBodySize
self.errorHandler = onError
self.serverHeader = serverHeader
self.state = .ready
}
/// See `ChannelInboundHandler`.
func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
debugOnly { assert(ctx.channel.eventLoop.inEventLoop) }
switch unwrapInboundIn(data) {
case .head(let head):
debugOnly {
/// only perform this switch in debug mode
switch state {
case .ready: break
default: assertionFailure("Unexpected state: \(state)")
}
}
state = .awaitingBody(head)
case .body(var chunk):
switch state {
case .ready: debugOnly { assertionFailure("Unexpected state: \(state)") }
case .awaitingBody(let head):
/// 1: check to see which kind of body we are parsing from the head
///
/// short circuit on `contains(name:)` which is faster
/// - note: for some reason using String instead of HTTPHeaderName is faster here...
/// this will be standardized when NIO gets header names
if head.headers.contains(name: "Transfer-Encoding"), head.headers.firstValue(name: .transferEncoding) == "chunked" {
let stream = HTTPChunkedStream(on: ctx.eventLoop)
state = .streamingBody(stream)
respond(to: head, body: .init(chunked: stream), ctx: ctx)
} else {
state = .collectingBody(head, nil)
}
/// 2: perform the actual body read now
channelRead(ctx: ctx, data: data)
case .collectingBody(let head, let existingBody):
let body: ByteBuffer
if var existing = existingBody {
if existing.readableBytes + chunk.readableBytes > self.maxBodySize {
ERROR("[HTTP] Request size exceeded maximum, connection closed.")
ctx.close(promise: nil)
}
existing.write(buffer: &chunk)
body = existing
} else {
body = chunk
}
state = .collectingBody(head, body)
case .streamingBody(let stream): _ = stream.write(.chunk(chunk))
}
case .end(let tailHeaders):
debugOnly { assert(tailHeaders == nil, "Tail headers are not supported.") }
switch state {
case .ready: debugOnly { assertionFailure("Unexpected state: \(state)") }
case .awaitingBody(let head): respond(to: head, body: .empty, ctx: ctx)
case .collectingBody(let head, let body):
let body: HTTPBody = body.flatMap(HTTPBody.init(buffer:)) ?? .empty
respond(to: head, body: body, ctx: ctx)
case .streamingBody(let stream): _ = stream.write(.end)
}
state = .ready
}
}
/// Requests an `HTTPResponse` from the responder and serializes it.
private func respond(to head: HTTPRequestHead, body: HTTPBody, ctx: ChannelHandlerContext) {
var req = HTTPRequest(head: head, body: body, channel: ctx.channel)
switch head.method {
case .HEAD: req.method = .GET
default: break
}
let res = responder.respond(to: req, on: ctx.eventLoop)
res.whenSuccess { res in
debugOnly {
switch body.storage {
case .chunkedStream(let stream):
if !stream.isClosed {
ERROR("HTTPResponse sent while HTTPRequest had unconsumed chunked data.")
}
default: break
}
}
self.serialize(res, for: head, ctx: ctx)
}
res.whenFailure { error in
self.errorHandler(error)
ctx.close(promise: nil)
}
}
/// Serializes the `HTTPResponse`.
private func serialize(_ res: HTTPResponse, for reqhead: HTTPRequestHead, ctx: ChannelHandlerContext) {
// add a RFC1123 timestamp to the Date header to make this
// a valid request
var reshead = res.head
reshead.headers.add(name: "date", value: RFC1123DateCache.shared.currentTimestamp())
if let server = serverHeader {
reshead.headers.add(name: "server", value: server)
}
// begin serializing
ctx.write(wrapOutboundOut(.head(reshead)), promise: nil)
if reqhead.method == .HEAD {
// skip sending the body for HEAD requests
ctx.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
} else {
switch res.body.storage {
case .none: ctx.writeAndFlush(wrapOutboundOut(.end(nil)), promise: nil)
case .buffer(let buffer): writeAndflush(buffer: buffer, ctx: ctx)
case .string(let string):
var buffer = ctx.channel.allocator.buffer(capacity: string.count)
buffer.write(string: string)
writeAndflush(buffer: buffer, ctx: ctx)
case .staticString(let string):
var buffer = ctx.channel.allocator.buffer(capacity: string.count)
buffer.write(staticString: string)
writeAndflush(buffer: buffer, ctx: ctx)
case .data(let data):
var buffer = ctx.channel.allocator.buffer(capacity: data.count)
buffer.write(bytes: data)
writeAndflush(buffer: buffer, ctx: ctx)
case .dispatchData(let data):
var buffer = ctx.channel.allocator.buffer(capacity: data.count)
buffer.write(bytes: data)
writeAndflush(buffer: buffer, ctx: ctx)
case .chunkedStream(let stream):
stream.read { result, stream in
switch result {
case .chunk(let buffer): return ctx.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(buffer))))
case .end: return ctx.writeAndFlush(self.wrapOutboundOut(.end(nil)))
case .error(let error):
self.errorHandler(error)
return ctx.writeAndFlush(self.wrapOutboundOut(.end(nil)))
}
}
}
}
}
/// Writes a `ByteBuffer` to the ctx.
private func writeAndflush(buffer: ByteBuffer, ctx: ChannelHandlerContext) {
if buffer.readableBytes > 0 {
ctx.write(wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil)
}
ctx.writeAndFlush(wrapOutboundOut(.end(nil)), promise: nil)
}
/// See `ChannelInboundHandler`.
func errorCaught(ctx: ChannelHandlerContext, error: Error) {
errorHandler(error)
}
}
/// Tracks current HTTP server state
private enum HTTPServerState {
/// Waiting for request headers
case ready
/// Waiting for the body
/// This allows for performance optimization incase
/// a body never comes
case awaitingBody(HTTPRequestHead)
/// Collecting fixed-length body
case collectingBody(HTTPRequestHead, ByteBuffer?)
/// Collecting streaming body
case streamingBody(HTTPChunkedStream)
}