Skip to content

Commit d1f15ca

Browse files
committed
simplify code by passing hasCustomHeaders as function arg and not a state
1 parent 90501a0 commit d1f15ca

File tree

7 files changed

+53
-82
lines changed

7 files changed

+53
-82
lines changed

Sources/AWSLambdaRuntime/LambdaHandlers.swift

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,14 @@ public protocol StreamingLambdaHandler: _Lambda_SendableMetatype {
4949
public protocol LambdaResponseStreamWriter {
5050
/// Write a response part into the stream. Bytes written are streamed continually.
5151
/// - Parameter buffer: The buffer to write.
52-
func write(_ buffer: ByteBuffer) async throws
52+
func write(_ buffer: ByteBuffer, hasCustomHeaders: Bool) async throws
5353

5454
/// End the response stream and the underlying HTTP response.
5555
func finish() async throws
5656

5757
/// Write a response part into the stream and then end the stream as well as the underlying HTTP response.
5858
/// - Parameter buffer: The buffer to write.
5959
func writeAndFinish(_ buffer: ByteBuffer) async throws
60-
61-
/// Write a response part into the stream.
62-
// In the context of streaming Lambda, this is used to allow the user
63-
// to send custom headers or statusCode.
64-
/// - Note: user should use the writeStatusAndHeaders(:StreamingLambdaStatusAndHeadersResponse)
65-
// function to write the status code and headers
66-
/// - Parameter buffer: The buffer corresponding to the status code and headers to write.
67-
func writeCustomHeader(_ buffer: NIOCore.ByteBuffer) async throws
68-
6960
}
7061

7162
/// This handler protocol is intended to serve the most common use-cases.

Sources/AWSLambdaRuntime/LambdaResponseStreamWriter+Headers.swift

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,14 @@ extension LambdaResponseStreamWriter {
7575
buffer.writeBytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
7676

7777
// Write the JSON data and the separator
78-
try await writeCustomHeader(buffer)
78+
try await self.write(buffer, hasCustomHeaders: true)
79+
}
80+
81+
/// Write a response part into the stream. Bytes written are streamed continually.
82+
/// - Parameter buffer: The buffer to write.
83+
public func write(_ buffer: ByteBuffer) async throws {
84+
// Write the buffer to the response stream
85+
try await self.write(buffer, hasCustomHeaders: false)
7986
}
8087
}
8188

Sources/AWSLambdaRuntime/LambdaRuntimeClient.swift

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ import Synchronization
2020

2121
@usableFromInline
2222
final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
23-
@usableFromInline
24-
var _hasStreamingCustomHeaders = false
25-
2623
@usableFromInline
2724
nonisolated let unownedExecutor: UnownedSerialExecutor
2825

@@ -47,13 +44,8 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
4744
}
4845

4946
@usableFromInline
50-
func writeCustomHeader(_ buffer: NIOCore.ByteBuffer) async throws {
51-
try await self.runtimeClient.writeCustomHeader(buffer)
52-
}
53-
54-
@usableFromInline
55-
func write(_ buffer: NIOCore.ByteBuffer) async throws {
56-
try await self.runtimeClient.write(buffer)
47+
func write(_ buffer: NIOCore.ByteBuffer, hasCustomHeaders: Bool = false) async throws {
48+
try await self.runtimeClient.write(buffer, hasCustomHeaders: hasCustomHeaders)
5749
}
5850

5951
@usableFromInline
@@ -197,11 +189,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
197189
}
198190
}
199191

