Skip to content

Commit 841bf58

Browse files
authored
Merge pull request #112 from graphql-java-kickstart/bugfix/91
Fix Apollo keep alive implementation
2 parents 77526ec + a583b8b commit 841bf58

12 files changed

+203
-51
lines changed
Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,42 @@
11
package graphql.servlet;
22

3+
import java.time.Duration;
34
import java.util.Optional;
45

56
public interface ApolloSubscriptionConnectionListener extends SubscriptionConnectionListener {
67

8+
long KEEP_ALIVE_INTERVAL_SEC = 15;
9+
710
String CONNECT_RESULT_KEY = "CONNECT_RESULT";
811

912
default boolean isKeepAliveEnabled() {
10-
return false;
13+
return true;
1114
}
1215

13-
default Optional<Object> onConnect(Object payload) {
16+
default Optional<Object> onConnect(Object payload) throws SubscriptionException {
1417
return Optional.empty();
1518
}
1619

20+
default Duration getKeepAliveInterval() {
21+
return Duration.ofSeconds(KEEP_ALIVE_INTERVAL_SEC);
22+
}
23+
24+
static ApolloSubscriptionConnectionListener createWithKeepAliveDisabled() {
25+
return new ApolloSubscriptionConnectionListener() {
26+
@Override
27+
public boolean isKeepAliveEnabled() {
28+
return false;
29+
}
30+
};
31+
}
32+
33+
static ApolloSubscriptionConnectionListener createWithKeepAliveInterval(Duration interval) {
34+
return new ApolloSubscriptionConnectionListener() {
35+
@Override
36+
public Duration getKeepAliveInterval() {
37+
return interval;
38+
}
39+
};
40+
}
41+
1742
}

src/main/java/graphql/servlet/GraphQLWebsocketServlet.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,9 @@ public class GraphQLWebsocketServlet extends Endpoint {
3030
private static final CloseReason ERROR_CLOSE_REASON = new CloseReason(CloseReason.CloseCodes.UNEXPECTED_CONDITION, "Internal Server Error");
3131
private static final CloseReason SHUTDOWN_CLOSE_REASON = new CloseReason(CloseReason.CloseCodes.UNEXPECTED_CONDITION, "Server Shut Down");
3232

33-
private static final List<SubscriptionProtocolFactory> subscriptionProtocolFactories = Collections.singletonList(new ApolloSubscriptionProtocolFactory());
34-
private static final SubscriptionProtocolFactory fallbackSubscriptionProtocolFactory = new FallbackSubscriptionProtocolFactory();
35-
private static final List<String> allSubscriptionProtocols;
36-
37-
static {
38-
allSubscriptionProtocols = Stream.concat(subscriptionProtocolFactories.stream(), Stream.of(fallbackSubscriptionProtocolFactory))
39-
.map(SubscriptionProtocolFactory::getProtocol)
40-
.collect(Collectors.toList());
41-
}
33+
private final List<SubscriptionProtocolFactory> subscriptionProtocolFactories;
34+
private final SubscriptionProtocolFactory fallbackSubscriptionProtocolFactory;
35+
private final List<String> allSubscriptionProtocols;
4236

4337
private final Map<Session, WsSessionSubscriptions> sessionSubscriptionCache = new HashMap<>();
4438
private final SubscriptionHandlerInput subscriptionHandlerInput;
@@ -52,6 +46,12 @@ public GraphQLWebsocketServlet(GraphQLQueryInvoker queryInvoker, GraphQLInvocati
5246

5347
public GraphQLWebsocketServlet(GraphQLQueryInvoker queryInvoker, GraphQLInvocationInputFactory invocationInputFactory, GraphQLObjectMapper graphQLObjectMapper, SubscriptionConnectionListener subscriptionConnectionListener) {
5448
this.subscriptionHandlerInput = new SubscriptionHandlerInput(invocationInputFactory, queryInvoker, graphQLObjectMapper, subscriptionConnectionListener);
49+
50+
subscriptionProtocolFactories = Collections.singletonList(new ApolloSubscriptionProtocolFactory(subscriptionHandlerInput));
51+
fallbackSubscriptionProtocolFactory = new FallbackSubscriptionProtocolFactory(subscriptionHandlerInput);
52+
allSubscriptionProtocols = Stream.concat(subscriptionProtocolFactories.stream(), Stream.of(fallbackSubscriptionProtocolFactory))
53+
.map(SubscriptionProtocolFactory::getProtocol)
54+
.collect(Collectors.toList());
5555
}
5656

5757
@Override
@@ -119,7 +119,7 @@ public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request,
119119
}
120120

121121
SubscriptionProtocolFactory subscriptionProtocolFactory = getSubscriptionProtocolFactory(protocol);
122-
sec.getUserProperties().put(PROTOCOL_HANDLER_REQUEST_KEY, subscriptionProtocolFactory.createHandler(subscriptionHandlerInput));
122+
sec.getUserProperties().put(PROTOCOL_HANDLER_REQUEST_KEY, subscriptionProtocolFactory.createHandler());
123123

124124
if (request.getHeaders().get(HandshakeResponse.SEC_WEBSOCKET_ACCEPT) != null) {
125125
response.getHeaders().put(HandshakeResponse.SEC_WEBSOCKET_ACCEPT, allSubscriptionProtocols);
@@ -165,7 +165,7 @@ public boolean isShutDown() {
165165
return isShutDown.get();
166166
}
167167

168-
private static SubscriptionProtocolFactory getSubscriptionProtocolFactory(List<String> accept) {
168+
private SubscriptionProtocolFactory getSubscriptionProtocolFactory(List<String> accept) {
169169
for (String protocol : accept) {
170170
for (SubscriptionProtocolFactory subscriptionProtocolFactory : subscriptionProtocolFactories) {
171171
if (subscriptionProtocolFactory.getProtocol().equals(protocol)) {
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package graphql.servlet;
2+
3+
public class SubscriptionException extends Exception {
4+
5+
private Object payload;
6+
7+
public SubscriptionException() {
8+
}
9+
10+
public SubscriptionException(Object payload) {
11+
this.payload = payload;
12+
}
13+
14+
public Object getPayload() {
15+
return payload;
16+
}
17+
18+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package graphql.servlet.internal;
2+
3+
import org.slf4j.Logger;
4+
import org.slf4j.LoggerFactory;
5+
6+
import javax.websocket.Session;
7+
import java.time.Duration;
8+
import java.util.Map;
9+
import java.util.Objects;
10+
import java.util.concurrent.ConcurrentHashMap;
11+
import java.util.concurrent.Executors;
12+
import java.util.concurrent.Future;
13+
import java.util.concurrent.ScheduledExecutorService;
14+
import java.util.concurrent.ScheduledFuture;
15+
import java.util.concurrent.TimeUnit;
16+
17+
class ApolloSubscriptionKeepAliveRunner {
18+
19+
private static final Logger LOG = LoggerFactory.getLogger(ApolloSubscriptionKeepAliveRunner.class);
20+
21+
private static final int EXECUTOR_POOL_SIZE = 10;
22+
23+
private final ScheduledExecutorService executor;
24+
private final SubscriptionSender sender;
25+
private final ApolloSubscriptionProtocolHandler.OperationMessage keepAliveMessage;
26+
private final Map<Session, Future<?>> futures;
27+
private final long keepAliveIntervalSeconds;
28+
29+
ApolloSubscriptionKeepAliveRunner(SubscriptionSender sender, Duration keepAliveInterval) {
30+
this.sender = Objects.requireNonNull(sender);
31+
this.keepAliveMessage = ApolloSubscriptionProtocolHandler.OperationMessage.newKeepAliveMessage();
32+
this.executor = Executors.newScheduledThreadPool(EXECUTOR_POOL_SIZE);
33+
this.futures = new ConcurrentHashMap<>();
34+
this.keepAliveIntervalSeconds = keepAliveInterval.getSeconds();
35+
}
36+
37+
void keepAlive(Session session) {
38+
futures.computeIfAbsent(session, this::startKeepAlive);
39+
}
40+
41+
private ScheduledFuture<?> startKeepAlive(Session session) {
42+
return executor.scheduleAtFixedRate(() -> {
43+
try {
44+
if (session.isOpen()) {
45+
sender.send(session, keepAliveMessage);
46+
} else {
47+
LOG.warn("Session appears to be closed. Aborting keep alive");
48+
abort(session);
49+
}
50+
} catch (Throwable t) {
51+
LOG.error("Cannot send keep alive message. Aborting keep alive", t);
52+
abort(session);
53+
}
54+
}, 0, keepAliveIntervalSeconds, TimeUnit.SECONDS);
55+
}
56+
57+
void abort(Session session) {
58+
Future<?> future = futures.remove(session);
59+
if (future != null) {
60+
future.cancel(true);
61+
}
62+
}
63+
64+
}
Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,30 @@
11
package graphql.servlet.internal;
22

3+
import graphql.servlet.ApolloSubscriptionConnectionListener;
4+
35
/**
46
* @author Andrew Potter
57
*/
68
public class ApolloSubscriptionProtocolFactory extends SubscriptionProtocolFactory {
7-
public ApolloSubscriptionProtocolFactory() {
9+
private final SubscriptionHandlerInput subscriptionHandlerInput;
10+
private final SubscriptionSender subscriptionSender;
11+
private final ApolloSubscriptionKeepAliveRunner keepAliveRunner;
12+
private final ApolloSubscriptionConnectionListener connectionListener;
13+
14+
public ApolloSubscriptionProtocolFactory(SubscriptionHandlerInput subscriptionHandlerInput) {
815
super("graphql-ws");
16+
this.subscriptionHandlerInput = subscriptionHandlerInput;
17+
this.connectionListener = subscriptionHandlerInput.getSubscriptionConnectionListener()
18+
.filter(ApolloSubscriptionConnectionListener.class::isInstance)
19+
.map(ApolloSubscriptionConnectionListener.class::cast)
20+
.orElse(new ApolloSubscriptionConnectionListener() {});
21+
subscriptionSender =
22+
new SubscriptionSender(subscriptionHandlerInput.getGraphQLObjectMapper().getJacksonMapper());
23+
keepAliveRunner = new ApolloSubscriptionKeepAliveRunner(subscriptionSender, connectionListener.getKeepAliveInterval());
924
}
1025

1126
@Override
12-
public SubscriptionProtocolHandler createHandler(SubscriptionHandlerInput subscriptionHandlerInput) {
13-
return new ApolloSubscriptionProtocolHandler(subscriptionHandlerInput);
27+
public SubscriptionProtocolHandler createHandler() {
28+
return new ApolloSubscriptionProtocolHandler(subscriptionHandlerInput, connectionListener, subscriptionSender, keepAliveRunner);
1429
}
1530
}

src/main/java/graphql/servlet/internal/ApolloSubscriptionProtocolHandler.java

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import graphql.ExecutionResult;
77
import graphql.servlet.ApolloSubscriptionConnectionListener;
88
import graphql.servlet.GraphQLSingleInvocationInput;
9+
import graphql.servlet.SubscriptionException;
910
import org.slf4j.Logger;
1011
import org.slf4j.LoggerFactory;
1112

@@ -33,14 +34,18 @@ public class ApolloSubscriptionProtocolHandler extends SubscriptionProtocolHandl
3334
private static final CloseReason TERMINATE_CLOSE_REASON = new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "client requested " + GQL_CONNECTION_TERMINATE.getType());
3435

3536
private final SubscriptionHandlerInput input;
37+
private final SubscriptionSender sender;
38+
private final ApolloSubscriptionKeepAliveRunner keepAliveRunner;
3639
private final ApolloSubscriptionConnectionListener connectionListener;
3740

38-
public ApolloSubscriptionProtocolHandler(SubscriptionHandlerInput subscriptionHandlerInput) {
41+
public ApolloSubscriptionProtocolHandler(SubscriptionHandlerInput subscriptionHandlerInput,
42+
ApolloSubscriptionConnectionListener connectionListener,
43+
SubscriptionSender subscriptionSender,
44+
ApolloSubscriptionKeepAliveRunner keepAliveRunner) {
3945
this.input = subscriptionHandlerInput;
40-
this.connectionListener = subscriptionHandlerInput.getSubscriptionConnectionListener()
41-
.filter(ApolloSubscriptionConnectionListener.class::isInstance)
42-
.map(ApolloSubscriptionConnectionListener.class::cast)
43-
.orElse(new ApolloSubscriptionConnectionListener() {});
46+
this.connectionListener = connectionListener;
47+
this.sender = subscriptionSender;
48+
this.keepAliveRunner = keepAliveRunner;
4449
}
4550

4651
@Override
@@ -54,20 +59,20 @@ public void onMessage(HandshakeRequest request, Session session, WsSessionSubscr
5459
return;
5560
}
5661

57-
switch(message.getType()) {
62+
switch (message.getType()) {
5863
case GQL_CONNECTION_INIT:
5964
try {
6065
Optional<Object> connectionResponse = connectionListener.onConnect(message.getPayload());
6166
connectionResponse.ifPresent(it -> session.getUserProperties().put(ApolloSubscriptionConnectionListener.CONNECT_RESULT_KEY, it));
62-
} catch (Throwable t) {
63-
sendMessage(session, OperationMessage.Type.GQL_CONNECTION_ERROR, t.getMessage());
67+
} catch (SubscriptionException e) {
68+
sendMessage(session, OperationMessage.Type.GQL_CONNECTION_ERROR, message.getId(), e.getPayload());
6469
return;
6570
}
6671

6772
sendMessage(session, OperationMessage.Type.GQL_CONNECTION_ACK, message.getId());
6873

6974
if (connectionListener.isKeepAliveEnabled()) {
70-
sendMessage(session, OperationMessage.Type.GQL_CONNECTION_KEEP_ALIVE, message.getId());
75+
keepAliveRunner.keepAlive(session);
7176
}
7277
break;
7378

@@ -86,6 +91,7 @@ public void onMessage(HandshakeRequest request, Session session, WsSessionSubscr
8691
break;
8792

8893
case GQL_CONNECTION_TERMINATE:
94+
keepAliveRunner.abort(session);
8995
try {
9096
session.close(TERMINATE_CLOSE_REASON);
9197
} catch (IOException e) {
@@ -112,7 +118,7 @@ private GraphQLSingleInvocationInput createInvocationInput(Session session, Oper
112118
private void handleSubscriptionStart(Session session, WsSessionSubscriptions subscriptions, String id, ExecutionResult executionResult) {
113119
executionResult = input.getGraphQLObjectMapper().sanitizeErrors(executionResult);
114120

115-
if(input.getGraphQLObjectMapper().areErrorsPresent(executionResult)) {
121+
if (input.getGraphQLObjectMapper().areErrorsPresent(executionResult)) {
116122
sendMessage(session, OperationMessage.Type.GQL_ERROR, id, input.getGraphQLObjectMapper().convertSanitizedExecutionResult(executionResult, false));
117123
return;
118124
}
@@ -127,11 +133,13 @@ protected void sendDataMessage(Session session, String id, Object payload) {
127133

128134
@Override
129135
protected void sendErrorMessage(Session session, String id) {
136+
keepAliveRunner.abort(session);
130137
sendMessage(session, GQL_ERROR, id);
131138
}
132139

133140
@Override
134141
protected void sendCompleteMessage(Session session, String id) {
142+
keepAliveRunner.abort(session);
135143
sendMessage(session, GQL_COMPLETE, id);
136144
}
137145

@@ -140,13 +148,7 @@ private void sendMessage(Session session, OperationMessage.Type type, String id)
140148
}
141149

142150
private void sendMessage(Session session, OperationMessage.Type type, String id, Object payload) {
143-
try {
144-
session.getBasicRemote().sendText(input.getGraphQLObjectMapper().getJacksonMapper().writeValueAsString(
145-
new OperationMessage(type, id, payload)
146-
));
147-
} catch (IOException e) {
148-
throw new RuntimeException("Error sending subscription response", e);
149-
}
151+
sender.send(session, new OperationMessage(type, id, payload));
150152
}
151153

152154
@JsonInclude(JsonInclude.Include.NON_NULL)
@@ -164,6 +166,10 @@ public OperationMessage(Type type, String id, Object payload) {
164166
this.payload = payload;
165167
}
166168

169+
static OperationMessage newKeepAliveMessage() {
170+
return new OperationMessage(Type.GQL_CONNECTION_KEEP_ALIVE, null, null);
171+
}
172+
167173
public Type getType() {
168174
return type;
169175
}

src/main/java/graphql/servlet/internal/FallbackSubscriptionProtocolFactory.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
* @author Andrew Potter
55
*/
66
public class FallbackSubscriptionProtocolFactory extends SubscriptionProtocolFactory {
7-
public FallbackSubscriptionProtocolFactory() {
7+
private final SubscriptionHandlerInput subscriptionHandlerInput;
8+
9+
public FallbackSubscriptionProtocolFactory(SubscriptionHandlerInput subscriptionHandlerInput) {
810
super("");
11+
this.subscriptionHandlerInput = subscriptionHandlerInput;
912
}
1013

1114
@Override
12-
public SubscriptionProtocolHandler createHandler(SubscriptionHandlerInput subscriptionHandlerInput) {
15+
public SubscriptionProtocolHandler createHandler() {
1316
return new FallbackSubscriptionProtocolHandler(subscriptionHandlerInput);
1417
}
1518
}

src/main/java/graphql/servlet/internal/FallbackSubscriptionProtocolHandler.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
public class FallbackSubscriptionProtocolHandler extends SubscriptionProtocolHandler {
1212

1313
private final SubscriptionHandlerInput input;
14+
private final SubscriptionSender sender;
1415

1516
public FallbackSubscriptionProtocolHandler(SubscriptionHandlerInput subscriptionHandlerInput) {
1617
this.input = subscriptionHandlerInput;
18+
sender = new SubscriptionSender(subscriptionHandlerInput.getGraphQLObjectMapper().getJacksonMapper());
1719
}
1820

1921
@Override
@@ -32,11 +34,7 @@ public void onMessage(HandshakeRequest request, Session session, WsSessionSubscr
3234

3335
@Override
3436
protected void sendDataMessage(Session session, String id, Object payload) {
35-
try {
36-
session.getBasicRemote().sendText(input.getGraphQLObjectMapper().getJacksonMapper().writeValueAsString(payload));
37-
} catch (IOException e) {
38-
throw new RuntimeException("Error sending subscription response", e);
39-
}
37+
sender.send(session, payload);
4038
}
4139

4240
@Override

src/main/java/graphql/servlet/internal/SubscriptionProtocolFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ public String getProtocol() {
1414
return protocol;
1515
}
1616

17-
public abstract SubscriptionProtocolHandler createHandler(SubscriptionHandlerInput subscriptionHandlerInput);
17+
public abstract SubscriptionProtocolHandler createHandler();
1818
}

src/main/java/graphql/servlet/internal/SubscriptionProtocolHandler.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ public void onNext(ExecutionResult executionResult) {
5555
@Override
5656
public void onError(Throwable throwable) {
5757
log.error("Subscription error", throwable);
58-
subscriptions.cancel(id);
58+
unsubscribe(subscriptions, id);
5959
sendErrorMessage(session, id);
6060
}
6161

6262
@Override
6363
public void onComplete() {
64-
subscriptions.cancel(id);
64+
unsubscribe(subscriptions, id);
6565
sendCompleteMessage(session, id);
6666
}
6767
});

0 commit comments

Comments
 (0)