Skip to content

Add support for local server graceful shutdown #519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 73 additions & 53 deletions Sources/AWSLambdaRuntime/Lambda+LocalServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,21 @@ internal struct LambdaHTTPServer {
// consumed by iterating the group or by exiting the group. Since, we are never consuming
// the results of the group we need the group to automatically discard them; otherwise, this
// would result in a memory leak over time.
try await withThrowingDiscardingTaskGroup { taskGroup in
try await channel.executeThenClose { inbound in
for try await connectionChannel in inbound {

taskGroup.addTask {
logger.trace("Handling a new connection")
await server.handleConnection(channel: connectionChannel, logger: logger)
logger.trace("Done handling the connection")
try await withTaskCancellationHandler {
try await withThrowingDiscardingTaskGroup { taskGroup in
try await channel.executeThenClose { inbound in
for try await connectionChannel in inbound {

taskGroup.addTask {
logger.trace("Handling a new connection")
await server.handleConnection(channel: connectionChannel, logger: logger)
logger.trace("Done handling the connection")
}
}
}
}
} onCancel: {
channel.channel.close(promise: nil)
}
return .serverReturned(.success(()))
} catch {
Expand Down Expand Up @@ -230,38 +234,42 @@ internal struct LambdaHTTPServer {
// Note that this method is non-throwing and we are catching any error.
// We do this since we don't want to tear down the whole server when a single connection
// encounters an error.
do {
try await channel.executeThenClose { inbound, outbound in
for try await inboundData in inbound {
switch inboundData {
case .head(let head):
requestHead = head

case .body(let body):
requestBody.setOrWriteImmutableBuffer(body)

case .end:
precondition(requestHead != nil, "Received .end without .head")
// process the request
let response = try await self.processRequest(
head: requestHead,
body: requestBody,
logger: logger
)
// send the responses
try await self.sendResponse(
response: response,
outbound: outbound,
logger: logger
)

requestHead = nil
requestBody = nil
await withTaskCancellationHandler {
do {
try await channel.executeThenClose { inbound, outbound in
for try await inboundData in inbound {
switch inboundData {
case .head(let head):
requestHead = head

case .body(let body):
requestBody.setOrWriteImmutableBuffer(body)

case .end:
precondition(requestHead != nil, "Received .end without .head")
// process the request
let response = try await self.processRequest(
head: requestHead,
body: requestBody,
logger: logger
)
// send the responses
try await self.sendResponse(
response: response,
outbound: outbound,
logger: logger
)

requestHead = nil
requestBody = nil
}
}
}
} catch {
logger.error("Hit error: \(error)")
}
} catch {
logger.error("Hit error: \(error)")
} onCancel: {
channel.channel.close(promise: nil)
}
}

Expand Down Expand Up @@ -462,26 +470,38 @@ internal struct LambdaHTTPServer {
return nil
}

return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in
let nextAction = self.lock.withLock { state -> T? in
switch consume state {
case .buffer(var buffer):
if let first = buffer.popFirst() {
state = .buffer(buffer)
return first
} else {
state = .continuation(continuation)
return nil
}
return try await withTaskCancellationHandler {
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in
let nextAction = self.lock.withLock { state -> T? in
switch consume state {
case .buffer(var buffer):
if let first = buffer.popFirst() {
state = .buffer(buffer)
return first
} else {
state = .continuation(continuation)
return nil
}

case .continuation:
fatalError("Concurrent invocations to next(). This is illegal.")
case .continuation:
fatalError("Concurrent invocations to next(). This is illegal.")
}
}
}

guard let nextAction else { return }
guard let nextAction else { return }

continuation.resume(returning: nextAction)
continuation.resume(returning: nextAction)
}
} onCancel: {
self.lock.withLock { state in
switch consume state {
case .buffer(let buffer):
state = .buffer(buffer)
case .continuation(let continuation):
continuation?.resume(throwing: CancellationError())
state = .buffer([])
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,11 @@
#if ServiceLifecycleSupport
import ServiceLifecycle

extension LambdaRuntime: Service {}
extension LambdaRuntime: Service {
public func run() async throws {
try await cancelWhenGracefulShutdown {
try await self._run()
}
}
}
#endif
9 changes: 8 additions & 1 deletion Sources/AWSLambdaRuntime/LambdaRuntime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,15 @@ public final class LambdaRuntime<Handler>: @unchecked Sendable where Handler: St
self.logger.debug("LambdaRuntime initialized")
}

#if !ServiceLifecycleSupport
@inlinable
public func run() async throws {
internal func run() async throws {
try await _run()
}
#endif

@inlinable
internal func _run() async throws {
let handler = self.handlerMutex.withLockedValue { handler in
let result = handler
handler = nil
Expand Down
26 changes: 26 additions & 0 deletions Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import Logging
import NIOCore
import NIOPosix
import ServiceLifecycle
import Testing

import struct Foundation.UUID
Expand Down Expand Up @@ -139,4 +140,29 @@ struct LambdaRuntimeClientTests {
}
}
}
#if ServiceLifecycleSupport
@Test
func testLambdaRuntimeGracefulShutdown() async throws {
let runtime = LambdaRuntime {
(event: String, context: LambdaContext) in
"Hello \(event)"
}
var logger = Logger(label: "LambdaRuntime")
logger.logLevel = .debug
let serviceGroup = ServiceGroup(
services: [runtime],
gracefulShutdownSignals: [.sigterm, .sigint],
logger: logger
)
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await serviceGroup.run()
}
// wait a small amount to ensure we are waiting for continuation
try await Task.sleep(for: .milliseconds(100))

await serviceGroup.triggerGracefulShutdown()
}
}
#endif
}