From f28c4c3e5c989c0494c0d96b6f77aa5c42fed6b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Stormacq?= Date: Mon, 1 Sep 2025 12:32:24 +0200 Subject: [PATCH 1/2] Split LambdaRuntimeClient in two files for easier reading --- .../LambdaRuntimeClient+ChannelHandler.swift | 486 ++++++++++++++++++ .../LambdaRuntimeClient.swift | 472 +---------------- 2 files changed, 487 insertions(+), 471 deletions(-) create mode 100644 Sources/AWSLambdaRuntime/LambdaRuntimeClient+ChannelHandler.swift diff --git a/Sources/AWSLambdaRuntime/LambdaRuntimeClient+ChannelHandler.swift b/Sources/AWSLambdaRuntime/LambdaRuntimeClient+ChannelHandler.swift new file mode 100644 index 00000000..ca043bf3 --- /dev/null +++ b/Sources/AWSLambdaRuntime/LambdaRuntimeClient+ChannelHandler.swift @@ -0,0 +1,486 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftAWSLambdaRuntime open source project +// +// Copyright (c) 2025 Apple Inc. and the SwiftAWSLambdaRuntime project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOCore +import NIOHTTP1 +import NIOPosix + +protocol LambdaChannelHandlerDelegate { + func connectionWillClose(channel: any Channel) + func connectionErrorHappened(_ error: any Error, channel: any Channel) +} + +final class LambdaChannelHandler { + let nextInvocationPath = Consts.invocationURLPrefix + Consts.getNextInvocationURLSuffix + + enum State { + case disconnected + case connected(ChannelHandlerContext, LambdaState) + case closing + + enum LambdaState { + /// this is the "normal" state. Transitions to `waitingForNextInvocation` + case idle + /// this is the state while we wait for an invocation. A next call is running. + /// Transitions to `waitingForResponse` + case waitingForNextInvocation(CheckedContinuation) + /// The invocation was forwarded to the handler and we wait for a response. + /// Transitions to `sendingResponse` or `sentResponse`. + case waitingForResponse + case sendingResponse + case sentResponse(CheckedContinuation) + } + } + + private var state: State = .disconnected + private var lastError: Error? + private var reusableErrorBuffer: ByteBuffer? + private let logger: Logger + private let delegate: Delegate + private let configuration: LambdaRuntimeClient.Configuration + + /// These are the default headers that must be sent along an invocation + let defaultHeaders: HTTPHeaders + /// These headers must be sent along an invocation or initialization error report + let errorHeaders: HTTPHeaders + /// These headers must be sent when streaming a large response + let largeResponseHeaders: HTTPHeaders + /// These headers must be sent when the handler streams its response + let streamingHeaders: HTTPHeaders + + init( + delegate: Delegate, + logger: Logger, + configuration: LambdaRuntimeClient.Configuration + ) { + self.delegate = delegate + self.logger = logger + self.configuration = configuration + self.defaultHeaders = [ + "host": "\(self.configuration.ip):\(self.configuration.port)", + "user-agent": .userAgent, + ] + self.errorHeaders = [ + "host": "\(self.configuration.ip):\(self.configuration.port)", + "user-agent": .userAgent, + "lambda-runtime-function-error-type": "Unhandled", + ] + self.largeResponseHeaders = [ + "host": "\(self.configuration.ip):\(self.configuration.port)", + "user-agent": .userAgent, + "transfer-encoding": "chunked", + ] + // https://docs.aws.amazon.com/lambda/latest/dg/runtimes-custom.html#runtimes-custom-response-streaming + // These are the headers returned by the Runtime to the Lambda Data plane. + // These are not the headers the Lambda Data plane sends to the caller of the Lambda function + // The developer of the function can set the caller's headers in the handler code. + self.streamingHeaders = [ + "host": "\(self.configuration.ip):\(self.configuration.port)", + "user-agent": .userAgent, + "Lambda-Runtime-Function-Response-Mode": "streaming", + // these are not used by this runtime client at the moment + // FIXME: the eror handling should inject these headers in the streamed response to report mid-stream errors + "Trailer": "Lambda-Runtime-Function-Error-Type, Lambda-Runtime-Function-Error-Body", + ] + } + + func nextInvocation(isolation: isolated (any Actor)? = #isolation) async throws -> Invocation { + switch self.state { + case .connected(let context, .idle): + return try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation) in + self.state = .connected(context, .waitingForNextInvocation(continuation)) + self.sendNextRequest(context: context) + } + + case .connected(_, .sendingResponse), + .connected(_, .sentResponse), + .connected(_, .waitingForNextInvocation), + .connected(_, .waitingForResponse), + .closing: + fatalError("Invalid state: \(self.state)") + + case .disconnected: + throw LambdaRuntimeError(code: .connectionToControlPlaneLost) + } + } + + func reportError( + isolation: isolated (any Actor)? = #isolation, + _ error: any Error, + requestID: String + ) async throws { + switch self.state { + case .connected(_, .waitingForNextInvocation): + fatalError("Invalid state: \(self.state)") + + case .connected(let context, .waitingForResponse): + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + self.state = .connected(context, .sentResponse(continuation)) + self.sendReportErrorRequest(requestID: requestID, error: error, context: context) + } + + case .connected(let context, .sendingResponse): + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + self.state = .connected(context, .sentResponse(continuation)) + self.sendResponseStreamingFailure(error: error, context: context) + } + + case .connected(_, .idle), + .connected(_, .sentResponse): + // The final response has already been sent. The only way to report the unhandled error + // now is to log it. Normally this library never logs higher than debug, we make an + // exception here, as there is no other way of reporting the error. + self.logger.error( + "Unhandled error after stream has finished", + metadata: [ + "lambda_request_id": "\(requestID)", + "lambda_error": "\(String(describing: error))", + ] + ) + + case .disconnected: + throw LambdaRuntimeError(code: .connectionToControlPlaneLost) + + case .closing: + throw LambdaRuntimeError(code: .connectionToControlPlaneGoingAway) + } + } + + func writeResponseBodyPart( + isolation: isolated (any Actor)? = #isolation, + _ byteBuffer: ByteBuffer, + requestID: String, + hasCustomHeaders: Bool + ) async throws { + switch self.state { + case .connected(_, .waitingForNextInvocation): + fatalError("Invalid state: \(self.state)") + + case .connected(let context, .waitingForResponse): + self.state = .connected(context, .sendingResponse) + try await self.sendResponseBodyPart( + byteBuffer, + sendHeadWithRequestID: requestID, + context: context, + hasCustomHeaders: hasCustomHeaders + ) + + case .connected(let context, .sendingResponse): + + precondition(!hasCustomHeaders, "Programming error: Custom headers should not be sent in this state") + + try await self.sendResponseBodyPart( + byteBuffer, + sendHeadWithRequestID: nil, + context: context, + hasCustomHeaders: hasCustomHeaders + ) + + case .connected(_, .idle), + .connected(_, .sentResponse): + throw LambdaRuntimeError(code: .writeAfterFinishHasBeenSent) + + case .disconnected: + throw LambdaRuntimeError(code: .connectionToControlPlaneLost) + + case .closing: + throw LambdaRuntimeError(code: .connectionToControlPlaneGoingAway) + } + } + + func finishResponseRequest( + isolation: isolated (any Actor)? = #isolation, + finalData: ByteBuffer?, + requestID: String + ) async throws { + switch self.state { + case .connected(_, .idle), + .connected(_, .waitingForNextInvocation): + fatalError("Invalid state: \(self.state)") + + case .connected(let context, .waitingForResponse): + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + self.state = .connected(context, .sentResponse(continuation)) + self.sendResponseFinish(finalData, sendHeadWithRequestID: requestID, context: context) + } + + case .connected(let context, .sendingResponse): + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + self.state = .connected(context, .sentResponse(continuation)) + self.sendResponseFinish(finalData, sendHeadWithRequestID: nil, context: context) + } + + case .connected(_, .sentResponse): + throw LambdaRuntimeError(code: .finishAfterFinishHasBeenSent) + + case .disconnected: + throw LambdaRuntimeError(code: .connectionToControlPlaneLost) + + case .closing: + throw LambdaRuntimeError(code: .connectionToControlPlaneGoingAway) + } + } + + private func sendResponseBodyPart( + isolation: isolated (any Actor)? = #isolation, + _ byteBuffer: ByteBuffer, + sendHeadWithRequestID: String?, + context: ChannelHandlerContext, + hasCustomHeaders: Bool + ) async throws { + + if let requestID = sendHeadWithRequestID { + // TODO: This feels super expensive. We should be able to make this cheaper. requestIDs are fixed length. + let url = Consts.invocationURLPrefix + "/" + requestID + Consts.postResponseURLSuffix + + var headers = self.streamingHeaders + if hasCustomHeaders { + // this header is required by Function URL when the user sends custom status code or headers + headers.add(name: "Content-Type", value: "application/vnd.awslambda.http-integration-response") + } + let httpRequest = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: url, + headers: headers + ) + + context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil) + } + + let future = context.write(self.wrapOutboundOut(.body(.byteBuffer(byteBuffer)))) + context.flush() + try await future.get() + } + + private func sendResponseFinish( + isolation: isolated (any Actor)? = #isolation, + _ byteBuffer: ByteBuffer?, + sendHeadWithRequestID: String?, + context: ChannelHandlerContext + ) { + if let requestID = sendHeadWithRequestID { + // TODO: This feels quite expensive. We should be able to make this cheaper. requestIDs are fixed length. + let url = "\(Consts.invocationURLPrefix)/\(requestID)\(Consts.postResponseURLSuffix)" + + // If we have less than 6MB, we don't want to use the streaming API. If we have more + // than 6MB, we must use the streaming mode. + var headers: HTTPHeaders! + if byteBuffer?.readableBytes ?? 0 < 6_000_000 { + headers = self.defaultHeaders + headers.add(name: "content-length", value: "\(byteBuffer?.readableBytes ?? 0)") + } else { + headers = self.largeResponseHeaders + } + let httpRequest = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: url, + headers: headers + ) + + context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil) + } + + if let byteBuffer { + context.write(self.wrapOutboundOut(.body(.byteBuffer(byteBuffer))), promise: nil) + } + + context.write(self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + } + + private func sendNextRequest(context: ChannelHandlerContext) { + let httpRequest = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: self.nextInvocationPath, + headers: self.defaultHeaders + ) + + context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil) + context.write(self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + } + + private func sendReportErrorRequest(requestID: String, error: any Error, context: ChannelHandlerContext) { + // TODO: This feels quite expensive. We should be able to make this cheaper. requestIDs are fixed length + let url = "\(Consts.invocationURLPrefix)/\(requestID)\(Consts.postErrorURLSuffix)" + + let httpRequest = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: url, + headers: self.errorHeaders + ) + + if self.reusableErrorBuffer == nil { + self.reusableErrorBuffer = context.channel.allocator.buffer(capacity: 1024) + } else { + self.reusableErrorBuffer!.clear() + } + + let errorResponse = ErrorResponse(errorType: Consts.functionError, errorMessage: "\(error)") + // TODO: Write this directly into our ByteBuffer + let bytes = errorResponse.toJSONBytes() + self.reusableErrorBuffer!.writeBytes(bytes) + + context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil) + context.write(self.wrapOutboundOut(.body(.byteBuffer(self.reusableErrorBuffer!))), promise: nil) + context.write(self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + } + + private func sendResponseStreamingFailure(error: any Error, context: ChannelHandlerContext) { + let trailers: HTTPHeaders = [ + "Lambda-Runtime-Function-Error-Type": "Unhandled", + "Lambda-Runtime-Function-Error-Body": "Requires base64", + ] + + context.write(self.wrapOutboundOut(.end(trailers)), promise: nil) + context.flush() + } + + func cancelCurrentRequestAndCloseConnection() { + fatalError("Unimplemented") + } +} + +extension LambdaChannelHandler: ChannelInboundHandler { + typealias OutboundIn = Never + typealias InboundIn = NIOHTTPClientResponseFull + typealias OutboundOut = HTTPClientRequestPart + + func handlerAdded(context: ChannelHandlerContext) { + if context.channel.isActive { + self.state = .connected(context, .idle) + } + } + + func channelActive(context: ChannelHandlerContext) { + switch self.state { + case .disconnected: + self.state = .connected(context, .idle) + case .connected: + break + case .closing: + fatalError("Invalid state: \(self.state)") + } + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let response = unwrapInboundIn(data) + + // handle response content + + switch self.state { + case .connected(let context, .waitingForNextInvocation(let continuation)): + do { + let metadata = try InvocationMetadata(headers: response.head.headers) + self.state = .connected(context, .waitingForResponse) + continuation.resume(returning: Invocation(metadata: metadata, event: response.body ?? ByteBuffer())) + } catch { + self.state = .closing + + self.delegate.connectionWillClose(channel: context.channel) + context.close(promise: nil) + continuation.resume( + throwing: LambdaRuntimeError(code: .invocationMissingMetadata, underlying: error) + ) + } + + case .connected(let context, .sentResponse(let continuation)): + if response.head.status == .accepted { + self.state = .connected(context, .idle) + continuation.resume() + } else { + self.state = .connected(context, .idle) + continuation.resume(throwing: LambdaRuntimeError(code: .unexpectedStatusCodeForRequest)) + } + + case .disconnected, .closing, .connected(_, _): + break + } + + // As defined in RFC 7230 Section 6.3: + // HTTP/1.1 defaults to the use of "persistent connections", allowing + // multiple requests and responses to be carried over a single + // connection. The "close" connection option is used to signal that a + // connection will not persist after the current request/response. HTTP + // implementations SHOULD support persistent connections. + // + // That's why we only assume the connection shall be closed if we receive + // a "connection = close" header. + let serverCloseConnection = + response.head.headers["connection"].contains(where: { $0.lowercased() == "close" }) + + let closeConnection = serverCloseConnection || response.head.version != .http1_1 + + if closeConnection { + // If we were succeeding the request promise here directly and closing the connection + // after succeeding the promise we may run into a race condition: + // + // The lambda runtime will ask for the next work item directly after a succeeded post + // response request. The desire for the next work item might be faster than the attempt + // to close the connection. This will lead to a situation where we try to the connection + // but the next request has already been scheduled on the connection that we want to + // close. For this reason we postpone succeeding the promise until the connection has + // been closed. This codepath will only be hit in the very, very unlikely event of the + // Lambda control plane demanding to close connection. (It's more or less only + // implemented to support http1.1 correctly.) This behavior is ensured with the test + // `LambdaTest.testNoKeepAliveServer`. + self.state = .closing + self.delegate.connectionWillClose(channel: context.channel) + context.close(promise: nil) + } + } + + func errorCaught(context: ChannelHandlerContext, error: Error) { + self.logger.trace( + "Channel error caught", + metadata: [ + "error": "\(error)" + ] + ) + // pending responses will fail with lastError in channelInactive since we are calling context.close + self.delegate.connectionErrorHappened(error, channel: context.channel) + + self.lastError = error + context.channel.close(promise: nil) + } + + func channelInactive(context: ChannelHandlerContext) { + // fail any pending responses with last error or assume peer disconnected + switch self.state { + case .connected(_, let lambdaState): + switch lambdaState { + case .waitingForNextInvocation(let continuation): + continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel) + case .sentResponse(let continuation): + continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel) + case .idle, .sendingResponse, .waitingForResponse: + break + } + self.state = .disconnected + default: + break + } + + // we don't need to forward channelInactive to the delegate, as the delegate observes the + // closeFuture + context.fireChannelInactive() + } +} diff --git a/Sources/AWSLambdaRuntime/LambdaRuntimeClient.swift b/Sources/AWSLambdaRuntime/LambdaRuntimeClient.swift index 1d90e2c9..50efc614 100644 --- a/Sources/AWSLambdaRuntime/LambdaRuntimeClient.swift +++ b/Sources/AWSLambdaRuntime/LambdaRuntimeClient.swift @@ -405,9 +405,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol { } extension LambdaRuntimeClient: LambdaChannelHandlerDelegate { - nonisolated func connectionErrorHappened(_ error: any Error, channel: any Channel) { - - } + nonisolated func connectionErrorHappened(_ error: any Error, channel: any Channel) {} nonisolated func connectionWillClose(channel: any Channel) { self.assumeIsolated { isolated in @@ -439,471 +437,3 @@ extension LambdaRuntimeClient: LambdaChannelHandlerDelegate { } } } - -private protocol LambdaChannelHandlerDelegate { - func connectionWillClose(channel: any Channel) - func connectionErrorHappened(_ error: any Error, channel: any Channel) -} - -private final class LambdaChannelHandler { - let nextInvocationPath = Consts.invocationURLPrefix + Consts.getNextInvocationURLSuffix - - enum State { - case disconnected - case connected(ChannelHandlerContext, LambdaState) - case closing - - enum LambdaState { - /// this is the "normal" state. Transitions to `waitingForNextInvocation` - case idle - /// this is the state while we wait for an invocation. A next call is running. - /// Transitions to `waitingForResponse` - case waitingForNextInvocation(CheckedContinuation) - /// The invocation was forwarded to the handler and we wait for a response. - /// Transitions to `sendingResponse` or `sentResponse`. - case waitingForResponse - case sendingResponse - case sentResponse(CheckedContinuation) - } - } - - private var state: State = .disconnected - private var lastError: Error? - private var reusableErrorBuffer: ByteBuffer? - private let logger: Logger - private let delegate: Delegate - private let configuration: LambdaRuntimeClient.Configuration - - /// These are the default headers that must be sent along an invocation - let defaultHeaders: HTTPHeaders - /// These headers must be sent along an invocation or initialization error report - let errorHeaders: HTTPHeaders - /// These headers must be sent when streaming a large response - let largeResponseHeaders: HTTPHeaders - /// These headers must be sent when the handler streams its response - let streamingHeaders: HTTPHeaders - - init( - delegate: Delegate, - logger: Logger, - configuration: LambdaRuntimeClient.Configuration - ) { - self.delegate = delegate - self.logger = logger - self.configuration = configuration - self.defaultHeaders = [ - "host": "\(self.configuration.ip):\(self.configuration.port)", - "user-agent": .userAgent, - ] - self.errorHeaders = [ - "host": "\(self.configuration.ip):\(self.configuration.port)", - "user-agent": .userAgent, - "lambda-runtime-function-error-type": "Unhandled", - ] - self.largeResponseHeaders = [ - "host": "\(self.configuration.ip):\(self.configuration.port)", - "user-agent": .userAgent, - "transfer-encoding": "chunked", - ] - // https://docs.aws.amazon.com/lambda/latest/dg/runtimes-custom.html#runtimes-custom-response-streaming - // These are the headers returned by the Runtime to the Lambda Data plane. - // These are not the headers the Lambda Data plane sends to the caller of the Lambda function - // The developer of the function can set the caller's headers in the handler code. - self.streamingHeaders = [ - "host": "\(self.configuration.ip):\(self.configuration.port)", - "user-agent": .userAgent, - "Lambda-Runtime-Function-Response-Mode": "streaming", - // these are not used by this runtime client at the moment - // FIXME: the eror handling should inject these headers in the streamed response to report mid-stream errors - "Trailer": "Lambda-Runtime-Function-Error-Type, Lambda-Runtime-Function-Error-Body", - ] - } - - func nextInvocation(isolation: isolated (any Actor)? = #isolation) async throws -> Invocation { - switch self.state { - case .connected(let context, .idle): - return try await withCheckedThrowingContinuation { - (continuation: CheckedContinuation) in - self.state = .connected(context, .waitingForNextInvocation(continuation)) - self.sendNextRequest(context: context) - } - - case .connected(_, .sendingResponse), - .connected(_, .sentResponse), - .connected(_, .waitingForNextInvocation), - .connected(_, .waitingForResponse), - .closing: - fatalError("Invalid state: \(self.state)") - - case .disconnected: - throw LambdaRuntimeError(code: .connectionToControlPlaneLost) - } - } - - func reportError( - isolation: isolated (any Actor)? = #isolation, - _ error: any Error, - requestID: String - ) async throws { - switch self.state { - case .connected(_, .waitingForNextInvocation): - fatalError("Invalid state: \(self.state)") - - case .connected(let context, .waitingForResponse): - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - self.state = .connected(context, .sentResponse(continuation)) - self.sendReportErrorRequest(requestID: requestID, error: error, context: context) - } - - case .connected(let context, .sendingResponse): - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - self.state = .connected(context, .sentResponse(continuation)) - self.sendResponseStreamingFailure(error: error, context: context) - } - - case .connected(_, .idle), - .connected(_, .sentResponse): - // The final response has already been sent. The only way to report the unhandled error - // now is to log it. Normally this library never logs higher than debug, we make an - // exception here, as there is no other way of reporting the error. - self.logger.error( - "Unhandled error after stream has finished", - metadata: [ - "lambda_request_id": "\(requestID)", - "lambda_error": "\(String(describing: error))", - ] - ) - - case .disconnected: - throw LambdaRuntimeError(code: .connectionToControlPlaneLost) - - case .closing: - throw LambdaRuntimeError(code: .connectionToControlPlaneGoingAway) - } - } - - func writeResponseBodyPart( - isolation: isolated (any Actor)? = #isolation, - _ byteBuffer: ByteBuffer, - requestID: String, - hasCustomHeaders: Bool - ) async throws { - switch self.state { - case .connected(_, .waitingForNextInvocation): - fatalError("Invalid state: \(self.state)") - - case .connected(let context, .waitingForResponse): - self.state = .connected(context, .sendingResponse) - try await self.sendResponseBodyPart( - byteBuffer, - sendHeadWithRequestID: requestID, - context: context, - hasCustomHeaders: hasCustomHeaders - ) - - case .connected(let context, .sendingResponse): - - precondition(!hasCustomHeaders, "Programming error: Custom headers should not be sent in this state") - - try await self.sendResponseBodyPart( - byteBuffer, - sendHeadWithRequestID: nil, - context: context, - hasCustomHeaders: hasCustomHeaders - ) - - case .connected(_, .idle), - .connected(_, .sentResponse): - throw LambdaRuntimeError(code: .writeAfterFinishHasBeenSent) - - case .disconnected: - throw LambdaRuntimeError(code: .connectionToControlPlaneLost) - - case .closing: - throw LambdaRuntimeError(code: .connectionToControlPlaneGoingAway) - } - } - - func finishResponseRequest( - isolation: isolated (any Actor)? = #isolation, - finalData: ByteBuffer?, - requestID: String - ) async throws { - switch self.state { - case .connected(_, .idle), - .connected(_, .waitingForNextInvocation): - fatalError("Invalid state: \(self.state)") - - case .connected(let context, .waitingForResponse): - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - self.state = .connected(context, .sentResponse(continuation)) - self.sendResponseFinish(finalData, sendHeadWithRequestID: requestID, context: context) - } - - case .connected(let context, .sendingResponse): - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - self.state = .connected(context, .sentResponse(continuation)) - self.sendResponseFinish(finalData, sendHeadWithRequestID: nil, context: context) - } - - case .connected(_, .sentResponse): - throw LambdaRuntimeError(code: .finishAfterFinishHasBeenSent) - - case .disconnected: - throw LambdaRuntimeError(code: .connectionToControlPlaneLost) - - case .closing: - throw LambdaRuntimeError(code: .connectionToControlPlaneGoingAway) - } - } - - private func sendResponseBodyPart( - isolation: isolated (any Actor)? = #isolation, - _ byteBuffer: ByteBuffer, - sendHeadWithRequestID: String?, - context: ChannelHandlerContext, - hasCustomHeaders: Bool - ) async throws { - - if let requestID = sendHeadWithRequestID { - // TODO: This feels super expensive. We should be able to make this cheaper. requestIDs are fixed length. - let url = Consts.invocationURLPrefix + "/" + requestID + Consts.postResponseURLSuffix - - var headers = self.streamingHeaders - if hasCustomHeaders { - // this header is required by Function URL when the user sends custom status code or headers - headers.add(name: "Content-Type", value: "application/vnd.awslambda.http-integration-response") - } - let httpRequest = HTTPRequestHead( - version: .http1_1, - method: .POST, - uri: url, - headers: headers - ) - - context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil) - } - - let future = context.write(self.wrapOutboundOut(.body(.byteBuffer(byteBuffer)))) - context.flush() - try await future.get() - } - - private func sendResponseFinish( - isolation: isolated (any Actor)? = #isolation, - _ byteBuffer: ByteBuffer?, - sendHeadWithRequestID: String?, - context: ChannelHandlerContext - ) { - if let requestID = sendHeadWithRequestID { - // TODO: This feels quite expensive. We should be able to make this cheaper. requestIDs are fixed length. - let url = "\(Consts.invocationURLPrefix)/\(requestID)\(Consts.postResponseURLSuffix)" - - // If we have less than 6MB, we don't want to use the streaming API. If we have more - // than 6MB, we must use the streaming mode. - var headers: HTTPHeaders! - if byteBuffer?.readableBytes ?? 0 < 6_000_000 { - headers = self.defaultHeaders - headers.add(name: "content-length", value: "\(byteBuffer?.readableBytes ?? 0)") - } else { - headers = self.largeResponseHeaders - } - let httpRequest = HTTPRequestHead( - version: .http1_1, - method: .POST, - uri: url, - headers: headers - ) - - context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil) - } - - if let byteBuffer { - context.write(self.wrapOutboundOut(.body(.byteBuffer(byteBuffer))), promise: nil) - } - - context.write(self.wrapOutboundOut(.end(nil)), promise: nil) - context.flush() - } - - private func sendNextRequest(context: ChannelHandlerContext) { - let httpRequest = HTTPRequestHead( - version: .http1_1, - method: .GET, - uri: self.nextInvocationPath, - headers: self.defaultHeaders - ) - - context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil) - context.write(self.wrapOutboundOut(.end(nil)), promise: nil) - context.flush() - } - - private func sendReportErrorRequest(requestID: String, error: any Error, context: ChannelHandlerContext) { - // TODO: This feels quite expensive. We should be able to make this cheaper. requestIDs are fixed length - let url = "\(Consts.invocationURLPrefix)/\(requestID)\(Consts.postErrorURLSuffix)" - - let httpRequest = HTTPRequestHead( - version: .http1_1, - method: .POST, - uri: url, - headers: self.errorHeaders - ) - - if self.reusableErrorBuffer == nil { - self.reusableErrorBuffer = context.channel.allocator.buffer(capacity: 1024) - } else { - self.reusableErrorBuffer!.clear() - } - - let errorResponse = ErrorResponse(errorType: Consts.functionError, errorMessage: "\(error)") - // TODO: Write this directly into our ByteBuffer - let bytes = errorResponse.toJSONBytes() - self.reusableErrorBuffer!.writeBytes(bytes) - - context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil) - context.write(self.wrapOutboundOut(.body(.byteBuffer(self.reusableErrorBuffer!))), promise: nil) - context.write(self.wrapOutboundOut(.end(nil)), promise: nil) - context.flush() - } - - private func sendResponseStreamingFailure(error: any Error, context: ChannelHandlerContext) { - let trailers: HTTPHeaders = [ - "Lambda-Runtime-Function-Error-Type": "Unhandled", - "Lambda-Runtime-Function-Error-Body": "Requires base64", - ] - - context.write(self.wrapOutboundOut(.end(trailers)), promise: nil) - context.flush() - } - - func cancelCurrentRequestAndCloseConnection() { - fatalError("Unimplemented") - } -} - -extension LambdaChannelHandler: ChannelInboundHandler { - typealias OutboundIn = Never - typealias InboundIn = NIOHTTPClientResponseFull - typealias OutboundOut = HTTPClientRequestPart - - func handlerAdded(context: ChannelHandlerContext) { - if context.channel.isActive { - self.state = .connected(context, .idle) - } - } - - func channelActive(context: ChannelHandlerContext) { - switch self.state { - case .disconnected: - self.state = .connected(context, .idle) - case .connected: - break - case .closing: - fatalError("Invalid state: \(self.state)") - } - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let response = unwrapInboundIn(data) - - // handle response content - - switch self.state { - case .connected(let context, .waitingForNextInvocation(let continuation)): - do { - let metadata = try InvocationMetadata(headers: response.head.headers) - self.state = .connected(context, .waitingForResponse) - continuation.resume(returning: Invocation(metadata: metadata, event: response.body ?? ByteBuffer())) - } catch { - self.state = .closing - - self.delegate.connectionWillClose(channel: context.channel) - context.close(promise: nil) - continuation.resume( - throwing: LambdaRuntimeError(code: .invocationMissingMetadata, underlying: error) - ) - } - - case .connected(let context, .sentResponse(let continuation)): - if response.head.status == .accepted { - self.state = .connected(context, .idle) - continuation.resume() - } else { - self.state = .connected(context, .idle) - continuation.resume(throwing: LambdaRuntimeError(code: .unexpectedStatusCodeForRequest)) - } - - case .disconnected, .closing, .connected(_, _): - break - } - - // As defined in RFC 7230 Section 6.3: - // HTTP/1.1 defaults to the use of "persistent connections", allowing - // multiple requests and responses to be carried over a single - // connection. The "close" connection option is used to signal that a - // connection will not persist after the current request/response. HTTP - // implementations SHOULD support persistent connections. - // - // That's why we only assume the connection shall be closed if we receive - // a "connection = close" header. - let serverCloseConnection = - response.head.headers["connection"].contains(where: { $0.lowercased() == "close" }) - - let closeConnection = serverCloseConnection || response.head.version != .http1_1 - - if closeConnection { - // If we were succeeding the request promise here directly and closing the connection - // after succeeding the promise we may run into a race condition: - // - // The lambda runtime will ask for the next work item directly after a succeeded post - // response request. The desire for the next work item might be faster than the attempt - // to close the connection. This will lead to a situation where we try to the connection - // but the next request has already been scheduled on the connection that we want to - // close. For this reason we postpone succeeding the promise until the connection has - // been closed. This codepath will only be hit in the very, very unlikely event of the - // Lambda control plane demanding to close connection. (It's more or less only - // implemented to support http1.1 correctly.) This behavior is ensured with the test - // `LambdaTest.testNoKeepAliveServer`. - self.state = .closing - self.delegate.connectionWillClose(channel: context.channel) - context.close(promise: nil) - } - } - - func errorCaught(context: ChannelHandlerContext, error: Error) { - self.logger.trace( - "Channel error caught", - metadata: [ - "error": "\(error)" - ] - ) - // pending responses will fail with lastError in channelInactive since we are calling context.close - self.delegate.connectionErrorHappened(error, channel: context.channel) - - self.lastError = error - context.channel.close(promise: nil) - } - - func channelInactive(context: ChannelHandlerContext) { - // fail any pending responses with last error or assume peer disconnected - switch self.state { - case .connected(_, let lambdaState): - switch lambdaState { - case .waitingForNextInvocation(let continuation): - continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel) - case .sentResponse(let continuation): - continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel) - case .idle, .sendingResponse, .waitingForResponse: - break - } - self.state = .disconnected - default: - break - } - - // we don't need to forward channelInactive to the delegate, as the delegate observes the - // closeFuture - context.fireChannelInactive() - } -} From 29ae761038aea2b15e0b34bcec7d06ea58a9f558 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Stormacq?= Date: Mon, 1 Sep 2025 12:54:41 +0200 Subject: [PATCH 2/2] update visibility from package to internal --- .../AWSLambdaRuntime/LambdaRuntimeClient+ChannelHandler.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/AWSLambdaRuntime/LambdaRuntimeClient+ChannelHandler.swift b/Sources/AWSLambdaRuntime/LambdaRuntimeClient+ChannelHandler.swift index ca043bf3..e69cb032 100644 --- a/Sources/AWSLambdaRuntime/LambdaRuntimeClient+ChannelHandler.swift +++ b/Sources/AWSLambdaRuntime/LambdaRuntimeClient+ChannelHandler.swift @@ -17,12 +17,12 @@ import NIOCore import NIOHTTP1 import NIOPosix -protocol LambdaChannelHandlerDelegate { +internal protocol LambdaChannelHandlerDelegate { func connectionWillClose(channel: any Channel) func connectionErrorHappened(_ error: any Error, channel: any Channel) } -final class LambdaChannelHandler { +internal final class LambdaChannelHandler { let nextInvocationPath = Consts.invocationURLPrefix + Consts.getNextInvocationURLSuffix enum State {