200-
private func writeCustomHeader(_ buffer: NIOCore.ByteBuffer) async throws {
201-
_hasStreamingCustomHeaders = true
202-
try await self.write(buffer)
203-
}
204-
private func write(_ buffer: NIOCore.ByteBuffer) async throws {
192+
private func write(_ buffer: NIOCore.ByteBuffer, hasCustomHeaders: Bool = false) async throws {
205193
switch self.lambdaState {
206194
case .idle, .sentResponse:
207195
throw LambdaRuntimeError(code: .writeAfterFinishHasBeenSent)
@@ -218,12 +206,15 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
218206
guard case .sendingResponse(requestID) = self.lambdaState else {
219207
fatalError("Invalid state: \(self.lambdaState)")
220208
}
221-
return try await handler.writeResponseBodyPart(buffer, requestID: requestID)
209+
return try await handler.writeResponseBodyPart(
210+
buffer,
211+
requestID: requestID,
212+
hasCustomHeaders: hasCustomHeaders
213+
)
222214
}
223215
}
224216

225217
private func writeAndFinish(_ buffer: NIOCore.ByteBuffer?) async throws {
226-
_hasStreamingCustomHeaders = false
227218
switch self.lambdaState {
228219
case .idle, .sentResponse:
229220
throw LambdaRuntimeError(code: .finishAfterFinishHasBeenSent)
@@ -444,16 +435,11 @@ extension LambdaRuntimeClient: LambdaChannelHandlerDelegate {
444435
}
445436
}
446437
}
447-
448-
func hasStreamingCustomHeaders(isolation: isolated (any Actor)? = #isolation) async -> Bool {
449-
await self._hasStreamingCustomHeaders
450-
}
451438
}
452439

453440
private protocol LambdaChannelHandlerDelegate {
454441
func connectionWillClose(channel: any Channel)
455442
func connectionErrorHappened(_ error: any Error, channel: any Channel)
456-
func hasStreamingCustomHeaders(isolation: isolated (any Actor)?) async -> Bool
457443
}
458444

