diff --git a/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift b/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift index 895ea7d8..ccb45ae5 100644 --- a/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift +++ b/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift @@ -166,17 +166,21 @@ internal struct LambdaHTTPServer { // consumed by iterating the group or by exiting the group. Since, we are never consuming // the results of the group we need the group to automatically discard them; otherwise, this // would result in a memory leak over time. - try await withThrowingDiscardingTaskGroup { taskGroup in - try await channel.executeThenClose { inbound in - for try await connectionChannel in inbound { - - taskGroup.addTask { - logger.trace("Handling a new connection") - await server.handleConnection(channel: connectionChannel, logger: logger) - logger.trace("Done handling the connection") + try await withTaskCancellationHandler { + try await withThrowingDiscardingTaskGroup { taskGroup in + try await channel.executeThenClose { inbound in + for try await connectionChannel in inbound { + + taskGroup.addTask { + logger.trace("Handling a new connection") + await server.handleConnection(channel: connectionChannel, logger: logger) + logger.trace("Done handling the connection") + } } } } + } onCancel: { + channel.channel.close(promise: nil) } return .serverReturned(.success(())) } catch { @@ -230,38 +234,42 @@ internal struct LambdaHTTPServer { // Note that this method is non-throwing and we are catching any error. // We do this since we don't want to tear down the whole server when a single connection // encounters an error. - do { - try await channel.executeThenClose { inbound, outbound in - for try await inboundData in inbound { - switch inboundData { - case .head(let head): - requestHead = head - - case .body(let body): - requestBody.setOrWriteImmutableBuffer(body) - - case .end: - precondition(requestHead != nil, "Received .end without .head") - // process the request - let response = try await self.processRequest( - head: requestHead, - body: requestBody, - logger: logger - ) - // send the responses - try await self.sendResponse( - response: response, - outbound: outbound, - logger: logger - ) - - requestHead = nil - requestBody = nil + await withTaskCancellationHandler { + do { + try await channel.executeThenClose { inbound, outbound in + for try await inboundData in inbound { + switch inboundData { + case .head(let head): + requestHead = head + + case .body(let body): + requestBody.setOrWriteImmutableBuffer(body) + + case .end: + precondition(requestHead != nil, "Received .end without .head") + // process the request + let response = try await self.processRequest( + head: requestHead, + body: requestBody, + logger: logger + ) + // send the responses + try await self.sendResponse( + response: response, + outbound: outbound, + logger: logger + ) + + requestHead = nil + requestBody = nil + } } } + } catch { + logger.error("Hit error: \(error)") } - } catch { - logger.error("Hit error: \(error)") + } onCancel: { + channel.channel.close(promise: nil) } } @@ -462,26 +470,38 @@ internal struct LambdaHTTPServer { return nil } - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - let nextAction = self.lock.withLock { state -> T? in - switch consume state { - case .buffer(var buffer): - if let first = buffer.popFirst() { - state = .buffer(buffer) - return first - } else { - state = .continuation(continuation) - return nil - } + return try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let nextAction = self.lock.withLock { state -> T? in + switch consume state { + case .buffer(var buffer): + if let first = buffer.popFirst() { + state = .buffer(buffer) + return first + } else { + state = .continuation(continuation) + return nil + } - case .continuation: - fatalError("Concurrent invocations to next(). This is illegal.") + case .continuation: + fatalError("Concurrent invocations to next(). This is illegal.") + } } - } - guard let nextAction else { return } + guard let nextAction else { return } - continuation.resume(returning: nextAction) + continuation.resume(returning: nextAction) + } + } onCancel: { + self.lock.withLock { state in + switch consume state { + case .buffer(let buffer): + state = .buffer(buffer) + case .continuation(let continuation): + continuation?.resume(throwing: CancellationError()) + state = .buffer([]) + } + } } } diff --git a/Sources/AWSLambdaRuntime/LambdaRuntime+ServiceLifecycle.swift b/Sources/AWSLambdaRuntime/LambdaRuntime+ServiceLifecycle.swift index 54ecb537..1b05b1c2 100644 --- a/Sources/AWSLambdaRuntime/LambdaRuntime+ServiceLifecycle.swift +++ b/Sources/AWSLambdaRuntime/LambdaRuntime+ServiceLifecycle.swift @@ -15,5 +15,11 @@ #if ServiceLifecycleSupport import ServiceLifecycle -extension LambdaRuntime: Service {} +extension LambdaRuntime: Service { + public func run() async throws { + try await cancelWhenGracefulShutdown { + try await self._run() + } + } +} #endif diff --git a/Sources/AWSLambdaRuntime/LambdaRuntime.swift b/Sources/AWSLambdaRuntime/LambdaRuntime.swift index 5ff0daff..7aba2812 100644 --- a/Sources/AWSLambdaRuntime/LambdaRuntime.swift +++ b/Sources/AWSLambdaRuntime/LambdaRuntime.swift @@ -51,8 +51,15 @@ public final class LambdaRuntime: @unchecked Sendable where Handler: St self.logger.debug("LambdaRuntime initialized") } + #if !ServiceLifecycleSupport @inlinable - public func run() async throws { + internal func run() async throws { + try await _run() + } + #endif + + @inlinable + internal func _run() async throws { let handler = self.handlerMutex.withLockedValue { handler in let result = handler handler = nil diff --git a/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift b/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift index 33ebde3f..62255d66 100644 --- a/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift +++ b/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift @@ -15,6 +15,7 @@ import Logging import NIOCore import NIOPosix +import ServiceLifecycle import Testing import struct Foundation.UUID @@ -139,4 +140,29 @@ struct LambdaRuntimeClientTests { } } } + #if ServiceLifecycleSupport + @Test + func testLambdaRuntimeGracefulShutdown() async throws { + let runtime = LambdaRuntime { + (event: String, context: LambdaContext) in + "Hello \(event)" + } + var logger = Logger(label: "LambdaRuntime") + logger.logLevel = .debug + let serviceGroup = ServiceGroup( + services: [runtime], + gracefulShutdownSignals: [.sigterm, .sigint], + logger: logger + ) + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await serviceGroup.run() + } + // wait a small amount to ensure we are waiting for continuation + try await Task.sleep(for: .milliseconds(100)) + + await serviceGroup.triggerGracefulShutdown() + } + } + #endif }