Skip to content

Commit 19fd060

Browse files
authored
Merge pull request #111 from graphql-java-kickstart/feature/100
Handle Apollo subscription onConnect
2 parents 4e91e67 + 547dc23 commit 19fd060

10 files changed

+108
-23
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package graphql.servlet;
2+
3+
import java.util.Optional;
4+
5+
public interface ApolloSubscriptionConnectionListener extends SubscriptionConnectionListener {
6+
7+
String CONNECT_RESULT_KEY = "CONNECT_RESULT";
8+
9+
default boolean isKeepAliveEnabled() {
10+
return false;
11+
}
12+
13+
default Optional<Object> onConnect(Object payload) {
14+
return Optional.empty();
15+
}
16+
17+
}

src/main/java/graphql/servlet/DefaultGraphQLContextBuilder.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import javax.servlet.http.HttpServletRequest;
44
import javax.servlet.http.HttpServletResponse;
5+
import javax.websocket.Session;
56
import javax.websocket.server.HandshakeRequest;
67

78
public class DefaultGraphQLContextBuilder implements GraphQLContextBuilder {
@@ -12,8 +13,8 @@ public GraphQLContext build(HttpServletRequest httpServletRequest, HttpServletRe
1213
}
1314

1415
@Override
15-
public GraphQLContext build(HandshakeRequest handshakeRequest) {
16-
return new GraphQLContext(handshakeRequest);
16+
public GraphQLContext build(Session session, HandshakeRequest handshakeRequest) {
17+
return new GraphQLContext(session, handshakeRequest);
1718
}
1819

1920
@Override

src/main/java/graphql/servlet/GraphQLContext.java

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,28 @@
66
import javax.servlet.http.HttpServletRequest;
77
import javax.servlet.http.HttpServletResponse;
88
import javax.servlet.http.Part;
9+
import javax.websocket.Session;
910
import javax.websocket.server.HandshakeRequest;
1011
import java.util.List;
1112
import java.util.Map;
1213
import java.util.Optional;
1314

1415
public class GraphQLContext {
16+
1517
private HttpServletRequest httpServletRequest;
1618
private HttpServletResponse httpServletResponse;
19+
private Session session;
1720
private HandshakeRequest handshakeRequest;
1821

1922
private Subject subject;
2023
private Map<String, List<Part>> files;
2124

2225
private DataLoaderRegistry dataLoaderRegistry;
2326

24-
public GraphQLContext(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, HandshakeRequest handshakeRequest, Subject subject) {
27+
public GraphQLContext(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, Session session, HandshakeRequest handshakeRequest, Subject subject) {
2528
this.httpServletRequest = httpServletRequest;
2629
this.httpServletResponse = httpServletResponse;
30+
this.session = session;
2731
this.handshakeRequest = handshakeRequest;
2832
this.subject = subject;
2933
}
@@ -33,27 +37,40 @@ public GraphQLContext(HttpServletRequest httpServletRequest) {
3337
}
3438

3539
public GraphQLContext(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
36-
this(httpServletRequest, httpServletResponse, null, null);
40+
this(httpServletRequest, httpServletResponse, null, null, null);
3741
}
3842

39-
public GraphQLContext(HandshakeRequest handshakeRequest) {
40-
this(null, null, handshakeRequest, null);
43+
public GraphQLContext(Session session, HandshakeRequest handshakeRequest) {
44+
this(null, null, session, handshakeRequest, null);
4145
}
4246

4347
public GraphQLContext() {
44-
this(null, null, null, null);
48+
this(null, null, null, null, null);
4549
}
4650

4751
public Optional<HttpServletRequest> getHttpServletRequest() {
4852
return Optional.ofNullable(httpServletRequest);
4953
}
5054

51-
public Optional<HttpServletResponse> getHttpServletResponse() { return Optional.ofNullable(httpServletResponse); }
55+
public Optional<HttpServletResponse> getHttpServletResponse() {
56+
return Optional.ofNullable(httpServletResponse);
57+
}
5258

5359
public Optional<Subject> getSubject() {
5460
return Optional.ofNullable(subject);
5561
}
5662

63+
public Optional<Session> getSession() {
64+
return Optional.ofNullable(session);
65+
}
66+
67+
public Optional<Object> getConnectResult() {
68+
if (session != null) {
69+
return Optional.ofNullable(session.getUserProperties().get(ApolloSubscriptionConnectionListener.CONNECT_RESULT_KEY));
70+
}
71+
return Optional.empty();
72+
}
73+
5774
public Optional<HandshakeRequest> getHandshakeRequest() {
5875
return Optional.ofNullable(handshakeRequest);
5976
}

src/main/java/graphql/servlet/GraphQLContextBuilder.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33
import javax.servlet.http.HttpServletRequest;
44
import javax.servlet.http.HttpServletResponse;
5+
import javax.websocket.Session;
56
import javax.websocket.server.HandshakeRequest;
67

78
public interface GraphQLContextBuilder {
9+
810
GraphQLContext build(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse);
9-
GraphQLContext build(HandshakeRequest handshakeRequest);
11+
12+
GraphQLContext build(Session session, HandshakeRequest handshakeRequest);
1013

1114
/**
1215
* Only used for MBean calls.

src/main/java/graphql/servlet/GraphQLInvocationInputFactory.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import javax.servlet.http.HttpServletRequest;
77
import javax.servlet.http.HttpServletResponse;
8+
import javax.websocket.Session;
89
import javax.websocket.server.HandshakeRequest;
910
import java.util.List;
1011
import java.util.function.Supplier;
@@ -70,20 +71,20 @@ private GraphQLBatchedInvocationInput create(List<GraphQLRequest> graphQLRequest
7071
);
7172
}
7273

73-
public GraphQLSingleInvocationInput create(GraphQLRequest graphQLRequest, HandshakeRequest request) {
74+
public GraphQLSingleInvocationInput create(GraphQLRequest graphQLRequest, Session session, HandshakeRequest request) {
7475
return new GraphQLSingleInvocationInput(
7576
graphQLRequest,
7677
schemaProviderSupplier.get().getSchema(request),
77-
contextBuilderSupplier.get().build(request),
78+
contextBuilderSupplier.get().build(session, request),
7879
rootObjectBuilderSupplier.get().build(request)
7980
);
8081
}
8182

82-
public GraphQLBatchedInvocationInput create(List<GraphQLRequest> graphQLRequest, HandshakeRequest request) {
83+
public GraphQLBatchedInvocationInput create(List<GraphQLRequest> graphQLRequest, Session session, HandshakeRequest request) {
8384
return new GraphQLBatchedInvocationInput(
8485
graphQLRequest,
8586
schemaProviderSupplier.get().getSchema(request),
86-
contextBuilderSupplier.get().build(request),
87+
contextBuilderSupplier.get().build(session, request),
8788
rootObjectBuilderSupplier.get().build(request)
8889
);
8990
}

src/main/java/graphql/servlet/GraphQLWebsocketServlet.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import java.io.IOException;
1111
import java.util.Collections;
1212
import java.util.HashMap;
13-
import java.util.HashSet;
1413
import java.util.List;
1514
import java.util.Map;
1615
import java.util.concurrent.atomic.AtomicBoolean;
@@ -48,7 +47,11 @@ public class GraphQLWebsocketServlet extends Endpoint {
4847
private final Object cacheLock = new Object();
4948

5049
public GraphQLWebsocketServlet(GraphQLQueryInvoker queryInvoker, GraphQLInvocationInputFactory invocationInputFactory, GraphQLObjectMapper graphQLObjectMapper) {
51-
this.subscriptionHandlerInput = new SubscriptionHandlerInput(invocationInputFactory, queryInvoker, graphQLObjectMapper);
50+
this(queryInvoker, invocationInputFactory, graphQLObjectMapper, null);
51+
}
52+
53+
public GraphQLWebsocketServlet(GraphQLQueryInvoker queryInvoker, GraphQLInvocationInputFactory invocationInputFactory, GraphQLObjectMapper graphQLObjectMapper, SubscriptionConnectionListener subscriptionConnectionListener) {
54+
this.subscriptionHandlerInput = new SubscriptionHandlerInput(invocationInputFactory, queryInvoker, graphQLObjectMapper, subscriptionConnectionListener);
5255
}
5356

5457
@Override
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package graphql.servlet;
2+
3+
/**
4+
* Marker interface
5+
*/
6+
public interface SubscriptionConnectionListener {
7+
}

src/main/java/graphql/servlet/internal/ApolloSubscriptionProtocolHandler.java

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import com.fasterxml.jackson.annotation.JsonInclude;
55
import com.fasterxml.jackson.annotation.JsonValue;
66
import graphql.ExecutionResult;
7+
import graphql.servlet.ApolloSubscriptionConnectionListener;
8+
import graphql.servlet.GraphQLSingleInvocationInput;
79
import org.slf4j.Logger;
810
import org.slf4j.LoggerFactory;
911

@@ -13,6 +15,7 @@
1315
import java.io.IOException;
1416
import java.util.HashMap;
1517
import java.util.Map;
18+
import java.util.Optional;
1619

1720
import static graphql.servlet.internal.ApolloSubscriptionProtocolHandler.OperationMessage.Type.GQL_COMPLETE;
1821
import static graphql.servlet.internal.ApolloSubscriptionProtocolHandler.OperationMessage.Type.GQL_CONNECTION_TERMINATE;
@@ -30,9 +33,14 @@ public class ApolloSubscriptionProtocolHandler extends SubscriptionProtocolHandl
3033
private static final CloseReason TERMINATE_CLOSE_REASON = new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "client requested " + GQL_CONNECTION_TERMINATE.getType());
3134

3235
private final SubscriptionHandlerInput input;
36+
private final ApolloSubscriptionConnectionListener connectionListener;
3337

3438
public ApolloSubscriptionProtocolHandler(SubscriptionHandlerInput subscriptionHandlerInput) {
3539
this.input = subscriptionHandlerInput;
40+
this.connectionListener = subscriptionHandlerInput.getSubscriptionConnectionListener()
41+
.filter(ApolloSubscriptionConnectionListener.class::isInstance)
42+
.map(ApolloSubscriptionConnectionListener.class::cast)
43+
.orElse(new ApolloSubscriptionConnectionListener() {});
3644
}
3745

3846
@Override
@@ -48,19 +56,28 @@ public void onMessage(HandshakeRequest request, Session session, WsSessionSubscr
4856

4957
switch(message.getType()) {
5058
case GQL_CONNECTION_INIT:
59+
try {
60+
Optional<Object> connectionResponse = connectionListener.onConnect(message.getPayload());
61+
connectionResponse.ifPresent(it -> session.getUserProperties().put(ApolloSubscriptionConnectionListener.CONNECT_RESULT_KEY, it));
62+
} catch (Throwable t) {
63+
sendMessage(session, OperationMessage.Type.GQL_CONNECTION_ERROR, t.getMessage());
64+
return;
65+
}
66+
5167
sendMessage(session, OperationMessage.Type.GQL_CONNECTION_ACK, message.getId());
52-
sendMessage(session, OperationMessage.Type.GQL_CONNECTION_KEEP_ALIVE, message.getId());
68+
69+
if (connectionListener.isKeepAliveEnabled()) {
70+
sendMessage(session, OperationMessage.Type.GQL_CONNECTION_KEEP_ALIVE, message.getId());
71+
}
5372
break;
5473

5574
case GQL_START:
75+
GraphQLSingleInvocationInput graphQLSingleInvocationInput = createInvocationInput(session, message);
5676
handleSubscriptionStart(
5777
session,
5878
subscriptions,
5979
message.id,
60-
input.getQueryInvoker().query(input.getInvocationInputFactory().create(
61-
input.getGraphQLObjectMapper().getJacksonMapper().convertValue(message.payload, GraphQLRequest.class),
62-
(HandshakeRequest) session.getUserProperties().get(HandshakeRequest.class.getName())
63-
))
80+
input.getQueryInvoker().query(graphQLSingleInvocationInput)
6481
);
6582
break;
6683

@@ -81,6 +98,16 @@ public void onMessage(HandshakeRequest request, Session session, WsSessionSubscr
8198
}
8299
}
83100

101+
private GraphQLSingleInvocationInput createInvocationInput(Session session, OperationMessage message) {
102+
GraphQLRequest graphQLRequest = input.getGraphQLObjectMapper()
103+
.getJacksonMapper()
104+
.convertValue(message.getPayload(), GraphQLRequest.class);
105+
HandshakeRequest handshakeRequest = (HandshakeRequest) session.getUserProperties()
106+
.get(HandshakeRequest.class.getName());
107+
108+
return input.getInvocationInputFactory().create(graphQLRequest, session, handshakeRequest);
109+
}
110+
84111
@SuppressWarnings("unchecked")
85112
private void handleSubscriptionStart(Session session, WsSessionSubscriptions subscriptions, String id, ExecutionResult executionResult) {
86113
executionResult = input.getGraphQLObjectMapper().sanitizeErrors(executionResult);

src/main/java/graphql/servlet/internal/SubscriptionHandlerInput.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,22 @@
33
import graphql.servlet.GraphQLInvocationInputFactory;
44
import graphql.servlet.GraphQLObjectMapper;
55
import graphql.servlet.GraphQLQueryInvoker;
6+
import graphql.servlet.SubscriptionConnectionListener;
7+
8+
import java.util.Optional;
69

710
public class SubscriptionHandlerInput {
811

912
private final GraphQLInvocationInputFactory invocationInputFactory;
1013
private final GraphQLQueryInvoker queryInvoker;
1114
private final GraphQLObjectMapper graphQLObjectMapper;
15+
private final SubscriptionConnectionListener subscriptionConnectionListener;
1216

13-
public SubscriptionHandlerInput(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper) {
17+
public SubscriptionHandlerInput(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, SubscriptionConnectionListener subscriptionConnectionListener) {
1418
this.invocationInputFactory = invocationInputFactory;
1519
this.queryInvoker = queryInvoker;
1620
this.graphQLObjectMapper = graphQLObjectMapper;
21+
this.subscriptionConnectionListener = subscriptionConnectionListener;
1722
}
1823

1924
public GraphQLInvocationInputFactory getInvocationInputFactory() {
@@ -27,4 +32,8 @@ public GraphQLQueryInvoker getQueryInvoker() {
2732
public GraphQLObjectMapper getGraphQLObjectMapper() {
2833
return graphQLObjectMapper;
2934
}
35+
36+
public Optional<SubscriptionConnectionListener> getSubscriptionConnectionListener() {
37+
return Optional.ofNullable(subscriptionConnectionListener);
38+
}
3039
}

src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,7 +1021,7 @@ class AbstractGraphQLHttpServletSpec extends Specification {
10211021

10221022
setup:
10231023
Instrumentation expectedInstrumentation = Mock()
1024-
GraphQLContext context = new GraphQLContext(request, response, null, null)
1024+
GraphQLContext context = new GraphQLContext(request, response, null, null, null)
10251025
SimpleGraphQLHttpServlet simpleGraphQLServlet = SimpleGraphQLHttpServlet
10261026
.newBuilder(TestUtils.createGraphQlSchema())
10271027
.withQueryInvoker(GraphQLQueryInvoker.newBuilder().withInstrumentation(expectedInstrumentation).build())
@@ -1037,7 +1037,7 @@ class AbstractGraphQLHttpServletSpec extends Specification {
10371037
def "getInstrumentation returns the ChainedInstrumentation if DataLoader provided in context"() {
10381038
setup:
10391039
Instrumentation servletInstrumentation = Mock()
1040-
GraphQLContext context = new GraphQLContext(request, response, null, null)
1040+
GraphQLContext context = new GraphQLContext(request, response, null, null, null)
10411041
DataLoaderRegistry dlr = Mock()
10421042
context.setDataLoaderRegistry(dlr)
10431043
SimpleGraphQLHttpServlet simpleGraphQLServlet = SimpleGraphQLHttpServlet

0 commit comments

Comments
 (0)