From eec0eed9a3586fd423c07958b0dba0d48adb6f01 Mon Sep 17 00:00:00 2001 From: Valentin Kovalenko Date: Thu, 6 Jan 2022 09:34:27 -0700 Subject: [PATCH] Merge changes from tls-channel for race condition manifested when closing async sockets right after creation (#851) This is a backport of https://github.com/mongodb/mongo-java-driver/pull/848 JAVA-4417 --- .../async/AsynchronousTlsChannelGroup.java | 161 ++++++++++++------ 1 file changed, 106 insertions(+), 55 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannelGroup.java b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannelGroup.java index 23f18659ec8..27660c0babd 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannelGroup.java +++ b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannelGroup.java @@ -39,7 +39,7 @@ import java.nio.channels.SocketChannel; import java.nio.channels.WritePendingException; import java.util.Iterator; -import java.util.concurrent.CancellationException; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; @@ -100,19 +100,15 @@ class RegisteredSocket { /** Bitwise union of pending operation to be registered in the selector */ final AtomicInteger pendingOps = new AtomicInteger(); - RegisteredSocket(TlsChannel tlsChannel, SocketChannel socketChannel) - throws ClosedChannelException { + RegisteredSocket(TlsChannel tlsChannel, SocketChannel socketChannel) { this.tlsChannel = tlsChannel; this.socketChannel = socketChannel; } public void close() { - doCancelRead(this, null); - doCancelWrite(this, null); if (key != null) { key.cancel(); } - currentRegistrations.getAndDecrement(); /* * Actual de-registration from the selector will happen asynchronously. */ @@ -195,8 +191,7 @@ private enum Shutdown { private LongAdder cancelledReads = new LongAdder(); private LongAdder cancelledWrites = new LongAdder(); - // used for synchronization - private AtomicInteger currentRegistrations = new AtomicInteger(); + private final ConcurrentHashMap registrations = new ConcurrentHashMap<>(); private LongAdder currentReads = new LongAdder(); private LongAdder currentWrites = new LongAdder(); @@ -232,13 +227,11 @@ public AsynchronousTlsChannelGroup() { this(Runtime.getRuntime().availableProcessors()); } - RegisteredSocket registerSocket(TlsChannel reader, SocketChannel socketChannel) - throws ClosedChannelException { + RegisteredSocket registerSocket(TlsChannel reader, SocketChannel socketChannel) { if (shutdown != Shutdown.No) { throw new ShutdownChannelGroupException(); } RegisteredSocket socket = new RegisteredSocket(reader, socketChannel); - currentRegistrations.getAndIncrement(); pendingRegistrations.add(socket); selector.wakeup(); return socket; @@ -247,18 +240,13 @@ RegisteredSocket registerSocket(TlsChannel reader, SocketChannel socketChannel) boolean doCancelRead(RegisteredSocket socket, ReadOperation op) { socket.readLock.lock(); try { - // a null op means cancel any operation - if (op != null && socket.readOperation == op || op == null && socket.readOperation != null) { - if (op == null) { - socket.readOperation.onFailure.accept(new CancellationException()); - } - socket.readOperation = null; - cancelledReads.increment(); - currentReads.decrement(); - return true; - } else { + if (op != socket.readOperation) { return false; } + socket.readOperation = null; + cancelledReads.increment(); + currentReads.decrement(); + return true; } finally { socket.readLock.unlock(); } @@ -267,18 +255,13 @@ boolean doCancelRead(RegisteredSocket socket, ReadOperation op) { boolean doCancelWrite(RegisteredSocket socket, WriteOperation op) { socket.writeLock.lock(); try { - // a null op means cancel any operation - if (op != null && socket.writeOperation == op || op == null && socket.writeOperation != null) { - if (op == null) { - socket.writeOperation.onFailure.accept(new CancellationException()); - } - socket.writeOperation = null; - cancelledWrites.increment(); - currentWrites.decrement(); - return true; - } else { + if (op != socket.writeOperation) { return false; } + socket.writeOperation = null; + cancelledWrites.increment(); + currentWrites.decrement(); + return true; } finally { socket.writeLock.unlock(); } @@ -295,13 +278,23 @@ ReadOperation startRead( checkTerminated(); Util.assertTrue(buffer.hasRemaining()); waitForSocketRegistration(socket); - ReadOperation op; socket.readLock.lock(); try { if (socket.readOperation != null) { throw new ReadPendingException(); } - op = new ReadOperation(buffer, onSuccess, onFailure); + ReadOperation op = new ReadOperation(buffer, onSuccess, onFailure); + + startedReads.increment(); + currentReads.increment(); + + if (!registrations.containsKey(socket)) { + op.onFailure.accept(new ClosedChannelException()); + failedReads.increment(); + currentReads.decrement(); + return op; + } + /* * we do not try to outsmart the TLS state machine and register for both IO operations for each new socket * operation @@ -324,9 +317,7 @@ ReadOperation startRead( socket.readLock.unlock(); } selector.wakeup(); - startedReads.increment(); - currentReads.increment(); - return op; + return socket.readOperation; } WriteOperation startWrite( @@ -340,13 +331,23 @@ WriteOperation startWrite( checkTerminated(); Util.assertTrue(buffer.hasRemaining()); waitForSocketRegistration(socket); - WriteOperation op; socket.writeLock.lock(); try { if (socket.writeOperation != null) { throw new WritePendingException(); } - op = new WriteOperation(buffer, onSuccess, onFailure); + WriteOperation op = new WriteOperation(buffer, onSuccess, onFailure); + + startedWrites.increment(); + currentWrites.increment(); + + if (!registrations.containsKey(socket)) { + op.onFailure.accept(new ClosedChannelException()); + failedWrites.increment(); + currentWrites.decrement(); + return op; + } + /* * we do not try to outsmart the TLS state machine and register for both IO operations for each new socket * operation @@ -369,9 +370,7 @@ WriteOperation startWrite( socket.writeLock.unlock(); } selector.wakeup(); - startedWrites.increment(); - currentWrites.increment(); - return op; + return socket.writeOperation; } private void checkTerminated() { @@ -391,8 +390,11 @@ private void waitForSocketRegistration(RegisteredSocket socket) { private void loop() { try { while (shutdown == Shutdown.No - || shutdown == Shutdown.Wait && currentRegistrations.intValue() > 0) { - int c = selector.select(); // block + || shutdown == Shutdown.Wait + && (!pendingRegistrations.isEmpty() || !registrations.isEmpty())) { + // most state-changing operations will wake the selector up, however, asynchronous closings + // of the channels won't, so we have to timeout to allow checking those cases + int c = selector.select(100); // block selectionCount.increment(); // avoid unnecessary creation of iterator object if (c > 0) { @@ -413,6 +415,7 @@ private void loop() { } registerPendingSockets(); processPendingInterests(); + checkClosings(); } } catch (Throwable e) { LOGGER.error("error in selector loop", e); @@ -420,17 +423,12 @@ private void loop() { executor.shutdown(); // use shutdownNow to stop delayed tasks timeoutExecutor.shutdownNow(); - if (shutdown == Shutdown.Immediate) { - for (SelectionKey key : selector.keys()) { - RegisteredSocket socket = (RegisteredSocket) key.attachment(); - socket.close(); - } - } try { selector.close(); } catch (IOException e) { LOGGER.warn("error closing selector: " + e.getMessage()); } + checkClosings(); } } @@ -606,14 +604,67 @@ private long readHandlingTasks(RegisteredSocket socket, ReadOperation op) throws } } - private void registerPendingSockets() throws ClosedChannelException { + private void registerPendingSockets() { RegisteredSocket socket; while ((socket = pendingRegistrations.poll()) != null) { - socket.key = socket.socketChannel.register(selector, 0, socket); - if (LOGGER.isTraceEnabled()) { - LOGGER.trace("registered key: " + socket.key); + try { + socket.key = socket.socketChannel.register(selector, 0, socket); + registrations.put(socket, true); + } catch (ClosedChannelException e) { + // can happen when channels are closed right after creation + } finally { + // decrement the count of the latch even in case of exceptions, so the waiting thread + // is unlocked; it will have to check the result, though + socket.registered.countDown(); + } + } + } + + /** + * Channels that are closed asynchronously are silently removed from selectors. This method will + * check them using the internal catalog and do the proper cleanup. + */ + private void checkClosings() { + for (RegisteredSocket socket : registrations.keySet()) { + if (!socket.key.isValid() || shutdown == Shutdown.Immediate) { + registrations.remove(socket); + failCurrentRead(socket); + failCurrentWrite(socket); } - socket.registered.countDown(); + } + } + + private void failCurrentRead(RegisteredSocket socket) { + socket.readLock.lock(); + try { + if (socket.readOperation != null) { + socket.readOperation.onFailure.accept(new ClosedChannelException()); + if (socket.readOperation.timeoutFuture != null) { + socket.readOperation.timeoutFuture.cancel(false); + } + socket.readOperation = null; + failedReads.increment(); + currentReads.decrement(); + } + } finally { + socket.readLock.unlock(); + } + } + + private void failCurrentWrite(RegisteredSocket socket) { + socket.writeLock.lock(); + try { + if (socket.writeOperation != null) { + socket.writeOperation.onFailure.accept(new ClosedChannelException()); + if (socket.writeOperation.timeoutFuture != null) { + socket.writeOperation.timeoutFuture.cancel(false); + } + socket.writeOperation = null; + failedWrites.increment(); + currentWrites.decrement(); + } + } finally { + socket.writeLock.unlock(); } } @@ -769,6 +820,6 @@ public long getCurrentWriteCount() { * @return number of sockets */ public long getCurrentRegistrationCount() { - return currentRegistrations.longValue(); + return registrations.mappingCount(); } }