Skip to content

Commit

Permalink
Merge branch 'main' into harryb/move-protocol-to-server-protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
Harry Barber committed Sep 13, 2022
2 parents 765260e + f324240 commit c2a67d1
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,29 @@ interface Protocol {
* there are no response headers or statuses available to further inform the error parsing.
*/
fun parseEventStreamGenericError(operationShape: OperationShape): RuntimeType

/**
* Returns a writable for the `RequestSpec` for an operation.
*/
fun serverRouterRequestSpec(
operationShape: OperationShape,
operationName: String,
serviceName: String,
requestSpecModule: RuntimeType,
): Writable

/**
* Returns the name of the constructor to be used on the `Router` type, to instantiate a `Router` using this
* protocol.
*/
fun serverRouterRuntimeConstructor(): String

/**
* In some protocols, such as restJson1,
* when there is no modeled body input, content type must not be set and the body must be empty.
* Returns a boolean indicating whether to perform this check.
*/
fun serverContentTypeCheckNoModeledInput(): Boolean = false
}

typealias ProtocolMap<C> = Map<ShapeId, ProtocolGeneratorFactory<ProtocolGenerator, C>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ open class RestJson(val coreCodegenContext: CoreCodegenContext) : Protocol {
*errorScope,
)
}

override fun serverRouterRequestSpec(
operationShape: OperationShape,
operationName: String,
serviceName: String,
requestSpecModule: RuntimeType,
): Writable = RestRequestSpecGenerator(httpBindingResolver, requestSpecModule).generate(operationShape)

override fun serverRouterRuntimeConstructor() = "new_rest_json_router"

override fun serverContentTypeCheckNoModeledInput() = true
}

fun restJsonFieldName(member: MemberShape): String {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ open class RestXml(val coreCodegenContext: CoreCodegenContext) : Protocol {
rust("#T::parse_generic_error(payload.as_ref())", restXmlErrors)
}
}

override fun serverRouterRequestSpec(
operationShape: OperationShape,
operationName: String,
serviceName: String,
requestSpecModule: RuntimeType,
): Writable = RestRequestSpecGenerator(httpBindingResolver, requestSpecModule).generate(operationShape)

override fun serverRouterRuntimeConstructor() = "new_rest_xml_router"

