From 2c8529f70a8e2767176f5034503050dec8d8a010 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Sun, 29 Jun 2025 07:14:10 +0100 Subject: [PATCH 1/2] Add support for local server graceful shutdown --- .../AWSLambdaRuntime/Lambda+LocalServer.swift | 135 +++++++++++------- .../LambdaRuntime+ServiceLifecycle.swift | 8 +- Sources/AWSLambdaRuntime/LambdaRuntime.swift | 9 +- .../LambdaRuntimeClientTests.swift | 26 ++++ 4 files changed, 124 insertions(+), 54 deletions(-) diff --git a/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift b/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift index 895ea7d8..1f11913b 100644 --- a/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift +++ b/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift @@ -166,17 +166,21 @@ internal struct LambdaHTTPServer { // consumed by iterating the group or by exiting the group. Since, we are never consuming // the results of the group we need the group to automatically discard them; otherwise, this // would result in a memory leak over time. - try await withThrowingDiscardingTaskGroup { taskGroup in - try await channel.executeThenClose { inbound in - for try await connectionChannel in inbound { - - taskGroup.addTask { - logger.trace("Handling a new connection") - await server.handleConnection(channel: connectionChannel, logger: logger) - logger.trace("Done handling the connection") + try await withTaskCancellationHandler { + try await withThrowingDiscardingTaskGroup { taskGroup in + try await channel.executeThenClose { inbound in + for try await connectionChannel in inbound { + + taskGroup.addTask { + logger.trace("Handling a new connection") + await server.handleConnection(channel: connectionChannel, logger: logger) + logger.trace("Done handling the connection") + } } } } + } onCancel: { + channel.channel.close(promise: nil) } return .serverReturned(.success(())) } catch { @@ -230,38 +234,42 @@ internal struct LambdaHTTPServer { // Note that this method is non-throwing and we are catching any error. // We do this since we don't want to tear down the whole server when a single connection // encounters an error. - do { - try await channel.executeThenClose { inbound, outbound in - for try await inboundData in inbound { - switch inboundData { - case .head(let head): - requestHead = head - - case .body(let body): - requestBody.setOrWriteImmutableBuffer(body) - - case .end: - precondition(requestHead != nil, "Received .end without .head") - // process the request - let response = try await self.processRequest( - head: requestHead, - body: requestBody, - logger: logger - ) - // send the responses - try await self.sendResponse( - response: response, - outbound: outbound, - logger: logger - ) - - requestHead = nil - requestBody = nil + await withTaskCancellationHandler { + do { + try await channel.executeThenClose { inbound, outbound in + for try await inboundData in inbound { + switch inboundData { + case .head(let head): + requestHead = head + + case .body(let body): + requestBody.setOrWriteImmutableBuffer(body) + + case .end: + precondition(requestHead != nil, "Received .end without .head") + // process the request + let response = try await self.processRequest( + head: requestHead, + body: requestBody, + logger: logger + ) + // send the responses + try await self.sendResponse( + response: response, + outbound: outbound, + logger: logger + ) + + requestHead = nil + requestBody = nil + } } } + } catch { + logger.error("Hit error: \(error)") } - } catch { - logger.error("Hit error: \(error)") + } onCancel: { + channel.channel.close(promise: nil) } } @@ -432,6 +440,7 @@ internal struct LambdaHTTPServer { enum State: ~Copyable { case buffer(Deque) case continuation(CheckedContinuation?) + case cancelled } private let lock = Mutex(.buffer([])) @@ -450,6 +459,10 @@ internal struct LambdaHTTPServer { buffer.append(invocation) state = .buffer(buffer) return nil + + case .cancelled: + state = .cancelled + return nil } } @@ -462,26 +475,44 @@ internal struct LambdaHTTPServer { return nil } - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - let nextAction = self.lock.withLock { state -> T? in - switch consume state { - case .buffer(var buffer): - if let first = buffer.popFirst() { - state = .buffer(buffer) - return first - } else { - state = .continuation(continuation) + return try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let nextAction = self.lock.withLock { state -> T? in + switch consume state { + case .buffer(var buffer): + if let first = buffer.popFirst() { + state = .buffer(buffer) + return first + } else { + state = .continuation(continuation) + return nil + } + + case .continuation: + fatalError("Concurrent invocations to next(). This is illegal.") + + case .cancelled: + state = .cancelled return nil } - - case .continuation: - fatalError("Concurrent invocations to next(). This is illegal.") } - } - guard let nextAction else { return } + guard let nextAction else { return } - continuation.resume(returning: nextAction) + continuation.resume(returning: nextAction) + } + } onCancel: { + self.lock.withLock { state in + switch consume state { + case .buffer(let buffer): + state = .buffer(buffer) + case .continuation(let continuation): + continuation?.resume(throwing: CancellationError()) + state = .continuation(continuation) + case .cancelled: + state = .cancelled + } + } } } diff --git a/Sources/AWSLambdaRuntime/LambdaRuntime+ServiceLifecycle.swift b/Sources/AWSLambdaRuntime/LambdaRuntime+ServiceLifecycle.swift index 54ecb537..1b05b1c2 100644 --- a/Sources/AWSLambdaRuntime/LambdaRuntime+ServiceLifecycle.swift +++ b/Sources/AWSLambdaRuntime/LambdaRuntime+ServiceLifecycle.swift @@ -15,5 +15,11 @@ #if ServiceLifecycleSupport import ServiceLifecycle -extension LambdaRuntime: Service {} +extension LambdaRuntime: Service { + public func run() async throws { + try await cancelWhenGracefulShutdown { + try await self._run() + } + } +} #endif diff --git a/Sources/AWSLambdaRuntime/LambdaRuntime.swift b/Sources/AWSLambdaRuntime/LambdaRuntime.swift index 5ff0daff..7aba2812 100644 --- a/Sources/AWSLambdaRuntime/LambdaRuntime.swift +++ b/Sources/AWSLambdaRuntime/LambdaRuntime.swift @@ -51,8 +51,15 @@ public final class LambdaRuntime: @unchecked Sendable where Handler: St self.logger.debug("LambdaRuntime initialized") } + #if !ServiceLifecycleSupport @inlinable - public func run() async throws { + internal func run() async throws { + try await _run() + } + #endif + + @inlinable + internal func _run() async throws { let handler = self.handlerMutex.withLockedValue { handler in let result = handler handler = nil diff --git a/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift b/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift index 33ebde3f..62255d66 100644 --- a/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift +++ b/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift @@ -15,6 +15,7 @@ import Logging import NIOCore import NIOPosix +import ServiceLifecycle import Testing import struct Foundation.UUID @@ -139,4 +140,29 @@ struct LambdaRuntimeClientTests { } } } + #if ServiceLifecycleSupport + @Test + func testLambdaRuntimeGracefulShutdown() async throws { + let runtime = LambdaRuntime { + (event: String, context: LambdaContext) in + "Hello \(event)" + } + var logger = Logger(label: "LambdaRuntime") + logger.logLevel = .debug + let serviceGroup = ServiceGroup( + services: [runtime], + gracefulShutdownSignals: [.sigterm, .sigint], + logger: logger + ) + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await serviceGroup.run() + } + // wait a small amount to ensure we are waiting for continuation + try await Task.sleep(for: .milliseconds(100)) + + await serviceGroup.triggerGracefulShutdown() + } + } + #endif } From 34b8544da54e69e7f6b3d035adddae9058b32a44 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Sun, 29 Jun 2025 07:37:33 +0100 Subject: [PATCH 2/2] Got rid of cancellation state as we can use .buffer([]) --- Sources/AWSLambdaRuntime/Lambda+LocalServer.swift | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift b/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift index 1f11913b..ccb45ae5 100644 --- a/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift +++ b/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift @@ -440,7 +440,6 @@ internal struct LambdaHTTPServer { enum State: ~Copyable { case buffer(Deque) case continuation(CheckedContinuation?) - case cancelled } private let lock = Mutex(.buffer([])) @@ -459,10 +458,6 @@ internal struct LambdaHTTPServer { buffer.append(invocation) state = .buffer(buffer) return nil - - case .cancelled: - state = .cancelled - return nil } } @@ -490,10 +485,6 @@ internal struct LambdaHTTPServer { case .continuation: fatalError("Concurrent invocations to next(). This is illegal.") - - case .cancelled: - state = .cancelled - return nil } } @@ -508,9 +499,7 @@ internal struct LambdaHTTPServer { state = .buffer(buffer) case .continuation(let continuation): continuation?.resume(throwing: CancellationError()) - state = .continuation(continuation) - case .cancelled: - state = .cancelled + state = .buffer([]) } } }