Skip to content

Commit

Permalink
Merge changes from tls-channel for race condition manifested when clo…
Browse files Browse the repository at this point in the history
…sing async sockets right after creation (#851)

This is a backport of #848

JAVA-4417
  • Loading branch information
stIncMale authored Jan 6, 2022
1 parent 3084f18 commit eec0eed
Showing 1 changed file with 106 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import java.nio.channels.SocketChannel;
import java.nio.channels.WritePendingException;
import java.util.Iterator;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -100,19 +100,15 @@ class RegisteredSocket {
/** Bitwise union of pending operation to be registered in the selector */
final AtomicInteger pendingOps = new AtomicInteger();

RegisteredSocket(TlsChannel tlsChannel, SocketChannel socketChannel)
throws ClosedChannelException {
RegisteredSocket(TlsChannel tlsChannel, SocketChannel socketChannel) {
this.tlsChannel = tlsChannel;
this.socketChannel = socketChannel;
}

public void close() {
doCancelRead(this, null);
doCancelWrite(this, null);
if (key != null) {
key.cancel();
}
currentRegistrations.getAndDecrement();
/*
* Actual de-registration from the selector will happen asynchronously.
*/
Expand Down Expand Up @@ -195,8 +191,7 @@ private enum Shutdown {
private LongAdder cancelledReads = new LongAdder();
private LongAdder cancelledWrites = new LongAdder();

// used for synchronization
private AtomicInteger currentRegistrations = new AtomicInteger();
private final ConcurrentHashMap<RegisteredSocket, Boolean> registrations = new ConcurrentHashMap<>();

private LongAdder currentReads = new LongAdder();
private LongAdder currentWrites = new LongAdder();
Expand Down Expand Up @@ -232,13 +227,11 @@ public AsynchronousTlsChannelGroup() {
this(Runtime.getRuntime().availableProcessors());
}

RegisteredSocket registerSocket(TlsChannel reader, SocketChannel socketChannel)
throws ClosedChannelException {
RegisteredSocket registerSocket(TlsChannel reader, SocketChannel socketChannel) {
if (shutdown != Shutdown.No) {
throw new ShutdownChannelGroupException();
}
RegisteredSocket socket = new RegisteredSocket(reader, socketChannel);
currentRegistrations.getAndIncrement();
pendingRegistrations.add(socket);
selector.wakeup();
return socket;
Expand All @@ -247,18 +240,13 @@ RegisteredSocket registerSocket(TlsChannel reader, SocketChannel socketChannel)
boolean doCancelRead(RegisteredSocket socket, ReadOperation op) {
socket.readLock.lock();
try {
// a null op means cancel any operation
if (op != null && socket.readOperation == op || op == null && socket.readOperation != null) {
if (op == null) {
socket.readOperation.onFailure.accept(new CancellationException());
}
socket.readOperation = null;
cancelledReads.increment();
currentReads.decrement();
return true;
} else {
if (op != socket.readOperation) {
return false;
}
socket.readOperation = null;
cancelledReads.increment();
currentReads.decrement();
return true;
} finally {
socket.readLock.unlock();
}
Expand All @@ -267,18 +255,13 @@ boolean doCancelRead(RegisteredSocket socket, ReadOperation op) {
boolean doCancelWrite(RegisteredSocket socket, WriteOperation op) {
socket.writeLock.lock();
try {
// a null op means cancel any operation
if (op != null && socket.writeOperation == op || op == null && socket.writeOperation != null) {
if (op == null) {
socket.writeOperation.onFailure.accept(new CancellationException());
}
socket.writeOperation = null;
cancelledWrites.increment();
currentWrites.decrement();
return true;
} else {
if (op != socket.writeOperation) {
return false;
}
socket.writeOperation = null;
cancelledWrites.increment();
currentWrites.decrement();
return true;
} finally {
socket.writeLock.unlock();
}
Expand All @@ -295,13 +278,23 @@ ReadOperation startRead(
checkTerminated();
Util.assertTrue(buffer.hasRemaining());
waitForSocketRegistration(socket);
ReadOperation op;
socket.readLock.lock();
try {
if (socket.readOperation != null) {
throw new ReadPendingException();
}
op = new ReadOperation(buffer, onSuccess, onFailure);
ReadOperation op = new ReadOperation(buffer, onSuccess, onFailure);

startedReads.increment();
currentReads.increment();

if (!registrations.containsKey(socket)) {
op.onFailure.accept(new ClosedChannelException());
failedReads.increment();
currentReads.decrement();
return op;
}

/*
* we do not try to outsmart the TLS state machine and register for both IO operations for each new socket
* operation
Expand All @@ -324,9 +317,7 @@ ReadOperation startRead(
socket.readLock.unlock();
}
selector.wakeup();
startedReads.increment();
currentReads.increment();
return op;
return socket.readOperation;
}

WriteOperation startWrite(
Expand All @@ -340,13 +331,23 @@ WriteOperation startWrite(
checkTerminated();
Util.assertTrue(buffer.hasRemaining());
waitForSocketRegistration(socket);
WriteOperation op;
socket.writeLock.lock();
try {
if (socket.writeOperation != null) {
throw new WritePendingException();
}
op = new WriteOperation(buffer, onSuccess, onFailure);
WriteOperation op = new WriteOperation(buffer, onSuccess, onFailure);

startedWrites.increment();
currentWrites.increment();

if (!registrations.containsKey(socket)) {
op.onFailure.accept(new ClosedChannelException());
failedWrites.increment();
currentWrites.decrement();
return op;
}

/*
* we do not try to outsmart the TLS state machine and register for both IO operations for each new socket
* operation
Expand All @@ -369,9 +370,7 @@ WriteOperation startWrite(
socket.writeLock.unlock();
}
selector.wakeup();
startedWrites.increment();
currentWrites.increment();
return op;
return socket.writeOperation;
}

private void checkTerminated() {
Expand All @@ -391,8 +390,11 @@ private void waitForSocketRegistration(RegisteredSocket socket) {
private void loop() {
try {
while (shutdown == Shutdown.No
|| shutdown == Shutdown.Wait && currentRegistrations.intValue() > 0) {
int c = selector.select(); // block
|| shutdown == Shutdown.Wait
&& (!pendingRegistrations.isEmpty() || !registrations.isEmpty())) {
// most state-changing operations will wake the selector up, however, asynchronous closings
// of the channels won't, so we have to timeout to allow checking those cases
int c = selector.select(100); // block
selectionCount.increment();
// avoid unnecessary creation of iterator object
if (c > 0) {
Expand All @@ -413,24 +415,20 @@ private void loop() {
}
registerPendingSockets();
processPendingInterests();
checkClosings();
}
} catch (Throwable e) {
LOGGER.error("error in selector loop", e);
} finally {
executor.shutdown();
// use shutdownNow to stop delayed tasks
timeoutExecutor.shutdownNow();
if (shutdown == Shutdown.Immediate) {
for (SelectionKey key : selector.keys()) {
RegisteredSocket socket = (RegisteredSocket) key.attachment();
socket.close();
}
}
try {
selector.close();
} catch (IOException e) {
LOGGER.warn("error closing selector: " + e.getMessage());
}
checkClosings();
}
}

Expand Down Expand Up @@ -606,14 +604,67 @@ private long readHandlingTasks(RegisteredSocket socket, ReadOperation op) throws
}
}

private void registerPendingSockets() throws ClosedChannelException {
private void registerPendingSockets() {
RegisteredSocket socket;
while ((socket = pendingRegistrations.poll()) != null) {
socket.key = socket.socketChannel.register(selector, 0, socket);
if (LOGGER.isTraceEnabled()) {
LOGGER.trace("registered key: " + socket.key);
try {
socket.key = socket.socketChannel.register(selector, 0, socket);
registrations.put(socket, true);
} catch (ClosedChannelException e) {
// can happen when channels are closed right after creation
} finally {
// decrement the count of the latch even in case of exceptions, so the waiting thread
// is unlocked; it will have to check the result, though
socket.registered.countDown();
}
}
}

/**
* Channels that are closed asynchronously are silently removed from selectors. This method will
* check them using the internal catalog and do the proper cleanup.
*/
private void checkClosings() {
for (RegisteredSocket socket : registrations.keySet()) {
if (!socket.key.isValid() || shutdown == Shutdown.Immediate) {
registrations.remove(socket);
failCurrentRead(socket);
failCurrentWrite(socket);
}
socket.registered.countDown();
}
}

private void failCurrentRead(RegisteredSocket socket) {
socket.readLock.lock();
try {
if (socket.readOperation != null) {
socket.readOperation.onFailure.accept(new ClosedChannelException());
if (socket.readOperation.timeoutFuture != null) {
socket.readOperation.timeoutFuture.cancel(false);
}
socket.readOperation = null;
failedReads.increment();
currentReads.decrement();
}
} finally {
socket.readLock.unlock();
}
}

private void failCurrentWrite(RegisteredSocket socket) {
socket.writeLock.lock();
try {
if (socket.writeOperation != null) {
socket.writeOperation.onFailure.accept(new ClosedChannelException());
if (socket.writeOperation.timeoutFuture != null) {
socket.writeOperation.timeoutFuture.cancel(false);
}
socket.writeOperation = null;
failedWrites.increment();
currentWrites.decrement();
}
} finally {
socket.writeLock.unlock();
}
}

Expand Down Expand Up @@ -769,6 +820,6 @@ public long getCurrentWriteCount() {
* @return number of sockets
*/
public long getCurrentRegistrationCount() {
return currentRegistrations.longValue();
return registrations.mappingCount();
}
}

0 comments on commit eec0eed

Please sign in to comment.