459445
private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate> {
@@ -596,18 +582,29 @@ private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate>
596582
func writeResponseBodyPart(
597583
isolation: isolated (any Actor)? = #isolation,
598584
_ byteBuffer: ByteBuffer,
599-
requestID: String
585+
requestID: String,
586+
hasCustomHeaders: Bool
600587
) async throws {
601588
switch self.state {
602589
case .connected(_, .waitingForNextInvocation):
603590
fatalError("Invalid state: \(self.state)")
604591

605592
case .connected(let context, .waitingForResponse):
606593
self.state = .connected(context, .sendingResponse)
607-
try await self.sendResponseBodyPart(byteBuffer, sendHeadWithRequestID: requestID, context: context)
594+
try await self.sendResponseBodyPart(
595+
byteBuffer,
596+
sendHeadWithRequestID: requestID,
597+
context: context,
598+
hasCustomHeaders: hasCustomHeaders
599+
)
608600

609601
case .connected(let context, .sendingResponse):
610-
try await self.sendResponseBodyPart(byteBuffer, sendHeadWithRequestID: nil, context: context)
602+
try await self.sendResponseBodyPart(
603+
byteBuffer,
604+
sendHeadWithRequestID: nil,
605+
context: context,
606+
hasCustomHeaders: hasCustomHeaders
607+
)
611608

612609
case .connected(_, .idle),
613610
.connected(_, .sentResponse):
@@ -658,15 +655,16 @@ private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate>
658655
isolation: isolated (any Actor)? = #isolation,
659656
_ byteBuffer: ByteBuffer,
660657
sendHeadWithRequestID: String?,
661-
context: ChannelHandlerContext
658+
context: ChannelHandlerContext,
659+
hasCustomHeaders: Bool
662660
) async throws {
663661

664662
if let requestID = sendHeadWithRequestID {
665663
// TODO: This feels super expensive. We should be able to make this cheaper. requestIDs are fixed length.
666664
let url = Consts.invocationURLPrefix + "/" + requestID + Consts.postResponseURLSuffix
667665

668666
var headers = self.streamingHeaders
669-
if await self.delegate.hasStreamingCustomHeaders(isolation: #isolation) {
667+
if hasCustomHeaders {
670668
// this header is required by Function URL when the user sends custom status code or headers
671669
headers.add(name: "Content-Type", value: "application/vnd.awslambda.http-integration-response")
672670
}
@@ -764,7 +762,6 @@ private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate>
764762
}
765763

766764
private func sendResponseStreamingFailure(error: any Error, context: ChannelHandlerContext) {
767-
// TODO: Use base64 here
768765
let trailers: HTTPHeaders = [
769766
"Lambda-Runtime-Function-Error-Type": "Unhandled",
770767
"Lambda-Runtime-Function-Error-Body": "Requires base64",

Sources/AWSLambdaRuntime/LambdaRuntimeClientProtocol.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@ import NIOCore
1616

1717
@usableFromInline
1818
package protocol LambdaRuntimeClientResponseStreamWriter: LambdaResponseStreamWriter {
19-
func write(_ buffer: ByteBuffer) async throws
19+
func write(_ buffer: ByteBuffer, hasCustomHeaders: Bool) async throws
2020
func finish() async throws
2121
func writeAndFinish(_ buffer: ByteBuffer) async throws
2222
func reportError(_ error: any Error) async throws
23-
func writeCustomHeader(_ buffer: NIOCore.ByteBuffer) async throws
2423
}
2524

2625
@usableFromInline

Tests/AWSLambdaRuntimeTests/Lambda+CodableTests.swift

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,12 @@ struct JSONTests {
8989
self._buffer = buffer
9090
}
9191

92-
func write(_ buffer: ByteBuffer) async throws {
92+
func write(_ buffer: ByteBuffer, hasCustomHeaders: Bool = false) async throws {
9393
fatalError("Unexpected call")
9494
}
9595

9696
func finish() async throws {
9797
fatalError("Unexpected call")
9898
}
99-
100-
func writeCustomHeader(_ buffer: NIOCore.ByteBuffer) async throws {
101-
// This is a mock, so we don't actually write custom headers.
102-
// In a real implementation, this would handle writing custom headers.
103-
fatalError("Unexpected call to writeCustomHeader")
104-
}
10599
}
106100
}

Tests/AWSLambdaRuntimeTests/LambdaResponseStreamWriter+HeadersTests.swift

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -537,11 +537,12 @@ final class MockLambdaResponseStreamWriter: LambdaResponseStreamWriter {
537537
let nullBytes: [UInt8] = [0, 0, 0, 0, 0, 0, 0, 0]
538538
buffer.writeBytes(nullBytes)
539539

540-
try await self.writeCustomHeader(buffer)
540+
try await self.write(buffer, hasCustomHeaders: true)
541541
}
542542

543-
func write(_ buffer: ByteBuffer) async throws {
543+
func write(_ buffer: ByteBuffer, hasCustomHeaders: Bool = false) async throws {
544544
writtenBuffers.append(buffer)
545+
self.hasCustomHeaders = hasCustomHeaders
545546
}
546547

547548
func finish() async throws {
@@ -552,11 +553,6 @@ final class MockLambdaResponseStreamWriter: LambdaResponseStreamWriter {
552553
writtenBuffers.append(buffer)
553554
isFinished = true
554555
}
555-
556-
func writeCustomHeader(_ buffer: NIOCore.ByteBuffer) async throws {
557-
hasCustomHeaders = true
558-
try await self.write(buffer)
559-
}
560556
}
561557

562558
// MARK: - Error Handling Mock Implementations
@@ -579,11 +575,12 @@ final class FailingMockLambdaResponseStreamWriter: LambdaResponseStreamWriter {
579575
) async throws {
580576
var buffer = ByteBuffer()
581577
buffer.writeString("{\"statusCode\":200}")
582-
try await writeCustomHeader(buffer)
578+
try await write(buffer, hasCustomHeaders: true)
583579
}
584580

585-
func write(_ buffer: ByteBuffer) async throws {
581+
func write(_ buffer: ByteBuffer, hasCustomHeaders: Bool = false) async throws {
586582
writeCallCount += 1
583+
self.hasCustomHeaders = hasCustomHeaders
587584

588585
if writeCallCount == failOnWriteCall {
589586
throw TestWriteError()
@@ -601,10 +598,6 @@ final class FailingMockLambdaResponseStreamWriter: LambdaResponseStreamWriter {
601598
try await finish()
602599
}
603600

604-
func writeCustomHeader(_ buffer: NIOCore.ByteBuffer) async throws {
605-
hasCustomHeaders = true
606-
try await write(buffer)
607-
}
608601
}
609602

610603
// MARK: - Test Error Types
@@ -693,11 +686,12 @@ final class TrackingLambdaResponseStreamWriter: LambdaResponseStreamWriter {
693686
) async throws {
694687
var buffer = ByteBuffer()
695688
buffer.writeString("{\"statusCode\":200}")
696-
try await writeCustomHeader(buffer)
689+
try await write(buffer, hasCustomHeaders: true)
697690
}
698691

699-
func write(_ buffer: ByteBuffer) async throws {
692+
func write(_ buffer: ByteBuffer, hasCustomHeaders: Bool = false) async throws {
700693
writeCallCount += 1
694+
self.hasCustomHeaders = hasCustomHeaders
701695
writtenBuffers.append(buffer)
702696
}
703697

@@ -712,10 +706,6 @@ final class TrackingLambdaResponseStreamWriter: LambdaResponseStreamWriter {
712706
isFinished = true
713707
}
714708

715-
func writeCustomHeader(_ buffer: NIOCore.ByteBuffer) async throws {
716-
hasCustomHeaders = true
717-
try await write(buffer)
718-
}
719709
}
720710

721711
/// Mock implementation with custom behavior for integration testing
@@ -732,12 +722,13 @@ final class CustomBehaviorLambdaResponseStreamWriter: LambdaResponseStreamWriter
732722
customBehaviorTriggered = true
733723
var buffer = ByteBuffer()
734724
buffer.writeString("{\"statusCode\":200}")
735-
try await writeCustomHeader(buffer)
725+
try await write(buffer, hasCustomHeaders: true)
736726
}
737727

738-
func write(_ buffer: ByteBuffer) async throws {
728+
func write(_ buffer: ByteBuffer, hasCustomHeaders: Bool = false) async throws {
739729
// Trigger custom behavior on any write
740730
customBehaviorTriggered = true
731+
self.hasCustomHeaders = hasCustomHeaders
741732
writtenBuffers.append(buffer)
742733
}
743734

@@ -750,9 +741,4 @@ final class CustomBehaviorLambdaResponseStreamWriter: LambdaResponseStreamWriter
750741
writtenBuffers.append(buffer)
751742
isFinished = true
752743
}
753-
754-
func writeCustomHeader(_ buffer: NIOCore.ByteBuffer) async throws {
755-
hasCustomHeaders = true
756-
try await write(buffer)
757-
}
758744
}

Tests/AWSLambdaRuntimeTests/MockLambdaClient.swift

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ struct MockLambdaWriter: LambdaRuntimeClientResponseStreamWriter {
2929
self.underlying = underlying
3030
}
3131

32-
func write(_ buffer: ByteBuffer) async throws {
33-
try await self.underlying.write(buffer)
32+
func write(_ buffer: ByteBuffer, hasCustomHeaders: Bool = false) async throws {
33+
try await self.underlying.write(buffer, hasCustomHeaders: hasCustomHeaders)
3434
}
3535

3636
func finish() async throws {
@@ -45,9 +45,6 @@ struct MockLambdaWriter: LambdaRuntimeClientResponseStreamWriter {
4545
func reportError(_ error: any Error) async throws {
4646
await self.underlying.reportError(error)
4747
}
48-
49-
func writeCustomHeader(_ buffer: NIOCore.ByteBuffer) async throws {
50-
}
5148
}
5249

5350
enum LambdaError: Error, Equatable {
@@ -158,7 +155,7 @@ final actor MockLambdaClient: LambdaRuntimeClientProtocol {
158155
}
159156
}
160157

161-
mutating func writeResult(buffer: ByteBuffer) -> ResultAction {
158+
mutating func writeResult(buffer: ByteBuffer, hasCustomHeaders: Bool = false) -> ResultAction {
162159
switch self.state {
163160
case .handlerIsProcessing(var accumulatedResponse, let eventProcessedHandler):
164161
accumulatedResponse.append(buffer)
@@ -279,8 +276,8 @@ final actor MockLambdaClient: LambdaRuntimeClientProtocol {
279276
}
280277
}
281278

282-
func write(_ buffer: ByteBuffer) async throws {
283-
switch self.stateMachine.writeResult(buffer: buffer) {
279+
func write(_ buffer: ByteBuffer, hasCustomHeaders: Bool = false) async throws {
280+
switch self.stateMachine.writeResult(buffer: buffer, hasCustomHeaders: hasCustomHeaders) {
284281
case .readyForMore:
285282
break
286283
case .fail(let error):

0 commit comments

Comments
 (0)