override fun serverContentTypeCheckNoModeledInput() = true
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -841,9 +841,6 @@ class ServerProtocolTestGenerator(
FailingTest(RestJson, "RestJsonBodyMalformedBlobInvalidBase64_case1", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonBodyMalformedBlobInvalidBase64_case2", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonWithBodyExpectsApplicationJsonContentType", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonWithPayloadExpectsImpliedContentType", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonWithPayloadExpectsModeledContentType", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonWithoutBodyExpectsEmptyContentType", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonBodyMalformedListNullItem", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonBodyMalformedMapNullValue", TestType.MalformedRequest),
FailingTest(RestJson, "RestJsonMalformedSetDuplicateItems", TestType.MalformedRequest),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.HttpErrorTrait
import software.amazon.smithy.model.traits.HttpPayloadTrait
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.MediaTypeTrait
import software.amazon.smithy.rust.codegen.client.rustlang.Attribute
import software.amazon.smithy.rust.codegen.client.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.client.rustlang.RustModule
import software.amazon.smithy.rust.codegen.client.rustlang.RustType
import software.amazon.smithy.rust.codegen.client.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.client.rustlang.Writable
import software.amazon.smithy.rust.codegen.client.rustlang.asType
import software.amazon.smithy.rust.codegen.client.rustlang.conditionalBlock
import software.amazon.smithy.rust.codegen.client.rustlang.render
import software.amazon.smithy.rust.codegen.client.rustlang.rust
import software.amazon.smithy.rust.codegen.client.rustlang.rustBlock
Expand Down Expand Up @@ -58,13 +61,15 @@ import software.amazon.smithy.rust.codegen.client.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.client.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.client.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.client.smithy.toOptional
import software.amazon.smithy.rust.codegen.client.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.client.smithy.transformers.operationErrors
import software.amazon.smithy.rust.codegen.client.smithy.wrapOptional
import software.amazon.smithy.rust.codegen.client.util.dq
import software.amazon.smithy.rust.codegen.client.util.expectTrait
import software.amazon.smithy.rust.codegen.client.util.findStreamingMember
import software.amazon.smithy.rust.codegen.client.util.getTrait
import software.amazon.smithy.rust.codegen.client.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.client.util.hasTrait
import software.amazon.smithy.rust.codegen.client.util.inputShape
import software.amazon.smithy.rust.codegen.client.util.isStreaming
import software.amazon.smithy.rust.codegen.client.util.outputShape
Expand Down Expand Up @@ -168,7 +173,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
val operationName = symbolProvider.toSymbol(operationShape).name
val inputName = "${operationName}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"

val verifyResponseContentType = writable {
val verifyAcceptHeader = writable {
httpBindingResolver.responseContentType(operationShape)?.also { contentType ->
rustTemplate(
"""
Expand All @@ -183,6 +188,30 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
)
}
}
val verifyRequestContentTypeHeader = writable {
operationShape
.inputShape(model)
.members()
.find { it.hasTrait<HttpPayloadTrait>() }
?.let { payload ->
val target = model.expectShape(payload.target)
if (!target.isBlobShape || target.hasTrait<MediaTypeTrait>()) {
val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)
?.let { "Some(${it.dq()})" } ?: "None"
rustTemplate(
"""
if #{SmithyHttpServer}::protocols::content_type_header_classifier(req, $expectedRequestContentType).is_err() {
return Err(#{RuntimeError} {
protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
kind: #{SmithyHttpServer}::runtime_error::RuntimeErrorKind::UnsupportedMediaType,
})
}
""",
*codegenScope,
)
}
}
}

// Implement `from_request` trait for input types.
rustTemplate(
Expand All @@ -197,7 +226,8 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
B::Data: Send,
#{RequestRejection} : From<<B as #{SmithyHttpServer}::body::HttpBody>::Error>
{
#{verify_response_content_type:W}
#{verifyAcceptHeader:W}
#{verifyRequestContentTypeHeader:W}
#{parse_request}(req)
.await
.map($inputName)
Expand Down Expand Up @@ -235,7 +265,8 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
"I" to inputSymbol,
"Marker" to serverProtocol.markerStruct(),
"parse_request" to serverParseRequest(operationShape),
"verify_response_content_type" to verifyResponseContentType,
"verifyAcceptHeader" to verifyAcceptHeader,
"verifyRequestContentTypeHeader" to verifyRequestContentTypeHeader,
)

// Implement `into_response` for output types.
Expand Down Expand Up @@ -711,16 +742,13 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
Attribute.AllowUnusedMut.render(this)
rust("let mut input = #T::default();", inputShape.builderSymbol(symbolProvider))
val parser = structuredDataParser.serverInputParser(operationShape)
val noInputs = model.expectShape(operationShape.inputShape).expectTrait<SyntheticInputTrait>().originalId == null
if (parser != null) {
val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)
rustTemplate(
"""
let body = request.take_body().ok_or(#{RequestRejection}::BodyAlreadyExtracted)?;
let bytes = #{Hyper}::body::to_bytes(body).await?;
if !bytes.is_empty() {
static EXPECTED_CONTENT_TYPE: #{OnceCell}::sync::Lazy<#{Mime}::Mime> =
#{OnceCell}::sync::Lazy::new(|| "$expectedRequestContentType".parse::<#{Mime}::Mime>().unwrap());
#{SmithyHttpServer}::protocols::check_content_type(request, &EXPECTED_CONTENT_TYPE)?;
input = #{parser}(bytes.as_ref(), input)?;
}
""",
Expand All @@ -740,6 +768,16 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
serverRenderUriPathParser(this, operationShape)
serverRenderQueryStringParser(this, operationShape)

if (noInputs && protocol.serverContentTypeCheckNoModeledInput()) {
conditionalBlock("if body.is_empty() {", "}", conditional = parser != null) {
rustTemplate(
"""
#{SmithyHttpServer}::protocols::content_type_header_empty_body_no_modeled_input(request)?;
""",
*codegenScope,
)
}
}
val err = if (StructureGenerator.fallibleBuilder(inputShape, symbolProvider)) {
"?"
} else ""
Expand Down
115 changes: 87 additions & 28 deletions rust-runtime/aws-smithy-http-server/src/protocols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,81 @@ pub struct AwsJson10;
pub struct AwsJson11;

/// Supported protocols.
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Protocol {
RestJson1,
RestXml,
AwsJson10,
AwsJson11,
}

pub fn check_content_type<B>(
/// When there are no modeled inputs,
/// a request body is empty and the content-type request header must not be set
pub fn content_type_header_empty_body_no_modeled_input<B>(
req: &RequestParts<B>,
expected_mime: &'static mime::Mime,
) -> Result<(), MissingContentTypeReason> {
let found_mime = req
.headers()
.ok_or(MissingContentTypeReason::HeadersTakenByAnotherExtractor)?
if req.headers().is_none() {
return Ok(());
}
let headers = req.headers().unwrap();
if headers.contains_key(http::header::CONTENT_TYPE) {
let found_mime = headers
.get(http::header::CONTENT_TYPE)
.unwrap() // The header is present, `unwrap` will not panic.
.to_str()
.map_err(MissingContentTypeReason::ToStrError)?
.parse::<mime::Mime>()
.map_err(MissingContentTypeReason::MimeParseError)?;
Err(MissingContentTypeReason::UnexpectedMimeType {
expected_mime: None,
found_mime: Some(found_mime),
})
} else {
Ok(())
}
}

/// Checks that the content-type in request headers is valid
pub fn content_type_header_classifier<B>(
req: &RequestParts<B>,
expected_content_type: Option<&'static str>,
) -> Result<(), MissingContentTypeReason> {
// Allow no CONTENT-TYPE header
if req.headers().is_none() {
return Ok(());
}
let headers = req.headers().unwrap(); // Headers are present, `unwrap` will not panic.
if !headers.contains_key(http::header::CONTENT_TYPE) {
return Ok(());
}
let client_type = headers
.get(http::header::CONTENT_TYPE)
.ok_or(MissingContentTypeReason::NoContentTypeHeader)?
.unwrap() // The header is present, `unwrap` will not panic.
.to_str()
.map_err(MissingContentTypeReason::ToStrError)?
.parse::<mime::Mime>()
.map_err(MissingContentTypeReason::MimeParseError)?;
if &found_mime == expected_mime {
Ok(())
// There is a content-type header
// If there is an implied content type, they must match
if let Some(expected_content_type) = expected_content_type {
let content_type = expected_content_type
.parse::<mime::Mime>()
// `expected_content_type` comes from the codegen.
.expect("BUG: MIME parsing failed, expected_content_type is not valid. Please file a bug report under https://github.com/awslabs/smithy-rs/issues");
if expected_content_type != client_type {
return Err(MissingContentTypeReason::UnexpectedMimeType {
expected_mime: Some(content_type),
found_mime: Some(client_type),
});
}
} else {
Err(MissingContentTypeReason::UnexpectedMimeType {
expected_mime,
found_mime,
})
// Content-type header and no modeled input (mismatch)
return Err(MissingContentTypeReason::UnexpectedMimeType {
expected_mime: None,
found_mime: Some(client_type),
});
}
Ok(())
}

pub fn accept_header_classifier<B>(req: &RequestParts<B>, content_type: &'static str) -> bool {
Expand Down Expand Up @@ -112,21 +158,34 @@ mod tests {
RequestParts::new(request)
}

static EXPECTED_MIME_APPLICATION_JSON: once_cell::sync::Lazy<mime::Mime> =
once_cell::sync::Lazy::new(|| "application/json".parse::<mime::Mime>().unwrap());
const EXPECTED_MIME_APPLICATION_JSON: Option<&'static str> = Some("application/json");

#[test]
fn check_content_type_header_empty_body_no_modeled_input() {
let request = Request::builder().body("").unwrap();
let request = RequestParts::new(request);
assert!(content_type_header_empty_body_no_modeled_input(&request).is_ok());
}

#[test]
fn check_valid_content_type() {
fn check_invalid_content_type_header_empty_body_no_modeled_input() {
let valid_request = req_content_type("application/json");
assert!(check_content_type(&valid_request, &EXPECTED_MIME_APPLICATION_JSON).is_ok());
let result = content_type_header_empty_body_no_modeled_input(&valid_request).unwrap_err();
assert!(matches!(
result,
MissingContentTypeReason::UnexpectedMimeType {
expected_mime: None,
found_mime: Some(_)
}
));
}

#[test]
fn check_invalid_content_type() {
let invalid = vec!["application/ajson", "text/xml"];
for invalid_mime in invalid {
let request = req_content_type(invalid_mime);
let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON);
let result = content_type_header_classifier(&request, EXPECTED_MIME_APPLICATION_JSON);

// Validates the rejection type since we cannot implement `PartialEq`
// for `MissingContentTypeReason`.
Expand All @@ -137,8 +196,11 @@ mod tests {
expected_mime,
found_mime,
} => {
assert_eq!(expected_mime, &"application/json".parse::<mime::Mime>().unwrap());
assert_eq!(found_mime, invalid_mime);
assert_eq!(
expected_mime.unwrap(),
"application/json".parse::<mime::Mime>().unwrap()
);
assert_eq!(found_mime, invalid_mime.parse::<mime::Mime>().ok());
}
_ => panic!("Unexpected `MissingContentTypeReason`: {}", e.to_string()),
},
Expand All @@ -147,19 +209,16 @@ mod tests {
}

#[test]
fn check_missing_content_type() {
fn check_missing_content_type_is_allowed() {
let request = RequestParts::new(Request::builder().body("").unwrap());
let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON);
assert!(matches!(
result.unwrap_err(),
MissingContentTypeReason::NoContentTypeHeader
));
let result = content_type_header_classifier(&request, EXPECTED_MIME_APPLICATION_JSON);
assert!(result.is_ok());
}

#[test]
fn check_not_parsable_content_type() {
let request = req_content_type("123");
let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON);
let result = content_type_header_classifier(&request, EXPECTED_MIME_APPLICATION_JSON);
assert!(matches!(
result.unwrap_err(),
MissingContentTypeReason::MimeParseError(_)
Expand All @@ -169,7 +228,7 @@ mod tests {
#[test]
fn check_non_ascii_visible_characters_content_type() {
let request = req_content_type("application/💩");
let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON);
let result = content_type_header_classifier(&request, EXPECTED_MIME_APPLICATION_JSON);
assert!(matches!(result.unwrap_err(), MissingContentTypeReason::ToStrError(_)));
}

Expand Down
4 changes: 2 additions & 2 deletions rust-runtime/aws-smithy-http-server/src/rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ pub enum MissingContentTypeReason {
ToStrError(http::header::ToStrError),
MimeParseError(mime::FromStrError),
UnexpectedMimeType {
expected_mime: &'static mime::Mime,
found_mime: mime::Mime,
expected_mime: Option<mime::Mime>,
found_mime: Option<mime::Mime>,
},
}

Expand Down
Loading

0 comments on commit c2a67d1

Please sign in to comment.