diff --git a/CHANGELOG.md b/CHANGELOG.md index 7af3c171e8c6a..b829a11e6f08b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [Remote Store] Add Segment download stats to remotestore stats API ([#8718](https://github.com/opensearch-project/OpenSearch/pull/8718)) - [Remote Store] Add remote segment transfer stats on NodesStats API ([#9168](https://github.com/opensearch-project/OpenSearch/pull/9168)) - Return 409 Conflict HTTP status instead of 503 on failure to concurrently execute snapshots ([#8986](https://github.com/opensearch-project/OpenSearch/pull/5855)) +- Fix memory leak when using Zstd Dictionary ([#9403](https://github.com/opensearch-project/OpenSearch/pull/9403)) ### Deprecated diff --git a/server/src/main/java/org/opensearch/index/codec/customcodecs/ZstdCompressionMode.java b/server/src/main/java/org/opensearch/index/codec/customcodecs/ZstdCompressionMode.java index f8fb21df84320..05ff725933e1a 100644 --- a/server/src/main/java/org/opensearch/index/codec/customcodecs/ZstdCompressionMode.java +++ b/server/src/main/java/org/opensearch/index/codec/customcodecs/ZstdCompressionMode.java @@ -103,11 +103,13 @@ private void compress(byte[] bytes, int offset, int length, DataOutput out) thro // dictionary compression first doCompress(bytes, offset, dictLength, cctx, out); - cctx.loadDict(new ZstdDictCompress(bytes, offset, dictLength, compressionLevel)); + try (ZstdDictCompress dictCompress = new ZstdDictCompress(bytes, offset, dictLength, compressionLevel)) { + cctx.loadDict(dictCompress); - for (int start = offset + dictLength; start < end; start += blockLength) { - int l = Math.min(blockLength, end - start); - doCompress(bytes, start, l, cctx, out); + for (int start = offset + dictLength; start < end; start += blockLength) { + int l = Math.min(blockLength, end - start); + doCompress(bytes, start, l, cctx, out); + } } } } @@ -170,32 +172,33 @@ public void decompress(DataInput in, int originalLength, int offset, int length, // decompress dictionary first doDecompress(in, dctx, bytes, dictLength); - - dctx.loadDict(new ZstdDictDecompress(bytes.bytes, 0, dictLength)); - - int offsetInBlock = dictLength; - int offsetInBytesRef = offset; - - // Skip unneeded blocks - while (offsetInBlock + blockLength < offset) { - final int compressedLength = in.readVInt(); - in.skipBytes(compressedLength); - offsetInBlock += blockLength; - offsetInBytesRef -= blockLength; + try (ZstdDictDecompress dictDecompress = new ZstdDictDecompress(bytes.bytes, 0, dictLength)) { + dctx.loadDict(dictDecompress); + + int offsetInBlock = dictLength; + int offsetInBytesRef = offset; + + // Skip unneeded blocks + while (offsetInBlock + blockLength < offset) { + final int compressedLength = in.readVInt(); + in.skipBytes(compressedLength); + offsetInBlock += blockLength; + offsetInBytesRef -= blockLength; + } + + // Read blocks that intersect with the interval we need + while (offsetInBlock < offset + length) { + bytes.bytes = ArrayUtil.grow(bytes.bytes, bytes.length + blockLength); + int l = Math.min(blockLength, originalLength - offsetInBlock); + doDecompress(in, dctx, bytes, l); + offsetInBlock += blockLength; + } + + bytes.offset = offsetInBytesRef; + bytes.length = length; + + assert bytes.isValid() : "decompression output is corrupted"; } - - // Read blocks that intersect with the interval we need - while (offsetInBlock < offset + length) { - bytes.bytes = ArrayUtil.grow(bytes.bytes, bytes.length + blockLength); - int l = Math.min(blockLength, originalLength - offsetInBlock); - doDecompress(in, dctx, bytes, l); - offsetInBlock += blockLength; - } - - bytes.offset = offsetInBytesRef; - bytes.length = length; - - assert bytes.isValid() : "decompression output is corrupted"; } }