diff --git a/broker/src/main/java/io/moquette/broker/MQTTConnection.java b/broker/src/main/java/io/moquette/broker/MQTTConnection.java index 7aa8b6bb8..91f55af8b 100644 --- a/broker/src/main/java/io/moquette/broker/MQTTConnection.java +++ b/broker/src/main/java/io/moquette/broker/MQTTConnection.java @@ -528,7 +528,7 @@ PostOffice.RouteResult processDisconnect(MqttMessage msg) { LOG.debug("NOT processing disconnect {}, not bound.", clientID); return null; } - if (protocolVersion == MqttVersion.MQTT_5.protocolLevel()) { + if (isProtocolVersion5()) { MqttReasonCodeAndPropertiesVariableHeader disconnectHeader = (MqttReasonCodeAndPropertiesVariableHeader) msg.variableHeader(); if (disconnectHeader.reasonCode() != MqttReasonCodes.Disconnect.NORMAL_DISCONNECT.byteValue()) { // handle the will @@ -549,6 +549,10 @@ PostOffice.RouteResult processDisconnect(MqttMessage msg) { }); } + boolean isProtocolVersion5() { + return protocolVersion == MqttVersion.MQTT_5.protocolLevel(); + } + PostOffice.RouteResult processSubscribe(MqttSubscribeMessage msg) { final String clientID = NettyUtils.clientID(channel); if (!connected) { @@ -834,4 +838,21 @@ private boolean isSessionUnbound() { public void bindSession(Session session) { bindedSession = session; } + + /** + * Invoked internally by broker to disconnect a client and close the connection + * */ + void brokerDisconnect() { + final MqttMessage disconnectMsg = MqttMessageBuilders.disconnect().build(); + channel.writeAndFlush(disconnectMsg) + .addListener(ChannelFutureListener.CLOSE); + } + + void brokerDisconnect(MqttReasonCodes.Disconnect reasonCode) { + final MqttMessage disconnectMsg = MqttMessageBuilders.disconnect() + .reasonCode(reasonCode.byteValue()) + .build(); + channel.writeAndFlush(disconnectMsg) + .addListener(ChannelFutureListener.CLOSE); + } } diff --git a/broker/src/main/java/io/moquette/broker/PostOffice.java b/broker/src/main/java/io/moquette/broker/PostOffice.java index dcbec9412..86f30ea83 100644 --- a/broker/src/main/java/io/moquette/broker/PostOffice.java +++ b/broker/src/main/java/io/moquette/broker/PostOffice.java @@ -32,6 +32,7 @@ import io.netty.handler.codec.mqtt.MqttSubscribeMessage; import io.netty.handler.codec.mqtt.MqttTopicSubscription; import io.netty.util.ReferenceCountUtil; +import org.apache.commons.codec.binary.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -276,6 +277,21 @@ public void subscribeClientToTopics(MqttSubscribeMessage msg, String clientID, S MQTTConnection mqttConnection) { // verify which topics of the subscribe ongoing has read access permission int messageID = messageId(msg); + final Session session = sessionRegistry.retrieve(clientID); + + if (mqttConnection.isProtocolVersion5()) { + for (MqttTopicSubscription topicFilter : msg.payload().topicSubscriptions()) { + if (isSharedSubscription(topicFilter.topicName())) { + final String shareName = extractShareName(topicFilter.topicName()); + if (!validateShareName(shareName)) { + // this is a malformed packet, MQTT-4.13.1-1, disconnect it + LOG.info("{} used an invalid shared subscription name {}, disconnecting", clientID, shareName); + session.disconnectFromBroker(); + return; + } + } + } + } List ackTopics = authorizator.verifyTopicsReadAccess(clientID, username, msg); MqttSubAckMessage ackMessage = doAckMessageFromValidateFilters(ackTopics, messageID); @@ -292,7 +308,6 @@ public void subscribeClientToTopics(MqttSubscribeMessage msg, String clientID, S } // add the subscriptions to Session - Session session = sessionRegistry.retrieve(clientID); session.addSubscriptions(newSubscriptions); // send ack message @@ -305,6 +320,33 @@ public void subscribeClientToTopics(MqttSubscribeMessage msg, String clientID, S } } + /** + * @return the share name in the topic filter of format $share/{shareName}/{topicFilter} + * */ + // VisibleForTesting + protected static String extractShareName(String sharedTopicFilter) { + int afterShare = "$share/".length(); + int endOfShareName = sharedTopicFilter.indexOf('/', afterShare); + return sharedTopicFilter.substring(afterShare, endOfShareName); + } + + /** + * @return true if shareName is well formed, is at least one characted and doesn't contain wildcard matchers + * */ + private boolean validateShareName(String shareName) { + // MQTT-4.8.2-1 MQTT-4.8.2-2, must be longer than 1 char and do not contain + or # + Objects.requireNonNull(shareName); + return shareName.length() > 0 && !shareName.contains("+") && !shareName.contains("#"); + } + + /** + * @return true if topic filter is shared format + * */ + private static boolean isSharedSubscription(String topicFilter) { + Objects.requireNonNull(topicFilter, "topicFilter can't be null"); + return topicFilter.startsWith("$share/"); + } + private void publishRetainedMessagesForSubscriptions(String clientID, List newSubscriptions) { Session targetSession = this.sessionRegistry.retrieve(clientID); for (Subscription subscription : newSubscriptions) { diff --git a/broker/src/main/java/io/moquette/broker/Session.java b/broker/src/main/java/io/moquette/broker/Session.java index 4d0a7a937..8fde1cf75 100644 --- a/broker/src/main/java/io/moquette/broker/Session.java +++ b/broker/src/main/java/io/moquette/broker/Session.java @@ -22,13 +22,7 @@ import io.moquette.broker.subscriptions.Subscription; import io.moquette.broker.subscriptions.Topic; import io.netty.buffer.ByteBuf; -import io.netty.handler.codec.mqtt.MqttFixedHeader; -import io.netty.handler.codec.mqtt.MqttMessage; -import io.netty.handler.codec.mqtt.MqttMessageType; -import io.netty.handler.codec.mqtt.MqttPublishMessage; -import io.netty.handler.codec.mqtt.MqttPublishVariableHeader; -import io.netty.handler.codec.mqtt.MqttQoS; -import io.netty.handler.codec.mqtt.MqttVersion; +import io.netty.handler.codec.mqtt.*; import io.netty.util.ReferenceCountUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -516,6 +510,14 @@ ISessionsRepository.SessionData getSessionData() { return this.data; } + /** + * Disconnect the client from the broker, sending a disconnect and closing the connection + * */ + public void disconnectFromBroker() { + mqttConnection.brokerDisconnect(MqttReasonCodes.Disconnect.MALFORMED_PACKET); + disconnect(); + } + @Override public String toString() { return "Session{" + diff --git a/broker/src/test/java/io/moquette/broker/PostOfficeSubscribeTest.java b/broker/src/test/java/io/moquette/broker/PostOfficeSubscribeTest.java index 8d33bdb99..cfc095cf8 100644 --- a/broker/src/test/java/io/moquette/broker/PostOfficeSubscribeTest.java +++ b/broker/src/test/java/io/moquette/broker/PostOfficeSubscribeTest.java @@ -344,4 +344,11 @@ public void testLowerTheQosToTheRequestedBySubscription() { Subscription subQos2 = new Subscription("Sub B", new Topic("a/+"), EXACTLY_ONCE); assertEquals(EXACTLY_ONCE, PostOffice.lowerQosToTheSubscriptionDesired(subQos2, EXACTLY_ONCE)); } + + @Test + public void testExtractShareName() { + assertEquals("", PostOffice.extractShareName("$share//measures/+/1")); + assertEquals("myShared", PostOffice.extractShareName("$share/myShared/measures/+/1")); + assertEquals("#", PostOffice.extractShareName("$share/#/measures/+/1")); + } } diff --git a/broker/src/test/java/io/moquette/integration/mqtt5/ConnectTest.java b/broker/src/test/java/io/moquette/integration/mqtt5/ConnectTest.java index 83a7e1782..290030ee9 100644 --- a/broker/src/test/java/io/moquette/integration/mqtt5/ConnectTest.java +++ b/broker/src/test/java/io/moquette/integration/mqtt5/ConnectTest.java @@ -128,7 +128,7 @@ public void receiveInflightPublishesAfterAReconnect() { reconnectingSubscriber.disconnect(); } - private void assertConnectionAccepted(MqttConnAckMessage connAck, String message) { + public static void assertConnectionAccepted(MqttConnAckMessage connAck, String message) { assertEquals(MqttConnectReturnCode.CONNECTION_ACCEPTED, connAck.variableHeader().connectReturnCode(), message); } diff --git a/broker/src/test/java/io/moquette/integration/mqtt5/SharedSubscriptionTest.java b/broker/src/test/java/io/moquette/integration/mqtt5/SharedSubscriptionTest.java new file mode 100644 index 000000000..ad760edb3 --- /dev/null +++ b/broker/src/test/java/io/moquette/integration/mqtt5/SharedSubscriptionTest.java @@ -0,0 +1,47 @@ +package io.moquette.integration.mqtt5; + +import com.hivemq.client.mqtt.MqttClient; +import com.hivemq.client.mqtt.mqtt5.Mqtt5BlockingClient; +import com.hivemq.client.mqtt.mqtt5.message.connect.connack.Mqtt5ConnAckReasonCode; +import io.netty.handler.codec.mqtt.MqttConnAckMessage; +import io.netty.handler.codec.mqtt.MqttMessage; +import io.netty.handler.codec.mqtt.MqttQoS; +import io.netty.handler.codec.mqtt.MqttReasonCodeAndPropertiesVariableHeader; +import io.netty.handler.codec.mqtt.MqttReasonCodes; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.Test; + +import static io.moquette.integration.mqtt5.ConnectTest.assertConnectionAccepted; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class SharedSubscriptionTest extends AbstractServerIntegrationTest { + + @Override + public String clientName() { + return "subscriber"; + } + + @Test + public void givenAClientSendingBadlyFormattedSharedSubscriptionNameThenItIsDisconnected() { + MqttConnAckMessage connAck = lowLevelClient.connectV5(); + assertConnectionAccepted(connAck, "Connection must be accepted"); + + MqttMessage received = lowLevelClient.subscribeWithError("$share/+/measures/temp", MqttQoS.AT_LEAST_ONCE); + + // verify received is a disconnect with an error + MqttReasonCodeAndPropertiesVariableHeader disconnectHeader = (MqttReasonCodeAndPropertiesVariableHeader) received.variableHeader(); + assertEquals(MqttReasonCodes.Disconnect.MALFORMED_PACKET.byteValue(), disconnectHeader.reasonCode()); + } + + @NotNull + private Mqtt5BlockingClient createSubscriberClient() { + final Mqtt5BlockingClient client = MqttClient.builder() + .useMqttVersion5() + .identifier(clientName()) + .serverHost("localhost") + .serverPort(1883) + .buildBlocking(); + assertEquals(Mqtt5ConnAckReasonCode.SUCCESS, client.connect().getReasonCode(), "Subscriber connected"); + return client; + } +} diff --git a/broker/src/test/java/io/moquette/testclient/Client.java b/broker/src/test/java/io/moquette/testclient/Client.java index e98174c40..5cad95f89 100644 --- a/broker/src/test/java/io/moquette/testclient/Client.java +++ b/broker/src/test/java/io/moquette/testclient/Client.java @@ -218,6 +218,39 @@ public MqttSubAckMessage subscribe(String topic, MqttQoS qos) { return (MqttSubAckMessage) subAckMessage; } + public MqttMessage subscribeWithError(String topic, MqttQoS qos) { + final MqttSubscribeMessage subscribeMessage = MqttMessageBuilders.subscribe() + .messageId(1) + .addSubscription(qos, topic) + .build(); + + final CountDownLatch subscribeAckLatch = new CountDownLatch(1); + this.setCallback(msg -> { + receivedMsg.getAndSet(msg); + LOG.debug("Subscribe callback invocation, received message {}", msg.fixedHeader().messageType()); + subscribeAckLatch.countDown(); + + // clear the callback + setCallback(null); + }); + + LOG.debug("Sending SUBSCRIBE message"); + sendMessage(subscribeMessage); + LOG.debug("Sent SUBSCRIBE message"); + + boolean waitElapsed; + try { + waitElapsed = !subscribeAckLatch.await(200, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while waiting", e); + } + + if (waitElapsed) { + throw new RuntimeException("Cannot receive SubscribeAck in 200 ms"); + } + return this.receivedMsg.get(); + } + public void disconnect() { final MqttMessage disconnectMessage = MqttMessageBuilders.disconnect().build(); sendMessage(disconnectMessage);