From 6c7eb17a6c9457cfa69c22bf952268f5e98c82ef Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Tue, 2 May 2023 10:29:46 -0700 Subject: [PATCH] Simplify event stream message signer configuration This PR creates a `DeferredSigner` implementation that allows for the event stream message signer to be wired up by the signing implementation later in the request lifecycle rather than by adding an event stream signer method to the config. Refactoring this brings the middleware client implementation closer to how the orchestrator implementation will work, which unblocks the work required to make event streams work in the orchestrator. --- .../aws-sig-auth/src/event_stream.rs | 98 ++++++------- .../aws-sig-auth/src/middleware.rs | 72 ++++++---- aws/rust-runtime/aws-sig-auth/src/signer.rs | 5 +- .../smithy/rustsdk/AwsPresigningDecorator.kt | 4 +- .../smithy/rustsdk/SigV4SigningDecorator.kt | 39 ++---- ...onTest.kt => SigV4SigningDecoratorTest.kt} | 2 +- .../client/smithy/RustClientCodegenPlugin.kt | 2 - .../NoOpEventStreamSigningDecorator.kt | 66 --------- .../config/EventStreamSigningConfig.kt | 46 ------- .../protocols/HttpBoundProtocolGenerator.kt | 35 ++++- .../client/smithy/ClientCodegenVisitorTest.kt | 2 - .../HttpVersionListGeneratorTest.kt | 85 ------------ .../protocol/ProtocolTestGeneratorTest.kt | 2 +- .../generators/protocol/ProtocolGenerator.kt | 4 +- .../HttpBoundProtocolPayloadGenerator.kt | 86 +++++------- .../ServerHttpBoundProtocolGenerator.kt | 32 ++++- .../aws-smithy-eventstream/src/frame.rs | 130 ++++++++++++++++++ 17 files changed, 330 insertions(+), 380 deletions(-) rename aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/{SigV4SigningCustomizationTest.kt => SigV4SigningDecoratorTest.kt} (96%) delete mode 100644 codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/NoOpEventStreamSigningDecorator.kt delete mode 100644 codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/EventStreamSigningConfig.kt diff --git a/aws/rust-runtime/aws-sig-auth/src/event_stream.rs b/aws/rust-runtime/aws-sig-auth/src/event_stream.rs index 74b8c65dd42..6a57677c6ea 100644 --- a/aws/rust-runtime/aws-sig-auth/src/event_stream.rs +++ b/aws/rust-runtime/aws-sig-auth/src/event_stream.rs @@ -3,12 +3,10 @@ * SPDX-License-Identifier: Apache-2.0 */ -use crate::middleware::Signature; use aws_credential_types::Credentials; use aws_sigv4::event_stream::{sign_empty_message, sign_message}; use aws_sigv4::SigningParams; use aws_smithy_eventstream::frame::{Message, SignMessage, SignMessageError}; -use aws_smithy_http::property_bag::{PropertyBag, SharedPropertyBag}; use aws_types::region::SigningRegion; use aws_types::SigningService; use std::time::SystemTime; @@ -16,73 +14,59 @@ use std::time::SystemTime; /// Event Stream SigV4 signing implementation. #[derive(Debug)] pub struct SigV4Signer { - properties: SharedPropertyBag, - last_signature: Option, + last_signature: String, + credentials: Credentials, + signing_region: SigningRegion, + signing_service: SigningService, + time: Option, } impl SigV4Signer { - pub fn new(properties: SharedPropertyBag) -> Self { + pub fn new( + last_signature: String, + credentials: Credentials, + signing_region: SigningRegion, + signing_service: SigningService, + time: Option, + ) -> Self { Self { - properties, - last_signature: None, + last_signature, + credentials, + signing_region, + signing_service, + time, } } - fn signing_params(properties: &PropertyBag) -> SigningParams<()> { - // Every single one of these values would have been retrieved during the initial request, - // so we can safely assume they all exist in the property bag at this point. - let credentials = properties.get::().unwrap(); - let region = properties.get::().unwrap(); - let signing_service = properties.get::().unwrap(); - let time = properties - .get::() - .copied() - .unwrap_or_else(SystemTime::now); + fn signing_params(&self) -> SigningParams<()> { let mut builder = SigningParams::builder() - .access_key(credentials.access_key_id()) - .secret_key(credentials.secret_access_key()) - .region(region.as_ref()) - .service_name(signing_service.as_ref()) - .time(time) + .access_key(self.credentials.access_key_id()) + .secret_key(self.credentials.secret_access_key()) + .region(self.signing_region.as_ref()) + .service_name(self.signing_service.as_ref()) + .time(self.time.unwrap_or_else(SystemTime::now)) .settings(()); - builder.set_security_token(credentials.session_token()); + builder.set_security_token(self.credentials.session_token()); builder.build().unwrap() } } impl SignMessage for SigV4Signer { fn sign(&mut self, message: Message) -> Result { - let properties = self.properties.acquire(); - if self.last_signature.is_none() { - // The Signature property should exist in the property bag for all Event Stream requests. - self.last_signature = Some( - properties - .get::() - .expect("property bag contains initial Signature") - .as_ref() - .into(), - ) - } - let (signed_message, signature) = { - let params = Self::signing_params(&properties); - sign_message(&message, self.last_signature.as_ref().unwrap(), ¶ms).into_parts() + let params = self.signing_params(); + sign_message(&message, &self.last_signature, ¶ms).into_parts() }; - self.last_signature = Some(signature); + self.last_signature = signature; Ok(signed_message) } fn sign_empty(&mut self) -> Option> { - let properties = self.properties.acquire(); - if self.last_signature.is_none() { - // The Signature property should exist in the property bag for all Event Stream requests. - self.last_signature = Some(properties.get::().unwrap().as_ref().into()) - } let (signed_message, signature) = { - let params = Self::signing_params(&properties); - sign_empty_message(self.last_signature.as_ref().unwrap(), ¶ms).into_parts() + let params = self.signing_params(); + sign_empty_message(&self.last_signature, ¶ms).into_parts() }; - self.last_signature = Some(signature); + self.last_signature = signature; Some(Ok(signed_message)) } } @@ -90,27 +74,27 @@ impl SignMessage for SigV4Signer { #[cfg(test)] mod tests { use crate::event_stream::SigV4Signer; - use crate::middleware::Signature; use aws_credential_types::Credentials; use aws_smithy_eventstream::frame::{HeaderValue, Message, SignMessage}; - use aws_smithy_http::property_bag::PropertyBag; use aws_types::region::Region; use aws_types::region::SigningRegion; use aws_types::SigningService; use std::time::{Duration, UNIX_EPOCH}; + fn check_send_sync(value: T) -> T { + value + } + #[test] fn sign_message() { let region = Region::new("us-east-1"); - let mut properties = PropertyBag::new(); - properties.insert(region.clone()); - properties.insert(UNIX_EPOCH + Duration::new(1611160427, 0)); - properties.insert(SigningService::from_static("transcribe")); - properties.insert(Credentials::for_tests()); - properties.insert(SigningRegion::from(region)); - properties.insert(Signature::new("initial-signature".into())); - - let mut signer = SigV4Signer::new(properties.into()); + let mut signer = check_send_sync(SigV4Signer::new( + "initial-signature".into(), + Credentials::for_tests(), + SigningRegion::from(region), + SigningService::from_static("transcribe"), + Some(UNIX_EPOCH + Duration::new(1611160427, 0)), + )); let mut signatures = Vec::new(); for _ in 0..5 { let signed = signer diff --git a/aws/rust-runtime/aws-sig-auth/src/middleware.rs b/aws/rust-runtime/aws-sig-auth/src/middleware.rs index d7ec53454c1..15d74857439 100644 --- a/aws/rust-runtime/aws-sig-auth/src/middleware.rs +++ b/aws/rust-runtime/aws-sig-auth/src/middleware.rs @@ -20,21 +20,10 @@ use crate::signer::{ OperationSigningConfig, RequestConfig, SigV4Signer, SigningError, SigningRequirements, }; -/// Container for the request signature for use in the property bag. -#[non_exhaustive] -pub struct Signature(String); - -impl Signature { - pub fn new(signature: String) -> Self { - Self(signature) - } -} - -impl AsRef for Signature { - fn as_ref(&self) -> &str { - &self.0 - } -} +#[cfg(feature = "sign-eventstream")] +use crate::event_stream::SigV4Signer as EventStreamSigV4Signer; +#[cfg(feature = "sign-eventstream")] +use aws_smithy_eventstream::frame::DeferredSignerSender; /// Middleware stage to sign requests with SigV4 /// @@ -177,11 +166,26 @@ impl MapRequest for SigV4SigningStage { SigningRequirements::Required => signing_config(config)?, }; - let signature = self + let _signature = self .signer .sign(operation_config, &request_config, &creds, &mut req) .map_err(SigningStageErrorKind::SigningFailure)?; - config.insert(signature); + + // If this is an event stream operation, set up the event stream signer + #[cfg(feature = "sign-eventstream")] + if let Some(signer_sender) = config.get::() { + let time_override = config.get::().copied(); + signer_sender + .send(Box::new(EventStreamSigV4Signer::new( + _signature, + creds, + request_config.region.clone(), + request_config.service.clone(), + time_override, + )) as _) + .expect("failed to send deferred signer"); + } + Ok(req) }) } @@ -202,13 +206,17 @@ mod test { use aws_types::region::{Region, SigningRegion}; use aws_types::SigningService; - use crate::middleware::{ - SigV4SigningStage, Signature, SigningStageError, SigningStageErrorKind, - }; + use crate::middleware::{SigV4SigningStage, SigningStageError, SigningStageErrorKind}; use crate::signer::{OperationSigningConfig, SigV4Signer}; + #[cfg(feature = "sign-eventstream")] #[test] - fn places_signature_in_property_bag() { + fn sends_event_stream_signer_for_event_stream_operations() { + use crate::event_stream::SigV4Signer as EventStreamSigV4Signer; + use aws_smithy_eventstream::frame::{DeferredSigner, SignMessage}; + use std::time::SystemTime; + + let (mut deferred_signer, deferred_signer_sender) = DeferredSigner::new(); let req = http::Request::builder() .uri("https://test-service.test-region.amazonaws.com/") .body(SdkBody::from("")) @@ -217,21 +225,31 @@ mod test { let req = operation::Request::new(req) .augment(|req, properties| { properties.insert(region.clone()); - properties.insert(UNIX_EPOCH + Duration::new(1611160427, 0)); + properties.insert::(UNIX_EPOCH + Duration::new(1611160427, 0)); properties.insert(SigningService::from_static("kinesis")); properties.insert(OperationSigningConfig::default_config()); properties.insert(Credentials::for_tests()); - properties.insert(SigningRegion::from(region)); + properties.insert(SigningRegion::from(region.clone())); + properties.insert(deferred_signer_sender); Result::<_, Infallible>::Ok(req) }) .expect("succeeds"); let signer = SigV4SigningStage::new(SigV4Signer::new()); - let req = signer.apply(req).unwrap(); + let _ = signer.apply(req).unwrap(); + + let mut signer_for_comparison = EventStreamSigV4Signer::new( + // This is the expected SigV4 signature for the HTTP request above + "abac477b4afabf5651079e7b9a0aa6a1a3e356a7418a81d974cdae9d4c8e5441".into(), + Credentials::for_tests(), + SigningRegion::from(region), + SigningService::from_static("kinesis"), + Some(UNIX_EPOCH + Duration::new(1611160427, 0)), + ); - let property_bag = req.properties(); - let signature = property_bag.get::(); - assert!(signature.is_some()); + let expected_signed_empty = signer_for_comparison.sign_empty().unwrap().unwrap(); + let actual_signed_empty = deferred_signer.sign_empty().unwrap().unwrap(); + assert_eq!(expected_signed_empty, actual_signed_empty); } // check that the endpoint middleware followed by signing middleware produce the expected result diff --git a/aws/rust-runtime/aws-sig-auth/src/signer.rs b/aws/rust-runtime/aws-sig-auth/src/signer.rs index a1d36c97cad..706ab39d85c 100644 --- a/aws/rust-runtime/aws-sig-auth/src/signer.rs +++ b/aws/rust-runtime/aws-sig-auth/src/signer.rs @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -use crate::middleware::Signature; use aws_credential_types::Credentials; use aws_sigv4::http_request::{ sign, PayloadChecksumKind, PercentEncodingMode, SessionTokenMode, SignableRequest, @@ -191,7 +190,7 @@ impl SigV4Signer { request_config: &RequestConfig<'_>, credentials: &Credentials, request: &mut http::Request, - ) -> Result { + ) -> Result { let settings = Self::settings(operation_config); let signing_params = Self::signing_params(settings, credentials, request_config); @@ -223,7 +222,7 @@ impl SigV4Signer { signing_instructions.apply_to_request(request); - Ok(Signature::new(signature)) + Ok(signature) } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt index 20d280ddf9e..144d3764a0a 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt @@ -21,6 +21,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegen import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientSection import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.MakeOperationGenerator +import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientHttpBoundProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.docs @@ -34,7 +35,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection -import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.util.cloneOperation import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait @@ -173,7 +173,7 @@ class AwsInputPresignedMethod( MakeOperationGenerator( codegenContext, protocol, - HttpBoundProtocolPayloadGenerator(codegenContext, protocol), + ClientHttpBoundProtocolPayloadGenerator(codegenContext, protocol), // Prefixed with underscore to avoid colliding with modeled functions functionName = makeOperationFn, public = false, diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt index 81400d65975..3c40b518ebc 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt @@ -16,7 +16,7 @@ import software.amazon.smithy.model.traits.OptionalAuthTrait import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization -import software.amazon.smithy.rust.codegen.client.smithy.generators.config.EventStreamSigningConfig +import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate @@ -77,17 +77,17 @@ class SigV4SigningDecorator : ClientCodegenDecorator { } class SigV4SigningConfig( - runtimeConfig: RuntimeConfig, + private val runtimeConfig: RuntimeConfig, private val serviceHasEventStream: Boolean, private val sigV4Trait: SigV4Trait, -) : EventStreamSigningConfig(runtimeConfig) { - private val codegenScope = arrayOf( - "SigV4Signer" to AwsRuntimeType.awsSigAuthEventStream(runtimeConfig).resolve("event_stream::SigV4Signer"), - ) - - override fun configImplSection(): Writable { - return writable { - rustTemplate( +) : ConfigCustomization() { + override fun section(section: ServiceConfig): Writable = writable { + if (section is ServiceConfig.ConfigImpl) { + if (serviceHasEventStream) { + // enable the aws-sig-auth `sign-eventstream` feature + addDependency(AwsRuntimeType.awsSigAuthEventStream(runtimeConfig).toSymbol()) + } + rust( """ /// The signature version 4 service signing name to use in the credential scope when signing requests. /// @@ -97,24 +97,7 @@ class SigV4SigningConfig( ${sigV4Trait.name.dq()} } """, - *codegenScope, ) - if (serviceHasEventStream) { - rustTemplate( - "#{signerFn:W}", - "signerFn" to - renderEventStreamSignerFn { propertiesName -> - writable { - rustTemplate( - """ - #{SigV4Signer}::new($propertiesName) - """, - *codegenScope, - ) - } - }, - ) - } } } } @@ -209,5 +192,3 @@ class SigV4SigningFeature( } } } - -fun RuntimeConfig.sigAuth() = awsRuntimeCrate("aws-sig-auth") diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningCustomizationTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecoratorTest.kt similarity index 96% rename from aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningCustomizationTest.kt rename to aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecoratorTest.kt index 9de7c91f654..71bd5eaf6c9 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningCustomizationTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecoratorTest.kt @@ -12,7 +12,7 @@ import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.unitTest -internal class SigV4SigningCustomizationTest { +internal class SigV4SigningDecoratorTest { @Test fun `generates a valid config`() { val project = stubConfigProject( diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt index faa3a01c5df..08ba90d1401 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt @@ -15,7 +15,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpAuth import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpConnectorConfigDecorator import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator -import software.amazon.smithy.rust.codegen.client.smithy.customize.NoOpEventStreamSigningDecorator import software.amazon.smithy.rust.codegen.client.smithy.customize.RequiredCustomizations import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointParamsDecorator import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointsDecorator @@ -62,7 +61,6 @@ class RustClientCodegenPlugin : ClientDecoratableBuildPlugin() { FluentClientDecorator(), EndpointsDecorator(), EndpointParamsDecorator(), - NoOpEventStreamSigningDecorator(), ApiKeyAuthDecorator(), HttpAuthDecorator(), HttpConnectorConfigDecorator(), diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/NoOpEventStreamSigningDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/NoOpEventStreamSigningDecorator.kt deleted file mode 100644 index 7d924ee75aa..00000000000 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/NoOpEventStreamSigningDecorator.kt +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.client.smithy.customize - -import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext -import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization -import software.amazon.smithy.rust.codegen.client.smithy.generators.config.EventStreamSigningConfig -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.util.hasEventStreamOperations - -/** - * The NoOpEventStreamSigningDecorator: - * - adds a `new_event_stream_signer()` method to `config` to create an Event Stream NoOp signer - */ -open class NoOpEventStreamSigningDecorator : ClientCodegenDecorator { - override val name: String = "NoOpEventStreamSigning" - override val order: Byte = Byte.MIN_VALUE - - private fun applies(codegenContext: CodegenContext, baseCustomizations: List): Boolean = - codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model) && - // and if there is no other `EventStreamSigningConfig`, apply this one - !baseCustomizations.any { it is EventStreamSigningConfig } - - override fun configCustomizations( - codegenContext: ClientCodegenContext, - baseCustomizations: List, - ): List { - if (!applies(codegenContext, baseCustomizations)) { - return baseCustomizations - } - return baseCustomizations + NoOpEventStreamSigningConfig( - codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model), - codegenContext.runtimeConfig, - ) - } -} - -class NoOpEventStreamSigningConfig( - private val serviceHasEventStream: Boolean, - runtimeConfig: RuntimeConfig, -) : EventStreamSigningConfig(runtimeConfig) { - - private val codegenScope = arrayOf( - "NoOpSigner" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::NoOpSigner"), - ) - - override fun configImplSection() = renderEventStreamSignerFn { - writable { - if (serviceHasEventStream) { - rustTemplate( - """ - #{NoOpSigner}{} - """, - *codegenScope, - ) - } - } - } -} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/EventStreamSigningConfig.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/EventStreamSigningConfig.kt deleted file mode 100644 index 35da7d63b0e..00000000000 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/EventStreamSigningConfig.kt +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.client.smithy.generators.config - -import software.amazon.smithy.rust.codegen.core.rustlang.Writable -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType - -open class EventStreamSigningConfig( - runtimeConfig: RuntimeConfig, -) : ConfigCustomization() { - private val codegenScope = arrayOf( - "SharedPropertyBag" to RuntimeType.smithyHttp(runtimeConfig).resolve("property_bag::SharedPropertyBag"), - "SignMessage" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::SignMessage"), - ) - - override fun section(section: ServiceConfig): Writable { - return when (section) { - is ServiceConfig.ConfigImpl -> configImplSection() - else -> emptySection - } - } - - open fun configImplSection(): Writable = emptySection - - fun renderEventStreamSignerFn(signerInstantiator: (String) -> Writable): Writable = writable { - rustTemplate( - """ - /// Creates a new Event Stream `SignMessage` implementor. - pub fn new_event_stream_signer( - &self, - _properties: #{SharedPropertyBag} - ) -> impl #{SignMessage} { - #{signer:W} - } - """, - *codegenScope, - "signer" to signerInstantiator("_properties"), - ) - } -} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt index e6f3e6e1f7b..166e32ca05a 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.Cli import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.MakeOperationGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolParserGenerator import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate @@ -22,6 +23,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol @@ -29,11 +31,11 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctio import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember import software.amazon.smithy.rust.codegen.core.util.outputShape -// TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime` +// TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime` (replace with ClientProtocolGenerator) class HttpBoundProtocolGenerator( codegenContext: ClientCodegenContext, protocol: Protocol, - bodyGenerator: ProtocolPayloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol), + bodyGenerator: ProtocolPayloadGenerator = ClientHttpBoundProtocolPayloadGenerator(codegenContext, protocol), ) : ClientProtocolGenerator( codegenContext, protocol, @@ -48,6 +50,35 @@ class HttpBoundProtocolGenerator( HttpBoundProtocolTraitImplGenerator(codegenContext, protocol), ) +class ClientHttpBoundProtocolPayloadGenerator( + codegenContext: ClientCodegenContext, + protocol: Protocol, +) : ProtocolPayloadGenerator by HttpBoundProtocolPayloadGenerator( + codegenContext, protocol, HttpMessageType.REQUEST, + renderEventStreamBody = { writer, params -> + writer.rustTemplate( + """ + { + let error_marshaller = #{errorMarshallerConstructorFn}(); + let marshaller = #{marshallerConstructorFn}(); + let (signer, signer_sender) = #{DeferredSigner}::new(); + properties.acquire_mut().insert(signer_sender); + let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> = + ${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer); + let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into(); + body + } + """, + "hyper" to CargoDependency.HyperWithStream.toType(), + "SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig), + "aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig), + "DeferredSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::DeferredSigner"), + "marshallerConstructorFn" to params.marshallerConstructorFn, + "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, + ) + }, +) + // TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime` open class HttpBoundProtocolTraitImplGenerator( codegenContext: ClientCodegenContext, diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitorTest.kt index 2cb21423fd4..f9cf52375b6 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitorTest.kt @@ -10,7 +10,6 @@ import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.client.smithy.customizations.ClientCustomizations import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator -import software.amazon.smithy.rust.codegen.client.smithy.customize.NoOpEventStreamSigningDecorator import software.amazon.smithy.rust.codegen.client.smithy.customize.RequiredCustomizations import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientDecorator import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel @@ -52,7 +51,6 @@ class ClientCodegenVisitorTest { ClientCustomizations(), RequiredCustomizations(), FluentClientDecorator(), - NoOpEventStreamSigningDecorator(), ) val visitor = ClientCodegenVisitor(ctx, codegenDecorator) val baselineModel = visitor.baselineTransform(model) diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpVersionListGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpVersionListGeneratorTest.kt index 3b5fc70a6f1..73633a603ea 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpVersionListGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpVersionListGeneratorTest.kt @@ -6,19 +6,9 @@ package software.amazon.smithy.rust.codegen.client.smithy.customizations import org.junit.jupiter.api.Test -import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext -import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator -import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization -import software.amazon.smithy.rust.codegen.client.smithy.generators.config.EventStreamSigningConfig -import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.integrationTest @@ -172,7 +162,6 @@ internal class HttpVersionListGeneratorTest { clientIntegrationTest( model, IntegrationTestParams(addModuleToEventStreamAllowList = true), - additionalDecorators = listOf(FakeSigningDecorator()), ) { clientCodegenContext, rustCrate -> val moduleName = clientCodegenContext.moduleUseName() rustCrate.integrationTest("validate_eventstream_http") { @@ -196,77 +185,3 @@ internal class HttpVersionListGeneratorTest { } } } - -class FakeSigningDecorator : ClientCodegenDecorator { - override val name: String = "fakesigning" - override val order: Byte = 0 - override fun classpathDiscoverable(): Boolean = false - override fun configCustomizations( - codegenContext: ClientCodegenContext, - baseCustomizations: List, - ): List { - return baseCustomizations.filterNot { - it is EventStreamSigningConfig - } + FakeSigningConfig(codegenContext.runtimeConfig) - } -} - -class FakeSigningConfig( - runtimeConfig: RuntimeConfig, -) : EventStreamSigningConfig(runtimeConfig) { - private val codegenScope = arrayOf( - "SharedPropertyBag" to RuntimeType.smithyHttp(runtimeConfig).resolve("property_bag::SharedPropertyBag"), - "SignMessageError" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::SignMessageError"), - "SignMessage" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::SignMessage"), - "Message" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::Message"), - ) - - override fun section(section: ServiceConfig): Writable { - return when (section) { - is ServiceConfig.ConfigImpl -> writable { - rustTemplate( - """ - /// Creates a new Event Stream `SignMessage` implementor. - pub fn new_event_stream_signer( - &self, - properties: #{SharedPropertyBag} - ) -> FakeSigner { - FakeSigner::new(properties) - } - """, - *codegenScope, - ) - } - - is ServiceConfig.Extras -> writable { - rustTemplate( - """ - /// Fake signing implementation. - ##[derive(Debug)] - pub struct FakeSigner; - - impl FakeSigner { - /// Create a real `FakeSigner` - pub fn new(_properties: #{SharedPropertyBag}) -> Self { - Self {} - } - } - - impl #{SignMessage} for FakeSigner { - fn sign(&mut self, message: #{Message}) -> Result<#{Message}, #{SignMessageError}> { - Ok(message) - } - - fn sign_empty(&mut self) -> Option> { - Some(Ok(#{Message}::new(Vec::new()))) - } - } - """, - *codegenScope, - ) - } - - else -> emptySection - } - } -} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt index d0d271231ac..b03f374403a 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt @@ -38,7 +38,7 @@ private class TestProtocolPayloadGenerator(private val body: String) : ProtocolP override fun payloadMetadata(operationShape: OperationShape) = ProtocolPayloadGenerator.PayloadMetadata(takesOwnership = false) - override fun generatePayload(writer: RustWriter, self: String, operationShape: OperationShape) { + override fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) { writer.writeWithNoFormatting(body) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolGenerator.kt index 4f410bf03ad..ceb385c6d17 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolGenerator.kt @@ -36,13 +36,13 @@ interface ProtocolPayloadGenerator { /** * Write the payload into [writer]. * - * [self] is the name of the variable binding for the Rust struct that is to be serialized into the payload. + * [shapeName] is the name of the variable binding for the Rust struct that is to be serialized into the payload. * * This should be an expression that returns bytes: * - a `Vec` for non-streaming operations; or * - a `ByteStream` for streaming operations. */ - fun generatePayload(writer: RustWriter, self: String, operationShape: OperationShape) + fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) } /** diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt index 6d4e7bf8502..1e17beca4e8 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt @@ -18,7 +18,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext @@ -42,10 +41,18 @@ import software.amazon.smithy.rust.codegen.core.util.isOutputEventStream import software.amazon.smithy.rust.codegen.core.util.isStreaming import software.amazon.smithy.rust.codegen.core.util.outputShape +data class EventStreamBodyParams( + val outerName: String, + val memberName: String, + val marshallerConstructorFn: RuntimeType, + val errorMarshallerConstructorFn: RuntimeType, +) + class HttpBoundProtocolPayloadGenerator( codegenContext: CodegenContext, private val protocol: Protocol, private val httpMessageType: HttpMessageType = HttpMessageType.REQUEST, + private val renderEventStreamBody: (RustWriter, EventStreamBodyParams) -> Unit, ) : ProtocolPayloadGenerator { private val symbolProvider = codegenContext.symbolProvider private val model = codegenContext.model @@ -91,38 +98,38 @@ class HttpBoundProtocolPayloadGenerator( } } - override fun generatePayload(writer: RustWriter, self: String, operationShape: OperationShape) { + override fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) { when (httpMessageType) { - HttpMessageType.RESPONSE -> generateResponsePayload(writer, self, operationShape) - HttpMessageType.REQUEST -> generateRequestPayload(writer, self, operationShape) + HttpMessageType.RESPONSE -> generateResponsePayload(writer, shapeName, operationShape) + HttpMessageType.REQUEST -> generateRequestPayload(writer, shapeName, operationShape) } } - private fun generateRequestPayload(writer: RustWriter, self: String, operationShape: OperationShape) { + private fun generateRequestPayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) { val payloadMemberName = httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName if (payloadMemberName == null) { val serializerGenerator = protocol.structuredDataSerializer() - generateStructureSerializer(writer, self, serializerGenerator.operationInputSerializer(operationShape)) + generateStructureSerializer(writer, shapeName, serializerGenerator.operationInputSerializer(operationShape)) } else { - generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName) + generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName) } } - private fun generateResponsePayload(writer: RustWriter, self: String, operationShape: OperationShape) { + private fun generateResponsePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) { val payloadMemberName = httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName if (payloadMemberName == null) { val serializerGenerator = protocol.structuredDataSerializer() - generateStructureSerializer(writer, self, serializerGenerator.operationOutputSerializer(operationShape)) + generateStructureSerializer(writer, shapeName, serializerGenerator.operationOutputSerializer(operationShape)) } else { - generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName) + generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName) } } private fun generatePayloadMemberSerializer( writer: RustWriter, - self: String, + shapeName: String, operationShape: OperationShape, payloadMemberName: String, ) { @@ -131,7 +138,7 @@ class HttpBoundProtocolPayloadGenerator( if (operationShape.isEventStream(model)) { if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) { val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName) - writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "self") + writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, shapeName) } else if (operationShape.isOutputEventStream(model) && target == CodegenTarget.SERVER) { val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName) writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "output") @@ -144,16 +151,16 @@ class HttpBoundProtocolPayloadGenerator( HttpMessageType.RESPONSE -> operationShape.outputShape(model).expectMember(payloadMemberName) HttpMessageType.REQUEST -> operationShape.inputShape(model).expectMember(payloadMemberName) } - writer.serializeViaPayload(bodyMetadata, self, payloadMember, serializerGenerator) + writer.serializeViaPayload(bodyMetadata, shapeName, payloadMember, serializerGenerator) } } - private fun generateStructureSerializer(writer: RustWriter, self: String, serializer: RuntimeType?) { + private fun generateStructureSerializer(writer: RustWriter, shapeName: String, serializer: RuntimeType?) { if (serializer == null) { writer.rust("\"\"") } else { writer.rust( - "#T(&$self)?", + "#T(&$shapeName)?", serializer, ) } @@ -193,47 +200,20 @@ class HttpBoundProtocolPayloadGenerator( // TODO(EventStream): [RPC] RPC protocols need to send an initial message with the // parameters that are not `@eventHeader` or `@eventPayload`. - when (target) { - CodegenTarget.CLIENT -> - rustTemplate( - """ - { - let error_marshaller = #{errorMarshallerConstructorFn}(); - let marshaller = #{marshallerConstructorFn}(); - let signer = _config.new_event_stream_signer(properties.clone()); - let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, _> = - $outerName.$memberName.into_body_stream(marshaller, error_marshaller, signer); - let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into(); - body - } - """, - *codegenScope, - "marshallerConstructorFn" to marshallerConstructorFn, - "errorMarshallerConstructorFn" to errorMarshallerConstructorFn, - ) - CodegenTarget.SERVER -> { - rustTemplate( - """ - { - let error_marshaller = #{errorMarshallerConstructorFn}(); - let marshaller = #{marshallerConstructorFn}(); - let signer = #{NoOpSigner}{}; - let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, _> = - $outerName.$memberName.into_body_stream(marshaller, error_marshaller, signer); - adapter - } - """, - *codegenScope, - "marshallerConstructorFn" to marshallerConstructorFn, - "errorMarshallerConstructorFn" to errorMarshallerConstructorFn, - ) - } - } + renderEventStreamBody( + this, + EventStreamBodyParams( + outerName, + memberName, + marshallerConstructorFn, + errorMarshallerConstructorFn, + ), + ) } private fun RustWriter.serializeViaPayload( payloadMetadata: ProtocolPayloadGenerator.PayloadMetadata, - self: String, + shapeName: String, member: MemberShape, serializerGenerator: StructuredDataSerializerGenerator, ) { @@ -281,7 +261,7 @@ class HttpBoundProtocolPayloadGenerator( } } } - rust("#T($ref $self.${symbolProvider.toMemberName(member)})?", serializer) + rust("#T($ref $shapeName.${symbolProvider.toMemberName(member)})?", serializer) } private fun RustWriter.renderPayload( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 2530f044044..89accafd748 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -41,6 +41,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization @@ -48,12 +49,14 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustom import software.amazon.smithy.rust.codegen.core.smithy.customize.Section import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.mapRustType import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation +import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator @@ -114,6 +117,31 @@ class ServerHttpBoundProtocolGenerator( } } +class ServerHttpBoundProtocolPayloadGenerator( + codegenContext: CodegenContext, + protocol: Protocol, +) : ProtocolPayloadGenerator by HttpBoundProtocolPayloadGenerator( + codegenContext, protocol, HttpMessageType.RESPONSE, + renderEventStreamBody = { writer, params -> + writer.rustTemplate( + """ + { + let error_marshaller = #{errorMarshallerConstructorFn}(); + let marshaller = #{marshallerConstructorFn}(); + let signer = #{NoOpSigner}{}; + let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> = + ${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer); + adapter + } + """, + "aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig), + "NoOpSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::NoOpSigner"), + "marshallerConstructorFn" to params.marshallerConstructorFn, + "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, + ) + }, +) + /* * Generate all operation input parsers and output serializers for streaming and * non-streaming types. @@ -504,12 +532,12 @@ class ServerHttpBoundProtocolTraitImplGenerator( ?: serverRenderHttpResponseCode(httpTraitStatusCode)(this) operationShape.outputShape(model).findStreamingMember(model)?.let { - val payloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol, httpMessageType = HttpMessageType.RESPONSE) + val payloadGenerator = ServerHttpBoundProtocolPayloadGenerator(codegenContext, protocol) withBlockTemplate("let body = #{SmithyHttpServer}::body::boxed(#{SmithyHttpServer}::body::Body::wrap_stream(", "));", *codegenScope) { payloadGenerator.generatePayload(this, "output", operationShape) } } ?: run { - val payloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol, httpMessageType = HttpMessageType.RESPONSE) + val payloadGenerator = ServerHttpBoundProtocolPayloadGenerator(codegenContext, protocol) withBlockTemplate("let payload = ", ";") { payloadGenerator.generatePayload(this, "output", operationShape) } diff --git a/rust-runtime/aws-smithy-eventstream/src/frame.rs b/rust-runtime/aws-smithy-eventstream/src/frame.rs index a3d9c29a009..fc045b45003 100644 --- a/rust-runtime/aws-smithy-eventstream/src/frame.rs +++ b/rust-runtime/aws-smithy-eventstream/src/frame.rs @@ -14,6 +14,7 @@ use std::convert::{TryFrom, TryInto}; use std::error::Error as StdError; use std::fmt; use std::mem::size_of; +use std::sync::{mpsc, Mutex}; const PRELUDE_LENGTH_BYTES: u32 = 3 * size_of::() as u32; const PRELUDE_LENGTH_BYTES_USIZE: usize = PRELUDE_LENGTH_BYTES as usize; @@ -34,6 +35,88 @@ pub trait SignMessage: fmt::Debug { fn sign_empty(&mut self) -> Option>; } +/// A sender that gets placed in the request config to wire up an event stream signer after signing. +#[derive(Debug)] +#[non_exhaustive] +pub struct DeferredSignerSender(Mutex>>); + +impl DeferredSignerSender { + /// Creates a new `DeferredSignerSender` + fn new(tx: mpsc::Sender>) -> Self { + Self(Mutex::new(tx)) + } + + /// Sends a signer on the channel + pub fn send( + &self, + signer: Box, + ) -> Result<(), mpsc::SendError>> { + self.0.lock().unwrap().send(signer.into()) + } +} + +/// Deferred event stream signer to allow a signer to be wired up later. +/// +/// HTTP request signing takes place after serialization, and the event stream +/// message stream body is established during serialization. Since event stream +/// signing may need context from the initial HTTP signing operation, this +/// [`DeferredSigner`] is needed to wire up the signer later in the request lifecycle. +/// +/// This signer basically just establishes a MPSC channel so that the sender can +/// be placed in the request's config. Then the HTTP signer implementation can +/// retrieve the sender from that config and send an actual signing implementation +/// with all the context needed. +/// +/// When an event stream implementation needs to sign a message, the first call to +/// sign will acquire a signing implementation off of the channel and cache it +/// for the remainder of the operation. +#[derive(Debug)] +pub struct DeferredSigner { + rx: Option>>>, + signer: Option>, +} + +impl DeferredSigner { + pub fn new() -> (Self, DeferredSignerSender) { + let (tx, rx) = mpsc::channel(); + ( + Self { + rx: Some(Mutex::new(rx)), + signer: None, + }, + DeferredSignerSender::new(tx), + ) + } + + fn acquire(&mut self) -> &mut (dyn SignMessage + Send + Sync) { + // Can't use `if let Some(signer) = &mut self.signer` because the borrow checker isn't smart enough + if self.signer.is_some() { + return self.signer.as_mut().unwrap().as_mut(); + } else { + self.signer = Some( + self.rx + .take() + .expect("only taken once") + .lock() + .unwrap() + .try_recv() + .expect("signer must become available before first use"), + ); + self.acquire() + } + } +} + +impl SignMessage for DeferredSigner { + fn sign(&mut self, message: Message) -> Result { + self.acquire().sign(message) + } + + fn sign_empty(&mut self) -> Option> { + self.acquire().sign_empty() + } +} + #[derive(Debug)] pub struct NoOpSigner {} impl SignMessage for NoOpSigner { @@ -848,3 +931,50 @@ mod message_frame_decoder_tests { } } } + +#[cfg(test)] +mod deferred_signer_tests { + use crate::frame::{DeferredSigner, Header, HeaderValue, Message, SignMessage}; + use bytes::Bytes; + + fn check_send_sync(value: T) -> T { + value + } + + #[test] + fn deferred_signer() { + #[derive(Default, Debug)] + struct TestSigner { + call_num: i32, + } + impl SignMessage for TestSigner { + fn sign( + &mut self, + message: crate::frame::Message, + ) -> Result { + self.call_num += 1; + Ok(message.add_header(Header::new("call_num", HeaderValue::Int32(self.call_num)))) + } + + fn sign_empty( + &mut self, + ) -> Option> { + None + } + } + + let (mut signer, sender) = check_send_sync(DeferredSigner::new()); + + sender + .send(Box::new(TestSigner::default())) + .expect("success"); + + let message = signer.sign(Message::new(Bytes::new())).expect("success"); + assert_eq!(1, message.headers()[0].value().as_int32().unwrap()); + + let message = signer.sign(Message::new(Bytes::new())).expect("success"); + assert_eq!(2, message.headers()[0].value().as_int32().unwrap()); + + assert!(signer.sign_empty().is_none()); + } +}