Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add shape deserializer overrides #512

Merged
merged 1 commit into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
Expand Down Expand Up @@ -75,6 +77,7 @@ public abstract class HttpBindingProtocolGenerator implements ProtocolGenerator
private final Set<Shape> serializeDocumentBindingShapes = new TreeSet<>();
private final Set<Shape> deserializeDocumentBindingShapes = new TreeSet<>();
private final Set<StructureShape> deserializingErrorShapes = new TreeSet<>();
private final Map<ShapeId, Symbol> deserializerOverrides = new HashMap<>();

/**
* Creates a Http binding protocol generator.
Expand Down Expand Up @@ -1082,6 +1085,13 @@ private String conditionallyBase64Encode(

@Override
public void generateResponseDeserializers(GenerationContext context) {
deserializerOverrides.putAll(
context.getIntegrations().stream()
.flatMap(it -> it.getClientPlugins(context.getModel(), context.getService()).stream())
.flatMap(it -> it.getShapeDeserializers().entrySet().stream())
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))
);

EventStreamIndex streamIndex = EventStreamIndex.of(context.getModel());

for (OperationShape operation : getHttpBindingOperations(context)) {
Expand Down Expand Up @@ -1347,13 +1357,24 @@ private void writeHeaderDeserializerFunction(
) {
writer.openBlock("if headerValues := response.Header.Values($S); len(headerValues) != 0 {", "}",
binding.getLocationName(), () -> {
Shape targetShape = context.getModel().expectShape(memberShape.getTarget());
var target = memberShape.getTarget();
Shape targetShape = context.getModel().expectShape(target);

String operand = "headerValues";
operand = writeHeaderValueAccessor(context, writer, targetShape, binding, operand);

String value = generateHttpHeaderValue(context, writer, memberShape, binding,
operand);
if (deserializerOverrides.containsKey(target)) {
writer.write("""
deserOverride, err := $T($L)
if err != nil {
return err
}
v.$L = deserOverride
""", deserializerOverrides.get(target), operand, memberName);
return;
}

var value = generateHttpHeaderValue(context, writer, memberShape, binding, operand);
writer.write("v.$L = $L", memberName,
CodegenUtils.getAsPointerIfPointable(context.getModel(), writer,
GoPointableIndex.of(context.getModel()), memberShape, value));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.function.BiPredicate;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.auth.AuthParameter;
import software.amazon.smithy.go.codegen.auth.AuthParametersResolver;
Expand Down Expand Up @@ -54,6 +55,7 @@ public final class RuntimeClientPlugin implements ToSmithyBuilder<RuntimeClientP
private final MiddlewareRegistrar registerMiddleware;
private final Map<String, GoWriter.Writable> endpointBuiltinBindings;
private final Map<ShapeId, AuthSchemeDefinition> authSchemeDefinitions;
private final Map<ShapeId, Symbol> shapeDeserializers;

private RuntimeClientPlugin(Builder builder) {
operationPredicate = builder.operationPredicate;
Expand All @@ -67,6 +69,7 @@ private RuntimeClientPlugin(Builder builder) {
configFieldResolvers = builder.configFieldResolvers;
endpointBuiltinBindings = builder.endpointBuiltinBindings;
authSchemeDefinitions = builder.authSchemeDefinitions;
shapeDeserializers = builder.shapeDeserializers;
}

@FunctionalInterface
Expand Down Expand Up @@ -130,6 +133,14 @@ public Map<ShapeId, AuthSchemeDefinition> getAuthSchemeDefinitions() {
return authSchemeDefinitions;
}

/**
* Gets the registered shape deserializers.
* @return the deserializers.
*/
public Map<ShapeId, Symbol> getShapeDeserializers() {
return shapeDeserializers;
}

/**
* Gets the optionally present middleware registrar object that resolves to middleware registering function.
*
Expand Down Expand Up @@ -242,6 +253,7 @@ public static final class Builder implements SmithyBuilder<RuntimeClientPlugin>
private Map<String, GoWriter.Writable> endpointBuiltinBindings = new HashMap<>();
private MiddlewareRegistrar registerMiddleware;
private Map<ShapeId, AuthSchemeDefinition> authSchemeDefinitions = new HashMap<>();
private Map<ShapeId, Symbol> shapeDeserializers = new HashMap<>();

@Override
public RuntimeClientPlugin build() {
Expand Down Expand Up @@ -496,5 +508,18 @@ public Builder addAuthSchemeDefinition(ShapeId schemeId, AuthSchemeDefinition de
this.authSchemeDefinitions.put(schemeId, definition);
return this;
}

/**
* Registers a codegen definition for a custom shape deserializer. This feature is currently only supported for
* overriding deserialization in HTTP bindings.
* @param id The shape id.
* @param deserializer The deserializer symbol. The written code MUST be a function which accepts the
* corresponding type for the shape and returns (*type, error) accordingly.
* @return Returns the builder.
*/
public Builder addShapeDeserializer(ShapeId id, Symbol deserializer) {
this.shapeDeserializers.put(id, deserializer);
return this;
}
}
}
Loading