|
2 | 2 | //
|
3 | 3 | // This source file is part of the SwiftAWSLambdaRuntime open source project
|
4 | 4 | //
|
5 |
| -// Copyright (c) 2020 Apple Inc. and the SwiftAWSLambdaRuntime project authors |
| 5 | +// Copyright (c) 2025 Apple Inc. and the SwiftAWSLambdaRuntime project authors |
6 | 6 | // Licensed under Apache License v2.0
|
7 | 7 | //
|
8 | 8 | // See LICENSE.txt for license information
|
@@ -76,7 +76,7 @@ extension Lambda {
|
76 | 76 | /// 1. POST /invoke - the client posts the event to the lambda function
|
77 | 77 | ///
|
78 | 78 | /// This server passes the data received from /invoke POST request to the lambda function (GET /next) and then forwards the response back to the client.
|
79 |
| -private struct LambdaHTTPServer { |
| 79 | +internal struct LambdaHTTPServer { |
80 | 80 | private let invocationEndpoint: String
|
81 | 81 |
|
82 | 82 | private let invocationPool = Pool<LocalServerInvocation>()
|
@@ -166,17 +166,21 @@ private struct LambdaHTTPServer {
|
166 | 166 | // consumed by iterating the group or by exiting the group. Since, we are never consuming
|
167 | 167 | // the results of the group we need the group to automatically discard them; otherwise, this
|
168 | 168 | // would result in a memory leak over time.
|
169 |
| - try await withThrowingDiscardingTaskGroup { taskGroup in |
170 |
| - try await channel.executeThenClose { inbound in |
171 |
| - for try await connectionChannel in inbound { |
172 |
| - |
173 |
| - taskGroup.addTask { |
174 |
| - logger.trace("Handling a new connection") |
175 |
| - await server.handleConnection(channel: connectionChannel, logger: logger) |
176 |
| - logger.trace("Done handling the connection") |
| 169 | + try await withTaskCancellationHandler { |
| 170 | + try await withThrowingDiscardingTaskGroup { taskGroup in |
| 171 | + try await channel.executeThenClose { inbound in |
| 172 | + for try await connectionChannel in inbound { |
| 173 | + |
| 174 | + taskGroup.addTask { |
| 175 | + logger.trace("Handling a new connection") |
| 176 | + await server.handleConnection(channel: connectionChannel, logger: logger) |
| 177 | + logger.trace("Done handling the connection") |
| 178 | + } |
177 | 179 | }
|
178 | 180 | }
|
179 | 181 | }
|
| 182 | + } onCancel: { |
| 183 | + channel.channel.close(promise: nil) |
180 | 184 | }
|
181 | 185 | return .serverReturned(.success(()))
|
182 | 186 | } catch {
|
@@ -230,38 +234,42 @@ private struct LambdaHTTPServer {
|
230 | 234 | // Note that this method is non-throwing and we are catching any error.
|
231 | 235 | // We do this since we don't want to tear down the whole server when a single connection
|
232 | 236 | // encounters an error.
|
233 |
| - do { |
234 |
| - try await channel.executeThenClose { inbound, outbound in |
235 |
| - for try await inboundData in inbound { |
236 |
| - switch inboundData { |
237 |
| - case .head(let head): |
238 |
| - requestHead = head |
239 |
| - |
240 |
| - case .body(let body): |
241 |
| - requestBody.setOrWriteImmutableBuffer(body) |
242 |
| - |
243 |
| - case .end: |
244 |
| - precondition(requestHead != nil, "Received .end without .head") |
245 |
| - // process the request |
246 |
| - let response = try await self.processRequest( |
247 |
| - head: requestHead, |
248 |
| - body: requestBody, |
249 |
| - logger: logger |
250 |
| - ) |
251 |
| - // send the responses |
252 |
| - try await self.sendResponse( |
253 |
| - response: response, |
254 |
| - outbound: outbound, |
255 |
| - logger: logger |
256 |
| - ) |
257 |
| - |
258 |
| - requestHead = nil |
259 |
| - requestBody = nil |
| 237 | + await withTaskCancellationHandler { |
| 238 | + do { |
| 239 | + try await channel.executeThenClose { inbound, outbound in |
| 240 | + for try await inboundData in inbound { |
| 241 | + switch inboundData { |
| 242 | + case .head(let head): |
| 243 | + requestHead = head |
| 244 | + |
| 245 | + case .body(let body): |
| 246 | + requestBody.setOrWriteImmutableBuffer(body) |
| 247 | + |
| 248 | + case .end: |
| 249 | + precondition(requestHead != nil, "Received .end without .head") |
| 250 | + // process the request |
| 251 | + let response = try await self.processRequest( |
| 252 | + head: requestHead, |
| 253 | + body: requestBody, |
| 254 | + logger: logger |
| 255 | + ) |
| 256 | + // send the responses |
| 257 | + try await self.sendResponse( |
| 258 | + response: response, |
| 259 | + outbound: outbound, |
| 260 | + logger: logger |
| 261 | + ) |
| 262 | + |
| 263 | + requestHead = nil |
| 264 | + requestBody = nil |
| 265 | + } |
260 | 266 | }
|
261 | 267 | }
|
| 268 | + } catch { |
| 269 | + logger.error("Hit error: \(error)") |
262 | 270 | }
|
263 |
| - } catch { |
264 |
| - logger.error("Hit error: \(error)") |
| 271 | + } onCancel: { |
| 272 | + channel.channel.close(promise: nil) |
265 | 273 | }
|
266 | 274 | }
|
267 | 275 |
|
@@ -426,7 +434,7 @@ private struct LambdaHTTPServer {
|
426 | 434 | /// A shared data structure to store the current invocation or response requests and the continuation objects.
|
427 | 435 | /// This data structure is shared between instances of the HTTPHandler
|
428 | 436 | /// (one instance to serve requests from the Lambda function and one instance to serve requests from the client invoking the lambda function).
|
429 |
| - private final class Pool<T>: AsyncSequence, AsyncIteratorProtocol, Sendable where T: Sendable { |
| 437 | + internal final class Pool<T>: AsyncSequence, AsyncIteratorProtocol, Sendable where T: Sendable { |
430 | 438 | typealias Element = T
|
431 | 439 |
|
432 | 440 | enum State: ~Copyable {
|
@@ -462,26 +470,38 @@ private struct LambdaHTTPServer {
|
462 | 470 | return nil
|
463 | 471 | }
|
464 | 472 |
|
465 |
| - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in |
466 |
| - let nextAction = self.lock.withLock { state -> T? in |
467 |
| - switch consume state { |
468 |
| - case .buffer(var buffer): |
469 |
| - if let first = buffer.popFirst() { |
470 |
| - state = .buffer(buffer) |
471 |
| - return first |
472 |
| - } else { |
473 |
| - state = .continuation(continuation) |
474 |
| - return nil |
475 |
| - } |
| 473 | + return try await withTaskCancellationHandler { |
| 474 | + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in |
| 475 | + let nextAction = self.lock.withLock { state -> T? in |
| 476 | + switch consume state { |
| 477 | + case .buffer(var buffer): |
| 478 | + if let first = buffer.popFirst() { |
| 479 | + state = .buffer(buffer) |
| 480 | + return first |
| 481 | + } else { |
| 482 | + state = .continuation(continuation) |
| 483 | + return nil |
| 484 | + } |
476 | 485 |
|
477 |
| - case .continuation: |
478 |
| - fatalError("Concurrent invocations to next(). This is illegal.") |
| 486 | + case .continuation: |
| 487 | + fatalError("Concurrent invocations to next(). This is illegal.") |
| 488 | + } |
479 | 489 | }
|
480 |
| - } |
481 | 490 |
|
482 |
| - guard let nextAction else { return } |
| 491 | + guard let nextAction else { return } |
483 | 492 |
|
484 |
| - continuation.resume(returning: nextAction) |
| 493 | + continuation.resume(returning: nextAction) |
| 494 | + } |
| 495 | + } onCancel: { |
| 496 | + self.lock.withLock { state in |
| 497 | + switch consume state { |
| 498 | + case .buffer(let buffer): |
| 499 | + state = .buffer(buffer) |
| 500 | + case .continuation(let continuation): |
| 501 | + continuation?.resume(throwing: CancellationError()) |
| 502 | + state = .buffer([]) |
| 503 | + } |
| 504 | + } |
485 | 505 | }
|
486 | 506 | }
|
487 | 507 |
|
|
0 commit comments