Skip to content

Commit

Permalink
fix: S3OutputStream write/close checks should be thread-safe. (#721)
Browse files Browse the repository at this point in the history
Co-authored-by: Jason.Song-宋金泽 <Jason.Song@fanruan.com>
  • Loading branch information
songjinze and Jason.Song-宋金泽 committed Jul 25, 2023
1 parent 4cbffaa commit 64c5ed8
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -77,7 +78,7 @@ public final class S3OutputStream
/**
* Indicates if the stream has been closed.
*/
private volatile boolean closed;
private volatile AtomicBoolean closed = new AtomicBoolean(false);

/**
* Internal buffer. May be {@code null} if no bytes are buffered.
Expand Down Expand Up @@ -220,12 +221,24 @@ public void write(final int bytes)
write(new byte[]{ (byte) bytes });
}

@Override
public void write(byte[] bytes) throws IOException
{
write(bytes, 0, bytes.length);
}

@Override
public void write(final byte[] bytes,
final int offset,
final int length)
throws IOException
{

if (closed.get())
{
throw new StreamAlreadyClosedException();
}

if ((offset < 0) || (offset > bytes.length) || (length < 0) || ((offset + length) > bytes.length) ||
((offset + length) < 0))
{
Expand All @@ -237,11 +250,6 @@ public void write(final byte[] bytes,
return;
}

if (closed)
{
throw new IOException("Already closed");
}

synchronized (this)
{
if (uploadId != null && partETags.size() >= MAX_ALLOWED_UPLOAD_PARTS)
Expand All @@ -267,11 +275,19 @@ public void write(final byte[] bytes,
}
}

/**
* @return True if the stream has been closed, false if the stream is still open.
*/
public boolean isClosed()
{
return this.closed.get();
}

@Override
public void close()
throws IOException
{
if (closed)
if (closed.get())
{
return;
}
Expand All @@ -292,7 +308,7 @@ public void close()
completeMultipartUpload();
}

closed = true;
closed.set(true);
}
}

Expand Down Expand Up @@ -374,13 +390,14 @@ private void uploadPart(final long contentLength,
{
if (!success)
{
closed = true;
closed.set(true);
abortMultipartUpload();
}
}

if (partNumber >= MAX_ALLOWED_UPLOAD_PARTS)
{
LOGGER.warn("Uploaded part is out of max allowed parts, stream closed.");
close();
}
}
Expand Down Expand Up @@ -525,4 +542,12 @@ private String getValueFromMetadata(final String key)

return null;
}

public static class StreamAlreadyClosedException extends IOException
{
public StreamAlreadyClosedException() {
super("Stream has already been closed.");
}
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.carlspring.cloud.storage.s3fs;

import org.assertj.core.api.Assertions;
import org.carlspring.cloud.storage.s3fs.util.S3ClientMock;
import org.carlspring.cloud.storage.s3fs.util.S3MockFactory;
import org.junit.jupiter.api.Test;
Expand All @@ -22,10 +23,16 @@
import java.util.HashMap;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.util.UUID.randomUUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.carlspring.cloud.storage.s3fs.S3OutputStream.MAX_ALLOWED_UPLOAD_PARTS;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
Expand Down Expand Up @@ -76,6 +83,108 @@ void openAndCloseProducesEmptyObject()
assertThatBytesHaveBeenPut(client, data);
}

@Test
void writeToClosedStreamShouldProduceExceptionThreadSafe()
throws ExecutionException, InterruptedException, IOException
{
//given
final String key = getTestBasePath() + "/" + randomUUID();
final S3ClientMock client = S3MockFactory.getS3ClientMock();
client.bucket(BUCKET_NAME).file(key);
final S3ObjectId objectId = S3ObjectId.builder().bucket(BUCKET_NAME).key(key).build();

final S3OutputStream outputStream = new S3OutputStream(client, objectId);

// Simulate closing the outputStream from another thread.
Runnable closeStreamRunnable = () -> {
try
{
outputStream.close();
}
catch (IOException e)
{
throw new RuntimeException(e);
}
};
closeStreamRunnable.run();

CountDownLatch count = new CountDownLatch(1);
AtomicBoolean alreadyClosedException = new AtomicBoolean(false);
Runnable runnable = () -> {
count.countDown();
try
{
outputStream.write(new byte[0]);
}
catch (S3OutputStream.StreamAlreadyClosedException e)
{
alreadyClosedException.set(true);
}
catch (Exception e)
{
throw new RuntimeException(e);
}
};

CompletableFuture[] futures = new CompletableFuture[] {
CompletableFuture.runAsync(runnable),
CompletableFuture.runAsync(runnable),
CompletableFuture.runAsync(runnable),
CompletableFuture.runAsync(runnable),
};
count.countDown();

CompletableFuture.allOf(futures).get();
assertTrue(alreadyClosedException.get());
}

@Test
void closingStreamShouldBeThreadSafe()
throws ExecutionException, InterruptedException, IOException
{
//given
final String key = getTestBasePath() + "/" + randomUUID();
final S3ClientMock client = S3MockFactory.getS3ClientMock();
client.bucket(BUCKET_NAME).file(key);
final S3ObjectId objectId = S3ObjectId.builder().bucket(BUCKET_NAME).key(key).build();

final S3OutputStream outputStream = new S3OutputStream(client, objectId);

// Simulate closing the outputStream from another thread.
Runnable closeStreamRunnable = () -> {
try
{
outputStream.close();
}
catch (IOException e)
{
throw new RuntimeException(e);
}
};
closeStreamRunnable.run();

CountDownLatch latch = new CountDownLatch(1);
AtomicInteger counter = new AtomicInteger();
Runnable runnable = () -> {
latch.countDown();
if(outputStream.isClosed()) {
counter.incrementAndGet();
}
};

CompletableFuture[] futures = new CompletableFuture[] {
CompletableFuture.runAsync(runnable),
CompletableFuture.runAsync(runnable),
CompletableFuture.runAsync(runnable),
CompletableFuture.runAsync(runnable),
CompletableFuture.runAsync(runnable),
};
latch.countDown();

CompletableFuture.allOf(futures).get();
assertThat(counter).hasValue(5);
}

@Test
void zeroBytesWrittenProduceEmptyObject()
throws IOException
Expand Down

0 comments on commit 64c5ed8

Please sign in to comment.