From 4d64dab7342ae538977daec58633600f396945b8 Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Fri, 23 Aug 2024 15:34:12 +0200 Subject: [PATCH 01/21] Initial version --- .../client/core/SFArrowResultSet.java | 115 +++++++----- .../client/core/SFBaseResultSet.java | 6 + .../AbstractArrowFullVectorConverter.java | 165 ++++++++++++++++++ .../ArrowFullVectorConverter.java | 11 ++ .../BigIntVectorConverter.java | 33 ++++ .../DecimalVectorConverter.java | 37 ++++ .../IntVectorConverter.java | 33 ++++ .../SimpleArrowFullVectorConverter.java | 35 ++++ .../SmallIntVectorConverter.java | 33 ++++ .../TinyIntVectorConverter.java | 33 ++++ .../client/jdbc/ArrowResultChunk.java | 37 ++++ .../client/jdbc/SFAsyncResultSet.java | 5 + .../client/jdbc/SnowflakeResultSet.java | 2 + .../client/jdbc/SnowflakeResultSetV1.java | 5 + 14 files changed, 511 insertions(+), 39 deletions(-) create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractArrowFullVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BigIntVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DecimalVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/IntVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SmallIntVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TinyIntVectorConverter.java diff --git a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java index 14f21e8d1..f95f9c0f6 100644 --- a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java +++ b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java @@ -30,14 +30,8 @@ import net.snowflake.client.core.arrow.VarCharConverter; import net.snowflake.client.core.arrow.VectorTypeConverter; import net.snowflake.client.core.json.Converters; -import net.snowflake.client.jdbc.ArrowResultChunk; +import net.snowflake.client.jdbc.*; import net.snowflake.client.jdbc.ArrowResultChunk.ArrowChunkIterator; -import net.snowflake.client.jdbc.ErrorCode; -import net.snowflake.client.jdbc.FieldMetadata; -import net.snowflake.client.jdbc.SnowflakeResultSetSerializableV1; -import net.snowflake.client.jdbc.SnowflakeSQLException; -import net.snowflake.client.jdbc.SnowflakeSQLLoggedException; -import net.snowflake.client.jdbc.SnowflakeUtil; import net.snowflake.client.jdbc.telemetry.Telemetry; import net.snowflake.client.jdbc.telemetry.TelemetryData; import net.snowflake.client.jdbc.telemetry.TelemetryField; @@ -112,6 +106,9 @@ public class SFArrowResultSet extends SFBaseResultSet implements DataConversionC */ private boolean formatDateWithTimezone; + private boolean arrowBatchesMode = false; + private boolean rowIteratorMode = false; + @SnowflakeJdbcInternalApi protected Converters converters; /** @@ -246,6 +243,31 @@ private boolean fetchNextRow() throws SnowflakeSQLException { } } + private ArrowResultChunk fetchNextChunk() throws SnowflakeSQLException { + try { + eventHandler.triggerStateTransition( + BasicEvent.QueryState.CONSUMING_RESULT, + String.format( + BasicEvent.QueryState.CONSUMING_RESULT.getArgString(), queryId, nextChunkIndex)); + + ArrowResultChunk nextChunk = (ArrowResultChunk) chunkDownloader.getNextChunkToConsume(); + + if (nextChunk == null) { + throw new SnowflakeSQLLoggedException( + queryId, + session, + ErrorCode.INTERNAL_ERROR.getMessageCode(), + SqlState.INTERNAL_ERROR, + "Expect chunk but got null for chunk index " + nextChunkIndex); + } + + return nextChunk; + } catch (InterruptedException ex) { + throw new SnowflakeSQLLoggedException( + queryId, session, ErrorCode.INTERRUPTED.getMessageCode(), SqlState.QUERY_CANCELED); + } + } + /** * Goto next row. If end of current chunk, update currentChunkIterator to the beginning of next * chunk, if any chunk not being consumed yet. @@ -259,40 +281,21 @@ private boolean fetchNextRowUnsorted() throws SnowflakeSQLException { return true; } else { if (nextChunkIndex < chunkCount) { - try { - eventHandler.triggerStateTransition( - BasicEvent.QueryState.CONSUMING_RESULT, - String.format( - BasicEvent.QueryState.CONSUMING_RESULT.getArgString(), queryId, nextChunkIndex)); + ArrowResultChunk nextChunk = fetchNextChunk(); - ArrowResultChunk nextChunk = (ArrowResultChunk) chunkDownloader.getNextChunkToConsume(); + currentChunkIterator.getChunk().freeData(); + currentChunkIterator = nextChunk.getIterator(this); + if (currentChunkIterator.next()) { - if (nextChunk == null) { - throw new SnowflakeSQLLoggedException( - queryId, - session, - ErrorCode.INTERNAL_ERROR.getMessageCode(), - SqlState.INTERNAL_ERROR, - "Expect chunk but got null for chunk index " + nextChunkIndex); - } + logger.debug( + "Moving to chunk index: {}, row count: {}", + nextChunkIndex, + nextChunk.getRowCount()); - currentChunkIterator.getChunk().freeData(); - currentChunkIterator = nextChunk.getIterator(this); - if (currentChunkIterator.next()) { - - logger.debug( - "Moving to chunk index: {}, row count: {}", - nextChunkIndex, - nextChunk.getRowCount()); - - nextChunkIndex++; - return true; - } else { - return false; - } - } catch (InterruptedException ex) { - throw new SnowflakeSQLLoggedException( - queryId, session, ErrorCode.INTERRUPTED.getMessageCode(), SqlState.QUERY_CANCELED); + nextChunkIndex++; + return true; + } else { + return false; } } else { // always free current chunk @@ -428,9 +431,10 @@ public Timestamp convertToTimestamp( */ @Override public boolean next() throws SFException, SnowflakeSQLException { - if (isClosed()) { + if (isClosed() || arrowBatchesMode) { return false; } + rowIteratorMode = true; // otherwise try to fetch again if (fetchNextRow()) { @@ -763,6 +767,39 @@ public BigDecimal getBigDecimal(int columnIndex, int scale) throws SFException { return bigDec == null ? null : bigDec.setScale(scale, RoundingMode.HALF_UP); } + public ArrowBatches getArrowBatches() { + if (rowIteratorMode) { + return null; + } + arrowBatchesMode = true; + return new SFArrowBatchesIterator(); + } + + private class SFArrowBatchesIterator implements ArrowBatches { + private boolean firstFetched = false; + + @Override + public long getRowCount() { + return 0; + } + + @Override + public boolean hasNext() { + return nextChunkIndex < chunkCount || !firstFetched; + } + + @Override + public ArrowBatch next() throws SQLException { + if (!firstFetched) { + firstFetched = true; + return currentChunkIterator.getChunk().getArrowBatch(SFArrowResultSet.this); + } else { + nextChunkIndex++; + return fetchNextChunk().getArrowBatch(SFArrowResultSet.this); + } + } + } + @Override public boolean isLast() { return nextChunkIndex == chunkCount && currentChunkIterator.isLast(); diff --git a/src/main/java/net/snowflake/client/core/SFBaseResultSet.java b/src/main/java/net/snowflake/client/core/SFBaseResultSet.java index 71e56a515..f9fcf5c01 100644 --- a/src/main/java/net/snowflake/client/core/SFBaseResultSet.java +++ b/src/main/java/net/snowflake/client/core/SFBaseResultSet.java @@ -30,8 +30,10 @@ import java.util.stream.Stream; import java.util.stream.StreamSupport; import net.snowflake.client.core.json.Converters; +import net.snowflake.client.jdbc.ArrowBatches; import net.snowflake.client.jdbc.ErrorCode; import net.snowflake.client.jdbc.FieldMetadata; +import net.snowflake.client.jdbc.SnowflakeLoggedFeatureNotSupportedException; import net.snowflake.client.jdbc.SnowflakeResultSetSerializable; import net.snowflake.client.jdbc.SnowflakeResultSetSerializableV1; import net.snowflake.client.jdbc.SnowflakeSQLException; @@ -137,6 +139,10 @@ public SFBaseSession getSession() { return this.session; } + public ArrowBatches getArrowBatches() throws SnowflakeLoggedFeatureNotSupportedException { + throw new SnowflakeLoggedFeatureNotSupportedException(session); + } + // default implementation public boolean next() throws SFException, SnowflakeSQLException { logger.trace("boolean next()", false); diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractArrowFullVectorConverter.java new file mode 100644 index 000000000..65526ca7e --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractArrowFullVectorConverter.java @@ -0,0 +1,165 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.arrow.ArrayConverter; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import net.snowflake.client.core.arrow.BigIntToFixedConverter; +import net.snowflake.client.core.arrow.BigIntToScaledFixedConverter; +import net.snowflake.client.core.arrow.BigIntToTimeConverter; +import net.snowflake.client.core.arrow.BigIntToTimestampLTZConverter; +import net.snowflake.client.core.arrow.BigIntToTimestampNTZConverter; +import net.snowflake.client.core.arrow.BitToBooleanConverter; +import net.snowflake.client.core.arrow.DateConverter; +import net.snowflake.client.core.arrow.DecimalToScaledFixedConverter; +import net.snowflake.client.core.arrow.DoubleToRealConverter; +import net.snowflake.client.core.arrow.IntToFixedConverter; +import net.snowflake.client.core.arrow.IntToScaledFixedConverter; +import net.snowflake.client.core.arrow.IntToTimeConverter; +import net.snowflake.client.core.arrow.MapConverter; +import net.snowflake.client.core.arrow.SmallIntToFixedConverter; +import net.snowflake.client.core.arrow.SmallIntToScaledFixedConverter; +import net.snowflake.client.core.arrow.StructConverter; +import net.snowflake.client.core.arrow.ThreeFieldStructToTimestampTZConverter; +import net.snowflake.client.core.arrow.TinyIntToFixedConverter; +import net.snowflake.client.core.arrow.TinyIntToScaledFixedConverter; +import net.snowflake.client.core.arrow.TwoFieldStructToTimestampLTZConverter; +import net.snowflake.client.core.arrow.TwoFieldStructToTimestampNTZConverter; +import net.snowflake.client.core.arrow.TwoFieldStructToTimestampTZConverter; +import net.snowflake.client.core.arrow.VarBinaryToBinaryConverter; +import net.snowflake.client.core.arrow.VarCharConverter; +import net.snowflake.client.core.arrow.VectorTypeConverter; +import net.snowflake.client.jdbc.ErrorCode; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import net.snowflake.client.jdbc.SnowflakeSQLLoggedException; +import net.snowflake.client.jdbc.SnowflakeType; +import net.snowflake.common.core.SqlState; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.types.Types; + +import java.util.Map; + +public abstract class AbstractArrowFullVectorConverter implements ArrowFullVectorConverter { + protected RootAllocator allocator; + protected ValueVector vector; + protected DataConversionContext context; + protected SFBaseSession session; + protected int idx; + + public AbstractArrowFullVectorConverter(RootAllocator allocator, ValueVector vector, DataConversionContext context, + SFBaseSession session, int idx) { + this.allocator = allocator; + this.vector = vector; + this.context = context; + this.session = session; + this.idx = idx; + } + + private static Types.MinorType deduceType(ValueVector vector) { + Types.MinorType type = Types.getMinorTypeForArrowType(vector.getField().getType()); + // each column's metadata + Map customMeta = vector.getField().getMetadata(); + if (type == Types.MinorType.DECIMAL) { + // Note: Decimal vector is different from others + return Types.MinorType.DECIMAL; + } else if (!customMeta.isEmpty()) { + SnowflakeType st = SnowflakeType.valueOf(customMeta.get("logicalType")); + switch (st) { + case FIXED: + { + String scaleStr = vector.getField().getMetadata().get("scale"); + int sfScale = Integer.parseInt(scaleStr); + if (sfScale != 0) { + return Types.MinorType.DECIMAL; + } + break; + } + case TIME: + return Types.MinorType.TIMEMILLI; + case TIMESTAMP_LTZ: + { + String scaleStr = vector.getField().getMetadata().get("scale"); + int sfScale = Integer.parseInt(scaleStr); + switch (sfScale) { + case 0: + return Types.MinorType.TIMESTAMPSECTZ; + case 3: + return Types.MinorType.TIMESTAMPMILLITZ; + case 6: + return Types.MinorType.TIMESTAMPMICROTZ; + case 9: + return Types.MinorType.TIMESTAMPNANOTZ; + } + break; + } + case TIMESTAMP_TZ: + { + String scaleStr = vector.getField().getMetadata().get("scale"); + int sfScale = Integer.parseInt(scaleStr); + switch (sfScale) { + case 0: + return Types.MinorType.TIMESTAMPSECTZ; + case 3: + return Types.MinorType.TIMESTAMPMILLITZ; + case 6: + return Types.MinorType.TIMESTAMPMICROTZ; + case 9: + return Types.MinorType.TIMESTAMPNANOTZ; + } + break; + } + case TIMESTAMP_NTZ: + { + String scaleStr = vector.getField().getMetadata().get("scale"); + int sfScale = Integer.parseInt(scaleStr); + switch (sfScale) { + case 0: + return Types.MinorType.TIMESTAMPSEC; + case 3: + return Types.MinorType.TIMESTAMPMILLI; + case 6: + return Types.MinorType.TIMESTAMPMICRO; + case 9: + return Types.MinorType.TIMESTAMPNANO; + } + break; + } + } + } + return type; + } + + public static FieldVector convert(RootAllocator allocator, ValueVector vector, DataConversionContext context, + SFBaseSession session, int idx, Object targetType) throws SnowflakeSQLException { + try { + if (targetType == null) { + targetType = deduceType(vector); + } + if (targetType instanceof Types.MinorType) { + switch ((Types.MinorType) targetType) { + case TINYINT: + return new TinyIntVectorConverter(allocator, vector, context, session, idx).convert(); + case SMALLINT: + return new SmallIntVectorConverter(allocator, vector, context, session, idx).convert(); + case INT: + return new IntVectorConverter(allocator, vector, context, session, idx).convert(); + case BIGINT: + return new BigIntVectorConverter(allocator, vector, context, session, idx).convert(); + case DECIMAL: + return new DecimalVectorConverter(allocator, vector, context, session, idx).convert(); + } + } + } catch (SFException ex) { + throw new SnowflakeSQLException( + ex.getCause(), ex.getSqlState(), ex.getVendorCode(), ex.getParams()); + } + return null; + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java new file mode 100644 index 000000000..13a351e64 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java @@ -0,0 +1,11 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; + +public interface ArrowFullVectorConverter { + FieldVector convert() throws SFException, SnowflakeSQLException; +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BigIntVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BigIntVectorConverter.java new file mode 100644 index 000000000..b3c8df4d7 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BigIntVectorConverter.java @@ -0,0 +1,33 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.ValueVector; + +public class BigIntVectorConverter extends SimpleArrowFullVectorConverter { + + public BigIntVectorConverter(RootAllocator allocator, ValueVector vector, DataConversionContext context, SFBaseSession session, int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return (vector instanceof BigIntVector); + } + + @Override + protected BigIntVector initVector() { + BigIntVector resultVector = new BigIntVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, BigIntVector to, int idx) throws SFException { + to.set(idx, from.toLong(idx)); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DecimalVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DecimalVectorConverter.java new file mode 100644 index 000000000..d48c18144 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DecimalVectorConverter.java @@ -0,0 +1,37 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.ValueVector; + +public class DecimalVectorConverter extends SimpleArrowFullVectorConverter { + + public DecimalVectorConverter(RootAllocator allocator, ValueVector vector, DataConversionContext context, SFBaseSession session, int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return (vector instanceof DecimalVector); + } + + @Override + protected DecimalVector initVector() { + String scaleString = vector.getField().getMetadata().get("scale"); + String precisionString = vector.getField().getMetadata().get("precision"); + int scale = Integer.parseInt(scaleString); + int precision = Integer.parseInt(precisionString); + DecimalVector resultVector = new DecimalVector(vector.getName(), allocator, precision, scale); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, DecimalVector to, int idx) throws SFException { + to.set(idx, from.toBigDecimal(idx)); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/IntVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/IntVectorConverter.java new file mode 100644 index 000000000..0780962d9 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/IntVectorConverter.java @@ -0,0 +1,33 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.ValueVector; + +public class IntVectorConverter extends SimpleArrowFullVectorConverter { + + public IntVectorConverter(RootAllocator allocator, ValueVector vector, DataConversionContext context, SFBaseSession session, int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return (vector instanceof IntVector); + } + + @Override + protected IntVector initVector() { + IntVector resultVector = new IntVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, IntVector to, int idx) throws SFException { + to.set(idx, from.toInt(idx)); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java new file mode 100644 index 000000000..0f6ecafd5 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java @@ -0,0 +1,35 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; + +public abstract class SimpleArrowFullVectorConverter extends AbstractArrowFullVectorConverter { + public SimpleArrowFullVectorConverter(RootAllocator allocator, ValueVector vector, DataConversionContext context, SFBaseSession session, int idx) { + super(allocator, vector, context, session, idx); + } + + + abstract protected boolean matchingType(); + abstract protected T initVector(); + + abstract protected void convertValue(ArrowVectorConverter from, T to, int idx) throws SFException; + +@Override + public FieldVector convert() throws SFException, SnowflakeSQLException { + if (matchingType()) {return (FieldVector) vector;} + int size = vector.getValueCount(); + T converted = initVector(); + ArrowVectorConverter converter = ArrowVectorConverter.initConverter(vector, context, session, idx); + for (int i = 0; i < size; i++) { + convertValue(converter, converted, i); + } + converted.setValueCount(size); + return converted; + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SmallIntVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SmallIntVectorConverter.java new file mode 100644 index 000000000..073d3268c --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SmallIntVectorConverter.java @@ -0,0 +1,33 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.ValueVector; + +public class SmallIntVectorConverter extends SimpleArrowFullVectorConverter { + + public SmallIntVectorConverter(RootAllocator allocator, ValueVector vector, DataConversionContext context, SFBaseSession session, int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return (vector instanceof SmallIntVector); + } + + @Override + protected SmallIntVector initVector() { + SmallIntVector resultVector = new SmallIntVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, SmallIntVector to, int idx) throws SFException { + to.set(idx, from.toShort(idx)); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TinyIntVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TinyIntVectorConverter.java new file mode 100644 index 000000000..0e298c18e --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TinyIntVectorConverter.java @@ -0,0 +1,33 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.ValueVector; + +public class TinyIntVectorConverter extends SimpleArrowFullVectorConverter { + + public TinyIntVectorConverter(RootAllocator allocator, ValueVector vector, DataConversionContext context, SFBaseSession session, int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return (vector instanceof TinyIntVector); + } + + @Override + protected TinyIntVector initVector() { + TinyIntVector resultVector = new TinyIntVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, TinyIntVector to, int idx) throws SFException { + to.set(idx, from.toByte(idx)); + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java index dca895464..aafba5bee 100644 --- a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java +++ b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java @@ -9,12 +9,14 @@ import java.io.InputStream; import java.nio.channels.ClosedByInterruptException; import java.util.ArrayList; +import java.util.Iterator; import java.util.List; import net.snowflake.client.core.DataConversionContext; import net.snowflake.client.core.SFBaseSession; import net.snowflake.client.core.SFException; import net.snowflake.client.core.arrow.ArrowResultChunkIndexSorter; import net.snowflake.client.core.arrow.ArrowVectorConverter; +import net.snowflake.client.core.arrow.fullvectorconverters.AbstractArrowFullVectorConverter; import net.snowflake.client.log.SFLogger; import net.snowflake.client.log.SFLoggerFactory; import net.snowflake.common.core.SqlState; @@ -55,6 +57,7 @@ public class ArrowResultChunk extends SnowflakeResultChunk { private IntVector firstResultChunkSortedIndices; private VectorSchemaRoot root; private SFBaseSession session; + private boolean batchesMode = false; public ArrowResultChunk( String url, @@ -126,6 +129,9 @@ public long computeNeededChunkMemory() { @Override public void freeData() { + if (batchesMode) { + return; + } batchOfVectors.forEach(list -> list.forEach(ValueVector::close)); this.batchOfVectors.clear(); if (firstResultChunkSortedIndices != null) { @@ -494,6 +500,37 @@ private void sortFirstResultChunk(List converters) } } + public ArrowBatch getArrowBatch(DataConversionContext context) { + batchesMode = true; + return new ArrowResultBatch(context); + } + + public class ArrowResultBatch implements ArrowBatch { + private DataConversionContext context; + + ArrowResultBatch(DataConversionContext context){ + this.context = context; + } + + public List fetch() throws SnowflakeSQLException { + List result = new ArrayList<>(); + for (List record: batchOfVectors){ + List convertedVectors = new ArrayList<>(); + for (int i = 0; i < record.size(); i++){ + ValueVector vector = record.get(i); + convertedVectors.add(AbstractArrowFullVectorConverter.convert(rootAllocator, vector, context, session, i, null)); + } + result.add(new VectorSchemaRoot(convertedVectors)); + } + return result; + } + + @Override + public long getRowCount() { + return rowCount; + } + } + private boolean sortFirstResultChunkEnabled() { return enableSortFirstResultChunk; } diff --git a/src/main/java/net/snowflake/client/jdbc/SFAsyncResultSet.java b/src/main/java/net/snowflake/client/jdbc/SFAsyncResultSet.java index 0bafbf12d..51cc14a04 100644 --- a/src/main/java/net/snowflake/client/jdbc/SFAsyncResultSet.java +++ b/src/main/java/net/snowflake/client/jdbc/SFAsyncResultSet.java @@ -402,4 +402,9 @@ public List getResultSetSerializables(long maxSi .unwrap(SnowflakeResultSet.class) .getResultSetSerializables(maxSizeInBytes); } + + @Override + public ArrowBatches getArrowBatches() throws SQLException { + throw new SnowflakeLoggedFeatureNotSupportedException(session); + } } diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSet.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSet.java index 1fc7ff9ee..03171a599 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSet.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSet.java @@ -62,4 +62,6 @@ public interface SnowflakeResultSet { */ List getResultSetSerializables(long maxSizeInBytes) throws SQLException; + + ArrowBatches getArrowBatches() throws SQLException; } diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetV1.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetV1.java index 49c8c8546..085c8bcc3 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetV1.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetV1.java @@ -397,6 +397,11 @@ public List getResultSetSerializables(long maxSi return sfBaseResultSet.getResultSetSerializables(maxSizeInBytes); } + @Override + public ArrowBatches getArrowBatches() throws SQLException { + return sfBaseResultSet.getArrowBatches(); + } + /** Empty result set */ static class EmptyResultSet implements ResultSet { private boolean isClosed; From 986967de617c0c626b837bf597e120c71f6d3351 Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Mon, 26 Aug 2024 10:00:34 +0200 Subject: [PATCH 02/21] Initial tests --- .../client/core/SFArrowResultSet.java | 5 + .../client/jdbc/ArrowBatchesTest.java | 228 ++++++++++++++++++ 2 files changed, 233 insertions(+) create mode 100644 src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java diff --git a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java index f95f9c0f6..2ac116cb2 100644 --- a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java +++ b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java @@ -235,6 +235,11 @@ public SFArrowResultSet( } } + @SnowflakeJdbcInternalApi + public long getAllocatedMemory() { + return rootAllocator.getAllocatedMemory(); + } + private boolean fetchNextRow() throws SnowflakeSQLException { if (sortResult) { return fetchNextRowSorted(); diff --git a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java new file mode 100644 index 000000000..4334447b3 --- /dev/null +++ b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java @@ -0,0 +1,228 @@ +package net.snowflake.client.jdbc; + +import net.snowflake.client.core.SFArrowResultSet; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.junit.Before; +import org.junit.Test; + +import java.math.BigDecimal; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class ArrowBatchesTest extends BaseJDBCWithSharedConnectionIT { + @Before + public void setUp() throws Exception { + try (Statement statement = connection.createStatement()) { + statement.execute("alter session set jdbc_query_result_format = 'arrow'"); + } + } + + private static void assertNoMemoryLeaks(ResultSet rs) throws SQLException { + assertEquals(((SFArrowResultSet) rs.unwrap(SnowflakeResultSetV1.class).sfBaseResultSet).getAllocatedMemory(), 0); + } + + @Test + public void testMultipleBatches() throws Exception { + Statement statement = connection.createStatement();; + ResultSet rs = statement.executeQuery("select seq1(), seq2(), seq4(), seq8() from TABLE (generator(rowcount => 30000))"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + int totalRows = 0; + ArrayList allRoots = new ArrayList<>(); + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + allRoots.add(root); + assertTrue(root.getVector(0) instanceof TinyIntVector); + assertTrue(root.getVector(1) instanceof SmallIntVector); + assertTrue(root.getVector(2) instanceof IntVector); + assertTrue(root.getVector(3) instanceof BigIntVector); + } + } + + rs.close(); + + // The memory should not be freed when closing the result set. + for (VectorSchemaRoot root : allRoots) { + assertTrue(root.getVector(0).getValueCount() > 0); + root.close(); + } + assertNoMemoryLeaks(rs); + assertEquals(30000, totalRows); + + } + + @Test + public void testTinyIntBatch() throws Exception { + Statement statement = connection.createStatement();; + ResultSet rs = statement.executeQuery("select 1 union select 2 union select 3;"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + int totalRows = 0; + List values = new ArrayList<>(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof TinyIntVector); + TinyIntVector vector = (TinyIntVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.get(i)); + } + root.close(); + } + } + rs.close(); + + // All expected values are present + for(byte i = 1; i < 4; i++) { + assertTrue(values.contains(i)); + } + + assertEquals(3, totalRows); + } + + @Test + public void testSmallIntBatch() throws Exception { + Statement statement = connection.createStatement();; + ResultSet rs = statement.executeQuery("select 129 union select 130 union select 131;"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + int totalRows = 0; + List values = new ArrayList<>(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof SmallIntVector); + SmallIntVector vector = (SmallIntVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.get(i)); + } + root.close(); + } + } + rs.close(); + + // All expected values are present + for(short i = 129; i < 132; i++) { + assertTrue(values.contains(i)); + } + + assertEquals(3, totalRows); + } + + @Test + public void testIntBatch() throws Exception { + Statement statement = connection.createStatement();; + ResultSet rs = statement.executeQuery("select 100000 union select 100001 union select 100002;"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + int totalRows = 0; + List values = new ArrayList<>(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof IntVector); + IntVector vector = (IntVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.get(i)); + } + root.close(); + } + } + rs.close(); + + // All expected values are present + for(int i = 100000; i < 100003; i++) { + assertTrue(values.contains(i)); + } + + assertEquals(3, totalRows); + } + + @Test + public void testBigIntBatch() throws Exception { + Statement statement = connection.createStatement();; + ResultSet rs = statement.executeQuery("select 10000000000 union select 10000000001 union select 10000000002;"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + int totalRows = 0; + List values = new ArrayList<>(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof BigIntVector); + BigIntVector vector = (BigIntVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.get(i)); + } + root.close(); + } + } + rs.close(); + + // All expected values are present + for(long i = 10000000000L; i < 10000000003L; i++) { + assertTrue(values.contains(i)); + } + + assertEquals(3, totalRows); + } + + @Test + public void testDecimalBatch() throws Exception { + Statement statement = connection.createStatement();; + ResultSet rs = statement.executeQuery("select 1.1 union select 1.2 union select 1.3;"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + int totalRows = 0; + List values = new ArrayList<>(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof DecimalVector); + DecimalVector vector = (DecimalVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.getObject(i)); + } + root.close(); + } + } + + rs.close(); + + // All expected values are present + for(int i = 1; i < 4; i++) { + assertTrue(values.contains(new BigDecimal("1."+ i))); + } + + assertEquals(3, totalRows); + } +} From c9f62878cbaad8497729c3f93c203939bb531b5a Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Mon, 26 Aug 2024 10:04:18 +0200 Subject: [PATCH 03/21] Formatting --- .../client/core/SFArrowResultSet.java | 22 +- .../AbstractArrowFullVectorConverter.java | 266 ++++++------ .../ArrowFullVectorConverter.java | 4 +- .../BigIntVectorConverter.java | 40 +- .../DecimalVectorConverter.java | 48 ++- .../IntVectorConverter.java | 39 +- .../SimpleArrowFullVectorConverter.java | 45 +- .../SmallIntVectorConverter.java | 40 +- .../TinyIntVectorConverter.java | 40 +- .../client/jdbc/ArrowResultChunk.java | 11 +- .../client/jdbc/ArrowBatchesTest.java | 388 +++++++++--------- 11 files changed, 482 insertions(+), 461 deletions(-) diff --git a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java index 2ac116cb2..50580ecee 100644 --- a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java +++ b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java @@ -251,25 +251,25 @@ private boolean fetchNextRow() throws SnowflakeSQLException { private ArrowResultChunk fetchNextChunk() throws SnowflakeSQLException { try { eventHandler.triggerStateTransition( - BasicEvent.QueryState.CONSUMING_RESULT, - String.format( - BasicEvent.QueryState.CONSUMING_RESULT.getArgString(), queryId, nextChunkIndex)); + BasicEvent.QueryState.CONSUMING_RESULT, + String.format( + BasicEvent.QueryState.CONSUMING_RESULT.getArgString(), queryId, nextChunkIndex)); ArrowResultChunk nextChunk = (ArrowResultChunk) chunkDownloader.getNextChunkToConsume(); if (nextChunk == null) { throw new SnowflakeSQLLoggedException( - queryId, - session, - ErrorCode.INTERNAL_ERROR.getMessageCode(), - SqlState.INTERNAL_ERROR, - "Expect chunk but got null for chunk index " + nextChunkIndex); + queryId, + session, + ErrorCode.INTERNAL_ERROR.getMessageCode(), + SqlState.INTERNAL_ERROR, + "Expect chunk but got null for chunk index " + nextChunkIndex); } return nextChunk; } catch (InterruptedException ex) { throw new SnowflakeSQLLoggedException( - queryId, session, ErrorCode.INTERRUPTED.getMessageCode(), SqlState.QUERY_CANCELED); + queryId, session, ErrorCode.INTERRUPTED.getMessageCode(), SqlState.QUERY_CANCELED); } } @@ -293,9 +293,7 @@ private boolean fetchNextRowUnsorted() throws SnowflakeSQLException { if (currentChunkIterator.next()) { logger.debug( - "Moving to chunk index: {}, row count: {}", - nextChunkIndex, - nextChunk.getRowCount()); + "Moving to chunk index: {}, row count: {}", nextChunkIndex, nextChunk.getRowCount()); nextChunkIndex++; return true; diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractArrowFullVectorConverter.java index 65526ca7e..8d580a8af 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractArrowFullVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractArrowFullVectorConverter.java @@ -1,165 +1,141 @@ package net.snowflake.client.core.arrow.fullvectorconverters; +import java.util.Map; import net.snowflake.client.core.DataConversionContext; import net.snowflake.client.core.SFBaseSession; import net.snowflake.client.core.SFException; -import net.snowflake.client.core.arrow.ArrayConverter; -import net.snowflake.client.core.arrow.ArrowVectorConverter; -import net.snowflake.client.core.arrow.BigIntToFixedConverter; -import net.snowflake.client.core.arrow.BigIntToScaledFixedConverter; -import net.snowflake.client.core.arrow.BigIntToTimeConverter; -import net.snowflake.client.core.arrow.BigIntToTimestampLTZConverter; -import net.snowflake.client.core.arrow.BigIntToTimestampNTZConverter; -import net.snowflake.client.core.arrow.BitToBooleanConverter; -import net.snowflake.client.core.arrow.DateConverter; -import net.snowflake.client.core.arrow.DecimalToScaledFixedConverter; -import net.snowflake.client.core.arrow.DoubleToRealConverter; -import net.snowflake.client.core.arrow.IntToFixedConverter; -import net.snowflake.client.core.arrow.IntToScaledFixedConverter; -import net.snowflake.client.core.arrow.IntToTimeConverter; -import net.snowflake.client.core.arrow.MapConverter; -import net.snowflake.client.core.arrow.SmallIntToFixedConverter; -import net.snowflake.client.core.arrow.SmallIntToScaledFixedConverter; -import net.snowflake.client.core.arrow.StructConverter; -import net.snowflake.client.core.arrow.ThreeFieldStructToTimestampTZConverter; -import net.snowflake.client.core.arrow.TinyIntToFixedConverter; -import net.snowflake.client.core.arrow.TinyIntToScaledFixedConverter; -import net.snowflake.client.core.arrow.TwoFieldStructToTimestampLTZConverter; -import net.snowflake.client.core.arrow.TwoFieldStructToTimestampNTZConverter; -import net.snowflake.client.core.arrow.TwoFieldStructToTimestampTZConverter; -import net.snowflake.client.core.arrow.VarBinaryToBinaryConverter; -import net.snowflake.client.core.arrow.VarCharConverter; -import net.snowflake.client.core.arrow.VectorTypeConverter; -import net.snowflake.client.jdbc.ErrorCode; import net.snowflake.client.jdbc.SnowflakeSQLException; -import net.snowflake.client.jdbc.SnowflakeSQLLoggedException; import net.snowflake.client.jdbc.SnowflakeType; -import net.snowflake.common.core.SqlState; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; -import org.apache.arrow.vector.complex.FixedSizeListVector; -import org.apache.arrow.vector.complex.ListVector; -import org.apache.arrow.vector.complex.MapVector; -import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.types.Types; -import java.util.Map; +public abstract class AbstractArrowFullVectorConverter + implements ArrowFullVectorConverter { + protected RootAllocator allocator; + protected ValueVector vector; + protected DataConversionContext context; + protected SFBaseSession session; + protected int idx; -public abstract class AbstractArrowFullVectorConverter implements ArrowFullVectorConverter { - protected RootAllocator allocator; - protected ValueVector vector; - protected DataConversionContext context; - protected SFBaseSession session; - protected int idx; + public AbstractArrowFullVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + this.allocator = allocator; + this.vector = vector; + this.context = context; + this.session = session; + this.idx = idx; + } - public AbstractArrowFullVectorConverter(RootAllocator allocator, ValueVector vector, DataConversionContext context, - SFBaseSession session, int idx) { - this.allocator = allocator; - this.vector = vector; - this.context = context; - this.session = session; - this.idx = idx; - } - - private static Types.MinorType deduceType(ValueVector vector) { - Types.MinorType type = Types.getMinorTypeForArrowType(vector.getField().getType()); - // each column's metadata - Map customMeta = vector.getField().getMetadata(); - if (type == Types.MinorType.DECIMAL) { - // Note: Decimal vector is different from others - return Types.MinorType.DECIMAL; - } else if (!customMeta.isEmpty()) { - SnowflakeType st = SnowflakeType.valueOf(customMeta.get("logicalType")); - switch (st) { - case FIXED: - { - String scaleStr = vector.getField().getMetadata().get("scale"); - int sfScale = Integer.parseInt(scaleStr); - if (sfScale != 0) { - return Types.MinorType.DECIMAL; - } - break; - } - case TIME: - return Types.MinorType.TIMEMILLI; - case TIMESTAMP_LTZ: - { - String scaleStr = vector.getField().getMetadata().get("scale"); - int sfScale = Integer.parseInt(scaleStr); - switch (sfScale) { - case 0: - return Types.MinorType.TIMESTAMPSECTZ; - case 3: - return Types.MinorType.TIMESTAMPMILLITZ; - case 6: - return Types.MinorType.TIMESTAMPMICROTZ; - case 9: - return Types.MinorType.TIMESTAMPNANOTZ; - } - break; - } - case TIMESTAMP_TZ: - { - String scaleStr = vector.getField().getMetadata().get("scale"); - int sfScale = Integer.parseInt(scaleStr); - switch (sfScale) { - case 0: - return Types.MinorType.TIMESTAMPSECTZ; - case 3: - return Types.MinorType.TIMESTAMPMILLITZ; - case 6: - return Types.MinorType.TIMESTAMPMICROTZ; - case 9: - return Types.MinorType.TIMESTAMPNANOTZ; - } - break; - } - case TIMESTAMP_NTZ: - { - String scaleStr = vector.getField().getMetadata().get("scale"); - int sfScale = Integer.parseInt(scaleStr); - switch (sfScale) { - case 0: - return Types.MinorType.TIMESTAMPSEC; - case 3: - return Types.MinorType.TIMESTAMPMILLI; - case 6: - return Types.MinorType.TIMESTAMPMICRO; - case 9: - return Types.MinorType.TIMESTAMPNANO; - } - break; - } + private static Types.MinorType deduceType(ValueVector vector) { + Types.MinorType type = Types.getMinorTypeForArrowType(vector.getField().getType()); + // each column's metadata + Map customMeta = vector.getField().getMetadata(); + if (type == Types.MinorType.DECIMAL) { + // Note: Decimal vector is different from others + return Types.MinorType.DECIMAL; + } else if (!customMeta.isEmpty()) { + SnowflakeType st = SnowflakeType.valueOf(customMeta.get("logicalType")); + switch (st) { + case FIXED: + { + String scaleStr = vector.getField().getMetadata().get("scale"); + int sfScale = Integer.parseInt(scaleStr); + if (sfScale != 0) { + return Types.MinorType.DECIMAL; } - } - return type; - } - - public static FieldVector convert(RootAllocator allocator, ValueVector vector, DataConversionContext context, - SFBaseSession session, int idx, Object targetType) throws SnowflakeSQLException { - try { - if (targetType == null) { - targetType = deduceType(vector); + break; + } + case TIME: + return Types.MinorType.TIMEMILLI; + case TIMESTAMP_LTZ: + { + String scaleStr = vector.getField().getMetadata().get("scale"); + int sfScale = Integer.parseInt(scaleStr); + switch (sfScale) { + case 0: + return Types.MinorType.TIMESTAMPSECTZ; + case 3: + return Types.MinorType.TIMESTAMPMILLITZ; + case 6: + return Types.MinorType.TIMESTAMPMICROTZ; + case 9: + return Types.MinorType.TIMESTAMPNANOTZ; + } + break; + } + case TIMESTAMP_TZ: + { + String scaleStr = vector.getField().getMetadata().get("scale"); + int sfScale = Integer.parseInt(scaleStr); + switch (sfScale) { + case 0: + return Types.MinorType.TIMESTAMPSECTZ; + case 3: + return Types.MinorType.TIMESTAMPMILLITZ; + case 6: + return Types.MinorType.TIMESTAMPMICROTZ; + case 9: + return Types.MinorType.TIMESTAMPNANOTZ; } - if (targetType instanceof Types.MinorType) { - switch ((Types.MinorType) targetType) { - case TINYINT: - return new TinyIntVectorConverter(allocator, vector, context, session, idx).convert(); - case SMALLINT: - return new SmallIntVectorConverter(allocator, vector, context, session, idx).convert(); - case INT: - return new IntVectorConverter(allocator, vector, context, session, idx).convert(); - case BIGINT: - return new BigIntVectorConverter(allocator, vector, context, session, idx).convert(); - case DECIMAL: - return new DecimalVectorConverter(allocator, vector, context, session, idx).convert(); - } + break; + } + case TIMESTAMP_NTZ: + { + String scaleStr = vector.getField().getMetadata().get("scale"); + int sfScale = Integer.parseInt(scaleStr); + switch (sfScale) { + case 0: + return Types.MinorType.TIMESTAMPSEC; + case 3: + return Types.MinorType.TIMESTAMPMILLI; + case 6: + return Types.MinorType.TIMESTAMPMICRO; + case 9: + return Types.MinorType.TIMESTAMPNANO; } - } catch (SFException ex) { - throw new SnowflakeSQLException( - ex.getCause(), ex.getSqlState(), ex.getVendorCode(), ex.getParams()); + break; + } + } + } + return type; + } + + public static FieldVector convert( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx, + Object targetType) + throws SnowflakeSQLException { + try { + if (targetType == null) { + targetType = deduceType(vector); + } + if (targetType instanceof Types.MinorType) { + switch ((Types.MinorType) targetType) { + case TINYINT: + return new TinyIntVectorConverter(allocator, vector, context, session, idx).convert(); + case SMALLINT: + return new SmallIntVectorConverter(allocator, vector, context, session, idx).convert(); + case INT: + return new IntVectorConverter(allocator, vector, context, session, idx).convert(); + case BIGINT: + return new BigIntVectorConverter(allocator, vector, context, session, idx).convert(); + case DECIMAL: + return new DecimalVectorConverter(allocator, vector, context, session, idx).convert(); } - return null; + } + } catch (SFException ex) { + throw new SnowflakeSQLException( + ex.getCause(), ex.getSqlState(), ex.getVendorCode(), ex.getParams()); } + return null; + } } diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java index 13a351e64..c8638ba19 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java @@ -1,11 +1,9 @@ package net.snowflake.client.core.arrow.fullvectorconverters; import net.snowflake.client.core.SFException; -import net.snowflake.client.core.arrow.ArrowVectorConverter; import net.snowflake.client.jdbc.SnowflakeSQLException; import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.ValueVector; public interface ArrowFullVectorConverter { - FieldVector convert() throws SFException, SnowflakeSQLException; + FieldVector convert() throws SFException, SnowflakeSQLException; } diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BigIntVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BigIntVectorConverter.java index b3c8df4d7..cc1325e6b 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BigIntVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BigIntVectorConverter.java @@ -10,24 +10,30 @@ public class BigIntVectorConverter extends SimpleArrowFullVectorConverter { - public BigIntVectorConverter(RootAllocator allocator, ValueVector vector, DataConversionContext context, SFBaseSession session, int idx) { - super(allocator, vector, context, session, idx); - } + public BigIntVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } - @Override - protected boolean matchingType() { - return (vector instanceof BigIntVector); - } + @Override + protected boolean matchingType() { + return (vector instanceof BigIntVector); + } - @Override - protected BigIntVector initVector() { - BigIntVector resultVector = new BigIntVector(vector.getName(), allocator); - resultVector.allocateNew(vector.getValueCount()); - return resultVector; - } + @Override + protected BigIntVector initVector() { + BigIntVector resultVector = new BigIntVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } - @Override - protected void convertValue(ArrowVectorConverter from, BigIntVector to, int idx) throws SFException { - to.set(idx, from.toLong(idx)); - } + @Override + protected void convertValue(ArrowVectorConverter from, BigIntVector to, int idx) + throws SFException { + to.set(idx, from.toLong(idx)); + } } diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DecimalVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DecimalVectorConverter.java index d48c18144..cf824b526 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DecimalVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DecimalVectorConverter.java @@ -10,28 +10,34 @@ public class DecimalVectorConverter extends SimpleArrowFullVectorConverter { - public DecimalVectorConverter(RootAllocator allocator, ValueVector vector, DataConversionContext context, SFBaseSession session, int idx) { - super(allocator, vector, context, session, idx); - } + public DecimalVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } - @Override - protected boolean matchingType() { - return (vector instanceof DecimalVector); - } + @Override + protected boolean matchingType() { + return (vector instanceof DecimalVector); + } - @Override - protected DecimalVector initVector() { - String scaleString = vector.getField().getMetadata().get("scale"); - String precisionString = vector.getField().getMetadata().get("precision"); - int scale = Integer.parseInt(scaleString); - int precision = Integer.parseInt(precisionString); - DecimalVector resultVector = new DecimalVector(vector.getName(), allocator, precision, scale); - resultVector.allocateNew(vector.getValueCount()); - return resultVector; - } + @Override + protected DecimalVector initVector() { + String scaleString = vector.getField().getMetadata().get("scale"); + String precisionString = vector.getField().getMetadata().get("precision"); + int scale = Integer.parseInt(scaleString); + int precision = Integer.parseInt(precisionString); + DecimalVector resultVector = new DecimalVector(vector.getName(), allocator, precision, scale); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } - @Override - protected void convertValue(ArrowVectorConverter from, DecimalVector to, int idx) throws SFException { - to.set(idx, from.toBigDecimal(idx)); - } + @Override + protected void convertValue(ArrowVectorConverter from, DecimalVector to, int idx) + throws SFException { + to.set(idx, from.toBigDecimal(idx)); + } } diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/IntVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/IntVectorConverter.java index 0780962d9..e18ac2b5d 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/IntVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/IntVectorConverter.java @@ -10,24 +10,29 @@ public class IntVectorConverter extends SimpleArrowFullVectorConverter { - public IntVectorConverter(RootAllocator allocator, ValueVector vector, DataConversionContext context, SFBaseSession session, int idx) { - super(allocator, vector, context, session, idx); - } + public IntVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } - @Override - protected boolean matchingType() { - return (vector instanceof IntVector); - } + @Override + protected boolean matchingType() { + return (vector instanceof IntVector); + } - @Override - protected IntVector initVector() { - IntVector resultVector = new IntVector(vector.getName(), allocator); - resultVector.allocateNew(vector.getValueCount()); - return resultVector; - } + @Override + protected IntVector initVector() { + IntVector resultVector = new IntVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } - @Override - protected void convertValue(ArrowVectorConverter from, IntVector to, int idx) throws SFException { - to.set(idx, from.toInt(idx)); - } + @Override + protected void convertValue(ArrowVectorConverter from, IntVector to, int idx) throws SFException { + to.set(idx, from.toInt(idx)); + } } diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java index 0f6ecafd5..7a39418bf 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java @@ -9,27 +9,36 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; -public abstract class SimpleArrowFullVectorConverter extends AbstractArrowFullVectorConverter { - public SimpleArrowFullVectorConverter(RootAllocator allocator, ValueVector vector, DataConversionContext context, SFBaseSession session, int idx) { - super(allocator, vector, context, session, idx); - } +public abstract class SimpleArrowFullVectorConverter + extends AbstractArrowFullVectorConverter { + public SimpleArrowFullVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } + protected abstract boolean matchingType(); - abstract protected boolean matchingType(); - abstract protected T initVector(); + protected abstract T initVector(); - abstract protected void convertValue(ArrowVectorConverter from, T to, int idx) throws SFException; + protected abstract void convertValue(ArrowVectorConverter from, T to, int idx) throws SFException; -@Override - public FieldVector convert() throws SFException, SnowflakeSQLException { - if (matchingType()) {return (FieldVector) vector;} - int size = vector.getValueCount(); - T converted = initVector(); - ArrowVectorConverter converter = ArrowVectorConverter.initConverter(vector, context, session, idx); - for (int i = 0; i < size; i++) { - convertValue(converter, converted, i); - } - converted.setValueCount(size); - return converted; + @Override + public FieldVector convert() throws SFException, SnowflakeSQLException { + if (matchingType()) { + return (FieldVector) vector; + } + int size = vector.getValueCount(); + T converted = initVector(); + ArrowVectorConverter converter = + ArrowVectorConverter.initConverter(vector, context, session, idx); + for (int i = 0; i < size; i++) { + convertValue(converter, converted, i); } + converted.setValueCount(size); + return converted; + } } diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SmallIntVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SmallIntVectorConverter.java index 073d3268c..bdc71ceee 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SmallIntVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SmallIntVectorConverter.java @@ -10,24 +10,30 @@ public class SmallIntVectorConverter extends SimpleArrowFullVectorConverter { - public SmallIntVectorConverter(RootAllocator allocator, ValueVector vector, DataConversionContext context, SFBaseSession session, int idx) { - super(allocator, vector, context, session, idx); - } + public SmallIntVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } - @Override - protected boolean matchingType() { - return (vector instanceof SmallIntVector); - } + @Override + protected boolean matchingType() { + return (vector instanceof SmallIntVector); + } - @Override - protected SmallIntVector initVector() { - SmallIntVector resultVector = new SmallIntVector(vector.getName(), allocator); - resultVector.allocateNew(vector.getValueCount()); - return resultVector; - } + @Override + protected SmallIntVector initVector() { + SmallIntVector resultVector = new SmallIntVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } - @Override - protected void convertValue(ArrowVectorConverter from, SmallIntVector to, int idx) throws SFException { - to.set(idx, from.toShort(idx)); - } + @Override + protected void convertValue(ArrowVectorConverter from, SmallIntVector to, int idx) + throws SFException { + to.set(idx, from.toShort(idx)); + } } diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TinyIntVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TinyIntVectorConverter.java index 0e298c18e..cc070c9de 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TinyIntVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TinyIntVectorConverter.java @@ -10,24 +10,30 @@ public class TinyIntVectorConverter extends SimpleArrowFullVectorConverter { - public TinyIntVectorConverter(RootAllocator allocator, ValueVector vector, DataConversionContext context, SFBaseSession session, int idx) { - super(allocator, vector, context, session, idx); - } + public TinyIntVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } - @Override - protected boolean matchingType() { - return (vector instanceof TinyIntVector); - } + @Override + protected boolean matchingType() { + return (vector instanceof TinyIntVector); + } - @Override - protected TinyIntVector initVector() { - TinyIntVector resultVector = new TinyIntVector(vector.getName(), allocator); - resultVector.allocateNew(vector.getValueCount()); - return resultVector; - } + @Override + protected TinyIntVector initVector() { + TinyIntVector resultVector = new TinyIntVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } - @Override - protected void convertValue(ArrowVectorConverter from, TinyIntVector to, int idx) throws SFException { - to.set(idx, from.toByte(idx)); - } + @Override + protected void convertValue(ArrowVectorConverter from, TinyIntVector to, int idx) + throws SFException { + to.set(idx, from.toByte(idx)); + } } diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java index aafba5bee..b08c5417f 100644 --- a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java +++ b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java @@ -9,7 +9,6 @@ import java.io.InputStream; import java.nio.channels.ClosedByInterruptException; import java.util.ArrayList; -import java.util.Iterator; import java.util.List; import net.snowflake.client.core.DataConversionContext; import net.snowflake.client.core.SFBaseSession; @@ -508,17 +507,19 @@ public ArrowBatch getArrowBatch(DataConversionContext context) { public class ArrowResultBatch implements ArrowBatch { private DataConversionContext context; - ArrowResultBatch(DataConversionContext context){ + ArrowResultBatch(DataConversionContext context) { this.context = context; } public List fetch() throws SnowflakeSQLException { List result = new ArrayList<>(); - for (List record: batchOfVectors){ + for (List record : batchOfVectors) { List convertedVectors = new ArrayList<>(); - for (int i = 0; i < record.size(); i++){ + for (int i = 0; i < record.size(); i++) { ValueVector vector = record.get(i); - convertedVectors.add(AbstractArrowFullVectorConverter.convert(rootAllocator, vector, context, session, i, null)); + convertedVectors.add( + AbstractArrowFullVectorConverter.convert( + rootAllocator, vector, context, session, i, null)); } result.add(new VectorSchemaRoot(convertedVectors)); } diff --git a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java index 4334447b3..90bfde7c5 100644 --- a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java +++ b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java @@ -1,5 +1,14 @@ package net.snowflake.client.jdbc; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.math.BigDecimal; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; import net.snowflake.client.core.SFArrowResultSet; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.DecimalVector; @@ -10,219 +19,220 @@ import org.junit.Before; import org.junit.Test; -import java.math.BigDecimal; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - public class ArrowBatchesTest extends BaseJDBCWithSharedConnectionIT { - @Before - public void setUp() throws Exception { - try (Statement statement = connection.createStatement()) { - statement.execute("alter session set jdbc_query_result_format = 'arrow'"); - } + @Before + public void setUp() throws Exception { + try (Statement statement = connection.createStatement()) { + statement.execute("alter session set jdbc_query_result_format = 'arrow'"); } - - private static void assertNoMemoryLeaks(ResultSet rs) throws SQLException { - assertEquals(((SFArrowResultSet) rs.unwrap(SnowflakeResultSetV1.class).sfBaseResultSet).getAllocatedMemory(), 0); + } + + private static void assertNoMemoryLeaks(ResultSet rs) throws SQLException { + assertEquals( + ((SFArrowResultSet) rs.unwrap(SnowflakeResultSetV1.class).sfBaseResultSet) + .getAllocatedMemory(), + 0); + } + + @Test + public void testMultipleBatches() throws Exception { + Statement statement = connection.createStatement(); + ; + ResultSet rs = + statement.executeQuery( + "select seq1(), seq2(), seq4(), seq8() from TABLE (generator(rowcount => 30000))"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + int totalRows = 0; + ArrayList allRoots = new ArrayList<>(); + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + allRoots.add(root); + assertTrue(root.getVector(0) instanceof TinyIntVector); + assertTrue(root.getVector(1) instanceof SmallIntVector); + assertTrue(root.getVector(2) instanceof IntVector); + assertTrue(root.getVector(3) instanceof BigIntVector); + } } - @Test - public void testMultipleBatches() throws Exception { - Statement statement = connection.createStatement();; - ResultSet rs = statement.executeQuery("select seq1(), seq2(), seq4(), seq8() from TABLE (generator(rowcount => 30000))"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - int totalRows = 0; - ArrayList allRoots = new ArrayList<>(); - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - totalRows += root.getRowCount(); - allRoots.add(root); - assertTrue(root.getVector(0) instanceof TinyIntVector); - assertTrue(root.getVector(1) instanceof SmallIntVector); - assertTrue(root.getVector(2) instanceof IntVector); - assertTrue(root.getVector(3) instanceof BigIntVector); - } - } - - rs.close(); - - // The memory should not be freed when closing the result set. - for (VectorSchemaRoot root : allRoots) { - assertTrue(root.getVector(0).getValueCount() > 0); - root.close(); - } - assertNoMemoryLeaks(rs); - assertEquals(30000, totalRows); + rs.close(); + // The memory should not be freed when closing the result set. + for (VectorSchemaRoot root : allRoots) { + assertTrue(root.getVector(0).getValueCount() > 0); + root.close(); } - - @Test - public void testTinyIntBatch() throws Exception { - Statement statement = connection.createStatement();; - ResultSet rs = statement.executeQuery("select 1 union select 2 union select 3;"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - - int totalRows = 0; - List values = new ArrayList<>(); - - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - totalRows += root.getRowCount(); - assertTrue(root.getVector(0) instanceof TinyIntVector); - TinyIntVector vector = (TinyIntVector) root.getVector(0); - for (int i = 0; i < root.getRowCount(); i++) { - values.add(vector.get(i)); - } - root.close(); - } + assertNoMemoryLeaks(rs); + assertEquals(30000, totalRows); + } + + @Test + public void testTinyIntBatch() throws Exception { + Statement statement = connection.createStatement(); + ; + ResultSet rs = statement.executeQuery("select 1 union select 2 union select 3;"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + int totalRows = 0; + List values = new ArrayList<>(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof TinyIntVector); + TinyIntVector vector = (TinyIntVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.get(i)); } - rs.close(); - - // All expected values are present - for(byte i = 1; i < 4; i++) { - assertTrue(values.contains(i)); - } - - assertEquals(3, totalRows); + root.close(); + } } + rs.close(); - @Test - public void testSmallIntBatch() throws Exception { - Statement statement = connection.createStatement();; - ResultSet rs = statement.executeQuery("select 129 union select 130 union select 131;"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - - int totalRows = 0; - List values = new ArrayList<>(); - - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - totalRows += root.getRowCount(); - assertTrue(root.getVector(0) instanceof SmallIntVector); - SmallIntVector vector = (SmallIntVector) root.getVector(0); - for (int i = 0; i < root.getRowCount(); i++) { - values.add(vector.get(i)); - } - root.close(); - } - } - rs.close(); + // All expected values are present + for (byte i = 1; i < 4; i++) { + assertTrue(values.contains(i)); + } - // All expected values are present - for(short i = 129; i < 132; i++) { - assertTrue(values.contains(i)); + assertEquals(3, totalRows); + } + + @Test + public void testSmallIntBatch() throws Exception { + Statement statement = connection.createStatement(); + ; + ResultSet rs = statement.executeQuery("select 129 union select 130 union select 131;"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + int totalRows = 0; + List values = new ArrayList<>(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof SmallIntVector); + SmallIntVector vector = (SmallIntVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.get(i)); } - - assertEquals(3, totalRows); + root.close(); + } } + rs.close(); - @Test - public void testIntBatch() throws Exception { - Statement statement = connection.createStatement();; - ResultSet rs = statement.executeQuery("select 100000 union select 100001 union select 100002;"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - - int totalRows = 0; - List values = new ArrayList<>(); - - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - totalRows += root.getRowCount(); - assertTrue(root.getVector(0) instanceof IntVector); - IntVector vector = (IntVector) root.getVector(0); - for (int i = 0; i < root.getRowCount(); i++) { - values.add(vector.get(i)); - } - root.close(); - } - } - rs.close(); + // All expected values are present + for (short i = 129; i < 132; i++) { + assertTrue(values.contains(i)); + } - // All expected values are present - for(int i = 100000; i < 100003; i++) { - assertTrue(values.contains(i)); + assertEquals(3, totalRows); + } + + @Test + public void testIntBatch() throws Exception { + Statement statement = connection.createStatement(); + ; + ResultSet rs = statement.executeQuery("select 100000 union select 100001 union select 100002;"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + int totalRows = 0; + List values = new ArrayList<>(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof IntVector); + IntVector vector = (IntVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.get(i)); } - - assertEquals(3, totalRows); + root.close(); + } } + rs.close(); - @Test - public void testBigIntBatch() throws Exception { - Statement statement = connection.createStatement();; - ResultSet rs = statement.executeQuery("select 10000000000 union select 10000000001 union select 10000000002;"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - - int totalRows = 0; - List values = new ArrayList<>(); - - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - totalRows += root.getRowCount(); - assertTrue(root.getVector(0) instanceof BigIntVector); - BigIntVector vector = (BigIntVector) root.getVector(0); - for (int i = 0; i < root.getRowCount(); i++) { - values.add(vector.get(i)); - } - root.close(); - } - } - rs.close(); + // All expected values are present + for (int i = 100000; i < 100003; i++) { + assertTrue(values.contains(i)); + } - // All expected values are present - for(long i = 10000000000L; i < 10000000003L; i++) { - assertTrue(values.contains(i)); + assertEquals(3, totalRows); + } + + @Test + public void testBigIntBatch() throws Exception { + Statement statement = connection.createStatement(); + ; + ResultSet rs = + statement.executeQuery( + "select 10000000000 union select 10000000001 union select 10000000002;"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + int totalRows = 0; + List values = new ArrayList<>(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof BigIntVector); + BigIntVector vector = (BigIntVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.get(i)); } + root.close(); + } + } + rs.close(); - assertEquals(3, totalRows); + // All expected values are present + for (long i = 10000000000L; i < 10000000003L; i++) { + assertTrue(values.contains(i)); } - @Test - public void testDecimalBatch() throws Exception { - Statement statement = connection.createStatement();; - ResultSet rs = statement.executeQuery("select 1.1 union select 1.2 union select 1.3;"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - - int totalRows = 0; - List values = new ArrayList<>(); - - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - totalRows += root.getRowCount(); - assertTrue(root.getVector(0) instanceof DecimalVector); - DecimalVector vector = (DecimalVector) root.getVector(0); - for (int i = 0; i < root.getRowCount(); i++) { - values.add(vector.getObject(i)); - } - root.close(); - } + assertEquals(3, totalRows); + } + + @Test + public void testDecimalBatch() throws Exception { + Statement statement = connection.createStatement(); + ; + ResultSet rs = statement.executeQuery("select 1.1 union select 1.2 union select 1.3;"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + int totalRows = 0; + List values = new ArrayList<>(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof DecimalVector); + DecimalVector vector = (DecimalVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.getObject(i)); } + root.close(); + } + } - rs.close(); + rs.close(); - // All expected values are present - for(int i = 1; i < 4; i++) { - assertTrue(values.contains(new BigDecimal("1."+ i))); - } - - assertEquals(3, totalRows); + // All expected values are present + for (int i = 1; i < 4; i++) { + assertTrue(values.contains(new BigDecimal("1." + i))); } + + assertEquals(3, totalRows); + } } From 0fc7b5ab4d1716ef94b97c9d465738de441bf8ce Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Mon, 26 Aug 2024 10:42:15 +0200 Subject: [PATCH 04/21] Import formatting --- .../net/snowflake/client/core/SFArrowResultSet.java | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java index 50580ecee..3e3c195e5 100644 --- a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java +++ b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java @@ -30,8 +30,16 @@ import net.snowflake.client.core.arrow.VarCharConverter; import net.snowflake.client.core.arrow.VectorTypeConverter; import net.snowflake.client.core.json.Converters; -import net.snowflake.client.jdbc.*; +import net.snowflake.client.jdbc.ArrowBatch; +import net.snowflake.client.jdbc.ArrowBatches; +import net.snowflake.client.jdbc.ArrowResultChunk; import net.snowflake.client.jdbc.ArrowResultChunk.ArrowChunkIterator; +import net.snowflake.client.jdbc.ErrorCode; +import net.snowflake.client.jdbc.FieldMetadata; +import net.snowflake.client.jdbc.SnowflakeResultSetSerializableV1; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import net.snowflake.client.jdbc.SnowflakeSQLLoggedException; +import net.snowflake.client.jdbc.SnowflakeUtil; import net.snowflake.client.jdbc.telemetry.Telemetry; import net.snowflake.client.jdbc.telemetry.TelemetryData; import net.snowflake.client.jdbc.telemetry.TelemetryField; From b409510e57201d43825aad3ed6de754e31ab3490 Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Mon, 26 Aug 2024 13:17:54 +0200 Subject: [PATCH 05/21] Added missing interface definitions --- .../java/net/snowflake/client/jdbc/ArrowBatch.java | 10 ++++++++++ .../java/net/snowflake/client/jdbc/ArrowBatches.java | 11 +++++++++++ 2 files changed, 21 insertions(+) create mode 100644 src/main/java/net/snowflake/client/jdbc/ArrowBatch.java create mode 100644 src/main/java/net/snowflake/client/jdbc/ArrowBatches.java diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java b/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java new file mode 100644 index 000000000..ef55f4b8b --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java @@ -0,0 +1,10 @@ +package net.snowflake.client.jdbc; + +import java.util.List; +import org.apache.arrow.vector.VectorSchemaRoot; + +public interface ArrowBatch { + List fetch() throws SnowflakeSQLException; + + long getRowCount(); +} diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowBatches.java b/src/main/java/net/snowflake/client/jdbc/ArrowBatches.java new file mode 100644 index 000000000..c6fef3545 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/ArrowBatches.java @@ -0,0 +1,11 @@ +package net.snowflake.client.jdbc; + +import java.sql.SQLException; + +public interface ArrowBatches { + boolean hasNext(); + + ArrowBatch next() throws SQLException; + + long getRowCount(); +} From 1de2c394ba0e041c8397e789defcc340d7197e50 Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Tue, 27 Aug 2024 11:44:07 +0200 Subject: [PATCH 06/21] Implemented review feedback --- .../client/core/SFArrowResultSet.java | 31 ++-- .../AbstractArrowFullVectorConverter.java | 141 ------------------ .../ArrowFullVectorConverter.java | 116 ++++++++++++++ .../BigIntVectorConverter.java | 2 + .../DecimalVectorConverter.java | 2 + .../IntVectorConverter.java | 2 + .../SimpleArrowFullVectorConverter.java | 18 ++- .../SmallIntVectorConverter.java | 2 + .../TinyIntVectorConverter.java | 2 + .../snowflake/client/jdbc/ArrowBatches.java | 2 +- .../client/jdbc/ArrowResultChunk.java | 5 +- .../client/jdbc/SnowflakeChunkDownloader.java | 5 +- .../client/jdbc/ArrowBatchesTest.java | 14 +- 13 files changed, 179 insertions(+), 163 deletions(-) delete mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractArrowFullVectorConverter.java diff --git a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java index 3e3c195e5..ea818e33d 100644 --- a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java +++ b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java @@ -114,8 +114,14 @@ public class SFArrowResultSet extends SFBaseResultSet implements DataConversionC */ private boolean formatDateWithTimezone; - private boolean arrowBatchesMode = false; - private boolean rowIteratorMode = false; + /** The result set should be read either only as rows or only as batches */ + private enum ReadingMode { + UNSPECIFIED, + ROW_MODE, + BATCH_MODE + } + + private ReadingMode readingMode = ReadingMode.UNSPECIFIED; @SnowflakeJdbcInternalApi protected Converters converters; @@ -258,11 +264,11 @@ private boolean fetchNextRow() throws SnowflakeSQLException { private ArrowResultChunk fetchNextChunk() throws SnowflakeSQLException { try { + logger.debug("Fetching chunk number " + nextChunkIndex); eventHandler.triggerStateTransition( BasicEvent.QueryState.CONSUMING_RESULT, String.format( BasicEvent.QueryState.CONSUMING_RESULT.getArgString(), queryId, nextChunkIndex)); - ArrowResultChunk nextChunk = (ArrowResultChunk) chunkDownloader.getNextChunkToConsume(); if (nextChunk == null) { @@ -273,7 +279,7 @@ private ArrowResultChunk fetchNextChunk() throws SnowflakeSQLException { SqlState.INTERNAL_ERROR, "Expect chunk but got null for chunk index " + nextChunkIndex); } - + logger.debug("Chunk fetched successfully."); return nextChunk; } catch (InterruptedException ex) { throw new SnowflakeSQLLoggedException( @@ -442,10 +448,14 @@ public Timestamp convertToTimestamp( */ @Override public boolean next() throws SFException, SnowflakeSQLException { - if (isClosed() || arrowBatchesMode) { + if (isClosed()) { + return false; + } + if (readingMode == ReadingMode.BATCH_MODE) { + logger.warn("Cannot read rows after getArrowBatches() was called."); return false; } - rowIteratorMode = true; + readingMode = ReadingMode.ROW_MODE; // otherwise try to fetch again if (fetchNextRow()) { @@ -779,10 +789,11 @@ public BigDecimal getBigDecimal(int columnIndex, int scale) throws SFException { } public ArrowBatches getArrowBatches() { - if (rowIteratorMode) { + if (readingMode == ReadingMode.ROW_MODE) { + logger.warn("Cannot read arrow batches after next() was called."); return null; } - arrowBatchesMode = true; + readingMode = ReadingMode.BATCH_MODE; return new SFArrowBatchesIterator(); } @@ -790,8 +801,8 @@ private class SFArrowBatchesIterator implements ArrowBatches { private boolean firstFetched = false; @Override - public long getRowCount() { - return 0; + public long getRowCount() throws SQLException { + return resultSetSerializable.getRowCount(); } @Override diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractArrowFullVectorConverter.java deleted file mode 100644 index 8d580a8af..000000000 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractArrowFullVectorConverter.java +++ /dev/null @@ -1,141 +0,0 @@ -package net.snowflake.client.core.arrow.fullvectorconverters; - -import java.util.Map; -import net.snowflake.client.core.DataConversionContext; -import net.snowflake.client.core.SFBaseSession; -import net.snowflake.client.core.SFException; -import net.snowflake.client.jdbc.SnowflakeSQLException; -import net.snowflake.client.jdbc.SnowflakeType; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.ValueVector; -import org.apache.arrow.vector.types.Types; - -public abstract class AbstractArrowFullVectorConverter - implements ArrowFullVectorConverter { - protected RootAllocator allocator; - protected ValueVector vector; - protected DataConversionContext context; - protected SFBaseSession session; - protected int idx; - - public AbstractArrowFullVectorConverter( - RootAllocator allocator, - ValueVector vector, - DataConversionContext context, - SFBaseSession session, - int idx) { - this.allocator = allocator; - this.vector = vector; - this.context = context; - this.session = session; - this.idx = idx; - } - - private static Types.MinorType deduceType(ValueVector vector) { - Types.MinorType type = Types.getMinorTypeForArrowType(vector.getField().getType()); - // each column's metadata - Map customMeta = vector.getField().getMetadata(); - if (type == Types.MinorType.DECIMAL) { - // Note: Decimal vector is different from others - return Types.MinorType.DECIMAL; - } else if (!customMeta.isEmpty()) { - SnowflakeType st = SnowflakeType.valueOf(customMeta.get("logicalType")); - switch (st) { - case FIXED: - { - String scaleStr = vector.getField().getMetadata().get("scale"); - int sfScale = Integer.parseInt(scaleStr); - if (sfScale != 0) { - return Types.MinorType.DECIMAL; - } - break; - } - case TIME: - return Types.MinorType.TIMEMILLI; - case TIMESTAMP_LTZ: - { - String scaleStr = vector.getField().getMetadata().get("scale"); - int sfScale = Integer.parseInt(scaleStr); - switch (sfScale) { - case 0: - return Types.MinorType.TIMESTAMPSECTZ; - case 3: - return Types.MinorType.TIMESTAMPMILLITZ; - case 6: - return Types.MinorType.TIMESTAMPMICROTZ; - case 9: - return Types.MinorType.TIMESTAMPNANOTZ; - } - break; - } - case TIMESTAMP_TZ: - { - String scaleStr = vector.getField().getMetadata().get("scale"); - int sfScale = Integer.parseInt(scaleStr); - switch (sfScale) { - case 0: - return Types.MinorType.TIMESTAMPSECTZ; - case 3: - return Types.MinorType.TIMESTAMPMILLITZ; - case 6: - return Types.MinorType.TIMESTAMPMICROTZ; - case 9: - return Types.MinorType.TIMESTAMPNANOTZ; - } - break; - } - case TIMESTAMP_NTZ: - { - String scaleStr = vector.getField().getMetadata().get("scale"); - int sfScale = Integer.parseInt(scaleStr); - switch (sfScale) { - case 0: - return Types.MinorType.TIMESTAMPSEC; - case 3: - return Types.MinorType.TIMESTAMPMILLI; - case 6: - return Types.MinorType.TIMESTAMPMICRO; - case 9: - return Types.MinorType.TIMESTAMPNANO; - } - break; - } - } - } - return type; - } - - public static FieldVector convert( - RootAllocator allocator, - ValueVector vector, - DataConversionContext context, - SFBaseSession session, - int idx, - Object targetType) - throws SnowflakeSQLException { - try { - if (targetType == null) { - targetType = deduceType(vector); - } - if (targetType instanceof Types.MinorType) { - switch ((Types.MinorType) targetType) { - case TINYINT: - return new TinyIntVectorConverter(allocator, vector, context, session, idx).convert(); - case SMALLINT: - return new SmallIntVectorConverter(allocator, vector, context, session, idx).convert(); - case INT: - return new IntVectorConverter(allocator, vector, context, session, idx).convert(); - case BIGINT: - return new BigIntVectorConverter(allocator, vector, context, session, idx).convert(); - case DECIMAL: - return new DecimalVectorConverter(allocator, vector, context, session, idx).convert(); - } - } - } catch (SFException ex) { - throw new SnowflakeSQLException( - ex.getCause(), ex.getSqlState(), ex.getVendorCode(), ex.getParams()); - } - return null; - } -} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java index c8638ba19..98261096f 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java @@ -1,9 +1,125 @@ package net.snowflake.client.core.arrow.fullvectorconverters; +import java.util.Map; +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; import net.snowflake.client.jdbc.SnowflakeSQLException; +import net.snowflake.client.jdbc.SnowflakeType; +import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.types.Types; +@SnowflakeJdbcInternalApi public interface ArrowFullVectorConverter { + static Types.MinorType deduceType(ValueVector vector) { + Types.MinorType type = Types.getMinorTypeForArrowType(vector.getField().getType()); + // each column's metadata + Map customMeta = vector.getField().getMetadata(); + if (type == Types.MinorType.DECIMAL) { + // Note: Decimal vector is different from others + return Types.MinorType.DECIMAL; + } else if (!customMeta.isEmpty()) { + SnowflakeType st = SnowflakeType.valueOf(customMeta.get("logicalType")); + switch (st) { + case FIXED: + { + String scaleStr = vector.getField().getMetadata().get("scale"); + int sfScale = Integer.parseInt(scaleStr); + if (sfScale != 0) { + return Types.MinorType.DECIMAL; + } + break; + } + case TIME: + return Types.MinorType.TIMEMILLI; + case TIMESTAMP_LTZ: + { + String scaleStr = vector.getField().getMetadata().get("scale"); + int sfScale = Integer.parseInt(scaleStr); + switch (sfScale) { + case 0: + return Types.MinorType.TIMESTAMPSECTZ; + case 3: + return Types.MinorType.TIMESTAMPMILLITZ; + case 6: + return Types.MinorType.TIMESTAMPMICROTZ; + case 9: + return Types.MinorType.TIMESTAMPNANOTZ; + } + break; + } + case TIMESTAMP_TZ: + { + String scaleStr = vector.getField().getMetadata().get("scale"); + int sfScale = Integer.parseInt(scaleStr); + switch (sfScale) { + case 0: + return Types.MinorType.TIMESTAMPSECTZ; + case 3: + return Types.MinorType.TIMESTAMPMILLITZ; + case 6: + return Types.MinorType.TIMESTAMPMICROTZ; + case 9: + return Types.MinorType.TIMESTAMPNANOTZ; + } + break; + } + case TIMESTAMP_NTZ: + { + String scaleStr = vector.getField().getMetadata().get("scale"); + int sfScale = Integer.parseInt(scaleStr); + switch (sfScale) { + case 0: + return Types.MinorType.TIMESTAMPSEC; + case 3: + return Types.MinorType.TIMESTAMPMILLI; + case 6: + return Types.MinorType.TIMESTAMPMICRO; + case 9: + return Types.MinorType.TIMESTAMPNANO; + } + break; + } + } + } + return type; + } + + static FieldVector convert( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx, + Object targetType) + throws SnowflakeSQLException { + try { + if (targetType == null) { + targetType = deduceType(vector); + } + if (targetType instanceof Types.MinorType) { + switch ((Types.MinorType) targetType) { + case TINYINT: + return new TinyIntVectorConverter(allocator, vector, context, session, idx).convert(); + case SMALLINT: + return new SmallIntVectorConverter(allocator, vector, context, session, idx).convert(); + case INT: + return new IntVectorConverter(allocator, vector, context, session, idx).convert(); + case BIGINT: + return new BigIntVectorConverter(allocator, vector, context, session, idx).convert(); + case DECIMAL: + return new DecimalVectorConverter(allocator, vector, context, session, idx).convert(); + } + } + } catch (SFException ex) { + throw new SnowflakeSQLException( + ex.getCause(), ex.getSqlState(), ex.getVendorCode(), ex.getParams()); + } + return null; + } + FieldVector convert() throws SFException, SnowflakeSQLException; } diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BigIntVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BigIntVectorConverter.java index cc1325e6b..04e90e1a4 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BigIntVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BigIntVectorConverter.java @@ -3,11 +3,13 @@ import net.snowflake.client.core.DataConversionContext; import net.snowflake.client.core.SFBaseSession; import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; import net.snowflake.client.core.arrow.ArrowVectorConverter; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.ValueVector; +@SnowflakeJdbcInternalApi public class BigIntVectorConverter extends SimpleArrowFullVectorConverter { public BigIntVectorConverter( diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DecimalVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DecimalVectorConverter.java index cf824b526..d7421f858 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DecimalVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DecimalVectorConverter.java @@ -3,11 +3,13 @@ import net.snowflake.client.core.DataConversionContext; import net.snowflake.client.core.SFBaseSession; import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; import net.snowflake.client.core.arrow.ArrowVectorConverter; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.ValueVector; +@SnowflakeJdbcInternalApi public class DecimalVectorConverter extends SimpleArrowFullVectorConverter { public DecimalVectorConverter( diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/IntVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/IntVectorConverter.java index e18ac2b5d..db199e703 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/IntVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/IntVectorConverter.java @@ -3,11 +3,13 @@ import net.snowflake.client.core.DataConversionContext; import net.snowflake.client.core.SFBaseSession; import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; import net.snowflake.client.core.arrow.ArrowVectorConverter; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.ValueVector; +@SnowflakeJdbcInternalApi public class IntVectorConverter extends SimpleArrowFullVectorConverter { public IntVectorConverter( diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java index 7a39418bf..9534b3cca 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java @@ -3,21 +3,34 @@ import net.snowflake.client.core.DataConversionContext; import net.snowflake.client.core.SFBaseSession; import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; import net.snowflake.client.core.arrow.ArrowVectorConverter; import net.snowflake.client.jdbc.SnowflakeSQLException; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; +@SnowflakeJdbcInternalApi public abstract class SimpleArrowFullVectorConverter - extends AbstractArrowFullVectorConverter { + implements ArrowFullVectorConverter { + + protected RootAllocator allocator; + protected ValueVector vector; + protected DataConversionContext context; + protected SFBaseSession session; + protected int idx; + public SimpleArrowFullVectorConverter( RootAllocator allocator, ValueVector vector, DataConversionContext context, SFBaseSession session, int idx) { - super(allocator, vector, context, session, idx); + this.allocator = allocator; + this.vector = vector; + this.context = context; + this.session = session; + this.idx = idx; } protected abstract boolean matchingType(); @@ -26,7 +39,6 @@ public SimpleArrowFullVectorConverter( protected abstract void convertValue(ArrowVectorConverter from, T to, int idx) throws SFException; - @Override public FieldVector convert() throws SFException, SnowflakeSQLException { if (matchingType()) { return (FieldVector) vector; diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SmallIntVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SmallIntVectorConverter.java index bdc71ceee..f15a027ef 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SmallIntVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SmallIntVectorConverter.java @@ -3,11 +3,13 @@ import net.snowflake.client.core.DataConversionContext; import net.snowflake.client.core.SFBaseSession; import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; import net.snowflake.client.core.arrow.ArrowVectorConverter; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.SmallIntVector; import org.apache.arrow.vector.ValueVector; +@SnowflakeJdbcInternalApi public class SmallIntVectorConverter extends SimpleArrowFullVectorConverter { public SmallIntVectorConverter( diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TinyIntVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TinyIntVectorConverter.java index cc070c9de..a4c7bdb22 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TinyIntVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TinyIntVectorConverter.java @@ -3,11 +3,13 @@ import net.snowflake.client.core.DataConversionContext; import net.snowflake.client.core.SFBaseSession; import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; import net.snowflake.client.core.arrow.ArrowVectorConverter; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.ValueVector; +@SnowflakeJdbcInternalApi public class TinyIntVectorConverter extends SimpleArrowFullVectorConverter { public TinyIntVectorConverter( diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowBatches.java b/src/main/java/net/snowflake/client/jdbc/ArrowBatches.java index c6fef3545..fba1d8d3e 100644 --- a/src/main/java/net/snowflake/client/jdbc/ArrowBatches.java +++ b/src/main/java/net/snowflake/client/jdbc/ArrowBatches.java @@ -7,5 +7,5 @@ public interface ArrowBatches { ArrowBatch next() throws SQLException; - long getRowCount(); + long getRowCount() throws SQLException; } diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java index b08c5417f..cecc0d61b 100644 --- a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java +++ b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java @@ -15,7 +15,7 @@ import net.snowflake.client.core.SFException; import net.snowflake.client.core.arrow.ArrowResultChunkIndexSorter; import net.snowflake.client.core.arrow.ArrowVectorConverter; -import net.snowflake.client.core.arrow.fullvectorconverters.AbstractArrowFullVectorConverter; +import net.snowflake.client.core.arrow.fullvectorconverters.ArrowFullVectorConverter; import net.snowflake.client.log.SFLogger; import net.snowflake.client.log.SFLoggerFactory; import net.snowflake.common.core.SqlState; @@ -518,8 +518,7 @@ public List fetch() throws SnowflakeSQLException { for (int i = 0; i < record.size(); i++) { ValueVector vector = record.get(i); convertedVectors.add( - AbstractArrowFullVectorConverter.convert( - rootAllocator, vector, context, session, i, null)); + ArrowFullVectorConverter.convert(rootAllocator, vector, context, session, i, null)); } result.add(new VectorSchemaRoot(convertedVectors)); } diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeChunkDownloader.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeChunkDownloader.java index 8f29f5702..6db9aede0 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeChunkDownloader.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeChunkDownloader.java @@ -844,7 +844,7 @@ public DownloaderMetrics terminate() throws InterruptedException { logger.info( "Completed processing {} {} chunks for query {} in {} ms. Download took {} ms (average: {} ms)," + " parsing took {} ms (average: {} ms). Chunks uncompressed size: {} MB (average: {} MB)," - + " rows in chunks: {} (total: {}, average in chunk: {}), total memory used: {} MB", + + " rows in chunks: {} (total: {}, average in chunk: {}), total memory used: {} MB, free memory {} MB", chunksSize, queryResultFormat == QueryResultFormat.ARROW ? "ARROW" : "JSON", queryId, @@ -858,7 +858,8 @@ public DownloaderMetrics terminate() throws InterruptedException { rowsInChunks, firstChunkRowCount + rowsInChunks, rowsInChunks / chunksSize, - Runtime.getRuntime().totalMemory() / MB); + Runtime.getRuntime().totalMemory() / MB, + Runtime.getRuntime().freeMemory() / MB); return new DownloaderMetrics( numberMillisWaitingForChunks, diff --git a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java index 90bfde7c5..86e7b1cad 100644 --- a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java +++ b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java @@ -16,6 +16,7 @@ import org.apache.arrow.vector.SmallIntVector; import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -27,6 +28,13 @@ public void setUp() throws Exception { } } + @After + public void tearDown() throws Exception { + try (Statement statement = connection.createStatement()) { + statement.execute("alter session unset jdbc_query_result_format"); + } + } + private static void assertNoMemoryLeaks(ResultSet rs) throws SQLException { assertEquals( ((SFArrowResultSet) rs.unwrap(SnowflakeResultSetV1.class).sfBaseResultSet) @@ -37,11 +45,11 @@ private static void assertNoMemoryLeaks(ResultSet rs) throws SQLException { @Test public void testMultipleBatches() throws Exception { Statement statement = connection.createStatement(); - ; ResultSet rs = statement.executeQuery( - "select seq1(), seq2(), seq4(), seq8() from TABLE (generator(rowcount => 30000))"); + "select seq1(), seq2(), seq4(), seq8() from TABLE (generator(rowcount => 300000))"); ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + assertEquals(batches.getRowCount(), 300000); int totalRows = 0; ArrayList allRoots = new ArrayList<>(); while (batches.hasNext()) { @@ -65,7 +73,7 @@ public void testMultipleBatches() throws Exception { root.close(); } assertNoMemoryLeaks(rs); - assertEquals(30000, totalRows); + assertEquals(300000, totalRows); } @Test From 969c59cabc340bb8e24c19f4f90b7dc8734df72e Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Wed, 28 Aug 2024 08:39:09 +0200 Subject: [PATCH 07/21] Further review feedback --- .../client/core/SFArrowResultSet.java | 2 +- .../client/jdbc/ArrowResultChunk.java | 48 +++++++++---------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java index ea818e33d..d14b96177 100644 --- a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java +++ b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java @@ -279,7 +279,7 @@ private ArrowResultChunk fetchNextChunk() throws SnowflakeSQLException { SqlState.INTERNAL_ERROR, "Expect chunk but got null for chunk index " + nextChunkIndex); } - logger.debug("Chunk fetched successfully."); + logger.debug("Chunk number " + nextChunkIndex + " fetched successfully."); return nextChunk; } catch (InterruptedException ex) { throw new SnowflakeSQLLoggedException( diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java index cecc0d61b..a1905ec0d 100644 --- a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java +++ b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java @@ -504,6 +504,30 @@ public ArrowBatch getArrowBatch(DataConversionContext context) { return new ArrowResultBatch(context); } + private boolean sortFirstResultChunkEnabled() { + return enableSortFirstResultChunk; + } + + /** + * Empty arrow result chunk implementation. Used when rowset from server is null or empty or in + * testing + */ + private static class EmptyArrowResultChunk extends ArrowResultChunk { + EmptyArrowResultChunk() { + super("", 0, 0, 0, null, null); + } + + @Override + public final long computeNeededChunkMemory() { + return 0; + } + + @Override + public final void freeData() { + // do nothing + } + } + public class ArrowResultBatch implements ArrowBatch { private DataConversionContext context; @@ -530,28 +554,4 @@ public long getRowCount() { return rowCount; } } - - private boolean sortFirstResultChunkEnabled() { - return enableSortFirstResultChunk; - } - - /** - * Empty arrow result chunk implementation. Used when rowset from server is null or empty or in - * testing - */ - private static class EmptyArrowResultChunk extends ArrowResultChunk { - EmptyArrowResultChunk() { - super("", 0, 0, 0, null, null); - } - - @Override - public final long computeNeededChunkMemory() { - return 0; - } - - @Override - public final void freeData() { - // do nothing - } - } } From a1cfe08075749cadba39014c8cbb965296817269 Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Tue, 3 Sep 2024 14:54:18 +0200 Subject: [PATCH 08/21] Added handling of remaining types --- .../client/core/SFArrowResultSet.java | 7 +- .../ArrowFullVectorConverter.java | 81 ++--- .../BinaryVectorConverter.java | 40 +++ .../BitVectorConverter.java | 40 +++ .../DateVectorConverter.java | 51 +++ .../FloatVectorConverter.java | 41 +++ .../SimpleArrowFullVectorConverter.java | 3 + .../TimeMicroVectorConverter.java | 29 ++ .../TimeMilliVectorConverter.java | 28 ++ .../TimeNanoVectorConverter.java | 29 ++ .../TimeSecVectorConverter.java | 28 ++ .../TimeVectorConverter.java | 44 +++ .../TimestampVectorConverter.java | 165 +++++++++ .../net/snowflake/client/jdbc/ArrowBatch.java | 4 + .../client/jdbc/ArrowResultChunk.java | 18 +- .../client/jdbc/ArrowBatchesTest.java | 340 +++++++++++++++++- 16 files changed, 899 insertions(+), 49 deletions(-) create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BinaryVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BitVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DateVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/FloatVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeMicroVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeMilliVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeNanoVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeSecVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java diff --git a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java index d14b96177..9c278a873 100644 --- a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java +++ b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java @@ -814,10 +814,13 @@ public boolean hasNext() { public ArrowBatch next() throws SQLException { if (!firstFetched) { firstFetched = true; - return currentChunkIterator.getChunk().getArrowBatch(SFArrowResultSet.this); + return currentChunkIterator + .getChunk() + .getArrowBatch(SFArrowResultSet.this, useSessionTimezone ? sessionTimeZone : null); } else { nextChunkIndex++; - return fetchNextChunk().getArrowBatch(SFArrowResultSet.this); + return fetchNextChunk() + .getArrowBatch(SFArrowResultSet.this, useSessionTimezone ? sessionTimeZone : null); } } } diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java index 98261096f..d2edf29c9 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java @@ -1,6 +1,7 @@ package net.snowflake.client.core.arrow.fullvectorconverters; import java.util.Map; +import java.util.TimeZone; import net.snowflake.client.core.DataConversionContext; import net.snowflake.client.core.SFBaseSession; import net.snowflake.client.core.SFException; @@ -34,55 +35,27 @@ static Types.MinorType deduceType(ValueVector vector) { break; } case TIME: - return Types.MinorType.TIMEMILLI; - case TIMESTAMP_LTZ: { String scaleStr = vector.getField().getMetadata().get("scale"); int sfScale = Integer.parseInt(scaleStr); - switch (sfScale) { - case 0: - return Types.MinorType.TIMESTAMPSECTZ; - case 3: - return Types.MinorType.TIMESTAMPMILLITZ; - case 6: - return Types.MinorType.TIMESTAMPMICROTZ; - case 9: - return Types.MinorType.TIMESTAMPNANOTZ; + if (sfScale == 0) { + return Types.MinorType.TIMESEC; } - break; - } - case TIMESTAMP_TZ: - { - String scaleStr = vector.getField().getMetadata().get("scale"); - int sfScale = Integer.parseInt(scaleStr); - switch (sfScale) { - case 0: - return Types.MinorType.TIMESTAMPSECTZ; - case 3: - return Types.MinorType.TIMESTAMPMILLITZ; - case 6: - return Types.MinorType.TIMESTAMPMICROTZ; - case 9: - return Types.MinorType.TIMESTAMPNANOTZ; + if (sfScale <= 3) { + return Types.MinorType.TIMEMILLI; } - break; - } - case TIMESTAMP_NTZ: - { - String scaleStr = vector.getField().getMetadata().get("scale"); - int sfScale = Integer.parseInt(scaleStr); - switch (sfScale) { - case 0: - return Types.MinorType.TIMESTAMPSEC; - case 3: - return Types.MinorType.TIMESTAMPMILLI; - case 6: - return Types.MinorType.TIMESTAMPMICRO; - case 9: - return Types.MinorType.TIMESTAMPNANO; + if (sfScale <= 6) { + return Types.MinorType.TIMEMICRO; + } + if (sfScale <= 9) { + return Types.MinorType.TIMENANO; } - break; } + case TIMESTAMP_NTZ: + return Types.MinorType.TIMESTAMPNANO; + case TIMESTAMP_LTZ: + case TIMESTAMP_TZ: + return Types.MinorType.TIMESTAMPNANOTZ; } } return type; @@ -93,6 +66,7 @@ static FieldVector convert( ValueVector vector, DataConversionContext context, SFBaseSession session, + TimeZone timeZoneToUse, int idx, Object targetType) throws SnowflakeSQLException { @@ -112,6 +86,29 @@ static FieldVector convert( return new BigIntVectorConverter(allocator, vector, context, session, idx).convert(); case DECIMAL: return new DecimalVectorConverter(allocator, vector, context, session, idx).convert(); + case FLOAT8: + return new FloatVectorConverter(allocator, vector, context, session, idx).convert(); + case BIT: + return new BitVectorConverter(allocator, vector, context, session, idx).convert(); + case VARBINARY: + return new BinaryVectorConverter(allocator, vector, context, session, idx).convert(); + case TIMESTAMPNANOTZ: + return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, false) + .convert(); + case TIMESTAMPNANO: + return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, true) + .convert(); + case DATEDAY: + return new DateVectorConverter(allocator, vector, context, session, idx, timeZoneToUse) + .convert(); + case TIMESEC: + return new TimeSecVectorConverter(allocator, vector).convert(); + case TIMEMILLI: + return new TimeMilliVectorConverter(allocator, vector).convert(); + case TIMEMICRO: + return new TimeMicroVectorConverter(allocator, vector).convert(); + case TIMENANO: + return new TimeNanoVectorConverter(allocator, vector).convert(); } } } catch (SFException ex) { diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BinaryVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BinaryVectorConverter.java new file mode 100644 index 000000000..8cee6d3f5 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BinaryVectorConverter.java @@ -0,0 +1,40 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarBinaryVector; + +@SnowflakeJdbcInternalApi +public class BinaryVectorConverter extends SimpleArrowFullVectorConverter { + public BinaryVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return vector instanceof VarBinaryVector; + } + + @Override + protected VarBinaryVector initVector() { + VarBinaryVector resultVector = new VarBinaryVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, VarBinaryVector to, int idx) + throws SFException { + to.set(idx, from.toBytes(idx)); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BitVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BitVectorConverter.java new file mode 100644 index 000000000..76701800f --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BitVectorConverter.java @@ -0,0 +1,40 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class BitVectorConverter extends SimpleArrowFullVectorConverter { + + public BitVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return vector instanceof BitVector; + } + + @Override + protected BitVector initVector() { + BitVector resultVector = new BitVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, BitVector to, int idx) throws SFException { + to.set(idx, from.toBoolean(idx) ? 1 : 0); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DateVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DateVectorConverter.java new file mode 100644 index 000000000..38b04283a --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DateVectorConverter.java @@ -0,0 +1,51 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import java.util.TimeZone; +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class DateVectorConverter extends SimpleArrowFullVectorConverter { + private TimeZone timeZone; + + public DateVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx, + TimeZone timeZone) { + super(allocator, vector, context, session, idx); + this.timeZone = timeZone; + } + + @Override + protected boolean matchingType() { + return vector instanceof DateDayVector; + } + + @Override + protected DateDayVector initVector() { + DateDayVector resultVector = new DateDayVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void additionalConverterInit(ArrowVectorConverter converter) { + converter.setSessionTimeZone(timeZone); + converter.setUseSessionTimezone(true); + } + + @Override + protected void convertValue(ArrowVectorConverter from, DateDayVector to, int idx) + throws SFException { + to.set(idx, (int) (from.toDate(idx, null, false).getTime() / (1000 * 3600 * 24))); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/FloatVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/FloatVectorConverter.java new file mode 100644 index 000000000..e47079293 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/FloatVectorConverter.java @@ -0,0 +1,41 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class FloatVectorConverter extends SimpleArrowFullVectorConverter { + + public FloatVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return vector instanceof Float8Vector; + } + + @Override + protected Float8Vector initVector() { + Float8Vector resultVector = new Float8Vector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, Float8Vector to, int idx) + throws SFException { + to.set(idx, from.toDouble(idx)); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java index 9534b3cca..001145658 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java @@ -39,6 +39,8 @@ public SimpleArrowFullVectorConverter( protected abstract void convertValue(ArrowVectorConverter from, T to, int idx) throws SFException; + protected void additionalConverterInit(ArrowVectorConverter converter) {} + public FieldVector convert() throws SFException, SnowflakeSQLException { if (matchingType()) { return (FieldVector) vector; @@ -47,6 +49,7 @@ public FieldVector convert() throws SFException, SnowflakeSQLException { T converted = initVector(); ArrowVectorConverter converter = ArrowVectorConverter.initConverter(vector, context, session, idx); + additionalConverterInit(converter); for (int i = 0; i < size; i++) { convertValue(converter, converted, i); } diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeMicroVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeMicroVectorConverter.java new file mode 100644 index 000000000..93bc6318e --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeMicroVectorConverter.java @@ -0,0 +1,29 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class TimeMicroVectorConverter extends TimeVectorConverter { + + public TimeMicroVectorConverter(RootAllocator allocator, ValueVector vector) { + super(allocator, vector); + } + + @Override + protected TimeMicroVector initVector() { + return new TimeMicroVector(vector.getName(), allocator); + } + + @Override + protected void convertValue(TimeMicroVector dstVector, int idx, long value) { + dstVector.set(idx, value); + } + + @Override + protected int targetScale() { + return 6; + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeMilliVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeMilliVectorConverter.java new file mode 100644 index 000000000..63a56c73c --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeMilliVectorConverter.java @@ -0,0 +1,28 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class TimeMilliVectorConverter extends TimeVectorConverter { + public TimeMilliVectorConverter(RootAllocator allocator, ValueVector vector) { + super(allocator, vector); + } + + @Override + protected TimeMilliVector initVector() { + return new TimeMilliVector(vector.getName(), allocator); + } + + @Override + protected void convertValue(TimeMilliVector dstVector, int idx, long value) { + dstVector.set(idx, (int) value); + } + + @Override + protected int targetScale() { + return 3; + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeNanoVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeNanoVectorConverter.java new file mode 100644 index 000000000..ad91e7a67 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeNanoVectorConverter.java @@ -0,0 +1,29 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class TimeNanoVectorConverter extends TimeVectorConverter { + + public TimeNanoVectorConverter(RootAllocator allocator, ValueVector vector) { + super(allocator, vector); + } + + @Override + protected TimeNanoVector initVector() { + return new TimeNanoVector(vector.getName(), allocator); + } + + @Override + protected void convertValue(TimeNanoVector dstVector, int idx, long value) { + dstVector.set(idx, value); + } + + @Override + protected int targetScale() { + return 9; + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeSecVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeSecVectorConverter.java new file mode 100644 index 000000000..64498c715 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeSecVectorConverter.java @@ -0,0 +1,28 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class TimeSecVectorConverter extends TimeVectorConverter { + public TimeSecVectorConverter(RootAllocator allocator, ValueVector vector) { + super(allocator, vector); + } + + @Override + protected TimeSecVector initVector() { + return new TimeSecVector(vector.getName(), allocator); + } + + @Override + protected void convertValue(TimeSecVector dstVector, int idx, long value) { + dstVector.set(idx, (int) value); + } + + @Override + protected int targetScale() { + return 0; + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeVectorConverter.java new file mode 100644 index 000000000..6f41dfc07 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeVectorConverter.java @@ -0,0 +1,44 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowResultUtil; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public abstract class TimeVectorConverter + implements ArrowFullVectorConverter { + protected RootAllocator allocator; + protected ValueVector vector; + + public TimeVectorConverter(RootAllocator allocator, ValueVector vector) { + this.allocator = allocator; + this.vector = vector; + } + + protected abstract T initVector(); + + protected abstract void convertValue(T dstVector, int idx, long value); + + protected abstract int targetScale(); + + @Override + public FieldVector convert() throws SFException, SnowflakeSQLException { + int size = vector.getValueCount(); + T converted = initVector(); + converted.allocateNew(size); + BaseIntVector srcVector = (BaseIntVector) vector; + int scale = Integer.parseInt(vector.getField().getMetadata().get("scale")); + long scalingFactor = ArrowResultUtil.powerOfTen(targetScale() - scale); + for (int i = 0; i < size; i++) { + convertValue(converted, i, srcVector.getValueAsLong(i) * scalingFactor); + } + converted.setValueCount(size); + return converted; + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java new file mode 100644 index 000000000..448a7f63f --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java @@ -0,0 +1,165 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import java.util.ArrayList; +import java.util.List; +import java.util.TimeZone; +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowResultUtil; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import net.snowflake.client.util.SFPair; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.types.pojo.Field; + +@SnowflakeJdbcInternalApi +public class TimestampVectorConverter implements ArrowFullVectorConverter { + private RootAllocator allocator; + private ValueVector vector; + private DataConversionContext context; + private TimeZone timeZoneToUse; + private boolean isNTZ; + + /** Field names of the struct vectors used by timestamp */ + private static final String FIELD_NAME_EPOCH = "epoch"; // seconds since epoch + + private static final String FIELD_NAME_TIME_ZONE_INDEX = "timezone"; // time zone index + private static final String FIELD_NAME_FRACTION = "fraction"; // fraction in nanoseconds + + public TimestampVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + TimeZone timeZoneToUse, + boolean isNTZ) { + this.allocator = allocator; + this.vector = vector; + this.context = context; + this.timeZoneToUse = timeZoneToUse; + this.isNTZ = isNTZ; + } + + private IntVector makeVectorOfZeroes(int length) { + IntVector vector = new IntVector(FIELD_NAME_FRACTION, allocator); + vector.allocateNew(length); + vector.zeroVector(); + vector.setValueCount(length); + return vector; + } + + private IntVector makeVectorOfUTCOffsets(int length) { + IntVector vector = new IntVector(FIELD_NAME_TIME_ZONE_INDEX, allocator); + vector.allocateNew(length); + vector.setValueCount(length); + for (int i = 0; i < length; i++) { + vector.set(i, 1440); + } + return vector; + } + + private SFPair normalizeTimeSinceEpoch(BigIntVector vector) { + int length = vector.getValueCount(); + int scale = Integer.parseInt(vector.getField().getMetadata().get("scale")); + if (scale == 0) { + IntVector fractions = makeVectorOfZeroes(length); + fractions + .getValidityBuffer() + .setBytes(0L, vector.getValidityBuffer(), 0L, fractions.getValidityBuffer().capacity()); + return SFPair.of(vector, fractions); + } + long scaleFactor = ArrowResultUtil.powerOfTen(scale); + long fractionScaleFactor = ArrowResultUtil.powerOfTen(9 - scale); + BigIntVector epoch = new BigIntVector(FIELD_NAME_EPOCH, allocator); + epoch.allocateNew(length); + epoch.setValueCount(length); + IntVector fractions = new IntVector(FIELD_NAME_FRACTION, allocator); + fractions.allocateNew(length); + fractions.setValueCount(length); + for (int i = 0; i < length; i++) { + epoch.set(i, vector.get(i) / scaleFactor); + fractions.set(i, (int) ((vector.get(i) % scaleFactor) * fractionScaleFactor)); + } + return SFPair.of(vector, fractions); + } + + private IntVector makeTimeZoneOffsets( + BigIntVector seconds, IntVector fractions, TimeZone timeZone) { + IntVector offsets = new IntVector(FIELD_NAME_TIME_ZONE_INDEX, allocator); + offsets.allocateNew(vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + offsets.set( + i, + 1440 + + timeZone.getOffset(seconds.get(i) * 1000 + fractions.get(i) / 1000000) + / (1000 * 60)); + } + return offsets; + } + + private StructVector pack(BigIntVector seconds, IntVector fractions, IntVector offsets) { + StructVector result = StructVector.empty(vector.getName(), allocator); + List fields = + new ArrayList() { + { + add(seconds.getField()); + add(fractions.getField()); + add(offsets.getField()); + } + }; + result.setInitialCapacity(seconds.getValueCount()); + result.initializeChildrenFromFields(fields); + seconds.makeTransferPair(result.getChild(FIELD_NAME_EPOCH)).transfer(); + fractions.makeTransferPair(result.getChild(FIELD_NAME_FRACTION)).transfer(); + offsets.makeTransferPair(result.getChild(FIELD_NAME_TIME_ZONE_INDEX)).transfer(); + result.setValueCount(vector.getValueCount()); + result + .getValidityBuffer() + .setBytes(0L, vector.getValidityBuffer(), 0L, vector.getValidityBuffer().capacity()); + return result; + } + + @Override + public FieldVector convert() throws SFException, SnowflakeSQLException { + BigIntVector seconds; + IntVector fractions; + IntVector timeZoneIndices = null; + if (vector instanceof BigIntVector) { + SFPair normalized = normalizeTimeSinceEpoch((BigIntVector) vector); + seconds = normalized.left; + fractions = normalized.right; + } else { + StructVector structVector = (StructVector) vector; + if (structVector.getChildrenFromFields().size() == 3) { + return structVector; + } + if (structVector.getChild(FIELD_NAME_FRACTION) == null) { + SFPair normalized = + normalizeTimeSinceEpoch(structVector.getChild(FIELD_NAME_EPOCH, BigIntVector.class)); + seconds = normalized.left; + fractions = normalized.right; + } else { + seconds = structVector.getChild(FIELD_NAME_EPOCH, BigIntVector.class); + fractions = structVector.getChild(FIELD_NAME_FRACTION, IntVector.class); + } + timeZoneIndices = structVector.getChild(FIELD_NAME_TIME_ZONE_INDEX, IntVector.class); + } + if (timeZoneIndices == null) { + if (isNTZ && context.getHonorClientTZForTimestampNTZ()) { + timeZoneIndices = makeTimeZoneOffsets(seconds, fractions, TimeZone.getDefault()); + for (int i = 0; i < vector.getValueCount(); i++) { + seconds.set(i, seconds.get(i) - (timeZoneIndices.get(i) - 1440) * 60L); + } + } else if (isNTZ || timeZoneToUse == null) { + timeZoneIndices = makeVectorOfUTCOffsets(vector.getValueCount()); + } else { + timeZoneIndices = makeTimeZoneOffsets(seconds, fractions, timeZoneToUse); + } + } + return pack(seconds, fractions, timeZoneIndices); + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java b/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java index ef55f4b8b..c9dd11c12 100644 --- a/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java +++ b/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java @@ -1,10 +1,14 @@ package net.snowflake.client.jdbc; import java.util.List; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; public interface ArrowBatch { List fetch() throws SnowflakeSQLException; + ArrowVectorConverter getTimestampConverter(FieldVector vector, int colIdx); + long getRowCount(); } diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java index a1905ec0d..72b74f27c 100644 --- a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java +++ b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java @@ -10,11 +10,13 @@ import java.nio.channels.ClosedByInterruptException; import java.util.ArrayList; import java.util.List; +import java.util.TimeZone; import net.snowflake.client.core.DataConversionContext; import net.snowflake.client.core.SFBaseSession; import net.snowflake.client.core.SFException; import net.snowflake.client.core.arrow.ArrowResultChunkIndexSorter; import net.snowflake.client.core.arrow.ArrowVectorConverter; +import net.snowflake.client.core.arrow.ThreeFieldStructToTimestampTZConverter; import net.snowflake.client.core.arrow.fullvectorconverters.ArrowFullVectorConverter; import net.snowflake.client.log.SFLogger; import net.snowflake.client.log.SFLoggerFactory; @@ -499,9 +501,9 @@ private void sortFirstResultChunk(List converters) } } - public ArrowBatch getArrowBatch(DataConversionContext context) { + public ArrowBatch getArrowBatch(DataConversionContext context, TimeZone timeZoneToUse) { batchesMode = true; - return new ArrowResultBatch(context); + return new ArrowResultBatch(context, timeZoneToUse); } private boolean sortFirstResultChunkEnabled() { @@ -530,9 +532,11 @@ public final void freeData() { public class ArrowResultBatch implements ArrowBatch { private DataConversionContext context; + private TimeZone timeZoneToUse; - ArrowResultBatch(DataConversionContext context) { + ArrowResultBatch(DataConversionContext context, TimeZone timeZoneToUse) { this.context = context; + this.timeZoneToUse = timeZoneToUse; } public List fetch() throws SnowflakeSQLException { @@ -542,13 +546,19 @@ public List fetch() throws SnowflakeSQLException { for (int i = 0; i < record.size(); i++) { ValueVector vector = record.get(i); convertedVectors.add( - ArrowFullVectorConverter.convert(rootAllocator, vector, context, session, i, null)); + ArrowFullVectorConverter.convert( + rootAllocator, vector, context, session, timeZoneToUse, i, null)); } result.add(new VectorSchemaRoot(convertedVectors)); } return result; } + @Override + public ArrowVectorConverter getTimestampConverter(FieldVector vector, int colIdx) { + return new ThreeFieldStructToTimestampTZConverter(vector, colIdx, context); + } + @Override public long getRowCount() { return rowCount; diff --git a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java index 86e7b1cad..9d0b4f5af 100644 --- a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java +++ b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java @@ -7,15 +7,28 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalTime; import java.util.ArrayList; import java.util.List; import net.snowflake.client.core.SFArrowResultSet; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.arrow.ArrowVectorConverter; import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.StructVector; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -213,7 +226,6 @@ public void testBigIntBatch() throws Exception { @Test public void testDecimalBatch() throws Exception { Statement statement = connection.createStatement(); - ; ResultSet rs = statement.executeQuery("select 1.1 union select 1.2 union select 1.3;"); ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); @@ -243,4 +255,330 @@ public void testDecimalBatch() throws Exception { assertEquals(3, totalRows); } + + @Test + public void testBitBatch() throws Exception { + Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select true union all select false union all select true union all select false" + + " union all select true union all select false union all select true"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + int trueCount = 0; + int falseCount = 0; + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + assertTrue(root.getVector(0) instanceof BitVector); + BitVector vector = (BitVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + if (vector.getObject(i)) { + trueCount++; + } else { + falseCount++; + } + } + root.close(); + } + } + + assertEquals(4, trueCount); + assertEquals(3, falseCount); + } + + @Test + public void testBinaryBatch() throws Exception { + Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery("select TO_BINARY('546AB0') union select TO_BINARY('018E3271')"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + int totalRows = 0; + List> values = new ArrayList<>(); + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + assertTrue(root.getVector(0) instanceof VarBinaryVector); + VarBinaryVector vector = (VarBinaryVector) root.getVector(0); + totalRows += root.getRowCount(); + for (int i = 0; i < root.getRowCount(); i++) { + byte[] bytes = vector.getObject(i); + ArrayList byteArrayList = + new ArrayList() { + { + for (byte aByte : bytes) { + add(aByte); + } + } + }; + values.add(byteArrayList); + } + root.close(); + } + } + + List> expected = + new ArrayList>() { + { + add( + new ArrayList() { + { + add((byte) 0x54); + add((byte) 0x6A); + add((byte) 0xB0); + } + }); + add( + new ArrayList() { + { + add((byte) 0x01); + add((byte) 0x8E); + add((byte) 0x32); + add((byte) 0x71); + } + }); + } + }; + + assertEquals(2, totalRows); + assertTrue(values.containsAll(expected)); + } + + private void testTimestampBase(String query) throws Exception, SFException { + Statement statement = connection.createStatement(); + ResultSet rs = statement.executeQuery(query); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + ArrowBatch batch = batches.next(); + VectorSchemaRoot root = batch.fetch().get(0); + assertTrue(root.getVector(0) instanceof StructVector); + ArrowVectorConverter converter = batch.getTimestampConverter(root.getVector(0), 1); + Timestamp tsFromBatch = converter.toTimestamp(0, null); + + rs = statement.executeQuery(query); + rs.next(); + Timestamp tsFromRow = rs.getTimestamp(1); + + assertTrue(tsFromBatch.equals(tsFromRow)); + root.close(); + } + + @Test + public void testTimestampTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_TZ"); + } + + @Test + public void testTimestampLTZUseSessionTimezoneBatch() throws Exception, SFException { + Statement statement = connection.createStatement(); + statement.execute("alter session set JDBC_USE_SESSION_TIMEZONE=true"); + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_LTZ"); + statement.execute("alter session unset JDBC_USE_SESSION_TIMEZONE"); + } + + @Test + public void testTimestampLTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12+0700'::TIMESTAMP_LTZ"); + } + + @Test + public void testTimestampNTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_NTZ"); + } + + @Test + public void testDateBatch() throws Exception, SFException { + Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery("select '1119-02-01'::DATE union select '2021-09-11'::DATE"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + int totalRows = 0; + List values = new ArrayList<>(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof DateDayVector); + DateDayVector vector = (DateDayVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(LocalDate.ofEpochDay(vector.get(i))); + } + root.close(); + } + } + + rs.close(); + + List expected = + new ArrayList() { + { + add(LocalDate.of(1119, 2, 1)); + add(LocalDate.of(2021, 9, 11)); + } + }; + + assertEquals(2, totalRows); + assertTrue(values.containsAll(expected)); + } + + @Test + public void testTimeSecBatch() throws Exception, SFException { + Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery("select '11:32:54'::TIME(0) union select '8:11:25'::TIME(0)"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + int totalRows = 0; + List values = new ArrayList<>(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof TimeSecVector); + TimeSecVector vector = (TimeSecVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(LocalTime.ofSecondOfDay(vector.get(i))); + } + root.close(); + } + } + + rs.close(); + + List expected = + new ArrayList() { + { + add(LocalTime.of(11, 32, 54)); + add(LocalTime.of(8, 11, 25)); + } + }; + + assertEquals(2, totalRows); + assertTrue(values.containsAll(expected)); + } + + @Test + public void testTimeMilliBatch() throws Exception, SFException { + Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery("select '11:32:54.13'::TIME(2) union select '8:11:25.91'::TIME(2)"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + int totalRows = 0; + List values = new ArrayList<>(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof TimeMilliVector); + TimeMilliVector vector = (TimeMilliVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.getObject(i).toLocalTime()); + } + root.close(); + } + } + + rs.close(); + + List expected = + new ArrayList() { + { + add(LocalTime.of(11, 32, 54, 130 * 1000 * 1000)); + add(LocalTime.of(8, 11, 25, 910 * 1000 * 1000)); + } + }; + + assertEquals(2, totalRows); + assertTrue(values.containsAll(expected)); + } + + @Test + public void testTimeMicroBatch() throws Exception, SFException { + Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select '11:32:54.139901'::TIME(6) union select '8:11:25.911765'::TIME(6)"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + int totalRows = 0; + List values = new ArrayList<>(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof TimeMicroVector); + TimeMicroVector vector = (TimeMicroVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(LocalTime.ofNanoOfDay(vector.get(i) * 1000)); + } + root.close(); + } + } + + rs.close(); + + List expected = + new ArrayList() { + { + add(LocalTime.of(11, 32, 54, 139901 * 1000)); + add(LocalTime.of(8, 11, 25, 911765 * 1000)); + } + }; + + assertEquals(2, totalRows); + assertTrue(values.containsAll(expected)); + } + + @Test + public void testTimeNanoBatch() throws Exception, SFException { + Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select '11:32:54.1399013'::TIME(7) union select '8:11:25.9117654'::TIME(7)"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + int totalRows = 0; + List values = new ArrayList<>(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof TimeNanoVector); + TimeNanoVector vector = (TimeNanoVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(LocalTime.ofNanoOfDay(vector.get(i))); + } + root.close(); + } + } + + rs.close(); + + List expected = + new ArrayList() { + { + add(LocalTime.of(11, 32, 54, 139901300)); + add(LocalTime.of(8, 11, 25, 911765400)); + } + }; + + assertEquals(2, totalRows); + assertTrue(values.containsAll(expected)); + } } From 0fd7d0f0816a261122a947927e711be4e8bf7de9 Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Tue, 3 Sep 2024 15:35:10 +0200 Subject: [PATCH 09/21] Add null time zone check --- .../arrow/fullvectorconverters/DateVectorConverter.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DateVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DateVectorConverter.java index 38b04283a..c509af685 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DateVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DateVectorConverter.java @@ -39,8 +39,10 @@ protected DateDayVector initVector() { @Override protected void additionalConverterInit(ArrowVectorConverter converter) { - converter.setSessionTimeZone(timeZone); - converter.setUseSessionTimezone(true); + if (timeZone != null) { + converter.setSessionTimeZone(timeZone); + converter.setUseSessionTimezone(true); + } } @Override From 73d6b4d8c0ce305cd11ae548670fcf1a6c477a6b Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Tue, 3 Sep 2024 15:38:15 +0200 Subject: [PATCH 10/21] Removed timestamp support --- .../ArrowFullVectorConverter.java | 6 - .../TimestampVectorConverter.java | 165 ------------------ .../client/jdbc/ArrowBatchesTest.java | 42 ----- 3 files changed, 213 deletions(-) delete mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java index d2edf29c9..0fdc6142b 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java @@ -92,12 +92,6 @@ static FieldVector convert( return new BitVectorConverter(allocator, vector, context, session, idx).convert(); case VARBINARY: return new BinaryVectorConverter(allocator, vector, context, session, idx).convert(); - case TIMESTAMPNANOTZ: - return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, false) - .convert(); - case TIMESTAMPNANO: - return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, true) - .convert(); case DATEDAY: return new DateVectorConverter(allocator, vector, context, session, idx, timeZoneToUse) .convert(); diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java deleted file mode 100644 index 448a7f63f..000000000 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java +++ /dev/null @@ -1,165 +0,0 @@ -package net.snowflake.client.core.arrow.fullvectorconverters; - -import java.util.ArrayList; -import java.util.List; -import java.util.TimeZone; -import net.snowflake.client.core.DataConversionContext; -import net.snowflake.client.core.SFException; -import net.snowflake.client.core.SnowflakeJdbcInternalApi; -import net.snowflake.client.core.arrow.ArrowResultUtil; -import net.snowflake.client.jdbc.SnowflakeSQLException; -import net.snowflake.client.util.SFPair; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.BigIntVector; -import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.IntVector; -import org.apache.arrow.vector.ValueVector; -import org.apache.arrow.vector.complex.StructVector; -import org.apache.arrow.vector.types.pojo.Field; - -@SnowflakeJdbcInternalApi -public class TimestampVectorConverter implements ArrowFullVectorConverter { - private RootAllocator allocator; - private ValueVector vector; - private DataConversionContext context; - private TimeZone timeZoneToUse; - private boolean isNTZ; - - /** Field names of the struct vectors used by timestamp */ - private static final String FIELD_NAME_EPOCH = "epoch"; // seconds since epoch - - private static final String FIELD_NAME_TIME_ZONE_INDEX = "timezone"; // time zone index - private static final String FIELD_NAME_FRACTION = "fraction"; // fraction in nanoseconds - - public TimestampVectorConverter( - RootAllocator allocator, - ValueVector vector, - DataConversionContext context, - TimeZone timeZoneToUse, - boolean isNTZ) { - this.allocator = allocator; - this.vector = vector; - this.context = context; - this.timeZoneToUse = timeZoneToUse; - this.isNTZ = isNTZ; - } - - private IntVector makeVectorOfZeroes(int length) { - IntVector vector = new IntVector(FIELD_NAME_FRACTION, allocator); - vector.allocateNew(length); - vector.zeroVector(); - vector.setValueCount(length); - return vector; - } - - private IntVector makeVectorOfUTCOffsets(int length) { - IntVector vector = new IntVector(FIELD_NAME_TIME_ZONE_INDEX, allocator); - vector.allocateNew(length); - vector.setValueCount(length); - for (int i = 0; i < length; i++) { - vector.set(i, 1440); - } - return vector; - } - - private SFPair normalizeTimeSinceEpoch(BigIntVector vector) { - int length = vector.getValueCount(); - int scale = Integer.parseInt(vector.getField().getMetadata().get("scale")); - if (scale == 0) { - IntVector fractions = makeVectorOfZeroes(length); - fractions - .getValidityBuffer() - .setBytes(0L, vector.getValidityBuffer(), 0L, fractions.getValidityBuffer().capacity()); - return SFPair.of(vector, fractions); - } - long scaleFactor = ArrowResultUtil.powerOfTen(scale); - long fractionScaleFactor = ArrowResultUtil.powerOfTen(9 - scale); - BigIntVector epoch = new BigIntVector(FIELD_NAME_EPOCH, allocator); - epoch.allocateNew(length); - epoch.setValueCount(length); - IntVector fractions = new IntVector(FIELD_NAME_FRACTION, allocator); - fractions.allocateNew(length); - fractions.setValueCount(length); - for (int i = 0; i < length; i++) { - epoch.set(i, vector.get(i) / scaleFactor); - fractions.set(i, (int) ((vector.get(i) % scaleFactor) * fractionScaleFactor)); - } - return SFPair.of(vector, fractions); - } - - private IntVector makeTimeZoneOffsets( - BigIntVector seconds, IntVector fractions, TimeZone timeZone) { - IntVector offsets = new IntVector(FIELD_NAME_TIME_ZONE_INDEX, allocator); - offsets.allocateNew(vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - offsets.set( - i, - 1440 - + timeZone.getOffset(seconds.get(i) * 1000 + fractions.get(i) / 1000000) - / (1000 * 60)); - } - return offsets; - } - - private StructVector pack(BigIntVector seconds, IntVector fractions, IntVector offsets) { - StructVector result = StructVector.empty(vector.getName(), allocator); - List fields = - new ArrayList() { - { - add(seconds.getField()); - add(fractions.getField()); - add(offsets.getField()); - } - }; - result.setInitialCapacity(seconds.getValueCount()); - result.initializeChildrenFromFields(fields); - seconds.makeTransferPair(result.getChild(FIELD_NAME_EPOCH)).transfer(); - fractions.makeTransferPair(result.getChild(FIELD_NAME_FRACTION)).transfer(); - offsets.makeTransferPair(result.getChild(FIELD_NAME_TIME_ZONE_INDEX)).transfer(); - result.setValueCount(vector.getValueCount()); - result - .getValidityBuffer() - .setBytes(0L, vector.getValidityBuffer(), 0L, vector.getValidityBuffer().capacity()); - return result; - } - - @Override - public FieldVector convert() throws SFException, SnowflakeSQLException { - BigIntVector seconds; - IntVector fractions; - IntVector timeZoneIndices = null; - if (vector instanceof BigIntVector) { - SFPair normalized = normalizeTimeSinceEpoch((BigIntVector) vector); - seconds = normalized.left; - fractions = normalized.right; - } else { - StructVector structVector = (StructVector) vector; - if (structVector.getChildrenFromFields().size() == 3) { - return structVector; - } - if (structVector.getChild(FIELD_NAME_FRACTION) == null) { - SFPair normalized = - normalizeTimeSinceEpoch(structVector.getChild(FIELD_NAME_EPOCH, BigIntVector.class)); - seconds = normalized.left; - fractions = normalized.right; - } else { - seconds = structVector.getChild(FIELD_NAME_EPOCH, BigIntVector.class); - fractions = structVector.getChild(FIELD_NAME_FRACTION, IntVector.class); - } - timeZoneIndices = structVector.getChild(FIELD_NAME_TIME_ZONE_INDEX, IntVector.class); - } - if (timeZoneIndices == null) { - if (isNTZ && context.getHonorClientTZForTimestampNTZ()) { - timeZoneIndices = makeTimeZoneOffsets(seconds, fractions, TimeZone.getDefault()); - for (int i = 0; i < vector.getValueCount(); i++) { - seconds.set(i, seconds.get(i) - (timeZoneIndices.get(i) - 1440) * 60L); - } - } else if (isNTZ || timeZoneToUse == null) { - timeZoneIndices = makeVectorOfUTCOffsets(vector.getValueCount()); - } else { - timeZoneIndices = makeTimeZoneOffsets(seconds, fractions, timeZoneToUse); - } - } - return pack(seconds, fractions, timeZoneIndices); - } -} diff --git a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java index 2e10e1958..256051793 100644 --- a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java +++ b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java @@ -344,48 +344,6 @@ public void testBinaryBatch() throws Exception { assertTrue(values.containsAll(expected)); } - private void testTimestampBase(String query) throws Exception, SFException { - Statement statement = connection.createStatement(); - ResultSet rs = statement.executeQuery(query); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - - ArrowBatch batch = batches.next(); - VectorSchemaRoot root = batch.fetch().get(0); - assertTrue(root.getVector(0) instanceof StructVector); - ArrowVectorConverter converter = batch.getTimestampConverter(root.getVector(0), 1); - Timestamp tsFromBatch = converter.toTimestamp(0, null); - - rs = statement.executeQuery(query); - rs.next(); - Timestamp tsFromRow = rs.getTimestamp(1); - - assertTrue(tsFromBatch.equals(tsFromRow)); - root.close(); - } - - @Test - public void testTimestampTZBatch() throws Exception, SFException { - testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_TZ"); - } - - @Test - public void testTimestampLTZUseSessionTimezoneBatch() throws Exception, SFException { - Statement statement = connection.createStatement(); - statement.execute("alter session set JDBC_USE_SESSION_TIMEZONE=true"); - testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_LTZ"); - statement.execute("alter session unset JDBC_USE_SESSION_TIMEZONE"); - } - - @Test - public void testTimestampLTZBatch() throws Exception, SFException { - testTimestampBase("select '2020-04-05 12:22:12+0700'::TIMESTAMP_LTZ"); - } - - @Test - public void testTimestampNTZBatch() throws Exception, SFException { - testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_NTZ"); - } - @Test public void testDateBatch() throws Exception, SFException { Statement statement = connection.createStatement(); From 7424acf3cf64d80f1d246ffdf3eb7fa36b2abae9 Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Tue, 3 Sep 2024 15:40:46 +0200 Subject: [PATCH 11/21] Added timestamp support --- .../ArrowFullVectorConverter.java | 6 + .../TimestampVectorConverter.java | 225 ++++++++++++++++++ .../client/jdbc/ArrowBatchesTest.java | 42 ++++ 3 files changed, 273 insertions(+) create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java index 0fdc6142b..f0be5abad 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java @@ -103,6 +103,12 @@ static FieldVector convert( return new TimeMicroVectorConverter(allocator, vector).convert(); case TIMENANO: return new TimeNanoVectorConverter(allocator, vector).convert(); + case TIMESTAMPNANOTZ: + return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, false) + .convert(); + case TIMESTAMPNANO: + return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, true) + .convert(); } } } catch (SFException ex) { diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java new file mode 100644 index 000000000..bad0d8329 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java @@ -0,0 +1,225 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import java.sql.ResultSet; +import java.sql.Statement; +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.List; +import java.util.TimeZone; +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowResultUtil; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import net.snowflake.client.jdbc.ArrowBatch; +import net.snowflake.client.jdbc.ArrowBatches; +import net.snowflake.client.jdbc.SnowflakeResultSet; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import net.snowflake.client.util.SFPair; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.types.pojo.Field; + +@SnowflakeJdbcInternalApi +public class TimestampVectorConverter implements ArrowFullVectorConverter { + private RootAllocator allocator; + private ValueVector vector; + private DataConversionContext context; + private TimeZone timeZoneToUse; + private boolean isNTZ; + + /** Field names of the struct vectors used by timestamp */ + private static final String FIELD_NAME_EPOCH = "epoch"; // seconds since epoch + + private static final String FIELD_NAME_TIME_ZONE_INDEX = "timezone"; // time zone index + private static final String FIELD_NAME_FRACTION = "fraction"; // fraction in nanoseconds + + public TimestampVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + TimeZone timeZoneToUse, + boolean isNTZ) { + this.allocator = allocator; + this.vector = vector; + this.context = context; + this.timeZoneToUse = timeZoneToUse; + this.isNTZ = isNTZ; + } + + private IntVector makeVectorOfZeroes(int length) { + IntVector vector = new IntVector(FIELD_NAME_FRACTION, allocator); + vector.allocateNew(length); + vector.zeroVector(); + vector.setValueCount(length); + return vector; + } + + private IntVector makeVectorOfUTCOffsets(int length) { + IntVector vector = new IntVector(FIELD_NAME_TIME_ZONE_INDEX, allocator); + vector.allocateNew(length); + vector.setValueCount(length); + for (int i = 0; i < length; i++) { + vector.set(i, 1440); + } + return vector; + } + + private SFPair normalizeTimeSinceEpoch(BigIntVector vector) { + int length = vector.getValueCount(); + int scale = Integer.parseInt(vector.getField().getMetadata().get("scale")); + if (scale == 0) { + IntVector fractions = makeVectorOfZeroes(length); + fractions + .getValidityBuffer() + .setBytes(0L, vector.getValidityBuffer(), 0L, fractions.getValidityBuffer().capacity()); + return SFPair.of(vector, fractions); + } + long scaleFactor = ArrowResultUtil.powerOfTen(scale); + long fractionScaleFactor = ArrowResultUtil.powerOfTen(9 - scale); + BigIntVector epoch = new BigIntVector(FIELD_NAME_EPOCH, allocator); + epoch.allocateNew(length); + epoch.setValueCount(length); + IntVector fractions = new IntVector(FIELD_NAME_FRACTION, allocator); + fractions.allocateNew(length); + fractions.setValueCount(length); + for (int i = 0; i < length; i++) { + epoch.set(i, vector.get(i) / scaleFactor); + fractions.set(i, (int) ((vector.get(i) % scaleFactor) * fractionScaleFactor)); + } + return SFPair.of(vector, fractions); + } + + private IntVector makeTimeZoneOffsets( + BigIntVector seconds, IntVector fractions, TimeZone timeZone) { + IntVector offsets = new IntVector(FIELD_NAME_TIME_ZONE_INDEX, allocator); + offsets.allocateNew(vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + offsets.set( + i, + 1440 + + timeZone.getOffset(seconds.get(i) * 1000 + fractions.get(i) / 1000000) + / (1000 * 60)); + } + return offsets; + } + + private StructVector pack(BigIntVector seconds, IntVector fractions, IntVector offsets) { + StructVector result = StructVector.empty(vector.getName(), allocator); + List fields = + new ArrayList() { + { + add(seconds.getField()); + add(fractions.getField()); + add(offsets.getField()); + } + }; + result.setInitialCapacity(seconds.getValueCount()); + result.initializeChildrenFromFields(fields); + seconds.makeTransferPair(result.getChild(FIELD_NAME_EPOCH)).transfer(); + fractions.makeTransferPair(result.getChild(FIELD_NAME_FRACTION)).transfer(); + offsets.makeTransferPair(result.getChild(FIELD_NAME_TIME_ZONE_INDEX)).transfer(); + result.setValueCount(vector.getValueCount()); + result + .getValidityBuffer() + .setBytes(0L, vector.getValidityBuffer(), 0L, vector.getValidityBuffer().capacity()); + return result; + } + + @Override + public FieldVector convert() throws SFException, SnowflakeSQLException { + BigIntVector seconds; + IntVector fractions; + IntVector timeZoneIndices = null; + if (vector instanceof BigIntVector) { + SFPair normalized = normalizeTimeSinceEpoch((BigIntVector) vector); + seconds = normalized.left; + fractions = normalized.right; + } else { + StructVector structVector = (StructVector) vector; + if (structVector.getChildrenFromFields().size() == 3) { + return structVector; + } + if (structVector.getChild(FIELD_NAME_FRACTION) == null) { + SFPair normalized = + normalizeTimeSinceEpoch(structVector.getChild(FIELD_NAME_EPOCH, BigIntVector.class)); + seconds = normalized.left; + fractions = normalized.right; + } else { + seconds = structVector.getChild(FIELD_NAME_EPOCH, BigIntVector.class); + fractions = structVector.getChild(FIELD_NAME_FRACTION, IntVector.class); + } + timeZoneIndices = structVector.getChild(FIELD_NAME_TIME_ZONE_INDEX, IntVector.class); + } + if (timeZoneIndices == null) { + if (isNTZ && context.getHonorClientTZForTimestampNTZ()) { + timeZoneIndices = makeTimeZoneOffsets(seconds, fractions, TimeZone.getDefault()); + for (int i = 0; i < vector.getValueCount(); i++) { + seconds.set(i, seconds.get(i) - (timeZoneIndices.get(i) - 1440) * 60L); + } + } else if (isNTZ || timeZoneToUse == null) { + timeZoneIndices = makeVectorOfUTCOffsets(vector.getValueCount()); + } else { + timeZoneIndices = makeTimeZoneOffsets(seconds, fractions, timeZoneToUse); + } + } + return pack(seconds, fractions, timeZoneIndices); + } + + +} + + + /* case TIMESTAMPNANOTZ: + return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, false) + .convert(); + case TIMESTAMPNANO: + return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, true) + .convert(); + +private void testTimestampBase(String query) throws Exception, SFException { + Statement statement = connection.createStatement(); + ResultSet rs = statement.executeQuery(query); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + ArrowBatch batch = batches.next(); + VectorSchemaRoot root = batch.fetch().get(0); + assertTrue(root.getVector(0) instanceof StructVector); + ArrowVectorConverter converter = batch.getTimestampConverter(root.getVector(0), 1); + Timestamp tsFromBatch = converter.toTimestamp(0, null); + + rs = statement.executeQuery(query); + rs.next(); + Timestamp tsFromRow = rs.getTimestamp(1); + + assertTrue(tsFromBatch.equals(tsFromRow)); + root.close(); +} + +@Test +public void testTimestampTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_TZ"); +} + +@Test +public void testTimestampLTZUseSessionTimezoneBatch() throws Exception, SFException { + Statement statement = connection.createStatement(); + statement.execute("alter session set JDBC_USE_SESSION_TIMEZONE=true"); + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_LTZ"); + statement.execute("alter session unset JDBC_USE_SESSION_TIMEZONE"); +} + +@Test +public void testTimestampLTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12+0700'::TIMESTAMP_LTZ"); +} + +@Test +public void testTimestampNTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_NTZ"); +}*/ \ No newline at end of file diff --git a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java index 256051793..2ff47a66f 100644 --- a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java +++ b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java @@ -535,4 +535,46 @@ public void testTimeNanoBatch() throws Exception, SFException { assertEquals(2, totalRows); assertTrue(values.containsAll(expected)); } + + private void testTimestampBase(String query) throws Exception, SFException { + Statement statement = connection.createStatement(); + ResultSet rs = statement.executeQuery(query); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + ArrowBatch batch = batches.next(); + VectorSchemaRoot root = batch.fetch().get(0); + assertTrue(root.getVector(0) instanceof StructVector); + ArrowVectorConverter converter = batch.getTimestampConverter(root.getVector(0), 1); + Timestamp tsFromBatch = converter.toTimestamp(0, null); + + rs = statement.executeQuery(query); + rs.next(); + Timestamp tsFromRow = rs.getTimestamp(1); + + assertTrue(tsFromBatch.equals(tsFromRow)); + root.close(); + } + + @Test + public void testTimestampTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_TZ"); + } + + @Test + public void testTimestampLTZUseSessionTimezoneBatch() throws Exception, SFException { + Statement statement = connection.createStatement(); + statement.execute("alter session set JDBC_USE_SESSION_TIMEZONE=true"); + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_LTZ"); + statement.execute("alter session unset JDBC_USE_SESSION_TIMEZONE"); + } + + @Test + public void testTimestampLTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12+0700'::TIMESTAMP_LTZ"); + } + + @Test + public void testTimestampNTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_NTZ"); + } } From 3dd20a9953b4ef0e7fc528342a54c2b2d735d07d Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Tue, 3 Sep 2024 15:42:31 +0200 Subject: [PATCH 12/21] Formatting --- .../ArrowFullVectorConverter.java | 4 +- .../TimestampVectorConverter.java | 377 +++++++++--------- .../net/snowflake/client/jdbc/ArrowBatch.java | 1 - 3 files changed, 185 insertions(+), 197 deletions(-) diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java index f0be5abad..cd656e924 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java @@ -105,10 +105,10 @@ static FieldVector convert( return new TimeNanoVectorConverter(allocator, vector).convert(); case TIMESTAMPNANOTZ: return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, false) - .convert(); + .convert(); case TIMESTAMPNANO: return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, true) - .convert(); + .convert(); } } } catch (SFException ex) { diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java index bad0d8329..f9ce28e33 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java @@ -1,8 +1,5 @@ package net.snowflake.client.core.arrow.fullvectorconverters; -import java.sql.ResultSet; -import java.sql.Statement; -import java.sql.Timestamp; import java.util.ArrayList; import java.util.List; import java.util.TimeZone; @@ -10,10 +7,6 @@ import net.snowflake.client.core.SFException; import net.snowflake.client.core.SnowflakeJdbcInternalApi; import net.snowflake.client.core.arrow.ArrowResultUtil; -import net.snowflake.client.core.arrow.ArrowVectorConverter; -import net.snowflake.client.jdbc.ArrowBatch; -import net.snowflake.client.jdbc.ArrowBatches; -import net.snowflake.client.jdbc.SnowflakeResultSet; import net.snowflake.client.jdbc.SnowflakeSQLException; import net.snowflake.client.util.SFPair; import org.apache.arrow.memory.RootAllocator; @@ -21,205 +14,201 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.ValueVector; -import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.types.pojo.Field; @SnowflakeJdbcInternalApi public class TimestampVectorConverter implements ArrowFullVectorConverter { - private RootAllocator allocator; - private ValueVector vector; - private DataConversionContext context; - private TimeZone timeZoneToUse; - private boolean isNTZ; - - /** Field names of the struct vectors used by timestamp */ - private static final String FIELD_NAME_EPOCH = "epoch"; // seconds since epoch - - private static final String FIELD_NAME_TIME_ZONE_INDEX = "timezone"; // time zone index - private static final String FIELD_NAME_FRACTION = "fraction"; // fraction in nanoseconds - - public TimestampVectorConverter( - RootAllocator allocator, - ValueVector vector, - DataConversionContext context, - TimeZone timeZoneToUse, - boolean isNTZ) { - this.allocator = allocator; - this.vector = vector; - this.context = context; - this.timeZoneToUse = timeZoneToUse; - this.isNTZ = isNTZ; + private RootAllocator allocator; + private ValueVector vector; + private DataConversionContext context; + private TimeZone timeZoneToUse; + private boolean isNTZ; + + /** Field names of the struct vectors used by timestamp */ + private static final String FIELD_NAME_EPOCH = "epoch"; // seconds since epoch + + private static final String FIELD_NAME_TIME_ZONE_INDEX = "timezone"; // time zone index + private static final String FIELD_NAME_FRACTION = "fraction"; // fraction in nanoseconds + + public TimestampVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + TimeZone timeZoneToUse, + boolean isNTZ) { + this.allocator = allocator; + this.vector = vector; + this.context = context; + this.timeZoneToUse = timeZoneToUse; + this.isNTZ = isNTZ; + } + + private IntVector makeVectorOfZeroes(int length) { + IntVector vector = new IntVector(FIELD_NAME_FRACTION, allocator); + vector.allocateNew(length); + vector.zeroVector(); + vector.setValueCount(length); + return vector; + } + + private IntVector makeVectorOfUTCOffsets(int length) { + IntVector vector = new IntVector(FIELD_NAME_TIME_ZONE_INDEX, allocator); + vector.allocateNew(length); + vector.setValueCount(length); + for (int i = 0; i < length; i++) { + vector.set(i, 1440); } - - private IntVector makeVectorOfZeroes(int length) { - IntVector vector = new IntVector(FIELD_NAME_FRACTION, allocator); - vector.allocateNew(length); - vector.zeroVector(); - vector.setValueCount(length); - return vector; + return vector; + } + + private SFPair normalizeTimeSinceEpoch(BigIntVector vector) { + int length = vector.getValueCount(); + int scale = Integer.parseInt(vector.getField().getMetadata().get("scale")); + if (scale == 0) { + IntVector fractions = makeVectorOfZeroes(length); + fractions + .getValidityBuffer() + .setBytes(0L, vector.getValidityBuffer(), 0L, fractions.getValidityBuffer().capacity()); + return SFPair.of(vector, fractions); } - - private IntVector makeVectorOfUTCOffsets(int length) { - IntVector vector = new IntVector(FIELD_NAME_TIME_ZONE_INDEX, allocator); - vector.allocateNew(length); - vector.setValueCount(length); - for (int i = 0; i < length; i++) { - vector.set(i, 1440); - } - return vector; + long scaleFactor = ArrowResultUtil.powerOfTen(scale); + long fractionScaleFactor = ArrowResultUtil.powerOfTen(9 - scale); + BigIntVector epoch = new BigIntVector(FIELD_NAME_EPOCH, allocator); + epoch.allocateNew(length); + epoch.setValueCount(length); + IntVector fractions = new IntVector(FIELD_NAME_FRACTION, allocator); + fractions.allocateNew(length); + fractions.setValueCount(length); + for (int i = 0; i < length; i++) { + epoch.set(i, vector.get(i) / scaleFactor); + fractions.set(i, (int) ((vector.get(i) % scaleFactor) * fractionScaleFactor)); } - - private SFPair normalizeTimeSinceEpoch(BigIntVector vector) { - int length = vector.getValueCount(); - int scale = Integer.parseInt(vector.getField().getMetadata().get("scale")); - if (scale == 0) { - IntVector fractions = makeVectorOfZeroes(length); - fractions - .getValidityBuffer() - .setBytes(0L, vector.getValidityBuffer(), 0L, fractions.getValidityBuffer().capacity()); - return SFPair.of(vector, fractions); - } - long scaleFactor = ArrowResultUtil.powerOfTen(scale); - long fractionScaleFactor = ArrowResultUtil.powerOfTen(9 - scale); - BigIntVector epoch = new BigIntVector(FIELD_NAME_EPOCH, allocator); - epoch.allocateNew(length); - epoch.setValueCount(length); - IntVector fractions = new IntVector(FIELD_NAME_FRACTION, allocator); - fractions.allocateNew(length); - fractions.setValueCount(length); - for (int i = 0; i < length; i++) { - epoch.set(i, vector.get(i) / scaleFactor); - fractions.set(i, (int) ((vector.get(i) % scaleFactor) * fractionScaleFactor)); - } - return SFPair.of(vector, fractions); + return SFPair.of(vector, fractions); + } + + private IntVector makeTimeZoneOffsets( + BigIntVector seconds, IntVector fractions, TimeZone timeZone) { + IntVector offsets = new IntVector(FIELD_NAME_TIME_ZONE_INDEX, allocator); + offsets.allocateNew(vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + offsets.set( + i, + 1440 + + timeZone.getOffset(seconds.get(i) * 1000 + fractions.get(i) / 1000000) + / (1000 * 60)); } - - private IntVector makeTimeZoneOffsets( - BigIntVector seconds, IntVector fractions, TimeZone timeZone) { - IntVector offsets = new IntVector(FIELD_NAME_TIME_ZONE_INDEX, allocator); - offsets.allocateNew(vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - offsets.set( - i, - 1440 - + timeZone.getOffset(seconds.get(i) * 1000 + fractions.get(i) / 1000000) - / (1000 * 60)); - } - return offsets; - } - - private StructVector pack(BigIntVector seconds, IntVector fractions, IntVector offsets) { - StructVector result = StructVector.empty(vector.getName(), allocator); - List fields = - new ArrayList() { - { - add(seconds.getField()); - add(fractions.getField()); - add(offsets.getField()); - } - }; - result.setInitialCapacity(seconds.getValueCount()); - result.initializeChildrenFromFields(fields); - seconds.makeTransferPair(result.getChild(FIELD_NAME_EPOCH)).transfer(); - fractions.makeTransferPair(result.getChild(FIELD_NAME_FRACTION)).transfer(); - offsets.makeTransferPair(result.getChild(FIELD_NAME_TIME_ZONE_INDEX)).transfer(); - result.setValueCount(vector.getValueCount()); - result - .getValidityBuffer() - .setBytes(0L, vector.getValidityBuffer(), 0L, vector.getValidityBuffer().capacity()); - return result; + return offsets; + } + + private StructVector pack(BigIntVector seconds, IntVector fractions, IntVector offsets) { + StructVector result = StructVector.empty(vector.getName(), allocator); + List fields = + new ArrayList() { + { + add(seconds.getField()); + add(fractions.getField()); + add(offsets.getField()); + } + }; + result.setInitialCapacity(seconds.getValueCount()); + result.initializeChildrenFromFields(fields); + seconds.makeTransferPair(result.getChild(FIELD_NAME_EPOCH)).transfer(); + fractions.makeTransferPair(result.getChild(FIELD_NAME_FRACTION)).transfer(); + offsets.makeTransferPair(result.getChild(FIELD_NAME_TIME_ZONE_INDEX)).transfer(); + result.setValueCount(vector.getValueCount()); + result + .getValidityBuffer() + .setBytes(0L, vector.getValidityBuffer(), 0L, vector.getValidityBuffer().capacity()); + return result; + } + + @Override + public FieldVector convert() throws SFException, SnowflakeSQLException { + BigIntVector seconds; + IntVector fractions; + IntVector timeZoneIndices = null; + if (vector instanceof BigIntVector) { + SFPair normalized = normalizeTimeSinceEpoch((BigIntVector) vector); + seconds = normalized.left; + fractions = normalized.right; + } else { + StructVector structVector = (StructVector) vector; + if (structVector.getChildrenFromFields().size() == 3) { + return structVector; + } + if (structVector.getChild(FIELD_NAME_FRACTION) == null) { + SFPair normalized = + normalizeTimeSinceEpoch(structVector.getChild(FIELD_NAME_EPOCH, BigIntVector.class)); + seconds = normalized.left; + fractions = normalized.right; + } else { + seconds = structVector.getChild(FIELD_NAME_EPOCH, BigIntVector.class); + fractions = structVector.getChild(FIELD_NAME_FRACTION, IntVector.class); + } + timeZoneIndices = structVector.getChild(FIELD_NAME_TIME_ZONE_INDEX, IntVector.class); } - - @Override - public FieldVector convert() throws SFException, SnowflakeSQLException { - BigIntVector seconds; - IntVector fractions; - IntVector timeZoneIndices = null; - if (vector instanceof BigIntVector) { - SFPair normalized = normalizeTimeSinceEpoch((BigIntVector) vector); - seconds = normalized.left; - fractions = normalized.right; - } else { - StructVector structVector = (StructVector) vector; - if (structVector.getChildrenFromFields().size() == 3) { - return structVector; - } - if (structVector.getChild(FIELD_NAME_FRACTION) == null) { - SFPair normalized = - normalizeTimeSinceEpoch(structVector.getChild(FIELD_NAME_EPOCH, BigIntVector.class)); - seconds = normalized.left; - fractions = normalized.right; - } else { - seconds = structVector.getChild(FIELD_NAME_EPOCH, BigIntVector.class); - fractions = structVector.getChild(FIELD_NAME_FRACTION, IntVector.class); - } - timeZoneIndices = structVector.getChild(FIELD_NAME_TIME_ZONE_INDEX, IntVector.class); - } - if (timeZoneIndices == null) { - if (isNTZ && context.getHonorClientTZForTimestampNTZ()) { - timeZoneIndices = makeTimeZoneOffsets(seconds, fractions, TimeZone.getDefault()); - for (int i = 0; i < vector.getValueCount(); i++) { - seconds.set(i, seconds.get(i) - (timeZoneIndices.get(i) - 1440) * 60L); - } - } else if (isNTZ || timeZoneToUse == null) { - timeZoneIndices = makeVectorOfUTCOffsets(vector.getValueCount()); - } else { - timeZoneIndices = makeTimeZoneOffsets(seconds, fractions, timeZoneToUse); - } + if (timeZoneIndices == null) { + if (isNTZ && context.getHonorClientTZForTimestampNTZ()) { + timeZoneIndices = makeTimeZoneOffsets(seconds, fractions, TimeZone.getDefault()); + for (int i = 0; i < vector.getValueCount(); i++) { + seconds.set(i, seconds.get(i) - (timeZoneIndices.get(i) - 1440) * 60L); } - return pack(seconds, fractions, timeZoneIndices); + } else if (isNTZ || timeZoneToUse == null) { + timeZoneIndices = makeVectorOfUTCOffsets(vector.getValueCount()); + } else { + timeZoneIndices = makeTimeZoneOffsets(seconds, fractions, timeZoneToUse); + } } - - + return pack(seconds, fractions, timeZoneIndices); + } } - /* case TIMESTAMPNANOTZ: - return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, false) - .convert(); - case TIMESTAMPNANO: - return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, true) - .convert(); - -private void testTimestampBase(String query) throws Exception, SFException { - Statement statement = connection.createStatement(); - ResultSet rs = statement.executeQuery(query); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - - ArrowBatch batch = batches.next(); - VectorSchemaRoot root = batch.fetch().get(0); - assertTrue(root.getVector(0) instanceof StructVector); - ArrowVectorConverter converter = batch.getTimestampConverter(root.getVector(0), 1); - Timestamp tsFromBatch = converter.toTimestamp(0, null); - - rs = statement.executeQuery(query); - rs.next(); - Timestamp tsFromRow = rs.getTimestamp(1); - - assertTrue(tsFromBatch.equals(tsFromRow)); - root.close(); -} - -@Test -public void testTimestampTZBatch() throws Exception, SFException { - testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_TZ"); -} - -@Test -public void testTimestampLTZUseSessionTimezoneBatch() throws Exception, SFException { - Statement statement = connection.createStatement(); - statement.execute("alter session set JDBC_USE_SESSION_TIMEZONE=true"); - testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_LTZ"); - statement.execute("alter session unset JDBC_USE_SESSION_TIMEZONE"); -} - -@Test -public void testTimestampLTZBatch() throws Exception, SFException { - testTimestampBase("select '2020-04-05 12:22:12+0700'::TIMESTAMP_LTZ"); -} - -@Test -public void testTimestampNTZBatch() throws Exception, SFException { - testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_NTZ"); -}*/ \ No newline at end of file + return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, false) + .convert(); + case TIMESTAMPNANO: + return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, true) + .convert(); + + private void testTimestampBase(String query) throws Exception, SFException { + Statement statement = connection.createStatement(); + ResultSet rs = statement.executeQuery(query); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + ArrowBatch batch = batches.next(); + VectorSchemaRoot root = batch.fetch().get(0); + assertTrue(root.getVector(0) instanceof StructVector); + ArrowVectorConverter converter = batch.getTimestampConverter(root.getVector(0), 1); + Timestamp tsFromBatch = converter.toTimestamp(0, null); + + rs = statement.executeQuery(query); + rs.next(); + Timestamp tsFromRow = rs.getTimestamp(1); + + assertTrue(tsFromBatch.equals(tsFromRow)); + root.close(); + } + + @Test + public void testTimestampTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_TZ"); + } + + @Test + public void testTimestampLTZUseSessionTimezoneBatch() throws Exception, SFException { + Statement statement = connection.createStatement(); + statement.execute("alter session set JDBC_USE_SESSION_TIMEZONE=true"); + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_LTZ"); + statement.execute("alter session unset JDBC_USE_SESSION_TIMEZONE"); + } + + @Test + public void testTimestampLTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12+0700'::TIMESTAMP_LTZ"); + } + + @Test + public void testTimestampNTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_NTZ"); + }*/ diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java b/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java index 5ef92ed16..c9dd11c12 100644 --- a/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java +++ b/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java @@ -3,7 +3,6 @@ import java.util.List; import net.snowflake.client.core.arrow.ArrowVectorConverter; import org.apache.arrow.vector.FieldVector; - import org.apache.arrow.vector.VectorSchemaRoot; public interface ArrowBatch { From 28fb57adef136435ccf97900e319e741936cfb43 Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Tue, 3 Sep 2024 15:45:24 +0200 Subject: [PATCH 13/21] Removed old comments --- .../TimestampVectorConverter.java | 49 ------------------- 1 file changed, 49 deletions(-) diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java index f9ce28e33..448a7f63f 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java @@ -163,52 +163,3 @@ public FieldVector convert() throws SFException, SnowflakeSQLException { return pack(seconds, fractions, timeZoneIndices); } } - - /* case TIMESTAMPNANOTZ: - return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, false) - .convert(); - case TIMESTAMPNANO: - return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, true) - .convert(); - - private void testTimestampBase(String query) throws Exception, SFException { - Statement statement = connection.createStatement(); - ResultSet rs = statement.executeQuery(query); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - - ArrowBatch batch = batches.next(); - VectorSchemaRoot root = batch.fetch().get(0); - assertTrue(root.getVector(0) instanceof StructVector); - ArrowVectorConverter converter = batch.getTimestampConverter(root.getVector(0), 1); - Timestamp tsFromBatch = converter.toTimestamp(0, null); - - rs = statement.executeQuery(query); - rs.next(); - Timestamp tsFromRow = rs.getTimestamp(1); - - assertTrue(tsFromBatch.equals(tsFromRow)); - root.close(); - } - - @Test - public void testTimestampTZBatch() throws Exception, SFException { - testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_TZ"); - } - - @Test - public void testTimestampLTZUseSessionTimezoneBatch() throws Exception, SFException { - Statement statement = connection.createStatement(); - statement.execute("alter session set JDBC_USE_SESSION_TIMEZONE=true"); - testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_LTZ"); - statement.execute("alter session unset JDBC_USE_SESSION_TIMEZONE"); - } - - @Test - public void testTimestampLTZBatch() throws Exception, SFException { - testTimestampBase("select '2020-04-05 12:22:12+0700'::TIMESTAMP_LTZ"); - } - - @Test - public void testTimestampNTZBatch() throws Exception, SFException { - testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_NTZ"); - }*/ From afa11422de4010bc10ca7c7492c9ff0c765d00d0 Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Tue, 3 Sep 2024 15:55:38 +0200 Subject: [PATCH 14/21] Fixed memory leak and added assertion of no leaks in tests. --- .../fullvectorconverters/TimestampVectorConverter.java | 3 ++- .../java/net/snowflake/client/jdbc/ArrowBatchesTest.java | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java index 448a7f63f..ee47a2834 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java @@ -84,7 +84,7 @@ private SFPair normalizeTimeSinceEpoch(BigIntVector vec epoch.set(i, vector.get(i) / scaleFactor); fractions.set(i, (int) ((vector.get(i) % scaleFactor) * fractionScaleFactor)); } - return SFPair.of(vector, fractions); + return SFPair.of(epoch, fractions); } private IntVector makeTimeZoneOffsets( @@ -120,6 +120,7 @@ private StructVector pack(BigIntVector seconds, IntVector fractions, IntVector o result .getValidityBuffer() .setBytes(0L, vector.getValidityBuffer(), 0L, vector.getValidityBuffer().capacity()); + vector.close(); return result; } diff --git a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java index 2ff47a66f..c4172a9ba 100644 --- a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java +++ b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java @@ -49,10 +49,9 @@ public void tearDown() throws Exception { } private static void assertNoMemoryLeaks(ResultSet rs) throws SQLException { - assertEquals( + assertEquals(0, ((SFArrowResultSet) rs.unwrap(SnowflakeResultSetV1.class).sfBaseResultSet) - .getAllocatedMemory(), - 0); + .getAllocatedMemory()); } @Test @@ -546,13 +545,14 @@ private void testTimestampBase(String query) throws Exception, SFException { assertTrue(root.getVector(0) instanceof StructVector); ArrowVectorConverter converter = batch.getTimestampConverter(root.getVector(0), 1); Timestamp tsFromBatch = converter.toTimestamp(0, null); + root.close(); + assertNoMemoryLeaks(rs); rs = statement.executeQuery(query); rs.next(); Timestamp tsFromRow = rs.getTimestamp(1); assertTrue(tsFromBatch.equals(tsFromRow)); - root.close(); } @Test From 7c9ab7aeef82919283ee651bcb90a92a7805d58a Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Tue, 3 Sep 2024 16:00:44 +0200 Subject: [PATCH 15/21] Fixed memory leaks and added assertions of no memory leaks. --- .../SimpleArrowFullVectorConverter.java | 1 + .../TimeVectorConverter.java | 1 + .../net/snowflake/client/jdbc/ArrowBatch.java | 1 - .../client/jdbc/ArrowBatchesTest.java | 23 +++++++++++-------- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java index 001145658..c8e1405e4 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java @@ -54,6 +54,7 @@ public FieldVector convert() throws SFException, SnowflakeSQLException { convertValue(converter, converted, i); } converted.setValueCount(size); + vector.close(); return converted; } } diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeVectorConverter.java index 6f41dfc07..baba5931a 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeVectorConverter.java @@ -39,6 +39,7 @@ public FieldVector convert() throws SFException, SnowflakeSQLException { convertValue(converted, i, srcVector.getValueAsLong(i) * scalingFactor); } converted.setValueCount(size); + vector.close(); return converted; } } diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java b/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java index 5ef92ed16..c9dd11c12 100644 --- a/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java +++ b/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java @@ -3,7 +3,6 @@ import java.util.List; import net.snowflake.client.core.arrow.ArrowVectorConverter; import org.apache.arrow.vector.FieldVector; - import org.apache.arrow.vector.VectorSchemaRoot; public interface ArrowBatch { diff --git a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java index 256051793..647e8046f 100644 --- a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java +++ b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java @@ -7,14 +7,12 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; -import java.sql.Timestamp; import java.time.LocalDate; import java.time.LocalTime; import java.util.ArrayList; import java.util.List; import net.snowflake.client.core.SFArrowResultSet; import net.snowflake.client.core.SFException; -import net.snowflake.client.core.arrow.ArrowVectorConverter; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.DateDayVector; @@ -28,7 +26,6 @@ import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.complex.StructVector; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -111,6 +108,7 @@ public void testTinyIntBatch() throws Exception { root.close(); } } + assertNoMemoryLeaks(rs); rs.close(); // All expected values are present @@ -143,6 +141,7 @@ public void testSmallIntBatch() throws Exception { root.close(); } } + assertNoMemoryLeaks(rs); rs.close(); // All expected values are present @@ -175,6 +174,7 @@ public void testIntBatch() throws Exception { root.close(); } } + assertNoMemoryLeaks(rs); rs.close(); // All expected values are present @@ -209,6 +209,7 @@ public void testBigIntBatch() throws Exception { root.close(); } } + assertNoMemoryLeaks(rs); rs.close(); // All expected values are present @@ -241,7 +242,7 @@ public void testDecimalBatch() throws Exception { root.close(); } } - + assertNoMemoryLeaks(rs); rs.close(); // All expected values are present @@ -280,6 +281,8 @@ public void testBitBatch() throws Exception { root.close(); } } + assertNoMemoryLeaks(rs); + rs.close(); assertEquals(4, trueCount); assertEquals(3, falseCount); @@ -316,6 +319,8 @@ public void testBinaryBatch() throws Exception { root.close(); } } + assertNoMemoryLeaks(rs); + rs.close(); List> expected = new ArrayList>() { @@ -367,7 +372,7 @@ public void testDateBatch() throws Exception, SFException { root.close(); } } - + assertNoMemoryLeaks(rs); rs.close(); List expected = @@ -405,7 +410,7 @@ public void testTimeSecBatch() throws Exception, SFException { root.close(); } } - + assertNoMemoryLeaks(rs); rs.close(); List expected = @@ -443,7 +448,7 @@ public void testTimeMilliBatch() throws Exception, SFException { root.close(); } } - + assertNoMemoryLeaks(rs); rs.close(); List expected = @@ -482,7 +487,7 @@ public void testTimeMicroBatch() throws Exception, SFException { root.close(); } } - + assertNoMemoryLeaks(rs); rs.close(); List expected = @@ -521,7 +526,7 @@ public void testTimeNanoBatch() throws Exception, SFException { root.close(); } } - + assertNoMemoryLeaks(rs); rs.close(); List expected = From 69cfe592502f1624dd8cf099d9e1502e40df7bc4 Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Fri, 6 Sep 2024 07:32:47 +0200 Subject: [PATCH 16/21] Merge fixes --- src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java index fb9de9745..03195a4f8 100644 --- a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java +++ b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java @@ -7,6 +7,7 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; +import java.sql.Timestamp; import java.time.LocalDate; import java.time.LocalTime; import java.util.ArrayList; @@ -15,6 +16,7 @@ import net.snowflake.client.category.TestCategoryArrow; import net.snowflake.client.core.SFArrowResultSet; import net.snowflake.client.core.SFException; +import net.snowflake.client.core.arrow.ArrowVectorConverter; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.DateDayVector; @@ -28,6 +30,7 @@ import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.StructVector; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Rule; From 1c34bb9fe156c94e6ac54fd1da3dd77cea585ce0 Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Fri, 6 Sep 2024 07:37:30 +0200 Subject: [PATCH 17/21] Added null check --- .../fullvectorconverters/SimpleArrowFullVectorConverter.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java index 3c7c55e90..f2d5c1d27 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java @@ -52,7 +52,9 @@ public FieldVector convert() throws SFException, SnowflakeSQLException { ArrowVectorConverterUtil.initConverter(vector, context, session, idx); additionalConverterInit(converter); for (int i = 0; i < size; i++) { - convertValue(converter, converted, i); + if (!vector.isNull(i)) { + convertValue(converter, converted, i); + } } converted.setValueCount(size); vector.close(); From ef5238c084cfea843889446f18aca682157e3bbc Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Fri, 6 Sep 2024 07:44:41 +0200 Subject: [PATCH 18/21] Formatting --- src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java index 03195a4f8..bc0058132 100644 --- a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java +++ b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesTest.java @@ -58,7 +58,8 @@ public static void tearDown() throws Exception { } private static void assertNoMemoryLeaks(ResultSet rs) throws SQLException { - assertEquals(0, + assertEquals( + 0, ((SFArrowResultSet) rs.unwrap(SnowflakeResultSetV1.class).sfBaseResultSet) .getAllocatedMemory()); } From 44e608d4ac33495a62325fe10653c513bb280d9c Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Fri, 20 Sep 2024 09:45:06 +0200 Subject: [PATCH 19/21] Merge fixes --- .../ArrowFullVectorConverter.java | 97 ------------------- .../ArrowFullVectorConverterUtil.java | 84 ++++++++-------- .../client/jdbc/ArrowResultChunk.java | 2 +- 3 files changed, 40 insertions(+), 143 deletions(-) diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java index fa6a1c9f1..929dcdc1e 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java @@ -7,102 +7,5 @@ @SnowflakeJdbcInternalApi public interface ArrowFullVectorConverter { - static Types.MinorType deduceType(ValueVector vector) { - Types.MinorType type = Types.getMinorTypeForArrowType(vector.getField().getType()); - // each column's metadata - Map customMeta = vector.getField().getMetadata(); - if (type == Types.MinorType.DECIMAL) { - // Note: Decimal vector is different from others - return Types.MinorType.DECIMAL; - } else if (!customMeta.isEmpty()) { - SnowflakeType st = SnowflakeType.valueOf(customMeta.get("logicalType")); - switch (st) { - case FIXED: - { - String scaleStr = vector.getField().getMetadata().get("scale"); - int sfScale = Integer.parseInt(scaleStr); - if (sfScale != 0) { - return Types.MinorType.DECIMAL; - } - break; - } - case TIME: - { - String scaleStr = vector.getField().getMetadata().get("scale"); - int sfScale = Integer.parseInt(scaleStr); - if (sfScale == 0) { - return Types.MinorType.TIMESEC; - } - if (sfScale <= 3) { - return Types.MinorType.TIMEMILLI; - } - if (sfScale <= 6) { - return Types.MinorType.TIMEMICRO; - } - if (sfScale <= 9) { - return Types.MinorType.TIMENANO; - } - } - case TIMESTAMP_NTZ: - return Types.MinorType.TIMESTAMPNANO; - case TIMESTAMP_LTZ: - case TIMESTAMP_TZ: - return Types.MinorType.TIMESTAMPNANOTZ; - } - } - return type; - } - - static FieldVector convert( - RootAllocator allocator, - ValueVector vector, - DataConversionContext context, - SFBaseSession session, - TimeZone timeZoneToUse, - int idx, - Object targetType) - throws SnowflakeSQLException { - try { - if (targetType == null) { - targetType = deduceType(vector); - } - if (targetType instanceof Types.MinorType) { - switch ((Types.MinorType) targetType) { - case TINYINT: - return new TinyIntVectorConverter(allocator, vector, context, session, idx).convert(); - case SMALLINT: - return new SmallIntVectorConverter(allocator, vector, context, session, idx).convert(); - case INT: - return new IntVectorConverter(allocator, vector, context, session, idx).convert(); - case BIGINT: - return new BigIntVectorConverter(allocator, vector, context, session, idx).convert(); - case DECIMAL: - return new DecimalVectorConverter(allocator, vector, context, session, idx).convert(); - case FLOAT8: - return new FloatVectorConverter(allocator, vector, context, session, idx).convert(); - case BIT: - return new BitVectorConverter(allocator, vector, context, session, idx).convert(); - case VARBINARY: - return new BinaryVectorConverter(allocator, vector, context, session, idx).convert(); - case DATEDAY: - return new DateVectorConverter(allocator, vector, context, session, idx, timeZoneToUse) - .convert(); - case TIMESEC: - return new TimeSecVectorConverter(allocator, vector).convert(); - case TIMEMILLI: - return new TimeMilliVectorConverter(allocator, vector).convert(); - case TIMEMICRO: - return new TimeMicroVectorConverter(allocator, vector).convert(); - case TIMENANO: - return new TimeNanoVectorConverter(allocator, vector).convert(); - } - } - } catch (SFException ex) { - throw new SnowflakeSQLException( - ex.getCause(), ex.getSqlState(), ex.getVendorCode(), ex.getParams()); - } - return null; - } - FieldVector convert() throws SFException, SnowflakeSQLException; } diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java index f8b7952eb..6b5c3d9c1 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java @@ -3,6 +3,8 @@ import static net.snowflake.client.core.arrow.ArrowVectorConverterUtil.getScale; import java.util.Map; +import java.util.TimeZone; + import net.snowflake.client.core.DataConversionContext; import net.snowflake.client.core.SFBaseSession; import net.snowflake.client.core.SFException; @@ -19,8 +21,7 @@ public class ArrowFullVectorConverterUtil { private ArrowFullVectorConverterUtil() {} - public static Types.MinorType deduceType(ValueVector vector, SFBaseSession session) - throws SnowflakeSQLLoggedException { + static Types.MinorType deduceType(ValueVector vector, SFBaseSession session) throws SnowflakeSQLLoggedException { Types.MinorType type = Types.getMinorTypeForArrowType(vector.getField().getType()); // each column's metadata Map customMeta = vector.getField().getMetadata(); @@ -39,52 +40,27 @@ public static Types.MinorType deduceType(ValueVector vector, SFBaseSession sessi break; } case TIME: - return Types.MinorType.TIMEMILLI; - case TIMESTAMP_LTZ: - { - int sfScale = getScale(vector, session); - switch (sfScale) { - case 0: - return Types.MinorType.TIMESTAMPSECTZ; - case 3: - return Types.MinorType.TIMESTAMPMILLITZ; - case 6: - return Types.MinorType.TIMESTAMPMICROTZ; - case 9: - return Types.MinorType.TIMESTAMPNANOTZ; - } - break; + { + String scaleStr = vector.getField().getMetadata().get("scale"); + int sfScale = Integer.parseInt(scaleStr); + if (sfScale == 0) { + return Types.MinorType.TIMESEC; } - case TIMESTAMP_TZ: - { - int sfScale = getScale(vector, session); - switch (sfScale) { - case 0: - return Types.MinorType.TIMESTAMPSECTZ; - case 3: - return Types.MinorType.TIMESTAMPMILLITZ; - case 6: - return Types.MinorType.TIMESTAMPMICROTZ; - case 9: - return Types.MinorType.TIMESTAMPNANOTZ; - } - break; + if (sfScale <= 3) { + return Types.MinorType.TIMEMILLI; } - case TIMESTAMP_NTZ: - { - int sfScale = getScale(vector, session); - switch (sfScale) { - case 0: - return Types.MinorType.TIMESTAMPSEC; - case 3: - return Types.MinorType.TIMESTAMPMILLI; - case 6: - return Types.MinorType.TIMESTAMPMICRO; - case 9: - return Types.MinorType.TIMESTAMPNANO; - } - break; + if (sfScale <= 6) { + return Types.MinorType.TIMEMICRO; } + if (sfScale <= 9) { + return Types.MinorType.TIMENANO; + } + } + case TIMESTAMP_NTZ: + return Types.MinorType.TIMESTAMPNANO; + case TIMESTAMP_LTZ: + case TIMESTAMP_TZ: + return Types.MinorType.TIMESTAMPNANOTZ; } } return type; @@ -95,6 +71,7 @@ public static FieldVector convert( ValueVector vector, DataConversionContext context, SFBaseSession session, + TimeZone timeZoneToUse, int idx, Object targetType) throws SnowflakeSQLException { @@ -114,6 +91,23 @@ public static FieldVector convert( return new BigIntVectorConverter(allocator, vector, context, session, idx).convert(); case DECIMAL: return new DecimalVectorConverter(allocator, vector, context, session, idx).convert(); + case FLOAT8: + return new FloatVectorConverter(allocator, vector, context, session, idx).convert(); + case BIT: + return new BitVectorConverter(allocator, vector, context, session, idx).convert(); + case VARBINARY: + return new BinaryVectorConverter(allocator, vector, context, session, idx).convert(); + case DATEDAY: + return new DateVectorConverter(allocator, vector, context, session, idx, timeZoneToUse) + .convert(); + case TIMESEC: + return new TimeSecVectorConverter(allocator, vector).convert(); + case TIMEMILLI: + return new TimeMilliVectorConverter(allocator, vector).convert(); + case TIMEMICRO: + return new TimeMicroVectorConverter(allocator, vector).convert(); + case TIMENANO: + return new TimeNanoVectorConverter(allocator, vector).convert(); default: throw new SnowflakeSQLLoggedException( session, diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java index 8ea5a6076..1d66cb38e 100644 --- a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java +++ b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java @@ -546,7 +546,7 @@ public List fetch() throws SnowflakeSQLException { for (int i = 0; i < record.size(); i++) { ValueVector vector = record.get(i); convertedVectors.add( - ArrowFullVectorConverterUtil.convert(rootAllocator, vector, context, session, i, null)); + ArrowFullVectorConverterUtil.convert(rootAllocator, vector, context, session, timeZoneToUse, i, null)); } result.add(new VectorSchemaRoot(convertedVectors)); } From 400fd1328e22e8e19658c92c123ccfabf08e06ff Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Fri, 20 Sep 2024 09:50:59 +0200 Subject: [PATCH 20/21] Added try-with-resources statements --- .../snowflake/client/jdbc/ArrowBatchesIT.java | 267 +++++++++--------- 1 file changed, 133 insertions(+), 134 deletions(-) diff --git a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java index 950993d55..6bfeb7f44 100644 --- a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java +++ b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java @@ -259,34 +259,34 @@ public void testDecimalBatch() throws Exception { @Test public void testBitBatch() throws Exception { - Statement statement = connection.createStatement(); - ResultSet rs = - statement.executeQuery( - "select true union all select false union all select true union all select false" - + " union all select true union all select false union all select true"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - int trueCount = 0; int falseCount = 0; - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - assertTrue(root.getVector(0) instanceof BitVector); - BitVector vector = (BitVector) root.getVector(0); - for (int i = 0; i < root.getRowCount(); i++) { - if (vector.getObject(i)) { - trueCount++; - } else { - falseCount++; + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select true union all select false union all select true union all select false" + + " union all select true union all select false union all select true")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + assertTrue(root.getVector(0) instanceof BitVector); + BitVector vector = (BitVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + if (vector.getObject(i)) { + trueCount++; + } else { + falseCount++; + } } + root.close(); } - root.close(); } + assertNoMemoryLeaks(rs); } - assertNoMemoryLeaks(rs); - rs.close(); assertEquals(4, trueCount); assertEquals(3, falseCount); @@ -294,37 +294,38 @@ public void testBitBatch() throws Exception { @Test public void testBinaryBatch() throws Exception { - Statement statement = connection.createStatement(); - ResultSet rs = - statement.executeQuery("select TO_BINARY('546AB0') union select TO_BINARY('018E3271')"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - int totalRows = 0; List> values = new ArrayList<>(); - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - assertTrue(root.getVector(0) instanceof VarBinaryVector); - VarBinaryVector vector = (VarBinaryVector) root.getVector(0); - totalRows += root.getRowCount(); - for (int i = 0; i < root.getRowCount(); i++) { - byte[] bytes = vector.getObject(i); - ArrayList byteArrayList = - new ArrayList() { - { - for (byte aByte : bytes) { - add(aByte); - } - } - }; - values.add(byteArrayList); + + try(Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery("select TO_BINARY('546AB0') union select TO_BINARY('018E3271')")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + assertTrue(root.getVector(0) instanceof VarBinaryVector); + VarBinaryVector vector = (VarBinaryVector) root.getVector(0); + totalRows += root.getRowCount(); + for (int i = 0; i < root.getRowCount(); i++) { + byte[] bytes = vector.getObject(i); + ArrayList byteArrayList = + new ArrayList() { + { + for (byte aByte : bytes) { + add(aByte); + } + } + }; + values.add(byteArrayList); + } + root.close(); } - root.close(); } + assertNoMemoryLeaks(rs); } - assertNoMemoryLeaks(rs); - rs.close(); List> expected = new ArrayList>() { @@ -355,29 +356,29 @@ public void testBinaryBatch() throws Exception { @Test public void testDateBatch() throws Exception, SFException { - Statement statement = connection.createStatement(); - ResultSet rs = - statement.executeQuery("select '1119-02-01'::DATE union select '2021-09-11'::DATE"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - int totalRows = 0; List values = new ArrayList<>(); - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - totalRows += root.getRowCount(); - assertTrue(root.getVector(0) instanceof DateDayVector); - DateDayVector vector = (DateDayVector) root.getVector(0); - for (int i = 0; i < root.getRowCount(); i++) { - values.add(LocalDate.ofEpochDay(vector.get(i))); + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery("select '1119-02-01'::DATE union select '2021-09-11'::DATE")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof DateDayVector); + DateDayVector vector = (DateDayVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(LocalDate.ofEpochDay(vector.get(i))); + } + root.close(); } - root.close(); } + assertNoMemoryLeaks(rs); } - assertNoMemoryLeaks(rs); - rs.close(); List expected = new ArrayList() { @@ -393,29 +394,29 @@ public void testDateBatch() throws Exception, SFException { @Test public void testTimeSecBatch() throws Exception, SFException { - Statement statement = connection.createStatement(); - ResultSet rs = - statement.executeQuery("select '11:32:54'::TIME(0) union select '8:11:25'::TIME(0)"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - int totalRows = 0; List values = new ArrayList<>(); - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - totalRows += root.getRowCount(); - assertTrue(root.getVector(0) instanceof TimeSecVector); - TimeSecVector vector = (TimeSecVector) root.getVector(0); - for (int i = 0; i < root.getRowCount(); i++) { - values.add(LocalTime.ofSecondOfDay(vector.get(i))); + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery("select '11:32:54'::TIME(0) union select '8:11:25'::TIME(0)")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof TimeSecVector); + TimeSecVector vector = (TimeSecVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(LocalTime.ofSecondOfDay(vector.get(i))); + } + root.close(); } - root.close(); } + assertNoMemoryLeaks(rs); } - assertNoMemoryLeaks(rs); - rs.close(); List expected = new ArrayList() { @@ -431,29 +432,29 @@ public void testTimeSecBatch() throws Exception, SFException { @Test public void testTimeMilliBatch() throws Exception, SFException { - Statement statement = connection.createStatement(); - ResultSet rs = - statement.executeQuery("select '11:32:54.13'::TIME(2) union select '8:11:25.91'::TIME(2)"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - int totalRows = 0; List values = new ArrayList<>(); - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - totalRows += root.getRowCount(); - assertTrue(root.getVector(0) instanceof TimeMilliVector); - TimeMilliVector vector = (TimeMilliVector) root.getVector(0); - for (int i = 0; i < root.getRowCount(); i++) { - values.add(vector.getObject(i).toLocalTime()); + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery("select '11:32:54.13'::TIME(2) union select '8:11:25.91'::TIME(2)")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof TimeMilliVector); + TimeMilliVector vector = (TimeMilliVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.getObject(i).toLocalTime()); + } + root.close(); } - root.close(); } + assertNoMemoryLeaks(rs); } - assertNoMemoryLeaks(rs); - rs.close(); List expected = new ArrayList() { @@ -469,30 +470,29 @@ public void testTimeMilliBatch() throws Exception, SFException { @Test public void testTimeMicroBatch() throws Exception, SFException { - Statement statement = connection.createStatement(); - ResultSet rs = - statement.executeQuery( - "select '11:32:54.139901'::TIME(6) union select '8:11:25.911765'::TIME(6)"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - int totalRows = 0; List values = new ArrayList<>(); - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - totalRows += root.getRowCount(); - assertTrue(root.getVector(0) instanceof TimeMicroVector); - TimeMicroVector vector = (TimeMicroVector) root.getVector(0); - for (int i = 0; i < root.getRowCount(); i++) { - values.add(LocalTime.ofNanoOfDay(vector.get(i) * 1000)); + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery("select '11:32:54.139901'::TIME(6) union select '8:11:25.911765'::TIME(6)")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof TimeMicroVector); + TimeMicroVector vector = (TimeMicroVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(LocalTime.ofNanoOfDay(vector.get(i) * 1000)); + } + root.close(); } - root.close(); } + assertNoMemoryLeaks(rs); } - assertNoMemoryLeaks(rs); - rs.close(); List expected = new ArrayList() { @@ -508,30 +508,29 @@ public void testTimeMicroBatch() throws Exception, SFException { @Test public void testTimeNanoBatch() throws Exception, SFException { - Statement statement = connection.createStatement(); - ResultSet rs = - statement.executeQuery( - "select '11:32:54.1399013'::TIME(7) union select '8:11:25.9117654'::TIME(7)"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - int totalRows = 0; List values = new ArrayList<>(); - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - totalRows += root.getRowCount(); - assertTrue(root.getVector(0) instanceof TimeNanoVector); - TimeNanoVector vector = (TimeNanoVector) root.getVector(0); - for (int i = 0; i < root.getRowCount(); i++) { - values.add(LocalTime.ofNanoOfDay(vector.get(i))); + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery("select '11:32:54.1399013'::TIME(7) union select '8:11:25.9117654'::TIME(7)")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof TimeNanoVector); + TimeNanoVector vector = (TimeNanoVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(LocalTime.ofNanoOfDay(vector.get(i))); + } + root.close(); } - root.close(); } + assertNoMemoryLeaks(rs); } - assertNoMemoryLeaks(rs); - rs.close(); List expected = new ArrayList() { From 318bc4f84ec8959decf311da203f123099be93d1 Mon Sep 17 00:00:00 2001 From: sfc-gh-astachowski Date: Tue, 24 Sep 2024 14:59:58 +0200 Subject: [PATCH 21/21] Merge fixes and introduced constants --- .../ArrowFullVectorConverterUtil.java | 42 ++++++++------ .../TimestampVectorConverter.java | 18 ++++-- .../client/jdbc/ArrowResultChunk.java | 3 +- .../snowflake/client/jdbc/ArrowBatchesIT.java | 56 +++++++++++-------- 4 files changed, 73 insertions(+), 46 deletions(-) diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java index 6b5c3d9c1..6f3f14d57 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java @@ -4,7 +4,6 @@ import java.util.Map; import java.util.TimeZone; - import net.snowflake.client.core.DataConversionContext; import net.snowflake.client.core.SFBaseSession; import net.snowflake.client.core.SFException; @@ -21,7 +20,8 @@ public class ArrowFullVectorConverterUtil { private ArrowFullVectorConverterUtil() {} - static Types.MinorType deduceType(ValueVector vector, SFBaseSession session) throws SnowflakeSQLLoggedException { + static Types.MinorType deduceType(ValueVector vector, SFBaseSession session) + throws SnowflakeSQLLoggedException { Types.MinorType type = Types.getMinorTypeForArrowType(vector.getField().getType()); // each column's metadata Map customMeta = vector.getField().getMetadata(); @@ -40,22 +40,22 @@ static Types.MinorType deduceType(ValueVector vector, SFBaseSession session) thr break; } case TIME: - { - String scaleStr = vector.getField().getMetadata().get("scale"); - int sfScale = Integer.parseInt(scaleStr); - if (sfScale == 0) { - return Types.MinorType.TIMESEC; - } - if (sfScale <= 3) { - return Types.MinorType.TIMEMILLI; - } - if (sfScale <= 6) { - return Types.MinorType.TIMEMICRO; - } - if (sfScale <= 9) { - return Types.MinorType.TIMENANO; + { + String scaleStr = vector.getField().getMetadata().get("scale"); + int sfScale = Integer.parseInt(scaleStr); + if (sfScale == 0) { + return Types.MinorType.TIMESEC; + } + if (sfScale <= 3) { + return Types.MinorType.TIMEMILLI; + } + if (sfScale <= 6) { + return Types.MinorType.TIMEMICRO; + } + if (sfScale <= 9) { + return Types.MinorType.TIMENANO; + } } - } case TIMESTAMP_NTZ: return Types.MinorType.TIMESTAMPNANO; case TIMESTAMP_LTZ: @@ -99,7 +99,7 @@ public static FieldVector convert( return new BinaryVectorConverter(allocator, vector, context, session, idx).convert(); case DATEDAY: return new DateVectorConverter(allocator, vector, context, session, idx, timeZoneToUse) - .convert(); + .convert(); case TIMESEC: return new TimeSecVectorConverter(allocator, vector).convert(); case TIMEMILLI: @@ -108,6 +108,12 @@ public static FieldVector convert( return new TimeMicroVectorConverter(allocator, vector).convert(); case TIMENANO: return new TimeNanoVectorConverter(allocator, vector).convert(); + case TIMESTAMPNANOTZ: + return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, false) + .convert(); + case TIMESTAMPNANO: + return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, true) + .convert(); default: throw new SnowflakeSQLLoggedException( session, diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java index ee47a2834..df344b065 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java @@ -30,6 +30,10 @@ public class TimestampVectorConverter implements ArrowFullVectorConverter { private static final String FIELD_NAME_TIME_ZONE_INDEX = "timezone"; // time zone index private static final String FIELD_NAME_FRACTION = "fraction"; // fraction in nanoseconds + private static final int UTC_OFFSET = 1440; + private static final long NANOS_PER_MILLI = 1000000L; + private static final int MILLIS_PER_SECOND = 1000; + private static final int SECONDS_PER_MINUTE = 60; public TimestampVectorConverter( RootAllocator allocator, @@ -57,7 +61,7 @@ private IntVector makeVectorOfUTCOffsets(int length) { vector.allocateNew(length); vector.setValueCount(length); for (int i = 0; i < length; i++) { - vector.set(i, 1440); + vector.set(i, UTC_OFFSET); } return vector; } @@ -91,12 +95,14 @@ private IntVector makeTimeZoneOffsets( BigIntVector seconds, IntVector fractions, TimeZone timeZone) { IntVector offsets = new IntVector(FIELD_NAME_TIME_ZONE_INDEX, allocator); offsets.allocateNew(vector.getValueCount()); + offsets.setValueCount(vector.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { offsets.set( i, - 1440 - + timeZone.getOffset(seconds.get(i) * 1000 + fractions.get(i) / 1000000) - / (1000 * 60)); + UTC_OFFSET + + timeZone.getOffset( + seconds.get(i) * MILLIS_PER_SECOND + fractions.get(i) / NANOS_PER_MILLI) + / (MILLIS_PER_SECOND * SECONDS_PER_MINUTE)); } return offsets; } @@ -153,7 +159,9 @@ public FieldVector convert() throws SFException, SnowflakeSQLException { if (isNTZ && context.getHonorClientTZForTimestampNTZ()) { timeZoneIndices = makeTimeZoneOffsets(seconds, fractions, TimeZone.getDefault()); for (int i = 0; i < vector.getValueCount(); i++) { - seconds.set(i, seconds.get(i) - (timeZoneIndices.get(i) - 1440) * 60L); + seconds.set( + i, + seconds.get(i) - (long) (timeZoneIndices.get(i) - UTC_OFFSET) * SECONDS_PER_MINUTE); } } else if (isNTZ || timeZoneToUse == null) { timeZoneIndices = makeVectorOfUTCOffsets(vector.getValueCount()); diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java index 1d66cb38e..57dcb38ae 100644 --- a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java +++ b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java @@ -546,7 +546,8 @@ public List fetch() throws SnowflakeSQLException { for (int i = 0; i < record.size(); i++) { ValueVector vector = record.get(i); convertedVectors.add( - ArrowFullVectorConverterUtil.convert(rootAllocator, vector, context, session, timeZoneToUse, i, null)); + ArrowFullVectorConverterUtil.convert( + rootAllocator, vector, context, session, timeZoneToUse, i, null)); } result.add(new VectorSchemaRoot(convertedVectors)); } diff --git a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java index 8f0a9c410..62032dcf5 100644 --- a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java +++ b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java @@ -267,9 +267,9 @@ public void testBitBatch() throws Exception { try (Statement statement = connection.createStatement(); ResultSet rs = - statement.executeQuery( - "select true union all select false union all select true union all select false" - + " union all select true union all select false union all select true")) { + statement.executeQuery( + "select true union all select false union all select true union all select false" + + " union all select true union all select false union all select true")) { ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); while (batches.hasNext()) { @@ -300,9 +300,10 @@ public void testBinaryBatch() throws Exception { int totalRows = 0; List> values = new ArrayList<>(); - try(Statement statement = connection.createStatement(); + try (Statement statement = connection.createStatement(); ResultSet rs = - statement.executeQuery("select TO_BINARY('546AB0') union select TO_BINARY('018E3271')")) { + statement.executeQuery( + "select TO_BINARY('546AB0') union select TO_BINARY('018E3271')")) { ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); while (batches.hasNext()) { @@ -315,13 +316,13 @@ public void testBinaryBatch() throws Exception { for (int i = 0; i < root.getRowCount(); i++) { byte[] bytes = vector.getObject(i); ArrayList byteArrayList = - new ArrayList() { - { - for (byte aByte : bytes) { - add(aByte); - } - } - }; + new ArrayList() { + { + for (byte aByte : bytes) { + add(aByte); + } + } + }; values.add(byteArrayList); } root.close(); @@ -364,7 +365,7 @@ public void testDateBatch() throws Exception, SFException { try (Statement statement = connection.createStatement(); ResultSet rs = - statement.executeQuery("select '1119-02-01'::DATE union select '2021-09-11'::DATE")) { + statement.executeQuery("select '1119-02-01'::DATE union select '2021-09-11'::DATE")) { ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); while (batches.hasNext()) { @@ -402,7 +403,7 @@ public void testTimeSecBatch() throws Exception, SFException { try (Statement statement = connection.createStatement(); ResultSet rs = - statement.executeQuery("select '11:32:54'::TIME(0) union select '8:11:25'::TIME(0)")) { + statement.executeQuery("select '11:32:54'::TIME(0) union select '8:11:25'::TIME(0)")) { ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); while (batches.hasNext()) { @@ -439,8 +440,9 @@ public void testTimeMilliBatch() throws Exception, SFException { List values = new ArrayList<>(); try (Statement statement = connection.createStatement(); - ResultSet rs = - statement.executeQuery("select '11:32:54.13'::TIME(2) union select '8:11:25.91'::TIME(2)")) { + ResultSet rs = + statement.executeQuery( + "select '11:32:54.13'::TIME(2) union select '8:11:25.91'::TIME(2)")) { ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); while (batches.hasNext()) { @@ -477,8 +479,9 @@ public void testTimeMicroBatch() throws Exception, SFException { List values = new ArrayList<>(); try (Statement statement = connection.createStatement(); - ResultSet rs = - statement.executeQuery("select '11:32:54.139901'::TIME(6) union select '8:11:25.911765'::TIME(6)")) { + ResultSet rs = + statement.executeQuery( + "select '11:32:54.139901'::TIME(6) union select '8:11:25.911765'::TIME(6)")) { ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); while (batches.hasNext()) { @@ -515,8 +518,9 @@ public void testTimeNanoBatch() throws Exception, SFException { List values = new ArrayList<>(); try (Statement statement = connection.createStatement(); - ResultSet rs = - statement.executeQuery("select '11:32:54.1399013'::TIME(7) union select '8:11:25.9117654'::TIME(7)")) { + ResultSet rs = + statement.executeQuery( + "select '11:32:54.1399013'::TIME(7) union select '8:11:25.9117654'::TIME(7)")) { ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); while (batches.hasNext()) { @@ -569,7 +573,7 @@ private void testTimestampBase(String query) throws Exception, SFException { @Test public void testTimestampTZBatch() throws Exception, SFException { - testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_TZ"); + testTimestampBase("select '2020-04-05 12:22:12+0700'::TIMESTAMP_TZ"); } @Test @@ -582,11 +586,19 @@ public void testTimestampLTZUseSessionTimezoneBatch() throws Exception, SFExcept @Test public void testTimestampLTZBatch() throws Exception, SFException { - testTimestampBase("select '2020-04-05 12:22:12+0700'::TIMESTAMP_LTZ"); + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_LTZ"); } @Test public void testTimestampNTZBatch() throws Exception, SFException { testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_NTZ"); } + + @Test + public void testTimestampNTZDontHonorClientTimezone() throws Exception, SFException { + Statement statement = connection.createStatement(); + statement.execute("alter session set CLIENT_HONOR_CLIENT_TZ_FOR_TIMESTAMP_NTZ=false"); + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_LTZ"); + statement.execute("alter session unset CLIENT_HONOR_CLIENT_TZ_FOR_TIMESTAMP_NTZ"); + } }