Skip to content

Commit

Permalink
Validate shared subscription and closed connection if not valid (moqu…
Browse files Browse the repository at this point in the history
…ette-io#795)

During the processing of an MQTT5 SUBSCRIBE message search all topicFilters for shared subscriptions and check for validity respect to specs [MQTT-4.8.2-1] [MQTT-4.8.2-1]
  • Loading branch information
andsel authored Nov 18, 2023
1 parent 6a7bc49 commit 39b1688
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 10 deletions.
23 changes: 22 additions & 1 deletion broker/src/main/java/io/moquette/broker/MQTTConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ PostOffice.RouteResult processDisconnect(MqttMessage msg) {
LOG.debug("NOT processing disconnect {}, not bound.", clientID);
return null;
}
if (protocolVersion == MqttVersion.MQTT_5.protocolLevel()) {
if (isProtocolVersion5()) {
MqttReasonCodeAndPropertiesVariableHeader disconnectHeader = (MqttReasonCodeAndPropertiesVariableHeader) msg.variableHeader();
if (disconnectHeader.reasonCode() != MqttReasonCodes.Disconnect.NORMAL_DISCONNECT.byteValue()) {
// handle the will
Expand All @@ -549,6 +549,10 @@ PostOffice.RouteResult processDisconnect(MqttMessage msg) {
});
}

boolean isProtocolVersion5() {
return protocolVersion == MqttVersion.MQTT_5.protocolLevel();
}

PostOffice.RouteResult processSubscribe(MqttSubscribeMessage msg) {
final String clientID = NettyUtils.clientID(channel);
if (!connected) {
Expand Down Expand Up @@ -834,4 +838,21 @@ private boolean isSessionUnbound() {
public void bindSession(Session session) {
bindedSession = session;
}

/**
* Invoked internally by broker to disconnect a client and close the connection
* */
void brokerDisconnect() {
final MqttMessage disconnectMsg = MqttMessageBuilders.disconnect().build();
channel.writeAndFlush(disconnectMsg)
.addListener(ChannelFutureListener.CLOSE);
}

void brokerDisconnect(MqttReasonCodes.Disconnect reasonCode) {
final MqttMessage disconnectMsg = MqttMessageBuilders.disconnect()
.reasonCode(reasonCode.byteValue())
.build();
channel.writeAndFlush(disconnectMsg)
.addListener(ChannelFutureListener.CLOSE);
}
}
44 changes: 43 additions & 1 deletion broker/src/main/java/io/moquette/broker/PostOffice.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import io.netty.handler.codec.mqtt.MqttSubscribeMessage;
import io.netty.handler.codec.mqtt.MqttTopicSubscription;
import io.netty.util.ReferenceCountUtil;
import org.apache.commons.codec.binary.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -276,6 +277,21 @@ public void subscribeClientToTopics(MqttSubscribeMessage msg, String clientID, S
MQTTConnection mqttConnection) {
// verify which topics of the subscribe ongoing has read access permission
int messageID = messageId(msg);
final Session session = sessionRegistry.retrieve(clientID);

if (mqttConnection.isProtocolVersion5()) {
for (MqttTopicSubscription topicFilter : msg.payload().topicSubscriptions()) {
if (isSharedSubscription(topicFilter.topicName())) {
final String shareName = extractShareName(topicFilter.topicName());
if (!validateShareName(shareName)) {
// this is a malformed packet, MQTT-4.13.1-1, disconnect it
LOG.info("{} used an invalid shared subscription name {}, disconnecting", clientID, shareName);
session.disconnectFromBroker();
return;
}
}
}
}
List<MqttTopicSubscription> ackTopics = authorizator.verifyTopicsReadAccess(clientID, username, msg);
MqttSubAckMessage ackMessage = doAckMessageFromValidateFilters(ackTopics, messageID);

Expand All @@ -292,7 +308,6 @@ public void subscribeClientToTopics(MqttSubscribeMessage msg, String clientID, S
}

// add the subscriptions to Session
Session session = sessionRegistry.retrieve(clientID);
session.addSubscriptions(newSubscriptions);

// send ack message
Expand All @@ -305,6 +320,33 @@ public void subscribeClientToTopics(MqttSubscribeMessage msg, String clientID, S
}
}

/**
* @return the share name in the topic filter of format $share/{shareName}/{topicFilter}
* */
// VisibleForTesting
protected static String extractShareName(String sharedTopicFilter) {
int afterShare = "$share/".length();
int endOfShareName = sharedTopicFilter.indexOf('/', afterShare);
return sharedTopicFilter.substring(afterShare, endOfShareName);
}

/**
* @return true if shareName is well formed, is at least one characted and doesn't contain wildcard matchers
* */
private boolean validateShareName(String shareName) {
// MQTT-4.8.2-1 MQTT-4.8.2-2, must be longer than 1 char and do not contain + or #
Objects.requireNonNull(shareName);
return shareName.length() > 0 && !shareName.contains("+") && !shareName.contains("#");
}

/**
* @return true if topic filter is shared format
* */
private static boolean isSharedSubscription(String topicFilter) {
Objects.requireNonNull(topicFilter, "topicFilter can't be null");
return topicFilter.startsWith("$share/");
}

private void publishRetainedMessagesForSubscriptions(String clientID, List<Subscription> newSubscriptions) {
Session targetSession = this.sessionRegistry.retrieve(clientID);
for (Subscription subscription : newSubscriptions) {
Expand Down
16 changes: 9 additions & 7 deletions broker/src/main/java/io/moquette/broker/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,7 @@
import io.moquette.broker.subscriptions.Subscription;
import io.moquette.broker.subscriptions.Topic;
import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.mqtt.MqttFixedHeader;
import io.netty.handler.codec.mqtt.MqttMessage;
import io.netty.handler.codec.mqtt.MqttMessageType;
import io.netty.handler.codec.mqtt.MqttPublishMessage;
import io.netty.handler.codec.mqtt.MqttPublishVariableHeader;
import io.netty.handler.codec.mqtt.MqttQoS;
import io.netty.handler.codec.mqtt.MqttVersion;
import io.netty.handler.codec.mqtt.*;
import io.netty.util.ReferenceCountUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -516,6 +510,14 @@ ISessionsRepository.SessionData getSessionData() {
return this.data;
}

/**
* Disconnect the client from the broker, sending a disconnect and closing the connection
* */
public void disconnectFromBroker() {
mqttConnection.brokerDisconnect(MqttReasonCodes.Disconnect.MALFORMED_PACKET);
disconnect();
}

@Override
public String toString() {
return "Session{" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,4 +344,11 @@ public void testLowerTheQosToTheRequestedBySubscription() {
Subscription subQos2 = new Subscription("Sub B", new Topic("a/+"), EXACTLY_ONCE);
assertEquals(EXACTLY_ONCE, PostOffice.lowerQosToTheSubscriptionDesired(subQos2, EXACTLY_ONCE));
}

@Test
public void testExtractShareName() {
assertEquals("", PostOffice.extractShareName("$share//measures/+/1"));
assertEquals("myShared", PostOffice.extractShareName("$share/myShared/measures/+/1"));
assertEquals("#", PostOffice.extractShareName("$share/#/measures/+/1"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ public void receiveInflightPublishesAfterAReconnect() {
reconnectingSubscriber.disconnect();
}

private void assertConnectionAccepted(MqttConnAckMessage connAck, String message) {
public static void assertConnectionAccepted(MqttConnAckMessage connAck, String message) {
assertEquals(MqttConnectReturnCode.CONNECTION_ACCEPTED, connAck.variableHeader().connectReturnCode(), message);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package io.moquette.integration.mqtt5;

import com.hivemq.client.mqtt.MqttClient;
import com.hivemq.client.mqtt.mqtt5.Mqtt5BlockingClient;
import com.hivemq.client.mqtt.mqtt5.message.connect.connack.Mqtt5ConnAckReasonCode;
import io.netty.handler.codec.mqtt.MqttConnAckMessage;
import io.netty.handler.codec.mqtt.MqttMessage;
import io.netty.handler.codec.mqtt.MqttQoS;
import io.netty.handler.codec.mqtt.MqttReasonCodeAndPropertiesVariableHeader;
import io.netty.handler.codec.mqtt.MqttReasonCodes;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.Test;

import static io.moquette.integration.mqtt5.ConnectTest.assertConnectionAccepted;
import static org.junit.jupiter.api.Assertions.assertEquals;

public class SharedSubscriptionTest extends AbstractServerIntegrationTest {

@Override
public String clientName() {
return "subscriber";
}

@Test
public void givenAClientSendingBadlyFormattedSharedSubscriptionNameThenItIsDisconnected() {
MqttConnAckMessage connAck = lowLevelClient.connectV5();
assertConnectionAccepted(connAck, "Connection must be accepted");

MqttMessage received = lowLevelClient.subscribeWithError("$share/+/measures/temp", MqttQoS.AT_LEAST_ONCE);

// verify received is a disconnect with an error
MqttReasonCodeAndPropertiesVariableHeader disconnectHeader = (MqttReasonCodeAndPropertiesVariableHeader) received.variableHeader();
assertEquals(MqttReasonCodes.Disconnect.MALFORMED_PACKET.byteValue(), disconnectHeader.reasonCode());
}

@NotNull
private Mqtt5BlockingClient createSubscriberClient() {
final Mqtt5BlockingClient client = MqttClient.builder()
.useMqttVersion5()
.identifier(clientName())
.serverHost("localhost")
.serverPort(1883)
.buildBlocking();
assertEquals(Mqtt5ConnAckReasonCode.SUCCESS, client.connect().getReasonCode(), "Subscriber connected");
return client;
}
}
33 changes: 33 additions & 0 deletions broker/src/test/java/io/moquette/testclient/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,39 @@ public MqttSubAckMessage subscribe(String topic, MqttQoS qos) {
return (MqttSubAckMessage) subAckMessage;
}

public MqttMessage subscribeWithError(String topic, MqttQoS qos) {
final MqttSubscribeMessage subscribeMessage = MqttMessageBuilders.subscribe()
.messageId(1)
.addSubscription(qos, topic)
.build();

final CountDownLatch subscribeAckLatch = new CountDownLatch(1);
this.setCallback(msg -> {
receivedMsg.getAndSet(msg);
LOG.debug("Subscribe callback invocation, received message {}", msg.fixedHeader().messageType());
subscribeAckLatch.countDown();

// clear the callback
setCallback(null);
});

LOG.debug("Sending SUBSCRIBE message");
sendMessage(subscribeMessage);
LOG.debug("Sent SUBSCRIBE message");

boolean waitElapsed;
try {
waitElapsed = !subscribeAckLatch.await(200, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
throw new RuntimeException("Interrupted while waiting", e);
}

if (waitElapsed) {
throw new RuntimeException("Cannot receive SubscribeAck in 200 ms");
}
return this.receivedMsg.get();
}

public void disconnect() {
final MqttMessage disconnectMessage = MqttMessageBuilders.disconnect().build();
sendMessage(disconnectMessage);
Expand Down

0 comments on commit 39b1688

Please sign in to comment.