Skip to content

Commit 156e6c7

Browse files
committed
make sure connectionToControlPlaneLost error is triggered by the test
1 parent 482e09e commit 156e6c7

File tree

2 files changed

+78
-32
lines changed

2 files changed

+78
-32
lines changed

Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ struct LambdaRuntimeClientTests {
4242
.success((self.requestId, self.event))
4343
}
4444

45-
func processResponse(requestId: String, response: String?) -> Result<Void, ProcessResponseError> {
45+
func processResponse(requestId: String, response: String?) -> Result<String?, ProcessResponseError> {
4646
#expect(self.requestId == requestId)
4747
#expect(self.event == response)
48-
return .success(())
48+
return .success(nil)
4949
}
5050

5151
func processError(requestId: String, error: ErrorResponse) -> Result<Void, ProcessErrorError> {
@@ -102,9 +102,9 @@ struct LambdaRuntimeClientTests {
102102
.success((self.requestId, self.event))
103103
}
104104

105-
func processResponse(requestId: String, response: String?) -> Result<Void, ProcessResponseError> {
105+
func processResponse(requestId: String, response: String?) -> Result<String?, ProcessResponseError> {
106106
#expect(self.requestId == requestId)
107-
return .success(())
107+
return .success(nil)
108108
}
109109

110110
mutating func captureHeaders(_ headers: HTTPHeaders) {
@@ -197,10 +197,10 @@ struct LambdaRuntimeClientTests {
197197
.success((self.requestId, self.event))
198198
}
199199

200-
func processResponse(requestId: String, response: String?) -> Result<Void, ProcessResponseError> {
200+
func processResponse(requestId: String, response: String?) -> Result<String?, ProcessResponseError> {
201201
#expect(self.requestId == requestId)
202202
#expect(self.event == response)
203-
return .success(())
203+
return .success(nil)
204204
}
205205

206206
func processError(requestId: String, error: ErrorResponse) -> Result<Void, ProcessErrorError> {
@@ -239,31 +239,56 @@ struct LambdaRuntimeClientTests {
239239
}
240240
}
241241

242-
@Test("Server closing the connection when waiting for next invocation throws an error")
243-
func testChannelCloseFutureWithWaitingForNextInvocation() async throws {
244-
struct DisconnectBehavior: LambdaServerBehavior {
245-
func getInvocation() -> GetInvocationResult {
246-
// Return "disconnect" to trigger server closing the connection
247-
.success(("disconnect", "0"))
248-
}
242+
struct DisconnectAfterSendingResponseBehavior: LambdaServerBehavior {
243+
func getInvocation() -> GetInvocationResult {
244+
.success((UUID().uuidString, "hello"))
245+
}
249246

250-
func processResponse(requestId: String, response: String?) -> Result<Void, ProcessResponseError> {
251-
Issue.record("should not process response")
252-
return .failure(.internalServerError)
253-
}
247+
func processResponse(requestId: String, response: String?) -> Result<String?, ProcessResponseError> {
248+
// Return "disconnect" to trigger server closing the connection
249+
// after having accepted a response
250+
.success("delayed-disconnect")
251+
}
254252

255-
func processError(requestId: String, error: ErrorResponse) -> Result<Void, ProcessErrorError> {
256-
Issue.record("should not report error")
257-
return .failure(.internalServerError)
258-
}
253+
func processError(requestId: String, error: ErrorResponse) -> Result<Void, ProcessErrorError> {
254+
Issue.record("should not report error")
255+
return .failure(.internalServerError)
256+
}
259257

260-
func processInitError(error: ErrorResponse) -> Result<Void, ProcessErrorError> {
261-
Issue.record("should not report init error")
262-
return .failure(.internalServerError)
263-
}
258+
func processInitError(error: ErrorResponse) -> Result<Void, ProcessErrorError> {
259+
Issue.record("should not report init error")
260+
return .failure(.internalServerError)
261+
}
262+
}
263+
264+
struct DisconnectBehavior: LambdaServerBehavior {
265+
func getInvocation() -> GetInvocationResult {
266+
// Return "disconnect" to trigger server closing the connection
267+
.success(("disconnect", "0"))
264268
}
265269

266-
try await withMockServer(behaviour: DisconnectBehavior()) { port in
270+
func processResponse(requestId: String, response: String?) -> Result<String?, ProcessResponseError> {
271+
Issue.record("should not process response")
272+
return .failure(.internalServerError)
273+
}
274+
275+
func processError(requestId: String, error: ErrorResponse) -> Result<Void, ProcessErrorError> {
276+
Issue.record("should not report error")
277+
return .failure(.internalServerError)
278+
}
279+
280+
func processInitError(error: ErrorResponse) -> Result<Void, ProcessErrorError> {
281+
Issue.record("should not report init error")
282+
return .failure(.internalServerError)
283+
}
284+
}
285+
286+
@Test(
287+
"Server closing the connection when waiting for next invocation throws an error",
288+
arguments: [DisconnectAfterSendingResponseBehavior(), DisconnectBehavior()] as [any LambdaServerBehavior]
289+
)
290+
func testChannelCloseFutureWithWaitingForNextInvocation(behavior: LambdaServerBehavior) async throws {
291+
try await withMockServer(behaviour: behavior) { port in
267292
let configuration = LambdaRuntimeClient.Configuration(ip: "127.0.0.1", port: port)
268293

269294
try await LambdaRuntimeClient.withRuntimeClient(
@@ -273,7 +298,12 @@ struct LambdaRuntimeClientTests {
273298
) { runtimeClient in
274299
do {
275300
// This should fail when server closes connection
301+
let (_, writer) = try await runtimeClient.nextInvocation()
302+
let response = ByteBuffer(string: "hello")
303+
try await writer.writeAndFinish(response)
304+
276305
let _ = try await runtimeClient.nextInvocation()
306+
277307
Issue.record("Expected connection error but got successful invocation")
278308

279309
} catch let error as LambdaRuntimeError {

Tests/AWSLambdaRuntimeTests/MockLambdaServer.swift

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ final class HTTPHandler: ChannelInboundHandler {
160160
var responseStatus: HTTPResponseStatus
161161
var responseBody: String?
162162
var responseHeaders: [(String, String)]?
163+
var disconnectAfterSend = false
163164

164165
// Handle post-init-error first to avoid matching the less specific post-error suffix.
165166
if request.head.uri.hasSuffix(Consts.postInitErrorURL) {
@@ -202,8 +203,11 @@ final class HTTPHandler: ChannelInboundHandler {
202203
behavior.captureHeaders(request.head.headers)
203204

204205
switch behavior.processResponse(requestId: String(requestId), response: requestBody) {
205-
case .success:
206+
case .success(let next):
206207
responseStatus = .accepted
208+
if next == "delayed-disconnect" {
209+
disconnectAfterSend = true
210+
}
207211
case .failure(let error):
208212
responseStatus = .init(statusCode: error.rawValue)
209213
}
@@ -223,14 +227,21 @@ final class HTTPHandler: ChannelInboundHandler {
223227
} else {
224228
responseStatus = .notFound
225229
}
226-
self.writeResponse(context: context, status: responseStatus, headers: responseHeaders, body: responseBody)
230+
self.writeResponse(
231+
context: context,
232+
status: responseStatus,
233+
headers: responseHeaders,
234+
body: responseBody,
235+
closeConnection: disconnectAfterSend
236+
)
227237
}
228238

229239
func writeResponse(
230240
context: ChannelHandlerContext,
231241
status: HTTPResponseStatus,
232242
headers: [(String, String)]? = nil,
233-
body: String? = nil
243+
body: String? = nil,
244+
closeConnection: Bool = false
234245
) {
235246
var headers = HTTPHeaders(headers ?? [])
236247
headers.add(name: "Content-Length", value: "\(body?.utf8.count ?? 0)")
@@ -253,14 +264,19 @@ final class HTTPHandler: ChannelInboundHandler {
253264
}
254265

255266
let loopBoundContext = NIOLoopBound(context, eventLoop: context.eventLoop)
256-
257267
let keepAlive = self.keepAlive
258268
context.writeAndFlush(wrapOutboundOut(.end(nil))).whenComplete { result in
269+
let context = loopBoundContext.value
270+
if closeConnection {
271+
context.close(promise: nil)
272+
return
273+
}
274+
259275
if case .failure(let error) = result {
260276
logger.error("write error \(error)")
261277
}
278+
262279
if !keepAlive {
263-
let context = loopBoundContext.value
264280
context.close().whenFailure { error in
265281
logger.error("close error \(error)")
266282
}
@@ -271,7 +287,7 @@ final class HTTPHandler: ChannelInboundHandler {
271287

272288
protocol LambdaServerBehavior: Sendable {
273289
func getInvocation() -> GetInvocationResult
274-
func processResponse(requestId: String, response: String?) -> Result<Void, ProcessResponseError>
290+
func processResponse(requestId: String, response: String?) -> Result<String?, ProcessResponseError>
275291
func processError(requestId: String, error: ErrorResponse) -> Result<Void, ProcessErrorError>
276292
func processInitError(error: ErrorResponse) -> Result<Void, ProcessErrorError>
277293

0 commit comments

Comments
 (0)