Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use host and port from current transport config for sni #333

Merged
merged 2 commits into from
Sep 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
import org.jetbrains.annotations.NotNull;

import javax.inject.Inject;
import javax.net.ssl.SSLException;
import java.net.URISyntaxException;

/**
* Initializes:
Expand Down Expand Up @@ -91,11 +89,11 @@ protected void initChannel(final @NotNull Channel channel) throws Exception {
final MqttClientTransportConfigImpl transportConfig = connAckFlow.getTransportConfig();
final MqttClientSslConfigImpl sslConfig = transportConfig.getRawSslConfig();
if (sslConfig != null) {
initSsl(channel, sslConfig);
SslUtil.initChannel(channel, sslConfig, transportConfig.getServerAddress());
}
final MqttWebSocketConfigImpl webSocketConfig = transportConfig.getRawWebSocketConfig();
if (webSocketConfig != null) {
initWebSocketMqtt(channel, webSocketConfig);
webSocketInitializer.get().initChannel(channel, webSocketConfig);
} else {
initMqtt(channel);
}
Expand All @@ -109,19 +107,6 @@ public void initMqtt(final @NotNull Channel channel) {
.addLast(MqttDisconnectHandler.NAME, disconnectHandler);
}

private void initWebSocketMqtt(
final @NotNull Channel channel, final @NotNull MqttWebSocketConfigImpl webSocketConfig)
throws URISyntaxException {

webSocketInitializer.get().initChannel(channel, webSocketConfig);
}

private void initSsl(final @NotNull Channel channel, final @NotNull MqttClientSslConfigImpl sslConfig)
throws SSLException {

SslUtil.initChannel(channel, sslConfig, clientConfig.getServerHost(), clientConfig.getServerPort());
}

@Override
public void exceptionCaught(final @NotNull ChannelHandlerContext ctx, final @NotNull Throwable cause) {
if (ctx.pipeline().get(MqttDisconnectHandler.NAME) != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import java.net.InetSocketAddress;

/**
* @author Christoph Schäbel
Expand All @@ -35,16 +36,16 @@ public final class SslUtil {

public static void initChannel(
final @NotNull Channel channel, final @NotNull MqttClientSslConfigImpl sslConfig,
final @NotNull String host, final int port) throws SSLException {
final @NotNull InetSocketAddress address) throws SSLException {

channel.pipeline().addFirst(SSL_HANDLER_NAME, createSslHandler(channel, sslConfig, host, port));
channel.pipeline().addFirst(SSL_HANDLER_NAME, createSslHandler(channel, sslConfig, address));
}

private static @NotNull SslHandler createSslHandler(
final @NotNull Channel channel, final @NotNull MqttClientSslConfigImpl sslConfig,
final @NotNull String host, final int port) throws SSLException {
final @NotNull InetSocketAddress address) throws SSLException {

return createSslContext(sslConfig).newHandler(channel.alloc(), host, port);
return createSslContext(sslConfig).newHandler(channel.alloc(), address.getHostString(), address.getPort());
}

static @NotNull SslContext createSslContext(final @NotNull MqttClientSslConfigImpl sslConfig) throws SSLException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import java.net.InetSocketAddress;

import static org.junit.Assert.assertNotNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

/**
Expand All @@ -61,22 +62,24 @@ public class MqttChannelInitializerSslTest {
private MqttAuthHandler authHandler;
@Mock
private Lazy<MqttWebSocketInitializer> webSocketInitializer;
@Mock
private MqttClientTransportConfigImpl transportConfig;
@Mock
private MqttClientSslConfigImpl sslConfig;

private Channel channel;

@Before
public void before() {
MockitoAnnotations.initMocks(this);
channel = new EmbeddedChannel();
when(connAckFlow.getTransportConfig()).thenReturn(transportConfig);
when(transportConfig.getRawSslConfig()).thenReturn(sslConfig);
when(transportConfig.getServerAddress()).thenReturn(InetSocketAddress.createUnresolved("localhost", 1883));
}

@Test
public void test_initialize_default_ssldata() throws Exception {
final MqttClientTransportConfigImpl transportConfig = mock(MqttClientTransportConfigImpl.class);
final MqttClientSslConfigImpl sslConfig = mock(MqttClientSslConfigImpl.class);
when(connAckFlow.getTransportConfig()).thenReturn(transportConfig);
when(transportConfig.getRawSslConfig()).thenReturn(sslConfig);

final MqttChannelInitializer mqttChannelInitializer =
new MqttChannelInitializer(clientData, connect, connAckFlow, encoder, connectHandler, disconnectHandler,
authHandler, webSocketInitializer);
Expand All @@ -85,5 +88,4 @@ public void test_initialize_default_ssldata() throws Exception {

assertNotNull(channel.pipeline().get(SslHandler.class));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ private List<String> getEnabledProtocols() throws Exception {
return Arrays.asList(sslEngine.getEnabledProtocols());
}

static @NotNull SSLEngine createSslEngine(
private static @NotNull SSLEngine createSslEngine(
final @NotNull Channel channel, final @NotNull MqttClientSslConfigImpl sslConfig) throws SSLException {

return SslUtil.createSslContext(sslConfig).newEngine(channel.alloc());
Expand Down