diff --git a/Tests/AWSLambdaRuntimeTests/LambdaRuntime+ServiceLifeCycle.swift b/Tests/AWSLambdaRuntimeTests/LambdaRuntime+ServiceLifeCycle.swift new file mode 100644 index 00000000..7103ea8d --- /dev/null +++ b/Tests/AWSLambdaRuntimeTests/LambdaRuntime+ServiceLifeCycle.swift @@ -0,0 +1,46 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftAWSLambdaRuntime open source project +// +// Copyright (c) 2024 Apple Inc. and the SwiftAWSLambdaRuntime project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if ServiceLifecycleSupport +@testable import AWSLambdaRuntime +import ServiceLifecycle +import Testing +import Logging + +@Suite +struct LambdaRuntimeServiceLifecycleTests { + @Test + func testLambdaRuntimeGracefulShutdown() async throws { + let runtime = LambdaRuntime { + (event: String, context: LambdaContext) in + "Hello \(event)" + } + + let serviceGroup = ServiceGroup( + services: [runtime], + gracefulShutdownSignals: [.sigterm, .sigint], + logger: Logger(label: "TestLambdaRuntimeGracefulShutdown") + ) + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await serviceGroup.run() + } + // wait a small amount to ensure we are waiting for continuation + try await Task.sleep(for: .milliseconds(100)) + + await serviceGroup.triggerGracefulShutdown() + } + } +} +#endif diff --git a/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift b/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift index cc901461..33ebde3f 100644 --- a/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift +++ b/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift @@ -15,7 +15,6 @@ import Logging import NIOCore import NIOPosix -import ServiceLifecycle import Testing import struct Foundation.UUID @@ -140,28 +139,4 @@ struct LambdaRuntimeClientTests { } } } - #if ServiceLifecycleSupport - @Test - func testLambdaRuntimeGracefulShutdown() async throws { - let runtime = LambdaRuntime { - (event: String, context: LambdaContext) in - "Hello \(event)" - } - - let serviceGroup = ServiceGroup( - services: [runtime], - gracefulShutdownSignals: [.sigterm, .sigint], - logger: Logger(label: "TestLambdaRuntimeGracefulShutdown") - ) - try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - try await serviceGroup.run() - } - // wait a small amount to ensure we are waiting for continuation - try await Task.sleep(for: .milliseconds(100)) - - await serviceGroup.triggerGracefulShutdown() - } - } - #endif } diff --git a/Tests/AWSLambdaRuntimeTests/LambdaRuntimeTests.swift b/Tests/AWSLambdaRuntimeTests/LambdaRuntimeTests.swift index cd519d76..430fe70a 100644 --- a/Tests/AWSLambdaRuntimeTests/LambdaRuntimeTests.swift +++ b/Tests/AWSLambdaRuntimeTests/LambdaRuntimeTests.swift @@ -75,6 +75,53 @@ struct LambdaRuntimeTests { taskGroup.cancelAll() } } + @Test("run() must be cancellable") + func testLambdaRuntimeCancellable() async throws { + + let logger = Logger(label: "LambdaRuntimeTests.RuntimeCancellable") + // create a runtime + let runtime = LambdaRuntime( + handler: MockHandler(), + eventLoop: Lambda.defaultEventLoop, + logger: logger + ) + + // Running the runtime with structured concurrency + // Task group returns when all tasks are completed. + // Even cancelled tasks must cooperatlivly complete + await #expect(throws: Never.self) { + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + logger.trace("--- launching runtime ----") + try await runtime.run() + } + + // Add a timeout task to the group + taskGroup.addTask { + logger.trace("--- launching timeout task ----") + try await Task.sleep(for: .seconds(5)) + if Task.isCancelled { return } + logger.trace("--- throwing timeout error ----") + throw TestError.timeout // Fail the test if the timeout triggers + } + + do { + // Wait for the runtime to start + logger.trace("--- waiting for runtime to start ----") + try await Task.sleep(for: .seconds(1)) + + // Cancel all tasks, this should not throw an error + // and should allow the runtime to complete gracefully + logger.trace("--- cancel all tasks ----") + taskGroup.cancelAll() // Cancel all tasks + } catch { + logger.error("--- catch an error: \(error)") + throw error // Propagate the error to fail the test + } + } + } + + } } struct MockHandler: StreamingLambdaHandler { @@ -86,3 +133,15 @@ struct MockHandler: StreamingLambdaHandler { } } + +// Define a custom error for timeout +enum TestError: Error, CustomStringConvertible { + case timeout + + var description: String { + switch self { + case .timeout: + return "Test timed out waiting for the task to complete." + } + } +}